from __future__ import annotations import argparse import json from pathlib import Path import numpy as np import onnxruntime as ort import soundfile as sf from scipy.signal import resample_poly MODEL_PATH = Path(__file__).with_name("model.onnx") SAMPLE_RATE = 16_000 CHUNK_FRAMES = 160_000 SUPPORTED_EXTENSIONS = {".wav", ".flac", ".mp3", ".ogg", ".m4a"} def load_audio(path: Path) -> tuple[np.ndarray, int]: audio, sample_rate = sf.read(path, dtype="float32", always_2d=True) waveform = np.ascontiguousarray(audio.mean(axis=1), dtype=np.float32) return waveform, int(sample_rate) def resample_audio( waveform: np.ndarray, source_rate: int, target_rate: int ) -> np.ndarray: gcd = np.gcd(source_rate, target_rate) waveform = resample_poly( waveform, target_rate // gcd, source_rate // gcd ).astype(np.float32) return np.ascontiguousarray(waveform) def layer_norm(waveform: np.ndarray, eps: float = 1e-5) -> np.ndarray: mean = waveform.mean(dtype=np.float64) variance = waveform.var(dtype=np.float64) return ((waveform - mean) / np.sqrt(variance + eps)).astype(np.float32) def chunk_waveform(waveform: np.ndarray, chunk_frames: int) -> np.ndarray: if chunk_frames <= 0 or waveform.size <= chunk_frames: return waveform[None, :] chunks = [ waveform[start : start + chunk_frames] for start in range(0, waveform.size, chunk_frames) ] max_length = max(chunk.size for chunk in chunks) batch = np.zeros((len(chunks), max_length), dtype=np.float32) for index, chunk in enumerate(chunks): batch[index, : chunk.size] = chunk return batch def softmax(logits: np.ndarray) -> np.ndarray: logits = logits.astype(np.float64) probabilities = np.exp(logits - logits.max()) return probabilities / probabilities.sum() class TTSSuitabilityClassifier: def __init__( self, model_path: str | Path = MODEL_PATH, provider: str = "auto", cuda_device_id: int = 0, ) -> None: available = set(ort.get_available_providers()) if provider == "auto": provider = "cuda" if "CUDAExecutionProvider" in available else "cpu" if provider == "cuda": if "CUDAExecutionProvider" not in available: raise RuntimeError( "CUDAExecutionProvider is unavailable. Install onnxruntime-gpu " "or use provider='cpu'." ) providers = [ ("CUDAExecutionProvider", {"device_id": cuda_device_id}), "CPUExecutionProvider", ] elif provider == "cpu": providers = ["CPUExecutionProvider"] else: raise ValueError("provider must be one of: auto, cpu, cuda") self.session = ort.InferenceSession(str(model_path), providers=providers) self.input_name = self.session.get_inputs()[0].name self.output_names = [output.name for output in self.session.get_outputs()] def predict(self, audio_path: str | Path) -> dict[str, object]: path = Path(audio_path).expanduser().resolve() waveform, sample_rate = load_audio(path) if sample_rate != SAMPLE_RATE: waveform = resample_audio(waveform, sample_rate, SAMPLE_RATE) waveform = layer_norm(waveform) batch = chunk_waveform(waveform, CHUNK_FRAMES) logits = self.session.run( self.output_names, {self.input_name: batch} )[0].mean(axis=0) probabilities = softmax(logits) predicted_class = int(probabilities.argmax()) return { "path": str(path), "label": "tts" if predicted_class == 1 else "not_tts", "predicted_class": predicted_class, "p_not_tts": float(probabilities[0]), "p_tts": float(probabilities[1]), "logits": [float(value) for value in logits], } def collect_audio_paths(path: Path) -> list[Path]: path = path.expanduser().resolve() if path.is_file(): return [path] return sorted( child for child in path.rglob("*") if child.is_file() and child.suffix.lower() in SUPPORTED_EXTENSIONS ) def main() -> None: parser = argparse.ArgumentParser( description="ONNX inference for the TTS suitability classifier." ) parser.add_argument("audio", type=Path, help="Audio file or directory.") parser.add_argument( "--model", type=Path, default=MODEL_PATH, help="Path to model.onnx." ) parser.add_argument( "--provider", choices=("auto", "cpu", "cuda"), default="auto" ) parser.add_argument("--cuda-device-id", type=int, default=0) args = parser.parse_args() classifier = TTSSuitabilityClassifier( args.model, args.provider, args.cuda_device_id ) paths = collect_audio_paths(args.audio) if not paths: raise RuntimeError(f"No supported audio files found at '{args.audio}'.") for path in paths: print(json.dumps(classifier.predict(path), ensure_ascii=False)) if __name__ == "__main__": main()