"""Stage 2 — Orchestrator. Entry point: process(stage1_output) -> Stage2Output Pipeline: 1. Load config 2. Per segment: audio emotion + text emotion + fusion 3. Aggregate per speaker → SpeakerSummary 4. Assemble Stage2Output + save JSON """ from __future__ import annotations import logging import os import time from collections import defaultdict from pathlib import Path import torch import yaml from src.common.constants import EMOTION_LABELS from src.common.schemas import ( EmotionResult, SpeakerSummary, Stage1Output, Stage2Output, ) from src.stage2.audio_emotion import predict as audio_predict from src.stage2.fusion import fuse from src.stage2.text_emotion import predict as text_predict logger = logging.getLogger(__name__) def _load_config(config: dict | None = None) -> dict: """Load config from yaml if not provided.""" if config is not None: return config config_path = Path(__file__).parent.parent.parent / "config.yaml" if config_path.exists(): with open(config_path) as f: return yaml.safe_load(f).get("stage2", {}) logger.warning("config.yaml not found, using defaults") return {} def _get_device() -> str: """Determine compute device.""" return "cuda" if torch.cuda.is_available() else "cpu" def _aggregate_speakers( emotions: list[EmotionResult], ) -> dict[str, SpeakerSummary]: """Aggregate per-segment emotions into per-speaker summaries. For each speaker: - dominant_emotion: emotion with highest total weight - emotion_distribution: normalized weights (sum = 1.0) - avg_confidence: mean fused_confidence """ speaker_emotions: dict[str, list[EmotionResult]] = defaultdict(list) for em in emotions: speaker_emotions[em.speaker_id].append(em) summaries: dict[str, SpeakerSummary] = {} for speaker_id, speaker_ems in speaker_emotions.items(): # Count emotion occurrences weighted by confidence emotion_weights: dict[str, float] = defaultdict(float) total_confidence = 0.0 for em in speaker_ems: emotion_weights[em.fused_emotion] += em.fused_confidence total_confidence += em.fused_confidence # Normalize to sum to 1.0 total_weight = sum(emotion_weights.values()) if total_weight > 0: emotion_distribution = { k: round(v / total_weight, 4) for k, v in emotion_weights.items() } else: emotion_distribution = {"neutral": 1.0} dominant_emotion = max(emotion_distribution, key=emotion_distribution.get) avg_confidence = total_confidence / len(speaker_ems) if speaker_ems else 0.0 summaries[speaker_id] = SpeakerSummary( dominant_emotion=dominant_emotion, emotion_distribution=emotion_distribution, avg_confidence=round(avg_confidence, 4), ) return summaries def process( stage1_output: Stage1Output, config: dict | None = None, ) -> Stage2Output: """Stage 2 entry point. Args: stage1_output: Stage 1 output with segments containing audio paths and text. config: Optional config dict. If None, loads from config.yaml. Returns: Stage2Output with per-segment emotions and per-speaker summaries. """ start_time = time.time() cfg = _load_config(config) device = _get_device() audio_cfg = cfg.get("audio_emotion", {}) audio_model_id = audio_cfg.get("model", "iic/emotion2vec_plus_base") lora_onnx_path = audio_cfg.get("lora_onnx_path") text_cfg = cfg.get("text_emotion", {}) ko_model_id = text_cfg.get("korean_model") en_model_id = text_cfg.get("english_model") ko_lora_onnx = text_cfg.get("korean_lora_onnx_path") ko_lora_tokenizer = text_cfg.get("korean_lora_tokenizer") fusion_cfg = cfg.get("fusion", {}) fusion_mode = fusion_cfg.get("mode", "emotion_specific") call_id = stage1_output.call_id segments = stage1_output.segments logger.info( "Stage 2 processing %s: %d segments, device=%s", call_id, len(segments), device, ) if not segments: logger.warning("No segments to process for %s", call_id) return Stage2Output( call_id=call_id, emotions=[], speaker_summaries={}, ) # Per-segment emotion analysis emotions: list[EmotionResult] = [] for i, seg in enumerate(segments): # Audio emotion (LoRA ONNX preferred if available) audio_result = audio_predict( seg.audio_path, device=device, model_id=audio_model_id, lora_onnx_path=lora_onnx_path, ) # Text emotion (Korean uses LoRA ONNX if available) text_model_id = ko_model_id if seg.language == "ko" else en_model_id text_result = text_predict( seg.text, language=seg.language, model_id=text_model_id, korean_lora_onnx_path=ko_lora_onnx, korean_lora_tokenizer=ko_lora_tokenizer, ) # Fusion (emotion-specific by default, per-language weights) fused_result = fuse(audio_result, text_result, mode=fusion_mode, language=seg.language) emotions.append(EmotionResult( speaker_id=seg.speaker_id, segment_id=seg.segment_id, audio_emotion=audio_result["emotion"], audio_confidence=audio_result["confidence"], text_emotion=text_result["emotion"], text_confidence=text_result["confidence"], fused_emotion=fused_result["emotion"], fused_confidence=fused_result["confidence"], )) if (i + 1) % 10 == 0: logger.info(" %d/%d segments processed", i + 1, len(segments)) # Speaker aggregation speaker_summaries = _aggregate_speakers(emotions) processing_time = time.time() - start_time output = Stage2Output( call_id=call_id, emotions=emotions, speaker_summaries=speaker_summaries, ) # Save to JSON output_path = cfg.get("output_path", "data/stage2_output.json") os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) with open(output_path, "w", encoding="utf-8") as f: f.write(output.model_dump_json(indent=2)) logger.info( "Stage 2 complete: %d emotions, %d speakers, %.1fs processing time", len(emotions), len(speaker_summaries), processing_time, ) return output