VGG-CNN / model.py
AAAkater's picture
Upload folder using huggingface_hub
d10b7cf verified
import torch
import torch.nn as nn
import torchvision.models as models
from PIL import Image
from torch import Tensor, nn
from torchvision.models import VGG, VGG16_Weights
class VGG16WithCNN(nn.Module):
def __init__(self, num_classes: int = 10):
super(VGG16WithCNN, self).__init__()
self.num_classes = num_classes
self.vgg16: VGG = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1)
for param in self.vgg16.parameters():
param.requires_grad = False
self.custom_cnn = nn.Sequential(
nn.Conv2d(512, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.BatchNorm2d(64),
nn.MaxPool2d(kernel_size=2, stride=2),
)
# 分类器
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(32 * 3 * 6, 1024),
nn.ReLU(inplace=True),
nn.Dropout(0.4),
nn.Linear(1024, num_classes),
)
def forward(self, x: Tensor):
x = self.vgg16.features(x)
x = self.custom_cnn(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x