File size: 1,377 Bytes
127caf5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
import torch.nn as nn
import torch.nn.functional as F
import imagenet_resnet  # Importing ResNet-18 from imagenet_resnet.py

# Feature Extraction Network
class NetFeat(nn.Module):
    def __init__(self, arch, pretrained, dataset):
        super(NetFeat, self).__init__()

        if dataset == 'Clothing1M':
            if arch == 'resnet18': 
                net = imagenet_resnet.resnet18(pretrained=pretrained)
                self.feat_dim = 512  # ResNet-18 output feature dimension

        resnet_feature_layers = ['conv1','bn1','relu','maxpool','layer1','layer2','layer3','layer4']
        resnet_module_list = [getattr(net, l) for l in resnet_feature_layers]
        last_layer_idx = resnet_feature_layers.index('layer4')
        self.feat_net = nn.Sequential(*(resnet_module_list[:last_layer_idx+1] + [nn.AdaptiveAvgPool2d((1,1))]))

    def forward(self, x):
        x = self.feat_net(x)
        x = torch.flatten(x, 1)
        return x

# Classification Layer
class NetClassifier(nn.Module):
    def __init__(self, feat_dim, nb_cls):
        super(NetClassifier, self).__init__()
        self.weight = torch.nn.Parameter(nn.Linear(feat_dim, nb_cls, bias=False).weight.T, requires_grad=True) # feat_dim * nb_cls

    def forward(self, feature):
        clsScore = torch.mm(feature, self.weight)
        return clsScore