File size: 4,886 Bytes
eab3f1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4f516c
eab3f1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4f516c
 
 
 
 
 
 
eab3f1d
e4f516c
 
 
 
 
 
 
 
 
 
eab3f1d
 
 
e4f516c
eab3f1d
e4f516c
 
eab3f1d
e4f516c
 
 
 
 
eab3f1d
 
e4f516c
eab3f1d
 
e4f516c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os

import torch
from pytorch_lightning import LightningModule, Trainer, LightningDataModule
from torch import nn
from torch.nn import functional as F
from torchmetrics import Accuracy
from torchvision import transforms
PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels,kernel_size=3, stride=1, padding=1, downsample = None):
        super(ResBlock, self).__init__()
        self.block1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size = kernel_size, stride = stride, padding = padding),
                                    nn.BatchNorm2d(out_channels),
                                    # nn.ReLU(inplace=False)
                                    )
        self.block2 = nn.Sequential(nn.Conv2d(out_channels, out_channels, kernel_size = kernel_size, stride = stride, padding = padding),
                                    nn.BatchNorm2d(out_channels))

        self.downsample = downsample
        self.relu = nn.ReLU(inplace=False)
        self.out_channels = out_channels

    def forward(self, x):
        residual = x
        out = self.block1(x)
        out = self.block2(out)
        if self.downsample:
            residual = self.downsample(x)
        out+=residual
        out = self.relu(out)
        return out


class LightningDavidNet(LightningModule):

    def __init__(self,data_dir=PATH_DATASETS, hidden_size=16, learning_rate=2e-4,kernel_size=3, stride=1, padding=1, downsample = None):
        super().__init__()
        self.learning_rate =learning_rate
        self.data_dir = data_dir
        self.hidden_size = hidden_size

        # Hardcode some dataset specific attributes
        self.num_classes = 10
        self.prep = nn.Sequential(nn.Conv2d(3, 64, kernel_size = 3, stride = 1, padding = 1),
                                nn.BatchNorm2d(64),
                                nn.ReLU(inplace=False))
        self.l1X = nn.Sequential(nn.Conv2d(64, 128, kernel_size = 3, stride = 1, padding = 1),
                                nn.MaxPool2d(kernel_size = 2),
                                nn.BatchNorm2d(128),
                                nn.ReLU(inplace=False))
        self.r1 = ResBlock(128, 128,kernel_size=3, stride=1, padding=1, downsample = None)
        self.l2X = nn.Sequential(nn.Conv2d(128, 256, kernel_size = 3, stride = 1, padding = 1),
                                nn.MaxPool2d(kernel_size = 2),
                                nn.BatchNorm2d(256),
                                nn.ReLU(inplace=False))
        self.l3X = nn.Sequential(nn.Conv2d(256, 512, kernel_size = 3, stride = 1, padding = 1),
                                nn.MaxPool2d(kernel_size = 2),
                                nn.BatchNorm2d(512),
                                nn.ReLU(inplace=False))
        self.r2 = ResBlock(512, 512,kernel_size=3, stride=1, padding=1, downsample = None)
        self.maxPool = nn.MaxPool2d(kernel_size = 4)
        self.fc1 = nn.Linear(512,10)

        self.accuracy = Accuracy(task = "multiclass",num_classes = self.num_classes)

    def forward(self, x):
        x = self.prep(x)
        x = self.l1X(x)
        residual = x
        x = self.r1(x)
        x= residual+ x
        x = self.l2X(x)
        x = self.l3X(x)
        residual = x
        x = self.r2(x)
        x=residual+x
        x = self.maxPool(x)
        x = x.view(-1,512)
        x = self.fc1(x)
        x = F.log_softmax(x, dim=1)
        return x

    def training_step(self, batch, batch_idx):
        x,y = batch
        y_pred = self(x)
        loss = F.cross_entropy(y_pred, y)
        acc  = self.accuracy(y_pred, y)

        self.log('train_loss', loss, prog_bar=True, on_step=False, on_epoch=True)
        self.log('train_acc', acc, prog_bar=True, on_step=False, on_epoch=True)

        return loss
        
    def evaluate(self, batch, stage=None):
        x, y = batch
        y_test_pred = self(x)
        loss = F.cross_entropy(y_test_pred, y)
        acc  = self.accuracy(y_test_pred, y)

        if stage:
            self.log(f"{stage}_loss", loss, prog_bar=True)
            self.log(f"{stage}_acc", acc, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.03, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
                optimizer,
                max_lr= 5.38E-02, #self.hparams.lr,
                pct_start = 5/self.trainer.max_epochs,
                epochs=self.trainer.max_epochs,
                steps_per_epoch=len(train_loader),
                div_factor=100,verbose=False,
                three_phase=False
            )
        return ([optimizer],[scheduler])

    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, "val")

    def test_step(self,batch,batch_idx):
        self.evaluate(batch, "test")