import torch from torch import nn from einops.layers.torch import Rearrange def init_layer(layer): nn.init.xavier_uniform_(layer.weight) if hasattr(layer, "bias") and layer.bias is not None: layer.bias.data.fill_(0.0) def init_bn(bn): bn.bias.data.fill_(0.0) bn.weight.data.fill_(1.0) bn.running_mean.data.fill_(0.0) bn.running_var.data.fill_(1.0) class BiGRU(nn.Module): def __init__( self, patch_size, channels, depth ): super(BiGRU, self).__init__() patch_width, patch_height = patch_size patch_dim = channels * patch_height * patch_width self.to_patch_embedding = nn.Sequential( Rearrange( 'b c (w p1) (h p2) -> b (w h) (p1 p2 c)', p1=patch_width, p2=patch_height ) ) self.gru = nn.GRU( patch_dim, patch_dim // 2, num_layers=depth, batch_first=True, bidirectional=True ) def forward(self, x): x = self.to_patch_embedding(x) try: return self.gru(x)[0] except: torch.backends.cudnn.enabled = False return self.gru(x)[0] class ResConvBlock(nn.Module): def __init__( self, in_planes, out_planes ): super(ResConvBlock, self).__init__() self.bn1 = nn.BatchNorm2d( in_planes, momentum=0.01 ) self.bn2 = nn.BatchNorm2d( out_planes, momentum=0.01 ) self.act1 = nn.PReLU() self.act2 = nn.PReLU() self.conv1 = nn.Conv2d( in_planes, out_planes, (3, 3), padding=(1, 1), bias=False ) self.conv2 = nn.Conv2d( out_planes, out_planes, (3, 3), padding=(1, 1), bias=False ) self.is_shortcut = False if in_planes != out_planes: self.shortcut = nn.Conv2d( in_planes, out_planes, (1, 1) ) self.is_shortcut = True self.init_weights() def init_weights(self): init_bn(self.bn1) init_bn(self.bn2) init_layer(self.conv1) init_layer(self.conv2) if self.is_shortcut: init_layer(self.shortcut) def forward(self, x): out = self.conv1( self.act1(self.bn1(x)) ) out = self.conv2( self.act2(self.bn2(out)) ) if self.is_shortcut: return self.shortcut(x) + out else: return out + x