mwirth7 commited on
Commit
ebaddb5
·
verified ·
1 Parent(s): 3be8907
config.json ADDED
The diff for this file is too large to render. See raw diff
 
configuration_protonet.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ import warnings
4
+
5
+ class AudioProtoNetConfig(PretrainedConfig):
6
+ _auto_class = "AutoConfig"
7
+ model_type = "AudioProtoNet"
8
+
9
+ def __init__(
10
+ self,
11
+ prototypes_per_class: int = 1,
12
+ channels: int = 1024,
13
+ height: int = 1,
14
+ width: int = 1,
15
+ num_classes: int = 9736,
16
+ topk_k: int = 1,
17
+ margin: float = None,
18
+ add_on_layers_type: str = "upsample",
19
+ incorrect_class_connection: float = None,
20
+ correct_class_connection: float = 1.0,
21
+ bias_last_layer: float = -2.0,
22
+ non_negative_last_layer: bool = True,
23
+ embedded_spectrogram_height: int = None,
24
+ **kwargs,
25
+ ):
26
+ super().__init__(**kwargs)
27
+ self.prototypes_per_class = prototypes_per_class
28
+ self.num_prototypes_after_pruning = None
29
+ self.channels = channels
30
+ self.height = height
31
+ self.width = width
32
+ self.num_classes = num_classes
33
+ self.topk_k = topk_k
34
+ self.margin = margin
35
+ self.relu_on_cos = True
36
+ self.add_on_layers_type = add_on_layers_type
37
+ self.incorrect_class_connection = incorrect_class_connection
38
+ self.correct_class_connection = correct_class_connection
39
+ self.input_vector_length = 64
40
+ self.n_eps_channels = 2
41
+ self.epsilon_val = 1e-4
42
+ self.bias_last_layer = bias_last_layer
43
+ self.non_negative_last_layer = non_negative_last_layer
44
+ self.embedded_spectrogram_height = embedded_spectrogram_height
45
+
46
+ if self.bias_last_layer:
47
+ self.use_bias_last_layer = True
48
+ else:
49
+ self.use_bias_last_layer = False
50
+
51
+ self.prototype_class_identity = None
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:93eaf3eed2413ea93032d0965254ef193a93f5c186553c0f9bbc32f801344e5d
3
+ size 1148684416
modeling_protonet.py ADDED
@@ -0,0 +1,953 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch import Tensor
4
+ import math
5
+
6
+ from transformers import PreTrainedModel, ConvNextModel, ConvNextConfig
7
+ from transformers.utils import logging
8
+ from transformers.modeling_outputs import ModelOutput, BaseModelOutputWithPoolingAndNoAttention
9
+ from dataclasses import dataclass
10
+
11
+ from .configuration_protonet import AudioProtoNetConfig
12
+
13
+ logger = logging.get_logger(__name__)
14
+
15
+
16
+ @dataclass
17
+ class SequenceClassifierOutputWithProtoTypeActivations(ModelOutput):
18
+ logits: torch.Tensor
19
+ loss: torch.Tensor = None
20
+ last_hidden_state: torch.FloatTensor = None
21
+ hidden_states: tuple[torch.FloatTensor, ...] = None
22
+ prototype_activations: torch.FloatTensor = None
23
+
24
+
25
+ # https://openaccess.thecvf.com/content/ICCV2021/papers/Ridnik_Asymmetric_Loss_for_Multi-Label_Classification_ICCV_2021_paper.pdf
26
+ # https://github.com/huggingface/pytorch-image-models/blob/bbe798317fb26f063c18279827c038058e376479/timm/loss/asymmetric_loss.py#L6
27
+ class AsymmetricLossMultiLabel(nn.Module):
28
+ def __init__(
29
+ self,
30
+ gamma_neg=4,
31
+ gamma_pos=1,
32
+ clip=0.05,
33
+ eps=1e-8,
34
+ disable_torch_grad_focal_loss=False,
35
+ reduction="mean",
36
+ ):
37
+ super().__init__()
38
+
39
+ self.gamma_neg = gamma_neg
40
+ self.gamma_pos = gamma_pos
41
+ self.clip = clip
42
+ self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
43
+ self.eps = eps
44
+ self.reduction = reduction
45
+
46
+ def forward(self, x, y):
47
+ """ "
48
+ Parameters
49
+ ----------
50
+ x: input logits
51
+ y: targets (multi-label binarized vector)
52
+ """
53
+
54
+ # Calculating Probabilities
55
+ x_sigmoid = torch.sigmoid(x)
56
+ xs_pos = x_sigmoid
57
+ xs_neg = 1 - x_sigmoid
58
+
59
+ # Asymmetric Clipping
60
+ if self.clip is not None and self.clip > 0:
61
+ xs_neg = (xs_neg + self.clip).clamp(max=1)
62
+
63
+ # Basic CE calculation
64
+ los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
65
+ los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
66
+ loss = los_pos + los_neg
67
+
68
+ # Asymmetric Focusing
69
+ if self.gamma_neg > 0 or self.gamma_pos > 0:
70
+ if self.disable_torch_grad_focal_loss:
71
+ torch._C.set_grad_enabled(False)
72
+ pt0 = xs_pos * y
73
+ pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
74
+ pt = pt0 + pt1
75
+ one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
76
+ one_sided_w = torch.pow(1 - pt, one_sided_gamma)
77
+ if self.disable_torch_grad_focal_loss:
78
+ torch._C.set_grad_enabled(True)
79
+ loss *= one_sided_w
80
+
81
+ if self.reduction == "mean":
82
+ return -loss.mean()
83
+ if self.reduction == "sum":
84
+ return -loss.sum()
85
+
86
+ return -loss
87
+
88
+
89
+ class NonNegativeLinear(nn.Module):
90
+ """
91
+ A PyTorch module for a linear layer with non-negative weights.
92
+
93
+ This module applies a linear transformation to the incoming data: `y = xA^T + b`.
94
+ The weights of the transformation are constrained to be non-negative, making this
95
+ module particularly useful in models where negative weights may not be appropriate.
96
+
97
+ Attributes:
98
+ in_features (int): The number of features in the input tensor.
99
+ out_features (int): The number of features in the output tensor.
100
+ weight (torch.Tensor): The weight parameter of the module, constrained to be non-negative.
101
+ bias (torch.Tensor, optional): The bias parameter of the module.
102
+
103
+ Args:
104
+ in_features (int): The number of features in the input tensor.
105
+ out_features (int): The number of features in the output tensor.
106
+ bias (bool, optional): If True, the layer will include a learnable bias. Default: True.
107
+ device (optional): The device (CPU/GPU) on which to perform computations.
108
+ dtype (optional): The data type for the parameters (e.g., float32).
109
+ """
110
+
111
+ def __init__(
112
+ self,
113
+ in_features: int,
114
+ out_features: int,
115
+ bias: bool = True,
116
+ device=None,
117
+ dtype=None,
118
+ ) -> None:
119
+ factory_kwargs = {"device": device, "dtype": dtype}
120
+ super().__init__()
121
+ self.in_features = in_features
122
+ self.out_features = out_features
123
+ self.weight = nn.Parameter(
124
+ torch.empty((out_features, in_features), **factory_kwargs)
125
+ )
126
+ if bias:
127
+ self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
128
+ else:
129
+ self.register_parameter("bias", None)
130
+
131
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
132
+ """
133
+ Defines the forward pass of the NonNegativeLinear module.
134
+
135
+ Args:
136
+ input (torch.Tensor): The input tensor of shape (batch_size, in_features).
137
+
138
+ Returns:
139
+ torch.Tensor: The output tensor of shape (batch_size, out_features).
140
+ """
141
+ return nn.functional.linear(input, torch.relu(self.weight), self.bias)
142
+
143
+
144
+ class LinearLayerWithoutNegativeConnections(nn.Module):
145
+ r"""
146
+ Custom Linear Layer where each output class is connected to a specific subset of input features.
147
+
148
+ Args:
149
+ in_features: size of each input sample
150
+ out_features: size of each output sample
151
+ bias: If set to ``False``, the layer will not learn an additive bias.
152
+ Default: ``True``
153
+ device: the device of the module parameters. Default: ``None``
154
+ dtype: the data type of the module parameters. Default: ``None``
155
+
156
+ Shape:
157
+ - Input: :math:`(*, H_{in})` where :math:`*` means any number of
158
+ dimensions including none and :math:`H_{in} = \text{in_features}`.
159
+ - Output: :math:`(*, H_{out})` where all but the last dimension
160
+ are the same shape as the input and :math:`H_{out} = \text{out_features}`.
161
+
162
+ Attributes:
163
+ weight: the learnable weights of the module of shape
164
+ :math:`(\text{out_features}, \text{features_per_output_class})`.
165
+ bias: the learnable bias of the module of shape :math:`(\text{out_features})`.
166
+ If :attr:`bias` is ``True``, the values are initialized from
167
+ :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
168
+ :math:`k = \frac{1}{\text{features_per_output_class}}`
169
+ """
170
+
171
+ __constants__ = ["in_features", "out_features", "bias"]
172
+ in_features: int
173
+ out_features: int
174
+ weight: torch.Tensor
175
+
176
+ def __init__(
177
+ self,
178
+ in_features: int,
179
+ out_features: int,
180
+ bias: bool = True,
181
+ non_negative: bool = True,
182
+ device: torch.device = None,
183
+ dtype: torch.dtype = None,
184
+ ) -> None:
185
+ factory_kwargs = {"device": device, "dtype": dtype}
186
+ super().__init__()
187
+ self.in_features = in_features
188
+ self.out_features = out_features
189
+ self.non_negative = non_negative
190
+
191
+ # Calculate the number of features per output class
192
+ self.features_per_output_class = in_features // out_features
193
+
194
+ # Ensure input size is divisible by the output size
195
+ assert (
196
+ in_features % out_features == 0
197
+ ), f"{in_features = } must be divisible by {out_features = }"
198
+
199
+ # Define weights and biases
200
+ self.weight = nn.Parameter(
201
+ torch.empty(
202
+ (out_features, self.features_per_output_class), **factory_kwargs
203
+ )
204
+ )
205
+ if bias:
206
+ self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
207
+ else:
208
+ self.register_parameter("bias", None)
209
+
210
+ # Initialize weights and biases
211
+ self.reset_parameters()
212
+
213
+ def reset_parameters(self) -> None:
214
+ """
215
+ Initialize the weights and biases.
216
+ Weights are initialized using Kaiming uniform initialization.
217
+ Biases are initialized using a uniform distribution.
218
+ """
219
+ # Kaiming uniform initialization for the weights
220
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
221
+
222
+ if self.bias is not None:
223
+ # Calculate fan-in and fan-out values
224
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
225
+
226
+ # Uniform initialization for the biases
227
+ bound = 1 / math.sqrt(fan_in)
228
+ nn.init.uniform_(self.bias, -bound, bound)
229
+
230
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
231
+ """
232
+ Forward pass for the custom linear layer.
233
+
234
+ Args:
235
+ input (Tensor): Input tensor of shape (batch_size, in_features).
236
+
237
+ Returns:
238
+ Tensor: Output tensor of shape (batch_size, out_features).
239
+ """
240
+ batch_size = input.size(0)
241
+ # Reshape input to (batch_size, out_features, features_per_output_class)
242
+ reshaped_input = input.view(
243
+ batch_size, self.out_features, self.features_per_output_class
244
+ )
245
+
246
+ # Apply ReLU to weights if non_negative_last_layer is True
247
+ weight = torch.relu(self.weight) if self.non_negative else self.weight
248
+
249
+ # Perform batch matrix multiplication and add bias
250
+ output = torch.einsum("bof,of->bo", reshaped_input, weight)
251
+
252
+ if self.bias is not None:
253
+ output += self.bias
254
+
255
+ return output
256
+
257
+ def extra_repr(self) -> str:
258
+ return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}"
259
+
260
+
261
+ class AudioProtoNetClassificationHead(nn.Module):
262
+ def __init__(
263
+ self,
264
+ config: AudioProtoNetConfig,
265
+ ) -> None:
266
+ """
267
+ PPNet is a class that implements the Prototypical Part Network (ProtoPNet) for prototype-based classification.
268
+ """
269
+
270
+ super().__init__()
271
+ self.prototypes_per_class = config.prototypes_per_class
272
+ self.num_classes = config.num_classes
273
+ self.num_prototypes = self.prototypes_per_class * self.num_classes
274
+ self.num_prototypes_after_pruning = config.num_prototypes_after_pruning
275
+ self.margin = config.margin
276
+ self.relu_on_cos = config.relu_on_cos
277
+ self.incorrect_class_connection = config.incorrect_class_connection
278
+ self.correct_class_connection = config.correct_class_connection
279
+ self.input_vector_length = config.input_vector_length
280
+ self.n_eps_channels = config.n_eps_channels
281
+ self.epsilon_val = config.epsilon_val
282
+ self.topk_k = config.topk_k
283
+ self.bias_last_layer = config.bias_last_layer
284
+ self.non_negative_last_layer = config.non_negative_last_layer
285
+ self.embedded_spectrogram_height = config.embedded_spectrogram_height
286
+ self.use_bias_last_layer = config.use_bias_last_layer
287
+ self.prototype_class_identity = config.prototype_class_identity
288
+
289
+ # Create a 1D tensor where each element represents the class index
290
+ self.prototype_class_identity = (
291
+ torch.arange(self.num_prototypes) // self.prototypes_per_class
292
+ )
293
+
294
+ self.prototype_shape = (self.num_prototypes, config.channels, config.height, config.width)
295
+
296
+ self._setup_add_on_layers(add_on_layers_type=config.add_on_layers_type)
297
+
298
+ self.prototype_vectors = nn.Parameter(
299
+ torch.rand(self.prototype_shape), requires_grad=True
300
+ )
301
+
302
+ self.frequency_weights = None
303
+ if self.embedded_spectrogram_height is not None:
304
+ # Initialize the frequency weights with a large positive value of 3.0 so that sigmoid(frequency_weights) is close to 1.
305
+ self.frequency_weights = nn.Parameter(
306
+ torch.full(
307
+ (
308
+ self.num_prototypes,
309
+ self.embedded_spectrogram_height,
310
+ ),
311
+ 3.0,
312
+ )
313
+ )
314
+
315
+
316
+ if self.incorrect_class_connection:
317
+ if self.non_negative_last_layer:
318
+ self.last_layer = NonNegativeLinear(
319
+ self.num_prototypes, self.num_classes, bias=self.use_bias_last_layer
320
+ )
321
+ else:
322
+ self.last_layer = nn.Linear(
323
+ self.num_prototypes, self.num_classes, bias=self.use_bias_last_layer
324
+ )
325
+ else:
326
+ self.last_layer = LinearLayerWithoutNegativeConnections(
327
+ in_features=self.num_prototypes,
328
+ out_features=self.num_classes,
329
+ non_negative=self.non_negative_last_layer,
330
+ )
331
+
332
+ def forward(
333
+ self,
334
+ features: torch.Tensor,
335
+ prototypes_of_wrong_class: torch.Tensor = None,
336
+ ) -> tuple[torch.Tensor, list[torch.Tensor]]:
337
+ """
338
+ Forward pass of the PPNet model.
339
+
340
+ Args:
341
+ - x (torch.Tensor): Input tensor with shape (batch_size, num_channels, height, width).
342
+ - prototypes_of_wrong_class (Optional[torch.Tensor]): The prototypes of the wrong classes that are needed
343
+ when using subtractive margins. Defaults to None.
344
+
345
+ Returns:
346
+ Tuple[torch.Tensor, List[torch.Tensor]]:
347
+ - logits: A tensor containing the logits for each class in the model.
348
+ - a list containing:
349
+ - mean_activations: A tensor containing the mean of the top-k prototype activations.
350
+ (in evaluation mode k is always 1)
351
+ - marginless_logits: A tensor containing the logits for each class in the model, calculated using the
352
+ marginless activations.
353
+ - conv_features: A tensor containing the convolutional features.
354
+ - marginless_max_activations: A tensor containing the max-pooled marginless activations.
355
+
356
+ """
357
+
358
+ features = self.add_on_layers(features)
359
+
360
+ activations, additional_returns = self.prototype_activations(
361
+ features, prototypes_of_wrong_class=prototypes_of_wrong_class
362
+ )
363
+ marginless_activations = additional_returns[0]
364
+ conv_features = additional_returns[1]
365
+
366
+ # Set topk_k based on training mode: use predefined value if training, else 1 for evaluation
367
+ topk_k = 1
368
+
369
+ # Reshape activations to combine spatial dimensions: (batch_size, num_prototypes, height*width)
370
+ activations = activations.view(activations.shape[0], activations.shape[1], -1)
371
+
372
+ # Perform top-k pooling along the combined spatial dimension
373
+ # For topk_k=1, this is equivalent to global max pooling
374
+ topk_activations, _ = torch.topk(activations, topk_k, dim=-1)
375
+
376
+ # Calculate the mean of the top-k activations for each channel: (batch_size, num_channels)
377
+ # If topk_k=1, this mean operation does nothing since there's only one value.
378
+ mean_activations = torch.mean(topk_activations, dim=-1)
379
+
380
+ marginless_max_activations = nn.functional.max_pool2d(
381
+ marginless_activations,
382
+ kernel_size=(
383
+ marginless_activations.size()[2],
384
+ marginless_activations.size()[3],
385
+ ),
386
+ )
387
+ marginless_max_activations = marginless_max_activations.view(
388
+ -1, self.num_prototypes
389
+ )
390
+
391
+ logits = self.last_layer(mean_activations)
392
+ marginless_logits = self.last_layer(marginless_max_activations)
393
+ return logits, [
394
+ mean_activations,
395
+ marginless_logits,
396
+ conv_features,
397
+ marginless_max_activations,
398
+ marginless_activations,
399
+ ]
400
+
401
+ # def conv_features(self, x: torch.Tensor) -> torch.Tensor:
402
+ # """
403
+ # Takes an input tensor and passes it through the backbone model to extract features.
404
+ # Then, it passes them through the additional layers to produce the output tensor.
405
+ #
406
+ # Args:
407
+ # x (torch.Tensor): The input tensor.
408
+ #
409
+ # Returns:
410
+ # torch.Tensor: The output tensor after passing through the backbone model and additional layers.
411
+ # """
412
+ # # Extract features using the backbone model
413
+ # features = self.backbone_model(x)
414
+ #
415
+ # # The features must be a 4D tensor of shape (batch size, channels, height, width)
416
+ # if features.dim() == 3:
417
+ # features.unsqueeze_(0)
418
+ #
419
+ # # Pass the features through additional layers
420
+ # output = self.add_on_layers(features)
421
+ #
422
+ # return output
423
+
424
+ def cos_activation(
425
+ self,
426
+ x: torch.Tensor,
427
+ prototypes_of_wrong_class: torch.Tensor = None,
428
+ ) -> tuple[torch.Tensor, torch.Tensor]:
429
+ """
430
+ Compute the cosine activation between input tensor x and prototype vectors.
431
+
432
+ Parameters:
433
+ -----------
434
+ x : torch.Tensor
435
+ Input tensor with shape (batch_size, num_channels, height, width).
436
+ prototypes_of_wrong_class : Optional[torch.Tensor]
437
+ Tensor containing the prototypes of the wrong class with shape (batch_size, num_prototypes).
438
+
439
+ Returns:
440
+ --------
441
+ Tuple[torch.Tensor, torch.Tensor]
442
+ A tuple containing:
443
+ - activations: The cosine activations with potential margin adjustments.
444
+ - marginless_activations: The cosine activations without margin adjustments.
445
+ """
446
+ input_vector_length = self.input_vector_length
447
+ normalizing_factor = (
448
+ self.prototype_shape[-2] * self.prototype_shape[-1]
449
+ ) ** 0.5
450
+
451
+ # Pre-allocate epsilon channels on the correct device for input tensor x
452
+ epsilon_channel_x = torch.full(
453
+ (x.shape[0], self.n_eps_channels, x.shape[2], x.shape[3]),
454
+ self.epsilon_val,
455
+ device=x.device,
456
+ requires_grad=False,
457
+ )
458
+ x = torch.cat((x, epsilon_channel_x), dim=-3)
459
+
460
+ # Normalize x
461
+ x_length = torch.sqrt(torch.sum(x**2, dim=-3, keepdim=True) + self.epsilon_val)
462
+ x_normalized = (input_vector_length * x / x_length) / normalizing_factor
463
+
464
+ # Pre-allocate epsilon channels for prototypes on the correct device
465
+ epsilon_channel_p = torch.full(
466
+ (
467
+ self.prototype_shape[0],
468
+ self.n_eps_channels,
469
+ self.prototype_shape[2],
470
+ self.prototype_shape[3],
471
+ ),
472
+ self.epsilon_val,
473
+ device=self.prototype_vectors.device,
474
+ requires_grad=False,
475
+ )
476
+ appended_protos = torch.cat((self.prototype_vectors, epsilon_channel_p), dim=-3)
477
+
478
+ # Normalize prototypes
479
+ prototype_vector_length = torch.sqrt(
480
+ torch.sum(appended_protos**2, dim=-3, keepdim=True) + self.epsilon_val
481
+ )
482
+ normalized_prototypes = appended_protos / (
483
+ prototype_vector_length + self.epsilon_val
484
+ )
485
+ normalized_prototypes /= normalizing_factor
486
+
487
+ # Compute activations using convolution
488
+ activations_dot = nn.functional.conv2d(x_normalized, normalized_prototypes)
489
+ marginless_activations = activations_dot / (input_vector_length * 1.01)
490
+
491
+ if self.frequency_weights is not None:
492
+ # Apply sigmoid to frequency weights. s.t. weights are between 0 and 1.
493
+ freq_weights = torch.sigmoid(self.frequency_weights)
494
+
495
+ # Multiply each prototype's frequency response by the corresponding weights
496
+ marginless_activations = marginless_activations * freq_weights[:, :, None]
497
+
498
+ if (
499
+ self.margin is None
500
+ or not self.training
501
+ or prototypes_of_wrong_class is None
502
+ ):
503
+ activations = marginless_activations
504
+ else:
505
+ # Apply margin adjustment for wrong class prototypes
506
+ wrong_class_margin = (prototypes_of_wrong_class * self.margin).view(
507
+ x.size(0), self.prototype_vectors.size(0), 1, 1
508
+ )
509
+ wrong_class_margin = wrong_class_margin.expand(
510
+ -1, -1, activations_dot.size(-2), activations_dot.size(-1)
511
+ )
512
+ penalized_angles = (
513
+ torch.acos(activations_dot / (input_vector_length * 1.01))
514
+ - wrong_class_margin
515
+ )
516
+ activations = torch.cos(torch.relu(penalized_angles))
517
+
518
+ if self.relu_on_cos:
519
+ # Apply ReLU activation on the cosine values
520
+ activations = torch.relu(activations)
521
+ marginless_activations = torch.relu(marginless_activations)
522
+
523
+ return activations, marginless_activations
524
+
525
+ def prototype_activations(
526
+ self,
527
+ x: torch.Tensor,
528
+ prototypes_of_wrong_class: torch.Tensor = None,
529
+ ) -> tuple[torch.Tensor, list[torch.Tensor]]:
530
+ """
531
+ Compute the prototype activations for a given input tensor.
532
+
533
+ Args:
534
+ - x (torch.Tensor): The raw input tensor with shape (batch_size, num_channels, height, width).
535
+ - prototypes_of_wrong_class (Optional[torch.Tensor]): The prototypes of the wrong classes that are needed
536
+ when using subtractive margins. Defaults to None.
537
+
538
+ Returns:
539
+ Tuple[torch.Tensor, List[torch.Tensor]]:
540
+ - activations: A tensor containing the prototype activations.
541
+ - a list containing:
542
+ - marginless_activations: A tensor containing the activations before applying subtractive margin.
543
+ - conv_features: A tensor containing the convolutional features.
544
+ """
545
+ # Compute cosine activations
546
+ activations, marginless_activations = self.cos_activation(
547
+ x,
548
+ prototypes_of_wrong_class=prototypes_of_wrong_class,
549
+ )
550
+
551
+ return activations, [marginless_activations, x]
552
+
553
+ def get_prototype_orthogonalities(self, use_part_prototypes: bool = False) -> torch.Tensor:
554
+ """
555
+ Computes the orthogonality loss, encouraging each piece of a prototype to be orthogonal to the others.
556
+
557
+ This method is inspired by the paper:
558
+ https://openaccess.thecvf.com/content/ICCV2021/papers/Wang_Interpretable_Image_Recognition_by_Constructing_Transparent_Embedding_Space_ICCV_2021_paper.pdf
559
+
560
+ Args:
561
+ use_part_prototypes (bool): If True, treats each spatial part of the prototypes as a separate prototype.
562
+
563
+ Returns:
564
+ torch.Tensor: A tensor representing the orthogonalities.
565
+ """
566
+
567
+ if use_part_prototypes:
568
+ # Normalize prototypes to unit length
569
+ prototype_vector_length = torch.sqrt(
570
+ torch.sum(torch.square(self.prototype_vectors), dim=1, keepdim=True)
571
+ + self.epsilon_val
572
+ )
573
+ normalized_prototypes = self.prototype_vectors / (
574
+ prototype_vector_length + self.epsilon_val
575
+ )
576
+
577
+ # Calculate total part prototypes per class
578
+ num_part_prototypes_per_class = (
579
+ self.num_prototypes_per_class
580
+ * self.prototype_shape[2]
581
+ * self.prototype_shape[3]
582
+ )
583
+
584
+ # Reshape to match class structure
585
+ normalized_prototypes = normalized_prototypes.view(
586
+ self.num_classes,
587
+ self.num_prototypes_per_class,
588
+ self.prototype_shape[1],
589
+ self.prototype_shape[2] * self.prototype_shape[3],
590
+ )
591
+
592
+ # Transpose and reshape to treat each spatial part as a separate prototype
593
+ normalized_prototypes = normalized_prototypes.permute(0, 1, 3, 2).reshape(
594
+ self.num_classes, num_part_prototypes_per_class, self.prototype_shape[1]
595
+ )
596
+
597
+ else:
598
+ # Normalize prototypes to unit length
599
+ prototype_vectors_reshaped = self.prototype_vectors.view(
600
+ self.num_prototypes, -1
601
+ )
602
+ prototype_vector_length = torch.sqrt(
603
+ torch.sum(torch.square(prototype_vectors_reshaped), dim=1, keepdim=True)
604
+ + self.epsilon_val
605
+ )
606
+ normalized_prototypes = prototype_vectors_reshaped / (
607
+ prototype_vector_length + self.epsilon_val
608
+ )
609
+
610
+ # Reshape to match class structure
611
+ normalized_prototypes = normalized_prototypes.view(
612
+ self.num_classes,
613
+ self.num_prototypes_per_class,
614
+ self.prototype_shape[1]
615
+ * self.prototype_shape[2]
616
+ * self.prototype_shape[3],
617
+ )
618
+
619
+ # Compute orthogonality matrix for each class
620
+ orthogonalities = torch.matmul(
621
+ normalized_prototypes, normalized_prototypes.transpose(1, 2)
622
+ )
623
+
624
+ # Identity matrix to enforce orthogonality
625
+ identity_matrix = (
626
+ torch.eye(normalized_prototypes.shape[1], device=orthogonalities.device)
627
+ .unsqueeze(0)
628
+ .repeat(self.num_classes, 1, 1)
629
+ )
630
+
631
+ # Subtract identity to focus on orthogonality
632
+ orthogonalities = orthogonalities - identity_matrix
633
+
634
+ return orthogonalities
635
+
636
+ def identify_prototypes_to_prune(self) -> list[int]:
637
+ """
638
+ Identifies the indices of prototypes that should be pruned.
639
+
640
+ This function iterates through the prototypes and checks if the specific weight
641
+ connecting the prototype to its class is zero. It is specifically designed to handle
642
+ the LinearLayerWithoutNegativeConnections where each class has a subset of features
643
+ it connects to.
644
+
645
+ Returns:
646
+ list[int]: A list of prototype indices that should be pruned.
647
+ """
648
+ prototypes_to_prune = []
649
+
650
+ # Calculate the number of prototypes assigned to each class
651
+ prototypes_per_class = self.num_prototypes // self.num_classes
652
+
653
+ if isinstance(self.last_layer, LinearLayerWithoutNegativeConnections):
654
+ # Custom layer mapping prototypes to a subset of input features for each output class
655
+ for prototype_index in range(self.num_prototypes):
656
+ class_index = self.prototype_class_identity[prototype_index]
657
+ # Calculate the specific index within the 'features_per_output_class' for this prototype
658
+ index_within_class = prototype_index % prototypes_per_class
659
+ # Check if the specific weight connecting the prototype to its class is zero
660
+ if self.last_layer.weight[class_index, index_within_class] == 0.0:
661
+ prototypes_to_prune.append(prototype_index)
662
+ else:
663
+ # Standard linear layer: each prototype directly maps to a feature index
664
+ weights_to_check = self.last_layer.weight
665
+ for prototype_index in range(self.num_prototypes):
666
+ class_index = self.prototype_class_identity[prototype_index]
667
+ if weights_to_check[class_index, prototype_index] == 0.0:
668
+ prototypes_to_prune.append(prototype_index)
669
+
670
+ return prototypes_to_prune
671
+
672
+ def prune_prototypes_by_threshold(self, threshold: float = 1e-3) -> None:
673
+ """
674
+ Prune the weights in the classification layer by setting weights below a specified threshold to zero.
675
+
676
+ This method modifies the weights of the last layer of the model in-place. Weights falling below the
677
+ threshold are set to zero, diminishing their influence in the model's decisions. It also identifies
678
+ and prunes prototypes based on these updated weights, thereby refining the model's structure.
679
+
680
+ Args:
681
+ threshold (float): The threshold value below which weights will be set to zero. Defaults to 1e-3.
682
+ """
683
+ # Access the weights of the last layer
684
+ weights = self.last_layer.weight.data
685
+
686
+ # Set weights below the threshold to zero
687
+ # This step reduces the influence of low-value weights in the model's decision-making process
688
+ weights[weights < threshold] = 0.0
689
+
690
+ # Update the weights in the last layer to reflect the pruning
691
+ self.last_layer.weight.data.copy_(weights)
692
+
693
+ # Identify prototypes that need to be pruned based on the updated weights
694
+ prototypes_to_prune = self.identify_prototypes_to_prune()
695
+
696
+ # Execute the pruning of identified prototypes
697
+ self.prune_prototypes_by_index(prototypes_to_prune)
698
+
699
+ def prune_prototypes_by_index(self, prototypes_to_prune: list[int]) -> None:
700
+ """
701
+ Prunes specified prototypes from the PPNet.
702
+
703
+ Args:
704
+ prototypes_to_prune (list[int]): A list of indices indicating the prototypes to be removed.
705
+ Each index should be in the range [0, current number of prototypes - 1].
706
+
707
+ Returns:
708
+ None
709
+ """
710
+
711
+ # Validate the provided indices to ensure they are within the valid range
712
+ if any(
713
+ index < 0 or index >= self.num_prototypes for index in prototypes_to_prune
714
+ ):
715
+ raise ValueError("Provided prototype indices are out of valid range!")
716
+
717
+ # Calculate the new number of prototypes after pruning
718
+ self.num_prototypes_after_pruning = self.num_prototypes - len(
719
+ prototypes_to_prune
720
+ )
721
+
722
+ # Remove the prototype vectors that are no longer needed
723
+ with torch.no_grad():
724
+ # If frequency_weights are being used, set the weights of pruned prototypes to -7
725
+ if self.frequency_weights is not None:
726
+ self.frequency_weights.data[prototypes_to_prune, :] = -7.0
727
+
728
+ # Adjust the weights in the last layer depending on its type
729
+ if isinstance(self.last_layer, LinearLayerWithoutNegativeConnections):
730
+ # For LinearLayerWithoutNegativeConnections, set the connection weights to zero
731
+ # only for the pruned prototypes related to their specific classes
732
+ for class_idx in range(self.last_layer.out_features):
733
+ # Identify prototypes belonging to the current class
734
+ indices_for_class = [
735
+ idx % self.last_layer.features_per_output_class
736
+ for idx in prototypes_to_prune
737
+ if self.prototype_class_identity[idx] == class_idx
738
+ ]
739
+ self.last_layer.weight.data[class_idx, indices_for_class] = 0.0
740
+ else:
741
+ # For other layer types, set the weights of pruned prototypes to zero
742
+ self.last_layer.weight.data[:, prototypes_to_prune] = 0.0
743
+
744
+ def __repr__(self) -> str:
745
+ rep = f"""PPNet(
746
+ prototype_shape: {self.prototype_shape},
747
+ num_classes: {self.num_classes},
748
+ epsilon: {self.epsilon_val})"""
749
+
750
+ return rep
751
+
752
+ def set_last_layer_incorrect_connection(
753
+ self, incorrect_strength: float = None
754
+ ) -> None:
755
+ """
756
+ Modifies the last layer weights to have incorrect connections with a specified strength.
757
+ If incorrect_strength is None, initializes the weights for LinearLayerWithoutNegativeConnections
758
+ with correct_class_connection value.
759
+
760
+ Args:
761
+ - incorrect_strength (Optional[float]): The strength of the incorrect connections.
762
+ If None, initialize without incorrect connections.
763
+
764
+ Returns:
765
+ None
766
+ """
767
+ if incorrect_strength is None:
768
+ # Handle LinearLayerWithoutNegativeConnections initialization
769
+ if isinstance(self.last_layer, LinearLayerWithoutNegativeConnections):
770
+ # Initialize all weights to the correct_class_connection value
771
+ self.last_layer.weight.data.fill_(self.correct_class_connection)
772
+ else:
773
+ raise ValueError(
774
+ "last_layer is not an instance of LinearLayerWithoutNegativeConnections"
775
+ )
776
+
777
+ else:
778
+ # Create a one-hot matrix for correct connections
779
+ positive_one_weights_locations = torch.zeros(
780
+ self.num_classes, self.num_prototypes
781
+ )
782
+ positive_one_weights_locations[
783
+ self.prototype_class_identity,
784
+ torch.arange(self.num_prototypes),
785
+ ] = 1
786
+
787
+ # Create a matrix for incorrect connections
788
+ negative_one_weights_locations = 1 - positive_one_weights_locations
789
+
790
+ # This variable represents the strength of the connection for correct class
791
+ correct_class_connection = self.correct_class_connection
792
+
793
+ # This variable represents the strength of the connection for incorrect class
794
+ incorrect_class_connection = incorrect_strength
795
+
796
+ # Modify weights to have correct and incorrect connections
797
+ self.last_layer.weight.data.copy_(
798
+ correct_class_connection * positive_one_weights_locations
799
+ + incorrect_class_connection * negative_one_weights_locations
800
+ )
801
+
802
+ if self.last_layer.bias is not None:
803
+ # Initialize all biases to bias_last_layer value
804
+ self.last_layer.bias.data.fill_(self.bias_last_layer)
805
+
806
+ def _setup_add_on_layers(self, add_on_layers_type: str):
807
+ """
808
+ Configures additional layers based on the backbone model architecture and the specified add_on_layers_type.
809
+
810
+ Args:
811
+ add_on_layers_type (str): Type of additional layers to add. Can be 'identity' or 'upsample'.
812
+ """
813
+
814
+ if add_on_layers_type == "identity":
815
+ self.add_on_layers = nn.Sequential(nn.Identity())
816
+ elif add_on_layers_type == "upsample":
817
+ self.add_on_layers = nn.Upsample(scale_factor=2, mode="bilinear")
818
+ else:
819
+ raise NotImplementedError(
820
+ f"The add-on layer type {add_on_layers_type} isn't implemented yet."
821
+ )
822
+
823
+ # TODO
824
+ # def _initialize_weights(self) -> None:
825
+ # """
826
+ # Initializes the weights of the add-on layers of the network and the last layer with incorrect connections.
827
+ #
828
+ # Returns:
829
+ # None
830
+ # """
831
+ #
832
+ # for m in self.add_on_layers.modules():
833
+ # if isinstance(m, (nn.Conv2d, nn.Linear)):
834
+ # nn.init.trunc_normal_(m.weight, std=0.02)
835
+ # if m.bias is not None:
836
+ # nn.init.zeros_(m.bias)
837
+ #
838
+ # # Initialize the last layer with incorrect connections using specified incorrect class connection strength
839
+ # self.set_last_layer_incorrect_connection(
840
+ # incorrect_strength=self.incorrect_class_connection
841
+ # )
842
+
843
+
844
+ class AudioProtoNetPreTrainedModel(PreTrainedModel):
845
+ config_class = AudioProtoNetConfig
846
+ base_model_prefix = "model"
847
+
848
+ def _init_weights(self, module):
849
+ if isinstance(module, (nn.Conv2d, nn.Linear)):
850
+ nn.init.trunc_normal_(module.weight, std=0.02)
851
+ if module.bias is not None:
852
+ nn.init.zeros_(module.bias)
853
+ if isinstance(module, (nn.Conv2d, nn.Linear)):
854
+ nn.init.trunc_normal_(module.weight, std=0.02)
855
+ if module.bias is not None:
856
+ nn.init.zeros_(module.bias)
857
+ if self.incorrect_class_connection is None and isinstance(self.last_layer, LinearLayerWithoutNegativeConnections): # TODO missing initilization
858
+ # Initialize all weights to the correct_class_connection value
859
+ self.last_layer.weight.data.fill_(self.correct_class_connection)
860
+
861
+
862
+ class AudioProtoNetModel(AudioProtoNetPreTrainedModel):
863
+ _auto_class = "AutoModel"
864
+
865
+ def __init__(self, config: AudioProtoNetConfig):
866
+ super().__init__(config)
867
+ backbone_config = ConvNextConfig.from_pretrained("facebook/convnext-base-224-22k", num_channels=1)
868
+ self.backbone = ConvNextModel(backbone_config)
869
+
870
+ def forward(
871
+ self,
872
+ input_values: torch.Tensor,
873
+ output_hidden_states: bool = None,
874
+ return_dict: bool = None
875
+ ) -> tuple | BaseModelOutputWithPoolingAndNoAttention:
876
+ """
877
+ Args:
878
+ input_values:
879
+ output_hidden_states:
880
+ return_dict:
881
+
882
+ Returns:
883
+ last_hidden_state: torch.FloatTensor = None
884
+ pooler_output: torch.FloatTensor = None
885
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
886
+
887
+ """
888
+ return self.backbone(input_values, output_hidden_states, return_dict)
889
+
890
+
891
+ class AudioProtoNetForSequenceClassification(PreTrainedModel):
892
+ _auto_class = "AutoModelForSequenceClassification"
893
+
894
+ def __init__(self, config: AudioProtoNetConfig):
895
+ super().__init__(config)
896
+
897
+ self.model = AudioProtoNetModel(config)
898
+ self.head = AudioProtoNetClassificationHead(config)
899
+
900
+
901
+ def freeze_backbone(self):
902
+ pass
903
+
904
+ def int2str(self): # TODO
905
+ pass
906
+
907
+ def forward(
908
+ self,
909
+ input_values: torch.Tensor,
910
+ labels: torch.Tensor = None,
911
+ prototypes_of_wrong_class: torch.Tensor = None,
912
+ output_hidden_states: bool = None,
913
+ output_prototypical_activations: bool = None,
914
+ return_dict: bool = None,
915
+ ) -> tuple | SequenceClassifierOutputWithProtoTypeActivations:
916
+
917
+ backbone_outputs = self.model(input_values, output_hidden_states, return_dict)
918
+
919
+ last_hidden_state = backbone_outputs[0]
920
+
921
+ logits, info = self.head(last_hidden_state, prototypes_of_wrong_class)
922
+
923
+ loss = None
924
+ if labels is not None:
925
+ labels.to(logits.device)
926
+ loss_fct = AsymmetricLossMultiLabel()
927
+ loss = loss_fct(logits, labels.float())
928
+
929
+
930
+
931
+ hidden_states = None
932
+ if output_hidden_states is not None:
933
+ hidden_states = backbone_outputs[2]
934
+
935
+ prototype_activations = None
936
+ if output_prototypical_activations is not None:
937
+ prototype_activations = info[4]
938
+
939
+ if return_dict:
940
+ output = (logits,)
941
+ output += (loss, ) if loss is not None else ()
942
+ output += (last_hidden_state, )
943
+ output += (hidden_states, ) if hidden_states is not None else ()
944
+ output += (prototype_activations,) if prototype_activations is not None else ()
945
+ return output
946
+
947
+ return SequenceClassifierOutputWithProtoTypeActivations(
948
+ logits=logits,
949
+ loss=loss,
950
+ last_hidden_state=last_hidden_state,
951
+ hidden_states=hidden_states,
952
+ prototype_activations=prototype_activations
953
+ )
preprocessor_config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoFeatureExtractor": "processing_protonet.AudioProtoNetFeatureExtractor"
4
+ },
5
+ "db_scale": null,
6
+ "feature_extractor_type": "AudioProtoNetFeatureExtractor",
7
+ "feature_size": 1,
8
+ "hop_length": 256,
9
+ "mean": -13.369,
10
+ "mel_scale": null,
11
+ "n_fft": 2048,
12
+ "n_mels": 256,
13
+ "n_stft": 1025,
14
+ "padding_side": "right",
15
+ "padding_value": 0.0,
16
+ "power": 2.0,
17
+ "return_attention_mask": true,
18
+ "sampling_rate": 32000,
19
+ "spec_transform": null,
20
+ "std": 13.162,
21
+ "stype": "power",
22
+ "top_db": 80
23
+ }
processing_protonet.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import SequenceFeatureExtractor
2
+ from transformers.utils import PaddingStrategy
3
+ from transformers.feature_extraction_utils import BatchFeature
4
+ from torchaudio import transforms
5
+ from typing import Union
6
+ import numpy as np
7
+ import torch
8
+
9
+
10
+ class AudioProtoNetFeatureExtractor(SequenceFeatureExtractor):
11
+ _auto_class = "AutoFeatureExtractor"
12
+ model_input_names = ["input_values"]
13
+
14
+ def __init__(self,
15
+ # spectrogram
16
+ n_fft: int = 2048,
17
+ feature_size: int = 1,
18
+ hop_length: int = 256,
19
+ power: float = 2.0,
20
+
21
+ # mel scale
22
+ n_mels: int = 256,
23
+ sampling_rate: int = 32_000,
24
+ n_stft: int = 1025,
25
+
26
+ # power to db
27
+ stype: str = "power",
28
+ top_db: int = 80,
29
+
30
+ # normalization
31
+ mean: float = -13.369,
32
+ std: float = 13.162,
33
+ padding_value: float = 0.0,
34
+
35
+ return_attention_mask: bool = True,
36
+ **kwargs,
37
+ ):
38
+ super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
39
+
40
+ # Store parameters for serialization
41
+ self.n_fft = n_fft
42
+ self.hop_length = hop_length
43
+ self.power = power
44
+ self.n_mels = n_mels
45
+ self.sampling_rate = sampling_rate
46
+ self.n_stft = n_stft
47
+ self.stype = stype
48
+ self.top_db = top_db
49
+ self.mean = mean
50
+ self.std = std
51
+ self.padding_value = padding_value
52
+ self.return_attention_mask = return_attention_mask
53
+ self.spec_transform = None
54
+ self.mel_scale = None
55
+ self.db_scale = None
56
+
57
+ def _init_transforms(self): # TODO post init method?
58
+ self.spec_transform = transforms.Spectrogram(n_fft=self.n_fft, hop_length=self.hop_length, power=self.power)
59
+ self.mel_scale = transforms.MelScale(n_mels=self.n_mels, sample_rate=self.sampling_rate, n_stft=self.n_stft)
60
+ self.db_scale = transforms.AmplitudeToDB(stype=self.stype, top_db=self.top_db)
61
+
62
+ def __call__(self,
63
+ waveform_batch: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]],
64
+ padding: Union[bool, str, PaddingStrategy] = "longest",
65
+ max_length: int | None = None,
66
+ truncation: bool = True,
67
+ return_tensors: str = "pt"
68
+ ):
69
+ if self.spec_transform is None:
70
+ self._init_transforms()
71
+ clip_duration = 5 # TODO this is the clip duration used in training
72
+ max_length = max_length or int(int(self.sampling_rate) * clip_duration)
73
+
74
+ if isinstance(waveform_batch, (list, np.ndarray)) and not isinstance(waveform_batch[0], (list, np.ndarray)):
75
+ waveform_batch = [waveform_batch]
76
+
77
+ waveform_batch = BatchFeature({"input_values": waveform_batch})
78
+
79
+ waveform_batch = self.pad(
80
+ waveform_batch,
81
+ padding=padding,
82
+ max_length=max_length,
83
+ truncation=truncation,
84
+ return_attention_mask=self.return_attention_mask
85
+ )
86
+ waveform_batch = waveform_batch["input_values"]
87
+ audio_tensor = torch.as_tensor(waveform_batch)
88
+ spec_gram = self.spec_transform(audio_tensor)
89
+ mel_spec = self.mel_scale(spec_gram)
90
+ mel_spec = self.db_scale(mel_spec)
91
+ mel_spec_norm = (mel_spec - self.mean) / self.std
92
+
93
+ return mel_spec_norm.unsqueeze(1)
94
+
95
+
96
+