cfm_svc / models /codec_wrapper.py
Hector Li
Initial commit for Hugging Face
df93d13
import os
import torch
import torch.nn as nn
import dac
import torch.nn.functional as F
# Resolve project root relative to this file (models/codec_wrapper.py -> project root)
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
class CodecWrapper(nn.Module):
def __init__(self, backend="dac", latent_dim=1024):
super().__init__()
self.latent_dim = latent_dim
print(f"Loading actual {backend.upper()} codec model...")
if backend == "dac":
# Prefer local weights for reproducibility and portability
local_path = os.path.join(_PROJECT_ROOT, "codec_pretrain", "dac_44khz.pth")
if os.path.isfile(local_path):
print(f" Using local DAC weights: {local_path}")
self.codec = dac.utils.load_model(load_path=local_path)
else:
print(f" Local weights not found at {local_path}, downloading via dac library...")
self.codec = dac.utils.load_model(tag="latest", model_type="44khz")
# Freeze all codec parameters
if self.codec is not None:
self.codec.eval()
for p in self.codec.parameters():
p.requires_grad = False
# The Projector Network P(u) -> z_hat
self.projector = nn.Sequential(
nn.Conv1d(latent_dim, latent_dim * 2, kernel_size=3, padding=1),
nn.GELU(),
nn.Conv1d(latent_dim * 2, latent_dim, kernel_size=3, padding=1)
)
def forward_project(self, u_hat):
"""
Maps continuous flow prediction back to codebook manifold.
u_hat: (B, D, T)
"""
return self.projector(u_hat)
@torch.no_grad()
def decode(self, z_hat):
"""
Decode projected latents into waveform using the frozen codec.
z_hat: (B, D, T) -> dac takes (B, D, T) discrete mapping.
"""
import warnings
warnings.filterwarnings("ignore")
return self.codec.decode(z_hat)
@torch.no_grad()
def encode(self, wav, sample_rate):
"""
Encode waveform to codec latent space.
wav: (B, 1, T) or (B, T)
Returns z with shape (B, D, T_latent)
"""
if wav.ndim == 2:
wav = wav.unsqueeze(1)
if wav.ndim != 3:
raise ValueError(f"Expected wav to be 2D or 3D tensor, got shape {tuple(wav.shape)}")
# DAC 44k expects 44.1k input
if sample_rate != 44100:
target_len = int(round(wav.shape[-1] * 44100 / sample_rate))
wav = F.interpolate(wav, size=target_len, mode="linear", align_corners=False)
wav = self.codec.preprocess(wav, 44100)
z, _, _, _, _ = self.codec.encode(wav)
return z