| """Stage 2 — Audio + Text Emotion Fusion. |
| |
| Uses emotion-specific fusion weights per class (from Korean 263 val grid search). |
| Falls back to fixed 60/40 if mode="fixed". |
| """ |
|
|
| from __future__ import annotations |
|
|
| from src.common.constants import ( |
| EMOTION_FUSION_WEIGHTS_EN, |
| EMOTION_FUSION_WEIGHTS_KO, |
| FUSION_WEIGHTS, |
| ) |
|
|
| PROJECT_LABELS = ["neutral", "joy", "sadness", "anger", "surprise", "fear", "disgust"] |
|
|
|
|
| def fuse( |
| audio_scores: dict[str, float], |
| text_scores: dict[str, float], |
| mode: str = "emotion_specific", |
| language: str = "ko", |
| ) -> dict: |
| """Fuse audio and text emotion predictions. |
| |
| Args: |
| audio_scores: {"emotion": str, "confidence": float, "scores": dict} |
| text_scores: {"emotion": str, "confidence": float, "scores": dict} |
| If text was empty, scores will be uniform distribution. |
| mode: "emotion_specific" (per-class weights) or "fixed" (60/40) |
| language: "ko" or "en" — selects language-specific emotion weights. |
| |
| Returns: |
| {"emotion": str, "confidence": float, "scores": dict[str, float]} |
| """ |
| a_scores = audio_scores.get("scores", {}) |
| t_scores = text_scores.get("scores", {}) |
|
|
| |
| text_confidence = text_scores.get("confidence", 0.0) |
| audio_only = text_confidence <= 0.1 |
|
|
| weights_table = EMOTION_FUSION_WEIGHTS_EN if language == "en" else EMOTION_FUSION_WEIGHTS_KO |
|
|
| |
| fused = {} |
| for label in PROJECT_LABELS: |
| a = a_scores.get(label, 0.0) |
| t = t_scores.get(label, 0.0) |
|
|
| if audio_only: |
| audio_w, text_w = 1.0, 0.0 |
| elif mode == "emotion_specific": |
| w = weights_table.get(label, {"audio": 0.6, "text": 0.4}) |
| audio_w, text_w = w["audio"], w["text"] |
| else: |
| audio_w = FUSION_WEIGHTS["audio"] |
| text_w = FUSION_WEIGHTS["text"] |
|
|
| fused[label] = a * audio_w + t * text_w |
|
|
| |
| total = sum(fused.values()) |
| if total > 0: |
| fused = {k: v / total for k, v in fused.items()} |
| else: |
| fused = {label: 1.0 / len(PROJECT_LABELS) for label in PROJECT_LABELS} |
|
|
| top_label = max(fused, key=fused.get) |
| confidence = fused[top_label] |
|
|
| return { |
| "emotion": top_label, |
| "confidence": round(confidence, 4), |
| "scores": {k: round(v, 4) for k, v in fused.items()}, |
| } |
|
|