Adding RIFE_HDv3.py to 4.25
Browse files- 4.25/train_log/RIFE_HDv3.py +89 -0
4.25/train_log/RIFE_HDv3.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
from torch.optim import AdamW
|
| 5 |
+
import torch.optim as optim
|
| 6 |
+
import itertools
|
| 7 |
+
from model.warplayer import warp
|
| 8 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 9 |
+
from train_log.IFNet_HDv3 import *
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from model.loss import *
|
| 12 |
+
|
| 13 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 14 |
+
|
| 15 |
+
class Model:
|
| 16 |
+
def __init__(self, local_rank=-1):
|
| 17 |
+
self.flownet = IFNet()
|
| 18 |
+
self.device()
|
| 19 |
+
self.optimG = AdamW(self.flownet.parameters(), lr=1e-6, weight_decay=1e-4)
|
| 20 |
+
self.epe = EPE()
|
| 21 |
+
self.version = 4.25
|
| 22 |
+
# self.vgg = VGGPerceptualLoss().to(device)
|
| 23 |
+
self.sobel = SOBEL()
|
| 24 |
+
if local_rank != -1:
|
| 25 |
+
self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank)
|
| 26 |
+
|
| 27 |
+
def train(self):
|
| 28 |
+
self.flownet.train()
|
| 29 |
+
|
| 30 |
+
def eval(self):
|
| 31 |
+
self.flownet.eval()
|
| 32 |
+
|
| 33 |
+
def device(self):
|
| 34 |
+
self.flownet.to(device)
|
| 35 |
+
|
| 36 |
+
def load_model(self, path, rank=0):
|
| 37 |
+
def convert(param):
|
| 38 |
+
if rank == -1:
|
| 39 |
+
return {
|
| 40 |
+
k.replace("module.", ""): v
|
| 41 |
+
for k, v in param.items()
|
| 42 |
+
if "module." in k
|
| 43 |
+
}
|
| 44 |
+
else:
|
| 45 |
+
return param
|
| 46 |
+
if rank <= 0:
|
| 47 |
+
if torch.cuda.is_available():
|
| 48 |
+
self.flownet.load_state_dict(convert(torch.load('{}/flownet.pkl'.format(path))), False)
|
| 49 |
+
else:
|
| 50 |
+
self.flownet.load_state_dict(convert(torch.load('{}/flownet.pkl'.format(path), map_location ='cpu')), False)
|
| 51 |
+
|
| 52 |
+
def save_model(self, path, rank=0):
|
| 53 |
+
if rank == 0:
|
| 54 |
+
torch.save(self.flownet.state_dict(),'{}/flownet.pkl'.format(path))
|
| 55 |
+
|
| 56 |
+
def inference(self, img0, img1, timestep=0.5, scale=1.0):
|
| 57 |
+
imgs = torch.cat((img0, img1), 1)
|
| 58 |
+
scale_list = [16/scale, 8/scale, 4/scale, 2/scale, 1/scale]
|
| 59 |
+
flow, mask, merged = self.flownet(imgs, timestep, scale_list)
|
| 60 |
+
return merged[-1]
|
| 61 |
+
|
| 62 |
+
def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
|
| 63 |
+
for param_group in self.optimG.param_groups:
|
| 64 |
+
param_group['lr'] = learning_rate
|
| 65 |
+
img0 = imgs[:, :3]
|
| 66 |
+
img1 = imgs[:, 3:]
|
| 67 |
+
if training:
|
| 68 |
+
self.train()
|
| 69 |
+
else:
|
| 70 |
+
self.eval()
|
| 71 |
+
scale = [16, 8, 4, 2, 1]
|
| 72 |
+
flow, mask, merged = self.flownet(torch.cat((imgs, gt), 1), scale=scale, training=training)
|
| 73 |
+
loss_l1 = (merged[-1] - gt).abs().mean()
|
| 74 |
+
loss_smooth = self.sobel(flow[-1], flow[-1]*0).mean()
|
| 75 |
+
# loss_vgg = self.vgg(merged[-1], gt)
|
| 76 |
+
if training:
|
| 77 |
+
self.optimG.zero_grad()
|
| 78 |
+
loss_G = loss_l1 + loss_cons + loss_smooth * 0.1
|
| 79 |
+
loss_G.backward()
|
| 80 |
+
self.optimG.step()
|
| 81 |
+
else:
|
| 82 |
+
flow_teacher = flow[2]
|
| 83 |
+
return merged[-1], {
|
| 84 |
+
'mask': mask,
|
| 85 |
+
'flow': flow[-1][:, :2],
|
| 86 |
+
'loss_l1': loss_l1,
|
| 87 |
+
'loss_cons': loss_cons,
|
| 88 |
+
'loss_smooth': loss_smooth,
|
| 89 |
+
}
|