justice / audio_cache_io.py
hidude561's picture
Upload audio_cache_io.py with huggingface_hub
df5e685 verified
"""Opus-based audio cache I/O for v86 feature cache.
Encode: float32 24kHz mono ndarray β†’ Opus@96kbps bytes (~45 KB / 4 sec).
Decode: bytes β†’ float32 24kHz mono ndarray (exact length 96000).
Opus internally resamples to 48 kHz; we resample back to 24 kHz on decode
and trim the 156-sample pre-padding to align with the original.
"""
from __future__ import annotations
import io
import numpy as np
import av
import scipy.signal as sps
SR_TARGET = 24000
SR_OPUS = 48000
EXPECTED_SAMPLES = 96000 # 4 sec @ 24 kHz
OPUS_BIT_RATE = 96000 # 96 kbps
# Empirically the decoded length matches 96000 already (no trimming needed).
# Residual lag is ~47 samples (2 ms) which is acceptable for instrument audio
# since the codec/LM windows are 4 sec each.
OPUS_PRE_DELAY_24K = 0
def encode_opus(audio_24k_f32: np.ndarray, bit_rate: int = OPUS_BIT_RATE) -> bytes:
"""Encode (96000,) float32 24kHz mono β†’ Opus bytes."""
assert audio_24k_f32.ndim == 1, audio_24k_f32.shape
assert audio_24k_f32.dtype == np.float32, audio_24k_f32.dtype
# PyAV expects (channels, samples) for non-planar formats β€” but we use flt which is planar.
audio_2d = audio_24k_f32.reshape(1, -1)
buf = io.BytesIO()
container = av.open(buf, mode="w", format="ogg")
stream = container.add_stream("libopus", rate=SR_TARGET)
stream.bit_rate = bit_rate
stream.layout = "mono"
frame = av.AudioFrame.from_ndarray(audio_2d, format="flt", layout="mono")
frame.sample_rate = SR_TARGET
for packet in stream.encode(frame):
container.mux(packet)
for packet in stream.encode(): # flush
container.mux(packet)
container.close()
return buf.getvalue()
def decode_opus(opus_bytes: bytes,
target_samples: int = EXPECTED_SAMPLES) -> np.ndarray:
"""Decode Opus bytes β†’ (target_samples,) float32 24kHz mono."""
buf = io.BytesIO(opus_bytes)
container = av.open(buf, mode="r")
frames = []
for frame in container.decode(audio=0):
# frame.to_ndarray() returns (channels, samples) for planar
arr = frame.to_ndarray()
if arr.ndim == 2: arr = arr[0]
frames.append(arr.astype(np.float32))
container.close()
audio_48k = np.concatenate(frames)
# Resample 48k β†’ 24k
audio_24k = sps.resample_poly(audio_48k, 1, 2).astype(np.float32)
# Drop pre-padding
if len(audio_24k) > OPUS_PRE_DELAY_24K:
audio_24k = audio_24k[OPUS_PRE_DELAY_24K:]
# Pad or trim to exact length
if len(audio_24k) < target_samples:
audio_24k = np.pad(audio_24k, (0, target_samples - len(audio_24k)))
elif len(audio_24k) > target_samples:
audio_24k = audio_24k[:target_samples]
return audio_24k
if __name__ == "__main__":
# Smoke test with real-ish audio (sine + noise)
import time
sr = SR_TARGET
t = np.arange(EXPECTED_SAMPLES) / sr
audio = (
0.3 * np.sin(2 * np.pi * 440 * t) +
0.2 * np.sin(2 * np.pi * 880 * t) +
0.05 * np.random.randn(EXPECTED_SAMPLES)
).astype(np.float32)
audio = np.clip(audio, -1, 1)
t0 = time.time()
enc = encode_opus(audio)
enc_ms = (time.time() - t0) * 1000
t0 = time.time()
dec = decode_opus(enc)
dec_ms = (time.time() - t0) * 1000
print(f"Input: {audio.shape} {audio.dtype} max={np.abs(audio).max():.3f}")
print(f"Encoded: {len(enc)} bytes (encode took {enc_ms:.1f} ms)")
print(f"Decoded: {dec.shape} {dec.dtype} (decode took {dec_ms:.1f} ms)")
print(f"Compression vs float32: {audio.nbytes / len(enc):.1f}x")
diff = audio - dec
print(f"Time-domain max abs diff: {np.abs(diff).max():.4f}")
print(f" L1 mean diff: {np.abs(diff).mean():.4f}")
# Spectral check
import scipy.signal as sps_
f, Pxx_in = sps_.welch(audio, fs=sr, nperseg=2048)
f, Pxx_out = sps_.welch(dec, fs=sr, nperseg=2048)
spec_db = 10 * np.log10(np.mean((np.log10(Pxx_in + 1e-10) - np.log10(Pxx_out + 1e-10))**2))
print(f" Spectral log-power MSE (dB): {spec_db:.2f}")