mwirth7 commited on
Commit
6e22215
·
verified ·
1 Parent(s): 988745a

Update modeling_protonet.py

Browse files
Files changed (1) hide show
  1. modeling_protonet.py +953 -953
modeling_protonet.py CHANGED
@@ -1,953 +1,953 @@
1
- import torch
2
- import torch.nn as nn
3
- from torch import Tensor
4
- import math
5
-
6
- from transformers import PreTrainedModel, ConvNextModel, ConvNextConfig
7
- from transformers.utils import logging
8
- from transformers.modeling_outputs import ModelOutput, BaseModelOutputWithPoolingAndNoAttention
9
- from dataclasses import dataclass
10
-
11
- from .configuration_protonet import AudioProtoNetConfig
12
-
13
- logger = logging.get_logger(__name__)
14
-
15
-
16
- @dataclass
17
- class SequenceClassifierOutputWithProtoTypeActivations(ModelOutput):
18
- logits: torch.Tensor
19
- loss: torch.Tensor = None
20
- last_hidden_state: torch.FloatTensor = None
21
- hidden_states: tuple[torch.FloatTensor, ...] = None
22
- prototype_activations: torch.FloatTensor = None
23
-
24
-
25
- # https://openaccess.thecvf.com/content/ICCV2021/papers/Ridnik_Asymmetric_Loss_for_Multi-Label_Classification_ICCV_2021_paper.pdf
26
- # https://github.com/huggingface/pytorch-image-models/blob/bbe798317fb26f063c18279827c038058e376479/timm/loss/asymmetric_loss.py#L6
27
- class AsymmetricLossMultiLabel(nn.Module):
28
- def __init__(
29
- self,
30
- gamma_neg=4,
31
- gamma_pos=1,
32
- clip=0.05,
33
- eps=1e-8,
34
- disable_torch_grad_focal_loss=False,
35
- reduction="mean",
36
- ):
37
- super().__init__()
38
-
39
- self.gamma_neg = gamma_neg
40
- self.gamma_pos = gamma_pos
41
- self.clip = clip
42
- self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
43
- self.eps = eps
44
- self.reduction = reduction
45
-
46
- def forward(self, x, y):
47
- """ "
48
- Parameters
49
- ----------
50
- x: input logits
51
- y: targets (multi-label binarized vector)
52
- """
53
-
54
- # Calculating Probabilities
55
- x_sigmoid = torch.sigmoid(x)
56
- xs_pos = x_sigmoid
57
- xs_neg = 1 - x_sigmoid
58
-
59
- # Asymmetric Clipping
60
- if self.clip is not None and self.clip > 0:
61
- xs_neg = (xs_neg + self.clip).clamp(max=1)
62
-
63
- # Basic CE calculation
64
- los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
65
- los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
66
- loss = los_pos + los_neg
67
-
68
- # Asymmetric Focusing
69
- if self.gamma_neg > 0 or self.gamma_pos > 0:
70
- if self.disable_torch_grad_focal_loss:
71
- torch._C.set_grad_enabled(False)
72
- pt0 = xs_pos * y
73
- pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
74
- pt = pt0 + pt1
75
- one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
76
- one_sided_w = torch.pow(1 - pt, one_sided_gamma)
77
- if self.disable_torch_grad_focal_loss:
78
- torch._C.set_grad_enabled(True)
79
- loss *= one_sided_w
80
-
81
- if self.reduction == "mean":
82
- return -loss.mean()
83
- if self.reduction == "sum":
84
- return -loss.sum()
85
-
86
- return -loss
87
-
88
-
89
- class NonNegativeLinear(nn.Module):
90
- """
91
- A PyTorch module for a linear layer with non-negative weights.
92
-
93
- This module applies a linear transformation to the incoming data: `y = xA^T + b`.
94
- The weights of the transformation are constrained to be non-negative, making this
95
- module particularly useful in models where negative weights may not be appropriate.
96
-
97
- Attributes:
98
- in_features (int): The number of features in the input tensor.
99
- out_features (int): The number of features in the output tensor.
100
- weight (torch.Tensor): The weight parameter of the module, constrained to be non-negative.
101
- bias (torch.Tensor, optional): The bias parameter of the module.
102
-
103
- Args:
104
- in_features (int): The number of features in the input tensor.
105
- out_features (int): The number of features in the output tensor.
106
- bias (bool, optional): If True, the layer will include a learnable bias. Default: True.
107
- device (optional): The device (CPU/GPU) on which to perform computations.
108
- dtype (optional): The data type for the parameters (e.g., float32).
109
- """
110
-
111
- def __init__(
112
- self,
113
- in_features: int,
114
- out_features: int,
115
- bias: bool = True,
116
- device=None,
117
- dtype=None,
118
- ) -> None:
119
- factory_kwargs = {"device": device, "dtype": dtype}
120
- super().__init__()
121
- self.in_features = in_features
122
- self.out_features = out_features
123
- self.weight = nn.Parameter(
124
- torch.empty((out_features, in_features), **factory_kwargs)
125
- )
126
- if bias:
127
- self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
128
- else:
129
- self.register_parameter("bias", None)
130
-
131
- def forward(self, input: torch.Tensor) -> torch.Tensor:
132
- """
133
- Defines the forward pass of the NonNegativeLinear module.
134
-
135
- Args:
136
- input (torch.Tensor): The input tensor of shape (batch_size, in_features).
137
-
138
- Returns:
139
- torch.Tensor: The output tensor of shape (batch_size, out_features).
140
- """
141
- return nn.functional.linear(input, torch.relu(self.weight), self.bias)
142
-
143
-
144
- class LinearLayerWithoutNegativeConnections(nn.Module):
145
- r"""
146
- Custom Linear Layer where each output class is connected to a specific subset of input features.
147
-
148
- Args:
149
- in_features: size of each input sample
150
- out_features: size of each output sample
151
- bias: If set to ``False``, the layer will not learn an additive bias.
152
- Default: ``True``
153
- device: the device of the module parameters. Default: ``None``
154
- dtype: the data type of the module parameters. Default: ``None``
155
-
156
- Shape:
157
- - Input: :math:`(*, H_{in})` where :math:`*` means any number of
158
- dimensions including none and :math:`H_{in} = \text{in_features}`.
159
- - Output: :math:`(*, H_{out})` where all but the last dimension
160
- are the same shape as the input and :math:`H_{out} = \text{out_features}`.
161
-
162
- Attributes:
163
- weight: the learnable weights of the module of shape
164
- :math:`(\text{out_features}, \text{features_per_output_class})`.
165
- bias: the learnable bias of the module of shape :math:`(\text{out_features})`.
166
- If :attr:`bias` is ``True``, the values are initialized from
167
- :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
168
- :math:`k = \frac{1}{\text{features_per_output_class}}`
169
- """
170
-
171
- __constants__ = ["in_features", "out_features", "bias"]
172
- in_features: int
173
- out_features: int
174
- weight: torch.Tensor
175
-
176
- def __init__(
177
- self,
178
- in_features: int,
179
- out_features: int,
180
- bias: bool = True,
181
- non_negative: bool = True,
182
- device: torch.device = None,
183
- dtype: torch.dtype = None,
184
- ) -> None:
185
- factory_kwargs = {"device": device, "dtype": dtype}
186
- super().__init__()
187
- self.in_features = in_features
188
- self.out_features = out_features
189
- self.non_negative = non_negative
190
-
191
- # Calculate the number of features per output class
192
- self.features_per_output_class = in_features // out_features
193
-
194
- # Ensure input size is divisible by the output size
195
- assert (
196
- in_features % out_features == 0
197
- ), f"{in_features = } must be divisible by {out_features = }"
198
-
199
- # Define weights and biases
200
- self.weight = nn.Parameter(
201
- torch.empty(
202
- (out_features, self.features_per_output_class), **factory_kwargs
203
- )
204
- )
205
- if bias:
206
- self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
207
- else:
208
- self.register_parameter("bias", None)
209
-
210
- # Initialize weights and biases
211
- self.reset_parameters()
212
-
213
- def reset_parameters(self) -> None:
214
- """
215
- Initialize the weights and biases.
216
- Weights are initialized using Kaiming uniform initialization.
217
- Biases are initialized using a uniform distribution.
218
- """
219
- # Kaiming uniform initialization for the weights
220
- nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
221
-
222
- if self.bias is not None:
223
- # Calculate fan-in and fan-out values
224
- fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
225
-
226
- # Uniform initialization for the biases
227
- bound = 1 / math.sqrt(fan_in)
228
- nn.init.uniform_(self.bias, -bound, bound)
229
-
230
- def forward(self, input: torch.Tensor) -> torch.Tensor:
231
- """
232
- Forward pass for the custom linear layer.
233
-
234
- Args:
235
- input (Tensor): Input tensor of shape (batch_size, in_features).
236
-
237
- Returns:
238
- Tensor: Output tensor of shape (batch_size, out_features).
239
- """
240
- batch_size = input.size(0)
241
- # Reshape input to (batch_size, out_features, features_per_output_class)
242
- reshaped_input = input.view(
243
- batch_size, self.out_features, self.features_per_output_class
244
- )
245
-
246
- # Apply ReLU to weights if non_negative_last_layer is True
247
- weight = torch.relu(self.weight) if self.non_negative else self.weight
248
-
249
- # Perform batch matrix multiplication and add bias
250
- output = torch.einsum("bof,of->bo", reshaped_input, weight)
251
-
252
- if self.bias is not None:
253
- output += self.bias
254
-
255
- return output
256
-
257
- def extra_repr(self) -> str:
258
- return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}"
259
-
260
-
261
- class AudioProtoNetClassificationHead(nn.Module):
262
- def __init__(
263
- self,
264
- config: AudioProtoNetConfig,
265
- ) -> None:
266
- """
267
- PPNet is a class that implements the Prototypical Part Network (ProtoPNet) for prototype-based classification.
268
- """
269
-
270
- super().__init__()
271
- self.prototypes_per_class = config.prototypes_per_class
272
- self.num_classes = config.num_classes
273
- self.num_prototypes = self.prototypes_per_class * self.num_classes
274
- self.num_prototypes_after_pruning = config.num_prototypes_after_pruning
275
- self.margin = config.margin
276
- self.relu_on_cos = config.relu_on_cos
277
- self.incorrect_class_connection = config.incorrect_class_connection
278
- self.correct_class_connection = config.correct_class_connection
279
- self.input_vector_length = config.input_vector_length
280
- self.n_eps_channels = config.n_eps_channels
281
- self.epsilon_val = config.epsilon_val
282
- self.topk_k = config.topk_k
283
- self.bias_last_layer = config.bias_last_layer
284
- self.non_negative_last_layer = config.non_negative_last_layer
285
- self.embedded_spectrogram_height = config.embedded_spectrogram_height
286
- self.use_bias_last_layer = config.use_bias_last_layer
287
- self.prototype_class_identity = config.prototype_class_identity
288
-
289
- # Create a 1D tensor where each element represents the class index
290
- self.prototype_class_identity = (
291
- torch.arange(self.num_prototypes) // self.prototypes_per_class
292
- )
293
-
294
- self.prototype_shape = (self.num_prototypes, config.channels, config.height, config.width)
295
-
296
- self._setup_add_on_layers(add_on_layers_type=config.add_on_layers_type)
297
-
298
- self.prototype_vectors = nn.Parameter(
299
- torch.rand(self.prototype_shape), requires_grad=True
300
- )
301
-
302
- self.frequency_weights = None
303
- if self.embedded_spectrogram_height is not None:
304
- # Initialize the frequency weights with a large positive value of 3.0 so that sigmoid(frequency_weights) is close to 1.
305
- self.frequency_weights = nn.Parameter(
306
- torch.full(
307
- (
308
- self.num_prototypes,
309
- self.embedded_spectrogram_height,
310
- ),
311
- 3.0,
312
- )
313
- )
314
-
315
-
316
- if self.incorrect_class_connection:
317
- if self.non_negative_last_layer:
318
- self.last_layer = NonNegativeLinear(
319
- self.num_prototypes, self.num_classes, bias=self.use_bias_last_layer
320
- )
321
- else:
322
- self.last_layer = nn.Linear(
323
- self.num_prototypes, self.num_classes, bias=self.use_bias_last_layer
324
- )
325
- else:
326
- self.last_layer = LinearLayerWithoutNegativeConnections(
327
- in_features=self.num_prototypes,
328
- out_features=self.num_classes,
329
- non_negative=self.non_negative_last_layer,
330
- )
331
-
332
- def forward(
333
- self,
334
- features: torch.Tensor,
335
- prototypes_of_wrong_class: torch.Tensor = None,
336
- ) -> tuple[torch.Tensor, list[torch.Tensor]]:
337
- """
338
- Forward pass of the PPNet model.
339
-
340
- Args:
341
- - x (torch.Tensor): Input tensor with shape (batch_size, num_channels, height, width).
342
- - prototypes_of_wrong_class (Optional[torch.Tensor]): The prototypes of the wrong classes that are needed
343
- when using subtractive margins. Defaults to None.
344
-
345
- Returns:
346
- Tuple[torch.Tensor, List[torch.Tensor]]:
347
- - logits: A tensor containing the logits for each class in the model.
348
- - a list containing:
349
- - mean_activations: A tensor containing the mean of the top-k prototype activations.
350
- (in evaluation mode k is always 1)
351
- - marginless_logits: A tensor containing the logits for each class in the model, calculated using the
352
- marginless activations.
353
- - conv_features: A tensor containing the convolutional features.
354
- - marginless_max_activations: A tensor containing the max-pooled marginless activations.
355
-
356
- """
357
-
358
- features = self.add_on_layers(features)
359
-
360
- activations, additional_returns = self.prototype_activations(
361
- features, prototypes_of_wrong_class=prototypes_of_wrong_class
362
- )
363
- marginless_activations = additional_returns[0]
364
- conv_features = additional_returns[1]
365
-
366
- # Set topk_k based on training mode: use predefined value if training, else 1 for evaluation
367
- topk_k = 1
368
-
369
- # Reshape activations to combine spatial dimensions: (batch_size, num_prototypes, height*width)
370
- activations = activations.view(activations.shape[0], activations.shape[1], -1)
371
-
372
- # Perform top-k pooling along the combined spatial dimension
373
- # For topk_k=1, this is equivalent to global max pooling
374
- topk_activations, _ = torch.topk(activations, topk_k, dim=-1)
375
-
376
- # Calculate the mean of the top-k activations for each channel: (batch_size, num_channels)
377
- # If topk_k=1, this mean operation does nothing since there's only one value.
378
- mean_activations = torch.mean(topk_activations, dim=-1)
379
-
380
- marginless_max_activations = nn.functional.max_pool2d(
381
- marginless_activations,
382
- kernel_size=(
383
- marginless_activations.size()[2],
384
- marginless_activations.size()[3],
385
- ),
386
- )
387
- marginless_max_activations = marginless_max_activations.view(
388
- -1, self.num_prototypes
389
- )
390
-
391
- logits = self.last_layer(mean_activations)
392
- marginless_logits = self.last_layer(marginless_max_activations)
393
- return logits, [
394
- mean_activations,
395
- marginless_logits,
396
- conv_features,
397
- marginless_max_activations,
398
- marginless_activations,
399
- ]
400
-
401
- # def conv_features(self, x: torch.Tensor) -> torch.Tensor:
402
- # """
403
- # Takes an input tensor and passes it through the backbone model to extract features.
404
- # Then, it passes them through the additional layers to produce the output tensor.
405
- #
406
- # Args:
407
- # x (torch.Tensor): The input tensor.
408
- #
409
- # Returns:
410
- # torch.Tensor: The output tensor after passing through the backbone model and additional layers.
411
- # """
412
- # # Extract features using the backbone model
413
- # features = self.backbone_model(x)
414
- #
415
- # # The features must be a 4D tensor of shape (batch size, channels, height, width)
416
- # if features.dim() == 3:
417
- # features.unsqueeze_(0)
418
- #
419
- # # Pass the features through additional layers
420
- # output = self.add_on_layers(features)
421
- #
422
- # return output
423
-
424
- def cos_activation(
425
- self,
426
- x: torch.Tensor,
427
- prototypes_of_wrong_class: torch.Tensor = None,
428
- ) -> tuple[torch.Tensor, torch.Tensor]:
429
- """
430
- Compute the cosine activation between input tensor x and prototype vectors.
431
-
432
- Parameters:
433
- -----------
434
- x : torch.Tensor
435
- Input tensor with shape (batch_size, num_channels, height, width).
436
- prototypes_of_wrong_class : Optional[torch.Tensor]
437
- Tensor containing the prototypes of the wrong class with shape (batch_size, num_prototypes).
438
-
439
- Returns:
440
- --------
441
- Tuple[torch.Tensor, torch.Tensor]
442
- A tuple containing:
443
- - activations: The cosine activations with potential margin adjustments.
444
- - marginless_activations: The cosine activations without margin adjustments.
445
- """
446
- input_vector_length = self.input_vector_length
447
- normalizing_factor = (
448
- self.prototype_shape[-2] * self.prototype_shape[-1]
449
- ) ** 0.5
450
-
451
- # Pre-allocate epsilon channels on the correct device for input tensor x
452
- epsilon_channel_x = torch.full(
453
- (x.shape[0], self.n_eps_channels, x.shape[2], x.shape[3]),
454
- self.epsilon_val,
455
- device=x.device,
456
- requires_grad=False,
457
- )
458
- x = torch.cat((x, epsilon_channel_x), dim=-3)
459
-
460
- # Normalize x
461
- x_length = torch.sqrt(torch.sum(x**2, dim=-3, keepdim=True) + self.epsilon_val)
462
- x_normalized = (input_vector_length * x / x_length) / normalizing_factor
463
-
464
- # Pre-allocate epsilon channels for prototypes on the correct device
465
- epsilon_channel_p = torch.full(
466
- (
467
- self.prototype_shape[0],
468
- self.n_eps_channels,
469
- self.prototype_shape[2],
470
- self.prototype_shape[3],
471
- ),
472
- self.epsilon_val,
473
- device=self.prototype_vectors.device,
474
- requires_grad=False,
475
- )
476
- appended_protos = torch.cat((self.prototype_vectors, epsilon_channel_p), dim=-3)
477
-
478
- # Normalize prototypes
479
- prototype_vector_length = torch.sqrt(
480
- torch.sum(appended_protos**2, dim=-3, keepdim=True) + self.epsilon_val
481
- )
482
- normalized_prototypes = appended_protos / (
483
- prototype_vector_length + self.epsilon_val
484
- )
485
- normalized_prototypes /= normalizing_factor
486
-
487
- # Compute activations using convolution
488
- activations_dot = nn.functional.conv2d(x_normalized, normalized_prototypes)
489
- marginless_activations = activations_dot / (input_vector_length * 1.01)
490
-
491
- if self.frequency_weights is not None:
492
- # Apply sigmoid to frequency weights. s.t. weights are between 0 and 1.
493
- freq_weights = torch.sigmoid(self.frequency_weights)
494
-
495
- # Multiply each prototype's frequency response by the corresponding weights
496
- marginless_activations = marginless_activations * freq_weights[:, :, None]
497
-
498
- if (
499
- self.margin is None
500
- or not self.training
501
- or prototypes_of_wrong_class is None
502
- ):
503
- activations = marginless_activations
504
- else:
505
- # Apply margin adjustment for wrong class prototypes
506
- wrong_class_margin = (prototypes_of_wrong_class * self.margin).view(
507
- x.size(0), self.prototype_vectors.size(0), 1, 1
508
- )
509
- wrong_class_margin = wrong_class_margin.expand(
510
- -1, -1, activations_dot.size(-2), activations_dot.size(-1)
511
- )
512
- penalized_angles = (
513
- torch.acos(activations_dot / (input_vector_length * 1.01))
514
- - wrong_class_margin
515
- )
516
- activations = torch.cos(torch.relu(penalized_angles))
517
-
518
- if self.relu_on_cos:
519
- # Apply ReLU activation on the cosine values
520
- activations = torch.relu(activations)
521
- marginless_activations = torch.relu(marginless_activations)
522
-
523
- return activations, marginless_activations
524
-
525
- def prototype_activations(
526
- self,
527
- x: torch.Tensor,
528
- prototypes_of_wrong_class: torch.Tensor = None,
529
- ) -> tuple[torch.Tensor, list[torch.Tensor]]:
530
- """
531
- Compute the prototype activations for a given input tensor.
532
-
533
- Args:
534
- - x (torch.Tensor): The raw input tensor with shape (batch_size, num_channels, height, width).
535
- - prototypes_of_wrong_class (Optional[torch.Tensor]): The prototypes of the wrong classes that are needed
536
- when using subtractive margins. Defaults to None.
537
-
538
- Returns:
539
- Tuple[torch.Tensor, List[torch.Tensor]]:
540
- - activations: A tensor containing the prototype activations.
541
- - a list containing:
542
- - marginless_activations: A tensor containing the activations before applying subtractive margin.
543
- - conv_features: A tensor containing the convolutional features.
544
- """
545
- # Compute cosine activations
546
- activations, marginless_activations = self.cos_activation(
547
- x,
548
- prototypes_of_wrong_class=prototypes_of_wrong_class,
549
- )
550
-
551
- return activations, [marginless_activations, x]
552
-
553
- def get_prototype_orthogonalities(self, use_part_prototypes: bool = False) -> torch.Tensor:
554
- """
555
- Computes the orthogonality loss, encouraging each piece of a prototype to be orthogonal to the others.
556
-
557
- This method is inspired by the paper:
558
- https://openaccess.thecvf.com/content/ICCV2021/papers/Wang_Interpretable_Image_Recognition_by_Constructing_Transparent_Embedding_Space_ICCV_2021_paper.pdf
559
-
560
- Args:
561
- use_part_prototypes (bool): If True, treats each spatial part of the prototypes as a separate prototype.
562
-
563
- Returns:
564
- torch.Tensor: A tensor representing the orthogonalities.
565
- """
566
-
567
- if use_part_prototypes:
568
- # Normalize prototypes to unit length
569
- prototype_vector_length = torch.sqrt(
570
- torch.sum(torch.square(self.prototype_vectors), dim=1, keepdim=True)
571
- + self.epsilon_val
572
- )
573
- normalized_prototypes = self.prototype_vectors / (
574
- prototype_vector_length + self.epsilon_val
575
- )
576
-
577
- # Calculate total part prototypes per class
578
- num_part_prototypes_per_class = (
579
- self.num_prototypes_per_class
580
- * self.prototype_shape[2]
581
- * self.prototype_shape[3]
582
- )
583
-
584
- # Reshape to match class structure
585
- normalized_prototypes = normalized_prototypes.view(
586
- self.num_classes,
587
- self.num_prototypes_per_class,
588
- self.prototype_shape[1],
589
- self.prototype_shape[2] * self.prototype_shape[3],
590
- )
591
-
592
- # Transpose and reshape to treat each spatial part as a separate prototype
593
- normalized_prototypes = normalized_prototypes.permute(0, 1, 3, 2).reshape(
594
- self.num_classes, num_part_prototypes_per_class, self.prototype_shape[1]
595
- )
596
-
597
- else:
598
- # Normalize prototypes to unit length
599
- prototype_vectors_reshaped = self.prototype_vectors.view(
600
- self.num_prototypes, -1
601
- )
602
- prototype_vector_length = torch.sqrt(
603
- torch.sum(torch.square(prototype_vectors_reshaped), dim=1, keepdim=True)
604
- + self.epsilon_val
605
- )
606
- normalized_prototypes = prototype_vectors_reshaped / (
607
- prototype_vector_length + self.epsilon_val
608
- )
609
-
610
- # Reshape to match class structure
611
- normalized_prototypes = normalized_prototypes.view(
612
- self.num_classes,
613
- self.num_prototypes_per_class,
614
- self.prototype_shape[1]
615
- * self.prototype_shape[2]
616
- * self.prototype_shape[3],
617
- )
618
-
619
- # Compute orthogonality matrix for each class
620
- orthogonalities = torch.matmul(
621
- normalized_prototypes, normalized_prototypes.transpose(1, 2)
622
- )
623
-
624
- # Identity matrix to enforce orthogonality
625
- identity_matrix = (
626
- torch.eye(normalized_prototypes.shape[1], device=orthogonalities.device)
627
- .unsqueeze(0)
628
- .repeat(self.num_classes, 1, 1)
629
- )
630
-
631
- # Subtract identity to focus on orthogonality
632
- orthogonalities = orthogonalities - identity_matrix
633
-
634
- return orthogonalities
635
-
636
- def identify_prototypes_to_prune(self) -> list[int]:
637
- """
638
- Identifies the indices of prototypes that should be pruned.
639
-
640
- This function iterates through the prototypes and checks if the specific weight
641
- connecting the prototype to its class is zero. It is specifically designed to handle
642
- the LinearLayerWithoutNegativeConnections where each class has a subset of features
643
- it connects to.
644
-
645
- Returns:
646
- list[int]: A list of prototype indices that should be pruned.
647
- """
648
- prototypes_to_prune = []
649
-
650
- # Calculate the number of prototypes assigned to each class
651
- prototypes_per_class = self.num_prototypes // self.num_classes
652
-
653
- if isinstance(self.last_layer, LinearLayerWithoutNegativeConnections):
654
- # Custom layer mapping prototypes to a subset of input features for each output class
655
- for prototype_index in range(self.num_prototypes):
656
- class_index = self.prototype_class_identity[prototype_index]
657
- # Calculate the specific index within the 'features_per_output_class' for this prototype
658
- index_within_class = prototype_index % prototypes_per_class
659
- # Check if the specific weight connecting the prototype to its class is zero
660
- if self.last_layer.weight[class_index, index_within_class] == 0.0:
661
- prototypes_to_prune.append(prototype_index)
662
- else:
663
- # Standard linear layer: each prototype directly maps to a feature index
664
- weights_to_check = self.last_layer.weight
665
- for prototype_index in range(self.num_prototypes):
666
- class_index = self.prototype_class_identity[prototype_index]
667
- if weights_to_check[class_index, prototype_index] == 0.0:
668
- prototypes_to_prune.append(prototype_index)
669
-
670
- return prototypes_to_prune
671
-
672
- def prune_prototypes_by_threshold(self, threshold: float = 1e-3) -> None:
673
- """
674
- Prune the weights in the classification layer by setting weights below a specified threshold to zero.
675
-
676
- This method modifies the weights of the last layer of the model in-place. Weights falling below the
677
- threshold are set to zero, diminishing their influence in the model's decisions. It also identifies
678
- and prunes prototypes based on these updated weights, thereby refining the model's structure.
679
-
680
- Args:
681
- threshold (float): The threshold value below which weights will be set to zero. Defaults to 1e-3.
682
- """
683
- # Access the weights of the last layer
684
- weights = self.last_layer.weight.data
685
-
686
- # Set weights below the threshold to zero
687
- # This step reduces the influence of low-value weights in the model's decision-making process
688
- weights[weights < threshold] = 0.0
689
-
690
- # Update the weights in the last layer to reflect the pruning
691
- self.last_layer.weight.data.copy_(weights)
692
-
693
- # Identify prototypes that need to be pruned based on the updated weights
694
- prototypes_to_prune = self.identify_prototypes_to_prune()
695
-
696
- # Execute the pruning of identified prototypes
697
- self.prune_prototypes_by_index(prototypes_to_prune)
698
-
699
- def prune_prototypes_by_index(self, prototypes_to_prune: list[int]) -> None:
700
- """
701
- Prunes specified prototypes from the PPNet.
702
-
703
- Args:
704
- prototypes_to_prune (list[int]): A list of indices indicating the prototypes to be removed.
705
- Each index should be in the range [0, current number of prototypes - 1].
706
-
707
- Returns:
708
- None
709
- """
710
-
711
- # Validate the provided indices to ensure they are within the valid range
712
- if any(
713
- index < 0 or index >= self.num_prototypes for index in prototypes_to_prune
714
- ):
715
- raise ValueError("Provided prototype indices are out of valid range!")
716
-
717
- # Calculate the new number of prototypes after pruning
718
- self.num_prototypes_after_pruning = self.num_prototypes - len(
719
- prototypes_to_prune
720
- )
721
-
722
- # Remove the prototype vectors that are no longer needed
723
- with torch.no_grad():
724
- # If frequency_weights are being used, set the weights of pruned prototypes to -7
725
- if self.frequency_weights is not None:
726
- self.frequency_weights.data[prototypes_to_prune, :] = -7.0
727
-
728
- # Adjust the weights in the last layer depending on its type
729
- if isinstance(self.last_layer, LinearLayerWithoutNegativeConnections):
730
- # For LinearLayerWithoutNegativeConnections, set the connection weights to zero
731
- # only for the pruned prototypes related to their specific classes
732
- for class_idx in range(self.last_layer.out_features):
733
- # Identify prototypes belonging to the current class
734
- indices_for_class = [
735
- idx % self.last_layer.features_per_output_class
736
- for idx in prototypes_to_prune
737
- if self.prototype_class_identity[idx] == class_idx
738
- ]
739
- self.last_layer.weight.data[class_idx, indices_for_class] = 0.0
740
- else:
741
- # For other layer types, set the weights of pruned prototypes to zero
742
- self.last_layer.weight.data[:, prototypes_to_prune] = 0.0
743
-
744
- def __repr__(self) -> str:
745
- rep = f"""PPNet(
746
- prototype_shape: {self.prototype_shape},
747
- num_classes: {self.num_classes},
748
- epsilon: {self.epsilon_val})"""
749
-
750
- return rep
751
-
752
- def set_last_layer_incorrect_connection(
753
- self, incorrect_strength: float = None
754
- ) -> None:
755
- """
756
- Modifies the last layer weights to have incorrect connections with a specified strength.
757
- If incorrect_strength is None, initializes the weights for LinearLayerWithoutNegativeConnections
758
- with correct_class_connection value.
759
-
760
- Args:
761
- - incorrect_strength (Optional[float]): The strength of the incorrect connections.
762
- If None, initialize without incorrect connections.
763
-
764
- Returns:
765
- None
766
- """
767
- if incorrect_strength is None:
768
- # Handle LinearLayerWithoutNegativeConnections initialization
769
- if isinstance(self.last_layer, LinearLayerWithoutNegativeConnections):
770
- # Initialize all weights to the correct_class_connection value
771
- self.last_layer.weight.data.fill_(self.correct_class_connection)
772
- else:
773
- raise ValueError(
774
- "last_layer is not an instance of LinearLayerWithoutNegativeConnections"
775
- )
776
-
777
- else:
778
- # Create a one-hot matrix for correct connections
779
- positive_one_weights_locations = torch.zeros(
780
- self.num_classes, self.num_prototypes
781
- )
782
- positive_one_weights_locations[
783
- self.prototype_class_identity,
784
- torch.arange(self.num_prototypes),
785
- ] = 1
786
-
787
- # Create a matrix for incorrect connections
788
- negative_one_weights_locations = 1 - positive_one_weights_locations
789
-
790
- # This variable represents the strength of the connection for correct class
791
- correct_class_connection = self.correct_class_connection
792
-
793
- # This variable represents the strength of the connection for incorrect class
794
- incorrect_class_connection = incorrect_strength
795
-
796
- # Modify weights to have correct and incorrect connections
797
- self.last_layer.weight.data.copy_(
798
- correct_class_connection * positive_one_weights_locations
799
- + incorrect_class_connection * negative_one_weights_locations
800
- )
801
-
802
- if self.last_layer.bias is not None:
803
- # Initialize all biases to bias_last_layer value
804
- self.last_layer.bias.data.fill_(self.bias_last_layer)
805
-
806
- def _setup_add_on_layers(self, add_on_layers_type: str):
807
- """
808
- Configures additional layers based on the backbone model architecture and the specified add_on_layers_type.
809
-
810
- Args:
811
- add_on_layers_type (str): Type of additional layers to add. Can be 'identity' or 'upsample'.
812
- """
813
-
814
- if add_on_layers_type == "identity":
815
- self.add_on_layers = nn.Sequential(nn.Identity())
816
- elif add_on_layers_type == "upsample":
817
- self.add_on_layers = nn.Upsample(scale_factor=2, mode="bilinear")
818
- else:
819
- raise NotImplementedError(
820
- f"The add-on layer type {add_on_layers_type} isn't implemented yet."
821
- )
822
-
823
- # TODO
824
- # def _initialize_weights(self) -> None:
825
- # """
826
- # Initializes the weights of the add-on layers of the network and the last layer with incorrect connections.
827
- #
828
- # Returns:
829
- # None
830
- # """
831
- #
832
- # for m in self.add_on_layers.modules():
833
- # if isinstance(m, (nn.Conv2d, nn.Linear)):
834
- # nn.init.trunc_normal_(m.weight, std=0.02)
835
- # if m.bias is not None:
836
- # nn.init.zeros_(m.bias)
837
- #
838
- # # Initialize the last layer with incorrect connections using specified incorrect class connection strength
839
- # self.set_last_layer_incorrect_connection(
840
- # incorrect_strength=self.incorrect_class_connection
841
- # )
842
-
843
-
844
- class AudioProtoNetPreTrainedModel(PreTrainedModel):
845
- config_class = AudioProtoNetConfig
846
- base_model_prefix = "model"
847
-
848
- def _init_weights(self, module):
849
- if isinstance(module, (nn.Conv2d, nn.Linear)):
850
- nn.init.trunc_normal_(module.weight, std=0.02)
851
- if module.bias is not None:
852
- nn.init.zeros_(module.bias)
853
- if isinstance(module, (nn.Conv2d, nn.Linear)):
854
- nn.init.trunc_normal_(module.weight, std=0.02)
855
- if module.bias is not None:
856
- nn.init.zeros_(module.bias)
857
- if self.incorrect_class_connection is None and isinstance(self.last_layer, LinearLayerWithoutNegativeConnections): # TODO missing initilization
858
- # Initialize all weights to the correct_class_connection value
859
- self.last_layer.weight.data.fill_(self.correct_class_connection)
860
-
861
-
862
- class AudioProtoNetModel(AudioProtoNetPreTrainedModel):
863
- _auto_class = "AutoModel"
864
-
865
- def __init__(self, config: AudioProtoNetConfig):
866
- super().__init__(config)
867
- backbone_config = ConvNextConfig.from_pretrained("facebook/convnext-base-224-22k", num_channels=1)
868
- self.backbone = ConvNextModel(backbone_config)
869
-
870
- def forward(
871
- self,
872
- input_values: torch.Tensor,
873
- output_hidden_states: bool = None,
874
- return_dict: bool = None
875
- ) -> tuple | BaseModelOutputWithPoolingAndNoAttention:
876
- """
877
- Args:
878
- input_values:
879
- output_hidden_states:
880
- return_dict:
881
-
882
- Returns:
883
- last_hidden_state: torch.FloatTensor = None
884
- pooler_output: torch.FloatTensor = None
885
- hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
886
-
887
- """
888
- return self.backbone(input_values, output_hidden_states, return_dict)
889
-
890
-
891
- class AudioProtoNetForSequenceClassification(PreTrainedModel):
892
- _auto_class = "AutoModelForSequenceClassification"
893
-
894
- def __init__(self, config: AudioProtoNetConfig):
895
- super().__init__(config)
896
-
897
- self.model = AudioProtoNetModel(config)
898
- self.head = AudioProtoNetClassificationHead(config)
899
-
900
-
901
- def freeze_backbone(self):
902
- pass
903
-
904
- def int2str(self): # TODO
905
- pass
906
-
907
- def forward(
908
- self,
909
- input_values: torch.Tensor,
910
- labels: torch.Tensor = None,
911
- prototypes_of_wrong_class: torch.Tensor = None,
912
- output_hidden_states: bool = None,
913
- output_prototypical_activations: bool = None,
914
- return_dict: bool = None,
915
- ) -> tuple | SequenceClassifierOutputWithProtoTypeActivations:
916
-
917
- backbone_outputs = self.model(input_values, output_hidden_states, return_dict)
918
-
919
- last_hidden_state = backbone_outputs[0]
920
-
921
- logits, info = self.head(last_hidden_state, prototypes_of_wrong_class)
922
-
923
- loss = None
924
- if labels is not None:
925
- labels.to(logits.device)
926
- loss_fct = AsymmetricLossMultiLabel()
927
- loss = loss_fct(logits, labels.float())
928
-
929
-
930
-
931
- hidden_states = None
932
- if output_hidden_states is not None:
933
- hidden_states = backbone_outputs[2]
934
-
935
- prototype_activations = None
936
- if output_prototypical_activations is not None:
937
- prototype_activations = info[4]
938
-
939
- if return_dict:
940
- output = (logits,)
941
- output += (loss, ) if loss is not None else ()
942
- output += (last_hidden_state, )
943
- output += (hidden_states, ) if hidden_states is not None else ()
944
- output += (prototype_activations,) if prototype_activations is not None else ()
945
- return output
946
-
947
- return SequenceClassifierOutputWithProtoTypeActivations(
948
- logits=logits,
949
- loss=loss,
950
- last_hidden_state=last_hidden_state,
951
- hidden_states=hidden_states,
952
- prototype_activations=prototype_activations
953
- )
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch import Tensor
4
+ import math
5
+
6
+ from transformers import PreTrainedModel, ConvNextModel, ConvNextConfig
7
+ from transformers.utils import logging
8
+ from transformers.modeling_outputs import ModelOutput, BaseModelOutputWithPoolingAndNoAttention
9
+ from dataclasses import dataclass
10
+
11
+ from .configuration_protonet import AudioProtoNetConfig
12
+
13
+ logger = logging.get_logger(__name__)
14
+
15
+
16
+ @dataclass
17
+ class SequenceClassifierOutputWithProtoTypeActivations(ModelOutput):
18
+ logits: torch.Tensor
19
+ loss: torch.Tensor = None
20
+ last_hidden_state: torch.FloatTensor = None
21
+ hidden_states: tuple[torch.FloatTensor, ...] = None
22
+ prototype_activations: torch.FloatTensor = None
23
+
24
+
25
+ # https://openaccess.thecvf.com/content/ICCV2021/papers/Ridnik_Asymmetric_Loss_for_Multi-Label_Classification_ICCV_2021_paper.pdf
26
+ # https://github.com/huggingface/pytorch-image-models/blob/bbe798317fb26f063c18279827c038058e376479/timm/loss/asymmetric_loss.py#L6
27
+ class AsymmetricLossMultiLabel(nn.Module):
28
+ def __init__(
29
+ self,
30
+ gamma_neg=4,
31
+ gamma_pos=1,
32
+ clip=0.05,
33
+ eps=1e-8,
34
+ disable_torch_grad_focal_loss=False,
35
+ reduction="mean",
36
+ ):
37
+ super().__init__()
38
+
39
+ self.gamma_neg = gamma_neg
40
+ self.gamma_pos = gamma_pos
41
+ self.clip = clip
42
+ self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
43
+ self.eps = eps
44
+ self.reduction = reduction
45
+
46
+ def forward(self, x, y):
47
+ """ "
48
+ Parameters
49
+ ----------
50
+ x: input logits
51
+ y: targets (multi-label binarized vector)
52
+ """
53
+
54
+ # Calculating Probabilities
55
+ x_sigmoid = torch.sigmoid(x)
56
+ xs_pos = x_sigmoid
57
+ xs_neg = 1 - x_sigmoid
58
+
59
+ # Asymmetric Clipping
60
+ if self.clip is not None and self.clip > 0:
61
+ xs_neg = (xs_neg + self.clip).clamp(max=1)
62
+
63
+ # Basic CE calculation
64
+ los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
65
+ los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
66
+ loss = los_pos + los_neg
67
+
68
+ # Asymmetric Focusing
69
+ if self.gamma_neg > 0 or self.gamma_pos > 0:
70
+ if self.disable_torch_grad_focal_loss:
71
+ torch._C.set_grad_enabled(False)
72
+ pt0 = xs_pos * y
73
+ pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
74
+ pt = pt0 + pt1
75
+ one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
76
+ one_sided_w = torch.pow(1 - pt, one_sided_gamma)
77
+ if self.disable_torch_grad_focal_loss:
78
+ torch._C.set_grad_enabled(True)
79
+ loss *= one_sided_w
80
+
81
+ if self.reduction == "mean":
82
+ return -loss.mean()
83
+ if self.reduction == "sum":
84
+ return -loss.sum()
85
+
86
+ return -loss
87
+
88
+
89
+ class NonNegativeLinear(nn.Module):
90
+ """
91
+ A PyTorch module for a linear layer with non-negative weights.
92
+
93
+ This module applies a linear transformation to the incoming data: `y = xA^T + b`.
94
+ The weights of the transformation are constrained to be non-negative, making this
95
+ module particularly useful in models where negative weights may not be appropriate.
96
+
97
+ Attributes:
98
+ in_features (int): The number of features in the input tensor.
99
+ out_features (int): The number of features in the output tensor.
100
+ weight (torch.Tensor): The weight parameter of the module, constrained to be non-negative.
101
+ bias (torch.Tensor, optional): The bias parameter of the module.
102
+
103
+ Args:
104
+ in_features (int): The number of features in the input tensor.
105
+ out_features (int): The number of features in the output tensor.
106
+ bias (bool, optional): If True, the layer will include a learnable bias. Default: True.
107
+ device (optional): The device (CPU/GPU) on which to perform computations.
108
+ dtype (optional): The data type for the parameters (e.g., float32).
109
+ """
110
+
111
+ def __init__(
112
+ self,
113
+ in_features: int,
114
+ out_features: int,
115
+ bias: bool = True,
116
+ device=None,
117
+ dtype=None,
118
+ ) -> None:
119
+ factory_kwargs = {"device": device, "dtype": dtype}
120
+ super().__init__()
121
+ self.in_features = in_features
122
+ self.out_features = out_features
123
+ self.weight = nn.Parameter(
124
+ torch.empty((out_features, in_features), **factory_kwargs)
125
+ )
126
+ if bias:
127
+ self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
128
+ else:
129
+ self.register_parameter("bias", None)
130
+
131
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
132
+ """
133
+ Defines the forward pass of the NonNegativeLinear module.
134
+
135
+ Args:
136
+ input (torch.Tensor): The input tensor of shape (batch_size, in_features).
137
+
138
+ Returns:
139
+ torch.Tensor: The output tensor of shape (batch_size, out_features).
140
+ """
141
+ return nn.functional.linear(input, torch.relu(self.weight), self.bias)
142
+
143
+
144
+ class LinearLayerWithoutNegativeConnections(nn.Module):
145
+ r"""
146
+ Custom Linear Layer where each output class is connected to a specific subset of input features.
147
+
148
+ Args:
149
+ in_features: size of each input sample
150
+ out_features: size of each output sample
151
+ bias: If set to ``False``, the layer will not learn an additive bias.
152
+ Default: ``True``
153
+ device: the device of the module parameters. Default: ``None``
154
+ dtype: the data type of the module parameters. Default: ``None``
155
+
156
+ Shape:
157
+ - Input: :math:`(*, H_{in})` where :math:`*` means any number of
158
+ dimensions including none and :math:`H_{in} = \text{in_features}`.
159
+ - Output: :math:`(*, H_{out})` where all but the last dimension
160
+ are the same shape as the input and :math:`H_{out} = \text{out_features}`.
161
+
162
+ Attributes:
163
+ weight: the learnable weights of the module of shape
164
+ :math:`(\text{out_features}, \text{features_per_output_class})`.
165
+ bias: the learnable bias of the module of shape :math:`(\text{out_features})`.
166
+ If :attr:`bias` is ``True``, the values are initialized from
167
+ :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
168
+ :math:`k = \frac{1}{\text{features_per_output_class}}`
169
+ """
170
+
171
+ __constants__ = ["in_features", "out_features", "bias"]
172
+ in_features: int
173
+ out_features: int
174
+ weight: torch.Tensor
175
+
176
+ def __init__(
177
+ self,
178
+ in_features: int,
179
+ out_features: int,
180
+ bias: bool = True,
181
+ non_negative: bool = True,
182
+ device: torch.device = None,
183
+ dtype: torch.dtype = None,
184
+ ) -> None:
185
+ factory_kwargs = {"device": device, "dtype": dtype}
186
+ super().__init__()
187
+ self.in_features = in_features
188
+ self.out_features = out_features
189
+ self.non_negative = non_negative
190
+
191
+ # Calculate the number of features per output class
192
+ self.features_per_output_class = in_features // out_features
193
+
194
+ # Ensure input size is divisible by the output size
195
+ assert (
196
+ in_features % out_features == 0
197
+ ), f"{in_features = } must be divisible by {out_features = }"
198
+
199
+ # Define weights and biases
200
+ self.weight = nn.Parameter(
201
+ torch.empty(
202
+ (out_features, self.features_per_output_class), **factory_kwargs
203
+ )
204
+ )
205
+ if bias:
206
+ self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
207
+ else:
208
+ self.register_parameter("bias", None)
209
+
210
+ # Initialize weights and biases
211
+ self.reset_parameters()
212
+
213
+ def reset_parameters(self) -> None:
214
+ """
215
+ Initialize the weights and biases.
216
+ Weights are initialized using Kaiming uniform initialization.
217
+ Biases are initialized using a uniform distribution.
218
+ """
219
+ # Kaiming uniform initialization for the weights
220
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
221
+
222
+ if self.bias is not None:
223
+ # Calculate fan-in and fan-out values
224
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
225
+
226
+ # Uniform initialization for the biases
227
+ bound = 1 / math.sqrt(fan_in)
228
+ nn.init.uniform_(self.bias, -bound, bound)
229
+
230
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
231
+ """
232
+ Forward pass for the custom linear layer.
233
+
234
+ Args:
235
+ input (Tensor): Input tensor of shape (batch_size, in_features).
236
+
237
+ Returns:
238
+ Tensor: Output tensor of shape (batch_size, out_features).
239
+ """
240
+ batch_size = input.size(0)
241
+ # Reshape input to (batch_size, out_features, features_per_output_class)
242
+ reshaped_input = input.view(
243
+ batch_size, self.out_features, self.features_per_output_class
244
+ )
245
+
246
+ # Apply ReLU to weights if non_negative_last_layer is True
247
+ weight = torch.relu(self.weight) if self.non_negative else self.weight
248
+
249
+ # Perform batch matrix multiplication and add bias
250
+ output = torch.einsum("bof,of->bo", reshaped_input, weight)
251
+
252
+ if self.bias is not None:
253
+ output += self.bias
254
+
255
+ return output
256
+
257
+ def extra_repr(self) -> str:
258
+ return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}"
259
+
260
+
261
+ class AudioProtoNetClassificationHead(nn.Module):
262
+ def __init__(
263
+ self,
264
+ config: AudioProtoNetConfig,
265
+ ) -> None:
266
+ """
267
+ PPNet is a class that implements the Prototypical Part Network (ProtoPNet) for prototype-based classification.
268
+ """
269
+
270
+ super().__init__()
271
+ self.prototypes_per_class = config.prototypes_per_class
272
+ self.num_classes = config.num_classes
273
+ self.num_prototypes = self.prototypes_per_class * self.num_classes
274
+ self.num_prototypes_after_pruning = config.num_prototypes_after_pruning
275
+ self.margin = config.margin
276
+ self.relu_on_cos = config.relu_on_cos
277
+ self.incorrect_class_connection = config.incorrect_class_connection
278
+ self.correct_class_connection = config.correct_class_connection
279
+ self.input_vector_length = config.input_vector_length
280
+ self.n_eps_channels = config.n_eps_channels
281
+ self.epsilon_val = config.epsilon_val
282
+ self.topk_k = config.topk_k
283
+ self.bias_last_layer = config.bias_last_layer
284
+ self.non_negative_last_layer = config.non_negative_last_layer
285
+ self.embedded_spectrogram_height = config.embedded_spectrogram_height
286
+ self.use_bias_last_layer = config.use_bias_last_layer
287
+ self.prototype_class_identity = config.prototype_class_identity
288
+
289
+ # Create a 1D tensor where each element represents the class index
290
+ self.prototype_class_identity = (
291
+ torch.arange(self.num_prototypes) // self.prototypes_per_class
292
+ )
293
+
294
+ self.prototype_shape = (self.num_prototypes, config.channels, config.height, config.width)
295
+
296
+ self._setup_add_on_layers(add_on_layers_type=config.add_on_layers_type)
297
+
298
+ self.prototype_vectors = nn.Parameter(
299
+ torch.rand(self.prototype_shape), requires_grad=True
300
+ )
301
+
302
+ self.frequency_weights = None
303
+ if self.embedded_spectrogram_height is not None:
304
+ # Initialize the frequency weights with a large positive value of 3.0 so that sigmoid(frequency_weights) is close to 1.
305
+ self.frequency_weights = nn.Parameter(
306
+ torch.full(
307
+ (
308
+ self.num_prototypes,
309
+ self.embedded_spectrogram_height,
310
+ ),
311
+ 3.0,
312
+ )
313
+ )
314
+
315
+
316
+ if self.incorrect_class_connection:
317
+ if self.non_negative_last_layer:
318
+ self.last_layer = NonNegativeLinear(
319
+ self.num_prototypes, self.num_classes, bias=self.use_bias_last_layer
320
+ )
321
+ else:
322
+ self.last_layer = nn.Linear(
323
+ self.num_prototypes, self.num_classes, bias=self.use_bias_last_layer
324
+ )
325
+ else:
326
+ self.last_layer = LinearLayerWithoutNegativeConnections(
327
+ in_features=self.num_prototypes,
328
+ out_features=self.num_classes,
329
+ non_negative=self.non_negative_last_layer,
330
+ )
331
+
332
+ def forward(
333
+ self,
334
+ features: torch.Tensor,
335
+ prototypes_of_wrong_class: torch.Tensor = None,
336
+ ) -> tuple[torch.Tensor, list[torch.Tensor]]:
337
+ """
338
+ Forward pass of the PPNet model.
339
+
340
+ Args:
341
+ - x (torch.Tensor): Input tensor with shape (batch_size, num_channels, height, width).
342
+ - prototypes_of_wrong_class (Optional[torch.Tensor]): The prototypes of the wrong classes that are needed
343
+ when using subtractive margins. Defaults to None.
344
+
345
+ Returns:
346
+ Tuple[torch.Tensor, List[torch.Tensor]]:
347
+ - logits: A tensor containing the logits for each class in the model.
348
+ - a list containing:
349
+ - mean_activations: A tensor containing the mean of the top-k prototype activations.
350
+ (in evaluation mode k is always 1)
351
+ - marginless_logits: A tensor containing the logits for each class in the model, calculated using the
352
+ marginless activations.
353
+ - conv_features: A tensor containing the convolutional features.
354
+ - marginless_max_activations: A tensor containing the max-pooled marginless activations.
355
+
356
+ """
357
+
358
+ features = self.add_on_layers(features)
359
+
360
+ activations, additional_returns = self.prototype_activations(
361
+ features, prototypes_of_wrong_class=prototypes_of_wrong_class
362
+ )
363
+ marginless_activations = additional_returns[0]
364
+ conv_features = additional_returns[1]
365
+
366
+ # Set topk_k based on training mode: use predefined value if training, else 1 for evaluation
367
+ topk_k = 1
368
+
369
+ # Reshape activations to combine spatial dimensions: (batch_size, num_prototypes, height*width)
370
+ activations = activations.view(activations.shape[0], activations.shape[1], -1)
371
+
372
+ # Perform top-k pooling along the combined spatial dimension
373
+ # For topk_k=1, this is equivalent to global max pooling
374
+ topk_activations, _ = torch.topk(activations, topk_k, dim=-1)
375
+
376
+ # Calculate the mean of the top-k activations for each channel: (batch_size, num_channels)
377
+ # If topk_k=1, this mean operation does nothing since there's only one value.
378
+ mean_activations = torch.mean(topk_activations, dim=-1)
379
+
380
+ marginless_max_activations = nn.functional.max_pool2d(
381
+ marginless_activations,
382
+ kernel_size=(
383
+ marginless_activations.size()[2],
384
+ marginless_activations.size()[3],
385
+ ),
386
+ )
387
+ marginless_max_activations = marginless_max_activations.view(
388
+ -1, self.num_prototypes
389
+ )
390
+
391
+ logits = self.last_layer(mean_activations)
392
+ marginless_logits = self.last_layer(marginless_max_activations)
393
+ return logits, [
394
+ mean_activations,
395
+ marginless_logits,
396
+ conv_features,
397
+ marginless_max_activations,
398
+ marginless_activations,
399
+ ]
400
+
401
+ # def conv_features(self, x: torch.Tensor) -> torch.Tensor:
402
+ # """
403
+ # Takes an input tensor and passes it through the backbone model to extract features.
404
+ # Then, it passes them through the additional layers to produce the output tensor.
405
+ #
406
+ # Args:
407
+ # x (torch.Tensor): The input tensor.
408
+ #
409
+ # Returns:
410
+ # torch.Tensor: The output tensor after passing through the backbone model and additional layers.
411
+ # """
412
+ # # Extract features using the backbone model
413
+ # features = self.backbone_model(x)
414
+ #
415
+ # # The features must be a 4D tensor of shape (batch size, channels, height, width)
416
+ # if features.dim() == 3:
417
+ # features.unsqueeze_(0)
418
+ #
419
+ # # Pass the features through additional layers
420
+ # output = self.add_on_layers(features)
421
+ #
422
+ # return output
423
+
424
+ def cos_activation(
425
+ self,
426
+ x: torch.Tensor,
427
+ prototypes_of_wrong_class: torch.Tensor = None,
428
+ ) -> tuple[torch.Tensor, torch.Tensor]:
429
+ """
430
+ Compute the cosine activation between input tensor x and prototype vectors.
431
+
432
+ Parameters:
433
+ -----------
434
+ x : torch.Tensor
435
+ Input tensor with shape (batch_size, num_channels, height, width).
436
+ prototypes_of_wrong_class : Optional[torch.Tensor]
437
+ Tensor containing the prototypes of the wrong class with shape (batch_size, num_prototypes).
438
+
439
+ Returns:
440
+ --------
441
+ Tuple[torch.Tensor, torch.Tensor]
442
+ A tuple containing:
443
+ - activations: The cosine activations with potential margin adjustments.
444
+ - marginless_activations: The cosine activations without margin adjustments.
445
+ """
446
+ input_vector_length = self.input_vector_length
447
+ normalizing_factor = (
448
+ self.prototype_shape[-2] * self.prototype_shape[-1]
449
+ ) ** 0.5
450
+
451
+ # Pre-allocate epsilon channels on the correct device for input tensor x
452
+ epsilon_channel_x = torch.full(
453
+ (x.shape[0], self.n_eps_channels, x.shape[2], x.shape[3]),
454
+ self.epsilon_val,
455
+ device=x.device,
456
+ requires_grad=False,
457
+ )
458
+ x = torch.cat((x, epsilon_channel_x), dim=-3)
459
+
460
+ # Normalize x
461
+ x_length = torch.sqrt(torch.sum(x**2, dim=-3, keepdim=True) + self.epsilon_val)
462
+ x_normalized = (input_vector_length * x / x_length) / normalizing_factor
463
+
464
+ # Pre-allocate epsilon channels for prototypes on the correct device
465
+ epsilon_channel_p = torch.full(
466
+ (
467
+ self.prototype_shape[0],
468
+ self.n_eps_channels,
469
+ self.prototype_shape[2],
470
+ self.prototype_shape[3],
471
+ ),
472
+ self.epsilon_val,
473
+ device=self.prototype_vectors.device,
474
+ requires_grad=False,
475
+ )
476
+ appended_protos = torch.cat((self.prototype_vectors, epsilon_channel_p), dim=-3)
477
+
478
+ # Normalize prototypes
479
+ prototype_vector_length = torch.sqrt(
480
+ torch.sum(appended_protos**2, dim=-3, keepdim=True) + self.epsilon_val
481
+ )
482
+ normalized_prototypes = appended_protos / (
483
+ prototype_vector_length + self.epsilon_val
484
+ )
485
+ normalized_prototypes /= normalizing_factor
486
+
487
+ # Compute activations using convolution
488
+ activations_dot = nn.functional.conv2d(x_normalized, normalized_prototypes)
489
+ marginless_activations = activations_dot / (input_vector_length * 1.01)
490
+
491
+ if self.frequency_weights is not None:
492
+ # Apply sigmoid to frequency weights. s.t. weights are between 0 and 1.
493
+ freq_weights = torch.sigmoid(self.frequency_weights)
494
+
495
+ # Multiply each prototype's frequency response by the corresponding weights
496
+ marginless_activations = marginless_activations * freq_weights[:, :, None]
497
+
498
+ if (
499
+ self.margin is None
500
+ or not self.training
501
+ or prototypes_of_wrong_class is None
502
+ ):
503
+ activations = marginless_activations
504
+ else:
505
+ # Apply margin adjustment for wrong class prototypes
506
+ wrong_class_margin = (prototypes_of_wrong_class * self.margin).view(
507
+ x.size(0), self.prototype_vectors.size(0), 1, 1
508
+ )
509
+ wrong_class_margin = wrong_class_margin.expand(
510
+ -1, -1, activations_dot.size(-2), activations_dot.size(-1)
511
+ )
512
+ penalized_angles = (
513
+ torch.acos(activations_dot / (input_vector_length * 1.01))
514
+ - wrong_class_margin
515
+ )
516
+ activations = torch.cos(torch.relu(penalized_angles))
517
+
518
+ if self.relu_on_cos:
519
+ # Apply ReLU activation on the cosine values
520
+ activations = torch.relu(activations)
521
+ marginless_activations = torch.relu(marginless_activations)
522
+
523
+ return activations, marginless_activations
524
+
525
+ def prototype_activations(
526
+ self,
527
+ x: torch.Tensor,
528
+ prototypes_of_wrong_class: torch.Tensor = None,
529
+ ) -> tuple[torch.Tensor, list[torch.Tensor]]:
530
+ """
531
+ Compute the prototype activations for a given input tensor.
532
+
533
+ Args:
534
+ - x (torch.Tensor): The raw input tensor with shape (batch_size, num_channels, height, width).
535
+ - prototypes_of_wrong_class (Optional[torch.Tensor]): The prototypes of the wrong classes that are needed
536
+ when using subtractive margins. Defaults to None.
537
+
538
+ Returns:
539
+ Tuple[torch.Tensor, List[torch.Tensor]]:
540
+ - activations: A tensor containing the prototype activations.
541
+ - a list containing:
542
+ - marginless_activations: A tensor containing the activations before applying subtractive margin.
543
+ - conv_features: A tensor containing the convolutional features.
544
+ """
545
+ # Compute cosine activations
546
+ activations, marginless_activations = self.cos_activation(
547
+ x,
548
+ prototypes_of_wrong_class=prototypes_of_wrong_class,
549
+ )
550
+
551
+ return activations, [marginless_activations, x]
552
+
553
+ def get_prototype_orthogonalities(self, use_part_prototypes: bool = False) -> torch.Tensor:
554
+ """
555
+ Computes the orthogonality loss, encouraging each piece of a prototype to be orthogonal to the others.
556
+
557
+ This method is inspired by the paper:
558
+ https://openaccess.thecvf.com/content/ICCV2021/papers/Wang_Interpretable_Image_Recognition_by_Constructing_Transparent_Embedding_Space_ICCV_2021_paper.pdf
559
+
560
+ Args:
561
+ use_part_prototypes (bool): If True, treats each spatial part of the prototypes as a separate prototype.
562
+
563
+ Returns:
564
+ torch.Tensor: A tensor representing the orthogonalities.
565
+ """
566
+
567
+ if use_part_prototypes:
568
+ # Normalize prototypes to unit length
569
+ prototype_vector_length = torch.sqrt(
570
+ torch.sum(torch.square(self.prototype_vectors), dim=1, keepdim=True)
571
+ + self.epsilon_val
572
+ )
573
+ normalized_prototypes = self.prototype_vectors / (
574
+ prototype_vector_length + self.epsilon_val
575
+ )
576
+
577
+ # Calculate total part prototypes per class
578
+ num_part_prototypes_per_class = (
579
+ self.num_prototypes_per_class
580
+ * self.prototype_shape[2]
581
+ * self.prototype_shape[3]
582
+ )
583
+
584
+ # Reshape to match class structure
585
+ normalized_prototypes = normalized_prototypes.view(
586
+ self.num_classes,
587
+ self.num_prototypes_per_class,
588
+ self.prototype_shape[1],
589
+ self.prototype_shape[2] * self.prototype_shape[3],
590
+ )
591
+
592
+ # Transpose and reshape to treat each spatial part as a separate prototype
593
+ normalized_prototypes = normalized_prototypes.permute(0, 1, 3, 2).reshape(
594
+ self.num_classes, num_part_prototypes_per_class, self.prototype_shape[1]
595
+ )
596
+
597
+ else:
598
+ # Normalize prototypes to unit length
599
+ prototype_vectors_reshaped = self.prototype_vectors.view(
600
+ self.num_prototypes, -1
601
+ )
602
+ prototype_vector_length = torch.sqrt(
603
+ torch.sum(torch.square(prototype_vectors_reshaped), dim=1, keepdim=True)
604
+ + self.epsilon_val
605
+ )
606
+ normalized_prototypes = prototype_vectors_reshaped / (
607
+ prototype_vector_length + self.epsilon_val
608
+ )
609
+
610
+ # Reshape to match class structure
611
+ normalized_prototypes = normalized_prototypes.view(
612
+ self.num_classes,
613
+ self.num_prototypes_per_class,
614
+ self.prototype_shape[1]
615
+ * self.prototype_shape[2]
616
+ * self.prototype_shape[3],
617
+ )
618
+
619
+ # Compute orthogonality matrix for each class
620
+ orthogonalities = torch.matmul(
621
+ normalized_prototypes, normalized_prototypes.transpose(1, 2)
622
+ )
623
+
624
+ # Identity matrix to enforce orthogonality
625
+ identity_matrix = (
626
+ torch.eye(normalized_prototypes.shape[1], device=orthogonalities.device)
627
+ .unsqueeze(0)
628
+ .repeat(self.num_classes, 1, 1)
629
+ )
630
+
631
+ # Subtract identity to focus on orthogonality
632
+ orthogonalities = orthogonalities - identity_matrix
633
+
634
+ return orthogonalities
635
+
636
+ def identify_prototypes_to_prune(self) -> list[int]:
637
+ """
638
+ Identifies the indices of prototypes that should be pruned.
639
+
640
+ This function iterates through the prototypes and checks if the specific weight
641
+ connecting the prototype to its class is zero. It is specifically designed to handle
642
+ the LinearLayerWithoutNegativeConnections where each class has a subset of features
643
+ it connects to.
644
+
645
+ Returns:
646
+ list[int]: A list of prototype indices that should be pruned.
647
+ """
648
+ prototypes_to_prune = []
649
+
650
+ # Calculate the number of prototypes assigned to each class
651
+ prototypes_per_class = self.num_prototypes // self.num_classes
652
+
653
+ if isinstance(self.last_layer, LinearLayerWithoutNegativeConnections):
654
+ # Custom layer mapping prototypes to a subset of input features for each output class
655
+ for prototype_index in range(self.num_prototypes):
656
+ class_index = self.prototype_class_identity[prototype_index]
657
+ # Calculate the specific index within the 'features_per_output_class' for this prototype
658
+ index_within_class = prototype_index % prototypes_per_class
659
+ # Check if the specific weight connecting the prototype to its class is zero
660
+ if self.last_layer.weight[class_index, index_within_class] == 0.0:
661
+ prototypes_to_prune.append(prototype_index)
662
+ else:
663
+ # Standard linear layer: each prototype directly maps to a feature index
664
+ weights_to_check = self.last_layer.weight
665
+ for prototype_index in range(self.num_prototypes):
666
+ class_index = self.prototype_class_identity[prototype_index]
667
+ if weights_to_check[class_index, prototype_index] == 0.0:
668
+ prototypes_to_prune.append(prototype_index)
669
+
670
+ return prototypes_to_prune
671
+
672
+ def prune_prototypes_by_threshold(self, threshold: float = 1e-3) -> None:
673
+ """
674
+ Prune the weights in the classification layer by setting weights below a specified threshold to zero.
675
+
676
+ This method modifies the weights of the last layer of the model in-place. Weights falling below the
677
+ threshold are set to zero, diminishing their influence in the model's decisions. It also identifies
678
+ and prunes prototypes based on these updated weights, thereby refining the model's structure.
679
+
680
+ Args:
681
+ threshold (float): The threshold value below which weights will be set to zero. Defaults to 1e-3.
682
+ """
683
+ # Access the weights of the last layer
684
+ weights = self.last_layer.weight.data
685
+
686
+ # Set weights below the threshold to zero
687
+ # This step reduces the influence of low-value weights in the model's decision-making process
688
+ weights[weights < threshold] = 0.0
689
+
690
+ # Update the weights in the last layer to reflect the pruning
691
+ self.last_layer.weight.data.copy_(weights)
692
+
693
+ # Identify prototypes that need to be pruned based on the updated weights
694
+ prototypes_to_prune = self.identify_prototypes_to_prune()
695
+
696
+ # Execute the pruning of identified prototypes
697
+ self.prune_prototypes_by_index(prototypes_to_prune)
698
+
699
+ def prune_prototypes_by_index(self, prototypes_to_prune: list[int]) -> None:
700
+ """
701
+ Prunes specified prototypes from the PPNet.
702
+
703
+ Args:
704
+ prototypes_to_prune (list[int]): A list of indices indicating the prototypes to be removed.
705
+ Each index should be in the range [0, current number of prototypes - 1].
706
+
707
+ Returns:
708
+ None
709
+ """
710
+
711
+ # Validate the provided indices to ensure they are within the valid range
712
+ if any(
713
+ index < 0 or index >= self.num_prototypes for index in prototypes_to_prune
714
+ ):
715
+ raise ValueError("Provided prototype indices are out of valid range!")
716
+
717
+ # Calculate the new number of prototypes after pruning
718
+ self.num_prototypes_after_pruning = self.num_prototypes - len(
719
+ prototypes_to_prune
720
+ )
721
+
722
+ # Remove the prototype vectors that are no longer needed
723
+ with torch.no_grad():
724
+ # If frequency_weights are being used, set the weights of pruned prototypes to -7
725
+ if self.frequency_weights is not None:
726
+ self.frequency_weights.data[prototypes_to_prune, :] = -7.0
727
+
728
+ # Adjust the weights in the last layer depending on its type
729
+ if isinstance(self.last_layer, LinearLayerWithoutNegativeConnections):
730
+ # For LinearLayerWithoutNegativeConnections, set the connection weights to zero
731
+ # only for the pruned prototypes related to their specific classes
732
+ for class_idx in range(self.last_layer.out_features):
733
+ # Identify prototypes belonging to the current class
734
+ indices_for_class = [
735
+ idx % self.last_layer.features_per_output_class
736
+ for idx in prototypes_to_prune
737
+ if self.prototype_class_identity[idx] == class_idx
738
+ ]
739
+ self.last_layer.weight.data[class_idx, indices_for_class] = 0.0
740
+ else:
741
+ # For other layer types, set the weights of pruned prototypes to zero
742
+ self.last_layer.weight.data[:, prototypes_to_prune] = 0.0
743
+
744
+ def __repr__(self) -> str:
745
+ rep = f"""PPNet(
746
+ prototype_shape: {self.prototype_shape},
747
+ num_classes: {self.num_classes},
748
+ epsilon: {self.epsilon_val})"""
749
+
750
+ return rep
751
+
752
+ def set_last_layer_incorrect_connection(
753
+ self, incorrect_strength: float = None
754
+ ) -> None:
755
+ """
756
+ Modifies the last layer weights to have incorrect connections with a specified strength.
757
+ If incorrect_strength is None, initializes the weights for LinearLayerWithoutNegativeConnections
758
+ with correct_class_connection value.
759
+
760
+ Args:
761
+ - incorrect_strength (Optional[float]): The strength of the incorrect connections.
762
+ If None, initialize without incorrect connections.
763
+
764
+ Returns:
765
+ None
766
+ """
767
+ if incorrect_strength is None:
768
+ # Handle LinearLayerWithoutNegativeConnections initialization
769
+ if isinstance(self.last_layer, LinearLayerWithoutNegativeConnections):
770
+ # Initialize all weights to the correct_class_connection value
771
+ self.last_layer.weight.data.fill_(self.correct_class_connection)
772
+ else:
773
+ raise ValueError(
774
+ "last_layer is not an instance of LinearLayerWithoutNegativeConnections"
775
+ )
776
+
777
+ else:
778
+ # Create a one-hot matrix for correct connections
779
+ positive_one_weights_locations = torch.zeros(
780
+ self.num_classes, self.num_prototypes
781
+ )
782
+ positive_one_weights_locations[
783
+ self.prototype_class_identity,
784
+ torch.arange(self.num_prototypes),
785
+ ] = 1
786
+
787
+ # Create a matrix for incorrect connections
788
+ negative_one_weights_locations = 1 - positive_one_weights_locations
789
+
790
+ # This variable represents the strength of the connection for correct class
791
+ correct_class_connection = self.correct_class_connection
792
+
793
+ # This variable represents the strength of the connection for incorrect class
794
+ incorrect_class_connection = incorrect_strength
795
+
796
+ # Modify weights to have correct and incorrect connections
797
+ self.last_layer.weight.data.copy_(
798
+ correct_class_connection * positive_one_weights_locations
799
+ + incorrect_class_connection * negative_one_weights_locations
800
+ )
801
+
802
+ if self.last_layer.bias is not None:
803
+ # Initialize all biases to bias_last_layer value
804
+ self.last_layer.bias.data.fill_(self.bias_last_layer)
805
+
806
+ def _setup_add_on_layers(self, add_on_layers_type: str):
807
+ """
808
+ Configures additional layers based on the backbone model architecture and the specified add_on_layers_type.
809
+
810
+ Args:
811
+ add_on_layers_type (str): Type of additional layers to add. Can be 'identity' or 'upsample'.
812
+ """
813
+
814
+ if add_on_layers_type == "identity":
815
+ self.add_on_layers = nn.Sequential(nn.Identity())
816
+ elif add_on_layers_type == "upsample":
817
+ self.add_on_layers = nn.Upsample(scale_factor=2, mode="bilinear")
818
+ else:
819
+ raise NotImplementedError(
820
+ f"The add-on layer type {add_on_layers_type} isn't implemented yet."
821
+ )
822
+
823
+ # TODO
824
+ # def _initialize_weights(self) -> None:
825
+ # """
826
+ # Initializes the weights of the add-on layers of the network and the last layer with incorrect connections.
827
+ #
828
+ # Returns:
829
+ # None
830
+ # """
831
+ #
832
+ # for m in self.add_on_layers.modules():
833
+ # if isinstance(m, (nn.Conv2d, nn.Linear)):
834
+ # nn.init.trunc_normal_(m.weight, std=0.02)
835
+ # if m.bias is not None:
836
+ # nn.init.zeros_(m.bias)
837
+ #
838
+ # # Initialize the last layer with incorrect connections using specified incorrect class connection strength
839
+ # self.set_last_layer_incorrect_connection(
840
+ # incorrect_strength=self.incorrect_class_connection
841
+ # )
842
+
843
+
844
+ class AudioProtoNetPreTrainedModel(PreTrainedModel):
845
+ config_class = AudioProtoNetConfig
846
+ base_model_prefix = "model"
847
+
848
+ def _init_weights(self, module):
849
+ if isinstance(module, (nn.Conv2d, nn.Linear)):
850
+ nn.init.trunc_normal_(module.weight, std=0.02)
851
+ if module.bias is not None:
852
+ nn.init.zeros_(module.bias)
853
+ if isinstance(module, (nn.Conv2d, nn.Linear)):
854
+ nn.init.trunc_normal_(module.weight, std=0.02)
855
+ if module.bias is not None:
856
+ nn.init.zeros_(module.bias)
857
+ if self.incorrect_class_connection is None and isinstance(self.last_layer, LinearLayerWithoutNegativeConnections): # TODO missing initilization
858
+ # Initialize all weights to the correct_class_connection value
859
+ self.last_layer.weight.data.fill_(self.correct_class_connection)
860
+
861
+
862
+ class AudioProtoNetModel(AudioProtoNetPreTrainedModel):
863
+ _auto_class = "AutoModel"
864
+
865
+ def __init__(self, config: AudioProtoNetConfig):
866
+ super().__init__(config)
867
+ backbone_config = ConvNextConfig.from_pretrained("facebook/convnext-base-224-22k", num_channels=1)
868
+ self.backbone = ConvNextModel(backbone_config)
869
+
870
+ def forward(
871
+ self,
872
+ input_values: torch.Tensor,
873
+ output_hidden_states: bool = None,
874
+ return_dict: bool = None
875
+ ) -> tuple | BaseModelOutputWithPoolingAndNoAttention:
876
+ """
877
+ Args:
878
+ input_values:
879
+ output_hidden_states:
880
+ return_dict:
881
+
882
+ Returns:
883
+ last_hidden_state: torch.FloatTensor = None
884
+ pooler_output: torch.FloatTensor = None
885
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
886
+
887
+ """
888
+ return self.backbone(input_values, output_hidden_states, return_dict)
889
+
890
+
891
+ class AudioProtoNetForSequenceClassification(AudioProtoNetPreTrainedModel):
892
+ _auto_class = "AutoModelForSequenceClassification"
893
+
894
+ def __init__(self, config: AudioProtoNetConfig):
895
+ super().__init__(config)
896
+
897
+ self.model = AudioProtoNetModel(config)
898
+ self.head = AudioProtoNetClassificationHead(config)
899
+
900
+
901
+ def freeze_backbone(self):
902
+ pass
903
+
904
+ def int2str(self): # TODO
905
+ pass
906
+
907
+ def forward(
908
+ self,
909
+ input_values: torch.Tensor,
910
+ labels: torch.Tensor = None,
911
+ prototypes_of_wrong_class: torch.Tensor = None,
912
+ output_hidden_states: bool = None,
913
+ output_prototypical_activations: bool = None,
914
+ return_dict: bool = None,
915
+ ) -> tuple | SequenceClassifierOutputWithProtoTypeActivations:
916
+
917
+ backbone_outputs = self.model(input_values, output_hidden_states, return_dict)
918
+
919
+ last_hidden_state = backbone_outputs[0]
920
+
921
+ logits, info = self.head(last_hidden_state, prototypes_of_wrong_class)
922
+
923
+ loss = None
924
+ if labels is not None:
925
+ labels.to(logits.device)
926
+ loss_fct = AsymmetricLossMultiLabel()
927
+ loss = loss_fct(logits, labels.float())
928
+
929
+
930
+
931
+ hidden_states = None
932
+ if output_hidden_states is not None:
933
+ hidden_states = backbone_outputs[2]
934
+
935
+ prototype_activations = None
936
+ if output_prototypical_activations is not None:
937
+ prototype_activations = info[4]
938
+
939
+ if return_dict:
940
+ output = (logits,)
941
+ output += (loss, ) if loss is not None else ()
942
+ output += (last_hidden_state, )
943
+ output += (hidden_states, ) if hidden_states is not None else ()
944
+ output += (prototype_activations,) if prototype_activations is not None else ()
945
+ return output
946
+
947
+ return SequenceClassifierOutputWithProtoTypeActivations(
948
+ logits=logits,
949
+ loss=loss,
950
+ last_hidden_state=last_hidden_state,
951
+ hidden_states=hidden_states,
952
+ prototype_activations=prototype_activations
953
+ )