| | import os |
| | import torch |
| | import torch.nn as nn |
| | import dac |
| | import torch.nn.functional as F |
| |
|
| | |
| | _PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
| |
|
| | class CodecWrapper(nn.Module): |
| | def __init__(self, backend="dac", latent_dim=1024): |
| | super().__init__() |
| | self.latent_dim = latent_dim |
| | |
| | print(f"Loading actual {backend.upper()} codec model...") |
| | if backend == "dac": |
| | |
| | local_path = os.path.join(_PROJECT_ROOT, "codec_pretrain", "dac_44khz.pth") |
| | if os.path.isfile(local_path): |
| | print(f" Using local DAC weights: {local_path}") |
| | self.codec = dac.utils.load_model(load_path=local_path) |
| | else: |
| | print(f" Local weights not found at {local_path}, downloading via dac library...") |
| | self.codec = dac.utils.load_model(tag="latest", model_type="44khz") |
| | |
| | |
| | if self.codec is not None: |
| | self.codec.eval() |
| | for p in self.codec.parameters(): |
| | p.requires_grad = False |
| | |
| | |
| | self.projector = nn.Sequential( |
| | nn.Conv1d(latent_dim, latent_dim * 2, kernel_size=3, padding=1), |
| | nn.GELU(), |
| | nn.Conv1d(latent_dim * 2, latent_dim, kernel_size=3, padding=1) |
| | ) |
| |
|
| | def forward_project(self, u_hat): |
| | """ |
| | Maps continuous flow prediction back to codebook manifold. |
| | u_hat: (B, D, T) |
| | """ |
| | return self.projector(u_hat) |
| | |
| | @torch.no_grad() |
| | def decode(self, z_hat): |
| | """ |
| | Decode projected latents into waveform using the frozen codec. |
| | z_hat: (B, D, T) -> dac takes (B, D, T) discrete mapping. |
| | """ |
| | import warnings |
| | warnings.filterwarnings("ignore") |
| | return self.codec.decode(z_hat) |
| |
|
| | @torch.no_grad() |
| | def encode(self, wav, sample_rate): |
| | """ |
| | Encode waveform to codec latent space. |
| | wav: (B, 1, T) or (B, T) |
| | Returns z with shape (B, D, T_latent) |
| | """ |
| | if wav.ndim == 2: |
| | wav = wav.unsqueeze(1) |
| | if wav.ndim != 3: |
| | raise ValueError(f"Expected wav to be 2D or 3D tensor, got shape {tuple(wav.shape)}") |
| |
|
| | |
| | if sample_rate != 44100: |
| | target_len = int(round(wav.shape[-1] * 44100 / sample_rate)) |
| | wav = F.interpolate(wav, size=target_len, mode="linear", align_corners=False) |
| | wav = self.codec.preprocess(wav, 44100) |
| | z, _, _, _, _ = self.codec.encode(wav) |
| | return z |
| |
|