| | import torch
|
| | import torch.nn as nn
|
| | import numpy as np
|
| | from torch.optim import AdamW
|
| | import torch.optim as optim
|
| | import itertools
|
| | from torch.nn.parallel import DistributedDataParallel as DDP
|
| | from .IFNet_HDv3 import *
|
| | import torch.nn.functional as F
|
| |
|
| |
|
| |
|
| |
|
| | class Model:
|
| | def __init__(self, local_rank=-1):
|
| | self.flownet = IFNet()
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | if local_rank != -1:
|
| | self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank)
|
| |
|
| | def train(self):
|
| | self.flownet.train()
|
| |
|
| | def eval(self):
|
| | self.flownet.eval()
|
| |
|
| | def to(self, device):
|
| | self.flownet.to(device)
|
| |
|
| | def load_model(self, path, rank=0, device = "cuda"):
|
| | self.device = device
|
| | def convert(param):
|
| | if rank == -1:
|
| | return {
|
| | k.replace("module.", ""): v
|
| | for k, v in param.items()
|
| | if "module." in k
|
| | }
|
| | else:
|
| | return param
|
| | self.flownet.load_state_dict(convert(torch.load(path, map_location=device)))
|
| |
|
| | def save_model(self, path, rank=0):
|
| | if rank == 0:
|
| | torch.save(self.flownet.state_dict(),'{}/flownet.pkl'.format(path))
|
| |
|
| | def inference(self, img0, img1, scale=1.0):
|
| | imgs = torch.cat((img0, img1), 1)
|
| | scale_list = [4/scale, 2/scale, 1/scale]
|
| | flow, mask, merged = self.flownet(imgs, scale_list)
|
| | return merged[2]
|
| |
|
| | def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
|
| | for param_group in self.optimG.param_groups:
|
| | param_group['lr'] = learning_rate
|
| | img0 = imgs[:, :3]
|
| | img1 = imgs[:, 3:]
|
| | if training:
|
| | self.train()
|
| | else:
|
| | self.eval()
|
| | scale = [4, 2, 1]
|
| | flow, mask, merged = self.flownet(torch.cat((imgs, gt), 1), scale=scale, training=training)
|
| | loss_l1 = (merged[2] - gt).abs().mean()
|
| | loss_smooth = self.sobel(flow[2], flow[2]*0).mean()
|
| |
|
| | if training:
|
| | self.optimG.zero_grad()
|
| | loss_G = loss_cons + loss_smooth * 0.1
|
| | loss_G.backward()
|
| | self.optimG.step()
|
| | else:
|
| | flow_teacher = flow[2]
|
| | return merged[2], {
|
| | 'mask': mask,
|
| | 'flow': flow[2][:, :2],
|
| | 'loss_l1': loss_l1,
|
| | 'loss_cons': loss_cons,
|
| | 'loss_smooth': loss_smooth,
|
| | }
|
| |
|