pixels_tracking / track-pixels_gradio.py
isambalghari's picture
editing tracker
60bbdb2
raw
history blame
29.9 kB
# import os
# import sys
# import cv2
# import math
# import time
# import torch
# import numpy as np
# import gradio as gr
# from tqdm import tqdm
# from pathlib import Path
# from collections import deque
# from argparse import Namespace
# from torchvision import transforms
# # === RAFT Setup ===
# sys.path.append("/app/preprocess/RAFT/core")
# from raft import RAFT
# from utils.utils import InputPadder
# # === CONFIG ===
# DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# MODEL_PATH = "/app/RAFT/raft-things.pth"
# OUTPUT_VIDEO = "/app/full_tracked_output.mp4"
# OUTPUT_MASK_VIDEO = "/app/mask_output.mp4"
# STABILIZED_MASK = "/app/stabilized_mask_output.mp4"
# REVERSED_INPUT = "/app/reversed_input.mp4"
# # ==========================================================
# # === VIDEO UTILITIES =====================================
# # ==========================================================
# def reverse_video(input_path, output_path):
# """Reverse frames of input video and save as output."""
# cap = cv2.VideoCapture(input_path)
# if not cap.isOpened():
# raise FileNotFoundError(f"❌ Could not open video: {input_path}")
# fps = cap.get(cv2.CAP_PROP_FPS)
# width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
# height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# fourcc = cv2.VideoWriter_fourcc(*'mp4v')
# out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
# frames = []
# while True:
# ret, frame = cap.read()
# if not ret:
# break
# frames.append(frame)
# cap.release()
# # Write reversed frames
# for frame in reversed(frames):
# out.write(frame)
# out.release()
# print(f"πŸ” Video reversed and saved: {output_path}")
# return output_path
# def reverse_video_file_inplace(path_in):
# """Reverse an existing video and overwrite it."""
# tmp_path = path_in.replace(".mp4", "_tmp.mp4")
# reverse_video(path_in, tmp_path)
# os.replace(tmp_path, path_in)
# # ==========================================================
# # === RAFT LOADING =========================================
# # ==========================================================
# def load_raft_model(model_path):
# args = Namespace(
# small=False,
# mixed_precision=False,
# alternate_corr=False,
# dropout=0.0,
# max_depth=16,
# depth_network=False,
# depth_residual=False,
# depth_scale=1.0
# )
# model = torch.nn.DataParallel(RAFT(args))
# model.load_state_dict(torch.load(model_path, map_location=DEVICE))
# return model.module.to(DEVICE).eval()
# def to_tensor(image):
# return transforms.ToTensor()(image).unsqueeze(0).to(DEVICE)
# @torch.no_grad()
# def compute_flow(model, img1, img2):
# t1, t2 = to_tensor(img1), to_tensor(img2)
# padder = InputPadder(t1.shape)
# t1, t2 = padder.pad(t1, t2)
# _, flow = model(t1, t2, iters=30, test_mode=True)
# flow = padder.unpad(flow)[0]
# return flow.permute(1, 2, 0).cpu().numpy()
# # ==========================================================
# # === FRAME / MASK HELPERS ================================
# # ==========================================================
# def extract_frame(video_path, frame_number):
# cap = cv2.VideoCapture(video_path)
# cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
# ret, frame = cap.read()
# cap.release()
# if not ret:
# return None
# return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# def save_mask(data):
# if data is None:
# return None, "⚠️ No mask data received!"
# if isinstance(data, dict):
# mask = data.get("mask")
# else:
# mask = data
# if mask is None:
# return None, "⚠️ Mask missing!"
# if mask.ndim == 3:
# mask_gray = cv2.cvtColor(mask, cv2.COLOR_RGBA2GRAY)
# else:
# mask_gray = mask
# _, bin_mask = cv2.threshold(mask_gray, 1, 255, cv2.THRESH_BINARY)
# mask_path = "user_mask.png"
# cv2.imwrite(mask_path, bin_mask)
# return mask_path, f"βœ… Saved mask ({np.count_nonzero(bin_mask)} painted pixels)"
# # ==========================================================
# # === CROP HELPERS =========================================
# # ==========================================================
# def get_mask_center(mask_path):
# mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
# if mask is None:
# raise FileNotFoundError("Mask not found: " + mask_path)
# ys, xs = np.where(mask > 0)
# h, w = mask.shape[:2]
# if len(xs) == 0:
# return w // 2, h // 2
# return int(np.mean(xs)), int(np.mean(ys))
# def clamp_crop(x0, y0, cw, ch, W, H):
# x0 = max(0, min(x0, W - 1))
# y0 = max(0, min(y0, H - 1))
# x1 = x0 + cw
# y1 = y0 + ch
# if x1 > W:
# x0 -= (x1 - W)
# x1 = W
# if y1 > H:
# y0 -= (y1 - H)
# y1 = H
# return x0, y0, x1, y1
# def compute_crop_box_from_mask(first_frame_bgr, mask_path, crop_w=400, crop_h=400):
# H, W = first_frame_bgr.shape[:2]
# cx, cy = get_mask_center(mask_path)
# x0 = cx - crop_w // 2
# y0 = cy - crop_h // 2
# return clamp_crop(x0, y0, crop_w, crop_h, W, H)
# def draw_crop_preview_on_frame(frame_rgb, crop_box, color=(0,255,0), thickness=2):
# x0, y0, x1, y1 = crop_box
# frame = frame_rgb.copy()
# cv2.rectangle(frame, (x0, y0), (x1, y1), color, thickness)
# return frame
# # ==========================================================
# # === STABILIZATION ========================================
# # ==========================================================
# def stabilize_black_regions(input_video):
# # === Define kernels ===
# kernel_fill = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
# kernel_edge = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
# cap = cv2.VideoCapture(input_video)
# if not cap.isOpened():
# raise FileNotFoundError(f"❌ Could not open video: {input_video}")
# fps = cap.get(cv2.CAP_PROP_FPS)
# width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
# height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# fourcc = cv2.VideoWriter_fourcc(*"mp4v")
# out = cv2.VideoWriter(STABILIZED_MASK, fourcc, fps, (width, height))
# while True:
# ret, frame = cap.read()
# if not ret:
# break
# # Convert to grayscale and threshold
# gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
# _, mask = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)
# # === Step 1: Fill black regions ===
# inv = cv2.bitwise_not(mask)
# flood = inv.copy()
# h, w = inv.shape
# flood_mask = np.zeros((h + 2, w + 2), np.uint8)
# cv2.floodFill(flood, flood_mask, (0, 0), 255)
# holes = cv2.bitwise_not(flood)
# filled = cv2.bitwise_or(inv, holes)
# filled = cv2.bitwise_not(filled)
# # === Step 2: Morphological stabilization ===
# # Fill small black holes and unify mask
# stable = cv2.morphologyEx(filled, cv2.MORPH_CLOSE, kernel_fill, iterations=1)
# # Smooth jagged edges
# stable = cv2.morphologyEx(stable, cv2.MORPH_OPEN, kernel_edge, iterations=1)
# # Write result
# out.write(cv2.cvtColor(stable, cv2.COLOR_GRAY2BGR))
# cap.release()
# out.release()
# print(f"βœ… Stabilized mask saved: {STABILIZED_MASK}")
# return STABILIZED_MASK
# # ==========================================================
# # === TRACKING =============================================
# # ==========================================================
# def run_tracking(video_path, mask_path, selection_mode="All Pixels", crop_w=400, crop_h=400):
# BLACK_THRESH = 1
# HISTORY_LEN = 5
# reversed_path = reverse_video(video_path, REVERSED_INPUT)
# cap = cv2.VideoCapture(reversed_path)
# model = load_raft_model(MODEL_PATH)
# fps = cap.get(cv2.CAP_PROP_FPS)
# ret, first_frame = cap.read()
# if not ret:
# return "❌ Could not read first frame.", None, None, None
# H, W = first_frame.shape[:2]
# x0, y0, x1, y1 = compute_crop_box_from_mask(first_frame, mask_path, crop_w, crop_h)
# cw, ch = x1 - x0, y1 - y0
# fourcc = cv2.VideoWriter_fourcc(*'mp4v')
# out_vis = cv2.VideoWriter(OUTPUT_VIDEO, fourcc, fps, (W, H))
# out_mask = cv2.VideoWriter(OUTPUT_MASK_VIDEO, fourcc, fps, (W, H), isColor=False)
# full_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
# full_mask = cv2.resize(full_mask, (W, H), interpolation=cv2.INTER_NEAREST)
# crop_mask = full_mask[y0:y1, x0:x1]
# if selection_mode == "All Pixels":
# ys, xs = np.where(crop_mask > 0)
# else:
# gray_first = cv2.cvtColor(first_frame, cv2.COLOR_BGR2GRAY)
# black_pixels = (gray_first[y0:y1, x0:x1] < BLACK_THRESH)
# combined = (crop_mask > 0) & black_pixels
# ys, xs = np.where(combined)
# tracked_points = np.vstack((xs, ys)).T.astype(np.float32)
# prev_full_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
# prev_crop_rgb = prev_full_rgb[y0:y1, x0:x1]
# # Frame-level history (deque of last 5 black-region detections)
# history = deque([True]*HISTORY_LEN, maxlen=HISTORY_LEN)
# stopped = False
# frame_idx = 0
# while True:
# ret, curr_frame = cap.read()
# if not ret:
# break
# frame_idx += 1
# curr_full_rgb = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2RGB)
# curr_crop_rgb = curr_full_rgb[y0:y1, x0:x1]
# gray_crop = cv2.cvtColor(curr_crop_rgb, cv2.COLOR_RGB2GRAY)
# # --- Compute optical flow ---
# flow_crop = compute_flow(model, prev_crop_rgb, curr_crop_rgb)
# vis_full = curr_full_rgb.copy()
# mask_full = np.full((H, W), 255, dtype=np.uint8)
# # --- Move points ---
# new_points = []
# for pt in tracked_points:
# px, py = int(pt[0]), int(pt[1])
# if 0 <= px < cw and 0 <= py < ch:
# dx, dy = flow_crop[py, px]
# nx, ny = pt[0] + dx, pt[1] + dy
# nx = np.clip(nx, 0, cw-1)
# ny = np.clip(ny, 0, ch-1)
# new_points.append([nx, ny])
# tracked_points = np.array(new_points, dtype=np.float32)
# # --- Check black presence ---
# black_mask = gray_crop < BLACK_THRESH
# black_indices = tracked_points.astype(int)
# has_black = False
# for (px, py) in black_indices:
# if 0 <= px < cw and 0 <= py < ch:
# if black_mask[py, px]:
# has_black = True
# break
# history.append(has_black)
# # --- Determine painting condition ---
# if stopped:
# paint = False
# elif has_black:
# paint = True
# elif not any(history): # last 5 all False
# stopped = True
# paint = False
# else:
# paint = True
# # --- Paint or skip ---
# if paint:
# for pt in tracked_points:
# fx, fy = int(pt[0] + x0), int(pt[1] + y0)
# if 0 <= fx < W and 0 <= fy < H:
# cv2.circle(vis_full, (fx, fy), 1, (0,255,0), -1)
# mask_full[fy, fx] = 0
# else:
# # no painting this frame
# pass
# out_vis.write(cv2.cvtColor(vis_full, cv2.COLOR_RGB2BGR))
# out_mask.write(mask_full)
# prev_crop_rgb = curr_crop_rgb
# # Optional: progress log
# if frame_idx % 10 == 0:
# print(f"Frame {frame_idx}: {'PAINT' if paint else 'NO-PAINT'} | has_black={has_black} | stopped={stopped}")
# cap.release()
# out_vis.release()
# out_mask.release()
# stabilize_black_regions(OUTPUT_MASK_VIDEO)
# reverse_video_file_inplace(OUTPUT_VIDEO)
# reverse_video_file_inplace(OUTPUT_MASK_VIDEO)
# reverse_video_file_inplace(STABILIZED_MASK)
# return (
# f"βœ… Tracking complete ({selection_mode}).\n"
# f"Crop {cw}x{ch} @ ({x0},{y0})\n"
# f"Painting stopped={'Yes' if stopped else 'No'} after {frame_idx} frames.\n"
# "Saved outputs reversed back to forward order.",
# OUTPUT_VIDEO,
# OUTPUT_MASK_VIDEO,
# STABILIZED_MASK
# )
# # ==========================================================
# # === GRADIO APP ===========================================
# # ==========================================================
# def build_app():
# with gr.Blocks() as demo:
# gr.Markdown("# 🎯 Pixel Tracker ")
# with gr.Row():
# video_in = gr.Video(label="🎞️ Upload Video")
# frame_num = gr.Number(value=0, visible=False)
# load_btn = gr.Button("πŸ“Έ Load Frame for Annotation")
# annot = gr.Image(label="πŸ–ŒοΈ Paint ROI Mask", tool="sketch", type="numpy", image_mode="RGBA", height=1000)
# save_btn = gr.Button("πŸ’Ύ Save Mask")
# log = gr.Textbox(label="Logs", lines=8)
# with gr.Row():
# pixel_mode = gr.Dropdown(["All Pixels", "Only Black Pixels"], value="All Pixels")
# crop_w = gr.Number(value=400, label="Crop Width")
# crop_h = gr.Number(value=400, label="Crop Height")
# preview_btn = gr.Button("πŸ”Ž Preview Crop")
# with gr.Row():
# preview_frame = gr.Image(label="Preview Frame")
# preview_crop = gr.Image(label="Cropped Region")
# run_btn = gr.Button("πŸš€ Run Tracking")
# with gr.Row():
# result_video = gr.Video(label="🎬 Result (Forward)")
# mask_video = gr.Video(label="⬛ Mask (Forward)")
# stabilized_video = gr.Video(label="🧱 Stabilized (Forward)")
# # Load reversed frame for painting
# def load_reversed_frame(v, f):
# reversed_path = reverse_video(v.name if hasattr(v, "name") else v, REVERSED_INPUT)
# return extract_frame(reversed_path, int(f))
# load_btn.click(load_reversed_frame, [video_in, frame_num], annot)
# save_btn.click(save_mask, annot, [gr.State(), log])
# def preview_crop_fn(v, cw, ch):
# reversed_path = reverse_video(v.name if hasattr(v, "name") else v, REVERSED_INPUT)
# frame0 = extract_frame(reversed_path, 0)
# if frame0 is None or not os.path.exists("user_mask.png"):
# return None, None, "⚠️ Paint and Save Mask first."
# x0,y0,x1,y1 = compute_crop_box_from_mask(cv2.cvtColor(frame0, cv2.COLOR_RGB2BGR), "user_mask.png", int(cw), int(ch))
# frame_box = draw_crop_preview_on_frame(frame0, (x0,y0,x1,y1))
# return frame_box, frame0[y0:y1, x0:x1], f"Crop {cw}x{ch} at ({x0},{y0})"
# preview_btn.click(preview_crop_fn, [video_in, crop_w, crop_h], [preview_frame, preview_crop, log])
# def run_btn_fn(v, m, cw, ch):
# if not os.path.exists("user_mask.png"):
# return "⚠️ Save Mask first.", None, None, None
# return run_tracking(v.name if hasattr(v, "name") else v, "user_mask.png", m, int(cw), int(ch))
# run_btn.click(run_btn_fn, [video_in, pixel_mode, crop_w, crop_h],
# [log, result_video, mask_video, stabilized_video])
# return demo
# if __name__ == "__main__":
# app = build_app()
# app.launch(server_name="0.0.0.0", server_port=7861, debug=True)
import os
import sys
import cv2
import math
import time
import torch
import numpy as np
import gradio as gr
from tqdm import tqdm
from pathlib import Path
from collections import deque
from argparse import Namespace
from torchvision import transforms
# === RAFT Setup ===
sys.path.append("/app/preprocess/RAFT/core")
from raft import RAFT
from utils.utils import InputPadder
# === CONFIG ===
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
MODEL_PATH = "/app/RAFT/raft-things.pth"
OUTPUT_VIDEO = "/app/full_tracked_output.mp4"
OUTPUT_MASK_VIDEO = "/app/mask_output.mp4"
STABILIZED_MASK = "/app/stabilized_mask_output.mp4"
REVERSED_INPUT = "/app/reversed_input.mp4"
# ==========================================================
# === VIDEO UTILITIES =====================================
# ==========================================================
def reverse_video(input_path, output_path):
cap = cv2.VideoCapture(input_path)
if not cap.isOpened():
raise FileNotFoundError(f"❌ Could not open video: {input_path}")
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
frames = []
while True:
ret, frame = cap.read()
if not ret:
break
frames.append(frame)
cap.release()
for frame in reversed(frames):
out.write(frame)
out.release()
print(f"πŸ” Video reversed and saved: {output_path}")
return output_path
def reverse_video_file_inplace(path_in):
tmp_path = path_in.replace(".mp4", "_tmp.mp4")
reverse_video(path_in, tmp_path)
os.replace(tmp_path, path_in)
# ==========================================================
# === RAFT LOADING =========================================
# ==========================================================
def load_raft_model(model_path):
args = Namespace(
small=False,
mixed_precision=False,
alternate_corr=False,
dropout=0.0,
max_depth=16,
depth_network=False,
depth_residual=False,
depth_scale=1.0
)
model = torch.nn.DataParallel(RAFT(args))
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
return model.module.to(DEVICE).eval()
def to_tensor(image):
return transforms.ToTensor()(image).unsqueeze(0).to(DEVICE)
@torch.no_grad()
def compute_flow(model, img1, img2):
t1, t2 = to_tensor(img1), to_tensor(img2)
padder = InputPadder(t1.shape)
t1, t2 = padder.pad(t1, t2)
_, flow = model(t1, t2, iters=30, test_mode=True)
flow = padder.unpad(flow)[0]
return flow.permute(1, 2, 0).cpu().numpy()
# ==========================================================
# === FRAME / MASK HELPERS ================================
# ==========================================================
def extract_frame(video_path, frame_number):
cap = cv2.VideoCapture(video_path)
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
ret, frame = cap.read()
cap.release()
if not ret:
return None
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
def save_mask(data):
if data is None:
return None, "⚠️ No mask data received!"
if isinstance(data, dict):
mask = data.get("mask")
else:
mask = data
if mask is None:
return None, "⚠️ Mask missing!"
if mask.ndim == 3:
mask_gray = cv2.cvtColor(mask, cv2.COLOR_RGBA2GRAY)
else:
mask_gray = mask
_, bin_mask = cv2.threshold(mask_gray, 1, 255, cv2.THRESH_BINARY)
mask_path = "user_mask.png"
cv2.imwrite(mask_path, bin_mask)
return mask_path, f"βœ… Saved mask ({np.count_nonzero(bin_mask)} painted pixels)"
# ==========================================================
# === UPDATED DYNAMIC CROP LOGIC ===========================
# ==========================================================
def compute_crop_box_from_mask_dynamic(first_frame_bgr, mask_path, pad=200):
"""
Compute a square crop region based on mask region + padding.
Ensures equal width & height.
"""
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
if mask is None:
raise FileNotFoundError(f"Mask not found: {mask_path}")
H, W = mask.shape[:2]
ys, xs = np.where(mask > 0)
# Fallback to center crop if no mask
if len(xs) == 0:
cx, cy = W // 2, H // 2
size = min(W, H) // 2
return cx - size // 2, cy - size // 2, cx + size // 2, cy + size // 2
x_min, x_max = np.min(xs), np.max(xs)
y_min, y_max = np.min(ys), np.max(ys)
# Add padding
x_min = max(0, x_min - pad)
y_min = max(0, y_min - pad)
x_max = min(W, x_max + pad)
y_max = min(H, y_max + pad)
# Make it square
width = x_max - x_min
height = y_max - y_min
side = max(width, height)
cx = (x_min + x_max) // 2
cy = (y_min + y_max) // 2
x_min = max(0, cx - side // 2)
y_min = max(0, cy - side // 2)
x_max = min(W, x_min + side)
y_max = min(H, y_min + side)
return int(x_min), int(y_min), int(x_max), int(y_max)
def draw_crop_preview_on_frame(frame_rgb, crop_box, color=(0,255,0), thickness=2):
x0, y0, x1, y1 = crop_box
frame = frame_rgb.copy()
cv2.rectangle(frame, (x0, y0), (x1, y1), color, thickness)
return frame
# ==========================================================
# === STABILIZATION ========================================
# ==========================================================
def stabilize_black_regions(input_video):
kernel_fill = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
kernel_edge = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
cap = cv2.VideoCapture(input_video)
if not cap.isOpened():
raise FileNotFoundError(f"❌ Could not open video: {input_video}")
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(STABILIZED_MASK, fourcc, fps, (width, height))
while True:
ret, frame = cap.read()
if not ret:
break
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
_, mask = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)
inv = cv2.bitwise_not(mask)
flood = inv.copy()
h, w = inv.shape
flood_mask = np.zeros((h + 2, w + 2), np.uint8)
cv2.floodFill(flood, flood_mask, (0, 0), 255)
holes = cv2.bitwise_not(flood)
filled = cv2.bitwise_or(inv, holes)
filled = cv2.bitwise_not(filled)
stable = cv2.morphologyEx(filled, cv2.MORPH_CLOSE, kernel_fill, iterations=1)
stable = cv2.morphologyEx(stable, cv2.MORPH_OPEN, kernel_edge, iterations=1)
out.write(cv2.cvtColor(stable, cv2.COLOR_GRAY2BGR))
cap.release()
out.release()
print(f"βœ… Stabilized mask saved: {STABILIZED_MASK}")
return STABILIZED_MASK
# ==========================================================
# === TRACKING =============================================
# ==========================================================
def run_tracking(video_path, mask_path, selection_mode="All Pixels"):
BLACK_THRESH = 1
HISTORY_LEN = 5
reversed_path = reverse_video(video_path, REVERSED_INPUT)
cap = cv2.VideoCapture(reversed_path)
model = load_raft_model(MODEL_PATH)
fps = cap.get(cv2.CAP_PROP_FPS)
ret, first_frame = cap.read()
if not ret:
return "❌ Could not read first frame.", None, None, None
H, W = first_frame.shape[:2]
x0, y0, x1, y1 = compute_crop_box_from_mask_dynamic(first_frame, mask_path, pad=200)
cw, ch = x1 - x0, y1 - y0
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out_vis = cv2.VideoWriter(OUTPUT_VIDEO, fourcc, fps, (W, H))
out_mask = cv2.VideoWriter(OUTPUT_MASK_VIDEO, fourcc, fps, (W, H), isColor=False)
full_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
full_mask = cv2.resize(full_mask, (W, H), interpolation=cv2.INTER_NEAREST)
crop_mask = full_mask[y0:y1, x0:x1]
if selection_mode == "All Pixels":
ys, xs = np.where(crop_mask > 0)
else:
gray_first = cv2.cvtColor(first_frame, cv2.COLOR_BGR2GRAY)
black_pixels = (gray_first[y0:y1, x0:x1] < BLACK_THRESH)
combined = (crop_mask > 0) & black_pixels
ys, xs = np.where(combined)
tracked_points = np.vstack((xs, ys)).T.astype(np.float32)
prev_full_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
prev_crop_rgb = prev_full_rgb[y0:y1, x0:x1]
history = deque([True]*HISTORY_LEN, maxlen=HISTORY_LEN)
stopped = False
frame_idx = 0
while True:
ret, curr_frame = cap.read()
if not ret:
break
frame_idx += 1
curr_full_rgb = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2RGB)
curr_crop_rgb = curr_full_rgb[y0:y1, x0:x1]
gray_crop = cv2.cvtColor(curr_crop_rgb, cv2.COLOR_RGB2GRAY)
flow_crop = compute_flow(model, prev_crop_rgb, curr_crop_rgb)
vis_full = curr_full_rgb.copy()
mask_full = np.full((H, W), 255, dtype=np.uint8)
new_points = []
for pt in tracked_points:
px, py = int(pt[0]), int(pt[1])
if 0 <= px < cw and 0 <= py < ch:
dx, dy = flow_crop[py, px]
nx, ny = pt[0] + dx, pt[1] + dy
nx = np.clip(nx, 0, cw-1)
ny = np.clip(ny, 0, ch-1)
new_points.append([nx, ny])
tracked_points = np.array(new_points, dtype=np.float32)
black_mask = gray_crop < BLACK_THRESH
black_indices = tracked_points.astype(int)
has_black = any(
0 <= px < cw and 0 <= py < ch and black_mask[py, px]
for px, py in black_indices
)
history.append(has_black)
if stopped:
paint = False
elif has_black:
paint = True
elif not any(history):
stopped = True
paint = False
else:
paint = True
if paint:
for pt in tracked_points:
fx, fy = int(pt[0] + x0), int(pt[1] + y0)
if 0 <= fx < W and 0 <= fy < H:
cv2.circle(vis_full, (fx, fy), 1, (0,255,0), -1)
mask_full[fy, fx] = 0
out_vis.write(cv2.cvtColor(vis_full, cv2.COLOR_RGB2BGR))
out_mask.write(mask_full)
prev_crop_rgb = curr_crop_rgb
if frame_idx % 10 == 0:
print(f"Frame {frame_idx}: {'PAINT' if paint else 'NO-PAINT'} | has_black={has_black} | stopped={stopped}")
cap.release()
out_vis.release()
out_mask.release()
stabilize_black_regions(OUTPUT_MASK_VIDEO)
reverse_video_file_inplace(OUTPUT_VIDEO)
reverse_video_file_inplace(OUTPUT_MASK_VIDEO)
reverse_video_file_inplace(STABILIZED_MASK)
return (
f"βœ… Tracking complete ({selection_mode}).\n"
f"Square Crop {cw}x{ch} @ ({x0},{y0}) with padding=100\n"
f"Painting stopped={'Yes' if stopped else 'No'} after {frame_idx} frames.\n"
"Saved outputs reversed back to forward order.",
OUTPUT_VIDEO,
OUTPUT_MASK_VIDEO,
STABILIZED_MASK
)
# ==========================================================
# === GRADIO APP ===========================================
# ==========================================================
def build_app():
with gr.Blocks() as demo:
gr.Markdown("# 🎯 Pixel Tracker (Dynamic Square Crop)")
with gr.Row():
video_in = gr.Video(label="🎞️ Upload Video")
frame_num = gr.Number(value=0, visible=False)
load_btn = gr.Button("πŸ“Έ Load Frame for Annotation")
annot = gr.Image(label="πŸ–ŒοΈ Paint ROI Mask", tool="sketch", type="numpy", image_mode="RGBA", height=1000)
save_btn = gr.Button("πŸ’Ύ Save Mask")
log = gr.Textbox(label="Logs", lines=8)
preview_btn = gr.Button("πŸ”Ž Preview Crop", visible=False)
with gr.Row():
preview_frame = gr.Image(label="Preview Frame", visible=False)
preview_crop = gr.Image(label="Cropped Region", visible=False)
run_btn = gr.Button("πŸš€ Run Tracking")
with gr.Row():
result_video = gr.Video(label="🎬 Result (Forward)")
mask_video = gr.Video(label="⬛ Mask (Forward)")
stabilized_video = gr.Video(label="🧱 Stabilized (Forward)")
def load_reversed_frame(v, f):
reversed_path = reverse_video(v.name if hasattr(v, "name") else v, REVERSED_INPUT)
return extract_frame(reversed_path, int(f))
load_btn.click(load_reversed_frame, [video_in, frame_num], annot)
save_btn.click(save_mask, annot, [gr.State(), log])
def preview_crop_fn(v):
reversed_path = reverse_video(v.name if hasattr(v, "name") else v, REVERSED_INPUT)
frame0 = extract_frame(reversed_path, 0)
if frame0 is None or not os.path.exists("user_mask.png"):
return None, None, "⚠️ Paint and Save Mask first."
x0,y0,x1,y1 = compute_crop_box_from_mask_dynamic(cv2.cvtColor(frame0, cv2.COLOR_RGB2BGR), "user_mask.png", pad=200)
frame_box = draw_crop_preview_on_frame(frame0, (x0,y0,x1,y1))
return frame_box, frame0[y0:y1, x0:x1], f"Square crop {x1-x0}x{y1-y0} at ({x0},{y0})"
preview_btn.click(preview_crop_fn, video_in, [preview_frame, preview_crop, log])
def run_btn_fn(v, m):
if not os.path.exists("user_mask.png"):
return "⚠️ Save Mask first.", None, None, None
return run_tracking(v.name if hasattr(v, "name") else v, "user_mask.png", m)
run_btn.click(run_btn_fn, [video_in, gr.Dropdown(["All Pixels", "Only Black Pixels"], value="All Pixels")],
[log, result_video, mask_video, stabilized_video])
return demo
if __name__ == "__main__":
app = build_app()
app.launch(server_name="0.0.0.0", server_port=7861, debug=True)