Spaces:
Running
Running
| from collections import OrderedDict | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torchvision | |
| class VGGInputNormalization(torch.nn.Module): | |
| def __init__(self, inplace=True): | |
| super().__init__() | |
| self.inplace = inplace | |
| mean = np.array([0.485, 0.456, 0.406]) | |
| mean = mean[:, np.newaxis, np.newaxis] | |
| std = np.array([0.229, 0.224, 0.225]) | |
| std = std[:, np.newaxis, np.newaxis] | |
| self.register_buffer('mean', torch.tensor(mean)) | |
| self.register_buffer('std', torch.tensor(std)) | |
| def forward(self, tensor): | |
| if self.inplace: | |
| tensor /= 255.0 | |
| else: | |
| tensor = tensor / 255.0 | |
| tensor -= self.mean | |
| tensor /= self.std | |
| return tensor | |
| class VGG19BNNamedFeatures(torch.nn.Sequential): | |
| def __init__(self): | |
| names = [] | |
| for block in range(5): | |
| block_size = 2 if block < 2 else 4 | |
| for layer in range(block_size): | |
| names.append(f'conv{block+1}_{layer+1}') | |
| names.append(f'bn{block+1}_{layer+1}') | |
| names.append(f'relu{block+1}_{layer+1}') | |
| names.append(f'pool{block+1}') | |
| vgg = torchvision.models.vgg19_bn(pretrained=True) | |
| vgg_features = vgg.features | |
| vgg.classifier = torch.nn.Sequential() | |
| assert len(names) == len(vgg_features) | |
| named_features = OrderedDict({'normalize': VGGInputNormalization()}) | |
| for name, feature in zip(names, vgg_features): | |
| if isinstance(feature, nn.MaxPool2d): | |
| feature.ceil_mode = True | |
| named_features[name] = feature | |
| super().__init__(named_features) | |
| class VGG19NamedFeatures(torch.nn.Sequential): | |
| def __init__(self): | |
| names = [] | |
| for block in range(5): | |
| block_size = 2 if block < 2 else 4 | |
| for layer in range(block_size): | |
| names.append(f'conv{block+1}_{layer+1}') | |
| names.append(f'relu{block+1}_{layer+1}') | |
| names.append(f'pool{block+1}') | |
| vgg = torchvision.models.vgg19(pretrained=True) | |
| vgg_features = vgg.features | |
| vgg.classifier = torch.nn.Sequential() | |
| assert len(names) == len(vgg_features) | |
| named_features = OrderedDict({'normalize': VGGInputNormalization()}) | |
| for name, feature in zip(names, vgg_features): | |
| if isinstance(feature, nn.MaxPool2d): | |
| feature.ceil_mode = True | |
| named_features[name] = feature | |
| super().__init__(named_features) | |