jaiyeshchahar commited on
Commit
1c00e32
·
1 Parent(s): d07268b

Create custom_resnet.py

Browse files
Files changed (1) hide show
  1. custom_resnet.py +151 -0
custom_resnet.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+ import torch.nn as nn
3
+ from pytorch_lightning import LightningModule
4
+ from torch.optim.lr_scheduler import OneCycleLR
5
+ from torchmetrics.functional import accuracy
6
+ import torch
7
+
8
+ dropout_value = 0.1
9
+
10
+ class X(nn.Module):
11
+ def __init__(self, in_channels, out_channels):
12
+ super(X, self).__init__()
13
+
14
+ self.conv1 = nn.Sequential(
15
+ nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=1, padding=1,bias = False),
16
+ nn.MaxPool2d(kernel_size=2,stride=2),
17
+ nn.BatchNorm2d(out_channels),
18
+ nn.ReLU()
19
+ )
20
+
21
+ def forward(self, x):
22
+ return self.conv1(x)
23
+
24
+ class ResBlock(nn.Module):
25
+ def __init__(self, in_channels, out_channels):
26
+ super(ResBlock, self).__init__()
27
+
28
+ self.conv = nn.Sequential(
29
+ nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=1, padding=1,bias = False),
30
+ nn.BatchNorm2d(out_channels),
31
+ nn.ReLU()
32
+ )
33
+
34
+ def forward(self, x):
35
+ out = self.conv(x)
36
+ out = self.conv(out)
37
+ out = out + x
38
+ return out
39
+
40
+ class Net(nn.Module):
41
+ def __init__(self):
42
+ super(Net, self).__init__()
43
+
44
+ # Prep Layer
45
+ self.preplayer = nn.Sequential(
46
+ nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3,stride=1, padding=1,bias=False),
47
+ nn.BatchNorm2d(64),
48
+ nn.ReLU()
49
+ ) ## 32x32
50
+
51
+ # Layer 1
52
+ self.X1 = X(in_channels=64,out_channels=128) # 16x16
53
+ self.R1 = ResBlock(in_channels=128,out_channels=128) # 32x32
54
+
55
+ # Layer 2
56
+ self.X2 = X(in_channels=128,out_channels=256)
57
+
58
+ # Layer 3
59
+ self.X3 = X(in_channels=256,out_channels=512)
60
+ self.R3 = ResBlock(in_channels=512,out_channels=512)
61
+
62
+ # Max Pool
63
+ self.maxpool = nn.MaxPool2d(kernel_size=4, stride=1)
64
+
65
+ # FC
66
+ self.fc = nn.Linear(512,10)
67
+
68
+ def forward(self, x):
69
+ batch_size = x.shape[0]
70
+
71
+ out = self.preplayer(x)
72
+
73
+ # Layer 1
74
+ X = self.X1(out) ## 16x16
75
+ R1 = self.R1(X)
76
+
77
+
78
+ out = X + R1
79
+
80
+ # Layer 2
81
+ out = self.X2(out)
82
+
83
+ # Layer 3
84
+ X = self.X3(out)
85
+ R2 = self.R3(X)
86
+
87
+ out = X + R2
88
+
89
+ out = self.maxpool(out)
90
+
91
+ # FC
92
+ out = out.view(out.size(0),-1)
93
+ out = self.fc(out)
94
+
95
+ # return F.log_softmax(out, dim=-1)
96
+ return out.view(-1, 10)
97
+
98
+ class LitCustomResnet(LightningModule):
99
+ def __init__(self, lr = 0.05,batch_size=64):
100
+ super().__init__()
101
+ self.model = Net()
102
+ self.save_hyperparameters()
103
+ self.BATCH_SIZE=batch_size
104
+
105
+ def forward(self,x):
106
+ return self.model(x)
107
+
108
+ def training_step(self,batch,batch_id):
109
+ x,y = batch
110
+ logits = self(x)
111
+ loss = F.cross_entropy(logits,y)
112
+ self.log("training loss", loss)
113
+ return loss
114
+
115
+ def evaluate(self, batch, stage=None):
116
+ x,y = batch
117
+ logits = self(x)
118
+ loss = F.cross_entropy(logits, y)
119
+ preds = torch.argmax(logits, dim=1)
120
+
121
+ # print(preds.shape,y.shape)
122
+ acc= accuracy(preds,y, task = "multiclass", num_classes=10)
123
+
124
+ if stage:
125
+ self.log(f"{stage}_loss", loss, prog_bar=True)
126
+ self.log(f"{stage}_acc", acc, prog_bar=True)
127
+
128
+ def validation_step(self, batch, batch_idx):
129
+ self.evaluate(batch, "val")
130
+
131
+ def test_step(self, batch, batch_idx):
132
+ self.evaluate(batch, "test")
133
+
134
+ def configure_optimizers(self):
135
+ optimizer = torch.optim.SGD(
136
+ self.parameters(),
137
+ lr=self.hparams.lr,
138
+ momentum=0.9,
139
+ weight_decay=5e-4,
140
+ )
141
+ steps_per_epoch = 45000 // self.BATCH_SIZE
142
+ scheduler_dict = {
143
+ "scheduler": OneCycleLR(
144
+ optimizer,
145
+ 0.01,
146
+ epochs=self.trainer.max_epochs,
147
+ steps_per_epoch=steps_per_epoch,
148
+ ),
149
+ "interval": "step",
150
+ }
151
+ return {"optimizer": optimizer, "lr_scheduler": scheduler_dict}