from __future__ import annotations import numpy as np from dataclasses import dataclass import librosa import feature as ls_feature @dataclass class StreamingUpdate: start_frame: int logits: np.ndarray probabilities: np.ndarray preview_start_frame: int preview_logits: np.ndarray preview_probabilities: np.ndarray frame_hz: float duration_seconds: float total_emitted_frames: int class StreamingFeatureExtractor: """Incremental feature extraction that matches the repo's offline path.""" def __init__(self, config: dict) -> None: feat_config = config["data"]["feat"] self.sample_rate = int(feat_config["sample_rate"]) self.frame_size = int(feat_config["win_length"]) self.n_fft = 1 << (self.frame_size - 1).bit_length() self.hop_length = int(feat_config["hop_length"]) self.n_mels = int(feat_config["n_mels"]) self.context_size = int(config["data"]["context_recp"]) self.subsampling = int(config["data"]["subsampling"]) self.transform_type = str(config["data"]["feat_type"]) self.left_pad = self.n_fft // 2 self.model_input_dim = (2 * self.context_size + 1) * self.n_mels self.audio_buffer = np.zeros(0, dtype=np.float32) self.audio_start_sample = 0 self.total_samples = 0 self.next_stft_frame = 0 self.next_model_frame = 0 self.base_feature_start = 0 self.base_feature_buffer = np.zeros((0, self.n_mels), dtype=np.float32) self.cumulative_feature_sum = np.zeros(self.n_mels, dtype=np.float64) self.mel_basis = librosa.filters.mel( sr=self.sample_rate, n_fft=self.n_fft, n_mels=self.n_mels, ) def push_audio(self, chunk: np.ndarray) -> np.ndarray: if chunk.size == 0: return np.zeros((0, self.model_input_dim), dtype=np.float32) self.audio_buffer = np.concatenate([self.audio_buffer, chunk.astype(np.float32, copy=False)], axis=0) self.total_samples += len(chunk) self._append_stft_frames(self._stable_stft_frame_count(), allow_right_pad=False) return self._emit_model_frames(final=False) def finalize(self) -> np.ndarray: usable_samples = self._usable_sample_count(self.total_samples) total_stft_frames = self._offline_stft_frame_count(usable_samples) self._append_stft_frames(total_stft_frames, allow_right_pad=True, effective_total_samples=usable_samples) return self._emit_model_frames(final=True, total_stft_frames=total_stft_frames) def _usable_sample_count(self, sample_count: int) -> int: block_size = self.hop_length * self.subsampling return (sample_count // block_size) * block_size def _stable_stft_frame_count(self) -> int: if self.total_samples <= self.left_pad: return 0 return max(0, (self.total_samples - self.left_pad) // self.hop_length + 1) def _offline_stft_frame_count(self, usable_samples: int) -> int: if usable_samples <= 0: return 0 return max(0, usable_samples // self.hop_length - 1) def _model_frame_count(self, total_stft_frames: int) -> int: if total_stft_frames <= 0: return 0 return (total_stft_frames + self.subsampling - 1) // self.subsampling def _append_stft_frames( self, target_frame_count: int, allow_right_pad: bool, effective_total_samples: int | None = None, ) -> None: if target_frame_count <= self.next_stft_frame: return frame_start = self.next_stft_frame frame_stop = target_frame_count segment = self._stft_segment( frame_start=frame_start, frame_stop=frame_stop, allow_right_pad=allow_right_pad, effective_total_samples=effective_total_samples, ) stft = librosa.stft( segment, n_fft=self.n_fft, win_length=self.frame_size, hop_length=self.hop_length, center=False, ).T expected_frames = frame_stop - frame_start if stft.shape[0] < expected_frames: raise RuntimeError( f"Streaming STFT underflow: expected {expected_frames} frames, got {stft.shape[0]}" ) stft = stft[:expected_frames] transformed = self._transform_batch(stft, frame_start) if self.base_feature_buffer.size == 0: self.base_feature_buffer = transformed else: self.base_feature_buffer = np.concatenate([self.base_feature_buffer, transformed], axis=0) self.next_stft_frame = frame_stop self._drop_consumed_audio() def _stft_segment( self, frame_start: int, frame_stop: int, allow_right_pad: bool, effective_total_samples: int | None, ) -> np.ndarray: if frame_stop <= frame_start: return np.zeros(0, dtype=np.float32) total_samples = self.total_samples if effective_total_samples is None else effective_total_samples global_start = frame_start * self.hop_length - self.left_pad global_stop = (frame_stop - 1) * self.hop_length - self.left_pad + self.n_fft prefix = np.zeros(max(0, -global_start), dtype=np.float32) suffix = np.zeros(max(0, global_stop - total_samples), dtype=np.float32) if allow_right_pad else np.zeros(0, dtype=np.float32) raw_start = max(0, global_start) raw_stop = min(total_samples, global_stop) if raw_start < self.audio_start_sample: raise RuntimeError( f"Audio buffer underflow: need sample {raw_start}, buffer starts at {self.audio_start_sample}" ) local_start = raw_start - self.audio_start_sample local_stop = raw_stop - self.audio_start_sample core = self.audio_buffer[local_start:local_stop] if prefix.size == 0 and suffix.size == 0: return core return np.concatenate([prefix, core, suffix], axis=0) def _transform_batch(self, stft: np.ndarray, frame_start: int) -> np.ndarray: if self.transform_type != "logmel23_cummn": return ls_feature.transform(stft, self.transform_type) magnitude = np.abs(stft) mel = np.dot(magnitude ** 2, self.mel_basis.T) logmel = np.log10(np.maximum(mel, 1e-10)) counts = np.arange(frame_start + 1, frame_start + 1 + len(logmel), dtype=np.float64)[:, None] cumsum = np.cumsum(logmel, axis=0, dtype=np.float64) + self.cumulative_feature_sum[None, :] cummean = cumsum / counts self.cumulative_feature_sum = cumsum[-1] return (logmel - cummean).astype(np.float32, copy=False) def _emit_model_frames(self, final: bool, total_stft_frames: int | None = None) -> np.ndarray: outputs: list[np.ndarray] = [] latest_frame = self.next_stft_frame - 1 total_model_frames = self._model_frame_count(total_stft_frames or 0) if final else None while True: center_index = self.next_model_frame * self.subsampling if final: if total_model_frames is None or self.next_model_frame >= total_model_frames: break max_index = (total_stft_frames or 0) - 1 else: if center_index + self.context_size > latest_frame: break max_index = latest_frame outputs.append(self._splice_frame(center_index, max_index)) self.next_model_frame += 1 self._drop_consumed_base_features() if not outputs: return np.zeros((0, self.model_input_dim), dtype=np.float32) return np.stack(outputs, axis=0).astype(np.float32, copy=False) def _splice_frame(self, center_index: int, max_index: int) -> np.ndarray: pieces = [] for frame_index in range(center_index - self.context_size, center_index + self.context_size + 1): if frame_index < 0 or frame_index > max_index: pieces.append(np.zeros(self.n_mels, dtype=np.float32)) continue local_index = frame_index - self.base_feature_start if local_index < 0 or local_index >= self.base_feature_buffer.shape[0]: raise RuntimeError( f"Feature buffer underflow: need frame {frame_index}, buffer covers " f"[{self.base_feature_start}, {self.base_feature_start + self.base_feature_buffer.shape[0] - 1}]" ) pieces.append(self.base_feature_buffer[local_index]) return np.concatenate(pieces, axis=0) def _drop_consumed_audio(self) -> None: keep_from = max(0, self.next_stft_frame * self.hop_length - self.left_pad) drop = keep_from - self.audio_start_sample if drop <= 0: return self.audio_buffer = self.audio_buffer[drop:] self.audio_start_sample += drop def _drop_consumed_base_features(self) -> None: keep_from = max(0, self.next_model_frame * self.subsampling - self.context_size) drop = keep_from - self.base_feature_start if drop <= 0: return self.base_feature_buffer = self.base_feature_buffer[drop:] self.base_feature_start += drop