Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Optional | |
| import warnings | |
| __all__ = ['ConvMLP', 'ConvModule'] | |
| class LayerNorm2d(nn.LayerNorm): | |
| """ LayerNorm for channels of '2D' spatial BCHW tensors """ | |
| def __init__(self, num_channels): | |
| super().__init__(num_channels) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return F.layer_norm( | |
| x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) | |
| class DepthwiseSeparableConv2d(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): | |
| super(DepthwiseSeparableConv2d, self).__init__() | |
| self.depthwise = nn.Conv2d( | |
| in_channels, | |
| in_channels, | |
| kernel_size=kernel_size, | |
| dilation=dilation, | |
| padding=padding, | |
| stride=stride, | |
| bias=bias, | |
| groups=in_channels, | |
| ) | |
| self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, groups=groups, bias=bias) | |
| def forward(self, x): | |
| out = self.depthwise(x) | |
| out = self.pointwise(out) | |
| return out | |
| class ConvMLP(nn.Module): | |
| def __init__(self, in_channels, out_channels=None, hidden_channels=None, drop=0.25): | |
| super().__init__() | |
| out_channels = in_channels or out_channels | |
| hidden_channels = in_channels or hidden_channels | |
| self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1, bias=True) | |
| self.norm = LayerNorm2d(hidden_channels) | |
| self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1, bias=True) | |
| self.act = nn.ReLU() | |
| self.drop = nn.Dropout(drop) | |
| def forward(self, x): | |
| x = self.fc1(x) | |
| x = self.norm(x) | |
| x = self.act(x) | |
| x = self.drop(x) | |
| x = self.fc2(x) | |
| return x | |
| class ConvModule(nn.Module): | |
| def __init__(self, | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride=1, | |
| padding=0, | |
| dilation=1, | |
| groups=1, | |
| bias='auto', | |
| conv_layer:Optional[nn.Module]=nn.Conv2d, | |
| norm_layer:Optional[nn.Module]=nn.BatchNorm2d, | |
| act_layer:Optional[nn.Module]=nn.ReLU, | |
| inplace=True, | |
| with_spectral_norm=False, | |
| padding_mode='zeros', | |
| order=('conv', 'norm', 'act') | |
| ): | |
| official_padding_mode = ['zeros', 'circular'] | |
| nonofficial_padding_mode = dict(zero=nn.ZeroPad2d, reflect=nn.ReflectionPad2d, replicate=nn.ReplicationPad2d) | |
| self.with_spectral_norm = with_spectral_norm | |
| self.with_explicit_padding = padding_mode not in official_padding_mode | |
| self.order = order | |
| assert isinstance(self.order, tuple) and len(self.order) == 3 | |
| assert set(order) == set(['conv', 'norm', 'act']) | |
| self.with_norm = norm_layer is not None | |
| self.with_act = act_layer is not None | |
| if bias == 'auto': | |
| bias = not self.with_norm | |
| self.with_bias = bias | |
| if self.with_norm and self.with_bias: | |
| warnings.warn('ConvModule has norm and bias at the same time') | |
| if self.with_explicit_padding: | |
| assert padding_mode in list(nonofficial_padding_mode), "Not implemented padding algorithm" | |
| self.padding_layer = nonofficial_padding_mode[padding_mode] | |
| # reset padding to 0 for conv module | |
| conv_padding = 0 if self.with_explicit_padding else padding | |
| self.conv = conv_layer( | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride=stride, | |
| padding=conv_padding, | |
| dilation=dilation, | |
| groups=groups, | |
| bias=bias | |
| ) | |
| self.in_channels = self.conv.in_channels | |
| self.out_channels = self.conv.out_channels | |
| self.kernel_size = self.conv.kernel_size | |
| self.stride = self.conv.stride | |
| self.padding = padding | |
| self.dilation = self.conv.dilation | |
| self.transposed = self.conv.transposed | |
| self.output_padding = self.conv.output_padding | |
| self.groups = self.conv.groups | |
| if self.with_spectral_norm: | |
| self.conv = nn.utils.spectral_norm(self.conv) | |
| # build normalization layers | |
| if self.with_norm: | |
| # norm layer is after conv layer | |
| if order.index('norm') > order.index('conv'): | |
| norm_channels = out_channels | |
| else: | |
| norm_channels = in_channels | |
| self.norm = norm_layer(norm_channels) | |
| if self.with_act: | |
| if act_layer not in [nn.Tanh, nn.PReLU, nn.Sigmoid]: | |
| self.act = act_layer() | |
| else: | |
| self.act = act_layer(inplace=inplace) | |
| def forward(self, x, activate=True, norm=True): | |
| for layer in self.order: | |
| if layer == 'conv': | |
| if self.with_explicit_padding: | |
| x = self.padding_layer(x) | |
| x = self.conv(x) | |
| elif layer == 'norm' and norm and self.with_norm: | |
| x = self.norm(x) | |
| elif layer == 'act' and activate and self.with_act: | |
| x = self.act(x) | |
| return x |