| """ |
| https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_scripts/interpolate.py |
| https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_train/frame_interpolation/models/ssldtm.py |
| https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_util/util_v0.py |
| https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_util/twodee_v0.py |
| https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_util/pytorch_v0.py |
| https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_util/distance_transform_v0.py |
| https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_util/sketchers_v1.py |
| https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_train/frame_interpolation/helpers/interpolator_v0.py |
| https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_train/frame_interpolation/helpers/gridnet_v1.py |
| https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_util/flow_v0.py |
| https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_util/softsplat_v0.py |
| https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_train/frame_interpolation/helpers/raft_v1/rfr_new.py |
| https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_train/frame_interpolation/helpers/raft_v1/extractor.py |
| https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_train/frame_interpolation/helpers/raft_v1/update.py |
| https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_train/frame_interpolation/helpers/raft_v1/corr.py |
| https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_train/frame_interpolation/helpers/raft_v1/utils.py |
| """ |
|
|
| import copy |
| import cv2 |
| import torch.nn.functional as F |
| import torchvision.transforms.functional as F |
| import gc |
| from PIL import Image, ImageFile, ImageFont, ImageDraw |
| import inspect |
| from scipy import interpolate |
| import kornia |
| import math |
| from argparse import Namespace |
| import torch.nn as nn |
| import numpy as np |
| import os |
| from functools import partial |
| import pathlib |
| import PIL |
| import re |
| import requests |
| from scipy.spatial.transform import Rotation |
| import scipy |
| import shutil |
| import torchvision.transforms as T |
| import time |
| import torch |
| import torchvision as tv |
| import zlib |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from tqdm.auto import tqdm as std_tqdm |
| from tqdm.auto import trange as std_trange |
| from vfi_models.ops import FunctionSoftsplat, batch_edt |
| from comfy.model_management import get_torch_device |
|
|
| device = get_torch_device() |
| autocast = torch.autocast |
| tqdm = partial(std_tqdm, dynamic_ncols=True) |
| trange = partial(std_trange, dynamic_ncols=True) |
| ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
|
|
| def pixel_ij(x, rounding=True): |
| if isinstance(x, np.ndarray): |
| x = x.tolist() |
| return tuple( |
| pixel_rounder(i, rounding) |
| for i in (x if isinstance(x, tuple) or isinstance(x, list) else (x, x)) |
| ) |
|
|
|
|
| def rescale_dry(x, factor): |
| h, w = x[-2:] if isinstance(x, tuple) or isinstance(x, list) else I(x).size |
| return (h * factor, w * factor) |
|
|
|
|
| def pixel_rounder(n, mode): |
| if mode == True or mode == "round": |
| return round(n) |
| elif mode == "ceil": |
| return math.ceil(n) |
| elif mode == "floor": |
| return math.floor(n) |
| else: |
| return n |
|
|
|
|
| def diam(x): |
| if isinstance(x, tuple) or isinstance(x, list): |
| h, w = x[-2:] |
| elif isinstance(x, I): |
| h, w = x.size |
| else: |
| h, w = x.shape[-2:] |
| return np.sqrt(h**2 + w**2) |
|
|
|
|
| def pixel_logit(x, pixel_margin=1): |
| x = (x * (255 - 2 * pixel_margin) + pixel_margin) / 255 |
| return torch.log(x / (1 - x)) |
|
|
|
|
| class InputPadder: |
| """Pads images such that dimensions are divisible by 8""" |
|
|
| def __init__(self, dims): |
| self.ht, self.wd = dims[-2:] |
| pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 |
| pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 |
| self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht] |
|
|
| def pad(self, *inputs): |
| return [F.pad(x, self._pad, mode="replicate") for x in inputs] |
|
|
| def unpad(self, x): |
| ht, wd = x.shape[-2:] |
| c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] |
| return x[..., c[0] : c[1], c[2] : c[3]] |
|
|
|
|
| def forward_interpolate(flow): |
| flow = flow.detach().cpu().numpy() |
| dx, dy = flow[0], flow[1] |
|
|
| ht, wd = dx.shape |
| x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) |
|
|
| x1 = x0 + dx |
| y1 = y0 + dy |
|
|
| x1 = x1.reshape(-1) |
| y1 = y1.reshape(-1) |
| dx = dx.reshape(-1) |
| dy = dy.reshape(-1) |
|
|
| valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) |
| x1 = x1[valid] |
| y1 = y1[valid] |
| dx = dx[valid] |
| dy = dy[valid] |
|
|
| flow_x = interpolate.griddata((x1, y1), dx, (x0, y0), method="cubic", fill_value=0) |
|
|
| flow_y = interpolate.griddata((x1, y1), dy, (x0, y0), method="cubic", fill_value=0) |
|
|
| flow = np.stack([flow_x, flow_y], axis=0) |
| return torch.from_numpy(flow).float() |
|
|
|
|
| def bilinear_sampler(img, coords, mode="bilinear", mask=False): |
| """Wrapper for grid_sample, uses pixel coordinates""" |
| H, W = img.shape[-2:] |
| xgrid, ygrid = coords.split([1, 1], dim=-1) |
| xgrid = 2 * xgrid / (W - 1) - 1 |
| ygrid = 2 * ygrid / (H - 1) - 1 |
|
|
| grid = torch.cat([xgrid, ygrid], dim=-1) |
| |
| img = F.grid_sample(img, grid, align_corners=True) |
|
|
| if mask: |
| mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) |
| return img, mask.float() |
|
|
| return img |
|
|
|
|
| def coords_grid(batch, ht, wd): |
| coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) |
| coords = torch.stack(coords[::-1], dim=0).float() |
| return coords[None].repeat(batch, 1, 1, 1) |
|
|
|
|
| def upflow8(flow, mode="bilinear"): |
| new_size = (8 * flow.shape[2], 8 * flow.shape[3]) |
| return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) |
|
|
|
|
| class CorrBlock: |
| def __init__(self, fmap1, fmap2, num_levels=4, radius=4): |
| self.num_levels = num_levels |
| self.radius = radius |
| self.corr_pyramid = [] |
|
|
| |
| corr = CorrBlock.corr(fmap1, fmap2) |
|
|
| batch, h1, w1, dim, h2, w2 = corr.shape |
| corr = corr.reshape(batch * h1 * w1, dim, h2, w2) |
|
|
| self.corr_pyramid.append(corr) |
| for i in range(self.num_levels - 1): |
| corr = F.avg_pool2d(corr, 2, stride=2) |
| self.corr_pyramid.append(corr) |
|
|
| def __call__(self, coords): |
| r = self.radius |
| coords = coords.permute(0, 2, 3, 1) |
| batch, h1, w1, _ = coords.shape |
|
|
| out_pyramid = [] |
| for i in range(self.num_levels): |
| corr = self.corr_pyramid[i] |
| dx = torch.linspace(-r, r, 2 * r + 1) |
| dy = torch.linspace(-r, r, 2 * r + 1) |
| delta = torch.stack(torch.meshgrid(dy, dx), dim=-1).to(coords.device) |
|
|
| centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i |
| delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) |
| coords_lvl = centroid_lvl + delta_lvl |
|
|
| corr = bilinear_sampler(corr, coords_lvl) |
| corr = corr.view(batch, h1, w1, -1) |
| out_pyramid.append(corr) |
|
|
| out = torch.cat(out_pyramid, dim=-1) |
| return out.permute(0, 3, 1, 2).contiguous().float() |
|
|
| @staticmethod |
| def corr(fmap1, fmap2): |
| batch, dim, ht, wd = fmap1.shape |
| fmap1 = fmap1.view(batch, dim, ht * wd) |
| fmap2 = fmap2.view(batch, dim, ht * wd) |
|
|
| corr = torch.matmul(fmap1.transpose(1, 2), fmap2) |
| corr = corr.view(batch, ht, wd, 1, ht, wd) |
| return corr / torch.sqrt(torch.tensor(dim).float()) |
|
|
|
|
| class FlowHead(nn.Module): |
| def __init__(self, input_dim=128, hidden_dim=256): |
| super(FlowHead, self).__init__() |
| self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) |
| self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) |
| self.relu = nn.ReLU(inplace=True) |
|
|
| def forward(self, x): |
| return self.conv2(self.relu(self.conv1(x))) |
|
|
|
|
| class ConvGRU(nn.Module): |
| def __init__(self, hidden_dim=128, input_dim=192 + 128): |
| super(ConvGRU, self).__init__() |
| self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) |
| self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) |
| self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) |
|
|
| def forward(self, h, x): |
| hx = torch.cat([h, x], dim=1) |
|
|
| z = torch.sigmoid(self.convz(hx)) |
| r = torch.sigmoid(self.convr(hx)) |
| q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1))) |
|
|
| h = (1 - z) * h + z * q |
| return h |
|
|
|
|
| class SepConvGRU(nn.Module): |
| def __init__(self, hidden_dim=128, input_dim=192 + 128): |
| super(SepConvGRU, self).__init__() |
| self.convz1 = nn.Conv2d( |
| hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) |
| ) |
| self.convr1 = nn.Conv2d( |
| hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) |
| ) |
| self.convq1 = nn.Conv2d( |
| hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) |
| ) |
|
|
| self.convz2 = nn.Conv2d( |
| hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) |
| ) |
| self.convr2 = nn.Conv2d( |
| hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) |
| ) |
| self.convq2 = nn.Conv2d( |
| hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) |
| ) |
|
|
| def forward(self, h, x): |
| |
| hx = torch.cat([h, x], dim=1) |
| z = torch.sigmoid(self.convz1(hx)) |
| r = torch.sigmoid(self.convr1(hx)) |
| q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) |
| h = (1 - z) * h + z * q |
|
|
| |
| hx = torch.cat([h, x], dim=1) |
| z = torch.sigmoid(self.convz2(hx)) |
| r = torch.sigmoid(self.convr2(hx)) |
| q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) |
| h = (1 - z) * h + z * q |
|
|
| return h |
|
|
|
|
| class SmallMotionEncoder(nn.Module): |
| def __init__(self, args): |
| super(SmallMotionEncoder, self).__init__() |
| cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2 |
| self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) |
| self.convf1 = nn.Conv2d(2, 64, 7, padding=3) |
| self.convf2 = nn.Conv2d(64, 32, 3, padding=1) |
| self.conv = nn.Conv2d(128, 80, 3, padding=1) |
|
|
| def forward(self, flow, corr): |
| cor = F.relu(self.convc1(corr)) |
| flo = F.relu(self.convf1(flow)) |
| flo = F.relu(self.convf2(flo)) |
| cor_flo = torch.cat([cor, flo], dim=1) |
| out = F.relu(self.conv(cor_flo)) |
| return torch.cat([out, flow], dim=1) |
|
|
|
|
| class BasicMotionEncoder(nn.Module): |
| def __init__(self, args): |
| super(BasicMotionEncoder, self).__init__() |
| cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2 |
| self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) |
| self.convc2 = nn.Conv2d(256, 192, 3, padding=1) |
| self.convf1 = nn.Conv2d(2, 128, 7, padding=3) |
| self.convf2 = nn.Conv2d(128, 64, 3, padding=1) |
| self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1) |
|
|
| def forward(self, flow, corr): |
| cor = F.relu(self.convc1(corr)) |
| cor = F.relu(self.convc2(cor)) |
| flo = F.relu(self.convf1(flow)) |
| flo = F.relu(self.convf2(flo)) |
|
|
| cor_flo = torch.cat([cor, flo], dim=1) |
| out = F.relu(self.conv(cor_flo)) |
| return torch.cat([out, flow], dim=1) |
|
|
|
|
| class SmallUpdateBlock(nn.Module): |
| def __init__(self, args, hidden_dim=96): |
| super(SmallUpdateBlock, self).__init__() |
| self.encoder = SmallMotionEncoder(args) |
| self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82 + 64) |
| self.flow_head = FlowHead(hidden_dim, hidden_dim=128) |
|
|
| def forward(self, net, inp, corr, flow): |
| motion_features = self.encoder(flow, corr) |
| inp = torch.cat([inp, motion_features], dim=1) |
| net = self.gru(net, inp) |
| delta_flow = self.flow_head(net) |
|
|
| return net, None, delta_flow |
|
|
|
|
| class BasicUpdateBlock(nn.Module): |
| def __init__(self, args, hidden_dim=128, input_dim=128): |
| super(BasicUpdateBlock, self).__init__() |
| self.args = args |
| self.encoder = BasicMotionEncoder(args) |
| self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim) |
| self.flow_head = FlowHead(hidden_dim, hidden_dim=256) |
|
|
| self.mask = nn.Sequential( |
| nn.Conv2d(128, 256, 3, padding=1), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(256, 64 * 9, 1, padding=0), |
| ) |
|
|
| def forward(self, net, inp, corr, flow, upsample=True): |
| motion_features = self.encoder(flow, corr) |
| inp = torch.cat([inp, motion_features], dim=1) |
|
|
| net = self.gru(net, inp) |
| delta_flow = self.flow_head(net) |
|
|
| |
| mask = 0.25 * self.mask(net) |
| return net, mask, delta_flow |
|
|
|
|
| class ResidualBlock(nn.Module): |
| def __init__(self, in_planes, planes, norm_fn="group", stride=1): |
| super(ResidualBlock, self).__init__() |
|
|
| self.conv1 = nn.Conv2d( |
| in_planes, planes, kernel_size=3, padding=1, stride=stride |
| ) |
| self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) |
| self.relu = nn.ReLU(inplace=True) |
|
|
| num_groups = planes // 8 |
|
|
| if norm_fn == "group": |
| self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) |
| self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) |
| if not stride == 1: |
| self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) |
|
|
| elif norm_fn == "batch": |
| self.norm1 = nn.BatchNorm2d(planes) |
| self.norm2 = nn.BatchNorm2d(planes) |
| if not stride == 1: |
| self.norm3 = nn.BatchNorm2d(planes) |
|
|
| elif norm_fn == "instance": |
| self.norm1 = nn.InstanceNorm2d(planes) |
| self.norm2 = nn.InstanceNorm2d(planes) |
| if not stride == 1: |
| self.norm3 = nn.InstanceNorm2d(planes) |
|
|
| elif norm_fn == "none": |
| self.norm1 = nn.Sequential() |
| self.norm2 = nn.Sequential() |
| if not stride == 1: |
| self.norm3 = nn.Sequential() |
|
|
| if stride == 1: |
| self.downsample = None |
|
|
| else: |
| self.downsample = nn.Sequential( |
| nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3 |
| ) |
|
|
| def forward(self, x): |
| y = x |
| y = self.relu(self.norm1(self.conv1(y))) |
| y = self.relu(self.norm2(self.conv2(y))) |
|
|
| if self.downsample is not None: |
| x = self.downsample(x) |
|
|
| return self.relu(x + y) |
|
|
|
|
| class BottleneckBlock(nn.Module): |
| def __init__(self, in_planes, planes, norm_fn="group", stride=1): |
| super(BottleneckBlock, self).__init__() |
|
|
| self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0) |
| self.conv2 = nn.Conv2d( |
| planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride |
| ) |
| self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0) |
| self.relu = nn.ReLU(inplace=True) |
|
|
| num_groups = planes // 8 |
|
|
| if norm_fn == "group": |
| self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) |
| self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) |
| self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) |
| if not stride == 1: |
| self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) |
|
|
| elif norm_fn == "batch": |
| self.norm1 = nn.BatchNorm2d(planes // 4) |
| self.norm2 = nn.BatchNorm2d(planes // 4) |
| self.norm3 = nn.BatchNorm2d(planes) |
| if not stride == 1: |
| self.norm4 = nn.BatchNorm2d(planes) |
|
|
| elif norm_fn == "instance": |
| self.norm1 = nn.InstanceNorm2d(planes // 4) |
| self.norm2 = nn.InstanceNorm2d(planes // 4) |
| self.norm3 = nn.InstanceNorm2d(planes) |
| if not stride == 1: |
| self.norm4 = nn.InstanceNorm2d(planes) |
|
|
| elif norm_fn == "none": |
| self.norm1 = nn.Sequential() |
| self.norm2 = nn.Sequential() |
| self.norm3 = nn.Sequential() |
| if not stride == 1: |
| self.norm4 = nn.Sequential() |
|
|
| if stride == 1: |
| self.downsample = None |
|
|
| else: |
| self.downsample = nn.Sequential( |
| nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4 |
| ) |
|
|
| def forward(self, x): |
| y = x |
| y = self.relu(self.norm1(self.conv1(y))) |
| y = self.relu(self.norm2(self.conv2(y))) |
| y = self.relu(self.norm3(self.conv3(y))) |
|
|
| if self.downsample is not None: |
| x = self.downsample(x) |
|
|
| return self.relu(x + y) |
|
|
|
|
| class BasicEncoder(nn.Module): |
| def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0): |
| super(BasicEncoder, self).__init__() |
| self.norm_fn = norm_fn |
|
|
| if self.norm_fn == "group": |
| self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) |
|
|
| elif self.norm_fn == "batch": |
| self.norm1 = nn.BatchNorm2d(64) |
|
|
| elif self.norm_fn == "instance": |
| self.norm1 = nn.InstanceNorm2d(64) |
|
|
| elif self.norm_fn == "none": |
| self.norm1 = nn.Sequential() |
|
|
| self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) |
| self.relu1 = nn.ReLU(inplace=True) |
|
|
| self.in_planes = 64 |
| self.layer1 = self._make_layer(64, stride=1) |
| self.layer2 = self._make_layer(96, stride=2) |
| self.layer3 = self._make_layer(128, stride=2) |
|
|
| |
| self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) |
|
|
| self.dropout = None |
| if dropout > 0: |
| self.dropout = nn.Dropout2d(p=dropout) |
|
|
| for m in self.modules(): |
| if isinstance(m, nn.Conv2d): |
| nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") |
| elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): |
| if m.weight is not None: |
| nn.init.constant_(m.weight, 1) |
| if m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
|
|
| def _make_layer(self, dim, stride=1): |
| layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) |
| layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) |
| layers = (layer1, layer2) |
|
|
| self.in_planes = dim |
| return nn.Sequential(*layers) |
|
|
| def forward(self, x): |
| |
| is_list = isinstance(x, tuple) or isinstance(x, list) |
| if is_list: |
| batch_dim = x[0].shape[0] |
| x = torch.cat(x, dim=0) |
|
|
| x = self.conv1(x) |
| x = self.norm1(x) |
| x = self.relu1(x) |
|
|
| x = self.layer1(x) |
| x = self.layer2(x) |
| x = self.layer3(x) |
|
|
| x = self.conv2(x) |
|
|
| if self.training and self.dropout is not None: |
| x = self.dropout(x) |
|
|
| if is_list: |
| x = torch.split(x, [batch_dim, batch_dim], dim=0) |
|
|
| return x |
|
|
|
|
| class BasicEncoder1(nn.Module): |
| def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0): |
| super(BasicEncoder1, self).__init__() |
| self.norm_fn = norm_fn |
|
|
| if self.norm_fn == "group": |
| self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) |
|
|
| elif self.norm_fn == "batch": |
| self.norm1 = nn.BatchNorm2d(64) |
|
|
| elif self.norm_fn == "instance": |
| self.norm1 = nn.InstanceNorm2d(64) |
|
|
| elif self.norm_fn == "none": |
| self.norm1 = nn.Sequential() |
|
|
| self.conv1 = nn.Conv2d(2, 64, kernel_size=7, stride=2, padding=3) |
| self.relu1 = nn.ReLU(inplace=True) |
|
|
| self.in_planes = 64 |
| self.layer1 = self._make_layer(64, stride=1) |
| self.layer2 = self._make_layer(96, stride=2) |
| self.layer3 = self._make_layer(128, stride=2) |
|
|
| |
| self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) |
|
|
| self.dropout = None |
| if dropout > 0: |
| self.dropout = nn.Dropout2d(p=dropout) |
|
|
| for m in self.modules(): |
| if isinstance(m, nn.Conv2d): |
| nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") |
| elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): |
| if m.weight is not None: |
| nn.init.constant_(m.weight, 1) |
| if m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
|
|
| def _make_layer(self, dim, stride=1): |
| layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) |
| layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) |
| layers = (layer1, layer2) |
|
|
| self.in_planes = dim |
| return nn.Sequential(*layers) |
|
|
| def forward(self, x): |
| |
| is_list = isinstance(x, tuple) or isinstance(x, list) |
| if is_list: |
| batch_dim = x[0].shape[0] |
| x = torch.cat(x, dim=0) |
|
|
| x = self.conv1(x) |
| x = self.norm1(x) |
| x = self.relu1(x) |
|
|
| x = self.layer1(x) |
| x = self.layer2(x) |
| x = self.layer3(x) |
|
|
| x = self.conv2(x) |
|
|
| if self.training and self.dropout is not None: |
| x = self.dropout(x) |
|
|
| if is_list: |
| x = torch.split(x, [batch_dim, batch_dim], dim=0) |
|
|
| return x |
|
|
|
|
| class SmallEncoder(nn.Module): |
| def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0): |
| super(SmallEncoder, self).__init__() |
| self.norm_fn = norm_fn |
|
|
| if self.norm_fn == "group": |
| self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) |
|
|
| elif self.norm_fn == "batch": |
| self.norm1 = nn.BatchNorm2d(32) |
|
|
| elif self.norm_fn == "instance": |
| self.norm1 = nn.InstanceNorm2d(32) |
|
|
| elif self.norm_fn == "none": |
| self.norm1 = nn.Sequential() |
|
|
| self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) |
| self.relu1 = nn.ReLU(inplace=True) |
|
|
| self.in_planes = 32 |
| self.layer1 = self._make_layer(32, stride=1) |
| self.layer2 = self._make_layer(64, stride=2) |
| self.layer3 = self._make_layer(96, stride=2) |
|
|
| self.dropout = None |
| if dropout > 0: |
| self.dropout = nn.Dropout2d(p=dropout) |
|
|
| self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) |
|
|
| for m in self.modules(): |
| if isinstance(m, nn.Conv2d): |
| nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") |
| elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): |
| if m.weight is not None: |
| nn.init.constant_(m.weight, 1) |
| if m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
|
|
| def _make_layer(self, dim, stride=1): |
| layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) |
| layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) |
| layers = (layer1, layer2) |
|
|
| self.in_planes = dim |
| return nn.Sequential(*layers) |
|
|
| def forward(self, x): |
| |
| is_list = isinstance(x, tuple) or isinstance(x, list) |
| if is_list: |
| batch_dim = x[0].shape[0] |
| x = torch.cat(x, dim=0) |
|
|
| x = self.conv1(x) |
| x = self.norm1(x) |
| x = self.relu1(x) |
|
|
| x = self.layer1(x) |
| x = self.layer2(x) |
| x = self.layer3(x) |
| x = self.conv2(x) |
|
|
| if self.training and self.dropout is not None: |
| x = self.dropout(x) |
|
|
| if is_list: |
| x = torch.split(x, [batch_dim, batch_dim], dim=0) |
|
|
| return x |
|
|
|
|
| |
| |
| |
|
|
|
|
| def backwarp(img, flow): |
| _, _, H, W = img.size() |
|
|
| u = flow[:, 0, :, :] |
| v = flow[:, 1, :, :] |
|
|
| gridX, gridY = np.meshgrid(np.arange(W), np.arange(H)) |
|
|
| gridX = torch.tensor( |
| gridX, |
| requires_grad=False, |
| ).cuda() |
| gridY = torch.tensor( |
| gridY, |
| requires_grad=False, |
| ).cuda() |
| x = gridX.unsqueeze(0).expand_as(u).float() + u |
| y = gridY.unsqueeze(0).expand_as(v).float() + v |
| |
| x = 2 * (x / (W - 1) - 0.5) |
| y = 2 * (y / (H - 1) - 0.5) |
| |
| grid = torch.stack((x, y), dim=3) |
| |
| imgOut = torch.nn.functional.grid_sample(img, grid, align_corners=True) |
|
|
| return imgOut |
|
|
|
|
| class ErrorAttention(nn.Module): |
| """A three-layer network for predicting mask""" |
|
|
| def __init__(self, input, output): |
| super(ErrorAttention, self).__init__() |
| self.conv1 = nn.Conv2d(input, 32, 5, padding=2) |
| self.conv2 = nn.Conv2d(32, 32, 3, padding=1) |
| self.conv3 = nn.Conv2d(38, output, 3, padding=1) |
| self.prelu1 = nn.PReLU() |
| self.prelu2 = nn.PReLU() |
|
|
| def forward(self, x1): |
| x = self.prelu1(self.conv1(x1)) |
| x = self.prelu2(torch.cat([self.conv2(x), x1], dim=1)) |
| x = self.conv3(x) |
| return x |
|
|
|
|
| class RFR(nn.Module): |
| def __init__(self, args): |
| super(RFR, self).__init__() |
| self.attention2 = ErrorAttention(6, 1) |
| self.hidden_dim = hdim = 128 |
| self.context_dim = cdim = 128 |
| args.corr_levels = 4 |
| args.corr_radius = 4 |
| args.dropout = 0 |
| self.args = args |
|
|
| |
| self.fnet = BasicEncoder(output_dim=256, norm_fn="none", dropout=args.dropout) |
| |
| self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) |
|
|
| def freeze_bn(self): |
| for m in self.modules(): |
| if isinstance(m, nn.BatchNorm2d): |
| m.eval() |
|
|
| def initialize_flow(self, img): |
| """Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" |
| N, C, H, W = img.shape |
| coords0 = coords_grid(N, H // 8, W // 8).to(img.device) |
| coords1 = coords_grid(N, H // 8, W // 8).to(img.device) |
|
|
| |
| return coords0, coords1 |
|
|
| def upsample_flow(self, flow, mask): |
| """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination""" |
| N, _, H, W = flow.shape |
| mask = mask.view(N, 1, 9, 8, 8, H, W) |
| mask = torch.softmax(mask, dim=2) |
|
|
| up_flow = F.unfold(8 * flow, [3, 3], padding=1) |
| up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) |
|
|
| up_flow = torch.sum(mask * up_flow, dim=2) |
| up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) |
| return up_flow.reshape(N, 2, 8 * H, 8 * W) |
|
|
| def forward( |
| self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False |
| ): |
| H, W = image1.size()[2:4] |
| H8 = H // 8 * 8 |
| W8 = W // 8 * 8 |
|
|
| if flow_init is not None: |
| flow_init_resize = F.interpolate( |
| flow_init, size=(H8 // 8, W8 // 8), mode="nearest" |
| ) |
|
|
| flow_init_resize[:, :1] = ( |
| flow_init_resize[:, :1].clone() * (W8 // 8 * 1.0) / flow_init.size()[3] |
| ) |
| flow_init_resize[:, 1:] = ( |
| flow_init_resize[:, 1:].clone() * (H8 // 8 * 1.0) / flow_init.size()[2] |
| ) |
|
|
| if not hasattr(self.args, "not_use_rfr_mask") or ( |
| hasattr(self.args, "not_use_rfr_mask") |
| and (not self.args.not_use_rfr_mask) |
| ): |
| im18 = F.interpolate(image1, size=(H8 // 8, W8 // 8), mode="bilinear") |
| im28 = F.interpolate(image2, size=(H8 // 8, W8 // 8), mode="bilinear") |
|
|
| warp21 = backwarp(im28, flow_init_resize) |
| error21 = torch.sum(torch.abs(warp21 - im18), dim=1, keepdim=True) |
| |
| f12init = ( |
| torch.exp( |
| -self.attention2( |
| torch.cat([im18, error21, flow_init_resize], dim=1) |
| ) |
| ** 2 |
| ) |
| * flow_init_resize |
| ) |
| else: |
| flow_init_resize = None |
| flow_init = torch.zeros( |
| image1.size()[0], 2, image1.size()[2] // 8, image1.size()[3] // 8 |
| ).cuda() |
| error21 = torch.zeros( |
| image1.size()[0], 1, image1.size()[2] // 8, image1.size()[3] // 8 |
| ).cuda() |
|
|
| f12_init = flow_init |
| |
|
|
| image1 = F.interpolate(image1, size=(H8, W8), mode="bilinear") |
| image2 = F.interpolate(image2, size=(H8, W8), mode="bilinear") |
|
|
| f12s, f12, f12_init = self.forward_pred( |
| image1, image2, iters, flow_init_resize, upsample, test_mode |
| ) |
|
|
| if hasattr(self.args, "requires_sq_flow") and self.args.requires_sq_flow: |
| for ii in range(len(f12s)): |
| f12s[ii] = F.interpolate(f12s[ii], size=(H, W), mode="bilinear") |
| f12s[ii][:, :1] = f12s[ii][:, :1].clone() / (1.0 * W8) * W |
| f12s[ii][:, 1:] = f12s[ii][:, 1:].clone() / (1.0 * H8) * H |
| if self.training: |
| return f12s |
| else: |
| return [f12s[-1]], f12_init |
| else: |
| f12[:, :1] = f12[:, :1].clone() / (1.0 * W8) * W |
| f12[:, 1:] = f12[:, 1:].clone() / (1.0 * H8) * H |
|
|
| f12 = F.interpolate(f12, size=(H, W), mode="bilinear") |
| |
| return ( |
| f12, |
| f12_init, |
| error21, |
| ) |
|
|
| def forward_pred( |
| self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False |
| ): |
| """Estimate optical flow between pair of frames""" |
|
|
| image1 = image1.contiguous() |
| image2 = image2.contiguous() |
|
|
| hdim = self.hidden_dim |
| cdim = self.context_dim |
|
|
| |
| with autocast(device.type, enabled=self.args.mixed_precision): |
| fmap1, fmap2 = self.fnet([image1, image2]) |
| fmap1 = fmap1.float() |
| fmap2 = fmap2.float() |
| corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) |
|
|
| |
| with autocast(device.type, enabled=self.args.mixed_precision): |
| cnet = self.fnet(image1) |
| net, inp = torch.split(cnet, [hdim, cdim], dim=1) |
| net = torch.tanh(net) |
| inp = torch.relu(inp) |
|
|
| coords0, coords1 = self.initialize_flow(image1) |
|
|
| if flow_init is not None: |
| coords1 = coords1 + flow_init |
|
|
| flow_predictions = [] |
| for itr in range(iters): |
| coords1 = coords1.detach() |
| if itr == 0: |
| if flow_init is not None: |
| coords1 = coords1 + flow_init |
| corr = corr_fn(coords1) |
|
|
| flow = coords1 - coords0 |
| with autocast(device.type, enabled=self.args.mixed_precision): |
| net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) |
|
|
| |
| coords1 = coords1 + delta_flow |
|
|
| |
| if up_mask is None: |
| flow_up = upflow8(coords1 - coords0) |
| else: |
| flow_up = self.upsample_flow(coords1 - coords0, up_mask) |
|
|
| flow_predictions.append(flow_up) |
|
|
| return flow_predictions, flow_up, flow_init |
|
|
| |
|
|
|
|
| |
| |
| |
| def flow_backwarp( |
| img, flow, resample="bilinear", padding_mode="border", align_corners=False |
| ): |
| if len(img.shape) != 4: |
| img = img[None,] |
| if len(flow.shape) != 4: |
| flow = flow[None,] |
| q = ( |
| 2 |
| * flow |
| / torch.tensor( |
| [ |
| flow.shape[-2], |
| flow.shape[-1], |
| ], |
| device=flow.device, |
| dtype=torch.float, |
| )[None, :, None, None] |
| ) |
| q = q + torch.stack( |
| torch.meshgrid( |
| torch.linspace(-1, 1, flow.shape[-2]), |
| torch.linspace(-1, 1, flow.shape[-1]), |
| ) |
| )[ |
| None, |
| ].to( |
| flow.device |
| ) |
| if img.dtype != q.dtype: |
| img = img.type(q.dtype) |
|
|
| return nn.functional.grid_sample( |
| img, |
| q.flip(dims=(1,)).permute(0, 2, 3, 1), |
| mode=resample, |
| padding_mode=padding_mode, |
| align_corners=align_corners, |
| ) |
|
|
|
|
| backwarp = flow_warp = flow_backwarp |
|
|
|
|
| |
| |
| |
| |
| |
| def flow_forewarp( |
| img, flow, mode="average", metric=None, mask=False, retain_device=True |
| ): |
| |
| |
| |
| |
| |
| if mode in ["lin", "linear"]: |
| |
| mode = "linear" if metric is not None else "avg" |
| elif mode in ["sm", "softmax"]: |
| |
| mode = "soft" if metric is not None else "avg" |
| if len(img.shape) != 4: |
| img = img[None,] |
| if len(flow.shape) != 4: |
| flow = flow[None,] |
| if metric is not None and len(metric.shape) != 4: |
| metric = metric[None,] |
| flow = flow.flip(dims=(1,)) |
| if img.dtype != torch.float32: |
| img = img.type(torch.float32) |
| if flow.dtype != torch.float32: |
| flow = flow.type(torch.float32) |
| if metric is not None and metric.dtype != torch.float32: |
| metric = metric.type(torch.float32) |
|
|
| |
| assert img.device == flow.device |
| if metric is not None: |
| assert img.device == metric.device |
| was_cpu = img.device.type == "cpu" |
| if was_cpu: |
| img = img.to("cuda") |
| flow = flow.to("cuda") |
| if metric is not None: |
| metric = metric.to("cuda") |
|
|
| |
| if mask: |
| bs, ch, h, w = img.shape |
| img = torch.cat( |
| [img, torch.ones(bs, 1, h, w, dtype=img.dtype, device=img.device)], dim=1 |
| ) |
|
|
| |
| ans = FunctionSoftsplat(img, flow, metric, mode) |
| if was_cpu and retain_device: |
| ans = ans.cpu() |
| return ans |
|
|
|
|
| forewarp = flow_forewarp |
|
|
|
|
| |
| def flow_resize(flow, size, mode="nearest", align_corners=False): |
| |
| size = pixel_ij(size, rounding=True) |
| if flow.dtype != torch.float: |
| flow = flow.float() |
| if len(flow.shape) == 3: |
| flow = flow[None,] |
| if flow.shape[-2:] == size: |
| return flow |
| return ( |
| nn.functional.interpolate( |
| flow, |
| size=size, |
| mode=mode, |
| align_corners=align_corners if mode != "nearest" else None, |
| ) |
| * torch.tensor( |
| [b / a for a, b in zip(flow.shape[-2:], size)], |
| device=flow.device, |
| )[None, :, None, None] |
| ) |
|
|
|
|
| |
|
|
| |
| _lucaskanade = lambda a, b: np.moveaxis( |
| cv2.optflow.calcOpticalFlowSparseToDense( |
| a, |
| b, |
| ), |
| 2, |
| 0, |
| )[ |
| None, |
| ] |
| _farneback = lambda a, b: np.moveaxis( |
| cv2.calcOpticalFlowFarneback( |
| a, |
| b, |
| None, |
| 0.6, |
| 3, |
| 25, |
| 7, |
| 5, |
| 1.2, |
| cv2.OPTFLOW_FARNEBACK_GAUSSIAN, |
| ), |
| 2, |
| 0, |
| )[ |
| None, |
| ] |
| _dtvl1_ = cv2.optflow.createOptFlow_DualTVL1() |
| _dtvl1 = lambda a, b: np.moveaxis( |
| _dtvl1_.calc( |
| a, |
| b, |
| None, |
| ), |
| 2, |
| 0, |
| )[ |
| None, |
| ] |
| _simple = lambda a, b: np.moveaxis( |
| cv2.optflow.calcOpticalFlowSF( |
| a, |
| b, |
| 3, |
| 5, |
| 5, |
| ), |
| 2, |
| 0, |
| )[ |
| None, |
| ] |
| _pca_ = cv2.optflow.createOptFlow_PCAFlow() |
| _pca = lambda a, b: np.moveaxis( |
| _pca_.calc( |
| a, |
| b, |
| None, |
| ), |
| 2, |
| 0, |
| )[ |
| None, |
| ] |
| _drlof = lambda a, b: np.moveaxis( |
| cv2.optflow.calcOpticalFlowDenseRLOF( |
| a, |
| b, |
| None, |
| ), |
| 2, |
| 0, |
| )[ |
| None, |
| ] |
| _deepflow_ = cv2.optflow.createOptFlow_DeepFlow() |
| _deepflow = lambda a, b: np.moveaxis( |
| _deepflow_.calc( |
| a, |
| b, |
| None, |
| ), |
| 2, |
| 0, |
| )[ |
| None, |
| ] |
|
|
|
|
| def cv2flow(a, b, method="lucaskanade", back=False): |
| if method == "lucaskanade": |
| f = _lucaskanade |
| a = a.convert("L").cv2() |
| b = b.convert("L").cv2() |
| elif method == "farneback": |
| f = _farneback |
| a = a.convert("L").cv2() |
| b = b.convert("L").cv2() |
| elif method == "dtvl1": |
| f = _dtvl1 |
| a = a.convert("L").cv2() |
| b = b.convert("L").cv2() |
| elif method == "simple": |
| f = _simple |
| a = a.convert("RGB").cv2() |
| b = b.convert("RGB").cv2() |
| elif method == "pca": |
| f = _pca |
| a = a.convert("L").cv2() |
| b = b.convert("L").cv2() |
| elif method == "drlof": |
| f = _drlof |
| a = a.convert("RGB").cv2() |
| b = b.convert("RGB").cv2() |
| elif method == "deepflow": |
| f = _deepflow |
| a = a.convert("L").cv2() |
| b = b.convert("L").cv2() |
| else: |
| assert 0 |
| ans = f(b, a) |
| if back: |
| ans = np.concatenate( |
| [ |
| ans, |
| f(a, b), |
| ] |
| ) |
| return torch.tensor(ans).flip(dims=(1,)) |
|
|
|
|
| |
|
|
|
|
| def flownet2(img_a, img_b, mode="shm", back=False): |
| |
| url = f"http://localhost:8109/get-flow" |
| if mode == "shm": |
| t = time.time() |
| fn_a = img_a.save(mkfile(f"/dev/shm/_flownet2/{t}/img_a.png")) |
| fn_b = img_b.save(mkfile(f"/dev/shm/_flownet2/{t}/img_b.png")) |
| elif mode == "net": |
| assert False, "not impl" |
| q = u2d.img2uri(img.pil("RGB")) |
| q.decode() |
| resp = requests.get( |
| url, |
| params={ |
| "img_a": fn_a, |
| "img_b": fn_b, |
| "mode": mode, |
| "back": back, |
| |
| }, |
| ) |
|
|
| |
| ans = {"response": resp} |
| if resp.status_code == 200: |
| j = resp.json() |
| ans["time"] = j["time"] |
| ans["output"] = { |
| "flow": torch.tensor(load(j["fn_flow"])), |
| } |
| |
| |
| if mode == "shm": |
| shutil.rmtree(f"/dev/shm/_flownet2/{t}") |
| return ans |
|
|
|
|
| |
|
|
|
|
| class Gridnet(nn.Module): |
| def __init__(self, channels_0, channels_1, channels_2, total_dropout_p, depth): |
| super().__init__() |
| self.channels_0 = ch0 = channels_0 |
| self.channels_1 = ch1 = channels_1 |
| self.channels_2 = ch2 = channels_2 |
| self.total_dropout_p = p = total_dropout_p |
| self.depth = depth |
| self.encoders = nn.ModuleList( |
| [GridnetEncoder(ch0, ch1, ch2) for i in range(self.depth)] |
| ) |
| self.decoders = nn.ModuleList( |
| [GridnetDecoder(ch0, ch1, ch2) for i in range(self.depth)] |
| ) |
| self.total_dropout = GridnetTotalDropout(p) |
| return |
|
|
| def forward(self, x): |
| for e, enc in enumerate(self.encoders): |
| t = [self.total_dropout(i) for i in t] if e != 0 else x |
| t = enc(t) |
| for d, dec in enumerate(self.decoders): |
| t = [self.total_dropout(i) for i in t] |
| t = dec(t) |
| return t |
|
|
|
|
| class GridnetEncoder(nn.Module): |
| def __init__(self, channels_0, channels_1, channels_2): |
| super().__init__() |
| self.channels_0 = ch0 = channels_0 |
| self.channels_1 = ch1 = channels_1 |
| self.channels_2 = ch2 = channels_2 |
| self.resnet_0 = GridnetResnet(ch0) |
| self.resnet_1 = GridnetResnet(ch1) |
| self.resnet_2 = GridnetResnet(ch2) |
| self.downsample_01 = GridnetDownsample(ch0, ch1) |
| self.downsample_12 = GridnetDownsample(ch1, ch2) |
| return |
|
|
| def forward(self, x): |
| out = [ |
| None, |
| ] * 3 |
| out[0] = self.resnet_0(x[0]) |
| out[1] = self.resnet_1(x[1]) + self.downsample_01(out[0]) |
| out[2] = self.resnet_2(x[2]) + self.downsample_12(out[1]) |
| return out |
|
|
|
|
| class GridnetDecoder(nn.Module): |
| def __init__(self, channels_0, channels_1, channels_2): |
| super().__init__() |
| self.channels_0 = ch0 = channels_0 |
| self.channels_1 = ch1 = channels_1 |
| self.channels_2 = ch2 = channels_2 |
| self.resnet_0 = GridnetResnet(ch0) |
| self.resnet_1 = GridnetResnet(ch1) |
| self.resnet_2 = GridnetResnet(ch2) |
| self.upsample_10 = GridnetUpsample(ch1, ch0) |
| self.upsample_21 = GridnetUpsample(ch2, ch1) |
| return |
|
|
| def forward(self, x): |
| out = [ |
| None, |
| ] * 3 |
| out[2] = self.resnet_2(x[2]) |
| out[1] = self.resnet_1(x[1]) + self.upsample_21(out[2]) |
| out[0] = self.resnet_0(x[0]) + self.upsample_10(out[1]) |
| return out |
|
|
|
|
| class GridnetConverter(nn.Module): |
| def __init__(self, channels_in, channels_out): |
| super().__init__() |
| self.channels_in = cin = channels_in |
| self.channels_out = cout = channels_out |
| self.nets = nn.ModuleList( |
| [ |
| nn.Sequential( |
| nn.PReLU(a), |
| nn.Conv2d(a, b, kernel_size=1, padding=0), |
| nn.BatchNorm2d(b), |
| ) |
| for a, b in zip(cin, cout) |
| ] |
| ) |
| return |
|
|
| def forward(self, x): |
| return [m(q) for m, q in zip(self.nets, x)] |
|
|
|
|
| class GridnetResnet(nn.Module): |
| def __init__(self, channels): |
| super().__init__() |
| self.channels = ch = channels |
| self.net = nn.Sequential( |
| nn.PReLU(ch), |
| nn.Conv2d(ch, ch, kernel_size=3, padding=1), |
| nn.BatchNorm2d(ch), |
| nn.PReLU(ch), |
| nn.Conv2d(ch, ch, kernel_size=3, padding=1), |
| nn.BatchNorm2d(ch), |
| ) |
| return |
|
|
| def forward(self, x): |
| return x + self.net(x) |
|
|
|
|
| class GridnetDownsample(nn.Module): |
| def __init__(self, channels_in, channels_out): |
| super().__init__() |
| self.channels_in = chin = channels_in |
| self.channels_out = chout = channels_out |
| self.net = nn.Sequential( |
| nn.PReLU(chin), |
| nn.Conv2d(chin, chin, kernel_size=3, padding=1, stride=2), |
| nn.BatchNorm2d(chin), |
| nn.PReLU(chin), |
| nn.Conv2d(chin, chout, kernel_size=3, padding=1), |
| nn.BatchNorm2d(chout), |
| ) |
| return |
|
|
| def forward(self, x): |
| return self.net(x) |
|
|
|
|
| class GridnetUpsample(nn.Module): |
| def __init__(self, channels_in, channels_out): |
| super().__init__() |
| self.channels_in = chin = channels_in |
| self.channels_out = chout = channels_out |
| self.net = nn.Sequential( |
| nn.Upsample(scale_factor=2, mode="nearest"), |
| nn.PReLU(chin), |
| nn.Conv2d(chin, chout, kernel_size=3, padding=1), |
| nn.BatchNorm2d(chout), |
| nn.PReLU(chout), |
| nn.Conv2d(chout, chout, kernel_size=3, padding=1), |
| nn.BatchNorm2d(chout), |
| ) |
| return |
|
|
| def forward(self, x): |
| return self.net(x) |
|
|
|
|
| class GridnetTotalDropout(nn.Module): |
| def __init__(self, p): |
| super().__init__() |
| assert 0 <= p < 1 |
| self.p = p |
| self.weight = 1 / (1 - p) |
| return |
|
|
| def get_drop(self, x): |
| d = torch.rand(len(x))[:, None, None, None] < self.p |
| d = (1 - d.float()).to(x.device) * self.weight |
| return d |
|
|
| def forward(self, x, force_drop=None): |
| if force_drop is True: |
| ans = x * self.get_drop(x) |
| elif force_drop is False: |
| ans = x |
| else: |
| if self.training: |
| ans = x * self.get_drop(x) |
| else: |
| ans = x |
| return ans |
|
|
|
|
| class Interpolator(nn.Module): |
| def __init__(self, size, mode="bilinear"): |
| super().__init__() |
| self.size = size |
| self.mode = mode |
| return |
|
|
| def forward(self, x, is_flow=False): |
| if x.shape[-2] == self.size: |
| return x |
| if len(x.shape) == 4: |
| |
| bs, ch, h, w = x.shape |
| ans = nn.functional.interpolate( |
| x, |
| size=self.size, |
| mode=self.mode, |
| align_corners=(False, None)[self.mode == "nearest"], |
| ) |
| if is_flow: |
| ans = ( |
| ans |
| * torch.tensor( |
| [b / a for a, b in zip((h, w), self.size)], |
| device=ans.device, |
| )[None, :, None, None] |
| ) |
| return ans |
| elif len(x.shape) == 5: |
| |
| bs, k, ch, h, w = x.shape |
| return self.forward( |
| x.view(bs * k, ch, h, w), |
| is_flow=is_flow, |
| ).view(bs, k, ch, *self.size) |
| else: |
| assert 0 |
|
|
|
|
| |
|
|
|
|
| def canny(img, a=100, b=200): |
| img = I(img).convert("L") |
| return I(cv2.Canny(img.cv2(), a, b)) |
|
|
|
|
| |
| def canny_pis(img, sigma=0.33): |
| |
| img = I(img).convert("L").uint8(ch_last=False) |
| v = np.median(img) |
| |
| lower = int(max(0, (1.0 - sigma) * v)) |
| upper = int(min(255, (1.0 + sigma) * v)) |
| edged = cv2.Canny(img[0], lower, upper) |
| |
| return I(edged) |
|
|
|
|
| |
| def canny_otsu(img): |
| img = I(img).convert("L").uint8(ch_last=False) |
| high, _ = cv2.threshold(img[0], 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) |
| low = 0.5 * high |
| return I(cv2.Canny(img[0], low, high)) |
|
|
|
|
| def xdog(img, t=1.0, epsilon=0.04, phi=100, sigma=3, k=1.6): |
| img = I(img).convert("L").uint8(ch_last=False) |
| grey = np.asarray(img, dtype=np.float32) |
| g0 = scipy.ndimage.gaussian_filter(grey, sigma) |
| g1 = scipy.ndimage.gaussian_filter(grey, sigma * k) |
|
|
| |
| ans = (g0 - t * g1) / 255 |
| ans = 1 + np.tanh(phi * (ans - epsilon)) * (ans < epsilon) |
| return ans |
|
|
|
|
| def dog(img, t=1.0, sigma=1.0, k=1.6, epsilon=0.01, kernel_factor=4, clip=True): |
| img = I(img).convert("L").tensor()[None] |
| kern0 = max(2 * int(sigma * kernel_factor) + 1, 3) |
| kern1 = max(2 * int(sigma * k * kernel_factor) + 1, 3) |
| g0 = kornia.filters.gaussian_blur2d( |
| img, |
| (kern0, kern0), |
| (sigma, sigma), |
| border_type="replicate", |
| ) |
| g1 = kornia.filters.gaussian_blur2d( |
| img, |
| (kern1, kern1), |
| (sigma * k, sigma * k), |
| border_type="replicate", |
| ) |
| ans = 0.5 + t * (g1 - g0) - epsilon |
| ans = ans.clip(0, 1) if clip else ans |
| return ans[0].numpy() |
|
|
|
|
| |
| |
| def batch_dog(img, t=1.0, sigma=1.0, k=1.6, epsilon=0.01, kernel_factor=4, clip=True): |
| |
| bs, ch, h, w = img.shape |
| if ch in [3, 4]: |
| img = kornia.color.rgb_to_grayscale(img[:, :3]) |
| else: |
| assert ch == 1 |
|
|
| |
| kern0 = max(2 * int(sigma * kernel_factor) + 1, 3) |
| kern1 = max(2 * int(sigma * k * kernel_factor) + 1, 3) |
| g0 = kornia.filters.gaussian_blur2d( |
| img, |
| (kern0, kern0), |
| (sigma, sigma), |
| border_type="replicate", |
| ) |
| g1 = kornia.filters.gaussian_blur2d( |
| img, |
| (kern1, kern1), |
| (sigma * k, sigma * k), |
| border_type="replicate", |
| ) |
| ans = 0.5 + t * (g1 - g0) - epsilon |
| ans = ans.clip(0, 1) if clip else ans |
| return ans |
|
|
|
|
| |
|
|
| |
| |
| |
|
|
|
|
| |
| |
| def batch_chamfer_distance(gt, pred, block=1024, return_more=False): |
| t = batch_chamfer_distance_t(gt, pred, block=block) |
| p = batch_chamfer_distance_p(gt, pred, block=block) |
| cd = (t + p) / 2 |
| return cd |
|
|
|
|
| def batch_chamfer_distance_t(gt, pred, block=1024, return_more=False): |
| assert gt.device == pred.device and gt.shape == pred.shape |
| bs, h, w = gt.shape[0], gt.shape[-2], gt.shape[-1] |
| dpred = batch_edt(pred, block=block) |
| cd = (gt * dpred).float().mean((-2, -1)) / np.sqrt(h**2 + w**2) |
| if len(cd.shape) == 2: |
| assert cd.shape[1] == 1 |
| cd = cd.squeeze(1) |
| return cd |
|
|
|
|
| def batch_chamfer_distance_p(gt, pred, block=1024, return_more=False): |
| assert gt.device == pred.device and gt.shape == pred.shape |
| bs, h, w = gt.shape[0], gt.shape[-2], gt.shape[-1] |
| dgt = batch_edt(gt, block=block) |
| cd = (pred * dgt).float().mean((-2, -1)) / np.sqrt(h**2 + w**2) |
| if len(cd.shape) == 2: |
| assert cd.shape[1] == 1 |
| cd = cd.squeeze(1) |
| return cd |
|
|
|
|
| |
| |
| def batch_hausdorff_distance(gt, pred, block=1024, return_more=False): |
| assert gt.device == pred.device and gt.shape == pred.shape |
| bs, h, w = gt.shape[0], gt.shape[-2], gt.shape[-1] |
| dgt = batch_edt(gt, block=block) |
| dpred = batch_edt(pred, block=block) |
| hd = torch.stack( |
| [ |
| (dgt * pred).amax(dim=(-2, -1)), |
| (dpred * gt).amax(dim=(-2, -1)), |
| ] |
| ).amax(dim=0).float() / np.sqrt(h**2 + w**2) |
| if len(hd.shape) == 2: |
| assert hd.shape[1] == 1 |
| hd = hd.squeeze(1) |
| return hd |
|
|
|
|
| |
|
|
|
|
| def reset_parameters(model): |
| for layer in model.children(): |
| if hasattr(layer, "reset_parameters"): |
| layer.reset_parameters() |
| return model |
|
|
|
|
| def channel_squeeze(x, dim=1): |
| a = x.shape[:dim] |
| b = x.shape[dim + 2 :] |
| return x.reshape(*a, -1, *b) |
|
|
|
|
| def channel_unsqueeze(x, shape, dim=1): |
| a = x.shape[:dim] |
| b = x.shape[dim + 1 :] |
| return x.reshape(*a, *shape, *b) |
|
|
|
|
| def default_collate(items, device=None): |
| return to(dict(torch.utils.data.dataloader.default_collate(items)), device) |
|
|
|
|
| def to(x, device): |
| if device is None: |
| return x |
| if issubclass(x.__class__, dict): |
| return dict( |
| { |
| k: v.to(device) if isinstance(v, torch.Tensor) else v |
| for k, v in x.items() |
| } |
| ) |
| if isinstance(x, torch.Tensor): |
| return x.to(device) |
| if isinstance(x, np.ndarray): |
| return torch.tensor(x).to(device) |
| assert 0, "data not understood" |
|
|
|
|
| |
|
|
| from argparse import Namespace |
|
|
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| def read_filter(fn, cast=None, sort=True, sort_key=None): |
| if cast is None: |
| cast = lambda x: x |
| ans = [cast(line) for line in read(fn).split("\n") if line != ""] |
| if sort: |
| return sorted(ans, key=sort_key) |
| else: |
| return ans |
|
|
|
|
| |
|
|
|
|
| def mkfile(fn, parents=True, exist_ok=True): |
| dn = "/".join(fn.split("/")[:-1]) |
| mkdir(dn, parents=parents, exist_ok=exist_ok) |
| return fn |
|
|
|
|
| def mkdir(dn, parents=True, exist_ok=True): |
| pathlib.Path(dn).mkdir(parents=parents, exist_ok=exist_ok) |
| return dn if (not dn[-1] == "/" or dn == "/") else dn[:-1] |
|
|
|
|
| def fstrip(fn, return_more=False): |
| dspl = fn.split("/") |
| dn = "/".join(dspl[:-1]) if len(dspl) > 1 else "." |
| fn = dspl[-1] |
| fspl = fn.split(".") |
| if len(fspl) == 1: |
| bn = fspl[0] |
| ext = "" |
| else: |
| bn = ".".join(fspl[:-1]) |
| ext = fspl[-1] |
| if return_more: |
| return Namespace( |
| dn=dn, |
| fn=fn, |
| path=f"{dn}/{fn}", |
| bn_path=f"{dn}/{bn}", |
| bn=bn, |
| ext=ext, |
| ) |
| else: |
| return bn |
|
|
|
|
| def read(fn, mode="r"): |
| with open(fn, mode) as handle: |
| return handle.read() |
|
|
|
|
| def write(text, fn, mode="w"): |
| mkfile(fn, parents=True, exist_ok=True) |
| with open(fn, mode) as handle: |
| return handle.write(text) |
|
|
|
|
| import pickle |
|
|
|
|
| def dump(obj, fn, mode="wb"): |
| mkfile(fn, parents=True, exist_ok=True) |
| with open(fn, mode) as handle: |
| return pickle.dump(obj, handle) |
|
|
|
|
| def load(fn, mode="rb"): |
| with open(fn, mode) as handle: |
| return pickle.load(handle) |
|
|
|
|
| import json |
|
|
|
|
| def jwrite(x, fn, mode="w", indent="\t", sort_keys=False): |
| mkfile(fn, parents=True, exist_ok=True) |
| with open(fn, mode) as handle: |
| return json.dump(x, handle, indent=indent, sort_keys=sort_keys) |
|
|
|
|
| def jread(fn, mode="r"): |
| with open(fn, mode) as handle: |
| return json.load(handle) |
|
|
|
|
| try: |
| import yaml |
|
|
| def ywrite(x, fn, mode="w", default_flow_style=False): |
| mkfile(fn, parents=True, exist_ok=True) |
| with open(fn, mode) as handle: |
| return yaml.dump(x, handle, default_flow_style=default_flow_style) |
|
|
| def yread(fn, mode="r"): |
| with open(fn, mode) as handle: |
| return yaml.safe_load(handle) |
|
|
| except: |
| pass |
|
|
| try: |
| import pyunpack |
| except: |
| pass |
|
|
| try: |
| import mysql |
| import mysql.connector |
| except: |
| pass |
|
|
|
|
| |
|
|
| hakase = "./env/__hakase__.jpg" |
| if not os.path.isfile(hakase): |
| hakase = "./__env__/__hakase__.jpg" |
|
|
|
|
| def mem(units="m"): |
| return ( |
| psProcess(os.getpid()).memory_info().rss |
| / { |
| "b": 1, |
| "k": 1e3, |
| "m": 1e6, |
| "g": 1e9, |
| "t": 1e12, |
| }[units[0].lower()] |
| ) |
|
|
|
|
| def chunk(array, length, colwise=True): |
| if colwise: |
| return [array[i : i + length] for i in range(0, len(array), length)] |
| else: |
| return chunk(array, int(math.ceil(len(array) / length)), colwise=True) |
|
|
|
|
| def classtree(x): |
| return inspect.getclasstree(inspect.getmro(x)) |
|
|
|
|
| |
|
|
|
|
| class Table: |
| def __init__( |
| self, |
| table, |
| delimiter=" ", |
| orientation="br", |
| double_colon=True, |
| ): |
| self.delimiter = delimiter |
| self.orientation = orientation |
| self.t = Table.parse(table, delimiter, orientation, double_colon) |
| return |
|
|
| |
| def __str__(self): |
| return self.render() |
|
|
| def __repr__(self): |
| return self.render() |
|
|
| def render(self): |
| |
| empty = ("", Table._spec(self.orientation, transpose=False)) |
|
|
| |
| t = copy.deepcopy(self.t) |
| totalrows = len(t) |
| totalcols = [len(r) for r in t] |
| assert min(totalcols) == max(totalcols) |
| totalcols = totalcols[0] |
|
|
| |
| for i in range(totalrows): |
| for j in range(totalcols): |
| x, s = t[i][j] |
| sp = s[11] |
| if sp: |
| x = eval(f'f"{{{x}{sp}}}"') |
| Table._put((str(x), s), t, (i, j), empty) |
|
|
| |
| _repl = ( |
| lambda s: s[:2] + (1, 0, 0, 0, 0) + s[7:10] + (1,) + s[11:] |
| if s[2] |
| else s[:2] + (0, 0, 0, 0, 0) + s[7:10] + (1,) + s[11:] |
| ) |
| for i, row in enumerate(t): |
| for j, (x, s_own) in enumerate(row): |
| |
| if s_own[3]: |
| u, v = i, j |
| while 0 <= u: |
| _, s = t[u][v] |
| if (i, j) != (u, v) and (s[2] and not s[10]): |
| break |
| Table._put((x, _repl(s)), t, (u, v), empty) |
| u -= 1 |
|
|
| |
| if s_own[4]: |
| u, v = i, j |
| while u < totalrows: |
| _, s = t[u][v] |
| if (i, j) != (u, v) and (s[2] and not s[10]): |
| break |
| Table._put((x, _repl(s)), t, (u, v), empty) |
| u += 1 |
|
|
| |
| if s_own[5]: |
| u, v = i, j |
| while v < totalcols: |
| _, s = t[u][v] |
| if (i, j) != (u, v) and (s[2] and not s[10]): |
| break |
| Table._put((x, _repl(s)), t, (u, v), empty) |
| v += 1 |
|
|
| |
| if s_own[6]: |
| u, v = i, j |
| while 0 <= v: |
| _, s = t[u][v] |
| if (i, j) != (u, v) and (s[2] and not s[10]): |
| break |
| Table._put((x, _repl(s)), t, (u, v), empty) |
| v -= 1 |
|
|
| |
| widths = [ |
| 0, |
| ] * totalcols |
| heights = [ |
| 0, |
| ] * totalrows |
| for i, row in enumerate(t): |
| for j, (x, s) in enumerate(row): |
| |
| heights[i] = max(heights[i], x.count("\n")) |
|
|
| |
| if s[2] or not s[10]: |
| w = max(len(q) for q in x.split("\n")) |
| widths[j] = max(widths[j], w) |
| |
| heights = [h + 1 for h in heights] |
|
|
| |
| rend = [] |
| roff = 0 |
| for i, row in enumerate(t): |
| for j, (x, s) in enumerate(row): |
| w, h = widths[j], heights[i] |
|
|
| |
| if s[2] or s[10]: |
| xs = x.split("\n") |
| xw0 = min(len(l) for l in xs) |
| xw1 = max(len(l) for l in xs) |
| xh = len(xs) |
| if (xw0 == xw1 == w) and (xh == h): |
| pass |
| elif xw0 == xw1 == w: |
| x = "\n".join( |
| [ |
| xs[0], |
| ] |
| * h |
| ) |
| elif xh == h: |
| x = "\n".join([(l[0] if l else "") * w for l in xs]) |
| else: |
| x = x[0] if x else " " |
| x = "\n".join( |
| [ |
| x * w, |
| ] |
| * h |
| ) |
|
|
| |
| x = [l.rjust(w) if s[0] else l.ljust(w) for l in x.split("\n")] |
|
|
| |
| plus = [ |
| " " * w, |
| ] * (h - len(x)) |
| x = plus + x if not s[1] else x + plus |
|
|
| |
| for r, xline in enumerate(x): |
| Table._put(xline, rend, (roff + r, j), None) |
| roff += h |
|
|
| |
| return "\n".join(["".join(r) for r in rend]) |
|
|
| |
| def _spec(s, transpose=False): |
| if ":" in s: |
| i = s.index(":") |
| sp = s[i:] |
| s = s[:i] |
| else: |
| sp = "" |
| s = s.lower() |
| return ( |
| int("r" in s), |
| int("t" in s), |
| int(any([i in s for i in [".", "<", ">", "^", "v"]])), |
| int("^" in s if not transpose else "<" in s), |
| int("v" in s if not transpose else ">" in s), |
| int(">" in s if not transpose else "v" in s), |
| int("<" in s if not transpose else "^" in s), |
| int("+" in s), |
| int("-" in s if not transpose else "|" in s), |
| int("|" in s if not transpose else "-" in s), |
| int("_" in s), |
| sp, |
| ) |
|
|
| def _put(obj, t, ij, empty): |
| i, j = ij |
| while i >= len(t): |
| t.append([]) |
| while j >= len(t[i]): |
| t[i].append(empty) |
| t[i][j] = obj |
| return |
|
|
| def parse( |
| table, |
| delimiter=" ", |
| orientation="br", |
| double_colon=True, |
| ): |
| |
| transpose = False |
|
|
| |
| empty = ("", Table._spec(orientation, transpose)) |
|
|
| |
| t = [] |
| for i, row in enumerate(table): |
| for j, item in enumerate(row): |
| ij = (i, j) if not transpose else (j, i) |
| if type(item) == tuple and len(item) == 2 and type(item[1]) == str: |
| item = (item[0], Table._spec(item[1], transpose)) |
| elif double_colon and type(item) == str and "::" in item: |
| x, s = item.split("::") |
| item = (x, Table._spec(s, transpose)) |
| else: |
| item = (item, Table._spec(orientation, transpose)) |
| Table._put(item, t, ij, empty) |
|
|
| |
| maxcol = 0 |
| maxrow = len(t) |
| for i, row in enumerate(t): |
| |
| maxcol = max(maxcol, len([i for i in row if not i[1][2]])) |
|
|
| |
| for j, (x, s) in enumerate(row): |
| if s[7]: |
| r = len(x) |
| maxrow = max(maxrow, i + r) |
| c = max(len(q) for q in x) |
| maxcol = max(maxcol, j + c) |
| elif s[8]: |
| c = len(x) |
| maxcol = max(maxcol, j + c) |
| elif s[9]: |
| r = len(x) |
| maxrow = max(maxrow, i + r) |
| totalcols = 2 * maxcol + 1 |
| totalrows = maxrow |
| t += [[]] * (totalrows - len(t)) |
| newt = [] |
| delim = (delimiter, Table._spec("._" + orientation, transpose)) |
| for i, row in enumerate(t): |
| wasd = False |
| tcount = 0 |
| for j in range(totalcols): |
| item = t[i][tcount] if tcount < len(t[i]) else empty |
| isd = item[1][2] |
| if wasd and isd: |
| Table._put(empty, newt, (i, j), empty) |
| wasd = False |
| elif wasd and not isd: |
| Table._put(item, newt, (i, j), empty) |
| tcount += 1 |
| wasd = False |
| elif not wasd and isd: |
| Table._put(item, newt, (i, j), empty) |
| tcount += 1 |
| wasd = True |
| elif not wasd and not isd: |
| Table._put(delim, newt, (i, j), empty) |
| wasd = True |
| t = newt |
|
|
| |
| for row in t: |
| row.append(empty) |
|
|
| |
| delim_cols = [i for i in range(totalcols) if i % 2 == 0] |
| while True: |
| |
| ij = None |
| for i, row in enumerate(t): |
| for j, item in enumerate(row): |
| st, s = item |
| if s[7]: |
| ij = i, j, 7, st, s |
| break |
| elif s[8]: |
| ij = i, j, 8, st, s |
| break |
| elif s[9]: |
| ij = i, j, 9, st, s |
| break |
| if ij is not None: |
| break |
| if ij is None: |
| break |
|
|
| |
| i, j, k, st, s = ij |
| s = list(s) |
| s[7] = s[8] = s[9] = 0 |
| s = tuple(s) |
|
|
| |
| if k == 7: |
| for x, row in enumerate(st): |
| for y, obj in enumerate(row): |
| a = i + x if not transpose else i + y |
| b = j + 2 * y if not transpose else j + 2 * x |
| Table._put((obj, s), t, (a, b), None) |
| if k == 8: |
| for y, obj in enumerate(st): |
| Table._put((obj, s), t, (i, j + 2 * y), None) |
| if k == 9: |
| for x, obj in enumerate(st): |
| Table._put((obj, s), t, (i + x, j), None) |
|
|
| |
| return t |
|
|
|
|
| class Resnet(nn.Module): |
| def __init__(self, channels): |
| super().__init__() |
| self.channels = ch = channels |
| self.net = nn.Sequential( |
| nn.PReLU(ch), |
| nn.Conv2d(ch, ch, kernel_size=3, padding=1), |
| nn.BatchNorm2d(ch), |
| nn.PReLU(ch), |
| nn.Conv2d(ch, ch, kernel_size=3, padding=1), |
| nn.BatchNorm2d(ch), |
| ) |
| return |
|
|
| def forward(self, x): |
| return x + self.net(x) |
|
|
|
|
| class Synthesizer(nn.Module): |
| def __init__( |
| self, size, channels_image, channels_flow, channels_mask, channels_feature |
| ): |
| super().__init__() |
| self.size = size |
| self.diam = diam(self.size) |
| self.channels_image = cimg = channels_image |
| self.channels_flow = cflow = channels_flow |
| self.channels_mask = cmask = channels_mask |
| self.channels_feature = cfeat = channels_feature |
| self.channels = ch = cimg + cflow // 2 + cmask + cfeat |
| self.interpolator = Interpolator(self.size, mode="bilinear") |
| self.net = nn.Sequential( |
| nn.Conv2d(ch + 3, 64, kernel_size=1, padding=0), |
| Resnet(64), |
| nn.Sequential( |
| nn.PReLU(64), |
| nn.Conv2d(64, 32, kernel_size=3, padding=1), |
| nn.BatchNorm2d(32), |
| ), |
| Resnet(32), |
| nn.Sequential( |
| nn.PReLU(32), |
| nn.Conv2d(32, 16, kernel_size=3, padding=1), |
| nn.BatchNorm2d(16), |
| ), |
| Resnet(16), |
| nn.Sequential( |
| nn.PReLU(16), |
| nn.Conv2d(16, 3, kernel_size=3, padding=1), |
| ), |
| ) |
| return |
|
|
| def forward(self, images, flows, masks, features, return_more=False): |
| itp = self.interpolator |
| images = [ |
| (images[0] + images[1]) / 2, |
| ] + images |
| logimgs = [itp(pixel_logit(i[:, :3])) for i in images] |
| cat = torch.cat( |
| [ |
| *logimgs, |
| *[itp(f).norm(dim=1, keepdim=True) / self.diam for f in flows], |
| *[itp(m) for m in masks], |
| *[itp(f) for f in features], |
| ], |
| dim=1, |
| ) |
| residual = self.net(cat) |
| return torch.sigmoid(logimgs[0] + 0.5 * residual), ( |
| locals() if return_more else None |
| ) |
|
|
|
|
| class FlowZMetric(nn.Module): |
| def __init__(self): |
| super().__init__() |
| return |
|
|
| def forward(self, img0, img1, flow0, flow1, return_more=False): |
| |
| |
| |
| |
| img0 = kornia.color.rgb_to_lab(img0[:, :3]) |
| img1 = kornia.color.rgb_to_lab(img1[:, :3]) |
| return [ |
| -0.1 * (img1 - flow_backwarp(img0, flow0)).norm(dim=1, keepdim=True), |
| -0.1 * (img0 - flow_backwarp(img1, flow1)).norm(dim=1, keepdim=True), |
| ], (locals() if return_more else None) |
|
|
|
|
| class NEDT(nn.Module): |
| def __init__(self): |
| super().__init__() |
| return |
|
|
| def forward( |
| self, |
| img, |
| t=2.0, |
| sigma_factor=1 / 540, |
| k=1.6, |
| epsilon=0.01, |
| kernel_factor=4, |
| exp_factor=540 / 15, |
| return_more=False, |
| ): |
| with torch.no_grad(): |
| dog = batch_dog( |
| img, |
| t=t, |
| sigma=img.shape[-2] * sigma_factor, |
| k=k, |
| epsilon=epsilon, |
| kernel_factor=kernel_factor, |
| clip=False, |
| ) |
| edt = batch_edt((dog > 0.5).float()) |
| ans = 1 - (-edt * exp_factor / max(edt.shape[-2:])).exp() |
| return ans, (locals() if return_more else None) |
|
|
|
|
| class HalfWarper(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.channels_image = 4 * 3 |
| self.channels_flow = 2 * 2 |
| self.channels_mask = 2 * 1 |
| self.channels = self.channels_image + self.channels_flow + self.channels_mask |
|
|
| def morph_open(self, x, k): |
| if k == 0: |
| return x |
| else: |
| with torch.no_grad(): |
| return kornia.morphology.opening(x, torch.ones(k, k, device=x.device)) |
|
|
| def forward(self, img0, img1, flow0, flow1, z0, z1, k, t=0.5, return_more=False): |
| |
| flow0_ = (1 - t) * flow0 |
| flow1_ = t * flow1 |
| f01 = forewarp(img0, flow1_, mode="sm", metric=z1, mask=True) |
| f10 = forewarp(img1, flow0_, mode="sm", metric=z0, mask=True) |
| f01i, f01m = f01[:, :-1], self.morph_open(f01[:, -1:], k=k) |
| f10i, f10m = f10[:, :-1], self.morph_open(f10[:, -1:], k=k) |
|
|
| |
| base0 = f01m * f01i + (1 - f01m) * f10i |
| base1 = f10m * f10i + (1 - f10m) * f01i |
| ans = [ |
| [ |
| base0, |
| base1, |
| f01i, |
| f10i, |
| ], |
| [ |
| flow0_, |
| flow1_, |
| ], |
| [ |
| f01m, |
| f10m, |
| ], |
| ] |
| return ans, (locals() if return_more else None) |
|
|
|
|
| class ResnetFeatureExtractor(nn.Module): |
| def __init__(self, inferserve_query, size_in=None): |
| super().__init__() |
| self.inferserve_query = iq = inferserve_query |
| self.size_in = si = size_in |
| if iq[0] == "torchvision": |
| |
| self.base_hparams = None |
| resnet = tv.models.resnet50(pretrained=True) |
|
|
| self.resize = T.Resize(256) |
| self.resnet_preprocess = T.Normalize( |
| mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225], |
| ) |
| self.conv1 = resnet.conv1 |
| self.bn1 = resnet.bn1 |
| self.relu = resnet.relu |
| self.maxpool = resnet.maxpool |
| self.layer1 = resnet.layer1 |
| self.layer2 = resnet.layer2 |
| else: |
| base = userving.infer_model_load(*iq).eval() |
| self.base_hparams = base.hparams |
|
|
| self.resize = T.Resize(base.hparams.largs.size) |
| self.resnet_preprocess = base.resnet_preprocess |
| self.conv1 = base.resnet.conv1 |
| self.bn1 = base.resnet.bn1 |
| self.relu = base.resnet.relu |
| self.maxpool = base.resnet.maxpool |
| self.layer1 = base.resnet.layer1 |
| self.layer2 = base.resnet.layer2 |
| if self.size_in is None: |
| self.sizes_out = None |
| else: |
| s = self.resize.size |
| self.sizes_out = [ |
| pixel_ij( |
| rescale_dry(si, (s // 2) / si[0]), rounding="ceil" |
| ), |
| pixel_ij( |
| rescale_dry(si, (s // 4) / si[0]), rounding="ceil" |
| ), |
| pixel_ij( |
| rescale_dry(si, (s // 8) / si[0]), rounding="ceil" |
| ), |
| ] |
| self.channels = [ |
| 64, |
| 256, |
| 512, |
| ] |
| return |
|
|
| def forward(self, x, force_sizes_out=False, return_more=False): |
| ans = [] |
| x = x[:, :3] |
| x = self.resize(x) |
| x = self.resnet_preprocess(x) |
| x = self.conv1(x) |
| x = self.bn1(x) |
| x = self.relu(x) |
| ans.append(x) |
| x = self.maxpool(x) |
| x = self.layer1(x) |
| ans.append(x) |
| x = self.layer2(x) |
| ans.append(x) |
| if force_sizes_out or (self.sizes_out is None): |
| self.sizes_out = [tuple(q.shape[-2:]) for q in ans] |
| return ans, (locals() if return_more else None) |
|
|
|
|
| class NetNedt(nn.Module): |
| def __init__(self): |
| super().__init__() |
| chin = 3 + 1 + 4 + 4 + 1 + 1 |
| ch = 16 |
| chout = 1 |
| self.net = nn.Sequential( |
| nn.PReLU(chin), |
| nn.Conv2d(chin, ch, kernel_size=3, padding=1), |
| nn.BatchNorm2d(ch), |
| nn.PReLU(ch), |
| nn.Conv2d(ch, ch, kernel_size=3, padding=1), |
| nn.BatchNorm2d(ch), |
| nn.PReLU(ch), |
| nn.Conv2d(ch, chout, kernel_size=3, padding=1), |
| ) |
| return |
|
|
| def forward(self, out_base, out_base_nedt, hw_imgs, hw_masks, return_more=False): |
| cat = torch.cat( |
| [ |
| out_base, |
| out_base_nedt, |
| hw_imgs[0], |
| hw_imgs[1], |
| hw_masks[0], |
| hw_masks[1], |
| ], |
| dim=1, |
| ) |
| log = pixel_logit(cat.clip(0, 1)) |
| ans = torch.sigmoid(self.net(log)) |
| return ans, (locals() if return_more else None) |
|
|
|
|
| class NetTail(nn.Module): |
| def __init__(self): |
| super().__init__() |
| chin = 3 + 1 + 1 |
| ch = 16 |
| chout = 3 |
| self.net = nn.Sequential( |
| nn.PReLU(chin), |
| nn.Conv2d(chin, ch, kernel_size=3, padding=1), |
| nn.BatchNorm2d(ch), |
| nn.PReLU(ch), |
| nn.Conv2d(ch, ch, kernel_size=3, padding=1), |
| nn.BatchNorm2d(ch), |
| nn.PReLU(ch), |
| nn.Conv2d(ch, ch, kernel_size=3, padding=1), |
| nn.BatchNorm2d(ch), |
| nn.PReLU(ch), |
| nn.Conv2d(ch, chout, kernel_size=3, padding=1), |
| ) |
| return |
|
|
| def forward(self, out_base, out_base_nedt, pred_nedt, return_more=False): |
| cat = torch.cat( |
| [ |
| out_base, |
| out_base_nedt, |
| pred_nedt, |
| ], |
| dim=1, |
| ) |
| log = pixel_logit(cat.clip(0, 1)) |
| ans = torch.sigmoid(log[:, :3] + self.net(log)) |
| return ans, (locals() if return_more else None) |
|
|
|
|
| class SoftsplatLite(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.feature_extractor = ResnetFeatureExtractor( |
| ("torchvision", "resnet50"), |
| (540, 960), |
| ) |
| self.z_metric = FlowZMetric() |
| self.flow_downsamplers = [ |
| Interpolator(s, mode="bilinear") for s in self.feature_extractor.sizes_out |
| ] |
| self.gridnet_converter = GridnetConverter( |
| self.feature_extractor.channels, |
| [32, 64, 128], |
| ) |
| self.gridnet = Gridnet( |
| *[32, 64, 128], |
| total_dropout_p=0.0, |
| depth=1, |
| ) |
| self.nedt = NEDT() |
| self.half_warper = HalfWarper() |
| self.synthesizer = Synthesizer( |
| (540, 960), |
| self.half_warper.channels_image, |
| self.half_warper.channels_flow, |
| self.half_warper.channels_mask, |
| self.gridnet.channels_0, |
| ) |
| return |
|
|
| def forward(self, x, t=0.5, k=5, return_more=False): |
| rm = return_more |
| flow0, flow1 = x["flows"].swapaxes(0, 1) |
| img0, img1 = x["images"][:, 0], x["images"][:, -1] |
| (z0, z1), locs_z = self.z_metric(img0, img1, flow0, flow1, return_more=rm) |
| img0 = torch.cat([img0, self.nedt(img0)[0]], dim=1) |
| img1 = torch.cat([img1, self.nedt(img1)[0]], dim=1) |
|
|
| |
| (hw_imgs, hw_flows, hw_masks), locs_hw = self.half_warper( |
| img0, |
| img1, |
| flow0, |
| flow1, |
| z0, |
| z1, |
| k, |
| t=t, |
| return_more=rm, |
| ) |
|
|
| |
| feats0, locs_fe0 = self.feature_extractor(img0, return_more=rm) |
| feats1, locs_fe1 = self.feature_extractor(img1, return_more=rm) |
| warps = [] |
| for ft0, ft1, ds in zip(feats0, feats1, self.flow_downsamplers): |
| (w, _, _), _ = self.half_warper( |
| ft0, |
| ft1, |
| ds(flow0, 1), |
| ds(flow1, 1), |
| ds(z0), |
| ds(z1), |
| k, |
| t=t, |
| ) |
| warps.append((w[0] + w[1]) / 2) |
| feats = self.gridnet(self.gridnet_converter(warps)) |
|
|
| |
| pred, locs_synth = self.synthesizer( |
| hw_imgs, |
| hw_flows, |
| hw_masks, |
| [ |
| feats[0], |
| ], |
| return_more=rm, |
| ) |
| return pred, (locals() if rm else None) |
|
|
|
|
| class DTM(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.net_nedt = NetNedt() |
| self.net_tail = NetTail() |
| self.nedt = NEDT() |
| return |
|
|
| def forward(self, x, out_base, locs_base, return_more=False): |
| rm = return_more |
| with torch.no_grad(): |
| out_base_nedt, locs_base_nedt = self.nedt(out_base, return_more=rm) |
| hw_imgs, hw_masks = locs_base["hw_imgs"], locs_base["hw_masks"] |
| pred_nedt, locs_nedt = self.net_nedt( |
| out_base, out_base_nedt, hw_imgs, hw_masks, return_more=rm |
| ) |
| pred, locs_tail = self.net_tail( |
| out_base, out_base_nedt, pred_nedt.clone().detach(), return_more=rm |
| ) |
| return torch.cat([pred, pred_nedt], dim=1), (locals() if rm else None) |
|
|
|
|
| class RAFT(nn.Module): |
| def __init__(self, path="/workspace/tensorrt/models/anime_interp_full.ckpt"): |
| super().__init__() |
| self.raft = RFR( |
| Namespace( |
| small=False, |
| mixed_precision=False, |
| ) |
| ) |
| if path is not None: |
| sd = torch.load(path)["model_state_dict"] |
| self.raft.load_state_dict( |
| { |
| k[len("module.flownet.") :]: v |
| for k, v in sd.items() |
| if k.startswith("module.flownet.") |
| }, |
| strict=False, |
| ) |
| return |
|
|
| def forward(self, img0, img1, flow0=None, iters=12, return_more=False): |
| if flow0 is not None: |
| flow0 = flow0.flip(dims=(1,)) |
| out = self.raft(img1, img0, iters=iters, flow_init=flow0) |
| return out[0].flip(dims=(1,)), (locals() if return_more else None) |
|
|