jasonsfmeitian commited on
Commit
8d99de9
·
verified ·
1 Parent(s): f28876e

Update models/jason_cnn.py

Browse files
Files changed (1) hide show
  1. models/jason_cnn.py +13 -4
models/jason_cnn.py CHANGED
@@ -1,14 +1,17 @@
1
  import torch.nn as nn
2
  import torch.nn.functional as F
3
 
4
- class CNN(nn.Module):
5
- def __init__(self, in_channels=1, num_classes=7, filters=(16, 32), dropout=0.25):
6
- super(CNN, self).__init__()
 
7
  self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=filters[0], kernel_size=3, padding=1)
8
  self.conv2 = nn.Conv2d(in_channels=filters[0], out_channels=filters[1], kernel_size=3, padding=1)
 
 
9
  self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
10
  self.dropout = nn.Dropout(dropout)
11
- self.fc1 = nn.Linear(filters[1] * 12 * 12, num_classes)
12
 
13
  def forward(self, x):
14
  x = F.relu(self.conv1(x))
@@ -17,6 +20,12 @@ class CNN(nn.Module):
17
  x = F.relu(self.conv2(x))
18
  x = self.pool(x)
19
  x = self.dropout(x)
 
 
 
 
 
 
20
  x = x.reshape(x.shape[0], -1)
21
  x = self.fc1(x)
22
  return x
 
1
  import torch.nn as nn
2
  import torch.nn.functional as F
3
 
4
+ class CNN_4Layer(nn.Module):
5
+ def __init__(self, in_channels, num_classes, filters=(16,32,64,128), dropout=0.25):
6
+ super(CNN_4Layer, self).__init__()
7
+
8
  self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=filters[0], kernel_size=3, padding=1)
9
  self.conv2 = nn.Conv2d(in_channels=filters[0], out_channels=filters[1], kernel_size=3, padding=1)
10
+ self.conv3 = nn.Conv2d(in_channels=filters[1], out_channels=filters[2], kernel_size=3, padding=1)
11
+ self.conv4 = nn.Conv2d(in_channels=filters[2], out_channels=filters[3], kernel_size=3, padding=1)
12
  self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
13
  self.dropout = nn.Dropout(dropout)
14
+ self.fc1 = nn.Linear(filters[3] * 3 * 3, num_classes)
15
 
16
  def forward(self, x):
17
  x = F.relu(self.conv1(x))
 
20
  x = F.relu(self.conv2(x))
21
  x = self.pool(x)
22
  x = self.dropout(x)
23
+ x = F.relu(self.conv3(x))
24
+ x = self.pool(x)
25
+ x = self.dropout(x)
26
+ x = F.relu(self.conv4(x))
27
+ x = self.pool(x)
28
+ x = self.dropout(x)
29
  x = x.reshape(x.shape[0], -1)
30
  x = self.fc1(x)
31
  return x