| """Stage 1 — Orchestrator. |
| |
| Entry point: process(audio_path) -> Stage1Output |
| |
| Pipeline: |
| 1. Audio validation & preprocessing |
| 2. Speaker diarization (pyannote) |
| 3. ASR transcription (Faster-Whisper via WhisperX) |
| 4. Forced alignment (WhisperX wav2vec2) |
| 5. Diarization-transcript merge |
| 6. Language detection (SenseVoice-Small) |
| 7. Segment audio extraction |
| 8. Output assembly |
| """ |
|
|
| from __future__ import annotations |
|
|
| import logging |
| import os |
| import time |
| import uuid |
| from pathlib import Path |
|
|
| import torch |
| import torchaudio |
| import yaml |
|
|
| from src.common.schemas import ( |
| Models, |
| ProcessingInfo, |
| Segment, |
| Stage1Output, |
| ) |
| from src.stage1.diarization import diarize |
| from src.stage1.language_id import decide_language_from_text, detect_languages |
| from src.stage1.transcription import align, assign_speakers, transcribe |
|
|
| logger = logging.getLogger(__name__) |
|
|
| SUPPORTED_EXTENSIONS = {".wav", ".mp3", ".m4a", ".ogg"} |
|
|
|
|
| def _load_config(config: dict | None = None) -> dict: |
| """Load config from yaml if not provided.""" |
| if config is not None: |
| return config |
|
|
| config_path = Path(__file__).parent.parent.parent / "config.yaml" |
| if config_path.exists(): |
| with open(config_path) as f: |
| return yaml.safe_load(f).get("stage1", {}) |
|
|
| logger.warning("config.yaml not found, using defaults") |
| return {} |
|
|
|
|
| def _get_device() -> str: |
| """Determine compute device.""" |
| return "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
| def _validate_audio(audio_path: str, config: dict) -> None: |
| """Step 1a: Validate audio file.""" |
| path = Path(audio_path) |
|
|
| if not path.exists(): |
| raise FileNotFoundError(f"Audio file not found: {audio_path}") |
|
|
| if path.suffix.lower() not in SUPPORTED_EXTENSIONS: |
| raise ValueError( |
| f"Unsupported audio format: {path.suffix}. " |
| f"Supported: {SUPPORTED_EXTENSIONS}" |
| ) |
|
|
|
|
| def _preprocess_audio(audio_path: str, config: dict) -> tuple[str, float]: |
| """Step 1b: Convert to 16kHz mono WAV, return (preprocessed_path, duration).""" |
| preprocess_cfg = config.get("preprocessing", {}) |
| target_sr = preprocess_cfg.get("target_sample_rate", 16000) |
| min_dur = preprocess_cfg.get("min_duration_sec", 3) |
| max_dur = preprocess_cfg.get("max_duration_sec", 300) |
|
|
| waveform, sr = torchaudio.load(audio_path) |
| duration = waveform.shape[1] / sr |
|
|
| if duration < min_dur: |
| raise ValueError(f"Audio too short: {duration:.1f}s (min: {min_dur}s)") |
| if duration > max_dur: |
| raise ValueError(f"Audio too long: {duration:.1f}s (max: {max_dur}s)") |
|
|
| |
| if sr != target_sr: |
| resampler = torchaudio.transforms.Resample(sr, target_sr) |
| waveform = resampler(waveform) |
|
|
| |
| if waveform.shape[0] > 1: |
| waveform = waveform.mean(dim=0, keepdim=True) |
|
|
| |
| peak = waveform.abs().max() |
| if peak > 0: |
| target_peak = preprocess_cfg.get("target_peak", 0.95) |
| waveform = waveform * (target_peak / peak) |
|
|
| |
| preprocessed_path = audio_path |
| if sr != target_sr or Path(audio_path).suffix.lower() != ".wav": |
| preprocessed_dir = Path(config.get("segments_dir", "data/segments")) |
| preprocessed_dir.mkdir(parents=True, exist_ok=True) |
| preprocessed_path = str(preprocessed_dir / f"_preprocessed_{uuid.uuid4().hex[:8]}.wav") |
| torchaudio.save(preprocessed_path, waveform, target_sr) |
|
|
| return preprocessed_path, duration |
|
|
|
|
| def _extract_segment_audio( |
| waveform: torch.Tensor, |
| sr: int, |
| segments: list[dict], |
| call_id: str, |
| segments_dir: str, |
| ) -> list[str]: |
| """Step 7: Extract audio clips for each segment.""" |
| seg_dir = Path(segments_dir) / call_id |
| seg_dir.mkdir(parents=True, exist_ok=True) |
|
|
| paths: list[str] = [] |
| for seg in segments: |
| idx = seg["segment_id"] |
| speaker = seg["speaker_id"] |
| start_sample = int(seg["start"] * sr) |
| end_sample = int(seg["end"] * sr) |
|
|
| |
| start_sample = max(0, start_sample) |
| end_sample = min(waveform.shape[1], end_sample) |
|
|
| if end_sample <= start_sample: |
| paths.append("") |
| continue |
|
|
| clip = waveform[:, start_sample:end_sample] |
| filename = f"seg_{idx:03d}_{speaker}.wav" |
| clip_path = str(seg_dir / filename) |
| torchaudio.save(clip_path, clip, sr) |
| paths.append(clip_path) |
|
|
| return paths |
|
|
|
|
| def process(audio_path: str, config: dict | None = None) -> Stage1Output: |
| """Stage 1 entry point. |
| |
| Args: |
| audio_path: Path to input audio file (.wav/.mp3/.m4a/.ogg) |
| config: Optional config dict. If None, loads from config.yaml. |
| |
| Returns: |
| Stage1Output with call_id, duration, and speaker-segmented transcript. |
| |
| Raises: |
| FileNotFoundError: Audio file does not exist. |
| ValueError: Unsupported format, too short, or too long. |
| RuntimeError: Audio decode or model inference failure. |
| """ |
| start_time = time.time() |
| cfg = _load_config(config) |
| device = _get_device() |
|
|
| |
| _validate_audio(audio_path, cfg) |
| preprocessed_path, duration = _preprocess_audio(audio_path, cfg) |
| call_id = f"call_{time.strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:4]}" |
|
|
| logger.info("Processing %s (%.1fs) as %s", audio_path, duration, call_id) |
|
|
| |
| waveform, sr = torchaudio.load(preprocessed_path) |
|
|
| |
| diar_cfg = cfg.get("diarization", {}) |
| diar_cfg.setdefault("device", device) |
| diar_segments = diarize(preprocessed_path, diar_cfg) |
|
|
| if not diar_segments: |
| logger.warning("No speech segments detected in %s", call_id) |
| return Stage1Output( |
| call_id=call_id, |
| duration=duration, |
| speakers=["speaker_0", "speaker_1"], |
| audio_path=audio_path, |
| segments=[], |
| processing_info=ProcessingInfo( |
| processing_time_sec=time.time() - start_time, |
| models=Models( |
| diarization=diar_cfg.get("model", "pyannote/speaker-diarization-3.1"), |
| asr=cfg.get("asr", {}).get("model", "large-v3-turbo"), |
| language_id=cfg.get("language_id", {}).get("model", "none"), |
| alignment="none", |
| ), |
| device=device, |
| ), |
| ) |
|
|
| |
| asr_cfg = cfg.get("asr", {}) |
| asr_result = transcribe(preprocessed_path, asr_cfg, device) |
| whisper_language = asr_result.get("language", "ko") |
|
|
| |
| |
| |
| |
| |
| |
| full_text = " ".join(seg.get("text", "") for seg in asr_result.get("segments", [])) |
| align_language, korean_ratio = decide_language_from_text( |
| full_text, fallback=whisper_language, |
| ) |
| if align_language != whisper_language: |
| logger.warning( |
| "Language disagreement: whisper=%s, text-based=%s " |
| "(korean_ratio=%.3f, chars=%d)", |
| whisper_language, align_language, korean_ratio, |
| len(full_text.replace(" ", "")), |
| ) |
|
|
| |
| align_enabled = cfg.get("alignment", {}).get("enabled", True) |
| aligned_result = align( |
| asr_result, preprocessed_path, device, align_enabled, |
| language_code=align_language, |
| ) |
|
|
| |
| merged_result = assign_speakers(aligned_result, diar_segments) |
|
|
| |
| lid_cfg = cfg.get("language_id", {}) |
| merged_segments = merged_result.get("segments", []) |
| languages = detect_languages( |
| merged_segments, preprocessed_path, lid_cfg, device, whisper_language, |
| ) |
|
|
| |
| segments_dir = cfg.get("segments_dir", "data/segments") |
|
|
| |
| segment_dicts: list[dict] = [] |
| for idx, seg in enumerate(merged_segments): |
| speaker = seg.get("speaker", "speaker_0") |
| segment_dicts.append({ |
| "segment_id": idx, |
| "speaker_id": speaker, |
| "start": seg.get("start", 0.0), |
| "end": seg.get("end", 0.0), |
| "text": seg.get("text", "").strip(), |
| "language": languages[idx] if idx < len(languages) else whisper_language, |
| }) |
|
|
| audio_paths = _extract_segment_audio( |
| waveform, sr, segment_dicts, call_id, segments_dir, |
| ) |
|
|
| |
| |
| all_speakers = sorted(set(s["speaker_id"] for s in segment_dicts)) |
| if len(all_speakers) < 2: |
| all_speakers = ["speaker_0", "speaker_1"] |
|
|
| segments: list[Segment] = [] |
| for seg_dict, seg_audio_path in zip(segment_dicts, audio_paths): |
| |
| orig_seg = merged_segments[seg_dict["segment_id"]] if seg_dict["segment_id"] < len(merged_segments) else {} |
| words = orig_seg.get("words", []) |
| if words and seg_dict["text"]: |
| scores = [w.get("score", 0.0) for w in words if "score" in w] |
| confidence = sum(scores) / len(scores) if scores else 0.5 |
| elif not seg_dict["text"]: |
| confidence = 0.0 |
| else: |
| confidence = 0.5 |
|
|
| segments.append(Segment( |
| segment_id=seg_dict["segment_id"], |
| speaker_id=seg_dict["speaker_id"], |
| start=round(seg_dict["start"], 3), |
| end=round(seg_dict["end"], 3), |
| text=seg_dict["text"], |
| language=seg_dict["language"], |
| audio_path=seg_audio_path, |
| confidence=round(min(max(confidence, 0.0), 1.0), 3), |
| )) |
|
|
| processing_time = time.time() - start_time |
|
|
| output = Stage1Output( |
| call_id=call_id, |
| duration=round(duration, 3), |
| speakers=all_speakers, |
| audio_path=audio_path, |
| segments=segments, |
| processing_info=ProcessingInfo( |
| processing_time_sec=round(processing_time, 3), |
| models=Models( |
| diarization=diar_cfg.get("model", "pyannote/speaker-diarization-3.1"), |
| asr=f"whisperx/{asr_cfg.get('model', 'large-v3-turbo')}-{asr_cfg.get('compute_type', 'int8')}", |
| language_id=lid_cfg.get("model", "none") if lid_cfg.get("enabled", True) else "none", |
| alignment="whisperx/wav2vec2-forced-alignment" if align_enabled else "none", |
| ), |
| device=device, |
| language_chosen=align_language, |
| korean_ratio=round(korean_ratio, 4), |
| ), |
| ) |
|
|
| |
| output_path = cfg.get("output_path", "data/stage1_output.json") |
| os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) |
| with open(output_path, "w", encoding="utf-8") as f: |
| f.write(output.model_dump_json(indent=2)) |
|
|
| logger.info( |
| "Stage 1 complete: %d segments, %.1fs processing time", |
| len(segments), processing_time, |
| ) |
|
|
| return output |
|
|