classification / src /streamlit_app.py
anshu9749's picture
Update src/streamlit_app.py
d016772 verified
import altair as alt
import numpy as np
import pandas as pd
import streamlit as st
import streamlit as st
import pandas as pd
import torch
import torch.nn.functional as F
from transformers import BertTokenizer, BertForSequenceClassification
@st.cache_resource(show_spinner=False)
def load_model():
# Load your fine-tuned model and tokenizer
tokenizer = BertTokenizer.from_pretrained("CustomModel")
model = BertForSequenceClassification.from_pretrained("CustomModel")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
return tokenizer, model, device
tokenizer, model, device = load_model()
st.title("Batch Toxic Comment Classifier")
st.write("Upload a CSV file containing text comments and get toxicity scores for each row.")
# CSV upload
uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
if uploaded_file is not None:
df = pd.read_csv(uploaded_file)
# Let user select which column contains text
text_cols = df.select_dtypes(include=["object"]).columns.tolist()
if not text_cols:
st.error("No text columns found in the uploaded file.")
else:
col = st.selectbox("Select text column to classify", text_cols)
if st.button("Classify CSV"):
texts = df[col].astype(str).tolist()
results = []
# Batch inference
for text in texts:
inputs = tokenizer(
text,
padding=True,
truncation=True,
return_tensors="pt"
).to(device)
outputs = model(**inputs)
probs = F.softmax(outputs.logits, dim=-1).detach().cpu().numpy()[0]
id2label = model.config.id2label if hasattr(model.config, "id2label") else {0: "non-toxic", 1: "toxic"}
# record per-row scores
row_res = {id2label[i]: float(probs[i]) for i in range(len(probs))}
results.append(row_res)
# Combine with original
score_df = pd.DataFrame(results)
combined = pd.concat([df.reset_index(drop=True), score_df], axis=1)
st.subheader("Classification Results")
st.dataframe(combined)
# Optional: download results
csv = combined.to_csv(index=False).encode('utf-8')
st.download_button(
label="Download results as CSV",
data=csv,
file_name="classified_results.csv",
mime="text/csv"
)