File size: 9,363 Bytes
17dd01d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 | from __future__ import annotations
import numpy as np
from dataclasses import dataclass
import librosa
import feature as ls_feature
@dataclass
class StreamingUpdate:
start_frame: int
logits: np.ndarray
probabilities: np.ndarray
preview_start_frame: int
preview_logits: np.ndarray
preview_probabilities: np.ndarray
frame_hz: float
duration_seconds: float
total_emitted_frames: int
class StreamingFeatureExtractor:
"""Incremental feature extraction that matches the repo's offline path."""
def __init__(self, config: dict) -> None:
feat_config = config["data"]["feat"]
self.sample_rate = int(feat_config["sample_rate"])
self.frame_size = int(feat_config["win_length"])
self.n_fft = 1 << (self.frame_size - 1).bit_length()
self.hop_length = int(feat_config["hop_length"])
self.n_mels = int(feat_config["n_mels"])
self.context_size = int(config["data"]["context_recp"])
self.subsampling = int(config["data"]["subsampling"])
self.transform_type = str(config["data"]["feat_type"])
self.left_pad = self.n_fft // 2
self.model_input_dim = (2 * self.context_size + 1) * self.n_mels
self.audio_buffer = np.zeros(0, dtype=np.float32)
self.audio_start_sample = 0
self.total_samples = 0
self.next_stft_frame = 0
self.next_model_frame = 0
self.base_feature_start = 0
self.base_feature_buffer = np.zeros((0, self.n_mels), dtype=np.float32)
self.cumulative_feature_sum = np.zeros(self.n_mels, dtype=np.float64)
self.mel_basis = librosa.filters.mel(
sr=self.sample_rate,
n_fft=self.n_fft,
n_mels=self.n_mels,
)
def push_audio(self, chunk: np.ndarray) -> np.ndarray:
if chunk.size == 0:
return np.zeros((0, self.model_input_dim), dtype=np.float32)
self.audio_buffer = np.concatenate([self.audio_buffer, chunk.astype(np.float32, copy=False)], axis=0)
self.total_samples += len(chunk)
self._append_stft_frames(self._stable_stft_frame_count(), allow_right_pad=False)
return self._emit_model_frames(final=False)
def finalize(self) -> np.ndarray:
usable_samples = self._usable_sample_count(self.total_samples)
total_stft_frames = self._offline_stft_frame_count(usable_samples)
self._append_stft_frames(total_stft_frames, allow_right_pad=True, effective_total_samples=usable_samples)
return self._emit_model_frames(final=True, total_stft_frames=total_stft_frames)
def _usable_sample_count(self, sample_count: int) -> int:
block_size = self.hop_length * self.subsampling
return (sample_count // block_size) * block_size
def _stable_stft_frame_count(self) -> int:
if self.total_samples <= self.left_pad:
return 0
return max(0, (self.total_samples - self.left_pad) // self.hop_length + 1)
def _offline_stft_frame_count(self, usable_samples: int) -> int:
if usable_samples <= 0:
return 0
return max(0, usable_samples // self.hop_length - 1)
def _model_frame_count(self, total_stft_frames: int) -> int:
if total_stft_frames <= 0:
return 0
return (total_stft_frames + self.subsampling - 1) // self.subsampling
def _append_stft_frames(
self,
target_frame_count: int,
allow_right_pad: bool,
effective_total_samples: int | None = None,
) -> None:
if target_frame_count <= self.next_stft_frame:
return
frame_start = self.next_stft_frame
frame_stop = target_frame_count
segment = self._stft_segment(
frame_start=frame_start,
frame_stop=frame_stop,
allow_right_pad=allow_right_pad,
effective_total_samples=effective_total_samples,
)
stft = librosa.stft(
segment,
n_fft=self.n_fft,
win_length=self.frame_size,
hop_length=self.hop_length,
center=False,
).T
expected_frames = frame_stop - frame_start
if stft.shape[0] < expected_frames:
raise RuntimeError(
f"Streaming STFT underflow: expected {expected_frames} frames, got {stft.shape[0]}"
)
stft = stft[:expected_frames]
transformed = self._transform_batch(stft, frame_start)
if self.base_feature_buffer.size == 0:
self.base_feature_buffer = transformed
else:
self.base_feature_buffer = np.concatenate([self.base_feature_buffer, transformed], axis=0)
self.next_stft_frame = frame_stop
self._drop_consumed_audio()
def _stft_segment(
self,
frame_start: int,
frame_stop: int,
allow_right_pad: bool,
effective_total_samples: int | None,
) -> np.ndarray:
if frame_stop <= frame_start:
return np.zeros(0, dtype=np.float32)
total_samples = self.total_samples if effective_total_samples is None else effective_total_samples
global_start = frame_start * self.hop_length - self.left_pad
global_stop = (frame_stop - 1) * self.hop_length - self.left_pad + self.n_fft
prefix = np.zeros(max(0, -global_start), dtype=np.float32)
suffix = np.zeros(max(0, global_stop - total_samples), dtype=np.float32) if allow_right_pad else np.zeros(0, dtype=np.float32)
raw_start = max(0, global_start)
raw_stop = min(total_samples, global_stop)
if raw_start < self.audio_start_sample:
raise RuntimeError(
f"Audio buffer underflow: need sample {raw_start}, buffer starts at {self.audio_start_sample}"
)
local_start = raw_start - self.audio_start_sample
local_stop = raw_stop - self.audio_start_sample
core = self.audio_buffer[local_start:local_stop]
if prefix.size == 0 and suffix.size == 0:
return core
return np.concatenate([prefix, core, suffix], axis=0)
def _transform_batch(self, stft: np.ndarray, frame_start: int) -> np.ndarray:
if self.transform_type != "logmel23_cummn":
return ls_feature.transform(stft, self.transform_type)
magnitude = np.abs(stft)
mel = np.dot(magnitude ** 2, self.mel_basis.T)
logmel = np.log10(np.maximum(mel, 1e-10))
counts = np.arange(frame_start + 1, frame_start + 1 + len(logmel), dtype=np.float64)[:, None]
cumsum = np.cumsum(logmel, axis=0, dtype=np.float64) + self.cumulative_feature_sum[None, :]
cummean = cumsum / counts
self.cumulative_feature_sum = cumsum[-1]
return (logmel - cummean).astype(np.float32, copy=False)
def _emit_model_frames(self, final: bool, total_stft_frames: int | None = None) -> np.ndarray:
outputs: list[np.ndarray] = []
latest_frame = self.next_stft_frame - 1
total_model_frames = self._model_frame_count(total_stft_frames or 0) if final else None
while True:
center_index = self.next_model_frame * self.subsampling
if final:
if total_model_frames is None or self.next_model_frame >= total_model_frames:
break
max_index = (total_stft_frames or 0) - 1
else:
if center_index + self.context_size > latest_frame:
break
max_index = latest_frame
outputs.append(self._splice_frame(center_index, max_index))
self.next_model_frame += 1
self._drop_consumed_base_features()
if not outputs:
return np.zeros((0, self.model_input_dim), dtype=np.float32)
return np.stack(outputs, axis=0).astype(np.float32, copy=False)
def _splice_frame(self, center_index: int, max_index: int) -> np.ndarray:
pieces = []
for frame_index in range(center_index - self.context_size, center_index + self.context_size + 1):
if frame_index < 0 or frame_index > max_index:
pieces.append(np.zeros(self.n_mels, dtype=np.float32))
continue
local_index = frame_index - self.base_feature_start
if local_index < 0 or local_index >= self.base_feature_buffer.shape[0]:
raise RuntimeError(
f"Feature buffer underflow: need frame {frame_index}, buffer covers "
f"[{self.base_feature_start}, {self.base_feature_start + self.base_feature_buffer.shape[0] - 1}]"
)
pieces.append(self.base_feature_buffer[local_index])
return np.concatenate(pieces, axis=0)
def _drop_consumed_audio(self) -> None:
keep_from = max(0, self.next_stft_frame * self.hop_length - self.left_pad)
drop = keep_from - self.audio_start_sample
if drop <= 0:
return
self.audio_buffer = self.audio_buffer[drop:]
self.audio_start_sample += drop
def _drop_consumed_base_features(self) -> None:
keep_from = max(0, self.next_model_frame * self.subsampling - self.context_size)
drop = keep_from - self.base_feature_start
if drop <= 0:
return
self.base_feature_buffer = self.base_feature_buffer[drop:]
self.base_feature_start += drop
|