Clothpredict / model.py
Saahil-doryu's picture
Upload 11 files
127caf5 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import imagenet_resnet # Importing ResNet-18 from imagenet_resnet.py
# Feature Extraction Network
class NetFeat(nn.Module):
def __init__(self, arch, pretrained, dataset):
super(NetFeat, self).__init__()
if dataset == 'Clothing1M':
if arch == 'resnet18':
net = imagenet_resnet.resnet18(pretrained=pretrained)
self.feat_dim = 512 # ResNet-18 output feature dimension
resnet_feature_layers = ['conv1','bn1','relu','maxpool','layer1','layer2','layer3','layer4']
resnet_module_list = [getattr(net, l) for l in resnet_feature_layers]
last_layer_idx = resnet_feature_layers.index('layer4')
self.feat_net = nn.Sequential(*(resnet_module_list[:last_layer_idx+1] + [nn.AdaptiveAvgPool2d((1,1))]))
def forward(self, x):
x = self.feat_net(x)
x = torch.flatten(x, 1)
return x
# Classification Layer
class NetClassifier(nn.Module):
def __init__(self, feat_dim, nb_cls):
super(NetClassifier, self).__init__()
self.weight = torch.nn.Parameter(nn.Linear(feat_dim, nb_cls, bias=False).weight.T, requires_grad=True) # feat_dim * nb_cls
def forward(self, feature):
clsScore = torch.mm(feature, self.weight)
return clsScore