YiMeng-SYSU's picture
Initial commit of transfer learning project files
e3469ed verified
import torch
from torch import nn
from torchvision import transforms
from torchvision.models import resnet18,resnet50,ResNet50_Weights
class SimpleCNN(nn.Module):
def __init__(self, num_inputs=1,input_size=28,num_classes=10, dropout_rate=0.3):
super().__init__()
self.features = nn.Sequential(
# Block 1
nn.Conv2d(num_inputs, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(2),
# Block 2
nn.Conv2d(32, 64, kernel_size=3, padding=1),
# 【修复 1】你之前写的是 62,必须是 64 才能匹配上一层的输出
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
)
final_size = input_size // 4
flatten_dim = 64 * final_size * final_size
self.classifier = nn.Sequential(
nn.Flatten(),
# 计算逻辑: 28 -> 14 -> 7,通道 64
nn.Linear(flatten_dim, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(512, num_classes),
)
self.apply(self._init_weights)
def _init_weights(self, m):
# 【修复 2】你之前写的是 nn.Linaer (拼写错误),导致全连接层没有被正确初始化!
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
nn.init.constant_(m.weight, 1)
# 【修复 3】你之前写的是 m.weiht (拼写错误),导致偏置没有归零
nn.init.constant_(m.bias, 0)
# 【修复 4】你之前的代码里完全漏掉了 forward 函数!
# 没有这个函数,模型根本不知道怎么跑数据,虽然不报错(如果没调用),但跑起来就是随机的
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x
class ResNet18_CIFAR(nn.Module):
def __init__(self,num_inputs=3,num_classes=10,dropout_rate=0.0):
super().__init__()
self.aug = nn.Sequential(
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32,padding=4,padding_mode='reflect'),
)
self.net = resnet18(weights=None)
self.net.conv1 = nn.Conv2d(num_inputs,64,kernel_size=3,stride=1,padding=1,bias=False)
self.net.maxpool = nn.Identity()
self.net.fc = nn.Linear(512,num_classes)
def forward(self,x):
if self.training:
x = self.aug(x)
return self.net(x)
class TransferResNet50(nn.Module):
def __init__(self, num_classes=10, dropout_rate=0.0):
super().__init__()
print("⬇️ Loading Pre-trained ResNet50 (ImageNet)...")
# 1. 正确加载权重
self.net = resnet50(weights=ResNet50_Weights.DEFAULT)
# 2. 全网微调 (不冻结)
# 因为 CIFAR-10 和 ImageNet 差异较大 (清晰度、物体类别),微调 Backbone 是必须的
# 我们已经在 train.py 里用了极小的 backbone_lr (1e-5) 来保护它,所以这里不需要 freeze
# 3. 替换分类头
num_ftrs = self.net.fc.in_features
self.net.fc = nn.Sequential(
nn.Dropout(dropout_rate),
nn.Linear(num_ftrs, num_classes),
)
def forward(self, x):
return self.net(x)