Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| import lightning as L | |
| from torchmetrics import Accuracy | |
| from typing import Any | |
| from utils.common import one_cycle_lr | |
| class ResidualBlock(L.LightningModule): | |
| def __init__(self, channels): | |
| super(ResidualBlock, self).__init__() | |
| self.residual_block = nn.Sequential( | |
| nn.Conv2d( | |
| in_channels=channels, | |
| out_channels=channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| bias=False, | |
| ), | |
| nn.BatchNorm2d(channels), | |
| nn.ReLU(), | |
| nn.Conv2d( | |
| in_channels=channels, | |
| out_channels=channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| bias=False, | |
| ), | |
| nn.BatchNorm2d(channels), | |
| nn.ReLU(), | |
| ) | |
| def forward(self, x): | |
| return x + self.residual_block(x) | |
| class ResNet(L.LightningModule): | |
| def __init__( | |
| self, batch_size=512, shuffle=True, num_workers=4, learning_rate=0.003, scheduler_steps=None, maxlr=None, epochs=None | |
| ): | |
| super(ResNet, self).__init__() | |
| self.data_dir = "./data" | |
| self.batch_size = batch_size | |
| self.shuffle = shuffle | |
| self.num_workers = num_workers | |
| self.learning_rate = learning_rate | |
| self.scheduler_steps = scheduler_steps | |
| self.maxlr = maxlr if maxlr is not None else learning_rate | |
| self.epochs = epochs | |
| self.prep = nn.Sequential( | |
| nn.Conv2d( | |
| in_channels=3, | |
| out_channels=64, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| bias=False, | |
| ), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(), | |
| ) | |
| self.layer1 = nn.Sequential( | |
| nn.Conv2d( | |
| in_channels=64, | |
| out_channels=128, | |
| kernel_size=3, | |
| padding=1, | |
| stride=1, | |
| bias=False, | |
| ), | |
| nn.MaxPool2d(kernel_size=2), | |
| nn.BatchNorm2d(128), | |
| nn.ReLU(), | |
| ResidualBlock(channels=128), | |
| ) | |
| self.layer2 = nn.Sequential( | |
| nn.Conv2d( | |
| in_channels=128, | |
| out_channels=256, | |
| kernel_size=3, | |
| padding=1, | |
| stride=1, | |
| bias=False, | |
| ), | |
| nn.MaxPool2d(kernel_size=2), | |
| nn.BatchNorm2d(256), | |
| nn.ReLU(), | |
| ) | |
| self.layer3 = nn.Sequential( | |
| nn.Conv2d( | |
| in_channels=256, | |
| out_channels=512, | |
| kernel_size=3, | |
| padding=1, | |
| stride=1, | |
| bias=False, | |
| ), | |
| nn.MaxPool2d(kernel_size=2), | |
| nn.BatchNorm2d(512), | |
| nn.ReLU(), | |
| ResidualBlock(channels=512), | |
| ) | |
| self.pool = nn.MaxPool2d(kernel_size=4) | |
| self.fc = nn.Linear(in_features=512, out_features=10, bias=False) | |
| self.softmax = nn.Softmax(dim=-1) | |
| self.accuracy = Accuracy(task="multiclass", num_classes=10) | |
| def forward(self, x): | |
| x = self.prep(x) | |
| x = self.layer1(x) | |
| x = self.layer2(x) | |
| x = self.layer3(x) | |
| x = self.pool(x) | |
| x = x.view(-1, 512) | |
| x = self.fc(x) | |
| # x = self.softmax(x) | |
| return x | |
| def configure_optimizers(self) -> Any: | |
| optimizer = torch.optim.Adam( | |
| self.parameters(), lr=self.learning_rate, weight_decay=1e-4 | |
| ) | |
| scheduler = one_cycle_lr( | |
| optimizer=optimizer, maxlr=self.maxlr, steps=self.scheduler_steps, epochs=self.epochs | |
| ) | |
| return {"optimizer": optimizer, | |
| "lr_scheduler": {"scheduler": scheduler, | |
| "interval": "step"}} | |
| def training_step(self, batch, batch_idx): | |
| X, y = batch | |
| y_pred = self(X) | |
| loss = nn.CrossEntropyLoss()(y_pred, y) | |
| preds = torch.argmax(y_pred, dim=1) | |
| accuracy = self.accuracy(preds, y) | |
| self.log_dict({"train_loss": loss, "train_acc": accuracy}, prog_bar=True) | |
| return loss | |
| def validation_step(self, batch, batch_idx): | |
| X, y = batch | |
| y_pred = self(X) | |
| loss = nn.CrossEntropyLoss(reduction="sum")(y_pred, y) | |
| preds = torch.argmax(y_pred, dim=1) | |
| accuracy = self.accuracy(preds, y) | |
| self.log_dict({"val_loss": loss, "val_acc": accuracy}, prog_bar=True) | |
| return loss | |
| def test_step(self, batch, batch_idx): | |
| X, y = batch | |
| y_pred = self(X) | |
| loss = nn.CrossEntropyLoss(reduction="sum")(y_pred, y) | |
| preds = torch.argmax(y_pred, dim=1) | |
| accuracy = self.accuracy(preds, y) | |
| self.log_dict({"test_loss": loss, "test_acc": accuracy}, prog_bar=True) |