import torch import torch.nn as nn import torch.nn.functional as F import lightning as L class BasicBlock(nn.Module): expansion = 1 # ResNet18/34 使用 expansion=1 def __init__(self, in_channels, out_channels, stride=1): super().__init__() self.conv1 = nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False ) self.bn1 = nn.BatchNorm2d(out_channels) self.conv2 = nn.Conv2d( out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False ) self.bn2 = nn.BatchNorm2d(out_channels) # Downsample for shape mismatch self.shortcut = nn.Sequential() if stride != 1 or in_channels != out_channels: self.shortcut = nn.Sequential( nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=stride, bias=False ), nn.BatchNorm2d(out_channels) ) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += self.shortcut(x) out = F.relu(out) return out class ResNet18_CIFAR10(nn.Module): def __init__(self, num_classes=10): super().__init__() # 第一层换成 CIFAR10 友好的 3x3 conv,去掉 maxpool self.conv1 = nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(64) # ResNet stages self.layer1 = self._make_layer(64, 64, num_blocks=2, stride=1) self.layer2 = self._make_layer(64, 128, num_blocks=2, stride=2) # 32x32 -> 16x16 self.layer3 = self._make_layer(128, 256, num_blocks=2, stride=2) # 16x16 -> 8x8 self.layer4 = self._make_layer(256, 512, num_blocks=2, stride=2) # 8x8 -> 4x4 self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Sequential( nn.Dropout(0.2), nn.Linear(512 * BasicBlock.expansion, num_classes) ) def _make_layer(self, in_c, out_c, num_blocks, stride): layers = [] layers.append(BasicBlock(in_c, out_c, stride)) for _ in range(1, num_blocks): layers.append(BasicBlock(out_c, out_c, stride=1)) # 后续 block stride=1 return nn.Sequential(*layers) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) # 注意这里有relu out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) out = self.avg_pool(out) # [B, 512, 1, 1] out = torch.flatten(out, 1) # [B, 512] out = self.fc(out) # [B, num_classes] return out class CIFARCNN(L.LightningModule): def __init__(self, lr=1e-3): super().__init__() self.save_hyperparameters() self.example_input_array = torch.Tensor(64, 3, 32, 32) self.net = ResNet18_CIFAR10(num_classes=10) self.loss_fn = nn.CrossEntropyLoss() def forward(self, x): return self.net(x) def training_step(self, batch, batch_idx): # _代表batch_idx,这里不需要用到 x, y = batch logits = self(x) loss = self.loss_fn(logits, y) preds = torch.argmax(logits, dim=1) acc = (preds == y).float().mean() self.log("train_loss", loss, on_step=True, prog_bar=True) # 在每个step记录 self.log("train_acc", acc, on_step=True, prog_bar=True) return loss def validation_step(self, batch, batch_idx): x, y = batch logits = self(x) loss = self.loss_fn(logits, y) preds = torch.argmax(logits, dim=1) acc = (preds == y).float().mean() # log 专门给 validation 用: self.log("val_loss", loss, prog_bar=True, sync_dist=True) # 把val_loss显示在lightning的progress bar上; sync_dist=True表示在分布式训练时同步各个设备上的指标 self.log("val_acc", acc, prog_bar=True, sync_dist=True) return {"val_loss": loss, "val_acc": acc} def test_step(self, batch, batch_idx): x, y = batch logits = self(x) loss = self.loss_fn(logits, y) preds = torch.argmax(logits, dim=1) acc = (preds == y).float().mean() self.log("test_loss", loss, prog_bar=True) self.log("test_acc", acc, prog_bar=True) return {"test_loss": loss, "test_acc": acc} def predict_step(self, batch, batch_idx, dataloader_idx=0): x, _ = batch return self(x) def configure_optimizers(self): optimizer = torch.optim.SGD( self.parameters(), lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4 ) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=self.trainer.max_epochs ) return {"optimizer": optimizer, "lr_scheduler": scheduler} if __name__ == "__main__": # 简单测试前向传播 model = CIFARCNN() x = torch.randn(4, 3, 32, 32).to(model.device) logits = model(x) print(logits.shape) # [4, 10]