import torch import torch.nn as nn import torch.nn.functional as F import imagenet_resnet # Importing ResNet-18 from imagenet_resnet.py # Feature Extraction Network class NetFeat(nn.Module): def __init__(self, arch, pretrained, dataset): super(NetFeat, self).__init__() if dataset == 'Clothing1M': if arch == 'resnet18': net = imagenet_resnet.resnet18(pretrained=pretrained) self.feat_dim = 512 # ResNet-18 output feature dimension resnet_feature_layers = ['conv1','bn1','relu','maxpool','layer1','layer2','layer3','layer4'] resnet_module_list = [getattr(net, l) for l in resnet_feature_layers] last_layer_idx = resnet_feature_layers.index('layer4') self.feat_net = nn.Sequential(*(resnet_module_list[:last_layer_idx+1] + [nn.AdaptiveAvgPool2d((1,1))])) def forward(self, x): x = self.feat_net(x) x = torch.flatten(x, 1) return x # Classification Layer class NetClassifier(nn.Module): def __init__(self, feat_dim, nb_cls): super(NetClassifier, self).__init__() self.weight = torch.nn.Parameter(nn.Linear(feat_dim, nb_cls, bias=False).weight.T, requires_grad=True) # feat_dim * nb_cls def forward(self, feature): clsScore = torch.mm(feature, self.weight) return clsScore