Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import imagenet_resnet # Importing ResNet-18 from imagenet_resnet.py | |
| # Feature Extraction Network | |
| class NetFeat(nn.Module): | |
| def __init__(self, arch, pretrained, dataset): | |
| super(NetFeat, self).__init__() | |
| if dataset == 'Clothing1M': | |
| if arch == 'resnet18': | |
| net = imagenet_resnet.resnet18(pretrained=pretrained) | |
| self.feat_dim = 512 # ResNet-18 output feature dimension | |
| resnet_feature_layers = ['conv1','bn1','relu','maxpool','layer1','layer2','layer3','layer4'] | |
| resnet_module_list = [getattr(net, l) for l in resnet_feature_layers] | |
| last_layer_idx = resnet_feature_layers.index('layer4') | |
| self.feat_net = nn.Sequential(*(resnet_module_list[:last_layer_idx+1] + [nn.AdaptiveAvgPool2d((1,1))])) | |
| def forward(self, x): | |
| x = self.feat_net(x) | |
| x = torch.flatten(x, 1) | |
| return x | |
| # Classification Layer | |
| class NetClassifier(nn.Module): | |
| def __init__(self, feat_dim, nb_cls): | |
| super(NetClassifier, self).__init__() | |
| self.weight = torch.nn.Parameter(nn.Linear(feat_dim, nb_cls, bias=False).weight.T, requires_grad=True) # feat_dim * nb_cls | |
| def forward(self, feature): | |
| clsScore = torch.mm(feature, self.weight) | |
| return clsScore | |