James040's picture
Update app.py
def50be verified
import os
import cv2
import shutil
import subprocess
import numpy as np
import onnxruntime as ort
import gradio as gr
from huggingface_hub import hf_hub_download
# --- MODEL SETUP ---
# Using the 2x model that requires 64x64 fixed input
def load_model():
model_path = hf_hub_download(
repo_id="tidus2102/Real-ESRGAN",
filename="Real-ESRGAN_x2plus.onnx"
)
sess_options = ort.SessionOptions()
sess_options.intra_op_num_threads = 2
return ort.InferenceSession(model_path, sess_options, providers=['CPUExecutionProvider'])
session = load_model()
def upscale_frame_tiled(frame):
tile_size = 64
h, w, c = frame.shape
upscaled_img = np.zeros((h * 2, w * 2, c), dtype=np.uint8)
tiles = []
coords = []
# 1. Collect all tiles first
for y in range(0, h, tile_size):
for x in range(0, w, tile_size):
y_end, x_end = min(y + tile_size, h), min(x + tile_size, w)
tile = frame[y:y_end, x:x_end]
# Pad if necessary
if tile.shape[0] < tile_size or tile.shape[1] < tile_size:
tile = cv2.copyMakeBorder(tile, 0, tile_size - tile.shape[0], 0, tile_size - tile.shape[1], cv2.BORDER_REFLECT)
tiles.append(cv2.cvtColor(tile, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0)
coords.append((y, x, y_end - y, x_end - x))
# 2. Process in Batches of 32 (Uses that extra RAM!)
batch_size = 32
all_outputs = []
for i in range(0, len(tiles), batch_size):
batch = np.array(tiles[i : i + batch_size]) # Shape: (Batch, 3, 64, 64)
batch = np.transpose(batch, (0, 3, 1, 2))
inputs = {session.get_inputs()[0].name: batch}
output = session.run(None, inputs)[0]
all_outputs.extend(output)
# 3. Stitch back
for i, output in enumerate(all_outputs):
y, x, actual_h, actual_w = coords[i]
tile_out = np.clip(np.squeeze(output), 0, 1).transpose(1, 2, 0)
tile_out = cv2.cvtColor((tile_out * 255.0).astype(np.uint8), cv2.COLOR_RGB2BGR)
upscaled_img[y*2 : y*2 + (actual_h*2), x*2 : x*2 + (actual_w*2)] = tile_out[:actual_h*2, :actual_w*2]
return upscaled_img
def process_video(input_path, progress=gr.Progress()):
if not input_path: return None
# 1. Sanitize Filename & Detect FPS
local_input = "input_video_sanitized.mp4"
shutil.copy(input_path, local_input)
cap = cv2.VideoCapture(local_input)
fps = cap.get(cv2.CAP_PROP_FPS)
if fps < 1: fps = 30 # Default if metadata is missing
cap.release()
# 2. Setup Dirs
frames_dir, audio_path, output_video = "temp_frames", "temp_audio.mp3", "upscaled_2x.mp4"
if os.path.exists(frames_dir): shutil.rmtree(frames_dir)
os.makedirs(frames_dir)
# 3. Extract Audio & Frames
subprocess.run(f'ffmpeg -i "{local_input}" -vn -acodec libmp3lame "{audio_path}" -y', shell=True)
subprocess.run(f'ffmpeg -i "{local_input}" "{frames_dir}/raw_%05d.png" -y', shell=True)
raw_files = sorted([f for f in os.listdir(frames_dir) if f.startswith("raw_")])
total = len(raw_files)
# 4. Upscale Loop (Now using Tiling)
for i, f_name in enumerate(raw_files):
f_path = os.path.join(frames_dir, f_name)
img = cv2.imread(f_path)
try:
# THIS IS WHERE THE AI RUNS (Tiled to prevent dimension error)
res = upscale_frame_tiled(img)
cv2.imwrite(os.path.join(frames_dir, f"out_{i:05d}.png"), res)
except Exception as e:
print(f"Critical Error on Frame {i}: {e}")
h, w = img.shape[:2]
res = cv2.resize(img, (w*2, h*2), interpolation=cv2.INTER_LANCZOS4)
cv2.imwrite(os.path.join(frames_dir, f"out_{i:05d}.png"), res)
os.remove(f_path) # Conserve disk space
if i % 2 == 0:
progress(i/total, desc=f"AI Upscaling {i}/{total} @ {fps} FPS")
# 5. Reassemble Video
audio_cmd = f'-i "{audio_path}"' if os.path.exists(audio_path) else ""
ffmpeg_cmd = (
f'ffmpeg -framerate {fps} -i "{frames_dir}/out_%05d.png" {audio_cmd} '
f'-c:v libx264 -pix_fmt yuv420p -c:a aac -shortest "{output_video}" -y'
)
subprocess.run(ffmpeg_cmd, shell=True)
# Final Cleanup
shutil.rmtree(frames_dir)
for f in [audio_path, local_input]:
if os.path.exists(f): os.remove(f)
return output_video
# --- UI ---
demo = gr.Interface(
fn=process_video,
inputs=gr.Video(label="Input Video"),
outputs=gr.Video(label="Upscaled 2x Result"),
title="Real-ESRGAN 2x (CPU Tiled)",
description="Uses 64x64 tiling to bypass dimension errors. Note: This is detailed but slower on CPU.Takes About 70 Sec Per Frame For 720P"
)
if __name__ == "__main__":
demo.launch()