ustwo-api / src /stage1 /process.py
asdfasdfqrqwer's picture
Deploy from GitHub 2026-04-23T03:56:31Z
c857b85
Raw
History Blame Contribute Delete
11.6 kB
"""Stage 1 — Orchestrator.
Entry point: process(audio_path) -> Stage1Output
Pipeline:
1. Audio validation & preprocessing
2. Speaker diarization (pyannote)
3. ASR transcription (Faster-Whisper via WhisperX)
4. Forced alignment (WhisperX wav2vec2)
5. Diarization-transcript merge
6. Language detection (SenseVoice-Small)
7. Segment audio extraction
8. Output assembly
"""
from __future__ import annotations
import logging
import os
import time
import uuid
from pathlib import Path
import torch
import torchaudio
import yaml
from src.common.schemas import (
Models,
ProcessingInfo,
Segment,
Stage1Output,
)
from src.stage1.diarization import diarize
from src.stage1.language_id import decide_language_from_text, detect_languages
from src.stage1.transcription import align, assign_speakers, transcribe
logger = logging.getLogger(__name__)
SUPPORTED_EXTENSIONS = {".wav", ".mp3", ".m4a", ".ogg"}
def _load_config(config: dict | None = None) -> dict:
"""Load config from yaml if not provided."""
if config is not None:
return config
config_path = Path(__file__).parent.parent.parent / "config.yaml"
if config_path.exists():
with open(config_path) as f:
return yaml.safe_load(f).get("stage1", {})
logger.warning("config.yaml not found, using defaults")
return {}
def _get_device() -> str:
"""Determine compute device."""
return "cuda" if torch.cuda.is_available() else "cpu"
def _validate_audio(audio_path: str, config: dict) -> None:
"""Step 1a: Validate audio file."""
path = Path(audio_path)
if not path.exists():
raise FileNotFoundError(f"Audio file not found: {audio_path}")
if path.suffix.lower() not in SUPPORTED_EXTENSIONS:
raise ValueError(
f"Unsupported audio format: {path.suffix}. "
f"Supported: {SUPPORTED_EXTENSIONS}"
)
def _preprocess_audio(audio_path: str, config: dict) -> tuple[str, float]:
"""Step 1b: Convert to 16kHz mono WAV, return (preprocessed_path, duration)."""
preprocess_cfg = config.get("preprocessing", {})
target_sr = preprocess_cfg.get("target_sample_rate", 16000)
min_dur = preprocess_cfg.get("min_duration_sec", 3)
max_dur = preprocess_cfg.get("max_duration_sec", 300)
waveform, sr = torchaudio.load(audio_path)
duration = waveform.shape[1] / sr
if duration < min_dur:
raise ValueError(f"Audio too short: {duration:.1f}s (min: {min_dur}s)")
if duration > max_dur:
raise ValueError(f"Audio too long: {duration:.1f}s (max: {max_dur}s)")
# Resample if needed
if sr != target_sr:
resampler = torchaudio.transforms.Resample(sr, target_sr)
waveform = resampler(waveform)
# Stereo to mono
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
# Peak normalization — handle volume differences across devices
peak = waveform.abs().max()
if peak > 0:
target_peak = preprocess_cfg.get("target_peak", 0.95)
waveform = waveform * (target_peak / peak)
# Save preprocessed WAV
preprocessed_path = audio_path
if sr != target_sr or Path(audio_path).suffix.lower() != ".wav":
preprocessed_dir = Path(config.get("segments_dir", "data/segments"))
preprocessed_dir.mkdir(parents=True, exist_ok=True)
preprocessed_path = str(preprocessed_dir / f"_preprocessed_{uuid.uuid4().hex[:8]}.wav")
torchaudio.save(preprocessed_path, waveform, target_sr)
return preprocessed_path, duration
def _extract_segment_audio(
waveform: torch.Tensor,
sr: int,
segments: list[dict],
call_id: str,
segments_dir: str,
) -> list[str]:
"""Step 7: Extract audio clips for each segment."""
seg_dir = Path(segments_dir) / call_id
seg_dir.mkdir(parents=True, exist_ok=True)
paths: list[str] = []
for seg in segments:
idx = seg["segment_id"]
speaker = seg["speaker_id"]
start_sample = int(seg["start"] * sr)
end_sample = int(seg["end"] * sr)
# Clamp to valid range
start_sample = max(0, start_sample)
end_sample = min(waveform.shape[1], end_sample)
if end_sample <= start_sample:
paths.append("")
continue
clip = waveform[:, start_sample:end_sample]
filename = f"seg_{idx:03d}_{speaker}.wav"
clip_path = str(seg_dir / filename)
torchaudio.save(clip_path, clip, sr)
paths.append(clip_path)
return paths
def process(audio_path: str, config: dict | None = None) -> Stage1Output:
"""Stage 1 entry point.
Args:
audio_path: Path to input audio file (.wav/.mp3/.m4a/.ogg)
config: Optional config dict. If None, loads from config.yaml.
Returns:
Stage1Output with call_id, duration, and speaker-segmented transcript.
Raises:
FileNotFoundError: Audio file does not exist.
ValueError: Unsupported format, too short, or too long.
RuntimeError: Audio decode or model inference failure.
"""
start_time = time.time()
cfg = _load_config(config)
device = _get_device()
# --- Step 1: Validate & Preprocess ---
_validate_audio(audio_path, cfg)
preprocessed_path, duration = _preprocess_audio(audio_path, cfg)
call_id = f"call_{time.strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:4]}"
logger.info("Processing %s (%.1fs) as %s", audio_path, duration, call_id)
# Load waveform for segment extraction later
waveform, sr = torchaudio.load(preprocessed_path)
# --- Step 2: Speaker Diarization ---
diar_cfg = cfg.get("diarization", {})
diar_cfg.setdefault("device", device)
diar_segments = diarize(preprocessed_path, diar_cfg)
if not diar_segments:
logger.warning("No speech segments detected in %s", call_id)
return Stage1Output(
call_id=call_id,
duration=duration,
speakers=["speaker_0", "speaker_1"],
audio_path=audio_path,
segments=[],
processing_info=ProcessingInfo(
processing_time_sec=time.time() - start_time,
models=Models(
diarization=diar_cfg.get("model", "pyannote/speaker-diarization-3.1"),
asr=cfg.get("asr", {}).get("model", "large-v3-turbo"),
language_id=cfg.get("language_id", {}).get("model", "none"),
alignment="none",
),
device=device,
),
)
# --- Step 3: ASR Transcription ---
asr_cfg = cfg.get("asr", {})
asr_result = transcribe(preprocessed_path, asr_cfg, device)
whisper_language = asr_result.get("language", "ko")
# --- Step 3.5: Text-based global language decision ---
# Whisper's auto-detected language is unreliable on code-switched or
# noisy calls. Decide alignment language from the transcript text
# instead, applying the same 50% Korean-character majority rule we
# use per-segment in language_id.py. Falls back to whisper_language
# when the transcript is too short to be statistically meaningful.
full_text = " ".join(seg.get("text", "") for seg in asr_result.get("segments", []))
align_language, korean_ratio = decide_language_from_text(
full_text, fallback=whisper_language,
)
if align_language != whisper_language:
logger.warning(
"Language disagreement: whisper=%s, text-based=%s "
"(korean_ratio=%.3f, chars=%d)",
whisper_language, align_language, korean_ratio,
len(full_text.replace(" ", "")),
)
# --- Step 4: Forced Alignment ---
align_enabled = cfg.get("alignment", {}).get("enabled", True)
aligned_result = align(
asr_result, preprocessed_path, device, align_enabled,
language_code=align_language,
)
# --- Step 5: Diarization-Transcript Merge ---
merged_result = assign_speakers(aligned_result, diar_segments)
# --- Step 6: Language Detection ---
lid_cfg = cfg.get("language_id", {})
merged_segments = merged_result.get("segments", [])
languages = detect_languages(
merged_segments, preprocessed_path, lid_cfg, device, whisper_language,
)
# --- Step 7: Segment Audio Extraction ---
segments_dir = cfg.get("segments_dir", "data/segments")
# Build segment dicts for audio extraction
segment_dicts: list[dict] = []
for idx, seg in enumerate(merged_segments):
speaker = seg.get("speaker", "speaker_0")
segment_dicts.append({
"segment_id": idx,
"speaker_id": speaker,
"start": seg.get("start", 0.0),
"end": seg.get("end", 0.0),
"text": seg.get("text", "").strip(),
"language": languages[idx] if idx < len(languages) else whisper_language,
})
audio_paths = _extract_segment_audio(
waveform, sr, segment_dicts, call_id, segments_dir,
)
# --- Step 8: Output Assembly ---
# Determine speakers list
all_speakers = sorted(set(s["speaker_id"] for s in segment_dicts))
if len(all_speakers) < 2:
all_speakers = ["speaker_0", "speaker_1"]
segments: list[Segment] = []
for seg_dict, seg_audio_path in zip(segment_dicts, audio_paths):
# Compute confidence from WhisperX word scores if available
orig_seg = merged_segments[seg_dict["segment_id"]] if seg_dict["segment_id"] < len(merged_segments) else {}
words = orig_seg.get("words", [])
if words and seg_dict["text"]:
scores = [w.get("score", 0.0) for w in words if "score" in w]
confidence = sum(scores) / len(scores) if scores else 0.5
elif not seg_dict["text"]:
confidence = 0.0
else:
confidence = 0.5
segments.append(Segment(
segment_id=seg_dict["segment_id"],
speaker_id=seg_dict["speaker_id"],
start=round(seg_dict["start"], 3),
end=round(seg_dict["end"], 3),
text=seg_dict["text"],
language=seg_dict["language"],
audio_path=seg_audio_path,
confidence=round(min(max(confidence, 0.0), 1.0), 3),
))
processing_time = time.time() - start_time
output = Stage1Output(
call_id=call_id,
duration=round(duration, 3),
speakers=all_speakers,
audio_path=audio_path,
segments=segments,
processing_info=ProcessingInfo(
processing_time_sec=round(processing_time, 3),
models=Models(
diarization=diar_cfg.get("model", "pyannote/speaker-diarization-3.1"),
asr=f"whisperx/{asr_cfg.get('model', 'large-v3-turbo')}-{asr_cfg.get('compute_type', 'int8')}",
language_id=lid_cfg.get("model", "none") if lid_cfg.get("enabled", True) else "none",
alignment="whisperx/wav2vec2-forced-alignment" if align_enabled else "none",
),
device=device,
language_chosen=align_language,
korean_ratio=round(korean_ratio, 4),
),
)
# Save to JSON
output_path = cfg.get("output_path", "data/stage1_output.json")
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
f.write(output.model_dump_json(indent=2))
logger.info(
"Stage 1 complete: %d segments, %.1fs processing time",
len(segments), processing_time,
)
return output