ynp3 commited on
Commit
3313c97
·
1 Parent(s): bc69c2f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -48
app.py CHANGED
@@ -1,63 +1,71 @@
1
  import streamlit as st
2
  import pandas as pd
3
- from transformers import BertTokenizer, BertForSequenceClassification
4
  import torch
 
5
 
6
- # Load pre-trained BERT model
7
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
8
- model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
 
9
  model.eval()
10
 
11
- # Create a DataFrame to store classification results
12
- classification_results_df = pd.DataFrame(columns=['Text', 'Toxic', 'Severe Toxic', 'Obscene', 'Threat', 'Insult', 'Identity Hate'])
13
 
14
  def classify_text(text):
15
  # Tokenize text
16
- inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True)
17
- # Forward pass through the BERT model
18
- outputs = model(**inputs)
19
- # Get predicted probabilities for each class
20
- probs = torch.sigmoid(outputs.logits)
21
- # Round probabilities to 0 or 1 to get binary predictions
22
- preds = (probs > 0.5).int().tolist()[0]
23
- return preds
24
-
25
- def add_classification_to_df(text, preds):
26
- # Add classification results to the DataFrame
27
- classification_results_df.loc[len(classification_results_df)] = [text] + preds
 
 
 
 
 
 
 
 
28
 
29
  # Streamlit app
30
  def main():
31
- st.title("Toxicity Classification with BERT")
32
- # Input text from user
33
- text = st.text_area("Enter text for classification", "")
34
- if st.button("Classify"):
35
- if text.strip() == "":
36
- st.warning("Please enter some text for classification.")
37
- else:
38
- # Perform classification
39
- preds = classify_text(text)
40
- # Display classification results
41
- st.subheader("Classification Results:")
42
- # Check if preds has enough elements
43
- if len(preds) >= 6:
44
- st.write("Toxic: ", preds[0])
45
- st.write("Severe Toxic: ", preds[1])
46
- st.write("Obscene: ", preds[2])
47
- st.write("Threat: ", preds[3])
48
- st.write("Insult: ", preds[4])
49
- st.write("Identity Hate: ", preds[5])
50
- # Add classification results to DataFrame
51
- add_classification_to_df(text, preds)
52
- else:
53
- st.error("Error: Classification results are incomplete.")
54
- # Debug statements
55
- st.write("preds:", preds)
56
- st.write("len(preds):", len(preds))
57
- if st.button("View Classification Results"):
58
- # Display classification results DataFrame
59
- st.subheader("All Classification Results:")
60
- st.write(classification_results_df)
61
 
62
  if __name__ == '__main__':
63
  main()
 
1
  import streamlit as st
2
  import pandas as pd
 
3
  import torch
4
+ from transformers import BertTokenizer, BertForSequenceClassification
5
 
6
+ # Load pre-trained BERT model and tokenizer
7
+ MODEL_NAME = 'bert-base-uncased'
8
+ tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
9
+ model = BertForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=6)
10
  model.eval()
11
 
12
+ # Create DataFrame to store classification results
13
+ df_results = pd.DataFrame(columns=['Text', 'Toxic', 'Severe Toxic', 'Obscene', 'Threat', 'Insult', 'Identity Hate'])
14
 
15
  def classify_text(text):
16
  # Tokenize text
17
+ tokens = tokenizer.encode_plus(
18
+ text,
19
+ max_length=512,
20
+ truncation=True,
21
+ padding=True,
22
+ return_attention_mask=True,
23
+ return_tensors='pt'
24
+ )
25
+
26
+ # Get model's predictions
27
+ with torch.no_grad():
28
+ outputs = model(**tokens)
29
+ logits = outputs.logits
30
+ probabilities = torch.softmax(logits, dim=1).tolist()[0]
31
+
32
+ # Extract predicted labels
33
+ labels = ['Toxic', 'Severe Toxic', 'Obscene', 'Threat', 'Insult', 'Identity Hate']
34
+ predicted_labels = [labels[i] for i, prob in enumerate(probabilities) if prob > 0.5]
35
+
36
+ return predicted_labels
37
 
38
  # Streamlit app
39
  def main():
40
+ st.title('Toxicity Classification')
41
+
42
+ # User input
43
+ text = st.text_area('Enter text:', max_chars=512)
44
+
45
+ # Perform classification
46
+ if st.button('Classify'):
47
+ predicted_labels = classify_text(text)
48
+ st.write('Predicted Labels:', predicted_labels)
49
+
50
+ # Allow user to add classification results to DataFrame
51
+ if st.button('Add to Results'):
52
+ global df_results
53
+ df_results = df_results.append({
54
+ 'Text': text,
55
+ 'Toxic': 'Toxic' in predicted_labels,
56
+ 'Severe Toxic': 'Severe Toxic' in predicted_labels,
57
+ 'Obscene': 'Obscene' in predicted_labels,
58
+ 'Threat': 'Threat' in predicted_labels,
59
+ 'Insult': 'Insult' in predicted_labels,
60
+ 'Identity Hate': 'Identity Hate' in predicted_labels
61
+ }, ignore_index=True)
62
+ st.success('Classification results added to DataFrame.')
63
+
64
+ # Show DataFrame with classification results
65
+ if not df_results.empty:
66
+ st.subheader('Classification Results')
67
+ st.dataframe(df_results)
68
+
 
69
 
70
  if __name__ == '__main__':
71
  main()