| import torch # for model | |
| import torch.nn as nn | |
| import torchvision.models as models #to load vgg 19 model | |
| class VGGNet(nn.Module): | |
| def __init__(self): | |
| super(VGGNet, self).__init__() | |
| self.chosen_features = ['0', '5', '10', '19', '28'] | |
| self.vgg = models.vgg19(pretrained = True).features #select only certain layers to extract fetaures | |
| def forward(self,x): | |
| features = [] #returns features from selected conv layers from VGG19 pretrained model | |
| for layer_num, layer in self.vgg._modules.items(): | |
| x = layer(x) | |
| if layer_num in self.chosen_features: | |
| features.append(x) | |
| return features |