File size: 8,587 Bytes
382c63c
8cfd70e
 
 
 
 
 
 
 
 
 
 
 
205f29e
8cfd70e
 
 
 
205f29e
8cfd70e
7ec6c48
 
8cfd70e
 
 
 
7ec6c48
8cfd70e
 
 
 
 
 
7ec6c48
8cfd70e
 
 
7ec6c48
 
 
 
 
 
 
205f29e
7ec6c48
8cfd70e
 
7ec6c48
 
 
 
8cfd70e
205f29e
8cfd70e
7ec6c48
8cfd70e
 
 
205f29e
 
 
 
8cfd70e
205f29e
8cfd70e
7ec6c48
8cfd70e
 
205f29e
 
 
 
7ec6c48
8cfd70e
 
 
 
140a4b8
205f29e
 
7ec6c48
205f29e
8cfd70e
205f29e
8cfd70e
7ec6c48
8cfd70e
 
 
7ec6c48
 
 
 
 
8cfd70e
 
7ec6c48
205f29e
8cfd70e
205f29e
7ec6c48
 
 
 
 
 
8cfd70e
7ec6c48
 
 
 
 
 
 
 
 
 
 
 
 
 
8cfd70e
7ec6c48
8cfd70e
 
 
 
205f29e
140a4b8
205f29e
 
7ec6c48
140a4b8
 
7ec6c48
140a4b8
 
 
 
 
205f29e
 
 
 
 
 
 
7ec6c48
205f29e
 
 
 
7ec6c48
140a4b8
 
205f29e
 
7ec6c48
 
205f29e
7ec6c48
 
205f29e
 
 
 
 
7ec6c48
205f29e
 
 
 
 
 
 
 
 
 
 
8cfd70e
 
205f29e
8cfd70e
205f29e
 
 
7ec6c48
205f29e
 
 
7ec6c48
205f29e
 
 
 
 
 
 
7ec6c48
205f29e
 
 
7ec6c48
205f29e
 
 
140a4b8
 
8cfd70e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import spaces
import os
import shutil
import subprocess
import glob
import importlib.util
import gradio as gr

# --- Constants ---
BASE_DIR = os.getcwd()
RIFE_DIR = os.path.join(BASE_DIR, "Practical-RIFE")
MODEL_URL = "https://huggingface.co/hzwer/RIFE/resolve/main/RIFEv4.26_0921.zip"

# --- Setup Functions ---
def run_command(command):
    subprocess.run(command, shell=True, check=True)

def setup_environment():
    """Sets up RIFE, downloads weights, and patches libraries."""
    print("--- Starting Environment Setup ---")
    
    # 1. Clone Repo
    if not os.path.exists(RIFE_DIR):
        print("Cloning Practical-RIFE...")
        run_command(f"git clone https://github.com/hzwer/Practical-RIFE {RIFE_DIR}")
    
    # 2. Download Weights
    if not os.path.exists(os.path.join(RIFE_DIR, "HDv3")):
        print("Downloading RIFE v4.26 model weights...")
        zip_path = os.path.join(RIFE_DIR, "RIFEv4.26_0921.zip")
        run_command(f"wget -O {zip_path} {MODEL_URL}")
        run_command(f"unzip -o {zip_path} -d {RIFE_DIR}")
        
        # Fix directory structure (User Fix)
        train_log_dir = os.path.join(RIFE_DIR, "train_log")
        os.makedirs(train_log_dir, exist_ok=True)
        extract_folder = os.path.join(RIFE_DIR, "RIFEv4.26_0921")
        
        # Move files safely
        if os.path.exists(os.path.join(extract_folder, "RIFE_HDv3.py")):
            shutil.move(os.path.join(extract_folder, "RIFE_HDv3.py"), train_log_dir)
        if os.path.exists(os.path.join(extract_folder, "IFNet_HDv3.py")):
            shutil.move(os.path.join(extract_folder, "IFNet_HDv3.py"), train_log_dir)
            
        with open(os.path.join(train_log_dir, "__init__.py"), 'w') as f: pass
        
        hdv3_dir = os.path.join(RIFE_DIR, "HDv3")
        os.makedirs(hdv3_dir, exist_ok=True)
        
        if os.path.exists(os.path.join(extract_folder, "flownet.pkl")):
            shutil.move(os.path.join(extract_folder, "flownet.pkl"), hdv3_dir)
            
        shutil.rmtree(extract_folder)
        if os.path.exists(zip_path): os.remove(zip_path)

    # 3. Patch skvideo (Numpy Fix)
    try:
        spec = importlib.util.find_spec('skvideo.io.abstract')
        if spec and spec.origin:
            with open(spec.origin, 'r') as f: content = f.read()
            if 'np.float' in content:
                run_command(f"sed -i 's/np.float/float/g' {spec.origin}")
                run_command(f"sed -i 's/np.int/int/g' {spec.origin}")
    except Exception as e:
        print(f"Warning: skvideo patch failed: {e}")

    # 4. Patch RIFE inference for libx264
    inference_script = os.path.join(RIFE_DIR, "inference_video.py")
    if os.path.exists(inference_script):
        with open(inference_script, 'r') as f: content = f.read()
        if "libx264" not in content:
            new_content = content.replace("-c:v', 'mpeg4', '-qscale:v', '1'", "-c:v', 'libx264', '-preset', 'medium', '-crf', '23'")
            with open(inference_script, 'w') as f: f.write(new_content)
            
    print("--- Setup Complete ---")

setup_environment()

