""" Nepali Hate Speech Detection - Streamlit Application ===================================================== Complete application with preprocessing, prediction, and explainability (LIME/SHAP/Captum) Run with: streamlit run main_app.py """ import os import sys import streamlit as st import pandas as pd import numpy as np import torch import plotly.graph_objects as go import plotly.express as px from datetime import datetime import json import warnings warnings.filterwarnings('ignore') # Matplotlib for Nepali font support import matplotlib.pyplot as plt from matplotlib.font_manager import FontProperties, fontManager # ============================================================================ # HF SPACES COMPATIBILITY — paths and environment # ============================================================================ # Detect if running on HF Spaces IS_HF_SPACES = bool(os.environ.get('SPACE_ID')) # Use /tmp for writable storage on HF Spaces, local 'data/' otherwise DATA_DIR = '/tmp/data' if IS_HF_SPACES else 'data' os.makedirs(DATA_DIR, exist_ok=True) HISTORY_FILE = os.path.join(DATA_DIR, 'prediction_history.json') # ============================================================================ # SCRIPT PATH SETUP # ============================================================================ BASE_DIR = os.path.dirname(os.path.abspath(__file__)) SCRIPTS_DIR = os.path.join(BASE_DIR, 'scripts') if SCRIPTS_DIR not in sys.path: sys.path.insert(0, BASE_DIR) sys.path.insert(0, SCRIPTS_DIR) # ============================================================================ # CUSTOM MODULE IMPORTS # ============================================================================ try: from scripts.transformer_data_preprocessing import ( HateSpeechPreprocessor, preprocess_text, get_script_info, get_emoji_info, EMOJI_TO_NEPALI ) from scripts.explainability import ( create_explainer_wrapper, LIMEExplainer, SHAPExplainer, check_availability as check_explainability ) from scripts.captum_explainer import ( CaptumExplainer, check_availability as check_captum_availability ) CUSTOM_MODULES_AVAILABLE = True except MemoryError: st.warning("⚠️ Captum not available due to memory constraints.") CUSTOM_MODULES_AVAILABLE = False captum_available = False except ImportError as e: st.error(f"⚠️ Custom modules not found: {e}") CUSTOM_MODULES_AVAILABLE = False # ============================================================================ # PAGE CONFIGURATION # ============================================================================ st.set_page_config( page_title="Nepali Hate Content Detector", page_icon="🛡️", layout="wide", initial_sidebar_state="expanded" ) # ============================================================================ # CUSTOM CSS # ============================================================================ st.markdown(""" """, unsafe_allow_html=True) # ============================================================================ # NEPALI FONT LOADING # ============================================================================ @st.cache_resource def load_nepali_font(): """Load Nepali font for matplotlib visualizations. Tries multiple font paths in order of preference: 1. Kalimati (primary — downloaded by Dockerfile) 2. Noto Sans Devanagari (Linux/HF Spaces fallback) 3. Other system Devanagari fonts (macOS, Windows) """ font_paths = [ # ── Kalimati (primary) ────────────────────────────────────────── # HF Spaces / Docker — downloaded by Dockerfile curl command '/app/fonts/Kalimati.ttf', # Registered system-wide by fc-cache in Dockerfile '/usr/local/share/fonts/nepali/Kalimati.ttf', # Local dev — absolute path relative to script location os.path.join(BASE_DIR, 'fonts', 'Kalimati.ttf'), # Local dev — relative path 'fonts/Kalimati.ttf', # ── Noto Sans Devanagari (Linux / HF Spaces fallback) ─────────── '/usr/share/fonts/truetype/noto/NotoSansDevanagari-Regular.ttf', '/usr/share/fonts/truetype/noto/NotoSansDevanagari[wdth,wght].ttf', '/usr/share/fonts/opentype/noto/NotoSansDevanagari-Regular.otf', '/usr/share/fonts/noto/NotoSansDevanagari-Regular.ttf', # Noto Serif Devanagari variant '/usr/share/fonts/truetype/noto/NotoSerifDevanagari-Regular.ttf', # Generic Noto fallback '/usr/share/fonts/truetype/noto/NotoSans-Regular.ttf', # ── macOS ──────────────────────────────────────────────────────── '/System/Library/Fonts/Supplemental/DevanagariSangamMN.ttc', '/System/Library/Fonts/Supplemental/DevanagariMT.ttc', '/Library/Fonts/Devanagari Sangam MN.ttc', # ── Windows ───────────────────────────────────────────────────── r'C:\Windows\Fonts\NirmalaUI.ttf', r'C:\Windows\Fonts\NirmalaUI-Bold.ttf', r'C:\Windows\Fonts\mangal.ttf', r'C:\Windows\Fonts\Aparajita.ttf', ] for font_path in font_paths: if os.path.exists(font_path): try: fontManager.addfont(font_path) fp = FontProperties(fname=font_path) return fp except Exception: continue # Silent failure — charts still render, just without Devanagari-specific glyphs return None # ============================================================================ # SESSION STATE INITIALIZATION # ============================================================================ if 'last_prediction' not in st.session_state: st.session_state.last_prediction = None if 'last_text' not in st.session_state: st.session_state.last_text = "" if 'batch_results' not in st.session_state: st.session_state.batch_results = None if 'batch_mode' not in st.session_state: st.session_state.batch_mode = None if 'csv_text_column' not in st.session_state: st.session_state.csv_text_column = None if 'explainability_results' not in st.session_state: st.session_state.explainability_results = None if 'preprocessor' not in st.session_state: st.session_state.preprocessor = None if 'model_wrapper' not in st.session_state: st.session_state.model_wrapper = None if 'nepali_font' not in st.session_state: st.session_state.nepali_font = None if 'session_predictions' not in st.session_state: st.session_state.session_predictions = 0 if 'session_class_counts' not in st.session_state: st.session_state.session_class_counts = {'NO': 0, 'OO': 0, 'OR': 0, 'OS': 0} # ============================================================================ # MODEL LOADING # ============================================================================ @st.cache_resource(show_spinner="Loading model... this may take a minute on first run.") def load_model_and_preprocessor(): """Load model, tokenizer, label encoder, and preprocessor.""" from transformers import AutoTokenizer, AutoModelForSequenceClassification import joblib hf_model_id = "UDHOV/xlm-roberta-large-nepali-hate-classification" local_model_path = 'models/saved_models/xlm_roberta_results/large_final' device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Initialize default label encoder as fallback from sklearn.preprocessing import LabelEncoder le = LabelEncoder() le.fit(['NO', 'OO', 'OR', 'OS']) # Try local model first (only relevant for local dev), then HF Hub if not IS_HF_SPACES and os.path.exists(local_model_path): try: tokenizer = AutoTokenizer.from_pretrained(local_model_path) model = AutoModelForSequenceClassification.from_pretrained(local_model_path) model.to(device).eval() le_path = os.path.join(local_model_path, 'label_encoder.pkl') if os.path.exists(le_path): le = joblib.load(le_path) st.success(f"✅ Model loaded from local path on {device}") except Exception as e: st.warning(f"⚠️ Local model failed: {e}. Falling back to HuggingFace Hub...") tokenizer = AutoTokenizer.from_pretrained(hf_model_id) model = AutoModelForSequenceClassification.from_pretrained(hf_model_id) model.to(device).eval() try: from huggingface_hub import hf_hub_download le_file = hf_hub_download(repo_id=hf_model_id, filename="label_encoder.pkl") le = joblib.load(le_file) except Exception: pass # Use default label encoder st.success(f"✅ Model loaded from HuggingFace Hub on {device}") else: # HF Spaces or local path not found — load directly from Hub tokenizer = AutoTokenizer.from_pretrained(hf_model_id) model = AutoModelForSequenceClassification.from_pretrained(hf_model_id) model.to(device).eval() try: from huggingface_hub import hf_hub_download le_file = hf_hub_download(repo_id=hf_model_id, filename="label_encoder.pkl") le = joblib.load(le_file) except Exception: pass # Use default label encoder st.success(f"✅ Model loaded from HuggingFace Hub on {device}") # Initialize preprocessor if CUSTOM_MODULES_AVAILABLE: preprocessor = HateSpeechPreprocessor( model_type="xlmr", translate_english=True, cache_size=2000 ) else: preprocessor = None return model, tokenizer, le, preprocessor, device # ============================================================================ # PREDICTION FUNCTIONS # ============================================================================ def predict_text(text, model, tokenizer, label_encoder, preprocessor, max_length=256): """Make prediction with preprocessing.""" device = next(model.parameters()).device # Preprocess if preprocessor: preprocessed, emoji_features = preprocessor.preprocess(text, verbose=False) else: preprocessed = text emoji_features = {} if not preprocessed.strip(): return { 'prediction': 'NO', 'confidence': 0.0, 'probabilities': {label: 0.0 for label in label_encoder.classes_}, 'preprocessed_text': '', 'emoji_features': emoji_features, 'error': 'Empty text after preprocessing' } # Tokenize inputs = tokenizer( preprocessed, return_tensors='pt', max_length=max_length, padding='max_length', truncation=True ) input_ids = inputs['input_ids'].to(device) attention_mask = inputs['attention_mask'].to(device) # Predict with torch.no_grad(): outputs = model(input_ids, attention_mask=attention_mask) probs = torch.softmax(outputs.logits, dim=-1)[0] probs_np = probs.cpu().numpy() pred_idx = np.argmax(probs_np) pred_label = label_encoder.classes_[pred_idx] confidence = probs_np[pred_idx] return { 'prediction': pred_label, 'confidence': float(confidence), 'probabilities': { label_encoder.classes_[i]: float(probs_np[i]) for i in range(len(label_encoder.classes_)) }, 'preprocessed_text': preprocessed, 'emoji_features': emoji_features } # ============================================================================ # VISUALIZATION FUNCTIONS # ============================================================================ def plot_probabilities(probabilities): """Create probability bar chart.""" labels = list(probabilities.keys()) probs = list(probabilities.values()) colors = { 'NO': '#28a745', 'OO': '#ffc107', 'OR': '#dc3545', 'OS': '#6f42c1' } bar_colors = [colors.get(label, '#6c757d') for label in labels] fig = go.Figure(data=[ go.Bar( x=labels, y=probs, marker_color=bar_colors, text=[f'{p:.2%}' for p in probs], textposition='outside', hovertemplate='%{x}
Probability: %{y:.4f}' ) ]) fig.update_layout( title="Class Probabilities", xaxis_title="Class", yaxis_title="Probability", yaxis_range=[0, 1.1], height=400, showlegend=False, template='plotly_white' ) return fig def get_label_description(label): """Get description for each label.""" descriptions = { 'NO': '✅ Non-Offensive: The text does not contain hate speech or offensive content.', 'OO': '⚠️ Other-Offensive: Contains general offensive language but not targeted hate.', 'OR': '🚫 Offensive-Racist: Contains hate speech targeting race, ethnicity, or religion.', 'OS': '🚫 Offensive-Sexist: Contains hate speech targeting gender or sexuality.' } return descriptions.get(label, 'Unknown category') # ============================================================================ # HISTORY MANAGEMENT # ============================================================================ def save_prediction_to_history(text, result, feedback=None): """Save prediction to history file.""" entry = { 'timestamp': datetime.now().isoformat(), 'text': text, 'prediction': result.get('prediction'), 'confidence': result.get('confidence'), 'probabilities': result.get('probabilities'), 'preprocessed_text': result.get('preprocessed_text'), 'emoji_features': result.get('emoji_features', {}), 'feedback': feedback } # Load existing history history = [] if os.path.exists(HISTORY_FILE): try: with open(HISTORY_FILE, 'r', encoding='utf-8') as f: history = json.load(f) except Exception: history = [] # Append and save history.append(entry) try: with open(HISTORY_FILE, 'w', encoding='utf-8') as f: json.dump(history, f, ensure_ascii=False, indent=2) return True except Exception as e: st.error(f"Failed to save history: {e}") return False # ============================================================================ # BATCH EXPLAINABILITY HELPER # ============================================================================ def render_batch_explainability(results_df, text_column, model, tokenizer, label_encoder, preprocessor, nepali_font, explainability_available, captum_available, mode_key="batch"): """Render explainability UI for batch results.""" if not CUSTOM_MODULES_AVAILABLE: st.warning("⚠️ Explainability not available.") return if not (explainability_available['lime'] or explainability_available['shap'] or captum_available): st.warning("⚠️ No explainability methods available.") return with st.expander("💡 Explain Individual Results", expanded=False): st.markdown("**Select a text from the batch to explain:**") text_options = [f"Row {idx}: {str(row[text_column])[:50]}..." for idx, row in results_df.iterrows()] selected_idx = st.selectbox( "Choose text:", range(len(text_options)), format_func=lambda x: text_options[x], key=f"{mode_key}_select" ) selected_text = str(results_df.iloc[selected_idx][text_column]) selected_pred = results_df.iloc[selected_idx]['Prediction'] st.write(f"**Selected:** {selected_text}") st.write(f"**Prediction:** {selected_pred}") available_methods = [] if explainability_available['lime']: available_methods.append("LIME") if explainability_available['shap']: available_methods.append("SHAP") if captum_available: available_methods.append("Captum (IG)") if not available_methods: st.warning("⚠️ No explainability methods available.") return explain_method = st.selectbox( "Explanation method:", available_methods, key=f"{mode_key}_method" ) if st.button("🔍 Generate Explanation", key=f"{mode_key}_explain_btn"): with st.spinner("Generating explanation..."): try: if st.session_state.model_wrapper is None: st.session_state.model_wrapper = create_explainer_wrapper( model, tokenizer, label_encoder, preprocessor ) wrapper = st.session_state.model_wrapper clean_selected = selected_text.replace('"', '').replace("'", '').replace('\u201c', '').replace('\u201d', '') preprocessed, emoji_features = preprocessor.preprocess(clean_selected) analysis = wrapper.predict_with_analysis(clean_selected) if explain_method == "LIME": lime_exp = LIMEExplainer(wrapper, nepali_font=nepali_font) result = lime_exp.explain_and_visualize( analysis['original_text'], analysis['preprocessed_text'], save_path=None, show=False, num_samples=200 ) st.subheader("LIME Explanation") st.pyplot(result['figure']) st.markdown("---") st.markdown("**📊 Feature Importance Details:**") word_scores = result['explanation']['word_scores'] if word_scores: df = pd.DataFrame(word_scores, columns=['Word', 'Score']) df = df.sort_values('Score', ascending=False) st.dataframe(df, hide_index=True, use_container_width=True) else: st.warning("No word scores available") elif explain_method == "SHAP": shap_exp = SHAPExplainer(wrapper, nepali_font=nepali_font) result = shap_exp.explain_and_visualize( analysis['original_text'], analysis['preprocessed_text'], save_path=None, show=False, use_fallback=True ) st.subheader("SHAP Explanation") st.pyplot(result['figure']) st.markdown("---") st.markdown("**📊 Attribution Details:**") st.write(f"**Method used:** {result['explanation']['method_used']}") word_scores = result['explanation']['word_scores'] if word_scores: df = pd.DataFrame(word_scores, columns=['Word', 'Score']) df = df.sort_values('Score', key=lambda x: abs(x), ascending=False) st.dataframe(df, hide_index=True, use_container_width=True) else: st.warning("No word scores available") elif explain_method == "Captum (IG)": try: captum_exp = CaptumExplainer( model, tokenizer, label_encoder, preprocessor, emoji_to_nepali_map=EMOJI_TO_NEPALI ) result = captum_exp.explain_and_visualize( analysis['original_text'], target=None, n_steps=50, save_dir=None, show=False, nepali_font=nepali_font ) st.subheader("Captum Integrated Gradients") col1, col2 = st.columns(2) with col1: st.markdown("**Bar Chart**") st.pyplot(result['bar_chart']) with col2: st.markdown("**Heatmap**") st.pyplot(result['heatmap']) st.markdown("---") st.markdown("**📊 Attribution Details:**") st.write(f"**Convergence Delta:** {result['explanation']['convergence_delta']:.6f}") word_attrs = result['explanation']['word_attributions'] if word_attrs: df = pd.DataFrame(word_attrs, columns=['Word', 'Abs Score', 'Signed Score']) df = df.sort_values('Abs Score', ascending=False) st.dataframe(df, hide_index=True, use_container_width=True) else: st.warning("No word attributions available") except (MemoryError, RuntimeError): st.error("❌ Captum (Integrated Gradients) requires more memory than available on this server.") st.info("💡 **Tip:** Use LIME or SHAP instead — they work on cloud deployments. Captum works on local machines with more RAM/GPU.") except Exception as e: st.error(f"❌ Explanation failed: {str(e)}") st.markdown("**🐛 Error Details:**") import traceback st.code(traceback.format_exc()) # ============================================================================ # MAIN APPLICATION # ============================================================================ def main(): """Main application.""" # Load Nepali font if st.session_state.nepali_font is None: st.session_state.nepali_font = load_nepali_font() nepali_font = st.session_state.nepali_font # Header st.markdown('

