"""
Nepali Hate Speech Detection - Streamlit Application
=====================================================
Complete application with preprocessing, prediction, and explainability (LIME/SHAP/Captum)
Run with: streamlit run main_app.py
"""
import os
import sys
import streamlit as st
import pandas as pd
import numpy as np
import torch
import plotly.graph_objects as go
import plotly.express as px
from datetime import datetime
import json
import warnings
warnings.filterwarnings('ignore')
# Matplotlib for Nepali font support
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties, fontManager
# ============================================================================
# HF SPACES COMPATIBILITY — paths and environment
# ============================================================================
# Detect if running on HF Spaces
IS_HF_SPACES = bool(os.environ.get('SPACE_ID'))
# Use /tmp for writable storage on HF Spaces, local 'data/' otherwise
DATA_DIR = '/tmp/data' if IS_HF_SPACES else 'data'
os.makedirs(DATA_DIR, exist_ok=True)
HISTORY_FILE = os.path.join(DATA_DIR, 'prediction_history.json')
# ============================================================================
# SCRIPT PATH SETUP
# ============================================================================
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
SCRIPTS_DIR = os.path.join(BASE_DIR, 'scripts')
if SCRIPTS_DIR not in sys.path:
sys.path.insert(0, BASE_DIR)
sys.path.insert(0, SCRIPTS_DIR)
# ============================================================================
# CUSTOM MODULE IMPORTS
# ============================================================================
try:
from scripts.transformer_data_preprocessing import (
HateSpeechPreprocessor,
preprocess_text,
get_script_info,
get_emoji_info,
EMOJI_TO_NEPALI
)
from scripts.explainability import (
create_explainer_wrapper,
LIMEExplainer,
SHAPExplainer,
check_availability as check_explainability
)
from scripts.captum_explainer import (
CaptumExplainer,
check_availability as check_captum_availability
)
CUSTOM_MODULES_AVAILABLE = True
except MemoryError:
st.warning("⚠️ Captum not available due to memory constraints.")
CUSTOM_MODULES_AVAILABLE = False
captum_available = False
except ImportError as e:
st.error(f"⚠️ Custom modules not found: {e}")
CUSTOM_MODULES_AVAILABLE = False
# ============================================================================
# PAGE CONFIGURATION
# ============================================================================
st.set_page_config(
page_title="Nepali Hate Content Detector",
page_icon="🛡️",
layout="wide",
initial_sidebar_state="expanded"
)
# ============================================================================
# CUSTOM CSS
# ============================================================================
st.markdown("""
""", unsafe_allow_html=True)
# ============================================================================
# NEPALI FONT LOADING
# ============================================================================
@st.cache_resource
def load_nepali_font():
"""Load Nepali font for matplotlib visualizations.
Tries multiple font paths in order of preference:
1. Kalimati (primary — downloaded by Dockerfile)
2. Noto Sans Devanagari (Linux/HF Spaces fallback)
3. Other system Devanagari fonts (macOS, Windows)
"""
font_paths = [
# ── Kalimati (primary) ──────────────────────────────────────────
# HF Spaces / Docker — downloaded by Dockerfile curl command
'/app/fonts/Kalimati.ttf',
# Registered system-wide by fc-cache in Dockerfile
'/usr/local/share/fonts/nepali/Kalimati.ttf',
# Local dev — absolute path relative to script location
os.path.join(BASE_DIR, 'fonts', 'Kalimati.ttf'),
# Local dev — relative path
'fonts/Kalimati.ttf',
# ── Noto Sans Devanagari (Linux / HF Spaces fallback) ───────────
'/usr/share/fonts/truetype/noto/NotoSansDevanagari-Regular.ttf',
'/usr/share/fonts/truetype/noto/NotoSansDevanagari[wdth,wght].ttf',
'/usr/share/fonts/opentype/noto/NotoSansDevanagari-Regular.otf',
'/usr/share/fonts/noto/NotoSansDevanagari-Regular.ttf',
# Noto Serif Devanagari variant
'/usr/share/fonts/truetype/noto/NotoSerifDevanagari-Regular.ttf',
# Generic Noto fallback
'/usr/share/fonts/truetype/noto/NotoSans-Regular.ttf',
# ── macOS ────────────────────────────────────────────────────────
'/System/Library/Fonts/Supplemental/DevanagariSangamMN.ttc',
'/System/Library/Fonts/Supplemental/DevanagariMT.ttc',
'/Library/Fonts/Devanagari Sangam MN.ttc',
# ── Windows ─────────────────────────────────────────────────────
r'C:\Windows\Fonts\NirmalaUI.ttf',
r'C:\Windows\Fonts\NirmalaUI-Bold.ttf',
r'C:\Windows\Fonts\mangal.ttf',
r'C:\Windows\Fonts\Aparajita.ttf',
]
for font_path in font_paths:
if os.path.exists(font_path):
try:
fontManager.addfont(font_path)
fp = FontProperties(fname=font_path)
return fp
except Exception:
continue
# Silent failure — charts still render, just without Devanagari-specific glyphs
return None
# ============================================================================
# SESSION STATE INITIALIZATION
# ============================================================================
if 'last_prediction' not in st.session_state:
st.session_state.last_prediction = None
if 'last_text' not in st.session_state:
st.session_state.last_text = ""
if 'batch_results' not in st.session_state:
st.session_state.batch_results = None
if 'batch_mode' not in st.session_state:
st.session_state.batch_mode = None
if 'csv_text_column' not in st.session_state:
st.session_state.csv_text_column = None
if 'explainability_results' not in st.session_state:
st.session_state.explainability_results = None
if 'preprocessor' not in st.session_state:
st.session_state.preprocessor = None
if 'model_wrapper' not in st.session_state:
st.session_state.model_wrapper = None
if 'nepali_font' not in st.session_state:
st.session_state.nepali_font = None
if 'session_predictions' not in st.session_state:
st.session_state.session_predictions = 0
if 'session_class_counts' not in st.session_state:
st.session_state.session_class_counts = {'NO': 0, 'OO': 0, 'OR': 0, 'OS': 0}
# ============================================================================
# MODEL LOADING
# ============================================================================
@st.cache_resource(show_spinner="Loading model... this may take a minute on first run.")
def load_model_and_preprocessor():
"""Load model, tokenizer, label encoder, and preprocessor."""
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import joblib
hf_model_id = "UDHOV/xlm-roberta-large-nepali-hate-classification"
local_model_path = 'models/saved_models/xlm_roberta_results/large_final'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Initialize default label encoder as fallback
from sklearn.preprocessing import LabelEncoder
le = LabelEncoder()
le.fit(['NO', 'OO', 'OR', 'OS'])
# Try local model first (only relevant for local dev), then HF Hub
if not IS_HF_SPACES and os.path.exists(local_model_path):
try:
tokenizer = AutoTokenizer.from_pretrained(local_model_path)
model = AutoModelForSequenceClassification.from_pretrained(local_model_path)
model.to(device).eval()
le_path = os.path.join(local_model_path, 'label_encoder.pkl')
if os.path.exists(le_path):
le = joblib.load(le_path)
st.success(f"✅ Model loaded from local path on {device}")
except Exception as e:
st.warning(f"⚠️ Local model failed: {e}. Falling back to HuggingFace Hub...")
tokenizer = AutoTokenizer.from_pretrained(hf_model_id)
model = AutoModelForSequenceClassification.from_pretrained(hf_model_id)
model.to(device).eval()
try:
from huggingface_hub import hf_hub_download
le_file = hf_hub_download(repo_id=hf_model_id, filename="label_encoder.pkl")
le = joblib.load(le_file)
except Exception:
pass # Use default label encoder
st.success(f"✅ Model loaded from HuggingFace Hub on {device}")
else:
# HF Spaces or local path not found — load directly from Hub
tokenizer = AutoTokenizer.from_pretrained(hf_model_id)
model = AutoModelForSequenceClassification.from_pretrained(hf_model_id)
model.to(device).eval()
try:
from huggingface_hub import hf_hub_download
le_file = hf_hub_download(repo_id=hf_model_id, filename="label_encoder.pkl")
le = joblib.load(le_file)
except Exception:
pass # Use default label encoder
st.success(f"✅ Model loaded from HuggingFace Hub on {device}")
# Initialize preprocessor
if CUSTOM_MODULES_AVAILABLE:
preprocessor = HateSpeechPreprocessor(
model_type="xlmr",
translate_english=True,
cache_size=2000
)
else:
preprocessor = None
return model, tokenizer, le, preprocessor, device
# ============================================================================
# PREDICTION FUNCTIONS
# ============================================================================
def predict_text(text, model, tokenizer, label_encoder, preprocessor, max_length=256):
"""Make prediction with preprocessing."""
device = next(model.parameters()).device
# Preprocess
if preprocessor:
preprocessed, emoji_features = preprocessor.preprocess(text, verbose=False)
else:
preprocessed = text
emoji_features = {}
if not preprocessed.strip():
return {
'prediction': 'NO',
'confidence': 0.0,
'probabilities': {label: 0.0 for label in label_encoder.classes_},
'preprocessed_text': '',
'emoji_features': emoji_features,
'error': 'Empty text after preprocessing'
}
# Tokenize
inputs = tokenizer(
preprocessed,
return_tensors='pt',
max_length=max_length,
padding='max_length',
truncation=True
)
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)
# Predict
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask)
probs = torch.softmax(outputs.logits, dim=-1)[0]
probs_np = probs.cpu().numpy()
pred_idx = np.argmax(probs_np)
pred_label = label_encoder.classes_[pred_idx]
confidence = probs_np[pred_idx]
return {
'prediction': pred_label,
'confidence': float(confidence),
'probabilities': {
label_encoder.classes_[i]: float(probs_np[i])
for i in range(len(label_encoder.classes_))
},
'preprocessed_text': preprocessed,
'emoji_features': emoji_features
}
# ============================================================================
# VISUALIZATION FUNCTIONS
# ============================================================================
def plot_probabilities(probabilities):
"""Create probability bar chart."""
labels = list(probabilities.keys())
probs = list(probabilities.values())
colors = {
'NO': '#28a745',
'OO': '#ffc107',
'OR': '#dc3545',
'OS': '#6f42c1'
}
bar_colors = [colors.get(label, '#6c757d') for label in labels]
fig = go.Figure(data=[
go.Bar(
x=labels,
y=probs,
marker_color=bar_colors,
text=[f'{p:.2%}' for p in probs],
textposition='outside',
hovertemplate='%{x}
Probability: %{y:.4f}'
)
])
fig.update_layout(
title="Class Probabilities",
xaxis_title="Class",
yaxis_title="Probability",
yaxis_range=[0, 1.1],
height=400,
showlegend=False,
template='plotly_white'
)
return fig
def get_label_description(label):
"""Get description for each label."""
descriptions = {
'NO': '✅ Non-Offensive: The text does not contain hate speech or offensive content.',
'OO': '⚠️ Other-Offensive: Contains general offensive language but not targeted hate.',
'OR': '🚫 Offensive-Racist: Contains hate speech targeting race, ethnicity, or religion.',
'OS': '🚫 Offensive-Sexist: Contains hate speech targeting gender or sexuality.'
}
return descriptions.get(label, 'Unknown category')
# ============================================================================
# HISTORY MANAGEMENT
# ============================================================================
def save_prediction_to_history(text, result, feedback=None):
"""Save prediction to history file."""
entry = {
'timestamp': datetime.now().isoformat(),
'text': text,
'prediction': result.get('prediction'),
'confidence': result.get('confidence'),
'probabilities': result.get('probabilities'),
'preprocessed_text': result.get('preprocessed_text'),
'emoji_features': result.get('emoji_features', {}),
'feedback': feedback
}
# Load existing history
history = []
if os.path.exists(HISTORY_FILE):
try:
with open(HISTORY_FILE, 'r', encoding='utf-8') as f:
history = json.load(f)
except Exception:
history = []
# Append and save
history.append(entry)
try:
with open(HISTORY_FILE, 'w', encoding='utf-8') as f:
json.dump(history, f, ensure_ascii=False, indent=2)
return True
except Exception as e:
st.error(f"Failed to save history: {e}")
return False
# ============================================================================
# BATCH EXPLAINABILITY HELPER
# ============================================================================
def render_batch_explainability(results_df, text_column, model, tokenizer, label_encoder,
preprocessor, nepali_font, explainability_available,
captum_available, mode_key="batch"):
"""Render explainability UI for batch results."""
if not CUSTOM_MODULES_AVAILABLE:
st.warning("⚠️ Explainability not available.")
return
if not (explainability_available['lime'] or explainability_available['shap'] or captum_available):
st.warning("⚠️ No explainability methods available.")
return
with st.expander("💡 Explain Individual Results", expanded=False):
st.markdown("**Select a text from the batch to explain:**")
text_options = [f"Row {idx}: {str(row[text_column])[:50]}..." for idx, row in results_df.iterrows()]
selected_idx = st.selectbox(
"Choose text:",
range(len(text_options)),
format_func=lambda x: text_options[x],
key=f"{mode_key}_select"
)
selected_text = str(results_df.iloc[selected_idx][text_column])
selected_pred = results_df.iloc[selected_idx]['Prediction']
st.write(f"**Selected:** {selected_text}")
st.write(f"**Prediction:** {selected_pred}")
available_methods = []
if explainability_available['lime']:
available_methods.append("LIME")
if explainability_available['shap']:
available_methods.append("SHAP")
if captum_available:
available_methods.append("Captum (IG)")
if not available_methods:
st.warning("⚠️ No explainability methods available.")
return
explain_method = st.selectbox(
"Explanation method:",
available_methods,
key=f"{mode_key}_method"
)
if st.button("🔍 Generate Explanation", key=f"{mode_key}_explain_btn"):
with st.spinner("Generating explanation..."):
try:
if st.session_state.model_wrapper is None:
st.session_state.model_wrapper = create_explainer_wrapper(
model, tokenizer, label_encoder, preprocessor
)
wrapper = st.session_state.model_wrapper
clean_selected = selected_text.replace('"', '').replace("'", '').replace('\u201c', '').replace('\u201d', '')
preprocessed, emoji_features = preprocessor.preprocess(clean_selected)
analysis = wrapper.predict_with_analysis(clean_selected)
if explain_method == "LIME":
lime_exp = LIMEExplainer(wrapper, nepali_font=nepali_font)
result = lime_exp.explain_and_visualize(
analysis['original_text'],
analysis['preprocessed_text'],
save_path=None,
show=False,
num_samples=200
)
st.subheader("LIME Explanation")
st.pyplot(result['figure'])
st.markdown("---")
st.markdown("**📊 Feature Importance Details:**")
word_scores = result['explanation']['word_scores']
if word_scores:
df = pd.DataFrame(word_scores, columns=['Word', 'Score'])
df = df.sort_values('Score', ascending=False)
st.dataframe(df, hide_index=True, use_container_width=True)
else:
st.warning("No word scores available")
elif explain_method == "SHAP":
shap_exp = SHAPExplainer(wrapper, nepali_font=nepali_font)
result = shap_exp.explain_and_visualize(
analysis['original_text'],
analysis['preprocessed_text'],
save_path=None,
show=False,
use_fallback=True
)
st.subheader("SHAP Explanation")
st.pyplot(result['figure'])
st.markdown("---")
st.markdown("**📊 Attribution Details:**")
st.write(f"**Method used:** {result['explanation']['method_used']}")
word_scores = result['explanation']['word_scores']
if word_scores:
df = pd.DataFrame(word_scores, columns=['Word', 'Score'])
df = df.sort_values('Score', key=lambda x: abs(x), ascending=False)
st.dataframe(df, hide_index=True, use_container_width=True)
else:
st.warning("No word scores available")
elif explain_method == "Captum (IG)":
try:
captum_exp = CaptumExplainer(
model, tokenizer, label_encoder, preprocessor,
emoji_to_nepali_map=EMOJI_TO_NEPALI
)
result = captum_exp.explain_and_visualize(
analysis['original_text'],
target=None,
n_steps=50,
save_dir=None,
show=False,
nepali_font=nepali_font
)
st.subheader("Captum Integrated Gradients")
col1, col2 = st.columns(2)
with col1:
st.markdown("**Bar Chart**")
st.pyplot(result['bar_chart'])
with col2:
st.markdown("**Heatmap**")
st.pyplot(result['heatmap'])
st.markdown("---")
st.markdown("**📊 Attribution Details:**")
st.write(f"**Convergence Delta:** {result['explanation']['convergence_delta']:.6f}")
word_attrs = result['explanation']['word_attributions']
if word_attrs:
df = pd.DataFrame(word_attrs, columns=['Word', 'Abs Score', 'Signed Score'])
df = df.sort_values('Abs Score', ascending=False)
st.dataframe(df, hide_index=True, use_container_width=True)
else:
st.warning("No word attributions available")
except (MemoryError, RuntimeError):
st.error("❌ Captum (Integrated Gradients) requires more memory than available on this server.")
st.info("💡 **Tip:** Use LIME or SHAP instead — they work on cloud deployments. Captum works on local machines with more RAM/GPU.")
except Exception as e:
st.error(f"❌ Explanation failed: {str(e)}")
st.markdown("**🐛 Error Details:**")
import traceback
st.code(traceback.format_exc())
# ============================================================================
# MAIN APPLICATION
# ============================================================================
def main():
"""Main application."""
# Load Nepali font
if st.session_state.nepali_font is None:
st.session_state.nepali_font = load_nepali_font()
nepali_font = st.session_state.nepali_font
# Header
st.markdown('
🛡️ Nepali Hate Content Detector
', unsafe_allow_html=True)
st.markdown("""
""", unsafe_allow_html=True)
# ========================================================================
# SIDEBAR
# ========================================================================
with st.sidebar:
st.header("ℹ️ About")
st.markdown("""
**Model**: XLM-RoBERTa Large
**Task**: Multi-class hate speech detection
**Language**: Nepali (Devanagari & Romanized)
**Classes:**
- **NO**: Non-offensive
- **OO**: General offensive
- **OR**: Racist/ethnic hate
- **OS**: Sexist/gender hate
""")
st.markdown("---")
st.header("🔧 Features")
st.markdown("""
✅ **Preprocessing**
- Script detection
- Transliteration
- Translation
- Emoji mapping
✅ **Explainability**
- LIME
- SHAP
- Captum (IG)
✅ **Batch Analysis**
- CSV upload
- Text area input
""")
st.markdown("---")
st.header("🎨 Font Settings")
with st.expander("Nepali Font Info", expanded=False):
st.markdown(f"""
**Status:** {'✅ Loaded' if nepali_font else '❌ Not loaded'}
**Fix squares in Devanagari:**
1. Download Kalimati.ttf
2. Create `fonts/` directory
3. Place font file there
4. Restart app
""")
st.markdown("---")
st.header("📊 Statistics")
# Session Statistics
st.subheader("🔄 Current Session")
if st.session_state.session_predictions > 0:
st.metric("Predictions", st.session_state.session_predictions)
session_counts = st.session_state.session_class_counts
if any(count > 0 for count in session_counts.values()):
st.write("**Session Distribution:**")
for label in ['NO', 'OO', 'OR', 'OS']:
count = session_counts.get(label, 0)
if count > 0:
pct = (count / st.session_state.session_predictions) * 100
st.write(f"• {label}: {count} ({pct:.0f}%)")
else:
st.info("No predictions in this session yet.")
st.markdown("---")
# History Statistics
st.subheader("📚 All Time")
if os.path.exists(HISTORY_FILE):
try:
with open(HISTORY_FILE, 'r', encoding='utf-8') as f:
history = json.load(f)
if history:
st.metric("Total Saved", len(history))
pred_counts = pd.Series([h['prediction'] for h in history]).value_counts()
st.write("**Distribution:**")
for label, count in pred_counts.items():
st.write(f"• {label}: {count}")
else:
st.info("No saved predictions yet.")
except Exception as e:
st.warning("⚠️ History file error")
with st.expander("Error details"):
st.code(str(e))
else:
st.info("📝 No history file\n\nEnable 'Save to history' in Tab 1 to track predictions.")
st.markdown("---")
st.markdown("""
""", unsafe_allow_html=True)
# ========================================================================
# LOAD MODEL
# ========================================================================
with st.spinner("Loading model..."):
model, tokenizer, label_encoder, preprocessor, device = load_model_and_preprocessor()
if model is None:
st.error("❌ Failed to load model!")
st.stop()
# Check explainability availability
explainability_available = check_explainability() if CUSTOM_MODULES_AVAILABLE else {'lime': False, 'shap': False}
captum_available = check_captum_availability() if CUSTOM_MODULES_AVAILABLE else False
# ========================================================================
# TABS
# ========================================================================
tabs = st.tabs([
"🔍 Single Prediction",
"💡 Explainability",
"📝 Batch Analysis",
"📈 History"
])
# ========================================================================
# TAB 1: SINGLE PREDICTION
# ========================================================================
with tabs[0]:
st.subheader("🔍 Single Text Analysis")
col1, col2 = st.columns([2, 1])
with col1:
text_input = st.text_area(
"Enter Nepali Text",
height=200,
placeholder="यहाँ आफ्नो पाठ लेख्नुहोस्...\nOr enter romanized Nepali: ma khusi xu\nOr English: This is a test",
help="Enter text in Devanagari, Romanized Nepali, or English."
)
col_a, col_b = st.columns(2)
with col_a:
analyze_button = st.button("🔍 Analyze Text", type="primary", use_container_width=True)
with col_b:
save_to_history = st.checkbox("Save to history", value=True)
with col2:
st.markdown("##### 💡 Quick Info")
st.info("""
**Supported:**
- Devanagari: नेपाली
- Romanized: ma nepali xu
- English: I am Nepali
- Mixed scripts
- Emojis: 😀😡🙏
**Auto-processing:**
- Script detection
- Transliteration
- Translation
- Emoji → Nepali words
- URL/mention removal
""")
if analyze_button and text_input.strip():
with st.spinner("🔄 Analyzing text..."):
result = predict_text(
text_input, model, tokenizer,
label_encoder, preprocessor
)
st.session_state.last_prediction = result
st.session_state.last_text = text_input
if 'prediction' in result:
st.session_state.session_predictions += 1
pred_label = result['prediction']
if pred_label in st.session_state.session_class_counts:
st.session_state.session_class_counts[pred_label] += 1
if save_to_history:
save_prediction_to_history(text_input, result)
if 'error' in result:
st.warning(f"⚠️ {result['error']}")
st.stop()
st.markdown("---")
st.subheader("📊 Analysis Results")
pred_label = result['prediction']
confidence = result['confidence']
box_class = {
'NO': 'no-box',
'OO': 'oo-box',
'OR': 'or-box',
'OS': 'os-box'
}.get(pred_label, 'no-box')
st.markdown(f"""
Prediction: {pred_label}
Confidence: {confidence:.2%}
{get_label_description(pred_label)}
""", unsafe_allow_html=True)
st.plotly_chart(plot_probabilities(result['probabilities']), use_container_width=True)
with st.expander("🔍 Preprocessing Details", expanded=False):
col1, col2, col3 = st.columns(3)
with col1:
st.markdown("**Original Text:**")
st.code(text_input, language=None)
with col2:
st.markdown("**Preprocessed:**")
st.code(result['preprocessed_text'], language=None)
with col3:
if CUSTOM_MODULES_AVAILABLE and preprocessor:
script_info = get_script_info(text_input)
st.markdown("**Script Detected:**")
st.write(f"• Type: {script_info['script_type']}")
confidence_pct = min(script_info['confidence'] * 100, 100.0)
st.write(f"• Confidence: {confidence_pct:.1f}%")
if result.get('emoji_features', {}).get('total_emoji_count', 0) > 0:
with st.expander("😊 Emoji Analysis", expanded=False):
features = result['emoji_features']
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Total Emojis", features['total_emoji_count'])
st.metric("Hate Emojis", features['hate_emoji_count'])
with col2:
st.metric("Positive Emojis", features['positive_emoji_count'])
st.metric("Mockery Emojis", features['mockery_emoji_count'])
with col3:
st.metric("Sadness Emojis", features['sadness_emoji_count'])
st.metric("Fear Emojis", features['fear_emoji_count'])
if CUSTOM_MODULES_AVAILABLE:
emoji_info = get_emoji_info(text_input)
if emoji_info['emojis_found']:
st.markdown("**Emojis Found:**")
st.write(" ".join(emoji_info['emojis_found']))
with st.expander("📊 Detailed Probabilities", expanded=False):
prob_df = pd.DataFrame({
'Class': list(result['probabilities'].keys()),
'Probability': list(result['probabilities'].values())
})
prob_df['Probability'] = prob_df['Probability'].apply(lambda x: f"{x:.4f}")
st.dataframe(prob_df, hide_index=True, use_container_width=True)
# ========================================================================
# TAB 2: EXPLAINABILITY
# ========================================================================
with tabs[1]:
st.subheader("💡 Model Explainability")
if not CUSTOM_MODULES_AVAILABLE:
st.error("❌ Explainability modules not available. Please check scripts directory.")
st.stop()
st.info(f"""
**Available Methods:**
- LIME: {'✅' if explainability_available['lime'] else '❌ (install: pip install lime)'}
- SHAP: {'✅' if explainability_available['shap'] else '❌ (install: pip install shap)'}
- Captum: {'✅' if captum_available else '❌ (install: pip install captum)'}
""")
explain_text = st.text_area(
"Enter text to explain",
height=150,
value=st.session_state.last_text if st.session_state.last_text else "",
placeholder="Enter Nepali text..."
)
available_methods = []
if explainability_available['lime']:
available_methods.append("LIME")
if explainability_available['shap']:
available_methods.append("SHAP")
if captum_available:
available_methods.append("Captum (IG)")
if not available_methods:
st.warning("⚠️ No explainability methods available. Please install required packages.")
st.code("pip install lime shap captum", language="bash")
st.stop()
method = st.selectbox("Select explanation method", available_methods)
with st.expander("⚙️ Configuration", expanded=False):
if method == "LIME":
num_samples = st.slider("Number of samples", 100, 500, 200, 50)
elif method == "SHAP":
use_fallback = st.checkbox("Use fallback if SHAP fails", value=True)
elif method == "Captum (IG)":
n_steps = st.slider("Integration steps", 10, 100, 50, 10)
explain_button = st.button("🔍 Generate Explanation", type="primary", use_container_width=True)
if explain_button and explain_text.strip():
with st.spinner("Generating explanation..."):
if st.session_state.model_wrapper is None:
st.session_state.model_wrapper = create_explainer_wrapper(
model, tokenizer, label_encoder, preprocessor
)
wrapper = st.session_state.model_wrapper
preprocessed, emoji_features = preprocessor.preprocess(explain_text)
analysis = wrapper.predict_with_analysis(explain_text)
st.success(f"**Prediction:** {analysis['predicted_label']} ({analysis['confidence']:.2%})")
col1, col2 = st.columns(2)
with col1:
st.write("**Original:**", explain_text)
with col2:
st.write("**Preprocessed:**", preprocessed)
st.markdown("---")
try:
if method == "LIME":
lime_exp = LIMEExplainer(wrapper, nepali_font=nepali_font)
result = lime_exp.explain_and_visualize(
analysis['original_text'],
analysis['preprocessed_text'],
save_path=None,
show=False,
num_samples=num_samples
)
st.subheader("LIME Explanation")
st.pyplot(result['figure'])
with st.expander("📊 Feature Importance Details"):
word_scores = result['explanation']['word_scores']
df = pd.DataFrame(word_scores, columns=['Word', 'Score'])
df = df.sort_values('Score', ascending=False)
st.dataframe(df, hide_index=True, use_container_width=True)
elif method == "SHAP":
shap_exp = SHAPExplainer(wrapper, nepali_font=nepali_font)
result = shap_exp.explain_and_visualize(
analysis['original_text'],
analysis['preprocessed_text'],
save_path=None,
show=False,
use_fallback=use_fallback
)
st.subheader("SHAP Explanation")
st.pyplot(result['figure'])
with st.expander("📊 Attribution Details"):
st.write(f"**Method used:** {result['explanation']['method_used']}")
word_scores = result['explanation']['word_scores']
df = pd.DataFrame(word_scores, columns=['Word', 'Score'])
df = df.sort_values('Score', key=lambda x: abs(x), ascending=False)
st.dataframe(df, hide_index=True, use_container_width=True)
elif method == "Captum (IG)":
try:
captum_exp = CaptumExplainer(
model, tokenizer, label_encoder, preprocessor,
emoji_to_nepali_map=EMOJI_TO_NEPALI
)
result = captum_exp.explain_and_visualize(
analysis['original_text'],
target=None,
n_steps=n_steps,
save_dir=None,
show=False,
nepali_font=nepali_font
)
st.subheader("Captum Integrated Gradients")
col1, col2 = st.columns(2)
with col1:
st.markdown("**Bar Chart**")
st.pyplot(result['bar_chart'])
with col2:
st.markdown("**Heatmap**")
st.pyplot(result['heatmap'])
with st.expander("📊 Attribution Details"):
st.write(f"**Convergence Delta:** {result['explanation']['convergence_delta']:.6f}")
word_attrs = result['explanation']['word_attributions']
df = pd.DataFrame(word_attrs, columns=['Word', 'Abs Score', 'Signed Score'])
df = df.sort_values('Abs Score', ascending=False)
st.dataframe(df, hide_index=True, use_container_width=True)
except (MemoryError, RuntimeError) as mem_err:
st.error("❌ Captum (Integrated Gradients) requires more memory than available on this server.")
st.info("💡 **Tip:** Use LIME or SHAP instead — they work great on cloud deployments. Captum works on local machines with more RAM/GPU.")
except Exception as e:
st.error(f"❌ Explanation failed: {str(e)}")
with st.expander("🐛 Error Details"):
st.exception(e)
# ========================================================================
# TAB 3: BATCH ANALYSIS
# ========================================================================
with tabs[2]:
st.subheader("📝 Batch Analysis")
st.markdown("### 📥 Download Example Files")
col1, col2 = st.columns(2)
with col1:
example_csv_data = {
'text': [
'यो राम्रो छ',
'तिमी मुर्ख हौ',
'मुस्लिम हरु सबै खराब छन्',
'केटीहरु घरमा बस्नु पर्छ',
'नमस्ते, कस्तो छ?'
]
}
example_csv = pd.DataFrame(example_csv_data).to_csv(index=False)
st.download_button(
label="📄 Download Example CSV",
data=example_csv,
file_name="example_batch.csv",
mime="text/csv",
use_container_width=True
)
with col2:
example_text = "यो राम्रो छ\nतिमी मुर्ख हौ\nमुस्लिम हरु सबै खराब छन्\nकेटीहरु घरमा बस्नु पर्छ\nनमस्ते, कस्तो छ?"
st.download_button(
label="📝 Download Example Text",
data=example_text,
file_name="example_batch.txt",
mime="text/plain",
use_container_width=True
)
st.markdown("---")
input_method = st.radio("Input method:", ["Text Area", "CSV Upload"])
# ---- TEXT AREA ----
if input_method == "Text Area":
st.info("💡 Enter one text per line")
batch_text = st.text_area(
"Enter texts (one per line)",
height=250,
placeholder="यो राम्रो छ\nतिमी मुर्ख हौ\n..."
)
if st.button("🚀 Analyze Batch", type="primary"):
if batch_text.strip():
texts = [line.strip() for line in batch_text.split('\n') if line.strip()]
with st.spinner(f"Analyzing {len(texts)} texts..."):
results = []
progress_bar = st.progress(0)
for idx, text in enumerate(texts):
try:
result = predict_text(
text, model, tokenizer,
label_encoder, preprocessor
)
results.append({
'Text': text[:60] + '...' if len(text) > 60 else text,
'Full_Text': text,
'Prediction': result['prediction'],
'Confidence': result['confidence'],
'Preprocessed': result['preprocessed_text']
})
except Exception as e:
results.append({
'Text': text[:60],
'Full_Text': text,
'Prediction': 'Error',
'Confidence': 0.0,
'Preprocessed': str(e)
})
progress_bar.progress((idx + 1) / len(texts))
st.session_state.batch_results = pd.DataFrame(results)
st.session_state.batch_mode = 'text_area'
for result in results:
if result['Prediction'] != 'Error':
st.session_state.session_predictions += 1
pred_label = result['Prediction']
if pred_label in st.session_state.session_class_counts:
st.session_state.session_class_counts[pred_label] += 1
st.rerun()
else:
st.warning("Please enter some texts.")
# Display results outside button block
if (st.session_state.batch_results is not None and
st.session_state.get('batch_mode') == 'text_area'):
results_df = st.session_state.batch_results
st.success(f"✅ Analyzed {len(results_df)} texts!")
display_df = results_df[['Text', 'Prediction', 'Confidence']].copy()
display_df['Confidence'] = display_df['Confidence'].apply(lambda x: f"{x:.2%}")
st.dataframe(display_df, use_container_width=True, hide_index=True, height=400)
st.markdown("---")
st.subheader("📊 Summary Statistics")
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Total Texts", len(results_df))
st.metric("Avg Confidence", f"{results_df['Confidence'].mean():.2%}")
with col2:
summary = results_df['Prediction'].value_counts()
fig = px.pie(
values=summary.values,
names=summary.index,
title="Prediction Distribution",
color_discrete_sequence=px.colors.qualitative.Set2
)
st.plotly_chart(fig, use_container_width=True)
with col3:
st.markdown("**Class Breakdown:**")
for label, count in summary.items():
pct = count / len(results_df) * 100
st.write(f"• {label}: {count} ({pct:.1f}%)")
st.markdown("---")
download_df = results_df[['Full_Text', 'Prediction', 'Confidence', 'Preprocessed']].copy()
download_df.columns = ['Text', 'Prediction', 'Confidence', 'Preprocessed']
csv = download_df.to_csv(index=False)
col_download, col_explain = st.columns(2)
with col_download:
st.download_button(
label="📥 Download Results CSV",
data=csv,
file_name=f"batch_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv",
mime="text/csv",
use_container_width=True,
key="download_batch_text"
)
with col_explain:
if st.button("💡 Explain Selected", use_container_width=True, key="hint_batch_text"):
st.info("👇 Select a text below to explain")
render_batch_explainability(
results_df=results_df,
text_column='Full_Text',
model=model,
tokenizer=tokenizer,
label_encoder=label_encoder,
preprocessor=preprocessor,
nepali_font=nepali_font,
explainability_available=explainability_available,
captum_available=captum_available,
mode_key="text_area"
)
# ---- CSV UPLOAD ----
else:
st.info("💡 Upload CSV with a 'text' column")
uploaded_file = st.file_uploader(
"Choose CSV file",
type=['csv'],
help="Max 200MB. Upload a CSV with a text column containing Nepali text."
)
if uploaded_file:
try:
# Try multiple encodings for Nepali text compatibility
try:
df = pd.read_csv(uploaded_file, encoding='utf-8')
except UnicodeDecodeError:
uploaded_file.seek(0)
df = pd.read_csv(uploaded_file, encoding='latin-1')
st.write("📄 **File Preview:**")
st.dataframe(df.head(10), use_container_width=True)
text_column = st.selectbox("Select text column:", df.columns)
if st.button("🚀 Analyze CSV", type="primary"):
texts = df[text_column].astype(str).tolist()
with st.spinner(f"Analyzing {len(texts)} texts..."):
predictions = []
confidences = []
preprocessed_list = []
progress_bar = st.progress(0)
for idx, text in enumerate(texts):
try:
result = predict_text(
str(text), model, tokenizer,
label_encoder, preprocessor
)
predictions.append(result['prediction'])
confidences.append(result['confidence'])
preprocessed_list.append(result['preprocessed_text'])
except Exception as e:
predictions.append('Error')
confidences.append(0.0)
preprocessed_list.append(str(e))
progress_bar.progress((idx + 1) / len(texts))
df['Prediction'] = predictions
df['Confidence'] = confidences
df['Preprocessed'] = preprocessed_list
st.session_state.batch_results = df
st.session_state.batch_mode = 'csv'
st.session_state.csv_text_column = text_column
for pred in predictions:
if pred != 'Error':
st.session_state.session_predictions += 1
if pred in st.session_state.session_class_counts:
st.session_state.session_class_counts[pred] += 1
st.rerun()
# Display results outside button block
if (st.session_state.batch_results is not None and
st.session_state.get('batch_mode') == 'csv'):
df_results = st.session_state.batch_results
text_col = st.session_state.get('csv_text_column', text_column)
st.success("✅ Analysis complete!")
st.dataframe(df_results, use_container_width=True, height=400)
st.markdown("---")
st.subheader("📊 Summary")
col1, col2 = st.columns(2)
with col1:
summary = df_results['Prediction'].value_counts()
fig = px.bar(
x=summary.index,
y=summary.values,
title="Prediction Distribution",
labels={'x': 'Class', 'y': 'Count'},
color=summary.index,
color_discrete_map={
'NO': '#28a745',
'OO': '#ffc107',
'OR': '#dc3545',
'OS': '#6f42c1'
}
)
st.plotly_chart(fig, use_container_width=True)
with col2:
st.metric("Total Texts", len(df_results))
st.metric("Avg Confidence", f"{df_results['Confidence'].mean():.2%}")
st.markdown("**Class Distribution:**")
for label, count in summary.items():
st.write(f"• {label}: {count}")
st.markdown("---")
csv_data = df_results.to_csv(index=False)
col_download, col_explain = st.columns(2)
with col_download:
st.download_button(
label="📥 Download Results CSV",
data=csv_data,
file_name=f"predictions_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv",
mime="text/csv",
use_container_width=True,
key="download_csv_results"
)
with col_explain:
if st.button("💡 Explain Selected", use_container_width=True, key="csv_explain_hint"):
st.info("👇 Use expander below to explain")
render_batch_explainability(
results_df=df_results,
text_column=text_col,
model=model,
tokenizer=tokenizer,
label_encoder=label_encoder,
preprocessor=preprocessor,
nepali_font=nepali_font,
explainability_available=explainability_available,
captum_available=captum_available,
mode_key="csv"
)
except Exception as e:
st.error(f"❌ Error processing file: {str(e)}")
with st.expander("🐛 Error Details"):
st.exception(e)
# ========================================================================
# TAB 4: HISTORY
# ========================================================================
with tabs[3]:
st.subheader("📈 Prediction History")
col1, col2 = st.columns([3, 1])
with col1:
st.write("View and analyze your prediction history")
with col2:
if st.button("🔄 Refresh", use_container_width=True):
st.rerun()
if os.path.exists(HISTORY_FILE):
try:
with open(HISTORY_FILE, 'r', encoding='utf-8') as f:
history = json.load(f)
if history:
history_df = pd.DataFrame(history)
history_df['timestamp'] = pd.to_datetime(history_df['timestamp'])
st.markdown("### 📊 Overview")
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Total Predictions", len(history_df))
with col2:
st.metric("Avg Confidence", f"{history_df['confidence'].mean():.2%}")
with col3:
if 'emoji_features' in history_df.columns:
total_emojis = sum(
e.get('total_emoji_count', 0)
for e in history_df['emoji_features']
if isinstance(e, dict)
)
st.metric("Total Emojis", total_emojis)
else:
st.metric("Total Emojis", "N/A")
with col4:
most_common = history_df['prediction'].mode()[0]
st.metric("Most Common", most_common)
st.markdown("---")
st.markdown("### 📈 Trends")
col1, col2 = st.columns(2)
with col1:
daily_counts = history_df.groupby(
history_df['timestamp'].dt.date
).size().reset_index(name='count')
fig = px.line(
daily_counts,
x='timestamp',
y='count',
title="Predictions Over Time",
labels={'timestamp': 'Date', 'count': 'Count'}
)
st.plotly_chart(fig, use_container_width=True)
with col2:
class_dist = history_df['prediction'].value_counts()
fig = px.pie(
values=class_dist.values,
names=class_dist.index,
title="Class Distribution",
color=class_dist.index,
color_discrete_map={
'NO': '#28a745',
'OO': '#ffc107',
'OR': '#dc3545',
'OS': '#6f42c1'
}
)
st.plotly_chart(fig, use_container_width=True)
st.markdown("---")
st.markdown("### 📋 Recent Predictions")
num_to_show = st.slider("Number to show", 5, 50, 20, 5)
recent = history_df.tail(num_to_show).sort_values('timestamp', ascending=False)
display = recent[['timestamp', 'text', 'prediction', 'confidence']].copy()
display['confidence'] = display['confidence'].apply(lambda x: f"{x:.2%}")
display['text'] = display['text'].apply(lambda x: x[:80] + '...' if len(x) > 80 else x)
display['timestamp'] = display['timestamp'].dt.strftime('%Y-%m-%d %H:%M:%S')
st.dataframe(display, use_container_width=True, hide_index=True, height=400)
st.markdown("---")
col1, col2 = st.columns(2)
with col1:
csv = history_df.to_csv(index=False)
st.download_button(
label="📥 Download Full History",
data=csv,
file_name=f"history_{datetime.now().strftime('%Y%m%d')}.csv",
mime="text/csv",
use_container_width=True
)
with col2:
if st.button("🗑️ Clear History", type="secondary", use_container_width=True):
if os.path.exists(HISTORY_FILE):
os.remove(HISTORY_FILE)
st.success("✅ History cleared!")
st.rerun()
else:
st.info("📝 No predictions in history yet.")
except Exception as e:
st.error(f"❌ Error loading history: {str(e)}")
with st.expander("🐛 Error Details"):
st.exception(e)
else:
st.info("📝 No history file found yet.")
st.markdown("""
### How to Build History:
1. Go to **Single Prediction** tab
2. Enable "Save to history" checkbox
3. Analyze some text
4. Your predictions will appear here!
""")
if __name__ == "__main__":
main()