LS-EEND-ONNX / example /ls_eend_onnx_runtime.py
GradientDescent2718's picture
Fixed microphone demo
cc40a1e verified
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
import librosa
import numpy as np
import onnxruntime as ort
from ls_eend_common import (
InferenceResult,
config_from_metadata,
ensure_mono,
extract_features,
frame_hz,
load_audio,
)
from ls_eend_streaming_common import StreamingFeatureExtractor, StreamingUpdate
@dataclass(frozen=True)
class ONNXStateLayout:
input_dim: int
full_output_dim: int
real_output_dim: int
conv_delay: int
max_nspks: int
def _load_metadata(onnx_model_path: Path) -> dict:
metadata_path = onnx_model_path.with_suffix(".json")
if not metadata_path.exists():
raise FileNotFoundError(f"Missing ONNX metadata JSON next to model: {metadata_path}")
with open(metadata_path, "r", encoding="utf-8") as handle:
return json.load(handle)
def _layout_from_metadata(metadata: dict) -> ONNXStateLayout:
return ONNXStateLayout(
input_dim=int(metadata["input_dim"]),
full_output_dim=int(metadata["full_output_dim"]),
real_output_dim=int(metadata["real_output_dim"]),
conv_delay=int(metadata["conv_delay"]),
max_nspks=int(metadata["max_nspks"]),
)
def _initial_state_tensors(metadata: dict, dtype: np.dtype = np.float32) -> dict[str, np.ndarray]:
return {
key: np.zeros(tuple(shape), dtype=dtype)
for key, shape in metadata["state_shapes"].items()
}
def _resolve_providers(requested: list[str] | None) -> list[str]:
available = ort.get_available_providers()
if not requested:
if "CPUExecutionProvider" in available:
return ["CPUExecutionProvider"]
return available
missing = [provider for provider in requested if provider not in available]
if missing:
raise ValueError(
f"Requested ONNX Runtime providers are unavailable: {missing}. "
f"Available providers: {available}"
)
return requested
class ONNXStreamingSession:
def __init__(self, engine: "ONNXLSEENDInferenceEngine", input_sample_rate: int) -> None:
self.engine = engine
self.input_sample_rate = int(input_sample_rate)
if self.input_sample_rate != self.engine.target_sample_rate:
raise ValueError(
"Stateful LS-EEND streaming expects audio at "
f"{self.engine.target_sample_rate} Hz, got {self.input_sample_rate} Hz."
)
self.feature_extractor = StreamingFeatureExtractor(self.engine.config)
self.state = _initial_state_tensors(self.engine.metadata, dtype=np.float32)
self.zero_frame = np.zeros((1, 1, self.engine.layout.input_dim), dtype=np.float32)
self.total_input_samples = 0
self.total_feature_frames = 0
self.emitted_frames = 0
self.full_logit_chunks: list[np.ndarray] = []
self.finalized = False
def push_audio(self, chunk: np.ndarray) -> StreamingUpdate | None:
if self.finalized:
raise RuntimeError("Streaming session is already finalized.")
chunk = ensure_mono(chunk)
if chunk.size == 0:
return None
self.total_input_samples += len(chunk)
features = self.feature_extractor.push_audio(chunk)
committed = self._ingest_features(features)
return self._build_update(committed, include_preview=True)
def finalize(self) -> StreamingUpdate | None:
if self.finalized:
return None
features = self.feature_extractor.finalize()
committed = self._ingest_features(features)
pending = self.total_feature_frames - self.emitted_frames
tail = self._flush_tail(self.state, pending) if pending > 0 else np.zeros((0, self.engine.decode_max_nspks), dtype=np.float32)
merged = committed if tail.size == 0 else np.concatenate([committed, tail], axis=0)
self.finalized = True
return self._build_update(merged, include_preview=False)
def snapshot(self) -> InferenceResult:
if self.full_logit_chunks:
full_logits = np.concatenate(self.full_logit_chunks, axis=0)
else:
full_logits = np.zeros((0, self.engine.decode_max_nspks), dtype=np.float32)
full_probabilities = 1.0 / (1.0 + np.exp(-full_logits))
logits = full_logits[:, 1:-1]
probabilities = full_probabilities[:, 1:-1]
return InferenceResult(
logits=logits,
probabilities=probabilities,
full_logits=full_logits,
full_probabilities=full_probabilities,
frame_hz=self.engine.model_frame_hz,
duration_seconds=float(self.total_input_samples / max(self.input_sample_rate, 1)),
)
def _ingest_features(self, features: np.ndarray) -> np.ndarray:
if features.size == 0:
return np.zeros((0, self.engine.decode_max_nspks), dtype=np.float32)
outputs: list[np.ndarray] = []
for frame in features:
should_decode = 1.0 if self.total_feature_frames >= self.engine.layout.conv_delay else 0.0
prediction = self.engine._predict_step(
frame=frame.reshape(1, 1, -1).astype(np.float32, copy=False),
state=self.state,
ingest=1.0,
decode=should_decode,
)
self.state = self.engine._next_state(prediction)
self.total_feature_frames += 1
if should_decode:
outputs.append(prediction["full_logits"].reshape(1, -1).astype(np.float32, copy=False))
self.emitted_frames += 1
if outputs:
return np.concatenate(outputs, axis=0)
return np.zeros((0, self.engine.decode_max_nspks), dtype=np.float32)
def _flush_tail(self, state: dict[str, np.ndarray], pending: int) -> np.ndarray:
outputs: list[np.ndarray] = []
for _ in range(pending):
prediction = self.engine._predict_step(
frame=self.zero_frame,
state=state,
ingest=0.0,
decode=1.0,
)
state = self.engine._next_state(prediction)
outputs.append(prediction["full_logits"].reshape(1, -1).astype(np.float32, copy=False))
if outputs:
return np.concatenate(outputs, axis=0)
return np.zeros((0, self.engine.decode_max_nspks), dtype=np.float32)
def _build_update(self, committed_full_logits: np.ndarray, include_preview: bool) -> StreamingUpdate | None:
start_frame = self.emitted_frames - committed_full_logits.shape[0]
if committed_full_logits.size > 0:
self.full_logit_chunks.append(committed_full_logits.astype(np.float32, copy=False))
committed_full_probabilities = 1.0 / (1.0 + np.exp(-committed_full_logits))
committed_logits = committed_full_logits[:, 1:-1]
committed_probabilities = committed_full_probabilities[:, 1:-1]
else:
committed_logits = np.zeros((0, max(self.engine.decode_max_nspks - 2, 0)), dtype=np.float32)
committed_probabilities = committed_logits.copy()
if include_preview:
pending = self.total_feature_frames - self.emitted_frames
preview_state = {key: value.copy() for key, value in self.state.items()}
preview_full_logits = self._flush_tail(preview_state, pending)
else:
preview_full_logits = np.zeros((0, self.engine.decode_max_nspks), dtype=np.float32)
if preview_full_logits.size > 0:
preview_full_probabilities = 1.0 / (1.0 + np.exp(-preview_full_logits))
preview_logits = preview_full_logits[:, 1:-1]
preview_probabilities = preview_full_probabilities[:, 1:-1]
else:
preview_logits = np.zeros((0, committed_logits.shape[1]), dtype=np.float32)
preview_probabilities = preview_logits.copy()
if committed_logits.size == 0 and preview_logits.size == 0:
return None
return StreamingUpdate(
start_frame=start_frame,
logits=committed_logits,
probabilities=committed_probabilities,
preview_start_frame=self.emitted_frames,
preview_logits=preview_logits,
preview_probabilities=preview_probabilities,
frame_hz=self.engine.model_frame_hz,
duration_seconds=float(self.total_input_samples / max(self.input_sample_rate, 1)),
total_emitted_frames=self.emitted_frames,
)
class ONNXLSEENDInferenceEngine:
def __init__(
self,
onnx_model_path: Path,
providers: list[str] | None = None,
) -> None:
self.onnx_model_path = Path(onnx_model_path)
self.metadata = _load_metadata(self.onnx_model_path)
self.config = config_from_metadata(self.metadata)
self.layout = _layout_from_metadata(self.metadata)
self.decode_max_nspks = self.layout.max_nspks
self.target_sample_rate = int(self.metadata["sample_rate"])
self.model_frame_hz = frame_hz(self.config)
self.streaming_latency_seconds = (
(int(self.metadata["n_fft"]) // 2)
+ (int(self.metadata["context_recp"]) * int(self.metadata["hop_length"]))
+ (int(self.metadata["conv_delay"]) * int(self.metadata["subsampling"]) * int(self.metadata["hop_length"]))
) / self.target_sample_rate
session_options = ort.SessionOptions()
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
self.providers = _resolve_providers(providers)
self.session = ort.InferenceSession(
str(self.onnx_model_path),
sess_options=session_options,
providers=self.providers,
)
self.output_names = [output_meta.name for output_meta in self.session.get_outputs()]
def create_session(self, input_sample_rate: int) -> ONNXStreamingSession:
return ONNXStreamingSession(self, input_sample_rate)
def infer_audio(self, audio: np.ndarray, sample_rate: int) -> InferenceResult:
features = extract_features(audio, sample_rate, self.config)
session = self.create_session(self.target_sample_rate)
session.total_input_samples = len(audio) if sample_rate == self.target_sample_rate else int(
round(len(audio) * (self.target_sample_rate / max(sample_rate, 1)))
)
committed = session._ingest_features(features)
pending = session.total_feature_frames - session.emitted_frames
tail = session._flush_tail(session.state, pending) if pending > 0 else np.zeros((0, self.decode_max_nspks), dtype=np.float32)
full_logits = committed if tail.size == 0 else np.concatenate([committed, tail], axis=0)
if full_logits.size > 0:
session.full_logit_chunks = [full_logits.astype(np.float32, copy=False)]
session.emitted_frames = full_logits.shape[0]
return session.snapshot()
def simulate_streaming_file(
self,
audio_path: Path,
chunk_seconds: float,
) -> tuple[InferenceResult, list[dict]]:
audio, sample_rate = load_audio(audio_path)
if sample_rate != self.target_sample_rate:
audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=self.target_sample_rate).astype(np.float32, copy=False)
sample_rate = self.target_sample_rate
session = self.create_session(sample_rate)
chunk_size = max(1, int(round(chunk_seconds * sample_rate)))
updates = []
for chunk_index, start in enumerate(range(0, len(audio), chunk_size), start=1):
stop = min(len(audio), start + chunk_size)
update = session.push_audio(audio[start:stop])
updates.append(
{
"chunk_index": chunk_index,
"buffer_seconds": round(stop / sample_rate, 3),
"num_frames_emitted": int(0 if update is None else update.probabilities.shape[0]),
"total_frames_emitted": int(session.emitted_frames),
}
)
final_update = session.finalize()
if final_update is not None:
updates.append(
{
"chunk_index": len(updates) + 1,
"buffer_seconds": round(len(audio) / sample_rate, 3),
"num_frames_emitted": int(final_update.probabilities.shape[0]),
"total_frames_emitted": int(session.emitted_frames),
"flush": True,
}
)
result = session.snapshot()
if result.full_probabilities.shape[0] == 0:
raise ValueError(f"No audio found in {audio_path}")
return result, updates
def _predict_step(
self,
frame: np.ndarray,
state: dict[str, np.ndarray],
ingest: float,
decode: float,
) -> dict[str, np.ndarray]:
prediction = self.session.run(
self.output_names,
{
"frame": frame,
"enc_ret_kv": state["enc_ret_kv"],
"enc_ret_scale": state["enc_ret_scale"],
"enc_conv_cache": state["enc_conv_cache"],
"dec_ret_kv": state["dec_ret_kv"],
"dec_ret_scale": state["dec_ret_scale"],
"top_buffer": state["top_buffer"],
"ingest": np.array([ingest], dtype=np.float32),
"decode": np.array([decode], dtype=np.float32),
},
)
return {name: value for name, value in zip(self.output_names, prediction)}
@staticmethod
def _next_state(prediction: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
return {
"enc_ret_kv": prediction["enc_ret_kv_out"],
"enc_ret_scale": prediction["enc_ret_scale_out"],
"enc_conv_cache": prediction["enc_conv_cache_out"],
"dec_ret_kv": prediction["dec_ret_kv_out"],
"dec_ret_scale": prediction["dec_ret_scale_out"],
"top_buffer": prediction["top_buffer_out"],
}