🛡️ Nepali Hate Content Detector

', unsafe_allow_html=True) st.markdown("""
AI-powered hate speech detection for Nepali text with advanced explainability
XLM-RoBERTa Large fine-tuned on Nepali social media data
""", unsafe_allow_html=True) # ======================================================================== # SIDEBAR # ======================================================================== with st.sidebar: st.header("ℹ️ About") st.markdown(""" **Model**: XLM-RoBERTa Large **Task**: Multi-class hate speech detection **Language**: Nepali (Devanagari & Romanized) **Classes:** - **NO**: Non-offensive - **OO**: General offensive - **OR**: Racist/ethnic hate - **OS**: Sexist/gender hate """) st.markdown("---") st.header("🔧 Features") st.markdown(""" ✅ **Preprocessing** - Script detection - Transliteration - Translation - Emoji mapping ✅ **Explainability** - LIME - SHAP - Captum (IG) ✅ **Batch Analysis** - CSV upload - Text area input """) st.markdown("---") st.header("🎨 Font Settings") with st.expander("Nepali Font Info", expanded=False): st.markdown(f""" **Status:** {'✅ Loaded' if nepali_font else '❌ Not loaded'} **Fix squares in Devanagari:** 1. Download Kalimati.ttf 2. Create `fonts/` directory 3. Place font file there 4. Restart app """) st.markdown("---") st.header("📊 Statistics") # Session Statistics st.subheader("🔄 Current Session") if st.session_state.session_predictions > 0: st.metric("Predictions", st.session_state.session_predictions) session_counts = st.session_state.session_class_counts if any(count > 0 for count in session_counts.values()): st.write("**Session Distribution:**") for label in ['NO', 'OO', 'OR', 'OS']: count = session_counts.get(label, 0) if count > 0: pct = (count / st.session_state.session_predictions) * 100 st.write(f"• {label}: {count} ({pct:.0f}%)") else: st.info("No predictions in this session yet.") st.markdown("---") # History Statistics st.subheader("📚 All Time") if os.path.exists(HISTORY_FILE): try: with open(HISTORY_FILE, 'r', encoding='utf-8') as f: history = json.load(f) if history: st.metric("Total Saved", len(history)) pred_counts = pd.Series([h['prediction'] for h in history]).value_counts() st.write("**Distribution:**") for label, count in pred_counts.items(): st.write(f"• {label}: {count}") else: st.info("No saved predictions yet.") except Exception as e: st.warning("⚠️ History file error") with st.expander("Error details"): st.code(str(e)) else: st.info("📝 No history file\n\nEnable 'Save to history' in Tab 1 to track predictions.") st.markdown("---") st.markdown("""
Model on HuggingFace 🤗
""", unsafe_allow_html=True) # ======================================================================== # LOAD MODEL # ======================================================================== with st.spinner("Loading model..."): model, tokenizer, label_encoder, preprocessor, device = load_model_and_preprocessor() if model is None: st.error("❌ Failed to load model!") st.stop() # Check explainability availability explainability_available = check_explainability() if CUSTOM_MODULES_AVAILABLE else {'lime': False, 'shap': False} captum_available = check_captum_availability() if CUSTOM_MODULES_AVAILABLE else False # ======================================================================== # TABS # ======================================================================== tabs = st.tabs([ "🔍 Single Prediction", "💡 Explainability", "📝 Batch Analysis", "📈 History" ]) # ======================================================================== # TAB 1: SINGLE PREDICTION # ======================================================================== with tabs[0]: st.subheader("🔍 Single Text Analysis") col1, col2 = st.columns([2, 1]) with col1: text_input = st.text_area( "Enter Nepali Text", height=200, placeholder="यहाँ आफ्नो पाठ लेख्नुहोस्...\nOr enter romanized Nepali: ma khusi xu\nOr English: This is a test", help="Enter text in Devanagari, Romanized Nepali, or English." ) col_a, col_b = st.columns(2) with col_a: analyze_button = st.button("🔍 Analyze Text", type="primary", use_container_width=True) with col_b: save_to_history = st.checkbox("Save to history", value=True) with col2: st.markdown("##### 💡 Quick Info") st.info(""" **Supported:** - Devanagari: नेपाली - Romanized: ma nepali xu - English: I am Nepali - Mixed scripts - Emojis: 😀😡🙏 **Auto-processing:** - Script detection - Transliteration - Translation - Emoji → Nepali words - URL/mention removal """) if analyze_button and text_input.strip(): with st.spinner("🔄 Analyzing text..."): result = predict_text( text_input, model, tokenizer, label_encoder, preprocessor ) st.session_state.last_prediction = result st.session_state.last_text = text_input if 'prediction' in result: st.session_state.session_predictions += 1 pred_label = result['prediction'] if pred_label in st.session_state.session_class_counts: st.session_state.session_class_counts[pred_label] += 1 if save_to_history: save_prediction_to_history(text_input, result) if 'error' in result: st.warning(f"⚠️ {result['error']}") st.stop() st.markdown("---") st.subheader("📊 Analysis Results") pred_label = result['prediction'] confidence = result['confidence'] box_class = { 'NO': 'no-box', 'OO': 'oo-box', 'OR': 'or-box', 'OS': 'os-box' }.get(pred_label, 'no-box') st.markdown(f"""

