import gc import re import streamlit as st from hatespeech_model import predict_hatespeech, load_model_from_hf, predict_hatespeech_from_file, get_rationale_from_mistral, preprocess_rationale_mistral, predict_hatespeech_from_file_batched import plotly.graph_objects as go import plotly.express as px import pandas as pd import numpy as np import time is_file_uploader_visible = False # Page configuration st.set_page_config( page_title="🛡️ Hate Speech Detector", page_icon="🛡️", layout="wide" ) # Cached model loading function @st.cache_resource def load_cached_model(model_type="altered"): model, tokenizer_hatebert, tokenizer_rationale, config, device = load_model_from_hf(model_type=model_type) return { "model": model, "tokenizer_hatebert": tokenizer_hatebert, "tokenizer_rationale": tokenizer_rationale, "config": config, "device": device } def clean_user_input(text): """Remove URLs and special characters (except exclamation points) from text""" # Remove URLs text = re.sub(r'https?://\S+|www\.\S+', '', text) # Remove special characters except exclamation points text = re.sub(r'[^a-zA-Z0-9\s!]', '', text) # Remove extra whitespace text = re.sub(r'\s+', ' ', text).strip() return text # Custom CSS st.markdown(""" """, unsafe_allow_html=True) # Header st.markdown('
🛡️ Hate Speech Detection System
', unsafe_allow_html=True) st.markdown('
Comparing Base vs Enhanced models with explainable AI for detecting hate speech
', unsafe_allow_html=True) # Load both models with spinner with st.spinner('🔄 Loading models... This may take a moment on first run.'): try: base_data = load_cached_model("base") enhanced_data = load_cached_model("altered") base_model = base_data["model"] base_tokenizer_hatebert = base_data["tokenizer_hatebert"] base_tokenizer_rationale = base_data["tokenizer_rationale"] base_config = base_data["config"] base_device = base_data["device"] enhanced_model = enhanced_data["model"] enhanced_tokenizer_hatebert = enhanced_data["tokenizer_hatebert"] enhanced_tokenizer_rationale = enhanced_data["tokenizer_rationale"] enhanced_config = enhanced_data["config"] enhanced_device = enhanced_data["device"] st.success('✅ Base Shield and Enhanced Shield models loaded successfully!') except Exception as e: st.error(f"❌ Error loading models: {str(e)}") st.stop() # Sidebar with st.sidebar: st.header("⚙️ Settings") st.divider() st.markdown(f"**Device:** CPU") st.markdown(f"**Max Length:** 128") st.markdown(f"**CNN Filters:** 128") st.divider() st.subheader("🔍 File Upload") is_file_uploader_visible = st.checkbox("Enable File Upload", value=is_file_uploader_visible) st.divider() show_rationale_viz = st.checkbox("Show Token Importance", value=True) show_probabilities = st.checkbox("Show Probability Distribution", value=True) show_details = st.checkbox("Show Technical Details", value=False) st.divider() st.subheader("💡 About") st.markdown(""" This model uses: - **HateBERT** for hate speech understanding - **Multi-Scale CNN** for feature extraction - **Attention mechanisms** for interpretability """) # Main interface col1, col2 = st.columns([2, 1]) with col1: if is_file_uploader_visible: user_input = None st.subheader("📂 Upload File") uploaded_file = st.file_uploader( "Choose a text file (.csv) to analyze:", type=["csv"], help="Upload a text file containing the content you want to analyze for hate speech" ) if uploaded_file is not None: try: file_content = pd.read_csv(uploaded_file, usecols=['text', 'CF_Rationales', 'label']) st.success("✅ File loaded successfully! Scroll down to analyze.") except Exception as e: st.error(f"❌ Error reading file: {str(e)}") user_input = "" else: st.subheader("📝 Input Text/File") user_input = st.text_area( "Enter text to analyze:", placeholder="Type or paste text here to check for hate speech...", height=150, help="Enter any text and the model will classify it as hate speech or not" ) optional_rationale = st.text_area( "Optional: Provide context or rationale (leave empty to use main text):", placeholder="Why might this be hate speech? (optional)", height=80 ) with col2: st.subheader("📊 Quick Stats") if user_input: word_count = len(user_input.split()) char_count = len(user_input) st.metric("Words", word_count) st.metric("Characters", char_count) if is_file_uploader_visible and uploaded_file is not None: st.markdown(f"**Filename:** {uploaded_file.name}") st.markdown(f"**Size:** {uploaded_file.size / 1024:.2f} KB") file_rows = len(file_content) st.metric("Rows in File", file_rows) else: st.info("Enter text/file to see statistics") # Classification button classify_button = st.button("🔍 Analyze Text", type="primary", use_container_width=True) if classify_button: if user_input and user_input.strip(): # Clean the input text cleaned_input = clean_user_input(user_input) with st.spinner('🔄 Generating rationale from Mistral AI...'): # --- Step 1: Get rationale from Mistral --- try: raw_rationale = get_rationale_from_mistral(cleaned_input) cleaned_rationale = preprocess_rationale_mistral(raw_rationale) print(f"Raw rationale from Mistral: {raw_rationale}") except Exception as e: st.error(f"❌ Error generating/processing rationale: {str(e)}") cleaned_rationale = cleaned_input # fallback to cleaned input with st.spinner('🔄 Analyzing text with models...'): # Run enhanced model enhanced_start = time.time() enhanced_model_result = predict_hatespeech( text=cleaned_input, rationale=cleaned_rationale, # use cleaned rationale model=enhanced_model, tokenizer_hatebert=enhanced_tokenizer_hatebert, tokenizer_rationale=enhanced_tokenizer_rationale, config=enhanced_config, device=enhanced_device, model_type="altered" ) enhanced_end = time.time() # Run base model base_start = time.time() base_model_result = predict_hatespeech( text=cleaned_input, rationale=cleaned_rationale, # use cleaned rationale model=base_model, tokenizer_hatebert=base_tokenizer_hatebert, tokenizer_rationale=base_tokenizer_rationale, config=base_config, device=base_device, model_type="base" ) base_end = time.time() gc.collect() # Clean up memory after inference # Extract results for both models base_prediction = base_model_result['prediction'] base_confidence = base_model_result['confidence'] base_probabilities = base_model_result['probabilities'] base_processing_time = base_end - base_start enhanced_prediction = enhanced_model_result['prediction'] enhanced_confidence = enhanced_model_result['confidence'] enhanced_probabilities = enhanced_model_result['probabilities'] enhanced_rationale_scores = enhanced_model_result['rationale_scores'] enhanced_tokens = enhanced_model_result['tokens'] enhanced_processing_time = enhanced_end - enhanced_start # Display results st.divider() st.header("📈 Analysis Results") # Side-by-side results columns base_col, enhanced_col = st.columns(2) # === BASE MODEL RESULTS (LEFT) === with base_col: st.subheader("🔵 Base Shield Results") # Prediction box if base_prediction == 1: st.markdown(f'
🚨 HATE SPEECH DETECTED
', unsafe_allow_html=True) else: st.markdown(f'
✅ NOT HATE SPEECH
', unsafe_allow_html=True) # Metrics st.metric("Confidence", f"{base_confidence:.1%}") base_m1, base_m2 = st.columns(2) with base_m1: st.metric("Not Hate Speech", f"{base_probabilities[0]:.1%}") with base_m2: st.metric("Hate Speech", f"{base_probabilities[1]:.1%}") st.metric("Processing Time", f"{base_processing_time:.3f}s") # Probability distribution chart if show_probabilities: st.markdown("**📊 Probability Distribution**") fig_base = go.Figure(data=[ go.Bar( x=['Not Hate Speech', 'Hate Speech'], y=base_probabilities, marker_color=['#66bb6a', '#ef5350'], text=[f"{p:.1%}" for p in base_probabilities], textposition='auto', ) ]) fig_base.update_layout( yaxis_title="Probability", yaxis_range=[0, 1], height=300, showlegend=False ) st.plotly_chart(fig_base, use_container_width=True) # Technical details for base if show_details: with st.expander("View Base Model Outputs"): st.json({ 'prediction': int(base_prediction), 'confidence': float(base_confidence), 'probability_not_hate': float(base_probabilities[0]), 'probability_hate': float(base_probabilities[1]), 'device': 'cpu', 'model_config': { 'max_length': '128', } }) # === ENHANCED MODEL RESULTS (RIGHT) === with enhanced_col: st.subheader("🟢 Enhanced Shield Results") # Prediction box if enhanced_prediction == 1: st.markdown(f'
🚨 HATE SPEECH DETECTED
', unsafe_allow_html=True) else: st.markdown(f'
✅ NOT HATE SPEECH
', unsafe_allow_html=True) # Metrics st.metric("Confidence", f"{enhanced_confidence:.1%}") enh_m1, enh_m2 = st.columns(2) with enh_m1: st.metric("Not Hate Speech", f"{enhanced_probabilities[0]:.1%}") with enh_m2: st.metric("Hate Speech", f"{enhanced_probabilities[1]:.1%}") st.metric("Processing Time", f"{enhanced_processing_time:.3f}s") # Probability distribution chart if show_probabilities: st.markdown("**📊 Probability Distribution**") fig_enhanced = go.Figure(data=[ go.Bar( x=['Not Hate Speech', 'Hate Speech'], y=enhanced_probabilities, marker_color=['#66bb6a', '#ef5350'], text=[f"{p:.1%}" for p in enhanced_probabilities], textposition='auto', ) ]) fig_enhanced.update_layout( yaxis_title="Probability", yaxis_range=[0, 1], height=300, showlegend=False ) st.plotly_chart(fig_enhanced, use_container_width=True) # Token importance visualization (only for enhanced) if show_rationale_viz: st.markdown("**🔍 Token Importance Analysis**") st.caption("Highlighted words show which parts influenced the prediction") # Filter out special tokens and create visualization token_importance = [] html_output = "
" for token, score in zip(enhanced_tokens, enhanced_rationale_scores): if token not in ['[CLS]', '[SEP]', '[PAD]']: # Clean token display_token = token.replace('##', '') token_importance.append({'Token': display_token, 'Importance': score}) # Color intensity based on score and prediction alpha = min(score * 1.5, 1.0) # Scale up visibility color = f"rgba(239, 83, 80, {alpha:.2f})" # Red for hate speech influence # Use green for non-hate speech, red for hate speech # if enhanced_prediction == 1: # Hate speech # color = f"rgba(239, 83, 80, {alpha:.2f})" # else: # Not hate speech # color = f"rgba(102, 187, 106, {alpha:.2f})" html_output += f"{display_token} " html_output += "
" st.markdown(html_output, unsafe_allow_html=True) # Dynamic caption based on prediction if enhanced_prediction == 1: st.caption("🔴 Darker red = More influence on hate speech detection.") # Top important tokens st.markdown("**📋 Top Important Tokens**") df_importance = pd.DataFrame(token_importance) df_importance = df_importance.sort_values('Importance', ascending=False).head(10) df_importance['Importance'] = df_importance['Importance'].apply(lambda x: f"{x:.4f}") st.dataframe( df_importance, use_container_width=True, hide_index=True ) # Technical details for enhanced if show_details: with st.expander("View Enhanced Model Outputs"): st.json({ 'prediction': int(enhanced_prediction), 'confidence': float(enhanced_confidence), 'probability_not_hate': float(enhanced_probabilities[0]), 'probability_hate': float(enhanced_probabilities[1]), 'num_tokens': len([t for t in enhanced_tokens if t not in ['[CLS]', '[SEP]', '[PAD]']]), 'device': 'cpu', 'model_config': { 'max_length': '128', 'cnn_filters': '128', } }) if is_file_uploader_visible and uploaded_file is not None: st.markdown("**Preview:**") st.dataframe(file_content.head(3), use_container_width=True) with st.spinner('🔄 Analyzing file with both models... This may take a while for large files.'): # Run both models on the file # base_result = predict_hatespeech_from_file(...) # Base model # enhanced_result = predict_hatespeech_from_file(...) # Enhanced model enhanced_result = predict_hatespeech_from_file_batched( text_list=file_content['text'].tolist(), rationale_list=file_content['CF_Rationales'].tolist(), true_label=file_content['label'].tolist(), model=enhanced_model, tokenizer_hatebert=enhanced_tokenizer_hatebert, tokenizer_rationale=enhanced_tokenizer_rationale, config=enhanced_config, device=enhanced_device, model_type="altered" ) base_result = predict_hatespeech_from_file( text_list=file_content['text'].tolist(), rationale_list=file_content['CF_Rationales'].tolist(), true_label=file_content['label'].tolist(), model=base_model, tokenizer_hatebert=base_tokenizer_hatebert, tokenizer_rationale=base_tokenizer_rationale, config=base_config, device=base_device, model_type="base" ) st.success("✅ File analysis complete for both models!") st.divider() st.header("📊 Analysis Results - Model Comparison") gc.collect() # Clean up memory after file inference # Side-by-side results columns base_file_col, enhanced_file_col = st.columns(2) # === BASE MODEL FILE RESULTS (LEFT) === with base_file_col: st.subheader("🔵 Base Shield Results") # Performance Metrics st.markdown("**📈 Classification Metrics**") base_fm1, base_fm2 = st.columns(2) with base_fm1: st.metric("F1 Score", f"{base_result['f1_score']:.4f}") st.metric("Precision", f"{base_result['precision']:.4f}") with base_fm2: st.metric("Accuracy", f"{base_result['accuracy']:.4f}") st.metric("Recall", f"{base_result['recall']:.4f}") # Confusion Matrix Visualization st.markdown("**🎯 Confusion Matrix**") base_cm = base_result['confusion_matrix'] fig_base_cm = go.Figure(data=go.Heatmap( z=base_cm, x=['Pred Not Hate', 'Pred Hate'], y=['True Not Hate', 'True Hate'], colorscale='Blues', text=base_cm, texttemplate='%{text}', textfont={"size": 14}, showscale=False )) fig_base_cm.update_layout(height=300) st.plotly_chart(fig_base_cm, use_container_width=True) # Resource Usage st.markdown("**⚙️ Resource Usage**") base_cpu_col, base_mem_col = st.columns(2) with base_cpu_col: st.metric("Avg CPU", f"{base_result['cpu_usage']:.2f}%") st.metric("Peak CPU", f"{base_result['peak_cpu_usage']:.2f}%") with base_mem_col: st.metric("Avg Memory", f"{base_result['memory_usage']:.2f} MB") st.metric("Peak Memory", f"{base_result['peak_memory_usage']:.2f} MB") # Runtime st.markdown("**⏱️ Performance**") st.metric("Total Runtime", f"{base_result['runtime']:.2f}s") st.metric("Avg Time/Sample", f"{base_result['runtime']/file_rows:.3f}s") # === ENHANCED MODEL FILE RESULTS (RIGHT) === with enhanced_file_col: st.subheader("🟢 Enhanced Shield Results") st.markdown("**📈 Classification Metrics**") enh_fm1, enh_fm2 = st.columns(2) with enh_fm1: st.metric("F1 Score", f"{enhanced_result['f1_score']:.4f}") st.metric("Precision", f"{enhanced_result['precision']:.4f}") with enh_fm2: st.metric("Accuracy", f"{enhanced_result['accuracy']:.4f}") st.metric("Recall", f"{enhanced_result['recall']:.4f}") # Confusion Matrix Visualization st.markdown("**🎯 Confusion Matrix**") enhanced_cm = enhanced_result['confusion_matrix'] fig_enhanced_cm = go.Figure(data=go.Heatmap( z=enhanced_cm, x=['Pred Not Hate', 'Pred Hate'], y=['True Not Hate', 'True Hate'], colorscale='Greens', text=enhanced_cm, texttemplate='%{text}', textfont={"size": 14}, showscale=False )) fig_enhanced_cm.update_layout(height=300) st.plotly_chart(fig_enhanced_cm, use_container_width=True) st.markdown("**⚙️ Resource Usage**") enh_cpu_col, enh_mem_col = st.columns(2) with enh_cpu_col: st.metric("Avg CPU", f"{enhanced_result['cpu_usage']:.2f}%") st.metric("Peak CPU", f"{enhanced_result['peak_cpu_usage']:.2f}%") with enh_mem_col: st.metric("Avg Memory", f"{enhanced_result['memory_usage']:.2f} MB") st.metric("Peak Memory", f"{enhanced_result['peak_memory_usage']:.2f} MB") # Runtime st.markdown("**⏱️ Performance**") st.metric("Total Runtime", f"{enhanced_result['runtime']:.2f}s") st.metric("Avg Time/Sample", f"{enhanced_result['runtime']/file_rows:.3f}s") else: st.warning("⚠️ Please enter some text to analyze.") # Examples section st.divider() st.subheader("💡 Try Example Texts") col1, col2, col3 = st.columns(3) with col1: if st.button("Example: Hate Speech", use_container_width=True): st.session_state.example_text = "You people are worthless and should leave this country!" st.rerun() with col2: if st.button("Example: Not Hate Speech", use_container_width=True): st.session_state.example_text = "I disagree with your opinion, but I respect your right to express it." st.rerun() with col3: if st.button("Example: Borderline", use_container_width=True): st.session_state.example_text = "This policy is terrible and will hurt everyone involved." st.rerun() if 'example_text' in st.session_state: st.info(f"**Example loaded:** {st.session_state.example_text}") st.caption("↑ Copy this text to the input box above and click 'Analyze Text'") # Footer st.divider() st.markdown("""

Hate Speech Detection Model Comparison

Base Shield vs Enhanced Shield (HateBERT + Multi-Scale CNN + Attention)

Side-by-side comparison for performance evaluation

""", unsafe_allow_html=True)