Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from typing import Union, List, Dict, Any, cast | |
| import torchvision | |
| import torch.nn.functional as F | |
| class VGG(torch.nn.Module): | |
| def __init__(self, arch_type, pretrained, progress): | |
| super().__init__() | |
| self.layer1 = torch.nn.Sequential() | |
| self.layer2 = torch.nn.Sequential() | |
| self.layer3 = torch.nn.Sequential() | |
| self.layer4 = torch.nn.Sequential() | |
| self.layer5 = torch.nn.Sequential() | |
| if arch_type == 'vgg11': | |
| official_vgg = torchvision.models.vgg11(pretrained=pretrained, progress=progress) | |
| blocks = [ [0,2], [2,5], [5,10], [10,15], [15,20] ] | |
| last_idx = 20 | |
| elif arch_type == 'vgg19': | |
| official_vgg = torchvision.models.vgg19(pretrained=pretrained, progress=progress) | |
| blocks = [ [0,4], [4,9], [9,18], [18,27], [27,36] ] | |
| last_idx = 36 | |
| else: | |
| raise NotImplementedError | |
| for x in range( *blocks[0] ): | |
| self.layer1.add_module(str(x), official_vgg.features[x]) | |
| for x in range( *blocks[1] ): | |
| self.layer2.add_module(str(x), official_vgg.features[x]) | |
| for x in range( *blocks[2] ): | |
| self.layer3.add_module(str(x), official_vgg.features[x]) | |
| for x in range( *blocks[3] ): | |
| self.layer4.add_module(str(x), official_vgg.features[x]) | |
| for x in range( *blocks[4] ): | |
| self.layer5.add_module(str(x), official_vgg.features[x]) | |
| self.max_pool = official_vgg.features[last_idx] | |
| self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) | |
| self.fc1 = official_vgg.classifier[0] | |
| self.fc2 = official_vgg.classifier[3] | |
| self.fc3 = official_vgg.classifier[6] | |
| self.dropout = nn.Dropout() | |
| def forward(self, x): | |
| out = {} | |
| x = self.layer1(x) | |
| out['f0'] = x | |
| x = self.layer2(x) | |
| out['f1'] = x | |
| x = self.layer3(x) | |
| out['f2'] = x | |
| x = self.layer4(x) | |
| out['f3'] = x | |
| x = self.layer5(x) | |
| out['f4'] = x | |
| x = self.max_pool(x) | |
| x = self.avgpool(x) | |
| x = x.view(-1,512*7*7) | |
| x = self.fc1(x) | |
| x = F.relu(x) | |
| x = self.dropout(x) | |
| x = self.fc2(x) | |
| x = F.relu(x) | |
| out['penultimate'] = x | |
| x = self.dropout(x) | |
| x = self.fc3(x) | |
| out['logits'] = x | |
| return out | |
| def vgg11(pretrained=False, progress=True): | |
| r"""VGG 11-layer model (configuration "A") from | |
| `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_. | |
| Args: | |
| pretrained (bool): If True, returns a model pre-trained on ImageNet | |
| progress (bool): If True, displays a progress bar of the download to stderr | |
| """ | |
| return VGG('vgg11', pretrained, progress) | |
| def vgg19(pretrained=False, progress=True): | |
| r"""VGG 19-layer model (configuration "E") | |
| `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_. | |
| Args: | |
| pretrained (bool): If True, returns a model pre-trained on ImageNet | |
| progress (bool): If True, displays a progress bar of the download to stderr | |
| """ | |
| return VGG('vgg19', pretrained, progress) | |