Omarrr7 commited on
Commit
276bda5
·
verified ·
1 Parent(s): e100393

Upload CNN_model.py

Browse files
Files changed (1) hide show
  1. CNN_model.py +35 -0
CNN_model.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class BasicCNN(nn.Module):
6
+ def __init__(self, num_classes=39):
7
+ super().__init__()
8
+
9
+ self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
10
+ self.bn1 = nn.BatchNorm2d(32)
11
+ self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
12
+ self.bn2 = nn.BatchNorm2d(64)
13
+ self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
14
+ self.bn3 = nn.BatchNorm2d(128)
15
+ self.conv4 = nn.Conv2d(128, 256, 3, padding=1)
16
+ self.bn4 = nn.BatchNorm2d(256)
17
+
18
+ self.gap = nn.AdaptiveAvgPool2d((1, 1)) #global average pooling
19
+ self.fc = nn.Linear(256, num_classes) #fc classifier
20
+ self.dropout = nn.Dropout(0.3) #regularise
21
+
22
+ def forward(self, x):
23
+ x = F.relu(self.bn1(self.conv1(x)))
24
+ x = F.max_pool2d(x, 2)
25
+ x = F.relu(self.bn2(self.conv2(x)))
26
+ x = F.max_pool2d(x, 2)
27
+ x = F.relu(self.bn3(self.conv3(x)))
28
+ x = F.max_pool2d(x, 2)
29
+ x = F.relu(self.bn4(self.conv4(x)))
30
+ x = F.max_pool2d(x, 2)
31
+ x = self.gap(x)
32
+ x = torch.flatten(x, 1)
33
+ x = self.dropout(x)
34
+ x = self.fc(x)
35
+ return x