Spaces:
Runtime error
Runtime error
| import torch.nn.functional as F | |
| import torch.nn as nn | |
| from pytorch_lightning import LightningModule | |
| from torch.optim.lr_scheduler import OneCycleLR | |
| from torchmetrics.functional import accuracy | |
| import torch | |
| dropout_value = 0.1 | |
| class X(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super(X, self).__init__() | |
| self.conv1 = nn.Sequential( | |
| nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=1, padding=1,bias = False), | |
| nn.MaxPool2d(kernel_size=2,stride=2), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU() | |
| ) | |
| def forward(self, x): | |
| return self.conv1(x) | |
| class ResBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super(ResBlock, self).__init__() | |
| self.conv = nn.Sequential( | |
| nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=1, padding=1,bias = False), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU() | |
| ) | |
| def forward(self, x): | |
| out = self.conv(x) | |
| out = self.conv(out) | |
| out = out + x | |
| return out | |
| class Net(nn.Module): | |
| def __init__(self): | |
| super(Net, self).__init__() | |
| # Prep Layer | |
| self.preplayer = nn.Sequential( | |
| nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3,stride=1, padding=1,bias=False), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU() | |
| ) ## 32x32 | |
| # Layer 1 | |
| self.X1 = X(in_channels=64,out_channels=128) # 16x16 | |
| self.R1 = ResBlock(in_channels=128,out_channels=128) # 32x32 | |
| # Layer 2 | |
| self.X2 = X(in_channels=128,out_channels=256) | |
| # Layer 3 | |
| self.X3 = X(in_channels=256,out_channels=512) | |
| self.R3 = ResBlock(in_channels=512,out_channels=512) | |
| # Max Pool | |
| self.maxpool = nn.MaxPool2d(kernel_size=4, stride=1) | |
| # FC | |
| self.fc = nn.Linear(512,10) | |
| def forward(self, x): | |
| batch_size = x.shape[0] | |
| out = self.preplayer(x) | |
| # Layer 1 | |
| X = self.X1(out) ## 16x16 | |
| R1 = self.R1(X) | |
| out = X + R1 | |
| # Layer 2 | |
| out = self.X2(out) | |
| # Layer 3 | |
| X = self.X3(out) | |
| R2 = self.R3(X) | |
| out = X + R2 | |
| out = self.maxpool(out) | |
| # FC | |
| out = out.view(out.size(0),-1) | |
| out = self.fc(out) | |
| # return F.log_softmax(out, dim=-1) | |
| return out.view(-1, 10) | |
| class LitCustomResnet(LightningModule): | |
| def __init__(self, lr = 0.05,batch_size=64): | |
| super().__init__() | |
| self.model = Net() | |
| self.save_hyperparameters() | |
| self.BATCH_SIZE=batch_size | |
| def forward(self,x): | |
| return self.model(x) | |
| def training_step(self,batch,batch_id): | |
| x,y = batch | |
| logits = self(x) | |
| loss = F.cross_entropy(logits,y) | |
| self.log("training loss", loss) | |
| return loss | |
| def evaluate(self, batch, stage=None): | |
| x,y = batch | |
| logits = self(x) | |
| loss = F.cross_entropy(logits, y) | |
| preds = torch.argmax(logits, dim=1) | |
| # print(preds.shape,y.shape) | |
| acc= accuracy(preds,y, task = "multiclass", num_classes=10) | |
| if stage: | |
| self.log(f"{stage}_loss", loss, prog_bar=True) | |
| self.log(f"{stage}_acc", acc, prog_bar=True) | |
| def validation_step(self, batch, batch_idx): | |
| self.evaluate(batch, "val") | |
| def test_step(self, batch, batch_idx): | |
| self.evaluate(batch, "test") | |
| def configure_optimizers(self): | |
| optimizer = torch.optim.SGD( | |
| self.parameters(), | |
| lr=self.hparams.lr, | |
| momentum=0.9, | |
| weight_decay=5e-4, | |
| ) | |
| steps_per_epoch = 45000 // self.BATCH_SIZE | |
| scheduler_dict = { | |
| "scheduler": OneCycleLR( | |
| optimizer, | |
| 0.01, | |
| epochs=self.trainer.max_epochs, | |
| steps_per_epoch=steps_per_epoch, | |
| ), | |
| "interval": "step", | |
| } | |
| return {"optimizer": optimizer, "lr_scheduler": scheduler_dict} | |