ynp3 commited on
Commit
89c7b20
·
1 Parent(s): d316bf2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -0
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import torch
4
+ from transformers import BertTokenizer, BertForSequenceClassification
5
+
6
+ # Load pre-trained BERT model and tokenizer
7
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
8
+ model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=6)
9
+
10
+ # Load pre-trained model weights
11
+ model_path = "model.pt" # Path to pre-trained model weights
12
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
13
+ model.eval()
14
+
15
+ # Create a DataFrame to store classification results
16
+ results_df = pd.DataFrame(columns=['Text', 'Toxic', 'Obscene', 'Insult', 'Identity_Hate', 'Threat', 'Severe_Toxic'])
17
+
18
+ def classify_text(text):
19
+ """
20
+ Function to classify text using the pre-trained BERT model.
21
+ """
22
+ inputs = tokenizer.encode_plus(text, add_special_tokens=True, return_tensors='pt')
23
+ with torch.no_grad():
24
+ outputs = model(inputs['input_ids'], attention_mask=inputs['attention_mask'])
25
+ logits = outputs.logits
26
+ probabilities = torch.sigmoid(logits)
27
+ return probabilities[0].tolist()
28
+
29
+ def main():
30
+ """
31
+ Main function to run the Streamlit app.
32
+ """
33
+ st.title("Toxicity Classification")
34
+ text = st.text_area("Enter text to classify", "")
35
+ if st.button("Classify"):
36
+ probabilities = classify_text(text)
37
+ toxic = probabilities[0]
38
+ obscene = probabilities[1]
39
+ insult = probabilities[2]
40
+ identity_hate = probabilities[3]
41
+ threat = probabilities[4]
42
+ severe_toxic = probabilities[5]
43
+ st.write("Toxic: {:.2f}%".format(toxic*100))
44
+ st.write("Obscene: {:.2f}%".format(obscene*100))
45
+ st.write("Insult: {:.2f}%".format(insult*100))
46
+ st.write("Identity Hate: {:.2f}%".format(identity_hate*100))
47
+ st.write("Threat: {:.2f}%".format(threat*100))
48
+ st.write("Severe Toxic: {:.2f}%".format(severe_toxic*100))
49
+ if st.button("Add to Results"):
50
+ global results_df
51
+ results_df = results_df.append({'Text': text, 'Toxic': toxic, 'Obscene': obscene,
52
+ 'Insult': insult, 'Identity_Hate': identity_hate,
53
+ 'Threat': threat, 'Severe_Toxic': severe_toxic}, ignore_index=True)
54
+ st.success("Text added to Results DataFrame!")
55
+ if st.button("View Results"):
56
+ st.write(results_df)
57
+
58
+ if __name__ == "__main__":
59
+ main()