File size: 2,799 Bytes
df93d13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import torch
import torch.nn as nn
import dac
import torch.nn.functional as F

# Resolve project root relative to this file (models/codec_wrapper.py -> project root)
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

class CodecWrapper(nn.Module):
    def __init__(self, backend="dac", latent_dim=1024):
        super().__init__()
        self.latent_dim = latent_dim
        
        print(f"Loading actual {backend.upper()} codec model...")
        if backend == "dac":
            # Prefer local weights for reproducibility and portability
            local_path = os.path.join(_PROJECT_ROOT, "codec_pretrain", "dac_44khz.pth")
            if os.path.isfile(local_path):
                print(f"  Using local DAC weights: {local_path}")
                self.codec = dac.utils.load_model(load_path=local_path)
            else:
                print(f"  Local weights not found at {local_path}, downloading via dac library...")
                self.codec = dac.utils.load_model(tag="latest", model_type="44khz")
            
        # Freeze all codec parameters
        if self.codec is not None:
            self.codec.eval()
            for p in self.codec.parameters():
                p.requires_grad = False
                
        # The Projector Network P(u) -> z_hat
        self.projector = nn.Sequential(
            nn.Conv1d(latent_dim, latent_dim * 2, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv1d(latent_dim * 2, latent_dim, kernel_size=3, padding=1)
        )

    def forward_project(self, u_hat):
        """
        Maps continuous flow prediction back to codebook manifold.
        u_hat: (B, D, T)
        """
        return self.projector(u_hat)
    
    @torch.no_grad()
    def decode(self, z_hat):
        """
        Decode projected latents into waveform using the frozen codec.
        z_hat: (B, D, T) -> dac takes (B, D, T) discrete mapping.
        """
        import warnings
        warnings.filterwarnings("ignore")
        return self.codec.decode(z_hat)

    @torch.no_grad()
    def encode(self, wav, sample_rate):
        """
        Encode waveform to codec latent space.
        wav: (B, 1, T) or (B, T)
        Returns z with shape (B, D, T_latent)
        """
        if wav.ndim == 2:
            wav = wav.unsqueeze(1)
        if wav.ndim != 3:
            raise ValueError(f"Expected wav to be 2D or 3D tensor, got shape {tuple(wav.shape)}")

        # DAC 44k expects 44.1k input
        if sample_rate != 44100:
            target_len = int(round(wav.shape[-1] * 44100 / sample_rate))
            wav = F.interpolate(wav, size=target_len, mode="linear", align_corners=False)
        wav = self.codec.preprocess(wav, 44100)
        z, _, _, _, _ = self.codec.encode(wav)
        return z