Prediction: {pred_label}

Confidence: {confidence:.2%}

{get_label_description(pred_label)}

""", unsafe_allow_html=True) st.plotly_chart(plot_probabilities(result['probabilities']), use_container_width=True) with st.expander("🔍 Preprocessing Details", expanded=False): col1, col2, col3 = st.columns(3) with col1: st.markdown("**Original Text:**") st.code(text_input, language=None) with col2: st.markdown("**Preprocessed:**") st.code(result['preprocessed_text'], language=None) with col3: if CUSTOM_MODULES_AVAILABLE and preprocessor: script_info = get_script_info(text_input) st.markdown("**Script Detected:**") st.write(f"• Type: {script_info['script_type']}") confidence_pct = min(script_info['confidence'] * 100, 100.0) st.write(f"• Confidence: {confidence_pct:.1f}%") if result.get('emoji_features', {}).get('total_emoji_count', 0) > 0: with st.expander("😊 Emoji Analysis", expanded=False): features = result['emoji_features'] col1, col2, col3 = st.columns(3) with col1: st.metric("Total Emojis", features['total_emoji_count']) st.metric("Hate Emojis", features['hate_emoji_count']) with col2: st.metric("Positive Emojis", features['positive_emoji_count']) st.metric("Mockery Emojis", features['mockery_emoji_count']) with col3: st.metric("Sadness Emojis", features['sadness_emoji_count']) st.metric("Fear Emojis", features['fear_emoji_count']) if CUSTOM_MODULES_AVAILABLE: emoji_info = get_emoji_info(text_input) if emoji_info['emojis_found']: st.markdown("**Emojis Found:**") st.write(" ".join(emoji_info['emojis_found'])) with st.expander("📊 Detailed Probabilities", expanded=False): prob_df = pd.DataFrame({ 'Class': list(result['probabilities'].keys()), 'Probability': list(result['probabilities'].values()) }) prob_df['Probability'] = prob_df['Probability'].apply(lambda x: f"{x:.4f}") st.dataframe(prob_df, hide_index=True, use_container_width=True) # ======================================================================== # TAB 2: EXPLAINABILITY # ======================================================================== with tabs[1]: st.subheader("💡 Model Explainability") if not CUSTOM_MODULES_AVAILABLE: st.error("❌ Explainability modules not available. Please check scripts directory.") st.stop() st.info(f""" **Available Methods:** - LIME: {'✅' if explainability_available['lime'] else '❌ (install: pip install lime)'} - SHAP: {'✅' if explainability_available['shap'] else '❌ (install: pip install shap)'} - Captum: {'✅' if captum_available else '❌ (install: pip install captum)'} """) explain_text = st.text_area( "Enter text to explain", height=150, value=st.session_state.last_text if st.session_state.last_text else "", placeholder="Enter Nepali text..." ) available_methods = [] if explainability_available['lime']: available_methods.append("LIME") if explainability_available['shap']: available_methods.append("SHAP") if captum_available: available_methods.append("Captum (IG)") if not available_methods: st.warning("⚠️ No explainability methods available. Please install required packages.") st.code("pip install lime shap captum", language="bash") st.stop() method = st.selectbox("Select explanation method", available_methods) with st.expander("⚙️ Configuration", expanded=False): if method == "LIME": num_samples = st.slider("Number of samples", 100, 500, 200, 50) elif method == "SHAP": use_fallback = st.checkbox("Use fallback if SHAP fails", value=True) elif method == "Captum (IG)": n_steps = st.slider("Integration steps", 10, 100, 50, 10) explain_button = st.button("🔍 Generate Explanation", type="primary", use_container_width=True) if explain_button and explain_text.strip(): with st.spinner("Generating explanation..."): if st.session_state.model_wrapper is None: st.session_state.model_wrapper = create_explainer_wrapper( model, tokenizer, label_encoder, preprocessor ) wrapper = st.session_state.model_wrapper preprocessed, emoji_features = preprocessor.preprocess(explain_text) analysis = wrapper.predict_with_analysis(explain_text) st.success(f"**Prediction:** {analysis['predicted_label']} ({analysis['confidence']:.2%})") col1, col2 = st.columns(2) with col1: st.write("**Original:**", explain_text) with col2: st.write("**Preprocessed:**", preprocessed) st.markdown("---") try: if method == "LIME": lime_exp = LIMEExplainer(wrapper, nepali_font=nepali_font) result = lime_exp.explain_and_visualize( analysis['original_text'], analysis['preprocessed_text'], save_path=None, show=False, num_samples=num_samples ) st.subheader("LIME Explanation") st.pyplot(result['figure']) with st.expander("📊 Feature Importance Details"): word_scores = result['explanation']['word_scores'] df = pd.DataFrame(word_scores, columns=['Word', 'Score']) df = df.sort_values('Score', ascending=False) st.dataframe(df, hide_index=True, use_container_width=True) elif method == "SHAP": shap_exp = SHAPExplainer(wrapper, nepali_font=nepali_font) result = shap_exp.explain_and_visualize( analysis['original_text'], analysis['preprocessed_text'], save_path=None, show=False, use_fallback=use_fallback ) st.subheader("SHAP Explanation") st.pyplot(result['figure']) with st.expander("📊 Attribution Details"): st.write(f"**Method used:** {result['explanation']['method_used']}") word_scores = result['explanation']['word_scores'] df = pd.DataFrame(word_scores, columns=['Word', 'Score']) df = df.sort_values('Score', key=lambda x: abs(x), ascending=False) st.dataframe(df, hide_index=True, use_container_width=True) elif method == "Captum (IG)": try: captum_exp = CaptumExplainer( model, tokenizer, label_encoder, preprocessor, emoji_to_nepali_map=EMOJI_TO_NEPALI ) result = captum_exp.explain_and_visualize( analysis['original_text'], target=None, n_steps=n_steps, save_dir=None, show=False, nepali_font=nepali_font ) st.subheader("Captum Integrated Gradients") col1, col2 = st.columns(2) with col1: st.markdown("**Bar Chart**") st.pyplot(result['bar_chart']) with col2: st.markdown("**Heatmap**") st.pyplot(result['heatmap']) with st.expander("📊 Attribution Details"): st.write(f"**Convergence Delta:** {result['explanation']['convergence_delta']:.6f}") word_attrs = result['explanation']['word_attributions'] df = pd.DataFrame(word_attrs, columns=['Word', 'Abs Score', 'Signed Score']) df = df.sort_values('Abs Score', ascending=False) st.dataframe(df, hide_index=True, use_container_width=True) except (MemoryError, RuntimeError) as mem_err: st.error("❌ Captum (Integrated Gradients) requires more memory than available on this server.") st.info("💡 **Tip:** Use LIME or SHAP instead — they work great on cloud deployments. Captum works on local machines with more RAM/GPU.") except Exception as e: st.error(f"❌ Explanation failed: {str(e)}") with st.expander("🐛 Error Details"): st.exception(e) # ======================================================================== # TAB 3: BATCH ANALYSIS # ======================================================================== with tabs[2]: st.subheader("📝 Batch Analysis") st.markdown("### 📥 Download Example Files") col1, col2 = st.columns(2) with col1: example_csv_data = { 'text': [ 'यो राम्रो छ', 'तिमी मुर्ख हौ', 'मुस्लिम हरु सबै खराब छन्', 'केटीहरु घरमा बस्नु पर्छ', 'नमस्ते, कस्तो छ?' ] } example_csv = pd.DataFrame(example_csv_data).to_csv(index=False) st.download_button( label="📄 Download Example CSV", data=example_csv, file_name="example_batch.csv", mime="text/csv", use_container_width=True ) with col2: example_text = "यो राम्रो छ\nतिमी मुर्ख हौ\nमुस्लिम हरु सबै खराब छन्\nकेटीहरु घरमा बस्नु पर्छ\nनमस्ते, कस्तो छ?" st.download_button( label="📝 Download Example Text", data=example_text, file_name="example_batch.txt", mime="text/plain", use_container_width=True ) st.markdown("---") input_method = st.radio("Input method:", ["Text Area", "CSV Upload"]) # ---- TEXT AREA ---- if input_method == "Text Area": st.info("💡 Enter one text per line") batch_text = st.text_area( "Enter texts (one per line)", height=250, placeholder="यो राम्रो छ\nतिमी मुर्ख हौ\n..." ) if st.button("🚀 Analyze Batch", type="primary"): if batch_text.strip(): texts = [line.strip() for line in batch_text.split('\n') if line.strip()] with st.spinner(f"Analyzing {len(texts)} texts..."): results = [] progress_bar = st.progress(0) for idx, text in enumerate(texts): try: result = predict_text( text, model, tokenizer, label_encoder, preprocessor ) results.append({ 'Text': text[:60] + '...' if len(text) > 60 else text, 'Full_Text': text, 'Prediction': result['prediction'], 'Confidence': result['confidence'], 'Preprocessed': result['preprocessed_text'] }) except Exception as e: results.append({ 'Text': text[:60], 'Full_Text': text, 'Prediction': 'Error', 'Confidence': 0.0, 'Preprocessed': str(e) }) progress_bar.progress((idx + 1) / len(texts)) st.session_state.batch_results = pd.DataFrame(results) st.session_state.batch_mode = 'text_area' for result in results: if result['Prediction'] != 'Error': st.session_state.session_predictions += 1 pred_label = result['Prediction'] if pred_label in st.session_state.session_class_counts: st.session_state.session_class_counts[pred_label] += 1 st.rerun() else: st.warning("Please enter some texts.") # Display results outside button block if (st.session_state.batch_results is not None and st.session_state.get('batch_mode') == 'text_area'): results_df = st.session_state.batch_results st.success(f"✅ Analyzed {len(results_df)} texts!") display_df = results_df[['Text', 'Prediction', 'Confidence']].copy() display_df['Confidence'] = display_df['Confidence'].apply(lambda x: f"{x:.2%}") st.dataframe(display_df, use_container_width=True, hide_index=True, height=400) st.markdown("---") st.subheader("📊 Summary Statistics") col1, col2, col3 = st.columns(3) with col1: st.metric("Total Texts", len(results_df)) st.metric("Avg Confidence", f"{results_df['Confidence'].mean():.2%}") with col2: summary = results_df['Prediction'].value_counts() fig = px.pie( values=summary.values, names=summary.index, title="Prediction Distribution", color_discrete_sequence=px.colors.qualitative.Set2 ) st.plotly_chart(fig, use_container_width=True) with col3: st.markdown("**Class Breakdown:**") for label, count in summary.items(): pct = count / len(results_df) * 100 st.write(f"• {label}: {count} ({pct:.1f}%)") st.markdown("---") download_df = results_df[['Full_Text', 'Prediction', 'Confidence', 'Preprocessed']].copy() download_df.columns = ['Text', 'Prediction', 'Confidence', 'Preprocessed'] csv = download_df.to_csv(index=False) col_download, col_explain = st.columns(2) with col_download: st.download_button( label="📥 Download Results CSV", data=csv, file_name=f"batch_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv", mime="text/csv", use_container_width=True, key="download_batch_text" ) with col_explain: if st.button("💡 Explain Selected", use_container_width=True, key="hint_batch_text"): st.info("👇 Select a text below to explain") render_batch_explainability( results_df=results_df, text_column='Full_Text', model=model, tokenizer=tokenizer, label_encoder=label_encoder, preprocessor=preprocessor, nepali_font=nepali_font, explainability_available=explainability_available, captum_available=captum_available, mode_key="text_area" ) # ---- CSV UPLOAD ---- else: st.info("💡 Upload CSV with a 'text' column") uploaded_file = st.file_uploader( "Choose CSV file", type=['csv'], help="Max 200MB. Upload a CSV with a text column containing Nepali text." ) if uploaded_file: try: # Try multiple encodings for Nepali text compatibility try: df = pd.read_csv(uploaded_file, encoding='utf-8') except UnicodeDecodeError: uploaded_file.seek(0) df = pd.read_csv(uploaded_file, encoding='latin-1') st.write("📄 **File Preview:**") st.dataframe(df.head(10), use_container_width=True) text_column = st.selectbox("Select text column:", df.columns) if st.button("🚀 Analyze CSV", type="primary"): texts = df[text_column].astype(str).tolist() with st.spinner(f"Analyzing {len(texts)} texts..."): predictions = [] confidences = [] preprocessed_list = [] progress_bar = st.progress(0) for idx, text in enumerate(texts): try: result = predict_text( str(text), model, tokenizer, label_encoder, preprocessor ) predictions.append(result['prediction']) confidences.append(result['confidence']) preprocessed_list.append(result['preprocessed_text']) except Exception as e: predictions.append('Error') confidences.append(0.0) preprocessed_list.append(str(e)) progress_bar.progress((idx + 1) / len(texts)) df['Prediction'] = predictions df['Confidence'] = confidences df['Preprocessed'] = preprocessed_list st.session_state.batch_results = df st.session_state.batch_mode = 'csv' st.session_state.csv_text_column = text_column for pred in predictions: if pred != 'Error': st.session_state.session_predictions += 1 if pred in st.session_state.session_class_counts: st.session_state.session_class_counts[pred] += 1 st.rerun() # Display results outside button block if (st.session_state.batch_results is not None and st.session_state.get('batch_mode') == 'csv'): df_results = st.session_state.batch_results text_col = st.session_state.get('csv_text_column', text_column) st.success("✅ Analysis complete!") st.dataframe(df_results, use_container_width=True, height=400) st.markdown("---") st.subheader("📊 Summary") col1, col2 = st.columns(2) with col1: summary = df_results['Prediction'].value_counts() fig = px.bar( x=summary.index, y=summary.values, title="Prediction Distribution", labels={'x': 'Class', 'y': 'Count'}, color=summary.index, color_discrete_map={ 'NO': '#28a745', 'OO': '#ffc107', 'OR': '#dc3545', 'OS': '#6f42c1' } ) st.plotly_chart(fig, use_container_width=True) with col2: st.metric("Total Texts", len(df_results)) st.metric("Avg Confidence", f"{df_results['Confidence'].mean():.2%}") st.markdown("**Class Distribution:**") for label, count in summary.items(): st.write(f"• {label}: {count}") st.markdown("---") csv_data = df_results.to_csv(index=False) col_download, col_explain = st.columns(2) with col_download: st.download_button( label="📥 Download Results CSV", data=csv_data, file_name=f"predictions_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv", mime="text/csv", use_container_width=True, key="download_csv_results" ) with col_explain: if st.button("💡 Explain Selected", use_container_width=True, key="csv_explain_hint"): st.info("👇 Use expander below to explain") render_batch_explainability( results_df=df_results, text_column=text_col, model=model, tokenizer=tokenizer, label_encoder=label_encoder, preprocessor=preprocessor, nepali_font=nepali_font, explainability_available=explainability_available, captum_available=captum_available, mode_key="csv" ) except Exception as e: st.error(f"❌ Error processing file: {str(e)}") with st.expander("🐛 Error Details"): st.exception(e) # ======================================================================== # TAB 4: HISTORY # ======================================================================== with tabs[3]: st.subheader("📈 Prediction History") col1, col2 = st.columns([3, 1]) with col1: st.write("View and analyze your prediction history") with col2: if st.button("🔄 Refresh", use_container_width=True): st.rerun() if os.path.exists(HISTORY_FILE): try: with open(HISTORY_FILE, 'r', encoding='utf-8') as f: history = json.load(f) if history: history_df = pd.DataFrame(history) history_df['timestamp'] = pd.to_datetime(history_df['timestamp']) st.markdown("### 📊 Overview") col1, col2, col3, col4 = st.columns(4) with col1: st.metric("Total Predictions", len(history_df)) with col2: st.metric("Avg Confidence", f"{history_df['confidence'].mean():.2%}") with col3: if 'emoji_features' in history_df.columns: total_emojis = sum( e.get('total_emoji_count', 0) for e in history_df['emoji_features'] if isinstance(e, dict) ) st.metric("Total Emojis", total_emojis) else: st.metric("Total Emojis", "N/A") with col4: most_common = history_df['prediction'].mode()[0] st.metric("Most Common", most_common) st.markdown("---") st.markdown("### 📈 Trends") col1, col2 = st.columns(2) with col1: daily_counts = history_df.groupby( history_df['timestamp'].dt.date ).size().reset_index(name='count') fig = px.line( daily_counts, x='timestamp', y='count', title="Predictions Over Time", labels={'timestamp': 'Date', 'count': 'Count'} ) st.plotly_chart(fig, use_container_width=True) with col2: class_dist = history_df['prediction'].value_counts() fig = px.pie( values=class_dist.values, names=class_dist.index, title="Class Distribution", color=class_dist.index, color_discrete_map={ 'NO': '#28a745', 'OO': '#ffc107', 'OR': '#dc3545', 'OS': '#6f42c1' } ) st.plotly_chart(fig, use_container_width=True) st.markdown("---") st.markdown("### 📋 Recent Predictions") num_to_show = st.slider("Number to show", 5, 50, 20, 5) recent = history_df.tail(num_to_show).sort_values('timestamp', ascending=False) display = recent[['timestamp', 'text', 'prediction', 'confidence']].copy() display['confidence'] = display['confidence'].apply(lambda x: f"{x:.2%}") display['text'] = display['text'].apply(lambda x: x[:80] + '...' if len(x) > 80 else x) display['timestamp'] = display['timestamp'].dt.strftime('%Y-%m-%d %H:%M:%S') st.dataframe(display, use_container_width=True, hide_index=True, height=400) st.markdown("---") col1, col2 = st.columns(2) with col1: csv = history_df.to_csv(index=False) st.download_button( label="📥 Download Full History", data=csv, file_name=f"history_{datetime.now().strftime('%Y%m%d')}.csv", mime="text/csv", use_container_width=True ) with col2: if st.button("🗑️ Clear History", type="secondary", use_container_width=True): if os.path.exists(HISTORY_FILE): os.remove(HISTORY_FILE) st.success("✅ History cleared!") st.rerun() else: st.info("📝 No predictions in history yet.") except Exception as e: st.error(f"❌ Error loading history: {str(e)}") with st.expander("🐛 Error Details"): st.exception(e) else: st.info("📝 No history file found yet.") st.markdown(""" ### How to Build History: 1. Go to **Single Prediction** tab 2. Enable "Save to history" checkbox 3. Analyze some text 4. Your predictions will appear here! """) if __name__ == "__main__": main()