Valerii Sielikhov
Refactor import statements in htr_convtext.py, layer.py, and mv_block.py for consistency; remove unused imports in resnet18.py and improve code formatting.
561759b | from typing import Optional, Tuple, Union | |
| from torch import nn | |
| class ConvLayer2d(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: Union[int, Tuple[int, int]], | |
| stride: int = 1, | |
| padding: int = 0, | |
| dilation: int = 1, | |
| groups: int = 1, | |
| bias: bool = False, | |
| use_norm: bool = True, | |
| use_act: bool = True, | |
| norm_layer: Optional[nn.Module] = None, | |
| act_layer: Optional[nn.Module] = None, | |
| ): | |
| super().__init__() | |
| layers = [] | |
| layers.append( | |
| nn.Conv2d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| dilation=dilation, | |
| groups=groups, | |
| bias=bias, | |
| ) | |
| ) | |
| if use_norm: | |
| if norm_layer is None: | |
| norm_layer = nn.BatchNorm2d(out_channels) | |
| layers.append(norm_layer) | |
| if use_act: | |
| if act_layer is None: | |
| act_layer = nn.ReLU(inplace=True) | |
| layers.append(act_layer) | |
| self.block = nn.Sequential(*layers) | |
| def forward(self, x): | |
| return self.block(x) | |
| # PEG from https://arxiv.org/abs/2102.10882 | |
| class PosCNN(nn.Module): | |
| def __init__(self, in_chans, embed_dim=None, s=1): | |
| super(PosCNN, self).__init__() | |
| self.proj = nn.Sequential( | |
| nn.Conv2d(in_chans, embed_dim, 3, s, 1, bias=True, groups=embed_dim), | |
| ) | |
| self.s = s | |
| def forward(self, x, H, W): | |
| B, N, C = x.shape | |
| feat_token = x | |
| cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W) | |
| if self.s == 1: | |
| x = self.proj(cnn_feat) + cnn_feat | |
| else: | |
| x = self.proj(cnn_feat) | |
| x = x.flatten(2).transpose(1, 2) | |
| return x | |
| def no_weight_decay(self): | |
| return ["proj.%d.weight" % i for i in range(4)] | |