Megatron17 commited on
Commit
eab3f1d
·
1 Parent(s): a5e6825
Files changed (2) hide show
  1. model.ckpt +3 -0
  2. model.py +123 -0
model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb839d75f52ec1d69e10d7e4c5cb7b703164833c04b52deabf58428e02bc8f33
3
+ size 78974911
model.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from pytorch_lightning import LightningModule, Trainer, LightningDataModule
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+ from torchmetrics import Accuracy
8
+ from torchvision import transforms
9
+ PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
10
+ class ResBlock(nn.Module):
11
+ def __init__(self, in_channels, out_channels,kernel_size=3, stride=1, padding=1, downsample = None):
12
+ super(ResBlock, self).__init__()
13
+ self.block1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size = kernel_size, stride = stride, padding = padding),
14
+ nn.BatchNorm2d(out_channels),
15
+ # nn.ReLU(inplace=False)
16
+ )
17
+ self.block2 = nn.Sequential(nn.Conv2d(out_channels, out_channels, kernel_size = kernel_size, stride = stride, padding = padding),
18
+ nn.BatchNorm2d(out_channels))
19
+
20
+ self.downsample = downsample
21
+ self.relu = nn.ReLU(inplace=False)
22
+ self.out_channels = out_channels
23
+
24
+ def forward(self, x):
25
+ residual = x
26
+ out = self.block1(x)
27
+ out = self.block2(out)
28
+ if self.downsample:
29
+ residual = self.downsample(x)
30
+ out+=residual
31
+ out = self.relu(out)
32
+ return out
33
+
34
+ class LightningDavidNet(LightningModule):
35
+
36
+ def __init__(self,data_dir=PATH_DATASETS, hidden_size=16, learning_rate=2e-4,kernel_size=3, stride=1, padding=1, downsample = None):
37
+ super().__init__()
38
+ self.learning_rate =learning_rate
39
+ self.data_dir = data_dir
40
+ self.hidden_size = hidden_size
41
+
42
+ # Hardcode some dataset specific attributes
43
+ self.num_classes = 10
44
+ self.prep = nn.Sequential(nn.Conv2d(3, 64, kernel_size = 3, stride = 1, padding = 1),
45
+ nn.BatchNorm2d(64),
46
+ nn.ReLU(inplace=False))
47
+ self.l1X = nn.Sequential(nn.Conv2d(64, 128, kernel_size = 3, stride = 1, padding = 1),
48
+ nn.MaxPool2d(kernel_size = 2),
49
+ nn.BatchNorm2d(128),
50
+ nn.ReLU(inplace=False))
51
+ self.r1 = ResBlock(128, 128,kernel_size=3, stride=1, padding=1, downsample = None)
52
+ self.l2X = nn.Sequential(nn.Conv2d(128, 256, kernel_size = 3, stride = 1, padding = 1),
53
+ nn.MaxPool2d(kernel_size = 2),
54
+ nn.BatchNorm2d(256),
55
+ nn.ReLU(inplace=False))
56
+ self.l3X = nn.Sequential(nn.Conv2d(256, 512, kernel_size = 3, stride = 1, padding = 1),
57
+ nn.MaxPool2d(kernel_size = 2),
58
+ nn.BatchNorm2d(512),
59
+ nn.ReLU(inplace=False))
60
+ self.r2 = ResBlock(512, 512,kernel_size=3, stride=1, padding=1, downsample = None)
61
+ self.maxPool = nn.MaxPool2d(kernel_size = 4)
62
+ self.fc1 = nn.Linear(512,10)
63
+
64
+ self.accuracy = Accuracy(task = "multiclass",num_classes = self.num_classes)
65
+
66
+ def forward(self, x):
67
+ x = self.prep(x)
68
+ x = self.l1X(x)
69
+ residual = x
70
+ x = self.r1(x)
71
+ x= residual+ x
72
+ x = self.l2X(x)
73
+ x = self.l3X(x)
74
+ residual = x
75
+ x = self.r2(x)
76
+ x=residual+x
77
+ x = self.maxPool(x)
78
+ # # x = self.avgpool(x)
79
+ x = x.view(-1,512)
80
+ x = self.fc1(x)
81
+ x = F.log_softmax(x, dim=1)
82
+ return x
83
+
84
+ def training_step(self, batch, batch_idx):
85
+ x,y = batch
86
+ loss = F.cross_entropy(self(x), y)
87
+ self.log("train_loss", loss)
88
+ return loss
89
+
90
+ def configure_optimizers(self):
91
+ optimizer = torch.optim.Adam(self.parameters(), lr=0.03, weight_decay=1e-4)
92
+ steps_per_epoch = len(train_loader)
93
+ scheduler_dict = {
94
+ "scheduler": torch.optim.lr_scheduler.OneCycleLR(
95
+ optimizer,
96
+ 0.1,
97
+ epochs=self.trainer.max_epochs,
98
+ steps_per_epoch=steps_per_epoch,
99
+ ),
100
+ "interval": "step",
101
+ }
102
+ return {"optimizer": optimizer, "lr_scheduler": scheduler_dict}
103
+
104
+ # lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, step_size=1)
105
+ # return [optimizer], [lr_scheduler]
106
+ # return optimizer
107
+
108
+ def validation_step(self, batch, batch_idx):
109
+ x,y = batch
110
+ logits = self(x)
111
+ loss = F.cross_entropy(logits, y)
112
+ preds = torch.argmax(logits,dim = 1)
113
+ self.accuracy(preds,y)
114
+ self.log("val_loss",loss, prog_bar = True)
115
+ self.log("val_arr",self.accuracy,prog_bar = True)
116
+
117
+ def test_step(self,batch,batch_idx):
118
+ return self.validation_step(batch,batch_idx)
119
+
120
+ def predict_step(self, batch, batch_idx, dataloader_idx=0):
121
+ x,y = batch
122
+ output = self(x)
123
+ return x,y,output.argmax(dim=1),output