|
|
import streamlit as st
|
|
|
import torch
|
|
|
from transformers import AutoTokenizer
|
|
|
from src.models.toxic_classifier import ToxicClassifier
|
|
|
import os
|
|
|
import numpy as np
|
|
|
import plotly.graph_objects as go
|
|
|
from typing import Dict
|
|
|
|
|
|
class ToxicPredictor:
|
|
|
def __init__(self, model_path: str):
|
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
|
|
|
self.model = ToxicClassifier().to(self.device)
|
|
|
|
|
|
try:
|
|
|
|
|
|
checkpoint = torch.load(model_path, map_location=self.device, weights_only=True)
|
|
|
|
|
|
|
|
|
if 'model_state_dict' in checkpoint:
|
|
|
state_dict = checkpoint['model_state_dict']
|
|
|
else:
|
|
|
state_dict = checkpoint
|
|
|
|
|
|
|
|
|
missing_keys, unexpected_keys = self.model.load_state_dict(state_dict, strict=False)
|
|
|
if missing_keys:
|
|
|
st.warning(f"Missing keys in state dict: {missing_keys}")
|
|
|
if unexpected_keys:
|
|
|
st.warning(f"Unexpected keys in state dict: {unexpected_keys}")
|
|
|
|
|
|
self.model.eval()
|
|
|
|
|
|
except Exception as e:
|
|
|
st.error(f"Error loading model: {str(e)}")
|
|
|
raise
|
|
|
|
|
|
|
|
|
self.categories = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
|
|
|
|
|
|
def predict(self, text: str) -> Dict[str, float]:
|
|
|
"""Predict toxicity scores for a single text"""
|
|
|
try:
|
|
|
|
|
|
encoding = self.tokenizer(
|
|
|
text,
|
|
|
add_special_tokens=True,
|
|
|
max_length=128,
|
|
|
padding='max_length',
|
|
|
truncation=True,
|
|
|
return_tensors='pt'
|
|
|
)
|
|
|
|
|
|
|
|
|
input_ids = encoding['input_ids'].to(self.device)
|
|
|
attention_mask = encoding['attention_mask'].to(self.device)
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
outputs = self.model(input_ids, attention_mask)
|
|
|
probabilities = torch.sigmoid(outputs).cpu().numpy()[0]
|
|
|
|
|
|
|
|
|
results = {
|
|
|
category: float(prob)
|
|
|
for category, prob in zip(self.categories, probabilities)
|
|
|
}
|
|
|
|
|
|
return results
|
|
|
except Exception as e:
|
|
|
st.error(f"Error during prediction: {str(e)}")
|
|
|
raise
|
|
|
|
|
|
def create_gauge_chart(value: float, title: str) -> go.Figure:
|
|
|
"""Create a gauge chart for toxicity scores"""
|
|
|
fig = go.Figure(go.Indicator(
|
|
|
mode="gauge+number",
|
|
|
value=value * 100,
|
|
|
domain={'x': [0, 1], 'y': [0, 1]},
|
|
|
title={'text': title},
|
|
|
gauge={
|
|
|
'axis': {'range': [0, 100]},
|
|
|
'bar': {'color': "darkblue"},
|
|
|
'steps': [
|
|
|
{'range': [0, 33], 'color': "lightgreen"},
|
|
|
{'range': [33, 66], 'color': "yellow"},
|
|
|
{'range': [66, 100], 'color': "red"}
|
|
|
],
|
|
|
'threshold': {
|
|
|
'line': {'color': "red", 'width': 4},
|
|
|
'thickness': 0.75,
|
|
|
'value': 50
|
|
|
}
|
|
|
}
|
|
|
))
|
|
|
|
|
|
fig.update_layout(height=200)
|
|
|
return fig
|
|
|
|
|
|
def main():
|
|
|
st.set_page_config(
|
|
|
page_title="Toxic Comment Classifier",
|
|
|
page_icon="🔍",
|
|
|
layout="wide"
|
|
|
)
|
|
|
|
|
|
|
|
|
st.title("💬 Toxic Comment Classifier")
|
|
|
st.markdown("""
|
|
|
This app uses a BERT-based model to detect toxic comments.
|
|
|
Enter your text below to analyze it for different types of toxicity.
|
|
|
""")
|
|
|
|
|
|
|
|
|
model_path = os.path.join("models", "saved", "best_model.pt")
|
|
|
|
|
|
if not os.path.exists(model_path):
|
|
|
st.error("Model file not found! Please train the model first.")
|
|
|
return
|
|
|
|
|
|
try:
|
|
|
|
|
|
@st.cache_resource(show_spinner=False)
|
|
|
def load_predictor():
|
|
|
with st.spinner("Loading model..."):
|
|
|
return ToxicPredictor(model_path)
|
|
|
|
|
|
predictor = load_predictor()
|
|
|
|
|
|
|
|
|
text = st.text_area(
|
|
|
"Enter text to analyze:",
|
|
|
height=100,
|
|
|
placeholder="Type or paste your text here..."
|
|
|
)
|
|
|
|
|
|
if st.button("Analyze", type="primary"):
|
|
|
if not text:
|
|
|
st.warning("Please enter some text to analyze.")
|
|
|
return
|
|
|
|
|
|
with st.spinner("Analyzing text..."):
|
|
|
try:
|
|
|
|
|
|
predictions = predictor.predict(text)
|
|
|
|
|
|
|
|
|
st.markdown("### Analysis Results")
|
|
|
|
|
|
|
|
|
col1, col2, col3 = st.columns(3)
|
|
|
|
|
|
|
|
|
with col1:
|
|
|
st.plotly_chart(create_gauge_chart(predictions['toxic'], "Toxic"), use_container_width=True)
|
|
|
st.plotly_chart(create_gauge_chart(predictions['obscene'], "Obscene"), use_container_width=True)
|
|
|
|
|
|
with col2:
|
|
|
st.plotly_chart(create_gauge_chart(predictions['severe_toxic'], "Severe Toxic"), use_container_width=True)
|
|
|
st.plotly_chart(create_gauge_chart(predictions['threat'], "Threat"), use_container_width=True)
|
|
|
|
|
|
with col3:
|
|
|
st.plotly_chart(create_gauge_chart(predictions['insult'], "Insult"), use_container_width=True)
|
|
|
st.plotly_chart(create_gauge_chart(predictions['identity_hate'], "Identity Hate"), use_container_width=True)
|
|
|
|
|
|
|
|
|
st.markdown("### Overall Assessment")
|
|
|
max_toxicity = max(predictions.values())
|
|
|
max_category = max(predictions.items(), key=lambda x: x[1])[0]
|
|
|
|
|
|
if max_toxicity > 0.5:
|
|
|
st.error(f"⚠️ This text may be toxic (highest score: {max_toxicity:.2%} for {max_category})")
|
|
|
else:
|
|
|
st.success(f"✅ This text appears to be non-toxic (highest score: {max_toxicity:.2%})")
|
|
|
|
|
|
except Exception as e:
|
|
|
st.error(f"Error analyzing text: {str(e)}")
|
|
|
|
|
|
|
|
|
with st.expander("ℹ️ About the Toxicity Categories"):
|
|
|
st.markdown("""
|
|
|
The model analyzes text for six types of toxicity:
|
|
|
|
|
|
* **Toxic**: General category for unpleasant content
|
|
|
* **Severe Toxic**: Extreme cases of toxicity
|
|
|
* **Obscene**: Explicit or vulgar content
|
|
|
* **Threat**: Expressions of intent to harm
|
|
|
* **Insult**: Disrespectful or demeaning language
|
|
|
* **Identity Hate**: Prejudiced language against protected characteristics
|
|
|
|
|
|
Scores range from 0% to 100%, where higher scores indicate stronger presence of that category.
|
|
|
""")
|
|
|
|
|
|
|
|
|
st.markdown("---")
|
|
|
st.markdown(
|
|
|
"Built with ❤️ using Streamlit and BERT. "
|
|
|
"Model trained on the Toxic Comment Classification Dataset."
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
st.error(f"Application error: {str(e)}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main() |