import os import cv2 import shutil import subprocess import numpy as np import onnxruntime as ort import gradio as gr from huggingface_hub import hf_hub_download # --- MODEL SETUP --- # Using the 2x model that requires 64x64 fixed input def load_model(): model_path = hf_hub_download( repo_id="tidus2102/Real-ESRGAN", filename="Real-ESRGAN_x2plus.onnx" ) sess_options = ort.SessionOptions() sess_options.intra_op_num_threads = 2 return ort.InferenceSession(model_path, sess_options, providers=['CPUExecutionProvider']) session = load_model() def upscale_frame_tiled(frame): tile_size = 64 h, w, c = frame.shape upscaled_img = np.zeros((h * 2, w * 2, c), dtype=np.uint8) tiles = [] coords = [] # 1. Collect all tiles first for y in range(0, h, tile_size): for x in range(0, w, tile_size): y_end, x_end = min(y + tile_size, h), min(x + tile_size, w) tile = frame[y:y_end, x:x_end] # Pad if necessary if tile.shape[0] < tile_size or tile.shape[1] < tile_size: tile = cv2.copyMakeBorder(tile, 0, tile_size - tile.shape[0], 0, tile_size - tile.shape[1], cv2.BORDER_REFLECT) tiles.append(cv2.cvtColor(tile, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0) coords.append((y, x, y_end - y, x_end - x)) # 2. Process in Batches of 32 (Uses that extra RAM!) batch_size = 32 all_outputs = [] for i in range(0, len(tiles), batch_size): batch = np.array(tiles[i : i + batch_size]) # Shape: (Batch, 3, 64, 64) batch = np.transpose(batch, (0, 3, 1, 2)) inputs = {session.get_inputs()[0].name: batch} output = session.run(None, inputs)[0] all_outputs.extend(output) # 3. Stitch back for i, output in enumerate(all_outputs): y, x, actual_h, actual_w = coords[i] tile_out = np.clip(np.squeeze(output), 0, 1).transpose(1, 2, 0) tile_out = cv2.cvtColor((tile_out * 255.0).astype(np.uint8), cv2.COLOR_RGB2BGR) upscaled_img[y*2 : y*2 + (actual_h*2), x*2 : x*2 + (actual_w*2)] = tile_out[:actual_h*2, :actual_w*2] return upscaled_img def process_video(input_path, progress=gr.Progress()): if not input_path: return None # 1. Sanitize Filename & Detect FPS local_input = "input_video_sanitized.mp4" shutil.copy(input_path, local_input) cap = cv2.VideoCapture(local_input) fps = cap.get(cv2.CAP_PROP_FPS) if fps < 1: fps = 30 # Default if metadata is missing cap.release() # 2. Setup Dirs frames_dir, audio_path, output_video = "temp_frames", "temp_audio.mp3", "upscaled_2x.mp4" if os.path.exists(frames_dir): shutil.rmtree(frames_dir) os.makedirs(frames_dir) # 3. Extract Audio & Frames subprocess.run(f'ffmpeg -i "{local_input}" -vn -acodec libmp3lame "{audio_path}" -y', shell=True) subprocess.run(f'ffmpeg -i "{local_input}" "{frames_dir}/raw_%05d.png" -y', shell=True) raw_files = sorted([f for f in os.listdir(frames_dir) if f.startswith("raw_")]) total = len(raw_files) # 4. Upscale Loop (Now using Tiling) for i, f_name in enumerate(raw_files): f_path = os.path.join(frames_dir, f_name) img = cv2.imread(f_path) try: # THIS IS WHERE THE AI RUNS (Tiled to prevent dimension error) res = upscale_frame_tiled(img) cv2.imwrite(os.path.join(frames_dir, f"out_{i:05d}.png"), res) except Exception as e: print(f"Critical Error on Frame {i}: {e}") h, w = img.shape[:2] res = cv2.resize(img, (w*2, h*2), interpolation=cv2.INTER_LANCZOS4) cv2.imwrite(os.path.join(frames_dir, f"out_{i:05d}.png"), res) os.remove(f_path) # Conserve disk space if i % 2 == 0: progress(i/total, desc=f"AI Upscaling {i}/{total} @ {fps} FPS") # 5. Reassemble Video audio_cmd = f'-i "{audio_path}"' if os.path.exists(audio_path) else "" ffmpeg_cmd = ( f'ffmpeg -framerate {fps} -i "{frames_dir}/out_%05d.png" {audio_cmd} ' f'-c:v libx264 -pix_fmt yuv420p -c:a aac -shortest "{output_video}" -y' ) subprocess.run(ffmpeg_cmd, shell=True) # Final Cleanup shutil.rmtree(frames_dir) for f in [audio_path, local_input]: if os.path.exists(f): os.remove(f) return output_video # --- UI --- demo = gr.Interface( fn=process_video, inputs=gr.Video(label="Input Video"), outputs=gr.Video(label="Upscaled 2x Result"), title="Real-ESRGAN 2x (CPU Tiled)", description="Uses 64x64 tiling to bypass dimension errors. Note: This is detailed but slower on CPU.Takes About 70 Sec Per Frame For 720P" ) if __name__ == "__main__": demo.launch()