Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class ConvBnRelu(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: int, | |
| stride: int = 1, | |
| padding: int = 0, | |
| dilation: int = 1, | |
| groups: int = 1, | |
| bias: bool = True, | |
| add_relu: bool = True, | |
| interpolate: bool = False, | |
| ): | |
| super(ConvBnRelu, self).__init__() | |
| self.conv = nn.Conv2d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| dilation=dilation, | |
| bias=bias, | |
| groups=groups, | |
| ) | |
| self.add_relu = add_relu | |
| self.interpolate = interpolate | |
| self.bn = nn.BatchNorm2d(out_channels) | |
| self.activation = nn.ReLU(inplace=True) | |
| def forward(self, x): | |
| x = self.conv(x) | |
| x = self.bn(x) | |
| if self.add_relu: | |
| x = self.activation(x) | |
| if self.interpolate: | |
| x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) | |
| return x | |
| class FPABlock(nn.Module): | |
| def __init__(self, in_channels, out_channels, upscale_mode="bilinear"): | |
| super(FPABlock, self).__init__() | |
| self.upscale_mode = upscale_mode | |
| if self.upscale_mode == "bilinear": | |
| self.align_corners = True | |
| else: | |
| self.align_corners = False | |
| # global pooling branch | |
| self.branch1 = nn.Sequential( | |
| nn.AdaptiveAvgPool2d(1), | |
| ConvBnRelu( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0, | |
| ), | |
| ) | |
| # midddle branch | |
| self.mid = nn.Sequential( | |
| ConvBnRelu( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0, | |
| ) | |
| ) | |
| self.down1 = nn.Sequential( | |
| nn.MaxPool2d(kernel_size=2, stride=2), | |
| ConvBnRelu( | |
| in_channels=in_channels, | |
| out_channels=1, | |
| kernel_size=7, | |
| stride=1, | |
| padding=3, | |
| ), | |
| ) | |
| self.down2 = nn.Sequential( | |
| nn.MaxPool2d(kernel_size=2, stride=2), | |
| ConvBnRelu( | |
| in_channels=1, out_channels=1, kernel_size=5, stride=1, padding=2 | |
| ), | |
| ) | |
| self.down3 = nn.Sequential( | |
| nn.MaxPool2d(kernel_size=2, stride=2), | |
| ConvBnRelu( | |
| in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1 | |
| ), | |
| ConvBnRelu( | |
| in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1 | |
| ), | |
| ) | |
| self.conv2 = ConvBnRelu( | |
| in_channels=1, out_channels=1, kernel_size=5, stride=1, padding=2 | |
| ) | |
| self.conv1 = ConvBnRelu( | |
| in_channels=1, out_channels=1, kernel_size=7, stride=1, padding=3 | |
| ) | |
| def forward(self, x): | |
| h, w = x.size(2), x.size(3) | |
| b1 = self.branch1(x) | |
| upscale_parameters = dict( | |
| mode=self.upscale_mode, align_corners=self.align_corners | |
| ) | |
| b1 = F.interpolate(b1, size=(h, w), **upscale_parameters) | |
| mid = self.mid(x) | |
| x1 = self.down1(x) | |
| x2 = self.down2(x1) | |
| x3 = self.down3(x2) | |
| x3 = F.interpolate(x3, size=(h // 4, w // 4), **upscale_parameters) | |
| x2 = self.conv2(x2) | |
| x = x2 + x3 | |
| x = F.interpolate(x, size=(h // 2, w // 2), **upscale_parameters) | |
| x1 = self.conv1(x1) | |
| x = x + x1 | |
| x = F.interpolate(x, size=(h, w), **upscale_parameters) | |
| x = torch.mul(x, mid) | |
| x = x + b1 | |
| return x | |
| class GAUBlock(nn.Module): | |
| def __init__( | |
| self, in_channels: int, out_channels: int, upscale_mode: str = "bilinear" | |
| ): | |
| super(GAUBlock, self).__init__() | |
| self.upscale_mode = upscale_mode | |
| self.align_corners = True if upscale_mode == "bilinear" else None | |
| self.conv1 = nn.Sequential( | |
| nn.AdaptiveAvgPool2d(1), | |
| ConvBnRelu( | |
| in_channels=out_channels, | |
| out_channels=out_channels, | |
| kernel_size=1, | |
| add_relu=False, | |
| ), | |
| nn.Sigmoid(), | |
| ) | |
| self.conv2 = ConvBnRelu( | |
| in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1 | |
| ) | |
| def forward(self, x, y): | |
| """ | |
| Args: | |
| x: low level feature | |
| y: high level feature | |
| """ | |
| h, w = x.size(2), x.size(3) | |
| y_up = F.interpolate( | |
| y, size=(h, w), mode=self.upscale_mode, align_corners=self.align_corners | |
| ) | |
| x = self.conv2(x) | |
| y = self.conv1(y) | |
| z = torch.mul(x, y) | |
| return y_up + z | |
| class PANDecoder(nn.Module): | |
| def __init__( | |
| self, encoder_channels, decoder_channels, upscale_mode: str = "bilinear" | |
| ): | |
| super().__init__() | |
| self.fpa = FPABlock( | |
| in_channels=encoder_channels[-1], out_channels=decoder_channels | |
| ) | |
| self.gau3 = GAUBlock( | |
| in_channels=encoder_channels[-2], | |
| out_channels=decoder_channels, | |
| upscale_mode=upscale_mode, | |
| ) | |
| self.gau2 = GAUBlock( | |
| in_channels=encoder_channels[-3], | |
| out_channels=decoder_channels, | |
| upscale_mode=upscale_mode, | |
| ) | |
| self.gau1 = GAUBlock( | |
| in_channels=encoder_channels[-4], | |
| out_channels=decoder_channels, | |
| upscale_mode=upscale_mode, | |
| ) | |
| def forward(self, *features): | |
| bottleneck = features[-1] | |
| x5 = self.fpa(bottleneck) # 1/32 | |
| x4 = self.gau3(features[-2], x5) # 1/16 | |
| x3 = self.gau2(features[-3], x4) # 1/8 | |
| x2 = self.gau1(features[-4], x3) # 1/4 | |
| return x2 | |