File size: 2,730 Bytes
30f8290 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import torch
from torch import nn
from einops.layers.torch import Rearrange
def init_layer(layer):
nn.init.xavier_uniform_(layer.weight)
if hasattr(layer, "bias") and layer.bias is not None:
layer.bias.data.fill_(0.0)
def init_bn(bn):
bn.bias.data.fill_(0.0)
bn.weight.data.fill_(1.0)
bn.running_mean.data.fill_(0.0)
bn.running_var.data.fill_(1.0)
class BiGRU(nn.Module):
def __init__(
self,
patch_size,
channels,
depth
):
super(BiGRU, self).__init__()
patch_width, patch_height = patch_size
patch_dim = channels * patch_height * patch_width
self.to_patch_embedding = nn.Sequential(
Rearrange(
'b c (w p1) (h p2) -> b (w h) (p1 p2 c)',
p1=patch_width,
p2=patch_height
)
)
self.gru = nn.GRU(
patch_dim,
patch_dim // 2,
num_layers=depth,
batch_first=True,
bidirectional=True
)
def forward(self, x):
x = self.to_patch_embedding(x)
try:
return self.gru(x)[0]
except:
torch.backends.cudnn.enabled = False
return self.gru(x)[0]
class ResConvBlock(nn.Module):
def __init__(
self,
in_planes,
out_planes
):
super(ResConvBlock, self).__init__()
self.bn1 = nn.BatchNorm2d(
in_planes,
momentum=0.01
)
self.bn2 = nn.BatchNorm2d(
out_planes,
momentum=0.01
)
self.act1 = nn.PReLU()
self.act2 = nn.PReLU()
self.conv1 = nn.Conv2d(
in_planes,
out_planes,
(3, 3),
padding=(1, 1),
bias=False
)
self.conv2 = nn.Conv2d(
out_planes,
out_planes,
(3, 3),
padding=(1, 1),
bias=False
)
self.is_shortcut = False
if in_planes != out_planes:
self.shortcut = nn.Conv2d(
in_planes,
out_planes,
(1, 1)
)
self.is_shortcut = True
self.init_weights()
def init_weights(self):
init_bn(self.bn1)
init_bn(self.bn2)
init_layer(self.conv1)
init_layer(self.conv2)
if self.is_shortcut: init_layer(self.shortcut)
def forward(self, x):
out = self.conv1(
self.act1(self.bn1(x))
)
out = self.conv2(
self.act2(self.bn2(out))
)
if self.is_shortcut: return self.shortcut(x) + out
else: return out + x |