|
|
from torch import nn, Tensor |
|
|
|
|
|
from typing import Union |
|
|
|
|
|
from .blocks import DepthSeparableConv2d, conv1x1, conv3x3 |
|
|
from .utils import _init_weights |
|
|
|
|
|
|
|
|
class ConvDownsample(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_channels: int, |
|
|
out_channels: int, |
|
|
norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d, |
|
|
activation: nn.Module = nn.ReLU(inplace=True), |
|
|
groups: int = 1, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
assert isinstance(groups, int) and groups > 0, f"Number of groups should be an integer greater than 0, but got {groups}." |
|
|
assert in_channels % groups == 0, f"Number of input channels {in_channels} should be divisible by number of groups {groups}." |
|
|
assert out_channels % groups == 0, f"Number of output channels {out_channels} should be divisible by number of groups {groups}." |
|
|
self.grouped_conv = groups > 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.conv1 = nn.AvgPool2d(kernel_size=2, stride=2) |
|
|
if self.grouped_conv: |
|
|
self.conv1_1x1 = nn.Identity() |
|
|
|
|
|
self.norm1 = norm_layer(in_channels) if norm_layer else nn.Identity() |
|
|
self.act1 = activation |
|
|
|
|
|
self.conv2 = conv3x3( |
|
|
in_channels=in_channels, |
|
|
out_channels=in_channels, |
|
|
stride=1, |
|
|
groups=groups, |
|
|
bias=not norm_layer, |
|
|
) |
|
|
if self.grouped_conv: |
|
|
self.conv2_1x1 = conv1x1(in_channels, in_channels, stride=1, bias=not norm_layer) |
|
|
|
|
|
self.norm2 = norm_layer(in_channels) if norm_layer else nn.Identity() |
|
|
self.act2 = activation |
|
|
|
|
|
self.conv3 = conv3x3( |
|
|
in_channels=in_channels, |
|
|
out_channels=out_channels, |
|
|
stride=1, |
|
|
groups=groups, |
|
|
bias=not norm_layer, |
|
|
) |
|
|
if self.grouped_conv: |
|
|
self.conv3_1x1 = conv1x1(out_channels, out_channels, stride=1, bias=not norm_layer) |
|
|
|
|
|
self.norm3 = norm_layer(out_channels) if norm_layer else nn.Identity() |
|
|
self.act3 = activation |
|
|
|
|
|
self.downsample = nn.Sequential( |
|
|
nn.AvgPool2d(kernel_size=2, stride=2), |
|
|
conv1x1(in_channels, out_channels, stride=1, bias=not norm_layer), |
|
|
norm_layer(out_channels) if norm_layer else nn.Identity(), |
|
|
) |
|
|
|
|
|
self.apply(_init_weights) |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
identity = x |
|
|
|
|
|
|
|
|
out = self.conv1(x) |
|
|
out = self.conv1_1x1(out) if self.grouped_conv else out |
|
|
out = self.norm1(out) |
|
|
out = self.act1(out) |
|
|
|
|
|
out = self.conv2(out) |
|
|
out = self.conv2_1x1(out) if self.grouped_conv else out |
|
|
out = self.norm2(out) |
|
|
out = self.act2(out) |
|
|
|
|
|
out = self.conv3(out) |
|
|
out = self.conv3_1x1(out) if self.grouped_conv else out |
|
|
out = self.norm3(out) |
|
|
|
|
|
|
|
|
out += self.downsample(identity) |
|
|
out = self.act3(out) |
|
|
return out |
|
|
|
|
|
|
|
|
class LightConvDownsample(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_channels: int, |
|
|
out_channels: int, |
|
|
norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d, |
|
|
activation: nn.Module = nn.ReLU(inplace=True), |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.conv1 = DepthSeparableConv2d( |
|
|
in_channels=in_channels, |
|
|
out_channels=in_channels, |
|
|
kernel_size=2, |
|
|
stride=2, |
|
|
padding=0, |
|
|
bias=not norm_layer, |
|
|
) |
|
|
self.norm1 = norm_layer(in_channels) if norm_layer else nn.Identity() |
|
|
self.act1 = activation |
|
|
|
|
|
self.conv2 = DepthSeparableConv2d( |
|
|
in_channels=in_channels, |
|
|
out_channels=out_channels, |
|
|
kernel_size=3, |
|
|
stride=1, |
|
|
padding=1, |
|
|
bias=not norm_layer, |
|
|
) |
|
|
self.norm2 = norm_layer(out_channels) if norm_layer else nn.Identity() |
|
|
self.act2 = activation |
|
|
|
|
|
self.conv3 = DepthSeparableConv2d( |
|
|
in_channels=out_channels, |
|
|
out_channels=out_channels, |
|
|
kernel_size=3, |
|
|
stride=1, |
|
|
padding=1, |
|
|
bias=not norm_layer, |
|
|
) |
|
|
self.norm3 = norm_layer(out_channels) if norm_layer else nn.Identity() |
|
|
self.act3 = activation |
|
|
|
|
|
self.downsample = nn.Sequential( |
|
|
nn.AvgPool2d(kernel_size=2, stride=2), |
|
|
conv1x1(in_channels, out_channels, stride=1, bias=not norm_layer), |
|
|
norm_layer(out_channels) if norm_layer else nn.Identity(), |
|
|
) |
|
|
|
|
|
self.apply(_init_weights) |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
identity = x |
|
|
|
|
|
|
|
|
out = self.conv1(x) |
|
|
out = self.norm1(out) |
|
|
out = self.act1(out) |
|
|
|
|
|
|
|
|
out = self.conv2(out) |
|
|
out = self.norm2(out) |
|
|
out = self.act2(out) |
|
|
|
|
|
|
|
|
out = self.conv3(out) |
|
|
out = self.norm3(out) |
|
|
|
|
|
|
|
|
out += self.downsample(identity) |
|
|
out = self.act3(out) |
|
|
return x |
|
|
|
|
|
|
|
|
class LighterConvDownsample(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_channels: int, |
|
|
out_channels: int, |
|
|
norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d, |
|
|
activation: nn.Module = nn.ReLU(inplace=True), |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.conv1 = DepthSeparableConv2d( |
|
|
in_channels=in_channels, |
|
|
out_channels=in_channels, |
|
|
kernel_size=2, |
|
|
stride=2, |
|
|
padding=0, |
|
|
bias=not norm_layer, |
|
|
) |
|
|
self.norm1 = norm_layer(in_channels) if norm_layer else nn.Identity() |
|
|
self.act1 = activation |
|
|
|
|
|
self.conv2 = conv3x3( |
|
|
in_channels=in_channels, |
|
|
out_channels=in_channels, |
|
|
stride=1, |
|
|
groups=in_channels, |
|
|
bias=not norm_layer, |
|
|
) |
|
|
self.norm2 = norm_layer(in_channels) if norm_layer else nn.Identity() |
|
|
self.act2 = activation |
|
|
|
|
|
self.conv3 = conv1x1( |
|
|
in_channels=in_channels, |
|
|
out_channels=out_channels, |
|
|
stride=1, |
|
|
bias=not norm_layer, |
|
|
) |
|
|
self.norm3 = norm_layer(out_channels) if norm_layer else nn.Identity() |
|
|
self.act3 = activation |
|
|
|
|
|
self.downsample = nn.Sequential( |
|
|
nn.AvgPool2d(kernel_size=2, stride=2), |
|
|
conv1x1(in_channels, out_channels, stride=1, bias=not norm_layer), |
|
|
norm_layer(out_channels) if norm_layer else nn.Identity(), |
|
|
) |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
identity = x |
|
|
|
|
|
|
|
|
out = self.conv1(x) |
|
|
out = self.norm1(out) |
|
|
out = self.act1(out) |
|
|
|
|
|
|
|
|
out = self.conv2(out) |
|
|
out = self.norm2(out) |
|
|
out = self.act2(out) |
|
|
|
|
|
|
|
|
out = self.conv3(out) |
|
|
out = self.norm3(out) |
|
|
|
|
|
|
|
|
out += self.downsample(identity) |
|
|
out = self.act3(out) |
|
|
return out |
|
|
|