| """ |
| Module for UNet based predictor. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class UNetPredictor(nn.Module): |
| """ |
| U net based predictor model class. |
| """ |
| |
| def __init__(self): |
| super().__init__() |
| |
| |
| self.enc1 = nn.Sequential( |
| nn.Conv2d(in_channels=15, out_channels=32, kernel_size=(3, 3), padding=1), |
| nn.GroupNorm(num_groups=8, num_channels=32), |
| nn.LeakyReLU() |
| ) |
| self.enc2 = nn.Sequential( |
| nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3), padding=1), |
| nn.GroupNorm(num_groups=8, num_channels=64), |
| nn.LeakyReLU() |
| ) |
| self.enc3 = nn.Sequential( |
| nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), padding=1), |
| nn.GroupNorm(num_groups=8, num_channels=128), |
| nn.LeakyReLU() |
| ) |
| self.bottleneck = nn.Sequential( |
| nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), padding=1), |
| nn.GroupNorm(num_groups=8, num_channels=256), |
| nn.LeakyReLU() |
| ) |
| |
| |
| self.up3 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) |
| self.dec3 = nn.Sequential( |
| nn.Conv2d(in_channels=384, out_channels=128, kernel_size=(3, 3), padding=1), |
| nn.GroupNorm(num_groups=8, num_channels=128), |
| nn.LeakyReLU() |
| ) |
| self.up2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) |
| self.dec2 = nn.Sequential( |
| nn.Conv2d(in_channels=192, out_channels=64, kernel_size=(3, 3), padding=1), |
| nn.GroupNorm(num_groups=8, num_channels=64), |
| nn.LeakyReLU() |
| ) |
| self.up1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) |
| self.dec1 = nn.Sequential( |
| nn.Conv2d(in_channels=96, out_channels=32, kernel_size=(3, 3), padding=1), |
| nn.GroupNorm(num_groups=8, num_channels=32), |
| nn.LeakyReLU() |
| ) |
| |
| |
| self.out = nn.Sequential( |
| nn.Conv2d(in_channels=32, out_channels=1, kernel_size=(3, 3), padding=1), |
| nn.Tanh() |
| ) |
| |
| |
| self.pool = nn.MaxPool2d(kernel_size=(2, 2), stride=2) |
|
|
| def forward(self, x: torch.Tensor): |
| |
| x = x.squeeze(2) |
| s1 = self.enc1(x) |
| s2 = self.enc2(self.pool(s1)) |
| s3 = self.enc3(self.pool(s2)) |
| b = self.bottleneck(self.pool(s3)) |
|
|
| d3 = self.dec3(torch.cat([self.up3(b), s3], dim=1)) |
| d2 = self.dec2(torch.cat([self.up2(d3), s2], dim=1)) |
| d1 = self.dec1(torch.cat([self.up1(d2), s1], dim=1)) |
|
|
| return self.out(d1) |
| |
|
|
| if __name__ == "__main__": |
| model = UNetPredictor() |
| x = torch.randn(2, 15, 1, 128, 128) |
| out = model(x) |
| print(out.shape) |