LS-EEND-ONNX / example /ls_eend_streaming_common.py
GradientDescent2718's picture
Upload 7 files
17dd01d verified
from __future__ import annotations
import numpy as np
from dataclasses import dataclass
import librosa
import feature as ls_feature
@dataclass
class StreamingUpdate:
start_frame: int
logits: np.ndarray
probabilities: np.ndarray
preview_start_frame: int
preview_logits: np.ndarray
preview_probabilities: np.ndarray
frame_hz: float
duration_seconds: float
total_emitted_frames: int
class StreamingFeatureExtractor:
"""Incremental feature extraction that matches the repo's offline path."""
def __init__(self, config: dict) -> None:
feat_config = config["data"]["feat"]
self.sample_rate = int(feat_config["sample_rate"])
self.frame_size = int(feat_config["win_length"])
self.n_fft = 1 << (self.frame_size - 1).bit_length()
self.hop_length = int(feat_config["hop_length"])
self.n_mels = int(feat_config["n_mels"])
self.context_size = int(config["data"]["context_recp"])
self.subsampling = int(config["data"]["subsampling"])
self.transform_type = str(config["data"]["feat_type"])
self.left_pad = self.n_fft // 2
self.model_input_dim = (2 * self.context_size + 1) * self.n_mels
self.audio_buffer = np.zeros(0, dtype=np.float32)
self.audio_start_sample = 0
self.total_samples = 0
self.next_stft_frame = 0
self.next_model_frame = 0
self.base_feature_start = 0
self.base_feature_buffer = np.zeros((0, self.n_mels), dtype=np.float32)
self.cumulative_feature_sum = np.zeros(self.n_mels, dtype=np.float64)
self.mel_basis = librosa.filters.mel(
sr=self.sample_rate,
n_fft=self.n_fft,
n_mels=self.n_mels,
)
def push_audio(self, chunk: np.ndarray) -> np.ndarray:
if chunk.size == 0:
return np.zeros((0, self.model_input_dim), dtype=np.float32)
self.audio_buffer = np.concatenate([self.audio_buffer, chunk.astype(np.float32, copy=False)], axis=0)
self.total_samples += len(chunk)
self._append_stft_frames(self._stable_stft_frame_count(), allow_right_pad=False)
return self._emit_model_frames(final=False)
def finalize(self) -> np.ndarray:
usable_samples = self._usable_sample_count(self.total_samples)
total_stft_frames = self._offline_stft_frame_count(usable_samples)
self._append_stft_frames(total_stft_frames, allow_right_pad=True, effective_total_samples=usable_samples)
return self._emit_model_frames(final=True, total_stft_frames=total_stft_frames)
def _usable_sample_count(self, sample_count: int) -> int:
block_size = self.hop_length * self.subsampling
return (sample_count // block_size) * block_size
def _stable_stft_frame_count(self) -> int:
if self.total_samples <= self.left_pad:
return 0
return max(0, (self.total_samples - self.left_pad) // self.hop_length + 1)
def _offline_stft_frame_count(self, usable_samples: int) -> int:
if usable_samples <= 0:
return 0
return max(0, usable_samples // self.hop_length - 1)
def _model_frame_count(self, total_stft_frames: int) -> int:
if total_stft_frames <= 0:
return 0
return (total_stft_frames + self.subsampling - 1) // self.subsampling
def _append_stft_frames(
self,
target_frame_count: int,
allow_right_pad: bool,
effective_total_samples: int | None = None,
) -> None:
if target_frame_count <= self.next_stft_frame:
return
frame_start = self.next_stft_frame
frame_stop = target_frame_count
segment = self._stft_segment(
frame_start=frame_start,
frame_stop=frame_stop,
allow_right_pad=allow_right_pad,
effective_total_samples=effective_total_samples,
)
stft = librosa.stft(
segment,
n_fft=self.n_fft,
win_length=self.frame_size,
hop_length=self.hop_length,
center=False,
).T
expected_frames = frame_stop - frame_start
if stft.shape[0] < expected_frames:
raise RuntimeError(
f"Streaming STFT underflow: expected {expected_frames} frames, got {stft.shape[0]}"
)
stft = stft[:expected_frames]
transformed = self._transform_batch(stft, frame_start)
if self.base_feature_buffer.size == 0:
self.base_feature_buffer = transformed
else:
self.base_feature_buffer = np.concatenate([self.base_feature_buffer, transformed], axis=0)
self.next_stft_frame = frame_stop
self._drop_consumed_audio()
def _stft_segment(
self,
frame_start: int,
frame_stop: int,
allow_right_pad: bool,
effective_total_samples: int | None,
) -> np.ndarray:
if frame_stop <= frame_start:
return np.zeros(0, dtype=np.float32)
total_samples = self.total_samples if effective_total_samples is None else effective_total_samples
global_start = frame_start * self.hop_length - self.left_pad
global_stop = (frame_stop - 1) * self.hop_length - self.left_pad + self.n_fft
prefix = np.zeros(max(0, -global_start), dtype=np.float32)
suffix = np.zeros(max(0, global_stop - total_samples), dtype=np.float32) if allow_right_pad else np.zeros(0, dtype=np.float32)
raw_start = max(0, global_start)
raw_stop = min(total_samples, global_stop)
if raw_start < self.audio_start_sample:
raise RuntimeError(
f"Audio buffer underflow: need sample {raw_start}, buffer starts at {self.audio_start_sample}"
)
local_start = raw_start - self.audio_start_sample
local_stop = raw_stop - self.audio_start_sample
core = self.audio_buffer[local_start:local_stop]
if prefix.size == 0 and suffix.size == 0:
return core
return np.concatenate([prefix, core, suffix], axis=0)
def _transform_batch(self, stft: np.ndarray, frame_start: int) -> np.ndarray:
if self.transform_type != "logmel23_cummn":
return ls_feature.transform(stft, self.transform_type)
magnitude = np.abs(stft)
mel = np.dot(magnitude ** 2, self.mel_basis.T)
logmel = np.log10(np.maximum(mel, 1e-10))
counts = np.arange(frame_start + 1, frame_start + 1 + len(logmel), dtype=np.float64)[:, None]
cumsum = np.cumsum(logmel, axis=0, dtype=np.float64) + self.cumulative_feature_sum[None, :]
cummean = cumsum / counts
self.cumulative_feature_sum = cumsum[-1]
return (logmel - cummean).astype(np.float32, copy=False)
def _emit_model_frames(self, final: bool, total_stft_frames: int | None = None) -> np.ndarray:
outputs: list[np.ndarray] = []
latest_frame = self.next_stft_frame - 1
total_model_frames = self._model_frame_count(total_stft_frames or 0) if final else None
while True:
center_index = self.next_model_frame * self.subsampling
if final:
if total_model_frames is None or self.next_model_frame >= total_model_frames:
break
max_index = (total_stft_frames or 0) - 1
else:
if center_index + self.context_size > latest_frame:
break
max_index = latest_frame
outputs.append(self._splice_frame(center_index, max_index))
self.next_model_frame += 1
self._drop_consumed_base_features()
if not outputs:
return np.zeros((0, self.model_input_dim), dtype=np.float32)
return np.stack(outputs, axis=0).astype(np.float32, copy=False)
def _splice_frame(self, center_index: int, max_index: int) -> np.ndarray:
pieces = []
for frame_index in range(center_index - self.context_size, center_index + self.context_size + 1):
if frame_index < 0 or frame_index > max_index:
pieces.append(np.zeros(self.n_mels, dtype=np.float32))
continue
local_index = frame_index - self.base_feature_start
if local_index < 0 or local_index >= self.base_feature_buffer.shape[0]:
raise RuntimeError(
f"Feature buffer underflow: need frame {frame_index}, buffer covers "
f"[{self.base_feature_start}, {self.base_feature_start + self.base_feature_buffer.shape[0] - 1}]"
)
pieces.append(self.base_feature_buffer[local_index])
return np.concatenate(pieces, axis=0)
def _drop_consumed_audio(self) -> None:
keep_from = max(0, self.next_stft_frame * self.hop_length - self.left_pad)
drop = keep_from - self.audio_start_sample
if drop <= 0:
return
self.audio_buffer = self.audio_buffer[drop:]
self.audio_start_sample += drop
def _drop_consumed_base_features(self) -> None:
keep_from = max(0, self.next_model_frame * self.subsampling - self.context_size)
drop = keep_from - self.base_feature_start
if drop <= 0:
return
self.base_feature_buffer = self.base_feature_buffer[drop:]
self.base_feature_start += drop