|
|
import os |
|
|
import torch |
|
|
from torch.nn import functional as F |
|
|
|
|
|
from .ssim import ssim_matlab |
|
|
|
|
|
from .RIFE_HDv3 import Model |
|
|
|
|
|
def get_frame(frames, frame_no): |
|
|
if frame_no >= frames.shape[1]: |
|
|
return None |
|
|
frame = (frames[:, frame_no] + 1) /2 |
|
|
frame = frame.clip(0., 1.) |
|
|
return frame |
|
|
|
|
|
def add_frame(frames, frame, h, w): |
|
|
frame = (frame * 2) - 1 |
|
|
frame = frame.clip(-1., 1.) |
|
|
frame = frame.squeeze(0) |
|
|
frame = frame[:, :h, :w] |
|
|
frame = frame.unsqueeze(1) |
|
|
frames.append(frame.cpu()) |
|
|
|
|
|
def process_frames(model, device, frames, exp): |
|
|
pos = 0 |
|
|
output_frames = [] |
|
|
|
|
|
lastframe = get_frame(frames, 0) |
|
|
_, h, w = lastframe.shape |
|
|
scale = 1 |
|
|
fp16 = False |
|
|
|
|
|
def make_inference(I0, I1, n): |
|
|
middle = model.inference(I0, I1, scale) |
|
|
if n == 1: |
|
|
return [middle] |
|
|
first_half = make_inference(I0, middle, n=n//2) |
|
|
second_half = make_inference(middle, I1, n=n//2) |
|
|
if n%2: |
|
|
return [*first_half, middle, *second_half] |
|
|
else: |
|
|
return [*first_half, *second_half] |
|
|
|
|
|
tmp = max(32, int(32 / scale)) |
|
|
ph = ((h - 1) // tmp + 1) * tmp |
|
|
pw = ((w - 1) // tmp + 1) * tmp |
|
|
padding = (0, pw - w, 0, ph - h) |
|
|
|
|
|
def pad_image(img): |
|
|
if(fp16): |
|
|
return F.pad(img, padding).half() |
|
|
else: |
|
|
return F.pad(img, padding) |
|
|
|
|
|
I1 = lastframe.to(device, non_blocking=True).unsqueeze(0) |
|
|
I1 = pad_image(I1) |
|
|
temp = None |
|
|
|
|
|
while True: |
|
|
if temp is not None: |
|
|
frame = temp |
|
|
temp = None |
|
|
else: |
|
|
pos += 1 |
|
|
frame = get_frame(frames, pos) |
|
|
if frame is None: |
|
|
break |
|
|
I0 = I1 |
|
|
I1 = frame.to(device, non_blocking=True).unsqueeze(0) |
|
|
I1 = pad_image(I1) |
|
|
I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False) |
|
|
I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False) |
|
|
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3]) |
|
|
|
|
|
break_flag = False |
|
|
if ssim > 0.996 or pos > 100: |
|
|
pos += 1 |
|
|
frame = get_frame(frames, pos) |
|
|
if frame is None: |
|
|
break_flag = True |
|
|
frame = lastframe |
|
|
else: |
|
|
temp = frame |
|
|
I1 = frame.to(device, non_blocking=True).unsqueeze(0) |
|
|
I1 = pad_image(I1) |
|
|
I1 = model.inference(I0, I1, scale) |
|
|
I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False) |
|
|
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3]) |
|
|
frame = I1[0][:, :h, :w] |
|
|
|
|
|
if ssim < 0.2: |
|
|
output = [] |
|
|
for _ in range((2 ** exp) - 1): |
|
|
output.append(I0) |
|
|
else: |
|
|
output = make_inference(I0, I1, 2**exp-1) if exp else [] |
|
|
|
|
|
add_frame(output_frames, lastframe, h, w) |
|
|
for mid in output: |
|
|
add_frame(output_frames, mid, h, w) |
|
|
lastframe = frame |
|
|
if break_flag: |
|
|
break |
|
|
|
|
|
add_frame(output_frames, lastframe, h, w) |
|
|
return torch.cat( output_frames, dim=1) |
|
|
|
|
|
def temporal_interpolation(model_path, frames, exp, device ="cuda"): |
|
|
|
|
|
model = Model() |
|
|
model.load_model(model_path, -1, device=device) |
|
|
|
|
|
model.eval() |
|
|
model.to(device=device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
output = process_frames(model, device, frames.float(), exp) |
|
|
|
|
|
return output |
|
|
|