File size: 3,348 Bytes
c679d56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""
Module for UNet based predictor.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class UNetPredictor(nn.Module):
    """
    U net based predictor model class.
    """
    
    def __init__(self):
        super().__init__()
        
        # Encoder blocks
        self.enc1 = nn.Sequential(
            nn.Conv2d(in_channels=15, out_channels=32, kernel_size=(3, 3), padding=1),
            nn.GroupNorm(num_groups=8, num_channels=32),
            nn.LeakyReLU()
        )
        self.enc2 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3), padding=1),
            nn.GroupNorm(num_groups=8, num_channels=64),
            nn.LeakyReLU()
        )
        self.enc3 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), padding=1),
            nn.GroupNorm(num_groups=8, num_channels=128),
            nn.LeakyReLU()
        )
        self.bottleneck = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), padding=1),
            nn.GroupNorm(num_groups=8, num_channels=256),
            nn.LeakyReLU()
        )
        
        # Decoder blocks
        self.up3 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.dec3 = nn.Sequential(
            nn.Conv2d(in_channels=384, out_channels=128, kernel_size=(3, 3), padding=1),
            nn.GroupNorm(num_groups=8, num_channels=128),
            nn.LeakyReLU()
        )
        self.up2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.dec2 = nn.Sequential(
            nn.Conv2d(in_channels=192, out_channels=64, kernel_size=(3, 3), padding=1),
            nn.GroupNorm(num_groups=8, num_channels=64),
            nn.LeakyReLU()
        )
        self.up1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.dec1 = nn.Sequential(
            nn.Conv2d(in_channels=96, out_channels=32, kernel_size=(3, 3), padding=1),
            nn.GroupNorm(num_groups=8, num_channels=32),
            nn.LeakyReLU()
        )
        
        # Output layer
        self.out = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=1, kernel_size=(3, 3), padding=1),
            nn.Tanh()
        )
        
        # Pooling layer
        self.pool = nn.MaxPool2d(kernel_size=(2, 2), stride=2)

    def forward(self, x: torch.Tensor):
        # x: (B, 15, 1, H, W) -> squeeze/reshape -> (B, 15, H, W)
        x = x.squeeze(2)                    # # (B,15,1,H,W) -> (B,15,H,W)
        s1 = self.enc1(x)                   # (B,32,128,128) <- skip1
        s2 = self.enc2(self.pool(s1))       # (B,64,64,64)   <- skip2
        s3 = self.enc3(self.pool(s2))       # (B,128,32,32)  <- skip3
        b  = self.bottleneck(self.pool(s3)) # (B,256,16,16)

        d3 = self.dec3(torch.cat([self.up3(b),  s3], dim=1))   # cat→384 -> 128, (B,128,32,32)
        d2 = self.dec2(torch.cat([self.up2(d3), s2], dim=1))   # cat→192 -> 64,  (B,64,64,64)
        d1 = self.dec1(torch.cat([self.up1(d2), s1], dim=1))   # cat→96  -> 32,  (B,32,128,128)

        return self.out(d1)     # (B,1,128,128)
    

if __name__ == "__main__":
    model = UNetPredictor()
    x = torch.randn(2, 15, 1, 128, 128)
    out = model(x)
    print(out.shape)   # expected: (2, 1, 128, 128)