| """Stage 2 — Orchestrator. |
| |
| Entry point: process(stage1_output) -> Stage2Output |
| |
| Pipeline: |
| 1. Load config |
| 2. Per segment: audio emotion + text emotion + fusion |
| 3. Aggregate per speaker → SpeakerSummary |
| 4. Assemble Stage2Output + save JSON |
| """ |
|
|
| from __future__ import annotations |
|
|
| import logging |
| import os |
| import time |
| from collections import defaultdict |
| from pathlib import Path |
|
|
| import torch |
| import yaml |
|
|
| from src.common.constants import EMOTION_LABELS |
| from src.common.schemas import ( |
| EmotionResult, |
| SpeakerSummary, |
| Stage1Output, |
| Stage2Output, |
| ) |
| from src.stage2.audio_emotion import predict as audio_predict |
| from src.stage2.fusion import fuse |
| from src.stage2.text_emotion import predict as text_predict |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def _load_config(config: dict | None = None) -> dict: |
| """Load config from yaml if not provided.""" |
| if config is not None: |
| return config |
|
|
| config_path = Path(__file__).parent.parent.parent / "config.yaml" |
| if config_path.exists(): |
| with open(config_path) as f: |
| return yaml.safe_load(f).get("stage2", {}) |
|
|
| logger.warning("config.yaml not found, using defaults") |
| return {} |
|
|
|
|
| def _get_device() -> str: |
| """Determine compute device.""" |
| return "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
| def _aggregate_speakers( |
| emotions: list[EmotionResult], |
| ) -> dict[str, SpeakerSummary]: |
| """Aggregate per-segment emotions into per-speaker summaries. |
| |
| For each speaker: |
| - dominant_emotion: emotion with highest total weight |
| - emotion_distribution: normalized weights (sum = 1.0) |
| - avg_confidence: mean fused_confidence |
| """ |
| speaker_emotions: dict[str, list[EmotionResult]] = defaultdict(list) |
| for em in emotions: |
| speaker_emotions[em.speaker_id].append(em) |
|
|
| summaries: dict[str, SpeakerSummary] = {} |
| for speaker_id, speaker_ems in speaker_emotions.items(): |
| |
| emotion_weights: dict[str, float] = defaultdict(float) |
| total_confidence = 0.0 |
|
|
| for em in speaker_ems: |
| emotion_weights[em.fused_emotion] += em.fused_confidence |
| total_confidence += em.fused_confidence |
|
|
| |
| total_weight = sum(emotion_weights.values()) |
| if total_weight > 0: |
| emotion_distribution = { |
| k: round(v / total_weight, 4) |
| for k, v in emotion_weights.items() |
| } |
| else: |
| emotion_distribution = {"neutral": 1.0} |
|
|
| dominant_emotion = max(emotion_distribution, key=emotion_distribution.get) |
| avg_confidence = total_confidence / len(speaker_ems) if speaker_ems else 0.0 |
|
|
| summaries[speaker_id] = SpeakerSummary( |
| dominant_emotion=dominant_emotion, |
| emotion_distribution=emotion_distribution, |
| avg_confidence=round(avg_confidence, 4), |
| ) |
|
|
| return summaries |
|
|
|
|
| def process( |
| stage1_output: Stage1Output, |
| config: dict | None = None, |
| ) -> Stage2Output: |
| """Stage 2 entry point. |
| |
| Args: |
| stage1_output: Stage 1 output with segments containing audio paths and text. |
| config: Optional config dict. If None, loads from config.yaml. |
| |
| Returns: |
| Stage2Output with per-segment emotions and per-speaker summaries. |
| """ |
| start_time = time.time() |
| cfg = _load_config(config) |
| device = _get_device() |
|
|
| audio_cfg = cfg.get("audio_emotion", {}) |
| audio_model_id = audio_cfg.get("model", "iic/emotion2vec_plus_base") |
| lora_onnx_path = audio_cfg.get("lora_onnx_path") |
|
|
| text_cfg = cfg.get("text_emotion", {}) |
| ko_model_id = text_cfg.get("korean_model") |
| en_model_id = text_cfg.get("english_model") |
| ko_lora_onnx = text_cfg.get("korean_lora_onnx_path") |
| ko_lora_tokenizer = text_cfg.get("korean_lora_tokenizer") |
|
|
| fusion_cfg = cfg.get("fusion", {}) |
| fusion_mode = fusion_cfg.get("mode", "emotion_specific") |
|
|
| call_id = stage1_output.call_id |
| segments = stage1_output.segments |
|
|
| logger.info( |
| "Stage 2 processing %s: %d segments, device=%s", |
| call_id, len(segments), device, |
| ) |
|
|
| if not segments: |
| logger.warning("No segments to process for %s", call_id) |
| return Stage2Output( |
| call_id=call_id, |
| emotions=[], |
| speaker_summaries={}, |
| ) |
|
|
| |
| emotions: list[EmotionResult] = [] |
| for i, seg in enumerate(segments): |
| |
| audio_result = audio_predict( |
| seg.audio_path, device=device, model_id=audio_model_id, |
| lora_onnx_path=lora_onnx_path, |
| ) |
|
|
| |
| text_model_id = ko_model_id if seg.language == "ko" else en_model_id |
| text_result = text_predict( |
| seg.text, language=seg.language, model_id=text_model_id, |
| korean_lora_onnx_path=ko_lora_onnx, |
| korean_lora_tokenizer=ko_lora_tokenizer, |
| ) |
|
|
| |
| fused_result = fuse(audio_result, text_result, mode=fusion_mode, language=seg.language) |
|
|
| emotions.append(EmotionResult( |
| speaker_id=seg.speaker_id, |
| segment_id=seg.segment_id, |
| audio_emotion=audio_result["emotion"], |
| audio_confidence=audio_result["confidence"], |
| text_emotion=text_result["emotion"], |
| text_confidence=text_result["confidence"], |
| fused_emotion=fused_result["emotion"], |
| fused_confidence=fused_result["confidence"], |
| )) |
|
|
| if (i + 1) % 10 == 0: |
| logger.info(" %d/%d segments processed", i + 1, len(segments)) |
|
|
| |
| speaker_summaries = _aggregate_speakers(emotions) |
|
|
| processing_time = time.time() - start_time |
|
|
| output = Stage2Output( |
| call_id=call_id, |
| emotions=emotions, |
| speaker_summaries=speaker_summaries, |
| ) |
|
|
| |
| output_path = cfg.get("output_path", "data/stage2_output.json") |
| os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) |
| with open(output_path, "w", encoding="utf-8") as f: |
| f.write(output.model_dump_json(indent=2)) |
|
|
| logger.info( |
| "Stage 2 complete: %d emotions, %d speakers, %.1fs processing time", |
| len(emotions), len(speaker_summaries), processing_time, |
| ) |
|
|
| return output |
|
|