| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from raft import RAFT |
| | from nnutils import make_conv_2d, make_upscale_2d, make_downscale_2d, ResBlock2d, Identity |
| |
|
| |
|
| | class ImportanceWeights(torch.nn.Module): |
| | def __init__(self, opt): |
| | super().__init__() |
| |
|
| | if opt.small: |
| | in_dim = 128 |
| | else: |
| | in_dim = 256 |
| | fn_0 = 16 |
| | self.input_fn = fn_0 + 3 * 2 |
| | fn_1 = 16 |
| | self.conv1 = torch.nn.Conv2d(in_channels=in_dim, out_channels=fn_0, kernel_size=3, stride=1, padding=1) |
| |
|
| | if opt.use_batch_norm: |
| | custom_batch_norm = torch.nn.BatchNorm2d |
| | else: |
| | custom_batch_norm = Identity |
| |
|
| | self.model = nn.Sequential( |
| | make_conv_2d(self.input_fn, fn_1, n_blocks=1, normalization=custom_batch_norm), |
| | ResBlock2d(fn_1, normalization=custom_batch_norm), |
| | ResBlock2d(fn_1, normalization=custom_batch_norm), |
| | ResBlock2d(fn_1, normalization=custom_batch_norm), |
| | nn.Conv2d(fn_1, 1, kernel_size=3, padding=1) |
| | |
| | ) |
| |
|
| | def forward(self, x, features): |
| | |
| | features = self.conv1(features) |
| | x = torch.cat([features, x], 1) |
| | assert x.shape[1] == self.input_fn |
| | x = self.model(x) |
| | print(x) |
| | print(x.max(), x.min(), x.mean()) |
| |
|
| | return torch.nn.Sigmoid()(x) |
| |
|
| | class NeuralNRT(nn.Module): |
| | def __init__(self, opt, path=None, device="cuda:0"): |
| | super(NeuralNRT, self).__init__() |
| | self.opt = opt |
| | self.CorresPred = RAFT(opt) |
| | self.ImportanceW = ImportanceWeights(opt) |
| | if path is not None: |
| | data = torch.load(path,map_location='cpu') |
| | if 'state_dict' in data.keys(): |
| | self.CorresPred.load_state_dict(data['state_dict']) |
| | print("load done") |
| | else: |
| | self.CorresPred.load_state_dict({k.replace('module.', ''):v for k,v in data.items()}) |
| | print("load done") |
| | def forward(self, src_im,tar_im, src_im_raw, tar_im_raw, Crop_param): |
| | N=src_im.shape[0] |
| | src_im = src_im*255.0 |
| | tar_im = tar_im*255.0 |
| | flow_fw_crop, feature_fw_crop = self.CorresPred(src_im, tar_im, iters=self.opt.iters) |
| |
|
| | xx = torch.arange(0, self.opt.width).view(1,-1).repeat(self.opt.height,1) |
| | yy = torch.arange(0, self.opt.height).view(-1,1).repeat(1,self.opt.width) |
| | xx = xx.view(1,1,self.opt.height,self.opt.width).repeat(N,1,1,1) |
| | yy = yy.view(1,1,self.opt.height,self.opt.width).repeat(N,1,1,1) |
| | grid = torch.cat((xx,yy),1).float() |
| | grid = grid.to(src_im.device) |
| |
|
| | grid_crop = grid[:, :, :self.opt.crop_height, :self.opt.crop_width] |
| |
|
| | flow_fw = torch.zeros((N, 2, self.opt.height, self.opt.width), device=src_im.device) |
| |
|
| | leftup1 = torch.cat((Crop_param[:, 0:1, 0], Crop_param[:, 2:3, 0]), 1)[:, :, None, None] |
| | leftup2 = torch.cat((Crop_param[:, 4:5, 0], Crop_param[:, 6:7, 0]), 1)[:, :, None, None] |
| |
|
| | scale1 = torch.cat(((Crop_param[:, 1:2, 0]-Crop_param[:, 0:1, 0]).float() / self.opt.crop_width, (Crop_param[:, 3:4, 0]-Crop_param[:, 2:3, 0]).float() / self.opt.crop_height), 1)[:, :, None, None] |
| | scale2 = torch.cat(((Crop_param[:, 5:6, 0]-Crop_param[:, 4:5, 0]).float() / self.opt.crop_width, (Crop_param[:, 7:8, 0]-Crop_param[:, 6:7, 0]).float() / self.opt.crop_height), 1)[:, :, None, None] |
| | |
| | flow_fw_crop = (scale2 - scale1) * grid_crop + scale2 * flow_fw_crop |
| | for i in range(N): |
| | flow_fw_cropi = F.interpolate(flow_fw_crop[i:(i+1)], ((Crop_param[i, 3, 0]-Crop_param[i, 2, 0]).item(), (Crop_param[i, 1, 0]-Crop_param[i, 0, 0]).item()), mode='bilinear', align_corners=True) |
| | flow_fw_cropi =flow_fw_cropi + (leftup2 - leftup1)[i:(i+1), :, :, :] |
| | flow_fw[i, :, Crop_param[i, 2, 0]:Crop_param[i, 3, 0], Crop_param[i, 0, 0]:Crop_param[i, 1, 0]] = flow_fw_cropi[0] |
| | return flow_fw |
| |
|