import os import cv2 import math import torch import numpy as np from tqdm import tqdm from pathlib import Path from torchvision import transforms import gradio as gr from argparse import Namespace import sys # === RAFT Setup === sys.path.append("/app/preprocess/RAFT/core") from raft import RAFT from utils.utils import InputPadder # === CONFIG === DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' MODEL_PATH = "/data/pretrained/RAFT/raft-things.pth" OUTPUT_VIDEO = "/app/tracked_roi_output.mp4" OUTPUT_MASK_VIDEO = "/app/final_tracked_mask_video.mp4" STABILIZED_MASK = "/app/stabilized_mask_video.mp4" # --- Load RAFT model --- def load_raft_model(model_path): args = Namespace( small=False, mixed_precision=False, alternate_corr=False, dropout=0.0, max_depth=16, depth_network=False, depth_residual=False, depth_scale=1.0 ) model = torch.nn.DataParallel(RAFT(args)) model.load_state_dict(torch.load(model_path, map_location=DEVICE)) return model.module.to(DEVICE).eval() # --- Convert to tensor --- def to_tensor(image): return transforms.ToTensor()(image).unsqueeze(0).to(DEVICE) # --- Compute optical flow between two frames --- @torch.no_grad() def compute_flow(model, img1, img2): t1, t2 = to_tensor(img1), to_tensor(img2) padder = InputPadder(t1.shape) t1, t2 = padder.pad(t1, t2) _, flow = model(t1, t2, iters=30, test_mode=True) flow = padder.unpad(flow)[0] return flow.permute(1, 2, 0).cpu().numpy() # [H, W, 2] # --- Extract a specific frame from the video --- def extract_frame(video_path, frame_number): cap = cv2.VideoCapture(video_path) cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number) ret, frame = cap.read() cap.release() if not ret: return None return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # --- Save user mask to disk --- def save_mask(data): if data is None: return None, "⚠️ No mask data received!" if isinstance(data, dict): mask = data.get("mask") else: mask = data if mask is None: return None, "⚠️ Mask missing!" if mask.ndim == 3: mask_gray = cv2.cvtColor(mask, cv2.COLOR_RGBA2GRAY) else: mask_gray = mask _, bin_mask = cv2.threshold(mask_gray, 1, 255, cv2.THRESH_BINARY) mask_path = "user_mask.png" cv2.imwrite(mask_path, bin_mask) return mask_path, f"✅ Saved mask ({np.count_nonzero(bin_mask)} painted pixels)" # --- Fill in the black region --- def stabilize_black_regions(input_video: str): """ Stabilize black region boundaries in a binary mask video. Fills small flickers and stabilizes edges without deforming boundaries. Args: input_video (str): path to input video (mask video with black & white pixels) """ kernel_fill = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)) kernel_edge = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) cap = cv2.VideoCapture(input_video) if not cap.isOpened(): raise FileNotFoundError(f"❌ Could not open video: {input_video}") fps = cap.get(cv2.CAP_PROP_FPS) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fourcc = cv2.VideoWriter_fourcc(*"mp4v") out = cv2.VideoWriter(STABILIZED_MASK, fourcc, fps, (width, height)) print(f"🎞 Processing {int(cap.get(cv2.CAP_PROP_FRAME_COUNT))} frames " f"({width}x{height} @ {fps:.1f} fps)...") while True: ret, frame = cap.read() if not ret: break gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) _, mask = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY) # Invert so black→white (foreground) inv = cv2.bitwise_not(mask) # Fill interior holes flood = inv.copy() h, w = inv.shape flood_mask = np.zeros((h + 2, w + 2), np.uint8) cv2.floodFill(flood, flood_mask, (0, 0), 255) holes = cv2.bitwise_not(flood) filled_inv = cv2.bitwise_or(inv, holes) # Stabilize edges (dilate → erode) closed = cv2.morphologyEx(filled_inv, cv2.MORPH_CLOSE, kernel_edge, iterations=1) # Remove small dots denoised = cv2.medianBlur(closed, 3) # Invert back (restore black region) result = cv2.bitwise_not(denoised) out.write(cv2.cvtColor(result, cv2.COLOR_GRAY2BGR)) cap.release() out.release() print(f"✅ Saved stabilized mask video: {STABILIZED_MASK}") return STABILIZED_MASK # --- RAFT Pixel Tracking based on user mask --- def run_tracking(video_path, mask_path, selection_mode="All Pixels"): model = load_raft_model(MODEL_PATH) cap = cv2.VideoCapture(video_path) if not cap.isOpened(): return "❌ Failed to open video.", None, None fps = cap.get(cv2.CAP_PROP_FPS) ret, first_frame = cap.read() if not ret: return "❌ Could not read first frame.", None, None H, W = first_frame.shape[:2] mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) mask = cv2.resize(mask, (W, H)) # Select points based on mode if selection_mode == "All Pixels": ys, xs = np.where(mask > 0) else: gray = cv2.cvtColor(first_frame, cv2.COLOR_BGR2GRAY) black_pixels = (gray < 30) # threshold for black combined = (mask > 0) & black_pixels ys, xs = np.where(combined) tracked_points = np.vstack((xs, ys)).T.astype(np.float32) print(f"🎯 Selected {len(tracked_points)} pixels under mode: {selection_mode}") # Writers for visualization and mask fourcc = cv2.VideoWriter_fourcc(*'mp4v') out_vis = cv2.VideoWriter(OUTPUT_VIDEO, fourcc, fps, (W, H)) out_mask = cv2.VideoWriter(OUTPUT_MASK_VIDEO, fourcc, fps, (W, H), isColor=False) prev_frame = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB) frame_idx = 1 while True: ret, curr_frame = cap.read() if not ret: break curr_rgb = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2RGB) flow = compute_flow(model, prev_frame, curr_rgb) vis = curr_rgb.copy() # Start with all-white background mask_frame = np.full((H, W), 255, dtype=np.uint8) new_points = [] for pt in tracked_points: px, py = int(round(pt[0])), int(round(pt[1])) if 0 <= px < W and 0 <= py < H: dx, dy = flow[py, px] nx, ny = pt[0] + dx, pt[1] + dy new_points.append([nx, ny]) # Visualization cv2.circle(vis, (int(nx), int(ny)), 1, (0, 255, 0), -1) # Mask (black pixel for tracked) mask_frame[int(ny), int(nx)] = 0 tracked_points = np.array(new_points, dtype=np.float32) out_vis.write(cv2.cvtColor(vis, cv2.COLOR_RGB2BGR)) out_mask.write(mask_frame) prev_frame = curr_rgb frame_idx += 1 cap.release() out_vis.release() out_mask.release() stabilize_mask = stabilize_black_regions(OUTPUT_MASK_VIDEO) return ( f"✅ Tracking complete ({selection_mode}).\nSaved:\n- {OUTPUT_VIDEO}\n- {OUTPUT_MASK_VIDEO}", OUTPUT_VIDEO, stabilize_mask ) # --- Gradio UI --- def build_app(): with gr.Blocks() as demo: gr.Markdown("## 🎯 RAFT Pixel Tracker with Brush-based ROI, Pixel Mode, and Mask Output (Inverted Mask)") with gr.Row(): video_in = gr.Video(label="🎞️ Upload Video") frame_num = gr.Number(label="Frame # to Paint", value=0, precision=0) load_btn = gr.Button("📸 Load Frame for Annotation") annot = gr.Image( label="🖌️ Paint ROI Mask", tool="sketch", type="numpy", image_mode="RGBA", height=480, ) pixel_mode = gr.Dropdown( choices=["All Pixels", "Only Black Pixels"], value="All Pixels", label="Pixel Selection Mode" ) save_btn = gr.Button("💾 Save Mask") run_btn = gr.Button("🚀 Run RAFT Tracking") log = gr.Textbox(label="Logs", lines=6) result_video = gr.Video(label="🎬 Visualization Video") mask_video = gr.Video(label="⬛ Tracked Mask Video (black = tracked pixels)") # Load frame load_btn.click( fn=lambda v, f: extract_frame(v.name if hasattr(v, "name") else v, int(f)), inputs=[video_in, frame_num], outputs=annot ) # Save mask save_btn.click( fn=save_mask, inputs=annot, outputs=[gr.State(), log] ) # Run tracking run_btn.click( fn=lambda v, m: run_tracking(v.name if hasattr(v, "name") else v, "user_mask.png", m), inputs=[video_in, pixel_mode], outputs=[log, result_video, mask_video] ) return demo if __name__ == "__main__": app = build_app() app.launch(server_name="0.0.0.0", server_port=7860, debug=True)