pixels_tracking / track-pixels_gradio.py
isambalghari's picture
updating morphological operations
ac5932f
import os
import sys
import cv2
import math
import time
import torch
import numpy as np
import gradio as gr
from tqdm import tqdm
from pathlib import Path
from collections import deque
from argparse import Namespace
from torchvision import transforms
# === RAFT Setup ===
sys.path.append("/app/preprocess/RAFT/core")
from raft import RAFT
from utils.utils import InputPadder
# === CONFIG ===
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
MODEL_PATH = "/app/RAFT/raft-things.pth"
OUTPUT_VIDEO = "/app/full_tracked_output.mp4"
OUTPUT_MASK_VIDEO = "/app/mask_output.mp4"
STABILIZED_MASK = "/app/stabilized_mask_output.mp4"
REVERSED_INPUT = "/app/reversed_input.mp4"
# ==========================================================
# === VIDEO UTILITIES =====================================
# ==========================================================
def reverse_video(input_path, output_path):
"""
Reverse frames robustly β€” preserves all readable frames
even if OpenCV metadata is off by one.
"""
cap = cv2.VideoCapture(input_path)
if not cap.isOpened():
raise FileNotFoundError(f"❌ Could not open video: {input_path}")
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
frames = []
while True:
ret, frame = cap.read()
if not ret:
break
frames.append(frame)
cap.release()
if len(frames) == 0:
raise ValueError("No frames read from video!")
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
for frame in reversed(frames):
out.write(frame)
out.release()
cv2.destroyAllWindows()
print(f"βœ… Reversed {len(frames)} frames β†’ {output_path}")
return output_path
def reverse_video_file_inplace(path_in):
"""
Reverse a video in-place without losing frames.
"""
tmp_path = path_in.replace(".mp4", "_tmp.mp4")
reverse_video(path_in, tmp_path)
os.replace(tmp_path, path_in)
print(f"πŸ” Overwrote {path_in} with reversed version (same frame count).")
# ==========================================================
# === RAFT LOADING =========================================
# ==========================================================
def load_raft_model(model_path):
args = Namespace(
small=False,
mixed_precision=False,
alternate_corr=False,
dropout=0.0,
max_depth=16,
depth_network=False,
depth_residual=False,
depth_scale=1.0
)
model = torch.nn.DataParallel(RAFT(args))
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
return model.module.to(DEVICE).eval()
def to_tensor(image):
return transforms.ToTensor()(image).unsqueeze(0).to(DEVICE)
@torch.no_grad()
def compute_flow(model, img1, img2):
t1, t2 = to_tensor(img1), to_tensor(img2)
padder = InputPadder(t1.shape)
t1, t2 = padder.pad(t1, t2)
_, flow = model(t1, t2, iters=30, test_mode=True)
flow = padder.unpad(flow)[0]
return flow.permute(1, 2, 0).cpu().numpy()
# ==========================================================
# === FRAME / MASK HELPERS ================================
# ==========================================================
def extract_frame(video_path, frame_number):
cap = cv2.VideoCapture(video_path)
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
ret, frame = cap.read()
cap.release()
if not ret:
return None
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
def save_mask(data):
if data is None:
return None, "⚠️ No mask data received!"
if isinstance(data, dict):
mask = data.get("mask")
else:
mask = data
if mask is None:
return None, "⚠️ Mask missing!"
if mask.ndim == 3:
mask_gray = cv2.cvtColor(mask, cv2.COLOR_RGBA2GRAY)
else:
mask_gray = mask
_, bin_mask = cv2.threshold(mask_gray, 1, 255, cv2.THRESH_BINARY)
mask_path = "user_mask.png"
cv2.imwrite(mask_path, bin_mask)
return mask_path, f"βœ… Saved mask ({np.count_nonzero(bin_mask)} painted pixels)"
# ==========================================================
# === UPDATED DYNAMIC CROP LOGIC ===========================
# ==========================================================
def compute_crop_box_from_mask_dynamic(first_frame_bgr, mask_path, pad=200):
"""
Compute a square crop region based on mask region + padding.
Ensures equal width & height.
"""
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
if mask is None:
raise FileNotFoundError(f"Mask not found: {mask_path}")
H, W = mask.shape[:2]
ys, xs = np.where(mask > 0)
# Fallback to center crop if no mask
if len(xs) == 0:
cx, cy = W // 2, H // 2
size = min(W, H) // 2
return cx - size // 2, cy - size // 2, cx + size // 2, cy + size // 2
x_min, x_max = np.min(xs), np.max(xs)
y_min, y_max = np.min(ys), np.max(ys)
# Add padding
x_min = max(0, x_min - pad)
y_min = max(0, y_min - pad)
x_max = min(W, x_max + pad)
y_max = min(H, y_max + pad)
# Make it square
width = x_max - x_min
height = y_max - y_min
side = max(width, height)
cx = (x_min + x_max) // 2
cy = (y_min + y_max) // 2
x_min = max(0, cx - side // 2)
y_min = max(0, cy - side // 2)
x_max = min(W, x_min + side)
y_max = min(H, y_min + side)
return int(x_min), int(y_min), int(x_max), int(y_max)
def draw_crop_preview_on_frame(frame_rgb, crop_box, color=(0,255,0), thickness=2):
x0, y0, x1, y1 = crop_box
frame = frame_rgb.copy()
cv2.rectangle(frame, (x0, y0), (x1, y1), color, thickness)
return frame
# ==========================================================
# === STABILIZATION ========================================
# ==========================================================
def stabilize_black_regions(input_video, output_path=STABILIZED_MASK, blend=0.3, sample_frames=10):
"""
Visually consistent black region stabilizer:
- Repairs broken, thick edges and fills missing gaps.
- Maintains consistent thickness and stable edges across frames.
- Smooth temporal blending removes flicker and breathing effects.
Args:
input_video (str): Path to input mask video (black/white).
output_path (str): Path to save stabilized video.
blend (float): Temporal smoothing factor (0.0–1.0).
sample_frames (int): Number of initial frames to sample for parameter estimation.
"""
cap = cv2.VideoCapture(input_video)
if not cap.isOpened():
raise FileNotFoundError(f"Could not open video: {input_video}")
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
# === Step 1: Estimate global morphology parameters from first N frames ===
thickness_samples = []
count = 0
while count < sample_frames:
ret, frame = cap.read()
if not ret:
break
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
_, mask = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY_INV)
dist = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
if np.any(mask > 0):
thickness_samples.append(np.mean(dist[mask > 0]))
count += 1
cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # rewind
avg_thickness = np.median(thickness_samples) if thickness_samples else 5
k = int(np.clip(avg_thickness / 2.0, 3, 9))
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
min_area = (width * height) * 0.0005
print(f"🧠 Fixed morphology parameters β€” kernel={k} | min_area={min_area:.1f}")
prev_mask = None
# === Step 2: Process all frames ===
while True:
ret, frame = cap.read()
if not ret:
break
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
_, mask = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY_INV)
# --- (A) Connectivity repair: bridge gaps & fill ---
bridge_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
repaired = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, bridge_kernel, iterations=2)
filled = cv2.morphologyEx(repaired, cv2.MORPH_CLOSE,
cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)), iterations=2)
filled = cv2.morphologyEx(filled, cv2.MORPH_OPEN,
cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)), iterations=1)
# --- (B) Edge thickness normalization ---
dist = cv2.distanceTransform(cv2.bitwise_not(filled), cv2.DIST_L2, 3)
normalized = (dist < avg_thickness * 1.2).astype(np.uint8) * 255
base_clean = cv2.bitwise_not(normalized)
# --- (C) Morphological cleanup (fixed parameters) ---
base_clean = cv2.morphologyEx(base_clean, cv2.MORPH_CLOSE, kernel, iterations=2)
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(base_clean, connectivity=8)
filtered_mask = np.zeros_like(base_clean)
for i in range(1, num_labels):
area = stats[i, cv2.CC_STAT_AREA]
component_mask = (labels == i).astype(np.uint8) * 255
if area >= min_area:
filtered_mask = cv2.bitwise_or(filtered_mask, component_mask)
else:
# Merge small blobs softly
merge_mask = cv2.dilate(component_mask, kernel, iterations=2)
filtered_mask = cv2.bitwise_or(filtered_mask, merge_mask)
# --- (D) Edge reinforcement ---
edges = cv2.morphologyEx(filtered_mask, cv2.MORPH_GRADIENT,
cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)))
reinforced = cv2.bitwise_or(filtered_mask, edges)
reinforced = cv2.morphologyEx(reinforced, cv2.MORPH_CLOSE, kernel, iterations=2)
reinforced = cv2.medianBlur(reinforced, 3)
# --- (E) Temporal stabilization ---
if prev_mask is not None:
reinforced = cv2.addWeighted(reinforced, 1 - blend, prev_mask, blend, 0)
reinforced = (reinforced > 127).astype(np.uint8) * 255 # re-binarize
prev_mask = reinforced.copy()
# Invert back to black region mask
# clean = cv2.bitwise_not(reinforced)
out.write(cv2.cvtColor(reinforced, cv2.COLOR_GRAY2BGR))
cap.release()
out.release()
print(f"βœ… Visually stable and connected mask saved: {output_path}")
return output_path
# ==========================================================
# === TRACKING =============================================
# ==========================================================
def run_tracking(video_path, mask_path, selection_mode="All Pixels"):
BLACK_THRESH = 1
HISTORY_LEN = 5
# --- Reverse input for backward tracking ---
reversed_path = reverse_video(video_path, REVERSED_INPUT)
cap = cv2.VideoCapture(reversed_path)
model = load_raft_model(MODEL_PATH)
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
print(f"🎞️ Input video: {total_frames} frames at {fps:.2f} FPS")
ret, first_frame = cap.read()
if not ret:
return "❌ Could not read first frame.", None, None, None
H, W = first_frame.shape[:2]
# --- Compute dynamic square crop from mask ---
x0, y0, x1, y1 = compute_crop_box_from_mask_dynamic(first_frame, mask_path, pad=200)
cw, ch = x1 - x0, y1 - y0
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out_vis = cv2.VideoWriter(OUTPUT_VIDEO, fourcc, fps, (W, H))
out_mask = cv2.VideoWriter(OUTPUT_MASK_VIDEO, fourcc, fps, (W, H), isColor=False)
full_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
full_mask = cv2.resize(full_mask, (W, H), interpolation=cv2.INTER_NEAREST)
crop_mask = full_mask[y0:y1, x0:x1]
if selection_mode == "All Pixels":
ys, xs = np.where(crop_mask > 0)
else:
gray_first = cv2.cvtColor(first_frame, cv2.COLOR_BGR2GRAY)
black_pixels = (gray_first[y0:y1, x0:x1] < BLACK_THRESH)
combined = (crop_mask > 0) & black_pixels
ys, xs = np.where(combined)
tracked_points = np.vstack((xs, ys)).T.astype(np.float32)
prev_full_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
prev_crop_rgb = prev_full_rgb[y0:y1, x0:x1]
history = deque([True]*HISTORY_LEN, maxlen=HISTORY_LEN)
stopped = False
frame_idx = 0
curr_full_rgb = None
# === Main tracking loop ===
while True:
ret, curr_frame = cap.read()
if not ret:
break
frame_idx += 1
curr_full_rgb = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2RGB)
curr_crop_rgb = curr_full_rgb[y0:y1, x0:x1]
gray_crop = cv2.cvtColor(curr_crop_rgb, cv2.COLOR_RGB2GRAY)
# --- Optical flow between prev and curr ---
flow_crop = compute_flow(model, prev_crop_rgb, curr_crop_rgb)
vis_full = curr_full_rgb.copy()
mask_full = np.full((H, W), 255, dtype=np.uint8)
# --- Move tracked points ---
new_points = []
for pt in tracked_points:
px, py = int(pt[0]), int(pt[1])
if 0 <= px < cw and 0 <= py < ch:
dx, dy = flow_crop[py, px]
nx, ny = pt[0] + dx, pt[1] + dy
nx = np.clip(nx, 0, cw-1)
ny = np.clip(ny, 0, ch-1)
new_points.append([nx, ny])
tracked_points = np.array(new_points, dtype=np.float32)
# --- Detect black pixels ---
black_mask = gray_crop < BLACK_THRESH
black_indices = tracked_points.astype(int)
has_black = any(
0 <= px < cw and 0 <= py < ch and black_mask[py, px]
for px, py in black_indices
)
history.append(has_black)
# --- Painting logic ---
if stopped:
paint = False
elif has_black:
paint = True
elif not any(history): # last N all False
stopped = True
paint = False
else:
paint = True
# --- Paint or skip ---
if paint:
for pt in tracked_points:
fx, fy = int(pt[0] + x0), int(pt[1] + y0)
if 0 <= fx < W and 0 <= fy < H:
cv2.circle(vis_full, (fx, fy), 1, (0,255,0), -1)
mask_full[fy, fx] = 0
out_vis.write(cv2.cvtColor(vis_full, cv2.COLOR_RGB2BGR))
out_mask.write(mask_full)
prev_crop_rgb = curr_crop_rgb
if frame_idx % 10 == 0:
print(f"Frame {frame_idx}: {'PAINT' if paint else 'NO-PAINT'} | has_black={has_black} | stopped={stopped}")
# === Add final static frame to preserve frame count ===
try:
if curr_full_rgb is not None:
out_vis.write(cv2.cvtColor(curr_full_rgb, cv2.COLOR_RGB2BGR))
out_mask.write(mask_full)
print("🧩 Added final frame to preserve total frame count.")
except Exception as e:
print(f"⚠️ Could not add final frame: {e}")
cap.release()
out_vis.release()
out_mask.release()
# === Post-process: stabilization + reversal ===
stabilize_black_regions(OUTPUT_MASK_VIDEO)
reverse_video_file_inplace(OUTPUT_VIDEO)
reverse_video_file_inplace(OUTPUT_MASK_VIDEO)
reverse_video_file_inplace(STABILIZED_MASK)
# === Verify output frame counts ===
for path in [OUTPUT_VIDEO, OUTPUT_MASK_VIDEO, STABILIZED_MASK]:
cap_test = cv2.VideoCapture(path)
n = int(cap_test.get(cv2.CAP_PROP_FRAME_COUNT))
cap_test.release()
print(f"βœ… Verified {os.path.basename(path)} β†’ {n} frames")
return (
f"βœ… Tracking complete ({selection_mode}).\n"
f"Square Crop {cw}x{ch} @ ({x0},{y0}) with padding=200\n"
f"Painting stopped={'Yes' if stopped else 'No'} after {frame_idx} processed frames.\n"
f"All outputs now match input frame count ({total_frames}).",
OUTPUT_VIDEO,
OUTPUT_MASK_VIDEO,
STABILIZED_MASK
)
# ==========================================================
# === GRADIO APP ===========================================
# ==========================================================
def build_app():
with gr.Blocks() as demo:
gr.Markdown("# 🎯 Pixel Tracker (Dynamic Square Crop)")
with gr.Row():
video_in = gr.Video(label="🎞️ Upload Video")
frame_num = gr.Number(value=0, visible=False)
load_btn = gr.Button("πŸ“Έ Load Frame for Annotation")
annot = gr.Image(label="πŸ–ŒοΈ Paint ROI Mask", tool="sketch", type="numpy", image_mode="RGBA", height=1000)
save_btn = gr.Button("πŸ’Ύ Save Mask")
log = gr.Textbox(label="Logs", lines=8)
preview_btn = gr.Button("πŸ”Ž Preview Crop", visible=True)
with gr.Row():
preview_frame = gr.Image(label="Preview Frame", visible=False)
preview_crop = gr.Image(label="Cropped Region", visible=True)
run_btn = gr.Button("πŸš€ Run Tracking")
with gr.Row():
result_video = gr.Video(label="🎬 Result (Forward)")
mask_video = gr.Video(label="⬛ Mask (Forward)")
stabilized_video = gr.Video(label="🧱 Stabilized (Forward)")
def load_reversed_frame(v, f):
reversed_path = reverse_video(v.name if hasattr(v, "name") else v, REVERSED_INPUT)
return extract_frame(reversed_path, int(f))
load_btn.click(load_reversed_frame, [video_in, frame_num], annot)
save_btn.click(save_mask, annot, [gr.State(), log])
def preview_crop_fn(v):
reversed_path = reverse_video(v.name if hasattr(v, "name") else v, REVERSED_INPUT)
frame0 = extract_frame(reversed_path, 0)
if frame0 is None or not os.path.exists("user_mask.png"):
return None, None, "⚠️ Paint and Save Mask first."
x0,y0,x1,y1 = compute_crop_box_from_mask_dynamic(cv2.cvtColor(frame0, cv2.COLOR_RGB2BGR), "user_mask.png", pad=200)
frame_box = draw_crop_preview_on_frame(frame0, (x0,y0,x1,y1))
return frame_box, frame0[y0:y1, x0:x1], f"Square crop {x1-x0}x{y1-y0} at ({x0},{y0})"
preview_btn.click(preview_crop_fn, video_in, [preview_frame, preview_crop, log])
def run_btn_fn(v, m):
if not os.path.exists("user_mask.png"):
return "⚠️ Save Mask first.", None, None, None
return run_tracking(v.name if hasattr(v, "name") else v, "user_mask.png", m)
run_btn.click(run_btn_fn, [video_in, gr.Dropdown(["All Pixels", "Only Black Pixels"], value="All Pixels")],
[log, result_video, mask_video, stabilized_video])
return demo
if __name__ == "__main__":
app = build_app()
app.launch(server_name="0.0.0.0", server_port=7860, debug=True)