Spaces:
Sleeping
Sleeping
Add initial project setup with requirements, README updates, and image classification functionality
cbf524d
| import torchvision.models as models | |
| import torch.nn as nn | |
| import torch | |
| import numpy as np | |
| import torch.nn.functional as F | |
| import torch.nn.init as init | |
| import math | |
| import model.cifar_resnet as cifar | |
| import model.imagenet_resnet as imagenet | |
| class NetFeat(nn.Module): | |
| def __init__(self, arch, pretrained, dataset): | |
| super(NetFeat, self).__init__() | |
| if 'CIFAR' in dataset: | |
| if 'resnet' in arch: | |
| if arch == 'resnet18': | |
| net = cifar.resnet18() | |
| resnet_feature_layers = ['conv1','conv2_x','conv3_x','conv4_x','conv5_x'] | |
| resnet_module_list = [getattr(net,l) for l in resnet_feature_layers] | |
| last_layer_idx = resnet_feature_layers.index('conv5_x') | |
| featExtractor = nn.Sequential(*(resnet_module_list[:last_layer_idx+1] + [nn.AdaptiveAvgPool2d((1, 1))])) | |
| self.feat_net = featExtractor | |
| self.feat_dim = 512 | |
| elif dataset == 'Clothing1M': | |
| if arch == 'resnet50': | |
| net = imagenet.resnet50(pretrained=pretrained) | |
| self.feat_dim = 2048 | |
| elif arch == 'resnet18': | |
| net = imagenet.resnet18(pretrained=pretrained) | |
| self.feat_dim = 512 | |
| resnet_feature_layers = ['conv1','bn1','relu','maxpool','layer1','layer2','layer3','layer4'] | |
| resnet_module_list = [getattr(net,l) for l in resnet_feature_layers] | |
| last_layer_idx = resnet_feature_layers.index('layer4') | |
| featExtractor = nn.Sequential(*(resnet_module_list[:last_layer_idx+1] + [nn.AvgPool2d(7, stride=1)])) | |
| self.feat_net = featExtractor | |
| def train(self, mode=True, freeze_bn=False): | |
| """ | |
| Override the default train() to freeze the BN parameters | |
| """ | |
| super(NetFeat, self).train(mode) | |
| self.freeze_bn = freeze_bn | |
| if self.freeze_bn: | |
| for m in self.modules(): | |
| if isinstance(m, nn.BatchNorm2d): | |
| m.eval() | |
| m.weight.requires_grad = False | |
| m.bias.requires_grad = False | |
| def forward(self, x): | |
| x = self.feat_net(x) | |
| x = torch.flatten(x, 1) | |
| return x | |
| class NetClassifier(nn.Module): | |
| def __init__(self, feat_dim, nb_cls): | |
| super(NetClassifier, self).__init__() | |
| self.weight = torch.nn.Parameter(nn.Linear(feat_dim, nb_cls, bias=False).weight.T, requires_grad=True) # dimension feat_dim * nb_cls | |
| def getWeight(self): | |
| return self.weight, self.bias, self.scale_cls | |
| def forward(self, feature): | |
| batchSize, nFeat = feature.size() | |
| clsScore = torch.mm(feature, self.weight) | |
| return clsScore | |
| if __name__ == '__main__': | |
| data = torch.randn(3, 3, 32, 32).to("cpu") | |
| net_feat = NetFeat(arch='resnet18', pretrained=False, dataset='CIFAR100') | |
| net_cls = NetClassifier(net_feat.feat_dim, 10) | |
| net_feat.to("cpu") | |
| net_cls.to("cpu") | |
| feat = net_feat(data) | |
| print (feat.size()) | |
| score = net_cls(feat) | |
| print (score.size()) |