taresh18 commited on
Commit
db933ea
·
verified ·
1 Parent(s): c82ec83

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +213 -0
model.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn.utils.parametrizations import weight_norm
5
+
6
+
7
+ # Snake activation
8
+
9
+ @torch.jit.script
10
+ def snake(x: torch.Tensor, alpha: torch.Tensor) -> torch.Tensor:
11
+ shape = x.shape # [B, C, T]
12
+ x = x.reshape(shape[0], shape[1], -1) # [B, C, T]
13
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
14
+ x = x.reshape(shape) # [B, C, T]
15
+ return x
16
+
17
+
18
+ class Snake1d(nn.Module):
19
+ def __init__(self, channels):
20
+ super().__init__()
21
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1)) # [1, C, 1] one for each channel
22
+
23
+ def forward(self, x):
24
+ return snake(x, self.alpha)
25
+
26
+
27
+ # Weight-normalized convolutions
28
+
29
+ def WNConv1d(*args, **kwargs):
30
+ return weight_norm(nn.Conv1d(*args, **kwargs))
31
+
32
+ def WNConvTranspose1d(*args, **kwargs):
33
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
34
+
35
+
36
+ class VQ(nn.Module):
37
+ def __init__(self, latent_ch, K=1024, codebook_dim=8):
38
+ super().__init__()
39
+ self.in_proj = nn.Linear(latent_ch, codebook_dim, bias=False)
40
+ self.out_proj = nn.Linear(codebook_dim, latent_ch, bias=False)
41
+ self.codebook = nn.Embedding(K, codebook_dim)
42
+
43
+ def forward(self, z: torch.tensor):
44
+ # z -> [N, C] 2d tensor flattened
45
+
46
+ # project to low-dim codebook space
47
+ z_e = self.in_proj(z) # [N, codebook_dim]
48
+
49
+ # L2 normalise for cosine similarity matching
50
+ z_e_norm = F.normalize(z_e, dim=-1) # [N, codebook_dim]
51
+ cb_norm = F.normalize(self.codebook.weight, dim=-1) # [K, codebook_dim]
52
+
53
+ # euclidean distance between two unit vectors ~ cosine similarity
54
+ sim = z_e_norm @ cb_norm.t() # [N, K]
55
+
56
+ # nearest codebook entry = highest similarity
57
+ indices = sim.max(dim=1)[1] # [N]
58
+
59
+ # lookup normalised codebook entry
60
+ z_q_norm = cb_norm[indices] # [N, codebook_dim]
61
+
62
+ # losses on normalised vectors
63
+ commitment_loss = F.mse_loss(z_e_norm, z_q_norm.detach()) # push encoder direction → codebook
64
+ codebook_loss = F.mse_loss(z_e_norm.detach(), z_q_norm) # push codebook → encoder direction
65
+
66
+ # STE in normalised space
67
+ z_q_st = z_e_norm + (z_q_norm - z_e_norm).detach()
68
+
69
+ # project back to full latent space
70
+ z_q_out = self.out_proj(z_q_st) # [N, latent_ch]
71
+
72
+ return z_q_out, indices, commitment_loss, codebook_loss
73
+
74
+
75
+ class RVQ(nn.Module):
76
+ def __init__(self, num_levels, latent_ch, K=1024, codebook_dim=8):
77
+ super().__init__()
78
+ self.num_levels = num_levels
79
+ self.levels = nn.ModuleList([
80
+ VQ(latent_ch, K=K, codebook_dim=codebook_dim) for _ in range(num_levels)
81
+ ])
82
+
83
+ def forward(self, z):
84
+ # z -> [N, C] 2d flat vector
85
+ r = z # initilise residual with z for the first level
86
+ quantised_sum = torch.zeros_like(z)
87
+ all_indices = []
88
+ total_commitment_loss = 0
89
+ total_codebook_loss = 0
90
+
91
+ for level in self.levels:
92
+ z_q, indices, commitment_loss, codebook_loss = level(r)
93
+ r = r - z_q.detach() # next level quantizes the error
94
+ quantised_sum = quantised_sum + z_q # accumulate: z ≈ q1 + q2 + q3 + ...
95
+ all_indices.append(indices)
96
+ total_commitment_loss = total_commitment_loss + commitment_loss
97
+ total_codebook_loss = total_codebook_loss + codebook_loss
98
+
99
+ return quantised_sum, all_indices, total_commitment_loss, total_codebook_loss
100
+
101
+
102
+ class ResidualUnit(nn.Module):
103
+ def __init__(self, ch, dilation=1):
104
+ super().__init__()
105
+ self.block = nn.Sequential(
106
+ Snake1d(ch), # [B, C, T]
107
+ WNConv1d(ch, ch, kernel_size=7, dilation=dilation, padding=3 * dilation), # [B, C, T] sk=7, padding=3 to keep same shape
108
+ Snake1d(ch), # [B, C, T]
109
+ WNConv1d(ch, ch, kernel_size=1), # [B, C, T]
110
+ )
111
+
112
+ def forward(self, x):
113
+ return x + self.block(x)
114
+
115
+
116
+ class EncoderBlock(nn.Module):
117
+ def __init__(self, in_ch, out_ch, stride):
118
+ super().__init__()
119
+ self.res1 = ResidualUnit(in_ch, dilation=1)
120
+ self.res2 = ResidualUnit(in_ch, dilation=3)
121
+ self.res3 = ResidualUnit(in_ch, dilation=9)
122
+ self.downsample = nn.Sequential(
123
+ Snake1d(in_ch),
124
+ WNConv1d(in_ch, out_ch, kernel_size=2 * stride, stride=stride, padding=stride // 2),
125
+ )
126
+
127
+ def forward(self, x):
128
+ x = self.res1(x)
129
+ x = self.res2(x)
130
+ x = self.res3(x)
131
+ x = self.downsample(x)
132
+ return x
133
+
134
+
135
+ class DecoderBlock(nn.Module):
136
+ def __init__(self, in_ch, out_ch, stride):
137
+ super().__init__()
138
+ self.upsample = nn.Sequential(
139
+ Snake1d(in_ch),
140
+ WNConvTranspose1d(in_ch, out_ch, kernel_size=2 * stride, stride=stride, padding=stride // 2),
141
+ )
142
+ self.res1 = ResidualUnit(out_ch, dilation=1)
143
+ self.res2 = ResidualUnit(out_ch, dilation=3)
144
+ self.res3 = ResidualUnit(out_ch, dilation=9)
145
+
146
+ def forward(self, x):
147
+ x = self.upsample(x)
148
+ x = self.res1(x)
149
+ x = self.res2(x)
150
+ x = self.res3(x)
151
+ return x
152
+
153
+
154
+ class RVQCodec(nn.Module):
155
+ def __init__(self, in_ch=1, latent_ch=32, K=1024, num_rvq_levels=1, codebook_dim=8):
156
+ super().__init__()
157
+ # Encoder - [B, 1, T] → [B, D, T/128]
158
+ # strides - 2 × 4 × 4 × 4 = 128x downsample
159
+ self.encoder = nn.Sequential(
160
+ WNConv1d(in_ch, 64, kernel_size=7, padding=3), # [B, 64, T]
161
+ EncoderBlock(64, 128, stride=2), # [B, 128, T/2]
162
+ EncoderBlock(128, 256, stride=4), # [B, 256, T/8]
163
+ EncoderBlock(256, 512, stride=4), # [B, 512, T/32]
164
+ EncoderBlock(512, 512, stride=4), # [B, 512, T/128]
165
+ Snake1d(512),
166
+ WNConv1d(512, latent_ch, kernel_size=3, padding=1), # [B, D, T/128]
167
+ )
168
+ # Decoder - [B, D, T/128] → [B, 1, T]
169
+ # strides - 4 × 4 × 4 × 2 = 128x upsample
170
+ self.decoder = nn.Sequential(
171
+ WNConv1d(latent_ch, 512, kernel_size=7, padding=3), # [B, 512, T/128]
172
+ DecoderBlock(512, 512, stride=4), # [B, 512, T/32]
173
+ DecoderBlock(512, 256, stride=4), # [B, 256, T/8]
174
+ DecoderBlock(256, 128, stride=4), # [B, 128, T/2]
175
+ DecoderBlock(128, 64, stride=2), # [B, 64, T]
176
+ Snake1d(64),
177
+ WNConv1d(64, in_ch, kernel_size=7, padding=3), # [B, 1, T]
178
+ nn.Tanh(),
179
+ )
180
+ self.rvq = RVQ(num_levels=num_rvq_levels, latent_ch=latent_ch, K=K, codebook_dim=codebook_dim)
181
+
182
+ def forward(self, x: torch.tensor):
183
+ # x -> [B, C=1, T]
184
+ z = self.encoder(x) # [B, D, T/128]
185
+
186
+ # flatten to 2d vector for applying rvq on channel dim
187
+ B, C, T_128 = z.shape
188
+ z_flat = z.permute(0, 2, 1).contiguous().view(B * T_128, C)
189
+
190
+ # vector quantize
191
+ z_q, all_indices, commitment_loss, codebook_loss = self.rvq(z_flat)
192
+
193
+ # reshape back
194
+ z_q = z_q.view(B, T_128, C).permute(0, 2, 1) # [B, C, T_128]
195
+
196
+ x_recon = self.decoder(z_q) # [B, C=1, T]
197
+
198
+ return x_recon, all_indices, commitment_loss, codebook_loss
199
+
200
+
201
+ if __name__ == "__main__":
202
+ device = "cuda"
203
+ x = torch.randn(1, 1, 8192)
204
+
205
+ model = RVQCodec()
206
+ print(model)
207
+ print(f"params: {sum(p.numel() for p in model.parameters()):,}")
208
+
209
+ x = x.to(device)
210
+ model = model.to(device)
211
+
212
+ out, _, _, _ = model(x)
213
+ print(f"in: {x.shape} → out: {out.shape}")