r3gm commited on
Commit
321992a
·
verified ·
1 Parent(s): a680456

Upload 4 files

Browse files
inference_video.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import argparse
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+ from torch.nn import functional as F
8
+ import warnings
9
+ import _thread
10
+ import skvideo.io
11
+ from queue import Queue, Empty
12
+ from model.pytorch_msssim import ssim_matlab
13
+
14
+ warnings.filterwarnings("ignore")
15
+
16
+ def transferAudio(sourceVideo, targetVideo):
17
+ import shutil
18
+ import moviepy.editor
19
+ tempAudioFileName = "./temp/audio.mkv"
20
+
21
+ # split audio from original video file and store in "temp" directory
22
+ if True:
23
+
24
+ # clear old "temp" directory if it exits
25
+ if os.path.isdir("temp"):
26
+ # remove temp directory
27
+ shutil.rmtree("temp")
28
+ # create new "temp" directory
29
+ os.makedirs("temp")
30
+ # extract audio from video
31
+ os.system('ffmpeg -y -i "{}" -c:a copy -vn {}'.format(sourceVideo, tempAudioFileName))
32
+
33
+ targetNoAudio = os.path.splitext(targetVideo)[0] + "_noaudio" + os.path.splitext(targetVideo)[1]
34
+ os.rename(targetVideo, targetNoAudio)
35
+ # combine audio file and new video file
36
+ os.system('ffmpeg -y -i "{}" -i {} -c copy "{}"'.format(targetNoAudio, tempAudioFileName, targetVideo))
37
+
38
+ if os.path.getsize(targetVideo) == 0: # if ffmpeg failed to merge the video and audio together try converting the audio to aac
39
+ tempAudioFileName = "./temp/audio.m4a"
40
+ os.system('ffmpeg -y -i "{}" -c:a aac -b:a 160k -vn {}'.format(sourceVideo, tempAudioFileName))
41
+ os.system('ffmpeg -y -i "{}" -i {} -c copy "{}"'.format(targetNoAudio, tempAudioFileName, targetVideo))
42
+ if (os.path.getsize(targetVideo) == 0): # if aac is not supported by selected format
43
+ os.rename(targetNoAudio, targetVideo)
44
+ print("Audio transfer failed. Interpolated video will have no audio")
45
+ else:
46
+ print("Lossless audio transfer failed. Audio was transcoded to AAC (M4A) instead.")
47
+
48
+ # remove audio-less video
49
+ os.remove(targetNoAudio)
50
+ else:
51
+ os.remove(targetNoAudio)
52
+
53
+ # remove temp directory
54
+ shutil.rmtree("temp")
55
+
56
+ parser = argparse.ArgumentParser(description='Interpolation for a pair of images')
57
+ parser.add_argument('--video', dest='video', type=str, default=None)
58
+ parser.add_argument('--output', dest='output', type=str, default=None)
59
+ parser.add_argument('--img', dest='img', type=str, default=None)
60
+ parser.add_argument('--montage', dest='montage', action='store_true', help='montage origin video')
61
+ parser.add_argument('--model', dest='modelDir', type=str, default='train_log', help='directory with trained model files')
62
+ parser.add_argument('--fp16', dest='fp16', action='store_true', help='fp16 mode for faster and more lightweight inference on cards with Tensor Cores')
63
+ parser.add_argument('--UHD', dest='UHD', action='store_true', help='support 4k video')
64
+ parser.add_argument('--scale', dest='scale', type=float, default=1.0, help='Try scale=0.5 for 4k video')
65
+ parser.add_argument('--skip', dest='skip', action='store_true', help='whether to remove static frames before processing')
66
+ parser.add_argument('--fps', dest='fps', type=int, default=None)
67
+ parser.add_argument('--png', dest='png', action='store_true', help='whether to vid_out png format vid_outs')
68
+ parser.add_argument('--ext', dest='ext', type=str, default='mp4', help='vid_out video extension')
69
+ parser.add_argument('--exp', dest='exp', type=int, default=1)
70
+ parser.add_argument('--multi', dest='multi', type=int, default=2)
71
+
72
+ args = parser.parse_args()
73
+ if args.exp != 1:
74
+ args.multi = (2 ** args.exp)
75
+ assert (not args.video is None or not args.img is None)
76
+ if args.skip:
77
+ print("skip flag is abandoned, please refer to issue #207.")
78
+ if args.UHD and args.scale==1.0:
79
+ args.scale = 0.5
80
+ assert args.scale in [0.25, 0.5, 1.0, 2.0, 4.0]
81
+ if not args.img is None:
82
+ args.png = True
83
+
84
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
85
+ torch.set_grad_enabled(False)
86
+ if torch.cuda.is_available():
87
+ torch.backends.cudnn.enabled = True
88
+ torch.backends.cudnn.benchmark = True
89
+ if(args.fp16):
90
+ torch.set_default_tensor_type(torch.cuda.HalfTensor)
91
+
92
+ from train_log.RIFE_HDv3 import Model
93
+ model = Model()
94
+ if not hasattr(model, 'version'):
95
+ model.version = 0
96
+ model.load_model(args.modelDir, -1)
97
+ print("Loaded 3.x/4.x HD model.")
98
+ model.eval()
99
+ model.device()
100
+
101
+ if not args.video is None:
102
+ videoCapture = cv2.VideoCapture(args.video)
103
+ fps = videoCapture.get(cv2.CAP_PROP_FPS)
104
+ tot_frame = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT)
105
+ videoCapture.release()
106
+ if args.fps is None:
107
+ fpsNotAssigned = True
108
+ args.fps = fps * args.multi
109
+ else:
110
+ fpsNotAssigned = False
111
+ videogen = skvideo.io.vreader(args.video)
112
+ lastframe = next(videogen)
113
+ fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
114
+ video_path_wo_ext, ext = os.path.splitext(args.video)
115
+ print('{}.{}, {} frames in total, {}FPS to {}FPS'.format(video_path_wo_ext, args.ext, tot_frame, fps, args.fps))
116
+ if args.png == False and fpsNotAssigned == True:
117
+ print("The audio will be merged after interpolation process")
118
+ else:
119
+ print("Will not merge audio because using png or fps flag!")
120
+ else:
121
+ videogen = []
122
+ for f in os.listdir(args.img):
123
+ if 'png' in f:
124
+ videogen.append(f)
125
+ tot_frame = len(videogen)
126
+ videogen.sort(key= lambda x:int(x[:-4]))
127
+ lastframe = cv2.imread(os.path.join(args.img, videogen[0]), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy()
128
+ videogen = videogen[1:]
129
+ h, w, _ = lastframe.shape
130
+ vid_out_name = None
131
+ vid_out = None
132
+ if args.png:
133
+ if not os.path.exists('vid_out'):
134
+ os.mkdir('vid_out')
135
+ else:
136
+ if args.output is not None:
137
+ vid_out_name = args.output
138
+ else:
139
+ vid_out_name = '{}_{}X_{}fps.{}'.format(video_path_wo_ext, args.multi, int(np.round(args.fps)), args.ext)
140
+ vid_out = cv2.VideoWriter(vid_out_name, fourcc, args.fps, (w, h))
141
+
142
+ def clear_write_buffer(user_args, write_buffer):
143
+ cnt = 0
144
+ while True:
145
+ item = write_buffer.get()
146
+ if item is None:
147
+ break
148
+ if user_args.png:
149
+ cv2.imwrite('vid_out/{:0>7d}.png'.format(cnt), item[:, :, ::-1])
150
+ cnt += 1
151
+ else:
152
+ vid_out.write(item[:, :, ::-1])
153
+
154
+ def build_read_buffer(user_args, read_buffer, videogen):
155
+ try:
156
+ for frame in videogen:
157
+ if not user_args.img is None:
158
+ frame = cv2.imread(os.path.join(user_args.img, frame), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy()
159
+ if user_args.montage:
160
+ frame = frame[:, left: left + w]
161
+ read_buffer.put(frame)
162
+ except:
163
+ pass
164
+ read_buffer.put(None)
165
+
166
+ def make_inference(I0, I1, n):
167
+ global model
168
+ if model.version >= 3.9:
169
+ res = []
170
+ for i in range(n):
171
+ res.append(model.inference(I0, I1, (i+1) * 1. / (n+1), args.scale))
172
+ return res
173
+ else:
174
+ middle = model.inference(I0, I1, args.scale)
175
+ if n == 1:
176
+ return [middle]
177
+ first_half = make_inference(I0, middle, n=n//2)
178
+ second_half = make_inference(middle, I1, n=n//2)
179
+ if n%2:
180
+ return [*first_half, middle, *second_half]
181
+ else:
182
+ return [*first_half, *second_half]
183
+
184
+ def pad_image(img):
185
+ if(args.fp16):
186
+ return F.pad(img, padding).half()
187
+ else:
188
+ return F.pad(img, padding)
189
+
190
+ if args.montage:
191
+ left = w // 4
192
+ w = w // 2
193
+ tmp = max(128, int(128 / args.scale))
194
+ ph = ((h - 1) // tmp + 1) * tmp
195
+ pw = ((w - 1) // tmp + 1) * tmp
196
+ padding = (0, pw - w, 0, ph - h)
197
+ pbar = tqdm(total=tot_frame)
198
+ if args.montage:
199
+ lastframe = lastframe[:, left: left + w]
200
+ write_buffer = Queue(maxsize=500)
201
+ read_buffer = Queue(maxsize=500)
202
+ _thread.start_new_thread(build_read_buffer, (args, read_buffer, videogen))
203
+ _thread.start_new_thread(clear_write_buffer, (args, write_buffer))
204
+
205
+ I1 = torch.from_numpy(np.transpose(lastframe, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
206
+ I1 = pad_image(I1)
207
+ temp = None # save lastframe when processing static frame
208
+
209
+ while True:
210
+ if temp is not None:
211
+ frame = temp
212
+ temp = None
213
+ else:
214
+ frame = read_buffer.get()
215
+ if frame is None:
216
+ break
217
+ I0 = I1
218
+ I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
219
+ I1 = pad_image(I1)
220
+ I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False)
221
+ I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
222
+ ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
223
+
224
+ break_flag = False
225
+ if ssim > 0.996:
226
+ frame = read_buffer.get() # read a new frame
227
+ if frame is None:
228
+ break_flag = True
229
+ frame = lastframe
230
+ else:
231
+ temp = frame
232
+ I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
233
+ I1 = pad_image(I1)
234
+ I1 = model.inference(I0, I1, scale=args.scale)
235
+ I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
236
+ ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
237
+ frame = (I1[0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w]
238
+
239
+ if ssim < 0.2:
240
+ output = []
241
+ for i in range(args.multi - 1):
242
+ output.append(I0)
243
+ '''
244
+ output = []
245
+ step = 1 / args.multi
246
+ alpha = 0
247
+ for i in range(args.multi - 1):
248
+ alpha += step
249
+ beta = 1-alpha
250
+ output.append(torch.from_numpy(np.transpose((cv2.addWeighted(frame[:, :, ::-1], alpha, lastframe[:, :, ::-1], beta, 0)[:, :, ::-1].copy()), (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.)
251
+ '''
252
+ else:
253
+ output = make_inference(I0, I1, args.multi - 1)
254
+
255
+ if args.montage:
256
+ write_buffer.put(np.concatenate((lastframe, lastframe), 1))
257
+ for mid in output:
258
+ mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0)))
259
+ write_buffer.put(np.concatenate((lastframe, mid[:h, :w]), 1))
260
+ else:
261
+ write_buffer.put(lastframe)
262
+ for mid in output:
263
+ mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0)))
264
+ write_buffer.put(mid[:h, :w])
265
+ pbar.update(1)
266
+ lastframe = frame
267
+ if break_flag:
268
+ break
269
+
270
+ if args.montage:
271
+ write_buffer.put(np.concatenate((lastframe, lastframe), 1))
272
+ else:
273
+ write_buffer.put(lastframe)
274
+ write_buffer.put(None)
275
+
276
+ import time
277
+ while(not write_buffer.empty()):
278
+ time.sleep(0.1)
279
+ pbar.close()
280
+ if not vid_out is None:
281
+ vid_out.release()
282
+
283
+ # move audio to new video file if appropriate
284
+ if args.png == False and fpsNotAssigned == True and not args.video is None:
285
+ try:
286
+ transferAudio(args.video, vid_out_name)
287
+ except:
288
+ print("Audio transfer failed. Interpolated video will have no audio")
289
+ targetNoAudio = os.path.splitext(vid_out_name)[0] + "_noaudio" + os.path.splitext(vid_out_name)[1]
290
+ os.rename(targetNoAudio, vid_out_name)
model/loss.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torchvision.models as models
6
+
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
+
10
+ class EPE(nn.Module):
11
+ def __init__(self):
12
+ super(EPE, self).__init__()
13
+
14
+ def forward(self, flow, gt, loss_mask):
15
+ loss_map = (flow - gt.detach()) ** 2
16
+ loss_map = (loss_map.sum(1, True) + 1e-6) ** 0.5
17
+ return (loss_map * loss_mask)
18
+
19
+
20
+ class Ternary(nn.Module):
21
+ def __init__(self):
22
+ super(Ternary, self).__init__()
23
+ patch_size = 7
24
+ out_channels = patch_size * patch_size
25
+ self.w = np.eye(out_channels).reshape(
26
+ (patch_size, patch_size, 1, out_channels))
27
+ self.w = np.transpose(self.w, (3, 2, 0, 1))
28
+ self.w = torch.tensor(self.w).float().to(device)
29
+
30
+ def transform(self, img):
31
+ patches = F.conv2d(img, self.w, padding=3, bias=None)
32
+ transf = patches - img
33
+ transf_norm = transf / torch.sqrt(0.81 + transf**2)
34
+ return transf_norm
35
+
36
+ def rgb2gray(self, rgb):
37
+ r, g, b = rgb[:, 0:1, :, :], rgb[:, 1:2, :, :], rgb[:, 2:3, :, :]
38
+ gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
39
+ return gray
40
+
41
+ def hamming(self, t1, t2):
42
+ dist = (t1 - t2) ** 2
43
+ dist_norm = torch.mean(dist / (0.1 + dist), 1, True)
44
+ return dist_norm
45
+
46
+ def valid_mask(self, t, padding):
47
+ n, _, h, w = t.size()
48
+ inner = torch.ones(n, 1, h - 2 * padding, w - 2 * padding).type_as(t)
49
+ mask = F.pad(inner, [padding] * 4)
50
+ return mask
51
+
52
+ def forward(self, img0, img1):
53
+ img0 = self.transform(self.rgb2gray(img0))
54
+ img1 = self.transform(self.rgb2gray(img1))
55
+ return self.hamming(img0, img1) * self.valid_mask(img0, 1)
56
+
57
+
58
+ class SOBEL(nn.Module):
59
+ def __init__(self):
60
+ super(SOBEL, self).__init__()
61
+ self.kernelX = torch.tensor([
62
+ [1, 0, -1],
63
+ [2, 0, -2],
64
+ [1, 0, -1],
65
+ ]).float()
66
+ self.kernelY = self.kernelX.clone().T
67
+ self.kernelX = self.kernelX.unsqueeze(0).unsqueeze(0).to(device)
68
+ self.kernelY = self.kernelY.unsqueeze(0).unsqueeze(0).to(device)
69
+
70
+ def forward(self, pred, gt):
71
+ N, C, H, W = pred.shape[0], pred.shape[1], pred.shape[2], pred.shape[3]
72
+ img_stack = torch.cat(
73
+ [pred.reshape(N*C, 1, H, W), gt.reshape(N*C, 1, H, W)], 0)
74
+ sobel_stack_x = F.conv2d(img_stack, self.kernelX, padding=1)
75
+ sobel_stack_y = F.conv2d(img_stack, self.kernelY, padding=1)
76
+ pred_X, gt_X = sobel_stack_x[:N*C], sobel_stack_x[N*C:]
77
+ pred_Y, gt_Y = sobel_stack_y[:N*C], sobel_stack_y[N*C:]
78
+
79
+ L1X, L1Y = torch.abs(pred_X-gt_X), torch.abs(pred_Y-gt_Y)
80
+ loss = (L1X+L1Y)
81
+ return loss
82
+
83
+ class MeanShift(nn.Conv2d):
84
+ def __init__(self, data_mean, data_std, data_range=1, norm=True):
85
+ c = len(data_mean)
86
+ super(MeanShift, self).__init__(c, c, kernel_size=1)
87
+ std = torch.Tensor(data_std)
88
+ self.weight.data = torch.eye(c).view(c, c, 1, 1)
89
+ if norm:
90
+ self.weight.data.div_(std.view(c, 1, 1, 1))
91
+ self.bias.data = -1 * data_range * torch.Tensor(data_mean)
92
+ self.bias.data.div_(std)
93
+ else:
94
+ self.weight.data.mul_(std.view(c, 1, 1, 1))
95
+ self.bias.data = data_range * torch.Tensor(data_mean)
96
+ self.requires_grad = False
97
+
98
+ class VGGPerceptualLoss(torch.nn.Module):
99
+ def __init__(self, rank=0):
100
+ super(VGGPerceptualLoss, self).__init__()
101
+ blocks = []
102
+ pretrained = True
103
+ self.vgg_pretrained_features = models.vgg19(pretrained=pretrained).features
104
+ self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda()
105
+ for param in self.parameters():
106
+ param.requires_grad = False
107
+
108
+ def forward(self, X, Y, indices=None):
109
+ X = self.normalize(X)
110
+ Y = self.normalize(Y)
111
+ indices = [2, 7, 12, 21, 30]
112
+ weights = [1.0/2.6, 1.0/4.8, 1.0/3.7, 1.0/5.6, 10/1.5]
113
+ k = 0
114
+ loss = 0
115
+ for i in range(indices[-1]):
116
+ X = self.vgg_pretrained_features[i](X)
117
+ Y = self.vgg_pretrained_features[i](Y)
118
+ if (i+1) in indices:
119
+ loss += weights[k] * (X - Y.detach()).abs().mean() * 0.1
120
+ k += 1
121
+ return loss
122
+
123
+ if __name__ == '__main__':
124
+ img0 = torch.zeros(3, 3, 256, 256).float().to(device)
125
+ img1 = torch.tensor(np.random.normal(
126
+ 0, 1, (3, 3, 256, 256))).float().to(device)
127
+ ternary_loss = Ternary()
128
+ print(ternary_loss(img0, img1).shape)
model/pytorch_msssim/__init__.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from math import exp
4
+ import numpy as np
5
+
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+
8
+ def gaussian(window_size, sigma):
9
+ gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
10
+ return gauss/gauss.sum()
11
+
12
+
13
+ def create_window(window_size, channel=1):
14
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
15
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).to(device)
16
+ window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
17
+ return window
18
+
19
+ def create_window_3d(window_size, channel=1):
20
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
21
+ _2D_window = _1D_window.mm(_1D_window.t())
22
+ _3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t())
23
+ window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device)
24
+ return window
25
+
26
+
27
+ def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
28
+ # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
29
+ if val_range is None:
30
+ if torch.max(img1) > 128:
31
+ max_val = 255
32
+ else:
33
+ max_val = 1
34
+
35
+ if torch.min(img1) < -0.5:
36
+ min_val = -1
37
+ else:
38
+ min_val = 0
39
+ L = max_val - min_val
40
+ else:
41
+ L = val_range
42
+
43
+ padd = 0
44
+ (_, channel, height, width) = img1.size()
45
+ if window is None:
46
+ real_size = min(window_size, height, width)
47
+ window = create_window(real_size, channel=channel).to(img1.device)
48
+
49
+ # mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
50
+ # mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
51
+ mu1 = F.conv2d(F.pad(img1, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel)
52
+ mu2 = F.conv2d(F.pad(img2, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel)
53
+
54
+ mu1_sq = mu1.pow(2)
55
+ mu2_sq = mu2.pow(2)
56
+ mu1_mu2 = mu1 * mu2
57
+
58
+ sigma1_sq = F.conv2d(F.pad(img1 * img1, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_sq
59
+ sigma2_sq = F.conv2d(F.pad(img2 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu2_sq
60
+ sigma12 = F.conv2d(F.pad(img1 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_mu2
61
+
62
+ C1 = (0.01 * L) ** 2
63
+ C2 = (0.03 * L) ** 2
64
+
65
+ v1 = 2.0 * sigma12 + C2
66
+ v2 = sigma1_sq + sigma2_sq + C2
67
+ cs = torch.mean(v1 / v2) # contrast sensitivity
68
+
69
+ ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
70
+
71
+ if size_average:
72
+ ret = ssim_map.mean()
73
+ else:
74
+ ret = ssim_map.mean(1).mean(1).mean(1)
75
+
76
+ if full:
77
+ return ret, cs
78
+ return ret
79
+
80
+
81
+ def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
82
+ # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
83
+ if val_range is None:
84
+ if torch.max(img1) > 128:
85
+ max_val = 255
86
+ else:
87
+ max_val = 1
88
+
89
+ if torch.min(img1) < -0.5:
90
+ min_val = -1
91
+ else:
92
+ min_val = 0
93
+ L = max_val - min_val
94
+ else:
95
+ L = val_range
96
+
97
+ padd = 0
98
+ (_, _, height, width) = img1.size()
99
+ if window is None:
100
+ real_size = min(window_size, height, width)
101
+ window = create_window_3d(real_size, channel=1).to(img1.device)
102
+ # Channel is set to 1 since we consider color images as volumetric images
103
+
104
+ img1 = img1.unsqueeze(1)
105
+ img2 = img2.unsqueeze(1)
106
+
107
+ mu1 = F.conv3d(F.pad(img1, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1)
108
+ mu2 = F.conv3d(F.pad(img2, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1)
109
+
110
+ mu1_sq = mu1.pow(2)
111
+ mu2_sq = mu2.pow(2)
112
+ mu1_mu2 = mu1 * mu2
113
+
114
+ sigma1_sq = F.conv3d(F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_sq
115
+ sigma2_sq = F.conv3d(F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu2_sq
116
+ sigma12 = F.conv3d(F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_mu2
117
+
118
+ C1 = (0.01 * L) ** 2
119
+ C2 = (0.03 * L) ** 2
120
+
121
+ v1 = 2.0 * sigma12 + C2
122
+ v2 = sigma1_sq + sigma2_sq + C2
123
+ cs = torch.mean(v1 / v2) # contrast sensitivity
124
+
125
+ ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
126
+
127
+ if size_average:
128
+ ret = ssim_map.mean()
129
+ else:
130
+ ret = ssim_map.mean(1).mean(1).mean(1)
131
+
132
+ if full:
133
+ return ret, cs
134
+ return ret
135
+
136
+
137
+ def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False):
138
+ device = img1.device
139
+ weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)
140
+ levels = weights.size()[0]
141
+ mssim = []
142
+ mcs = []
143
+ for _ in range(levels):
144
+ sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
145
+ mssim.append(sim)
146
+ mcs.append(cs)
147
+
148
+ img1 = F.avg_pool2d(img1, (2, 2))
149
+ img2 = F.avg_pool2d(img2, (2, 2))
150
+
151
+ mssim = torch.stack(mssim)
152
+ mcs = torch.stack(mcs)
153
+
154
+ # Normalize (to avoid NaNs during training unstable models, not compliant with original definition)
155
+ if normalize:
156
+ mssim = (mssim + 1) / 2
157
+ mcs = (mcs + 1) / 2
158
+
159
+ pow1 = mcs ** weights
160
+ pow2 = mssim ** weights
161
+ # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/
162
+ output = torch.prod(pow1[:-1] * pow2[-1])
163
+ return output
164
+
165
+
166
+ # Classes to re-use window
167
+ class SSIM(torch.nn.Module):
168
+ def __init__(self, window_size=11, size_average=True, val_range=None):
169
+ super(SSIM, self).__init__()
170
+ self.window_size = window_size
171
+ self.size_average = size_average
172
+ self.val_range = val_range
173
+
174
+ # Assume 3 channel for SSIM
175
+ self.channel = 3
176
+ self.window = create_window(window_size, channel=self.channel)
177
+
178
+ def forward(self, img1, img2):
179
+ (_, channel, _, _) = img1.size()
180
+
181
+ if channel == self.channel and self.window.dtype == img1.dtype:
182
+ window = self.window
183
+ else:
184
+ window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
185
+ self.window = window
186
+ self.channel = channel
187
+
188
+ _ssim = ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
189
+ dssim = (1 - _ssim) / 2
190
+ return dssim
191
+
192
+ class MSSSIM(torch.nn.Module):
193
+ def __init__(self, window_size=11, size_average=True, channel=3):
194
+ super(MSSSIM, self).__init__()
195
+ self.window_size = window_size
196
+ self.size_average = size_average
197
+ self.channel = channel
198
+
199
+ def forward(self, img1, img2):
200
+ return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average)
model/warplayer.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5
+ backwarp_tenGrid = {}
6
+
7
+
8
+ def warp(tenInput, tenFlow):
9
+ k = (str(tenFlow.device), str(tenFlow.size()))
10
+ if k not in backwarp_tenGrid:
11
+ tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(
12
+ 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
13
+ tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(
14
+ 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
15
+ backwarp_tenGrid[k] = torch.cat(
16
+ [tenHorizontal, tenVertical], 1).to(device)
17
+
18
+ tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
19
+ tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
20
+
21
+ g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
22
+ return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)