Sat3DGen / source /sr_module.py
qian43's picture
Upload 115 files
874cec4 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from source.vae_hacked import ResnetBlock
class SuperresolutionHybrid2X(nn.Module):
def __init__(self, in_channels, out_channels,padding_mode='zeros',v2=False):
super().__init__()
self.out_channels = out_channels
self.block0 = SynthesisBlockNoUp(in_channels, 128,padding_mode=padding_mode)
self.block1 = SynthesisBlock(128, out_channels, is_last=True,padding_mode=padding_mode,v2=v2)
def forward(self, x):
img = x[:, :3, :, :].clone()
if self.out_channels == 4:
opa = x[:, -1:, :, :].unsqueeze(1)
else:
opa = None
x = self.block0(x)
x = self.block1(x,img,opa)
return x
class SynthesisBlockNoUp(torch.nn.Module):
# conv group
def __init__(self, in_channels, out_channels,padding_mode):
super().__init__()
self.conv0 = ConvBlock(in_channels=in_channels, out_channels=128,padding_mode= padding_mode)
self.conv1 = ConvBlock(in_channels=128, out_channels=out_channels,padding_mode= padding_mode)
self.skip_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0,padding_mode= padding_mode)
def forward(self, x):
input = self.skip_conv(x)
x = self.conv0(x)
x = self.conv1(x)
return x+input
class SynthesisBlock(torch.nn.Module):
# need 2x upsample
# need toRGB layer (1x1 conv) to get 3 channels
def __init__(self, in_channels, out_channels,padding_mode, is_last=True, v2=False):
super().__init__()
self.conv0 = ConvBlock(in_channels=in_channels, out_channels = 128,padding_mode= padding_mode)
self.conv1 = ConvBlock(in_channels=128, out_channels = 128,padding_mode= padding_mode)
self.is_last = is_last
self.v2 = v2
self.toRGB = nn.Conv2d(in_channels=128, out_channels= out_channels, kernel_size=1, stride=1, padding=0,padding_mode= padding_mode, bias= not self.v2)
self.out_channels = out_channels
def forward(self, x,img,opa = None):
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False,antialias=True)
img = F.interpolate(img, scale_factor=2, mode='bilinear', align_corners=False,antialias=True)
if opa is not None:
opa = F.interpolate(opa, scale_factor=2, mode='nearest', align_corners=False,antialias=True)
x = torch.cat([x,opa],dim=1)
x = self.conv0(x)
x = self.conv1(x)
if self.is_last:
x = self.toRGB(x)
if not self.v2:
x = torch.tanh(x)
img = img + x[:,:3]
# clamp
if opa is not None:
opa = opa + x[:,3:]
opa = torch.clamp(opa, 0, 1)
img = torch.clamp(img, 0, 1)
# x = torch.sigmoid(x)
if opa is not None:
return img,opa
return img
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels,padding_mode, kernel_size=3, stride=1, padding=1, if_act = True):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, padding_mode=padding_mode)
self.norm = nn.InstanceNorm2d(out_channels)
self.act = nn.LeakyReLU(0.2) if if_act else nn.Identity()
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
x = self.act(x)
return x