Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import pandas as pd | |
| import transformers | |
| import torch | |
| import seaborn as sns | |
| import matplotlib.pyplot as plt | |
| # Load the pre-trained BERT model and tokenizer | |
| try: | |
| tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased') | |
| model = transformers.BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=6) | |
| except Exception as e: | |
| st.error(f"Error loading the model: {e}") | |
| # Set up the Streamlit app | |
| st.set_page_config(layout="wide") | |
| st.title('Toxicity Classification App') | |
| # Create a text input for the user to enter their text | |
| text_input = st.text_input('Enter text to classify') | |
| # Create a button to run the classification | |
| if st.button('Classify'): | |
| if not text_input: | |
| st.warning("Please enter text to classify.") | |
| else: | |
| # Tokenize the text and convert to input IDs | |
| encoded_text = tokenizer.encode_plus( | |
| text_input, | |
| max_length=512, | |
| padding='max_length', | |
| truncation=True, | |
| add_special_tokens=True, | |
| return_attention_mask=True, | |
| return_tensors='pt' | |
| ) | |
| # Run the text through the model | |
| with torch.no_grad(): | |
| output = model(encoded_text['input_ids'], encoded_text['attention_mask']) | |
| probabilities = torch.nn.functional.softmax(output[0], dim=1).tolist()[0] | |
| # Display the classification results | |
| st.write('Toxic:', probabilities[0]) | |
| st.write('Severe Toxic:', probabilities[1]) | |
| st.write('Obscene:', probabilities[2]) | |
| st.write('Threat:', probabilities[3]) | |
| st.write('Insult:', probabilities[4]) | |
| st.write('Identity Hate:', probabilities[5]) | |
| # Create a DataFrame to store the classification results | |
| results_df = pd.DataFrame(columns=['Text', 'Toxic', 'Severe Toxic', 'Obscene', 'Threat', 'Insult', 'Identity Hate']) | |
| results_df = results_df.append({ | |
| 'Text': text_input, | |
| 'Toxic': probabilities[0], | |
| 'Severe Toxic': probabilities[1], | |
| 'Obscene': probabilities[2], | |
| 'Threat': probabilities[3], | |
| 'Insult': probabilities[4], | |
| 'Identity Hate': probabilities[5] | |
| }, ignore_index=True) | |
| # Append the classification results to the persistent DataFrame | |
| if 'results' not in st.session_state: | |
| st.session_state['results'] = pd.DataFrame(columns=results_df.columns) | |
| st.session_state['results'] = st.session_state['results'].append(results_df, ignore_index=True) | |
| # Display the persistent DataFrame | |
| st.write('Classification Results:', st.session_state.get('results', pd.DataFrame())) | |
| # Plot the distribution of probabilities for each category | |
| if len(st.session_state.get('results', pd.DataFrame())) > 0: | |
| df = st.session_state['results'] | |
| st.pyplot(sns.histplot(data=df, x='Toxic', kde=True)) | |
| st.pyplot(sns.histplot(data=df, x='Severe Toxic', kde=True)) | |