Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class AttentionBlock(nn.Module): | |
| def __init__(self, in_ch): | |
| super().__init__() | |
| self.group_norm = nn.GroupNorm(32, in_ch) | |
| self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) | |
| self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) | |
| self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) | |
| self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) | |
| def forward(self, x): | |
| B, C, H, W = x.shape | |
| h = self.group_norm(x) | |
| q = self.proj_q(h) | |
| k = self.proj_k(h) | |
| v = self.proj_v(h) | |
| q = q.permute(0, 2, 3, 1).view(B, H * W, C) | |
| k = k.view(B, C, H * W) | |
| w = torch.bmm(q, k) * (int(C) ** (-0.5)) | |
| w = F.softmax(w, dim=-1) | |
| v = v.permute(0, 2, 3, 1).view(B, H * W, C) | |
| h = torch.bmm(w, v) | |
| assert list(h.shape) == [B, H * W, C] | |
| h = h.view(B, H, W, C).permute(0, 3, 1, 2) | |
| h = self.proj(h) | |
| return x + h | |
| class ResidualBlock(nn.Module): | |
| def __init__(self, | |
| in_channels: int, | |
| out_channels: int, | |
| dropout: float, | |
| n_groups: int = 32, | |
| has_attn: bool = False): | |
| super().__init__() | |
| self.norm1 = nn.GroupNorm(n_groups, in_channels) | |
| self.act1 = nn.SiLU() | |
| self.conv1 = nn.Conv2d(in_channels, out_channels, | |
| kernel_size=(3, 3), padding=(1, 1)) | |
| self.norm2 = nn.GroupNorm(n_groups, out_channels) | |
| self.act2 = nn.SiLU() | |
| self.conv2 = nn.Conv2d(out_channels, out_channels, | |
| kernel_size=(3, 3), padding=(1, 1)) | |
| if in_channels != out_channels: | |
| self.shortcut = nn.Conv2d( | |
| in_channels, out_channels, kernel_size=(1, 1)) | |
| else: | |
| self.shortcut = nn.Identity() | |
| if has_attn: | |
| self.attn = AttentionBlock(out_channels) | |
| else: | |
| self.attn = nn.Identity() | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x: torch.Tensor): | |
| h = self.conv1(self.act1(self.norm1(x))) | |
| h = self.conv2(self.dropout(self.act2(self.norm2(h)))) | |
| return self.attn(h + self.shortcut(x)) | |
| class DownBlock(nn.Module): | |
| def __init__(self, in_channels: int, out_channels: int, has_attn: bool, dropout: int): | |
| super().__init__() | |
| self.res = ResidualBlock( | |
| in_channels, out_channels, dropout=dropout, has_attn=has_attn) | |
| def forward(self, x: torch.Tensor): | |
| return self.res(x) | |
| class UpBlock(nn.Module): | |
| def __init__(self, in_channels: int, out_channels: int, has_attn: bool, dropout: int): | |
| super().__init__() | |
| self.res = ResidualBlock( | |
| in_channels, out_channels, dropout=dropout, has_attn=has_attn) | |
| def forward(self, x: torch.Tensor): | |
| return self.res(x) | |
| class MiddleBlock(nn.Module): | |
| def __init__(self, n_channels: int, dropout: int): | |
| super().__init__() | |
| self.res1 = ResidualBlock( | |
| n_channels, n_channels, dropout=dropout, has_attn=True) | |
| self.res2 = ResidualBlock(n_channels, n_channels, dropout=dropout) | |
| def forward(self, x: torch.Tensor): | |
| x = self.res1(x) | |
| x = self.res2(x) | |
| return x | |
| class Downsample(nn.Module): | |
| def __init__(self, n_channels): | |
| super().__init__() | |
| self.conv = nn.Conv2d(n_channels, n_channels, | |
| kernel_size=3, stride=2, padding=1) | |
| def forward(self, x: torch.Tensor): | |
| return self.conv(x) | |
| class Upsample(nn.Module): | |
| def __init__(self, n_channels): | |
| super().__init__() | |
| self.convT = nn.ConvTranspose2d( | |
| n_channels, n_channels, kernel_size=3, stride=2, padding=1, output_padding=1) | |
| self.conv = nn.Conv2d(n_channels, n_channels, | |
| kernel_size=3, stride=1, padding=1) | |
| def forward(self, x: torch.Tensor): | |
| # Bx, Cx, Hx, Wx = x.size() | |
| # x = F.interpolate(x, size=(2*Hx, 2*Wx), mode='bicubic', align_corners=False) | |
| return self.conv(self.convT(x)) | |
| class MeanShift(nn.Conv2d): | |
| def __init__( | |
| self, rgb_range, | |
| rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): | |
| super(MeanShift, self).__init__(3, 3, kernel_size=1) | |
| std = torch.Tensor(rgb_std) | |
| self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) | |
| self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std | |
| for p in self.parameters(): | |
| p.requires_grad = False | |
| class UNET(nn.Module): | |
| def __init__(self, | |
| in_channels: int = 3, | |
| out_channels: int = 3, | |
| n_features: int = 64, | |
| dropout: int = 0.1, | |
| block_out_channels=[64, 128, 128, 256], | |
| layers_per_block=4, | |
| is_attn_layers=(False, False, True, False), | |
| ): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.n_features = n_features | |
| self.dropout = dropout | |
| self.block_out_channels = block_out_channels | |
| self.layers_per_block = layers_per_block | |
| self.is_attn_layers = is_attn_layers | |
| self.sub_mean = MeanShift(255) | |
| self.add_mean = MeanShift(255, sign=1) | |
| self.shallow_feature_extraction = nn.Conv2d( | |
| in_channels, n_features, kernel_size=3, padding=1) | |
| self.image_rescontruction = nn.Conv2d( | |
| n_features, in_channels, kernel_size=3, padding=1) | |
| self.left_model = self.left_unet() | |
| self.middle_model = MiddleBlock( | |
| block_out_channels[-1], dropout=self.dropout) | |
| self.right_model = self.right_unet() | |
| def left_unet(self): | |
| left_model = [] | |
| in_channel = out_channel = self.n_features | |
| for i in range(len(self.block_out_channels)): | |
| out_channel = self.block_out_channels[i] | |
| down_block = [DownBlock(in_channel, out_channel, dropout=self.dropout, has_attn=self.is_attn_layers[i])] \ | |
| + [DownBlock(out_channel, out_channel, dropout=self.dropout, | |
| has_attn=self.is_attn_layers[i])] * (self.layers_per_block - 1) | |
| in_channel = out_channel | |
| left_model.append(nn.Sequential(*down_block)) | |
| if i < len(self.block_out_channels): | |
| left_model.append(Downsample(out_channel)) | |
| return nn.ModuleList(left_model) | |
| def right_unet(self): | |
| right_unet = [] | |
| in_channel = out_channel = self.block_out_channels[-1] | |
| for i in reversed(range(len(self.block_out_channels))): | |
| out_channel = self.block_out_channels[i] | |
| up_block = [UpBlock(in_channel, out_channel, dropout=self.dropout, has_attn=self.is_attn_layers[i - 1])] \ | |
| + [UpBlock(out_channel, out_channel, dropout=self.dropout, has_attn=self.is_attn_layers[i - 1]) | |
| ] * (self.layers_per_block - 1) | |
| in_channel = out_channel * 2 | |
| right_unet.append(nn.Sequential(*up_block)) | |
| right_unet.append(Upsample(out_channel)) | |
| in_channel, out_channel = self.block_out_channels[0] * \ | |
| 2, self.n_features | |
| up_block = [UpBlock(in_channel, out_channel, dropout=self.dropout, has_attn=self.is_attn_layers[0])] \ | |
| + [UpBlock(out_channel, out_channel, dropout=self.dropout, has_attn=self.is_attn_layers[0]) | |
| ] * (self.layers_per_block - 1) | |
| right_unet.append(nn.Sequential(*up_block)) | |
| return nn.ModuleList(right_unet) | |
| def forward(self, x): | |
| x = x * 255 | |
| x = self.sub_mean(x) | |
| feature_maps = self.shallow_feature_extraction(x) | |
| feature_x = [feature_maps] | |
| # print(feature_maps.shape) | |
| feature_block = feature_maps | |
| for block in self.left_model: | |
| feature_block = block(feature_block) | |
| if not isinstance(block, Downsample): | |
| # print(feature_block.shape) | |
| feature_x.append(feature_block) | |
| bottleneck = self.middle_model(feature_block) | |
| feature_x.reverse() | |
| # print('Middle::: ', feature_maps.shape) | |
| recover = bottleneck | |
| d = 0 | |
| for block in self.right_model: | |
| if isinstance(block, Upsample): | |
| # print('UP-CAT::: ', recover.shape) | |
| recover = block(recover) | |
| # print('UP-CAT-END::: ', recover.shape, feature_x[d].shape) | |
| recover = torch.cat([recover, feature_x[d]], 1) | |
| # print('UP-CAT-END::: ', recover.shape, feature_x[d].shape) | |
| d += 1 | |
| else: | |
| recover = block(recover) | |
| # print('UP-RES::: ', recover.shape) | |
| recover = self.image_rescontruction(recover) | |
| recover = self.add_mean(recover) / 255 | |
| return recover |