File size: 4,472 Bytes
f523f14 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | import os
import torch
from torch.nn import functional as F
# from .model.pytorch_msssim import ssim_matlab
from .ssim import ssim_matlab
from .RIFE_HDv3 import Model as ModelV3
from .RIFE_V4 import Model as ModelV4
def get_frame(frames, frame_no):
if frame_no >= frames.shape[1]:
return None
frame = frames[:, frame_no]
if frame.dtype == torch.uint8:
frame = frame.float().div_(255.0)
else:
frame = (frame + 1) / 2
frame = frame.clip(0., 1.)
return frame
def add_frame(frames, frame, h, w):
frame = (frame * 2) - 1
frame = frame.clip(-1., 1.)
frame = frame.squeeze(0)
frame = frame[:, :h, :w]
frame = frame.unsqueeze(1)
frames.append(frame.cpu())
def process_frames(model, device, frames, exp):
pos = 0
output_frames = []
lastframe = get_frame(frames, 0)
_, h, w = lastframe.shape
scale = 1
fp16 = False
supports_timestep = getattr(model, "supports_timestep", False)
pad_mod = getattr(model, "pad_mod", 32)
def make_inference(I0, I1, n):
if n <= 0:
return []
if supports_timestep:
return [model.inference(I0, I1, (i + 1) / (n + 1), scale) for i in range(n)]
middle = model.inference(I0, I1, scale)
if n == 1:
return [middle]
first_half = make_inference(I0, middle, n=n//2)
second_half = make_inference(middle, I1, n=n//2)
if n%2:
return [*first_half, middle, *second_half]
else:
return [*first_half, *second_half]
tmp = max(pad_mod, int(pad_mod / scale))
ph = ((h - 1) // tmp + 1) * tmp
pw = ((w - 1) // tmp + 1) * tmp
padding = (0, pw - w, 0, ph - h)
def pad_image(img):
if(fp16):
return F.pad(img, padding).half()
else:
return F.pad(img, padding)
I1 = lastframe.to(device, non_blocking=True).unsqueeze(0)
I1 = pad_image(I1)
temp = None # save lastframe when processing static frame
while True:
if temp is not None:
frame = temp
temp = None
else:
pos += 1
frame = get_frame(frames, pos)
if frame is None:
break
I0 = I1
I1 = frame.to(device, non_blocking=True).unsqueeze(0)
I1 = pad_image(I1)
I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False)
I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
break_flag = False
if ssim > 0.996 or pos > 100:
pos += 1
frame = get_frame(frames, pos)
if frame is None:
break_flag = True
frame = lastframe
else:
temp = frame
I1 = frame.to(device, non_blocking=True).unsqueeze(0)
I1 = pad_image(I1)
if supports_timestep:
I1 = model.inference(I0, I1, 0.5, scale)
else:
I1 = model.inference(I0, I1, scale)
I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
frame = I1[0][:, :h, :w]
if ssim < 0.2:
output = []
for _ in range((2 ** exp) - 1):
output.append(I0)
else:
output = make_inference(I0, I1, 2**exp-1) if exp else []
add_frame(output_frames, lastframe, h, w)
for mid in output:
add_frame(output_frames, mid, h, w)
lastframe = frame
if break_flag:
break
add_frame(output_frames, lastframe, h, w)
return torch.cat( output_frames, dim=1)
def temporal_interpolation(model_path, frames, exp, device ="cuda", rife_version="v3"):
input_was_uint8 = frames.dtype == torch.uint8
if rife_version == "v4":
model = ModelV4()
else:
model = ModelV3()
model.load_model(model_path, -1, device=device)
model.eval()
model.to(device=device)
with torch.no_grad():
output = process_frames(model, device, frames, exp)
if input_was_uint8:
output = output.add_(1.0).mul_(127.5).clamp_(0, 255).to(torch.uint8)
return output
|