File size: 1,165 Bytes
d10b7cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import torch.nn as nn
import torchvision.models as models
from PIL import Image
from torch import Tensor, nn
from torchvision.models import VGG, VGG16_Weights


class VGG16WithCNN(nn.Module):
    def __init__(self, num_classes: int = 10):
        super(VGG16WithCNN, self).__init__()
        self.num_classes = num_classes
        self.vgg16: VGG = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1)

        for param in self.vgg16.parameters():
            param.requires_grad = False

        self.custom_cnn = nn.Sequential(
            nn.Conv2d(512, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        # 分类器
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32 * 3 * 6, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(1024, num_classes),
        )

    def forward(self, x: Tensor):
        x = self.vgg16.features(x)
        x = self.custom_cnn(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x