nano-codec / model.py
taresh18's picture
Upload model.py with huggingface_hub
db933ea verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.parametrizations import weight_norm
# Snake activation
@torch.jit.script
def snake(x: torch.Tensor, alpha: torch.Tensor) -> torch.Tensor:
shape = x.shape # [B, C, T]
x = x.reshape(shape[0], shape[1], -1) # [B, C, T]
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
x = x.reshape(shape) # [B, C, T]
return x
class Snake1d(nn.Module):
def __init__(self, channels):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1, channels, 1)) # [1, C, 1] one for each channel
def forward(self, x):
return snake(x, self.alpha)
# Weight-normalized convolutions
def WNConv1d(*args, **kwargs):
return weight_norm(nn.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
class VQ(nn.Module):
def __init__(self, latent_ch, K=1024, codebook_dim=8):
super().__init__()
self.in_proj = nn.Linear(latent_ch, codebook_dim, bias=False)
self.out_proj = nn.Linear(codebook_dim, latent_ch, bias=False)
self.codebook = nn.Embedding(K, codebook_dim)
def forward(self, z: torch.tensor):
# z -> [N, C] 2d tensor flattened
# project to low-dim codebook space
z_e = self.in_proj(z) # [N, codebook_dim]
# L2 normalise for cosine similarity matching
z_e_norm = F.normalize(z_e, dim=-1) # [N, codebook_dim]
cb_norm = F.normalize(self.codebook.weight, dim=-1) # [K, codebook_dim]
# euclidean distance between two unit vectors ~ cosine similarity
sim = z_e_norm @ cb_norm.t() # [N, K]
# nearest codebook entry = highest similarity
indices = sim.max(dim=1)[1] # [N]
# lookup normalised codebook entry
z_q_norm = cb_norm[indices] # [N, codebook_dim]
# losses on normalised vectors
commitment_loss = F.mse_loss(z_e_norm, z_q_norm.detach()) # push encoder direction → codebook
codebook_loss = F.mse_loss(z_e_norm.detach(), z_q_norm) # push codebook → encoder direction
# STE in normalised space
z_q_st = z_e_norm + (z_q_norm - z_e_norm).detach()
# project back to full latent space
z_q_out = self.out_proj(z_q_st) # [N, latent_ch]
return z_q_out, indices, commitment_loss, codebook_loss
class RVQ(nn.Module):
def __init__(self, num_levels, latent_ch, K=1024, codebook_dim=8):
super().__init__()
self.num_levels = num_levels
self.levels = nn.ModuleList([
VQ(latent_ch, K=K, codebook_dim=codebook_dim) for _ in range(num_levels)
])
def forward(self, z):
# z -> [N, C] 2d flat vector
r = z # initilise residual with z for the first level
quantised_sum = torch.zeros_like(z)
all_indices = []
total_commitment_loss = 0
total_codebook_loss = 0
for level in self.levels:
z_q, indices, commitment_loss, codebook_loss = level(r)
r = r - z_q.detach() # next level quantizes the error
quantised_sum = quantised_sum + z_q # accumulate: z ≈ q1 + q2 + q3 + ...
all_indices.append(indices)
total_commitment_loss = total_commitment_loss + commitment_loss
total_codebook_loss = total_codebook_loss + codebook_loss
return quantised_sum, all_indices, total_commitment_loss, total_codebook_loss
class ResidualUnit(nn.Module):
def __init__(self, ch, dilation=1):
super().__init__()
self.block = nn.Sequential(
Snake1d(ch), # [B, C, T]
WNConv1d(ch, ch, kernel_size=7, dilation=dilation, padding=3 * dilation), # [B, C, T] sk=7, padding=3 to keep same shape
Snake1d(ch), # [B, C, T]
WNConv1d(ch, ch, kernel_size=1), # [B, C, T]
)
def forward(self, x):
return x + self.block(x)
class EncoderBlock(nn.Module):
def __init__(self, in_ch, out_ch, stride):
super().__init__()
self.res1 = ResidualUnit(in_ch, dilation=1)
self.res2 = ResidualUnit(in_ch, dilation=3)
self.res3 = ResidualUnit(in_ch, dilation=9)
self.downsample = nn.Sequential(
Snake1d(in_ch),
WNConv1d(in_ch, out_ch, kernel_size=2 * stride, stride=stride, padding=stride // 2),
)
def forward(self, x):
x = self.res1(x)
x = self.res2(x)
x = self.res3(x)
x = self.downsample(x)
return x
class DecoderBlock(nn.Module):
def __init__(self, in_ch, out_ch, stride):
super().__init__()
self.upsample = nn.Sequential(
Snake1d(in_ch),
WNConvTranspose1d(in_ch, out_ch, kernel_size=2 * stride, stride=stride, padding=stride // 2),
)
self.res1 = ResidualUnit(out_ch, dilation=1)
self.res2 = ResidualUnit(out_ch, dilation=3)
self.res3 = ResidualUnit(out_ch, dilation=9)
def forward(self, x):
x = self.upsample(x)
x = self.res1(x)
x = self.res2(x)
x = self.res3(x)
return x
class RVQCodec(nn.Module):
def __init__(self, in_ch=1, latent_ch=32, K=1024, num_rvq_levels=1, codebook_dim=8):
super().__init__()
# Encoder - [B, 1, T] → [B, D, T/128]
# strides - 2 × 4 × 4 × 4 = 128x downsample
self.encoder = nn.Sequential(
WNConv1d(in_ch, 64, kernel_size=7, padding=3), # [B, 64, T]
EncoderBlock(64, 128, stride=2), # [B, 128, T/2]
EncoderBlock(128, 256, stride=4), # [B, 256, T/8]
EncoderBlock(256, 512, stride=4), # [B, 512, T/32]
EncoderBlock(512, 512, stride=4), # [B, 512, T/128]
Snake1d(512),
WNConv1d(512, latent_ch, kernel_size=3, padding=1), # [B, D, T/128]
)
# Decoder - [B, D, T/128] → [B, 1, T]
# strides - 4 × 4 × 4 × 2 = 128x upsample
self.decoder = nn.Sequential(
WNConv1d(latent_ch, 512, kernel_size=7, padding=3), # [B, 512, T/128]
DecoderBlock(512, 512, stride=4), # [B, 512, T/32]
DecoderBlock(512, 256, stride=4), # [B, 256, T/8]
DecoderBlock(256, 128, stride=4), # [B, 128, T/2]
DecoderBlock(128, 64, stride=2), # [B, 64, T]
Snake1d(64),
WNConv1d(64, in_ch, kernel_size=7, padding=3), # [B, 1, T]
nn.Tanh(),
)
self.rvq = RVQ(num_levels=num_rvq_levels, latent_ch=latent_ch, K=K, codebook_dim=codebook_dim)
def forward(self, x: torch.tensor):
# x -> [B, C=1, T]
z = self.encoder(x) # [B, D, T/128]
# flatten to 2d vector for applying rvq on channel dim
B, C, T_128 = z.shape
z_flat = z.permute(0, 2, 1).contiguous().view(B * T_128, C)
# vector quantize
z_q, all_indices, commitment_loss, codebook_loss = self.rvq(z_flat)
# reshape back
z_q = z_q.view(B, T_128, C).permute(0, 2, 1) # [B, C, T_128]
x_recon = self.decoder(z_q) # [B, C=1, T]
return x_recon, all_indices, commitment_loss, codebook_loss
if __name__ == "__main__":
device = "cuda"
x = torch.randn(1, 1, 8192)
model = RVQCodec()
print(model)
print(f"params: {sum(p.numel() for p in model.parameters()):,}")
x = x.to(device)
model = model.to(device)
out, _, _, _ = model(x)
print(f"in: {x.shape} → out: {out.shape}")