clothing1m / model /model.py
prshanthreddy's picture
Add initial project setup with requirements, README updates, and image classification functionality
cbf524d
import torchvision.models as models
import torch.nn as nn
import torch
import numpy as np
import torch.nn.functional as F
import torch.nn.init as init
import math
import model.cifar_resnet as cifar
import model.imagenet_resnet as imagenet
class NetFeat(nn.Module):
def __init__(self, arch, pretrained, dataset):
super(NetFeat, self).__init__()
if 'CIFAR' in dataset:
if 'resnet' in arch:
if arch == 'resnet18':
net = cifar.resnet18()
resnet_feature_layers = ['conv1','conv2_x','conv3_x','conv4_x','conv5_x']
resnet_module_list = [getattr(net,l) for l in resnet_feature_layers]
last_layer_idx = resnet_feature_layers.index('conv5_x')
featExtractor = nn.Sequential(*(resnet_module_list[:last_layer_idx+1] + [nn.AdaptiveAvgPool2d((1, 1))]))
self.feat_net = featExtractor
self.feat_dim = 512
elif dataset == 'Clothing1M':
if arch == 'resnet50':
net = imagenet.resnet50(pretrained=pretrained)
self.feat_dim = 2048
elif arch == 'resnet18':
net = imagenet.resnet18(pretrained=pretrained)
self.feat_dim = 512
resnet_feature_layers = ['conv1','bn1','relu','maxpool','layer1','layer2','layer3','layer4']
resnet_module_list = [getattr(net,l) for l in resnet_feature_layers]
last_layer_idx = resnet_feature_layers.index('layer4')
featExtractor = nn.Sequential(*(resnet_module_list[:last_layer_idx+1] + [nn.AvgPool2d(7, stride=1)]))
self.feat_net = featExtractor
def train(self, mode=True, freeze_bn=False):
"""
Override the default train() to freeze the BN parameters
"""
super(NetFeat, self).train(mode)
self.freeze_bn = freeze_bn
if self.freeze_bn:
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
m.weight.requires_grad = False
m.bias.requires_grad = False
def forward(self, x):
x = self.feat_net(x)
x = torch.flatten(x, 1)
return x
class NetClassifier(nn.Module):
def __init__(self, feat_dim, nb_cls):
super(NetClassifier, self).__init__()
self.weight = torch.nn.Parameter(nn.Linear(feat_dim, nb_cls, bias=False).weight.T, requires_grad=True) # dimension feat_dim * nb_cls
def getWeight(self):
return self.weight, self.bias, self.scale_cls
def forward(self, feature):
batchSize, nFeat = feature.size()
clsScore = torch.mm(feature, self.weight)
return clsScore
if __name__ == '__main__':
data = torch.randn(3, 3, 32, 32).to("cpu")
net_feat = NetFeat(arch='resnet18', pretrained=False, dataset='CIFAR100')
net_cls = NetClassifier(net_feat.feat_dim, 10)
net_feat.to("cpu")
net_cls.to("cpu")
feat = net_feat(data)
print (feat.size())
score = net_cls(feat)
print (score.size())