Isi99999 commited on
Commit
51416db
·
verified ·
1 Parent(s): f381412

Adding RIFE_HDv3.py to 4.25

Browse files
Files changed (1) hide show
  1. 4.25/train_log/RIFE_HDv3.py +89 -0
4.25/train_log/RIFE_HDv3.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from torch.optim import AdamW
5
+ import torch.optim as optim
6
+ import itertools
7
+ from model.warplayer import warp
8
+ from torch.nn.parallel import DistributedDataParallel as DDP
9
+ from train_log.IFNet_HDv3 import *
10
+ import torch.nn.functional as F
11
+ from model.loss import *
12
+
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+
15
+ class Model:
16
+ def __init__(self, local_rank=-1):
17
+ self.flownet = IFNet()
18
+ self.device()
19
+ self.optimG = AdamW(self.flownet.parameters(), lr=1e-6, weight_decay=1e-4)
20
+ self.epe = EPE()
21
+ self.version = 4.25
22
+ # self.vgg = VGGPerceptualLoss().to(device)
23
+ self.sobel = SOBEL()
24
+ if local_rank != -1:
25
+ self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank)
26
+
27
+ def train(self):
28
+ self.flownet.train()
29
+
30
+ def eval(self):
31
+ self.flownet.eval()
32
+
33
+ def device(self):
34
+ self.flownet.to(device)
35
+
36
+ def load_model(self, path, rank=0):
37
+ def convert(param):
38
+ if rank == -1:
39
+ return {
40
+ k.replace("module.", ""): v
41
+ for k, v in param.items()
42
+ if "module." in k
43
+ }
44
+ else:
45
+ return param
46
+ if rank <= 0:
47
+ if torch.cuda.is_available():
48
+ self.flownet.load_state_dict(convert(torch.load('{}/flownet.pkl'.format(path))), False)
49
+ else:
50
+ self.flownet.load_state_dict(convert(torch.load('{}/flownet.pkl'.format(path), map_location ='cpu')), False)
51
+
52
+ def save_model(self, path, rank=0):
53
+ if rank == 0:
54
+ torch.save(self.flownet.state_dict(),'{}/flownet.pkl'.format(path))
55
+
56
+ def inference(self, img0, img1, timestep=0.5, scale=1.0):
57
+ imgs = torch.cat((img0, img1), 1)
58
+ scale_list = [16/scale, 8/scale, 4/scale, 2/scale, 1/scale]
59
+ flow, mask, merged = self.flownet(imgs, timestep, scale_list)
60
+ return merged[-1]
61
+
62
+ def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
63
+ for param_group in self.optimG.param_groups:
64
+ param_group['lr'] = learning_rate
65
+ img0 = imgs[:, :3]
66
+ img1 = imgs[:, 3:]
67
+ if training:
68
+ self.train()
69
+ else:
70
+ self.eval()
71
+ scale = [16, 8, 4, 2, 1]
72
+ flow, mask, merged = self.flownet(torch.cat((imgs, gt), 1), scale=scale, training=training)
73
+ loss_l1 = (merged[-1] - gt).abs().mean()
74
+ loss_smooth = self.sobel(flow[-1], flow[-1]*0).mean()
75
+ # loss_vgg = self.vgg(merged[-1], gt)
76
+ if training:
77
+ self.optimG.zero_grad()
78
+ loss_G = loss_l1 + loss_cons + loss_smooth * 0.1
79
+ loss_G.backward()
80
+ self.optimG.step()
81
+ else:
82
+ flow_teacher = flow[2]
83
+ return merged[-1], {
84
+ 'mask': mask,
85
+ 'flow': flow[-1][:, :2],
86
+ 'loss_l1': loss_l1,
87
+ 'loss_cons': loss_cons,
88
+ 'loss_smooth': loss_smooth,
89
+ }