import os import sys import cv2 import math import time import torch import numpy as np import gradio as gr from tqdm import tqdm from pathlib import Path from collections import deque from argparse import Namespace from torchvision import transforms # === RAFT Setup === sys.path.append("/app/preprocess/RAFT/core") from raft import RAFT from utils.utils import InputPadder # === CONFIG === DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' MODEL_PATH = "/app/RAFT/raft-things.pth" OUTPUT_VIDEO = "/app/full_tracked_output.mp4" OUTPUT_MASK_VIDEO = "/app/mask_output.mp4" STABILIZED_MASK = "/app/stabilized_mask_output.mp4" REVERSED_INPUT = "/app/reversed_input.mp4" # ========================================================== # === VIDEO UTILITIES ===================================== # ========================================================== def reverse_video(input_path, output_path): """ Reverse frames robustly — preserves all readable frames even if OpenCV metadata is off by one. """ cap = cv2.VideoCapture(input_path) if not cap.isOpened(): raise FileNotFoundError(f"❌ Could not open video: {input_path}") fps = cap.get(cv2.CAP_PROP_FPS) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fourcc = cv2.VideoWriter_fourcc(*'mp4v') frames = [] while True: ret, frame = cap.read() if not ret: break frames.append(frame) cap.release() if len(frames) == 0: raise ValueError("No frames read from video!") out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) for frame in reversed(frames): out.write(frame) out.release() cv2.destroyAllWindows() print(f"✅ Reversed {len(frames)} frames → {output_path}") return output_path def reverse_video_file_inplace(path_in): """ Reverse a video in-place without losing frames. """ tmp_path = path_in.replace(".mp4", "_tmp.mp4") reverse_video(path_in, tmp_path) os.replace(tmp_path, path_in) print(f"🔁 Overwrote {path_in} with reversed version (same frame count).") # ========================================================== # === RAFT LOADING ========================================= # ========================================================== def load_raft_model(model_path): args = Namespace( small=False, mixed_precision=False, alternate_corr=False, dropout=0.0, max_depth=16, depth_network=False, depth_residual=False, depth_scale=1.0 ) model = torch.nn.DataParallel(RAFT(args)) model.load_state_dict(torch.load(model_path, map_location=DEVICE)) return model.module.to(DEVICE).eval() def to_tensor(image): return transforms.ToTensor()(image).unsqueeze(0).to(DEVICE) @torch.no_grad() def compute_flow(model, img1, img2): t1, t2 = to_tensor(img1), to_tensor(img2) padder = InputPadder(t1.shape) t1, t2 = padder.pad(t1, t2) _, flow = model(t1, t2, iters=30, test_mode=True) flow = padder.unpad(flow)[0] return flow.permute(1, 2, 0).cpu().numpy() # ========================================================== # === FRAME / MASK HELPERS ================================ # ========================================================== def extract_frame(video_path, frame_number): cap = cv2.VideoCapture(video_path) cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number) ret, frame = cap.read() cap.release() if not ret: return None return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) def save_mask(data): if data is None: return None, "⚠️ No mask data received!" if isinstance(data, dict): mask = data.get("mask") else: mask = data if mask is None: return None, "⚠️ Mask missing!" if mask.ndim == 3: mask_gray = cv2.cvtColor(mask, cv2.COLOR_RGBA2GRAY) else: mask_gray = mask _, bin_mask = cv2.threshold(mask_gray, 1, 255, cv2.THRESH_BINARY) mask_path = "user_mask.png" cv2.imwrite(mask_path, bin_mask) return mask_path, f"✅ Saved mask ({np.count_nonzero(bin_mask)} painted pixels)" # ========================================================== # === UPDATED DYNAMIC CROP LOGIC =========================== # ========================================================== def compute_crop_box_from_mask_dynamic(first_frame_bgr, mask_path, pad=200): """ Compute a square crop region based on mask region + padding. Ensures equal width & height. """ mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) if mask is None: raise FileNotFoundError(f"Mask not found: {mask_path}") H, W = mask.shape[:2] ys, xs = np.where(mask > 0) # Fallback to center crop if no mask if len(xs) == 0: cx, cy = W // 2, H // 2 size = min(W, H) // 2 return cx - size // 2, cy - size // 2, cx + size // 2, cy + size // 2 x_min, x_max = np.min(xs), np.max(xs) y_min, y_max = np.min(ys), np.max(ys) # Add padding x_min = max(0, x_min - pad) y_min = max(0, y_min - pad) x_max = min(W, x_max + pad) y_max = min(H, y_max + pad) # Make it square width = x_max - x_min height = y_max - y_min side = max(width, height) cx = (x_min + x_max) // 2 cy = (y_min + y_max) // 2 x_min = max(0, cx - side // 2) y_min = max(0, cy - side // 2) x_max = min(W, x_min + side) y_max = min(H, y_min + side) return int(x_min), int(y_min), int(x_max), int(y_max) def draw_crop_preview_on_frame(frame_rgb, crop_box, color=(0,255,0), thickness=2): x0, y0, x1, y1 = crop_box frame = frame_rgb.copy() cv2.rectangle(frame, (x0, y0), (x1, y1), color, thickness) return frame # ========================================================== # === STABILIZATION ======================================== # ========================================================== def stabilize_black_regions(input_video, output_path=STABILIZED_MASK, blend=0.3, sample_frames=10): """ Visually consistent black region stabilizer: - Repairs broken, thick edges and fills missing gaps. - Maintains consistent thickness and stable edges across frames. - Smooth temporal blending removes flicker and breathing effects. Args: input_video (str): Path to input mask video (black/white). output_path (str): Path to save stabilized video. blend (float): Temporal smoothing factor (0.0–1.0). sample_frames (int): Number of initial frames to sample for parameter estimation. """ cap = cv2.VideoCapture(input_video) if not cap.isOpened(): raise FileNotFoundError(f"Could not open video: {input_video}") fps = cap.get(cv2.CAP_PROP_FPS) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fourcc = cv2.VideoWriter_fourcc(*"mp4v") out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) # === Step 1: Estimate global morphology parameters from first N frames === thickness_samples = [] count = 0 while count < sample_frames: ret, frame = cap.read() if not ret: break gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) _, mask = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY_INV) dist = cv2.distanceTransform(mask, cv2.DIST_L2, 3) if np.any(mask > 0): thickness_samples.append(np.mean(dist[mask > 0])) count += 1 cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # rewind avg_thickness = np.median(thickness_samples) if thickness_samples else 5 k = int(np.clip(avg_thickness / 2.0, 3, 9)) kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k)) min_area = (width * height) * 0.0005 print(f"🧠 Fixed morphology parameters — kernel={k} | min_area={min_area:.1f}") prev_mask = None # === Step 2: Process all frames === while True: ret, frame = cap.read() if not ret: break gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) _, mask = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY_INV) # --- (A) Connectivity repair: bridge gaps & fill --- bridge_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5)) repaired = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, bridge_kernel, iterations=2) filled = cv2.morphologyEx(repaired, cv2.MORPH_CLOSE, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)), iterations=2) filled = cv2.morphologyEx(filled, cv2.MORPH_OPEN, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)), iterations=1) # --- (B) Edge thickness normalization --- dist = cv2.distanceTransform(cv2.bitwise_not(filled), cv2.DIST_L2, 3) normalized = (dist < avg_thickness * 1.2).astype(np.uint8) * 255 base_clean = cv2.bitwise_not(normalized) # --- (C) Morphological cleanup (fixed parameters) --- base_clean = cv2.morphologyEx(base_clean, cv2.MORPH_CLOSE, kernel, iterations=2) num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(base_clean, connectivity=8) filtered_mask = np.zeros_like(base_clean) for i in range(1, num_labels): area = stats[i, cv2.CC_STAT_AREA] component_mask = (labels == i).astype(np.uint8) * 255 if area >= min_area: filtered_mask = cv2.bitwise_or(filtered_mask, component_mask) else: # Merge small blobs softly merge_mask = cv2.dilate(component_mask, kernel, iterations=2) filtered_mask = cv2.bitwise_or(filtered_mask, merge_mask) # --- (D) Edge reinforcement --- edges = cv2.morphologyEx(filtered_mask, cv2.MORPH_GRADIENT, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))) reinforced = cv2.bitwise_or(filtered_mask, edges) reinforced = cv2.morphologyEx(reinforced, cv2.MORPH_CLOSE, kernel, iterations=2) reinforced = cv2.medianBlur(reinforced, 3) # --- (E) Temporal stabilization --- if prev_mask is not None: reinforced = cv2.addWeighted(reinforced, 1 - blend, prev_mask, blend, 0) reinforced = (reinforced > 127).astype(np.uint8) * 255 # re-binarize prev_mask = reinforced.copy() # Invert back to black region mask # clean = cv2.bitwise_not(reinforced) out.write(cv2.cvtColor(reinforced, cv2.COLOR_GRAY2BGR)) cap.release() out.release() print(f"✅ Visually stable and connected mask saved: {output_path}") return output_path # ========================================================== # === TRACKING ============================================= # ========================================================== def run_tracking(video_path, mask_path, selection_mode="All Pixels"): BLACK_THRESH = 1 HISTORY_LEN = 5 # --- Reverse input for backward tracking --- reversed_path = reverse_video(video_path, REVERSED_INPUT) cap = cv2.VideoCapture(reversed_path) model = load_raft_model(MODEL_PATH) fps = cap.get(cv2.CAP_PROP_FPS) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) print(f"🎞️ Input video: {total_frames} frames at {fps:.2f} FPS") ret, first_frame = cap.read() if not ret: return "❌ Could not read first frame.", None, None, None H, W = first_frame.shape[:2] # --- Compute dynamic square crop from mask --- x0, y0, x1, y1 = compute_crop_box_from_mask_dynamic(first_frame, mask_path, pad=200) cw, ch = x1 - x0, y1 - y0 fourcc = cv2.VideoWriter_fourcc(*'mp4v') out_vis = cv2.VideoWriter(OUTPUT_VIDEO, fourcc, fps, (W, H)) out_mask = cv2.VideoWriter(OUTPUT_MASK_VIDEO, fourcc, fps, (W, H), isColor=False) full_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) full_mask = cv2.resize(full_mask, (W, H), interpolation=cv2.INTER_NEAREST) crop_mask = full_mask[y0:y1, x0:x1] if selection_mode == "All Pixels": ys, xs = np.where(crop_mask > 0) else: gray_first = cv2.cvtColor(first_frame, cv2.COLOR_BGR2GRAY) black_pixels = (gray_first[y0:y1, x0:x1] < BLACK_THRESH) combined = (crop_mask > 0) & black_pixels ys, xs = np.where(combined) tracked_points = np.vstack((xs, ys)).T.astype(np.float32) prev_full_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB) prev_crop_rgb = prev_full_rgb[y0:y1, x0:x1] history = deque([True]*HISTORY_LEN, maxlen=HISTORY_LEN) stopped = False frame_idx = 0 curr_full_rgb = None # === Main tracking loop === while True: ret, curr_frame = cap.read() if not ret: break frame_idx += 1 curr_full_rgb = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2RGB) curr_crop_rgb = curr_full_rgb[y0:y1, x0:x1] gray_crop = cv2.cvtColor(curr_crop_rgb, cv2.COLOR_RGB2GRAY) # --- Optical flow between prev and curr --- flow_crop = compute_flow(model, prev_crop_rgb, curr_crop_rgb) vis_full = curr_full_rgb.copy() mask_full = np.full((H, W), 255, dtype=np.uint8) # --- Move tracked points --- new_points = [] for pt in tracked_points: px, py = int(pt[0]), int(pt[1]) if 0 <= px < cw and 0 <= py < ch: dx, dy = flow_crop[py, px] nx, ny = pt[0] + dx, pt[1] + dy nx = np.clip(nx, 0, cw-1) ny = np.clip(ny, 0, ch-1) new_points.append([nx, ny]) tracked_points = np.array(new_points, dtype=np.float32) # --- Detect black pixels --- black_mask = gray_crop < BLACK_THRESH black_indices = tracked_points.astype(int) has_black = any( 0 <= px < cw and 0 <= py < ch and black_mask[py, px] for px, py in black_indices ) history.append(has_black) # --- Painting logic --- if stopped: paint = False elif has_black: paint = True elif not any(history): # last N all False stopped = True paint = False else: paint = True # --- Paint or skip --- if paint: for pt in tracked_points: fx, fy = int(pt[0] + x0), int(pt[1] + y0) if 0 <= fx < W and 0 <= fy < H: cv2.circle(vis_full, (fx, fy), 1, (0,255,0), -1) mask_full[fy, fx] = 0 out_vis.write(cv2.cvtColor(vis_full, cv2.COLOR_RGB2BGR)) out_mask.write(mask_full) prev_crop_rgb = curr_crop_rgb if frame_idx % 10 == 0: print(f"Frame {frame_idx}: {'PAINT' if paint else 'NO-PAINT'} | has_black={has_black} | stopped={stopped}") # === Add final static frame to preserve frame count === try: if curr_full_rgb is not None: out_vis.write(cv2.cvtColor(curr_full_rgb, cv2.COLOR_RGB2BGR)) out_mask.write(mask_full) print("🧩 Added final frame to preserve total frame count.") except Exception as e: print(f"⚠️ Could not add final frame: {e}") cap.release() out_vis.release() out_mask.release() # === Post-process: stabilization + reversal === stabilize_black_regions(OUTPUT_MASK_VIDEO) reverse_video_file_inplace(OUTPUT_VIDEO) reverse_video_file_inplace(OUTPUT_MASK_VIDEO) reverse_video_file_inplace(STABILIZED_MASK) # === Verify output frame counts === for path in [OUTPUT_VIDEO, OUTPUT_MASK_VIDEO, STABILIZED_MASK]: cap_test = cv2.VideoCapture(path) n = int(cap_test.get(cv2.CAP_PROP_FRAME_COUNT)) cap_test.release() print(f"✅ Verified {os.path.basename(path)} → {n} frames") return ( f"✅ Tracking complete ({selection_mode}).\n" f"Square Crop {cw}x{ch} @ ({x0},{y0}) with padding=200\n" f"Painting stopped={'Yes' if stopped else 'No'} after {frame_idx} processed frames.\n" f"All outputs now match input frame count ({total_frames}).", OUTPUT_VIDEO, OUTPUT_MASK_VIDEO, STABILIZED_MASK ) # ========================================================== # === GRADIO APP =========================================== # ========================================================== def build_app(): with gr.Blocks() as demo: gr.Markdown("# 🎯 Pixel Tracker (Dynamic Square Crop)") with gr.Row(): video_in = gr.Video(label="🎞️ Upload Video") frame_num = gr.Number(value=0, visible=False) load_btn = gr.Button("📸 Load Frame for Annotation") annot = gr.Image(label="🖌️ Paint ROI Mask", tool="sketch", type="numpy", image_mode="RGBA", height=1000) save_btn = gr.Button("💾 Save Mask") log = gr.Textbox(label="Logs", lines=8) preview_btn = gr.Button("🔎 Preview Crop", visible=True) with gr.Row(): preview_frame = gr.Image(label="Preview Frame", visible=False) preview_crop = gr.Image(label="Cropped Region", visible=True) run_btn = gr.Button("🚀 Run Tracking") with gr.Row(): result_video = gr.Video(label="🎬 Result (Forward)") mask_video = gr.Video(label="⬛ Mask (Forward)") stabilized_video = gr.Video(label="🧱 Stabilized (Forward)") def load_reversed_frame(v, f): reversed_path = reverse_video(v.name if hasattr(v, "name") else v, REVERSED_INPUT) return extract_frame(reversed_path, int(f)) load_btn.click(load_reversed_frame, [video_in, frame_num], annot) save_btn.click(save_mask, annot, [gr.State(), log]) def preview_crop_fn(v): reversed_path = reverse_video(v.name if hasattr(v, "name") else v, REVERSED_INPUT) frame0 = extract_frame(reversed_path, 0) if frame0 is None or not os.path.exists("user_mask.png"): return None, None, "⚠️ Paint and Save Mask first." x0,y0,x1,y1 = compute_crop_box_from_mask_dynamic(cv2.cvtColor(frame0, cv2.COLOR_RGB2BGR), "user_mask.png", pad=200) frame_box = draw_crop_preview_on_frame(frame0, (x0,y0,x1,y1)) return frame_box, frame0[y0:y1, x0:x1], f"Square crop {x1-x0}x{y1-y0} at ({x0},{y0})" preview_btn.click(preview_crop_fn, video_in, [preview_frame, preview_crop, log]) def run_btn_fn(v, m): if not os.path.exists("user_mask.png"): return "⚠️ Save Mask first.", None, None, None return run_tracking(v.name if hasattr(v, "name") else v, "user_mask.png", m) run_btn.click(run_btn_fn, [video_in, gr.Dropdown(["All Pixels", "Only Black Pixels"], value="All Pixels")], [log, result_video, mask_video, stabilized_video]) return demo if __name__ == "__main__": app = build_app() app.launch(server_name="0.0.0.0", server_port=7860, debug=True)