File size: 37,661 Bytes
6e22215 aab4958 6e22215 9877d4a 6e22215 9877d4a 6e22215 9877d4a 6e22215 9877d4a 6e22215 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 | import torch
import torch.nn as nn
from torch import Tensor
import math
from transformers import PreTrainedModel, ConvNextModel, ConvNextConfig
from transformers.utils import logging
from transformers.modeling_outputs import ModelOutput, BaseModelOutputWithPoolingAndNoAttention
from dataclasses import dataclass
from .configuration_protonet import AudioProtoNetConfig
logger = logging.get_logger(__name__)
@dataclass
class SequenceClassifierOutputWithProtoTypeActivations(ModelOutput):
logits: torch.Tensor
loss: torch.Tensor = None
last_hidden_state: torch.FloatTensor = None
hidden_states: tuple[torch.FloatTensor, ...] = None
prototype_activations: torch.FloatTensor = None
# https://openaccess.thecvf.com/content/ICCV2021/papers/Ridnik_Asymmetric_Loss_for_Multi-Label_Classification_ICCV_2021_paper.pdf
# https://github.com/huggingface/pytorch-image-models/blob/bbe798317fb26f063c18279827c038058e376479/timm/loss/asymmetric_loss.py#L6
class AsymmetricLossMultiLabel(nn.Module):
def __init__(
self,
gamma_neg=4,
gamma_pos=1,
clip=0.05,
eps=1e-8,
disable_torch_grad_focal_loss=False,
reduction="mean",
):
super().__init__()
self.gamma_neg = gamma_neg
self.gamma_pos = gamma_pos
self.clip = clip
self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
self.eps = eps
self.reduction = reduction
def forward(self, x, y):
""" "
Parameters
----------
x: input logits
y: targets (multi-label binarized vector)
"""
# Calculating Probabilities
x_sigmoid = torch.sigmoid(x)
xs_pos = x_sigmoid
xs_neg = 1 - x_sigmoid
# Asymmetric Clipping
if self.clip is not None and self.clip > 0:
xs_neg = (xs_neg + self.clip).clamp(max=1)
# Basic CE calculation
los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
loss = los_pos + los_neg
# Asymmetric Focusing
if self.gamma_neg > 0 or self.gamma_pos > 0:
if self.disable_torch_grad_focal_loss:
torch._C.set_grad_enabled(False)
pt0 = xs_pos * y
pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
pt = pt0 + pt1
one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
one_sided_w = torch.pow(1 - pt, one_sided_gamma)
if self.disable_torch_grad_focal_loss:
torch._C.set_grad_enabled(True)
loss *= one_sided_w
if self.reduction == "mean":
return -loss.mean()
if self.reduction == "sum":
return -loss.sum()
return -loss
class NonNegativeLinear(nn.Module):
"""
A PyTorch module for a linear layer with non-negative weights.
This module applies a linear transformation to the incoming data: `y = xA^T + b`.
The weights of the transformation are constrained to be non-negative, making this
module particularly useful in models where negative weights may not be appropriate.
Attributes:
in_features (int): The number of features in the input tensor.
out_features (int): The number of features in the output tensor.
weight (torch.Tensor): The weight parameter of the module, constrained to be non-negative.
bias (torch.Tensor, optional): The bias parameter of the module.
Args:
in_features (int): The number of features in the input tensor.
out_features (int): The number of features in the output tensor.
bias (bool, optional): If True, the layer will include a learnable bias. Default: True.
device (optional): The device (CPU/GPU) on which to perform computations.
dtype (optional): The data type for the parameters (e.g., float32).
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(
torch.empty((out_features, in_features), **factory_kwargs)
)
if bias:
self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
else:
self.register_parameter("bias", None)
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""
Defines the forward pass of the NonNegativeLinear module.
Args:
input (torch.Tensor): The input tensor of shape (batch_size, in_features).
Returns:
torch.Tensor: The output tensor of shape (batch_size, out_features).
"""
return nn.functional.linear(input, torch.relu(self.weight), self.bias)
class LinearLayerWithoutNegativeConnections(nn.Module):
r"""
Custom Linear Layer where each output class is connected to a specific subset of input features.
Args:
in_features: size of each input sample
out_features: size of each output sample
bias: If set to ``False``, the layer will not learn an additive bias.
Default: ``True``
device: the device of the module parameters. Default: ``None``
dtype: the data type of the module parameters. Default: ``None``
Shape:
- Input: :math:`(*, H_{in})` where :math:`*` means any number of
dimensions including none and :math:`H_{in} = \text{in_features}`.
- Output: :math:`(*, H_{out})` where all but the last dimension
are the same shape as the input and :math:`H_{out} = \text{out_features}`.
Attributes:
weight: the learnable weights of the module of shape
:math:`(\text{out_features}, \text{features_per_output_class})`.
bias: the learnable bias of the module of shape :math:`(\text{out_features})`.
If :attr:`bias` is ``True``, the values are initialized from
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{1}{\text{features_per_output_class}}`
"""
__constants__ = ["in_features", "out_features", "bias"]
in_features: int
out_features: int
weight: torch.Tensor
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
non_negative: bool = True,
device: torch.device = None,
dtype: torch.dtype = None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.non_negative = non_negative
# Calculate the number of features per output class
self.features_per_output_class = in_features // out_features
# Ensure input size is divisible by the output size
assert (
in_features % out_features == 0
), f"{in_features = } must be divisible by {out_features = }"
# Define weights and biases
self.weight = nn.Parameter(
torch.empty(
(out_features, self.features_per_output_class), **factory_kwargs
)
)
if bias:
self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
else:
self.register_parameter("bias", None)
# Initialize weights and biases
self.reset_parameters()
def reset_parameters(self) -> None:
"""
Initialize the weights and biases.
Weights are initialized using Kaiming uniform initialization.
Biases are initialized using a uniform distribution.
"""
# Kaiming uniform initialization for the weights
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
# Calculate fan-in and fan-out values
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
# Uniform initialization for the biases
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the custom linear layer.
Args:
input (Tensor): Input tensor of shape (batch_size, in_features).
Returns:
Tensor: Output tensor of shape (batch_size, out_features).
"""
batch_size = input.size(0)
# Reshape input to (batch_size, out_features, features_per_output_class)
reshaped_input = input.view(
batch_size, self.out_features, self.features_per_output_class
)
# Apply ReLU to weights if non_negative_last_layer is True
weight = torch.relu(self.weight) if self.non_negative else self.weight
# Perform batch matrix multiplication and add bias
output = torch.einsum("bof,of->bo", reshaped_input, weight)
if self.bias is not None:
output += self.bias
return output
def extra_repr(self) -> str:
return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}"
class AudioProtoNetClassificationHead(nn.Module):
def __init__(
self,
config: AudioProtoNetConfig,
) -> None:
"""
PPNet is a class that implements the Prototypical Part Network (ProtoPNet) for prototype-based classification.
"""
super().__init__()
self.prototypes_per_class = config.prototypes_per_class
self.num_classes = config.num_classes
self.num_prototypes = self.prototypes_per_class * self.num_classes
self.num_prototypes_after_pruning = config.num_prototypes_after_pruning
self.margin = config.margin
self.relu_on_cos = config.relu_on_cos
self.incorrect_class_connection = config.incorrect_class_connection
self.correct_class_connection = config.correct_class_connection
self.input_vector_length = config.input_vector_length
self.n_eps_channels = config.n_eps_channels
self.epsilon_val = config.epsilon_val
self.topk_k = config.topk_k
self.bias_last_layer = config.bias_last_layer
self.non_negative_last_layer = config.non_negative_last_layer
self.embedded_spectrogram_height = config.embedded_spectrogram_height
self.use_bias_last_layer = config.use_bias_last_layer
self.prototype_class_identity = config.prototype_class_identity
# Create a 1D tensor where each element represents the class index
self.prototype_class_identity = (
torch.arange(self.num_prototypes) // self.prototypes_per_class
)
self.prototype_shape = (self.num_prototypes, config.channels, config.height, config.width)
self._setup_add_on_layers(add_on_layers_type=config.add_on_layers_type)
self.prototype_vectors = nn.Parameter(
torch.rand(self.prototype_shape), requires_grad=True
)
self.frequency_weights = None
if self.embedded_spectrogram_height is not None:
# Initialize the frequency weights with a large positive value of 3.0 so that sigmoid(frequency_weights) is close to 1.
self.frequency_weights = nn.Parameter(
torch.full(
(
self.num_prototypes,
self.embedded_spectrogram_height,
),
3.0,
)
)
if self.incorrect_class_connection:
if self.non_negative_last_layer:
self.last_layer = NonNegativeLinear(
self.num_prototypes, self.num_classes, bias=self.use_bias_last_layer
)
else:
self.last_layer = nn.Linear(
self.num_prototypes, self.num_classes, bias=self.use_bias_last_layer
)
else:
self.last_layer = LinearLayerWithoutNegativeConnections(
in_features=self.num_prototypes,
out_features=self.num_classes,
non_negative=self.non_negative_last_layer,
)
def forward(
self,
features: torch.Tensor,
prototypes_of_wrong_class: torch.Tensor = None,
) -> tuple[torch.Tensor, list[torch.Tensor]]:
"""
Forward pass of the PPNet model.
Args:
- x (torch.Tensor): Input tensor with shape (batch_size, num_channels, height, width).
- prototypes_of_wrong_class (Optional[torch.Tensor]): The prototypes of the wrong classes that are needed
when using subtractive margins. Defaults to None.
Returns:
Tuple[torch.Tensor, List[torch.Tensor]]:
- logits: A tensor containing the logits for each class in the model.
- a list containing:
- mean_activations: A tensor containing the mean of the top-k prototype activations.
(in evaluation mode k is always 1)
- marginless_logits: A tensor containing the logits for each class in the model, calculated using the
marginless activations.
- conv_features: A tensor containing the convolutional features.
- marginless_max_activations: A tensor containing the max-pooled marginless activations.
"""
features = self.add_on_layers(features)
activations, additional_returns = self.prototype_activations(
features, prototypes_of_wrong_class=prototypes_of_wrong_class
)
marginless_activations = additional_returns[0]
conv_features = additional_returns[1]
# Set topk_k based on training mode: use predefined value if training, else 1 for evaluation
topk_k = 1
# Reshape activations to combine spatial dimensions: (batch_size, num_prototypes, height*width)
activations = activations.view(activations.shape[0], activations.shape[1], -1)
# Perform top-k pooling along the combined spatial dimension
# For topk_k=1, this is equivalent to global max pooling
topk_activations, _ = torch.topk(activations, topk_k, dim=-1)
# Calculate the mean of the top-k activations for each channel: (batch_size, num_channels)
# If topk_k=1, this mean operation does nothing since there's only one value.
mean_activations = torch.mean(topk_activations, dim=-1)
marginless_max_activations = nn.functional.max_pool2d(
marginless_activations,
kernel_size=(
marginless_activations.size()[2],
marginless_activations.size()[3],
),
)
marginless_max_activations = marginless_max_activations.view(
-1, self.num_prototypes
)
logits = self.last_layer(mean_activations)
marginless_logits = self.last_layer(marginless_max_activations)
return logits, [
mean_activations,
marginless_logits,
conv_features,
marginless_max_activations,
marginless_activations,
]
# def conv_features(self, x: torch.Tensor) -> torch.Tensor:
# """
# Takes an input tensor and passes it through the backbone model to extract features.
# Then, it passes them through the additional layers to produce the output tensor.
#
# Args:
# x (torch.Tensor): The input tensor.
#
# Returns:
# torch.Tensor: The output tensor after passing through the backbone model and additional layers.
# """
# # Extract features using the backbone model
# features = self.backbone_model(x)
#
# # The features must be a 4D tensor of shape (batch size, channels, height, width)
# if features.dim() == 3:
# features.unsqueeze_(0)
#
# # Pass the features through additional layers
# output = self.add_on_layers(features)
#
# return output
def cos_activation(
self,
x: torch.Tensor,
prototypes_of_wrong_class: torch.Tensor = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute the cosine activation between input tensor x and prototype vectors.
Parameters:
-----------
x : torch.Tensor
Input tensor with shape (batch_size, num_channels, height, width).
prototypes_of_wrong_class : Optional[torch.Tensor]
Tensor containing the prototypes of the wrong class with shape (batch_size, num_prototypes).
Returns:
--------
Tuple[torch.Tensor, torch.Tensor]
A tuple containing:
- activations: The cosine activations with potential margin adjustments.
- marginless_activations: The cosine activations without margin adjustments.
"""
input_vector_length = self.input_vector_length
normalizing_factor = (
self.prototype_shape[-2] * self.prototype_shape[-1]
) ** 0.5
# Pre-allocate epsilon channels on the correct device for input tensor x
epsilon_channel_x = torch.full(
(x.shape[0], self.n_eps_channels, x.shape[2], x.shape[3]),
self.epsilon_val,
device=x.device,
requires_grad=False,
)
x = torch.cat((x, epsilon_channel_x), dim=-3)
# Normalize x
x_length = torch.sqrt(torch.sum(x**2, dim=-3, keepdim=True) + self.epsilon_val)
x_normalized = (input_vector_length * x / x_length) / normalizing_factor
# Pre-allocate epsilon channels for prototypes on the correct device
epsilon_channel_p = torch.full(
(
self.prototype_shape[0],
self.n_eps_channels,
self.prototype_shape[2],
self.prototype_shape[3],
),
self.epsilon_val,
device=self.prototype_vectors.device,
requires_grad=False,
)
appended_protos = torch.cat((self.prototype_vectors, epsilon_channel_p), dim=-3)
# Normalize prototypes
prototype_vector_length = torch.sqrt(
torch.sum(appended_protos**2, dim=-3, keepdim=True) + self.epsilon_val
)
normalized_prototypes = appended_protos / (
prototype_vector_length + self.epsilon_val
)
normalized_prototypes /= normalizing_factor
# Compute activations using convolution
activations_dot = nn.functional.conv2d(x_normalized, normalized_prototypes)
marginless_activations = activations_dot / (input_vector_length * 1.01)
if self.frequency_weights is not None:
# Apply sigmoid to frequency weights. s.t. weights are between 0 and 1.
freq_weights = torch.sigmoid(self.frequency_weights)
# Multiply each prototype's frequency response by the corresponding weights
marginless_activations = marginless_activations * freq_weights[:, :, None]
if (
self.margin is None
or not self.training
or prototypes_of_wrong_class is None
):
activations = marginless_activations
else:
# Apply margin adjustment for wrong class prototypes
wrong_class_margin = (prototypes_of_wrong_class * self.margin).view(
x.size(0), self.prototype_vectors.size(0), 1, 1
)
wrong_class_margin = wrong_class_margin.expand(
-1, -1, activations_dot.size(-2), activations_dot.size(-1)
)
penalized_angles = (
torch.acos(activations_dot / (input_vector_length * 1.01))
- wrong_class_margin
)
activations = torch.cos(torch.relu(penalized_angles))
if self.relu_on_cos:
# Apply ReLU activation on the cosine values
activations = torch.relu(activations)
marginless_activations = torch.relu(marginless_activations)
return activations, marginless_activations
def prototype_activations(
self,
x: torch.Tensor,
prototypes_of_wrong_class: torch.Tensor = None,
) -> tuple[torch.Tensor, list[torch.Tensor]]:
"""
Compute the prototype activations for a given input tensor.
Args:
- x (torch.Tensor): The raw input tensor with shape (batch_size, num_channels, height, width).
- prototypes_of_wrong_class (Optional[torch.Tensor]): The prototypes of the wrong classes that are needed
when using subtractive margins. Defaults to None.
Returns:
Tuple[torch.Tensor, List[torch.Tensor]]:
- activations: A tensor containing the prototype activations.
- a list containing:
- marginless_activations: A tensor containing the activations before applying subtractive margin.
- conv_features: A tensor containing the convolutional features.
"""
# Compute cosine activations
activations, marginless_activations = self.cos_activation(
x,
prototypes_of_wrong_class=prototypes_of_wrong_class,
)
return activations, [marginless_activations, x]
def get_prototype_orthogonalities(self, use_part_prototypes: bool = False) -> torch.Tensor:
"""
Computes the orthogonality loss, encouraging each piece of a prototype to be orthogonal to the others.
This method is inspired by the paper:
https://openaccess.thecvf.com/content/ICCV2021/papers/Wang_Interpretable_Image_Recognition_by_Constructing_Transparent_Embedding_Space_ICCV_2021_paper.pdf
Args:
use_part_prototypes (bool): If True, treats each spatial part of the prototypes as a separate prototype.
Returns:
torch.Tensor: A tensor representing the orthogonalities.
"""
if use_part_prototypes:
# Normalize prototypes to unit length
prototype_vector_length = torch.sqrt(
torch.sum(torch.square(self.prototype_vectors), dim=1, keepdim=True)
+ self.epsilon_val
)
normalized_prototypes = self.prototype_vectors / (
prototype_vector_length + self.epsilon_val
)
# Calculate total part prototypes per class
num_part_prototypes_per_class = (
self.num_prototypes_per_class
* self.prototype_shape[2]
* self.prototype_shape[3]
)
# Reshape to match class structure
normalized_prototypes = normalized_prototypes.view(
self.num_classes,
self.num_prototypes_per_class,
self.prototype_shape[1],
self.prototype_shape[2] * self.prototype_shape[3],
)
# Transpose and reshape to treat each spatial part as a separate prototype
normalized_prototypes = normalized_prototypes.permute(0, 1, 3, 2).reshape(
self.num_classes, num_part_prototypes_per_class, self.prototype_shape[1]
)
else:
# Normalize prototypes to unit length
prototype_vectors_reshaped = self.prototype_vectors.view(
self.num_prototypes, -1
)
prototype_vector_length = torch.sqrt(
torch.sum(torch.square(prototype_vectors_reshaped), dim=1, keepdim=True)
+ self.epsilon_val
)
normalized_prototypes = prototype_vectors_reshaped / (
prototype_vector_length + self.epsilon_val
)
# Reshape to match class structure
normalized_prototypes = normalized_prototypes.view(
self.num_classes,
self.num_prototypes_per_class,
self.prototype_shape[1]
* self.prototype_shape[2]
* self.prototype_shape[3],
)
# Compute orthogonality matrix for each class
orthogonalities = torch.matmul(
normalized_prototypes, normalized_prototypes.transpose(1, 2)
)
# Identity matrix to enforce orthogonality
identity_matrix = (
torch.eye(normalized_prototypes.shape[1], device=orthogonalities.device)
.unsqueeze(0)
.repeat(self.num_classes, 1, 1)
)
# Subtract identity to focus on orthogonality
orthogonalities = orthogonalities - identity_matrix
return orthogonalities
def identify_prototypes_to_prune(self) -> list[int]:
"""
Identifies the indices of prototypes that should be pruned.
This function iterates through the prototypes and checks if the specific weight
connecting the prototype to its class is zero. It is specifically designed to handle
the LinearLayerWithoutNegativeConnections where each class has a subset of features
it connects to.
Returns:
list[int]: A list of prototype indices that should be pruned.
"""
prototypes_to_prune = []
# Calculate the number of prototypes assigned to each class
prototypes_per_class = self.num_prototypes // self.num_classes
if isinstance(self.last_layer, LinearLayerWithoutNegativeConnections):
# Custom layer mapping prototypes to a subset of input features for each output class
for prototype_index in range(self.num_prototypes):
class_index = self.prototype_class_identity[prototype_index]
# Calculate the specific index within the 'features_per_output_class' for this prototype
index_within_class = prototype_index % prototypes_per_class
# Check if the specific weight connecting the prototype to its class is zero
if self.last_layer.weight[class_index, index_within_class] == 0.0:
prototypes_to_prune.append(prototype_index)
else:
# Standard linear layer: each prototype directly maps to a feature index
weights_to_check = self.last_layer.weight
for prototype_index in range(self.num_prototypes):
class_index = self.prototype_class_identity[prototype_index]
if weights_to_check[class_index, prototype_index] == 0.0:
prototypes_to_prune.append(prototype_index)
return prototypes_to_prune
def prune_prototypes_by_threshold(self, threshold: float = 1e-3) -> None:
"""
Prune the weights in the classification layer by setting weights below a specified threshold to zero.
This method modifies the weights of the last layer of the model in-place. Weights falling below the
threshold are set to zero, diminishing their influence in the model's decisions. It also identifies
and prunes prototypes based on these updated weights, thereby refining the model's structure.
Args:
threshold (float): The threshold value below which weights will be set to zero. Defaults to 1e-3.
"""
# Access the weights of the last layer
weights = self.last_layer.weight.data
# Set weights below the threshold to zero
# This step reduces the influence of low-value weights in the model's decision-making process
weights[weights < threshold] = 0.0
# Update the weights in the last layer to reflect the pruning
self.last_layer.weight.data.copy_(weights)
# Identify prototypes that need to be pruned based on the updated weights
prototypes_to_prune = self.identify_prototypes_to_prune()
# Execute the pruning of identified prototypes
self.prune_prototypes_by_index(prototypes_to_prune)
def prune_prototypes_by_index(self, prototypes_to_prune: list[int]) -> None:
"""
Prunes specified prototypes from the PPNet.
Args:
prototypes_to_prune (list[int]): A list of indices indicating the prototypes to be removed.
Each index should be in the range [0, current number of prototypes - 1].
Returns:
None
"""
# Validate the provided indices to ensure they are within the valid range
if any(
index < 0 or index >= self.num_prototypes for index in prototypes_to_prune
):
raise ValueError("Provided prototype indices are out of valid range!")
# Calculate the new number of prototypes after pruning
self.num_prototypes_after_pruning = self.num_prototypes - len(
prototypes_to_prune
)
# Remove the prototype vectors that are no longer needed
with torch.no_grad():
# If frequency_weights are being used, set the weights of pruned prototypes to -7
if self.frequency_weights is not None:
self.frequency_weights.data[prototypes_to_prune, :] = -7.0
# Adjust the weights in the last layer depending on its type
if isinstance(self.last_layer, LinearLayerWithoutNegativeConnections):
# For LinearLayerWithoutNegativeConnections, set the connection weights to zero
# only for the pruned prototypes related to their specific classes
for class_idx in range(self.last_layer.out_features):
# Identify prototypes belonging to the current class
indices_for_class = [
idx % self.last_layer.features_per_output_class
for idx in prototypes_to_prune
if self.prototype_class_identity[idx] == class_idx
]
self.last_layer.weight.data[class_idx, indices_for_class] = 0.0
else:
# For other layer types, set the weights of pruned prototypes to zero
self.last_layer.weight.data[:, prototypes_to_prune] = 0.0
def __repr__(self) -> str:
rep = f"""PPNet(
prototype_shape: {self.prototype_shape},
num_classes: {self.num_classes},
epsilon: {self.epsilon_val})"""
return rep
def set_last_layer_incorrect_connection(
self, incorrect_strength: float = None
) -> None:
"""
Modifies the last layer weights to have incorrect connections with a specified strength.
If incorrect_strength is None, initializes the weights for LinearLayerWithoutNegativeConnections
with correct_class_connection value.
Args:
- incorrect_strength (Optional[float]): The strength of the incorrect connections.
If None, initialize without incorrect connections.
Returns:
None
"""
if incorrect_strength is None:
# Handle LinearLayerWithoutNegativeConnections initialization
if isinstance(self.last_layer, LinearLayerWithoutNegativeConnections):
# Initialize all weights to the correct_class_connection value
self.last_layer.weight.data.fill_(self.correct_class_connection)
else:
raise ValueError(
"last_layer is not an instance of LinearLayerWithoutNegativeConnections"
)
else:
# Create a one-hot matrix for correct connections
positive_one_weights_locations = torch.zeros(
self.num_classes, self.num_prototypes
)
positive_one_weights_locations[
self.prototype_class_identity,
torch.arange(self.num_prototypes),
] = 1
# Create a matrix for incorrect connections
negative_one_weights_locations = 1 - positive_one_weights_locations
# This variable represents the strength of the connection for correct class
correct_class_connection = self.correct_class_connection
# This variable represents the strength of the connection for incorrect class
incorrect_class_connection = incorrect_strength
# Modify weights to have correct and incorrect connections
self.last_layer.weight.data.copy_(
correct_class_connection * positive_one_weights_locations
+ incorrect_class_connection * negative_one_weights_locations
)
if self.last_layer.bias is not None:
# Initialize all biases to bias_last_layer value
self.last_layer.bias.data.fill_(self.bias_last_layer)
def _setup_add_on_layers(self, add_on_layers_type: str):
"""
Configures additional layers based on the backbone model architecture and the specified add_on_layers_type.
Args:
add_on_layers_type (str): Type of additional layers to add. Can be 'identity' or 'upsample'.
"""
if add_on_layers_type == "identity":
self.add_on_layers = nn.Sequential(nn.Identity())
elif add_on_layers_type == "upsample":
self.add_on_layers = nn.Upsample(scale_factor=2, mode="bilinear")
else:
raise NotImplementedError(
f"The add-on layer type {add_on_layers_type} isn't implemented yet."
)
# TODO
# def _initialize_weights(self) -> None:
# """
# Initializes the weights of the add-on layers of the network and the last layer with incorrect connections.
#
# Returns:
# None
# """
#
# for m in self.add_on_layers.modules():
# if isinstance(m, (nn.Conv2d, nn.Linear)):
# nn.init.trunc_normal_(m.weight, std=0.02)
# if m.bias is not None:
# nn.init.zeros_(m.bias)
#
# # Initialize the last layer with incorrect connections using specified incorrect class connection strength
# self.set_last_layer_incorrect_connection(
# incorrect_strength=self.incorrect_class_connection
# )
class AudioProtoNetPreTrainedModel(PreTrainedModel):
config_class = AudioProtoNetConfig
base_model_prefix = "model"
def _init_weights(self, module):
if isinstance(module, (nn.Conv2d, nn.Linear)):
nn.init.trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
if isinstance(module, (nn.Conv2d, nn.Linear)):
nn.init.trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
if isinstance(module, LinearLayerWithoutNegativeConnections) and self.incorrect_class_connection is None:
# Initialize all weights to the correct_class_connection value
self.last_layer.weight.data.fill_(self.correct_class_connection)
class AudioProtoNetModel(AudioProtoNetPreTrainedModel):
_auto_class = "AutoModel"
def __init__(self, config: AudioProtoNetConfig):
super().__init__(config)
backbone_config = ConvNextConfig.from_pretrained("facebook/convnext-base-224-22k", num_channels=1)
self.backbone = ConvNextModel(backbone_config)
def forward(
self,
input_values: torch.Tensor,
output_hidden_states: bool = None
) -> BaseModelOutputWithPoolingAndNoAttention:
"""
Args:
input_values:
output_hidden_states:
Returns:
last_hidden_state: torch.FloatTensor = None
pooler_output: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
"""
return self.backbone(input_values, output_hidden_states)
class AudioProtoNetForSequenceClassification(AudioProtoNetPreTrainedModel):
_auto_class = "AutoModelForSequenceClassification"
def __init__(self, config: AudioProtoNetConfig):
super().__init__(config)
self.model = AudioProtoNetModel(config)
self.head = AudioProtoNetClassificationHead(config)
def forward(
self,
input_values: torch.Tensor,
labels: torch.Tensor = None,
prototypes_of_wrong_class: torch.Tensor = None,
output_hidden_states: bool = None,
output_prototypical_activations: bool = None,
) -> SequenceClassifierOutputWithProtoTypeActivations:
backbone_outputs = self.model(input_values, output_hidden_states)
last_hidden_state = backbone_outputs[0]
logits, info = self.head(last_hidden_state, prototypes_of_wrong_class)
loss = None
if labels is not None:
labels.to(logits.device)
loss_fct = AsymmetricLossMultiLabel()
loss = loss_fct(logits, labels.float())
hidden_states = None
if output_hidden_states is not None:
hidden_states = backbone_outputs[2]
prototype_activations = None
if output_prototypical_activations is not None:
prototype_activations = info[4]
return SequenceClassifierOutputWithProtoTypeActivations(
logits=logits,
loss=loss,
last_hidden_state=last_hidden_state,
hidden_states=hidden_states,
prototype_activations=prototype_activations
)
|