| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| from torch.nn import functional as F |
| from torch.utils.model_zoo import load_url as load_state_dict_from_url |
| import cv2 |
| import math |
| import numpy |
| import numpy as np |
| import PIL |
| import PIL.Image |
| import re |
| import sys |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.optim as optim |
| import torch.optim.lr_scheduler as lrs |
| from vfi_models.ops import FunctionCorrelation, FunctionAdaCoF, ModuleSoftsplat |
| from vfi_utils import get_ckpt_container_path |
| import pathlib |
| MODEL_TYPE = pathlib.Path(__file__).parent.name |
|
|
| |
| def identity(x): |
| return x |
|
|
|
|
| def backwarp(tenInput, tenFlow): |
| backwarp_tenGrid = {} |
| backwarp_tenPartial = {} |
| if str(tenFlow.shape) not in backwarp_tenGrid: |
| tenHor = ( |
| torch.linspace( |
| -1.0 + (1.0 / tenFlow.shape[3]), |
| 1.0 - (1.0 / tenFlow.shape[3]), |
| tenFlow.shape[3], |
| ) |
| .view(1, 1, 1, -1) |
| .expand(-1, -1, tenFlow.shape[2], -1) |
| ) |
| tenVer = ( |
| torch.linspace( |
| -1.0 + (1.0 / tenFlow.shape[2]), |
| 1.0 - (1.0 / tenFlow.shape[2]), |
| tenFlow.shape[2], |
| ) |
| .view(1, 1, -1, 1) |
| .expand(-1, -1, -1, tenFlow.shape[3]) |
| ) |
|
|
| backwarp_tenGrid[str(tenFlow.shape)] = torch.cat([tenHor, tenVer], 1).cuda() |
| |
|
|
| if str(tenFlow.shape) not in backwarp_tenPartial: |
| backwarp_tenPartial[str(tenFlow.shape)] = tenFlow.new_ones( |
| [tenFlow.shape[0], 1, tenFlow.shape[2], tenFlow.shape[3]] |
| ) |
| |
|
|
| tenFlow = torch.cat( |
| [ |
| tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), |
| tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0), |
| ], |
| 1, |
| ) |
| tenInput = torch.cat([tenInput, backwarp_tenPartial[str(tenFlow.shape)]], 1) |
|
|
| tenOutput = torch.nn.functional.grid_sample( |
| input=tenInput, |
| grid=(backwarp_tenGrid[str(tenFlow.shape)] + tenFlow).permute(0, 2, 3, 1), |
| mode="bilinear", |
| padding_mode="zeros", |
| align_corners=False, |
| ) |
|
|
| tenMask = tenOutput[:, -1:, :, :] |
| tenMask[tenMask > 0.999] = 1.0 |
| tenMask[tenMask < 1.0] = 0.0 |
|
|
| return tenOutput[:, :-1, :, :] * tenMask |
|
|
|
|
| |
|
|
| |
|
|
|
|
| class PWCNet(torch.nn.Module): |
| def __init__(self): |
| super(PWCNet, self).__init__() |
|
|
| class Extractor(torch.nn.Module): |
| def __init__(self): |
| super(Extractor, self).__init__() |
|
|
| self.netOne = torch.nn.Sequential( |
| torch.nn.Conv2d( |
| in_channels=3, |
| out_channels=16, |
| kernel_size=3, |
| stride=2, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| torch.nn.Conv2d( |
| in_channels=16, |
| out_channels=16, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| torch.nn.Conv2d( |
| in_channels=16, |
| out_channels=16, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| ) |
|
|
| self.netTwo = torch.nn.Sequential( |
| torch.nn.Conv2d( |
| in_channels=16, |
| out_channels=32, |
| kernel_size=3, |
| stride=2, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| torch.nn.Conv2d( |
| in_channels=32, |
| out_channels=32, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| torch.nn.Conv2d( |
| in_channels=32, |
| out_channels=32, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| ) |
|
|
| self.netThr = torch.nn.Sequential( |
| torch.nn.Conv2d( |
| in_channels=32, |
| out_channels=64, |
| kernel_size=3, |
| stride=2, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| torch.nn.Conv2d( |
| in_channels=64, |
| out_channels=64, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| torch.nn.Conv2d( |
| in_channels=64, |
| out_channels=64, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| ) |
|
|
| self.netFou = torch.nn.Sequential( |
| torch.nn.Conv2d( |
| in_channels=64, |
| out_channels=96, |
| kernel_size=3, |
| stride=2, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| torch.nn.Conv2d( |
| in_channels=96, |
| out_channels=96, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| torch.nn.Conv2d( |
| in_channels=96, |
| out_channels=96, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| ) |
|
|
| self.netFiv = torch.nn.Sequential( |
| torch.nn.Conv2d( |
| in_channels=96, |
| out_channels=128, |
| kernel_size=3, |
| stride=2, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| torch.nn.Conv2d( |
| in_channels=128, |
| out_channels=128, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| torch.nn.Conv2d( |
| in_channels=128, |
| out_channels=128, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| ) |
|
|
| self.netSix = torch.nn.Sequential( |
| torch.nn.Conv2d( |
| in_channels=128, |
| out_channels=196, |
| kernel_size=3, |
| stride=2, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| torch.nn.Conv2d( |
| in_channels=196, |
| out_channels=196, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| torch.nn.Conv2d( |
| in_channels=196, |
| out_channels=196, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| ) |
|
|
| |
|
|
| def forward(self, tenInput): |
| tenOne = self.netOne(tenInput) |
| tenTwo = self.netTwo(tenOne) |
| tenThr = self.netThr(tenTwo) |
| tenFou = self.netFou(tenThr) |
| tenFiv = self.netFiv(tenFou) |
| tenSix = self.netSix(tenFiv) |
|
|
| return [tenOne, tenTwo, tenThr, tenFou, tenFiv, tenSix] |
|
|
| |
|
|
| |
|
|
| class Decoder(torch.nn.Module): |
| def __init__(self, intLevel): |
| super(Decoder, self).__init__() |
|
|
| intPrevious = [ |
| None, |
| None, |
| 81 + 32 + 2 + 2, |
| 81 + 64 + 2 + 2, |
| 81 + 96 + 2 + 2, |
| 81 + 128 + 2 + 2, |
| 81, |
| None, |
| ][intLevel + 1] |
| intCurrent = [ |
| None, |
| None, |
| 81 + 32 + 2 + 2, |
| 81 + 64 + 2 + 2, |
| 81 + 96 + 2 + 2, |
| 81 + 128 + 2 + 2, |
| 81, |
| None, |
| ][intLevel + 0] |
|
|
| if intLevel < 6: |
| self.netUpflow = torch.nn.ConvTranspose2d( |
| in_channels=2, |
| out_channels=2, |
| kernel_size=4, |
| stride=2, |
| padding=1, |
| ) |
| if intLevel < 6: |
| self.netUpfeat = torch.nn.ConvTranspose2d( |
| in_channels=intPrevious + 128 + 128 + 96 + 64 + 32, |
| out_channels=2, |
| kernel_size=4, |
| stride=2, |
| padding=1, |
| ) |
| if intLevel < 6: |
| self.fltBackwarp = [None, None, None, 5.0, 2.5, 1.25, 0.625, None][ |
| intLevel + 1 |
| ] |
|
|
| self.netOne = torch.nn.Sequential( |
| torch.nn.Conv2d( |
| in_channels=intCurrent, |
| out_channels=128, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| ) |
|
|
| self.netTwo = torch.nn.Sequential( |
| torch.nn.Conv2d( |
| in_channels=intCurrent + 128, |
| out_channels=128, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| ) |
|
|
| self.netThr = torch.nn.Sequential( |
| torch.nn.Conv2d( |
| in_channels=intCurrent + 128 + 128, |
| out_channels=96, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| ) |
|
|
| self.netFou = torch.nn.Sequential( |
| torch.nn.Conv2d( |
| in_channels=intCurrent + 128 + 128 + 96, |
| out_channels=64, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| ) |
|
|
| self.netFiv = torch.nn.Sequential( |
| torch.nn.Conv2d( |
| in_channels=intCurrent + 128 + 128 + 96 + 64, |
| out_channels=32, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| ) |
|
|
| self.netSix = torch.nn.Sequential( |
| torch.nn.Conv2d( |
| in_channels=intCurrent + 128 + 128 + 96 + 64 + 32, |
| out_channels=2, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ) |
| ) |
|
|
| |
|
|
| def forward(self, tenFirst, tenSecond, objPrevious): |
| tenFlow = None |
| tenFeat = None |
|
|
| if objPrevious is None: |
| tenFlow = None |
| tenFeat = None |
|
|
| tenVolume = torch.nn.functional.leaky_relu( |
| input=FunctionCorrelation( |
| tenFirst=tenFirst, tenSecond=tenSecond |
| ), |
| negative_slope=0.1, |
| inplace=False, |
| ) |
|
|
| tenFeat = torch.cat([tenVolume], 1) |
|
|
| elif objPrevious is not None: |
| tenFlow = self.netUpflow(objPrevious["tenFlow"]) |
| tenFeat = self.netUpfeat(objPrevious["tenFeat"]) |
|
|
| tenVolume = torch.nn.functional.leaky_relu( |
| input=FunctionCorrelation( |
| tenFirst=tenFirst, |
| tenSecond=backwarp( |
| tenInput=tenSecond, tenFlow=tenFlow * self.fltBackwarp |
| ), |
| ), |
| negative_slope=0.1, |
| inplace=False, |
| ) |
|
|
| tenFeat = torch.cat([tenVolume, tenFirst, tenFlow, tenFeat], 1) |
|
|
| |
|
|
| tenFeat = torch.cat([self.netOne(tenFeat), tenFeat], 1) |
| tenFeat = torch.cat([self.netTwo(tenFeat), tenFeat], 1) |
| tenFeat = torch.cat([self.netThr(tenFeat), tenFeat], 1) |
| tenFeat = torch.cat([self.netFou(tenFeat), tenFeat], 1) |
| tenFeat = torch.cat([self.netFiv(tenFeat), tenFeat], 1) |
|
|
| tenFlow = self.netSix(tenFeat) |
|
|
| return {"tenFlow": tenFlow, "tenFeat": tenFeat} |
|
|
| |
|
|
| |
|
|
| class Refiner(torch.nn.Module): |
| def __init__(self): |
| super(Refiner, self).__init__() |
|
|
| self.netMain = torch.nn.Sequential( |
| torch.nn.Conv2d( |
| in_channels=81 + 32 + 2 + 2 + 128 + 128 + 96 + 64 + 32, |
| out_channels=128, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| dilation=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| torch.nn.Conv2d( |
| in_channels=128, |
| out_channels=128, |
| kernel_size=3, |
| stride=1, |
| padding=2, |
| dilation=2, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| torch.nn.Conv2d( |
| in_channels=128, |
| out_channels=128, |
| kernel_size=3, |
| stride=1, |
| padding=4, |
| dilation=4, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| torch.nn.Conv2d( |
| in_channels=128, |
| out_channels=96, |
| kernel_size=3, |
| stride=1, |
| padding=8, |
| dilation=8, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| torch.nn.Conv2d( |
| in_channels=96, |
| out_channels=64, |
| kernel_size=3, |
| stride=1, |
| padding=16, |
| dilation=16, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| torch.nn.Conv2d( |
| in_channels=64, |
| out_channels=32, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| dilation=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| torch.nn.Conv2d( |
| in_channels=32, |
| out_channels=2, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| dilation=1, |
| ), |
| ) |
|
|
| |
|
|
| def forward(self, tenInput): |
| return self.netMain(tenInput) |
|
|
| |
|
|
| |
|
|
| self.netExtractor = Extractor() |
|
|
| self.netTwo = Decoder(2) |
| self.netThr = Decoder(3) |
| self.netFou = Decoder(4) |
| self.netFiv = Decoder(5) |
| self.netSix = Decoder(6) |
|
|
| self.netRefiner = Refiner() |
|
|
| self.load_state_dict( |
| { |
| strKey.replace("module", "net"): tenWeight |
| for strKey, tenWeight in torch.hub.load_state_dict_from_url( |
| url="http://content.sniklaus.com/github/pytorch-pwc/network-" |
| + "default" |
| + ".pytorch", |
| model_dir=get_ckpt_container_path(MODEL_TYPE) |
| ).items() |
| } |
| ) |
|
|
| |
|
|
| def forward(self, tenFirst, tenSecond, *args): |
| |
| if len(args) == 0: |
| tenFirst = self.netExtractor(tenFirst) |
| tenSecond = self.netExtractor(tenSecond) |
| else: |
| tenFirst, tenSecond = args |
|
|
| objEstimate = self.netSix(tenFirst[-1], tenSecond[-1], None) |
| objEstimate = self.netFiv(tenFirst[-2], tenSecond[-2], objEstimate) |
| objEstimate = self.netFou(tenFirst[-3], tenSecond[-3], objEstimate) |
| objEstimate = self.netThr(tenFirst[-4], tenSecond[-4], objEstimate) |
| objEstimate = self.netTwo(tenFirst[-5], tenSecond[-5], objEstimate) |
|
|
| return objEstimate["tenFlow"] + self.netRefiner(objEstimate["tenFeat"]) |
|
|
| |
|
|
| def extract_pyramid(self, tenFirst, tenSecond): |
| return self.netExtractor(tenFirst), self.netExtractor(tenSecond) |
|
|
| def extract_pyramid_single(self, tenFirst): |
| return self.netExtractor(tenFirst) |
|
|
|
|
| |
|
|
| netNetwork = None |
|
|
| |
|
|
|
|
| def estimate(tenFirst, tenSecond): |
| global netNetwork |
|
|
| if netNetwork is None: |
| netNetwork = Network().cuda().eval() |
| |
|
|
| assert tenFirst.shape[1] == tenSecond.shape[1] |
| assert tenFirst.shape[2] == tenSecond.shape[2] |
|
|
| intWidth = tenFirst.shape[2] |
| intHeight = tenFirst.shape[1] |
|
|
| assert ( |
| intWidth == 1024 |
| ) |
| assert ( |
| intHeight == 436 |
| ) |
|
|
| tenPreprocessedFirst = tenFirst.cuda().view(1, 3, intHeight, intWidth) |
| tenPreprocessedSecond = tenSecond.cuda().view(1, 3, intHeight, intWidth) |
|
|
| intPreprocessedWidth = int(math.floor(math.ceil(intWidth / 64.0) * 64.0)) |
| intPreprocessedHeight = int(math.floor(math.ceil(intHeight / 64.0) * 64.0)) |
|
|
| tenPreprocessedFirst = torch.nn.functional.interpolate( |
| input=tenPreprocessedFirst, |
| size=(intPreprocessedHeight, intPreprocessedWidth), |
| mode="bilinear", |
| align_corners=False, |
| ) |
| tenPreprocessedSecond = torch.nn.functional.interpolate( |
| input=tenPreprocessedSecond, |
| size=(intPreprocessedHeight, intPreprocessedWidth), |
| mode="bilinear", |
| align_corners=False, |
| ) |
|
|
| tenFlow = 20.0 * torch.nn.functional.interpolate( |
| input=netNetwork(tenPreprocessedFirst, tenPreprocessedSecond), |
| size=(intHeight, intWidth), |
| mode="bilinear", |
| align_corners=False, |
| ) |
|
|
| tenFlow[:, 0, :, :] *= float(intWidth) / float(intPreprocessedWidth) |
| tenFlow[:, 1, :, :] *= float(intHeight) / float(intPreprocessedHeight) |
|
|
| return tenFlow[0, :, :, :].cpu() |
|
|
|
|
| |
|
|
|
|
| class Upsampler_8tap(nn.Module): |
| def __init__(self): |
| super(Upsampler_8tap, self).__init__() |
| filt_8tap = torch.tensor([[-1, 4, -11, 40, 40, -11, 4, -1]]).div(64) |
| self.filter = nn.Parameter(filt_8tap.repeat(3, 1, 1, 1), requires_grad=False) |
|
|
| def forward(self, im): |
| b, c, h, w = im.shape |
| im_up = torch.zeros(b, c, h * 2, w * 2).to(im.device) |
| im_up[:, :, ::2, ::2] = im |
|
|
| p = (8 - 1) // 2 |
| im_up_row = F.conv2d( |
| F.pad(im, pad=(p, p + 1, 0, 0), mode="reflect"), self.filter, groups=3 |
| ) |
| im_up[:, :, 0::2, 1::2] = im_up_row |
| im_up_col = torch.transpose( |
| F.conv2d( |
| F.pad(torch.transpose(im, 2, 3), pad=(p, p + 1, 0, 0), mode="reflect"), |
| self.filter, |
| groups=3, |
| ), |
| 2, |
| 3, |
| ) |
| im_up[:, :, 1::2, 0::2] = im_up_col |
| im_up_cross = F.conv2d( |
| F.pad(im_up[:, :, 1::2, ::2], pad=(p, p + 1, 0, 0), mode="reflect"), |
| self.filter, |
| groups=3, |
| ) |
| im_up[:, :, 1::2, 1::2] = im_up_cross |
| return im_up |
|
|
| |
|
|
|
|
| model_urls = { |
| "r3d_18": "https://download.pytorch.org/models/r3d_18-b3b3357e.pth", |
| "mc3_18": "https://download.pytorch.org/models/mc3_18-a90a0ba3.pth", |
| "r2plus1d_18": "https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth", |
| } |
|
|
|
|
| class Conv3DSimple(nn.Conv3d): |
| def __init__(self, in_planes, out_planes, midplanes=None, stride=1, padding=1): |
| super(Conv3DSimple, self).__init__( |
| in_channels=in_planes, |
| out_channels=out_planes, |
| kernel_size=(3, 3, 3), |
| stride=stride, |
| padding=padding, |
| bias=False, |
| ) |
|
|
| @staticmethod |
| def get_downsample_stride(stride, temporal_stride): |
| if temporal_stride: |
| return (temporal_stride, stride, stride) |
| else: |
| return (stride, stride, stride) |
|
|
|
|
| class Conv2Plus1D(nn.Sequential): |
| def __init__(self, in_planes, out_planes, midplanes, stride=1, padding=1): |
| super(Conv2Plus1D, self).__init__( |
| nn.Conv3d( |
| in_planes, |
| midplanes, |
| kernel_size=(1, 3, 3), |
| stride=(1, stride, stride), |
| padding=(0, padding, padding), |
| bias=False, |
| ), |
| batchnorm(midplanes), |
| nn.ReLU(inplace=True), |
| nn.Conv3d( |
| midplanes, |
| out_planes, |
| kernel_size=(3, 1, 1), |
| stride=(stride, 1, 1), |
| padding=(padding, 0, 0), |
| bias=False, |
| ), |
| ) |
|
|
| @staticmethod |
| def get_downsample_stride(stride): |
| return stride, stride, stride |
|
|
|
|
| class Conv3DNoTemporal(nn.Conv3d): |
| def __init__(self, in_planes, out_planes, midplanes=None, stride=1, padding=1): |
| super(Conv3DNoTemporal, self).__init__( |
| in_channels=in_planes, |
| out_channels=out_planes, |
| kernel_size=(1, 3, 3), |
| stride=(1, stride, stride), |
| padding=(0, padding, padding), |
| bias=False, |
| ) |
|
|
| @staticmethod |
| def get_downsample_stride(stride): |
| return 1, stride, stride |
|
|
|
|
| class SEGating(nn.Module): |
| def __init__(self, inplanes, reduction=16): |
| super().__init__() |
|
|
| self.pool = nn.AdaptiveAvgPool3d(1) |
| self.attn_layer = nn.Sequential( |
| nn.Conv3d(inplanes, inplanes, kernel_size=1, stride=1, bias=True), |
| nn.Sigmoid(), |
| ) |
|
|
| def forward(self, x): |
| out = self.pool(x) |
| y = self.attn_layer(out) |
| return x * y |
|
|
|
|
| class BasicBlock(nn.Module): |
| expansion = 1 |
|
|
| def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): |
| midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes) |
|
|
| super(BasicBlock, self).__init__() |
| self.conv1 = nn.Sequential( |
| conv_builder(inplanes, planes, midplanes, stride), |
| batchnorm(planes), |
| nn.ReLU(inplace=True), |
| ) |
| self.conv2 = nn.Sequential( |
| conv_builder(planes, planes, midplanes), batchnorm(planes) |
| ) |
| self.fg = SEGating(planes) |
| self.relu = nn.ReLU(inplace=True) |
| self.downsample = downsample |
| self.stride = stride |
|
|
| def forward(self, x): |
| residual = x |
| out = self.conv1(x) |
| out = self.conv2(out) |
| out = self.fg(out) |
| if self.downsample is not None: |
| residual = self.downsample(x) |
|
|
| out += residual |
| out = self.relu(out) |
|
|
| return out |
|
|
|
|
| class Bottleneck(nn.Module): |
| expansion = 4 |
|
|
| def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): |
| super(Bottleneck, self).__init__() |
| midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes) |
|
|
| |
| self.conv1 = nn.Sequential( |
| nn.Conv3d(inplanes, planes, kernel_size=1, bias=False), |
| batchnorm(planes), |
| nn.ReLU(inplace=True), |
| ) |
| |
| self.conv2 = nn.Sequential( |
| conv_builder(planes, planes, midplanes, stride), |
| batchnorm(planes), |
| nn.ReLU(inplace=True), |
| ) |
|
|
| |
| self.conv3 = nn.Sequential( |
| nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False), |
| batchnorm(planes * self.expansion), |
| ) |
| self.relu = nn.ReLU(inplace=True) |
| self.downsample = downsample |
| self.stride = stride |
|
|
| def forward(self, x): |
| residual = x |
|
|
| out = self.conv1(x) |
| out = self.conv2(out) |
| out = self.conv3(out) |
|
|
| if self.downsample is not None: |
| residual = self.downsample(x) |
|
|
| out += residual |
| out = self.relu(out) |
|
|
| return out |
|
|
|
|
| class BasicStem(nn.Sequential): |
| """The default conv-batchnorm-relu stem""" |
|
|
| def __init__(self, outplanes=32): |
| super(BasicStem, self).__init__( |
| nn.Conv3d( |
| 3, |
| outplanes, |
| kernel_size=(3, 7, 7), |
| stride=(1, 2, 2), |
| padding=(1, 3, 3), |
| bias=False, |
| ), |
| batchnorm(outplanes), |
| nn.ReLU(inplace=True), |
| ) |
|
|
|
|
| class R2Plus1dStem(nn.Sequential): |
| """R(2+1)D stem is different than the default one as it uses separated 3D convolution""" |
|
|
| def __init__(self): |
| super(R2Plus1dStem, self).__init__( |
| nn.Conv3d( |
| 3, |
| 45, |
| kernel_size=(1, 7, 7), |
| stride=(1, 2, 2), |
| padding=(0, 3, 3), |
| bias=False, |
| ), |
| batchnorm(45), |
| nn.ReLU(inplace=True), |
| nn.Conv3d( |
| 45, |
| 64, |
| kernel_size=(3, 1, 1), |
| stride=(1, 1, 1), |
| padding=(1, 0, 0), |
| bias=False, |
| ), |
| batchnorm(64), |
| nn.ReLU(inplace=True), |
| ) |
|
|
|
|
| class VideoResNet(nn.Module): |
| def __init__( |
| self, |
| block, |
| conv_makers, |
| layers, |
| stem, |
| zero_init_residual=False, |
| channels=[32, 64, 96, 128], |
| ): |
| """Generic resnet video generator. |
| |
| Args: |
| block (nn.Module): resnet building block |
| conv_makers (list(functions)): generator function for each layer |
| layers (List[int]): number of blocks per layer |
| stem (nn.Module, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None. |
| zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False. |
| """ |
| super(VideoResNet, self).__init__() |
| self.inplanes = channels[0] |
|
|
| self.stem = stem() |
|
|
| self.layer1 = self._make_layer( |
| block, conv_makers[0], channels[0], layers[0], stride=1 |
| ) |
| self.layer2 = self._make_layer( |
| block, conv_makers[1], channels[1], layers[1], stride=2, temporal_stride=1 |
| ) |
| self.layer3 = self._make_layer( |
| block, conv_makers[2], channels[2], layers[2], stride=2, temporal_stride=1 |
| ) |
| self.layer4 = self._make_layer( |
| block, conv_makers[3], channels[3], layers[3], stride=1, temporal_stride=1 |
| ) |
|
|
| |
| self._initialize_weights() |
|
|
| if zero_init_residual: |
| for m in self.modules(): |
| if isinstance(m, Bottleneck): |
| nn.init.constant_(m.bn3.weight, 0) |
|
|
| def forward(self, x): |
| tensorConv0 = self.stem(x) |
| tensorConv1 = self.layer1(tensorConv0) |
| tensorConv2 = self.layer2(tensorConv1) |
| tensorConv3 = self.layer3(tensorConv2) |
| tensorConv4 = self.layer4(tensorConv3) |
| return tensorConv0, tensorConv1, tensorConv2, tensorConv3, tensorConv4 |
|
|
| def _make_layer( |
| self, block, conv_builder, planes, blocks, stride=1, temporal_stride=None |
| ): |
| downsample = None |
|
|
| if stride != 1 or self.inplanes != planes * block.expansion: |
| ds_stride = conv_builder.get_downsample_stride(stride, temporal_stride) |
| downsample = nn.Sequential( |
| nn.Conv3d( |
| self.inplanes, |
| planes * block.expansion, |
| kernel_size=1, |
| stride=ds_stride, |
| bias=False, |
| ), |
| batchnorm(planes * block.expansion), |
| ) |
| stride = ds_stride |
|
|
| layers = [] |
| layers.append(block(self.inplanes, planes, conv_builder, stride, downsample)) |
|
|
| self.inplanes = planes * block.expansion |
| for i in range(1, blocks): |
| layers.append(block(self.inplanes, planes, conv_builder)) |
|
|
| return nn.Sequential(*layers) |
|
|
| def _initialize_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Conv3d): |
| nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") |
| if m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, nn.BatchNorm3d): |
| nn.init.constant_(m.weight, 1) |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, nn.Linear): |
| nn.init.normal_(m.weight, 0, 0.01) |
| nn.init.constant_(m.bias, 0) |
|
|
|
|
| def _video_resnet(arch, pretrained=False, progress=True, **kwargs): |
| model = VideoResNet(**kwargs) |
|
|
| if pretrained: |
| state_dict = load_state_dict_from_url(model_urls[arch], progress=progress, model_dir=get_ckpt_container_path(MODEL_TYPE)) |
| model.load_state_dict(state_dict) |
| return model |
|
|
|
|
| def r3d_18(bn=False, pretrained=False, progress=True, **kwargs): |
| """Construct 18 layer Resnet3D model as in |
| https://arxiv.org/abs/1711.11248 |
| |
| Args: |
| pretrained (bool): If True, returns a model pre-trained on Kinetics-400 |
| progress (bool): If True, displays a progress bar of the download to stderr |
| |
| Returns: |
| nn.Module: R3D-18 network |
| """ |
|
|
| global batchnorm |
| if bn: |
| batchnorm = nn.BatchNorm3d |
| else: |
| batchnorm = identity |
|
|
| return _video_resnet( |
| "r3d_18", |
| pretrained, |
| progress, |
| block=BasicBlock, |
| conv_makers=[Conv3DSimple] * 4, |
| layers=[2, 2, 2, 2], |
| stem=BasicStem, |
| **kwargs, |
| ) |
|
|
|
|
| def mc3_18(bn=False, pretrained=False, progress=True, **kwargs): |
| """Constructor for 18 layer Mixed Convolution network as in |
| https://arxiv.org/abs/1711.11248 |
| |
| Args: |
| pretrained (bool): If True, returns a model pre-trained on Kinetics-400 |
| progress (bool): If True, displays a progress bar of the download to stderr |
| |
| Returns: |
| nn.Module: MC3 Network definition |
| """ |
| global batchnorm |
| if bn: |
| batchnorm = nn.BatchNorm3d |
| else: |
| batchnorm = identity |
|
|
| return _video_resnet( |
| "mc3_18", |
| pretrained, |
| progress, |
| block=BasicBlock, |
| conv_makers=[Conv3DSimple] + [Conv3DNoTemporal] * 3, |
| layers=[2, 2, 2, 2], |
| stem=BasicStem, |
| **kwargs, |
| ) |
|
|
|
|
| def r2plus1d_18(bn=False, pretrained=False, progress=True, **kwargs): |
| """Constructor for the 18 layer deep R(2+1)D network as in |
| https://arxiv.org/abs/1711.11248 |
| |
| Args: |
| pretrained (bool): If True, returns a model pre-trained on Kinetics-400 |
| progress (bool): If True, displays a progress bar of the download to stderr |
| |
| Returns: |
| nn.Module: R(2+1)D-18 network |
| """ |
|
|
| global batchnorm |
| if bn: |
| batchnorm = nn.BatchNorm3d |
| else: |
| batchnorm = identity |
|
|
| return _video_resnet( |
| "r2plus1d_18", |
| pretrained, |
| progress, |
| block=BasicBlock, |
| conv_makers=[Conv2Plus1D] * 4, |
| layers=[2, 2, 2, 2], |
| stem=R2Plus1dStem, |
| **kwargs, |
| ) |
|
|
|
|
| class upConv3D(nn.Module): |
| def __init__(self, in_ch, out_ch, kernel_size, stride, padding, upmode="transpose"): |
| super().__init__() |
| self.upmode = upmode |
| if self.upmode == "transpose": |
| self.upconv = nn.ModuleList( |
| [ |
| nn.ConvTranspose3d( |
| in_ch, |
| out_ch, |
| kernel_size=kernel_size, |
| stride=stride, |
| padding=padding, |
| ), |
| SEGating(out_ch), |
| batchnorm(out_ch), |
| ] |
| ) |
| else: |
| self.upconv = nn.ModuleList( |
| [ |
| nn.Upsample( |
| mode="trilinear", scale_factor=(1, 2, 2), align_corners=False |
| ), |
| nn.Conv3d(in_ch, out_ch, kernel_size=1, stride=1), |
| SEGating(out_ch), |
| batchnorm(out_ch), |
| ] |
| ) |
| self.upconv = nn.Sequential(*self.upconv) |
|
|
| def forward(self, x): |
| return self.upconv(x) |
|
|
|
|
| class Conv_3d(nn.Module): |
| def __init__(self, in_ch, out_ch, kernel_size, stride=1, padding=0, bias=True): |
| super().__init__() |
| self.conv = nn.Sequential( |
| nn.Conv3d( |
| in_ch, |
| out_ch, |
| kernel_size=kernel_size, |
| stride=stride, |
| padding=padding, |
| bias=bias, |
| ), |
| SEGating(out_ch), |
| batchnorm(out_ch), |
| ) |
|
|
| def forward(self, x): |
| return self.conv(x) |
|
|
|
|
| def make_optimizer(args, my_model): |
| trainable = filter(lambda x: x.requires_grad, my_model.parameters()) |
|
|
| if args.optimizer == "SGD": |
| optimizer_function = optim.SGD |
| kwargs = {"momentum": 0.9} |
| elif args.optimizer == "ADAM": |
| optimizer_function = optim.Adam |
| kwargs = {"betas": (0.9, 0.999), "eps": 1e-08} |
| elif args.optimizer == "ADAMax": |
| optimizer_function = optim.Adamax |
| kwargs = {"betas": (0.9, 0.999), "eps": 1e-08} |
| elif args.optimizer == "RMSprop": |
| optimizer_function = optim.RMSprop |
| kwargs = {"eps": 1e-08} |
|
|
| kwargs["lr"] = args.lr |
| kwargs["weight_decay"] = args.weight_decay |
|
|
| return optimizer_function(trainable, **kwargs) |
|
|
|
|
| def make_scheduler(args, my_optimizer): |
| if args.decay_type == "step": |
| scheduler = lrs.StepLR(my_optimizer, step_size=args.lr_decay, gamma=args.gamma) |
| elif args.decay_type.find("step") >= 0: |
| milestones = args.decay_type.split("_") |
| milestones.pop(0) |
| milestones = list(map(lambda x: int(x), milestones)) |
| scheduler = lrs.MultiStepLR( |
| my_optimizer, milestones=milestones, gamma=args.gamma |
| ) |
| elif args.decay_type == "plateau": |
| scheduler = lrs.ReduceLROnPlateau( |
| my_optimizer, |
| mode="max", |
| factor=args.gamma, |
| patience=args.patience, |
| threshold=0.01, |
| threshold_mode="abs", |
| verbose=True, |
| ) |
|
|
| return scheduler |
|
|
|
|
| def gaussian_kernel(sz, sigma): |
| k = torch.arange(-(sz - 1) / 2, (sz + 1) / 2) |
| k = torch.exp(-1.0 / (2 * sigma**2) * k**2) |
| k = k.reshape(-1, 1) * k.reshape(1, -1) |
| k = k / torch.sum(k) |
| return k |
|
|
|
|
| def moduleNormalize(frame): |
| return torch.cat( |
| [ |
| (frame[:, 0:1, :, :] - 0.4631), |
| (frame[:, 1:2, :, :] - 0.4352), |
| (frame[:, 2:3, :, :] - 0.3990), |
| ], |
| 1, |
| ) |
|
|
|
|
| class FoldUnfold: |
| """ |
| Class to handle folding tensor frame into batch of patches and back to frame again |
| Thanks to Charlie Tan (charlie.tan.2019@bristol.ac.uk) for the earier version. |
| """ |
|
|
| def __init__(self, height, width, patch_size, overlap): |
| if height % 2 or width % 2 or patch_size % 2 or overlap % 2: |
| print( |
| "only defined for even values of height, width, patch_size size and overlap, odd values will reconstruct incorrectly" |
| ) |
| return |
|
|
| self.height = height |
| self.width = width |
|
|
| self.patch_size = patch_size |
| self.overlap = overlap |
| self.stride = patch_size - overlap |
|
|
| def fold_to_patches(self, *frames): |
| """ |
| args: frames -- list of (1,3,H,W) tensors |
| returns: list of (B,3,h,w) image patches |
| """ |
|
|
| |
| n_blocks_h = (self.height // (self.stride)) + 1 |
| n_blocks_w = (self.width // (self.stride)) + 1 |
|
|
| |
| self.pad_h = (self.stride * n_blocks_h + self.overlap - self.height) // 2 |
| self.pad_w = (self.stride * n_blocks_w + self.overlap - self.width) // 2 |
| self.height_pad = self.height + 2 * self.pad_h |
| self.width_pad = self.width + 2 * self.pad_w |
|
|
| |
| patches_list = [] |
| for i in range(len(frames)): |
| padded = F.pad( |
| frames[i], |
| (self.pad_w, self.pad_w, self.pad_h, self.pad_h), |
| mode="reflect", |
| ) |
| unfolded = F.unfold(padded, self.patch_size, stride=self.stride) |
| patches = unfolded.permute(2, 1, 0).reshape( |
| -1, 3, self.patch_size, self.patch_size |
| ) |
| patches_list.append(patches) |
|
|
| return patches_list |
|
|
| def unfold_to_frame(self, patches): |
| """ |
| args: patches -- tensor of shape (B,3,h,w) |
| returns: frame -- tensor of shape (1,3,H,W) |
| """ |
|
|
| |
| frame_unfold = patches.reshape(-1, 3 * self.patch_size**2, 1).permute(2, 1, 0) |
|
|
| |
| frame_fold = F.fold( |
| frame_unfold, |
| (self.height_pad, self.width_pad), |
| self.patch_size, |
| stride=self.stride, |
| ) |
|
|
| |
| |
| ones = torch.ones_like(frame_fold) |
| ones_unfold = F.unfold(ones, self.patch_size, stride=self.stride) |
|
|
| |
| |
| divisor = F.fold( |
| ones_unfold, |
| (self.height_pad, self.width_pad), |
| self.patch_size, |
| stride=self.stride, |
| ) |
|
|
| |
| frame_div = frame_fold / divisor |
|
|
| |
| frame_crop = frame_div[ |
| :, :, self.pad_h : -self.pad_h, self.pad_w : -self.pad_w |
| ].clone() |
|
|
| return frame_crop |
|
|
|
|
| def read_frame_yuv2rgb(stream, width, height, iFrame, bit_depth, pix_fmt="420"): |
| if pix_fmt == "420": |
| multiplier = 1 |
| uv_factor = 2 |
| elif pix_fmt == "444": |
| multiplier = 2 |
| uv_factor = 1 |
| else: |
| print("Pixel format {} is not supported".format(pix_fmt)) |
| return |
|
|
| if bit_depth == 8: |
| datatype = np.uint8 |
| stream.seek(iFrame * 1.5 * width * height * multiplier) |
| Y = np.fromfile(stream, dtype=datatype, count=width * height).reshape( |
| (height, width) |
| ) |
|
|
| |
| U = np.fromfile( |
| stream, dtype=datatype, count=(width // uv_factor) * (height // uv_factor) |
| ).reshape((height // uv_factor, width // uv_factor)) |
| V = np.fromfile( |
| stream, dtype=datatype, count=(width // uv_factor) * (height // uv_factor) |
| ).reshape((height // uv_factor, width // uv_factor)) |
|
|
| else: |
| datatype = np.uint16 |
| stream.seek(iFrame * 3 * width * height * multiplier) |
| Y = np.fromfile(stream, dtype=datatype, count=width * height).reshape( |
| (height, width) |
| ) |
|
|
| U = np.fromfile( |
| stream, dtype=datatype, count=(width // uv_factor) * (height // uv_factor) |
| ).reshape((height // uv_factor, width // uv_factor)) |
| V = np.fromfile( |
| stream, dtype=datatype, count=(width // uv_factor) * (height // uv_factor) |
| ).reshape((height // uv_factor, width // uv_factor)) |
|
|
| if pix_fmt == "420": |
| yuv = np.empty((height * 3 // 2, width), dtype=datatype) |
| yuv[0:height, :] = Y |
|
|
| yuv[height : height + height // 4, :] = U.reshape(-1, width) |
| yuv[height + height // 4 :, :] = V.reshape(-1, width) |
|
|
| if bit_depth != 8: |
| yuv = (yuv / (2**bit_depth - 1) * 255).astype(np.uint8) |
|
|
| |
| rgb = cv2.cvtColor(yuv, cv2.COLOR_YUV2RGB_I420) |
|
|
| else: |
| yvu = np.stack([Y, V, U], axis=2) |
| if bit_depth != 8: |
| yvu = (yvu / (2**bit_depth - 1) * 255).astype(np.uint8) |
| rgb = cv2.cvtColor(yvu, cv2.COLOR_YCrCb2RGB) |
|
|
| return rgb |
|
|
|
|
| def quantize(imTensor): |
| return imTensor.clamp(0.0, 1.0).mul(255).round() |
|
|
|
|
| def tensor2rgb(tensor): |
| """ |
| Convert GPU Tensor to RGB image (numpy array) |
| """ |
| out = [] |
| for b in range(tensor.shape[0]): |
| out.append( |
| np.moveaxis(quantize(tensor[b]).cpu().detach().numpy(), 0, 2).astype( |
| np.uint8 |
| ) |
| ) |
| return np.array(out) |
|
|
|
|
| class Identity(nn.Module): |
| def __init__(self, *args): |
| super(Identity, self).__init__() |
|
|
| def forward(self, x): |
| return x |
|
|
|
|
| class SEBlock(nn.Module): |
| def __init__(self, input_dim, reduction=16): |
| super(SEBlock, self).__init__() |
| mid = int(input_dim / reduction) |
| self.avg_pool = nn.AdaptiveAvgPool2d(1) |
| self.fc = nn.Sequential( |
| nn.Linear(input_dim, mid), |
| nn.ReLU(inplace=True), |
| nn.Linear(mid, input_dim), |
| nn.Sigmoid(), |
| ) |
|
|
| def forward(self, x): |
| b, c, _, _ = x.size() |
| y = self.avg_pool(x).view(b, c) |
| y = self.fc(y).view(b, c, 1, 1) |
| return x * y |
|
|
|
|
| class ResNextBlock(nn.Module): |
| def __init__( |
| self, down, cin, cout, ks, stride=1, groups=32, base_width=4, norm_layer=None |
| ): |
| super(ResNextBlock, self).__init__() |
| if norm_layer is None or norm_layer == "batch": |
| norm_layer = nn.BatchNorm2d |
| elif norm_layer == "identity": |
| norm_layer = Identity |
| width = int(cout * (base_width / 64.0)) * groups |
| |
| self.conv1 = nn.Conv2d(cin, width, kernel_size=1, stride=1, bias=False) |
| self.bn1 = norm_layer(width) |
| if down: |
| self.conv2 = nn.Conv2d( |
| width, |
| width, |
| kernel_size=ks, |
| stride=stride, |
| padding=(ks - 1) // 2, |
| groups=groups, |
| bias=False, |
| ) |
| else: |
| self.conv2 = nn.ConvTranspose2d( |
| width, |
| width, |
| kernel_size=ks, |
| stride=stride, |
| padding=(ks - stride) // 2, |
| groups=groups, |
| bias=False, |
| ) |
| self.bn2 = norm_layer(width) |
| self.conv3 = nn.Conv2d(width, cout, kernel_size=1, stride=1, bias=False) |
| self.bn3 = norm_layer(cout) |
| self.relu = nn.ReLU(inplace=True) |
| self.downsample = None |
| if stride != 1 or cin != cout: |
| if down: |
| self.downsample = nn.Sequential( |
| nn.Conv2d(cin, cout, kernel_size=1, stride=stride, bias=False), |
| norm_layer(cout), |
| ) |
| else: |
| self.downsample = nn.Sequential( |
| |
| nn.ConvTranspose2d( |
| cin, cout, kernel_size=2, stride=stride, bias=False |
| ), |
| norm_layer(cout), |
| ) |
| self.stride = stride |
|
|
| def forward(self, x): |
| identity = x |
|
|
| out = self.conv1(x) |
| out = self.bn1(out) |
| out = self.relu(out) |
|
|
| out = self.conv2(out) |
| out = self.bn2(out) |
| out = self.relu(out) |
|
|
| out = self.conv3(out) |
| out = self.bn3(out) |
|
|
| if self.downsample is not None: |
| identity = self.downsample(x) |
|
|
| out += identity |
| out = self.relu(out) |
|
|
| return out |
|
|
|
|
| class MultiScaleResNextBlock(nn.Module): |
| def __init__(self, down, cin, cout, ks_s, ks_l, stride, norm_layer): |
| super(MultiScaleResNextBlock, self).__init__() |
| self.resnext_small = ResNextBlock( |
| down, cin, cout // 2, ks_s, stride, norm_layer=norm_layer |
| ) |
| self.resnext_large = ResNextBlock( |
| down, cin, cout // 2, ks_l, stride, norm_layer=norm_layer |
| ) |
| self.attention = SEBlock(cout) |
|
|
| def forward(self, tensorCombine): |
| out_small = self.resnext_small(tensorCombine) |
| out_large = self.resnext_large(tensorCombine) |
| out = torch.cat([out_small, out_large], 1) |
| out = self.attention(out) |
| return out |
|
|
|
|
| class UMultiScaleResNext(nn.Module): |
| def __init__( |
| self, channels=[64, 128, 256, 512], norm_layer="batch", inplanes=6, **kwargs |
| ): |
| super(UMultiScaleResNext, self).__init__() |
| self.conv1 = MultiScaleResNextBlock( |
| True, inplanes, channels[0], ks_s=3, ks_l=7, stride=2, norm_layer=norm_layer |
| ) |
| self.conv2 = MultiScaleResNextBlock( |
| True, |
| channels[0], |
| channels[1], |
| ks_s=3, |
| ks_l=7, |
| stride=2, |
| norm_layer=norm_layer, |
| ) |
| self.conv3 = MultiScaleResNextBlock( |
| True, |
| channels[1], |
| channels[2], |
| ks_s=3, |
| ks_l=5, |
| stride=2, |
| norm_layer=norm_layer, |
| ) |
| self.conv4 = MultiScaleResNextBlock( |
| True, |
| channels[2], |
| channels[3], |
| ks_s=3, |
| ks_l=5, |
| stride=2, |
| norm_layer=norm_layer, |
| ) |
|
|
| self.deconv4 = MultiScaleResNextBlock( |
| True, |
| channels[3], |
| channels[3], |
| ks_s=3, |
| ks_l=5, |
| stride=1, |
| norm_layer=norm_layer, |
| ) |
| self.deconv3 = MultiScaleResNextBlock( |
| False, |
| channels[3], |
| channels[2], |
| ks_s=4, |
| ks_l=6, |
| stride=2, |
| norm_layer=norm_layer, |
| ) |
| self.deconv2 = MultiScaleResNextBlock( |
| False, |
| channels[2], |
| channels[1], |
| ks_s=4, |
| ks_l=8, |
| stride=2, |
| norm_layer=norm_layer, |
| ) |
| self.deconv1 = MultiScaleResNextBlock( |
| False, |
| channels[1], |
| channels[0], |
| ks_s=4, |
| ks_l=8, |
| stride=2, |
| norm_layer=norm_layer, |
| ) |
|
|
| def forward(self, im0, im2): |
| tensorJoin = torch.cat([im0, im2], 1) |
|
|
| tensorConv1 = self.conv1(tensorJoin) |
| tensorConv2 = self.conv2(tensorConv1) |
| tensorConv3 = self.conv3(tensorConv2) |
| tensorConv4 = self.conv4(tensorConv3) |
|
|
| tensorDeconv4 = self.deconv4(tensorConv4) |
| tensorDeconv3 = self.deconv3(tensorDeconv4 + tensorConv4) |
| tensorDeconv2 = self.deconv2(tensorDeconv3 + tensorConv3) |
| tensorDeconv1 = self.deconv1(tensorDeconv2 + tensorConv2) |
|
|
| return tensorDeconv1 |
|
|
|
|
| class MultiInputGridNet(nn.Module): |
| def __init__(self, in_chs, out_chs, grid_chs=(32, 64, 96), n_row=3, n_col=6): |
| super(MultiInputGridNet, self).__init__() |
|
|
| self.n_row = n_row |
| self.n_col = n_col |
| self.n_chs = grid_chs |
| assert ( |
| len(grid_chs) == self.n_row |
| ), "should give num channels for each row (scale stream)" |
| assert ( |
| len(in_chs) == self.n_row |
| ), "should give input channels for each row (scale stream)" |
|
|
| for r, n_ch in enumerate(self.n_chs): |
| setattr(self, f"lateral_{r}_0", LateralBlock(in_chs[r], n_ch)) |
| for c in range(1, self.n_col): |
| setattr(self, f"lateral_{r}_{c}", LateralBlock(n_ch, n_ch)) |
|
|
| for r, (in_ch, out_ch) in enumerate(zip(self.n_chs[:-1], self.n_chs[1:])): |
| for c in range(int(self.n_col / 2)): |
| setattr(self, f"down_{r}_{c}", DownSamplingBlock(in_ch, out_ch)) |
|
|
| for r, (in_ch, out_ch) in enumerate(zip(self.n_chs[1:], self.n_chs[:-1])): |
| for c in range(int(self.n_col / 2)): |
| setattr(self, f"up_{r}_{c}", UpSamplingBlock(in_ch, out_ch)) |
|
|
| self.lateral_final = LateralBlock(self.n_chs[0], out_chs) |
|
|
| def forward(self, *args): |
| assert len(args) == self.n_row |
|
|
| |
| cur_col = list(args) |
| for c in range(int(self.n_col / 2)): |
| for r in range(self.n_row): |
| cur_col[r] = getattr(self, f"lateral_{r}_{c}")(cur_col[r]) |
| if r != 0: |
| cur_col[r] += getattr(self, f"down_{r-1}_{c}")(cur_col[r - 1]) |
|
|
| for c in range(int(self.n_col / 2), self.n_col): |
| for r in range(self.n_row - 1, -1, -1): |
| cur_col[r] = getattr(self, f"lateral_{r}_{c}")(cur_col[r]) |
| if r != self.n_row - 1: |
| cur_col[r] += getattr(self, f"up_{r}_{c-int(self.n_col/2)}")( |
| cur_col[r + 1] |
| ) |
|
|
| return self.lateral_final(cur_col[0]) |
|
|
|
|
| class MIMOGridNet(nn.Module): |
| def __init__( |
| self, in_chs, out_chs, grid_chs=(32, 64, 96), n_row=3, n_col=6, outrow=(0, 1, 2) |
| ): |
| super(MIMOGridNet, self).__init__() |
|
|
| self.n_row = n_row |
| self.n_col = n_col |
| self.n_chs = grid_chs |
| self.outrow = outrow |
| assert ( |
| len(grid_chs) == self.n_row |
| ), "should give num channels for each row (scale stream)" |
| assert ( |
| len(in_chs) == self.n_row |
| ), "should give input channels for each row (scale stream)" |
| assert len(out_chs) == len( |
| self.outrow |
| ), "should give out channels for each output row (scale stream)" |
|
|
| for r, n_ch in enumerate(self.n_chs): |
| setattr(self, f"lateral_{r}_0", LateralBlock(in_chs[r], n_ch)) |
| for c in range(1, self.n_col): |
| setattr(self, f"lateral_{r}_{c}", LateralBlock(n_ch, n_ch)) |
|
|
| for r, (in_ch, out_ch) in enumerate(zip(self.n_chs[:-1], self.n_chs[1:])): |
| for c in range(int(self.n_col / 2)): |
| setattr(self, f"down_{r}_{c}", DownSamplingBlock(in_ch, out_ch)) |
|
|
| for r, (in_ch, out_ch) in enumerate(zip(self.n_chs[1:], self.n_chs[:-1])): |
| for c in range(int(self.n_col / 2)): |
| setattr(self, f"up_{r}_{c}", UpSamplingBlock(in_ch, out_ch)) |
|
|
| for i, r in enumerate(outrow): |
| setattr(self, f"lateral_final_{r}", LateralBlock(self.n_chs[r], out_chs[i])) |
|
|
| def forward(self, *args): |
| assert len(args) == self.n_row |
|
|
| |
| cur_col = list(args) |
| for c in range(int(self.n_col / 2)): |
| for r in range(self.n_row): |
| cur_col[r] = getattr(self, f"lateral_{r}_{c}")(cur_col[r]) |
| if r != 0: |
| cur_col[r] += getattr(self, f"down_{r-1}_{c}")(cur_col[r - 1]) |
|
|
| for c in range(int(self.n_col / 2), self.n_col): |
| for r in range(self.n_row - 1, -1, -1): |
| cur_col[r] = getattr(self, f"lateral_{r}_{c}")(cur_col[r]) |
| if r != self.n_row - 1: |
| cur_col[r] += getattr(self, f"up_{r}_{c-int(self.n_col/2)}")( |
| cur_col[r + 1] |
| ) |
|
|
| out = [] |
| for r in self.outrow: |
| out.append(getattr(self, f"lateral_final_{r}")(cur_col[r])) |
|
|
| return out |
|
|
|
|
| class GeneralGridNet(nn.Module): |
| def __init__(self, in_chs, out_chs, grid_chs=(32, 64, 96), n_row=3, n_col=6): |
| super(GeneralGridNet, self).__init__() |
|
|
| self.n_row = n_row |
| self.n_col = n_col |
| self.n_chs = grid_chs |
| assert ( |
| len(grid_chs) == self.n_row |
| ), "should give num channels for each row (scale stream)" |
|
|
| for r, n_ch in enumerate(self.n_chs): |
| if r == 0: |
| setattr(self, f"lateral_{r}_0", LateralBlock(in_chs, n_ch)) |
| for c in range(1, self.n_col): |
| setattr(self, f"lateral_{r}_{c}", LateralBlock(n_ch, n_ch)) |
|
|
| for r, (in_ch, out_ch) in enumerate(zip(self.n_chs[:-1], self.n_chs[1:])): |
| for c in range(int(self.n_col / 2)): |
| setattr(self, f"down_{r}_{c}", DownSamplingBlock(in_ch, out_ch)) |
|
|
| for r, (in_ch, out_ch) in enumerate(zip(self.n_chs[1:], self.n_chs[:-1])): |
| for c in range(int(self.n_col / 2)): |
| setattr(self, f"up_{r}_{c}", UpSamplingBlock(in_ch, out_ch)) |
|
|
| self.lateral_final = LateralBlock(self.n_chs[0], out_chs) |
|
|
| def forward(self, x): |
| cur_col = [x] + [None] * (self.n_row - 1) |
| for c in range(int(self.n_col / 2)): |
| for r in range(self.n_row): |
| if cur_col[r] != None: |
| cur_col[r] = getattr(self, f"lateral_{r}_{c}")(cur_col[r]) |
| else: |
| cur_col[r] = 0.0 |
| if r != 0: |
| cur_col[r] += getattr(self, f"down_{r-1}_{c}")(cur_col[r - 1]) |
|
|
| for c in range(int(self.n_col / 2), self.n_col): |
| for r in range(self.n_row - 1, -1, -1): |
| cur_col[r] = getattr(self, f"lateral_{r}_{c}")(cur_col[r]) |
| if r != self.n_row - 1: |
| cur_col[r] += getattr(self, f"up_{r}_{c-int(self.n_col/2)}")( |
| cur_col[r + 1] |
| ) |
|
|
| return self.lateral_final(cur_col[0]) |
|
|
|
|
| class GridNet(nn.Module): |
| def __init__(self, in_chs, out_chs, grid_chs=(32, 64, 96)): |
| super(GridNet, self).__init__() |
|
|
| self.n_row = 3 |
| self.n_col = 6 |
| self.n_chs = grid_chs |
| assert ( |
| len(grid_chs) == self.n_row |
| ), "should give num channels for each row (scale stream)" |
|
|
| self.lateral_init = LateralBlock(in_chs, self.n_chs[0]) |
|
|
| for r, n_ch in enumerate(self.n_chs): |
| for c in range(self.n_col - 1): |
| setattr(self, f"lateral_{r}_{c}", LateralBlock(n_ch, n_ch)) |
|
|
| for r, (in_ch, out_ch) in enumerate(zip(self.n_chs[:-1], self.n_chs[1:])): |
| for c in range(int(self.n_col / 2)): |
| setattr(self, f"down_{r}_{c}", DownSamplingBlock(in_ch, out_ch)) |
|
|
| for r, (in_ch, out_ch) in enumerate(zip(self.n_chs[1:], self.n_chs[:-1])): |
| for c in range(int(self.n_col / 2)): |
| setattr(self, f"up_{r}_{c}", UpSamplingBlock(in_ch, out_ch)) |
|
|
| self.lateral_final = LateralBlock(self.n_chs[0], out_chs) |
|
|
| def forward(self, x): |
| state_00 = self.lateral_init(x) |
| state_10 = self.down_0_0(state_00) |
| state_20 = self.down_1_0(state_10) |
|
|
| state_01 = self.lateral_0_0(state_00) |
| state_11 = self.down_0_1(state_01) + self.lateral_1_0(state_10) |
| state_21 = self.down_1_1(state_11) + self.lateral_2_0(state_20) |
|
|
| state_02 = self.lateral_0_1(state_01) |
| state_12 = self.down_0_2(state_02) + self.lateral_1_1(state_11) |
| state_22 = self.down_1_2(state_12) + self.lateral_2_1(state_21) |
|
|
| state_23 = self.lateral_2_2(state_22) |
| state_13 = self.up_1_0(state_23) + self.lateral_1_2(state_12) |
| state_03 = self.up_0_0(state_13) + self.lateral_0_2(state_02) |
|
|
| state_24 = self.lateral_2_3(state_23) |
| state_14 = self.up_1_1(state_24) + self.lateral_1_3(state_13) |
| state_04 = self.up_0_1(state_14) + self.lateral_0_3(state_03) |
|
|
| state_25 = self.lateral_2_4(state_24) |
| state_15 = self.up_1_2(state_25) + self.lateral_1_4(state_14) |
| state_05 = self.up_0_2(state_15) + self.lateral_0_4(state_04) |
|
|
| return self.lateral_final(state_05) |
|
|
|
|
| class LateralBlock(nn.Module): |
| def __init__(self, ch_in, ch_out): |
| super(LateralBlock, self).__init__() |
| self.f = nn.Sequential( |
| nn.PReLU(), |
| nn.Conv2d(ch_in, ch_out, kernel_size=3, padding=1), |
| nn.PReLU(), |
| nn.Conv2d(ch_out, ch_out, kernel_size=3, padding=1), |
| ) |
| if ch_in != ch_out: |
| self.conv = nn.Conv2d(ch_in, ch_out, kernel_size=3, padding=1) |
|
|
| def forward(self, x): |
| fx = self.f(x) |
| if fx.shape[1] != x.shape[1]: |
| x = self.conv(x) |
| return fx + x |
|
|
|
|
| class DownSamplingBlock(nn.Module): |
| def __init__(self, ch_in, ch_out): |
| super(DownSamplingBlock, self).__init__() |
| self.f = nn.Sequential( |
| nn.PReLU(), |
| nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=2, padding=1), |
| nn.PReLU(), |
| nn.Conv2d(ch_out, ch_out, kernel_size=3, padding=1), |
| ) |
|
|
| def forward(self, x): |
| return self.f(x) |
|
|
|
|
| class UpSamplingBlock(nn.Module): |
| def __init__(self, ch_in, ch_out): |
| super(UpSamplingBlock, self).__init__() |
| self.f = nn.Sequential( |
| nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), |
| nn.PReLU(), |
| nn.Conv2d(ch_in, ch_out, kernel_size=3, padding=1), |
| nn.PReLU(), |
| nn.Conv2d(ch_out, ch_out, kernel_size=3, padding=1), |
| ) |
|
|
| def forward(self, x): |
| return self.f(x) |
|
|
| |
|
|
|
|
| class Network(torch.nn.Module): |
| def __init__(self): |
| super(Network, self).__init__() |
|
|
| class Extractor(torch.nn.Module): |
| def __init__(self): |
| super(Extractor, self).__init__() |
|
|
| self.netOne = torch.nn.Sequential( |
| torch.nn.Conv2d( |
| in_channels=3, |
| out_channels=16, |
| kernel_size=3, |
| stride=2, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| torch.nn.Conv2d( |
| in_channels=16, |
| out_channels=16, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| torch.nn.Conv2d( |
| in_channels=16, |
| out_channels=16, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| ) |
|
|
| self.netTwo = torch.nn.Sequential( |
| torch.nn.Conv2d( |
| in_channels=16, |
| out_channels=32, |
| kernel_size=3, |
| stride=2, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| torch.nn.Conv2d( |
| in_channels=32, |
| out_channels=32, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| torch.nn.Conv2d( |
| in_channels=32, |
| out_channels=32, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| ) |
|
|
| self.netThr = torch.nn.Sequential( |
| torch.nn.Conv2d( |
| in_channels=32, |
| out_channels=64, |
| kernel_size=3, |
| stride=2, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| torch.nn.Conv2d( |
| in_channels=64, |
| out_channels=64, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| torch.nn.Conv2d( |
| in_channels=64, |
| out_channels=64, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| ) |
|
|
| self.netFou = torch.nn.Sequential( |
| torch.nn.Conv2d( |
| in_channels=64, |
| out_channels=96, |
| kernel_size=3, |
| stride=2, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| torch.nn.Conv2d( |
| in_channels=96, |
| out_channels=96, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| torch.nn.Conv2d( |
| in_channels=96, |
| out_channels=96, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| ) |
|
|
| self.netFiv = torch.nn.Sequential( |
| torch.nn.Conv2d( |
| in_channels=96, |
| out_channels=128, |
| kernel_size=3, |
| stride=2, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| torch.nn.Conv2d( |
| in_channels=128, |
| out_channels=128, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| torch.nn.Conv2d( |
| in_channels=128, |
| out_channels=128, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| ) |
|
|
| self.netSix = torch.nn.Sequential( |
| torch.nn.Conv2d( |
| in_channels=128, |
| out_channels=196, |
| kernel_size=3, |
| stride=2, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| torch.nn.Conv2d( |
| in_channels=196, |
| out_channels=196, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| torch.nn.Conv2d( |
| in_channels=196, |
| out_channels=196, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| ) |
|
|
| |
|
|
| def forward(self, tenInput): |
| tenOne = self.netOne(tenInput) |
| tenTwo = self.netTwo(tenOne) |
| tenThr = self.netThr(tenTwo) |
| tenFou = self.netFou(tenThr) |
| tenFiv = self.netFiv(tenFou) |
| tenSix = self.netSix(tenFiv) |
|
|
| return [tenOne, tenTwo, tenThr, tenFou, tenFiv, tenSix] |
|
|
| |
|
|
| |
|
|
| class Decoder(torch.nn.Module): |
| def __init__(self, intLevel): |
| super(Decoder, self).__init__() |
|
|
| intPrevious = [ |
| None, |
| None, |
| 81 + 32 + 2 + 2, |
| 81 + 64 + 2 + 2, |
| 81 + 96 + 2 + 2, |
| 81 + 128 + 2 + 2, |
| 81, |
| None, |
| ][intLevel + 1] |
| intCurrent = [ |
| None, |
| None, |
| 81 + 32 + 2 + 2, |
| 81 + 64 + 2 + 2, |
| 81 + 96 + 2 + 2, |
| 81 + 128 + 2 + 2, |
| 81, |
| None, |
| ][intLevel + 0] |
|
|
| if intLevel < 6: |
| self.netUpflow = torch.nn.ConvTranspose2d( |
| in_channels=2, |
| out_channels=2, |
| kernel_size=4, |
| stride=2, |
| padding=1, |
| ) |
| if intLevel < 6: |
| self.netUpfeat = torch.nn.ConvTranspose2d( |
| in_channels=intPrevious + 128 + 128 + 96 + 64 + 32, |
| out_channels=2, |
| kernel_size=4, |
| stride=2, |
| padding=1, |
| ) |
| if intLevel < 6: |
| self.fltBackwarp = [None, None, None, 5.0, 2.5, 1.25, 0.625, None][ |
| intLevel + 1 |
| ] |
|
|
| self.netOne = torch.nn.Sequential( |
| torch.nn.Conv2d( |
| in_channels=intCurrent, |
| out_channels=128, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| ) |
|
|
| self.netTwo = torch.nn.Sequential( |
| torch.nn.Conv2d( |
| in_channels=intCurrent + 128, |
| out_channels=128, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| ) |
|
|
| self.netThr = torch.nn.Sequential( |
| torch.nn.Conv2d( |
| in_channels=intCurrent + 128 + 128, |
| out_channels=96, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| ) |
|
|
| self.netFou = torch.nn.Sequential( |
| torch.nn.Conv2d( |
| in_channels=intCurrent + 128 + 128 + 96, |
| out_channels=64, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| ) |
|
|
| self.netFiv = torch.nn.Sequential( |
| torch.nn.Conv2d( |
| in_channels=intCurrent + 128 + 128 + 96 + 64, |
| out_channels=32, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| ) |
|
|
| self.netSix = torch.nn.Sequential( |
| torch.nn.Conv2d( |
| in_channels=intCurrent + 128 + 128 + 96 + 64 + 32, |
| out_channels=2, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ) |
| ) |
|
|
| |
|
|
| def forward(self, tenFirst, tenSecond, objPrevious): |
| tenFlow = None |
| tenFeat = None |
|
|
| if objPrevious is None: |
| tenFlow = None |
| tenFeat = None |
|
|
| tenVolume = torch.nn.functional.leaky_relu( |
| input=FunctionCorrelation( |
| tenFirst=tenFirst, tenSecond=tenSecond |
| ), |
| negative_slope=0.1, |
| inplace=False, |
| ) |
|
|
| tenFeat = torch.cat([tenVolume], 1) |
|
|
| elif objPrevious is not None: |
| tenFlow = self.netUpflow(objPrevious["tenFlow"]) |
| tenFeat = self.netUpfeat(objPrevious["tenFeat"]) |
|
|
| tenVolume = torch.nn.functional.leaky_relu( |
| input=FunctionCorrelation( |
| tenFirst=tenFirst, |
| tenSecond=backwarp( |
| tenInput=tenSecond, tenFlow=tenFlow * self.fltBackwarp |
| ), |
| ), |
| negative_slope=0.1, |
| inplace=False, |
| ) |
|
|
| tenFeat = torch.cat([tenVolume, tenFirst, tenFlow, tenFeat], 1) |
|
|
| |
|
|
| tenFeat = torch.cat([self.netOne(tenFeat), tenFeat], 1) |
| tenFeat = torch.cat([self.netTwo(tenFeat), tenFeat], 1) |
| tenFeat = torch.cat([self.netThr(tenFeat), tenFeat], 1) |
| tenFeat = torch.cat([self.netFou(tenFeat), tenFeat], 1) |
| tenFeat = torch.cat([self.netFiv(tenFeat), tenFeat], 1) |
|
|
| tenFlow = self.netSix(tenFeat) |
|
|
| return {"tenFlow": tenFlow, "tenFeat": tenFeat} |
|
|
| |
|
|
| |
|
|
| class Refiner(torch.nn.Module): |
| def __init__(self): |
| super(Refiner, self).__init__() |
|
|
| self.netMain = torch.nn.Sequential( |
| torch.nn.Conv2d( |
| in_channels=81 + 32 + 2 + 2 + 128 + 128 + 96 + 64 + 32, |
| out_channels=128, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| dilation=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| torch.nn.Conv2d( |
| in_channels=128, |
| out_channels=128, |
| kernel_size=3, |
| stride=1, |
| padding=2, |
| dilation=2, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| torch.nn.Conv2d( |
| in_channels=128, |
| out_channels=128, |
| kernel_size=3, |
| stride=1, |
| padding=4, |
| dilation=4, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| torch.nn.Conv2d( |
| in_channels=128, |
| out_channels=96, |
| kernel_size=3, |
| stride=1, |
| padding=8, |
| dilation=8, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| torch.nn.Conv2d( |
| in_channels=96, |
| out_channels=64, |
| kernel_size=3, |
| stride=1, |
| padding=16, |
| dilation=16, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| torch.nn.Conv2d( |
| in_channels=64, |
| out_channels=32, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| dilation=1, |
| ), |
| torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
| torch.nn.Conv2d( |
| in_channels=32, |
| out_channels=2, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| dilation=1, |
| ), |
| ) |
|
|
| |
|
|
| def forward(self, tenInput): |
| return self.netMain(tenInput) |
|
|
| |
|
|
| |
|
|
| self.netExtractor = Extractor() |
|
|
| self.netTwo = Decoder(2) |
| self.netThr = Decoder(3) |
| self.netFou = Decoder(4) |
| self.netFiv = Decoder(5) |
| self.netSix = Decoder(6) |
|
|
| self.netRefiner = Refiner() |
|
|
| self.load_state_dict( |
| { |
| strKey.replace("module", "net"): tenWeight |
| for strKey, tenWeight in torch.hub.load_state_dict_from_url( |
| url="http://content.sniklaus.com/github/pytorch-pwc/network-" |
| + "default" |
| + ".pytorch", |
| model_dir=get_ckpt_container_path(MODEL_TYPE) |
| ).items() |
| } |
| ) |
|
|
| |
|
|
| def forward(self, tenFirst, tenSecond, *args): |
| |
| if len(args) == 0: |
| tenFirst = self.netExtractor(tenFirst) |
| tenSecond = self.netExtractor(tenSecond) |
| else: |
| tenFirst, tenSecond = args |
|
|
| objEstimate = self.netSix(tenFirst[-1], tenSecond[-1], None) |
| objEstimate = self.netFiv(tenFirst[-2], tenSecond[-2], objEstimate) |
| objEstimate = self.netFou(tenFirst[-3], tenSecond[-3], objEstimate) |
| objEstimate = self.netThr(tenFirst[-4], tenSecond[-4], objEstimate) |
| objEstimate = self.netTwo(tenFirst[-5], tenSecond[-5], objEstimate) |
|
|
| return objEstimate["tenFlow"] + self.netRefiner(objEstimate["tenFeat"]) |
|
|
| |
|
|
| def extract_pyramid(self, tenFirst, tenSecond): |
| return self.netExtractor(tenFirst), self.netExtractor(tenSecond) |
|
|
| def extract_pyramid_single(self, tenFirst): |
| return self.netExtractor(tenFirst) |
|
|
|
|
| |
|
|
| netNetwork = None |
|
|
| |
|
|
|
|
| def estimate(tenFirst, tenSecond): |
| global netNetwork |
|
|
| if netNetwork is None: |
| netNetwork = Network().cuda().eval() |
| |
|
|
| assert tenFirst.shape[1] == tenSecond.shape[1] |
| assert tenFirst.shape[2] == tenSecond.shape[2] |
|
|
| intWidth = tenFirst.shape[2] |
| intHeight = tenFirst.shape[1] |
|
|
| assert ( |
| intWidth == 1024 |
| ) |
| assert ( |
| intHeight == 436 |
| ) |
|
|
| tenPreprocessedFirst = tenFirst.cuda().view(1, 3, intHeight, intWidth) |
| tenPreprocessedSecond = tenSecond.cuda().view(1, 3, intHeight, intWidth) |
|
|
| intPreprocessedWidth = int(math.floor(math.ceil(intWidth / 64.0) * 64.0)) |
| intPreprocessedHeight = int(math.floor(math.ceil(intHeight / 64.0) * 64.0)) |
|
|
| tenPreprocessedFirst = torch.nn.functional.interpolate( |
| input=tenPreprocessedFirst, |
| size=(intPreprocessedHeight, intPreprocessedWidth), |
| mode="bilinear", |
| align_corners=False, |
| ) |
| tenPreprocessedSecond = torch.nn.functional.interpolate( |
| input=tenPreprocessedSecond, |
| size=(intPreprocessedHeight, intPreprocessedWidth), |
| mode="bilinear", |
| align_corners=False, |
| ) |
|
|
| tenFlow = 20.0 * torch.nn.functional.interpolate( |
| input=netNetwork(tenPreprocessedFirst, tenPreprocessedSecond), |
| size=(intHeight, intWidth), |
| mode="bilinear", |
| align_corners=False, |
| ) |
|
|
| tenFlow[:, 0, :, :] *= float(intWidth) / float(intPreprocessedWidth) |
| tenFlow[:, 1, :, :] *= float(intHeight) / float(intPreprocessedHeight) |
|
|
| return tenFlow[0, :, :, :].cpu() |
|
|
|
|
| |
|
|
|
|
| class UNet3d_18(nn.Module): |
| def __init__(self, channels=[32, 64, 96, 128], bn=True): |
| super(UNet3d_18, self).__init__() |
| growth = 2 |
| upmode = "transpose" |
|
|
| self.channels = channels |
|
|
| self.lrelu = nn.LeakyReLU(0.2, True) |
|
|
| self.encoder = r3d_18(bn=bn, channels=channels) |
|
|
| self.decoder = nn.Sequential( |
| Conv_3d( |
| channels[::-1][0], |
| channels[::-1][1], |
| kernel_size=3, |
| padding=1, |
| bias=True, |
| ), |
| upConv3D( |
| channels[::-1][1] * growth, |
| channels[::-1][2], |
| kernel_size=(3, 4, 4), |
| stride=(1, 2, 2), |
| padding=(1, 1, 1), |
| upmode=upmode, |
| ), |
| upConv3D( |
| channels[::-1][2] * growth, |
| channels[::-1][3], |
| kernel_size=(3, 4, 4), |
| stride=(1, 2, 2), |
| padding=(1, 1, 1), |
| upmode=upmode, |
| ), |
| Conv_3d( |
| channels[::-1][3] * growth, |
| channels[::-1][3], |
| kernel_size=3, |
| padding=1, |
| bias=True, |
| ), |
| upConv3D( |
| channels[::-1][3] * growth, |
| channels[::-1][3], |
| kernel_size=(3, 4, 4), |
| stride=(1, 2, 2), |
| padding=(1, 1, 1), |
| upmode=upmode, |
| ), |
| ) |
|
|
| self.feature_fuse = nn.Sequential( |
| *( |
| [ |
| nn.Conv2d( |
| channels[::-1][3] * 5, |
| channels[::-1][3], |
| kernel_size=1, |
| stride=1, |
| bias=False, |
| ) |
| ] |
| + [nn.BatchNorm2d(channels[::-1][3]) if bn else Identity] |
| ) |
| ) |
|
|
| self.outconv = nn.Sequential( |
| nn.ReflectionPad2d(3), |
| nn.Conv2d(channels[::-1][3], 3, kernel_size=7, stride=1, padding=0), |
| ) |
|
|
| def forward(self, im1, im3, im5, im7, im4_tilde): |
| images = torch.stack((im1, im3, im4_tilde, im5, im7), dim=2) |
|
|
| x_0, x_1, x_2, x_3, x_4 = self.encoder(images) |
|
|
| dx_3 = self.lrelu(self.decoder[0](x_4)) |
| dx_3 = torch.cat([dx_3, x_3], dim=1) |
|
|
| dx_2 = self.lrelu(self.decoder[1](dx_3)) |
| dx_2 = torch.cat([dx_2, x_2], dim=1) |
|
|
| dx_1 = self.lrelu(self.decoder[2](dx_2)) |
| dx_1 = torch.cat([dx_1, x_1], dim=1) |
|
|
| dx_0 = self.lrelu(self.decoder[3](dx_1)) |
| dx_0 = torch.cat([dx_0, x_0], dim=1) |
|
|
| dx_out = self.lrelu(self.decoder[4](dx_0)) |
| dx_out = torch.cat(torch.unbind(dx_out, 2), 1) |
|
|
| out = self.lrelu(self.feature_fuse(dx_out)) |
| out = self.outconv(out) |
|
|
| return out |
|
|
|
|
| class KernelEstimation(torch.nn.Module): |
| def __init__(self, kernel_size): |
| super(KernelEstimation, self).__init__() |
| self.kernel_size = kernel_size |
|
|
| def Subnet_offset(ks): |
| return torch.nn.Sequential( |
| torch.nn.Conv2d( |
| in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1 |
| ), |
| torch.nn.ReLU(inplace=False), |
| torch.nn.Conv2d( |
| in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1 |
| ), |
| torch.nn.ReLU(inplace=False), |
| torch.nn.Conv2d( |
| in_channels=64, out_channels=ks, kernel_size=3, stride=1, padding=1 |
| ), |
| torch.nn.ReLU(inplace=False), |
| torch.nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), |
| torch.nn.Conv2d( |
| in_channels=ks, out_channels=ks, kernel_size=3, stride=1, padding=1 |
| ), |
| ) |
|
|
| def Subnet_weight(ks): |
| return torch.nn.Sequential( |
| torch.nn.Conv2d( |
| in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1 |
| ), |
| torch.nn.ReLU(inplace=False), |
| torch.nn.Conv2d( |
| in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1 |
| ), |
| torch.nn.ReLU(inplace=False), |
| torch.nn.Conv2d( |
| in_channels=64, out_channels=ks, kernel_size=3, stride=1, padding=1 |
| ), |
| torch.nn.ReLU(inplace=False), |
| torch.nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), |
| torch.nn.Conv2d( |
| in_channels=ks, out_channels=ks, kernel_size=3, stride=1, padding=1 |
| ), |
| torch.nn.Softmax(dim=1), |
| ) |
|
|
| def Subnet_offset_ds(ks): |
| return torch.nn.Sequential( |
| torch.nn.Conv2d( |
| in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1 |
| ), |
| torch.nn.ReLU(inplace=False), |
| torch.nn.Conv2d( |
| in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1 |
| ), |
| torch.nn.ReLU(inplace=False), |
| torch.nn.Conv2d( |
| in_channels=64, out_channels=ks, kernel_size=3, stride=1, padding=1 |
| ), |
| ) |
|
|
| def Subnet_weight_ds(ks): |
| return torch.nn.Sequential( |
| torch.nn.Conv2d( |
| in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1 |
| ), |
| torch.nn.ReLU(inplace=False), |
| torch.nn.Conv2d( |
| in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1 |
| ), |
| torch.nn.ReLU(inplace=False), |
| torch.nn.Conv2d( |
| in_channels=64, out_channels=ks, kernel_size=3, stride=1, padding=1 |
| ), |
| torch.nn.Softmax(dim=1), |
| ) |
|
|
| def Subnet_offset_us(ks): |
| return torch.nn.Sequential( |
| torch.nn.Conv2d( |
| in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1 |
| ), |
| torch.nn.ReLU(inplace=False), |
| torch.nn.Conv2d( |
| in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1 |
| ), |
| torch.nn.ReLU(inplace=False), |
| torch.nn.Conv2d( |
| in_channels=64, out_channels=ks, kernel_size=3, stride=1, padding=1 |
| ), |
| torch.nn.ReLU(inplace=False), |
| torch.nn.Upsample(scale_factor=4, mode="bilinear", align_corners=True), |
| torch.nn.Conv2d( |
| in_channels=ks, out_channels=ks, kernel_size=3, stride=1, padding=1 |
| ), |
| ) |
|
|
| def Subnet_weight_us(ks): |
| return torch.nn.Sequential( |
| torch.nn.Conv2d( |
| in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1 |
| ), |
| torch.nn.ReLU(inplace=False), |
| torch.nn.Conv2d( |
| in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1 |
| ), |
| torch.nn.ReLU(inplace=False), |
| torch.nn.Conv2d( |
| in_channels=64, out_channels=ks, kernel_size=3, stride=1, padding=1 |
| ), |
| torch.nn.ReLU(inplace=False), |
| torch.nn.Upsample(scale_factor=4, mode="bilinear", align_corners=True), |
| torch.nn.Conv2d( |
| in_channels=ks, out_channels=ks, kernel_size=3, stride=1, padding=1 |
| ), |
| torch.nn.Softmax(dim=1), |
| ) |
|
|
| self.moduleWeight1_ds = Subnet_weight_ds(self.kernel_size**2) |
| self.moduleAlpha1_ds = Subnet_offset_ds(self.kernel_size**2) |
| self.moduleBeta1_ds = Subnet_offset_ds(self.kernel_size**2) |
| self.moduleWeight2_ds = Subnet_weight_ds(self.kernel_size**2) |
| self.moduleAlpha2_ds = Subnet_offset_ds(self.kernel_size**2) |
| self.moduleBeta2_ds = Subnet_offset_ds(self.kernel_size**2) |
|
|
| self.moduleWeight1 = Subnet_weight(self.kernel_size**2) |
| self.moduleAlpha1 = Subnet_offset(self.kernel_size**2) |
| self.moduleBeta1 = Subnet_offset(self.kernel_size**2) |
| self.moduleWeight2 = Subnet_weight(self.kernel_size**2) |
| self.moduleAlpha2 = Subnet_offset(self.kernel_size**2) |
| self.moduleBeta2 = Subnet_offset(self.kernel_size**2) |
|
|
| self.moduleWeight1_us = Subnet_weight_us(self.kernel_size**2) |
| self.moduleAlpha1_us = Subnet_offset_us(self.kernel_size**2) |
| self.moduleBeta1_us = Subnet_offset_us(self.kernel_size**2) |
| self.moduleWeight2_us = Subnet_weight_us(self.kernel_size**2) |
| self.moduleAlpha2_us = Subnet_offset_us(self.kernel_size**2) |
| self.moduleBeta2_us = Subnet_offset_us(self.kernel_size**2) |
|
|
| def forward(self, tensorCombine): |
| |
| Weight1_ds = self.moduleWeight1_ds(tensorCombine) |
| Weight1 = self.moduleWeight1(tensorCombine) |
| Weight1_us = self.moduleWeight1_us(tensorCombine) |
| Alpha1_ds = self.moduleAlpha1_ds(tensorCombine) |
| Alpha1 = self.moduleAlpha1(tensorCombine) |
| Alpha1_us = self.moduleAlpha1_us(tensorCombine) |
| Beta1_ds = self.moduleBeta1_ds(tensorCombine) |
| Beta1 = self.moduleBeta1(tensorCombine) |
| Beta1_us = self.moduleBeta1_us(tensorCombine) |
|
|
| |
| Weight2_ds = self.moduleWeight2_ds(tensorCombine) |
| Weight2 = self.moduleWeight2(tensorCombine) |
| Weight2_us = self.moduleWeight2_us(tensorCombine) |
| Alpha2_ds = self.moduleAlpha2_ds(tensorCombine) |
| Alpha2 = self.moduleAlpha2(tensorCombine) |
| Alpha2_us = self.moduleAlpha2_us(tensorCombine) |
| Beta2_ds = self.moduleBeta2_ds(tensorCombine) |
| Beta2 = self.moduleBeta2(tensorCombine) |
| Beta2_us = self.moduleBeta2_us(tensorCombine) |
|
|
| return ( |
| Weight1_ds, |
| Alpha1_ds, |
| Beta1_ds, |
| Weight2_ds, |
| Alpha2_ds, |
| Beta2_ds, |
| Weight1, |
| Alpha1, |
| Beta1, |
| Weight2, |
| Alpha2, |
| Beta2, |
| Weight1_us, |
| Alpha1_us, |
| Beta1_us, |
| Weight2_us, |
| Alpha2_us, |
| Beta2_us, |
| ) |
|
|
|
|
| class STMFNet_Model(torch.nn.Module): |
| def __init__(self): |
| super(STMFNet_Model, self).__init__() |
|
|
| class Metric(torch.nn.Module): |
| def __init__(self): |
| super(Metric, self).__init__() |
| self.paramScale = torch.nn.Parameter(-torch.ones(1, 1, 1, 1)) |
|
|
| def forward(self, tenFirst, tenSecond, tenFlow): |
| return self.paramScale * F.l1_loss( |
| input=tenFirst, |
| target=backwarp(tenSecond, tenFlow), |
| reduction="none", |
| ).mean(1, True) |
|
|
| self.kernel_size = 5 |
| self.dilation = 1 |
| self.featc = [64, 128, 256, 512] |
| self.featnorm = "batch" |
| self.finetune_pwc = False |
|
|
| self.kernel_pad = int(((self.kernel_size - 1) * self.dilation) / 2.0) |
|
|
| self.feature_extractor = UMultiScaleResNext( |
| self.featc, norm_layer=self.featnorm |
| ) |
|
|
| self.get_kernel = KernelEstimation(self.kernel_size) |
|
|
| self.modulePad = torch.nn.ReplicationPad2d( |
| [self.kernel_pad, self.kernel_pad, self.kernel_pad, self.kernel_pad] |
| ) |
|
|
| self.moduleAdaCoF = FunctionAdaCoF.apply |
|
|
| self.gauss_kernel = torch.nn.Parameter( |
| gaussian_kernel(5, 0.5).repeat(3, 1, 1, 1), requires_grad=False |
| ) |
|
|
| self.upsampler = Upsampler_8tap() |
|
|
| self.scale_synthesis = MIMOGridNet( |
| (6, 6 + 6, 6), (3,), grid_chs=(32, 64, 96), n_row=3, n_col=4, outrow=(1,) |
| ) |
|
|
| self.flow_estimator = PWCNet() |
|
|
| self.softsplat = ModuleSoftsplat(strType="softmax") |
|
|
| self.metric = Metric() |
|
|
| self.dyntex_generator = UNet3d_18(bn=self.featnorm) |
|
|
| |
| if not self.finetune_pwc: |
| for param in self.flow_estimator.parameters(): |
| param.requires_grad = False |
|
|
| def forward(self, I0, I1, I2, I3): |
| h0 = int(list(I1.size())[2]) |
| w0 = int(list(I1.size())[3]) |
| h2 = int(list(I2.size())[2]) |
| w2 = int(list(I2.size())[3]) |
| if h0 != h2 or w0 != w2: |
| sys.exit("Frame sizes do not match") |
|
|
| h_padded = False |
| w_padded = False |
| if h0 % 128 != 0: |
| pad_h = 128 - (h0 % 128) |
| I0 = F.pad(I0, (0, 0, 0, pad_h), mode="reflect") |
| I1 = F.pad(I1, (0, 0, 0, pad_h), mode="reflect") |
| I2 = F.pad(I2, (0, 0, 0, pad_h), mode="reflect") |
| I3 = F.pad(I3, (0, 0, 0, pad_h), mode="reflect") |
| h_padded = True |
|
|
| if w0 % 128 != 0: |
| pad_w = 128 - (w0 % 128) |
| I0 = F.pad(I0, (0, pad_w, 0, 0), mode="reflect") |
| I1 = F.pad(I1, (0, pad_w, 0, 0), mode="reflect") |
| I2 = F.pad(I2, (0, pad_w, 0, 0), mode="reflect") |
| I3 = F.pad(I3, (0, pad_w, 0, 0), mode="reflect") |
| w_padded = True |
|
|
| feats = self.feature_extractor(moduleNormalize(I1), moduleNormalize(I2)) |
| kernelest = self.get_kernel(feats) |
| Weight1_ds, Alpha1_ds, Beta1_ds, Weight2_ds, Alpha2_ds, Beta2_ds = kernelest[:6] |
| Weight1, Alpha1, Beta1, Weight2, Alpha2, Beta2 = kernelest[6:12] |
| Weight1_us, Alpha1_us, Beta1_us, Weight2_us, Alpha2_us, Beta2_us = kernelest[ |
| 12: |
| ] |
|
|
| |
| tensorAdaCoF1 = ( |
| self.moduleAdaCoF(self.modulePad(I1), Weight1, Alpha1, Beta1, self.dilation) |
| * 1.0 |
| ) |
| tensorAdaCoF2 = ( |
| self.moduleAdaCoF(self.modulePad(I2), Weight2, Alpha2, Beta2, self.dilation) |
| * 1.0 |
| ) |
|
|
| |
| c, h, w = I1.shape[1:] |
| p = (self.gauss_kernel.shape[-1] - 1) // 2 |
| I1_blur = F.conv2d( |
| F.pad(I1, pad=(p, p, p, p), mode="reflect"), self.gauss_kernel, groups=c |
| ) |
| I2_blur = F.conv2d( |
| F.pad(I2, pad=(p, p, p, p), mode="reflect"), self.gauss_kernel, groups=c |
| ) |
| I1_ds = F.interpolate( |
| I1_blur, size=(h // 2, w // 2), mode="bilinear", align_corners=False |
| ) |
| I2_ds = F.interpolate( |
| I2_blur, size=(h // 2, w // 2), mode="bilinear", align_corners=False |
| ) |
| tensorAdaCoF1_ds = ( |
| self.moduleAdaCoF( |
| self.modulePad(I1_ds), Weight1_ds, Alpha1_ds, Beta1_ds, self.dilation |
| ) |
| * 1.0 |
| ) |
| tensorAdaCoF2_ds = ( |
| self.moduleAdaCoF( |
| self.modulePad(I2_ds), Weight2_ds, Alpha2_ds, Beta2_ds, self.dilation |
| ) |
| * 1.0 |
| ) |
|
|
| |
| I1_us = self.upsampler(I1) |
| I2_us = self.upsampler(I2) |
| tensorAdaCoF1_us = ( |
| self.moduleAdaCoF( |
| self.modulePad(I1_us), Weight1_us, Alpha1_us, Beta1_us, self.dilation |
| ) |
| * 1.0 |
| ) |
| tensorAdaCoF2_us = ( |
| self.moduleAdaCoF( |
| self.modulePad(I2_us), Weight2_us, Alpha2_us, Beta2_us, self.dilation |
| ) |
| * 1.0 |
| ) |
|
|
| |
| pyramid0, pyramid2 = self.flow_estimator.extract_pyramid(I1, I2) |
| flow_0_2 = 20 * self.flow_estimator(I1, I2, pyramid0, pyramid2) |
| flow_0_2 = F.interpolate( |
| flow_0_2, size=(h, w), mode="bilinear", align_corners=False |
| ) |
| flow_2_0 = 20 * self.flow_estimator(I2, I1, pyramid2, pyramid0) |
| flow_2_0 = F.interpolate( |
| flow_2_0, size=(h, w), mode="bilinear", align_corners=False |
| ) |
| metric_0_2 = self.metric(I1, I2, flow_0_2) |
| metric_2_0 = self.metric(I2, I1, flow_2_0) |
| tensorSoftsplat0 = self.softsplat(I1, 0.5 * flow_0_2, metric_0_2) |
| tensorSoftsplat2 = self.softsplat(I2, 0.5 * flow_2_0, metric_2_0) |
|
|
| |
| tensorCombine_us = torch.cat([tensorAdaCoF1_us, tensorAdaCoF2_us], dim=1) |
| tensorCombine = torch.cat( |
| [tensorAdaCoF1, tensorAdaCoF2, tensorSoftsplat0, tensorSoftsplat2], dim=1 |
| ) |
| tensorCombine_ds = torch.cat([tensorAdaCoF1_ds, tensorAdaCoF2_ds], dim=1) |
| output_tilde = self.scale_synthesis( |
| tensorCombine_us, tensorCombine, tensorCombine_ds |
| )[0] |
|
|
| |
| dyntex = self.dyntex_generator(I0, I1, I2, I3, output_tilde) |
| output = output_tilde + dyntex |
|
|
| if h_padded: |
| output = output[:, :, 0:h0, :] |
| if w_padded: |
| output = output[:, :, :, 0:w0] |
|
|
| if self.training: |
| return {"frame1": output} |
| else: |
| return output |