import nltk import streamlit as st from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline import torch import torch.nn.functional as F import spacy import re from nltk.sentiment import SentimentIntensityAnalyzer import emoji import plotly.graph_objects as go import plotly.express as px from collections import Counter import time import numpy as np # Configuration - Multiple Models MODELS = { "helinivan": "helinivan/English-sarcasm-detector", "distilbert": "dima806/sarcasm-detection-distilbert" } # Initialize NLTK VADER analyzer try: nltk.data.path.append('/app/nltk_data') sia = SentimentIntensityAnalyzer() except Exception as e: st.error(f"Error downloading NLTK data: {e}") sia = None # Cache multiple models & tokenizers @st.cache_resource def load_models(): models = {} tokenizers = {} for name, model_path in MODELS.items(): try: model = AutoModelForSequenceClassification.from_pretrained(model_path, cache_dir="/tmp/hf_cache") tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir="/tmp/hf_cache") model.eval() models[name] = model tokenizers[name] = tokenizer st.success(f"โ Loaded {name} model successfully") except Exception as e: st.error(f"โ Failed to load {name} model: {str(e)}") models[name] = None tokenizers[name] = None return models, tokenizers # Lazy-load SpaCy (optional - not used in current implementation) def load_spacy(): try: return spacy.load("en_core_web_sm") except OSError: st.warning("SpaCy model 'en_core_web_sm' not found. Some features may be limited.") return None # Pattern detection functions with highlighting info def social_media_sarcasm_cues(text: str) -> tuple[float, list, list]: explanations = [] highlights = [] boost = 0.0 text_lower = text.lower() # Enhanced sarcasm phrases (including Reddit-style patterns) sarcasm_phrases = [ "oh sure", "yeah right", "of course", "totally", "absolutely", "perfect", "wonderful", "fantastic", "amazing", "brilliant", "great job", "well done", "nice one", "good going", "way to go", "real smooth", "genius move", "solid plan", "makes sense", "just perfect", "exactly what i needed", "this is fine", # Reddit-style additions "thanks genius", "no shit sherlock", "well duh", "captain obvious", "groundbreaking", "revolutionary", "what a concept", "mind blown", "shocking", "who would have guessed", "truly inspiring" ] for phrase in sarcasm_phrases: # Find all occurrences of the phrase for match in re.finditer(re.escape(phrase), text_lower): boost += 0.2 explanations.append(f"Sarcastic phrase: '{phrase}'") highlights.append({ 'start': match.start(), 'end': match.end(), 'type': 'sarcastic_phrase', 'text': phrase }) # Exaggerated expressions exaggerated_match = re.search(r'\b(SO|TOTALLY|ABSOLUTELY|REALLY|VERY)\b.*\b(great|good|perfect|amazing|helpful|useful)\b', text, re.IGNORECASE) if exaggerated_match: boost += 0.25 explanations.append("Exaggerated positive expression") highlights.append({ 'start': exaggerated_match.start(), 'end': exaggerated_match.end(), 'type': 'exaggerated', 'text': exaggerated_match.group() }) return boost, explanations, highlights def emoji_punctuation_analysis(text: str) -> tuple[float, list, list]: explanations = [] highlights = [] boost = 0.0 # Extract emojis with positions try: emojis = emoji.emoji_list(text) sarcastic_emojis = ['๐', '๐', '๐', '๐ค', '๐คจ', '๐ค', '๐คท', '๐', '๐', '๐คก', '๐', '๐คฏ'] for emoji_info in emojis: if emoji_info['emoji'] in sarcastic_emojis: boost += 0.15 explanations.append(f"Sarcastic emoji: {emoji_info['emoji']}") highlights.append({ 'start': emoji_info['match_start'], 'end': emoji_info['match_end'], 'type': 'sarcastic_emoji', 'text': emoji_info['emoji'] }) except Exception as e: # Fallback if emoji library has issues pass # Excessive punctuation for match in re.finditer(r'[!?]{2,}', text): boost += 0.1 explanations.append(f"Excessive punctuation: {match.group()}") highlights.append({ 'start': match.start(), 'end': match.end(), 'type': 'excessive_punct', 'text': match.group() }) # Ellipsis (often sarcastic) for match in re.finditer(r'\.{3,}', text): boost += 0.15 explanations.append(f"Trailing ellipsis: {match.group()}") highlights.append({ 'start': match.start(), 'end': match.end(), 'type': 'ellipsis', 'text': match.group() }) return boost, explanations, highlights def rhetorical_questions_analysis(text: str) -> tuple[float, list, list]: explanations = [] highlights = [] boost = 0.0 rhetorical_patterns = [ (r'what could possibly go wrong\?', "Rhetorical question"), (r'who would have thought\?', "Rhetorical question"), (r'seriously\?', "Emphatic question"), (r'really\?.*really\?', "Repeated question"), (r'no way\?', "Disbelief question"), (r'you don\'t say\?', "Sarcastic response"), (r'shocking.*\?', "Mock surprise") ] for pattern, description in rhetorical_patterns: for match in re.finditer(pattern, text, re.IGNORECASE): boost += 0.3 explanations.append(description) highlights.append({ 'start': match.start(), 'end': match.end(), 'type': 'rhetorical_question', 'text': match.group() }) return boost, explanations, highlights def capitalization_analysis(text: str) -> tuple[float, list, list]: explanations = [] highlights = [] boost = 0.0 # ALL CAPS words for match in re.finditer(r'\b[A-Z]{3,}\b', text): if match.group() not in ['AND', 'THE', 'FOR', 'BUT', 'YOU', 'ARE']: boost += 0.1 explanations.append(f"Emphatic caps: {match.group()}") highlights.append({ 'start': match.start(), 'end': match.end(), 'type': 'caps_emphasis', 'text': match.group() }) # Letter repetition for match in re.finditer(r'(.)\1{2,}', text): boost += 0.1 explanations.append(f"Letter repetition: {match.group()}") highlights.append({ 'start': match.start(), 'end': match.end(), 'type': 'repetition', 'text': match.group() }) return boost, explanations, highlights # Combined analysis with highlighting def enhanced_rule_analysis(text: str) -> tuple[float, list, list]: all_explanations = [] all_highlights = [] total_boost = 0.0 # Apply all analysis functions boost1, exp1, high1 = social_media_sarcasm_cues(text) boost2, exp2, high2 = emoji_punctuation_analysis(text) boost3, exp3, high3 = rhetorical_questions_analysis(text) boost4, exp4, high4 = capitalization_analysis(text) total_boost = boost1 + boost2 + boost3 + boost4 all_explanations.extend(exp1 + exp2 + exp3 + exp4) all_highlights.extend(high1 + high2 + high3 + high4) # Cap the total boost total_boost = min(total_boost, 0.8) return total_boost, all_explanations, all_highlights # Multi-model prediction function def get_model_predictions_current(text: str, models: dict, tokenizers: dict, device) -> dict: predictions = {} for name, model in models.items(): if model is None or tokenizers[name] is None: predictions[name] = 0.0 continue try: inputs = tokenizers[name]([text], return_tensors="pt", truncation=True, padding=True).to(device) model.to(device) with torch.no_grad(): logits = model(**inputs).logits # Handle different output formats if logits.shape[-1] == 2: # Binary classification score = F.softmax(logits, dim=-1)[0, 1].item() else: # Single output score = torch.sigmoid(logits)[0, 0].item() predictions[name] = score except Exception as e: st.warning(f"Error with {name} model: {str(e)}") predictions[name] = 0.0 return predictions # Modify get_model_predictions to accept context and reply def get_model_predictions_experiment(context: str, reply: str, models: dict, tokenizers: dict, device) -> dict: predictions = {} for name, model in models.items(): if model is None or tokenizers[name] is None: predictions[name] = 0.0 continue try: # Use sentence-pair interface inputs = tokenizers[name]( context, reply, return_tensors="pt", truncation=True, padding=True ) inputs = {k: v.to(device) for k, v in inputs.items()} model.to(device) with torch.no_grad(): logits = model(**inputs).logits if logits.shape[-1] == 2: # Binary classification score = F.softmax(logits, dim=-1)[0, 1].item() else: score = torch.sigmoid(logits)[0, 0].item() predictions[name] = score except Exception as e: st.warning(f"Error with {name} model: {str(e)}") predictions[name] = 0.0 return predictions # Enhanced ensemble prediction def ensemble_prediction(model_scores: dict, rule_boost: float, weights: dict = None) -> float: if weights is None: # Default weights - adjust based on model performance weights = { 'helinivan': 0.4, 'distilbert': 0.5, # Higher weight for Reddit-trained model 'rules': 0.1 } ensemble_score = 0.0 total_weight = 0.0 # Weighted average of model predictions for model_name, score in model_scores.items(): if score > 0: # Only include valid predictions weight = weights.get(model_name, 0.3) ensemble_score += score * weight total_weight += weight # Add rule-based contribution if total_weight > 0: ensemble_score = ensemble_score / total_weight # Apply rule-based boost final_score = min(ensemble_score + (rule_boost * weights.get('rules', 0.1)), 1.0) return final_score # Create highlighted text HTML def create_highlighted_text(text: str, highlights: list) -> str: if not isinstance(text, str): return "" if not highlights: return text.replace("&", "&").replace("<", "<").replace(">", ">") # Sort highlights by start position sorted_highlights = sorted(highlights, key=lambda x: x['start']) color_map = { 'sarcastic_phrase': '#ff6b6b', 'sarcastic_emoji': '#4ecdc4', 'excessive_punct': '#45b7d1', 'rhetorical_question': '#96ceb4', 'caps_emphasis': '#feca57', 'repetition': '#ff9ff3', 'exaggerated': '#54a0ff', 'ellipsis': '#fd79a8' } result = "" last_end = 0 for highlight in sorted_highlights: start, end = highlight['start'], highlight['end'] highlight_type = highlight['type'] color = color_map.get(highlight_type, '#dda0dd') # Add text before highlight if start > last_end: before_text = text[last_end:start] result += before_text.replace("&", "&").replace("<", "<").replace(">", ">") # Add highlighted text highlighted_text = text[start:end] safe_text = highlighted_text.replace("&", "&").replace("<", "<").replace(">", ">") result += f'{safe_text}' last_end = end # Add remaining text if last_end < len(text): remaining_text = text[last_end:] result += remaining_text.replace("&", "&").replace("<", "<").replace(">", ">") return result # Enhanced confidence gauge def create_confidence_gauge(score: float) -> go.Figure: fig = go.Figure(go.Indicator( mode = "gauge+number+delta", value = score, domain = {'x': [0, 1], 'y': [0, 1]}, title = {'text': "Ensemble Sarcasm Score"}, delta = {'reference': 0.5}, gauge = { 'axis': {'range': [None, 1]}, 'bar': {'color': "darkblue"}, 'steps': [ {'range': [0, 0.3], 'color': "lightgray"}, {'range': [0.3, 0.6], 'color': "yellow"}, {'range': [0.6, 1], 'color': "red"} ], 'threshold': { 'line': {'color': "red", 'width': 4}, 'thickness': 0.75, 'value': 0.7 } } )) fig.update_layout(height=300) return fig # Multi-model feature importance visualization def create_model_comparison_chart(model_scores: dict, rule_boost: float, final_score: float) -> go.Figure: models = list(model_scores.keys()) scores = list(model_scores.values()) # Add rule-based and final scores models.extend(['Rule-based', 'Final Ensemble']) scores.extend([rule_boost, final_score]) colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd'] fig = go.Figure(go.Bar( x=models, y=scores, marker_color=colors[:len(models)], text=[f'{score:.3f}' for score in scores], textposition='auto', )) fig.update_layout( title="Model Comparison & Ensemble Result", yaxis_title="Sarcasm Score", height=400, showlegend=False ) return fig # Real-time analysis function with multiple models def analyze_text_realtime_current(text: str, models: dict, tokenizers: dict, device) -> dict: if not text.strip(): return { 'score': 0.0, 'label': 'Enter text to analyze', 'explanations': [], 'highlights': [], 'model_scores': {} } try: # Get rule-based analysis rule_boost, explanations, highlights = enhanced_rule_analysis(text) # Get predictions from all models model_scores = get_model_predictions_current(text, models, tokenizers, device) # Ensemble prediction final_score = ensemble_prediction(model_scores, rule_boost) # Determine label if final_score > 0.8: label = "Extremely Sarcastic ๐คจ๐" elif final_score > 0.7: label = "Highly Sarcastic ๐คจ" elif final_score > 0.6: label = "Likely Sarcastic ๐" elif final_score > 0.4: label = "Possibly Sarcastic ๐ค" elif final_score > 0.3: label = "Probably Sincere ๐" else: label = "Sincere ๐" return { 'score': final_score, 'model_scores': model_scores, 'rule_boost': rule_boost, 'label': label, 'explanations': explanations, 'highlights': highlights } except Exception as e: return { 'score': 0.0, 'label': f'Error: {str(e)}', 'explanations': [], 'highlights': [], 'model_scores': {} } # Modify analyze_text_realtime to accept context and reply def analyze_text_realtime_experiment(context: str, reply: str, models: dict, tokenizers: dict, device) -> dict: if not reply.strip(): return { 'score': 0.0, 'label': 'Enter context and reply to analyze', 'explanations': [], 'highlights': [], 'model_scores': {} } try: # Use reply for rule-based analysis (context is not used in rules) rule_boost, explanations, highlights = enhanced_rule_analysis(reply) # Get predictions from all models using context and reply model_scores = get_model_predictions_experiment(context, reply, models, tokenizers, device) final_score = ensemble_prediction(model_scores, rule_boost) # ...label assignment unchanged... if final_score > 0.8: label = "Extremely Sarcastic ๐คจ๐" elif final_score > 0.7: label = "Highly Sarcastic ๐คจ" elif final_score > 0.6: label = "Likely Sarcastic ๐" elif final_score > 0.4: label = "Possibly Sarcastic ๐ค" elif final_score > 0.3: label = "Probably Sincere ๐" else: label = "Sincere ๐" return { 'score': final_score, 'model_scores': model_scores, 'rule_boost': rule_boost, 'label': label, 'explanations': explanations, 'highlights': highlights } except Exception as e: return { 'score': 0.0, 'label': f'Error: {str(e)}', 'explanations': [], 'highlights': [], 'model_scores': {} } # Streamlit UI st.set_page_config(page_title="Enhanced Sarcasm Detector", page_icon="๐คจ", layout="wide") st.title("๐จ๏ธ Enhanced Multi-Model Sarcasm Detector") st.markdown("*Combining DistilBERT (Reddit-trained) + HelinIvan + Rule-based Analysis*") # Load models device = torch.device("cuda" if torch.cuda.is_available() else "cpu") st.markdown(f"**Device:** {device}") with st.spinner("Loading AI models..."): models, tokenizers = load_models() # Model status display st.markdown("### ๐ค Model Status") status_cols = st.columns(2) with status_cols[0]: helinivan_status = "โ Loaded" if models.get('helinivan') else "โ Failed" st.markdown(f"**HelinIvan Model:** {helinivan_status}") with status_cols[1]: distilbert_status = "โ Loaded" if models.get('distilbert') else "โ Failed" st.markdown(f"**DistilBERT Model:** {distilbert_status}") # Sidebar with examples and tips (shared) with st.sidebar: st.markdown("### ๐ก **Quick Examples**") example_buttons = [ ("Oh great, more traffic ๐", "social_media"), ("Yeah, I just LOVE waiting in line", "emphasis"), ("What could possibly go wrong?", "rhetorical"), ("Perfect timing as always...", "timing"), ("Thanks for the help genius", "reddit_style"), ("WOW so helpful!!!", "caps_sarcasm"), ("No shit Sherlock ๐คก", "reddit_sarcasm"), ("Truly groundbreaking stuff here", "mock_praise") ] for example_text, example_type in example_buttons: if st.button(f"๐ {example_text[:25]}...", key=example_type): st.session_state.example_text = example_text # --- Tabs for navigation --- tab1, tab2 = st.tabs(["Single Message (Current)", "Context-Aware (Experimental)"]) # --- Tab 1: Single Message (Current) --- with tab1: st.markdown("### Single Message Sarcasm Detection") st.markdown("Analyze sarcasm in a single message (no conversational context).") col1, col2 = st.columns([2, 1]) with col1: # Text input with real-time analysi default_text = st.session_state.get('example_text', '') user_text = st.text_area( "Enter a paragraph for multi-model sarcasm analysis:", value=default_text, height=120, placeholder="Try: 'Oh fantastic, another meeting that could have been an email ๐ What a brilliant use of everyone's time...'" ) # Real-time analysis if user_text: with st.spinner("Analyzing with multiple models..."): analysis = analyze_text_realtime_current(user_text, models, tokenizers, device) # Display highlighted text st.markdown("### ๐ฏ **Analysis Results**") highlighted_html = create_highlighted_text(user_text, analysis['highlights']) st.markdown(f'