anshu9749 commited on
Commit
d016772
·
verified ·
1 Parent(s): c3779ec

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +58 -29
src/streamlit_app.py CHANGED
@@ -2,39 +2,68 @@ import altair as alt
2
  import numpy as np
3
  import pandas as pd
4
  import streamlit as st
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
 
 
12
 
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
 
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
 
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
 
22
 
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
 
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
2
  import numpy as np
3
  import pandas as pd
4
  import streamlit as st
5
+ import streamlit as st
6
+ import pandas as pd
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from transformers import BertTokenizer, BertForSequenceClassification
10
 
11
+ @st.cache_resource(show_spinner=False)
12
+ def load_model():
13
+ # Load your fine-tuned model and tokenizer
14
+ tokenizer = BertTokenizer.from_pretrained("CustomModel")
15
+ model = BertForSequenceClassification.from_pretrained("CustomModel")
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ model.to(device)
18
+ return tokenizer, model, device
19
 
20
+ tokenizer, model, device = load_model()
 
21
 
22
+ st.title("Batch Toxic Comment Classifier")
23
+ st.write("Upload a CSV file containing text comments and get toxicity scores for each row.")
24
 
25
+ # CSV upload
26
+ uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
27
+ if uploaded_file is not None:
28
+ df = pd.read_csv(uploaded_file)
29
 
30
+ # Let user select which column contains text
31
+ text_cols = df.select_dtypes(include=["object"]).columns.tolist()
32
+ if not text_cols:
33
+ st.error("No text columns found in the uploaded file.")
34
+ else:
35
+ col = st.selectbox("Select text column to classify", text_cols)
36
+ if st.button("Classify CSV"):
37
+ texts = df[col].astype(str).tolist()
38
+ results = []
39
+
40
+ # Batch inference
41
+ for text in texts:
42
+ inputs = tokenizer(
43
+ text,
44
+ padding=True,
45
+ truncation=True,
46
+ return_tensors="pt"
47
+ ).to(device)
48
+ outputs = model(**inputs)
49
+ probs = F.softmax(outputs.logits, dim=-1).detach().cpu().numpy()[0]
50
+ id2label = model.config.id2label if hasattr(model.config, "id2label") else {0: "non-toxic", 1: "toxic"}
51
+ # record per-row scores
52
+ row_res = {id2label[i]: float(probs[i]) for i in range(len(probs))}
53
+ results.append(row_res)
54
 
55
+ # Combine with original
56
+ score_df = pd.DataFrame(results)
57
+ combined = pd.concat([df.reset_index(drop=True), score_df], axis=1)
 
 
 
58
 
59
+ st.subheader("Classification Results")
60
+ st.dataframe(combined)
61
+
62
+ # Optional: download results
63
+ csv = combined.to_csv(index=False).encode('utf-8')
64
+ st.download_button(
65
+ label="Download results as CSV",
66
+ data=csv,
67
+ file_name="classified_results.csv",
68
+ mime="text/csv"
69
+ )