eumora-api / backend /src /predict.py
VivDubs's picture
refactor: move backend files into backend/ directory
9eb5faa
Raw
History Blame Contribute Delete
18.8 kB
"""Inference module for emotion prediction using our trained model."""
import math
import os
from pathlib import Path
from typing import Optional, Tuple
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import torch.nn.functional as F
from .config import config
from .visualize import EmotionVisualizer
class EmotionPredictor:
"""
Predict emotions from text using our custom trained model.
Trained on dair-ai/emotion dataset with 16,000+ samples.
Emotions: sadness, joy, love, anger, fear, surprise
"""
def __init__(
self,
model_path: Path = None,
enable_viz: bool = True,
target_sarcasm_prior: Optional[float] = config.target_sarcasm_prior,
sarcasm_threshold: Optional[float] = config.sarcasm_threshold,
train_sarcasm_prior: Optional[float] = None,
):
"""
Initialize the emotion predictor.
Args:
model_path: Path to trained model directory.
Defaults to models/emotion_classifier/final
enable_viz: Whether to enable chart visualization
target_sarcasm_prior: Target sarcasm prevalence in deployment text (0-1)
sarcasm_threshold: Optional decision threshold for sarcasm class (0-1)
train_sarcasm_prior: Optional override for sarcasm prevalence seen in training
"""
if not model_path:
# Dynamically fetch the latest modified folder in models/ (or default to emotion_classifier/final)
models_dir = config.model_dir
subdirs = [d for d in models_dir.iterdir() if d.is_dir() and "emotion_classifier" in d.name]
if subdirs:
latest_dir = max(subdirs, key=lambda d: d.stat().st_mtime)
# If there's a final folder inside it, use it
if (latest_dir / "final").exists():
self.model_path = latest_dir / "final"
else:
self.model_path = latest_dir
else:
self.model_path = config.model_dir / "emotion_classifier" / "final"
else:
self.model_path = Path(model_path)
self.device = self._get_device()
self.enable_viz = enable_viz
self.target_sarcasm_prior = self._validate_probability(target_sarcasm_prior, "target_sarcasm_prior")
self.sarcasm_threshold = self._validate_probability(sarcasm_threshold, "sarcasm_threshold", allow_none=True)
self.user_train_sarcasm_prior = self._validate_probability(
train_sarcasm_prior,
"train_sarcasm_prior",
allow_none=True,
)
# Initialize visualizer if enabled
if self.enable_viz:
self.visualizer = EmotionVisualizer()
self._load_model()
@staticmethod
def _validate_probability(value: Optional[float], name: str, allow_none: bool = True) -> Optional[float]:
"""Validate probability-like arguments."""
if value is None:
if allow_none:
return None
raise ValueError(f"{name} cannot be None")
value = float(value)
if not 0.0 < value < 1.0:
raise ValueError(f"{name} must be in the open interval (0, 1), got {value}")
return value
def _get_device(self):
"""Get the best available device."""
if torch.cuda.is_available():
return torch.device("cuda")
elif torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
def _load_model(self):
"""Load model and tokenizer, with optional Hugging Face Hub fallback.
If the local model path doesn't exist, checks the HF_MODEL_REPO env var
(e.g. "username/eumora-emotion-classifier") and downloads from HF Hub.
Use HF_TOKEN for private repos.
"""
if self.model_path.exists():
source = str(self.model_path)
else:
hf_repo = os.environ.get("HF_MODEL_REPO")
if hf_repo:
source = hf_repo
print(f"Local model not found. Loading from HuggingFace Hub: {hf_repo}")
else:
raise FileNotFoundError(
f"❌ No trained model found at {self.model_path}\n"
f" Either run 'python main.py train' or set HF_MODEL_REPO env var."
)
hf_token = os.environ.get("HF_TOKEN") or None
print(f"Loading model from: {source}")
self.tokenizer = AutoTokenizer.from_pretrained(source, token=hf_token)
self.model = AutoModelForSequenceClassification.from_pretrained(source, token=hf_token)
self.model.to(self.device)
self.model.eval()
# Get label mappings from model config
self.id2label = {int(k): v for k, v in self.model.config.id2label.items()}
self.label2id = {k: int(v) for k, v in self.model.config.label2id.items()}
self.sarcasm_idx = self.label2id.get("sarcasm")
self.train_sarcasm_prior = self._resolve_train_sarcasm_prior()
if self.sarcasm_idx is None and (self.target_sarcasm_prior is not None or self.sarcasm_threshold is not None):
print("⚠️ Loaded model has no 'sarcasm' class. Prior adjustment and sarcasm threshold are disabled.")
def _resolve_train_sarcasm_prior(self) -> Optional[float]:
"""Resolve training sarcasm prior from explicit override, model metadata, or config fallback."""
if self.user_train_sarcasm_prior is not None:
return self.user_train_sarcasm_prior
model_prior = getattr(self.model.config, "sarcasm_train_prior", None)
if model_prior is not None:
try:
return self._validate_probability(model_prior, "model.sarcasm_train_prior", allow_none=False)
except ValueError:
pass
if self.sarcasm_idx is not None:
return self._validate_probability(
config.assumed_train_sarcasm_prior,
"config.assumed_train_sarcasm_prior",
allow_none=False,
)
return None
def _compute_sarcasm_logit_shift(self) -> float:
"""Compute prior-shift logit adjustment for sarcasm vs non-sarcasm."""
if self.sarcasm_idx is None or self.target_sarcasm_prior is None or self.train_sarcasm_prior is None:
return 0.0
target_odds = self.target_sarcasm_prior / (1.0 - self.target_sarcasm_prior)
train_odds = self.train_sarcasm_prior / (1.0 - self.train_sarcasm_prior)
return math.log(target_odds) - math.log(train_odds)
def _apply_sarcasm_prior_adjustment(self, logits: torch.Tensor) -> Tuple[torch.Tensor, float]:
"""Shift sarcasm logit to better match deployment prevalence."""
shift = self._compute_sarcasm_logit_shift()
if self.sarcasm_idx is None or shift == 0.0:
return logits, 0.0
adjusted = logits.clone()
adjusted[self.sarcasm_idx] = adjusted[self.sarcasm_idx] + shift
return adjusted, shift
def _apply_sarcasm_threshold(self, probs: torch.Tensor, pred_idx: int) -> int:
"""Apply optional one-vs-rest sarcasm thresholding to final class decision."""
if self.sarcasm_idx is None or self.sarcasm_threshold is None:
return pred_idx
sarcasm_prob = probs[self.sarcasm_idx].item()
if sarcasm_prob >= self.sarcasm_threshold:
return self.sarcasm_idx
if pred_idx == self.sarcasm_idx:
non_sarcasm_probs = probs.clone()
non_sarcasm_probs[self.sarcasm_idx] = -1.0
return torch.argmax(non_sarcasm_probs).item()
return pred_idx
def predict(self, text: str, create_chart: bool = False, show_chart: bool = True) -> dict:
"""
Predict emotion from text.
Args:
text: Input text (lyrics, sentence, etc.)
create_chart: Whether to generate a visualization chart
show_chart: Whether to display the chart (only if create_chart=True)
Returns:
dict with emotion, confidence, probabilities, explanation, and optional chart_path
"""
# Tokenize
inputs = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=config.max_length,
padding=True
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Predict
with torch.no_grad():
outputs = self.model(**inputs)
# Get probabilities
logits = outputs.logits[0].detach().cpu()
adjusted_logits, sarcasm_logit_shift = self._apply_sarcasm_prior_adjustment(logits)
probs = torch.softmax(adjusted_logits, dim=-1)
pred_idx = torch.argmax(probs).item()
# Hardcode 8% sarcasm tripwire
if self.sarcasm_idx is not None and probs[self.sarcasm_idx].item() >= 0.08:
if pred_idx != self.sarcasm_idx:
# Swap the max probability with sarcasm so it shows up as the dominant emotion visually
original_max_val = probs[pred_idx].item()
sarcasm_val = probs[self.sarcasm_idx].item()
probs[pred_idx] = sarcasm_val
probs[self.sarcasm_idx] = original_max_val
pred_idx = self.sarcasm_idx
confidence = probs[pred_idx].item()
# Build results
emotion = self.id2label[pred_idx]
# Ensure probabilities sum to exactly 100% when displayed
raw_probs = [probs[i].item() for i in range(len(probs))]
# Convert to percentages and round, then normalize to ensure 100% total
percent_probs = [p * 100 for p in raw_probs]
rounded_percents = [round(p, 1) for p in percent_probs]
# Adjust largest value to make sum exactly 100.0%
total_percent = sum(rounded_percents)
if total_percent != 100.0:
max_idx = rounded_percents.index(max(rounded_percents))
rounded_percents[max_idx] += (100.0 - total_percent)
# Convert back to probability format (0-1 range)
emotion_probs = {
self.id2label[i]: round(rounded_percents[i] / 100.0, 4)
for i in range(len(probs))
}
# Get music context
music_context = config.emotion_to_music_mood.get(emotion, {
"mood": emotion, "energy": "medium", "valence": "neutral"
})
result = {
"emotion": emotion,
"confidence": round(confidence, 4),
"probabilities": emotion_probs,
"music_context": music_context,
"explanation": self._generate_explanation(emotion, confidence, emotion_probs, music_context),
"calibration": {
"target_sarcasm_prior": self.target_sarcasm_prior,
"train_sarcasm_prior": self.train_sarcasm_prior,
"sarcasm_logit_shift": round(sarcasm_logit_shift, 4),
"sarcasm_threshold": self.sarcasm_threshold,
"sarcasm_probability": round(emotion_probs.get("sarcasm", 0.0), 4),
},
}
# Generate visualization if requested
if create_chart and self.enable_viz:
try:
chart_path = self.visualizer.create_emotion_bar_chart(
emotion_probs, text, show_chart=show_chart, primary_emotion=emotion
)
result["chart_path"] = str(chart_path)
print(f"Chart saved to: {chart_path}")
except Exception as e:
print(f"Could not create chart: {e}")
result["chart_path"] = None
return result
def _generate_explanation(self, emotion: str, confidence: float,
probs: dict, music_context: dict) -> str:
"""Generate XAI-style explanation for the prediction."""
confidence_level = "high" if confidence > 0.7 else "moderate" if confidence > 0.4 else "low"
# Find secondary emotion
sorted_emotions = sorted(probs.items(), key=lambda x: x[1], reverse=True)
secondary = sorted_emotions[1] if len(sorted_emotions) > 1 else None
# Emotion descriptors
descriptors = {
"sadness": "melancholic and sorrowful themes",
"joy": "uplifting and celebratory content",
"love": "romantic and affectionate sentiments",
"anger": "intense and confrontational language",
"fear": "anxious and uncertain undertones",
"surprise": "unexpected and wonder-filled expressions",
"neutral": "balanced and observational language",
"sarcasm": "ironic or intentionally contradictory phrasing",
}
descriptor = descriptors.get(emotion, f"{emotion} emotional markers")
explanation = (
f"Detected {descriptor} with {confidence_level} confidence ({confidence:.1%}). "
f"Suggests {music_context.get('mood', emotion)} music with "
f"{music_context.get('energy', 'medium')} energy."
)
if secondary and secondary[1] > 0.15:
explanation += f" Secondary: {secondary[0]} ({secondary[1]:.1%})."
return explanation
def predict_batch(self, texts: list) -> list:
"""Predict emotions for multiple texts."""
return [self.predict(text) for text in texts]
def analyze_song(self, title: str, artist: str, lyrics: str, create_chart: bool = True) -> dict:
"""Full song analysis with metadata and optional visualization."""
prediction = self.predict(lyrics, create_chart=False) # We'll create a detailed chart instead
result = {
"song": {"title": title, "artist": artist},
"analysis": prediction,
"tags": self._generate_tags(prediction),
}
# Generate detailed visualization for song analysis
if create_chart and self.enable_viz:
try:
chart_path = self.visualizer.create_detailed_analysis_chart(
prediction, f"{title} by {artist}"
)
result["chart_path"] = str(chart_path)
print(f"πŸ“Š Detailed analysis chart saved to: {chart_path}")
except Exception as e:
print(f"⚠️ Could not create detailed chart: {e}")
result["chart_path"] = None
return result
def predict_with_visualization(self, text: str, chart_type: str = "simple") -> dict:
"""
Predict with automatic visualization.
Args:
text: Input text
chart_type: Type of chart ('simple', 'detailed')
Returns:
Prediction result with chart
"""
if chart_type == "detailed":
result = self.predict(text, create_chart=False)
if self.enable_viz:
try:
chart_path = self.visualizer.create_detailed_analysis_chart(result, text)
result["chart_path"] = str(chart_path)
except Exception as e:
print(f"⚠️ Could not create detailed chart: {e}")
result["chart_path"] = None
else:
result = self.predict(text, create_chart=True, show_chart=True)
return result
def _generate_tags(self, prediction: dict) -> list:
"""Generate recommendation tags from prediction."""
tags = []
emotion = prediction["emotion"]
music_ctx = prediction["music_context"]
tags.append(f"emotion:{emotion}")
tags.append(f"mood:{music_ctx.get('mood', emotion)}")
tags.append(f"energy:{music_ctx.get('energy', 'medium')}")
tags.append(f"valence:{music_ctx.get('valence', 'neutral')}")
# Activity suggestions
activity_map = {
"joy": ["party", "workout", "celebration"],
"sadness": ["reflection", "rainy-day", "comfort"],
"love": ["romantic", "date-night", "slow-dance"],
"anger": ["workout", "release", "intensity"],
"fear": ["thriller", "suspense", "atmospheric"],
"surprise": ["discovery", "adventure", "exploration"],
}
activities = activity_map.get(emotion, ["general"])
tags.extend([f"activity:{a}" for a in activities])
return tags
def demo():
"""Demo of the trained emotion predictor with visualizations."""
print("\n" + "=" * 60)
print("🎡 EUMORA - Emotion Analysis Demo (Custom Trained Model)")
print("=" * 60)
try:
predictor = EmotionPredictor(enable_viz=True)
except FileNotFoundError as e:
print(f"\n{e}")
return
test_samples = [
("Happy lyrics", "I feel so alive today, everything is wonderful and bright!"),
("Sad lyrics", "My heart is broken, tears falling like rain in the dark night"),
("Angry lyrics", "I hate this, you betrayed me, I want to scream at the world!"),
("Love lyrics", "You are my everything, I want to hold you forever my darling"),
("Fear lyrics", "Something is watching me in the shadows, I'm scared to move"),
("Surprise lyrics", "I can't believe it happened! This is incredible, wow!"),
]
print(f"\nAnalyzing {len(test_samples)} samples...\n")
print("-" * 60)
# Create comparison charts
all_results = []
for label, text in test_samples:
result = predictor.predict(text, create_chart=False) # Individual charts disabled for comparison
all_results.append(result)
print(f"\nπŸ“ {label}:")
print(f" \"{text[:50]}...\"")
print(f" 🎭 Emotion: {result['emotion'].upper()} ({result['confidence']:.1%})")
print(f" 🎸 Context: {result['music_context']}")
# Show probability distribution (text-based)
sorted_probs = sorted(result['probabilities'].items(), key=lambda x: x[1], reverse=True)
print(f" πŸ“Š Distribution:")
for emo, prob in sorted_probs[:3]:
bar = "β–ˆ" * int(prob * 20)
print(f" {emo:>10}: {bar:<20} {prob:.1%}")
print("-" * 60)
# Create comparison visualization
print(f"\nπŸ“Š Generating comparison chart...")
try:
titles = [label for label, _ in test_samples]
comparison_path = predictor.visualizer.create_comparison_chart(
all_results, titles, show_chart=True
)
print(f"πŸ“Š Comparison chart saved to: {comparison_path}")
except Exception as e:
print(f"⚠️ Could not create comparison chart: {e}")
print("\nβœ… Demo complete!")
if __name__ == "__main__":
demo()