shivansh-ka's picture
Update src/predict.py
3b47abb
import pandas as pd
import numpy as np
import tensorflow as tf
import plotly.express as px
import transformers
from transformers import AutoTokenizer
import os
from src.constants import *
import re
import streamlit as st
class PredictionServices:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def tokenizer_fn(self, text:str):
tokens = self.tokenizer(text,
max_length=MAX_LEN,
truncation=True,
padding="max_length",
add_special_tokens=True,
return_tensors="tf",
return_token_type_ids = False)
inputs = dict(tokens)
return inputs
def plot(self, pred):
probs = [round(pred*100,2), round((1-pred)*100,2)]
labels = ['toxic', 'non-toxic']
color_map = {'toxic':"red", "non-toxic":"green"}
fig = px.bar(x=probs,
y=labels,
width=400, height=250,
template="plotly_dark",
text_auto=True,
title="Probabilities(%)",
color = labels,
color_discrete_map = color_map,
labels={'x':'Confidence', 'y':"label"})
fig.update_traces(width=0.5,textfont_size=20, textangle=0, textposition="outside")
fig.update_layout(yaxis_title=None,xaxis_title=None, showlegend=False)
st.plotly_chart(fig, theme="streamlit", use_container_width=True)
def data_validation(self, data):
status=True
for column in data.columns:
if column not in ['id', 'comment_text']:
status=False
return status
def batch_predict(self, data):
try:
df = pd.read_csv(data)
if self.data_validation(df):
with st.spinner('Please Wait!!! Prediction in process....'):
st.success(f'Data Validation Successfull :thumbsup:')
df.dropna(inplace=True)
df["comment_text"] = df.comment_text.apply(lambda x: re.sub('\n',' ',x).strip())
input = self.tokenizer_fn(df.comment_text.values.tolist())
preds = self.model.predict(input)
df['probabilities'] = preds
df['toxic'] = np.where(df['probabilities']>0.5, 1, 0)
st.success("Prediction Process Completed!!!, :thumbsup:")
st.info("Press download button to download prediction file")
return df
else:
st.error("Data Validation Failed!! :thumbsdown:")
except Exception as e:
print(e)
def single_predict(self, text:str):
try:
text = re.sub('\n',' ',text).strip().lower()
input = self.tokenizer_fn(text)
pred = self.model.predict(input)[0][0]
return pred
except Exception as e:
print(e)