File size: 2,592 Bytes
8c396b2
 
 
 
d016772
 
 
 
 
8c396b2
d016772
 
 
 
 
 
 
 
8c396b2
d016772
8c396b2
d016772
 
8c396b2
d016772
 
 
 
8c396b2
d016772
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c396b2
d016772
 
 
8c396b2
d016772
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import altair as alt
import numpy as np
import pandas as pd
import streamlit as st
import streamlit as st
import pandas as pd
import torch
import torch.nn.functional as F
from transformers import BertTokenizer, BertForSequenceClassification

@st.cache_resource(show_spinner=False)
def load_model():
    # Load your fine-tuned model and tokenizer
    tokenizer = BertTokenizer.from_pretrained("CustomModel")
    model = BertForSequenceClassification.from_pretrained("CustomModel")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    return tokenizer, model, device

tokenizer, model, device = load_model()

st.title("Batch Toxic Comment Classifier")
st.write("Upload a CSV file containing text comments and get toxicity scores for each row.")

# CSV upload
uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
if uploaded_file is not None:
    df = pd.read_csv(uploaded_file)

    # Let user select which column contains text
    text_cols = df.select_dtypes(include=["object"]).columns.tolist()
    if not text_cols:
        st.error("No text columns found in the uploaded file.")
    else:
        col = st.selectbox("Select text column to classify", text_cols)
        if st.button("Classify CSV"):
            texts = df[col].astype(str).tolist()
            results = []
            
            # Batch inference
            for text in texts:
                inputs = tokenizer(
                    text,
                    padding=True,
                    truncation=True,
                    return_tensors="pt"
                ).to(device)
                outputs = model(**inputs)
                probs = F.softmax(outputs.logits, dim=-1).detach().cpu().numpy()[0]
                id2label = model.config.id2label if hasattr(model.config, "id2label") else {0: "non-toxic", 1: "toxic"}
                # record per-row scores
                row_res = {id2label[i]: float(probs[i]) for i in range(len(probs))}
                results.append(row_res)

            # Combine with original
            score_df = pd.DataFrame(results)
            combined = pd.concat([df.reset_index(drop=True), score_df], axis=1)

            st.subheader("Classification Results")
            st.dataframe(combined)
            
            # Optional: download results
            csv = combined.to_csv(index=False).encode('utf-8')
            st.download_button(
                label="Download results as CSV",
                data=csv,
                file_name="classified_results.csv",
                mime="text/csv"
            )