Spaces:
Running
on
Zero
Running
on
Zero
| from __future__ import annotations | |
| from pathlib import Path | |
| from typing import Union | |
| import numpy as np | |
| import sphn | |
| import torch | |
| import torch.nn.functional as F | |
| from ..audio import MimiCodec | |
| PathLike = Union[str, Path] | |
| def load_mono_audio(path: PathLike, target_sr: int) -> np.ndarray: | |
| """Read an audio file, convert to mono float32, and resample to target_sr.""" | |
| path = str(path) | |
| try: | |
| audio, sr = sphn.read_wav(path) | |
| except Exception: | |
| import soundfile as sf # Local fallback | |
| audio, sr = sf.read(path, dtype="float32", always_2d=False) | |
| audio = np.asarray(audio, dtype=np.float32) | |
| if audio.ndim == 2: | |
| audio = audio.mean(axis=1) | |
| if sr != target_sr: | |
| if hasattr(sphn, "resample_audio"): | |
| audio = sphn.resample_audio(audio, sr, target_sr).astype(np.float32) | |
| else: | |
| audio = _resample_linear(audio, sr, target_sr) | |
| return audio | |
| def audio_to_tensor(audio: np.ndarray, device: torch.device) -> torch.Tensor: | |
| """Convert mono PCM samples into shape [1, 1, T] tensor.""" | |
| tensor = torch.from_numpy(audio).to(device) | |
| if tensor.dim() == 1: | |
| tensor = tensor.unsqueeze(0) | |
| if tensor.dim() == 2: | |
| tensor = tensor.unsqueeze(0) | |
| return tensor | |
| def encode_audio_tokens(mimi: MimiCodec, audio: np.ndarray) -> torch.Tensor: | |
| """Encode PCM audio into Mimi codebook tokens [C, T].""" | |
| waveform = audio_to_tensor(audio, mimi.device) | |
| with torch.inference_mode(): | |
| codes, *_ = mimi.encode(waveform, return_dict=False) | |
| if isinstance(codes, (tuple, list)): | |
| codes = codes[0] | |
| # Mimi.encode returns [B, num_codebooks, T]; select batch 0. | |
| codes = codes[0].to(torch.long) | |
| return codes | |
| def _resample_linear(audio: np.ndarray, src_sr: int, dst_sr: int) -> np.ndarray: | |
| if src_sr == dst_sr: | |
| return audio.astype(np.float32) | |
| length = audio.shape[0] | |
| new_length = max(1, int(round(length * dst_sr / src_sr))) | |
| tensor = torch.from_numpy(audio.astype(np.float32)).unsqueeze(0).unsqueeze(0) | |
| with torch.no_grad(): | |
| resampled = F.interpolate(tensor, size=new_length, mode="linear", align_corners=False) | |
| return resampled.squeeze(0).squeeze(0).cpu().numpy().astype(np.float32) | |
| __all__ = ["load_mono_audio", "audio_to_tensor", "encode_audio_tokens"] | |