|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torchvision.models as models |
|
|
from PIL import Image |
|
|
from torch import Tensor, nn |
|
|
from torchvision.models import VGG, VGG16_Weights |
|
|
|
|
|
|
|
|
class VGG16WithCNN(nn.Module): |
|
|
def __init__(self, num_classes: int = 10): |
|
|
super(VGG16WithCNN, self).__init__() |
|
|
self.num_classes = num_classes |
|
|
self.vgg16: VGG = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1) |
|
|
|
|
|
for param in self.vgg16.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
self.custom_cnn = nn.Sequential( |
|
|
nn.Conv2d(512, 64, kernel_size=3, stride=1, padding=1), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.BatchNorm2d(64), |
|
|
nn.MaxPool2d(kernel_size=2, stride=2), |
|
|
) |
|
|
|
|
|
|
|
|
self.classifier = nn.Sequential( |
|
|
nn.Flatten(), |
|
|
nn.Linear(32 * 3 * 6, 1024), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Dropout(0.4), |
|
|
nn.Linear(1024, num_classes), |
|
|
) |
|
|
|
|
|
def forward(self, x: Tensor): |
|
|
x = self.vgg16.features(x) |
|
|
x = self.custom_cnn(x) |
|
|
x = x.view(x.size(0), -1) |
|
|
x = self.classifier(x) |
|
|
return x |
|
|
|