ynp3 commited on
Commit
0548248
·
1 Parent(s): bbdda78

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -90
app.py CHANGED
@@ -1,94 +1,56 @@
1
  import streamlit as st
2
  import pandas as pd
3
- import numpy as np
4
- from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, BertForSequenceClassification, DistilBertModel
5
  import torch
6
- from torch import cuda
7
- from torch.utils.data import Dataset, DataLoader
8
- import finetuning
9
- from finetuning import CustomDistilBertClass
10
 
11
- # device = 'cuda' if cuda.is_available() else 'cpu'
12
- # Load pretrained models
13
- model_map = {
14
- 'BERT': 'bert-base-uncased',
15
- 'RoBERTa': 'roberta-base',
16
- 'DistilBERT': 'distilbert-base-uncased'
17
- }
18
-
19
- # Load dropdown options
20
- model_options = list(model_map.keys())
21
-
22
- # Load dataset
23
- train_df = pd.read_csv('train.csv')
24
- train_df = train_df.sample(n=256)
25
- label_cols = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
26
-
27
- @st.cache_resource
28
- def load_model(model_name):
29
- """Load pretrained BERT model."""
30
- path = "finetuned_model.pt"
31
- model = torch.load(path)
32
- tokenizer = AutoTokenizer.from_pretrained(model_map[model_name])
33
- return model, tokenizer
34
-
35
- def classify_text(model, tokenizer, text):
36
- """Classify text using pretrained BERT model."""
37
- inputs = tokenizer.encode_plus(
38
- text,
39
- add_special_tokens=True,
40
- max_length=512,
41
- padding='max_length',
42
- return_tensors='pt',
43
- truncation=True
44
- )
45
- print(inputs)
46
- with torch.no_grad():
47
- logits = model(inputs['input_ids'],inputs['attention_mask'])[0]
48
- probabilities = torch.softmax(logits, dim=1)[0]
49
- pred_class = torch.argmax(probabilities, dim=0)
50
- print(f"pred class: {pred_class}")
51
- print(probabilities[0].tolist())
52
- return label_cols[pred_class], round(probabilities[0].tolist(),2)
53
-
54
- # Set up streamlit app
55
- st.title('Toxic Comment Classifier')
56
- model_name = st.sidebar.selectbox('Select a model', model_options)
57
- st.sidebar.write('Selected:', model_name)
58
- model, tokenizer = load_model(model_name)
59
- print(type(model))
60
-
61
- # Define input text area
62
- st.subheader('Enter comment below:')
63
- text_input = st.text_area(label='', height=100, max_chars=500)
64
-
65
- # Make prediction when user clicks 'Classify' button
66
- if st.button('Classify Toxicity'):
67
- if not text_input:
68
- st.write('Please enter comment')
69
- else:
70
- class_label, class_prob = classify_text(model, tokenizer, text_input)
71
- st.subheader('Results')
72
- st.write('Tweet:', text_input)
73
- st.write('Highest Toxicity Class:', class_label)
74
- st.write('Probability:', class_prob)
75
-
76
- # Display table of results
77
- st.subheader('Toxic Classification Results')
78
- if 'classification_results' not in st.session_state:
79
- st.session_state.classification_results = pd.DataFrame(columns=['tweet', 'toxicity_class', 'probability'])
80
- if st.button('Add to Results'):
81
- if not text_input:
82
- st.write('Please enter comment')
83
- else:
84
- class_label, class_prob = classify_text(model, tokenizer, text_input)
85
- st.subheader('Results')
86
- st.write('Tweet:', text_input)
87
- st.write('Highest Toxicity Class:', class_label)
88
- st.write('Probability:', class_prob)
89
- st.session_state.classification_results = st.session_state.classification_results.append({
90
- 'tweet': text_input,
91
- 'toxicity_class': class_label,
92
- 'probability': class_prob
93
- }, ignore_index=True)
94
- st.write(st.session_state.classification_results)
 
1
  import streamlit as st
2
  import pandas as pd
3
+ from transformers import BertTokenizer, BertForSequenceClassification
 
4
  import torch
 
 
 
 
5
 
6
+ # Load pre-trained BERT model
7
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
8
+ model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
9
+ model.eval()
10
+
11
+ # Create a DataFrame to store classification results
12
+ classification_results_df = pd.DataFrame(columns=['Text', 'Toxic', 'Severe Toxic', 'Obscene', 'Threat', 'Insult', 'Identity Hate'])
13
+
14
+ def classify_text(text):
15
+ # Tokenize text
16
+ inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True)
17
+ # Forward pass through the BERT model
18
+ outputs = model(**inputs)
19
+ # Get predicted probabilities for each class
20
+ probs = torch.sigmoid(outputs.logits)
21
+ # Round probabilities to 0 or 1 to get binary predictions
22
+ preds = (probs > 0.5).int().tolist()[0]
23
+ return preds
24
+
25
+ def add_classification_to_df(text, preds):
26
+ # Add classification results to the DataFrame
27
+ classification_results_df.loc[len(classification_results_df)] = [text] + preds
28
+
29
+ # Streamlit app
30
+ def main():
31
+ st.title("Toxicity Classification with BERT")
32
+ # Input text from user
33
+ text = st.text_area("Enter text for classification", "")
34
+ if st.button("Classify"):
35
+ if text.strip() == "":
36
+ st.warning("Please enter some text for classification.")
37
+ else:
38
+ # Perform classification
39
+ preds = classify_text(text)
40
+ # Display classification results
41
+ st.subheader("Classification Results:")
42
+ st.write("Toxic: ", preds[0])
43
+ st.write("Severe Toxic: ", preds[1])
44
+ st.write("Obscene: ", preds[2])
45
+ st.write("Threat: ", preds[3])
46
+ st.write("Insult: ", preds[4])
47
+ st.write("Identity Hate: ", preds[5])
48
+ # Add classification results to DataFrame
49
+ add_classification_to_df(text, preds)
50
+ if st.button("View Classification Results"):
51
+ # Display classification results DataFrame
52
+ st.subheader("All Classification Results:")
53
+ st.write(classification_results_df)
54
+
55
+ if __name__ == '__main__':
56
+ main()