Spaces:
Sleeping
Sleeping
File size: 2,592 Bytes
8c396b2 d016772 8c396b2 d016772 8c396b2 d016772 8c396b2 d016772 8c396b2 d016772 8c396b2 d016772 8c396b2 d016772 8c396b2 d016772 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import altair as alt
import numpy as np
import pandas as pd
import streamlit as st
import streamlit as st
import pandas as pd
import torch
import torch.nn.functional as F
from transformers import BertTokenizer, BertForSequenceClassification
@st.cache_resource(show_spinner=False)
def load_model():
# Load your fine-tuned model and tokenizer
tokenizer = BertTokenizer.from_pretrained("CustomModel")
model = BertForSequenceClassification.from_pretrained("CustomModel")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
return tokenizer, model, device
tokenizer, model, device = load_model()
st.title("Batch Toxic Comment Classifier")
st.write("Upload a CSV file containing text comments and get toxicity scores for each row.")
# CSV upload
uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
if uploaded_file is not None:
df = pd.read_csv(uploaded_file)
# Let user select which column contains text
text_cols = df.select_dtypes(include=["object"]).columns.tolist()
if not text_cols:
st.error("No text columns found in the uploaded file.")
else:
col = st.selectbox("Select text column to classify", text_cols)
if st.button("Classify CSV"):
texts = df[col].astype(str).tolist()
results = []
# Batch inference
for text in texts:
inputs = tokenizer(
text,
padding=True,
truncation=True,
return_tensors="pt"
).to(device)
outputs = model(**inputs)
probs = F.softmax(outputs.logits, dim=-1).detach().cpu().numpy()[0]
id2label = model.config.id2label if hasattr(model.config, "id2label") else {0: "non-toxic", 1: "toxic"}
# record per-row scores
row_res = {id2label[i]: float(probs[i]) for i in range(len(probs))}
results.append(row_res)
# Combine with original
score_df = pd.DataFrame(results)
combined = pd.concat([df.reset_index(drop=True), score_df], axis=1)
st.subheader("Classification Results")
st.dataframe(combined)
# Optional: download results
csv = combined.to_csv(index=False).encode('utf-8')
st.download_button(
label="Download results as CSV",
data=csv,
file_name="classified_results.csv",
mime="text/csv"
)
|