import torch import torch.nn as nn from torchvision import models class ScoresLayer(nn.Module): def __init__(self, input_dim, num_centers): super().__init__() self.input_dim = input_dim self.num_centers = num_centers self.centers = nn.Parameter(torch.zeros(num_centers, input_dim), requires_grad=True) self.logsigmas = nn.Parameter(torch.zeros(num_centers), requires_grad=True) def forward(self, x): batch_size = x.size(0) out = x.view(batch_size, self.input_dim, 1, 1) # [batch, C, 1, 1] centers = self.centers[None, :, :, None, None] # [1, K, C, 1, 1] diff = out.unsqueeze(1) - centers # [batch, K, C, 1, 1] sum_diff = torch.sum(diff, dim=2) # [batch, K, 1, 1] sign = torch.sign(sum_diff) squared_diff = torch.sum(diff ** 2, dim=2) # [batch, K, 1, 1] logsigmas = nn.functional.relu(self.logsigmas) denominator = 2 * torch.exp(2 * logsigmas) part1 = (sign * squared_diff) / denominator.view(1, -1, 1, 1) part2 = self.input_dim * logsigmas part2 = part2.view(1, -1, 1, 1) scores = part1 + part2 output = scores.sum(dim=(1, 2, 3)).view(-1, 1) # [batch, 1] return output class ImageClassifier(nn.Module): def __init__(self, settings): super().__init__() if settings.arch == 'baseline': self.backbone = models.resnet50(weights=None) self.backbone.fc = nn.Linear(self.backbone.fc.in_features, 1) elif settings.arch == 'nodown': self.backbone = models.resnet50(weights=None) # Replace first conv layer to avoid downsampling new_conv = nn.Conv2d(3, 64, kernel_size=7, stride=1, padding=3, bias=False) new_conv.weight = nn.Parameter(self.backbone.conv1.weight) self.backbone.conv1 = new_conv self.backbone.fc = nn.Sequential(nn.Linear(self.backbone.fc.in_features, 128), nn.Dropout(0.5)) else: raise NotImplementedError('Model not recognized') if settings.freeze: for param in self.backbone.parameters(): param.requires_grad = False for param in self.backbone.fc.parameters(): param.requires_grad = True else: for param in self.backbone.parameters(): param.requires_grad = True self.prototype = settings.prototype if self.prototype: self.proto = ScoresLayer(input_dim=self.backbone.fc[0].out_features, num_centers=settings.num_centers) for param in self.proto.parameters(): param.requires_grad = True def forward(self, x): x = self.backbone(x) if self.prototype: x = self.proto(x) return x