import altair as alt import numpy as np import pandas as pd import streamlit as st import streamlit as st import pandas as pd import torch import torch.nn.functional as F from transformers import BertTokenizer, BertForSequenceClassification @st.cache_resource(show_spinner=False) def load_model(): # Load your fine-tuned model and tokenizer tokenizer = BertTokenizer.from_pretrained("CustomModel") model = BertForSequenceClassification.from_pretrained("CustomModel") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) return tokenizer, model, device tokenizer, model, device = load_model() st.title("Batch Toxic Comment Classifier") st.write("Upload a CSV file containing text comments and get toxicity scores for each row.") # CSV upload uploaded_file = st.file_uploader("Choose a CSV file", type="csv") if uploaded_file is not None: df = pd.read_csv(uploaded_file) # Let user select which column contains text text_cols = df.select_dtypes(include=["object"]).columns.tolist() if not text_cols: st.error("No text columns found in the uploaded file.") else: col = st.selectbox("Select text column to classify", text_cols) if st.button("Classify CSV"): texts = df[col].astype(str).tolist() results = [] # Batch inference for text in texts: inputs = tokenizer( text, padding=True, truncation=True, return_tensors="pt" ).to(device) outputs = model(**inputs) probs = F.softmax(outputs.logits, dim=-1).detach().cpu().numpy()[0] id2label = model.config.id2label if hasattr(model.config, "id2label") else {0: "non-toxic", 1: "toxic"} # record per-row scores row_res = {id2label[i]: float(probs[i]) for i in range(len(probs))} results.append(row_res) # Combine with original score_df = pd.DataFrame(results) combined = pd.concat([df.reset_index(drop=True), score_df], axis=1) st.subheader("Classification Results") st.dataframe(combined) # Optional: download results csv = combined.to_csv(index=False).encode('utf-8') st.download_button( label="Download results as CSV", data=csv, file_name="classified_results.csv", mime="text/csv" )