| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from .correlation import correlation |
| |
|
| | def apply_offset(offset): |
| |
|
| | sizes = list(offset.size()[2:]) |
| | grid_list = torch.meshgrid([torch.arange(size, device=offset.device) for size in sizes]) |
| | grid_list = reversed(grid_list) |
| |
|
| | grid_list = [grid.float().unsqueeze(0) + offset[:, dim, ...] |
| | for dim, grid in enumerate(grid_list)] |
| |
|
| | grid_list = [grid / ((size - 1.0) / 2.0) - 1.0 |
| | for grid, size in zip(grid_list, reversed(sizes))] |
| |
|
| | return torch.stack(grid_list, dim=-1) |
| |
|
| |
|
| | class ResBlock(nn.Module): |
| | def __init__(self, in_channels): |
| | super(ResBlock, self).__init__() |
| | self.block = nn.Sequential( |
| | nn.BatchNorm2d(in_channels), |
| | nn.ReLU(inplace=True), |
| | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=False), |
| | nn.BatchNorm2d(in_channels), |
| | nn.ReLU(inplace=True), |
| | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=False) |
| | ) |
| |
|
| | def forward(self, x): |
| | return self.block(x) + x |
| |
|
| |
|
| | class DownSample(nn.Module): |
| | def __init__(self, in_channels, out_channels): |
| | super(DownSample, self).__init__() |
| | self.block= nn.Sequential( |
| | nn.BatchNorm2d(in_channels), |
| | nn.ReLU(inplace=True), |
| | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False) |
| | ) |
| |
|
| | def forward(self, x): |
| | return self.block(x) |
| |
|
| |
|
| |
|
| | class FeatureEncoder(nn.Module): |
| | def __init__(self, in_channels, chns=[64,128,256,256,256]): |
| | super(FeatureEncoder, self).__init__() |
| | self.encoders = [] |
| | for i, out_chns in enumerate(chns): |
| | if i == 0: |
| | encoder = nn.Sequential(DownSample(in_channels, out_chns), |
| | ResBlock(out_chns), |
| | ResBlock(out_chns)) |
| | else: |
| | encoder = nn.Sequential(DownSample(chns[i-1], out_chns), |
| | ResBlock(out_chns), |
| | ResBlock(out_chns)) |
| | |
| | self.encoders.append(encoder) |
| |
|
| | self.encoders = nn.ModuleList(self.encoders) |
| |
|
| |
|
| | def forward(self, x): |
| | encoder_features = [] |
| | for encoder in self.encoders: |
| | x = encoder(x) |
| | encoder_features.append(x) |
| | return encoder_features |
| |
|
| | class RefinePyramid(nn.Module): |
| | def __init__(self, chns=[64,128,256,256,256], fpn_dim=256): |
| | super(RefinePyramid, self).__init__() |
| | self.chns = chns |
| |
|
| | self.adaptive = [] |
| | for in_chns in list(reversed(chns)): |
| | adaptive_layer = nn.Conv2d(in_chns, fpn_dim, kernel_size=1) |
| | self.adaptive.append(adaptive_layer) |
| | self.adaptive = nn.ModuleList(self.adaptive) |
| |
|
| | self.smooth = [] |
| | for i in range(len(chns)): |
| | smooth_layer = nn.Conv2d(fpn_dim, fpn_dim, kernel_size=3, padding=1) |
| | self.smooth.append(smooth_layer) |
| | self.smooth = nn.ModuleList(self.smooth) |
| |
|
| | def forward(self, x): |
| | conv_ftr_list = x |
| | |
| | feature_list = [] |
| | last_feature = None |
| | for i, conv_ftr in enumerate(list(reversed(conv_ftr_list))): |
| | feature = self.adaptive[i](conv_ftr) |
| |
|
| | if last_feature is not None: |
| | feature = feature + F.interpolate(last_feature, scale_factor=2, mode='nearest') |
| |
|
| | feature = self.smooth[i](feature) |
| | last_feature = feature |
| | feature_list.append(feature) |
| |
|
| | return tuple(reversed(feature_list)) |
| |
|
| |
|
| | class AFlowNet(nn.Module): |
| | def __init__(self, num_pyramid, fpn_dim=256): |
| | super(AFlowNet, self).__init__() |
| | self.netMain = [] |
| | self.netRefine = [] |
| | for i in range(num_pyramid): |
| | netMain_layer = torch.nn.Sequential( |
| | torch.nn.Conv2d(in_channels=49, out_channels=128, kernel_size=3, stride=1, padding=1), |
| | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| | torch.nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1), |
| | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| | torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1), |
| | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| | torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1) |
| | ) |
| |
|
| | netRefine_layer = torch.nn.Sequential( |
| | torch.nn.Conv2d(2 * fpn_dim, out_channels=128, kernel_size=3, stride=1, padding=1), |
| | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| | torch.nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1), |
| | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| | torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1), |
| | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| | torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1) |
| | ) |
| | self.netMain.append(netMain_layer) |
| | self.netRefine.append(netRefine_layer) |
| |
|
| | self.netMain = nn.ModuleList(self.netMain) |
| | self.netRefine = nn.ModuleList(self.netRefine) |
| |
|
| |
|
| | def forward(self, x, x_warps, x_conds, warp_feature=True): |
| | last_flow = None |
| |
|
| | for i in range(len(x_warps)): |
| | x_warp = x_warps[len(x_warps) - 1 - i] |
| | x_cond = x_conds[len(x_warps) - 1 - i] |
| |
|
| | if last_flow is not None and warp_feature: |
| | x_warp_after = F.grid_sample(x_warp, last_flow.detach().permute(0, 2, 3, 1), |
| | mode='bilinear', padding_mode='border') |
| | else: |
| | x_warp_after = x_warp |
| |
|
| | tenCorrelation = F.leaky_relu(input=correlation.FunctionCorrelation(tenFirst=x_warp_after, tenSecond=x_cond, intStride=1), negative_slope=0.1, inplace=False) |
| | flow = self.netMain[i](tenCorrelation) |
| | flow = apply_offset(flow) |
| |
|
| | if last_flow is not None: |
| | flow = F.grid_sample(last_flow, flow, mode='bilinear', padding_mode='border') |
| | else: |
| | flow = flow.permute(0, 3, 1, 2) |
| |
|
| | last_flow = flow |
| | x_warp = F.grid_sample(x_warp, flow.permute(0, 2, 3, 1),mode='bilinear', padding_mode='border') |
| | concat = torch.cat([x_warp,x_cond],1) |
| | flow = self.netRefine[i](concat) |
| | flow = apply_offset(flow) |
| | flow = F.grid_sample(last_flow, flow, mode='bilinear', padding_mode='border') |
| |
|
| | last_flow = F.interpolate(flow, scale_factor=2, mode='bilinear') |
| |
|
| | x_warp = F.grid_sample(x, last_flow.permute(0, 2, 3, 1), |
| | mode='bilinear', padding_mode='border') |
| | return x_warp, last_flow, |
| |
|
| |
|
| | class AFWM(nn.Module): |
| |
|
| | def __init__(self, opt, input_nc): |
| | super(AFWM, self).__init__() |
| | num_filters = [64,128,256,256,256] |
| | self.image_features = FeatureEncoder(3, num_filters) |
| | self.cond_features = FeatureEncoder(input_nc, num_filters) |
| | self.image_FPN = RefinePyramid(num_filters) |
| | self.cond_FPN = RefinePyramid(num_filters) |
| | self.aflow_net = AFlowNet(len(num_filters)) |
| |
|
| | def forward(self, cond_input, image_input): |
| | cond_pyramids = self.cond_FPN(self.cond_features(cond_input)) |
| | image_pyramids = self.image_FPN(self.image_features(image_input)) |
| |
|
| | x_warp, last_flow = self.aflow_net(image_input, image_pyramids, cond_pyramids) |
| |
|
| | return x_warp, last_flow |
| |
|
| |
|