ynp3 commited on
Commit
0b7e7ec
·
1 Parent(s): 296cbab

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -0
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import transformers
4
+ from transformers import BertTokenizer, BertForSequenceClassification
5
+ import torch
6
+
7
+ # Load pre-trained BERT model and tokenizer
8
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
9
+ model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
10
+ model.eval()
11
+
12
+ # Create a persistent DataFrame to store classification results
13
+ results_df = pd.DataFrame(columns=['Text', 'Toxicity'])
14
+
15
+ def classify_text(text):
16
+ # Tokenize input text
17
+ inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
18
+ input_ids = inputs['input_ids']
19
+ attention_mask = inputs['attention_mask']
20
+
21
+ # Perform inference with BERT model
22
+ with torch.no_grad():
23
+ outputs = model(input_ids, attention_mask=attention_mask)
24
+ logits = outputs.logits
25
+ probabilities = torch.softmax(logits, dim=1)
26
+ toxicity_score = probabilities[0][1].item() # Extract toxicity score
27
+
28
+ return toxicity_score
29
+
30
+ def add_to_results(text, toxicity):
31
+ global results_df
32
+ results_df = results_df.append({'Text': text, 'Toxicity': toxicity}, ignore_index=True)
33
+
34
+ # Streamlit app
35
+ def main():
36
+ st.title('Toxicity Classification App')
37
+
38
+ # Input text box for user to enter text
39
+ user_text = st.text_area('Enter text:', '')
40
+
41
+ # Button to classify text
42
+ if st.button('Classify'):
43
+ if user_text:
44
+ toxicity_score = classify_text(user_text)
45
+ st.write('Toxicity Score:', toxicity_score)
46
+ add_to_results(user_text, toxicity_score)
47
+
48
+ # Display classification results
49
+ st.header('Classification Results')
50
+ st.dataframe(results_df)
51
+
52
+ if __name__ == '__main__':
53
+ main()