File size: 929 Bytes
8e500b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b10a4d
8e500b2
 
 
 
 
 
 
 
 
7b10a4d
8e500b2
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import torch.nn as nn
import torch.nn.functional as F 

class modelOne(nn.Module) :
    def __init__(self, noOfClasses=39):
        super(modelOne, self).__init__()

        self.conv1 = nn.Conv2d(3, 6, 5)
        self.batchNorm1 = nn.BatchNorm2d(6)
        self.pool = nn.MaxPool2d(2, 2)

        self.conv2 = nn.Conv2d(6, 16, 5, padding=2)
        self.batchNorm2 = nn.BatchNorm2d(16)

        self.fc1 = nn.Linear(63504, 512)
        self.dropout = nn.Dropout(0.5)

        self.fc2 = nn.Linear(512, 84)
        self.fc3 = nn.Linear(84, noOfClasses)

    def forward(self, x) :
        x = self.pool(F.relu(self.batchNorm1(self.conv1(x))))
        x = self.pool(F.relu(self.batchNorm2(self.conv2(x))))
        x = torch.flatten(x, 1)
        print("Flattened size:", x.shape[1]) 
        x = self.dropout(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x