SpyC0der77's picture
Speed up
969332c verified
raw
history blame
7.57 kB
import cv2
import numpy as np
import torch
import tempfile
import gradio as gr
import time
import io
from contextlib import redirect_stdout
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INFO] Using device: {device}")
if device.type == "cuda":
torch.backends.cudnn.benchmark = True
try:
print("[INFO] Attempting to load RAFT model from torch.hub...")
raft_model = torch.hub.load("princeton-vl/RAFT", "raft_small", pretrained=True, trust_repo=True)
raft_model = raft_model.to(device)
raft_model.eval()
print("[INFO] RAFT model loaded successfully.")
except Exception as e:
print("[ERROR] Error loading RAFT model:", e)
print("[INFO] Falling back to OpenCV Farneback optical flow.")
raft_model = None
gr.Warning("Falling back to OpenCV Farneback optical flow.")
def _resize(frame, w, h):
if frame.shape[1] == w and frame.shape[0] == h:
return frame
return cv2.resize(frame, (w, h), interpolation=cv2.INTER_AREA if (w < frame.shape[1] or h < frame.shape[0]) else cv2.INTER_LINEAR)
def _frame_to_raft_tensor_bgr(frame_bgr):
frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
t = torch.from_numpy(frame_rgb).permute(2, 0, 1).contiguous().float().unsqueeze(0).div_(255.0)
return t.to(device, non_blocking=(device.type == "cuda"))
def compute_offsets(
video_file,
out_w,
out_h,
motion_scale=0.5,
raft_iters=12,
progress=gr.Progress(),
progress_offset=0.0,
progress_scale=0.55,
):
cap = cv2.VideoCapture(video_file)
if not cap.isOpened():
raise gr.Error("Could not open video file for motion estimation.")
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0
mw = max(64, int(out_w * float(motion_scale)))
mh = max(64, int(out_h * float(motion_scale)))
sx = float(out_w) / float(mw)
sy = float(out_h) / float(mh)
ret, prev = cap.read()
if not ret:
cap.release()
raise gr.Error("Cannot read first frame from video.")
prev_out = _resize(prev, out_w, out_h)
prev_small = _resize(prev_out, mw, mh)
use_raft = raft_model is not None
use_amp = device.type == "cuda"
if use_raft:
prev_t = _frame_to_raft_tensor_bgr(prev_small)
else:
prev_g = cv2.cvtColor(prev_small, cv2.COLOR_BGR2GRAY)
offsets = [(0.0, 0.0)]
cum_dx = 0.0
cum_dy = 0.0
idx = 1
while True:
ret, frame = cap.read()
if not ret:
break
frame_out = _resize(frame, out_w, out_h)
curr_small = _resize(frame_out, mw, mh)
if use_raft:
curr_t = _frame_to_raft_tensor_bgr(curr_small)
with torch.no_grad():
if use_amp:
with torch.cuda.amp.autocast(True):
_, flow_up = raft_model(prev_t, curr_t, iters=int(raft_iters), test_mode=True)
else:
_, flow_up = raft_model(prev_t, curr_t, iters=int(raft_iters), test_mode=True)
flow = flow_up[0]
dx = float(flow[0].median().item())
dy = float(flow[1].median().item())
prev_t = curr_t
else:
curr_g = cv2.cvtColor(curr_small, cv2.COLOR_BGR2GRAY)
flow = cv2.calcOpticalFlowFarneback(
prev_g,
curr_g,
None,
pyr_scale=0.5,
levels=3,
winsize=15,
iterations=3,
poly_n=5,
poly_sigma=1.2,
flags=0,
)
dx = float(np.median(flow[..., 0]))
dy = float(np.median(flow[..., 1]))
prev_g = curr_g
dx *= sx
dy *= sy
cum_dx += dx
cum_dy += dy
offsets.append((-cum_dx, -cum_dy))
if total > 0 and (idx % 5 == 0 or idx == total - 1):
progress(progress_offset + (idx / max(1, total - 1)) * progress_scale, desc="Estimating Motion")
idx += 1
cap.release()
return offsets
def compute_auto_zoom(offsets, width, height):
dxs = [o[0] for o in offsets] or [0.0]
dys = [o[1] for o in offsets] or [0.0]
left = max(0.0, -min(dxs))
right = max(0.0, max(dxs))
top = max(0.0, -min(dys))
bottom = max(0.0, max(dys))
safe_w = float(width) - (left + right)
safe_h = float(height) - (top + bottom)
zx = (float(width) / safe_w) if safe_w > 1.0 else 1.0
zy = (float(height) / safe_h) if safe_h > 1.0 else 1.0
return max(1.0, zx, zy)
def stabilize_stream(
video_file,
offsets,
zoom=1.0,
vertical_only=False,
out_w=None,
out_h=None,
progress=gr.Progress(),
progress_offset=0.55,
progress_scale=0.45,
output_file=None,
):
cap = cv2.VideoCapture(video_file)
if not cap.isOpened():
raise gr.Error("Could not open video file for stabilization.")
fps = cap.get(cv2.CAP_PROP_FPS)
in_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
in_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
if out_w is None:
out_w = in_w
if out_h is None:
out_h = in_h
if output_file is None:
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
output_file = temp_file.name
temp_file.close()
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(output_file, fourcc, fps, (int(out_w), int(out_h)))
center = (float(out_w) / 2.0, float(out_h) / 2.0)
base = cv2.getRotationMatrix2D(center, 0.0, float(zoom))
total = len(offsets)
i = 0
while i < total:
ret, frame = cap.read()
if not ret:
break
frame_out = _resize(frame, int(out_w), int(out_h))
dx, dy = offsets[i]
if vertical_only:
dx = 0.0
M = base.copy()
M[0, 2] += float(dx)
M[1, 2] += float(dy)
stabilized = cv2.warpAffine(frame_out, M, (int(out_w), int(out_h)), borderMode=cv2.BORDER_REPLICATE)
out.write(stabilized)
if total > 0 and (i % 5 == 0 or i == total - 1):
progress(progress_offset + (i / max(1, total - 1)) * progress_scale, desc="Stabilizing Video")
i += 1
cap.release()
out.release()
return output_file
def process_video_ai(
video_file,
zoom,
vertical_only,
compress_mode,
target_width,
target_height,
auto_zoom,
progress=gr.Progress(track_tqdm=True),
):
gr.Info("Starting AI-powered video processing...")
log_buffer = io.StringIO()
with redirect_stdout(log_buffer):
if isinstance(video_file, dict):
video_file = video_file.get("name", None)
if video_file is None:
raise gr.Error("Please upload a video file.")
cap = cv2.VideoCapture(video_file)
if not cap.isOpened():
raise gr.Error("Could not open uploaded video.")
in_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
in_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
cap.release()
if compress_mode:
out_w = int(target_width)
out_h = int(target_height)
else:
out_w = in_w
out_h = in_h
offsets = compute_offsets(
video_file,
out_w,
out_h,
motion_scale=0.5,
raft_iters=12,
progress=progress,
progress_offset=0.0,
progress_scale=0.55,
)
gr.Info("Motion estimated successfully.")
if auto_zoom:
z = compute_auto_zoom(offsets, out