Spaces:
Sleeping
Sleeping
| import os | |
| import cv2 | |
| import math | |
| import torch | |
| import numpy as np | |
| from tqdm import tqdm | |
| from pathlib import Path | |
| from torchvision import transforms | |
| import gradio as gr | |
| from argparse import Namespace | |
| import sys | |
| import time | |
| # === RAFT Setup === | |
| sys.path.append("/app/preprocess/RAFT/core") | |
| from raft import RAFT | |
| from utils.utils import InputPadder | |
| # === CONFIG === | |
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| MODEL_PATH = "/app/RAFT/raft-things.pth" | |
| OUTPUT_VIDEO = "/app/full_tracked_output.mp4" | |
| OUTPUT_MASK_VIDEO = "/app/mask_output.mp4" | |
| STABILIZED_MASK = "/app/stabilized_mask_output.mp4" | |
| REVERSED_INPUT = "/app/reversed_input.mp4" | |
| # ========================================================== | |
| # === VIDEO UTILITIES ===================================== | |
| # ========================================================== | |
| def reverse_video(input_path, output_path): | |
| """Reverse frames of input video and save as output.""" | |
| cap = cv2.VideoCapture(input_path) | |
| if not cap.isOpened(): | |
| raise FileNotFoundError(f"β Could not open video: {input_path}") | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
| frames = [] | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frames.append(frame) | |
| cap.release() | |
| # Write reversed frames | |
| for frame in reversed(frames): | |
| out.write(frame) | |
| out.release() | |
| print(f"π Video reversed and saved: {output_path}") | |
| return output_path | |
| def reverse_video_file_inplace(path_in): | |
| """Reverse an existing video and overwrite it.""" | |
| tmp_path = path_in.replace(".mp4", "_tmp.mp4") | |
| reverse_video(path_in, tmp_path) | |
| os.replace(tmp_path, path_in) | |
| # ========================================================== | |
| # === RAFT LOADING ========================================= | |
| # ========================================================== | |
| def load_raft_model(model_path): | |
| args = Namespace( | |
| small=False, | |
| mixed_precision=False, | |
| alternate_corr=False, | |
| dropout=0.0, | |
| max_depth=16, | |
| depth_network=False, | |
| depth_residual=False, | |
| depth_scale=1.0 | |
| ) | |
| model = torch.nn.DataParallel(RAFT(args)) | |
| model.load_state_dict(torch.load(model_path, map_location=DEVICE)) | |
| return model.module.to(DEVICE).eval() | |
| def to_tensor(image): | |
| return transforms.ToTensor()(image).unsqueeze(0).to(DEVICE) | |
| def compute_flow(model, img1, img2): | |
| t1, t2 = to_tensor(img1), to_tensor(img2) | |
| padder = InputPadder(t1.shape) | |
| t1, t2 = padder.pad(t1, t2) | |
| _, flow = model(t1, t2, iters=30, test_mode=True) | |
| flow = padder.unpad(flow)[0] | |
| return flow.permute(1, 2, 0).cpu().numpy() | |
| # ========================================================== | |
| # === FRAME / MASK HELPERS ================================ | |
| # ========================================================== | |
| def extract_frame(video_path, frame_number): | |
| cap = cv2.VideoCapture(video_path) | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number) | |
| ret, frame = cap.read() | |
| cap.release() | |
| if not ret: | |
| return None | |
| return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| def save_mask(data): | |
| if data is None: | |
| return None, "β οΈ No mask data received!" | |
| if isinstance(data, dict): | |
| mask = data.get("mask") | |
| else: | |
| mask = data | |
| if mask is None: | |
| return None, "β οΈ Mask missing!" | |
| if mask.ndim == 3: | |
| mask_gray = cv2.cvtColor(mask, cv2.COLOR_RGBA2GRAY) | |
| else: | |
| mask_gray = mask | |
| _, bin_mask = cv2.threshold(mask_gray, 1, 255, cv2.THRESH_BINARY) | |
| mask_path = "user_mask.png" | |
| cv2.imwrite(mask_path, bin_mask) | |
| return mask_path, f"β Saved mask ({np.count_nonzero(bin_mask)} painted pixels)" | |
| # ========================================================== | |
| # === CROP HELPERS ========================================= | |
| # ========================================================== | |
| def get_mask_center(mask_path): | |
| mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) | |
| if mask is None: | |
| raise FileNotFoundError("Mask not found: " + mask_path) | |
| ys, xs = np.where(mask > 0) | |
| h, w = mask.shape[:2] | |
| if len(xs) == 0: | |
| return w // 2, h // 2 | |
| return int(np.mean(xs)), int(np.mean(ys)) | |
| def clamp_crop(x0, y0, cw, ch, W, H): | |
| x0 = max(0, min(x0, W - 1)) | |
| y0 = max(0, min(y0, H - 1)) | |
| x1 = x0 + cw | |
| y1 = y0 + ch | |
| if x1 > W: | |
| x0 -= (x1 - W) | |
| x1 = W | |
| if y1 > H: | |
| y0 -= (y1 - H) | |
| y1 = H | |
| return x0, y0, x1, y1 | |
| def compute_crop_box_from_mask(first_frame_bgr, mask_path, crop_w=400, crop_h=400): | |
| H, W = first_frame_bgr.shape[:2] | |
| cx, cy = get_mask_center(mask_path) | |
| x0 = cx - crop_w // 2 | |
| y0 = cy - crop_h // 2 | |
| return clamp_crop(x0, y0, crop_w, crop_h, W, H) | |
| def draw_crop_preview_on_frame(frame_rgb, crop_box, color=(0,255,0), thickness=2): | |
| x0, y0, x1, y1 = crop_box | |
| frame = frame_rgb.copy() | |
| cv2.rectangle(frame, (x0, y0), (x1, y1), color, thickness) | |
| return frame | |
| # ========================================================== | |
| # === STABILIZATION ======================================== | |
| # ========================================================== | |
| def stabilize_black_regions(input_video): | |
| # === Define kernels === | |
| kernel_fill = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)) | |
| kernel_edge = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) | |
| cap = cv2.VideoCapture(input_video) | |
| if not cap.isOpened(): | |
| raise FileNotFoundError(f"β Could not open video: {input_video}") | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
| out = cv2.VideoWriter(STABILIZED_MASK, fourcc, fps, (width, height)) | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| # Convert to grayscale and threshold | |
| gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) | |
| _, mask = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY) | |
| # === Step 1: Fill black regions === | |
| inv = cv2.bitwise_not(mask) | |
| flood = inv.copy() | |
| h, w = inv.shape | |
| flood_mask = np.zeros((h + 2, w + 2), np.uint8) | |
| cv2.floodFill(flood, flood_mask, (0, 0), 255) | |
| holes = cv2.bitwise_not(flood) | |
| filled = cv2.bitwise_or(inv, holes) | |
| filled = cv2.bitwise_not(filled) | |
| # === Step 2: Morphological stabilization === | |
| # Fill small black holes and unify mask | |
| stable = cv2.morphologyEx(filled, cv2.MORPH_CLOSE, kernel_fill, iterations=1) | |
| # Smooth jagged edges | |
| stable = cv2.morphologyEx(stable, cv2.MORPH_OPEN, kernel_edge, iterations=1) | |
| # Write result | |
| out.write(cv2.cvtColor(stable, cv2.COLOR_GRAY2BGR)) | |
| cap.release() | |
| out.release() | |
| print(f"β Stabilized mask saved: {STABILIZED_MASK}") | |
| return STABILIZED_MASK | |
| # ========================================================== | |
| # === TRACKING ============================================= | |
| # ========================================================== | |
| def run_tracking(video_path, mask_path, selection_mode="All Pixels", crop_w=400, crop_h=400): | |
| reversed_path = reverse_video(video_path, REVERSED_INPUT) | |
| cap = cv2.VideoCapture(reversed_path) | |
| model = load_raft_model(MODEL_PATH) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| ret, first_frame = cap.read() | |
| if not ret: | |
| return "β Could not read first frame.", None, None, None | |
| H, W = first_frame.shape[:2] | |
| x0, y0, x1, y1 = compute_crop_box_from_mask(first_frame, mask_path, crop_w, crop_h) | |
| cw, ch = x1 - x0, y1 - y0 | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out_vis = cv2.VideoWriter(OUTPUT_VIDEO, fourcc, fps, (W, H)) | |
| out_mask = cv2.VideoWriter(OUTPUT_MASK_VIDEO, fourcc, fps, (W, H), isColor=False) | |
| full_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) | |
| full_mask = cv2.resize(full_mask, (W, H), interpolation=cv2.INTER_NEAREST) | |
| crop_mask = full_mask[y0:y1, x0:x1] | |
| if selection_mode == "All Pixels": | |
| ys, xs = np.where(crop_mask > 0) | |
| else: | |
| gray_first = cv2.cvtColor(first_frame, cv2.COLOR_BGR2GRAY) | |
| black_pixels = (gray_first[y0:y1, x0:x1] < 30) | |
| combined = (crop_mask > 0) & black_pixels | |
| ys, xs = np.where(combined) | |
| tracked_points = np.vstack((xs, ys)).T.astype(np.float32) | |
| prev_full_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB) | |
| prev_crop_rgb = prev_full_rgb[y0:y1, x0:x1] | |
| while True: | |
| ret, curr_frame = cap.read() | |
| if not ret: | |
| break | |
| curr_full_rgb = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2RGB) | |
| curr_crop_rgb = curr_full_rgb[y0:y1, x0:x1] | |
| flow_crop = compute_flow(model, prev_crop_rgb, curr_crop_rgb) | |
| vis_full = curr_full_rgb.copy() | |
| mask_full = np.full((H, W), 255, dtype=np.uint8) | |
| new_points = [] | |
| for pt in tracked_points: | |
| px, py = int(pt[0]), int(pt[1]) | |
| if 0 <= px < cw and 0 <= py < ch: | |
| dx, dy = flow_crop[py, px] | |
| nx, ny = pt[0] + dx, pt[1] + dy | |
| nx = np.clip(nx, 0, cw-1) | |
| ny = np.clip(ny, 0, ch-1) | |
| new_points.append([nx, ny]) | |
| fx, fy = int(nx + x0), int(ny + y0) | |
| if 0 <= fx < W and 0 <= fy < H: | |
| cv2.circle(vis_full, (fx, fy), 1, (0,255,0), -1) | |
| mask_full[fy, fx] = 0 | |
| tracked_points = np.array(new_points, dtype=np.float32) | |
| out_vis.write(cv2.cvtColor(vis_full, cv2.COLOR_RGB2BGR)) | |
| out_mask.write(mask_full) | |
| prev_crop_rgb = curr_crop_rgb | |
| cap.release() | |
| out_vis.release() | |
| out_mask.release() | |
| stabilize_black_regions(OUTPUT_MASK_VIDEO) | |
| # Reverse outputs back to forward direction | |
| reverse_video_file_inplace(OUTPUT_VIDEO) | |
| reverse_video_file_inplace(OUTPUT_MASK_VIDEO) | |
| reverse_video_file_inplace(STABILIZED_MASK) | |
| return ( | |
| f"β Tracking complete ({selection_mode}).\nCrop {cw}x{ch} @ ({x0},{y0})\nSaved outputs reversed back to forward order.", | |
| OUTPUT_VIDEO, | |
| OUTPUT_MASK_VIDEO, | |
| STABILIZED_MASK | |
| ) | |
| # ========================================================== | |
| # === GRADIO APP =========================================== | |
| # ========================================================== | |
| def build_app(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## π― RAFT Pixel Tracker (Reversed Input Pipeline, Forward Outputs)") | |
| with gr.Row(): | |
| video_in = gr.Video(label="ποΈ Upload Video") | |
| frame_num = gr.Number(value=0, visible=False) | |
| load_btn = gr.Button("πΈ Load Frame for Annotation") | |
| annot = gr.Image(label="ποΈ Paint ROI Mask", tool="sketch", type="numpy", image_mode="RGBA", height=480) | |
| save_btn = gr.Button("πΎ Save Mask") | |
| log = gr.Textbox(label="Logs", lines=8) | |
| pixel_mode = gr.Dropdown(["All Pixels", "Only Black Pixels"], value="All Pixels") | |
| crop_w = gr.Number(value=400, label="Crop Width") | |
| crop_h = gr.Number(value=400, label="Crop Height") | |
| preview_btn = gr.Button("π Preview Crop") | |
| preview_frame = gr.Image(label="Preview Frame") | |
| preview_crop = gr.Image(label="Cropped Region") | |
| run_btn = gr.Button("π Run Tracking") | |
| with gr.Row(): | |
| result_video = gr.Video(label="π¬ Result (Forward)") | |
| mask_video = gr.Video(label="β¬ Mask (Forward)") | |
| stabilized_video = gr.Video(label="π§± Stabilized (Forward)") | |
| # Load reversed frame for painting | |
| def load_reversed_frame(v, f): | |
| reversed_path = reverse_video(v.name if hasattr(v, "name") else v, REVERSED_INPUT) | |
| return extract_frame(reversed_path, int(f)) | |
| load_btn.click(load_reversed_frame, [video_in, frame_num], annot) | |
| save_btn.click(save_mask, annot, [gr.State(), log]) | |
| def preview_crop_fn(v, cw, ch): | |
| reversed_path = reverse_video(v.name if hasattr(v, "name") else v, REVERSED_INPUT) | |
| frame0 = extract_frame(reversed_path, 0) | |
| if frame0 is None or not os.path.exists("user_mask.png"): | |
| return None, None, "β οΈ Paint and Save Mask first." | |
| x0,y0,x1,y1 = compute_crop_box_from_mask(cv2.cvtColor(frame0, cv2.COLOR_RGB2BGR), "user_mask.png", int(cw), int(ch)) | |
| frame_box = draw_crop_preview_on_frame(frame0, (x0,y0,x1,y1)) | |
| return frame_box, frame0[y0:y1, x0:x1], f"Crop {cw}x{ch} at ({x0},{y0})" | |
| preview_btn.click(preview_crop_fn, [video_in, crop_w, crop_h], [preview_frame, preview_crop, log]) | |
| def run_btn_fn(v, m, cw, ch): | |
| if not os.path.exists("user_mask.png"): | |
| return "β οΈ Save Mask first.", None, None, None | |
| return run_tracking(v.name if hasattr(v, "name") else v, "user_mask.png", m, int(cw), int(ch)) | |
| run_btn.click(run_btn_fn, [video_in, pixel_mode, crop_w, crop_h], | |
| [log, result_video, mask_video, stabilized_video]) | |
| return demo | |
| if __name__ == "__main__": | |
| app = build_app() | |
| app.launch(server_name="0.0.0.0", server_port=7861, debug=True) | |