Spaces:
Sleeping
Sleeping
| import altair as alt | |
| import numpy as np | |
| import pandas as pd | |
| import streamlit as st | |
| import streamlit as st | |
| import pandas as pd | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import BertTokenizer, BertForSequenceClassification | |
| def load_model(): | |
| # Load your fine-tuned model and tokenizer | |
| tokenizer = BertTokenizer.from_pretrained("CustomModel") | |
| model = BertForSequenceClassification.from_pretrained("CustomModel") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| return tokenizer, model, device | |
| tokenizer, model, device = load_model() | |
| st.title("Batch Toxic Comment Classifier") | |
| st.write("Upload a CSV file containing text comments and get toxicity scores for each row.") | |
| # CSV upload | |
| uploaded_file = st.file_uploader("Choose a CSV file", type="csv") | |
| if uploaded_file is not None: | |
| df = pd.read_csv(uploaded_file) | |
| # Let user select which column contains text | |
| text_cols = df.select_dtypes(include=["object"]).columns.tolist() | |
| if not text_cols: | |
| st.error("No text columns found in the uploaded file.") | |
| else: | |
| col = st.selectbox("Select text column to classify", text_cols) | |
| if st.button("Classify CSV"): | |
| texts = df[col].astype(str).tolist() | |
| results = [] | |
| # Batch inference | |
| for text in texts: | |
| inputs = tokenizer( | |
| text, | |
| padding=True, | |
| truncation=True, | |
| return_tensors="pt" | |
| ).to(device) | |
| outputs = model(**inputs) | |
| probs = F.softmax(outputs.logits, dim=-1).detach().cpu().numpy()[0] | |
| id2label = model.config.id2label if hasattr(model.config, "id2label") else {0: "non-toxic", 1: "toxic"} | |
| # record per-row scores | |
| row_res = {id2label[i]: float(probs[i]) for i in range(len(probs))} | |
| results.append(row_res) | |
| # Combine with original | |
| score_df = pd.DataFrame(results) | |
| combined = pd.concat([df.reset_index(drop=True), score_df], axis=1) | |
| st.subheader("Classification Results") | |
| st.dataframe(combined) | |
| # Optional: download results | |
| csv = combined.to_csv(index=False).encode('utf-8') | |
| st.download_button( | |
| label="Download results as CSV", | |
| data=csv, | |
| file_name="classified_results.csv", | |
| mime="text/csv" | |
| ) | |