Spaces:
Sleeping
Sleeping
| import nltk | |
| import streamlit as st | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline | |
| import torch | |
| import torch.nn.functional as F | |
| import spacy | |
| import re | |
| from nltk.sentiment import SentimentIntensityAnalyzer | |
| import emoji | |
| import plotly.graph_objects as go | |
| import plotly.express as px | |
| from collections import Counter | |
| import time | |
| import numpy as np | |
| # Configuration - Multiple Models | |
| MODELS = { | |
| "helinivan": "helinivan/English-sarcasm-detector", | |
| "distilbert": "dima806/sarcasm-detection-distilbert" | |
| } | |
| # Initialize NLTK VADER analyzer | |
| try: | |
| nltk.data.path.append('/app/nltk_data') | |
| sia = SentimentIntensityAnalyzer() | |
| except Exception as e: | |
| st.error(f"Error downloading NLTK data: {e}") | |
| sia = None | |
| # Cache multiple models & tokenizers | |
| def load_models(): | |
| models = {} | |
| tokenizers = {} | |
| for name, model_path in MODELS.items(): | |
| try: | |
| model = AutoModelForSequenceClassification.from_pretrained(model_path, cache_dir="/tmp/hf_cache") | |
| tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir="/tmp/hf_cache") | |
| model.eval() | |
| models[name] = model | |
| tokenizers[name] = tokenizer | |
| st.success(f"β Loaded {name} model successfully") | |
| except Exception as e: | |
| st.error(f"β Failed to load {name} model: {str(e)}") | |
| models[name] = None | |
| tokenizers[name] = None | |
| return models, tokenizers | |
| # Lazy-load SpaCy (optional - not used in current implementation) | |
| def load_spacy(): | |
| try: | |
| return spacy.load("en_core_web_sm") | |
| except OSError: | |
| st.warning("SpaCy model 'en_core_web_sm' not found. Some features may be limited.") | |
| return None | |
| # Pattern detection functions with highlighting info | |
| def social_media_sarcasm_cues(text: str) -> tuple[float, list, list]: | |
| explanations = [] | |
| highlights = [] | |
| boost = 0.0 | |
| text_lower = text.lower() | |
| # Enhanced sarcasm phrases (including Reddit-style patterns) | |
| sarcasm_phrases = [ | |
| "oh sure", "yeah right", "of course", "totally", "absolutely", | |
| "perfect", "wonderful", "fantastic", "amazing", "brilliant", | |
| "great job", "well done", "nice one", "good going", "way to go", | |
| "real smooth", "genius move", "solid plan", "makes sense", | |
| "just perfect", "exactly what i needed", "this is fine", | |
| # Reddit-style additions | |
| "thanks genius", "no shit sherlock", "well duh", "captain obvious", | |
| "groundbreaking", "revolutionary", "what a concept", "mind blown", | |
| "shocking", "who would have guessed", "truly inspiring" | |
| ] | |
| for phrase in sarcasm_phrases: | |
| # Find all occurrences of the phrase | |
| for match in re.finditer(re.escape(phrase), text_lower): | |
| boost += 0.2 | |
| explanations.append(f"Sarcastic phrase: '{phrase}'") | |
| highlights.append({ | |
| 'start': match.start(), | |
| 'end': match.end(), | |
| 'type': 'sarcastic_phrase', | |
| 'text': phrase | |
| }) | |
| # Exaggerated expressions | |
| exaggerated_match = re.search(r'\b(SO|TOTALLY|ABSOLUTELY|REALLY|VERY)\b.*\b(great|good|perfect|amazing|helpful|useful)\b', text, re.IGNORECASE) | |
| if exaggerated_match: | |
| boost += 0.25 | |
| explanations.append("Exaggerated positive expression") | |
| highlights.append({ | |
| 'start': exaggerated_match.start(), | |
| 'end': exaggerated_match.end(), | |
| 'type': 'exaggerated', | |
| 'text': exaggerated_match.group() | |
| }) | |
| return boost, explanations, highlights | |
| def emoji_punctuation_analysis(text: str) -> tuple[float, list, list]: | |
| explanations = [] | |
| highlights = [] | |
| boost = 0.0 | |
| # Extract emojis with positions | |
| try: | |
| emojis = emoji.emoji_list(text) | |
| sarcastic_emojis = ['π', 'π', 'π', 'π€', 'π€¨', 'π€', 'π€·', 'π', 'π', 'π€‘', 'π', 'π€―'] | |
| for emoji_info in emojis: | |
| if emoji_info['emoji'] in sarcastic_emojis: | |
| boost += 0.15 | |
| explanations.append(f"Sarcastic emoji: {emoji_info['emoji']}") | |
| highlights.append({ | |
| 'start': emoji_info['match_start'], | |
| 'end': emoji_info['match_end'], | |
| 'type': 'sarcastic_emoji', | |
| 'text': emoji_info['emoji'] | |
| }) | |
| except Exception as e: | |
| # Fallback if emoji library has issues | |
| pass | |
| # Excessive punctuation | |
| for match in re.finditer(r'[!?]{2,}', text): | |
| boost += 0.1 | |
| explanations.append(f"Excessive punctuation: {match.group()}") | |
| highlights.append({ | |
| 'start': match.start(), | |
| 'end': match.end(), | |
| 'type': 'excessive_punct', | |
| 'text': match.group() | |
| }) | |
| # Ellipsis (often sarcastic) | |
| for match in re.finditer(r'\.{3,}', text): | |
| boost += 0.15 | |
| explanations.append(f"Trailing ellipsis: {match.group()}") | |
| highlights.append({ | |
| 'start': match.start(), | |
| 'end': match.end(), | |
| 'type': 'ellipsis', | |
| 'text': match.group() | |
| }) | |
| return boost, explanations, highlights | |
| def rhetorical_questions_analysis(text: str) -> tuple[float, list, list]: | |
| explanations = [] | |
| highlights = [] | |
| boost = 0.0 | |
| rhetorical_patterns = [ | |
| (r'what could possibly go wrong\?', "Rhetorical question"), | |
| (r'who would have thought\?', "Rhetorical question"), | |
| (r'seriously\?', "Emphatic question"), | |
| (r'really\?.*really\?', "Repeated question"), | |
| (r'no way\?', "Disbelief question"), | |
| (r'you don\'t say\?', "Sarcastic response"), | |
| (r'shocking.*\?', "Mock surprise") | |
| ] | |
| for pattern, description in rhetorical_patterns: | |
| for match in re.finditer(pattern, text, re.IGNORECASE): | |
| boost += 0.3 | |
| explanations.append(description) | |
| highlights.append({ | |
| 'start': match.start(), | |
| 'end': match.end(), | |
| 'type': 'rhetorical_question', | |
| 'text': match.group() | |
| }) | |
| return boost, explanations, highlights | |
| def capitalization_analysis(text: str) -> tuple[float, list, list]: | |
| explanations = [] | |
| highlights = [] | |
| boost = 0.0 | |
| # ALL CAPS words | |
| for match in re.finditer(r'\b[A-Z]{3,}\b', text): | |
| if match.group() not in ['AND', 'THE', 'FOR', 'BUT', 'YOU', 'ARE']: | |
| boost += 0.1 | |
| explanations.append(f"Emphatic caps: {match.group()}") | |
| highlights.append({ | |
| 'start': match.start(), | |
| 'end': match.end(), | |
| 'type': 'caps_emphasis', | |
| 'text': match.group() | |
| }) | |
| # Letter repetition | |
| for match in re.finditer(r'(.)\1{2,}', text): | |
| boost += 0.1 | |
| explanations.append(f"Letter repetition: {match.group()}") | |
| highlights.append({ | |
| 'start': match.start(), | |
| 'end': match.end(), | |
| 'type': 'repetition', | |
| 'text': match.group() | |
| }) | |
| return boost, explanations, highlights | |
| # Combined analysis with highlighting | |
| def enhanced_rule_analysis(text: str) -> tuple[float, list, list]: | |
| all_explanations = [] | |
| all_highlights = [] | |
| total_boost = 0.0 | |
| # Apply all analysis functions | |
| boost1, exp1, high1 = social_media_sarcasm_cues(text) | |
| boost2, exp2, high2 = emoji_punctuation_analysis(text) | |
| boost3, exp3, high3 = rhetorical_questions_analysis(text) | |
| boost4, exp4, high4 = capitalization_analysis(text) | |
| total_boost = boost1 + boost2 + boost3 + boost4 | |
| all_explanations.extend(exp1 + exp2 + exp3 + exp4) | |
| all_highlights.extend(high1 + high2 + high3 + high4) | |
| # Cap the total boost | |
| total_boost = min(total_boost, 0.8) | |
| return total_boost, all_explanations, all_highlights | |
| # Multi-model prediction function | |
| def get_model_predictions_current(text: str, models: dict, tokenizers: dict, device) -> dict: | |
| predictions = {} | |
| for name, model in models.items(): | |
| if model is None or tokenizers[name] is None: | |
| predictions[name] = 0.0 | |
| continue | |
| try: | |
| inputs = tokenizers[name]([text], return_tensors="pt", truncation=True, padding=True).to(device) | |
| model.to(device) | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| # Handle different output formats | |
| if logits.shape[-1] == 2: # Binary classification | |
| score = F.softmax(logits, dim=-1)[0, 1].item() | |
| else: # Single output | |
| score = torch.sigmoid(logits)[0, 0].item() | |
| predictions[name] = score | |
| except Exception as e: | |
| st.warning(f"Error with {name} model: {str(e)}") | |
| predictions[name] = 0.0 | |
| return predictions | |
| # Modify get_model_predictions to accept context and reply | |
| def get_model_predictions_experiment(context: str, reply: str, models: dict, tokenizers: dict, device) -> dict: | |
| predictions = {} | |
| for name, model in models.items(): | |
| if model is None or tokenizers[name] is None: | |
| predictions[name] = 0.0 | |
| continue | |
| try: | |
| # Use sentence-pair interface | |
| inputs = tokenizers[name]( | |
| context, | |
| reply, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding=True | |
| ) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| model.to(device) | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| if logits.shape[-1] == 2: # Binary classification | |
| score = F.softmax(logits, dim=-1)[0, 1].item() | |
| else: | |
| score = torch.sigmoid(logits)[0, 0].item() | |
| predictions[name] = score | |
| except Exception as e: | |
| st.warning(f"Error with {name} model: {str(e)}") | |
| predictions[name] = 0.0 | |
| return predictions | |
| # Enhanced ensemble prediction | |
| def ensemble_prediction(model_scores: dict, rule_boost: float, weights: dict = None) -> float: | |
| if weights is None: | |
| # Default weights - adjust based on model performance | |
| weights = { | |
| 'helinivan': 0.4, | |
| 'distilbert': 0.5, # Higher weight for Reddit-trained model | |
| 'rules': 0.1 | |
| } | |
| ensemble_score = 0.0 | |
| total_weight = 0.0 | |
| # Weighted average of model predictions | |
| for model_name, score in model_scores.items(): | |
| if score > 0: # Only include valid predictions | |
| weight = weights.get(model_name, 0.3) | |
| ensemble_score += score * weight | |
| total_weight += weight | |
| # Add rule-based contribution | |
| if total_weight > 0: | |
| ensemble_score = ensemble_score / total_weight | |
| # Apply rule-based boost | |
| final_score = min(ensemble_score + (rule_boost * weights.get('rules', 0.1)), 1.0) | |
| return final_score | |
| # Create highlighted text HTML | |
| def create_highlighted_text(text: str, highlights: list) -> str: | |
| if not isinstance(text, str): | |
| return "" | |
| if not highlights: | |
| return text.replace("&", "&").replace("<", "<").replace(">", ">") | |
| # Sort highlights by start position | |
| sorted_highlights = sorted(highlights, key=lambda x: x['start']) | |
| color_map = { | |
| 'sarcastic_phrase': '#ff6b6b', | |
| 'sarcastic_emoji': '#4ecdc4', | |
| 'excessive_punct': '#45b7d1', | |
| 'rhetorical_question': '#96ceb4', | |
| 'caps_emphasis': '#feca57', | |
| 'repetition': '#ff9ff3', | |
| 'exaggerated': '#54a0ff', | |
| 'ellipsis': '#fd79a8' | |
| } | |
| result = "" | |
| last_end = 0 | |
| for highlight in sorted_highlights: | |
| start, end = highlight['start'], highlight['end'] | |
| highlight_type = highlight['type'] | |
| color = color_map.get(highlight_type, '#dda0dd') | |
| # Add text before highlight | |
| if start > last_end: | |
| before_text = text[last_end:start] | |
| result += before_text.replace("&", "&").replace("<", "<").replace(">", ">") | |
| # Add highlighted text | |
| highlighted_text = text[start:end] | |
| safe_text = highlighted_text.replace("&", "&").replace("<", "<").replace(">", ">") | |
| result += f'<span style="background-color: {color}; padding: 2px 4px; border-radius: 3px; color: black;">{safe_text}</span>' | |
| last_end = end | |
| # Add remaining text | |
| if last_end < len(text): | |
| remaining_text = text[last_end:] | |
| result += remaining_text.replace("&", "&").replace("<", "<").replace(">", ">") | |
| return result | |
| # Enhanced confidence gauge | |
| def create_confidence_gauge(score: float) -> go.Figure: | |
| fig = go.Figure(go.Indicator( | |
| mode = "gauge+number+delta", | |
| value = score, | |
| domain = {'x': [0, 1], 'y': [0, 1]}, | |
| title = {'text': "Ensemble Sarcasm Score"}, | |
| delta = {'reference': 0.5}, | |
| gauge = { | |
| 'axis': {'range': [None, 1]}, | |
| 'bar': {'color': "darkblue"}, | |
| 'steps': [ | |
| {'range': [0, 0.3], 'color': "lightgray"}, | |
| {'range': [0.3, 0.6], 'color': "yellow"}, | |
| {'range': [0.6, 1], 'color': "red"} | |
| ], | |
| 'threshold': { | |
| 'line': {'color': "red", 'width': 4}, | |
| 'thickness': 0.75, | |
| 'value': 0.7 | |
| } | |
| } | |
| )) | |
| fig.update_layout(height=300) | |
| return fig | |
| # Multi-model feature importance visualization | |
| def create_model_comparison_chart(model_scores: dict, rule_boost: float, final_score: float) -> go.Figure: | |
| models = list(model_scores.keys()) | |
| scores = list(model_scores.values()) | |
| # Add rule-based and final scores | |
| models.extend(['Rule-based', 'Final Ensemble']) | |
| scores.extend([rule_boost, final_score]) | |
| colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd'] | |
| fig = go.Figure(go.Bar( | |
| x=models, | |
| y=scores, | |
| marker_color=colors[:len(models)], | |
| text=[f'{score:.3f}' for score in scores], | |
| textposition='auto', | |
| )) | |
| fig.update_layout( | |
| title="Model Comparison & Ensemble Result", | |
| yaxis_title="Sarcasm Score", | |
| height=400, | |
| showlegend=False | |
| ) | |
| return fig | |
| # Real-time analysis function with multiple models | |
| def analyze_text_realtime_current(text: str, models: dict, tokenizers: dict, device) -> dict: | |
| if not text.strip(): | |
| return { | |
| 'score': 0.0, | |
| 'label': 'Enter text to analyze', | |
| 'explanations': [], | |
| 'highlights': [], | |
| 'model_scores': {} | |
| } | |
| try: | |
| # Get rule-based analysis | |
| rule_boost, explanations, highlights = enhanced_rule_analysis(text) | |
| # Get predictions from all models | |
| model_scores = get_model_predictions_current(text, models, tokenizers, device) | |
| # Ensemble prediction | |
| final_score = ensemble_prediction(model_scores, rule_boost) | |
| # Determine label | |
| if final_score > 0.8: | |
| label = "Extremely Sarcastic π€¨π" | |
| elif final_score > 0.7: | |
| label = "Highly Sarcastic π€¨" | |
| elif final_score > 0.6: | |
| label = "Likely Sarcastic π" | |
| elif final_score > 0.4: | |
| label = "Possibly Sarcastic π€" | |
| elif final_score > 0.3: | |
| label = "Probably Sincere π" | |
| else: | |
| label = "Sincere π" | |
| return { | |
| 'score': final_score, | |
| 'model_scores': model_scores, | |
| 'rule_boost': rule_boost, | |
| 'label': label, | |
| 'explanations': explanations, | |
| 'highlights': highlights | |
| } | |
| except Exception as e: | |
| return { | |
| 'score': 0.0, | |
| 'label': f'Error: {str(e)}', | |
| 'explanations': [], | |
| 'highlights': [], | |
| 'model_scores': {} | |
| } | |
| # Modify analyze_text_realtime to accept context and reply | |
| def analyze_text_realtime_experiment(context: str, reply: str, models: dict, tokenizers: dict, device) -> dict: | |
| if not reply.strip(): | |
| return { | |
| 'score': 0.0, | |
| 'label': 'Enter context and reply to analyze', | |
| 'explanations': [], | |
| 'highlights': [], | |
| 'model_scores': {} | |
| } | |
| try: | |
| # Use reply for rule-based analysis (context is not used in rules) | |
| rule_boost, explanations, highlights = enhanced_rule_analysis(reply) | |
| # Get predictions from all models using context and reply | |
| model_scores = get_model_predictions_experiment(context, reply, models, tokenizers, device) | |
| final_score = ensemble_prediction(model_scores, rule_boost) | |
| # ...label assignment unchanged... | |
| if final_score > 0.8: | |
| label = "Extremely Sarcastic π€¨π" | |
| elif final_score > 0.7: | |
| label = "Highly Sarcastic π€¨" | |
| elif final_score > 0.6: | |
| label = "Likely Sarcastic π" | |
| elif final_score > 0.4: | |
| label = "Possibly Sarcastic π€" | |
| elif final_score > 0.3: | |
| label = "Probably Sincere π" | |
| else: | |
| label = "Sincere π" | |
| return { | |
| 'score': final_score, | |
| 'model_scores': model_scores, | |
| 'rule_boost': rule_boost, | |
| 'label': label, | |
| 'explanations': explanations, | |
| 'highlights': highlights | |
| } | |
| except Exception as e: | |
| return { | |
| 'score': 0.0, | |
| 'label': f'Error: {str(e)}', | |
| 'explanations': [], | |
| 'highlights': [], | |
| 'model_scores': {} | |
| } | |
| # Streamlit UI | |
| st.set_page_config(page_title="Enhanced Sarcasm Detector", page_icon="π€¨", layout="wide") | |
| st.title("π¨οΈ Enhanced Multi-Model Sarcasm Detector") | |
| st.markdown("*Combining DistilBERT (Reddit-trained) + HelinIvan + Rule-based Analysis*") | |
| # Load models | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| st.markdown(f"**Device:** {device}") | |
| with st.spinner("Loading AI models..."): | |
| models, tokenizers = load_models() | |
| # Model status display | |
| st.markdown("### π€ Model Status") | |
| status_cols = st.columns(2) | |
| with status_cols[0]: | |
| helinivan_status = "β Loaded" if models.get('helinivan') else "β Failed" | |
| st.markdown(f"**HelinIvan Model:** {helinivan_status}") | |
| with status_cols[1]: | |
| distilbert_status = "β Loaded" if models.get('distilbert') else "β Failed" | |
| st.markdown(f"**DistilBERT Model:** {distilbert_status}") | |
| # Sidebar with examples and tips (shared) | |
| with st.sidebar: | |
| st.markdown("### π‘ **Quick Examples**") | |
| example_buttons = [ | |
| ("Oh great, more traffic π", "social_media"), | |
| ("Yeah, I just LOVE waiting in line", "emphasis"), | |
| ("What could possibly go wrong?", "rhetorical"), | |
| ("Perfect timing as always...", "timing"), | |
| ("Thanks for the help genius", "reddit_style"), | |
| ("WOW so helpful!!!", "caps_sarcasm"), | |
| ("No shit Sherlock π€‘", "reddit_sarcasm"), | |
| ("Truly groundbreaking stuff here", "mock_praise") | |
| ] | |
| for example_text, example_type in example_buttons: | |
| if st.button(f"π {example_text[:25]}...", key=example_type): | |
| st.session_state.example_text = example_text | |
| # --- Tabs for navigation --- | |
| tab1, tab2 = st.tabs(["Single Message (Current)", "Context-Aware (Experimental)"]) | |
| # --- Tab 1: Single Message (Current) --- | |
| with tab1: | |
| st.markdown("### Single Message Sarcasm Detection") | |
| st.markdown("Analyze sarcasm in a single message (no conversational context).") | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| # Text input with real-time analysi | |
| default_text = st.session_state.get('example_text', '') | |
| user_text = st.text_area( | |
| "Enter a paragraph for multi-model sarcasm analysis:", | |
| value=default_text, | |
| height=120, | |
| placeholder="Try: 'Oh fantastic, another meeting that could have been an email π What a brilliant use of everyone's time...'" | |
| ) | |
| # Real-time analysis | |
| if user_text: | |
| with st.spinner("Analyzing with multiple models..."): | |
| analysis = analyze_text_realtime_current(user_text, models, tokenizers, device) | |
| # Display highlighted text | |
| st.markdown("### π― **Analysis Results**") | |
| highlighted_html = create_highlighted_text(user_text, analysis['highlights']) | |
| st.markdown(f'<div style="padding: 10px; border: 1px solid #ddd; border-radius: 5px; background-color: #f9f9f9;">{highlighted_html}</div>', unsafe_allow_html=True) | |
| # Prediction and confidence | |
| st.markdown(f"### **Prediction: {analysis['label']}**") | |
| # Progress bar with custom colors | |
| progress_color = "π΄" if analysis['score'] > 0.7 else "π‘" if analysis['score'] > 0.4 else "π’" | |
| st.write(f"**Ensemble Score: {analysis['score']:.3f}** {progress_color}") | |
| st.progress(analysis['score']) | |
| with col2: | |
| if user_text and 'analysis' in locals(): | |
| # Confidence gauge | |
| st.markdown("### π **Confidence Gauge**") | |
| gauge_fig = create_confidence_gauge(analysis['score']) | |
| st.plotly_chart(gauge_fig, use_container_width=True) | |
| # Multi-model analysis section | |
| if user_text and 'analysis' in locals() and analysis['model_scores']: | |
| st.markdown("### π **Multi-Model Analysis**") | |
| col3, col4 = st.columns([1, 1]) | |
| with col3: | |
| st.markdown("#### π **Individual Model Scores:**") | |
| for model_name, score in analysis['model_scores'].items(): | |
| model_display = { | |
| 'helinivan': 'HelinIvan Model', | |
| 'distilbert': 'DistilBERT Model' | |
| } | |
| display_name = model_display.get(model_name, model_name) | |
| st.write(f"β’ **{display_name}:** {score:.3f}") | |
| st.write(f"β’ **Rule-based boost:** +{analysis['rule_boost']:.3f}") | |
| st.write(f"β’ **π― Final ensemble:** {analysis['score']:.3f}") | |
| if analysis['explanations']: | |
| st.markdown("#### π **Detected Patterns:**") | |
| for i, explanation in enumerate(analysis['explanations'], 1): | |
| st.write(f"{i}. {explanation}") | |
| with col4: | |
| st.markdown("#### π **Model Comparison:**") | |
| comparison_fig = create_model_comparison_chart( | |
| analysis['model_scores'], | |
| analysis['rule_boost'], | |
| analysis['score'] | |
| ) | |
| st.plotly_chart(comparison_fig, use_container_width=True) | |
| # Pattern legend | |
| if user_text and 'analysis' in locals() and analysis['highlights']: | |
| st.markdown("### π¨ **Highlighting Legend**") | |
| legend_cols = st.columns(4) | |
| legend_items = [ | |
| ("Sarcastic Phrases", "#ff6b6b"), | |
| ("Emojis", "#4ecdc4"), | |
| ("Punctuation", "#45b7d1"), | |
| ("Questions", "#96ceb4"), | |
| ("Emphasis", "#feca57"), | |
| ("Repetition", "#ff9ff3"), | |
| ("Exaggeration", "#54a0ff"), | |
| ("Ellipsis", "#fd79a8") | |
| ] | |
| for i, (label, color) in enumerate(legend_items): | |
| with legend_cols[i % 4]: | |
| safe_label = label.replace("<", "<").replace(">", ">") | |
| st.markdown(f'<span style="background-color: {color}; padding: 2px 6px; border-radius: 3px; color: black; font-size: 12px;">{safe_label}</span>', unsafe_allow_html=True) | |
| # --- Tab 2: Context-Aware (Experimental) --- | |
| with tab2: | |
| st.markdown("### Context-Aware Sarcasm Detection (Experimental)") | |
| st.info("This feature is experimental. The models are **not yet trained** on context+reply pairs. Predictions are based on formatting the input as a sentence pair, but results may not be reliable.") | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| context_text = st.text_area( | |
| "Context (previous message):", | |
| value=st.session_state.get('context_text', ''), | |
| height=68, | |
| placeholder="e.g. 'Can you finish this by today?'" | |
| ) | |
| reply_text = st.text_area( | |
| "Reply (current message):", | |
| value=st.session_state.get('reply_text', st.session_state.get('example_text', '')), | |
| height=80, | |
| placeholder="e.g. 'Oh sure, because I have nothing else to do.'" | |
| ) | |
| if reply_text: | |
| with st.spinner("Analyzing with experimental context-aware input..."): | |
| analysis_ctx = analyze_text_realtime_experiment(context_text, reply_text, models, tokenizers, device) | |
| st.markdown("### π― **Analysis Results**") | |
| highlighted_html = create_highlighted_text(reply_text, analysis_ctx['highlights']) | |
| st.markdown(f'<div style="padding: 10px; border: 1px solid #ddd; border-radius: 5px; background-color: #f9f9f9;">{highlighted_html}</div>', unsafe_allow_html=True) | |
| st.markdown(f"### **Prediction: {analysis_ctx['label']}**") | |
| progress_color = "π΄" if analysis_ctx['score'] > 0.7 else "π‘" if analysis_ctx['score'] > 0.4 else "π’" | |
| st.write(f"**Ensemble Score: {analysis_ctx['score']:.3f}** {progress_color}") | |
| st.progress(analysis_ctx['score']) | |
| with col2: | |
| if reply_text and 'analysis_ctx' in locals(): | |
| st.markdown("### π **Confidence Gauge**") | |
| gauge_fig = create_confidence_gauge(analysis_ctx['score']) | |
| st.plotly_chart(gauge_fig, use_container_width=True) | |
| if reply_text and 'analysis_ctx' in locals() and analysis_ctx['model_scores']: | |
| st.markdown("### π **Multi-Model Analysis**") | |
| col3, col4 = st.columns([1, 1]) | |
| with col3: | |
| st.markdown("#### π **Individual Model Scores:**") | |
| for model_name, score in analysis_ctx['model_scores'].items(): | |
| model_display = { | |
| 'helinivan': 'HelinIvan Model', | |
| 'distilbert': 'DistilBERT Model' | |
| } | |
| display_name = model_display.get(model_name, model_name) | |
| st.write(f"β’ **{display_name}:** {score:.3f}") | |
| st.write(f"β’ **Rule-based boost:** +{analysis_ctx['rule_boost']:.3f}") | |
| st.write(f"β’ **π― Final ensemble:** {analysis_ctx['score']:.3f}") | |
| if analysis_ctx['explanations']: | |
| st.markdown("#### π **Detected Patterns:**") | |
| for i, explanation in enumerate(analysis_ctx['explanations'], 1): | |
| st.write(f"{i}. {explanation}") | |
| with col4: | |
| st.markdown("#### π **Model Comparison:**") | |
| comparison_fig = create_model_comparison_chart( | |
| analysis_ctx['model_scores'], | |
| analysis_ctx['rule_boost'], | |
| analysis_ctx['score'] | |
| ) | |
| st.plotly_chart(comparison_fig, use_container_width=True) | |
| if reply_text and 'analysis_ctx' in locals() and analysis_ctx['highlights']: | |
| st.markdown("### π¨ **Highlighting Legend**") | |
| legend_cols = st.columns(4) | |
| legend_items = [ | |
| ("Sarcastic Phrases", "#ff6b6b"), | |
| ("Emojis", "#4ecdc4"), | |
| ("Punctuation", "#45b7d1"), | |
| ("Questions", "#96ceb4"), | |
| ("Emphasis", "#feca57"), | |
| ("Repetition", "#ff9ff3"), | |
| ("Exaggeration", "#54a0ff"), | |
| ("Ellipsis", "#fd79a8") | |
| ] | |
| for i, (label, color) in enumerate(legend_items): | |
| with legend_cols[i % 4]: | |
| safe_label = label.replace("<", "<").replace(">", ">") | |
| st.markdown(f'<span style="background-color: {color}; padding: 2px 6px; border-radius: 3px; color: black; font-size: 12px;">{safe_label}</span>', unsafe_allow_html=True) | |
| # --- Shared tutorial and advanced settings --- | |
| with st.expander("π **Multi-Model Sarcasm Detection Guide**"): | |
| st.markdown(""" | |
| ### How the Enhanced Detection Works: | |
| 1. **π€ HelinIvan Model**: General English sarcasm detection | |
| 2. **π€ DistilBERT Model**: Specialized Reddit-trained sarcasm detector | |
| 3. **π Rule-Based Analysis**: Linguistic patterns and social media cues | |
| 4. **π― Ensemble Method**: Combines all approaches with weighted averaging | |
| ### Model Advantages: | |
| - **HelinIvan**: Good for formal and general sarcasm | |
| - **DistilBERT (Reddit)**: Excellent for informal, social media style sarcasm | |
| - **Rule-based**: Catches obvious patterns and cultural references | |
| ### Why Ensemble Works Better: | |
| - **Robustness**: Multiple models reduce individual model weaknesses | |
| - **Coverage**: Different training data covers different sarcasm styles | |
| - **Confidence**: Agreement between models increases reliability | |
| ### Try These Reddit-Style Examples: | |
| - "No shit Sherlock π€‘" | |
| - "Thanks Captain Obvious" | |
| - "Groundbreaking discovery there genius" | |
| - "What a concept... mind blown π€―" | |
| """) | |
| # Model weights adjustment (advanced users) | |
| with st.expander("βοΈ **Advanced: Adjust Model Weights**"): | |
| st.markdown("Fine-tune the ensemble by adjusting model importance:") | |
| col_w1, col_w2, col_w3 = st.columns(3) | |
| with col_w1: | |
| helinivan_weight = st.slider("HelinIvan Weight", 0.0, 1.0, 0.4, 0.1) | |
| with col_w2: | |
| distilbert_weight = st.slider("DistilBERT Weight", 0.0, 1.0, 0.5, 0.1) | |
| with col_w3: | |
| rules_weight = st.slider("Rules Weight", 0.0, 0.5, 0.1, 0.05) | |
| st.info(f"Weights - HelinIvan: {helinivan_weight}, DistilBERT: {distilbert_weight}, Rules: {rules_weight}") | |
| # Clear session state | |
| if st.button("π Clear Text", key="clear_main"): | |
| if 'example_text' in st.session_state: | |
| del st.session_state.example_text | |
| st.rerun() | |