Keshavp08 commited on
Commit
b49b9ab
·
1 Parent(s): c75cf43

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -11
app.py CHANGED
@@ -5,7 +5,7 @@ import torch
5
  from transformers import BertForSequenceClassification, BertTokenizer
6
 
7
  # Load pre-trained BERT model and tokenizer
8
- model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
9
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
10
 
11
  # Define function to predict toxicity using the pre-trained BERT model
@@ -16,19 +16,22 @@ def predict_toxicity(text):
16
  input_tensor = torch.tensor([input_ids])
17
  # Get model prediction
18
  outputs = model(input_tensor)[0]
19
- # Apply softmax activation function to get probability distribution
20
- probs = torch.softmax(outputs, dim=1).detach().numpy()[0]
21
- # Return probability of being toxic
22
- return probs[1]
23
 
24
  # Load existing DataFrame or create a new one
25
  try:
26
  df = pd.read_csv('toxicity_data.csv')
27
  except:
28
- df = pd.DataFrame(columns=['text', 'toxicity'])
 
 
 
29
 
30
  # Define app layout
31
- st.set_page_config(page_title='Classifier', page_icon='🤬')
32
  st.title('Toxicity Classifier')
33
  st.write('Enter some text to check its toxicity:')
34
 
@@ -38,11 +41,12 @@ text = st.text_input('Text input', value='I love coding')
38
  # Perform toxicity classification when user clicks the button
39
  if st.button('Classify'):
40
  # Predict toxicity of the input text
41
- toxicity_prob = predict_toxicity(text)
42
  # Display the result
43
- st.write(f'The toxicity probability of "{text}" is {toxicity_prob:.2f}.')
 
44
  # Add the result to the DataFrame
45
- df = df.append({'text': text, 'toxicity': toxicity_prob}, ignore_index=True)
46
  # Save the DataFrame to a CSV file
47
  df.to_csv('toxicity_data.csv', index=False)
48
  else:
@@ -53,4 +57,4 @@ else:
53
 
54
  # Show the current DataFrame of classified texts
55
  st.write('Classification history:')
56
- st.dataframe(df)
 
5
  from transformers import BertForSequenceClassification, BertTokenizer
6
 
7
  # Load pre-trained BERT model and tokenizer
8
+ model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=6)
9
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
10
 
11
  # Define function to predict toxicity using the pre-trained BERT model
 
16
  input_tensor = torch.tensor([input_ids])
17
  # Get model prediction
18
  outputs = model(input_tensor)[0]
19
+ # Apply sigmoid activation function to get probability distribution
20
+ probs = torch.sigmoid(outputs).detach().numpy()[0]
21
+ # Return probability of being toxic for each category
22
+ return probs
23
 
24
  # Load existing DataFrame or create a new one
25
  try:
26
  df = pd.read_csv('toxicity_data.csv')
27
  except:
28
+ df = pd.DataFrame(columns=['text', 'toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'])
29
+
30
+ # Load sample submission DataFrame
31
+ sample_df = pd.read_csv('sample_submission.csv')
32
 
33
  # Define app layout
34
+ st.set_page_config(page_title='Toxicity Classifier', page_icon='🤬')
35
  st.title('Toxicity Classifier')
36
  st.write('Enter some text to check its toxicity:')
37
 
 
41
  # Perform toxicity classification when user clicks the button
42
  if st.button('Classify'):
43
  # Predict toxicity of the input text
44
+ toxicity_probs = predict_toxicity(text)
45
  # Display the result
46
+ for i, col in enumerate(sample_df.columns[1:]):
47
+ st.write(f'The {col} probability of "{text}" is {toxicity_probs[i]:.2f}.')
48
  # Add the result to the DataFrame
49
+ df = df.append({'text': text, 'toxic': toxicity_probs[0], 'severe_toxic': toxicity_probs[1], 'obscene': toxicity_probs[2], 'threat': toxicity_probs[3], 'insult': toxicity_probs[4], 'identity_hate': toxicity_probs[5]}, ignore_index=True)
50
  # Save the DataFrame to a CSV file
51
  df.to_csv('toxicity_data.csv', index=False)
52
  else:
 
57
 
58
  # Show the current DataFrame of classified texts
59
  st.write('Classification history:')
60
+ st.dataframe(df)