| | import os |
| | import cv2 |
| | import math |
| | import torch |
| | import numpy as np |
| | from tqdm import tqdm |
| | from pathlib import Path |
| | from torchvision import transforms |
| | import gradio as gr |
| | from argparse import Namespace |
| | import sys |
| |
|
| | |
| | sys.path.append("/app/preprocess/RAFT/core") |
| | from raft import RAFT |
| | from utils.utils import InputPadder |
| |
|
| | |
| | DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' |
| | MODEL_PATH = "/data/pretrained/RAFT/raft-things.pth" |
| | OUTPUT_VIDEO = "/app/tracked_roi_output.mp4" |
| | OUTPUT_MASK_VIDEO = "/app/final_tracked_mask_video.mp4" |
| | STABILIZED_MASK = "/app/stabilized_mask_video.mp4" |
| | |
| | def load_raft_model(model_path): |
| | args = Namespace( |
| | small=False, |
| | mixed_precision=False, |
| | alternate_corr=False, |
| | dropout=0.0, |
| | max_depth=16, |
| | depth_network=False, |
| | depth_residual=False, |
| | depth_scale=1.0 |
| | ) |
| | model = torch.nn.DataParallel(RAFT(args)) |
| | model.load_state_dict(torch.load(model_path, map_location=DEVICE)) |
| | return model.module.to(DEVICE).eval() |
| |
|
| |
|
| | |
| | def to_tensor(image): |
| | return transforms.ToTensor()(image).unsqueeze(0).to(DEVICE) |
| |
|
| |
|
| | |
| | @torch.no_grad() |
| | def compute_flow(model, img1, img2): |
| | t1, t2 = to_tensor(img1), to_tensor(img2) |
| | padder = InputPadder(t1.shape) |
| | t1, t2 = padder.pad(t1, t2) |
| | _, flow = model(t1, t2, iters=30, test_mode=True) |
| | flow = padder.unpad(flow)[0] |
| | return flow.permute(1, 2, 0).cpu().numpy() |
| |
|
| |
|
| | |
| | def extract_frame(video_path, frame_number): |
| | cap = cv2.VideoCapture(video_path) |
| | cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number) |
| | ret, frame = cap.read() |
| | cap.release() |
| | if not ret: |
| | return None |
| | return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| |
|
| |
|
| | |
| | def save_mask(data): |
| | if data is None: |
| | return None, "β οΈ No mask data received!" |
| | if isinstance(data, dict): |
| | mask = data.get("mask") |
| | else: |
| | mask = data |
| | if mask is None: |
| | return None, "β οΈ Mask missing!" |
| | if mask.ndim == 3: |
| | mask_gray = cv2.cvtColor(mask, cv2.COLOR_RGBA2GRAY) |
| | else: |
| | mask_gray = mask |
| | _, bin_mask = cv2.threshold(mask_gray, 1, 255, cv2.THRESH_BINARY) |
| | mask_path = "user_mask.png" |
| | cv2.imwrite(mask_path, bin_mask) |
| | return mask_path, f"β
Saved mask ({np.count_nonzero(bin_mask)} painted pixels)" |
| |
|
| |
|
| | |
| | def stabilize_black_regions(input_video: str): |
| | """ |
| | Stabilize black region boundaries in a binary mask video. |
| | Fills small flickers and stabilizes edges without deforming boundaries. |
| | |
| | Args: |
| | input_video (str): path to input video (mask video with black & white pixels) |
| | """ |
| |
|
| | kernel_fill = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)) |
| | kernel_edge = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) |
| |
|
| | cap = cv2.VideoCapture(input_video) |
| | if not cap.isOpened(): |
| | raise FileNotFoundError(f"β Could not open video: {input_video}") |
| |
|
| | fps = cap.get(cv2.CAP_PROP_FPS) |
| | width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| | height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
| |
|
| | fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
| | out = cv2.VideoWriter(STABILIZED_MASK, fourcc, fps, (width, height)) |
| |
|
| | print(f"π Processing {int(cap.get(cv2.CAP_PROP_FRAME_COUNT))} frames " |
| | f"({width}x{height} @ {fps:.1f} fps)...") |
| |
|
| | while True: |
| | ret, frame = cap.read() |
| | if not ret: |
| | break |
| |
|
| | gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) |
| | _, mask = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY) |
| |
|
| | |
| | inv = cv2.bitwise_not(mask) |
| |
|
| | |
| | flood = inv.copy() |
| | h, w = inv.shape |
| | flood_mask = np.zeros((h + 2, w + 2), np.uint8) |
| | cv2.floodFill(flood, flood_mask, (0, 0), 255) |
| | holes = cv2.bitwise_not(flood) |
| | filled_inv = cv2.bitwise_or(inv, holes) |
| |
|
| | |
| | closed = cv2.morphologyEx(filled_inv, cv2.MORPH_CLOSE, kernel_edge, iterations=1) |
| |
|
| | |
| | denoised = cv2.medianBlur(closed, 3) |
| |
|
| | |
| | result = cv2.bitwise_not(denoised) |
| | out.write(cv2.cvtColor(result, cv2.COLOR_GRAY2BGR)) |
| |
|
| | cap.release() |
| | out.release() |
| | print(f"β
Saved stabilized mask video: {STABILIZED_MASK}") |
| | return STABILIZED_MASK |
| |
|
| |
|
| | |
| | def run_tracking(video_path, mask_path, selection_mode="All Pixels"): |
| | model = load_raft_model(MODEL_PATH) |
| | cap = cv2.VideoCapture(video_path) |
| | if not cap.isOpened(): |
| | return "β Failed to open video.", None, None |
| |
|
| | fps = cap.get(cv2.CAP_PROP_FPS) |
| | ret, first_frame = cap.read() |
| | if not ret: |
| | return "β Could not read first frame.", None, None |
| |
|
| | H, W = first_frame.shape[:2] |
| | mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) |
| | mask = cv2.resize(mask, (W, H)) |
| |
|
| | |
| | if selection_mode == "All Pixels": |
| | ys, xs = np.where(mask > 0) |
| | else: |
| | gray = cv2.cvtColor(first_frame, cv2.COLOR_BGR2GRAY) |
| | black_pixels = (gray < 30) |
| | combined = (mask > 0) & black_pixels |
| | ys, xs = np.where(combined) |
| |
|
| | tracked_points = np.vstack((xs, ys)).T.astype(np.float32) |
| | print(f"π― Selected {len(tracked_points)} pixels under mode: {selection_mode}") |
| |
|
| | |
| | fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
| | out_vis = cv2.VideoWriter(OUTPUT_VIDEO, fourcc, fps, (W, H)) |
| | out_mask = cv2.VideoWriter(OUTPUT_MASK_VIDEO, fourcc, fps, (W, H), isColor=False) |
| |
|
| | prev_frame = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB) |
| | frame_idx = 1 |
| |
|
| | while True: |
| | ret, curr_frame = cap.read() |
| | if not ret: |
| | break |
| | curr_rgb = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2RGB) |
| | flow = compute_flow(model, prev_frame, curr_rgb) |
| |
|
| | vis = curr_rgb.copy() |
| | |
| | mask_frame = np.full((H, W), 255, dtype=np.uint8) |
| |
|
| | new_points = [] |
| | for pt in tracked_points: |
| | px, py = int(round(pt[0])), int(round(pt[1])) |
| | if 0 <= px < W and 0 <= py < H: |
| | dx, dy = flow[py, px] |
| | nx, ny = pt[0] + dx, pt[1] + dy |
| | new_points.append([nx, ny]) |
| | |
| | cv2.circle(vis, (int(nx), int(ny)), 1, (0, 255, 0), -1) |
| | |
| | mask_frame[int(ny), int(nx)] = 0 |
| |
|
| | tracked_points = np.array(new_points, dtype=np.float32) |
| | out_vis.write(cv2.cvtColor(vis, cv2.COLOR_RGB2BGR)) |
| | out_mask.write(mask_frame) |
| |
|
| | prev_frame = curr_rgb |
| | frame_idx += 1 |
| |
|
| | cap.release() |
| | out_vis.release() |
| | out_mask.release() |
| | stabilize_mask = stabilize_black_regions(OUTPUT_MASK_VIDEO) |
| |
|
| | return ( |
| | f"β
Tracking complete ({selection_mode}).\nSaved:\n- {OUTPUT_VIDEO}\n- {OUTPUT_MASK_VIDEO}", |
| | OUTPUT_VIDEO, |
| | stabilize_mask |
| | ) |
| |
|
| |
|
| | |
| | def build_app(): |
| | with gr.Blocks() as demo: |
| | gr.Markdown("## π― RAFT Pixel Tracker with Brush-based ROI, Pixel Mode, and Mask Output (Inverted Mask)") |
| |
|
| | with gr.Row(): |
| | video_in = gr.Video(label="ποΈ Upload Video") |
| | frame_num = gr.Number(label="Frame # to Paint", value=0, precision=0) |
| |
|
| | load_btn = gr.Button("πΈ Load Frame for Annotation") |
| | annot = gr.Image( |
| | label="ποΈ Paint ROI Mask", |
| | tool="sketch", |
| | type="numpy", |
| | image_mode="RGBA", |
| | height=480, |
| | ) |
| |
|
| | pixel_mode = gr.Dropdown( |
| | choices=["All Pixels", "Only Black Pixels"], |
| | value="All Pixels", |
| | label="Pixel Selection Mode" |
| | ) |
| |
|
| | save_btn = gr.Button("πΎ Save Mask") |
| | run_btn = gr.Button("π Run RAFT Tracking") |
| | log = gr.Textbox(label="Logs", lines=6) |
| | result_video = gr.Video(label="π¬ Visualization Video") |
| | mask_video = gr.Video(label="β¬ Tracked Mask Video (black = tracked pixels)") |
| |
|
| | |
| | load_btn.click( |
| | fn=lambda v, f: extract_frame(v.name if hasattr(v, "name") else v, int(f)), |
| | inputs=[video_in, frame_num], |
| | outputs=annot |
| | ) |
| |
|
| | |
| | save_btn.click( |
| | fn=save_mask, |
| | inputs=annot, |
| | outputs=[gr.State(), log] |
| | ) |
| |
|
| | |
| | run_btn.click( |
| | fn=lambda v, m: run_tracking(v.name if hasattr(v, "name") else v, "user_mask.png", m), |
| | inputs=[video_in, pixel_mode], |
| | outputs=[log, result_video, mask_video] |
| | ) |
| |
|
| | return demo |
| |
|
| |
|
| | if __name__ == "__main__": |
| | app = build_app() |
| | app.launch(server_name="0.0.0.0", server_port=7860, debug=True) |
| |
|