Spaces:
Build error
Build error
File size: 1,377 Bytes
127caf5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import imagenet_resnet # Importing ResNet-18 from imagenet_resnet.py
# Feature Extraction Network
class NetFeat(nn.Module):
def __init__(self, arch, pretrained, dataset):
super(NetFeat, self).__init__()
if dataset == 'Clothing1M':
if arch == 'resnet18':
net = imagenet_resnet.resnet18(pretrained=pretrained)
self.feat_dim = 512 # ResNet-18 output feature dimension
resnet_feature_layers = ['conv1','bn1','relu','maxpool','layer1','layer2','layer3','layer4']
resnet_module_list = [getattr(net, l) for l in resnet_feature_layers]
last_layer_idx = resnet_feature_layers.index('layer4')
self.feat_net = nn.Sequential(*(resnet_module_list[:last_layer_idx+1] + [nn.AdaptiveAvgPool2d((1,1))]))
def forward(self, x):
x = self.feat_net(x)
x = torch.flatten(x, 1)
return x
# Classification Layer
class NetClassifier(nn.Module):
def __init__(self, feat_dim, nb_cls):
super(NetClassifier, self).__init__()
self.weight = torch.nn.Parameter(nn.Linear(feat_dim, nb_cls, bias=False).weight.T, requires_grad=True) # feat_dim * nb_cls
def forward(self, feature):
clsScore = torch.mm(feature, self.weight)
return clsScore
|