ustwo-api / src /stage2 /process.py
asdfasdfqrqwer's picture
Deploy from GitHub 2026-04-23T03:56:31Z
c857b85
Raw
History Blame Contribute Delete
6.5 kB
"""Stage 2 — Orchestrator.
Entry point: process(stage1_output) -> Stage2Output
Pipeline:
1. Load config
2. Per segment: audio emotion + text emotion + fusion
3. Aggregate per speaker → SpeakerSummary
4. Assemble Stage2Output + save JSON
"""
from __future__ import annotations
import logging
import os
import time
from collections import defaultdict
from pathlib import Path
import torch
import yaml
from src.common.constants import EMOTION_LABELS
from src.common.schemas import (
EmotionResult,
SpeakerSummary,
Stage1Output,
Stage2Output,
)
from src.stage2.audio_emotion import predict as audio_predict
from src.stage2.fusion import fuse
from src.stage2.text_emotion import predict as text_predict
logger = logging.getLogger(__name__)
def _load_config(config: dict | None = None) -> dict:
"""Load config from yaml if not provided."""
if config is not None:
return config
config_path = Path(__file__).parent.parent.parent / "config.yaml"
if config_path.exists():
with open(config_path) as f:
return yaml.safe_load(f).get("stage2", {})
logger.warning("config.yaml not found, using defaults")
return {}
def _get_device() -> str:
"""Determine compute device."""
return "cuda" if torch.cuda.is_available() else "cpu"
def _aggregate_speakers(
emotions: list[EmotionResult],
) -> dict[str, SpeakerSummary]:
"""Aggregate per-segment emotions into per-speaker summaries.
For each speaker:
- dominant_emotion: emotion with highest total weight
- emotion_distribution: normalized weights (sum = 1.0)
- avg_confidence: mean fused_confidence
"""
speaker_emotions: dict[str, list[EmotionResult]] = defaultdict(list)
for em in emotions:
speaker_emotions[em.speaker_id].append(em)
summaries: dict[str, SpeakerSummary] = {}
for speaker_id, speaker_ems in speaker_emotions.items():
# Count emotion occurrences weighted by confidence
emotion_weights: dict[str, float] = defaultdict(float)
total_confidence = 0.0
for em in speaker_ems:
emotion_weights[em.fused_emotion] += em.fused_confidence
total_confidence += em.fused_confidence
# Normalize to sum to 1.0
total_weight = sum(emotion_weights.values())
if total_weight > 0:
emotion_distribution = {
k: round(v / total_weight, 4)
for k, v in emotion_weights.items()
}
else:
emotion_distribution = {"neutral": 1.0}
dominant_emotion = max(emotion_distribution, key=emotion_distribution.get)
avg_confidence = total_confidence / len(speaker_ems) if speaker_ems else 0.0
summaries[speaker_id] = SpeakerSummary(
dominant_emotion=dominant_emotion,
emotion_distribution=emotion_distribution,
avg_confidence=round(avg_confidence, 4),
)
return summaries
def process(
stage1_output: Stage1Output,
config: dict | None = None,
) -> Stage2Output:
"""Stage 2 entry point.
Args:
stage1_output: Stage 1 output with segments containing audio paths and text.
config: Optional config dict. If None, loads from config.yaml.
Returns:
Stage2Output with per-segment emotions and per-speaker summaries.
"""
start_time = time.time()
cfg = _load_config(config)
device = _get_device()
audio_cfg = cfg.get("audio_emotion", {})
audio_model_id = audio_cfg.get("model", "iic/emotion2vec_plus_base")
lora_onnx_path = audio_cfg.get("lora_onnx_path")
text_cfg = cfg.get("text_emotion", {})
ko_model_id = text_cfg.get("korean_model")
en_model_id = text_cfg.get("english_model")
ko_lora_onnx = text_cfg.get("korean_lora_onnx_path")
ko_lora_tokenizer = text_cfg.get("korean_lora_tokenizer")
fusion_cfg = cfg.get("fusion", {})
fusion_mode = fusion_cfg.get("mode", "emotion_specific")
call_id = stage1_output.call_id
segments = stage1_output.segments
logger.info(
"Stage 2 processing %s: %d segments, device=%s",
call_id, len(segments), device,
)
if not segments:
logger.warning("No segments to process for %s", call_id)
return Stage2Output(
call_id=call_id,
emotions=[],
speaker_summaries={},
)
# Per-segment emotion analysis
emotions: list[EmotionResult] = []
for i, seg in enumerate(segments):
# Audio emotion (LoRA ONNX preferred if available)
audio_result = audio_predict(
seg.audio_path, device=device, model_id=audio_model_id,
lora_onnx_path=lora_onnx_path,
)
# Text emotion (Korean uses LoRA ONNX if available)
text_model_id = ko_model_id if seg.language == "ko" else en_model_id
text_result = text_predict(
seg.text, language=seg.language, model_id=text_model_id,
korean_lora_onnx_path=ko_lora_onnx,
korean_lora_tokenizer=ko_lora_tokenizer,
)
# Fusion (emotion-specific by default, per-language weights)
fused_result = fuse(audio_result, text_result, mode=fusion_mode, language=seg.language)
emotions.append(EmotionResult(
speaker_id=seg.speaker_id,
segment_id=seg.segment_id,
audio_emotion=audio_result["emotion"],
audio_confidence=audio_result["confidence"],
text_emotion=text_result["emotion"],
text_confidence=text_result["confidence"],
fused_emotion=fused_result["emotion"],
fused_confidence=fused_result["confidence"],
))
if (i + 1) % 10 == 0:
logger.info(" %d/%d segments processed", i + 1, len(segments))
# Speaker aggregation
speaker_summaries = _aggregate_speakers(emotions)
processing_time = time.time() - start_time
output = Stage2Output(
call_id=call_id,
emotions=emotions,
speaker_summaries=speaker_summaries,
)
# Save to JSON
output_path = cfg.get("output_path", "data/stage2_output.json")
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
f.write(output.model_dump_json(indent=2))
logger.info(
"Stage 2 complete: %d emotions, %d speakers, %.1fs processing time",
len(emotions), len(speaker_summaries), processing_time,
)
return output