Rhodham96 commited on
Commit
8f070f2
·
verified ·
1 Parent(s): a843013

Create model_def.py

Browse files
Files changed (1) hide show
  1. model_def.py +42 -0
model_def.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class EuroSATCNN(nn.Module):
6
+ def __init__(self, num_classes, img_height=64, img_width=64):
7
+ super(EuroSATCNN, self).__init__()
8
+ self.features = nn.Sequential(
9
+ nn.Conv2d(13, 128, kernel_size=4, padding=1),
10
+ nn.ReLU(),
11
+ nn.MaxPool2d(kernel_size=2),
12
+
13
+ nn.Conv2d(128, 64, kernel_size=4, padding=1),
14
+ nn.ReLU(),
15
+ nn.MaxPool2d(kernel_size=2),
16
+
17
+ nn.Conv2d(64, 32, kernel_size=4, padding=1),
18
+ nn.ReLU(),
19
+ nn.MaxPool2d(kernel_size=2),
20
+
21
+ nn.Conv2d(32, 16, kernel_size=4, padding=1),
22
+ nn.ReLU(),
23
+ nn.MaxPool2d(kernel_size=2),
24
+ )
25
+
26
+ with torch.no_grad():
27
+ dummy_input = torch.randn(1, 13, img_height, img_width)
28
+ out = self.features(dummy_input)
29
+ fc1_input_size = out.view(1, -1).shape[1]
30
+
31
+ self.classifier = nn.Sequential(
32
+ nn.Flatten(),
33
+ nn.Linear(fc1_input_size, 64),
34
+ nn.ReLU(),
35
+ nn.Linear(64, num_classes)
36
+
37
+ )
38
+
39
+ def forward(self, x):
40
+ x = self.features(x)
41
+ x = self.classifier(x)
42
+ return x