import os import torch from pytorch_lightning import LightningModule, Trainer, LightningDataModule from torch import nn from torch.nn import functional as F from torchmetrics import Accuracy from torchvision import transforms PATH_DATASETS = os.environ.get("PATH_DATASETS", ".") class ResBlock(nn.Module): def __init__(self, in_channels, out_channels,kernel_size=3, stride=1, padding=1, downsample = None): super(ResBlock, self).__init__() self.block1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size = kernel_size, stride = stride, padding = padding), nn.BatchNorm2d(out_channels), # nn.ReLU(inplace=False) ) self.block2 = nn.Sequential(nn.Conv2d(out_channels, out_channels, kernel_size = kernel_size, stride = stride, padding = padding), nn.BatchNorm2d(out_channels)) self.downsample = downsample self.relu = nn.ReLU(inplace=False) self.out_channels = out_channels def forward(self, x): residual = x out = self.block1(x) out = self.block2(out) if self.downsample: residual = self.downsample(x) out+=residual out = self.relu(out) return out class LightningDavidNet(LightningModule): def __init__(self,data_dir=PATH_DATASETS, hidden_size=16, learning_rate=2e-4,kernel_size=3, stride=1, padding=1, downsample = None): super().__init__() self.learning_rate =learning_rate self.data_dir = data_dir self.hidden_size = hidden_size # Hardcode some dataset specific attributes self.num_classes = 10 self.prep = nn.Sequential(nn.Conv2d(3, 64, kernel_size = 3, stride = 1, padding = 1), nn.BatchNorm2d(64), nn.ReLU(inplace=False)) self.l1X = nn.Sequential(nn.Conv2d(64, 128, kernel_size = 3, stride = 1, padding = 1), nn.MaxPool2d(kernel_size = 2), nn.BatchNorm2d(128), nn.ReLU(inplace=False)) self.r1 = ResBlock(128, 128,kernel_size=3, stride=1, padding=1, downsample = None) self.l2X = nn.Sequential(nn.Conv2d(128, 256, kernel_size = 3, stride = 1, padding = 1), nn.MaxPool2d(kernel_size = 2), nn.BatchNorm2d(256), nn.ReLU(inplace=False)) self.l3X = nn.Sequential(nn.Conv2d(256, 512, kernel_size = 3, stride = 1, padding = 1), nn.MaxPool2d(kernel_size = 2), nn.BatchNorm2d(512), nn.ReLU(inplace=False)) self.r2 = ResBlock(512, 512,kernel_size=3, stride=1, padding=1, downsample = None) self.maxPool = nn.MaxPool2d(kernel_size = 4) self.fc1 = nn.Linear(512,10) self.accuracy = Accuracy(task = "multiclass",num_classes = self.num_classes) def forward(self, x): x = self.prep(x) x = self.l1X(x) residual = x x = self.r1(x) x= residual+ x x = self.l2X(x) x = self.l3X(x) residual = x x = self.r2(x) x=residual+x x = self.maxPool(x) x = x.view(-1,512) x = self.fc1(x) x = F.log_softmax(x, dim=1) return x def training_step(self, batch, batch_idx): x,y = batch y_pred = self(x) loss = F.cross_entropy(y_pred, y) acc = self.accuracy(y_pred, y) self.log('train_loss', loss, prog_bar=True, on_step=False, on_epoch=True) self.log('train_acc', acc, prog_bar=True, on_step=False, on_epoch=True) return loss def evaluate(self, batch, stage=None): x, y = batch y_test_pred = self(x) loss = F.cross_entropy(y_test_pred, y) acc = self.accuracy(y_test_pred, y) if stage: self.log(f"{stage}_loss", loss, prog_bar=True) self.log(f"{stage}_acc", acc, prog_bar=True) def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=0.03, weight_decay=1e-4) scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr= 5.38E-02, #self.hparams.lr, pct_start = 5/self.trainer.max_epochs, epochs=self.trainer.max_epochs, steps_per_epoch=len(train_loader), div_factor=100,verbose=False, three_phase=False ) return ([optimizer],[scheduler]) def validation_step(self, batch, batch_idx): self.evaluate(batch, "val") def test_step(self,batch,batch_idx): self.evaluate(batch, "test")