File size: 4,861 Bytes
3d981dc
 
 
 
 
 
 
 
 
 
bd9f747
3d981dc
 
bd9f747
 
3d981dc
 
b486f26
3d981dc
 
 
 
bd9f747
 
 
 
b486f26
8956519
 
 
 
bd9f747
 
8956519
bd9f747
 
8956519
 
 
bd9f747
8956519
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd9f747
 
3d981dc
 
 
 
bd9f747
 
3d981dc
 
d8e4073
bd9f747
 
d8e4073
3d981dc
d8e4073
 
3d981dc
 
 
bd9f747
3d981dc
 
 
 
 
 
bd9f747
3d981dc
 
 
 
 
bd9f747
 
3d981dc
d8e4073
bd9f747
3d981dc
 
 
 
bd9f747
 
 
3d981dc
bd9f747
3d981dc
 
d8e4073
 
3d981dc
 
 
bd9f747
 
 
 
 
3d981dc
 
bd9f747
 
 
 
 
 
def50be
bd9f747
3d981dc
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import cv2
import shutil
import subprocess
import numpy as np
import onnxruntime as ort
import gradio as gr
from huggingface_hub import hf_hub_download

# --- MODEL SETUP ---
# Using the 2x model that requires 64x64 fixed input
def load_model():
    model_path = hf_hub_download(
        repo_id="tidus2102/Real-ESRGAN", 
        filename="Real-ESRGAN_x2plus.onnx"
    )
    sess_options = ort.SessionOptions()
    sess_options.intra_op_num_threads = 2
    return ort.InferenceSession(model_path, sess_options, providers=['CPUExecutionProvider'])

session = load_model()

def upscale_frame_tiled(frame):
    tile_size = 64
    h, w, c = frame.shape
    upscaled_img = np.zeros((h * 2, w * 2, c), dtype=np.uint8)
    
    tiles = []
    coords = []
    
    # 1. Collect all tiles first
    for y in range(0, h, tile_size):
        for x in range(0, w, tile_size):
            y_end, x_end = min(y + tile_size, h), min(x + tile_size, w)
            tile = frame[y:y_end, x:x_end]
            
            # Pad if necessary
            if tile.shape[0] < tile_size or tile.shape[1] < tile_size:
                tile = cv2.copyMakeBorder(tile, 0, tile_size - tile.shape[0], 0, tile_size - tile.shape[1], cv2.BORDER_REFLECT)
            
            tiles.append(cv2.cvtColor(tile, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0)
            coords.append((y, x, y_end - y, x_end - x))

    # 2. Process in Batches of 32 (Uses that extra RAM!)
    batch_size = 32 
    all_outputs = []
    
    for i in range(0, len(tiles), batch_size):
        batch = np.array(tiles[i : i + batch_size]) # Shape: (Batch, 3, 64, 64)
        batch = np.transpose(batch, (0, 3, 1, 2))
        
        inputs = {session.get_inputs()[0].name: batch}
        output = session.run(None, inputs)[0]
        all_outputs.extend(output)

    # 3. Stitch back
    for i, output in enumerate(all_outputs):
        y, x, actual_h, actual_w = coords[i]
        tile_out = np.clip(np.squeeze(output), 0, 1).transpose(1, 2, 0)
        tile_out = cv2.cvtColor((tile_out * 255.0).astype(np.uint8), cv2.COLOR_RGB2BGR)
        upscaled_img[y*2 : y*2 + (actual_h*2), x*2 : x*2 + (actual_w*2)] = tile_out[:actual_h*2, :actual_w*2]
            
    return upscaled_img

def process_video(input_path, progress=gr.Progress()):
    if not input_path: return None
    
    # 1. Sanitize Filename & Detect FPS
    local_input = "input_video_sanitized.mp4"
    shutil.copy(input_path, local_input)
    
    cap = cv2.VideoCapture(local_input)
    fps = cap.get(cv2.CAP_PROP_FPS)
    if fps < 1: fps = 30 # Default if metadata is missing
    cap.release()
    
    # 2. Setup Dirs
    frames_dir, audio_path, output_video = "temp_frames", "temp_audio.mp3", "upscaled_2x.mp4"
    if os.path.exists(frames_dir): shutil.rmtree(frames_dir)
    os.makedirs(frames_dir)

    # 3. Extract Audio & Frames
    subprocess.run(f'ffmpeg -i "{local_input}" -vn -acodec libmp3lame "{audio_path}" -y', shell=True)
    subprocess.run(f'ffmpeg -i "{local_input}" "{frames_dir}/raw_%05d.png" -y', shell=True)
    
    raw_files = sorted([f for f in os.listdir(frames_dir) if f.startswith("raw_")])
    total = len(raw_files)

    # 4. Upscale Loop (Now using Tiling)
    for i, f_name in enumerate(raw_files):
        f_path = os.path.join(frames_dir, f_name)
        img = cv2.imread(f_path)
        
        try:
            # THIS IS WHERE THE AI RUNS (Tiled to prevent dimension error)
            res = upscale_frame_tiled(img)
            cv2.imwrite(os.path.join(frames_dir, f"out_{i:05d}.png"), res)
        except Exception as e:
            print(f"Critical Error on Frame {i}: {e}")
            h, w = img.shape[:2]
            res = cv2.resize(img, (w*2, h*2), interpolation=cv2.INTER_LANCZOS4)
            cv2.imwrite(os.path.join(frames_dir, f"out_{i:05d}.png"), res)
        
        os.remove(f_path) # Conserve disk space
        if i % 2 == 0:
            progress(i/total, desc=f"AI Upscaling {i}/{total} @ {fps} FPS")

    # 5. Reassemble Video
    audio_cmd = f'-i "{audio_path}"' if os.path.exists(audio_path) else ""
    ffmpeg_cmd = (
        f'ffmpeg -framerate {fps} -i "{frames_dir}/out_%05d.png" {audio_cmd} '
        f'-c:v libx264 -pix_fmt yuv420p -c:a aac -shortest "{output_video}" -y'
    )
    subprocess.run(ffmpeg_cmd, shell=True)
    
    # Final Cleanup
    shutil.rmtree(frames_dir)
    for f in [audio_path, local_input]:
        if os.path.exists(f): os.remove(f)
        
    return output_video

# --- UI ---
demo = gr.Interface(
    fn=process_video,
    inputs=gr.Video(label="Input Video"),
    outputs=gr.Video(label="Upscaled 2x Result"),
    title="Real-ESRGAN 2x (CPU Tiled)",
    description="Uses 64x64 tiling to bypass dimension errors. Note: This is detailed but slower on CPU.Takes About 70 Sec Per Frame For 720P"
)

if __name__ == "__main__":
    demo.launch()