kreyesp commited on
Commit
b2816b6
·
verified ·
1 Parent(s): 3b9d78e

Upload model definition

Browse files
Files changed (1) hide show
  1. model.py +71 -0
model.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class SmallCloudNet(nn.Module):
6
+ def __init__(self, in_ch=3, num_classes=4):
7
+ super().__init__()
8
+
9
+ def block(cin, cout):
10
+ # return nn.Sequential(
11
+ # nn.Conv2d(cin, cout, 3, padding=1),
12
+ # nn.GroupNorm(8, cout),
13
+ # nn.ReLU(),
14
+ # nn.Conv2d(cout, cout, 3, padding=1),
15
+ # nn.GroupNorm(8, cout),
16
+ # nn.ReLU(),
17
+ # )
18
+ return nn.Sequential(
19
+ nn.Conv2d(cin, cout, 3, padding=1),
20
+ nn.GroupNorm(16, cout), # bumped from 8 to 16 groups
21
+ nn.ReLU(),
22
+ nn.Conv2d(cout, cout, 3, padding=1),
23
+ nn.GroupNorm(16, cout),
24
+ nn.ReLU(),
25
+ )
26
+
27
+ # self.enc1 = block(in_ch, 32)
28
+ # self.enc2 = block(32, 64)
29
+ # self.enc3 = block(64, 128)
30
+ # self.pool = nn.MaxPool2d(2)
31
+
32
+ # self.up2 = nn.ConvTranspose2d(128, 64, 2, stride=2)
33
+ # self.dec2 = block(128, 64)
34
+
35
+ # self.up1 = nn.ConvTranspose2d(64, 32, 2, stride=2)
36
+ # self.dec1 = block(64, 32)
37
+
38
+ # self.head = nn.Conv2d(32, num_classes, 1)
39
+ self.enc1 = block(in_ch, 64)
40
+ self.enc2 = block(64, 128)
41
+ self.enc3 = block(128, 256)
42
+ self.pool = nn.MaxPool2d(2)
43
+
44
+
45
+ self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
46
+ self.dec2 = block(256, 128)
47
+
48
+ self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
49
+ self.dec1 = block(128, 64)
50
+
51
+ self.head = nn.Conv2d(64, num_classes, 1)
52
+
53
+ @staticmethod
54
+ def _match(up, skip):
55
+ """Crop skip connection to match upsampled tensor if sizes differ."""
56
+ if up.shape != skip.shape:
57
+ skip = skip[:, :, :up.shape[2], :up.shape[3]]
58
+ return skip
59
+
60
+ def forward(self, x):
61
+ e1 = self.enc1(x)
62
+ e2 = self.enc2(self.pool(e1))
63
+ e3 = self.enc3(self.pool(e2))
64
+
65
+ u2 = self.up2(e3)
66
+ d2 = self.dec2(torch.cat([u2, self._match(u2, e2)], dim=1))
67
+
68
+ u1 = self.up1(d2)
69
+ d1 = self.dec1(torch.cat([u1, self._match(u1, e1)], dim=1))
70
+
71
+ return self.head(d1)