File size: 4,472 Bytes
f523f14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
import torch
from torch.nn import functional as F
# from .model.pytorch_msssim import ssim_matlab
from .ssim import ssim_matlab

from .RIFE_HDv3 import Model as ModelV3
from .RIFE_V4 import Model as ModelV4

def get_frame(frames, frame_no):
    if frame_no >= frames.shape[1]:
        return None
    frame = frames[:, frame_no]
    if frame.dtype == torch.uint8:
        frame = frame.float().div_(255.0)
    else:
        frame = (frame + 1) / 2
        frame = frame.clip(0., 1.)
    return frame

def add_frame(frames, frame, h, w):
    frame = (frame * 2) - 1
    frame = frame.clip(-1., 1.)    
    frame = frame.squeeze(0)
    frame = frame[:, :h, :w]
    frame = frame.unsqueeze(1)
    frames.append(frame.cpu())

def process_frames(model, device, frames, exp):
    pos = 0
    output_frames = []

    lastframe = get_frame(frames, 0)
    _,  h, w = lastframe.shape
    scale = 1
    fp16 = False
    supports_timestep = getattr(model, "supports_timestep", False)
    pad_mod = getattr(model, "pad_mod", 32)

    def make_inference(I0, I1, n):
        if n <= 0:
            return []
        if supports_timestep:
            return [model.inference(I0, I1, (i + 1) / (n + 1), scale) for i in range(n)]
        middle = model.inference(I0, I1, scale)
        if n == 1:
            return [middle]
        first_half = make_inference(I0, middle, n=n//2)
        second_half = make_inference(middle, I1, n=n//2)
        if n%2:
            return [*first_half, middle, *second_half]
        else:
            return [*first_half, *second_half]

    tmp = max(pad_mod, int(pad_mod / scale))
    ph = ((h - 1) // tmp + 1) * tmp
    pw = ((w - 1) // tmp + 1) * tmp
    padding = (0, pw - w, 0, ph - h)

    def pad_image(img):
        if(fp16):
            return F.pad(img, padding).half()
        else:
            return F.pad(img, padding)

    I1 = lastframe.to(device, non_blocking=True).unsqueeze(0)
    I1 = pad_image(I1)
    temp = None # save lastframe when processing static frame

    while True:
        if temp is not None:
            frame = temp
            temp = None
        else:
            pos += 1
            frame = get_frame(frames, pos)
        if frame is None:
            break
        I0 = I1
        I1 = frame.to(device, non_blocking=True).unsqueeze(0)
        I1 = pad_image(I1)
        I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False)
        I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
        ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])

        break_flag = False
        if ssim > 0.996 or pos > 100:        
            pos += 1
            frame = get_frame(frames, pos)
            if frame is None:
                break_flag = True
                frame = lastframe
            else:
                temp = frame
            I1 = frame.to(device, non_blocking=True).unsqueeze(0)
            I1 = pad_image(I1)
            if supports_timestep:
                I1 = model.inference(I0, I1, 0.5, scale)
            else:
                I1 = model.inference(I0, I1, scale)
            I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
            ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
            frame = I1[0][:, :h, :w]
        
        if ssim < 0.2:
            output = []
            for _ in range((2 ** exp) - 1):
                output.append(I0)
        else:
            output = make_inference(I0, I1, 2**exp-1) if exp else []

        add_frame(output_frames, lastframe, h, w)
        for mid in output:
            add_frame(output_frames, mid, h, w)
        lastframe = frame
        if break_flag:
            break

    add_frame(output_frames, lastframe, h, w)
    return torch.cat( output_frames, dim=1)

def temporal_interpolation(model_path, frames, exp, device ="cuda", rife_version="v3"):

    input_was_uint8 = frames.dtype == torch.uint8
    if rife_version == "v4":
        model = ModelV4()
    else:
        model = ModelV3()
    model.load_model(model_path, -1, device=device)

    model.eval()
    model.to(device=device)

    with torch.no_grad():    
        output = process_frames(model, device, frames, exp)

    if input_was_uint8:
        output = output.add_(1.0).mul_(127.5).clamp_(0, 255).to(torch.uint8)
    return output