| """ |
| MIT License |
| |
| Copyright (c) 2024 Hzwer |
| |
| Permission is hereby granted, free of charge, to any person obtaining a copy |
| of this software and associated documentation files (the "Software"), to deal |
| in the Software without restriction, including without limitation the rights |
| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
| copies of the Software, and to permit persons to whom the Software is |
| furnished to do so, subject to the following conditions: |
| |
| The above copyright notice and this permission notice shall be included in all |
| copies or substantial portions of the Software. |
| |
| THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
| SOFTWARE. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| def warp(tenInput, tenFlow, tenFlow_div, backwarp_tenGrid): |
| dtype = tenInput.dtype |
| tenInput = tenInput.to(torch.float) |
| tenFlow = tenFlow.to(torch.float) |
|
|
| tenFlow = torch.cat( |
| [tenFlow[:, 0:1] / tenFlow_div[0], tenFlow[:, 1:2] / tenFlow_div[1]], 1 |
| ) |
| g = (backwarp_tenGrid + tenFlow).permute(0, 2, 3, 1) |
| padding_mode = "border" |
| if tenInput.device.type == "mps": |
| padding_mode = "zeros" |
| g = g.clamp(-1, 1) |
| return F.grid_sample( |
| input=tenInput, |
| grid=g, |
| mode="bilinear", |
| padding_mode=padding_mode, |
| align_corners=True, |
| ).to(dtype) |
|
|
|
|
| def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): |
| return nn.Sequential( |
| nn.Conv2d( |
| in_planes, |
| out_planes, |
| kernel_size=kernel_size, |
| stride=stride, |
| padding=padding, |
| dilation=dilation, |
| bias=True, |
| ), |
| nn.LeakyReLU(0.2, True), |
| ) |
|
|
|
|
| class Head(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.cnn0 = nn.Conv2d(3, 16, 3, 2, 1) |
| self.cnn1 = nn.Conv2d(16, 16, 3, 1, 1) |
| self.cnn2 = nn.Conv2d(16, 16, 3, 1, 1) |
| self.cnn3 = nn.ConvTranspose2d(16, 4, 4, 2, 1) |
| self.relu = nn.LeakyReLU(0.2, True) |
|
|
| def forward(self, x): |
| x = x.clamp(0.0, 1.0) |
| x = self.relu(self.cnn0(x)) |
| x = self.relu(self.cnn1(x)) |
| x = self.relu(self.cnn2(x)) |
| x = self.cnn3(x) |
| return x |
|
|
|
|
| class ResConv(nn.Module): |
| def __init__(self, c, dilation=1): |
| super().__init__() |
| self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) |
| self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) |
| self.relu = nn.LeakyReLU(0.2, True) |
|
|
| def forward(self, x): |
| return self.relu(self.conv(x) * self.beta + x) |
|
|
|
|
| class IFBlock(nn.Module): |
| def __init__(self, in_planes, c=64): |
| super().__init__() |
| self.conv0 = nn.Sequential( |
| conv(in_planes, c // 2, 3, 2, 1), |
| conv(c // 2, c, 3, 2, 1), |
| ) |
| self.convblock = nn.Sequential( |
| ResConv(c), |
| ResConv(c), |
| ResConv(c), |
| ResConv(c), |
| ResConv(c), |
| ResConv(c), |
| ResConv(c), |
| ResConv(c), |
| ) |
| self.lastconv = nn.Sequential( |
| nn.ConvTranspose2d(c, 4 * 13, 4, 2, 1), |
| nn.PixelShuffle(2), |
| ) |
|
|
| def forward(self, x, flow=None, scale=1): |
| x = F.interpolate( |
| x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False |
| ) |
| if flow is not None: |
| flow = ( |
| F.interpolate( |
| flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False |
| ) |
| * 1.0 |
| / scale |
| ) |
| x = torch.cat((x, flow), 1) |
| feat = self.conv0(x) |
| feat = self.convblock(feat) |
| tmp = self.lastconv(feat) |
| tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear", align_corners=False) |
| flow = tmp[:, :4] * scale |
| mask = tmp[:, 4:5] |
| feat = tmp[:, 5:] |
| return flow, mask, feat |
|
|
|
|
| class IFNet(nn.Module): |
| def __init__(self, scale=1.0): |
| super().__init__() |
| self.block0 = IFBlock(7 + 8, c=192) |
| self.block1 = IFBlock(8 + 4 + 8 + 8, c=128) |
| self.block2 = IFBlock(8 + 4 + 8 + 8, c=96) |
| self.block3 = IFBlock(8 + 4 + 8 + 8, c=64) |
| self.block4 = IFBlock(8 + 4 + 8 + 8, c=32) |
| self.scaleList = [16 / scale, 8 / scale, 4 / scale, 2 / scale, 1 / scale] |
| self.blocks = [self.block0, self.block1, self.block2, self.block3, self.block4] |
|
|
| def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid, f0, f1): |
| img0 = img0.clamp(0.0, 1.0) |
| img1 = img1.clamp(0.0, 1.0) |
|
|
| warped_img0 = img0 |
| warped_img1 = img1 |
| flow = None |
| mask = None |
| feat = None |
|
|
| for i in range(5): |
| if flow is None: |
| flow, mask, feat = self.blocks[i]( |
| torch.cat((img0, img1, f0, f1, timestep), 1), |
| None, |
| scale=self.scaleList[i], |
| ) |
| else: |
| wf0 = warp(f0, flow[:, :2], tenFlow_div, backwarp_tenGrid) |
| wf1 = warp(f1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) |
| fd, m0, feat = self.blocks[i]( |
| torch.cat( |
| ( |
| warped_img0, |
| warped_img1, |
| wf0, |
| wf1, |
| timestep, |
| mask, |
| feat, |
| ), |
| 1, |
| ), |
| flow, |
| scale=self.scaleList[i], |
| ) |
| mask = m0 |
| flow = flow + fd |
| warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) |
| warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) |
| mask = torch.sigmoid(mask) |
| return warped_img0 * mask + warped_img1 * (1 - mask) |
|
|
|
|
| class Model: |
| def __init__(self): |
| self.flownet = IFNet() |
| self.encode = Head() |
| self.pad_mod = 64 |
| self.supports_timestep = True |
| self._grid_cache = {} |
| self.device = None |
|
|
| def train(self): |
| self.flownet.train() |
| self.encode.train() |
|
|
| def eval(self): |
| self.flownet.eval() |
| self.encode.eval() |
|
|
| def to(self, device): |
| self.flownet.to(device) |
| self.encode.to(device) |
|
|
| def _get_grid(self, height, width, device): |
| key = (height, width, device.type, device.index) |
| cached = self._grid_cache.get(key) |
| if cached is not None: |
| return cached |
| tenFlow_div = torch.tensor( |
| [(width - 1.0) / 2.0, (height - 1.0) / 2.0], |
| dtype=torch.float32, |
| device=device, |
| ) |
| tenHorizontal = ( |
| torch.linspace(-1.0, 1.0, width, dtype=torch.float32, device=device) |
| .view(1, 1, 1, width) |
| .expand(1, 1, height, width) |
| ) |
| tenVertical = ( |
| torch.linspace(-1.0, 1.0, height, dtype=torch.float32, device=device) |
| .view(1, 1, height, 1) |
| .expand(1, 1, height, width) |
| ) |
| backwarp_tenGrid = torch.cat([tenHorizontal, tenVertical], 1) |
| self._grid_cache[key] = (tenFlow_div, backwarp_tenGrid) |
| return tenFlow_div, backwarp_tenGrid |
|
|
| def load_model(self, path, rank=0, device="cuda"): |
| self.device = device |
| state_dict = torch.load(path, map_location=device) |
| if isinstance(state_dict, dict): |
| if "state_dict" in state_dict: |
| state_dict = state_dict["state_dict"] |
| elif "flownet" in state_dict: |
| state_dict = state_dict["flownet"] |
| state_dict = { |
| k.replace("module.", ""): v for k, v in state_dict.items() |
| } |
| head_state = { |
| k.replace("encode.", ""): v |
| for k, v in state_dict.items() |
| if k.startswith("encode.") |
| } |
| if head_state: |
| self.encode.load_state_dict(head_state, strict=True) |
| flow_state = { |
| k: v for k, v in state_dict.items() if not k.startswith("encode.") |
| } |
| self.flownet.load_state_dict(flow_state, strict=False) |
| self.to(device) |
|
|
| def inference(self, img0, img1, timestep=0.5, scale=1.0): |
| if scale != 1.0: |
| self.flownet.scaleList = [ |
| 16 / scale, |
| 8 / scale, |
| 4 / scale, |
| 2 / scale, |
| 1 / scale, |
| ] |
| f0 = self.encode(img0) |
| f1 = self.encode(img1) |
| height = img0.shape[2] |
| width = img0.shape[3] |
| tenFlow_div, backwarp_tenGrid = self._get_grid(height, width, img0.device) |
| timestep_tensor = torch.full( |
| (1, 1, height, width), |
| float(timestep), |
| dtype=img0.dtype, |
| device=img0.device, |
| ) |
| return self.flownet( |
| img0, img1, timestep_tensor, tenFlow_div, backwarp_tenGrid, f0, f1 |
| ) |
|
|