import os import torch import torch.nn as nn import dac import torch.nn.functional as F # Resolve project root relative to this file (models/codec_wrapper.py -> project root) _PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) class CodecWrapper(nn.Module): def __init__(self, backend="dac", latent_dim=1024): super().__init__() self.latent_dim = latent_dim print(f"Loading actual {backend.upper()} codec model...") if backend == "dac": # Prefer local weights for reproducibility and portability local_path = os.path.join(_PROJECT_ROOT, "codec_pretrain", "dac_44khz.pth") if os.path.isfile(local_path): print(f" Using local DAC weights: {local_path}") self.codec = dac.utils.load_model(load_path=local_path) else: print(f" Local weights not found at {local_path}, downloading via dac library...") self.codec = dac.utils.load_model(tag="latest", model_type="44khz") # Freeze all codec parameters if self.codec is not None: self.codec.eval() for p in self.codec.parameters(): p.requires_grad = False # The Projector Network P(u) -> z_hat self.projector = nn.Sequential( nn.Conv1d(latent_dim, latent_dim * 2, kernel_size=3, padding=1), nn.GELU(), nn.Conv1d(latent_dim * 2, latent_dim, kernel_size=3, padding=1) ) def forward_project(self, u_hat): """ Maps continuous flow prediction back to codebook manifold. u_hat: (B, D, T) """ return self.projector(u_hat) @torch.no_grad() def decode(self, z_hat): """ Decode projected latents into waveform using the frozen codec. z_hat: (B, D, T) -> dac takes (B, D, T) discrete mapping. """ import warnings warnings.filterwarnings("ignore") return self.codec.decode(z_hat) @torch.no_grad() def encode(self, wav, sample_rate): """ Encode waveform to codec latent space. wav: (B, 1, T) or (B, T) Returns z with shape (B, D, T_latent) """ if wav.ndim == 2: wav = wav.unsqueeze(1) if wav.ndim != 3: raise ValueError(f"Expected wav to be 2D or 3D tensor, got shape {tuple(wav.shape)}") # DAC 44k expects 44.1k input if sample_rate != 44100: target_len = int(round(wav.shape[-1] * 44100 / sample_rate)) wav = F.interpolate(wav, size=target_len, mode="linear", align_corners=False) wav = self.codec.preprocess(wav, 44100) z, _, _, _, _ = self.codec.encode(wav) return z