Megatron17's picture
Update model.py
e4f516c
import os
import torch
from pytorch_lightning import LightningModule, Trainer, LightningDataModule
from torch import nn
from torch.nn import functional as F
from torchmetrics import Accuracy
from torchvision import transforms
PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels,kernel_size=3, stride=1, padding=1, downsample = None):
super(ResBlock, self).__init__()
self.block1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size = kernel_size, stride = stride, padding = padding),
nn.BatchNorm2d(out_channels),
# nn.ReLU(inplace=False)
)
self.block2 = nn.Sequential(nn.Conv2d(out_channels, out_channels, kernel_size = kernel_size, stride = stride, padding = padding),
nn.BatchNorm2d(out_channels))
self.downsample = downsample
self.relu = nn.ReLU(inplace=False)
self.out_channels = out_channels
def forward(self, x):
residual = x
out = self.block1(x)
out = self.block2(out)
if self.downsample:
residual = self.downsample(x)
out+=residual
out = self.relu(out)
return out
class LightningDavidNet(LightningModule):
def __init__(self,data_dir=PATH_DATASETS, hidden_size=16, learning_rate=2e-4,kernel_size=3, stride=1, padding=1, downsample = None):
super().__init__()
self.learning_rate =learning_rate
self.data_dir = data_dir
self.hidden_size = hidden_size
# Hardcode some dataset specific attributes
self.num_classes = 10
self.prep = nn.Sequential(nn.Conv2d(3, 64, kernel_size = 3, stride = 1, padding = 1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=False))
self.l1X = nn.Sequential(nn.Conv2d(64, 128, kernel_size = 3, stride = 1, padding = 1),
nn.MaxPool2d(kernel_size = 2),
nn.BatchNorm2d(128),
nn.ReLU(inplace=False))
self.r1 = ResBlock(128, 128,kernel_size=3, stride=1, padding=1, downsample = None)
self.l2X = nn.Sequential(nn.Conv2d(128, 256, kernel_size = 3, stride = 1, padding = 1),
nn.MaxPool2d(kernel_size = 2),
nn.BatchNorm2d(256),
nn.ReLU(inplace=False))
self.l3X = nn.Sequential(nn.Conv2d(256, 512, kernel_size = 3, stride = 1, padding = 1),
nn.MaxPool2d(kernel_size = 2),
nn.BatchNorm2d(512),
nn.ReLU(inplace=False))
self.r2 = ResBlock(512, 512,kernel_size=3, stride=1, padding=1, downsample = None)
self.maxPool = nn.MaxPool2d(kernel_size = 4)
self.fc1 = nn.Linear(512,10)
self.accuracy = Accuracy(task = "multiclass",num_classes = self.num_classes)
def forward(self, x):
x = self.prep(x)
x = self.l1X(x)
residual = x
x = self.r1(x)
x= residual+ x
x = self.l2X(x)
x = self.l3X(x)
residual = x
x = self.r2(x)
x=residual+x
x = self.maxPool(x)
x = x.view(-1,512)
x = self.fc1(x)
x = F.log_softmax(x, dim=1)
return x
def training_step(self, batch, batch_idx):
x,y = batch
y_pred = self(x)
loss = F.cross_entropy(y_pred, y)
acc = self.accuracy(y_pred, y)
self.log('train_loss', loss, prog_bar=True, on_step=False, on_epoch=True)
self.log('train_acc', acc, prog_bar=True, on_step=False, on_epoch=True)
return loss
def evaluate(self, batch, stage=None):
x, y = batch
y_test_pred = self(x)
loss = F.cross_entropy(y_test_pred, y)
acc = self.accuracy(y_test_pred, y)
if stage:
self.log(f"{stage}_loss", loss, prog_bar=True)
self.log(f"{stage}_acc", acc, prog_bar=True)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=0.03, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr= 5.38E-02, #self.hparams.lr,
pct_start = 5/self.trainer.max_epochs,
epochs=self.trainer.max_epochs,
steps_per_epoch=len(train_loader),
div_factor=100,verbose=False,
three_phase=False
)
return ([optimizer],[scheduler])
def validation_step(self, batch, batch_idx):
self.evaluate(batch, "val")
def test_step(self,batch,batch_idx):
self.evaluate(batch, "test")