File size: 2,279 Bytes
b2816b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
import torch.nn as nn


class SmallCloudNet(nn.Module):
    def __init__(self, in_ch=3, num_classes=4):
        super().__init__()

        def block(cin, cout):
            # return nn.Sequential(
            #     nn.Conv2d(cin, cout, 3, padding=1),
            #     nn.GroupNorm(8, cout),
            #     nn.ReLU(),
            #     nn.Conv2d(cout, cout, 3, padding=1),
            #     nn.GroupNorm(8, cout),
            #     nn.ReLU(),
            # )
            return nn.Sequential(
                    nn.Conv2d(cin, cout, 3, padding=1),
                    nn.GroupNorm(16, cout),   # bumped from 8 to 16 groups
                    nn.ReLU(),
                    nn.Conv2d(cout, cout, 3, padding=1),
                    nn.GroupNorm(16, cout),
                    nn.ReLU(),
                )

        # self.enc1 = block(in_ch, 32)
        # self.enc2 = block(32, 64)
        # self.enc3 = block(64, 128)
        # self.pool = nn.MaxPool2d(2)

        # self.up2  = nn.ConvTranspose2d(128, 64, 2, stride=2)
        # self.dec2 = block(128, 64)

        # self.up1  = nn.ConvTranspose2d(64, 32, 2, stride=2)
        # self.dec1 = block(64, 32)

        # self.head = nn.Conv2d(32, num_classes, 1)
        self.enc1 = block(in_ch, 64)
        self.enc2 = block(64, 128)
        self.enc3 = block(128, 256)
        self.pool = nn.MaxPool2d(2)       


        self.up2  = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = block(256, 128)

        self.up1  = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = block(128, 64)

        self.head = nn.Conv2d(64, num_classes, 1)

    @staticmethod
    def _match(up, skip):
        """Crop skip connection to match upsampled tensor if sizes differ."""
        if up.shape != skip.shape:
            skip = skip[:, :, :up.shape[2], :up.shape[3]]
        return skip

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))

        u2 = self.up2(e3)
        d2 = self.dec2(torch.cat([u2, self._match(u2, e2)], dim=1))

        u1 = self.up1(d2)
        d1 = self.dec1(torch.cat([u1, self._match(u1, e1)], dim=1))

        return self.head(d1)