Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| def ConvBlock(in_channels, out_channels, pool=False): | |
| layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(inplace=True)] | |
| if pool: | |
| layers.append(nn.MaxPool2d(4)) | |
| return nn.Sequential(*layers) | |
| class ImageClassificationBase(nn.Module): | |
| def training_step(self, batch): | |
| images, labels = batch | |
| out = self(images) | |
| loss = F.cross_entropy(out, labels) | |
| return loss | |
| def validation_step(self, batch): | |
| images, labels = batch | |
| out = self(images) | |
| loss = F.cross_entropy(out, labels) | |
| acc = (out.argmax(dim=1) == labels).float().mean() | |
| return {"val_loss": loss.detach(), "val_accuracy": acc} | |
| def validation_epoch_end(self, outputs): | |
| batch_losses = [x["val_loss"] for x in outputs] | |
| batch_accuracy = [x["val_accuracy"] for x in outputs] | |
| epoch_loss = torch.stack(batch_losses).mean() | |
| epoch_accuracy = torch.stack(batch_accuracy).mean() | |
| return {"val_loss": epoch_loss, "val_accuracy": epoch_accuracy} | |
| def epoch_end(self, epoch, result): | |
| print(f"Epoch [{epoch}], train_loss: {result['train_loss']:.4f}, val_loss: {result['val_loss']:.4f}, val_acc: {result['val_accuracy']:.4f}") | |
| class ResNet9(ImageClassificationBase): | |
| def __init__(self, in_channels, num_classes): | |
| super().__init__() | |
| self.conv1 = ConvBlock(in_channels, 64) | |
| self.conv2 = ConvBlock(64, 128, pool=True) | |
| self.res1 = nn.Sequential(ConvBlock(128, 128), ConvBlock(128, 128)) | |
| self.conv3 = ConvBlock(128, 256, pool=True) | |
| self.conv4 = ConvBlock(256, 512, pool=True) | |
| self.res2 = nn.Sequential(ConvBlock(512, 512), ConvBlock(512, 512)) | |
| self.classifier = nn.Sequential( | |
| nn.MaxPool2d(4), | |
| nn.Flatten(), | |
| nn.Linear(512, num_classes) | |
| ) | |
| def forward(self, xb): | |
| out = self.conv1(xb) | |
| out = self.conv2(out) | |
| out = self.res1(out) + out | |
| out = self.conv3(out) | |
| out = self.conv4(out) | |
| out = self.res2(out) + out | |
| out = self.classifier(out) | |
| return out | |