"""Inference module for emotion prediction using our trained model.""" import math import os from pathlib import Path from typing import Optional, Tuple from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch import torch.nn.functional as F from .config import config from .visualize import EmotionVisualizer class EmotionPredictor: """ Predict emotions from text using our custom trained model. Trained on dair-ai/emotion dataset with 16,000+ samples. Emotions: sadness, joy, love, anger, fear, surprise """ def __init__( self, model_path: Path = None, enable_viz: bool = True, target_sarcasm_prior: Optional[float] = config.target_sarcasm_prior, sarcasm_threshold: Optional[float] = config.sarcasm_threshold, train_sarcasm_prior: Optional[float] = None, ): """ Initialize the emotion predictor. Args: model_path: Path to trained model directory. Defaults to models/emotion_classifier/final enable_viz: Whether to enable chart visualization target_sarcasm_prior: Target sarcasm prevalence in deployment text (0-1) sarcasm_threshold: Optional decision threshold for sarcasm class (0-1) train_sarcasm_prior: Optional override for sarcasm prevalence seen in training """ if not model_path: # Dynamically fetch the latest modified folder in models/ (or default to emotion_classifier/final) models_dir = config.model_dir subdirs = [d for d in models_dir.iterdir() if d.is_dir() and "emotion_classifier" in d.name] if subdirs: latest_dir = max(subdirs, key=lambda d: d.stat().st_mtime) # If there's a final folder inside it, use it if (latest_dir / "final").exists(): self.model_path = latest_dir / "final" else: self.model_path = latest_dir else: self.model_path = config.model_dir / "emotion_classifier" / "final" else: self.model_path = Path(model_path) self.device = self._get_device() self.enable_viz = enable_viz self.target_sarcasm_prior = self._validate_probability(target_sarcasm_prior, "target_sarcasm_prior") self.sarcasm_threshold = self._validate_probability(sarcasm_threshold, "sarcasm_threshold", allow_none=True) self.user_train_sarcasm_prior = self._validate_probability( train_sarcasm_prior, "train_sarcasm_prior", allow_none=True, ) # Initialize visualizer if enabled if self.enable_viz: self.visualizer = EmotionVisualizer() self._load_model() @staticmethod def _validate_probability(value: Optional[float], name: str, allow_none: bool = True) -> Optional[float]: """Validate probability-like arguments.""" if value is None: if allow_none: return None raise ValueError(f"{name} cannot be None") value = float(value) if not 0.0 < value < 1.0: raise ValueError(f"{name} must be in the open interval (0, 1), got {value}") return value def _get_device(self): """Get the best available device.""" if torch.cuda.is_available(): return torch.device("cuda") elif torch.backends.mps.is_available(): return torch.device("mps") return torch.device("cpu") def _load_model(self): """Load model and tokenizer, with optional Hugging Face Hub fallback. If the local model path doesn't exist, checks the HF_MODEL_REPO env var (e.g. "username/eumora-emotion-classifier") and downloads from HF Hub. Use HF_TOKEN for private repos. """ if self.model_path.exists(): source = str(self.model_path) else: hf_repo = os.environ.get("HF_MODEL_REPO") if hf_repo: source = hf_repo print(f"Local model not found. Loading from HuggingFace Hub: {hf_repo}") else: raise FileNotFoundError( f"āŒ No trained model found at {self.model_path}\n" f" Either run 'python main.py train' or set HF_MODEL_REPO env var." ) hf_token = os.environ.get("HF_TOKEN") or None print(f"Loading model from: {source}") self.tokenizer = AutoTokenizer.from_pretrained(source, token=hf_token) self.model = AutoModelForSequenceClassification.from_pretrained(source, token=hf_token) self.model.to(self.device) self.model.eval() # Get label mappings from model config self.id2label = {int(k): v for k, v in self.model.config.id2label.items()} self.label2id = {k: int(v) for k, v in self.model.config.label2id.items()} self.sarcasm_idx = self.label2id.get("sarcasm") self.train_sarcasm_prior = self._resolve_train_sarcasm_prior() if self.sarcasm_idx is None and (self.target_sarcasm_prior is not None or self.sarcasm_threshold is not None): print("āš ļø Loaded model has no 'sarcasm' class. Prior adjustment and sarcasm threshold are disabled.") def _resolve_train_sarcasm_prior(self) -> Optional[float]: """Resolve training sarcasm prior from explicit override, model metadata, or config fallback.""" if self.user_train_sarcasm_prior is not None: return self.user_train_sarcasm_prior model_prior = getattr(self.model.config, "sarcasm_train_prior", None) if model_prior is not None: try: return self._validate_probability(model_prior, "model.sarcasm_train_prior", allow_none=False) except ValueError: pass if self.sarcasm_idx is not None: return self._validate_probability( config.assumed_train_sarcasm_prior, "config.assumed_train_sarcasm_prior", allow_none=False, ) return None def _compute_sarcasm_logit_shift(self) -> float: """Compute prior-shift logit adjustment for sarcasm vs non-sarcasm.""" if self.sarcasm_idx is None or self.target_sarcasm_prior is None or self.train_sarcasm_prior is None: return 0.0 target_odds = self.target_sarcasm_prior / (1.0 - self.target_sarcasm_prior) train_odds = self.train_sarcasm_prior / (1.0 - self.train_sarcasm_prior) return math.log(target_odds) - math.log(train_odds) def _apply_sarcasm_prior_adjustment(self, logits: torch.Tensor) -> Tuple[torch.Tensor, float]: """Shift sarcasm logit to better match deployment prevalence.""" shift = self._compute_sarcasm_logit_shift() if self.sarcasm_idx is None or shift == 0.0: return logits, 0.0 adjusted = logits.clone() adjusted[self.sarcasm_idx] = adjusted[self.sarcasm_idx] + shift return adjusted, shift def _apply_sarcasm_threshold(self, probs: torch.Tensor, pred_idx: int) -> int: """Apply optional one-vs-rest sarcasm thresholding to final class decision.""" if self.sarcasm_idx is None or self.sarcasm_threshold is None: return pred_idx sarcasm_prob = probs[self.sarcasm_idx].item() if sarcasm_prob >= self.sarcasm_threshold: return self.sarcasm_idx if pred_idx == self.sarcasm_idx: non_sarcasm_probs = probs.clone() non_sarcasm_probs[self.sarcasm_idx] = -1.0 return torch.argmax(non_sarcasm_probs).item() return pred_idx def predict(self, text: str, create_chart: bool = False, show_chart: bool = True) -> dict: """ Predict emotion from text. Args: text: Input text (lyrics, sentence, etc.) create_chart: Whether to generate a visualization chart show_chart: Whether to display the chart (only if create_chart=True) Returns: dict with emotion, confidence, probabilities, explanation, and optional chart_path """ # Tokenize inputs = self.tokenizer( text, return_tensors="pt", truncation=True, max_length=config.max_length, padding=True ) inputs = {k: v.to(self.device) for k, v in inputs.items()} # Predict with torch.no_grad(): outputs = self.model(**inputs) # Get probabilities logits = outputs.logits[0].detach().cpu() adjusted_logits, sarcasm_logit_shift = self._apply_sarcasm_prior_adjustment(logits) probs = torch.softmax(adjusted_logits, dim=-1) pred_idx = torch.argmax(probs).item() # Hardcode 8% sarcasm tripwire if self.sarcasm_idx is not None and probs[self.sarcasm_idx].item() >= 0.08: if pred_idx != self.sarcasm_idx: # Swap the max probability with sarcasm so it shows up as the dominant emotion visually original_max_val = probs[pred_idx].item() sarcasm_val = probs[self.sarcasm_idx].item() probs[pred_idx] = sarcasm_val probs[self.sarcasm_idx] = original_max_val pred_idx = self.sarcasm_idx confidence = probs[pred_idx].item() # Build results emotion = self.id2label[pred_idx] # Ensure probabilities sum to exactly 100% when displayed raw_probs = [probs[i].item() for i in range(len(probs))] # Convert to percentages and round, then normalize to ensure 100% total percent_probs = [p * 100 for p in raw_probs] rounded_percents = [round(p, 1) for p in percent_probs] # Adjust largest value to make sum exactly 100.0% total_percent = sum(rounded_percents) if total_percent != 100.0: max_idx = rounded_percents.index(max(rounded_percents)) rounded_percents[max_idx] += (100.0 - total_percent) # Convert back to probability format (0-1 range) emotion_probs = { self.id2label[i]: round(rounded_percents[i] / 100.0, 4) for i in range(len(probs)) } # Get music context music_context = config.emotion_to_music_mood.get(emotion, { "mood": emotion, "energy": "medium", "valence": "neutral" }) result = { "emotion": emotion, "confidence": round(confidence, 4), "probabilities": emotion_probs, "music_context": music_context, "explanation": self._generate_explanation(emotion, confidence, emotion_probs, music_context), "calibration": { "target_sarcasm_prior": self.target_sarcasm_prior, "train_sarcasm_prior": self.train_sarcasm_prior, "sarcasm_logit_shift": round(sarcasm_logit_shift, 4), "sarcasm_threshold": self.sarcasm_threshold, "sarcasm_probability": round(emotion_probs.get("sarcasm", 0.0), 4), }, } # Generate visualization if requested if create_chart and self.enable_viz: try: chart_path = self.visualizer.create_emotion_bar_chart( emotion_probs, text, show_chart=show_chart, primary_emotion=emotion ) result["chart_path"] = str(chart_path) print(f"Chart saved to: {chart_path}") except Exception as e: print(f"Could not create chart: {e}") result["chart_path"] = None return result def _generate_explanation(self, emotion: str, confidence: float, probs: dict, music_context: dict) -> str: """Generate XAI-style explanation for the prediction.""" confidence_level = "high" if confidence > 0.7 else "moderate" if confidence > 0.4 else "low" # Find secondary emotion sorted_emotions = sorted(probs.items(), key=lambda x: x[1], reverse=True) secondary = sorted_emotions[1] if len(sorted_emotions) > 1 else None # Emotion descriptors descriptors = { "sadness": "melancholic and sorrowful themes", "joy": "uplifting and celebratory content", "love": "romantic and affectionate sentiments", "anger": "intense and confrontational language", "fear": "anxious and uncertain undertones", "surprise": "unexpected and wonder-filled expressions", "neutral": "balanced and observational language", "sarcasm": "ironic or intentionally contradictory phrasing", } descriptor = descriptors.get(emotion, f"{emotion} emotional markers") explanation = ( f"Detected {descriptor} with {confidence_level} confidence ({confidence:.1%}). " f"Suggests {music_context.get('mood', emotion)} music with " f"{music_context.get('energy', 'medium')} energy." ) if secondary and secondary[1] > 0.15: explanation += f" Secondary: {secondary[0]} ({secondary[1]:.1%})." return explanation def predict_batch(self, texts: list) -> list: """Predict emotions for multiple texts.""" return [self.predict(text) for text in texts] def analyze_song(self, title: str, artist: str, lyrics: str, create_chart: bool = True) -> dict: """Full song analysis with metadata and optional visualization.""" prediction = self.predict(lyrics, create_chart=False) # We'll create a detailed chart instead result = { "song": {"title": title, "artist": artist}, "analysis": prediction, "tags": self._generate_tags(prediction), } # Generate detailed visualization for song analysis if create_chart and self.enable_viz: try: chart_path = self.visualizer.create_detailed_analysis_chart( prediction, f"{title} by {artist}" ) result["chart_path"] = str(chart_path) print(f"šŸ“Š Detailed analysis chart saved to: {chart_path}") except Exception as e: print(f"āš ļø Could not create detailed chart: {e}") result["chart_path"] = None return result def predict_with_visualization(self, text: str, chart_type: str = "simple") -> dict: """ Predict with automatic visualization. Args: text: Input text chart_type: Type of chart ('simple', 'detailed') Returns: Prediction result with chart """ if chart_type == "detailed": result = self.predict(text, create_chart=False) if self.enable_viz: try: chart_path = self.visualizer.create_detailed_analysis_chart(result, text) result["chart_path"] = str(chart_path) except Exception as e: print(f"āš ļø Could not create detailed chart: {e}") result["chart_path"] = None else: result = self.predict(text, create_chart=True, show_chart=True) return result def _generate_tags(self, prediction: dict) -> list: """Generate recommendation tags from prediction.""" tags = [] emotion = prediction["emotion"] music_ctx = prediction["music_context"] tags.append(f"emotion:{emotion}") tags.append(f"mood:{music_ctx.get('mood', emotion)}") tags.append(f"energy:{music_ctx.get('energy', 'medium')}") tags.append(f"valence:{music_ctx.get('valence', 'neutral')}") # Activity suggestions activity_map = { "joy": ["party", "workout", "celebration"], "sadness": ["reflection", "rainy-day", "comfort"], "love": ["romantic", "date-night", "slow-dance"], "anger": ["workout", "release", "intensity"], "fear": ["thriller", "suspense", "atmospheric"], "surprise": ["discovery", "adventure", "exploration"], } activities = activity_map.get(emotion, ["general"]) tags.extend([f"activity:{a}" for a in activities]) return tags def demo(): """Demo of the trained emotion predictor with visualizations.""" print("\n" + "=" * 60) print("šŸŽµ EUMORA - Emotion Analysis Demo (Custom Trained Model)") print("=" * 60) try: predictor = EmotionPredictor(enable_viz=True) except FileNotFoundError as e: print(f"\n{e}") return test_samples = [ ("Happy lyrics", "I feel so alive today, everything is wonderful and bright!"), ("Sad lyrics", "My heart is broken, tears falling like rain in the dark night"), ("Angry lyrics", "I hate this, you betrayed me, I want to scream at the world!"), ("Love lyrics", "You are my everything, I want to hold you forever my darling"), ("Fear lyrics", "Something is watching me in the shadows, I'm scared to move"), ("Surprise lyrics", "I can't believe it happened! This is incredible, wow!"), ] print(f"\nAnalyzing {len(test_samples)} samples...\n") print("-" * 60) # Create comparison charts all_results = [] for label, text in test_samples: result = predictor.predict(text, create_chart=False) # Individual charts disabled for comparison all_results.append(result) print(f"\nšŸ“ {label}:") print(f" \"{text[:50]}...\"") print(f" šŸŽ­ Emotion: {result['emotion'].upper()} ({result['confidence']:.1%})") print(f" šŸŽø Context: {result['music_context']}") # Show probability distribution (text-based) sorted_probs = sorted(result['probabilities'].items(), key=lambda x: x[1], reverse=True) print(f" šŸ“Š Distribution:") for emo, prob in sorted_probs[:3]: bar = "ā–ˆ" * int(prob * 20) print(f" {emo:>10}: {bar:<20} {prob:.1%}") print("-" * 60) # Create comparison visualization print(f"\nšŸ“Š Generating comparison chart...") try: titles = [label for label, _ in test_samples] comparison_path = predictor.visualizer.create_comparison_chart( all_results, titles, show_chart=True ) print(f"šŸ“Š Comparison chart saved to: {comparison_path}") except Exception as e: print(f"āš ļø Could not create comparison chart: {e}") print("\nāœ… Demo complete!") if __name__ == "__main__": demo()