import math import os import tempfile from dataclasses import dataclass from typing import List, Optional, Tuple import gradio as gr import numpy as np import soundfile as sf import torch import torch.nn.functional as F from pydantic import BaseModel from scipy.signal import resample as scipy_resample from torch import nn from torch.nn.utils import weight_norm from huggingface_hub import hf_hub_download # ========================================================= # AudioVAE model definition (single-file, standalone) # ========================================================= def WNConv1d(*args, **kwargs): return weight_norm(nn.Conv1d(*args, **kwargs)) def WNConvTranspose1d(*args, **kwargs): return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) class CausalConv1d(nn.Conv1d): def __init__(self, *args, padding: int = 0, output_padding: int = 0, **kwargs): super().__init__(*args, **kwargs) self.__padding = padding self.__output_padding = output_padding def forward(self, x): x_pad = F.pad(x, (self.__padding * 2 - self.__output_padding, 0)) return super().forward(x_pad) class CausalTransposeConv1d(nn.ConvTranspose1d): def __init__(self, *args, padding: int = 0, output_padding: int = 0, **kwargs): super().__init__(*args, **kwargs) self.__padding = padding self.__output_padding = output_padding def forward(self, x): return super().forward(x)[..., : -(self.__padding * 2 - self.__output_padding)] def WNCausalConv1d(*args, **kwargs): return weight_norm(CausalConv1d(*args, **kwargs)) def WNCausalTransposeConv1d(*args, **kwargs): return weight_norm(CausalTransposeConv1d(*args, **kwargs)) @torch.jit.script def snake(x, alpha): shape = x.shape x = x.reshape(shape[0], shape[1], -1) x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) x = x.reshape(shape) return x class Snake1d(nn.Module): def __init__(self, channels): super().__init__() self.alpha = nn.Parameter(torch.ones(1, channels, 1)) def forward(self, x): return snake(x, self.alpha) class CausalResidualUnit(nn.Module): def __init__(self, dim: int = 16, dilation: int = 1, kernel: int = 7, groups: int = 1): super().__init__() pad = ((7 - 1) * dilation) // 2 self.block = nn.Sequential( Snake1d(dim), WNCausalConv1d( dim, dim, kernel_size=kernel, dilation=dilation, padding=pad, groups=groups, ), Snake1d(dim), WNCausalConv1d(dim, dim, kernel_size=1), ) def forward(self, x): y = self.block(x) pad = (x.shape[-1] - y.shape[-1]) // 2 assert pad == 0 if pad > 0: x = x[..., pad:-pad] return x + y class CausalEncoderBlock(nn.Module): def __init__(self, output_dim: int = 16, input_dim=None, stride: int = 1, groups=1): super().__init__() input_dim = input_dim or output_dim // 2 self.block = nn.Sequential( CausalResidualUnit(input_dim, dilation=1, groups=groups), CausalResidualUnit(input_dim, dilation=3, groups=groups), CausalResidualUnit(input_dim, dilation=9, groups=groups), Snake1d(input_dim), WNCausalConv1d( input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2), output_padding=stride % 2, ), ) def forward(self, x): return self.block(x) class CausalEncoder(nn.Module): def __init__( self, d_model: int = 64, latent_dim: int = 32, strides: list = [2, 4, 8, 8], depthwise: bool = False, ): super().__init__() self.block = [WNCausalConv1d(1, d_model, kernel_size=7, padding=3)] for stride in strides: d_model *= 2 groups = d_model // 2 if depthwise else 1 self.block += [CausalEncoderBlock(output_dim=d_model, stride=stride, groups=groups)] self.fc_mu = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1) self.fc_logvar = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1) self.block = nn.Sequential(*self.block) self.enc_dim = d_model def forward(self, x): hidden_state = self.block(x) return { "hidden_state": hidden_state, "mu": self.fc_mu(hidden_state), "logvar": self.fc_logvar(hidden_state), } class NoiseBlock(nn.Module): def __init__(self, dim): super().__init__() self.linear = WNCausalConv1d(dim, dim, kernel_size=1, bias=False) def forward(self, x): B, C, T = x.shape noise = torch.randn((B, 1, T), device=x.device, dtype=x.dtype) h = self.linear(x) n = noise * h return x + n class CausalDecoderBlock(nn.Module): def __init__( self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, groups=1, use_noise_block: bool = False, ): super().__init__() layers = [ Snake1d(input_dim), WNCausalTransposeConv1d( input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2), output_padding=stride % 2, ), ] if use_noise_block: layers.append(NoiseBlock(output_dim)) layers.extend( [ CausalResidualUnit(output_dim, dilation=1, groups=groups), CausalResidualUnit(output_dim, dilation=3, groups=groups), CausalResidualUnit(output_dim, dilation=9, groups=groups), ] ) self.block = nn.Sequential(*layers) self.input_channels = input_dim def forward(self, x): return self.block(x) class TransposeLastTwoDim(torch.nn.Module): def forward(self, x): return torch.transpose(x, -1, -2) class SampleRateConditionLayer(nn.Module): def __init__( self, input_dim: int, sr_bin_buckets: int = None, cond_type: str = "scale_bias", cond_dim: int = 128, out_layer: bool = False, ): super().__init__() self.cond_type, out_layer_in_dim = cond_type, input_dim if cond_type == "scale_bias": self.scale_embed = nn.Embedding(sr_bin_buckets, input_dim) self.bias_embed = nn.Embedding(sr_bin_buckets, input_dim) nn.init.ones_(self.scale_embed.weight) nn.init.zeros_(self.bias_embed.weight) elif cond_type == "scale_bias_init": self.scale_embed = nn.Embedding(sr_bin_buckets, input_dim) self.bias_embed = nn.Embedding(sr_bin_buckets, input_dim) nn.init.normal_(self.scale_embed.weight, mean=1) nn.init.normal_(self.bias_embed.weight) elif cond_type == "add": self.cond_embed = nn.Embedding(sr_bin_buckets, input_dim) nn.init.normal_(self.cond_embed.weight) elif cond_type == "concat": self.cond_embed = nn.Embedding(sr_bin_buckets, cond_dim) assert out_layer, "out_layer must be True for concat cond_type" out_layer_in_dim = input_dim + cond_dim else: raise ValueError(f"Invalid cond_type: {cond_type}") if out_layer: self.out_layer = nn.Sequential( Snake1d(out_layer_in_dim), WNCausalConv1d(out_layer_in_dim, input_dim, kernel_size=1), ) else: self.out_layer = nn.Identity() def forward(self, x, sr_cond): if self.cond_type in ("scale_bias", "scale_bias_init"): x = x * self.scale_embed(sr_cond).unsqueeze(-1) + self.bias_embed(sr_cond).unsqueeze(-1) elif self.cond_type == "add": x = x + self.cond_embed(sr_cond).unsqueeze(-1) elif self.cond_type == "concat": x = torch.cat([x, self.cond_embed(sr_cond).unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1) return self.out_layer(x) class CausalDecoder(nn.Module): def __init__( self, input_channel, channels, rates, depthwise: bool = False, d_out: int = 1, use_noise_block: bool = False, sr_bin_boundaries: List[int] = None, cond_type: str = "scale_bias", cond_dim: int = 128, cond_out_layer: bool = False, ): super().__init__() if depthwise: layers = [ WNCausalConv1d(input_channel, input_channel, kernel_size=7, padding=3, groups=input_channel), WNCausalConv1d(input_channel, channels, kernel_size=1), ] else: layers = [WNCausalConv1d(input_channel, channels, kernel_size=7, padding=3)] for i, stride in enumerate(rates): input_dim = channels // 2**i output_dim = channels // 2 ** (i + 1) groups = output_dim if depthwise else 1 layers += [ CausalDecoderBlock( input_dim, output_dim, stride, groups=groups, use_noise_block=use_noise_block, ) ] layers += [ Snake1d(output_dim), WNCausalConv1d(output_dim, d_out, kernel_size=7, padding=3), nn.Tanh(), ] if sr_bin_boundaries is None: self.model = nn.Sequential(*layers) self.sr_bin_boundaries = None else: self.model = nn.ModuleList(layers) self.register_buffer("sr_bin_boundaries", torch.tensor(sr_bin_boundaries, dtype=torch.int32)) self.sr_bin_buckets = len(sr_bin_boundaries) + 1 cond_layers = [] for layer in self.model: if layer.__class__.__name__ == "CausalDecoderBlock": cond_layers.append( SampleRateConditionLayer( input_dim=layer.input_channels, sr_bin_buckets=self.sr_bin_buckets, cond_type=cond_type, cond_dim=cond_dim, out_layer=cond_out_layer, ) ) else: cond_layers.append(None) self.sr_cond_model = nn.ModuleList(cond_layers) def get_sr_idx(self, sr): return torch.bucketize(sr, self.sr_bin_boundaries) def forward(self, x, sr_cond=None): if self.sr_bin_boundaries is not None: sr_cond = self.get_sr_idx(sr_cond) for layer, sr_cond_layer in zip(self.model, self.sr_cond_model): if sr_cond_layer is not None: x = sr_cond_layer(x, sr_cond) x = layer(x) return x return self.model(x) class AudioVAEConfig(BaseModel): encoder_dim: int = 128 encoder_rates: List[int] = [2, 5, 8, 8] latent_dim: int = 64 decoder_dim: int = 2048 decoder_rates: List[int] = [8, 6, 5, 2, 2, 2] depthwise: bool = True sample_rate: int = 16000 out_sample_rate: int = 48000 use_noise_block: bool = False sr_bin_boundaries: Optional[List[int]] = [20000, 30000, 40000] cond_type: str = "scale_bias" cond_dim: int = 128 cond_out_layer: bool = False class AudioVAE(nn.Module): def __init__(self, config: AudioVAEConfig = None): if config is None: config = AudioVAEConfig() super().__init__() self.encoder_dim = config.encoder_dim self.encoder_rates = config.encoder_rates self.decoder_dim = config.decoder_dim self.decoder_rates = config.decoder_rates self.depthwise = config.depthwise self.use_noise_block = config.use_noise_block latent_dim = config.latent_dim if latent_dim is None: latent_dim = config.encoder_dim * (2 ** len(config.encoder_rates)) self.latent_dim = latent_dim self.hop_length = int(np.prod(config.encoder_rates)) self.encoder = CausalEncoder( config.encoder_dim, latent_dim, config.encoder_rates, depthwise=config.depthwise, ) self.decoder = CausalDecoder( latent_dim, config.decoder_dim, config.decoder_rates, depthwise=config.depthwise, use_noise_block=config.use_noise_block, sr_bin_boundaries=config.sr_bin_boundaries, cond_type=config.cond_type, cond_dim=config.cond_dim, cond_out_layer=config.cond_out_layer, ) self.sample_rate = config.sample_rate self.out_sample_rate = config.out_sample_rate self.sr_bin_boundaries = config.sr_bin_boundaries self.chunk_size = math.prod(config.encoder_rates) self.decode_chunk_size = math.prod(config.decoder_rates) def preprocess(self, audio_data, sample_rate): if sample_rate is None: sample_rate = self.sample_rate assert sample_rate == self.sample_rate pad_to = self.hop_length length = audio_data.shape[-1] right_pad = math.ceil(length / pad_to) * pad_to - length audio_data = nn.functional.pad(audio_data, (0, right_pad)) return audio_data def decode(self, z: torch.Tensor, sr_cond: torch.Tensor = None): if self.sr_bin_boundaries is not None and sr_cond is None: sr_cond = torch.tensor([self.out_sample_rate], device=z.device, dtype=torch.int32) return self.decoder(z, sr_cond) def streaming_decode(self): return StreamingVAEDecoder(self) def encode(self, audio_data: torch.Tensor, sample_rate: int): if audio_data.ndim == 2: audio_data = audio_data.unsqueeze(1) audio_data = self.preprocess(audio_data, sample_rate) return self.encoder(audio_data)["mu"] class StreamingVAEDecoder: def __init__(self, vae: AudioVAE): self._vae = vae self._states: dict = {} self._originals: list = [] def __enter__(self): self._states.clear() self._install() return self def __exit__(self, *exc): self._restore() self._states.clear() def decode_chunk(self, z_chunk: torch.Tensor) -> torch.Tensor: return self._vae.decode(z_chunk) def _install(self): for _, mod in self._vae.decoder.named_modules(): if isinstance(mod, CausalConv1d): pad = mod._CausalConv1d__padding * 2 - mod._CausalConv1d__output_padding if pad > 0: self._patch_causal_conv(mod, pad) elif isinstance(mod, CausalTransposeConv1d): trim = mod._CausalTransposeConv1d__padding * 2 - mod._CausalTransposeConv1d__output_padding ctx = (mod.kernel_size[0] - 1) // mod.stride[0] if ctx > 0: self._patch_transpose_conv(mod, ctx, trim) def _patch_causal_conv(self, mod, pad_size): states = self._states key = id(mod) orig = mod.forward def fwd(x, _k=key, _p=pad_size, _m=mod): x_pad = torch.cat([states[_k], x], dim=-1) if _k in states else F.pad(x, (_p, 0)) if x.shape[-1] >= _p: states[_k] = x[:, :, -_p:].detach() else: prev = states.get(_k, torch.zeros(x.shape[0], x.shape[1], _p, device=x.device, dtype=x.dtype)) states[_k] = torch.cat([prev, x], dim=-1)[:, :, -_p:].detach() return nn.Conv1d.forward(_m, x_pad) mod.forward = fwd self._originals.append((mod, orig)) def _patch_transpose_conv(self, mod, ctx, trim): states = self._states key = id(mod) orig = mod.forward def fwd(x, _k=key, _c=ctx, _t=trim, _m=mod): x_full = torch.cat([states[_k], x], dim=-1) if _k in states else F.pad(x, (_c, 0)) states[_k] = x[:, :, -_c:].detach() out = nn.ConvTranspose1d.forward(_m, x_full) left = _c * _m.stride[0] return out[..., left:-_t] if _t > 0 else out[..., left:] mod.forward = fwd self._originals.append((mod, orig)) def _restore(self): for mod, orig in self._originals: mod.forward = orig self._originals.clear() # ========================================================= # Loading utilities # ========================================================= REPO_ID = os.environ.get("AUDIOVAE_REPO", "openbmb/VoxCPM2") WEIGHTS_NAME = os.environ.get("AUDIOVAE_WEIGHTS", "audiovae.pth") DEVICE = "cuda" if torch.cuda.is_available() else "cpu" TARGET_SR = 16000 @dataclass class LoadedCodec: model: AudioVAE device: str @property def sample_rate(self) -> int: return int(self.model.sample_rate) @property def out_sample_rate(self) -> int: # ✅ NEW: expose out_sample_rate return int(self.model.out_sample_rate) @property def hop_length(self) -> int: return int(self.model.hop_length) def encode(self, wav: torch.Tensor) -> torch.Tensor: return self.model.encode(wav, self.sample_rate) def decode(self, z: torch.Tensor) -> torch.Tensor: return self.model.decode(z) def _pick_state_dict(obj): if isinstance(obj, dict): for key in ("state_dict", "model", "vae", "audio_vae", "module"): if key in obj and isinstance(obj[key], dict): return obj[key] return obj @torch.inference_mode() def load_codec(repo_id: str = REPO_ID, filename: str = WEIGHTS_NAME, device: str = DEVICE) -> LoadedCodec: path = hf_hub_download(repo_id=repo_id, filename=filename) ckpt = torch.load(path, map_location="cpu") state = _pick_state_dict(ckpt) model = AudioVAE() missing, unexpected = model.load_state_dict(state, strict=False) model.to(device).eval() print(f"[load] repo={repo_id} file={filename} device={device}") if missing: print(f"[load] missing keys: {len(missing)}") if unexpected: print(f"[load] unexpected keys: {len(unexpected)}") return LoadedCodec(model=model, device=device) codec = load_codec() # ========================================================= # Audio helpers # ========================================================= def load_audio_file(path: str) -> Tuple[np.ndarray, int]: audio, sr = sf.read(path, dtype="float32") if audio.ndim > 1: audio = audio.mean(axis=1) return audio.astype(np.float32), int(sr) def resample_audio(audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray: if orig_sr == target_sr: return audio num_samples = int(round(len(audio) * target_sr / orig_sr)) return scipy_resample(audio, num_samples).astype(np.float32) def to_tensor(audio: np.ndarray, device: str) -> torch.Tensor: return torch.from_numpy(audio).unsqueeze(0).unsqueeze(0).to(device) def save_wav_temp(wav: np.ndarray, sr: int) -> str: fd, path = tempfile.mkstemp(suffix=".wav") os.close(fd) sf.write(path, wav.astype(np.float32), sr) return path def fmt_stats(kv: dict) -> str: lines = ["| Property | Value |", "|---|---|"] for k, v in kv.items(): lines.append(f"| {k} | `{v}` |") return "\n".join(lines) # ========================================================= # Encode / Decode # ========================================================= def encode_audio(file_path): if file_path is None: return None, None, "Upload an audio file first." audio, sr = load_audio_file(file_path) orig_len = len(audio) audio = resample_audio(audio, sr, codec.sample_rate) wav = to_tensor(audio, codec.device) with torch.inference_mode(): z = codec.encode(wav) # (B, D, T) z_btd = z.transpose(1, 2).contiguous() # (B, T, D) latent = z_btd.squeeze(0).detach().cpu().numpy() stats = { "Original SR": f"{sr} Hz", "Model input SR": f"{codec.sample_rate} Hz", "Model output SR": f"{codec.out_sample_rate} Hz", # ✅ shown for clarity "Original samples": f"{orig_len:,}", "Resampled samples": f"{len(audio):,}", "Latent shape": str(tuple(latent.shape)), "Latent dim": f"{latent.shape[-1]}", "Frames": f"{latent.shape[0]}", "Hop length": f"{codec.hop_length} samples", "Approx duration": f"{latent.shape[0] * codec.hop_length / codec.sample_rate:.4f} s", "Latent min/max": f"{latent.min():.4f} / {latent.max():.4f}", "Latent mean/std": f"{latent.mean():.4f} / {latent.std():.4f}", } return latent.tolist(), latent.tolist(), fmt_stats(stats) def decode_audio(latent_list, current_stats): if latent_list is None: return None, (current_stats or "") + "\n\nNo latent found. Encode first." try: z = torch.tensor(latent_list, dtype=torch.float32, device=codec.device) if z.ndim == 2: z = z.unsqueeze(0) # (B, T, D) z = z.transpose(1, 2).contiguous() # (B, D, T) except Exception as e: return None, f"Invalid latent: {e}" with torch.inference_mode(): audio = codec.decode(z) wav = audio.squeeze().detach().cpu().numpy() wav = np.nan_to_num(wav) wav = np.clip(wav, -1.0, 1.0) # ✅ FIX: use out_sample_rate (48000), NOT sample_rate (16000). # The decoder upsamples by prod(decoder_rates) = 8×6×5×2×2×2 = 1920, # so the output SR is 48000 Hz, not 16000 Hz. out_sr = codec.out_sample_rate stats = { "Decoded samples": f"{len(wav):,}", "Output SR": f"{out_sr} Hz", # ✅ 48000 "Duration": f"{len(wav) / out_sr:.4f} s", # ✅ correct duration "Wave min/max": f"{wav.min():.4f} / {wav.max():.4f}", } merged = (current_stats or "") + "\n\n### Decode Stats\n" + fmt_stats(stats) return (out_sr, wav), merged # ✅ tell Gradio correct SR # ========================================================= # UI # ========================================================= CSS = """ body, .gradio-container { background: #0d0d0d !important; color: #eaeaea !important; } h1, h2, h3 { color: #00e5a0 !important; } .gr-button { background: #00e5a0 !important; color: #000 !important; font-weight: 700 !important; border: none !important; } .gr-box, .gr-panel { background: #151515 !important; border: 1px solid #2a2a2a !important; } code { background: #1e1e1e; padding: 2px 6px; border-radius: 2px; } """ with gr.Blocks(css=CSS, title="AudioVAE Encode / Decode") as demo: gr.Markdown( f""" # AudioVAE Encode / Decode Standalone one-file app for `audiovae.pth`. **Repo:** `{REPO_ID}` **Model input SR:** `{codec.sample_rate} Hz` **Model output SR:** `{codec.out_sample_rate} Hz` **Hop length:** `{codec.hop_length}` """ ) latent_state = gr.State() with gr.Row(): audio_in = gr.Audio(type="filepath", label="Input Audio") audio_out = gr.Audio(label="Reconstructed Audio", interactive=False) with gr.Row(): encode_btn = gr.Button("Encode") decode_btn = gr.Button("Decode") stats_out = gr.Markdown(value="Upload an audio file and press Encode.") with gr.Accordion("Raw latent preview", open=False): latent_preview = gr.JSON(label="Latent JSON") encode_btn.click( fn=encode_audio, inputs=audio_in, outputs=[latent_state, latent_preview, stats_out], ) decode_btn.click( fn=decode_audio, inputs=[latent_state, stats_out], outputs=[audio_out, stats_out], ) if __name__ == "__main__": demo.launch()