File size: 3,486 Bytes
e3469ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
from torch import nn
from torchvision import transforms
from torchvision.models import resnet18,resnet50,ResNet50_Weights

class SimpleCNN(nn.Module):
    def __init__(self, num_inputs=1,input_size=28,num_classes=10, dropout_rate=0.3):
        super().__init__()

        self.features = nn.Sequential(
            # Block 1
            nn.Conv2d(num_inputs, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),

            # Block 2
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            # 【修复 1】你之前写的是 62,必须是 64 才能匹配上一层的输出
            nn.BatchNorm2d(64), 
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        final_size = input_size // 4
        flatten_dim = 64 * final_size * final_size

        self.classifier = nn.Sequential(
            nn.Flatten(),
            # 计算逻辑: 28 -> 14 -> 7,通道 64
            nn.Linear(flatten_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(512, num_classes),
        )

        self.apply(self._init_weights)

    def _init_weights(self, m):
        # 【修复 2】你之前写的是 nn.Linaer (拼写错误),导致全连接层没有被正确初始化!
        if isinstance(m, (nn.Conv2d, nn.Linear)): 
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        
        elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
            nn.init.constant_(m.weight, 1)
            # 【修复 3】你之前写的是 m.weiht (拼写错误),导致偏置没有归零
            nn.init.constant_(m.bias, 0) 

    # 【修复 4】你之前的代码里完全漏掉了 forward 函数!
    # 没有这个函数,模型根本不知道怎么跑数据,虽然不报错(如果没调用),但跑起来就是随机的
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

class ResNet18_CIFAR(nn.Module):
    def __init__(self,num_inputs=3,num_classes=10,dropout_rate=0.0):
        super().__init__()

        self.aug = nn.Sequential(
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32,padding=4,padding_mode='reflect'),
        )

        self.net = resnet18(weights=None)

        self.net.conv1 = nn.Conv2d(num_inputs,64,kernel_size=3,stride=1,padding=1,bias=False)
        self.net.maxpool = nn.Identity()

        self.net.fc = nn.Linear(512,num_classes)

    def forward(self,x):
        if self.training:
            x = self.aug(x)
        return self.net(x)

class TransferResNet50(nn.Module):
    def __init__(self, num_classes=10, dropout_rate=0.0):
        super().__init__()

        print("⬇️ Loading Pre-trained ResNet50 (ImageNet)...")
        # 1. 正确加载权重
        self.net = resnet50(weights=ResNet50_Weights.DEFAULT)

        # 2. 全网微调 (不冻结)
        # 因为 CIFAR-10 和 ImageNet 差异较大 (清晰度、物体类别),微调 Backbone 是必须的
        # 我们已经在 train.py 里用了极小的 backbone_lr (1e-5) 来保护它,所以这里不需要 freeze
        
        # 3. 替换分类头
        num_ftrs = self.net.fc.in_features
        self.net.fc = nn.Sequential(
            nn.Dropout(dropout_rate),
            nn.Linear(num_ftrs, num_classes),
        )
    
    def forward(self, x):
        return self.net(x)