Wataru commited on
Commit
16293fd
·
verified ·
1 Parent(s): f634d4b

add non_integer stride

Browse files
Files changed (3) hide show
  1. continuous_filters.py +651 -0
  2. model.safetensors +1 -1
  3. modeling.py +1 -2
continuous_filters.py ADDED
@@ -0,0 +1,651 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implementations of latent analog filters.
2
+
3
+ Copyright (c) Tomohiko Nakamura
4
+ All rights reserved.
5
+ """
6
+
7
+ import functools
8
+ from typing import Sequence
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torchaudio
13
+ from torch import nn
14
+
15
+
16
+ def erb_to_hz(x):
17
+ """Convert ERB to Hz.
18
+
19
+ Args:
20
+ x (numpy.ndarray or float): Frequency in ERB scale
21
+
22
+ Return:
23
+ numpy.ndarray or float: Frequency in Hz
24
+
25
+ """
26
+ return (np.exp(x / 9.265) - 1) * 24.7 * 9.265
27
+
28
+
29
+ def hz_to_erb(x):
30
+ """Convert Hz to ERB.
31
+
32
+ Args:
33
+ x (numpy.ndarray or float): Frequency in Hz
34
+
35
+ Return:
36
+ numpy.ndarray or float: Frequency in ERB scale
37
+
38
+ """
39
+ return np.log(1 + x / (24.7 * 9.265)) * 9.265
40
+
41
+
42
+ #############################################
43
+ class ModulatedGaussianFilters(nn.Module):
44
+ r"""Modulated Gaussian filters.
45
+
46
+ The frequency response of this filter is given by
47
+
48
+ [
49
+ H(\omega) = e^{-(\omega-\omega_{c})^2/(2\sigma^2)} + e^{-(\omega+\omega_{c})^2/(2\sigma^2)}.
50
+ ]
51
+
52
+ If one_sided is True, this frequency response is changed as
53
+
54
+ [
55
+ H(\omega) = e^{-(\omega-\omega_{c})^2/(2\sigma^2)}.
56
+ ]
57
+
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ n_filters,
63
+ init_type="erb",
64
+ min_bw=1.0 * 2.0 * np.pi,
65
+ initial_freq_range=None,
66
+ one_sided=False,
67
+ init_sigma=100.0 * 2.0 * np.pi,
68
+ trainable=True,
69
+ ) -> None:
70
+ """Args:
71
+ n_filters (int): Number of filters
72
+ init_type (str): Initialization type of center frequencies.
73
+ If "erb", set them from initial_freq_range[0] to initial_freq_range[1] with an equal interval in the ERB scale.
74
+ If "linear", set them from initial_freq_range[0] to initial_freq_range[1] with an equal interval in the linear frequency scale.
75
+ min_bw (float): Minimum bandwidth in radian
76
+ initial_freq_range ([float,float]): Initial frequency ranges in Hz, as tuple of minimum (typically 50) and maximum values (typically, half of Nyquist frequency)
77
+ one_sided (bool): If True, ignore the term in the negative frequency region. If False, the corresponding impulse response is modulated Gaussian window.
78
+ init_sigma (float): Initial value for sigma
79
+ trainable (bool): Whether filter parameters are trainable or not.
80
+
81
+ """
82
+ if initial_freq_range is None:
83
+ initial_freq_range = [50.0, 32000 / 2]
84
+ super().__init__()
85
+ lf, hf = initial_freq_range
86
+ if init_type == "linear":
87
+ mus = np.linspace(lf, hf, n_filters) * 2.0 * np.pi
88
+ sigma2s = init_sigma**2 * np.ones((n_filters,), dtype="f")
89
+ elif init_type == "erb":
90
+ erb_mus = np.linspace(hz_to_erb(lf), hz_to_erb(hf), n_filters)
91
+ mus = erb_to_hz(erb_mus) * 2.0 * np.pi
92
+ sigma2s = init_sigma**2 * np.ones((n_filters,), dtype="f")
93
+ else:
94
+ raise ValueError
95
+ self.min_ln_sigma2s = np.log(min_bw**2)
96
+
97
+ self.mus = nn.Parameter(torch.from_numpy(mus).float(), requires_grad=trainable)
98
+ self._ln_sigma2s = nn.Parameter(
99
+ torch.from_numpy(np.log(sigma2s)).float().clamp(min=self.min_ln_sigma2s),
100
+ requires_grad=trainable,
101
+ )
102
+ self.phase = nn.Parameter(
103
+ torch.zeros((n_filters,), dtype=torch.float),
104
+ requires_grad=trainable,
105
+ )
106
+ self.phase.data.uniform_(0.0, np.pi)
107
+ self.one_sided = one_sided
108
+
109
+ @property
110
+ def sigma2s(self):
111
+ return self._ln_sigma2s.clamp(min=self.min_ln_sigma2s).exp()
112
+
113
+ def get_frequency_responses(self, omega: torch.Tensor):
114
+ """Sample frequency responses at omega.
115
+
116
+ Args:
117
+ omega (torch.Tensor): Angular frequencies (n_angs)
118
+
119
+ Return:
120
+ tuple[torch.Tensor]: Real and imaginary parts of frequency responses sampled at omega.
121
+
122
+ """
123
+ if self.one_sided:
124
+ resp_abs = torch.exp(
125
+ -(omega[None, :] - self.mus[:, None]).pow(2.0)
126
+ / (2.0 * self.sigma2s[:, None]),
127
+ ) # n_filters x n_angfreqs
128
+ resp_r = resp_abs * self.phase.cos()[:, None]
129
+ resp_i = resp_abs * self.phase.sin()[:, None]
130
+ else:
131
+ resp_abs = torch.exp(
132
+ -(omega[None, :] - self.mus[:, None]).pow(2.0)
133
+ / (2.0 * self.sigma2s[:, None]),
134
+ ) # n_filters x n_angfreqs
135
+ resp_abs2 = torch.exp(
136
+ -(omega[None, :] + self.mus[:, None]).pow(2.0)
137
+ / (2.0 * self.sigma2s[:, None]),
138
+ ) # to ensure filters whose impulse responses are real.
139
+ resp_r = (
140
+ resp_abs * self.phase.cos()[:, None]
141
+ + resp_abs2 * ((-self.phase).cos()[:, None])
142
+ )
143
+ resp_i = (
144
+ resp_abs * self.phase.sin()[:, None]
145
+ + resp_abs2 * ((-self.phase).sin()[:, None])
146
+ )
147
+ return resp_r, resp_i
148
+
149
+ def extra_repr(self):
150
+ s = f"n_filters={int(self.mus.shape[0])}, one_sided={self.one_sided}"
151
+ return s.format(**self.__dict__)
152
+
153
+ @property
154
+ def device(self):
155
+ return self.mus.device
156
+
157
+
158
+ class TDModulatedGaussianFilters(ModulatedGaussianFilters):
159
+ def __init__(
160
+ self,
161
+ n_filters,
162
+ train_sample_rate,
163
+ init_type="erb",
164
+ min_bw=1.0 * 2.0 * np.pi,
165
+ initial_freq_range=None,
166
+ one_sided=False,
167
+ init_sigma=100.0 * 2.0 * np.pi,
168
+ trainable=True,
169
+ ) -> None:
170
+ """Args:
171
+ n_filters (int): Number of filters
172
+ train_sample_rate (float): Trained sampling frequency
173
+ init_type (str): Initialization type of center frequencies.
174
+ If "erb", set them from initial_freq_range[0] to initial_freq_range[1] with an equal interval in the ERB scale.
175
+ If "linear", set them from initial_freq_range[0] to initial_freq_range[1] with an equal interval in the linear frequency scale.
176
+ min_bw (float): Minimum bandwidth in radian
177
+ initial_freq_range ([float,float]): Initial frequency ranges in Hz, as tuple of minimum (typically 50) and maximum values (typically, half of Nyquist frequency)
178
+ one_sided (bool): If True, ignore the term in the negative frequency region. If False, the corresponding impulse response is modulated Gaussian window.
179
+ init_sigma (float): Initial value for sigma
180
+ trainable (bool): Whether filter parameters are trainable or not.
181
+
182
+ """
183
+ if initial_freq_range is None:
184
+ initial_freq_range = [50.0, 32000 / 2]
185
+ super().__init__(
186
+ n_filters=n_filters,
187
+ init_type=init_type,
188
+ min_bw=min_bw,
189
+ initial_freq_range=initial_freq_range,
190
+ one_sided=one_sided,
191
+ init_sigma=init_sigma,
192
+ trainable=trainable,
193
+ )
194
+ self.register_buffer(
195
+ "train_sample_rate",
196
+ torch.tensor(float(train_sample_rate)),
197
+ )
198
+
199
+ def get_impulse_responses(self, sample_rate: int, tap_size: int):
200
+ """Sample impulse responses.
201
+
202
+ Args:
203
+ sample_rate (int): Target sampling frequency
204
+ tap_size (int): Tap size
205
+
206
+ Return:
207
+ torch.Tensor: Sampled impulse responses (n_filters x tap_size)
208
+
209
+ """
210
+ center_freqs_in_hz = self.mus / (2.0 * np.pi)
211
+ # check whether the center frequencies are below Nyquist rate
212
+ if self.train_sample_rate > sample_rate:
213
+ mask = center_freqs_in_hz <= sample_rate / 2
214
+ ###
215
+ t = torch.arange(0.0, tap_size, 1).type_as(center_freqs_in_hz) / sample_rate
216
+ t = (t - t.mean())[None, :]
217
+ ###
218
+ if self.one_sided:
219
+ raise NotImplementedError
220
+ c = (
221
+ 2.0
222
+ * (2.0 * np.pi * self.sigma2s[:, None]).sqrt()
223
+ * (-self.sigma2s[:, None] * (t**2) / 2.0).exp()
224
+ )
225
+ filter_coeffs = (
226
+ c * (self.mus[:, None] @ t + self.phase[:, None]).cos()
227
+ ) # n_filters x tap_size
228
+ if self.train_sample_rate > sample_rate:
229
+ filter_coeffs = filter_coeffs * mask[:, None]
230
+ return filter_coeffs[:, torch.arange(tap_size - 1, -1, -1)]
231
+
232
+
233
+ #############################################
234
+ class MultiPhaseGammaToneFilters(nn.Module):
235
+ """Multiphase gamma tone filters.
236
+
237
+ Remark:
238
+ This class includes the creation of Hilbert transform pairs.
239
+
240
+ [2] D. Ditter and T. Gerkmann, ``A multi-phase gammatone filterbank for speech separation via TasNet,'' in Proceedings of IEEE International Conference on Acoustics, Speech, and Signal Processing, 2020, pp. 36--40.
241
+ """
242
+
243
+ def __init__(
244
+ self,
245
+ n_filters,
246
+ train_sample_rate,
247
+ initial_freq_range=None,
248
+ n_center_freqs=24,
249
+ trainable=False,
250
+ ) -> None:
251
+ """Args:
252
+ n_filters (int): Number of filters
253
+ train_sample_rate (float): Trained sampling frequency
254
+ initial_freq_range ([float,float]): Initial frequency ranges in Hz, as tuple of minimum (typically 50) and maximum values (typically, half of Nyquist frequency)
255
+ n_center_freqs (int): Number of center frequencies
256
+ trainable (bool): Whether filter parameters are trainable or not.
257
+
258
+ """
259
+ if initial_freq_range is None:
260
+ initial_freq_range = [100.0, 16000 / 2]
261
+ super().__init__()
262
+ self.register_buffer(
263
+ "train_sample_rate",
264
+ torch.tensor(float(train_sample_rate)),
265
+ )
266
+ self.n_filters = n_filters
267
+ assert n_filters // 2 >= n_center_freqs
268
+ ## Ditter's initialization method
269
+ if trainable:
270
+ self.center_freqs_in_hz = nn.Parameter(
271
+ torch.from_numpy(
272
+ erb_to_hz(
273
+ np.linspace(
274
+ hz_to_erb(initial_freq_range[0]),
275
+ hz_to_erb(initial_freq_range[1]),
276
+ n_center_freqs,
277
+ ),
278
+ ).astype("f"),
279
+ ).float(), # [Hz]
280
+ requires_grad=trainable,
281
+ )
282
+ else:
283
+ self.register_buffer(
284
+ "center_freqs_in_hz",
285
+ torch.from_numpy(
286
+ erb_to_hz(
287
+ np.linspace(
288
+ hz_to_erb(initial_freq_range[0]),
289
+ hz_to_erb(initial_freq_range[1]),
290
+ n_center_freqs,
291
+ ),
292
+ ).astype("f"),
293
+ ).float(),
294
+ )
295
+ ###
296
+ n_phase_variations_list = (
297
+ np.ones(n_center_freqs) * np.floor(self.n_filters / 2 / n_center_freqs)
298
+ ).astype("i")
299
+ remaining_phases = int(self.n_filters // 2 - n_phase_variations_list.sum())
300
+ if remaining_phases > 0:
301
+ n_phase_variations_list[:remaining_phases] += 1
302
+ n_phase_variations_list = [int(_) for _ in n_phase_variations_list]
303
+ self.register_buffer(
304
+ "n_phase_variations",
305
+ torch.tensor(n_phase_variations_list),
306
+ )
307
+ ###
308
+ phases = []
309
+ for N in n_phase_variations_list:
310
+ phases.append(np.linspace(0.0, np.pi, N))
311
+ phases = np.concatenate(phases, axis=0)
312
+ ##
313
+ if trainable:
314
+ self.phases = nn.Parameter(
315
+ torch.from_numpy(phases).float(),
316
+ requires_grad=trainable,
317
+ ) # n_filters//2
318
+ else:
319
+ self.register_buffer("phases", torch.from_numpy(phases).float())
320
+
321
+ def compute_gammatone_impulse_response(self, center_freqs_in_hz, phases, t):
322
+ """Comptue gammatone impulse responses.
323
+
324
+ Args:
325
+ center_freqs_in_hz (torch.Tensor): Center frequencies in Hz
326
+ phases (torch.Tensor): Phases
327
+ sample_rate (float): Sampling frequency
328
+
329
+ Return:
330
+ torch.Tensor: Sampled impulse response (n_center_freqs x tap_size)
331
+
332
+ """
333
+ center_freqs_in_hz = center_freqs_in_hz[:, None]
334
+ n = 2
335
+ b = (24.7 + center_freqs_in_hz / 9.265) / (
336
+ (np.pi * np.math.factorial(2 * n - 2) * np.power(2, float(-(2 * n - 2))))
337
+ / np.square(np.math.factorial(n - 1))
338
+ ) # equiavalent rectangular bandwidth
339
+ a = 1.0
340
+ return (
341
+ a
342
+ * (t ** (n - 1))
343
+ * torch.exp(-2 * np.pi * b * t)
344
+ * torch.cos(2 * np.pi * center_freqs_in_hz * t + phases[:, None])
345
+ ) # n_center_freqs x tap_size
346
+
347
+ def normalize_filters(self, filter_coeffs):
348
+ """Normalize filter coefficients.
349
+
350
+ Args:
351
+ filter_coeffs (torch.Tensor): Filter coefficients (n_filters x tap_size)
352
+
353
+ Return:
354
+ torch.Tensor: Normalized filter coefficients (n_filters x tap_size)
355
+
356
+ """
357
+ rms_per_filter = (filter_coeffs**2).mean(dim=1).sqrt()
358
+ C = 1.0 / (rms_per_filter / rms_per_filter.max())
359
+ return filter_coeffs * C[:, None]
360
+
361
+ def get_impulse_responses(self, sample_rate: int, tap_size: int):
362
+ """Sample impulse responses.
363
+
364
+ Args:
365
+ sample_rate (int): Target sampling frequency
366
+ tap_size (int): Tap size
367
+
368
+ Return:
369
+ torch.Tensor: Sampled impulse responses (n_filters x tap_size)
370
+
371
+ """
372
+ phases = torch.cat((self.phases, self.phases + np.pi), dim=0) # n_filters
373
+ center_freqs_in_hz = self.center_freqs_in_hz.repeat_interleave(
374
+ self.n_phase_variations,
375
+ dim=0,
376
+ )
377
+ center_freqs_in_hz = center_freqs_in_hz.repeat(2) # doubles for Hilbert pairs
378
+ # check whether the center frequencies are below Nyquist rate
379
+ if self.train_sample_rate > sample_rate:
380
+ mask = center_freqs_in_hz <= sample_rate / 2
381
+ ###
382
+ if tap_size % 2 == 0:
383
+ # even: exclude the origin
384
+ t = (
385
+ torch.arange(1.0, tap_size + 1, 1).type_as(center_freqs_in_hz)
386
+ / sample_rate
387
+ )[None, :]
388
+ else:
389
+ # odd: include the origin
390
+ t = (
391
+ torch.arange(0.0, tap_size, 1).type_as(center_freqs_in_hz) / sample_rate
392
+ )[None, :]
393
+ filter_coeffs = self.compute_gammatone_impulse_response(
394
+ center_freqs_in_hz,
395
+ phases,
396
+ t,
397
+ ).type_as(center_freqs_in_hz) # n_center_freqs x tap_size
398
+ filter_coeffs = self.normalize_filters(filter_coeffs).type_as(
399
+ center_freqs_in_hz,
400
+ )
401
+ if self.train_sample_rate > sample_rate:
402
+ filter_coeffs = filter_coeffs * mask[:, None]
403
+ return filter_coeffs[:, torch.arange(tap_size - 1, -1, -1)]
404
+
405
+
406
+ class RFFTimeDomainImplicitFilter(nn.Module):
407
+ def __init__(
408
+ self,
409
+ n_filters: int,
410
+ init_kernel_size: int,
411
+ init_sample_rate: int,
412
+ ch_list: Sequence[int] = [32, 32],
413
+ n_RFFs: int = 32,
414
+ nonlinearity: str = "relu",
415
+ train_RFF: bool = False,
416
+ use_layer_norm: bool = False,
417
+ ) -> None:
418
+ """n_filters: Number of filters.
419
+ init_kernel_size: Initial kernel size.
420
+ init_sample_rate: Initial sample rate.
421
+ ch_list: Channel list of MLP.
422
+ n_RFFs: Number of RFFs. If n_RFFs <= 0, do not use random Fourier feature inputs (i.e., directly input normalized time).
423
+ nonlinearity (str): Nonlinearity
424
+ train_RFF (bool): If True, train RFFs.
425
+ use_layer_norm (bool): If True, use layer norm.
426
+ """
427
+ super().__init__()
428
+ self.n_filters = n_filters
429
+ self.register_buffer("init_kernel_size", torch.tensor(init_kernel_size).float())
430
+ self.register_buffer("init_sample_rate", torch.tensor(init_sample_rate).float())
431
+
432
+ # nonlinearity
433
+ if nonlinearity == "relu":
434
+ NonlinearityClass = functools.partial(nn.ReLU, inplace=True)
435
+ else:
436
+ raise NotImplementedError
437
+
438
+ # MLP
439
+ layers = []
440
+ in_ch_list = [n_RFFs * 2 if n_RFFs > 0 else 1, *list(ch_list)]
441
+ out_ch_list = [*list(ch_list), n_filters]
442
+ for (i, in_ch), out_ch in zip(enumerate(in_ch_list), out_ch_list):
443
+ layers.append(nn.Conv1d(in_ch, out_ch, 1))
444
+ if i < len(in_ch_list) - 1:
445
+ if use_layer_norm:
446
+ layers.append(nn.GroupNorm(1, out_ch))
447
+ layers.append(NonlinearityClass())
448
+ self.implicit_filter = nn.Sequential(*layers)
449
+
450
+ def init_weights(m) -> None:
451
+ if isinstance(m, nn.Conv1d):
452
+ torch.nn.init.xavier_uniform_(m.weight, gain=1e-3)
453
+ if m.bias is not None:
454
+ torch.nn.init.zeros_(m.bias)
455
+
456
+ self.implicit_filter.apply(init_weights)
457
+
458
+ if n_RFFs > 0:
459
+ self.RFF_param = nn.Parameter(
460
+ torch.zeros((n_RFFs,), dtype=torch.float).normal_(
461
+ 0.0,
462
+ 2.0 * torch.pi * 10.0,
463
+ ),
464
+ requires_grad=train_RFF,
465
+ )
466
+ else:
467
+ self.RFF_param = None
468
+
469
+ def set_zero_bias(m) -> None:
470
+ if isinstance(m, nn.Conv1d):
471
+ if m.bias is None:
472
+ msg = "bias cannot be none"
473
+ raise ValueError(msg)
474
+ m.bias.data.fill_(0.0)
475
+
476
+ self.implicit_filter.apply(set_zero_bias)
477
+
478
+ @staticmethod
479
+ def normalize_filters(filter_coeffs):
480
+ rms_per_filter = (filter_coeffs**2).mean(dim=1).sqrt()
481
+ # rms_per_filter = (filter_coeffs**2).mean(dim=1).clamp(min=1.0e-16)
482
+ # rms_per_filter = rms_per_filter.sqrt()
483
+ C = 1.0 / (rms_per_filter / rms_per_filter.max())
484
+ return filter_coeffs * C[:, None]
485
+
486
+ @property
487
+ def device(self):
488
+ return self.implicit_filter[0].weight.device
489
+
490
+ def _get_ir(self, normalized_time):
491
+ """Get impulse response.
492
+
493
+ Args:
494
+ normalized_time (torch.Tensor): Normalized time (time).
495
+
496
+ Return:
497
+ torch.Tensor: Discrete-time impulse responses (n_filters x time)
498
+
499
+ """
500
+ if self.RFF_param is not None:
501
+ RFF = self.RFF_param[:, None] @ normalized_time[None, :] # n_RFFs x time
502
+ RFF = torch.cat((RFF.sin(), RFF.cos()), dim=0) # n_RFFs*2 x time
503
+ ir = self.implicit_filter(RFF[None, :, :]) # 1 x n_filters x time
504
+ else:
505
+ ir = self.implicit_filter(
506
+ normalized_time[None, None, :],
507
+ ) # 1 x n_filters x time
508
+ return ir.view(*(ir.shape[1:]))
509
+
510
+ def get_impulse_responses(self, sample_rate: int, kernel_size):
511
+ """Calculate discrete-time impulse responses.
512
+
513
+ Corresponding to the weights of the convolutional layer from MLP.
514
+ """
515
+ use_oversampling = False
516
+ if not self.training and hasattr(self, "use_oversampling"):
517
+ use_oversampling = self.use_oversampling
518
+
519
+ if use_oversampling:
520
+ ir = self.get_impulse_responses_oversampling(sample_rate)
521
+ else:
522
+ normalized_time = torch.linspace(
523
+ -1.0,
524
+ 1.0,
525
+ kernel_size,
526
+ device=self.device,
527
+ requires_grad=False,
528
+ ) # time
529
+ ir = self._get_ir(normalized_time)
530
+ return ir
531
+
532
+ def get_impulse_responses_oversampling(self, sample_rate: int):
533
+ """Calculate discrete-time impulse responses from MLP with oversampling for anti-aliasing.
534
+
535
+ First, calculate the discrete-time impulse responses with the trained sample
536
+ rate.
537
+
538
+ Then, resample the calculated discrete-time impulse responses at the input
539
+ sample rate.
540
+ """
541
+ normalized_time = torch.linspace(
542
+ -1.0,
543
+ 1.0,
544
+ self.init_kernel_size.item(),
545
+ device=self.device,
546
+ requires_grad=False,
547
+ ) # time
548
+ ir = self._get_ir(normalized_time)
549
+ resampled_ir = torchaudio.functional.resample(
550
+ ir,
551
+ int(self.init_sample_rate.item()),
552
+ int(sample_rate),
553
+ ) # resampling
554
+ return resampled_ir.float().to(self.device)
555
+
556
+
557
+ class FrequencyDomainRFFImplicitFilter(nn.Module):
558
+ """Nueral analog filter (NAF) for frequency-domain sampling-frequency-independent convolutional layer in [1].
559
+
560
+ [1] Kanami Imamura, Tomohiko Nakamura, Kohei Yatabe, and Hiroshi Saruwatari, ``Neural analog filter for sampling-frequency-independent convolutional layer," APSIPA Transactions on Signal and Information Processing, vol. 13, no. 1, e28, Nov. 2024.
561
+ """
562
+
563
+ def __init__(
564
+ self,
565
+ n_filters: int,
566
+ max_freq: int,
567
+ ch_list: list[int] = [224, 224],
568
+ n_rffs: int = 128,
569
+ nonlinearity: str = "relu",
570
+ train_rff: bool = True,
571
+ use_layer_norm: bool = True,
572
+ ):
573
+ """Initialize FrequencyDomainRFFImplicitFilter.
574
+
575
+ Args:
576
+ n_filters (int): Number of filters
577
+ max_freq (float): Max. of frequency (i.e., Nyquist frequency of training data)
578
+ ch_list (list[int]): Channel list of MLP
579
+ n_rffs (int): # of RFFs. If equal to or less than 0, RFFs are not used.
580
+ nonlinearity (str): Nonlinearity
581
+ train_rff (bool): If True, train RFFs.
582
+ use_layer_norm (bool): If True, use layer norm.
583
+ """
584
+ super().__init__()
585
+ self.use_RFFs = n_rffs > 0
586
+
587
+ # nonlinearity
588
+ if nonlinearity == "relu":
589
+ nonlinearity = functools.partial(nn.ReLU, inplace=True)
590
+ elif nonlinearity == "none":
591
+ nonlinearity = functools.partial(nn.Identity, inplace=True)
592
+ else:
593
+ raise NotImplementedError
594
+
595
+ self.n_filters = n_filters
596
+ self.register_buffer("max_ang_freq", torch.tensor(max_freq * 2.0 * np.pi))
597
+ layers = []
598
+ in_ch_list = [n_rffs * 2 if self.use_RFFs else 1] + [i for i in ch_list]
599
+ out_ch_list = [i for i in ch_list] + [n_filters * 2]
600
+ for (i, in_ch), out_ch in zip(enumerate(in_ch_list), out_ch_list):
601
+ layers.append(nn.Conv1d(in_ch, out_ch, 1))
602
+ if i < len(in_ch_list) - 1:
603
+ if use_layer_norm:
604
+ layers.append(nn.GroupNorm(1, out_ch))
605
+ layers.append(nonlinearity())
606
+ self.implicit_filter = nn.Sequential(*layers)
607
+
608
+ if self.use_RFFs:
609
+ self.RFF_param = nn.Parameter(
610
+ torch.zeros((n_rffs,), dtype=torch.float).normal_(
611
+ 0.0, 2.0 * np.pi * 10.0
612
+ ),
613
+ requires_grad=train_rff,
614
+ )
615
+
616
+ def set_zero_bias(m):
617
+ if isinstance(m, nn.Conv1d):
618
+ m.bias.data.fill_(0.0)
619
+
620
+ self.implicit_filter.apply(set_zero_bias)
621
+ self.use_ideal_low_pass_filter = True
622
+
623
+ @property
624
+ def device(self):
625
+ """Device."""
626
+ return self.implicit_filter[0].weight.device
627
+
628
+ def get_frequency_responses(self, omega: torch.Tensor):
629
+ """Calculating frequency responses from MLP.
630
+
631
+ Args:
632
+ omega (torch.Tensor): (Unnormalized) angular frequencies (n_angfreqs)
633
+
634
+ Return:
635
+ Tuple[torch.Tensor,torch.Tensor]: Real and imaginary parts of frequency characteristics (pair of n_filters x n_angfreqs as tuple)
636
+ """
637
+ omega = omega / self.max_ang_freq # n_angfreqs
638
+ if self.use_RFFs:
639
+ x = self.RFF_param[:, None] @ omega[None, :] # n_RFFs x n_angfreqs
640
+ x = torch.cat((x.cos(), x.sin()), dim=0) # n_RFFs*2 x n_angfreqs
641
+ else:
642
+ x = omega[None, :] # 1 x n_angfreqs
643
+ freq_resps = self.implicit_filter(
644
+ x[None, :, :]
645
+ ) # 1 x n_RFFs*2 (or 1 (ang. freq.)) x n_angfreqs -> 1 x n_filters*2 x n_angfreqs
646
+
647
+ # Apply ideal low pass filter
648
+ if not self.training and omega.max() > 1.0 and self.use_ideal_low_pass_filter:
649
+ freq_resps *= (omega <= 1.0).float()[None, None, :]
650
+
651
+ return freq_resps[0, : self.n_filters, :], freq_resps[0, self.n_filters :, :]
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ff1edfa2e81e0c2b346d837fe0300dcc7273cf5b2d4c66b1ad26f8878a513fc6
3
  size 378849388
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7eb46af981d43c8568d802676c421e0b375808904c6b5788390d8d97122134d
3
  size 378849388
modeling.py CHANGED
@@ -1,9 +1,8 @@
1
  from transformers import HubertModel
2
  from transformers.models.hubert.modeling_hubert import HubertFeatureEncoder
3
 
4
- from sfi_ssl.model.hubert.continuous_filters import FrequencyDomainRFFImplicitFilter
5
-
6
  from .configuration import SfiHuBERTConfig
 
7
  from .conv_any_stride import FreqRespSampConv1d
8
 
9
 
 
1
  from transformers import HubertModel
2
  from transformers.models.hubert.modeling_hubert import HubertFeatureEncoder
3
 
 
 
4
  from .configuration import SfiHuBERTConfig
5
+ from .continuous_filters import FrequencyDomainRFFImplicitFilter
6
  from .conv_any_stride import FreqRespSampConv1d
7
 
8