|
|
import os |
|
|
import cv2 |
|
|
import math |
|
|
import torch |
|
|
import numpy as np |
|
|
from tqdm import tqdm |
|
|
from pathlib import Path |
|
|
from torchvision import transforms |
|
|
import gradio as gr |
|
|
from argparse import Namespace |
|
|
import sys |
|
|
|
|
|
|
|
|
sys.path.append("/app/preprocess/RAFT/core") |
|
|
from raft import RAFT |
|
|
from utils.utils import InputPadder |
|
|
|
|
|
|
|
|
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
MODEL_PATH = "/data/pretrained/RAFT/raft-things.pth" |
|
|
OUTPUT_VIDEO = "/app/tracked_roi_output.mp4" |
|
|
OUTPUT_MASK_VIDEO = "/app/final_tracked_mask_video.mp4" |
|
|
STABILIZED_MASK = "/app/stabilized_mask_video.mp4" |
|
|
|
|
|
def load_raft_model(model_path): |
|
|
args = Namespace( |
|
|
small=False, |
|
|
mixed_precision=False, |
|
|
alternate_corr=False, |
|
|
dropout=0.0, |
|
|
max_depth=16, |
|
|
depth_network=False, |
|
|
depth_residual=False, |
|
|
depth_scale=1.0 |
|
|
) |
|
|
model = torch.nn.DataParallel(RAFT(args)) |
|
|
model.load_state_dict(torch.load(model_path, map_location=DEVICE)) |
|
|
return model.module.to(DEVICE).eval() |
|
|
|
|
|
|
|
|
|
|
|
def to_tensor(image): |
|
|
return transforms.ToTensor()(image).unsqueeze(0).to(DEVICE) |
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def compute_flow(model, img1, img2): |
|
|
t1, t2 = to_tensor(img1), to_tensor(img2) |
|
|
padder = InputPadder(t1.shape) |
|
|
t1, t2 = padder.pad(t1, t2) |
|
|
_, flow = model(t1, t2, iters=30, test_mode=True) |
|
|
flow = padder.unpad(flow)[0] |
|
|
return flow.permute(1, 2, 0).cpu().numpy() |
|
|
|
|
|
|
|
|
|
|
|
def extract_frame(video_path, frame_number): |
|
|
cap = cv2.VideoCapture(video_path) |
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number) |
|
|
ret, frame = cap.read() |
|
|
cap.release() |
|
|
if not ret: |
|
|
return None |
|
|
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
|
|
|
|
|
|
def save_mask(data): |
|
|
if data is None: |
|
|
return None, "⚠️ No mask data received!" |
|
|
if isinstance(data, dict): |
|
|
mask = data.get("mask") |
|
|
else: |
|
|
mask = data |
|
|
if mask is None: |
|
|
return None, "⚠️ Mask missing!" |
|
|
if mask.ndim == 3: |
|
|
mask_gray = cv2.cvtColor(mask, cv2.COLOR_RGBA2GRAY) |
|
|
else: |
|
|
mask_gray = mask |
|
|
_, bin_mask = cv2.threshold(mask_gray, 1, 255, cv2.THRESH_BINARY) |
|
|
mask_path = "user_mask.png" |
|
|
cv2.imwrite(mask_path, bin_mask) |
|
|
return mask_path, f"✅ Saved mask ({np.count_nonzero(bin_mask)} painted pixels)" |
|
|
|
|
|
|
|
|
|
|
|
def stabilize_black_regions(input_video: str): |
|
|
""" |
|
|
Stabilize black region boundaries in a binary mask video. |
|
|
Fills small flickers and stabilizes edges without deforming boundaries. |
|
|
|
|
|
Args: |
|
|
input_video (str): path to input video (mask video with black & white pixels) |
|
|
""" |
|
|
|
|
|
kernel_fill = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)) |
|
|
kernel_edge = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) |
|
|
|
|
|
cap = cv2.VideoCapture(input_video) |
|
|
if not cap.isOpened(): |
|
|
raise FileNotFoundError(f"❌ Could not open video: {input_video}") |
|
|
|
|
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
|
|
|
fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
|
|
out = cv2.VideoWriter(STABILIZED_MASK, fourcc, fps, (width, height)) |
|
|
|
|
|
print(f"🎞 Processing {int(cap.get(cv2.CAP_PROP_FRAME_COUNT))} frames " |
|
|
f"({width}x{height} @ {fps:.1f} fps)...") |
|
|
|
|
|
while True: |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
|
|
|
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) |
|
|
_, mask = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY) |
|
|
|
|
|
|
|
|
inv = cv2.bitwise_not(mask) |
|
|
|
|
|
|
|
|
flood = inv.copy() |
|
|
h, w = inv.shape |
|
|
flood_mask = np.zeros((h + 2, w + 2), np.uint8) |
|
|
cv2.floodFill(flood, flood_mask, (0, 0), 255) |
|
|
holes = cv2.bitwise_not(flood) |
|
|
filled_inv = cv2.bitwise_or(inv, holes) |
|
|
|
|
|
|
|
|
closed = cv2.morphologyEx(filled_inv, cv2.MORPH_CLOSE, kernel_edge, iterations=1) |
|
|
|
|
|
|
|
|
denoised = cv2.medianBlur(closed, 3) |
|
|
|
|
|
|
|
|
result = cv2.bitwise_not(denoised) |
|
|
out.write(cv2.cvtColor(result, cv2.COLOR_GRAY2BGR)) |
|
|
|
|
|
cap.release() |
|
|
out.release() |
|
|
print(f"✅ Saved stabilized mask video: {STABILIZED_MASK}") |
|
|
return STABILIZED_MASK |
|
|
|
|
|
|
|
|
|
|
|
def run_tracking(video_path, mask_path, selection_mode="All Pixels"): |
|
|
model = load_raft_model(MODEL_PATH) |
|
|
cap = cv2.VideoCapture(video_path) |
|
|
if not cap.isOpened(): |
|
|
return "❌ Failed to open video.", None, None |
|
|
|
|
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
ret, first_frame = cap.read() |
|
|
if not ret: |
|
|
return "❌ Could not read first frame.", None, None |
|
|
|
|
|
H, W = first_frame.shape[:2] |
|
|
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) |
|
|
mask = cv2.resize(mask, (W, H)) |
|
|
|
|
|
|
|
|
if selection_mode == "All Pixels": |
|
|
ys, xs = np.where(mask > 0) |
|
|
else: |
|
|
gray = cv2.cvtColor(first_frame, cv2.COLOR_BGR2GRAY) |
|
|
black_pixels = (gray < 30) |
|
|
combined = (mask > 0) & black_pixels |
|
|
ys, xs = np.where(combined) |
|
|
|
|
|
tracked_points = np.vstack((xs, ys)).T.astype(np.float32) |
|
|
print(f"🎯 Selected {len(tracked_points)} pixels under mode: {selection_mode}") |
|
|
|
|
|
|
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
|
out_vis = cv2.VideoWriter(OUTPUT_VIDEO, fourcc, fps, (W, H)) |
|
|
out_mask = cv2.VideoWriter(OUTPUT_MASK_VIDEO, fourcc, fps, (W, H), isColor=False) |
|
|
|
|
|
prev_frame = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB) |
|
|
frame_idx = 1 |
|
|
|
|
|
while True: |
|
|
ret, curr_frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
curr_rgb = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2RGB) |
|
|
flow = compute_flow(model, prev_frame, curr_rgb) |
|
|
|
|
|
vis = curr_rgb.copy() |
|
|
|
|
|
mask_frame = np.full((H, W), 255, dtype=np.uint8) |
|
|
|
|
|
new_points = [] |
|
|
for pt in tracked_points: |
|
|
px, py = int(round(pt[0])), int(round(pt[1])) |
|
|
if 0 <= px < W and 0 <= py < H: |
|
|
dx, dy = flow[py, px] |
|
|
nx, ny = pt[0] + dx, pt[1] + dy |
|
|
new_points.append([nx, ny]) |
|
|
|
|
|
cv2.circle(vis, (int(nx), int(ny)), 1, (0, 255, 0), -1) |
|
|
|
|
|
mask_frame[int(ny), int(nx)] = 0 |
|
|
|
|
|
tracked_points = np.array(new_points, dtype=np.float32) |
|
|
out_vis.write(cv2.cvtColor(vis, cv2.COLOR_RGB2BGR)) |
|
|
out_mask.write(mask_frame) |
|
|
|
|
|
prev_frame = curr_rgb |
|
|
frame_idx += 1 |
|
|
|
|
|
cap.release() |
|
|
out_vis.release() |
|
|
out_mask.release() |
|
|
stabilize_mask = stabilize_black_regions(OUTPUT_MASK_VIDEO) |
|
|
|
|
|
return ( |
|
|
f"✅ Tracking complete ({selection_mode}).\nSaved:\n- {OUTPUT_VIDEO}\n- {OUTPUT_MASK_VIDEO}", |
|
|
OUTPUT_VIDEO, |
|
|
stabilize_mask |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def build_app(): |
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("## 🎯 RAFT Pixel Tracker with Brush-based ROI, Pixel Mode, and Mask Output (Inverted Mask)") |
|
|
|
|
|
with gr.Row(): |
|
|
video_in = gr.Video(label="🎞️ Upload Video") |
|
|
frame_num = gr.Number(label="Frame # to Paint", value=0, precision=0) |
|
|
|
|
|
load_btn = gr.Button("📸 Load Frame for Annotation") |
|
|
annot = gr.Image( |
|
|
label="🖌️ Paint ROI Mask", |
|
|
tool="sketch", |
|
|
type="numpy", |
|
|
image_mode="RGBA", |
|
|
height=480, |
|
|
) |
|
|
|
|
|
pixel_mode = gr.Dropdown( |
|
|
choices=["All Pixels", "Only Black Pixels"], |
|
|
value="All Pixels", |
|
|
label="Pixel Selection Mode" |
|
|
) |
|
|
|
|
|
save_btn = gr.Button("💾 Save Mask") |
|
|
run_btn = gr.Button("🚀 Run RAFT Tracking") |
|
|
log = gr.Textbox(label="Logs", lines=6) |
|
|
result_video = gr.Video(label="🎬 Visualization Video") |
|
|
mask_video = gr.Video(label="⬛ Tracked Mask Video (black = tracked pixels)") |
|
|
|
|
|
|
|
|
load_btn.click( |
|
|
fn=lambda v, f: extract_frame(v.name if hasattr(v, "name") else v, int(f)), |
|
|
inputs=[video_in, frame_num], |
|
|
outputs=annot |
|
|
) |
|
|
|
|
|
|
|
|
save_btn.click( |
|
|
fn=save_mask, |
|
|
inputs=annot, |
|
|
outputs=[gr.State(), log] |
|
|
) |
|
|
|
|
|
|
|
|
run_btn.click( |
|
|
fn=lambda v, m: run_tracking(v.name if hasattr(v, "name") else v, "user_mask.png", m), |
|
|
inputs=[video_in, pixel_mode], |
|
|
outputs=[log, result_video, mask_video] |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
app = build_app() |
|
|
app.launch(server_name="0.0.0.0", server_port=7860, debug=True) |
|
|
|