vfx-2 / preprocess /track_pixels.py
TaqiRaza512's picture
Initial commit
307c071
import os
import cv2
import math
import torch
import numpy as np
from tqdm import tqdm
from pathlib import Path
from torchvision import transforms
import gradio as gr
from argparse import Namespace
import sys
# === RAFT Setup ===
sys.path.append("/app/preprocess/RAFT/core")
from raft import RAFT
from utils.utils import InputPadder
# === CONFIG ===
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
MODEL_PATH = "/data/pretrained/RAFT/raft-things.pth"
OUTPUT_VIDEO = "/app/tracked_roi_output.mp4"
OUTPUT_MASK_VIDEO = "/app/final_tracked_mask_video.mp4"
STABILIZED_MASK = "/app/stabilized_mask_video.mp4"
# --- Load RAFT model ---
def load_raft_model(model_path):
args = Namespace(
small=False,
mixed_precision=False,
alternate_corr=False,
dropout=0.0,
max_depth=16,
depth_network=False,
depth_residual=False,
depth_scale=1.0
)
model = torch.nn.DataParallel(RAFT(args))
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
return model.module.to(DEVICE).eval()
# --- Convert to tensor ---
def to_tensor(image):
return transforms.ToTensor()(image).unsqueeze(0).to(DEVICE)
# --- Compute optical flow between two frames ---
@torch.no_grad()
def compute_flow(model, img1, img2):
t1, t2 = to_tensor(img1), to_tensor(img2)
padder = InputPadder(t1.shape)
t1, t2 = padder.pad(t1, t2)
_, flow = model(t1, t2, iters=30, test_mode=True)
flow = padder.unpad(flow)[0]
return flow.permute(1, 2, 0).cpu().numpy() # [H, W, 2]
# --- Extract a specific frame from the video ---
def extract_frame(video_path, frame_number):
cap = cv2.VideoCapture(video_path)
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
ret, frame = cap.read()
cap.release()
if not ret:
return None
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# --- Save user mask to disk ---
def save_mask(data):
if data is None:
return None, "⚠️ No mask data received!"
if isinstance(data, dict):
mask = data.get("mask")
else:
mask = data
if mask is None:
return None, "⚠️ Mask missing!"
if mask.ndim == 3:
mask_gray = cv2.cvtColor(mask, cv2.COLOR_RGBA2GRAY)
else:
mask_gray = mask
_, bin_mask = cv2.threshold(mask_gray, 1, 255, cv2.THRESH_BINARY)
mask_path = "user_mask.png"
cv2.imwrite(mask_path, bin_mask)
return mask_path, f"✅ Saved mask ({np.count_nonzero(bin_mask)} painted pixels)"
# --- Fill in the black region ---
def stabilize_black_regions(input_video: str):
"""
Stabilize black region boundaries in a binary mask video.
Fills small flickers and stabilizes edges without deforming boundaries.
Args:
input_video (str): path to input video (mask video with black & white pixels)
"""
kernel_fill = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
kernel_edge = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
cap = cv2.VideoCapture(input_video)
if not cap.isOpened():
raise FileNotFoundError(f"❌ Could not open video: {input_video}")
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(STABILIZED_MASK, fourcc, fps, (width, height))
print(f"🎞 Processing {int(cap.get(cv2.CAP_PROP_FRAME_COUNT))} frames "
f"({width}x{height} @ {fps:.1f} fps)...")
while True:
ret, frame = cap.read()
if not ret:
break
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
_, mask = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)
# Invert so black→white (foreground)
inv = cv2.bitwise_not(mask)
# Fill interior holes
flood = inv.copy()
h, w = inv.shape
flood_mask = np.zeros((h + 2, w + 2), np.uint8)
cv2.floodFill(flood, flood_mask, (0, 0), 255)
holes = cv2.bitwise_not(flood)
filled_inv = cv2.bitwise_or(inv, holes)
# Stabilize edges (dilate → erode)
closed = cv2.morphologyEx(filled_inv, cv2.MORPH_CLOSE, kernel_edge, iterations=1)
# Remove small dots
denoised = cv2.medianBlur(closed, 3)
# Invert back (restore black region)
result = cv2.bitwise_not(denoised)
out.write(cv2.cvtColor(result, cv2.COLOR_GRAY2BGR))
cap.release()
out.release()
print(f"✅ Saved stabilized mask video: {STABILIZED_MASK}")
return STABILIZED_MASK
# --- RAFT Pixel Tracking based on user mask ---
def run_tracking(video_path, mask_path, selection_mode="All Pixels"):
model = load_raft_model(MODEL_PATH)
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return "❌ Failed to open video.", None, None
fps = cap.get(cv2.CAP_PROP_FPS)
ret, first_frame = cap.read()
if not ret:
return "❌ Could not read first frame.", None, None
H, W = first_frame.shape[:2]
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
mask = cv2.resize(mask, (W, H))
# Select points based on mode
if selection_mode == "All Pixels":
ys, xs = np.where(mask > 0)
else:
gray = cv2.cvtColor(first_frame, cv2.COLOR_BGR2GRAY)
black_pixels = (gray < 30) # threshold for black
combined = (mask > 0) & black_pixels
ys, xs = np.where(combined)
tracked_points = np.vstack((xs, ys)).T.astype(np.float32)
print(f"🎯 Selected {len(tracked_points)} pixels under mode: {selection_mode}")
# Writers for visualization and mask
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out_vis = cv2.VideoWriter(OUTPUT_VIDEO, fourcc, fps, (W, H))
out_mask = cv2.VideoWriter(OUTPUT_MASK_VIDEO, fourcc, fps, (W, H), isColor=False)
prev_frame = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
frame_idx = 1
while True:
ret, curr_frame = cap.read()
if not ret:
break
curr_rgb = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2RGB)
flow = compute_flow(model, prev_frame, curr_rgb)
vis = curr_rgb.copy()
# Start with all-white background
mask_frame = np.full((H, W), 255, dtype=np.uint8)
new_points = []
for pt in tracked_points:
px, py = int(round(pt[0])), int(round(pt[1]))
if 0 <= px < W and 0 <= py < H:
dx, dy = flow[py, px]
nx, ny = pt[0] + dx, pt[1] + dy
new_points.append([nx, ny])
# Visualization
cv2.circle(vis, (int(nx), int(ny)), 1, (0, 255, 0), -1)
# Mask (black pixel for tracked)
mask_frame[int(ny), int(nx)] = 0
tracked_points = np.array(new_points, dtype=np.float32)
out_vis.write(cv2.cvtColor(vis, cv2.COLOR_RGB2BGR))
out_mask.write(mask_frame)
prev_frame = curr_rgb
frame_idx += 1
cap.release()
out_vis.release()
out_mask.release()
stabilize_mask = stabilize_black_regions(OUTPUT_MASK_VIDEO)
return (
f"✅ Tracking complete ({selection_mode}).\nSaved:\n- {OUTPUT_VIDEO}\n- {OUTPUT_MASK_VIDEO}",
OUTPUT_VIDEO,
stabilize_mask
)
# --- Gradio UI ---
def build_app():
with gr.Blocks() as demo:
gr.Markdown("## 🎯 RAFT Pixel Tracker with Brush-based ROI, Pixel Mode, and Mask Output (Inverted Mask)")
with gr.Row():
video_in = gr.Video(label="🎞️ Upload Video")
frame_num = gr.Number(label="Frame # to Paint", value=0, precision=0)
load_btn = gr.Button("📸 Load Frame for Annotation")
annot = gr.Image(
label="🖌️ Paint ROI Mask",
tool="sketch",
type="numpy",
image_mode="RGBA",
height=480,
)
pixel_mode = gr.Dropdown(
choices=["All Pixels", "Only Black Pixels"],
value="All Pixels",
label="Pixel Selection Mode"
)
save_btn = gr.Button("💾 Save Mask")
run_btn = gr.Button("🚀 Run RAFT Tracking")
log = gr.Textbox(label="Logs", lines=6)
result_video = gr.Video(label="🎬 Visualization Video")
mask_video = gr.Video(label="⬛ Tracked Mask Video (black = tracked pixels)")
# Load frame
load_btn.click(
fn=lambda v, f: extract_frame(v.name if hasattr(v, "name") else v, int(f)),
inputs=[video_in, frame_num],
outputs=annot
)
# Save mask
save_btn.click(
fn=save_mask,
inputs=annot,
outputs=[gr.State(), log]
)
# Run tracking
run_btn.click(
fn=lambda v, m: run_tracking(v.name if hasattr(v, "name") else v, "user_mask.png", m),
inputs=[video_in, pixel_mode],
outputs=[log, result_video, mask_video]
)
return demo
if __name__ == "__main__":
app = build_app()
app.launch(server_name="0.0.0.0", server_port=7860, debug=True)