from __future__ import annotations import json from dataclasses import dataclass from pathlib import Path import librosa import numpy as np import onnxruntime as ort from ls_eend_common import ( InferenceResult, config_from_metadata, ensure_mono, extract_features, frame_hz, load_audio, ) from ls_eend_streaming_common import StreamingFeatureExtractor, StreamingUpdate @dataclass(frozen=True) class ONNXStateLayout: input_dim: int full_output_dim: int real_output_dim: int conv_delay: int max_nspks: int def _load_metadata(onnx_model_path: Path) -> dict: metadata_path = onnx_model_path.with_suffix(".json") if not metadata_path.exists(): raise FileNotFoundError(f"Missing ONNX metadata JSON next to model: {metadata_path}") with open(metadata_path, "r", encoding="utf-8") as handle: return json.load(handle) def _layout_from_metadata(metadata: dict) -> ONNXStateLayout: return ONNXStateLayout( input_dim=int(metadata["input_dim"]), full_output_dim=int(metadata["full_output_dim"]), real_output_dim=int(metadata["real_output_dim"]), conv_delay=int(metadata["conv_delay"]), max_nspks=int(metadata["max_nspks"]), ) def _initial_state_tensors(metadata: dict, dtype: np.dtype = np.float32) -> dict[str, np.ndarray]: return { key: np.zeros(tuple(shape), dtype=dtype) for key, shape in metadata["state_shapes"].items() } def _resolve_providers(requested: list[str] | None) -> list[str]: available = ort.get_available_providers() if not requested: if "CPUExecutionProvider" in available: return ["CPUExecutionProvider"] return available missing = [provider for provider in requested if provider not in available] if missing: raise ValueError( f"Requested ONNX Runtime providers are unavailable: {missing}. " f"Available providers: {available}" ) return requested class ONNXStreamingSession: def __init__(self, engine: "ONNXLSEENDInferenceEngine", input_sample_rate: int) -> None: self.engine = engine self.input_sample_rate = int(input_sample_rate) if self.input_sample_rate != self.engine.target_sample_rate: raise ValueError( "Stateful LS-EEND streaming expects audio at " f"{self.engine.target_sample_rate} Hz, got {self.input_sample_rate} Hz." ) self.feature_extractor = StreamingFeatureExtractor(self.engine.config) self.state = _initial_state_tensors(self.engine.metadata, dtype=np.float32) self.zero_frame = np.zeros((1, 1, self.engine.layout.input_dim), dtype=np.float32) self.total_input_samples = 0 self.total_feature_frames = 0 self.emitted_frames = 0 self.full_logit_chunks: list[np.ndarray] = [] self.finalized = False def push_audio(self, chunk: np.ndarray) -> StreamingUpdate | None: if self.finalized: raise RuntimeError("Streaming session is already finalized.") chunk = ensure_mono(chunk) if chunk.size == 0: return None self.total_input_samples += len(chunk) features = self.feature_extractor.push_audio(chunk) committed = self._ingest_features(features) return self._build_update(committed, include_preview=True) def finalize(self) -> StreamingUpdate | None: if self.finalized: return None features = self.feature_extractor.finalize() committed = self._ingest_features(features) pending = self.total_feature_frames - self.emitted_frames tail = self._flush_tail(self.state, pending) if pending > 0 else np.zeros((0, self.engine.decode_max_nspks), dtype=np.float32) merged = committed if tail.size == 0 else np.concatenate([committed, tail], axis=0) self.finalized = True return self._build_update(merged, include_preview=False) def snapshot(self) -> InferenceResult: if self.full_logit_chunks: full_logits = np.concatenate(self.full_logit_chunks, axis=0) else: full_logits = np.zeros((0, self.engine.decode_max_nspks), dtype=np.float32) full_probabilities = 1.0 / (1.0 + np.exp(-full_logits)) logits = full_logits[:, 1:-1] probabilities = full_probabilities[:, 1:-1] return InferenceResult( logits=logits, probabilities=probabilities, full_logits=full_logits, full_probabilities=full_probabilities, frame_hz=self.engine.model_frame_hz, duration_seconds=float(self.total_input_samples / max(self.input_sample_rate, 1)), ) def _ingest_features(self, features: np.ndarray) -> np.ndarray: if features.size == 0: return np.zeros((0, self.engine.decode_max_nspks), dtype=np.float32) outputs: list[np.ndarray] = [] for frame in features: should_decode = 1.0 if self.total_feature_frames >= self.engine.layout.conv_delay else 0.0 prediction = self.engine._predict_step( frame=frame.reshape(1, 1, -1).astype(np.float32, copy=False), state=self.state, ingest=1.0, decode=should_decode, ) self.state = self.engine._next_state(prediction) self.total_feature_frames += 1 if should_decode: outputs.append(prediction["full_logits"].reshape(1, -1).astype(np.float32, copy=False)) self.emitted_frames += 1 if outputs: return np.concatenate(outputs, axis=0) return np.zeros((0, self.engine.decode_max_nspks), dtype=np.float32) def _flush_tail(self, state: dict[str, np.ndarray], pending: int) -> np.ndarray: outputs: list[np.ndarray] = [] for _ in range(pending): prediction = self.engine._predict_step( frame=self.zero_frame, state=state, ingest=0.0, decode=1.0, ) state = self.engine._next_state(prediction) outputs.append(prediction["full_logits"].reshape(1, -1).astype(np.float32, copy=False)) if outputs: return np.concatenate(outputs, axis=0) return np.zeros((0, self.engine.decode_max_nspks), dtype=np.float32) def _build_update(self, committed_full_logits: np.ndarray, include_preview: bool) -> StreamingUpdate | None: start_frame = self.emitted_frames - committed_full_logits.shape[0] if committed_full_logits.size > 0: self.full_logit_chunks.append(committed_full_logits.astype(np.float32, copy=False)) committed_full_probabilities = 1.0 / (1.0 + np.exp(-committed_full_logits)) committed_logits = committed_full_logits[:, 1:-1] committed_probabilities = committed_full_probabilities[:, 1:-1] else: committed_logits = np.zeros((0, max(self.engine.decode_max_nspks - 2, 0)), dtype=np.float32) committed_probabilities = committed_logits.copy() if include_preview: pending = self.total_feature_frames - self.emitted_frames preview_state = {key: value.copy() for key, value in self.state.items()} preview_full_logits = self._flush_tail(preview_state, pending) else: preview_full_logits = np.zeros((0, self.engine.decode_max_nspks), dtype=np.float32) if preview_full_logits.size > 0: preview_full_probabilities = 1.0 / (1.0 + np.exp(-preview_full_logits)) preview_logits = preview_full_logits[:, 1:-1] preview_probabilities = preview_full_probabilities[:, 1:-1] else: preview_logits = np.zeros((0, committed_logits.shape[1]), dtype=np.float32) preview_probabilities = preview_logits.copy() if committed_logits.size == 0 and preview_logits.size == 0: return None return StreamingUpdate( start_frame=start_frame, logits=committed_logits, probabilities=committed_probabilities, preview_start_frame=self.emitted_frames, preview_logits=preview_logits, preview_probabilities=preview_probabilities, frame_hz=self.engine.model_frame_hz, duration_seconds=float(self.total_input_samples / max(self.input_sample_rate, 1)), total_emitted_frames=self.emitted_frames, ) class ONNXLSEENDInferenceEngine: def __init__( self, onnx_model_path: Path, providers: list[str] | None = None, ) -> None: self.onnx_model_path = Path(onnx_model_path) self.metadata = _load_metadata(self.onnx_model_path) self.config = config_from_metadata(self.metadata) self.layout = _layout_from_metadata(self.metadata) self.decode_max_nspks = self.layout.max_nspks self.target_sample_rate = int(self.metadata["sample_rate"]) self.model_frame_hz = frame_hz(self.config) self.streaming_latency_seconds = ( (int(self.metadata["n_fft"]) // 2) + (int(self.metadata["context_recp"]) * int(self.metadata["hop_length"])) + (int(self.metadata["conv_delay"]) * int(self.metadata["subsampling"]) * int(self.metadata["hop_length"])) ) / self.target_sample_rate session_options = ort.SessionOptions() session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL self.providers = _resolve_providers(providers) self.session = ort.InferenceSession( str(self.onnx_model_path), sess_options=session_options, providers=self.providers, ) self.output_names = [output_meta.name for output_meta in self.session.get_outputs()] def create_session(self, input_sample_rate: int) -> ONNXStreamingSession: return ONNXStreamingSession(self, input_sample_rate) def infer_audio(self, audio: np.ndarray, sample_rate: int) -> InferenceResult: features = extract_features(audio, sample_rate, self.config) session = self.create_session(self.target_sample_rate) session.total_input_samples = len(audio) if sample_rate == self.target_sample_rate else int( round(len(audio) * (self.target_sample_rate / max(sample_rate, 1))) ) committed = session._ingest_features(features) pending = session.total_feature_frames - session.emitted_frames tail = session._flush_tail(session.state, pending) if pending > 0 else np.zeros((0, self.decode_max_nspks), dtype=np.float32) full_logits = committed if tail.size == 0 else np.concatenate([committed, tail], axis=0) if full_logits.size > 0: session.full_logit_chunks = [full_logits.astype(np.float32, copy=False)] session.emitted_frames = full_logits.shape[0] return session.snapshot() def simulate_streaming_file( self, audio_path: Path, chunk_seconds: float, ) -> tuple[InferenceResult, list[dict]]: audio, sample_rate = load_audio(audio_path) if sample_rate != self.target_sample_rate: audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=self.target_sample_rate).astype(np.float32, copy=False) sample_rate = self.target_sample_rate session = self.create_session(sample_rate) chunk_size = max(1, int(round(chunk_seconds * sample_rate))) updates = [] for chunk_index, start in enumerate(range(0, len(audio), chunk_size), start=1): stop = min(len(audio), start + chunk_size) update = session.push_audio(audio[start:stop]) updates.append( { "chunk_index": chunk_index, "buffer_seconds": round(stop / sample_rate, 3), "num_frames_emitted": int(0 if update is None else update.probabilities.shape[0]), "total_frames_emitted": int(session.emitted_frames), } ) final_update = session.finalize() if final_update is not None: updates.append( { "chunk_index": len(updates) + 1, "buffer_seconds": round(len(audio) / sample_rate, 3), "num_frames_emitted": int(final_update.probabilities.shape[0]), "total_frames_emitted": int(session.emitted_frames), "flush": True, } ) result = session.snapshot() if result.full_probabilities.shape[0] == 0: raise ValueError(f"No audio found in {audio_path}") return result, updates def _predict_step( self, frame: np.ndarray, state: dict[str, np.ndarray], ingest: float, decode: float, ) -> dict[str, np.ndarray]: prediction = self.session.run( self.output_names, { "frame": frame, "enc_ret_kv": state["enc_ret_kv"], "enc_ret_scale": state["enc_ret_scale"], "enc_conv_cache": state["enc_conv_cache"], "dec_ret_kv": state["dec_ret_kv"], "dec_ret_scale": state["dec_ret_scale"], "top_buffer": state["top_buffer"], "ingest": np.array([ingest], dtype=np.float32), "decode": np.array([decode], dtype=np.float32), }, ) return {name: value for name, value in zip(self.output_names, prediction)} @staticmethod def _next_state(prediction: dict[str, np.ndarray]) -> dict[str, np.ndarray]: return { "enc_ret_kv": prediction["enc_ret_kv_out"], "enc_ret_scale": prediction["enc_ret_scale_out"], "enc_conv_cache": prediction["enc_conv_cache_out"], "dec_ret_kv": prediction["dec_ret_kv_out"], "dec_ret_scale": prediction["dec_ret_scale_out"], "top_buffer": prediction["top_buffer_out"], }