|
|
from torch.nn.utils.parametrizations import spectral_norm |
|
|
import torch |
|
|
from torch.nn import functional as F |
|
|
import einops |
|
|
|
|
|
|
|
|
class GBlock(torch.nn.Module): |
|
|
def __init__(self, in_channel: int, out_channel: int): |
|
|
super().__init__() |
|
|
self.in_channel = in_channel |
|
|
self.out_channel = out_channel |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.relu = torch.nn.ReLU() |
|
|
|
|
|
self.conv1x1 = spectral_norm( |
|
|
torch.nn.Conv2d(in_channel, out_channel, kernel_size=1) |
|
|
) |
|
|
|
|
|
self.conv3x3_1 = spectral_norm( |
|
|
torch.nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1) |
|
|
) |
|
|
self.conv3x3_2 = spectral_norm( |
|
|
torch.nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1) |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor): |
|
|
|
|
|
if x.shape[1] != self.out_channel: |
|
|
res = self.conv1x1(x) |
|
|
else: |
|
|
res = x.clone() |
|
|
|
|
|
|
|
|
|
|
|
x = self.relu(x) |
|
|
x = self.conv3x3_1(x) |
|
|
|
|
|
|
|
|
x = self.relu(x) |
|
|
x = self.conv3x3_2(x) |
|
|
|
|
|
y = x + res |
|
|
|
|
|
return y |
|
|
|
|
|
|
|
|
class Up_GBlock(torch.nn.Module): |
|
|
def __init__(self, in_channel: int): |
|
|
super().__init__() |
|
|
self.in_channel = in_channel |
|
|
self.out_channel = int(in_channel / 2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.relu = torch.nn.ReLU() |
|
|
self.up = torch.nn.Upsample(scale_factor=2, mode='nearest') |
|
|
|
|
|
self.conv1x1 = spectral_norm( |
|
|
torch.nn.Conv2d(in_channel, self.out_channel, kernel_size=1) |
|
|
) |
|
|
|
|
|
self.conv3x3_1 = spectral_norm( |
|
|
torch.nn.Conv2d(in_channel, in_channel, kernel_size=3, padding=1) |
|
|
) |
|
|
self.conv3x3_2 = spectral_norm( |
|
|
torch.nn.Conv2d( |
|
|
in_channel, |
|
|
self.out_channel, |
|
|
kernel_size=3, |
|
|
padding=1)) |
|
|
|
|
|
def forward(self, x): |
|
|
res = self.up(x) |
|
|
res = self.conv1x1(res) |
|
|
|
|
|
|
|
|
x = self.relu(x) |
|
|
x = self.up(x) |
|
|
x = self.conv3x3_1(x) |
|
|
|
|
|
|
|
|
x = self.relu(x) |
|
|
x = self.conv3x3_2(x) |
|
|
|
|
|
y = x + res |
|
|
|
|
|
return y |
|
|
|
|
|
|
|
|
class DBlock(torch.nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_channel: int, |
|
|
out_channel: int, |
|
|
conv_type='2d', |
|
|
apply_relu=True, |
|
|
apply_down=False): |
|
|
super().__init__() |
|
|
self.in_channel = in_channel |
|
|
self.out_channel = out_channel |
|
|
self.apply_relu = apply_relu |
|
|
self.apply_down = apply_down |
|
|
|
|
|
|
|
|
if conv_type == '2d': |
|
|
self.avg_pool = torch.nn.AvgPool2d(kernel_size=2, stride=2) |
|
|
conv = torch.nn.Conv2d |
|
|
elif conv_type == '3d': |
|
|
self.avg_pool = torch.nn.AvgPool3d(kernel_size=2, stride=2) |
|
|
conv = torch.nn.Conv3d |
|
|
|
|
|
self.relu = torch.nn.ReLU() |
|
|
self.conv1x1 = spectral_norm( |
|
|
conv(in_channel, out_channel, kernel_size=1) |
|
|
) |
|
|
|
|
|
self.conv3x3_1 = spectral_norm( |
|
|
conv(in_channel, out_channel, kernel_size=3, padding=1) |
|
|
) |
|
|
self.conv3x3_2 = spectral_norm( |
|
|
conv(out_channel, out_channel, kernel_size=3, padding=1) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
if x.shape[1] != self.out_channel: |
|
|
res = self.conv1x1(x) |
|
|
else: |
|
|
res = x.clone() |
|
|
if self.apply_down: |
|
|
res = self.avg_pool(res) |
|
|
|
|
|
|
|
|
if self.apply_relu: |
|
|
x = self.relu(x) |
|
|
x = self.conv3x3_1(x) |
|
|
x = self.relu(x) |
|
|
x = self.conv3x3_2(x) |
|
|
if self.apply_down: |
|
|
x = self.avg_pool(x) |
|
|
|
|
|
|
|
|
y = res + x |
|
|
|
|
|
return y |
|
|
|
|
|
|
|
|
class LBlock(torch.nn.Module): |
|
|
def __init__(self, in_channel, out_channel): |
|
|
super().__init__() |
|
|
self.in_channel = in_channel |
|
|
self.out_channel = out_channel |
|
|
|
|
|
self.relu = torch.nn.ReLU() |
|
|
conv = torch.nn.Conv2d |
|
|
self.conv1x1 = conv( |
|
|
in_channel, |
|
|
(out_channel - in_channel), |
|
|
kernel_size=1) |
|
|
|
|
|
self.conv3x3_1 = conv(in_channel, in_channel, kernel_size=3, padding=1) |
|
|
self.conv3x3_2 = conv( |
|
|
in_channel, |
|
|
out_channel, |
|
|
kernel_size=3, |
|
|
padding=1) |
|
|
|
|
|
def forward(self, x): |
|
|
res = torch.cat([x, self.conv1x1(x)], dim=1) |
|
|
|
|
|
x = self.relu(x) |
|
|
x = self.conv3x3_1(x) |
|
|
x = self.relu(x) |
|
|
x = self.conv3x3_2(x) |
|
|
|
|
|
y = x + res |
|
|
|
|
|
return y |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def attention_einsum(q, k, v): |
|
|
""" |
|
|
Apply self-attention to tensors |
|
|
""" |
|
|
|
|
|
|
|
|
k = einops.rearrange(k, "h w c -> (h w) c") |
|
|
v = einops.rearrange(v, "h w c -> (h w) c") |
|
|
|
|
|
|
|
|
beta = F.softmax(torch.einsum("hwc, Lc->hwL", q, k), dim=-1) |
|
|
|
|
|
|
|
|
out = torch.einsum("hwL, Lc->hwc", beta, v) |
|
|
|
|
|
return out |
|
|
|
|
|
|
|
|
class AttentionLayer(torch.nn.Module): |
|
|
def __init__(self, in_channel, out_channel, ratio_kq=8, ratio_v=8): |
|
|
super().__init__() |
|
|
|
|
|
self.ratio_kq = ratio_kq |
|
|
self.ratio_v = ratio_v |
|
|
self.in_channel = in_channel |
|
|
self.out_channel = out_channel |
|
|
|
|
|
|
|
|
self.query = torch.nn.Conv2d( |
|
|
in_channel, |
|
|
out_channel // ratio_kq, |
|
|
kernel_size=1, |
|
|
bias=False |
|
|
) |
|
|
|
|
|
self.key = torch.nn.Conv2d( |
|
|
in_channel, |
|
|
out_channel // ratio_kq, |
|
|
kernel_size=1, |
|
|
bias=False |
|
|
) |
|
|
|
|
|
self.value = torch.nn.Conv2d( |
|
|
in_channel, |
|
|
out_channel // ratio_v, |
|
|
kernel_size=1, |
|
|
bias=False |
|
|
) |
|
|
|
|
|
self.conv = torch.nn.Conv2d( |
|
|
out_channel // 8, |
|
|
out_channel, |
|
|
kernel_size=1, |
|
|
bias=False |
|
|
) |
|
|
|
|
|
self.gamma = torch.nn.Parameter(torch.zeros(1)) |
|
|
|
|
|
def forward(self, x): |
|
|
query = self.query(x) |
|
|
key = self.key(x) |
|
|
value = self.value(x) |
|
|
|
|
|
out = [] |
|
|
for i in range(x.shape[0]): |
|
|
out.append(attention_einsum(query[i], key[i], value[i])) |
|
|
|
|
|
out = torch.stack(out, dim=0) |
|
|
out = self.gamma * self.conv(out) |
|
|
out = out + x |
|
|
|
|
|
return out |
|
|
|