@spaces.GPU(required=True)
def interpolate_video(input_video_path, multi_factor):
    if input_video_path is None: return None
    
    factor = str(multi_factor).replace("x", "").strip()
    output_path = os.path.join(BASE_DIR, "output_rife.mp4")
    final_output_path = os.path.join(BASE_DIR, "final_interpolated.mp4")

    # Clean previous runs
    if os.path.exists(output_path): os.remove(output_path)
    if os.path.exists(final_output_path): os.remove(final_output_path)

    # RIFE script creates a specific output name if audio transfer fails
    # It often appends _noaudio, so we must watch for that.
    expected_no_audio = output_path.replace(".mp4", "_noaudio.mp4")
    if os.path.exists(expected_no_audio): os.remove(expected_no_audio)

    os.chdir(RIFE_DIR)
    try:
        print(f"Running RIFE with {factor}x on {input_video_path}")
        cmd = ['python3', 'inference_video.py', '--video', input_video_path, '--output', output_path, '--multi', factor, '--model', 'HDv3']
        subprocess.run(cmd, check=True)
        
        # --- Logic to handle the file finding ---
        # If moviepy fails, RIFE creates 'output_rife.mp4' OR 'output_rife_noaudio.mp4'
        # depending on where it crashed. We check for both.
        
        source_to_encode = None
        
        if os.path.exists(output_path):
            source_to_encode = output_path
        elif os.path.exists(expected_no_audio):
            source_to_encode = expected_no_audio
            print("Audio transfer failed inside RIFE, using no-audio version.")
        else:
            print("Error: Output video file not found after inference.")
            return None

        # Re-encode for web compatibility
        print(f"Re-encoding {source_to_encode} to {final_output_path}...")
        subprocess.run(['ffmpeg', '-i', source_to_encode, '-c:v', 'libx264', '-pix_fmt', 'yuv420p', '-movflags', '+faststart', '-y', final_output_path], check=True)
        
        return final_output_path

    except Exception as e:
        print(f"Error during interpolation: {e}")
        return None
    finally:
        os.chdir(BASE_DIR)

# --- Stitching Logic ---
def stitch_videos(video_files, resolution_choice):
    if not video_files: return None
    
    # Parse resolution
    try:
        target_w, target_h = resolution_choice.split("x")
        target_w, target_h = target_w.strip(), target_h.strip()
    except:
        target_w, target_h = "1920", "1080"

    print(f"Stitching {len(video_files)} videos into {target_w}x{target_h}...")
    
    stitch_list_path = os.path.join(BASE_DIR, "stitch_list.txt")
    output_stitched = os.path.join(BASE_DIR, "final_stitched.mp4")
    temp_dir = os.path.join(BASE_DIR, "temp_stitch")
    
    if os.path.exists(temp_dir): shutil.rmtree(temp_dir)
    os.makedirs(temp_dir)

    # 1. Normalize
    normalized_files = []
    for i, vid_path in enumerate(video_files):
        temp_filepath = os.path.join(temp_dir, f"norm_{i}.mp4")
        
        # Scale and Pad Filter
        scale_filter = f"scale={target_w}:{target_h}:force_original_aspect_ratio=decrease,pad={target_w}:{target_h}:(ow-iw)/2:(oh-ih)/2"
        
        cmd = [
            'ffmpeg', '-i', vid_path,
            '-r', '60',                 
            '-vf', scale_filter,        
            '-c:v', 'libx264', '-crf', '23',
            '-c:a', 'aac',              
            '-ar', '44100',             
            '-y', temp_filepath
        ]
        subprocess.run(cmd, check=True)
        normalized_files.append(temp_filepath)

    # 2. Create List
    with open(stitch_list_path, 'w') as f:
        for path in normalized_files:
            f.write(f"file '{path}'\n")

    # 3. Concatenate
    print("Concatenating...")
    subprocess.run(['ffmpeg', '-f', 'concat', '-safe', '0', '-i', stitch_list_path, '-c', 'copy', '-y', output_stitched], check=True)
    
    shutil.rmtree(temp_dir)
    return output_stitched

# --- Gradio Interface ---
with gr.Blocks() as demo:
    gr.Markdown("# 🎞️ RIFE: Interpolate & Stitch")
    
    with gr.Tabs():
        # TAB 1: Interpolate
        with gr.TabItem("1. Smooth Video (Interpolate)"):
            gr.Markdown("Upload a single video to increase its framerate.")
            with gr.Row():
                with gr.Column():
                    video_input = gr.Video(label="Input Video")
                    multi_select = gr.Dropdown(["2", "4", "8"], value="2", label="Multiplier")
                    interp_btn = gr.Button("Interpolate", variant="primary")
                with gr.Column():
                    video_output = gr.Video(label="Smoothed Output")
            interp_btn.click(interpolate_video, inputs=[video_input, multi_select], outputs=video_output)

        # TAB 2: Stitch
        with gr.TabItem("2. Stitch Videos"):
            gr.Markdown("Upload multiple videos. They will be normalized to **60fps**.")
            with gr.Row():
                with gr.Column():
                    stitch_inputs = gr.File(label="Upload Clips", file_count="multiple")
                    res_select = gr.Dropdown(choices=["1920x1080", "1280x1280", "1024x1024"], value="1920x1080", label="Resolution")
                    stitch_btn = gr.Button("Stitch Videos", variant="primary")
                with gr.Column():
                    stitch_output = gr.Video(label="Stitched Result")
            
            stitch_btn.click(stitch_videos, inputs=[stitch_inputs, res_select], outputs=stitch_output)

demo.launch()