ustwo-api / src /stage2 /fusion.py
asdfasdfqrqwer's picture
Deploy from GitHub 2026-04-23T03:56:31Z
c857b85
Raw
History Blame Contribute Delete
2.47 kB
"""Stage 2 — Audio + Text Emotion Fusion.
Uses emotion-specific fusion weights per class (from Korean 263 val grid search).
Falls back to fixed 60/40 if mode="fixed".
"""
from __future__ import annotations
from src.common.constants import (
EMOTION_FUSION_WEIGHTS_EN,
EMOTION_FUSION_WEIGHTS_KO,
FUSION_WEIGHTS,
)
PROJECT_LABELS = ["neutral", "joy", "sadness", "anger", "surprise", "fear", "disgust"]
def fuse(
audio_scores: dict[str, float],
text_scores: dict[str, float],
mode: str = "emotion_specific",
language: str = "ko",
) -> dict:
"""Fuse audio and text emotion predictions.
Args:
audio_scores: {"emotion": str, "confidence": float, "scores": dict}
text_scores: {"emotion": str, "confidence": float, "scores": dict}
If text was empty, scores will be uniform distribution.
mode: "emotion_specific" (per-class weights) or "fixed" (60/40)
language: "ko" or "en" — selects language-specific emotion weights.
Returns:
{"emotion": str, "confidence": float, "scores": dict[str, float]}
"""
a_scores = audio_scores.get("scores", {})
t_scores = text_scores.get("scores", {})
# If text confidence is very low (empty text), use audio only
text_confidence = text_scores.get("confidence", 0.0)
audio_only = text_confidence <= 0.1
weights_table = EMOTION_FUSION_WEIGHTS_EN if language == "en" else EMOTION_FUSION_WEIGHTS_KO
# Weighted average of score distributions
fused = {}
for label in PROJECT_LABELS:
a = a_scores.get(label, 0.0)
t = t_scores.get(label, 0.0)
if audio_only:
audio_w, text_w = 1.0, 0.0
elif mode == "emotion_specific":
w = weights_table.get(label, {"audio": 0.6, "text": 0.4})
audio_w, text_w = w["audio"], w["text"]
else:
audio_w = FUSION_WEIGHTS["audio"]
text_w = FUSION_WEIGHTS["text"]
fused[label] = a * audio_w + t * text_w
# Normalize to sum to 1.0
total = sum(fused.values())
if total > 0:
fused = {k: v / total for k, v in fused.items()}
else:
fused = {label: 1.0 / len(PROJECT_LABELS) for label in PROJECT_LABELS}
top_label = max(fused, key=fused.get)
confidence = fused[top_label]
return {
"emotion": top_label,
"confidence": round(confidence, 4),
"scores": {k: round(v, 4) for k, v in fused.items()},
}