| import numpy as np |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.nn import init |
|
|
| import einops |
|
|
| def conv3x3(in_channels, out_channels, stride=1, |
| padding=1, bias=True, groups=1): |
| return nn.Conv2d( |
| in_channels, |
| out_channels, |
| kernel_size=3, |
| stride=stride, |
| padding=padding, |
| bias=bias, |
| groups=groups) |
|
|
| def upconv2x2(in_channels, out_channels, mode='transpose'): |
| if mode == 'transpose': |
| return nn.ConvTranspose2d( |
| in_channels, |
| out_channels, |
| kernel_size=2, |
| stride=2) |
| else: |
| |
| |
| return nn.Sequential( |
| nn.Upsample(mode='bilinear', scale_factor=2), |
| conv1x1(in_channels, out_channels)) |
|
|
| def conv1x1(in_channels, out_channels, groups=1): |
| return nn.Conv2d( |
| in_channels, |
| out_channels, |
| kernel_size=1, |
| groups=groups, |
| stride=1) |
|
|
| class ConvTriplane3dAware(nn.Module): |
| """ 3D aware triplane conv (as described in RODIN) """ |
| def __init__(self, internal_conv_f, in_channels, out_channels, order='xz'): |
| """ |
| Args: |
| internal_conv_f: function that should return a 2D convolution Module |
| given in and out channels |
| order: if triplane input is in 'xz' order |
| """ |
| super(ConvTriplane3dAware, self).__init__() |
| |
| self.in_channels = in_channels |
| self.out_channels = out_channels |
| assert order in ['xz', 'zx'] |
| self.order = order |
| |
| self.plane_convs = nn.ModuleList([ |
| internal_conv_f(3*self.in_channels, self.out_channels) for _ in range(3)]) |
| |
| def forward(self, triplanes_list): |
| """ |
| Args: |
| triplanes_list: [(B,Ci,H,W)]*3 in xy,yz,(zx or xz) depending on order |
| Returns: |
| out_triplanes_list: [(B,Co,H,W)]*3 in xy,yz,(zx or xz) depending on order |
| """ |
| inps = list(triplanes_list) |
| xp = 1 |
| yp = 2 |
| zp = 0 |
|
|
| if self.order == 'xz': |
| |
| inps[yp] = einops.rearrange(inps[yp], 'b c x z -> b c z x') |
|
|
|
|
| oplanes = [None]*3 |
| |
| for iplane in [zp, xp, yp]: |
| |
|
|
| |
| |
| |
| jplane = (iplane+1)%3 |
| kplane = (iplane+2)%3 |
|
|
| ifeat = inps[iplane] |
| |
| |
|
|
| |
| |
| jpool = torch.mean(inps[jplane], dim=3 ,keepdim=True) |
| jpool = einops.rearrange(jpool, 'b c k 1 -> b c 1 k') |
| jpool = einops.repeat(jpool, 'b c 1 k -> b c j k', j=ifeat.size(2)) |
|
|
| |
| |
| kpool = torch.mean(inps[kplane], dim=2 ,keepdim=True) |
| kpool = einops.rearrange(kpool, 'b c 1 j -> b c j 1') |
| kpool = einops.repeat(kpool, 'b c j 1 -> b c j k', k=ifeat.size(3)) |
|
|
| |
| |
| |
|
|
| |
| catfeat = torch.cat([ifeat, jpool, kpool], dim=1) |
| oplane = self.plane_convs[iplane](catfeat) |
| oplanes[iplane] = oplane |
|
|
| if self.order == 'xz': |
| |
| oplanes[yp] = einops.rearrange(oplanes[yp], 'b c z x -> b c x z') |
|
|
| return oplanes |
|
|
| def roll_triplanes(triplanes_list): |
| |
| tristack = torch.stack((triplanes_list),dim=2) |
| return einops.rearrange(tristack, 'b c tri h w -> b c (tri h) w', tri=3) |
|
|
| def unroll_triplanes(rolled_triplane): |
| |
| tristack = einops.rearrange(rolled_triplane, 'b c (tri h) w -> b c tri h w', tri=3) |
| return torch.unbind(tristack, dim=2) |
|
|
| def conv1x1triplane3daware(in_channels, out_channels, order='xz', **kwargs): |
| return ConvTriplane3dAware(lambda inp, out: conv1x1(inp,out,**kwargs), |
| in_channels, out_channels,order=order) |
|
|
| def Normalize(in_channels, num_groups=32): |
| num_groups = min(in_channels, num_groups) |
| return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) |
|
|
| def nonlinearity(x): |
| |
| |
| return x*torch.sigmoid(x) |
|
|
| class Upsample(nn.Module): |
| def __init__(self, in_channels, with_conv): |
| super().__init__() |
| self.with_conv = with_conv |
| if self.with_conv: |
| self.conv = torch.nn.Conv2d(in_channels, |
| in_channels, |
| kernel_size=3, |
| stride=1, |
| padding=1) |
|
|
| def forward(self, x): |
| x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") |
| if self.with_conv: |
| x = self.conv(x) |
| return x |
|
|
| class Downsample(nn.Module): |
| def __init__(self, in_channels, with_conv): |
| super().__init__() |
| self.with_conv = with_conv |
| if self.with_conv: |
| |
| self.conv = torch.nn.Conv2d(in_channels, |
| in_channels, |
| kernel_size=3, |
| stride=2, |
| padding=0) |
|
|
| def forward(self, x): |
| if self.with_conv: |
| pad = (0,1,0,1) |
| x = torch.nn.functional.pad(x, pad, mode="constant", value=0) |
| x = self.conv(x) |
| else: |
| x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) |
| return x |
|
|
| class ResnetBlock3dAware(nn.Module): |
| def __init__(self, in_channels, out_channels=None): |
| |
| super().__init__() |
| self.in_channels = in_channels |
| out_channels = in_channels if out_channels is None else out_channels |
| self.out_channels = out_channels |
| |
|
|
| self.norm1 = Normalize(in_channels) |
| self.conv1 = conv3x3(self.in_channels, self.out_channels) |
|
|
| self.norm_mid = Normalize(out_channels) |
| self.conv_3daware = conv1x1triplane3daware(self.out_channels, self.out_channels) |
|
|
| self.norm2 = Normalize(out_channels) |
| self.conv2 = conv3x3(self.out_channels, self.out_channels) |
|
|
| if self.in_channels != self.out_channels: |
| self.nin_shortcut = torch.nn.Conv2d(in_channels, |
| out_channels, |
| kernel_size=1, |
| stride=1, |
| padding=0) |
|
|
| def forward(self, x): |
| |
| h = x |
| h = self.norm1(h) |
| h = nonlinearity(h) |
| h = self.conv1(h) |
|
|
| |
| h = self.norm_mid(h) |
| h = nonlinearity(h) |
| h = unroll_triplanes(h) |
| h = self.conv_3daware(h) |
| h = roll_triplanes(h) |
|
|
| |
| h = self.norm2(h) |
| h = nonlinearity(h) |
| h = self.conv2(h) |
|
|
| if self.in_channels != self.out_channels: |
| x = self.nin_shortcut(x) |
|
|
| return x+h |
|
|
| class DownConv3dAware(nn.Module): |
| """ |
| A helper Module that performs 2 convolutions and 1 MaxPool. |
| A ReLU activation follows each convolution. |
| """ |
| def __init__(self, in_channels, out_channels, downsample=True, with_conv=False): |
| super(DownConv3dAware, self).__init__() |
|
|
| self.in_channels = in_channels |
| self.out_channels = out_channels |
|
|
| self.block = ResnetBlock3dAware(in_channels=in_channels, |
| out_channels=out_channels) |
|
|
| self.do_downsample = downsample |
| self.downsample = Downsample(out_channels, with_conv=with_conv) |
|
|
| def forward(self, x): |
| """ |
| rolled input, rolled output |
| Args: |
| x: rolled (b c (tri*h) w) |
| """ |
| x = self.block(x) |
| before_pool = x |
| |
| |
| if self.do_downsample: |
| |
| x = einops.rearrange(x, 'b c (tri h) w -> b (c tri) h w', tri=3) |
| x = self.downsample(x) |
| |
| x = einops.rearrange(x, 'b (c tri) h w -> b c (tri h) w', tri=3) |
| return x, before_pool |
|
|
| class UpConv3dAware(nn.Module): |
| """ |
| A helper Module that performs 2 convolutions and 1 UpConvolution. |
| A ReLU activation follows each convolution. |
| """ |
| def __init__(self, in_channels, out_channels, |
| merge_mode='concat', with_conv=False): |
| super(UpConv3dAware, self).__init__() |
|
|
| self.in_channels = in_channels |
| self.out_channels = out_channels |
| self.merge_mode = merge_mode |
| |
| self.upsample = Upsample(in_channels, with_conv) |
|
|
| if self.merge_mode == 'concat': |
| self.norm1 = Normalize(in_channels+out_channels) |
| self.block = ResnetBlock3dAware(in_channels=in_channels+out_channels, |
| out_channels=out_channels) |
| else: |
| self.norm1 = Normalize(in_channels) |
| self.block = ResnetBlock3dAware(in_channels=in_channels, |
| out_channels=out_channels) |
| |
|
|
| def forward(self, from_down, from_up): |
| """ Forward pass |
| rolled inputs, rolled output |
| rolled (b c (tri*h) w) |
| Arguments: |
| from_down: tensor from the encoder pathway |
| from_up: upconv'd tensor from the decoder pathway |
| """ |
| |
| from_up = self.upsample(from_up) |
| if self.merge_mode == 'concat': |
| x = torch.cat((from_up, from_down), 1) |
| else: |
| x = from_up + from_down |
|
|
| x = self.norm1(x) |
| x = self.block(x) |
| return x |
|
|
| class UNetTriplane3dAware(nn.Module): |
| def __init__(self, out_channels, in_channels=3, depth=5, |
| start_filts=64, |
| use_initial_conv=False, |
| merge_mode='concat', **kwargs): |
| """ |
| Arguments: |
| in_channels: int, number of channels in the input tensor. |
| Default is 3 for RGB images. |
| depth: int, number of MaxPools in the U-Net. |
| start_filts: int, number of convolutional filters for the |
| first conv. |
| """ |
| super(UNetTriplane3dAware, self).__init__() |
| |
|
|
| self.out_channels = out_channels |
| self.in_channels = in_channels |
| self.start_filts = start_filts |
| self.depth = depth |
|
|
| self.use_initial_conv = use_initial_conv |
| if use_initial_conv: |
| self.conv_initial = conv1x1(self.in_channels, self.start_filts) |
|
|
| self.down_convs = [] |
| self.up_convs = [] |
|
|
| |
| for i in range(depth): |
| if i == 0: |
| ins = self.start_filts if use_initial_conv else self.in_channels |
| else: |
| ins = outs |
| outs = self.start_filts*(2**i) |
| downsamp_it = True if i < depth-1 else False |
|
|
| down_conv = DownConv3dAware(ins, outs, downsample = downsamp_it) |
| self.down_convs.append(down_conv) |
|
|
| for i in range(depth-1): |
| ins = outs |
| outs = ins // 2 |
| up_conv = UpConv3dAware(ins, outs, |
| merge_mode=merge_mode) |
| self.up_convs.append(up_conv) |
|
|
| |
| self.down_convs = nn.ModuleList(self.down_convs) |
| self.up_convs = nn.ModuleList(self.up_convs) |
|
|
| self.norm_out = Normalize(outs) |
| self.conv_final = conv1x1(outs, self.out_channels) |
|
|
| self.reset_params() |
|
|
| @staticmethod |
| def weight_init(m): |
| if isinstance(m, nn.Conv2d): |
| |
| init.xavier_normal_(m.weight) |
| init.constant_(m.bias, 0) |
|
|
|
|
| def reset_params(self): |
| for i, m in enumerate(self.modules()): |
| self.weight_init(m) |
|
|
|
|
| def forward(self, x): |
| """ |
| Args: |
| x: Stacked triplane expected to be in (B,3,C,H,W) |
| """ |
| |
| x = einops.rearrange(x, 'b tri c h w -> b c (tri h) w', tri=3) |
|
|
| if self.use_initial_conv: |
| x = self.conv_initial(x) |
|
|
| encoder_outs = [] |
| |
| for i, module in enumerate(self.down_convs): |
| x, before_pool = module(x) |
| encoder_outs.append(before_pool) |
| |
| |
| |
|
|
| for i, module in enumerate(self.up_convs): |
| before_pool = encoder_outs[-(i+2)] |
| x = module(before_pool, x) |
| |
| x = self.norm_out(x) |
| |
| |
| |
| |
| x = self.conv_final(nonlinearity(x)) |
|
|
| |
| x = einops.rearrange(x, 'b c (tri h) w -> b tri c h w', tri=3) |
| return x |
|
|
| |
| def setup_unet(output_channels, input_channels, unet_cfg): |
| if unet_cfg['use_3d_aware']: |
| assert(unet_cfg['rolled']) |
| unet = UNetTriplane3dAware( |
| out_channels=output_channels, |
| in_channels=input_channels, |
| depth=unet_cfg['depth'], |
| use_initial_conv=unet_cfg['use_initial_conv'], |
| start_filts=unet_cfg['start_hidden_channels'],) |
| else: |
| raise NotImplementedError |
| return unet |
|
|
|
|