Spaces:
Sleeping
Sleeping
| """Inference module for emotion prediction using our trained model.""" | |
| import math | |
| import os | |
| from pathlib import Path | |
| from typing import Optional, Tuple | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| import torch.nn.functional as F | |
| from .config import config | |
| from .visualize import EmotionVisualizer | |
| class EmotionPredictor: | |
| """ | |
| Predict emotions from text using our custom trained model. | |
| Trained on dair-ai/emotion dataset with 16,000+ samples. | |
| Emotions: sadness, joy, love, anger, fear, surprise | |
| """ | |
| def __init__( | |
| self, | |
| model_path: Path = None, | |
| enable_viz: bool = True, | |
| target_sarcasm_prior: Optional[float] = config.target_sarcasm_prior, | |
| sarcasm_threshold: Optional[float] = config.sarcasm_threshold, | |
| train_sarcasm_prior: Optional[float] = None, | |
| ): | |
| """ | |
| Initialize the emotion predictor. | |
| Args: | |
| model_path: Path to trained model directory. | |
| Defaults to models/emotion_classifier/final | |
| enable_viz: Whether to enable chart visualization | |
| target_sarcasm_prior: Target sarcasm prevalence in deployment text (0-1) | |
| sarcasm_threshold: Optional decision threshold for sarcasm class (0-1) | |
| train_sarcasm_prior: Optional override for sarcasm prevalence seen in training | |
| """ | |
| if not model_path: | |
| # Dynamically fetch the latest modified folder in models/ (or default to emotion_classifier/final) | |
| models_dir = config.model_dir | |
| subdirs = [d for d in models_dir.iterdir() if d.is_dir() and "emotion_classifier" in d.name] | |
| if subdirs: | |
| latest_dir = max(subdirs, key=lambda d: d.stat().st_mtime) | |
| # If there's a final folder inside it, use it | |
| if (latest_dir / "final").exists(): | |
| self.model_path = latest_dir / "final" | |
| else: | |
| self.model_path = latest_dir | |
| else: | |
| self.model_path = config.model_dir / "emotion_classifier" / "final" | |
| else: | |
| self.model_path = Path(model_path) | |
| self.device = self._get_device() | |
| self.enable_viz = enable_viz | |
| self.target_sarcasm_prior = self._validate_probability(target_sarcasm_prior, "target_sarcasm_prior") | |
| self.sarcasm_threshold = self._validate_probability(sarcasm_threshold, "sarcasm_threshold", allow_none=True) | |
| self.user_train_sarcasm_prior = self._validate_probability( | |
| train_sarcasm_prior, | |
| "train_sarcasm_prior", | |
| allow_none=True, | |
| ) | |
| # Initialize visualizer if enabled | |
| if self.enable_viz: | |
| self.visualizer = EmotionVisualizer() | |
| self._load_model() | |
| def _validate_probability(value: Optional[float], name: str, allow_none: bool = True) -> Optional[float]: | |
| """Validate probability-like arguments.""" | |
| if value is None: | |
| if allow_none: | |
| return None | |
| raise ValueError(f"{name} cannot be None") | |
| value = float(value) | |
| if not 0.0 < value < 1.0: | |
| raise ValueError(f"{name} must be in the open interval (0, 1), got {value}") | |
| return value | |
| def _get_device(self): | |
| """Get the best available device.""" | |
| if torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| elif torch.backends.mps.is_available(): | |
| return torch.device("mps") | |
| return torch.device("cpu") | |
| def _load_model(self): | |
| """Load model and tokenizer, with optional Hugging Face Hub fallback. | |
| If the local model path doesn't exist, checks the HF_MODEL_REPO env var | |
| (e.g. "username/eumora-emotion-classifier") and downloads from HF Hub. | |
| Use HF_TOKEN for private repos. | |
| """ | |
| if self.model_path.exists(): | |
| source = str(self.model_path) | |
| else: | |
| hf_repo = os.environ.get("HF_MODEL_REPO") | |
| if hf_repo: | |
| source = hf_repo | |
| print(f"Local model not found. Loading from HuggingFace Hub: {hf_repo}") | |
| else: | |
| raise FileNotFoundError( | |
| f"β No trained model found at {self.model_path}\n" | |
| f" Either run 'python main.py train' or set HF_MODEL_REPO env var." | |
| ) | |
| hf_token = os.environ.get("HF_TOKEN") or None | |
| print(f"Loading model from: {source}") | |
| self.tokenizer = AutoTokenizer.from_pretrained(source, token=hf_token) | |
| self.model = AutoModelForSequenceClassification.from_pretrained(source, token=hf_token) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| # Get label mappings from model config | |
| self.id2label = {int(k): v for k, v in self.model.config.id2label.items()} | |
| self.label2id = {k: int(v) for k, v in self.model.config.label2id.items()} | |
| self.sarcasm_idx = self.label2id.get("sarcasm") | |
| self.train_sarcasm_prior = self._resolve_train_sarcasm_prior() | |
| if self.sarcasm_idx is None and (self.target_sarcasm_prior is not None or self.sarcasm_threshold is not None): | |
| print("β οΈ Loaded model has no 'sarcasm' class. Prior adjustment and sarcasm threshold are disabled.") | |
| def _resolve_train_sarcasm_prior(self) -> Optional[float]: | |
| """Resolve training sarcasm prior from explicit override, model metadata, or config fallback.""" | |
| if self.user_train_sarcasm_prior is not None: | |
| return self.user_train_sarcasm_prior | |
| model_prior = getattr(self.model.config, "sarcasm_train_prior", None) | |
| if model_prior is not None: | |
| try: | |
| return self._validate_probability(model_prior, "model.sarcasm_train_prior", allow_none=False) | |
| except ValueError: | |
| pass | |
| if self.sarcasm_idx is not None: | |
| return self._validate_probability( | |
| config.assumed_train_sarcasm_prior, | |
| "config.assumed_train_sarcasm_prior", | |
| allow_none=False, | |
| ) | |
| return None | |
| def _compute_sarcasm_logit_shift(self) -> float: | |
| """Compute prior-shift logit adjustment for sarcasm vs non-sarcasm.""" | |
| if self.sarcasm_idx is None or self.target_sarcasm_prior is None or self.train_sarcasm_prior is None: | |
| return 0.0 | |
| target_odds = self.target_sarcasm_prior / (1.0 - self.target_sarcasm_prior) | |
| train_odds = self.train_sarcasm_prior / (1.0 - self.train_sarcasm_prior) | |
| return math.log(target_odds) - math.log(train_odds) | |
| def _apply_sarcasm_prior_adjustment(self, logits: torch.Tensor) -> Tuple[torch.Tensor, float]: | |
| """Shift sarcasm logit to better match deployment prevalence.""" | |
| shift = self._compute_sarcasm_logit_shift() | |
| if self.sarcasm_idx is None or shift == 0.0: | |
| return logits, 0.0 | |
| adjusted = logits.clone() | |
| adjusted[self.sarcasm_idx] = adjusted[self.sarcasm_idx] + shift | |
| return adjusted, shift | |
| def _apply_sarcasm_threshold(self, probs: torch.Tensor, pred_idx: int) -> int: | |
| """Apply optional one-vs-rest sarcasm thresholding to final class decision.""" | |
| if self.sarcasm_idx is None or self.sarcasm_threshold is None: | |
| return pred_idx | |
| sarcasm_prob = probs[self.sarcasm_idx].item() | |
| if sarcasm_prob >= self.sarcasm_threshold: | |
| return self.sarcasm_idx | |
| if pred_idx == self.sarcasm_idx: | |
| non_sarcasm_probs = probs.clone() | |
| non_sarcasm_probs[self.sarcasm_idx] = -1.0 | |
| return torch.argmax(non_sarcasm_probs).item() | |
| return pred_idx | |
| def predict(self, text: str, create_chart: bool = False, show_chart: bool = True) -> dict: | |
| """ | |
| Predict emotion from text. | |
| Args: | |
| text: Input text (lyrics, sentence, etc.) | |
| create_chart: Whether to generate a visualization chart | |
| show_chart: Whether to display the chart (only if create_chart=True) | |
| Returns: | |
| dict with emotion, confidence, probabilities, explanation, and optional chart_path | |
| """ | |
| # Tokenize | |
| inputs = self.tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=config.max_length, | |
| padding=True | |
| ) | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| # Predict | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| # Get probabilities | |
| logits = outputs.logits[0].detach().cpu() | |
| adjusted_logits, sarcasm_logit_shift = self._apply_sarcasm_prior_adjustment(logits) | |
| probs = torch.softmax(adjusted_logits, dim=-1) | |
| pred_idx = torch.argmax(probs).item() | |
| # Hardcode 8% sarcasm tripwire | |
| if self.sarcasm_idx is not None and probs[self.sarcasm_idx].item() >= 0.08: | |
| if pred_idx != self.sarcasm_idx: | |
| # Swap the max probability with sarcasm so it shows up as the dominant emotion visually | |
| original_max_val = probs[pred_idx].item() | |
| sarcasm_val = probs[self.sarcasm_idx].item() | |
| probs[pred_idx] = sarcasm_val | |
| probs[self.sarcasm_idx] = original_max_val | |
| pred_idx = self.sarcasm_idx | |
| confidence = probs[pred_idx].item() | |
| # Build results | |
| emotion = self.id2label[pred_idx] | |
| # Ensure probabilities sum to exactly 100% when displayed | |
| raw_probs = [probs[i].item() for i in range(len(probs))] | |
| # Convert to percentages and round, then normalize to ensure 100% total | |
| percent_probs = [p * 100 for p in raw_probs] | |
| rounded_percents = [round(p, 1) for p in percent_probs] | |
| # Adjust largest value to make sum exactly 100.0% | |
| total_percent = sum(rounded_percents) | |
| if total_percent != 100.0: | |
| max_idx = rounded_percents.index(max(rounded_percents)) | |
| rounded_percents[max_idx] += (100.0 - total_percent) | |
| # Convert back to probability format (0-1 range) | |
| emotion_probs = { | |
| self.id2label[i]: round(rounded_percents[i] / 100.0, 4) | |
| for i in range(len(probs)) | |
| } | |
| # Get music context | |
| music_context = config.emotion_to_music_mood.get(emotion, { | |
| "mood": emotion, "energy": "medium", "valence": "neutral" | |
| }) | |
| result = { | |
| "emotion": emotion, | |
| "confidence": round(confidence, 4), | |
| "probabilities": emotion_probs, | |
| "music_context": music_context, | |
| "explanation": self._generate_explanation(emotion, confidence, emotion_probs, music_context), | |
| "calibration": { | |
| "target_sarcasm_prior": self.target_sarcasm_prior, | |
| "train_sarcasm_prior": self.train_sarcasm_prior, | |
| "sarcasm_logit_shift": round(sarcasm_logit_shift, 4), | |
| "sarcasm_threshold": self.sarcasm_threshold, | |
| "sarcasm_probability": round(emotion_probs.get("sarcasm", 0.0), 4), | |
| }, | |
| } | |
| # Generate visualization if requested | |
| if create_chart and self.enable_viz: | |
| try: | |
| chart_path = self.visualizer.create_emotion_bar_chart( | |
| emotion_probs, text, show_chart=show_chart, primary_emotion=emotion | |
| ) | |
| result["chart_path"] = str(chart_path) | |
| print(f"Chart saved to: {chart_path}") | |
| except Exception as e: | |
| print(f"Could not create chart: {e}") | |
| result["chart_path"] = None | |
| return result | |
| def _generate_explanation(self, emotion: str, confidence: float, | |
| probs: dict, music_context: dict) -> str: | |
| """Generate XAI-style explanation for the prediction.""" | |
| confidence_level = "high" if confidence > 0.7 else "moderate" if confidence > 0.4 else "low" | |
| # Find secondary emotion | |
| sorted_emotions = sorted(probs.items(), key=lambda x: x[1], reverse=True) | |
| secondary = sorted_emotions[1] if len(sorted_emotions) > 1 else None | |
| # Emotion descriptors | |
| descriptors = { | |
| "sadness": "melancholic and sorrowful themes", | |
| "joy": "uplifting and celebratory content", | |
| "love": "romantic and affectionate sentiments", | |
| "anger": "intense and confrontational language", | |
| "fear": "anxious and uncertain undertones", | |
| "surprise": "unexpected and wonder-filled expressions", | |
| "neutral": "balanced and observational language", | |
| "sarcasm": "ironic or intentionally contradictory phrasing", | |
| } | |
| descriptor = descriptors.get(emotion, f"{emotion} emotional markers") | |
| explanation = ( | |
| f"Detected {descriptor} with {confidence_level} confidence ({confidence:.1%}). " | |
| f"Suggests {music_context.get('mood', emotion)} music with " | |
| f"{music_context.get('energy', 'medium')} energy." | |
| ) | |
| if secondary and secondary[1] > 0.15: | |
| explanation += f" Secondary: {secondary[0]} ({secondary[1]:.1%})." | |
| return explanation | |
| def predict_batch(self, texts: list) -> list: | |
| """Predict emotions for multiple texts.""" | |
| return [self.predict(text) for text in texts] | |
| def analyze_song(self, title: str, artist: str, lyrics: str, create_chart: bool = True) -> dict: | |
| """Full song analysis with metadata and optional visualization.""" | |
| prediction = self.predict(lyrics, create_chart=False) # We'll create a detailed chart instead | |
| result = { | |
| "song": {"title": title, "artist": artist}, | |
| "analysis": prediction, | |
| "tags": self._generate_tags(prediction), | |
| } | |
| # Generate detailed visualization for song analysis | |
| if create_chart and self.enable_viz: | |
| try: | |
| chart_path = self.visualizer.create_detailed_analysis_chart( | |
| prediction, f"{title} by {artist}" | |
| ) | |
| result["chart_path"] = str(chart_path) | |
| print(f"π Detailed analysis chart saved to: {chart_path}") | |
| except Exception as e: | |
| print(f"β οΈ Could not create detailed chart: {e}") | |
| result["chart_path"] = None | |
| return result | |
| def predict_with_visualization(self, text: str, chart_type: str = "simple") -> dict: | |
| """ | |
| Predict with automatic visualization. | |
| Args: | |
| text: Input text | |
| chart_type: Type of chart ('simple', 'detailed') | |
| Returns: | |
| Prediction result with chart | |
| """ | |
| if chart_type == "detailed": | |
| result = self.predict(text, create_chart=False) | |
| if self.enable_viz: | |
| try: | |
| chart_path = self.visualizer.create_detailed_analysis_chart(result, text) | |
| result["chart_path"] = str(chart_path) | |
| except Exception as e: | |
| print(f"β οΈ Could not create detailed chart: {e}") | |
| result["chart_path"] = None | |
| else: | |
| result = self.predict(text, create_chart=True, show_chart=True) | |
| return result | |
| def _generate_tags(self, prediction: dict) -> list: | |
| """Generate recommendation tags from prediction.""" | |
| tags = [] | |
| emotion = prediction["emotion"] | |
| music_ctx = prediction["music_context"] | |
| tags.append(f"emotion:{emotion}") | |
| tags.append(f"mood:{music_ctx.get('mood', emotion)}") | |
| tags.append(f"energy:{music_ctx.get('energy', 'medium')}") | |
| tags.append(f"valence:{music_ctx.get('valence', 'neutral')}") | |
| # Activity suggestions | |
| activity_map = { | |
| "joy": ["party", "workout", "celebration"], | |
| "sadness": ["reflection", "rainy-day", "comfort"], | |
| "love": ["romantic", "date-night", "slow-dance"], | |
| "anger": ["workout", "release", "intensity"], | |
| "fear": ["thriller", "suspense", "atmospheric"], | |
| "surprise": ["discovery", "adventure", "exploration"], | |
| } | |
| activities = activity_map.get(emotion, ["general"]) | |
| tags.extend([f"activity:{a}" for a in activities]) | |
| return tags | |
| def demo(): | |
| """Demo of the trained emotion predictor with visualizations.""" | |
| print("\n" + "=" * 60) | |
| print("π΅ EUMORA - Emotion Analysis Demo (Custom Trained Model)") | |
| print("=" * 60) | |
| try: | |
| predictor = EmotionPredictor(enable_viz=True) | |
| except FileNotFoundError as e: | |
| print(f"\n{e}") | |
| return | |
| test_samples = [ | |
| ("Happy lyrics", "I feel so alive today, everything is wonderful and bright!"), | |
| ("Sad lyrics", "My heart is broken, tears falling like rain in the dark night"), | |
| ("Angry lyrics", "I hate this, you betrayed me, I want to scream at the world!"), | |
| ("Love lyrics", "You are my everything, I want to hold you forever my darling"), | |
| ("Fear lyrics", "Something is watching me in the shadows, I'm scared to move"), | |
| ("Surprise lyrics", "I can't believe it happened! This is incredible, wow!"), | |
| ] | |
| print(f"\nAnalyzing {len(test_samples)} samples...\n") | |
| print("-" * 60) | |
| # Create comparison charts | |
| all_results = [] | |
| for label, text in test_samples: | |
| result = predictor.predict(text, create_chart=False) # Individual charts disabled for comparison | |
| all_results.append(result) | |
| print(f"\nπ {label}:") | |
| print(f" \"{text[:50]}...\"") | |
| print(f" π Emotion: {result['emotion'].upper()} ({result['confidence']:.1%})") | |
| print(f" πΈ Context: {result['music_context']}") | |
| # Show probability distribution (text-based) | |
| sorted_probs = sorted(result['probabilities'].items(), key=lambda x: x[1], reverse=True) | |
| print(f" π Distribution:") | |
| for emo, prob in sorted_probs[:3]: | |
| bar = "β" * int(prob * 20) | |
| print(f" {emo:>10}: {bar:<20} {prob:.1%}") | |
| print("-" * 60) | |
| # Create comparison visualization | |
| print(f"\nπ Generating comparison chart...") | |
| try: | |
| titles = [label for label, _ in test_samples] | |
| comparison_path = predictor.visualizer.create_comparison_chart( | |
| all_results, titles, show_chart=True | |
| ) | |
| print(f"π Comparison chart saved to: {comparison_path}") | |
| except Exception as e: | |
| print(f"β οΈ Could not create comparison chart: {e}") | |
| print("\nβ Demo complete!") | |
| if __name__ == "__main__": | |
| demo() | |