Spaces:
Sleeping
Sleeping
File size: 25,244 Bytes
ee0d322 3e8208f ee0d322 34b86dc e85f4ac 420b25e 2b014b5 34b86dc 198730e 420b25e 34b86dc 420b25e dcb7c19 34b86dc 3e8208f 420b25e 34b86dc 420b25e 69200a8 420b25e 69200a8 420b25e dcb7c19 69200a8 420b25e 69200a8 420b25e 5cf3c1f 91d1ab4 198730e 420b25e 198730e 420b25e 198730e 420b25e 198730e 420b25e 198730e 3e8208f a6d62f5 3e8208f a6d62f5 3e8208f a6d62f5 b1587d0 3e8208f a6d62f5 b1587d0 69200a8 b1587d0 a6d62f5 b1587d0 3e8208f a6d62f5 b1587d0 ee0d322 420b25e 69200a8 420b25e 69200a8 420b25e 69200a8 420b25e 69200a8 420b25e 69200a8 420b25e 69200a8 420b25e 69200a8 420b25e 69200a8 420b25e 69200a8 420b25e 69200a8 5cf3c1f 69200a8 5cf3c1f 69200a8 5cf3c1f dcb7c19 69200a8 5cf3c1f 69200a8 198730e 69200a8 e85f4ac b1587d0 69200a8 ee0d322 198730e 69200a8 198730e 69200a8 198730e 69200a8 198730e 34b86dc 420b25e 69200a8 420b25e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 | import gc
import re
import streamlit as st
from hatespeech_model import predict_hatespeech, load_model_from_hf, predict_hatespeech_from_file, get_rationale_from_mistral, preprocess_rationale_mistral, predict_hatespeech_from_file_batched
import plotly.graph_objects as go
import plotly.express as px
import pandas as pd
import numpy as np
import time
is_file_uploader_visible = False
# Page configuration
st.set_page_config(
page_title="π‘οΈ Hate Speech Detector",
page_icon="π‘οΈ",
layout="wide"
)
# Cached model loading function
@st.cache_resource
def load_cached_model(model_type="altered"):
model, tokenizer_hatebert, tokenizer_rationale, config, device = load_model_from_hf(model_type=model_type)
return {
"model": model,
"tokenizer_hatebert": tokenizer_hatebert,
"tokenizer_rationale": tokenizer_rationale,
"config": config,
"device": device
}
def clean_user_input(text):
"""Remove URLs and special characters (except exclamation points) from text"""
# Remove URLs
text = re.sub(r'https?://\S+|www\.\S+', '', text)
# Remove special characters except exclamation points
text = re.sub(r'[^a-zA-Z0-9\s!]', '', text)
# Remove extra whitespace
text = re.sub(r'\s+', ' ', text).strip()
return text
# Custom CSS
st.markdown("""
<style>
.main-header {
font-size: 3rem;
font-weight: bold;
text-align: center;
color: #1f77b4;
margin-bottom: 1rem;
}
.sub-header {
font-size: 1.2rem;
text-align: center;
color: #555;
margin-bottom: 2rem;
}
.prediction-box {
padding: 2rem;
border-radius: 10px;
text-align: center;
font-size: 1.5rem;
font-weight: bold;
margin: 1rem 0;
}
.hate-speech {
background-color: #ffebee;
color: #c62828;
border: 2px solid #ef5350;
}
.not-hate-speech {
background-color: #e8f5e9;
color: #2e7d32;
border: 2px solid #66bb6a;
}
</style>
""", unsafe_allow_html=True)
# Header
st.markdown('<div class="main-header">π‘οΈ Hate Speech Detection System</div>', unsafe_allow_html=True)
st.markdown('<div class="sub-header">Comparing Base vs Enhanced models with explainable AI for detecting hate speech</div>', unsafe_allow_html=True)
# Load both models with spinner
with st.spinner('π Loading models... This may take a moment on first run.'):
try:
base_data = load_cached_model("base")
enhanced_data = load_cached_model("altered")
base_model = base_data["model"]
base_tokenizer_hatebert = base_data["tokenizer_hatebert"]
base_tokenizer_rationale = base_data["tokenizer_rationale"]
base_config = base_data["config"]
base_device = base_data["device"]
enhanced_model = enhanced_data["model"]
enhanced_tokenizer_hatebert = enhanced_data["tokenizer_hatebert"]
enhanced_tokenizer_rationale = enhanced_data["tokenizer_rationale"]
enhanced_config = enhanced_data["config"]
enhanced_device = enhanced_data["device"]
st.success('β
Base Shield and Enhanced Shield models loaded successfully!')
except Exception as e:
st.error(f"β Error loading models: {str(e)}")
st.stop()
# Sidebar
with st.sidebar:
st.header("βοΈ Settings")
st.divider()
st.markdown(f"**Device:** CPU")
st.markdown(f"**Max Length:** 128")
st.markdown(f"**CNN Filters:** 128")
st.divider()
st.subheader("π File Upload")
is_file_uploader_visible = st.checkbox("Enable File Upload", value=is_file_uploader_visible)
st.divider()
show_rationale_viz = st.checkbox("Show Token Importance", value=True)
show_probabilities = st.checkbox("Show Probability Distribution", value=True)
show_details = st.checkbox("Show Technical Details", value=False)
st.divider()
st.subheader("π‘ About")
st.markdown("""
This model uses:
- **HateBERT** for hate speech understanding
- **Multi-Scale CNN** for feature extraction
- **Attention mechanisms** for interpretability
""")
# Main interface
col1, col2 = st.columns([2, 1])
with col1:
if is_file_uploader_visible:
user_input = None
st.subheader("π Upload File")
uploaded_file = st.file_uploader(
"Choose a text file (.csv) to analyze:",
type=["csv"],
help="Upload a text file containing the content you want to analyze for hate speech"
)
if uploaded_file is not None:
try:
file_content = pd.read_csv(uploaded_file, usecols=['text', 'CF_Rationales', 'label'])
st.success("β
File loaded successfully! Scroll down to analyze.")
except Exception as e:
st.error(f"β Error reading file: {str(e)}")
user_input = ""
else:
st.subheader("π Input Text/File")
user_input = st.text_area(
"Enter text to analyze:",
placeholder="Type or paste text here to check for hate speech...",
height=150,
help="Enter any text and the model will classify it as hate speech or not"
)
optional_rationale = st.text_area(
"Optional: Provide context or rationale (leave empty to use main text):",
placeholder="Why might this be hate speech? (optional)",
height=80
)
with col2:
st.subheader("π Quick Stats")
if user_input:
word_count = len(user_input.split())
char_count = len(user_input)
st.metric("Words", word_count)
st.metric("Characters", char_count)
if is_file_uploader_visible and uploaded_file is not None:
st.markdown(f"**Filename:** {uploaded_file.name}")
st.markdown(f"**Size:** {uploaded_file.size / 1024:.2f} KB")
file_rows = len(file_content)
st.metric("Rows in File", file_rows)
else:
st.info("Enter text/file to see statistics")
# Classification button
classify_button = st.button("π Analyze Text", type="primary", use_container_width=True)
if classify_button:
if user_input and user_input.strip():
# Clean the input text
cleaned_input = clean_user_input(user_input)
with st.spinner('π Generating rationale from Mistral AI...'):
# --- Step 1: Get rationale from Mistral ---
try:
raw_rationale = get_rationale_from_mistral(cleaned_input)
cleaned_rationale = preprocess_rationale_mistral(raw_rationale)
print(f"Raw rationale from Mistral: {raw_rationale}")
except Exception as e:
st.error(f"β Error generating/processing rationale: {str(e)}")
cleaned_rationale = cleaned_input # fallback to cleaned input
with st.spinner('π Analyzing text with models...'):
# Run enhanced model
enhanced_start = time.time()
enhanced_model_result = predict_hatespeech(
text=cleaned_input,
rationale=cleaned_rationale, # use cleaned rationale
model=enhanced_model,
tokenizer_hatebert=enhanced_tokenizer_hatebert,
tokenizer_rationale=enhanced_tokenizer_rationale,
config=enhanced_config,
device=enhanced_device,
model_type="altered"
)
enhanced_end = time.time()
# Run base model
base_start = time.time()
base_model_result = predict_hatespeech(
text=cleaned_input,
rationale=cleaned_rationale, # use cleaned rationale
model=base_model,
tokenizer_hatebert=base_tokenizer_hatebert,
tokenizer_rationale=base_tokenizer_rationale,
config=base_config,
device=base_device,
model_type="base"
)
base_end = time.time()
gc.collect() # Clean up memory after inference
# Extract results for both models
base_prediction = base_model_result['prediction']
base_confidence = base_model_result['confidence']
base_probabilities = base_model_result['probabilities']
base_processing_time = base_end - base_start
enhanced_prediction = enhanced_model_result['prediction']
enhanced_confidence = enhanced_model_result['confidence']
enhanced_probabilities = enhanced_model_result['probabilities']
enhanced_rationale_scores = enhanced_model_result['rationale_scores']
enhanced_tokens = enhanced_model_result['tokens']
enhanced_processing_time = enhanced_end - enhanced_start
# Display results
st.divider()
st.header("π Analysis Results")
# Side-by-side results columns
base_col, enhanced_col = st.columns(2)
# === BASE MODEL RESULTS (LEFT) ===
with base_col:
st.subheader("π΅ Base Shield Results")
# Prediction box
if base_prediction == 1:
st.markdown(f'<div class="prediction-box hate-speech">π¨ HATE SPEECH DETECTED</div>',
unsafe_allow_html=True)
else:
st.markdown(f'<div class="prediction-box not-hate-speech">β
NOT HATE SPEECH</div>',
unsafe_allow_html=True)
# Metrics
st.metric("Confidence", f"{base_confidence:.1%}")
base_m1, base_m2 = st.columns(2)
with base_m1:
st.metric("Not Hate Speech", f"{base_probabilities[0]:.1%}")
with base_m2:
st.metric("Hate Speech", f"{base_probabilities[1]:.1%}")
st.metric("Processing Time", f"{base_processing_time:.3f}s")
# Probability distribution chart
if show_probabilities:
st.markdown("**π Probability Distribution**")
fig_base = go.Figure(data=[
go.Bar(
x=['Not Hate Speech', 'Hate Speech'],
y=base_probabilities,
marker_color=['#66bb6a', '#ef5350'],
text=[f"{p:.1%}" for p in base_probabilities],
textposition='auto',
)
])
fig_base.update_layout(
yaxis_title="Probability",
yaxis_range=[0, 1],
height=300,
showlegend=False
)
st.plotly_chart(fig_base, use_container_width=True)
# Technical details for base
if show_details:
with st.expander("View Base Model Outputs"):
st.json({
'prediction': int(base_prediction),
'confidence': float(base_confidence),
'probability_not_hate': float(base_probabilities[0]),
'probability_hate': float(base_probabilities[1]),
'device': 'cpu',
'model_config': {
'max_length': '128',
}
})
# === ENHANCED MODEL RESULTS (RIGHT) ===
with enhanced_col:
st.subheader("π’ Enhanced Shield Results")
# Prediction box
if enhanced_prediction == 1:
st.markdown(f'<div class="prediction-box hate-speech">π¨ HATE SPEECH DETECTED</div>',
unsafe_allow_html=True)
else:
st.markdown(f'<div class="prediction-box not-hate-speech">β
NOT HATE SPEECH</div>',
unsafe_allow_html=True)
# Metrics
st.metric("Confidence", f"{enhanced_confidence:.1%}")
enh_m1, enh_m2 = st.columns(2)
with enh_m1:
st.metric("Not Hate Speech", f"{enhanced_probabilities[0]:.1%}")
with enh_m2:
st.metric("Hate Speech", f"{enhanced_probabilities[1]:.1%}")
st.metric("Processing Time", f"{enhanced_processing_time:.3f}s")
# Probability distribution chart
if show_probabilities:
st.markdown("**π Probability Distribution**")
fig_enhanced = go.Figure(data=[
go.Bar(
x=['Not Hate Speech', 'Hate Speech'],
y=enhanced_probabilities,
marker_color=['#66bb6a', '#ef5350'],
text=[f"{p:.1%}" for p in enhanced_probabilities],
textposition='auto',
)
])
fig_enhanced.update_layout(
yaxis_title="Probability",
yaxis_range=[0, 1],
height=300,
showlegend=False
)
st.plotly_chart(fig_enhanced, use_container_width=True)
# Token importance visualization (only for enhanced)
if show_rationale_viz:
st.markdown("**π Token Importance Analysis**")
st.caption("Highlighted words show which parts influenced the prediction")
# Filter out special tokens and create visualization
token_importance = []
html_output = "<div style='font-size: 16px; line-height: 2.2; padding: 15px; background-color: #f9f7f7; border-radius: 10px; color: #333;'>"
for token, score in zip(enhanced_tokens, enhanced_rationale_scores):
if token not in ['[CLS]', '[SEP]', '[PAD]']:
# Clean token
display_token = token.replace('##', '')
token_importance.append({'Token': display_token, 'Importance': score})
# Color intensity based on score and prediction
alpha = min(score * 1.5, 1.0) # Scale up visibility
color = f"rgba(239, 83, 80, {alpha:.2f})" # Red for hate speech influence
# Use green for non-hate speech, red for hate speech
# if enhanced_prediction == 1: # Hate speech
# color = f"rgba(239, 83, 80, {alpha:.2f})"
# else: # Not hate speech
# color = f"rgba(102, 187, 106, {alpha:.2f})"
html_output += f"<span style='background-color: {color}; padding: 3px 6px; margin: 1px; border-radius: 4px; display: inline-block;'>{display_token}</span> "
html_output += "</div>"
st.markdown(html_output, unsafe_allow_html=True)
# Dynamic caption based on prediction
if enhanced_prediction == 1:
st.caption("π΄ Darker red = More influence on hate speech detection.")
# Top important tokens
st.markdown("**π Top Important Tokens**")
df_importance = pd.DataFrame(token_importance)
df_importance = df_importance.sort_values('Importance', ascending=False).head(10)
df_importance['Importance'] = df_importance['Importance'].apply(lambda x: f"{x:.4f}")
st.dataframe(
df_importance,
use_container_width=True,
hide_index=True
)
# Technical details for enhanced
if show_details:
with st.expander("View Enhanced Model Outputs"):
st.json({
'prediction': int(enhanced_prediction),
'confidence': float(enhanced_confidence),
'probability_not_hate': float(enhanced_probabilities[0]),
'probability_hate': float(enhanced_probabilities[1]),
'num_tokens': len([t for t in enhanced_tokens if t not in ['[CLS]', '[SEP]', '[PAD]']]),
'device': 'cpu',
'model_config': {
'max_length': '128',
'cnn_filters': '128',
}
})
if is_file_uploader_visible and uploaded_file is not None:
st.markdown("**Preview:**")
st.dataframe(file_content.head(3), use_container_width=True)
with st.spinner('π Analyzing file with both models... This may take a while for large files.'):
# Run both models on the file
# base_result = predict_hatespeech_from_file(...) # Base model
# enhanced_result = predict_hatespeech_from_file(...) # Enhanced model
enhanced_result = predict_hatespeech_from_file_batched(
text_list=file_content['text'].tolist(),
rationale_list=file_content['CF_Rationales'].tolist(),
true_label=file_content['label'].tolist(),
model=enhanced_model,
tokenizer_hatebert=enhanced_tokenizer_hatebert,
tokenizer_rationale=enhanced_tokenizer_rationale,
config=enhanced_config,
device=enhanced_device,
model_type="altered"
)
base_result = predict_hatespeech_from_file(
text_list=file_content['text'].tolist(),
rationale_list=file_content['CF_Rationales'].tolist(),
true_label=file_content['label'].tolist(),
model=base_model,
tokenizer_hatebert=base_tokenizer_hatebert,
tokenizer_rationale=base_tokenizer_rationale,
config=base_config,
device=base_device,
model_type="base"
)
st.success("β
File analysis complete for both models!")
st.divider()
st.header("π Analysis Results - Model Comparison")
gc.collect() # Clean up memory after file inference
# Side-by-side results columns
base_file_col, enhanced_file_col = st.columns(2)
# === BASE MODEL FILE RESULTS (LEFT) ===
with base_file_col:
st.subheader("π΅ Base Shield Results")
# Performance Metrics
st.markdown("**π Classification Metrics**")
base_fm1, base_fm2 = st.columns(2)
with base_fm1:
st.metric("F1 Score", f"{base_result['f1_score']:.4f}")
st.metric("Precision", f"{base_result['precision']:.4f}")
with base_fm2:
st.metric("Accuracy", f"{base_result['accuracy']:.4f}")
st.metric("Recall", f"{base_result['recall']:.4f}")
# Confusion Matrix Visualization
st.markdown("**π― Confusion Matrix**")
base_cm = base_result['confusion_matrix']
fig_base_cm = go.Figure(data=go.Heatmap(
z=base_cm,
x=['Pred Not Hate', 'Pred Hate'],
y=['True Not Hate', 'True Hate'],
colorscale='Blues',
text=base_cm,
texttemplate='%{text}',
textfont={"size": 14},
showscale=False
))
fig_base_cm.update_layout(height=300)
st.plotly_chart(fig_base_cm, use_container_width=True)
# Resource Usage
st.markdown("**βοΈ Resource Usage**")
base_cpu_col, base_mem_col = st.columns(2)
with base_cpu_col:
st.metric("Avg CPU", f"{base_result['cpu_usage']:.2f}%")
st.metric("Peak CPU", f"{base_result['peak_cpu_usage']:.2f}%")
with base_mem_col:
st.metric("Avg Memory", f"{base_result['memory_usage']:.2f} MB")
st.metric("Peak Memory", f"{base_result['peak_memory_usage']:.2f} MB")
# Runtime
st.markdown("**β±οΈ Performance**")
st.metric("Total Runtime", f"{base_result['runtime']:.2f}s")
st.metric("Avg Time/Sample", f"{base_result['runtime']/file_rows:.3f}s")
# === ENHANCED MODEL FILE RESULTS (RIGHT) ===
with enhanced_file_col:
st.subheader("π’ Enhanced Shield Results")
st.markdown("**π Classification Metrics**")
enh_fm1, enh_fm2 = st.columns(2)
with enh_fm1:
st.metric("F1 Score", f"{enhanced_result['f1_score']:.4f}")
st.metric("Precision", f"{enhanced_result['precision']:.4f}")
with enh_fm2:
st.metric("Accuracy", f"{enhanced_result['accuracy']:.4f}")
st.metric("Recall", f"{enhanced_result['recall']:.4f}")
# Confusion Matrix Visualization
st.markdown("**π― Confusion Matrix**")
enhanced_cm = enhanced_result['confusion_matrix']
fig_enhanced_cm = go.Figure(data=go.Heatmap(
z=enhanced_cm,
x=['Pred Not Hate', 'Pred Hate'],
y=['True Not Hate', 'True Hate'],
colorscale='Greens',
text=enhanced_cm,
texttemplate='%{text}',
textfont={"size": 14},
showscale=False
))
fig_enhanced_cm.update_layout(height=300)
st.plotly_chart(fig_enhanced_cm, use_container_width=True)
st.markdown("**βοΈ Resource Usage**")
enh_cpu_col, enh_mem_col = st.columns(2)
with enh_cpu_col:
st.metric("Avg CPU", f"{enhanced_result['cpu_usage']:.2f}%")
st.metric("Peak CPU", f"{enhanced_result['peak_cpu_usage']:.2f}%")
with enh_mem_col:
st.metric("Avg Memory", f"{enhanced_result['memory_usage']:.2f} MB")
st.metric("Peak Memory", f"{enhanced_result['peak_memory_usage']:.2f} MB")
# Runtime
st.markdown("**β±οΈ Performance**")
st.metric("Total Runtime", f"{enhanced_result['runtime']:.2f}s")
st.metric("Avg Time/Sample", f"{enhanced_result['runtime']/file_rows:.3f}s")
else:
st.warning("β οΈ Please enter some text to analyze.")
# Examples section
st.divider()
st.subheader("π‘ Try Example Texts")
col1, col2, col3 = st.columns(3)
with col1:
if st.button("Example: Hate Speech", use_container_width=True):
st.session_state.example_text = "You people are worthless and should leave this country!"
st.rerun()
with col2:
if st.button("Example: Not Hate Speech", use_container_width=True):
st.session_state.example_text = "I disagree with your opinion, but I respect your right to express it."
st.rerun()
with col3:
if st.button("Example: Borderline", use_container_width=True):
st.session_state.example_text = "This policy is terrible and will hurt everyone involved."
st.rerun()
if 'example_text' in st.session_state:
st.info(f"**Example loaded:** {st.session_state.example_text}")
st.caption("β Copy this text to the input box above and click 'Analyze Text'")
# Footer
st.divider()
st.markdown("""
<div style='text-align: center; color: gray; padding: 20px;'>
<p><b>Hate Speech Detection Model Comparison</b></p>
<p>Base Shield vs Enhanced Shield (HateBERT + Multi-Scale CNN + Attention)</p>
<p>Side-by-side comparison for performance evaluation</p>
</div>
""", unsafe_allow_html=True)
|