"""Stage 1 — Orchestrator. Entry point: process(audio_path) -> Stage1Output Pipeline: 1. Audio validation & preprocessing 2. Speaker diarization (pyannote) 3. ASR transcription (Faster-Whisper via WhisperX) 4. Forced alignment (WhisperX wav2vec2) 5. Diarization-transcript merge 6. Language detection (SenseVoice-Small) 7. Segment audio extraction 8. Output assembly """ from __future__ import annotations import logging import os import time import uuid from pathlib import Path import torch import torchaudio import yaml from src.common.schemas import ( Models, ProcessingInfo, Segment, Stage1Output, ) from src.stage1.diarization import diarize from src.stage1.language_id import decide_language_from_text, detect_languages from src.stage1.transcription import align, assign_speakers, transcribe logger = logging.getLogger(__name__) SUPPORTED_EXTENSIONS = {".wav", ".mp3", ".m4a", ".ogg"} def _load_config(config: dict | None = None) -> dict: """Load config from yaml if not provided.""" if config is not None: return config config_path = Path(__file__).parent.parent.parent / "config.yaml" if config_path.exists(): with open(config_path) as f: return yaml.safe_load(f).get("stage1", {}) logger.warning("config.yaml not found, using defaults") return {} def _get_device() -> str: """Determine compute device.""" return "cuda" if torch.cuda.is_available() else "cpu" def _validate_audio(audio_path: str, config: dict) -> None: """Step 1a: Validate audio file.""" path = Path(audio_path) if not path.exists(): raise FileNotFoundError(f"Audio file not found: {audio_path}") if path.suffix.lower() not in SUPPORTED_EXTENSIONS: raise ValueError( f"Unsupported audio format: {path.suffix}. " f"Supported: {SUPPORTED_EXTENSIONS}" ) def _preprocess_audio(audio_path: str, config: dict) -> tuple[str, float]: """Step 1b: Convert to 16kHz mono WAV, return (preprocessed_path, duration).""" preprocess_cfg = config.get("preprocessing", {}) target_sr = preprocess_cfg.get("target_sample_rate", 16000) min_dur = preprocess_cfg.get("min_duration_sec", 3) max_dur = preprocess_cfg.get("max_duration_sec", 300) waveform, sr = torchaudio.load(audio_path) duration = waveform.shape[1] / sr if duration < min_dur: raise ValueError(f"Audio too short: {duration:.1f}s (min: {min_dur}s)") if duration > max_dur: raise ValueError(f"Audio too long: {duration:.1f}s (max: {max_dur}s)") # Resample if needed if sr != target_sr: resampler = torchaudio.transforms.Resample(sr, target_sr) waveform = resampler(waveform) # Stereo to mono if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) # Peak normalization — handle volume differences across devices peak = waveform.abs().max() if peak > 0: target_peak = preprocess_cfg.get("target_peak", 0.95) waveform = waveform * (target_peak / peak) # Save preprocessed WAV preprocessed_path = audio_path if sr != target_sr or Path(audio_path).suffix.lower() != ".wav": preprocessed_dir = Path(config.get("segments_dir", "data/segments")) preprocessed_dir.mkdir(parents=True, exist_ok=True) preprocessed_path = str(preprocessed_dir / f"_preprocessed_{uuid.uuid4().hex[:8]}.wav") torchaudio.save(preprocessed_path, waveform, target_sr) return preprocessed_path, duration def _extract_segment_audio( waveform: torch.Tensor, sr: int, segments: list[dict], call_id: str, segments_dir: str, ) -> list[str]: """Step 7: Extract audio clips for each segment.""" seg_dir = Path(segments_dir) / call_id seg_dir.mkdir(parents=True, exist_ok=True) paths: list[str] = [] for seg in segments: idx = seg["segment_id"] speaker = seg["speaker_id"] start_sample = int(seg["start"] * sr) end_sample = int(seg["end"] * sr) # Clamp to valid range start_sample = max(0, start_sample) end_sample = min(waveform.shape[1], end_sample) if end_sample <= start_sample: paths.append("") continue clip = waveform[:, start_sample:end_sample] filename = f"seg_{idx:03d}_{speaker}.wav" clip_path = str(seg_dir / filename) torchaudio.save(clip_path, clip, sr) paths.append(clip_path) return paths def process(audio_path: str, config: dict | None = None) -> Stage1Output: """Stage 1 entry point. Args: audio_path: Path to input audio file (.wav/.mp3/.m4a/.ogg) config: Optional config dict. If None, loads from config.yaml. Returns: Stage1Output with call_id, duration, and speaker-segmented transcript. Raises: FileNotFoundError: Audio file does not exist. ValueError: Unsupported format, too short, or too long. RuntimeError: Audio decode or model inference failure. """ start_time = time.time() cfg = _load_config(config) device = _get_device() # --- Step 1: Validate & Preprocess --- _validate_audio(audio_path, cfg) preprocessed_path, duration = _preprocess_audio(audio_path, cfg) call_id = f"call_{time.strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:4]}" logger.info("Processing %s (%.1fs) as %s", audio_path, duration, call_id) # Load waveform for segment extraction later waveform, sr = torchaudio.load(preprocessed_path) # --- Step 2: Speaker Diarization --- diar_cfg = cfg.get("diarization", {}) diar_cfg.setdefault("device", device) diar_segments = diarize(preprocessed_path, diar_cfg) if not diar_segments: logger.warning("No speech segments detected in %s", call_id) return Stage1Output( call_id=call_id, duration=duration, speakers=["speaker_0", "speaker_1"], audio_path=audio_path, segments=[], processing_info=ProcessingInfo( processing_time_sec=time.time() - start_time, models=Models( diarization=diar_cfg.get("model", "pyannote/speaker-diarization-3.1"), asr=cfg.get("asr", {}).get("model", "large-v3-turbo"), language_id=cfg.get("language_id", {}).get("model", "none"), alignment="none", ), device=device, ), ) # --- Step 3: ASR Transcription --- asr_cfg = cfg.get("asr", {}) asr_result = transcribe(preprocessed_path, asr_cfg, device) whisper_language = asr_result.get("language", "ko") # --- Step 3.5: Text-based global language decision --- # Whisper's auto-detected language is unreliable on code-switched or # noisy calls. Decide alignment language from the transcript text # instead, applying the same 50% Korean-character majority rule we # use per-segment in language_id.py. Falls back to whisper_language # when the transcript is too short to be statistically meaningful. full_text = " ".join(seg.get("text", "") for seg in asr_result.get("segments", [])) align_language, korean_ratio = decide_language_from_text( full_text, fallback=whisper_language, ) if align_language != whisper_language: logger.warning( "Language disagreement: whisper=%s, text-based=%s " "(korean_ratio=%.3f, chars=%d)", whisper_language, align_language, korean_ratio, len(full_text.replace(" ", "")), ) # --- Step 4: Forced Alignment --- align_enabled = cfg.get("alignment", {}).get("enabled", True) aligned_result = align( asr_result, preprocessed_path, device, align_enabled, language_code=align_language, ) # --- Step 5: Diarization-Transcript Merge --- merged_result = assign_speakers(aligned_result, diar_segments) # --- Step 6: Language Detection --- lid_cfg = cfg.get("language_id", {}) merged_segments = merged_result.get("segments", []) languages = detect_languages( merged_segments, preprocessed_path, lid_cfg, device, whisper_language, ) # --- Step 7: Segment Audio Extraction --- segments_dir = cfg.get("segments_dir", "data/segments") # Build segment dicts for audio extraction segment_dicts: list[dict] = [] for idx, seg in enumerate(merged_segments): speaker = seg.get("speaker", "speaker_0") segment_dicts.append({ "segment_id": idx, "speaker_id": speaker, "start": seg.get("start", 0.0), "end": seg.get("end", 0.0), "text": seg.get("text", "").strip(), "language": languages[idx] if idx < len(languages) else whisper_language, }) audio_paths = _extract_segment_audio( waveform, sr, segment_dicts, call_id, segments_dir, ) # --- Step 8: Output Assembly --- # Determine speakers list all_speakers = sorted(set(s["speaker_id"] for s in segment_dicts)) if len(all_speakers) < 2: all_speakers = ["speaker_0", "speaker_1"] segments: list[Segment] = [] for seg_dict, seg_audio_path in zip(segment_dicts, audio_paths): # Compute confidence from WhisperX word scores if available orig_seg = merged_segments[seg_dict["segment_id"]] if seg_dict["segment_id"] < len(merged_segments) else {} words = orig_seg.get("words", []) if words and seg_dict["text"]: scores = [w.get("score", 0.0) for w in words if "score" in w] confidence = sum(scores) / len(scores) if scores else 0.5 elif not seg_dict["text"]: confidence = 0.0 else: confidence = 0.5 segments.append(Segment( segment_id=seg_dict["segment_id"], speaker_id=seg_dict["speaker_id"], start=round(seg_dict["start"], 3), end=round(seg_dict["end"], 3), text=seg_dict["text"], language=seg_dict["language"], audio_path=seg_audio_path, confidence=round(min(max(confidence, 0.0), 1.0), 3), )) processing_time = time.time() - start_time output = Stage1Output( call_id=call_id, duration=round(duration, 3), speakers=all_speakers, audio_path=audio_path, segments=segments, processing_info=ProcessingInfo( processing_time_sec=round(processing_time, 3), models=Models( diarization=diar_cfg.get("model", "pyannote/speaker-diarization-3.1"), asr=f"whisperx/{asr_cfg.get('model', 'large-v3-turbo')}-{asr_cfg.get('compute_type', 'int8')}", language_id=lid_cfg.get("model", "none") if lid_cfg.get("enabled", True) else "none", alignment="whisperx/wav2vec2-forced-alignment" if align_enabled else "none", ), device=device, language_chosen=align_language, korean_ratio=round(korean_ratio, 4), ), ) # Save to JSON output_path = cfg.get("output_path", "data/stage1_output.json") os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) with open(output_path, "w", encoding="utf-8") as f: f.write(output.model_dump_json(indent=2)) logger.info( "Stage 1 complete: %d segments, %.1fs processing time", len(segments), processing_time, ) return output