"""Stage 2 — Audio + Text Emotion Fusion. Uses emotion-specific fusion weights per class (from Korean 263 val grid search). Falls back to fixed 60/40 if mode="fixed". """ from __future__ import annotations from src.common.constants import ( EMOTION_FUSION_WEIGHTS_EN, EMOTION_FUSION_WEIGHTS_KO, FUSION_WEIGHTS, ) PROJECT_LABELS = ["neutral", "joy", "sadness", "anger", "surprise", "fear", "disgust"] def fuse( audio_scores: dict[str, float], text_scores: dict[str, float], mode: str = "emotion_specific", language: str = "ko", ) -> dict: """Fuse audio and text emotion predictions. Args: audio_scores: {"emotion": str, "confidence": float, "scores": dict} text_scores: {"emotion": str, "confidence": float, "scores": dict} If text was empty, scores will be uniform distribution. mode: "emotion_specific" (per-class weights) or "fixed" (60/40) language: "ko" or "en" — selects language-specific emotion weights. Returns: {"emotion": str, "confidence": float, "scores": dict[str, float]} """ a_scores = audio_scores.get("scores", {}) t_scores = text_scores.get("scores", {}) # If text confidence is very low (empty text), use audio only text_confidence = text_scores.get("confidence", 0.0) audio_only = text_confidence <= 0.1 weights_table = EMOTION_FUSION_WEIGHTS_EN if language == "en" else EMOTION_FUSION_WEIGHTS_KO # Weighted average of score distributions fused = {} for label in PROJECT_LABELS: a = a_scores.get(label, 0.0) t = t_scores.get(label, 0.0) if audio_only: audio_w, text_w = 1.0, 0.0 elif mode == "emotion_specific": w = weights_table.get(label, {"audio": 0.6, "text": 0.4}) audio_w, text_w = w["audio"], w["text"] else: audio_w = FUSION_WEIGHTS["audio"] text_w = FUSION_WEIGHTS["text"] fused[label] = a * audio_w + t * text_w # Normalize to sum to 1.0 total = sum(fused.values()) if total > 0: fused = {k: v / total for k, v in fused.items()} else: fused = {label: 1.0 / len(PROJECT_LABELS) for label in PROJECT_LABELS} top_label = max(fused, key=fused.get) confidence = fused[top_label] return { "emotion": top_label, "confidence": round(confidence, 4), "scores": {k: round(v, 4) for k, v in fused.items()}, }