| from __future__ import annotations |
|
|
| import numpy as np |
| from dataclasses import dataclass |
|
|
| import librosa |
|
|
| import feature as ls_feature |
|
|
|
|
| @dataclass |
| class StreamingUpdate: |
| start_frame: int |
| logits: np.ndarray |
| probabilities: np.ndarray |
| preview_start_frame: int |
| preview_logits: np.ndarray |
| preview_probabilities: np.ndarray |
| frame_hz: float |
| duration_seconds: float |
| total_emitted_frames: int |
|
|
|
|
| class StreamingFeatureExtractor: |
| """Incremental feature extraction that matches the repo's offline path.""" |
|
|
| def __init__(self, config: dict) -> None: |
| feat_config = config["data"]["feat"] |
| self.sample_rate = int(feat_config["sample_rate"]) |
| self.frame_size = int(feat_config["win_length"]) |
| self.n_fft = 1 << (self.frame_size - 1).bit_length() |
| self.hop_length = int(feat_config["hop_length"]) |
| self.n_mels = int(feat_config["n_mels"]) |
| self.context_size = int(config["data"]["context_recp"]) |
| self.subsampling = int(config["data"]["subsampling"]) |
| self.transform_type = str(config["data"]["feat_type"]) |
| self.left_pad = self.n_fft // 2 |
| self.model_input_dim = (2 * self.context_size + 1) * self.n_mels |
|
|
| self.audio_buffer = np.zeros(0, dtype=np.float32) |
| self.audio_start_sample = 0 |
| self.total_samples = 0 |
|
|
| self.next_stft_frame = 0 |
| self.next_model_frame = 0 |
|
|
| self.base_feature_start = 0 |
| self.base_feature_buffer = np.zeros((0, self.n_mels), dtype=np.float32) |
| self.cumulative_feature_sum = np.zeros(self.n_mels, dtype=np.float64) |
|
|
| self.mel_basis = librosa.filters.mel( |
| sr=self.sample_rate, |
| n_fft=self.n_fft, |
| n_mels=self.n_mels, |
| ) |
|
|
| def push_audio(self, chunk: np.ndarray) -> np.ndarray: |
| if chunk.size == 0: |
| return np.zeros((0, self.model_input_dim), dtype=np.float32) |
| self.audio_buffer = np.concatenate([self.audio_buffer, chunk.astype(np.float32, copy=False)], axis=0) |
| self.total_samples += len(chunk) |
| self._append_stft_frames(self._stable_stft_frame_count(), allow_right_pad=False) |
| return self._emit_model_frames(final=False) |
|
|
| def finalize(self) -> np.ndarray: |
| usable_samples = self._usable_sample_count(self.total_samples) |
| total_stft_frames = self._offline_stft_frame_count(usable_samples) |
| self._append_stft_frames(total_stft_frames, allow_right_pad=True, effective_total_samples=usable_samples) |
| return self._emit_model_frames(final=True, total_stft_frames=total_stft_frames) |
|
|
| def _usable_sample_count(self, sample_count: int) -> int: |
| block_size = self.hop_length * self.subsampling |
| return (sample_count // block_size) * block_size |
|
|
| def _stable_stft_frame_count(self) -> int: |
| if self.total_samples <= self.left_pad: |
| return 0 |
| return max(0, (self.total_samples - self.left_pad) // self.hop_length + 1) |
|
|
| def _offline_stft_frame_count(self, usable_samples: int) -> int: |
| if usable_samples <= 0: |
| return 0 |
| return max(0, usable_samples // self.hop_length - 1) |
|
|
| def _model_frame_count(self, total_stft_frames: int) -> int: |
| if total_stft_frames <= 0: |
| return 0 |
| return (total_stft_frames + self.subsampling - 1) // self.subsampling |
|
|
| def _append_stft_frames( |
| self, |
| target_frame_count: int, |
| allow_right_pad: bool, |
| effective_total_samples: int | None = None, |
| ) -> None: |
| if target_frame_count <= self.next_stft_frame: |
| return |
| frame_start = self.next_stft_frame |
| frame_stop = target_frame_count |
| segment = self._stft_segment( |
| frame_start=frame_start, |
| frame_stop=frame_stop, |
| allow_right_pad=allow_right_pad, |
| effective_total_samples=effective_total_samples, |
| ) |
| stft = librosa.stft( |
| segment, |
| n_fft=self.n_fft, |
| win_length=self.frame_size, |
| hop_length=self.hop_length, |
| center=False, |
| ).T |
| expected_frames = frame_stop - frame_start |
| if stft.shape[0] < expected_frames: |
| raise RuntimeError( |
| f"Streaming STFT underflow: expected {expected_frames} frames, got {stft.shape[0]}" |
| ) |
| stft = stft[:expected_frames] |
| transformed = self._transform_batch(stft, frame_start) |
| if self.base_feature_buffer.size == 0: |
| self.base_feature_buffer = transformed |
| else: |
| self.base_feature_buffer = np.concatenate([self.base_feature_buffer, transformed], axis=0) |
| self.next_stft_frame = frame_stop |
| self._drop_consumed_audio() |
|
|
| def _stft_segment( |
| self, |
| frame_start: int, |
| frame_stop: int, |
| allow_right_pad: bool, |
| effective_total_samples: int | None, |
| ) -> np.ndarray: |
| if frame_stop <= frame_start: |
| return np.zeros(0, dtype=np.float32) |
| total_samples = self.total_samples if effective_total_samples is None else effective_total_samples |
| global_start = frame_start * self.hop_length - self.left_pad |
| global_stop = (frame_stop - 1) * self.hop_length - self.left_pad + self.n_fft |
|
|
| prefix = np.zeros(max(0, -global_start), dtype=np.float32) |
| suffix = np.zeros(max(0, global_stop - total_samples), dtype=np.float32) if allow_right_pad else np.zeros(0, dtype=np.float32) |
|
|
| raw_start = max(0, global_start) |
| raw_stop = min(total_samples, global_stop) |
| if raw_start < self.audio_start_sample: |
| raise RuntimeError( |
| f"Audio buffer underflow: need sample {raw_start}, buffer starts at {self.audio_start_sample}" |
| ) |
| local_start = raw_start - self.audio_start_sample |
| local_stop = raw_stop - self.audio_start_sample |
| core = self.audio_buffer[local_start:local_stop] |
| if prefix.size == 0 and suffix.size == 0: |
| return core |
| return np.concatenate([prefix, core, suffix], axis=0) |
|
|
| def _transform_batch(self, stft: np.ndarray, frame_start: int) -> np.ndarray: |
| if self.transform_type != "logmel23_cummn": |
| return ls_feature.transform(stft, self.transform_type) |
| magnitude = np.abs(stft) |
| mel = np.dot(magnitude ** 2, self.mel_basis.T) |
| logmel = np.log10(np.maximum(mel, 1e-10)) |
| counts = np.arange(frame_start + 1, frame_start + 1 + len(logmel), dtype=np.float64)[:, None] |
| cumsum = np.cumsum(logmel, axis=0, dtype=np.float64) + self.cumulative_feature_sum[None, :] |
| cummean = cumsum / counts |
| self.cumulative_feature_sum = cumsum[-1] |
| return (logmel - cummean).astype(np.float32, copy=False) |
|
|
| def _emit_model_frames(self, final: bool, total_stft_frames: int | None = None) -> np.ndarray: |
| outputs: list[np.ndarray] = [] |
| latest_frame = self.next_stft_frame - 1 |
| total_model_frames = self._model_frame_count(total_stft_frames or 0) if final else None |
| while True: |
| center_index = self.next_model_frame * self.subsampling |
| if final: |
| if total_model_frames is None or self.next_model_frame >= total_model_frames: |
| break |
| max_index = (total_stft_frames or 0) - 1 |
| else: |
| if center_index + self.context_size > latest_frame: |
| break |
| max_index = latest_frame |
| outputs.append(self._splice_frame(center_index, max_index)) |
| self.next_model_frame += 1 |
| self._drop_consumed_base_features() |
| if not outputs: |
| return np.zeros((0, self.model_input_dim), dtype=np.float32) |
| return np.stack(outputs, axis=0).astype(np.float32, copy=False) |
|
|
| def _splice_frame(self, center_index: int, max_index: int) -> np.ndarray: |
| pieces = [] |
| for frame_index in range(center_index - self.context_size, center_index + self.context_size + 1): |
| if frame_index < 0 or frame_index > max_index: |
| pieces.append(np.zeros(self.n_mels, dtype=np.float32)) |
| continue |
| local_index = frame_index - self.base_feature_start |
| if local_index < 0 or local_index >= self.base_feature_buffer.shape[0]: |
| raise RuntimeError( |
| f"Feature buffer underflow: need frame {frame_index}, buffer covers " |
| f"[{self.base_feature_start}, {self.base_feature_start + self.base_feature_buffer.shape[0] - 1}]" |
| ) |
| pieces.append(self.base_feature_buffer[local_index]) |
| return np.concatenate(pieces, axis=0) |
|
|
| def _drop_consumed_audio(self) -> None: |
| keep_from = max(0, self.next_stft_frame * self.hop_length - self.left_pad) |
| drop = keep_from - self.audio_start_sample |
| if drop <= 0: |
| return |
| self.audio_buffer = self.audio_buffer[drop:] |
| self.audio_start_sample += drop |
|
|
| def _drop_consumed_base_features(self) -> None: |
| keep_from = max(0, self.next_model_frame * self.subsampling - self.context_size) |
| drop = keep_from - self.base_feature_start |
| if drop <= 0: |
| return |
| self.base_feature_buffer = self.base_feature_buffer[drop:] |
| self.base_feature_start += drop |
|
|