klemenk commited on
Commit
6865807
·
verified ·
1 Parent(s): 98fb8d5

Create modeling_wavcoch.py

Browse files
Files changed (1) hide show
  1. modeling_wavcoch.py +836 -0
modeling_wavcoch.py ADDED
@@ -0,0 +1,836 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ WavCoch: waveform-to-cochleagram encoder with an LFQ bottleneck.
3
+ Transforming waveforms to cochleagrams ("Transformation Imitation").
4
+ """
5
+
6
+ import math
7
+ from math import log2, ceil
8
+ import tqdm
9
+ from transformers.tokenization_utils import BatchEncoding
10
+ from transformers import PreTrainedModel
11
+ from functools import partial, cache
12
+ from collections import namedtuple
13
+ from contextlib import nullcontext
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.distributed as dist
17
+ from torch.distributed import nn as dist_nn
18
+ from torch import nn, einsum
19
+ import torch.nn.functional as F
20
+ from torch.nn import Module
21
+ from torch.amp import autocast
22
+
23
+ from .configuration_wavcoch import WavCochConfig
24
+
25
+
26
+ ########################################
27
+ ### Cochleagram Transform ###
28
+ ########################################
29
+
30
+
31
+ class CochleagramTransform:
32
+ def __init__(
33
+ self,
34
+ sr: int = 16000,
35
+ signal_size: int = 16000 * 5, # set default signal size to 5 sec @ 16khz
36
+ device: str = 'cpu',
37
+ batch_mode: bool = False,
38
+ return_on_cpu: bool = True,
39
+ ):
40
+
41
+ # try:
42
+ # import chcochleagram
43
+ # except:
44
+ # print("""The cochleagram library is required to perform inversion, please instlal it with:
45
+ # pip install git+https://github.com/jenellefeather/chcochleagram.git""")
46
+ # return None
47
+
48
+ self.sr = sr
49
+ self.device = device
50
+ self.batch_mode = batch_mode
51
+ self.return_on_cpu = return_on_cpu
52
+
53
+ self.cochleagram_fn = self._init_cochleagram_fn(signal_size=signal_size)
54
+
55
+ def cochleagram(self, audio: torch.Tensor) -> torch.Tensor:
56
+ """
57
+ Compute the cochleagram of the audio waveform.
58
+ From Jenelle Feather: chcochleagram
59
+ """
60
+ # move audio to specified device
61
+ audio = audio.to(self.device)
62
+
63
+ cochleagram = self.cochleagram_fn(audio) # (batch, n_channels, n_timesteps)
64
+
65
+ # Transpose the chochleagram such that n_timesteps x n_channels
66
+ cochleagram = cochleagram.permute(0, 2, 1)
67
+
68
+ # Check for nan values
69
+ if torch.isnan(cochleagram).any():
70
+ raise ValueError('Cochleagram contains nan values')
71
+
72
+ # Move cochleagram back to cpu to match the semantics of previous dataloader
73
+ # Maybe this can be improved in the future but it does not seem to make a big
74
+ # difference in terms of performance so far
75
+ if self.return_on_cpu:
76
+ cochleagram = cochleagram.to('cpu')
77
+
78
+ # This is a bit silly, but if the cochleagram has batch size of 1 we squeeze it
79
+ # in order to match the semantics of the previous dataloaders
80
+ if cochleagram.shape[0] == 1 and not self.batch_mode:
81
+ cochleagram = cochleagram.squeeze(0)
82
+
83
+ return cochleagram
84
+
85
+ def __call__(self, audio: torch.Tensor) -> torch.Tensor:
86
+ return self.cochleagram(audio)
87
+
88
+ def _init_cochleagram_fn(
89
+ self,
90
+ pad_factor: int = 1.5,
91
+ use_rfft: bool = True,
92
+ signal_size: int = 16000 * 5, # set default signal size to 5 sec @ 16khz
93
+ ):
94
+
95
+ ### Define the cochlear filters using ERBCosFilters.
96
+ # These are the arguments used for filter construction of ERBCosFilters. See helpers/erb_filters.py for
97
+ # more documentation.
98
+ half_cos_filter_kwargs = {
99
+ 'n': 50, # Number of filters to evenly tile the space
100
+ 'low_lim': 50,
101
+ # Lowest center frequency for full filter (if lowpass filters are used they can be centered lower)
102
+ 'high_lim': 8000, # Highest center frequency
103
+ 'sample_factor': 4, # Positive integer that determines how densely ERB function will be sampled
104
+ 'full_filter': False, # Whether to use the full-filter. Must be False if rFFT is true.
105
+ }
106
+
107
+ coch_filter_kwargs = {
108
+ 'use_rfft': use_rfft, # Whether to use rFFT or not
109
+ 'pad_factor': pad_factor, # How much to pad the signal
110
+ 'filter_kwargs': half_cos_filter_kwargs}
111
+
112
+
113
+ ### Define an envelope extraction operation
114
+ # Use the analytic amplitude of the hilbert transform here. Other types of envelope extraction
115
+ # are also implemented in envelope_extraction.py. Can use Identity if want the raw subbands.
116
+ envelope_extraction = chcochleagram.envelope_extraction.HilbertEnvelopeExtraction(signal_size=signal_size,
117
+ sr=self.sr,
118
+ use_rfft=use_rfft,
119
+ pad_factor=pad_factor)
120
+
121
+ # This (and most) cochleagrams use ERBCosFilters, however other types of filterbanks can be
122
+ # constructed for linear spaced filters or different shapes. Make a new CochlearFilter class for
123
+ # these.
124
+ filters = chcochleagram.cochlear_filters.ERBCosFilters(signal_size=signal_size,
125
+ sr=self.sr,
126
+ **coch_filter_kwargs)
127
+ ### Define a downsampling operation
128
+ # Downsample the extracted envelopes. Can use Identity if want the raw subbands.
129
+ env_sr = 200 # Sampling rate after downsampling
130
+ downsampling_kwargs = {'window_size': 1001} # Parameters for the downsampling filter (see downsampling.py)
131
+ downsampling_op = chcochleagram.downsampling.SincWithKaiserWindow(sr=self.sr, env_sr=env_sr, **downsampling_kwargs)
132
+
133
+ ### Define a compression operation.
134
+ compression_kwargs = {'power': 0.3, # Power compression of 0.3
135
+ 'offset': 1e-8, # Offset for numerical stability in backwards pass
136
+ 'scale': 1, # Optional multiplicative value applied to the envelopes before compression
137
+ 'clip_value': 100} # Clip the gradients for this compression for stability
138
+ compression = chcochleagram.compression.ClippedGradPowerCompression(**compression_kwargs)
139
+
140
+ cochleagram_fn = chcochleagram.cochleagram.Cochleagram(filter_object=filters,
141
+ envelope_extraction=envelope_extraction,
142
+ downsampling=downsampling_op,
143
+ compression=compression)
144
+ # Move cochleagram_fn to the specified device
145
+ cochleagram_fn = cochleagram_fn.to(self.device)
146
+
147
+ return cochleagram_fn
148
+
149
+
150
+ ########################################
151
+ ### LFQ Definition ###
152
+ ########################################
153
+
154
+
155
+ """
156
+ Lookup Free Quantization
157
+ Proposed in https://arxiv.org/abs/2310.05737
158
+ Adapted from vector-quantize-pytorch https://github.com/lucidrains/vector-quantize-pytorch
159
+ In the simplest setup, each dimension is quantized into {-1, 1}.
160
+ An entropy penalty is used to encourage utilization.
161
+ """
162
+
163
+ # constants
164
+
165
+ Return = namedtuple('Return', ['quantized', 'indices', 'entropy_aux_loss'])
166
+
167
+ LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'batch_entropy', 'commitment'])
168
+
169
+ # distributed helpers
170
+
171
+ @cache
172
+ def is_distributed():
173
+ return dist.is_initialized() and dist.get_world_size() > 1
174
+
175
+ def maybe_distributed_mean(t):
176
+ if not is_distributed():
177
+ return t
178
+
179
+ dist_nn.all_reduce(t)
180
+ t = t / dist.get_world_size()
181
+ return t
182
+
183
+ # helper functions
184
+
185
+ def exists(v):
186
+ return v is not None
187
+
188
+ def identity(t):
189
+ return t
190
+
191
+ def default(*args):
192
+ for arg in args:
193
+ if exists(arg):
194
+ return arg() if callable(arg) else arg
195
+ return None
196
+
197
+ def pack_one(tensor: torch.Tensor, pattern: str):
198
+ """
199
+ Packs a single tensor by flattening all axes matched by '*' into one.
200
+ Returns (packed_tensor, packed_shapes), where packed_shapes is a list
201
+ of one tuple describing the original wildcard dims.
202
+ """
203
+ tokens = pattern.split()
204
+ if '*' not in tokens:
205
+ raise ValueError("Pattern must contain a '*' wildcard axis")
206
+ idx = tokens.index('*')
207
+ n_before = idx
208
+ n_after = len(tokens) - idx - 1
209
+
210
+ shape = tensor.shape
211
+ # split original shape into before / wildcard / after
212
+ if n_after:
213
+ before = shape[:n_before]
214
+ wildcard = shape[n_before:-n_after]
215
+ after = shape[-n_after:]
216
+ else:
217
+ before = shape[:n_before]
218
+ wildcard = shape[n_before:]
219
+ after = ()
220
+
221
+ # compute flattened size and reshape
222
+ flat = 1
223
+ for d in wildcard:
224
+ flat *= d
225
+ new_shape = before + (flat,) + after
226
+ packed = tensor.reshape(new_shape)
227
+
228
+ # return list-of-shapes so unpack_one can use the same interface
229
+ return packed, [tuple(wildcard)]
230
+
231
+ def unpack_one(packed: torch.Tensor, ps: list, pattern: str):
232
+ """
233
+ Reverses pack_one on a single tensor.
234
+ `ps` should be the list-of-shapes returned by pack_one.
235
+ """
236
+ tokens = pattern.split()
237
+ if '*' not in tokens:
238
+ raise ValueError("Pattern must contain a '*' wildcard axis")
239
+ idx = tokens.index('*')
240
+ n_before = idx
241
+ n_after = len(tokens) - idx - 1
242
+
243
+ shape = packed.shape
244
+ # extract the wildcard shape that was saved
245
+ wildcard = tuple(ps[0])
246
+
247
+ # split packed shape into before/flat/after
248
+ if n_after:
249
+ before = shape[:n_before]
250
+ after = shape[-n_after:]
251
+ else:
252
+ before = shape[:n_before]
253
+ after = ()
254
+
255
+ orig_shape = before + wildcard + after
256
+ return packed.reshape(orig_shape)
257
+
258
+ def l2norm(t):
259
+ return F.normalize(t, dim = -1)
260
+
261
+ # entropy
262
+
263
+ def log(t, eps = 1e-5):
264
+ return t.clamp(min = eps).log()
265
+
266
+ def entropy(prob):
267
+ return (-prob * log(prob)).sum(dim=-1)
268
+
269
+ # cosine sim linear
270
+
271
+ class CosineSimLinear(Module):
272
+ def __init__(
273
+ self,
274
+ dim_in,
275
+ dim_out,
276
+ scale = 1.
277
+ ):
278
+ super().__init__()
279
+ self.scale = scale
280
+ self.weight = nn.Parameter(torch.randn(dim_in, dim_out))
281
+
282
+ def forward(self, x):
283
+ x = F.normalize(x, dim = -1)
284
+ w = F.normalize(self.weight, dim = 0)
285
+ return (x @ w) * self.scale
286
+
287
+ # class
288
+
289
+ class LFQ(Module):
290
+ def __init__(
291
+ self,
292
+ *,
293
+ dim = None,
294
+ codebook_size = None,
295
+ entropy_loss_weight = 0.1,
296
+ commitment_loss_weight = 0.,
297
+ diversity_gamma = 1.,
298
+ straight_through_activation = nn.Identity(),
299
+ num_codebooks = 1,
300
+ keep_num_codebooks_dim = None,
301
+ codebook_scale = 1., # for residual LFQ, codebook scaled down by 2x at each layer
302
+ frac_per_sample_entropy = 1., # make less than 1. to only use a random fraction of the probs for per sample entropy
303
+ has_projections = None,
304
+ projection_has_bias = True,
305
+ soft_clamp_input_value = None,
306
+ cosine_sim_project_in = False,
307
+ cosine_sim_project_in_scale = None,
308
+ channel_first = None,
309
+ experimental_softplus_entropy_loss = False,
310
+ entropy_loss_offset = 5., # how much to shift the loss before softplus
311
+ spherical = False, # from https://arxiv.org/abs/2406.07548
312
+ force_quantization_f32 = True # will force the quantization step to be full precision
313
+ ):
314
+ super().__init__()
315
+
316
+ # some assert validations
317
+
318
+ assert exists(dim) or exists(codebook_size), 'either dim or codebook_size must be specified for LFQ'
319
+ assert not exists(codebook_size) or log2(codebook_size).is_integer(), f'your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})'
320
+
321
+ codebook_size = default(codebook_size, lambda: 2 ** dim)
322
+ self.codebook_size = codebook_size
323
+
324
+ codebook_dim = int(log2(codebook_size))
325
+ codebook_dims = codebook_dim * num_codebooks
326
+ dim = default(dim, codebook_dims)
327
+
328
+ has_projections = default(has_projections, dim != codebook_dims)
329
+
330
+ if cosine_sim_project_in:
331
+ cosine_sim_project_in = default(cosine_sim_project_in_scale, codebook_scale)
332
+ project_in_klass = partial(CosineSimLinear, scale = cosine_sim_project_in)
333
+ else:
334
+ project_in_klass = partial(nn.Linear, bias = projection_has_bias)
335
+
336
+ self.project_in = project_in_klass(dim, codebook_dims) if has_projections else nn.Identity()
337
+ self.project_out = nn.Linear(codebook_dims, dim, bias = projection_has_bias) if has_projections else nn.Identity()
338
+ self.has_projections = has_projections
339
+
340
+ self.dim = dim
341
+ self.codebook_dim = codebook_dim
342
+ self.num_codebooks = num_codebooks
343
+
344
+ keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
345
+ assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
346
+ self.keep_num_codebooks_dim = keep_num_codebooks_dim
347
+
348
+ # channel first
349
+
350
+ self.channel_first = channel_first
351
+
352
+ # straight through activation
353
+
354
+ self.activation = straight_through_activation
355
+
356
+ # whether to use BSQ (binary spherical quantization)
357
+
358
+ self.spherical = spherical
359
+ self.maybe_l2norm = (lambda t: l2norm(t) * self.codebook_scale) if spherical else identity
360
+
361
+ # entropy aux loss related weights
362
+
363
+ assert 0 < frac_per_sample_entropy <= 1.
364
+ self.frac_per_sample_entropy = frac_per_sample_entropy
365
+
366
+ self.diversity_gamma = diversity_gamma
367
+ self.entropy_loss_weight = entropy_loss_weight
368
+
369
+ # codebook scale
370
+
371
+ self.codebook_scale = codebook_scale
372
+
373
+ # commitment loss
374
+
375
+ self.commitment_loss_weight = commitment_loss_weight
376
+
377
+ # whether to soft clamp the input value from -value to value
378
+
379
+ self.soft_clamp_input_value = soft_clamp_input_value
380
+ assert not exists(soft_clamp_input_value) or soft_clamp_input_value >= codebook_scale
381
+
382
+ # whether to make the entropy loss positive through a softplus (experimental, please report if this worked or not in discussions)
383
+
384
+ self.entropy_loss_offset = entropy_loss_offset
385
+ self.experimental_softplus_entropy_loss = experimental_softplus_entropy_loss
386
+
387
+ # for no auxiliary loss, during inference
388
+
389
+ self.register_buffer('mask', 2 ** torch.arange(codebook_dim - 1, -1, -1))
390
+ self.register_buffer('zero', torch.tensor(0.), persistent = False)
391
+
392
+ # whether to force quantization step to be f32
393
+
394
+ self.force_quantization_f32 = force_quantization_f32
395
+
396
+ # codes
397
+
398
+ all_codes = torch.arange(codebook_size)
399
+ bits = ((all_codes[..., None].int() & self.mask) != 0).float()
400
+ codebook = self.bits_to_codes(bits)
401
+
402
+ self.register_buffer('codebook', codebook.float(), persistent = False)
403
+
404
+ def bits_to_codes(self, bits):
405
+ return bits * self.codebook_scale * 2 - self.codebook_scale
406
+
407
+ @property
408
+ def dtype(self):
409
+ return self.codebook.dtype
410
+
411
+ def indices_to_codes(
412
+ self,
413
+ indices,
414
+ project_out = True
415
+ ):
416
+ is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
417
+ should_transpose = default(self.channel_first, is_img_or_video)
418
+
419
+ if not self.keep_num_codebooks_dim:
420
+ # append a singleton dimension at the end
421
+ indices = indices.unsqueeze(-1)
422
+
423
+ # indices to codes, which are bits of either -1 or 1
424
+
425
+ bits = ((indices[..., None].int() & self.mask) != 0).to(self.dtype)
426
+
427
+ codes = self.bits_to_codes(bits)
428
+
429
+ codes = self.maybe_l2norm(codes)
430
+
431
+ codes = codes.flatten(-2, -1)
432
+
433
+ # whether to project codes out to original dimensions
434
+ # if the input feature dimensions were not log2(codebook size)
435
+
436
+ if project_out:
437
+ codes = self.project_out(codes)
438
+
439
+ # move codes back to original shape
440
+
441
+ if should_transpose:
442
+ codes = codes.movedim(-1, 1)
443
+
444
+ return codes
445
+
446
+ def forward(
447
+ self,
448
+ x,
449
+ inv_temperature = 100.,
450
+ return_loss_breakdown = False,
451
+ mask = None,
452
+ ):
453
+ """
454
+ einstein notation
455
+ b - batch
456
+ n - sequence (or flattened spatial dimensions)
457
+ d - feature dimension, which is also log2(codebook size)
458
+ c - number of codebook dim
459
+ """
460
+
461
+ is_img_or_video = x.ndim >= 4
462
+ should_transpose = default(self.channel_first, is_img_or_video)
463
+
464
+ # standardize image or video into (batch, seq, dimension)
465
+
466
+ if should_transpose:
467
+ x = x.movedim(1, -1)
468
+ x, ps = pack_one(x, 'b * d')
469
+
470
+ assert x.shape[-1] == self.dim, f'expected dimension of {self.dim} but received {x.shape[-1]}'
471
+
472
+ x = self.project_in(x)
473
+
474
+ # maybe soft clamp
475
+
476
+ if exists(self.soft_clamp_input_value):
477
+ clamp_value = self.soft_clamp_input_value
478
+ x = (x / clamp_value).tanh() * clamp_value
479
+
480
+ # split out number of codebooks
481
+
482
+ x = x.reshape(*x.shape[:2], self.num_codebooks, -1)
483
+
484
+ # maybe l2norm
485
+
486
+ x = self.maybe_l2norm(x)
487
+
488
+ # whether to force quantization step to be full precision or not
489
+
490
+ force_f32 = self.force_quantization_f32
491
+
492
+ quantization_context = partial(autocast, 'cuda', enabled = False) if force_f32 else nullcontext
493
+
494
+ with quantization_context():
495
+
496
+ if force_f32:
497
+ orig_dtype = x.dtype
498
+ x = x.float()
499
+
500
+ # quantize by eq 3.
501
+
502
+ original_input = x
503
+
504
+ codebook_value = torch.ones_like(x) * self.codebook_scale
505
+ quantized = torch.where(x > 0, codebook_value, -codebook_value)
506
+
507
+ # calculate indices
508
+
509
+ t = (quantized > 0).int() * self.mask.int()
510
+ indices = t.sum(dim=-1)
511
+
512
+ quantized = self.maybe_l2norm(quantized)
513
+
514
+ # use straight-through gradients (optionally with custom activation fn) if training
515
+
516
+ if self.training:
517
+ x = self.activation(x)
518
+ x = x + (quantized - x).detach()
519
+ else:
520
+ x = quantized
521
+
522
+ # entropy aux loss
523
+
524
+ if self.training:
525
+
526
+ if force_f32:
527
+ codebook = self.codebook.float()
528
+
529
+ codebook = self.maybe_l2norm(codebook)
530
+
531
+ # whether to only use a fraction of probs, for reducing memory
532
+
533
+ input_for_entropy = original_input
534
+
535
+ if exists(mask):
536
+ input_for_entropy = original_input[mask]
537
+
538
+ input_for_entropy = input_for_entropy.flatten(0, 1)
539
+
540
+ if self.frac_per_sample_entropy < 1.:
541
+ # account for mask
542
+
543
+ num_tokens = input_for_entropy.size(0)
544
+ num_sampled_tokens = int(num_tokens * self.frac_per_sample_entropy)
545
+ rand_mask = torch.randn(num_tokens).argsort(dim = -1) < num_sampled_tokens
546
+
547
+ sampled_input = input_for_entropy[rand_mask]
548
+
549
+ sampled_distance = -2 * einsum('... i d, j d -> ... i j', sampled_input, codebook)
550
+
551
+ sampled_prob = (-sampled_distance * inv_temperature).softmax(dim = -1)
552
+
553
+ per_sample_probs = sampled_prob
554
+ else:
555
+
556
+ # the same as euclidean distance up to a constant
557
+ distance = -2 * einsum('... i d, j d -> ... i j', input_for_entropy, codebook)
558
+
559
+ prob = (-distance * inv_temperature).softmax(dim = -1)
560
+
561
+ per_sample_probs = prob
562
+
563
+ # calculate per sample entropy
564
+
565
+ per_sample_entropy = entropy(per_sample_probs).mean()
566
+
567
+ # distribution over all available tokens in the batch
568
+
569
+ avg_prob = (per_sample_probs
570
+ .flatten(start_dim=0, end_dim=-3)
571
+ .mean(dim=0))
572
+
573
+ avg_prob = maybe_distributed_mean(avg_prob)
574
+
575
+ codebook_entropy = entropy(avg_prob).mean()
576
+
577
+ # 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions
578
+ # 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch
579
+
580
+ entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy
581
+ else:
582
+ # if not training, just return dummy 0
583
+ entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero
584
+
585
+ # whether to make the entropy loss positive or not through a (shifted) softplus
586
+
587
+ if self.training and self.experimental_softplus_entropy_loss:
588
+ entropy_aux_loss = F.softplus(entropy_aux_loss + self.entropy_loss_offset)
589
+
590
+ # commit loss
591
+
592
+ if self.training and self.commitment_loss_weight > 0.:
593
+
594
+ commit_loss = F.mse_loss(original_input, quantized.detach(), reduction = 'none')
595
+
596
+ if exists(mask):
597
+ commit_loss = commit_loss[mask]
598
+
599
+ commit_loss = commit_loss.mean()
600
+ else:
601
+ commit_loss = self.zero
602
+
603
+ # input back to original dtype if needed
604
+
605
+ if force_f32:
606
+ x = x.type(orig_dtype)
607
+
608
+ # merge back codebook dim
609
+
610
+ x = x.flatten(2, 3)
611
+
612
+ # project out to feature dimension if needed
613
+
614
+ x = self.project_out(x)
615
+
616
+ # reconstitute image or video dimensions
617
+
618
+ if should_transpose:
619
+ x = unpack_one(x, ps, 'b * d')
620
+ x = x.movedim(-1, 1)
621
+
622
+ indices = unpack_one(indices, ps, 'b * c')
623
+
624
+ # whether to remove single codebook dim
625
+
626
+ if not self.keep_num_codebooks_dim:
627
+ indices = indices.squeeze(-1)
628
+
629
+ # complete aux loss
630
+
631
+ aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight
632
+
633
+ # returns
634
+
635
+ ret = Return(x, indices, aux_loss)
636
+
637
+ if not return_loss_breakdown:
638
+ return ret
639
+
640
+ return ret, LossBreakdown(per_sample_entropy, codebook_entropy, commit_loss)
641
+
642
+
643
+
644
+ #################$$$$$$$################
645
+ ### Quantizer Model ###
646
+ ################$$$$$$##################
647
+
648
+
649
+ class WavCoch(PreTrainedModel):
650
+ config_class = WavCochConfig
651
+
652
+ def __init__(self, config):
653
+
654
+ super().__init__(config)
655
+ self.N = config.window_size
656
+ self.hop_length = config.hop_length
657
+
658
+ # Initial frequency transform convolutions
659
+ self.conv_real_filters = nn.Conv1d(1, self.N // 2 + 1, kernel_size=self.N, stride=self.hop_length)
660
+ self.conv_imag_filters = nn.Conv1d(1, self.N // 2 + 1, kernel_size=self.N, stride=self.hop_length)
661
+ self._initialize_conv_filters()
662
+
663
+ # Configurable encoder and decoder layers
664
+ self.encoder = self._build_conv_block(
665
+ in_channels=self.N // 2 + 1,
666
+ out_channels=config.encoder_dim,
667
+ num_layers=config.encoder_layers,
668
+ kernel_size=config.encoder_kernel_size
669
+ )
670
+ self.quantizer = LFQ(
671
+ codebook_size=config.codebook_size,
672
+ dim=config.encoder_dim,
673
+ num_codebooks=1,
674
+ entropy_loss_weight=config.entropy_loss_weight,
675
+ commitment_loss_weight=config.commit_loss_weight,
676
+ diversity_gamma=config.diversity_gamma,
677
+ )
678
+ self.decoder = self._build_conv_block(
679
+ in_channels=config.decoder_dim,
680
+ out_channels=211,
681
+ num_layers=config.decoder_layers,
682
+ kernel_size=config.decoder_kernel_size
683
+ )
684
+
685
+ def _build_conv_block(self, in_channels, out_channels, num_layers, kernel_size=9):
686
+ """Creates a block of convolutional layers with residual connections."""
687
+ layers = []
688
+ for i in range(num_layers):
689
+ conv_layer = nn.Conv1d(
690
+ in_channels if i == 0 else out_channels,
691
+ out_channels,
692
+ kernel_size=kernel_size,
693
+ stride=1,
694
+ padding='same'
695
+ )
696
+ layers.extend([
697
+ conv_layer,
698
+ nn.ReLU(),
699
+ ])
700
+ return nn.Sequential(*layers)
701
+
702
+ def _compute_twiddle_factors(self):
703
+ n = torch.arange(self.N).unsqueeze(1)
704
+ k = torch.arange(self.N).unsqueeze(0)
705
+ angles = -2 * math.pi * n * k / self.N
706
+ return torch.cos(angles), torch.sin(angles) # Real and imaginary parts
707
+
708
+ def _initialize_conv_filters(self):
709
+ twiddle_factors_real, twiddle_factors_imag = self._compute_twiddle_factors()
710
+ twiddle_factors_real = twiddle_factors_real[:self.N // 2 + 1, :]
711
+ twiddle_factors_imag = twiddle_factors_imag[:self.N // 2 + 1, :]
712
+ window = torch.hann_window(self.N).view(1, 1, -1)
713
+ conv_real_filters = twiddle_factors_real.unsqueeze(1) * window
714
+ conv_imag_filters = twiddle_factors_imag.unsqueeze(1) * window
715
+ self.conv_real_filters.weight = nn.Parameter(conv_real_filters)
716
+ self.conv_imag_filters.weight = nn.Parameter(conv_imag_filters)
717
+
718
+ @property
719
+ def vocab_size(self):
720
+ return 8192
721
+
722
+ def forward(self, wav, coch=None, return_tensors="pt", sample_rate=16000, pad=True):
723
+
724
+
725
+ if coch is None:
726
+ # # if coch is a 1D input
727
+ # if len(wav.shape) == 1:
728
+ # wav = wav.unsqueeze(0).unsqueeze(0)
729
+
730
+ # Handle all input formats
731
+ if isinstance(wav, list):
732
+ # List[Tensor[T]] → pad to [B, T], then unsqueeze to [B, 1, T]
733
+ wav = [w.unsqueeze(0) if w.ndim == 1 else w for w in wav] # make [1, T]
734
+ wav = torch.nn.utils.rnn.pad_sequence(wav, batch_first=True) # [B, T]
735
+ wav = wav.unsqueeze(1) # [B, 1, T]
736
+
737
+ elif isinstance(wav, torch.Tensor):
738
+ if wav.ndim == 1:
739
+ wav = wav.unsqueeze(0).unsqueeze(0) # [1, 1, T]
740
+ elif wav.ndim == 2:
741
+ wav = wav.unsqueeze(1) # [B, T] → [B, 1, T]
742
+ elif wav.ndim != 3:
743
+ raise ValueError(f"Unexpected tensor shape {wav.shape}, expected 1D, 2D or 3D.")
744
+
745
+ else:
746
+ raise TypeError(f"Unsupported input type: {type(wav)}")
747
+
748
+ # pad input waveform to correct for cutoff performed by cochleagram
749
+ if pad:
750
+ wav = F.pad(wav, (self.N - self.hop_length, 0), mode='constant', value=0)
751
+
752
+
753
+ # quantize audio
754
+ codes = self.quantize(wav)
755
+ return BatchEncoding({
756
+ "input_values": codes,
757
+ "input_ids": codes,
758
+ })
759
+
760
+ with torch.no_grad():
761
+ real_part = self.conv_real_filters(wav)
762
+ imag_part = self.conv_imag_filters(wav)
763
+
764
+ x = real_part + imag_part
765
+ x = self.encoder(x)
766
+ x = x.permute(0, 2, 1)
767
+ quantized, indices, entropy_aux_loss = self.quantizer(x)
768
+ mel_spectrogram = self.decoder(quantized.permute(0, 2, 1)).permute(0, 2, 1)
769
+
770
+ loss = F.mse_loss(mel_spectrogram, coch)
771
+ return mel_spectrogram, loss, entropy_aux_loss
772
+
773
+ def quantize(self, wav):
774
+ with torch.no_grad():
775
+ real_part = self.conv_real_filters(wav)
776
+ imag_part = self.conv_imag_filters(wav)
777
+
778
+ x = real_part + imag_part
779
+ x = self.encoder(x)
780
+ x = x.permute(0, 2, 1)
781
+ quantized, indices, _ = self.quantizer(x)
782
+ return indices
783
+
784
+ def decode(self, indices):
785
+ emb = self.quantizer.indices_to_codes(indices)
786
+ mel_spectrogram = self.decoder(emb.permute(0, 2, 1)).permute(0, 2, 1)
787
+ return mel_spectrogram
788
+
789
+ def wav2coch(self, wav):
790
+ with torch.no_grad():
791
+ real_part = self.conv_real_filters(wav)
792
+ imag_part = self.conv_imag_filters(wav)
793
+
794
+ x = real_part + imag_part
795
+ x = self.encoder(x)
796
+ x = x.permute(0, 2, 1)
797
+ quantized, indices, _ = self.quantizer(x)
798
+ mel_spectrogram = self.decoder(quantized.permute(0, 2, 1)).permute(0, 2, 1)
799
+ return mel_spectrogram
800
+
801
+ def invert_cochleagram_to_audio(
802
+ self,
803
+ cochleagram,
804
+ device,
805
+ num_optim_steps=1000,
806
+ lr=1e-2,
807
+ transform_cls=CochleagramTransform
808
+ ):
809
+ """
810
+ Function to invert a cochleagram back to audio using gradient descent
811
+ """
812
+ # Initialize the transform function
813
+ transform = transform_cls(sr=16000, signal_size=16000*5, device=device, return_on_cpu=False)
814
+ # Initialize the audio to be optimized
815
+ audio = torch.randn(1, 1, 16000*5).to(device).requires_grad_()
816
+ # Define the optimizer
817
+ optimizer = torch.optim.Adam([audio], lr=lr)
818
+ # Define the loss function
819
+ criterion = torch.nn.MSELoss()
820
+ # Initialize tqdm progress bar
821
+ with tqdm.tqdm(total=num_optim_steps, desc="Inverting the cochleagram") as pbar:
822
+ # Invert the cochleagram
823
+ for _ in range(num_optim_steps):
824
+ optimizer.zero_grad()
825
+ # Compute the cochleagram from the audio
826
+ pred_coch = transform(audio[0])
827
+ # Compute the loss
828
+ loss = criterion(pred_coch, cochleagram)
829
+ # Backpropagate the loss
830
+ loss.backward()
831
+ # Update the audio
832
+ optimizer.step()
833
+ # Update the progress bar
834
+ pbar.set_postfix(loss=loss.item())
835
+ pbar.update(1)
836
+ return audio