| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from source.vae_hacked import ResnetBlock |
|
|
| class SuperresolutionHybrid2X(nn.Module): |
| def __init__(self, in_channels, out_channels,padding_mode='zeros',v2=False): |
| super().__init__() |
| self.out_channels = out_channels |
| self.block0 = SynthesisBlockNoUp(in_channels, 128,padding_mode=padding_mode) |
| self.block1 = SynthesisBlock(128, out_channels, is_last=True,padding_mode=padding_mode,v2=v2) |
|
|
| def forward(self, x): |
| img = x[:, :3, :, :].clone() |
| if self.out_channels == 4: |
| opa = x[:, -1:, :, :].unsqueeze(1) |
| else: |
| opa = None |
| x = self.block0(x) |
| x = self.block1(x,img,opa) |
| return x |
| |
| class SynthesisBlockNoUp(torch.nn.Module): |
| |
| def __init__(self, in_channels, out_channels,padding_mode): |
| super().__init__() |
| self.conv0 = ConvBlock(in_channels=in_channels, out_channels=128,padding_mode= padding_mode) |
| self.conv1 = ConvBlock(in_channels=128, out_channels=out_channels,padding_mode= padding_mode) |
| self.skip_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0,padding_mode= padding_mode) |
| |
| def forward(self, x): |
| input = self.skip_conv(x) |
| x = self.conv0(x) |
| x = self.conv1(x) |
| return x+input |
|
|
| class SynthesisBlock(torch.nn.Module): |
| |
| |
| def __init__(self, in_channels, out_channels,padding_mode, is_last=True, v2=False): |
| super().__init__() |
| self.conv0 = ConvBlock(in_channels=in_channels, out_channels = 128,padding_mode= padding_mode) |
| self.conv1 = ConvBlock(in_channels=128, out_channels = 128,padding_mode= padding_mode) |
| self.is_last = is_last |
| self.v2 = v2 |
| self.toRGB = nn.Conv2d(in_channels=128, out_channels= out_channels, kernel_size=1, stride=1, padding=0,padding_mode= padding_mode, bias= not self.v2) |
| self.out_channels = out_channels |
| |
| def forward(self, x,img,opa = None): |
| x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False,antialias=True) |
| img = F.interpolate(img, scale_factor=2, mode='bilinear', align_corners=False,antialias=True) |
| if opa is not None: |
| opa = F.interpolate(opa, scale_factor=2, mode='nearest', align_corners=False,antialias=True) |
| x = torch.cat([x,opa],dim=1) |
| x = self.conv0(x) |
| x = self.conv1(x) |
| if self.is_last: |
| x = self.toRGB(x) |
| if not self.v2: |
| x = torch.tanh(x) |
| img = img + x[:,:3] |
| |
| if opa is not None: |
| opa = opa + x[:,3:] |
| opa = torch.clamp(opa, 0, 1) |
| img = torch.clamp(img, 0, 1) |
| |
| if opa is not None: |
| return img,opa |
| return img |
| |
| class ConvBlock(nn.Module): |
| def __init__(self, in_channels, out_channels,padding_mode, kernel_size=3, stride=1, padding=1, if_act = True): |
| super().__init__() |
| self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, padding_mode=padding_mode) |
| self.norm = nn.InstanceNorm2d(out_channels) |
| self.act = nn.LeakyReLU(0.2) if if_act else nn.Identity() |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x = self.norm(x) |
| x = self.act(x) |
| return x |
|
|