ynp3 commited on
Commit
222e77a
·
1 Parent(s): 964ca92

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -36
app.py CHANGED
@@ -1,53 +1,42 @@
1
  import streamlit as st
2
  import pandas as pd
3
- import transformers
4
  from transformers import BertTokenizer, BertForSequenceClassification
5
  import torch
6
 
7
- # Load pre-trained BERT model and tokenizer
8
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
9
  model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
10
  model.eval()
11
 
12
  # Create a persistent DataFrame to store classification results
13
- results_df = pd.DataFrame(columns=['Text', 'Toxicity'])
14
 
15
  def classify_text(text):
16
- # Tokenize input text
17
- inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
18
- input_ids = inputs['input_ids']
19
- attention_mask = inputs['attention_mask']
20
 
21
- # Perform inference with BERT model
22
- with torch.no_grad():
23
- outputs = model(input_ids, attention_mask=attention_mask)
24
- logits = outputs.logits
25
- probabilities = torch.softmax(logits, dim=1)
26
- toxicity_score = probabilities[0][1].item() # Extract toxicity score
27
-
28
- return toxicity_score
29
-
30
- def add_to_results(text, toxicity):
31
- global results_df
32
- results_df = results_df.append({'Text': text, 'Toxicity': toxicity}, ignore_index=True)
33
 
34
  # Streamlit app
35
  def main():
36
- st.title('Toxicity Classification App')
37
-
38
- # Input text box for user to enter text
39
- user_text = st.text_area('Enter text:', '')
40
-
41
- # Button to classify text
42
- if st.button('Classify'):
43
  if user_text:
44
- toxicity_score = classify_text(user_text)
45
- st.write('Toxicity Score:', toxicity_score)
46
- add_to_results(user_text, toxicity_score)
47
-
48
- # Display classification results
49
- st.header('Classification Results')
50
- st.dataframe(results_df)
51
-
52
- if _name_ == '_main_':
53
- main()
 
 
 
1
  import streamlit as st
2
  import pandas as pd
 
3
  from transformers import BertTokenizer, BertForSequenceClassification
4
  import torch
5
 
6
+ # Load pre-trained BERT model
7
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
8
  model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
9
  model.eval()
10
 
11
  # Create a persistent DataFrame to store classification results
12
+ classified_data = pd.DataFrame(columns=['Text', 'Toxicity'])
13
 
14
  def classify_text(text):
15
+ # Tokenize and encode input text
16
+ input_ids = torch.tensor(tokenizer.encode(text, add_special_tokens=True)).unsqueeze(0)
 
 
17
 
18
+ # Forward pass through BERT model
19
+ outputs = model(input_ids)
20
+ logits = outputs.logits
21
+ predicted_class = torch.argmax(logits, dim=1).item()
22
+ toxicity = "Toxic" if predicted_class == 1 else "Non-Toxic"
23
+ return toxicity
 
 
 
 
 
 
24
 
25
  # Streamlit app
26
  def main():
27
+ st.title("Toxicity Classifier")
28
+ user_text = st.text_area("Enter text to classify:")
29
+ if st.button("Classify"):
 
 
 
 
30
  if user_text:
31
+ toxicity = classify_text(user_text)
32
+ st.write(f"Predicted Toxicity: {toxicity}")
33
+ # Add classification results to the persistent DataFrame
34
+ global classified_data
35
+ classified_data = classified_data.append({'Text': user_text, 'Toxicity': toxicity}, ignore_index=True)
36
+ else:
37
+ st.warning("Please enter some text.")
38
+ if st.button("View Classified Data"):
39
+ st.write(classified_data)
40
+
41
+ if __name__ == "__main__":
42
+ main()