import torch from torch import nn from torchvision import transforms from torchvision.models import resnet18,resnet50,ResNet50_Weights class SimpleCNN(nn.Module): def __init__(self, num_inputs=1,input_size=28,num_classes=10, dropout_rate=0.3): super().__init__() self.features = nn.Sequential( # Block 1 nn.Conv2d(num_inputs, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2), # Block 2 nn.Conv2d(32, 64, kernel_size=3, padding=1), # 【修复 1】你之前写的是 62,必须是 64 才能匹配上一层的输出 nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2), ) final_size = input_size // 4 flatten_dim = 64 * final_size * final_size self.classifier = nn.Sequential( nn.Flatten(), # 计算逻辑: 28 -> 14 -> 7,通道 64 nn.Linear(flatten_dim, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(dropout_rate), nn.Linear(512, num_classes), ) self.apply(self._init_weights) def _init_weights(self, m): # 【修复 2】你之前写的是 nn.Linaer (拼写错误),导致全连接层没有被正确初始化! if isinstance(m, (nn.Conv2d, nn.Linear)): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)): nn.init.constant_(m.weight, 1) # 【修复 3】你之前写的是 m.weiht (拼写错误),导致偏置没有归零 nn.init.constant_(m.bias, 0) # 【修复 4】你之前的代码里完全漏掉了 forward 函数! # 没有这个函数,模型根本不知道怎么跑数据,虽然不报错(如果没调用),但跑起来就是随机的 def forward(self, x): x = self.features(x) x = self.classifier(x) return x class ResNet18_CIFAR(nn.Module): def __init__(self,num_inputs=3,num_classes=10,dropout_rate=0.0): super().__init__() self.aug = nn.Sequential( transforms.RandomHorizontalFlip(), transforms.RandomCrop(32,padding=4,padding_mode='reflect'), ) self.net = resnet18(weights=None) self.net.conv1 = nn.Conv2d(num_inputs,64,kernel_size=3,stride=1,padding=1,bias=False) self.net.maxpool = nn.Identity() self.net.fc = nn.Linear(512,num_classes) def forward(self,x): if self.training: x = self.aug(x) return self.net(x) class TransferResNet50(nn.Module): def __init__(self, num_classes=10, dropout_rate=0.0): super().__init__() print("⬇️ Loading Pre-trained ResNet50 (ImageNet)...") # 1. 正确加载权重 self.net = resnet50(weights=ResNet50_Weights.DEFAULT) # 2. 全网微调 (不冻结) # 因为 CIFAR-10 和 ImageNet 差异较大 (清晰度、物体类别),微调 Backbone 是必须的 # 我们已经在 train.py 里用了极小的 backbone_lr (1e-5) 来保护它,所以这里不需要 freeze # 3. 替换分类头 num_ftrs = self.net.fc.in_features self.net.fc = nn.Sequential( nn.Dropout(dropout_rate), nn.Linear(num_ftrs, num_classes), ) def forward(self, x): return self.net(x)