| | import os
|
| | import torch
|
| | from torch.nn import functional as F
|
| |
|
| | from .ssim import ssim_matlab
|
| |
|
| | from .RIFE_HDv3 import Model as ModelV3
|
| | from .RIFE_V4 import Model as ModelV4
|
| |
|
| | def get_frame(frames, frame_no):
|
| | if frame_no >= frames.shape[1]:
|
| | return None
|
| | frame = frames[:, frame_no]
|
| | if frame.dtype == torch.uint8:
|
| | frame = frame.float().div_(255.0)
|
| | else:
|
| | frame = (frame + 1) / 2
|
| | frame = frame.clip(0., 1.)
|
| | return frame
|
| |
|
| | def add_frame(frames, frame, h, w):
|
| | frame = (frame * 2) - 1
|
| | frame = frame.clip(-1., 1.)
|
| | frame = frame.squeeze(0)
|
| | frame = frame[:, :h, :w]
|
| | frame = frame.unsqueeze(1)
|
| | frames.append(frame.cpu())
|
| |
|
| | def process_frames(model, device, frames, exp):
|
| | pos = 0
|
| | output_frames = []
|
| |
|
| | lastframe = get_frame(frames, 0)
|
| | _, h, w = lastframe.shape
|
| | scale = 1
|
| | fp16 = False
|
| | supports_timestep = getattr(model, "supports_timestep", False)
|
| | pad_mod = getattr(model, "pad_mod", 32)
|
| |
|
| | def make_inference(I0, I1, n):
|
| | if n <= 0:
|
| | return []
|
| | if supports_timestep:
|
| | return [model.inference(I0, I1, (i + 1) / (n + 1), scale) for i in range(n)]
|
| | middle = model.inference(I0, I1, scale)
|
| | if n == 1:
|
| | return [middle]
|
| | first_half = make_inference(I0, middle, n=n//2)
|
| | second_half = make_inference(middle, I1, n=n//2)
|
| | if n%2:
|
| | return [*first_half, middle, *second_half]
|
| | else:
|
| | return [*first_half, *second_half]
|
| |
|
| | tmp = max(pad_mod, int(pad_mod / scale))
|
| | ph = ((h - 1) // tmp + 1) * tmp
|
| | pw = ((w - 1) // tmp + 1) * tmp
|
| | padding = (0, pw - w, 0, ph - h)
|
| |
|
| | def pad_image(img):
|
| | if(fp16):
|
| | return F.pad(img, padding).half()
|
| | else:
|
| | return F.pad(img, padding)
|
| |
|
| | I1 = lastframe.to(device, non_blocking=True).unsqueeze(0)
|
| | I1 = pad_image(I1)
|
| | temp = None
|
| |
|
| | while True:
|
| | if temp is not None:
|
| | frame = temp
|
| | temp = None
|
| | else:
|
| | pos += 1
|
| | frame = get_frame(frames, pos)
|
| | if frame is None:
|
| | break
|
| | I0 = I1
|
| | I1 = frame.to(device, non_blocking=True).unsqueeze(0)
|
| | I1 = pad_image(I1)
|
| | I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False)
|
| | I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
|
| | ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
|
| |
|
| | break_flag = False
|
| | if ssim > 0.996 or pos > 100:
|
| | pos += 1
|
| | frame = get_frame(frames, pos)
|
| | if frame is None:
|
| | break_flag = True
|
| | frame = lastframe
|
| | else:
|
| | temp = frame
|
| | I1 = frame.to(device, non_blocking=True).unsqueeze(0)
|
| | I1 = pad_image(I1)
|
| | if supports_timestep:
|
| | I1 = model.inference(I0, I1, 0.5, scale)
|
| | else:
|
| | I1 = model.inference(I0, I1, scale)
|
| | I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
|
| | ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
|
| | frame = I1[0][:, :h, :w]
|
| |
|
| | if ssim < 0.2:
|
| | output = []
|
| | for _ in range((2 ** exp) - 1):
|
| | output.append(I0)
|
| | else:
|
| | output = make_inference(I0, I1, 2**exp-1) if exp else []
|
| |
|
| | add_frame(output_frames, lastframe, h, w)
|
| | for mid in output:
|
| | add_frame(output_frames, mid, h, w)
|
| | lastframe = frame
|
| | if break_flag:
|
| | break
|
| |
|
| | add_frame(output_frames, lastframe, h, w)
|
| | return torch.cat( output_frames, dim=1)
|
| |
|
| | def temporal_interpolation(model_path, frames, exp, device ="cuda", rife_version="v3"):
|
| |
|
| | input_was_uint8 = frames.dtype == torch.uint8
|
| | if rife_version == "v4":
|
| | model = ModelV4()
|
| | else:
|
| | model = ModelV3()
|
| | model.load_model(model_path, -1, device=device)
|
| |
|
| | model.eval()
|
| | model.to(device=device)
|
| |
|
| | with torch.no_grad():
|
| | output = process_frames(model, device, frames, exp)
|
| |
|
| | if input_was_uint8:
|
| | output = output.add_(1.0).mul_(127.5).clamp_(0, 255).to(torch.uint8)
|
| | return output
|
| |
|