ynp3 commited on
Commit
92d0911
·
1 Parent(s): 222e77a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -27
app.py CHANGED
@@ -3,40 +3,54 @@ import pandas as pd
3
  from transformers import BertTokenizer, BertForSequenceClassification
4
  import torch
5
 
6
- # Load pre-trained BERT model
7
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
8
- model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
9
  model.eval()
10
 
11
- # Create a persistent DataFrame to store classification results
12
- classified_data = pd.DataFrame(columns=['Text', 'Toxicity'])
13
-
14
  def classify_text(text):
15
- # Tokenize and encode input text
16
- input_ids = torch.tensor(tokenizer.encode(text, add_special_tokens=True)).unsqueeze(0)
 
 
 
 
 
 
 
 
17
 
18
- # Forward pass through BERT model
19
- outputs = model(input_ids)
20
- logits = outputs.logits
21
- predicted_class = torch.argmax(logits, dim=1).item()
22
- toxicity = "Toxic" if predicted_class == 1 else "Non-Toxic"
23
- return toxicity
24
 
25
  # Streamlit app
26
- def main():
27
- st.title("Toxicity Classifier")
28
- user_text = st.text_area("Enter text to classify:")
 
 
 
 
 
29
  if st.button("Classify"):
30
- if user_text:
31
- toxicity = classify_text(user_text)
32
- st.write(f"Predicted Toxicity: {toxicity}")
33
- # Add classification results to the persistent DataFrame
34
- global classified_data
35
- classified_data = classified_data.append({'Text': user_text, 'Toxicity': toxicity}, ignore_index=True)
36
- else:
37
- st.warning("Please enter some text.")
38
- if st.button("View Classified Data"):
39
- st.write(classified_data)
 
 
 
 
 
 
40
 
 
41
  if __name__ == "__main__":
42
- main()
 
3
  from transformers import BertTokenizer, BertForSequenceClassification
4
  import torch
5
 
6
+ # Load pre-trained BERT model and tokenizer
7
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
8
+ model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=6)
9
  model.eval()
10
 
11
+ # Function to classify text using the pre-trained BERT model
 
 
12
  def classify_text(text):
13
+ # Tokenize input text
14
+ input_ids = tokenizer.encode(text, add_special_tokens=True)
15
+ # Convert tokenized input to tensor
16
+ input_tensor = torch.tensor([input_ids])
17
+ # Get model predictions
18
+ with torch.no_grad():
19
+ logits = model(input_tensor)[0]
20
+ # Get predicted labels
21
+ predicted_labels = torch.sigmoid(logits).numpy()
22
+ return predicted_labels
23
 
24
+ # Create a persistent DataFrame to store classification results
25
+ results_df = pd.DataFrame(columns=['Text', 'Toxic', 'Severe Toxic', 'Obscene', 'Threat', 'Insult', 'Identity Hate'])
 
 
 
 
26
 
27
  # Streamlit app
28
+ def app():
29
+ st.title("Toxicity Classification App")
30
+ st.write("Enter text below to classify its toxicity.")
31
+
32
+ # User input
33
+ user_input = st.text_area("Enter text here:", "", key='user_input')
34
+
35
+ # Classification
36
  if st.button("Classify"):
37
+ # Perform classification
38
+ labels = classify_text(user_input)
39
+ # Print classification results
40
+ st.write("Classification Results:")
41
+ st.write("Toxic: {:.2%}".format(labels[0][0]))
42
+ st.write("Severe Toxic: {:.2%}".format(labels[0][1]))
43
+ st.write("Obscene: {:.2%}".format(labels[0][2]))
44
+ st.write("Threat: {:.2%}".format(labels[0][3]))
45
+ st.write("Insult: {:.2%}".format(labels[0][4]))
46
+ st.write("Identity Hate: {:.2%}".format(labels[0][5]))
47
+ # Add results to persistent DataFrame
48
+ results_df.loc[len(results_df)] = [user_input, labels[0][0], labels[0][1], labels[0][2], labels[0][3], labels[0][4], labels[0][5]]
49
+
50
+ # Show results DataFrame
51
+ st.write("Classification Results DataFrame:")
52
+ st.write(results_df)
53
 
54
+ # Run the app
55
  if __name__ == "__main__":
56
+ app()