tky823 commited on
Commit
aefc713
·
verified ·
1 Parent(s): 9f56050

Upload modeling_musicfm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_musicfm.py +416 -0
modeling_musicfm.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright 2023 ByteDance Inc.
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”),
6
+ # to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
8
+ #
9
+ # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
10
+ #
11
+ # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
12
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
13
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
14
+ # IN THE SOFTWARE.
15
+
16
+ import random
17
+
18
+ import torch
19
+ import torchaudio
20
+ from einops import rearrange
21
+ from torch import einsum, nn
22
+ from torch.nn.common_types import _size_2_t
23
+ from transformers import PreTrainedModel
24
+
25
+ from .configuration_musicfm import MusicFMConfig
26
+
27
+
28
+ class MusicFM25Hz(PreTrainedModel):
29
+ config_class = MusicFMConfig
30
+
31
+ def __init__(self, config: MusicFMConfig) -> None:
32
+ super().__init__(config)
33
+
34
+ # global variables
35
+ self.num_codebooks = config.num_codebooks
36
+ self.codebook_dim = config.codebook_dim
37
+ self.codebook_size = config.codebook_size
38
+ self.features = config.features
39
+ self.hop_length = config.hop_length
40
+ self.n_mels = config.n_mels
41
+ self.conv_dim = config.conv_dim
42
+ self.encoder_dim = config.encoder_dim
43
+ self.encoder_depth = config.encoder_depth
44
+ self.mask_hop = config.mask_hop
45
+ self.mask_prob = config.mask_prob
46
+ self.is_flash = config.is_flash
47
+ self.stat = config.stat
48
+
49
+ # feature extractor
50
+ self.preprocessor_melspec_2048 = MelSTFT(
51
+ n_fft=2048, hop_length=self.hop_length, is_db=True
52
+ )
53
+
54
+ # random quantizer
55
+ seed = 142
56
+ for feature in self.features:
57
+ for i in range(self.num_codebooks):
58
+ setattr(
59
+ self,
60
+ f"quantizer_{feature}_{i}",
61
+ RandomProjectionQuantizer(
62
+ self.n_mels * 4,
63
+ self.codebook_dim,
64
+ self.codebook_size,
65
+ seed=seed + i,
66
+ ),
67
+ )
68
+
69
+ # two residual convolution layers + one projection layer
70
+ self.conv = Conv2dSubsampling(
71
+ 1, self.conv_dim, self.encoder_dim, strides=[2, 2], n_bands=self.n_mels
72
+ )
73
+
74
+ # Conformer
75
+ if config.is_flash:
76
+ from .flash_conformer import (
77
+ Wav2Vec2ConformerConfig,
78
+ Wav2Vec2ConformerEncoder,
79
+ )
80
+ else:
81
+ from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
82
+ Wav2Vec2ConformerConfig,
83
+ Wav2Vec2ConformerEncoder,
84
+ )
85
+
86
+ conformer_config = Wav2Vec2ConformerConfig.from_pretrained(
87
+ "facebook/wav2vec2-conformer-rope-large-960h-ft"
88
+ )
89
+ conformer_config.num_hidden_layers = self.encoder_depth
90
+ conformer_config.hidden_size = self.encoder_dim
91
+ self.conformer = Wav2Vec2ConformerEncoder(conformer_config)
92
+
93
+ # projection
94
+ self.linear = nn.Linear(self.encoder_dim, self.codebook_size)
95
+
96
+ # loss function
97
+ self.loss = nn.CrossEntropyLoss()
98
+
99
+ # cls token (used for sequence classification)
100
+ random.seed(seed)
101
+ self.cls_token = nn.Parameter(torch.randn(self.encoder_dim))
102
+
103
+ def masking(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.LongTensor]:
104
+ """random masking of 400ms with given probability"""
105
+ mx = x.clone()
106
+ b, t = mx.shape
107
+ len_masking_raw = int(24000 * self.mask_hop)
108
+ len_masking_token = int(24000 / self.hop_length / 2 / 2 * self.mask_hop)
109
+
110
+ # get random mask indices
111
+ start_indices = torch.rand(b, t // len_masking_raw) < self.mask_prob
112
+ time_domain_masked_indices = torch.nonzero(
113
+ start_indices.repeat_interleave(len_masking_raw, dim=1)
114
+ )
115
+ token_domain_masked_indices = torch.nonzero(
116
+ start_indices.repeat_interleave(len_masking_token, dim=1)
117
+ )
118
+
119
+ # mask with random values
120
+ masking_noise = (
121
+ torch.randn(time_domain_masked_indices.shape[0], dtype=x.dtype) * 0.1
122
+ ) # 0 mean 0.1 std
123
+ mx[tuple(time_domain_masked_indices.t())] = masking_noise.to(x.device)
124
+
125
+ return mx, token_domain_masked_indices
126
+
127
+ @torch.no_grad()
128
+ def preprocessing(
129
+ self, x: torch.Tensor, features: dict[str, torch.Tensor]
130
+ ) -> dict[str, torch.Tensor]:
131
+ """extract classic audio features"""
132
+ # check precision
133
+ if x.dtype == torch.float16:
134
+ precision = 16
135
+ else:
136
+ precision = 32
137
+
138
+ out = {}
139
+ for key in features:
140
+ layer = getattr(self, "preprocessor_%s" % key)
141
+ out[key] = layer.float()(x.float())[..., :-1]
142
+ if precision == 16:
143
+ out[key] = out[key].half()
144
+ return out
145
+
146
+ def encoder(self, x: torch.Tensor) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
147
+ """2-layer conv + w2v-conformer"""
148
+ x = self.conv(x)
149
+ out = self.conformer(x, output_hidden_states=True)
150
+ hidden_emb = out["hidden_states"]
151
+ last_emb = out["last_hidden_state"]
152
+ logits = self.linear(last_emb)
153
+ logits = {
154
+ key: logits[:, :, i * self.codebook_size : (i + 1) * self.codebook_size]
155
+ for i, key in enumerate(self.features)
156
+ }
157
+ return logits, hidden_emb
158
+
159
+ @torch.no_grad()
160
+ def normalize(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
161
+ """normalize the input audio to have zero mean unit variance"""
162
+ for key in x.keys():
163
+ x[key] = (x[key] - self.stat["%s_mean" % key]) / self.stat["%s_std" % key]
164
+ return x
165
+
166
+ @torch.no_grad()
167
+ def rearrange(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
168
+ """rearrange the batch to flatten every 4 steps"""
169
+ for key in x.keys():
170
+ if key == "chromagram":
171
+ x[key] = rearrange(x[key], "b f t -> b t f")
172
+ else:
173
+ x[key] = rearrange(x[key], "b f (t s) -> b t (s f)", s=4)
174
+
175
+ return x
176
+
177
+ @torch.no_grad()
178
+ def tokenize(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
179
+ out = {}
180
+ for key in x.keys():
181
+ layer = getattr(self, "quantizer_%s" % key)
182
+ out[key] = layer(x[key])
183
+ return out
184
+
185
+ def get_targets(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
186
+ x = self.preprocessing(x, features=self.features)
187
+ x = self.normalize(x)
188
+ x = self.rearrange(x)
189
+ target_tokens = self.tokenize(x)
190
+
191
+ return target_tokens
192
+
193
+ def get_predictions(
194
+ self, x: torch.Tensor
195
+ ) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
196
+ # preprocessing
197
+ x = self.preprocessing(x, features=["melspec_2048"])
198
+ x = self.normalize(x)
199
+
200
+ # encoding
201
+ logits, hidden_emb = self.encoder(x["melspec_2048"])
202
+
203
+ return logits, hidden_emb
204
+
205
+ def get_latent(self, x: torch.Tensor, layer_ix: int = 12) -> torch.Tensor:
206
+ _, hidden_states = self.get_predictions(x)
207
+ emb = hidden_states[layer_ix]
208
+ return emb
209
+
210
+ def get_loss(
211
+ self,
212
+ logits: dict[str, torch.Tensor],
213
+ target_tokens: dict[str, torch.Tensor],
214
+ masked_indices: torch.LongTensor,
215
+ ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
216
+ losses = {}
217
+ accuracies = {}
218
+ for key in logits.keys():
219
+ masked_logits = logits[key][tuple(masked_indices.t())]
220
+ masked_tokens = target_tokens[key][tuple(masked_indices.t())]
221
+ losses[key] = self.loss(masked_logits, masked_tokens)
222
+ accuracies[key] = (
223
+ torch.sum(masked_logits.argmax(-1) == masked_tokens)
224
+ / masked_tokens.numel()
225
+ )
226
+ return losses, accuracies
227
+
228
+ def forward(
229
+ self, x: torch.Tensor
230
+ ) -> tuple[
231
+ dict[str, torch.Tensor],
232
+ torch.Tensor,
233
+ dict[str, torch.Tensor],
234
+ dict[str, torch.Tensor],
235
+ ]:
236
+ # get target feature tokens
237
+ target_tokens = self.get_targets(x)
238
+
239
+ # masking
240
+ x, masked_indices = self.masking(x)
241
+
242
+ # forward
243
+ logits, hidden_emb = self.get_predictions(x)
244
+
245
+ # get loss
246
+ losses, accuracies = self.get_loss(logits, target_tokens, masked_indices)
247
+
248
+ return logits, hidden_emb, losses, accuracies
249
+
250
+
251
+ class MelSTFT(nn.Module):
252
+ def __init__(
253
+ self,
254
+ sample_rate: int = 24000,
255
+ n_fft: int = 2048,
256
+ hop_length: int = 240,
257
+ n_mels: int = 128,
258
+ is_db: bool = False,
259
+ ) -> None:
260
+ super().__init__()
261
+
262
+ # spectrogram
263
+ self.mel_stft = torchaudio.transforms.MelSpectrogram(
264
+ sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels
265
+ )
266
+
267
+ # amplitude to decibel
268
+ self.is_db = is_db
269
+ if is_db:
270
+ self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
271
+
272
+ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
273
+ if self.is_db:
274
+ return self.amplitude_to_db(self.mel_stft(waveform))
275
+ else:
276
+ return self.mel_stft(waveform)
277
+
278
+
279
+ class RandomProjectionQuantizer(nn.Module):
280
+ """
281
+ Random projection and codebook lookup module
282
+
283
+ Some code is borrowed from:
284
+ https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/random_projection_quantizer.py
285
+ But I did normalization using pre-computed global mean & variance instead of using layer norm.
286
+ """
287
+
288
+ def __init__(
289
+ self,
290
+ input_dim: int,
291
+ codebook_dim: int,
292
+ codebook_size: int,
293
+ seed: int = 142,
294
+ ) -> None:
295
+ super().__init__()
296
+
297
+ # random seed
298
+ torch.manual_seed(seed)
299
+
300
+ # randomly initialized projection
301
+ random_projection = torch.empty(input_dim, codebook_dim)
302
+ nn.init.xavier_normal_(random_projection)
303
+ self.register_buffer("random_projection", random_projection)
304
+
305
+ # randomly initialized codebook
306
+ codebook = torch.empty(codebook_size, codebook_dim)
307
+ nn.init.normal_(codebook)
308
+ self.register_buffer("codebook", codebook)
309
+
310
+ def codebook_lookup(self, x: torch.Tensor) -> torch.Tensor:
311
+ # reshape
312
+ b = x.shape[0]
313
+ x = rearrange(x, "b n e -> (b n) e")
314
+
315
+ # L2 normalization
316
+ normalized_x = nn.functional.normalize(x, dim=1, p=2)
317
+ normalized_codebook = nn.functional.normalize(self.codebook, dim=1, p=2)
318
+
319
+ # compute distances
320
+ distances = torch.cdist(normalized_codebook, normalized_x)
321
+
322
+ # get nearest
323
+ nearest_indices = torch.argmin(distances, dim=0)
324
+
325
+ # reshape
326
+ xq = rearrange(nearest_indices, "(b n) -> b n", b=b)
327
+
328
+ return xq
329
+
330
+ @torch.no_grad()
331
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
332
+ # always eval
333
+ self.eval()
334
+
335
+ # random projection [batch, length, input_dim] -> [batch, length, codebook_dim]
336
+ x = einsum("b n d, d e -> b n e", x, self.random_projection)
337
+
338
+ # codebook lookup
339
+ xq = self.codebook_lookup(x)
340
+
341
+ return xq
342
+
343
+
344
+ class Res2dModule(nn.Module):
345
+ def __init__(self, idim: int, odim: int, stride: _size_2_t = (2, 2)) -> None:
346
+ super().__init__()
347
+ self.conv1 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride)
348
+ self.bn1 = nn.BatchNorm2d(odim)
349
+ self.conv2 = nn.Conv2d(odim, odim, 3, padding=1)
350
+ self.bn2 = nn.BatchNorm2d(odim)
351
+ self.relu = nn.ReLU()
352
+
353
+ # residual
354
+ self.diff = False
355
+ if (idim != odim) or (stride[0] > 1):
356
+ self.conv3 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride)
357
+ self.bn3 = nn.BatchNorm2d(odim)
358
+ self.diff = True
359
+
360
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
361
+ out = self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x)))))
362
+ if self.diff:
363
+ x = self.bn3(self.conv3(x))
364
+ out = x + out
365
+ out = self.relu(out)
366
+ return out
367
+
368
+
369
+ class Conv2dSubsampling(nn.Module):
370
+ """Convolutional 2D subsampling (to 1/4 length).
371
+
372
+ Args:
373
+ idim (int): Input dimension.
374
+ hdim (int): Hidden dimension.
375
+ odim (int): Output dimension.
376
+ strides (list): Sizes of strides.
377
+ n_bands (int): Number of frequency bands.
378
+
379
+ """
380
+
381
+ def __init__(
382
+ self,
383
+ idim: int,
384
+ hdim: int,
385
+ odim: int,
386
+ strides: list[int] = [2, 2],
387
+ n_bands: int = 64,
388
+ ) -> None:
389
+ """Construct an Conv2dSubsampling object."""
390
+ super().__init__()
391
+
392
+ self.conv = nn.Sequential(
393
+ Res2dModule(idim, hdim, (2, strides[0])),
394
+ Res2dModule(hdim, hdim, (2, strides[1])),
395
+ )
396
+ self.linear = nn.Linear(hdim * n_bands // 2 // 2, odim)
397
+
398
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
399
+ """Subsample x.
400
+
401
+ Args:
402
+ x (torch.Tensor): Input tensor (#batch, idim, time).
403
+
404
+ Returns:
405
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
406
+ where time' = time // 4.
407
+ """
408
+
409
+ if x.dim() == 3:
410
+ x = x.unsqueeze(1) # (b, c, f, t)
411
+
412
+ x = self.conv(x)
413
+ x = rearrange(x, "b c f t -> b t (c f)")
414
+ x = self.linear(x)
415
+
416
+ return x