MadhurGarg commited on
Commit
f70f71e
·
1 Parent(s): 0c2b813

Upload custom_resnet.py

Browse files
Files changed (1) hide show
  1. custom_resnet.py +85 -0
custom_resnet.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class LitResnet(LightningModule):
2
+ def __init__(self, lr=0.01, drop= 0.05, norm='BN',groupsize=1):
3
+ super().__init__()
4
+
5
+ self.save_hyperparameters()
6
+ self.num_classes =10
7
+ self.lr = lr
8
+
9
+ self.convblock1 = nn.Sequential(
10
+ nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(3, 3), padding=1, bias=False),
11
+ nn.ReLU(),
12
+ self.user_norm(norm,64,groupsize),
13
+ nn.Dropout(drop),
14
+ nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), padding=1, bias=False),
15
+ nn.MaxPool2d(2,2),
16
+ nn.ReLU(),
17
+ self.user_norm(norm,128,groupsize),
18
+ nn.Dropout(drop))
19
+ self.res1 = nn.Sequential(
20
+ nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), padding=1, bias=False),
21
+ nn.ReLU(),
22
+ self.user_norm(norm,128,groupsize),
23
+ nn.Dropout(drop),
24
+ nn.Conv2d(in_channels = 128, out_channels=128, kernel_size=(3, 3), padding=1, bias=False),
25
+ nn.ReLU(),
26
+ self.user_norm(norm,128,groupsize),
27
+ nn.Dropout(drop)
28
+ )
29
+
30
+ # CONVOLUTION BLOCK 2
31
+ # Layer 2
32
+ self.convblock2 = nn.Sequential(
33
+ nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), padding=1, bias=False),
34
+ nn.ReLU(),
35
+ self.user_norm(norm,256,groupsize),
36
+ nn.Dropout(drop),
37
+ nn.MaxPool2d(2,2),
38
+ nn.ReLU(),
39
+ self.user_norm(norm,256,groupsize),
40
+ nn.Dropout(drop)
41
+
42
+ )
43
+
44
+ # CONVOLUTION BLOCK 3
45
+ self.convblock3 = nn.Sequential(
46
+ nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), padding=1, bias=False),
47
+ nn.MaxPool2d(2,2),
48
+ nn.ReLU(),
49
+ self.user_norm(norm,512,groupsize),
50
+ nn.Dropout(drop)
51
+ )
52
+ self.res2 = nn.Sequential(
53
+ nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), padding=1, bias=False),
54
+ nn.ReLU(),
55
+ self.user_norm(norm,512,groupsize),
56
+ nn.Dropout(drop),
57
+ nn.Conv2d(in_channels = 512, out_channels=512, kernel_size=(3, 3), padding=1, bias=False),
58
+ nn.ReLU(),
59
+ self.user_norm(norm,512,groupsize),
60
+ nn.Dropout(drop))
61
+
62
+ # CONVOLUTION BLOCK 4
63
+ self.convblock4 = nn.Sequential(
64
+ nn.MaxPool2d(kernel_size=4),
65
+ nn.Conv2d(in_channels=512, out_channels=10, kernel_size=(1, 1), padding=0, bias=False))
66
+ self.model = nn.Sequential(self.convblock1 ,self.convblock2, self.convblock3, self.convblock4)
67
+
68
+ def user_norm(self, norm, channels,groupsize=1):
69
+ if norm == 'BN':
70
+ return nn.BatchNorm2d(channels)
71
+ elif norm == 'LN':
72
+ return nn.GroupNorm(1,channels) #(equivalent with LayerNorm)
73
+ elif norm == 'GN':
74
+ return nn.GroupNorm(groupsize,channels) #groups=2
75
+
76
+ def forward(self, x):
77
+
78
+ x = self.convblock1(x)
79
+ x = x + self.res1 (x)
80
+ x = self.convblock2(x)
81
+ x = self.convblock3(x)
82
+ x = x + self.res2 (x)
83
+ x = self.convblock4(x)
84
+ x = x.view(-1, 10)
85
+ return F.log_softmax(x, dim=-1)