File size: 7,764 Bytes
db933ea | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 | import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.parametrizations import weight_norm
# Snake activation
@torch.jit.script
def snake(x: torch.Tensor, alpha: torch.Tensor) -> torch.Tensor:
shape = x.shape # [B, C, T]
x = x.reshape(shape[0], shape[1], -1) # [B, C, T]
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
x = x.reshape(shape) # [B, C, T]
return x
class Snake1d(nn.Module):
def __init__(self, channels):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1, channels, 1)) # [1, C, 1] one for each channel
def forward(self, x):
return snake(x, self.alpha)
# Weight-normalized convolutions
def WNConv1d(*args, **kwargs):
return weight_norm(nn.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
class VQ(nn.Module):
def __init__(self, latent_ch, K=1024, codebook_dim=8):
super().__init__()
self.in_proj = nn.Linear(latent_ch, codebook_dim, bias=False)
self.out_proj = nn.Linear(codebook_dim, latent_ch, bias=False)
self.codebook = nn.Embedding(K, codebook_dim)
def forward(self, z: torch.tensor):
# z -> [N, C] 2d tensor flattened
# project to low-dim codebook space
z_e = self.in_proj(z) # [N, codebook_dim]
# L2 normalise for cosine similarity matching
z_e_norm = F.normalize(z_e, dim=-1) # [N, codebook_dim]
cb_norm = F.normalize(self.codebook.weight, dim=-1) # [K, codebook_dim]
# euclidean distance between two unit vectors ~ cosine similarity
sim = z_e_norm @ cb_norm.t() # [N, K]
# nearest codebook entry = highest similarity
indices = sim.max(dim=1)[1] # [N]
# lookup normalised codebook entry
z_q_norm = cb_norm[indices] # [N, codebook_dim]
# losses on normalised vectors
commitment_loss = F.mse_loss(z_e_norm, z_q_norm.detach()) # push encoder direction → codebook
codebook_loss = F.mse_loss(z_e_norm.detach(), z_q_norm) # push codebook → encoder direction
# STE in normalised space
z_q_st = z_e_norm + (z_q_norm - z_e_norm).detach()
# project back to full latent space
z_q_out = self.out_proj(z_q_st) # [N, latent_ch]
return z_q_out, indices, commitment_loss, codebook_loss
class RVQ(nn.Module):
def __init__(self, num_levels, latent_ch, K=1024, codebook_dim=8):
super().__init__()
self.num_levels = num_levels
self.levels = nn.ModuleList([
VQ(latent_ch, K=K, codebook_dim=codebook_dim) for _ in range(num_levels)
])
def forward(self, z):
# z -> [N, C] 2d flat vector
r = z # initilise residual with z for the first level
quantised_sum = torch.zeros_like(z)
all_indices = []
total_commitment_loss = 0
total_codebook_loss = 0
for level in self.levels:
z_q, indices, commitment_loss, codebook_loss = level(r)
r = r - z_q.detach() # next level quantizes the error
quantised_sum = quantised_sum + z_q # accumulate: z ≈ q1 + q2 + q3 + ...
all_indices.append(indices)
total_commitment_loss = total_commitment_loss + commitment_loss
total_codebook_loss = total_codebook_loss + codebook_loss
return quantised_sum, all_indices, total_commitment_loss, total_codebook_loss
class ResidualUnit(nn.Module):
def __init__(self, ch, dilation=1):
super().__init__()
self.block = nn.Sequential(
Snake1d(ch), # [B, C, T]
WNConv1d(ch, ch, kernel_size=7, dilation=dilation, padding=3 * dilation), # [B, C, T] sk=7, padding=3 to keep same shape
Snake1d(ch), # [B, C, T]
WNConv1d(ch, ch, kernel_size=1), # [B, C, T]
)
def forward(self, x):
return x + self.block(x)
class EncoderBlock(nn.Module):
def __init__(self, in_ch, out_ch, stride):
super().__init__()
self.res1 = ResidualUnit(in_ch, dilation=1)
self.res2 = ResidualUnit(in_ch, dilation=3)
self.res3 = ResidualUnit(in_ch, dilation=9)
self.downsample = nn.Sequential(
Snake1d(in_ch),
WNConv1d(in_ch, out_ch, kernel_size=2 * stride, stride=stride, padding=stride // 2),
)
def forward(self, x):
x = self.res1(x)
x = self.res2(x)
x = self.res3(x)
x = self.downsample(x)
return x
class DecoderBlock(nn.Module):
def __init__(self, in_ch, out_ch, stride):
super().__init__()
self.upsample = nn.Sequential(
Snake1d(in_ch),
WNConvTranspose1d(in_ch, out_ch, kernel_size=2 * stride, stride=stride, padding=stride // 2),
)
self.res1 = ResidualUnit(out_ch, dilation=1)
self.res2 = ResidualUnit(out_ch, dilation=3)
self.res3 = ResidualUnit(out_ch, dilation=9)
def forward(self, x):
x = self.upsample(x)
x = self.res1(x)
x = self.res2(x)
x = self.res3(x)
return x
class RVQCodec(nn.Module):
def __init__(self, in_ch=1, latent_ch=32, K=1024, num_rvq_levels=1, codebook_dim=8):
super().__init__()
# Encoder - [B, 1, T] → [B, D, T/128]
# strides - 2 × 4 × 4 × 4 = 128x downsample
self.encoder = nn.Sequential(
WNConv1d(in_ch, 64, kernel_size=7, padding=3), # [B, 64, T]
EncoderBlock(64, 128, stride=2), # [B, 128, T/2]
EncoderBlock(128, 256, stride=4), # [B, 256, T/8]
EncoderBlock(256, 512, stride=4), # [B, 512, T/32]
EncoderBlock(512, 512, stride=4), # [B, 512, T/128]
Snake1d(512),
WNConv1d(512, latent_ch, kernel_size=3, padding=1), # [B, D, T/128]
)
# Decoder - [B, D, T/128] → [B, 1, T]
# strides - 4 × 4 × 4 × 2 = 128x upsample
self.decoder = nn.Sequential(
WNConv1d(latent_ch, 512, kernel_size=7, padding=3), # [B, 512, T/128]
DecoderBlock(512, 512, stride=4), # [B, 512, T/32]
DecoderBlock(512, 256, stride=4), # [B, 256, T/8]
DecoderBlock(256, 128, stride=4), # [B, 128, T/2]
DecoderBlock(128, 64, stride=2), # [B, 64, T]
Snake1d(64),
WNConv1d(64, in_ch, kernel_size=7, padding=3), # [B, 1, T]
nn.Tanh(),
)
self.rvq = RVQ(num_levels=num_rvq_levels, latent_ch=latent_ch, K=K, codebook_dim=codebook_dim)
def forward(self, x: torch.tensor):
# x -> [B, C=1, T]
z = self.encoder(x) # [B, D, T/128]
# flatten to 2d vector for applying rvq on channel dim
B, C, T_128 = z.shape
z_flat = z.permute(0, 2, 1).contiguous().view(B * T_128, C)
# vector quantize
z_q, all_indices, commitment_loss, codebook_loss = self.rvq(z_flat)
# reshape back
z_q = z_q.view(B, T_128, C).permute(0, 2, 1) # [B, C, T_128]
x_recon = self.decoder(z_q) # [B, C=1, T]
return x_recon, all_indices, commitment_loss, codebook_loss
if __name__ == "__main__":
device = "cuda"
x = torch.randn(1, 1, 8192)
model = RVQCodec()
print(model)
print(f"params: {sum(p.numel() for p in model.parameters()):,}")
x = x.to(device)
model = model.to(device)
out, _, _, _ = model(x)
print(f"in: {x.shape} → out: {out.shape}")
|