| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.nn.utils.parametrizations import weight_norm |
|
|
|
|
| |
|
|
| @torch.jit.script |
| def snake(x: torch.Tensor, alpha: torch.Tensor) -> torch.Tensor: |
| shape = x.shape |
| x = x.reshape(shape[0], shape[1], -1) |
| x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) |
| x = x.reshape(shape) |
| return x |
|
|
|
|
| class Snake1d(nn.Module): |
| def __init__(self, channels): |
| super().__init__() |
| self.alpha = nn.Parameter(torch.ones(1, channels, 1)) |
|
|
| def forward(self, x): |
| return snake(x, self.alpha) |
|
|
|
|
| |
|
|
| def WNConv1d(*args, **kwargs): |
| return weight_norm(nn.Conv1d(*args, **kwargs)) |
|
|
| def WNConvTranspose1d(*args, **kwargs): |
| return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) |
|
|
|
|
| class VQ(nn.Module): |
| def __init__(self, latent_ch, K=1024, codebook_dim=8): |
| super().__init__() |
| self.in_proj = nn.Linear(latent_ch, codebook_dim, bias=False) |
| self.out_proj = nn.Linear(codebook_dim, latent_ch, bias=False) |
| self.codebook = nn.Embedding(K, codebook_dim) |
|
|
| def forward(self, z: torch.tensor): |
| |
|
|
| |
| z_e = self.in_proj(z) |
|
|
| |
| z_e_norm = F.normalize(z_e, dim=-1) |
| cb_norm = F.normalize(self.codebook.weight, dim=-1) |
|
|
| |
| sim = z_e_norm @ cb_norm.t() |
|
|
| |
| indices = sim.max(dim=1)[1] |
|
|
| |
| z_q_norm = cb_norm[indices] |
|
|
| |
| commitment_loss = F.mse_loss(z_e_norm, z_q_norm.detach()) |
| codebook_loss = F.mse_loss(z_e_norm.detach(), z_q_norm) |
|
|
| |
| z_q_st = z_e_norm + (z_q_norm - z_e_norm).detach() |
|
|
| |
| z_q_out = self.out_proj(z_q_st) |
|
|
| return z_q_out, indices, commitment_loss, codebook_loss |
|
|
|
|
| class RVQ(nn.Module): |
| def __init__(self, num_levels, latent_ch, K=1024, codebook_dim=8): |
| super().__init__() |
| self.num_levels = num_levels |
| self.levels = nn.ModuleList([ |
| VQ(latent_ch, K=K, codebook_dim=codebook_dim) for _ in range(num_levels) |
| ]) |
|
|
| def forward(self, z): |
| |
| r = z |
| quantised_sum = torch.zeros_like(z) |
| all_indices = [] |
| total_commitment_loss = 0 |
| total_codebook_loss = 0 |
|
|
| for level in self.levels: |
| z_q, indices, commitment_loss, codebook_loss = level(r) |
| r = r - z_q.detach() |
| quantised_sum = quantised_sum + z_q |
| all_indices.append(indices) |
| total_commitment_loss = total_commitment_loss + commitment_loss |
| total_codebook_loss = total_codebook_loss + codebook_loss |
|
|
| return quantised_sum, all_indices, total_commitment_loss, total_codebook_loss |
|
|
|
|
| class ResidualUnit(nn.Module): |
| def __init__(self, ch, dilation=1): |
| super().__init__() |
| self.block = nn.Sequential( |
| Snake1d(ch), |
| WNConv1d(ch, ch, kernel_size=7, dilation=dilation, padding=3 * dilation), |
| Snake1d(ch), |
| WNConv1d(ch, ch, kernel_size=1), |
| ) |
|
|
| def forward(self, x): |
| return x + self.block(x) |
|
|
|
|
| class EncoderBlock(nn.Module): |
| def __init__(self, in_ch, out_ch, stride): |
| super().__init__() |
| self.res1 = ResidualUnit(in_ch, dilation=1) |
| self.res2 = ResidualUnit(in_ch, dilation=3) |
| self.res3 = ResidualUnit(in_ch, dilation=9) |
| self.downsample = nn.Sequential( |
| Snake1d(in_ch), |
| WNConv1d(in_ch, out_ch, kernel_size=2 * stride, stride=stride, padding=stride // 2), |
| ) |
|
|
| def forward(self, x): |
| x = self.res1(x) |
| x = self.res2(x) |
| x = self.res3(x) |
| x = self.downsample(x) |
| return x |
|
|
|
|
| class DecoderBlock(nn.Module): |
| def __init__(self, in_ch, out_ch, stride): |
| super().__init__() |
| self.upsample = nn.Sequential( |
| Snake1d(in_ch), |
| WNConvTranspose1d(in_ch, out_ch, kernel_size=2 * stride, stride=stride, padding=stride // 2), |
| ) |
| self.res1 = ResidualUnit(out_ch, dilation=1) |
| self.res2 = ResidualUnit(out_ch, dilation=3) |
| self.res3 = ResidualUnit(out_ch, dilation=9) |
|
|
| def forward(self, x): |
| x = self.upsample(x) |
| x = self.res1(x) |
| x = self.res2(x) |
| x = self.res3(x) |
| return x |
|
|
|
|
| class RVQCodec(nn.Module): |
| def __init__(self, in_ch=1, latent_ch=32, K=1024, num_rvq_levels=1, codebook_dim=8): |
| super().__init__() |
| |
| |
| self.encoder = nn.Sequential( |
| WNConv1d(in_ch, 64, kernel_size=7, padding=3), |
| EncoderBlock(64, 128, stride=2), |
| EncoderBlock(128, 256, stride=4), |
| EncoderBlock(256, 512, stride=4), |
| EncoderBlock(512, 512, stride=4), |
| Snake1d(512), |
| WNConv1d(512, latent_ch, kernel_size=3, padding=1), |
| ) |
| |
| |
| self.decoder = nn.Sequential( |
| WNConv1d(latent_ch, 512, kernel_size=7, padding=3), |
| DecoderBlock(512, 512, stride=4), |
| DecoderBlock(512, 256, stride=4), |
| DecoderBlock(256, 128, stride=4), |
| DecoderBlock(128, 64, stride=2), |
| Snake1d(64), |
| WNConv1d(64, in_ch, kernel_size=7, padding=3), |
| nn.Tanh(), |
| ) |
| self.rvq = RVQ(num_levels=num_rvq_levels, latent_ch=latent_ch, K=K, codebook_dim=codebook_dim) |
|
|
| def forward(self, x: torch.tensor): |
| |
| z = self.encoder(x) |
|
|
| |
| B, C, T_128 = z.shape |
| z_flat = z.permute(0, 2, 1).contiguous().view(B * T_128, C) |
|
|
| |
| z_q, all_indices, commitment_loss, codebook_loss = self.rvq(z_flat) |
|
|
| |
| z_q = z_q.view(B, T_128, C).permute(0, 2, 1) |
|
|
| x_recon = self.decoder(z_q) |
|
|
| return x_recon, all_indices, commitment_loss, codebook_loss |
|
|
|
|
| if __name__ == "__main__": |
| device = "cuda" |
| x = torch.randn(1, 1, 8192) |
|
|
| model = RVQCodec() |
| print(model) |
| print(f"params: {sum(p.numel() for p in model.parameters()):,}") |
|
|
| x = x.to(device) |
| model = model.to(device) |
|
|
| out, _, _, _ = model(x) |
| print(f"in: {x.shape} → out: {out.shape}") |
|
|