|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch import Tensor |
|
|
import math |
|
|
|
|
|
from transformers import PreTrainedModel, ConvNextModel, ConvNextConfig |
|
|
from transformers.utils import logging |
|
|
from transformers.modeling_outputs import ModelOutput, BaseModelOutputWithPoolingAndNoAttention |
|
|
from dataclasses import dataclass |
|
|
|
|
|
from .configuration_protonet import AudioProtoNetConfig |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class SequenceClassifierOutputWithProtoTypeActivations(ModelOutput): |
|
|
logits: torch.Tensor |
|
|
loss: torch.Tensor = None |
|
|
last_hidden_state: torch.FloatTensor = None |
|
|
hidden_states: tuple[torch.FloatTensor, ...] = None |
|
|
prototype_activations: torch.FloatTensor = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AsymmetricLossMultiLabel(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
gamma_neg=4, |
|
|
gamma_pos=1, |
|
|
clip=0.05, |
|
|
eps=1e-8, |
|
|
disable_torch_grad_focal_loss=False, |
|
|
reduction="mean", |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.gamma_neg = gamma_neg |
|
|
self.gamma_pos = gamma_pos |
|
|
self.clip = clip |
|
|
self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss |
|
|
self.eps = eps |
|
|
self.reduction = reduction |
|
|
|
|
|
def forward(self, x, y): |
|
|
""" " |
|
|
Parameters |
|
|
---------- |
|
|
x: input logits |
|
|
y: targets (multi-label binarized vector) |
|
|
""" |
|
|
|
|
|
|
|
|
x_sigmoid = torch.sigmoid(x) |
|
|
xs_pos = x_sigmoid |
|
|
xs_neg = 1 - x_sigmoid |
|
|
|
|
|
|
|
|
if self.clip is not None and self.clip > 0: |
|
|
xs_neg = (xs_neg + self.clip).clamp(max=1) |
|
|
|
|
|
|
|
|
los_pos = y * torch.log(xs_pos.clamp(min=self.eps)) |
|
|
los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps)) |
|
|
loss = los_pos + los_neg |
|
|
|
|
|
|
|
|
if self.gamma_neg > 0 or self.gamma_pos > 0: |
|
|
if self.disable_torch_grad_focal_loss: |
|
|
torch._C.set_grad_enabled(False) |
|
|
pt0 = xs_pos * y |
|
|
pt1 = xs_neg * (1 - y) |
|
|
pt = pt0 + pt1 |
|
|
one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y) |
|
|
one_sided_w = torch.pow(1 - pt, one_sided_gamma) |
|
|
if self.disable_torch_grad_focal_loss: |
|
|
torch._C.set_grad_enabled(True) |
|
|
loss *= one_sided_w |
|
|
|
|
|
if self.reduction == "mean": |
|
|
return -loss.mean() |
|
|
if self.reduction == "sum": |
|
|
return -loss.sum() |
|
|
|
|
|
return -loss |
|
|
|
|
|
|
|
|
class NonNegativeLinear(nn.Module): |
|
|
""" |
|
|
A PyTorch module for a linear layer with non-negative weights. |
|
|
|
|
|
This module applies a linear transformation to the incoming data: `y = xA^T + b`. |
|
|
The weights of the transformation are constrained to be non-negative, making this |
|
|
module particularly useful in models where negative weights may not be appropriate. |
|
|
|
|
|
Attributes: |
|
|
in_features (int): The number of features in the input tensor. |
|
|
out_features (int): The number of features in the output tensor. |
|
|
weight (torch.Tensor): The weight parameter of the module, constrained to be non-negative. |
|
|
bias (torch.Tensor, optional): The bias parameter of the module. |
|
|
|
|
|
Args: |
|
|
in_features (int): The number of features in the input tensor. |
|
|
out_features (int): The number of features in the output tensor. |
|
|
bias (bool, optional): If True, the layer will include a learnable bias. Default: True. |
|
|
device (optional): The device (CPU/GPU) on which to perform computations. |
|
|
dtype (optional): The data type for the parameters (e.g., float32). |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_features: int, |
|
|
out_features: int, |
|
|
bias: bool = True, |
|
|
device=None, |
|
|
dtype=None, |
|
|
) -> None: |
|
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
|
super().__init__() |
|
|
self.in_features = in_features |
|
|
self.out_features = out_features |
|
|
self.weight = nn.Parameter( |
|
|
torch.empty((out_features, in_features), **factory_kwargs) |
|
|
) |
|
|
if bias: |
|
|
self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs)) |
|
|
else: |
|
|
self.register_parameter("bias", None) |
|
|
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Defines the forward pass of the NonNegativeLinear module. |
|
|
|
|
|
Args: |
|
|
input (torch.Tensor): The input tensor of shape (batch_size, in_features). |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: The output tensor of shape (batch_size, out_features). |
|
|
""" |
|
|
return nn.functional.linear(input, torch.relu(self.weight), self.bias) |
|
|
|
|
|
|
|
|
class LinearLayerWithoutNegativeConnections(nn.Module): |
|
|
r""" |
|
|
Custom Linear Layer where each output class is connected to a specific subset of input features. |
|
|
|
|
|
Args: |
|
|
in_features: size of each input sample |
|
|
out_features: size of each output sample |
|
|
bias: If set to ``False``, the layer will not learn an additive bias. |
|
|
Default: ``True`` |
|
|
device: the device of the module parameters. Default: ``None`` |
|
|
dtype: the data type of the module parameters. Default: ``None`` |
|
|
|
|
|
Shape: |
|
|
- Input: :math:`(*, H_{in})` where :math:`*` means any number of |
|
|
dimensions including none and :math:`H_{in} = \text{in_features}`. |
|
|
- Output: :math:`(*, H_{out})` where all but the last dimension |
|
|
are the same shape as the input and :math:`H_{out} = \text{out_features}`. |
|
|
|
|
|
Attributes: |
|
|
weight: the learnable weights of the module of shape |
|
|
:math:`(\text{out_features}, \text{features_per_output_class})`. |
|
|
bias: the learnable bias of the module of shape :math:`(\text{out_features})`. |
|
|
If :attr:`bias` is ``True``, the values are initialized from |
|
|
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where |
|
|
:math:`k = \frac{1}{\text{features_per_output_class}}` |
|
|
""" |
|
|
|
|
|
__constants__ = ["in_features", "out_features", "bias"] |
|
|
in_features: int |
|
|
out_features: int |
|
|
weight: torch.Tensor |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_features: int, |
|
|
out_features: int, |
|
|
bias: bool = True, |
|
|
non_negative: bool = True, |
|
|
device: torch.device = None, |
|
|
dtype: torch.dtype = None, |
|
|
) -> None: |
|
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
|
super().__init__() |
|
|
self.in_features = in_features |
|
|
self.out_features = out_features |
|
|
self.non_negative = non_negative |
|
|
|
|
|
|
|
|
self.features_per_output_class = in_features // out_features |
|
|
|
|
|
|
|
|
assert ( |
|
|
in_features % out_features == 0 |
|
|
), f"{in_features = } must be divisible by {out_features = }" |
|
|
|
|
|
|
|
|
self.weight = nn.Parameter( |
|
|
torch.empty( |
|
|
(out_features, self.features_per_output_class), **factory_kwargs |
|
|
) |
|
|
) |
|
|
if bias: |
|
|
self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs)) |
|
|
else: |
|
|
self.register_parameter("bias", None) |
|
|
|
|
|
|
|
|
self.reset_parameters() |
|
|
|
|
|
def reset_parameters(self) -> None: |
|
|
""" |
|
|
Initialize the weights and biases. |
|
|
Weights are initialized using Kaiming uniform initialization. |
|
|
Biases are initialized using a uniform distribution. |
|
|
""" |
|
|
|
|
|
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) |
|
|
|
|
|
if self.bias is not None: |
|
|
|
|
|
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) |
|
|
|
|
|
|
|
|
bound = 1 / math.sqrt(fan_in) |
|
|
nn.init.uniform_(self.bias, -bound, bound) |
|
|
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass for the custom linear layer. |
|
|
|
|
|
Args: |
|
|
input (Tensor): Input tensor of shape (batch_size, in_features). |
|
|
|
|
|
Returns: |
|
|
Tensor: Output tensor of shape (batch_size, out_features). |
|
|
""" |
|
|
batch_size = input.size(0) |
|
|
|
|
|
reshaped_input = input.view( |
|
|
batch_size, self.out_features, self.features_per_output_class |
|
|
) |
|
|
|
|
|
|
|
|
weight = torch.relu(self.weight) if self.non_negative else self.weight |
|
|
|
|
|
|
|
|
output = torch.einsum("bof,of->bo", reshaped_input, weight) |
|
|
|
|
|
if self.bias is not None: |
|
|
output += self.bias |
|
|
|
|
|
return output |
|
|
|
|
|
def extra_repr(self) -> str: |
|
|
return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}" |
|
|
|
|
|
|
|
|
class AudioProtoNetClassificationHead(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
config: AudioProtoNetConfig, |
|
|
) -> None: |
|
|
""" |
|
|
PPNet is a class that implements the Prototypical Part Network (ProtoPNet) for prototype-based classification. |
|
|
""" |
|
|
|
|
|
super().__init__() |
|
|
self.prototypes_per_class = config.prototypes_per_class |
|
|
self.num_classes = config.num_classes |
|
|
self.num_prototypes = self.prototypes_per_class * self.num_classes |
|
|
self.num_prototypes_after_pruning = config.num_prototypes_after_pruning |
|
|
self.margin = config.margin |
|
|
self.relu_on_cos = config.relu_on_cos |
|
|
self.incorrect_class_connection = config.incorrect_class_connection |
|
|
self.correct_class_connection = config.correct_class_connection |
|
|
self.input_vector_length = config.input_vector_length |
|
|
self.n_eps_channels = config.n_eps_channels |
|
|
self.epsilon_val = config.epsilon_val |
|
|
self.topk_k = config.topk_k |
|
|
self.bias_last_layer = config.bias_last_layer |
|
|
self.non_negative_last_layer = config.non_negative_last_layer |
|
|
self.embedded_spectrogram_height = config.embedded_spectrogram_height |
|
|
self.use_bias_last_layer = config.use_bias_last_layer |
|
|
self.prototype_class_identity = config.prototype_class_identity |
|
|
|
|
|
|
|
|
self.prototype_class_identity = ( |
|
|
torch.arange(self.num_prototypes) // self.prototypes_per_class |
|
|
) |
|
|
|
|
|
self.prototype_shape = (self.num_prototypes, config.channels, config.height, config.width) |
|
|
|
|
|
self._setup_add_on_layers(add_on_layers_type=config.add_on_layers_type) |
|
|
|
|
|
self.prototype_vectors = nn.Parameter( |
|
|
torch.rand(self.prototype_shape), requires_grad=True |
|
|
) |
|
|
|
|
|
self.frequency_weights = None |
|
|
if self.embedded_spectrogram_height is not None: |
|
|
|
|
|
self.frequency_weights = nn.Parameter( |
|
|
torch.full( |
|
|
( |
|
|
self.num_prototypes, |
|
|
self.embedded_spectrogram_height, |
|
|
), |
|
|
3.0, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
if self.incorrect_class_connection: |
|
|
if self.non_negative_last_layer: |
|
|
self.last_layer = NonNegativeLinear( |
|
|
self.num_prototypes, self.num_classes, bias=self.use_bias_last_layer |
|
|
) |
|
|
else: |
|
|
self.last_layer = nn.Linear( |
|
|
self.num_prototypes, self.num_classes, bias=self.use_bias_last_layer |
|
|
) |
|
|
else: |
|
|
self.last_layer = LinearLayerWithoutNegativeConnections( |
|
|
in_features=self.num_prototypes, |
|
|
out_features=self.num_classes, |
|
|
non_negative=self.non_negative_last_layer, |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
features: torch.Tensor, |
|
|
prototypes_of_wrong_class: torch.Tensor = None, |
|
|
) -> tuple[torch.Tensor, list[torch.Tensor]]: |
|
|
""" |
|
|
Forward pass of the PPNet model. |
|
|
|
|
|
Args: |
|
|
- x (torch.Tensor): Input tensor with shape (batch_size, num_channels, height, width). |
|
|
- prototypes_of_wrong_class (Optional[torch.Tensor]): The prototypes of the wrong classes that are needed |
|
|
when using subtractive margins. Defaults to None. |
|
|
|
|
|
Returns: |
|
|
Tuple[torch.Tensor, List[torch.Tensor]]: |
|
|
- logits: A tensor containing the logits for each class in the model. |
|
|
- a list containing: |
|
|
- mean_activations: A tensor containing the mean of the top-k prototype activations. |
|
|
(in evaluation mode k is always 1) |
|
|
- marginless_logits: A tensor containing the logits for each class in the model, calculated using the |
|
|
marginless activations. |
|
|
- conv_features: A tensor containing the convolutional features. |
|
|
- marginless_max_activations: A tensor containing the max-pooled marginless activations. |
|
|
|
|
|
""" |
|
|
|
|
|
features = self.add_on_layers(features) |
|
|
|
|
|
activations, additional_returns = self.prototype_activations( |
|
|
features, prototypes_of_wrong_class=prototypes_of_wrong_class |
|
|
) |
|
|
marginless_activations = additional_returns[0] |
|
|
conv_features = additional_returns[1] |
|
|
|
|
|
|
|
|
topk_k = 1 |
|
|
|
|
|
|
|
|
activations = activations.view(activations.shape[0], activations.shape[1], -1) |
|
|
|
|
|
|
|
|
|
|
|
topk_activations, _ = torch.topk(activations, topk_k, dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
mean_activations = torch.mean(topk_activations, dim=-1) |
|
|
|
|
|
marginless_max_activations = nn.functional.max_pool2d( |
|
|
marginless_activations, |
|
|
kernel_size=( |
|
|
marginless_activations.size()[2], |
|
|
marginless_activations.size()[3], |
|
|
), |
|
|
) |
|
|
marginless_max_activations = marginless_max_activations.view( |
|
|
-1, self.num_prototypes |
|
|
) |
|
|
|
|
|
logits = self.last_layer(mean_activations) |
|
|
marginless_logits = self.last_layer(marginless_max_activations) |
|
|
return logits, [ |
|
|
mean_activations, |
|
|
marginless_logits, |
|
|
conv_features, |
|
|
marginless_max_activations, |
|
|
marginless_activations, |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def cos_activation( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
prototypes_of_wrong_class: torch.Tensor = None, |
|
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Compute the cosine activation between input tensor x and prototype vectors. |
|
|
|
|
|
Parameters: |
|
|
----------- |
|
|
x : torch.Tensor |
|
|
Input tensor with shape (batch_size, num_channels, height, width). |
|
|
prototypes_of_wrong_class : Optional[torch.Tensor] |
|
|
Tensor containing the prototypes of the wrong class with shape (batch_size, num_prototypes). |
|
|
|
|
|
Returns: |
|
|
-------- |
|
|
Tuple[torch.Tensor, torch.Tensor] |
|
|
A tuple containing: |
|
|
- activations: The cosine activations with potential margin adjustments. |
|
|
- marginless_activations: The cosine activations without margin adjustments. |
|
|
""" |
|
|
input_vector_length = self.input_vector_length |
|
|
normalizing_factor = ( |
|
|
self.prototype_shape[-2] * self.prototype_shape[-1] |
|
|
) ** 0.5 |
|
|
|
|
|
|
|
|
epsilon_channel_x = torch.full( |
|
|
(x.shape[0], self.n_eps_channels, x.shape[2], x.shape[3]), |
|
|
self.epsilon_val, |
|
|
device=x.device, |
|
|
requires_grad=False, |
|
|
) |
|
|
x = torch.cat((x, epsilon_channel_x), dim=-3) |
|
|
|
|
|
|
|
|
x_length = torch.sqrt(torch.sum(x**2, dim=-3, keepdim=True) + self.epsilon_val) |
|
|
x_normalized = (input_vector_length * x / x_length) / normalizing_factor |
|
|
|
|
|
|
|
|
epsilon_channel_p = torch.full( |
|
|
( |
|
|
self.prototype_shape[0], |
|
|
self.n_eps_channels, |
|
|
self.prototype_shape[2], |
|
|
self.prototype_shape[3], |
|
|
), |
|
|
self.epsilon_val, |
|
|
device=self.prototype_vectors.device, |
|
|
requires_grad=False, |
|
|
) |
|
|
appended_protos = torch.cat((self.prototype_vectors, epsilon_channel_p), dim=-3) |
|
|
|
|
|
|
|
|
prototype_vector_length = torch.sqrt( |
|
|
torch.sum(appended_protos**2, dim=-3, keepdim=True) + self.epsilon_val |
|
|
) |
|
|
normalized_prototypes = appended_protos / ( |
|
|
prototype_vector_length + self.epsilon_val |
|
|
) |
|
|
normalized_prototypes /= normalizing_factor |
|
|
|
|
|
|
|
|
activations_dot = nn.functional.conv2d(x_normalized, normalized_prototypes) |
|
|
marginless_activations = activations_dot / (input_vector_length * 1.01) |
|
|
|
|
|
if self.frequency_weights is not None: |
|
|
|
|
|
freq_weights = torch.sigmoid(self.frequency_weights) |
|
|
|
|
|
|
|
|
marginless_activations = marginless_activations * freq_weights[:, :, None] |
|
|
|
|
|
if ( |
|
|
self.margin is None |
|
|
or not self.training |
|
|
or prototypes_of_wrong_class is None |
|
|
): |
|
|
activations = marginless_activations |
|
|
else: |
|
|
|
|
|
wrong_class_margin = (prototypes_of_wrong_class * self.margin).view( |
|
|
x.size(0), self.prototype_vectors.size(0), 1, 1 |
|
|
) |
|
|
wrong_class_margin = wrong_class_margin.expand( |
|
|
-1, -1, activations_dot.size(-2), activations_dot.size(-1) |
|
|
) |
|
|
penalized_angles = ( |
|
|
torch.acos(activations_dot / (input_vector_length * 1.01)) |
|
|
- wrong_class_margin |
|
|
) |
|
|
activations = torch.cos(torch.relu(penalized_angles)) |
|
|
|
|
|
if self.relu_on_cos: |
|
|
|
|
|
activations = torch.relu(activations) |
|
|
marginless_activations = torch.relu(marginless_activations) |
|
|
|
|
|
return activations, marginless_activations |
|
|
|
|
|
def prototype_activations( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
prototypes_of_wrong_class: torch.Tensor = None, |
|
|
) -> tuple[torch.Tensor, list[torch.Tensor]]: |
|
|
""" |
|
|
Compute the prototype activations for a given input tensor. |
|
|
|
|
|
Args: |
|
|
- x (torch.Tensor): The raw input tensor with shape (batch_size, num_channels, height, width). |
|
|
- prototypes_of_wrong_class (Optional[torch.Tensor]): The prototypes of the wrong classes that are needed |
|
|
when using subtractive margins. Defaults to None. |
|
|
|
|
|
Returns: |
|
|
Tuple[torch.Tensor, List[torch.Tensor]]: |
|
|
- activations: A tensor containing the prototype activations. |
|
|
- a list containing: |
|
|
- marginless_activations: A tensor containing the activations before applying subtractive margin. |
|
|
- conv_features: A tensor containing the convolutional features. |
|
|
""" |
|
|
|
|
|
activations, marginless_activations = self.cos_activation( |
|
|
x, |
|
|
prototypes_of_wrong_class=prototypes_of_wrong_class, |
|
|
) |
|
|
|
|
|
return activations, [marginless_activations, x] |
|
|
|
|
|
def get_prototype_orthogonalities(self, use_part_prototypes: bool = False) -> torch.Tensor: |
|
|
""" |
|
|
Computes the orthogonality loss, encouraging each piece of a prototype to be orthogonal to the others. |
|
|
|
|
|
This method is inspired by the paper: |
|
|
https://openaccess.thecvf.com/content/ICCV2021/papers/Wang_Interpretable_Image_Recognition_by_Constructing_Transparent_Embedding_Space_ICCV_2021_paper.pdf |
|
|
|
|
|
Args: |
|
|
use_part_prototypes (bool): If True, treats each spatial part of the prototypes as a separate prototype. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: A tensor representing the orthogonalities. |
|
|
""" |
|
|
|
|
|
if use_part_prototypes: |
|
|
|
|
|
prototype_vector_length = torch.sqrt( |
|
|
torch.sum(torch.square(self.prototype_vectors), dim=1, keepdim=True) |
|
|
+ self.epsilon_val |
|
|
) |
|
|
normalized_prototypes = self.prototype_vectors / ( |
|
|
prototype_vector_length + self.epsilon_val |
|
|
) |
|
|
|
|
|
|
|
|
num_part_prototypes_per_class = ( |
|
|
self.num_prototypes_per_class |
|
|
* self.prototype_shape[2] |
|
|
* self.prototype_shape[3] |
|
|
) |
|
|
|
|
|
|
|
|
normalized_prototypes = normalized_prototypes.view( |
|
|
self.num_classes, |
|
|
self.num_prototypes_per_class, |
|
|
self.prototype_shape[1], |
|
|
self.prototype_shape[2] * self.prototype_shape[3], |
|
|
) |
|
|
|
|
|
|
|
|
normalized_prototypes = normalized_prototypes.permute(0, 1, 3, 2).reshape( |
|
|
self.num_classes, num_part_prototypes_per_class, self.prototype_shape[1] |
|
|
) |
|
|
|
|
|
else: |
|
|
|
|
|
prototype_vectors_reshaped = self.prototype_vectors.view( |
|
|
self.num_prototypes, -1 |
|
|
) |
|
|
prototype_vector_length = torch.sqrt( |
|
|
torch.sum(torch.square(prototype_vectors_reshaped), dim=1, keepdim=True) |
|
|
+ self.epsilon_val |
|
|
) |
|
|
normalized_prototypes = prototype_vectors_reshaped / ( |
|
|
prototype_vector_length + self.epsilon_val |
|
|
) |
|
|
|
|
|
|
|
|
normalized_prototypes = normalized_prototypes.view( |
|
|
self.num_classes, |
|
|
self.num_prototypes_per_class, |
|
|
self.prototype_shape[1] |
|
|
* self.prototype_shape[2] |
|
|
* self.prototype_shape[3], |
|
|
) |
|
|
|
|
|
|
|
|
orthogonalities = torch.matmul( |
|
|
normalized_prototypes, normalized_prototypes.transpose(1, 2) |
|
|
) |
|
|
|
|
|
|
|
|
identity_matrix = ( |
|
|
torch.eye(normalized_prototypes.shape[1], device=orthogonalities.device) |
|
|
.unsqueeze(0) |
|
|
.repeat(self.num_classes, 1, 1) |
|
|
) |
|
|
|
|
|
|
|
|
orthogonalities = orthogonalities - identity_matrix |
|
|
|
|
|
return orthogonalities |
|
|
|
|
|
def identify_prototypes_to_prune(self) -> list[int]: |
|
|
""" |
|
|
Identifies the indices of prototypes that should be pruned. |
|
|
|
|
|
This function iterates through the prototypes and checks if the specific weight |
|
|
connecting the prototype to its class is zero. It is specifically designed to handle |
|
|
the LinearLayerWithoutNegativeConnections where each class has a subset of features |
|
|
it connects to. |
|
|
|
|
|
Returns: |
|
|
list[int]: A list of prototype indices that should be pruned. |
|
|
""" |
|
|
prototypes_to_prune = [] |
|
|
|
|
|
|
|
|
prototypes_per_class = self.num_prototypes // self.num_classes |
|
|
|
|
|
if isinstance(self.last_layer, LinearLayerWithoutNegativeConnections): |
|
|
|
|
|
for prototype_index in range(self.num_prototypes): |
|
|
class_index = self.prototype_class_identity[prototype_index] |
|
|
|
|
|
index_within_class = prototype_index % prototypes_per_class |
|
|
|
|
|
if self.last_layer.weight[class_index, index_within_class] == 0.0: |
|
|
prototypes_to_prune.append(prototype_index) |
|
|
else: |
|
|
|
|
|
weights_to_check = self.last_layer.weight |
|
|
for prototype_index in range(self.num_prototypes): |
|
|
class_index = self.prototype_class_identity[prototype_index] |
|
|
if weights_to_check[class_index, prototype_index] == 0.0: |
|
|
prototypes_to_prune.append(prototype_index) |
|
|
|
|
|
return prototypes_to_prune |
|
|
|
|
|
def prune_prototypes_by_threshold(self, threshold: float = 1e-3) -> None: |
|
|
""" |
|
|
Prune the weights in the classification layer by setting weights below a specified threshold to zero. |
|
|
|
|
|
This method modifies the weights of the last layer of the model in-place. Weights falling below the |
|
|
threshold are set to zero, diminishing their influence in the model's decisions. It also identifies |
|
|
and prunes prototypes based on these updated weights, thereby refining the model's structure. |
|
|
|
|
|
Args: |
|
|
threshold (float): The threshold value below which weights will be set to zero. Defaults to 1e-3. |
|
|
""" |
|
|
|
|
|
weights = self.last_layer.weight.data |
|
|
|
|
|
|
|
|
|
|
|
weights[weights < threshold] = 0.0 |
|
|
|
|
|
|
|
|
self.last_layer.weight.data.copy_(weights) |
|
|
|
|
|
|
|
|
prototypes_to_prune = self.identify_prototypes_to_prune() |
|
|
|
|
|
|
|
|
self.prune_prototypes_by_index(prototypes_to_prune) |
|
|
|
|
|
def prune_prototypes_by_index(self, prototypes_to_prune: list[int]) -> None: |
|
|
""" |
|
|
Prunes specified prototypes from the PPNet. |
|
|
|
|
|
Args: |
|
|
prototypes_to_prune (list[int]): A list of indices indicating the prototypes to be removed. |
|
|
Each index should be in the range [0, current number of prototypes - 1]. |
|
|
|
|
|
Returns: |
|
|
None |
|
|
""" |
|
|
|
|
|
|
|
|
if any( |
|
|
index < 0 or index >= self.num_prototypes for index in prototypes_to_prune |
|
|
): |
|
|
raise ValueError("Provided prototype indices are out of valid range!") |
|
|
|
|
|
|
|
|
self.num_prototypes_after_pruning = self.num_prototypes - len( |
|
|
prototypes_to_prune |
|
|
) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
if self.frequency_weights is not None: |
|
|
self.frequency_weights.data[prototypes_to_prune, :] = -7.0 |
|
|
|
|
|
|
|
|
if isinstance(self.last_layer, LinearLayerWithoutNegativeConnections): |
|
|
|
|
|
|
|
|
for class_idx in range(self.last_layer.out_features): |
|
|
|
|
|
indices_for_class = [ |
|
|
idx % self.last_layer.features_per_output_class |
|
|
for idx in prototypes_to_prune |
|
|
if self.prototype_class_identity[idx] == class_idx |
|
|
] |
|
|
self.last_layer.weight.data[class_idx, indices_for_class] = 0.0 |
|
|
else: |
|
|
|
|
|
self.last_layer.weight.data[:, prototypes_to_prune] = 0.0 |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
rep = f"""PPNet( |
|
|
prototype_shape: {self.prototype_shape}, |
|
|
num_classes: {self.num_classes}, |
|
|
epsilon: {self.epsilon_val})""" |
|
|
|
|
|
return rep |
|
|
|
|
|
def set_last_layer_incorrect_connection( |
|
|
self, incorrect_strength: float = None |
|
|
) -> None: |
|
|
""" |
|
|
Modifies the last layer weights to have incorrect connections with a specified strength. |
|
|
If incorrect_strength is None, initializes the weights for LinearLayerWithoutNegativeConnections |
|
|
with correct_class_connection value. |
|
|
|
|
|
Args: |
|
|
- incorrect_strength (Optional[float]): The strength of the incorrect connections. |
|
|
If None, initialize without incorrect connections. |
|
|
|
|
|
Returns: |
|
|
None |
|
|
""" |
|
|
if incorrect_strength is None: |
|
|
|
|
|
if isinstance(self.last_layer, LinearLayerWithoutNegativeConnections): |
|
|
|
|
|
self.last_layer.weight.data.fill_(self.correct_class_connection) |
|
|
else: |
|
|
raise ValueError( |
|
|
"last_layer is not an instance of LinearLayerWithoutNegativeConnections" |
|
|
) |
|
|
|
|
|
else: |
|
|
|
|
|
positive_one_weights_locations = torch.zeros( |
|
|
self.num_classes, self.num_prototypes |
|
|
) |
|
|
positive_one_weights_locations[ |
|
|
self.prototype_class_identity, |
|
|
torch.arange(self.num_prototypes), |
|
|
] = 1 |
|
|
|
|
|
|
|
|
negative_one_weights_locations = 1 - positive_one_weights_locations |
|
|
|
|
|
|
|
|
correct_class_connection = self.correct_class_connection |
|
|
|
|
|
|
|
|
incorrect_class_connection = incorrect_strength |
|
|
|
|
|
|
|
|
self.last_layer.weight.data.copy_( |
|
|
correct_class_connection * positive_one_weights_locations |
|
|
+ incorrect_class_connection * negative_one_weights_locations |
|
|
) |
|
|
|
|
|
if self.last_layer.bias is not None: |
|
|
|
|
|
self.last_layer.bias.data.fill_(self.bias_last_layer) |
|
|
|
|
|
def _setup_add_on_layers(self, add_on_layers_type: str): |
|
|
""" |
|
|
Configures additional layers based on the backbone model architecture and the specified add_on_layers_type. |
|
|
|
|
|
Args: |
|
|
add_on_layers_type (str): Type of additional layers to add. Can be 'identity' or 'upsample'. |
|
|
""" |
|
|
|
|
|
if add_on_layers_type == "identity": |
|
|
self.add_on_layers = nn.Sequential(nn.Identity()) |
|
|
elif add_on_layers_type == "upsample": |
|
|
self.add_on_layers = nn.Upsample(scale_factor=2, mode="bilinear") |
|
|
else: |
|
|
raise NotImplementedError( |
|
|
f"The add-on layer type {add_on_layers_type} isn't implemented yet." |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AudioProtoNetPreTrainedModel(PreTrainedModel): |
|
|
config_class = AudioProtoNetConfig |
|
|
base_model_prefix = "model" |
|
|
|
|
|
def _init_weights(self, module): |
|
|
if isinstance(module, (nn.Conv2d, nn.Linear)): |
|
|
nn.init.trunc_normal_(module.weight, std=0.02) |
|
|
if module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
if isinstance(module, (nn.Conv2d, nn.Linear)): |
|
|
nn.init.trunc_normal_(module.weight, std=0.02) |
|
|
if module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
if isinstance(module, LinearLayerWithoutNegativeConnections) and self.incorrect_class_connection is None: |
|
|
|
|
|
self.last_layer.weight.data.fill_(self.correct_class_connection) |
|
|
|
|
|
|
|
|
class AudioProtoNetModel(AudioProtoNetPreTrainedModel): |
|
|
_auto_class = "AutoModel" |
|
|
|
|
|
def __init__(self, config: AudioProtoNetConfig): |
|
|
super().__init__(config) |
|
|
backbone_config = ConvNextConfig.from_pretrained("facebook/convnext-base-224-22k", num_channels=1) |
|
|
self.backbone = ConvNextModel(backbone_config) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_values: torch.Tensor, |
|
|
output_hidden_states: bool = None |
|
|
) -> BaseModelOutputWithPoolingAndNoAttention: |
|
|
""" |
|
|
Args: |
|
|
input_values: |
|
|
output_hidden_states: |
|
|
|
|
|
Returns: |
|
|
last_hidden_state: torch.FloatTensor = None |
|
|
pooler_output: torch.FloatTensor = None |
|
|
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None |
|
|
|
|
|
""" |
|
|
return self.backbone(input_values, output_hidden_states) |
|
|
|
|
|
|
|
|
class AudioProtoNetForSequenceClassification(AudioProtoNetPreTrainedModel): |
|
|
_auto_class = "AutoModelForSequenceClassification" |
|
|
|
|
|
def __init__(self, config: AudioProtoNetConfig): |
|
|
super().__init__(config) |
|
|
|
|
|
self.model = AudioProtoNetModel(config) |
|
|
self.head = AudioProtoNetClassificationHead(config) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_values: torch.Tensor, |
|
|
labels: torch.Tensor = None, |
|
|
prototypes_of_wrong_class: torch.Tensor = None, |
|
|
output_hidden_states: bool = None, |
|
|
output_prototypical_activations: bool = None, |
|
|
) -> SequenceClassifierOutputWithProtoTypeActivations: |
|
|
|
|
|
backbone_outputs = self.model(input_values, output_hidden_states) |
|
|
|
|
|
last_hidden_state = backbone_outputs[0] |
|
|
|
|
|
logits, info = self.head(last_hidden_state, prototypes_of_wrong_class) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
labels.to(logits.device) |
|
|
loss_fct = AsymmetricLossMultiLabel() |
|
|
loss = loss_fct(logits, labels.float()) |
|
|
|
|
|
|
|
|
|
|
|
hidden_states = None |
|
|
if output_hidden_states is not None: |
|
|
hidden_states = backbone_outputs[2] |
|
|
|
|
|
prototype_activations = None |
|
|
if output_prototypical_activations is not None: |
|
|
prototype_activations = info[4] |
|
|
|
|
|
return SequenceClassifierOutputWithProtoTypeActivations( |
|
|
logits=logits, |
|
|
loss=loss, |
|
|
last_hidden_state=last_hidden_state, |
|
|
hidden_states=hidden_states, |
|
|
prototype_activations=prototype_activations |
|
|
) |
|
|
|