shivansh-ka's picture
Update app2.py
fb5fab5
import warnings
warnings.filterwarnings("ignore")
import streamlit as st
import pandas as pd
import plotly.express as px
from src import *
import os
global model
global prediction
@st.cache_resource
def model_obj():
model = ModelLoader()
prediction = PredictionServices(model.Model, model.Tokenizer)
st.image(os.path.join("img","toxic.jpg"))
return prediction
prediction = model_obj()
def single_predict(text):
preds = prediction.single_predict(text)
if preds < 0.5:
st.success(f'Non Toxic Comment!!! :thumbsup:')
else:
st.error(f'Toxic Comment!!! :thumbsdown:')
prediction.plot(preds)
def batch_predict(data):
preds = prediction.batch_predict(data)
return preds.to_csv(index=False).encode('utf-8')
st.title('Toxic Comment Classifier')
st.write("This application will help to classify any comment or text in any language into 'TOXIC' or 'NON-TOXIC'")
tab1, tab2 = st.tabs(["Single Value Prediciton","Batch Prediction"])
with tab1:
st.subheader("Prediction")
with st.form("comment_form", clear_on_submit=True):
comment = st.text_area(label="Enter your comment")
button = st.form_submit_button(label="Predict")
if button:
with st.spinner('Please Wait!!! Prediction in process....'):
single_predict(comment)
with tab2:
st.subheader("Batch Prediction")
csv_file = st.file_uploader("Upload File",type=['csv'])
if csv_file is not None:
csv = batch_predict(csv_file)
st.download_button(
label="Download",
data=csv,
file_name='prediction.csv',
mime='text/csv',
)