driftcall / cells /step_09_audio.py
saumilyajj's picture
Upload folder using huggingface_hub
b43d8da verified
"""Cell 09 — Audio pipeline (Kokoro-82M TTS + faster-whisper-small ASR).
Implements docs/modules/audio.md: TTS and ASR engines exposed at the env
boundary. Training never imports this module (docs/modules/audio.md §6.3).
Heavy deps (``kokoro``, ``faster_whisper``, ``torchaudio``, ``soundfile``)
are loaded lazily inside ``_load_*`` helpers so this cell imports cleanly
in environments where those optional packages are absent, and so tests can
monkeypatch the loaders to return fakes without ever touching the network.
"""
from __future__ import annotations
import hashlib
import io
import logging
import math
import struct
import threading
import time
import unicodedata
import wave
from collections.abc import Callable
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Any, Literal, cast
import numpy as np
from cachetools import LRUCache
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Public literal types (audio.md §2.1, §2.2)
# ---------------------------------------------------------------------------
LanguageCode = Literal["hi", "ta", "kn", "en", "hinglish"]
VoicePack = Literal[
"hi_female_1",
"hi_male_1",
"ta_female_1",
"kn_male_1",
"en_indian_female_1",
]
_LANGUAGE_CODES: frozenset[str] = frozenset({"hi", "ta", "kn", "en", "hinglish"})
_VOICE_PACKS_SET: frozenset[str] = frozenset(
{
"hi_female_1",
"hi_male_1",
"ta_female_1",
"kn_male_1",
"en_indian_female_1",
}
)
# ---------------------------------------------------------------------------
# Errors (audio.md §2.3)
# ---------------------------------------------------------------------------
class AudioError(Exception):
"""Base class for all audio-module errors."""
class ModelLoadError(AudioError):
"""Raised when Kokoro or faster-whisper cannot be instantiated."""
class UnsupportedLanguageError(AudioError):
"""Raised when a non-registered language code is passed to synthesize()."""
class UnsupportedVoicePackError(AudioError):
"""Raised when a voice pack is not in VOICE_PACKS[lang].allowed."""
class AudioDecodeError(AudioError):
"""Raised when transcribe() cannot decode the input bytes."""
class AudioTooLongError(AudioError):
"""Raised when transcribe() receives audio longer than max_duration_s in strict mode."""
class TTSOutOfMemoryError(AudioError):
"""Raised when TTS synthesis exhausts memory mid-call."""
# ---------------------------------------------------------------------------
# Data records (audio.md §2.1, §2.2, §2.2a, §4.1, §4.2)
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class VoicePackMapping:
"""Per-language default + allowed voice packs. audio.md §4.3."""
language: LanguageCode
default: VoicePack
allowed: tuple[VoicePack, ...]
VOICE_PACKS: dict[LanguageCode, VoicePackMapping] = {
"hi": VoicePackMapping(
language="hi",
default="hi_female_1",
allowed=("hi_female_1", "hi_male_1"),
),
"ta": VoicePackMapping(
language="ta",
default="ta_female_1",
allowed=("ta_female_1",),
),
"kn": VoicePackMapping(
language="kn",
default="kn_male_1",
allowed=("kn_male_1",),
),
"en": VoicePackMapping(
language="en",
default="en_indian_female_1",
allowed=("en_indian_female_1",),
),
"hinglish": VoicePackMapping(
language="hinglish",
default="en_indian_female_1",
allowed=("en_indian_female_1", "hi_female_1"),
),
}
@dataclass(frozen=True)
class TranscriptResult:
"""ASR output surfaced to the env observation builder. audio.md §4.1."""
text: str
language_detected: LanguageCode | Literal["unknown"]
confidence: float
duration_s: float
@dataclass(frozen=True)
class AudioTrace:
"""Per-call diagnostic record emitted via the configured trace sink.
audio.md §2.2a, §3.8.
"""
op: Literal["synthesize", "transcribe"]
input_hash: str
language: str
duration_s: float
latency_ms: int
confidence: float | None
cache_hit: bool
degraded: bool
ts_ist: str
TraceSink = Callable[[AudioTrace], None]
# ---------------------------------------------------------------------------
# Lazy dep loaders — patched by tests to inject fakes.
# ---------------------------------------------------------------------------
def _load_kokoro() -> Any:
"""Return the ``kokoro`` module. Patched in tests."""
import kokoro
return kokoro
def _load_faster_whisper() -> Any:
"""Return the ``faster_whisper`` module. Patched in tests."""
import faster_whisper
return faster_whisper
def _load_torchaudio_functional() -> Any:
"""Return ``torchaudio.functional``. Patched in tests."""
import torchaudio.functional as F
return F
def _load_torchaudio() -> Any:
"""Return the top-level ``torchaudio`` module. Patched in tests."""
import torchaudio
return torchaudio
def _load_soundfile() -> Any:
"""Return the ``soundfile`` module. Patched in tests."""
import soundfile
return soundfile
def _load_torch() -> Any:
"""Return the ``torch`` module. Patched in tests."""
import torch
return torch
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_IST_TZ = timezone(timedelta(hours=5, minutes=30))
def _ts_ist_now() -> str:
return datetime.now(tz=_IST_TZ).isoformat(timespec="milliseconds")
def _input_hash(payload: bytes) -> str:
return hashlib.blake2b(payload, digest_size=16).hexdigest()
def _logprob_to_confidence(avg_logprob: float) -> float:
"""Map faster-whisper ``avg_logprob`` into [0, 1] per audio.md §3.5."""
clamped = max(-1.5, min(0.0, float(avg_logprob)))
return round(math.exp(clamped), 3)
def _riff_header_sample_rate(audio_bytes: bytes) -> int | None:
"""Return the sample-rate field from a RIFF header, or None if not RIFF."""
if len(audio_bytes) < 28:
return None
if audio_bytes[0:4] != b"RIFF" or audio_bytes[8:12] != b"WAVE":
return None
return int(struct.unpack_from("<I", audio_bytes, 24)[0])
def _pcm16_silence_wav(duration_s: float, sample_rate_hz: int = 16000) -> bytes:
"""Build a 16-bit mono PCM WAV of pure silence for warmup / fallback."""
n_samples = max(1, int(duration_s * sample_rate_hz))
buf = io.BytesIO()
with wave.open(buf, "wb") as w:
w.setnchannels(1)
w.setsampwidth(2)
w.setframerate(sample_rate_hz)
w.writeframes(b"\x00\x00" * n_samples)
return buf.getvalue()
def _np_to_wav_bytes(pcm: np.ndarray, sample_rate_hz: int) -> bytes:
"""Encode a float32 mono numpy array as 16-bit PCM RIFF WAV bytes.
Used when torchaudio is unavailable or mocked — the fallback path
produces the same byte-level contract (RIFF header + 16 kHz mono 16-bit).
"""
if pcm.dtype != np.int16:
clipped = np.clip(pcm.astype(np.float32), -1.0, 1.0)
pcm_i16 = (clipped * 32767.0).astype(np.int16)
else:
pcm_i16 = pcm
buf = io.BytesIO()
with wave.open(buf, "wb") as w:
w.setnchannels(1)
w.setsampwidth(2)
w.setframerate(sample_rate_hz)
w.writeframes(pcm_i16.tobytes())
return buf.getvalue()
# ---------------------------------------------------------------------------
# TTS
# ---------------------------------------------------------------------------
_TTS_CACHE_MAX_BYTES: int = 64 * 1024 * 1024
_TTS_CACHE_MAX_ENTRIES: int = 256
def _available_voice_packs(kokoro_module: Any) -> set[str]:
"""Probe the installed Kokoro bundle for shipped voice-pack names.
Looks for ``AVAILABLE_VOICES``, ``list_voices()``, or ``VOICES``. A fresh
install typically exposes at least one of these. If none is present we
fall back to the full canonical set (best-effort; runtime per-call
fallback in ``_resolve_voice_pack`` still protects against missing packs).
"""
candidates: set[str] = set()
for attr in ("AVAILABLE_VOICES", "VOICES"):
value = getattr(kokoro_module, attr, None)
if isinstance(value, (list, tuple, set, frozenset)):
candidates.update(str(v) for v in value)
list_voices = getattr(kokoro_module, "list_voices", None)
if callable(list_voices):
try:
value = list_voices()
if isinstance(value, (list, tuple, set, frozenset)):
candidates.update(str(v) for v in value)
except Exception: # pragma: no cover — defensive
pass
if not candidates:
return set(_VOICE_PACKS_SET)
return candidates
_FALLBACK_CHAIN: dict[str, str] = {
"ta_female_1": "hi_female_1",
"kn_male_1": "hi_female_1",
"hi_male_1": "hi_female_1",
"hi_female_1": "en_indian_female_1",
}
class TTSEngine:
"""Kokoro-82M wrapper. Constructed via ``get_tts_engine()``.
One instance per process. All heavy deps are imported lazily.
"""
def __init__(
self,
*,
model_id: str = "hexgrad/Kokoro-82M",
trace_sink: TraceSink | None = None,
) -> None:
self._model_id = model_id
self._trace_sink = trace_sink
self._lock = threading.Lock()
self._cache: LRUCache[tuple[Any, ...], bytes] = LRUCache(
maxsize=_TTS_CACHE_MAX_BYTES, getsizeof=len
)
self._numpy_cache: LRUCache[tuple[Any, ...], np.ndarray] = LRUCache(
maxsize=_TTS_CACHE_MAX_BYTES, getsizeof=lambda a: int(a.nbytes)
)
self._fallback_used: dict[str, str] = {}
try:
kokoro = _load_kokoro()
except Exception as exc: # network / disk / import failure
raise ModelLoadError(f"failed to load kokoro: {exc}") from exc
self._kokoro = kokoro
try:
pipeline_cls = getattr(kokoro, "KPipeline", None)
if pipeline_cls is None:
raise AttributeError("kokoro.KPipeline missing")
self._pipeline = pipeline_cls(model_id=model_id)
except Exception as exc:
raise ModelLoadError(f"failed to construct KPipeline: {exc}") from exc
self._available_packs = _available_voice_packs(kokoro)
self._verify_critical_packs()
def _verify_critical_packs(self) -> None:
if (
"en_indian_female_1" not in self._available_packs
and "hi_female_1" not in self._available_packs
):
raise ModelLoadError("no usable voice pack for hi or en")
def _resolve_voice_pack(self, requested: VoicePack) -> tuple[VoicePack, bool, str | None]:
"""Walk the fallback chain until an available pack is found.
Returns ``(resolved_pack, degraded, fallback_from)``.
"""
current = requested
original = requested
degraded = False
fallback_from: str | None = None
visited: set[str] = set()
while current not in self._available_packs:
if current in visited:
break
visited.add(current)
successor = _FALLBACK_CHAIN.get(current)
if successor is None:
raise ModelLoadError(
f"no usable voice pack; chain exhausted from {original!r}"
)
fallback_from = original
current = cast("VoicePack", successor)
degraded = True
if degraded:
self._fallback_used[original] = current
return current, degraded, fallback_from
def _emit_trace(self, trace: AudioTrace) -> None:
if self._trace_sink is None:
return
try:
self._trace_sink(trace)
except Exception: # telemetry must never break production
logger.debug("trace sink raised; swallowed", exc_info=True)
def _render_pcm(self, text: str, voice_pack: VoicePack, seed: int) -> np.ndarray:
"""Invoke Kokoro inside a forked RNG context and return 24 kHz float32 PCM."""
torch = _load_torch()
with torch.random.fork_rng(devices=[]):
torch.manual_seed(seed)
try:
result = self._pipeline(text, voice=voice_pack)
except MemoryError as exc:
raise TTSOutOfMemoryError(f"TTS OOM: {exc}") from exc
except RuntimeError as exc:
msg = str(exc).lower()
if "out of memory" in msg or "alloc" in msg:
raise TTSOutOfMemoryError(f"TTS OOM: {exc}") from exc
raise
return _coerce_to_float32_mono(result)
def _resample_to_16k(self, pcm_24k: np.ndarray) -> np.ndarray:
"""Downsample 24 kHz → 16 kHz via torchaudio.functional.resample."""
try:
F = _load_torchaudio_functional()
except Exception as exc: # pragma: no cover — hard runtime failure
raise ModelLoadError(f"torchaudio.functional missing: {exc}") from exc
torch = _load_torch()
tensor = torch.from_numpy(pcm_24k.astype(np.float32)).unsqueeze(0)
resampled = F.resample(
tensor, orig_freq=24000, new_freq=16000, lowpass_filter_width=64
)
out = resampled.squeeze(0).cpu().numpy().astype(np.float32)
return cast("np.ndarray", out)
def _encode_wav(self, pcm_16k: np.ndarray, sample_rate_hz: int) -> bytes:
"""Encode the 16 kHz float32 PCM into 16-bit mono RIFF WAV bytes."""
try:
torchaudio = _load_torchaudio()
torch = _load_torch()
tensor = torch.from_numpy(pcm_16k.astype(np.float32)).unsqueeze(0)
buf = io.BytesIO()
torchaudio.save(
buf,
tensor,
sample_rate=sample_rate_hz,
bits_per_sample=16,
format="wav",
encoding="PCM_S",
)
return buf.getvalue()
except Exception:
# Fall back to stdlib wave encoder so the byte contract still holds
# even when torchaudio is unavailable.
return _np_to_wav_bytes(pcm_16k, sample_rate_hz)
def synthesize(
self,
text: str,
language_code: LanguageCode,
voice_pack: VoicePack | None = None,
*,
seed: int = 0,
sample_rate_hz: int = 16000,
) -> bytes:
"""Return 16-bit PCM mono WAV bytes. audio.md §2.1, §4.4."""
if sample_rate_hz != 16000:
raise UnsupportedLanguageError(
f"sample_rate_hz={sample_rate_hz} unsupported; only 16000 allowed in v1"
)
if language_code not in _LANGUAGE_CODES:
raise UnsupportedLanguageError(f"language_code={language_code!r} unsupported")
mapping = VOICE_PACKS[language_code]
if voice_pack is None:
voice_pack = mapping.default
if voice_pack not in mapping.allowed:
raise UnsupportedVoicePackError(
f"voice_pack={voice_pack!r} not allowed for language={language_code!r}"
)
text_hash = _input_hash(text.encode("utf-8"))
cache_key = (text_hash, voice_pack, seed, sample_rate_hz, "bytes")
start = time.perf_counter()
with self._lock:
cached = self._cache.get(cache_key)
if cached is not None:
latency_ms = int((time.perf_counter() - start) * 1000)
duration_s = _wav_duration_s(cached)
self._emit_trace(
AudioTrace(
op="synthesize",
input_hash=text_hash,
language=language_code,
duration_s=duration_s,
latency_ms=latency_ms,
confidence=None,
cache_hit=True,
degraded=False,
ts_ist=_ts_ist_now(),
)
)
return cached
resolved_pack, degraded, _ = self._resolve_voice_pack(voice_pack)
pcm_24k = self._render_pcm(text, resolved_pack, seed)
pcm_16k = self._resample_to_16k(pcm_24k)
wav_bytes = self._encode_wav(pcm_16k, sample_rate_hz)
with self._lock:
self._cache[cache_key] = wav_bytes
latency_ms = int((time.perf_counter() - start) * 1000)
duration_s = _wav_duration_s(wav_bytes)
self._emit_trace(
AudioTrace(
op="synthesize",
input_hash=text_hash,
language=language_code,
duration_s=duration_s,
latency_ms=latency_ms,
confidence=None,
cache_hit=False,
degraded=degraded,
ts_ist=_ts_ist_now(),
)
)
return wav_bytes
def synthesize_to_gradio(
self,
text: str,
language_hint: LanguageCode,
voice_pack: VoicePack | None = None,
*,
seed: int = 0,
) -> tuple[int, np.ndarray]:
"""Return ``(sample_rate, float32 mono ndarray)``. audio.md §2.1."""
if language_hint not in _LANGUAGE_CODES:
raise UnsupportedLanguageError(f"language_hint={language_hint!r} unsupported")
mapping = VOICE_PACKS[language_hint]
if voice_pack is None:
voice_pack = mapping.default
if voice_pack not in mapping.allowed:
raise UnsupportedVoicePackError(
f"voice_pack={voice_pack!r} not allowed for language={language_hint!r}"
)
text_hash = _input_hash(text.encode("utf-8"))
sample_rate_hz = 16000
cache_key = (text_hash, voice_pack, seed, sample_rate_hz, "numpy")
start = time.perf_counter()
with self._lock:
cached = self._numpy_cache.get(cache_key)
if cached is not None:
self._emit_trace(
AudioTrace(
op="synthesize",
input_hash=text_hash,
language=language_hint,
duration_s=float(len(cached)) / sample_rate_hz,
latency_ms=int((time.perf_counter() - start) * 1000),
confidence=None,
cache_hit=True,
degraded=False,
ts_ist=_ts_ist_now(),
)
)
return sample_rate_hz, cached.copy()
resolved_pack, degraded, _ = self._resolve_voice_pack(voice_pack)
pcm_24k = self._render_pcm(text, resolved_pack, seed)
pcm_16k = self._resample_to_16k(pcm_24k)
with self._lock:
self._numpy_cache[cache_key] = pcm_16k
self._emit_trace(
AudioTrace(
op="synthesize",
input_hash=text_hash,
language=language_hint,
duration_s=float(len(pcm_16k)) / sample_rate_hz,
latency_ms=int((time.perf_counter() - start) * 1000),
confidence=None,
cache_hit=False,
degraded=degraded,
ts_ist=_ts_ist_now(),
)
)
return sample_rate_hz, pcm_16k.copy()
def warmup(self) -> None:
"""Probe each voice pack; log WARN on missing Indic packs. audio.md §4.3.1."""
for lang, mapping in VOICE_PACKS.items():
for pack in mapping.allowed:
if pack not in self._available_packs:
logger.warning(
"voice pack %r missing from bundle (language=%s); will fall back at synth time",
pack,
lang,
)
try:
self.synthesize("warmup", "en")
except Exception: # pragma: no cover — warmup best-effort
logger.debug("warmup synthesize failed; continuing", exc_info=True)
def _coerce_to_float32_mono(result: Any) -> np.ndarray:
"""Turn whatever Kokoro returned into a 1-D float32 numpy array."""
torch = _load_torch()
if hasattr(result, "cpu") and hasattr(result, "numpy"):
arr = result.detach().cpu().numpy()
elif isinstance(result, tuple):
audio_like = result[0]
if hasattr(audio_like, "cpu") and hasattr(audio_like, "numpy"):
arr = audio_like.detach().cpu().numpy()
else:
arr = np.asarray(audio_like)
elif isinstance(result, np.ndarray):
arr = result
else:
try:
tensor = torch.as_tensor(result)
arr = tensor.detach().cpu().numpy()
except Exception as exc: # pragma: no cover — defensive
raise TTSOutOfMemoryError(f"unexpected Kokoro return type: {type(result)!r}: {exc}") from exc
arr = np.asarray(arr, dtype=np.float32).reshape(-1)
return arr
def _wav_duration_s(wav_bytes: bytes) -> float:
"""Return the duration in seconds for a RIFF WAV payload (best-effort)."""
try:
with wave.open(io.BytesIO(wav_bytes), "rb") as w:
frames = w.getnframes()
rate = w.getframerate()
if rate <= 0:
return 0.0
return round(frames / rate, 3)
except Exception:
return 0.0
# ---------------------------------------------------------------------------
# ASR
# ---------------------------------------------------------------------------
def _map_language(code: str | None) -> LanguageCode | Literal["unknown"]:
if code in _LANGUAGE_CODES:
return cast("LanguageCode", code)
return "unknown"
def _nfc(text: str) -> str:
return unicodedata.normalize("NFC", text).strip()
class ASREngine:
"""faster-whisper-small wrapper. Constructed via ``get_asr_engine()``.
audio.md §2.2. Heavy deps loaded lazily.
"""
def __init__(
self,
*,
model_id: str = "Systran/faster-whisper-small",
compute_type: Literal["int8", "int8_float16"] = "int8",
trace_sink: TraceSink | None = None,
) -> None:
self._model_id = model_id
self._compute_type = compute_type
self._trace_sink = trace_sink
self._lock = threading.Lock()
try:
fw = _load_faster_whisper()
except Exception as exc:
raise ModelLoadError(f"failed to load faster_whisper: {exc}") from exc
model_cls = getattr(fw, "WhisperModel", None)
if model_cls is None:
raise ModelLoadError("faster_whisper.WhisperModel missing")
try:
self._model = model_cls(model_id, compute_type=compute_type, device="cpu")
except Exception as exc:
raise ModelLoadError(f"failed to construct WhisperModel: {exc}") from exc
def _emit_trace(self, trace: AudioTrace) -> None:
if self._trace_sink is None:
return
try:
self._trace_sink(trace)
except Exception:
logger.debug("trace sink raised; swallowed", exc_info=True)
def transcribe(
self,
audio_bytes: bytes,
language_hint: LanguageCode | None,
*,
beam_size: int = 1,
vad_filter: bool = True,
max_duration_s: float = 30.0,
) -> TranscriptResult:
"""Decode WAV/PCM bytes. audio.md §2.2, §3.5, §4.4."""
start = time.perf_counter()
pcm, clip_duration = self._decode_input(audio_bytes)
if clip_duration > max_duration_s:
pcm = pcm[: int(max_duration_s * 16000)]
clip_duration = max_duration_s
language_for_whisper: str | None
if language_hint == "hinglish":
language_for_whisper = "hi"
elif language_hint is None:
language_for_whisper = None
else:
language_for_whisper = language_hint
segments, info = self._run_whisper(
pcm,
language=language_for_whisper,
beam_size=beam_size,
vad_filter=vad_filter,
)
segments_list = list(segments)
detected_code = _map_language(getattr(info, "language", None))
vad_dropped_all = getattr(info, "vad_dropped_all_segments", None)
if vad_dropped_all is None:
vad_dropped_all = len(segments_list) == 0 and vad_filter
combined_text = _nfc("".join(getattr(s, "text", "") for s in segments_list))
duration_s = round(min(float(clip_duration), float(max_duration_s)), 3)
degraded = False
if combined_text == "":
confidence = 0.0
if vad_dropped_all:
detected: LanguageCode | Literal["unknown"] = "unknown"
else:
detected = detected_code
degraded = True
else:
confidence = _duration_weighted_confidence(segments_list)
detected = _infer_hinglish(detected_code, combined_text, language_hint)
result = TranscriptResult(
text=combined_text,
language_detected=detected,
confidence=confidence,
duration_s=duration_s,
)
latency_ms = int((time.perf_counter() - start) * 1000)
self._emit_trace(
AudioTrace(
op="transcribe",
input_hash=_input_hash(audio_bytes),
language=language_hint or "unknown",
duration_s=duration_s,
latency_ms=latency_ms,
confidence=confidence,
cache_hit=False,
degraded=degraded,
ts_ist=_ts_ist_now(),
)
)
return result
def _decode_input(self, audio_bytes: bytes) -> tuple[np.ndarray, float]:
"""Return (float32 mono @ 16 kHz, duration_s); raise AudioDecodeError on mismatch."""
if len(audio_bytes) >= 3 and audio_bytes[:3] == b"ID3":
raise AudioDecodeError("MP3 / ID3-tagged inputs are not supported (no ffmpeg in image)")
rate = _riff_header_sample_rate(audio_bytes)
if rate is not None:
if rate != 16000:
raise AudioDecodeError("input must be 16 kHz mono; caller must pre-resample")
try:
sf = _load_soundfile()
data, sr = sf.read(io.BytesIO(audio_bytes), dtype="float32", always_2d=False)
except Exception as exc:
raise AudioDecodeError(f"soundfile failed to decode RIFF WAV: {exc}") from exc
if sr != 16000:
raise AudioDecodeError("input must be 16 kHz mono; caller must pre-resample")
arr = np.asarray(data, dtype=np.float32).reshape(-1)
duration = float(len(arr)) / 16000.0
return arr, duration
# Raw float32 PCM path (demo mic input). 16 kHz assumed. We only accept
# payloads that look like plausible audio — ≥ 0.25 s of float32 samples
# (4000 × 4 = 16000 bytes) whose magnitudes fit inside the normalized
# [-1, 1] range that Gradio emits. Short / out-of-range payloads are
# rejected so arbitrary random bytes do not slip through.
min_raw_pcm_bytes = 4000 * 4
if len(audio_bytes) >= min_raw_pcm_bytes and len(audio_bytes) % 4 == 0:
pcm = np.frombuffer(audio_bytes, dtype=np.float32).copy()
if pcm.size and np.all(np.isfinite(pcm)) and np.max(np.abs(pcm)) <= 2.0:
duration = float(pcm.size) / 16000.0
return pcm, duration
raise AudioDecodeError("input is not a valid 16 kHz RIFF WAV or float32 PCM payload")
def _run_whisper(
self,
pcm: np.ndarray,
*,
language: str | None,
beam_size: int,
vad_filter: bool,
) -> tuple[Any, Any]:
try:
segments, info = self._model.transcribe(
pcm,
language=language,
beam_size=beam_size,
vad_filter=vad_filter,
)
except Exception as exc:
raise AudioDecodeError(f"whisper decode failed: {exc}") from exc
return segments, info
def warmup(self) -> None:
"""Run one transcribe() on 0.5 s of silence to force load. audio.md §2.2."""
silence = _pcm16_silence_wav(0.5)
try:
self.transcribe(silence, "en")
except Exception: # pragma: no cover — warmup best-effort
logger.debug("warmup transcribe failed; continuing", exc_info=True)
def _duration_weighted_confidence(segments: list[Any]) -> float:
if not segments:
return 0.0
total_dur = 0.0
weighted = 0.0
for seg in segments:
start = float(getattr(seg, "start", 0.0) or 0.0)
end = float(getattr(seg, "end", 0.0) or 0.0)
dur = max(0.0, end - start)
avg_logprob = float(getattr(seg, "avg_logprob", 0.0) or 0.0)
confidence = _logprob_to_confidence(avg_logprob)
if dur == 0.0:
total_dur += 1.0
weighted += confidence
else:
total_dur += dur
weighted += confidence * dur
if total_dur == 0.0:
return 0.0
return round(weighted / total_dur, 3)
def _infer_hinglish(
detected: LanguageCode | Literal["unknown"],
text: str,
hint: LanguageCode | None,
) -> LanguageCode | Literal["unknown"]:
"""Downgrade ``hi`` to ``hinglish`` when the decoded text is code-mixed.
Heuristic per audio.md §3.6: ≥ 2 ASCII words intermixed with Devanagari.
"""
if hint != "hinglish":
return detected
if detected != "hi":
return detected
ascii_words = [tok for tok in text.split() if tok.isascii() and tok.isalpha()]
has_devanagari = any("ऀ" <= ch <= "ॿ" for ch in text)
if len(ascii_words) >= 2 and has_devanagari:
return "hinglish"
return detected
# ---------------------------------------------------------------------------
# Singletons
# ---------------------------------------------------------------------------
_tts_engine: TTSEngine | None = None
_asr_engine: ASREngine | None = None
_tts_lock = threading.Lock()
_asr_lock = threading.Lock()
def get_tts_engine(
*, trace_sink: TraceSink | None = None, model_id: str = "hexgrad/Kokoro-82M"
) -> TTSEngine:
"""Return the process-wide TTSEngine singleton. audio.md §3.2, §3.8."""
global _tts_engine
with _tts_lock:
if _tts_engine is None:
_tts_engine = TTSEngine(model_id=model_id, trace_sink=trace_sink)
elif trace_sink is not None and trace_sink is not _tts_engine._trace_sink:
logger.warning("get_tts_engine: different sink passed after construction; ignoring")
return _tts_engine
def get_asr_engine(
*,
trace_sink: TraceSink | None = None,
model_id: str = "Systran/faster-whisper-small",
compute_type: Literal["int8", "int8_float16"] = "int8",
) -> ASREngine:
"""Return the process-wide ASREngine singleton. audio.md §3.2, §3.8."""
global _asr_engine
with _asr_lock:
if _asr_engine is None:
_asr_engine = ASREngine(
model_id=model_id, compute_type=compute_type, trace_sink=trace_sink
)
elif trace_sink is not None and trace_sink is not _asr_engine._trace_sink:
logger.warning("get_asr_engine: different sink passed after construction; ignoring")
return _asr_engine
def _reset_singletons_for_tests() -> None:
"""Tear down singletons. Tests only. audio.md §3.2 "Unload. Never." exemption."""
global _tts_engine, _asr_engine
with _tts_lock:
_tts_engine = None
with _asr_lock:
_asr_engine = None
__all__ = [
"AudioDecodeError",
"AudioError",
"AudioTooLongError",
"AudioTrace",
"ASREngine",
"LanguageCode",
"ModelLoadError",
"TTSEngine",
"TTSOutOfMemoryError",
"TranscriptResult",
"TraceSink",
"UnsupportedLanguageError",
"UnsupportedVoicePackError",
"VOICE_PACKS",
"VoicePack",
"VoicePackMapping",
"get_asr_engine",
"get_tts_engine",
]