Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| import torch | |
| from transformers import BertForSequenceClassification, BertTokenizer | |
| # Load pre-trained BERT model and tokenizer | |
| model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=6) | |
| tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
| # Define function to predict toxicity using the pre-trained BERT model | |
| def predict_toxicity(text): | |
| # Tokenize input text | |
| input_ids = tokenizer.encode(text, add_special_tokens=True) | |
| # Convert input to tensor | |
| input_tensor = torch.tensor([input_ids]) | |
| # Get model prediction | |
| outputs = model(input_tensor)[0] | |
| # Apply sigmoid activation function to get probability distribution | |
| probs = torch.sigmoid(outputs).detach().numpy()[0] | |
| # Return probability of being toxic for each category | |
| return probs | |
| # Load existing DataFrame or create a new one | |
| try: | |
| df = pd.read_csv('toxicity_data.csv') | |
| except: | |
| df = pd.DataFrame(columns=['text', 'toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']) | |
| # Load sample submission DataFrame | |
| sample_df = pd.read_csv('sample_submission.csv') | |
| # Define app layout | |
| st.set_page_config(page_title='Toxicity Classifier', page_icon='🤬') | |
| st.title('Toxicity Classifier') | |
| st.write('Enter some text to check its toxicity:') | |
| # Define input field for user to enter text | |
| text = st.text_input('Text input', value='I love coding') | |
| # Perform toxicity classification when user clicks the button | |
| if st.button('Classify'): | |
| # Predict toxicity of the input text | |
| toxicity_probs = predict_toxicity(text) | |
| # Display the result | |
| for i, col in enumerate(sample_df.columns[1:]): | |
| st.write(f'The {col} probability of "{text}" is {toxicity_probs[i]:.2f}.') | |
| # Add the result to the DataFrame | |
| df = df.append({'text': text, 'toxic': toxicity_probs[0], 'severe_toxic': toxicity_probs[1], 'obscene': toxicity_probs[2], 'threat': toxicity_probs[3], 'insult': toxicity_probs[4], 'identity_hate': toxicity_probs[5]}, ignore_index=True) | |
| # Save the DataFrame to a CSV file | |
| df.to_csv('toxicity_data.csv', index=False) | |
| else: | |
| # Show a sample input for the user to choose | |
| sample_inputs = ['I love coding', 'I hate coding', 'This is a great product!', 'Your service sucks.'] | |
| sample_index = st.selectbox('Or select a sample input:', range(len(sample_inputs)), format_func=lambda i: sample_inputs[i]) | |
| text = sample_inputs[sample_index] | |
| # Show the current DataFrame of classified texts | |
| st.write('Classification history:') | |
| st.dataframe(df) |