AudioProtoPNet-5-BirdSet-XCL / modeling_protonet.py
mwirth7's picture
Update modeling_protonet.py
9877d4a verified
import torch
import torch.nn as nn
from torch import Tensor
import math
from transformers import PreTrainedModel, ConvNextModel, ConvNextConfig
from transformers.utils import logging
from transformers.modeling_outputs import ModelOutput, BaseModelOutputWithPoolingAndNoAttention
from dataclasses import dataclass
from .configuration_protonet import AudioProtoNetConfig
logger = logging.get_logger(__name__)
@dataclass
class SequenceClassifierOutputWithProtoTypeActivations(ModelOutput):
logits: torch.Tensor
loss: torch.Tensor = None
last_hidden_state: torch.FloatTensor = None
hidden_states: tuple[torch.FloatTensor, ...] = None
prototype_activations: torch.FloatTensor = None
# https://openaccess.thecvf.com/content/ICCV2021/papers/Ridnik_Asymmetric_Loss_for_Multi-Label_Classification_ICCV_2021_paper.pdf
# https://github.com/huggingface/pytorch-image-models/blob/bbe798317fb26f063c18279827c038058e376479/timm/loss/asymmetric_loss.py#L6
class AsymmetricLossMultiLabel(nn.Module):
def __init__(
self,
gamma_neg=4,
gamma_pos=1,
clip=0.05,
eps=1e-8,
disable_torch_grad_focal_loss=False,
reduction="mean",
):
super().__init__()
self.gamma_neg = gamma_neg
self.gamma_pos = gamma_pos
self.clip = clip
self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
self.eps = eps
self.reduction = reduction
def forward(self, x, y):
""" "
Parameters
----------
x: input logits
y: targets (multi-label binarized vector)
"""
# Calculating Probabilities
x_sigmoid = torch.sigmoid(x)
xs_pos = x_sigmoid
xs_neg = 1 - x_sigmoid
# Asymmetric Clipping
if self.clip is not None and self.clip > 0:
xs_neg = (xs_neg + self.clip).clamp(max=1)
# Basic CE calculation
los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
loss = los_pos + los_neg
# Asymmetric Focusing
if self.gamma_neg > 0 or self.gamma_pos > 0:
if self.disable_torch_grad_focal_loss:
torch._C.set_grad_enabled(False)
pt0 = xs_pos * y
pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
pt = pt0 + pt1
one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
one_sided_w = torch.pow(1 - pt, one_sided_gamma)
if self.disable_torch_grad_focal_loss:
torch._C.set_grad_enabled(True)
loss *= one_sided_w
if self.reduction == "mean":
return -loss.mean()
if self.reduction == "sum":
return -loss.sum()
return -loss
class NonNegativeLinear(nn.Module):
"""
A PyTorch module for a linear layer with non-negative weights.
This module applies a linear transformation to the incoming data: `y = xA^T + b`.
The weights of the transformation are constrained to be non-negative, making this
module particularly useful in models where negative weights may not be appropriate.
Attributes:
in_features (int): The number of features in the input tensor.
out_features (int): The number of features in the output tensor.
weight (torch.Tensor): The weight parameter of the module, constrained to be non-negative.
bias (torch.Tensor, optional): The bias parameter of the module.
Args:
in_features (int): The number of features in the input tensor.
out_features (int): The number of features in the output tensor.
bias (bool, optional): If True, the layer will include a learnable bias. Default: True.
device (optional): The device (CPU/GPU) on which to perform computations.
dtype (optional): The data type for the parameters (e.g., float32).
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(
torch.empty((out_features, in_features), **factory_kwargs)
)
if bias:
self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
else:
self.register_parameter("bias", None)
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""
Defines the forward pass of the NonNegativeLinear module.
Args:
input (torch.Tensor): The input tensor of shape (batch_size, in_features).
Returns:
torch.Tensor: The output tensor of shape (batch_size, out_features).
"""
return nn.functional.linear(input, torch.relu(self.weight), self.bias)
class LinearLayerWithoutNegativeConnections(nn.Module):
r"""
Custom Linear Layer where each output class is connected to a specific subset of input features.
Args:
in_features: size of each input sample
out_features: size of each output sample
bias: If set to ``False``, the layer will not learn an additive bias.
Default: ``True``
device: the device of the module parameters. Default: ``None``
dtype: the data type of the module parameters. Default: ``None``
Shape:
- Input: :math:`(*, H_{in})` where :math:`*` means any number of
dimensions including none and :math:`H_{in} = \text{in_features}`.
- Output: :math:`(*, H_{out})` where all but the last dimension
are the same shape as the input and :math:`H_{out} = \text{out_features}`.
Attributes:
weight: the learnable weights of the module of shape
:math:`(\text{out_features}, \text{features_per_output_class})`.
bias: the learnable bias of the module of shape :math:`(\text{out_features})`.
If :attr:`bias` is ``True``, the values are initialized from
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{1}{\text{features_per_output_class}}`
"""
__constants__ = ["in_features", "out_features", "bias"]
in_features: int
out_features: int
weight: torch.Tensor
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
non_negative: bool = True,
device: torch.device = None,
dtype: torch.dtype = None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.non_negative = non_negative
# Calculate the number of features per output class
self.features_per_output_class = in_features // out_features
# Ensure input size is divisible by the output size
assert (
in_features % out_features == 0
), f"{in_features = } must be divisible by {out_features = }"
# Define weights and biases
self.weight = nn.Parameter(
torch.empty(
(out_features, self.features_per_output_class), **factory_kwargs
)
)
if bias:
self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
else:
self.register_parameter("bias", None)
# Initialize weights and biases
self.reset_parameters()
def reset_parameters(self) -> None:
"""
Initialize the weights and biases.
Weights are initialized using Kaiming uniform initialization.
Biases are initialized using a uniform distribution.
"""
# Kaiming uniform initialization for the weights
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
# Calculate fan-in and fan-out values
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
# Uniform initialization for the biases
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the custom linear layer.
Args:
input (Tensor): Input tensor of shape (batch_size, in_features).
Returns:
Tensor: Output tensor of shape (batch_size, out_features).
"""
batch_size = input.size(0)
# Reshape input to (batch_size, out_features, features_per_output_class)
reshaped_input = input.view(
batch_size, self.out_features, self.features_per_output_class
)
# Apply ReLU to weights if non_negative_last_layer is True
weight = torch.relu(self.weight) if self.non_negative else self.weight
# Perform batch matrix multiplication and add bias
output = torch.einsum("bof,of->bo", reshaped_input, weight)
if self.bias is not None:
output += self.bias
return output
def extra_repr(self) -> str:
return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}"
class AudioProtoNetClassificationHead(nn.Module):
def __init__(
self,
config: AudioProtoNetConfig,
) -> None:
"""
PPNet is a class that implements the Prototypical Part Network (ProtoPNet) for prototype-based classification.
"""
super().__init__()
self.prototypes_per_class = config.prototypes_per_class
self.num_classes = config.num_classes
self.num_prototypes = self.prototypes_per_class * self.num_classes
self.num_prototypes_after_pruning = config.num_prototypes_after_pruning
self.margin = config.margin
self.relu_on_cos = config.relu_on_cos
self.incorrect_class_connection = config.incorrect_class_connection
self.correct_class_connection = config.correct_class_connection
self.input_vector_length = config.input_vector_length
self.n_eps_channels = config.n_eps_channels
self.epsilon_val = config.epsilon_val
self.topk_k = config.topk_k
self.bias_last_layer = config.bias_last_layer
self.non_negative_last_layer = config.non_negative_last_layer
self.embedded_spectrogram_height = config.embedded_spectrogram_height
self.use_bias_last_layer = config.use_bias_last_layer
self.prototype_class_identity = config.prototype_class_identity
# Create a 1D tensor where each element represents the class index
self.prototype_class_identity = (
torch.arange(self.num_prototypes) // self.prototypes_per_class
)
self.prototype_shape = (self.num_prototypes, config.channels, config.height, config.width)
self._setup_add_on_layers(add_on_layers_type=config.add_on_layers_type)
self.prototype_vectors = nn.Parameter(
torch.rand(self.prototype_shape), requires_grad=True
)
self.frequency_weights = None
if self.embedded_spectrogram_height is not None:
# Initialize the frequency weights with a large positive value of 3.0 so that sigmoid(frequency_weights) is close to 1.
self.frequency_weights = nn.Parameter(
torch.full(
(
self.num_prototypes,
self.embedded_spectrogram_height,
),
3.0,
)
)
if self.incorrect_class_connection:
if self.non_negative_last_layer:
self.last_layer = NonNegativeLinear(
self.num_prototypes, self.num_classes, bias=self.use_bias_last_layer
)
else:
self.last_layer = nn.Linear(
self.num_prototypes, self.num_classes, bias=self.use_bias_last_layer
)
else:
self.last_layer = LinearLayerWithoutNegativeConnections(
in_features=self.num_prototypes,
out_features=self.num_classes,
non_negative=self.non_negative_last_layer,
)
def forward(
self,
features: torch.Tensor,
prototypes_of_wrong_class: torch.Tensor = None,
) -> tuple[torch.Tensor, list[torch.Tensor]]:
"""
Forward pass of the PPNet model.
Args:
- x (torch.Tensor): Input tensor with shape (batch_size, num_channels, height, width).
- prototypes_of_wrong_class (Optional[torch.Tensor]): The prototypes of the wrong classes that are needed
when using subtractive margins. Defaults to None.
Returns:
Tuple[torch.Tensor, List[torch.Tensor]]:
- logits: A tensor containing the logits for each class in the model.
- a list containing:
- mean_activations: A tensor containing the mean of the top-k prototype activations.
(in evaluation mode k is always 1)
- marginless_logits: A tensor containing the logits for each class in the model, calculated using the
marginless activations.
- conv_features: A tensor containing the convolutional features.
- marginless_max_activations: A tensor containing the max-pooled marginless activations.
"""
features = self.add_on_layers(features)
activations, additional_returns = self.prototype_activations(
features, prototypes_of_wrong_class=prototypes_of_wrong_class
)
marginless_activations = additional_returns[0]
conv_features = additional_returns[1]
# Set topk_k based on training mode: use predefined value if training, else 1 for evaluation
topk_k = 1
# Reshape activations to combine spatial dimensions: (batch_size, num_prototypes, height*width)
activations = activations.view(activations.shape[0], activations.shape[1], -1)
# Perform top-k pooling along the combined spatial dimension
# For topk_k=1, this is equivalent to global max pooling
topk_activations, _ = torch.topk(activations, topk_k, dim=-1)
# Calculate the mean of the top-k activations for each channel: (batch_size, num_channels)
# If topk_k=1, this mean operation does nothing since there's only one value.
mean_activations = torch.mean(topk_activations, dim=-1)
marginless_max_activations = nn.functional.max_pool2d(
marginless_activations,
kernel_size=(
marginless_activations.size()[2],
marginless_activations.size()[3],
),
)
marginless_max_activations = marginless_max_activations.view(
-1, self.num_prototypes
)
logits = self.last_layer(mean_activations)
marginless_logits = self.last_layer(marginless_max_activations)
return logits, [
mean_activations,
marginless_logits,
conv_features,
marginless_max_activations,
marginless_activations,
]
# def conv_features(self, x: torch.Tensor) -> torch.Tensor:
# """
# Takes an input tensor and passes it through the backbone model to extract features.
# Then, it passes them through the additional layers to produce the output tensor.
#
# Args:
# x (torch.Tensor): The input tensor.
#
# Returns:
# torch.Tensor: The output tensor after passing through the backbone model and additional layers.
# """
# # Extract features using the backbone model
# features = self.backbone_model(x)
#
# # The features must be a 4D tensor of shape (batch size, channels, height, width)
# if features.dim() == 3:
# features.unsqueeze_(0)
#
# # Pass the features through additional layers
# output = self.add_on_layers(features)
#
# return output
def cos_activation(
self,
x: torch.Tensor,
prototypes_of_wrong_class: torch.Tensor = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute the cosine activation between input tensor x and prototype vectors.
Parameters:
-----------
x : torch.Tensor
Input tensor with shape (batch_size, num_channels, height, width).
prototypes_of_wrong_class : Optional[torch.Tensor]
Tensor containing the prototypes of the wrong class with shape (batch_size, num_prototypes).
Returns:
--------
Tuple[torch.Tensor, torch.Tensor]
A tuple containing:
- activations: The cosine activations with potential margin adjustments.
- marginless_activations: The cosine activations without margin adjustments.
"""
input_vector_length = self.input_vector_length
normalizing_factor = (
self.prototype_shape[-2] * self.prototype_shape[-1]
) ** 0.5
# Pre-allocate epsilon channels on the correct device for input tensor x
epsilon_channel_x = torch.full(
(x.shape[0], self.n_eps_channels, x.shape[2], x.shape[3]),
self.epsilon_val,
device=x.device,
requires_grad=False,
)
x = torch.cat((x, epsilon_channel_x), dim=-3)
# Normalize x
x_length = torch.sqrt(torch.sum(x**2, dim=-3, keepdim=True) + self.epsilon_val)
x_normalized = (input_vector_length * x / x_length) / normalizing_factor
# Pre-allocate epsilon channels for prototypes on the correct device
epsilon_channel_p = torch.full(
(
self.prototype_shape[0],
self.n_eps_channels,
self.prototype_shape[2],
self.prototype_shape[3],
),
self.epsilon_val,
device=self.prototype_vectors.device,
requires_grad=False,
)
appended_protos = torch.cat((self.prototype_vectors, epsilon_channel_p), dim=-3)
# Normalize prototypes
prototype_vector_length = torch.sqrt(
torch.sum(appended_protos**2, dim=-3, keepdim=True) + self.epsilon_val
)
normalized_prototypes = appended_protos / (
prototype_vector_length + self.epsilon_val
)
normalized_prototypes /= normalizing_factor
# Compute activations using convolution
activations_dot = nn.functional.conv2d(x_normalized, normalized_prototypes)
marginless_activations = activations_dot / (input_vector_length * 1.01)
if self.frequency_weights is not None:
# Apply sigmoid to frequency weights. s.t. weights are between 0 and 1.
freq_weights = torch.sigmoid(self.frequency_weights)
# Multiply each prototype's frequency response by the corresponding weights
marginless_activations = marginless_activations * freq_weights[:, :, None]
if (
self.margin is None
or not self.training
or prototypes_of_wrong_class is None
):
activations = marginless_activations
else:
# Apply margin adjustment for wrong class prototypes
wrong_class_margin = (prototypes_of_wrong_class * self.margin).view(
x.size(0), self.prototype_vectors.size(0), 1, 1
)
wrong_class_margin = wrong_class_margin.expand(
-1, -1, activations_dot.size(-2), activations_dot.size(-1)
)
penalized_angles = (
torch.acos(activations_dot / (input_vector_length * 1.01))
- wrong_class_margin
)
activations = torch.cos(torch.relu(penalized_angles))
if self.relu_on_cos:
# Apply ReLU activation on the cosine values
activations = torch.relu(activations)
marginless_activations = torch.relu(marginless_activations)
return activations, marginless_activations
def prototype_activations(
self,
x: torch.Tensor,
prototypes_of_wrong_class: torch.Tensor = None,
) -> tuple[torch.Tensor, list[torch.Tensor]]:
"""
Compute the prototype activations for a given input tensor.
Args:
- x (torch.Tensor): The raw input tensor with shape (batch_size, num_channels, height, width).
- prototypes_of_wrong_class (Optional[torch.Tensor]): The prototypes of the wrong classes that are needed
when using subtractive margins. Defaults to None.
Returns:
Tuple[torch.Tensor, List[torch.Tensor]]:
- activations: A tensor containing the prototype activations.
- a list containing:
- marginless_activations: A tensor containing the activations before applying subtractive margin.
- conv_features: A tensor containing the convolutional features.
"""
# Compute cosine activations
activations, marginless_activations = self.cos_activation(
x,
prototypes_of_wrong_class=prototypes_of_wrong_class,
)
return activations, [marginless_activations, x]
def get_prototype_orthogonalities(self, use_part_prototypes: bool = False) -> torch.Tensor:
"""
Computes the orthogonality loss, encouraging each piece of a prototype to be orthogonal to the others.
This method is inspired by the paper:
https://openaccess.thecvf.com/content/ICCV2021/papers/Wang_Interpretable_Image_Recognition_by_Constructing_Transparent_Embedding_Space_ICCV_2021_paper.pdf
Args:
use_part_prototypes (bool): If True, treats each spatial part of the prototypes as a separate prototype.
Returns:
torch.Tensor: A tensor representing the orthogonalities.
"""
if use_part_prototypes:
# Normalize prototypes to unit length
prototype_vector_length = torch.sqrt(
torch.sum(torch.square(self.prototype_vectors), dim=1, keepdim=True)
+ self.epsilon_val
)
normalized_prototypes = self.prototype_vectors / (
prototype_vector_length + self.epsilon_val
)
# Calculate total part prototypes per class
num_part_prototypes_per_class = (
self.num_prototypes_per_class
* self.prototype_shape[2]
* self.prototype_shape[3]
)
# Reshape to match class structure
normalized_prototypes = normalized_prototypes.view(
self.num_classes,
self.num_prototypes_per_class,
self.prototype_shape[1],
self.prototype_shape[2] * self.prototype_shape[3],
)
# Transpose and reshape to treat each spatial part as a separate prototype
normalized_prototypes = normalized_prototypes.permute(0, 1, 3, 2).reshape(
self.num_classes, num_part_prototypes_per_class, self.prototype_shape[1]
)
else:
# Normalize prototypes to unit length
prototype_vectors_reshaped = self.prototype_vectors.view(
self.num_prototypes, -1
)
prototype_vector_length = torch.sqrt(
torch.sum(torch.square(prototype_vectors_reshaped), dim=1, keepdim=True)
+ self.epsilon_val
)
normalized_prototypes = prototype_vectors_reshaped / (
prototype_vector_length + self.epsilon_val
)
# Reshape to match class structure
normalized_prototypes = normalized_prototypes.view(
self.num_classes,
self.num_prototypes_per_class,
self.prototype_shape[1]
* self.prototype_shape[2]
* self.prototype_shape[3],
)
# Compute orthogonality matrix for each class
orthogonalities = torch.matmul(
normalized_prototypes, normalized_prototypes.transpose(1, 2)
)
# Identity matrix to enforce orthogonality
identity_matrix = (
torch.eye(normalized_prototypes.shape[1], device=orthogonalities.device)
.unsqueeze(0)
.repeat(self.num_classes, 1, 1)
)
# Subtract identity to focus on orthogonality
orthogonalities = orthogonalities - identity_matrix
return orthogonalities
def identify_prototypes_to_prune(self) -> list[int]:
"""
Identifies the indices of prototypes that should be pruned.
This function iterates through the prototypes and checks if the specific weight
connecting the prototype to its class is zero. It is specifically designed to handle
the LinearLayerWithoutNegativeConnections where each class has a subset of features
it connects to.
Returns:
list[int]: A list of prototype indices that should be pruned.
"""
prototypes_to_prune = []
# Calculate the number of prototypes assigned to each class
prototypes_per_class = self.num_prototypes // self.num_classes
if isinstance(self.last_layer, LinearLayerWithoutNegativeConnections):
# Custom layer mapping prototypes to a subset of input features for each output class
for prototype_index in range(self.num_prototypes):
class_index = self.prototype_class_identity[prototype_index]
# Calculate the specific index within the 'features_per_output_class' for this prototype
index_within_class = prototype_index % prototypes_per_class
# Check if the specific weight connecting the prototype to its class is zero
if self.last_layer.weight[class_index, index_within_class] == 0.0:
prototypes_to_prune.append(prototype_index)
else:
# Standard linear layer: each prototype directly maps to a feature index
weights_to_check = self.last_layer.weight
for prototype_index in range(self.num_prototypes):
class_index = self.prototype_class_identity[prototype_index]
if weights_to_check[class_index, prototype_index] == 0.0:
prototypes_to_prune.append(prototype_index)
return prototypes_to_prune
def prune_prototypes_by_threshold(self, threshold: float = 1e-3) -> None:
"""
Prune the weights in the classification layer by setting weights below a specified threshold to zero.
This method modifies the weights of the last layer of the model in-place. Weights falling below the
threshold are set to zero, diminishing their influence in the model's decisions. It also identifies
and prunes prototypes based on these updated weights, thereby refining the model's structure.
Args:
threshold (float): The threshold value below which weights will be set to zero. Defaults to 1e-3.
"""
# Access the weights of the last layer
weights = self.last_layer.weight.data
# Set weights below the threshold to zero
# This step reduces the influence of low-value weights in the model's decision-making process
weights[weights < threshold] = 0.0
# Update the weights in the last layer to reflect the pruning
self.last_layer.weight.data.copy_(weights)
# Identify prototypes that need to be pruned based on the updated weights
prototypes_to_prune = self.identify_prototypes_to_prune()
# Execute the pruning of identified prototypes
self.prune_prototypes_by_index(prototypes_to_prune)
def prune_prototypes_by_index(self, prototypes_to_prune: list[int]) -> None:
"""
Prunes specified prototypes from the PPNet.
Args:
prototypes_to_prune (list[int]): A list of indices indicating the prototypes to be removed.
Each index should be in the range [0, current number of prototypes - 1].
Returns:
None
"""
# Validate the provided indices to ensure they are within the valid range
if any(
index < 0 or index >= self.num_prototypes for index in prototypes_to_prune
):
raise ValueError("Provided prototype indices are out of valid range!")
# Calculate the new number of prototypes after pruning
self.num_prototypes_after_pruning = self.num_prototypes - len(
prototypes_to_prune
)
# Remove the prototype vectors that are no longer needed
with torch.no_grad():
# If frequency_weights are being used, set the weights of pruned prototypes to -7
if self.frequency_weights is not None:
self.frequency_weights.data[prototypes_to_prune, :] = -7.0
# Adjust the weights in the last layer depending on its type
if isinstance(self.last_layer, LinearLayerWithoutNegativeConnections):
# For LinearLayerWithoutNegativeConnections, set the connection weights to zero
# only for the pruned prototypes related to their specific classes
for class_idx in range(self.last_layer.out_features):
# Identify prototypes belonging to the current class
indices_for_class = [
idx % self.last_layer.features_per_output_class
for idx in prototypes_to_prune
if self.prototype_class_identity[idx] == class_idx
]
self.last_layer.weight.data[class_idx, indices_for_class] = 0.0
else:
# For other layer types, set the weights of pruned prototypes to zero
self.last_layer.weight.data[:, prototypes_to_prune] = 0.0
def __repr__(self) -> str:
rep = f"""PPNet(
prototype_shape: {self.prototype_shape},
num_classes: {self.num_classes},
epsilon: {self.epsilon_val})"""
return rep
def set_last_layer_incorrect_connection(
self, incorrect_strength: float = None
) -> None:
"""
Modifies the last layer weights to have incorrect connections with a specified strength.
If incorrect_strength is None, initializes the weights for LinearLayerWithoutNegativeConnections
with correct_class_connection value.
Args:
- incorrect_strength (Optional[float]): The strength of the incorrect connections.
If None, initialize without incorrect connections.
Returns:
None
"""
if incorrect_strength is None:
# Handle LinearLayerWithoutNegativeConnections initialization
if isinstance(self.last_layer, LinearLayerWithoutNegativeConnections):
# Initialize all weights to the correct_class_connection value
self.last_layer.weight.data.fill_(self.correct_class_connection)
else:
raise ValueError(
"last_layer is not an instance of LinearLayerWithoutNegativeConnections"
)
else:
# Create a one-hot matrix for correct connections
positive_one_weights_locations = torch.zeros(
self.num_classes, self.num_prototypes
)
positive_one_weights_locations[
self.prototype_class_identity,
torch.arange(self.num_prototypes),
] = 1
# Create a matrix for incorrect connections
negative_one_weights_locations = 1 - positive_one_weights_locations
# This variable represents the strength of the connection for correct class
correct_class_connection = self.correct_class_connection
# This variable represents the strength of the connection for incorrect class
incorrect_class_connection = incorrect_strength
# Modify weights to have correct and incorrect connections
self.last_layer.weight.data.copy_(
correct_class_connection * positive_one_weights_locations
+ incorrect_class_connection * negative_one_weights_locations
)
if self.last_layer.bias is not None:
# Initialize all biases to bias_last_layer value
self.last_layer.bias.data.fill_(self.bias_last_layer)
def _setup_add_on_layers(self, add_on_layers_type: str):
"""
Configures additional layers based on the backbone model architecture and the specified add_on_layers_type.
Args:
add_on_layers_type (str): Type of additional layers to add. Can be 'identity' or 'upsample'.
"""
if add_on_layers_type == "identity":
self.add_on_layers = nn.Sequential(nn.Identity())
elif add_on_layers_type == "upsample":
self.add_on_layers = nn.Upsample(scale_factor=2, mode="bilinear")
else:
raise NotImplementedError(
f"The add-on layer type {add_on_layers_type} isn't implemented yet."
)
# TODO
# def _initialize_weights(self) -> None:
# """
# Initializes the weights of the add-on layers of the network and the last layer with incorrect connections.
#
# Returns:
# None
# """
#
# for m in self.add_on_layers.modules():
# if isinstance(m, (nn.Conv2d, nn.Linear)):
# nn.init.trunc_normal_(m.weight, std=0.02)
# if m.bias is not None:
# nn.init.zeros_(m.bias)
#
# # Initialize the last layer with incorrect connections using specified incorrect class connection strength
# self.set_last_layer_incorrect_connection(
# incorrect_strength=self.incorrect_class_connection
# )
class AudioProtoNetPreTrainedModel(PreTrainedModel):
config_class = AudioProtoNetConfig
base_model_prefix = "model"
def _init_weights(self, module):
if isinstance(module, (nn.Conv2d, nn.Linear)):
nn.init.trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
if isinstance(module, (nn.Conv2d, nn.Linear)):
nn.init.trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
if isinstance(module, LinearLayerWithoutNegativeConnections) and self.incorrect_class_connection is None:
# Initialize all weights to the correct_class_connection value
self.last_layer.weight.data.fill_(self.correct_class_connection)
class AudioProtoNetModel(AudioProtoNetPreTrainedModel):
_auto_class = "AutoModel"
def __init__(self, config: AudioProtoNetConfig):
super().__init__(config)
backbone_config = ConvNextConfig.from_pretrained("facebook/convnext-base-224-22k", num_channels=1)
self.backbone = ConvNextModel(backbone_config)
def forward(
self,
input_values: torch.Tensor,
output_hidden_states: bool = None
) -> BaseModelOutputWithPoolingAndNoAttention:
"""
Args:
input_values:
output_hidden_states:
Returns:
last_hidden_state: torch.FloatTensor = None
pooler_output: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
"""
return self.backbone(input_values, output_hidden_states)
class AudioProtoNetForSequenceClassification(AudioProtoNetPreTrainedModel):
_auto_class = "AutoModelForSequenceClassification"
def __init__(self, config: AudioProtoNetConfig):
super().__init__(config)
self.model = AudioProtoNetModel(config)
self.head = AudioProtoNetClassificationHead(config)
def forward(
self,
input_values: torch.Tensor,
labels: torch.Tensor = None,
prototypes_of_wrong_class: torch.Tensor = None,
output_hidden_states: bool = None,
output_prototypical_activations: bool = None,
) -> SequenceClassifierOutputWithProtoTypeActivations:
backbone_outputs = self.model(input_values, output_hidden_states)
last_hidden_state = backbone_outputs[0]
logits, info = self.head(last_hidden_state, prototypes_of_wrong_class)
loss = None
if labels is not None:
labels.to(logits.device)
loss_fct = AsymmetricLossMultiLabel()
loss = loss_fct(logits, labels.float())
hidden_states = None
if output_hidden_states is not None:
hidden_states = backbone_outputs[2]
prototype_activations = None
if output_prototypical_activations is not None:
prototype_activations = info[4]
return SequenceClassifierOutputWithProtoTypeActivations(
logits=logits,
loss=loss,
last_hidden_state=last_hidden_state,
hidden_states=hidden_states,
prototype_activations=prototype_activations
)