| from __future__ import annotations |
|
|
| import json |
| from dataclasses import dataclass |
| from pathlib import Path |
|
|
| import librosa |
| import numpy as np |
| import onnxruntime as ort |
|
|
| from ls_eend_common import ( |
| InferenceResult, |
| config_from_metadata, |
| ensure_mono, |
| extract_features, |
| frame_hz, |
| load_audio, |
| ) |
| from ls_eend_streaming_common import StreamingFeatureExtractor, StreamingUpdate |
|
|
|
|
| @dataclass(frozen=True) |
| class ONNXStateLayout: |
| input_dim: int |
| full_output_dim: int |
| real_output_dim: int |
| conv_delay: int |
| max_nspks: int |
|
|
|
|
| def _load_metadata(onnx_model_path: Path) -> dict: |
| metadata_path = onnx_model_path.with_suffix(".json") |
| if not metadata_path.exists(): |
| raise FileNotFoundError(f"Missing ONNX metadata JSON next to model: {metadata_path}") |
| with open(metadata_path, "r", encoding="utf-8") as handle: |
| return json.load(handle) |
|
|
|
|
| def _layout_from_metadata(metadata: dict) -> ONNXStateLayout: |
| return ONNXStateLayout( |
| input_dim=int(metadata["input_dim"]), |
| full_output_dim=int(metadata["full_output_dim"]), |
| real_output_dim=int(metadata["real_output_dim"]), |
| conv_delay=int(metadata["conv_delay"]), |
| max_nspks=int(metadata["max_nspks"]), |
| ) |
|
|
|
|
| def _initial_state_tensors(metadata: dict, dtype: np.dtype = np.float32) -> dict[str, np.ndarray]: |
| return { |
| key: np.zeros(tuple(shape), dtype=dtype) |
| for key, shape in metadata["state_shapes"].items() |
| } |
|
|
|
|
| def _resolve_providers(requested: list[str] | None) -> list[str]: |
| available = ort.get_available_providers() |
| if not requested: |
| if "CPUExecutionProvider" in available: |
| return ["CPUExecutionProvider"] |
| return available |
| missing = [provider for provider in requested if provider not in available] |
| if missing: |
| raise ValueError( |
| f"Requested ONNX Runtime providers are unavailable: {missing}. " |
| f"Available providers: {available}" |
| ) |
| return requested |
|
|
|
|
| class ONNXStreamingSession: |
| def __init__(self, engine: "ONNXLSEENDInferenceEngine", input_sample_rate: int) -> None: |
| self.engine = engine |
| self.input_sample_rate = int(input_sample_rate) |
| if self.input_sample_rate != self.engine.target_sample_rate: |
| raise ValueError( |
| "Stateful LS-EEND streaming expects audio at " |
| f"{self.engine.target_sample_rate} Hz, got {self.input_sample_rate} Hz." |
| ) |
| self.feature_extractor = StreamingFeatureExtractor(self.engine.config) |
| self.state = _initial_state_tensors(self.engine.metadata, dtype=np.float32) |
| self.zero_frame = np.zeros((1, 1, self.engine.layout.input_dim), dtype=np.float32) |
| self.total_input_samples = 0 |
| self.total_feature_frames = 0 |
| self.emitted_frames = 0 |
| self.full_logit_chunks: list[np.ndarray] = [] |
| self.finalized = False |
|
|
| def push_audio(self, chunk: np.ndarray) -> StreamingUpdate | None: |
| if self.finalized: |
| raise RuntimeError("Streaming session is already finalized.") |
| chunk = ensure_mono(chunk) |
| if chunk.size == 0: |
| return None |
| self.total_input_samples += len(chunk) |
| features = self.feature_extractor.push_audio(chunk) |
| committed = self._ingest_features(features) |
| return self._build_update(committed, include_preview=True) |
|
|
| def finalize(self) -> StreamingUpdate | None: |
| if self.finalized: |
| return None |
| features = self.feature_extractor.finalize() |
| committed = self._ingest_features(features) |
| pending = self.total_feature_frames - self.emitted_frames |
| tail = self._flush_tail(self.state, pending) if pending > 0 else np.zeros((0, self.engine.decode_max_nspks), dtype=np.float32) |
| merged = committed if tail.size == 0 else np.concatenate([committed, tail], axis=0) |
| self.finalized = True |
| return self._build_update(merged, include_preview=False) |
|
|
| def snapshot(self) -> InferenceResult: |
| if self.full_logit_chunks: |
| full_logits = np.concatenate(self.full_logit_chunks, axis=0) |
| else: |
| full_logits = np.zeros((0, self.engine.decode_max_nspks), dtype=np.float32) |
| full_probabilities = 1.0 / (1.0 + np.exp(-full_logits)) |
| logits = full_logits[:, 1:-1] |
| probabilities = full_probabilities[:, 1:-1] |
| return InferenceResult( |
| logits=logits, |
| probabilities=probabilities, |
| full_logits=full_logits, |
| full_probabilities=full_probabilities, |
| frame_hz=self.engine.model_frame_hz, |
| duration_seconds=float(self.total_input_samples / max(self.input_sample_rate, 1)), |
| ) |
|
|
| def _ingest_features(self, features: np.ndarray) -> np.ndarray: |
| if features.size == 0: |
| return np.zeros((0, self.engine.decode_max_nspks), dtype=np.float32) |
| outputs: list[np.ndarray] = [] |
| for frame in features: |
| should_decode = 1.0 if self.total_feature_frames >= self.engine.layout.conv_delay else 0.0 |
| prediction = self.engine._predict_step( |
| frame=frame.reshape(1, 1, -1).astype(np.float32, copy=False), |
| state=self.state, |
| ingest=1.0, |
| decode=should_decode, |
| ) |
| self.state = self.engine._next_state(prediction) |
| self.total_feature_frames += 1 |
| if should_decode: |
| outputs.append(prediction["full_logits"].reshape(1, -1).astype(np.float32, copy=False)) |
| self.emitted_frames += 1 |
| if outputs: |
| return np.concatenate(outputs, axis=0) |
| return np.zeros((0, self.engine.decode_max_nspks), dtype=np.float32) |
|
|
| def _flush_tail(self, state: dict[str, np.ndarray], pending: int) -> np.ndarray: |
| outputs: list[np.ndarray] = [] |
| for _ in range(pending): |
| prediction = self.engine._predict_step( |
| frame=self.zero_frame, |
| state=state, |
| ingest=0.0, |
| decode=1.0, |
| ) |
| state = self.engine._next_state(prediction) |
| outputs.append(prediction["full_logits"].reshape(1, -1).astype(np.float32, copy=False)) |
| if outputs: |
| return np.concatenate(outputs, axis=0) |
| return np.zeros((0, self.engine.decode_max_nspks), dtype=np.float32) |
|
|
| def _build_update(self, committed_full_logits: np.ndarray, include_preview: bool) -> StreamingUpdate | None: |
| start_frame = self.emitted_frames - committed_full_logits.shape[0] |
| if committed_full_logits.size > 0: |
| self.full_logit_chunks.append(committed_full_logits.astype(np.float32, copy=False)) |
| committed_full_probabilities = 1.0 / (1.0 + np.exp(-committed_full_logits)) |
| committed_logits = committed_full_logits[:, 1:-1] |
| committed_probabilities = committed_full_probabilities[:, 1:-1] |
| else: |
| committed_logits = np.zeros((0, max(self.engine.decode_max_nspks - 2, 0)), dtype=np.float32) |
| committed_probabilities = committed_logits.copy() |
|
|
| if include_preview: |
| pending = self.total_feature_frames - self.emitted_frames |
| preview_state = {key: value.copy() for key, value in self.state.items()} |
| preview_full_logits = self._flush_tail(preview_state, pending) |
| else: |
| preview_full_logits = np.zeros((0, self.engine.decode_max_nspks), dtype=np.float32) |
|
|
| if preview_full_logits.size > 0: |
| preview_full_probabilities = 1.0 / (1.0 + np.exp(-preview_full_logits)) |
| preview_logits = preview_full_logits[:, 1:-1] |
| preview_probabilities = preview_full_probabilities[:, 1:-1] |
| else: |
| preview_logits = np.zeros((0, committed_logits.shape[1]), dtype=np.float32) |
| preview_probabilities = preview_logits.copy() |
|
|
| if committed_logits.size == 0 and preview_logits.size == 0: |
| return None |
|
|
| return StreamingUpdate( |
| start_frame=start_frame, |
| logits=committed_logits, |
| probabilities=committed_probabilities, |
| preview_start_frame=self.emitted_frames, |
| preview_logits=preview_logits, |
| preview_probabilities=preview_probabilities, |
| frame_hz=self.engine.model_frame_hz, |
| duration_seconds=float(self.total_input_samples / max(self.input_sample_rate, 1)), |
| total_emitted_frames=self.emitted_frames, |
| ) |
|
|
|
|
| class ONNXLSEENDInferenceEngine: |
| def __init__( |
| self, |
| onnx_model_path: Path, |
| providers: list[str] | None = None, |
| ) -> None: |
| self.onnx_model_path = Path(onnx_model_path) |
| self.metadata = _load_metadata(self.onnx_model_path) |
| self.config = config_from_metadata(self.metadata) |
| self.layout = _layout_from_metadata(self.metadata) |
| self.decode_max_nspks = self.layout.max_nspks |
| self.target_sample_rate = int(self.metadata["sample_rate"]) |
| self.model_frame_hz = frame_hz(self.config) |
| self.streaming_latency_seconds = ( |
| (int(self.metadata["n_fft"]) // 2) |
| + (int(self.metadata["context_recp"]) * int(self.metadata["hop_length"])) |
| + (int(self.metadata["conv_delay"]) * int(self.metadata["subsampling"]) * int(self.metadata["hop_length"])) |
| ) / self.target_sample_rate |
| session_options = ort.SessionOptions() |
| session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL |
| self.providers = _resolve_providers(providers) |
| self.session = ort.InferenceSession( |
| str(self.onnx_model_path), |
| sess_options=session_options, |
| providers=self.providers, |
| ) |
| self.output_names = [output_meta.name for output_meta in self.session.get_outputs()] |
|
|
| def create_session(self, input_sample_rate: int) -> ONNXStreamingSession: |
| return ONNXStreamingSession(self, input_sample_rate) |
|
|
| def infer_audio(self, audio: np.ndarray, sample_rate: int) -> InferenceResult: |
| features = extract_features(audio, sample_rate, self.config) |
| session = self.create_session(self.target_sample_rate) |
| session.total_input_samples = len(audio) if sample_rate == self.target_sample_rate else int( |
| round(len(audio) * (self.target_sample_rate / max(sample_rate, 1))) |
| ) |
| committed = session._ingest_features(features) |
| pending = session.total_feature_frames - session.emitted_frames |
| tail = session._flush_tail(session.state, pending) if pending > 0 else np.zeros((0, self.decode_max_nspks), dtype=np.float32) |
| full_logits = committed if tail.size == 0 else np.concatenate([committed, tail], axis=0) |
| if full_logits.size > 0: |
| session.full_logit_chunks = [full_logits.astype(np.float32, copy=False)] |
| session.emitted_frames = full_logits.shape[0] |
| return session.snapshot() |
|
|
| def simulate_streaming_file( |
| self, |
| audio_path: Path, |
| chunk_seconds: float, |
| ) -> tuple[InferenceResult, list[dict]]: |
| audio, sample_rate = load_audio(audio_path) |
| if sample_rate != self.target_sample_rate: |
| audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=self.target_sample_rate).astype(np.float32, copy=False) |
| sample_rate = self.target_sample_rate |
| session = self.create_session(sample_rate) |
| chunk_size = max(1, int(round(chunk_seconds * sample_rate))) |
| updates = [] |
| for chunk_index, start in enumerate(range(0, len(audio), chunk_size), start=1): |
| stop = min(len(audio), start + chunk_size) |
| update = session.push_audio(audio[start:stop]) |
| updates.append( |
| { |
| "chunk_index": chunk_index, |
| "buffer_seconds": round(stop / sample_rate, 3), |
| "num_frames_emitted": int(0 if update is None else update.probabilities.shape[0]), |
| "total_frames_emitted": int(session.emitted_frames), |
| } |
| ) |
| final_update = session.finalize() |
| if final_update is not None: |
| updates.append( |
| { |
| "chunk_index": len(updates) + 1, |
| "buffer_seconds": round(len(audio) / sample_rate, 3), |
| "num_frames_emitted": int(final_update.probabilities.shape[0]), |
| "total_frames_emitted": int(session.emitted_frames), |
| "flush": True, |
| } |
| ) |
| result = session.snapshot() |
| if result.full_probabilities.shape[0] == 0: |
| raise ValueError(f"No audio found in {audio_path}") |
| return result, updates |
|
|
| def _predict_step( |
| self, |
| frame: np.ndarray, |
| state: dict[str, np.ndarray], |
| ingest: float, |
| decode: float, |
| ) -> dict[str, np.ndarray]: |
| prediction = self.session.run( |
| self.output_names, |
| { |
| "frame": frame, |
| "enc_ret_kv": state["enc_ret_kv"], |
| "enc_ret_scale": state["enc_ret_scale"], |
| "enc_conv_cache": state["enc_conv_cache"], |
| "dec_ret_kv": state["dec_ret_kv"], |
| "dec_ret_scale": state["dec_ret_scale"], |
| "top_buffer": state["top_buffer"], |
| "ingest": np.array([ingest], dtype=np.float32), |
| "decode": np.array([decode], dtype=np.float32), |
| }, |
| ) |
| return {name: value for name, value in zip(self.output_names, prediction)} |
|
|
| @staticmethod |
| def _next_state(prediction: dict[str, np.ndarray]) -> dict[str, np.ndarray]: |
| return { |
| "enc_ret_kv": prediction["enc_ret_kv_out"], |
| "enc_ret_scale": prediction["enc_ret_scale_out"], |
| "enc_conv_cache": prediction["enc_conv_cache_out"], |
| "dec_ret_kv": prediction["dec_ret_kv_out"], |
| "dec_ret_scale": prediction["dec_ret_scale_out"], |
| "top_buffer": prediction["top_buffer_out"], |
| } |
|
|