Spaces:
Sleeping
Sleeping
| # import os | |
| # import sys | |
| # import cv2 | |
| # import math | |
| # import time | |
| # import torch | |
| # import numpy as np | |
| # import gradio as gr | |
| # from tqdm import tqdm | |
| # from pathlib import Path | |
| # from collections import deque | |
| # from argparse import Namespace | |
| # from torchvision import transforms | |
| # # === RAFT Setup === | |
| # sys.path.append("/app/preprocess/RAFT/core") | |
| # from raft import RAFT | |
| # from utils.utils import InputPadder | |
| # # === CONFIG === | |
| # DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # MODEL_PATH = "/app/RAFT/raft-things.pth" | |
| # OUTPUT_VIDEO = "/app/full_tracked_output.mp4" | |
| # OUTPUT_MASK_VIDEO = "/app/mask_output.mp4" | |
| # STABILIZED_MASK = "/app/stabilized_mask_output.mp4" | |
| # REVERSED_INPUT = "/app/reversed_input.mp4" | |
| # # ========================================================== | |
| # # === VIDEO UTILITIES ===================================== | |
| # # ========================================================== | |
| # def reverse_video(input_path, output_path): | |
| # """Reverse frames of input video and save as output.""" | |
| # cap = cv2.VideoCapture(input_path) | |
| # if not cap.isOpened(): | |
| # raise FileNotFoundError(f"β Could not open video: {input_path}") | |
| # fps = cap.get(cv2.CAP_PROP_FPS) | |
| # width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| # height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| # fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| # out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
| # frames = [] | |
| # while True: | |
| # ret, frame = cap.read() | |
| # if not ret: | |
| # break | |
| # frames.append(frame) | |
| # cap.release() | |
| # # Write reversed frames | |
| # for frame in reversed(frames): | |
| # out.write(frame) | |
| # out.release() | |
| # print(f"π Video reversed and saved: {output_path}") | |
| # return output_path | |
| # def reverse_video_file_inplace(path_in): | |
| # """Reverse an existing video and overwrite it.""" | |
| # tmp_path = path_in.replace(".mp4", "_tmp.mp4") | |
| # reverse_video(path_in, tmp_path) | |
| # os.replace(tmp_path, path_in) | |
| # # ========================================================== | |
| # # === RAFT LOADING ========================================= | |
| # # ========================================================== | |
| # def load_raft_model(model_path): | |
| # args = Namespace( | |
| # small=False, | |
| # mixed_precision=False, | |
| # alternate_corr=False, | |
| # dropout=0.0, | |
| # max_depth=16, | |
| # depth_network=False, | |
| # depth_residual=False, | |
| # depth_scale=1.0 | |
| # ) | |
| # model = torch.nn.DataParallel(RAFT(args)) | |
| # model.load_state_dict(torch.load(model_path, map_location=DEVICE)) | |
| # return model.module.to(DEVICE).eval() | |
| # def to_tensor(image): | |
| # return transforms.ToTensor()(image).unsqueeze(0).to(DEVICE) | |
| # @torch.no_grad() | |
| # def compute_flow(model, img1, img2): | |
| # t1, t2 = to_tensor(img1), to_tensor(img2) | |
| # padder = InputPadder(t1.shape) | |
| # t1, t2 = padder.pad(t1, t2) | |
| # _, flow = model(t1, t2, iters=30, test_mode=True) | |
| # flow = padder.unpad(flow)[0] | |
| # return flow.permute(1, 2, 0).cpu().numpy() | |
| # # ========================================================== | |
| # # === FRAME / MASK HELPERS ================================ | |
| # # ========================================================== | |
| # def extract_frame(video_path, frame_number): | |
| # cap = cv2.VideoCapture(video_path) | |
| # cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number) | |
| # ret, frame = cap.read() | |
| # cap.release() | |
| # if not ret: | |
| # return None | |
| # return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| # def save_mask(data): | |
| # if data is None: | |
| # return None, "β οΈ No mask data received!" | |
| # if isinstance(data, dict): | |
| # mask = data.get("mask") | |
| # else: | |
| # mask = data | |
| # if mask is None: | |
| # return None, "β οΈ Mask missing!" | |
| # if mask.ndim == 3: | |
| # mask_gray = cv2.cvtColor(mask, cv2.COLOR_RGBA2GRAY) | |
| # else: | |
| # mask_gray = mask | |
| # _, bin_mask = cv2.threshold(mask_gray, 1, 255, cv2.THRESH_BINARY) | |
| # mask_path = "user_mask.png" | |
| # cv2.imwrite(mask_path, bin_mask) | |
| # return mask_path, f"β Saved mask ({np.count_nonzero(bin_mask)} painted pixels)" | |
| # # ========================================================== | |
| # # === CROP HELPERS ========================================= | |
| # # ========================================================== | |
| # def get_mask_center(mask_path): | |
| # mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) | |
| # if mask is None: | |
| # raise FileNotFoundError("Mask not found: " + mask_path) | |
| # ys, xs = np.where(mask > 0) | |
| # h, w = mask.shape[:2] | |
| # if len(xs) == 0: | |
| # return w // 2, h // 2 | |
| # return int(np.mean(xs)), int(np.mean(ys)) | |
| # def clamp_crop(x0, y0, cw, ch, W, H): | |
| # x0 = max(0, min(x0, W - 1)) | |
| # y0 = max(0, min(y0, H - 1)) | |
| # x1 = x0 + cw | |
| # y1 = y0 + ch | |
| # if x1 > W: | |
| # x0 -= (x1 - W) | |
| # x1 = W | |
| # if y1 > H: | |
| # y0 -= (y1 - H) | |
| # y1 = H | |
| # return x0, y0, x1, y1 | |
| # def compute_crop_box_from_mask(first_frame_bgr, mask_path, crop_w=400, crop_h=400): | |
| # H, W = first_frame_bgr.shape[:2] | |
| # cx, cy = get_mask_center(mask_path) | |
| # x0 = cx - crop_w // 2 | |
| # y0 = cy - crop_h // 2 | |
| # return clamp_crop(x0, y0, crop_w, crop_h, W, H) | |
| # def draw_crop_preview_on_frame(frame_rgb, crop_box, color=(0,255,0), thickness=2): | |
| # x0, y0, x1, y1 = crop_box | |
| # frame = frame_rgb.copy() | |
| # cv2.rectangle(frame, (x0, y0), (x1, y1), color, thickness) | |
| # return frame | |
| # # ========================================================== | |
| # # === STABILIZATION ======================================== | |
| # # ========================================================== | |
| # def stabilize_black_regions(input_video): | |
| # # === Define kernels === | |
| # kernel_fill = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)) | |
| # kernel_edge = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) | |
| # cap = cv2.VideoCapture(input_video) | |
| # if not cap.isOpened(): | |
| # raise FileNotFoundError(f"β Could not open video: {input_video}") | |
| # fps = cap.get(cv2.CAP_PROP_FPS) | |
| # width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| # height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| # fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
| # out = cv2.VideoWriter(STABILIZED_MASK, fourcc, fps, (width, height)) | |
| # while True: | |
| # ret, frame = cap.read() | |
| # if not ret: | |
| # break | |
| # # Convert to grayscale and threshold | |
| # gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) | |
| # _, mask = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY) | |
| # # === Step 1: Fill black regions === | |
| # inv = cv2.bitwise_not(mask) | |
| # flood = inv.copy() | |
| # h, w = inv.shape | |
| # flood_mask = np.zeros((h + 2, w + 2), np.uint8) | |
| # cv2.floodFill(flood, flood_mask, (0, 0), 255) | |
| # holes = cv2.bitwise_not(flood) | |
| # filled = cv2.bitwise_or(inv, holes) | |
| # filled = cv2.bitwise_not(filled) | |
| # # === Step 2: Morphological stabilization === | |
| # # Fill small black holes and unify mask | |
| # stable = cv2.morphologyEx(filled, cv2.MORPH_CLOSE, kernel_fill, iterations=1) | |
| # # Smooth jagged edges | |
| # stable = cv2.morphologyEx(stable, cv2.MORPH_OPEN, kernel_edge, iterations=1) | |
| # # Write result | |
| # out.write(cv2.cvtColor(stable, cv2.COLOR_GRAY2BGR)) | |
| # cap.release() | |
| # out.release() | |
| # print(f"β Stabilized mask saved: {STABILIZED_MASK}") | |
| # return STABILIZED_MASK | |
| # # ========================================================== | |
| # # === TRACKING ============================================= | |
| # # ========================================================== | |
| # def run_tracking(video_path, mask_path, selection_mode="All Pixels", crop_w=400, crop_h=400): | |
| # BLACK_THRESH = 1 | |
| # HISTORY_LEN = 5 | |
| # reversed_path = reverse_video(video_path, REVERSED_INPUT) | |
| # cap = cv2.VideoCapture(reversed_path) | |
| # model = load_raft_model(MODEL_PATH) | |
| # fps = cap.get(cv2.CAP_PROP_FPS) | |
| # ret, first_frame = cap.read() | |
| # if not ret: | |
| # return "β Could not read first frame.", None, None, None | |
| # H, W = first_frame.shape[:2] | |
| # x0, y0, x1, y1 = compute_crop_box_from_mask(first_frame, mask_path, crop_w, crop_h) | |
| # cw, ch = x1 - x0, y1 - y0 | |
| # fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| # out_vis = cv2.VideoWriter(OUTPUT_VIDEO, fourcc, fps, (W, H)) | |
| # out_mask = cv2.VideoWriter(OUTPUT_MASK_VIDEO, fourcc, fps, (W, H), isColor=False) | |
| # full_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) | |
| # full_mask = cv2.resize(full_mask, (W, H), interpolation=cv2.INTER_NEAREST) | |
| # crop_mask = full_mask[y0:y1, x0:x1] | |
| # if selection_mode == "All Pixels": | |
| # ys, xs = np.where(crop_mask > 0) | |
| # else: | |
| # gray_first = cv2.cvtColor(first_frame, cv2.COLOR_BGR2GRAY) | |
| # black_pixels = (gray_first[y0:y1, x0:x1] < BLACK_THRESH) | |
| # combined = (crop_mask > 0) & black_pixels | |
| # ys, xs = np.where(combined) | |
| # tracked_points = np.vstack((xs, ys)).T.astype(np.float32) | |
| # prev_full_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB) | |
| # prev_crop_rgb = prev_full_rgb[y0:y1, x0:x1] | |
| # # Frame-level history (deque of last 5 black-region detections) | |
| # history = deque([True]*HISTORY_LEN, maxlen=HISTORY_LEN) | |
| # stopped = False | |
| # frame_idx = 0 | |
| # while True: | |
| # ret, curr_frame = cap.read() | |
| # if not ret: | |
| # break | |
| # frame_idx += 1 | |
| # curr_full_rgb = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2RGB) | |
| # curr_crop_rgb = curr_full_rgb[y0:y1, x0:x1] | |
| # gray_crop = cv2.cvtColor(curr_crop_rgb, cv2.COLOR_RGB2GRAY) | |
| # # --- Compute optical flow --- | |
| # flow_crop = compute_flow(model, prev_crop_rgb, curr_crop_rgb) | |
| # vis_full = curr_full_rgb.copy() | |
| # mask_full = np.full((H, W), 255, dtype=np.uint8) | |
| # # --- Move points --- | |
| # new_points = [] | |
| # for pt in tracked_points: | |
| # px, py = int(pt[0]), int(pt[1]) | |
| # if 0 <= px < cw and 0 <= py < ch: | |
| # dx, dy = flow_crop[py, px] | |
| # nx, ny = pt[0] + dx, pt[1] + dy | |
| # nx = np.clip(nx, 0, cw-1) | |
| # ny = np.clip(ny, 0, ch-1) | |
| # new_points.append([nx, ny]) | |
| # tracked_points = np.array(new_points, dtype=np.float32) | |
| # # --- Check black presence --- | |
| # black_mask = gray_crop < BLACK_THRESH | |
| # black_indices = tracked_points.astype(int) | |
| # has_black = False | |
| # for (px, py) in black_indices: | |
| # if 0 <= px < cw and 0 <= py < ch: | |
| # if black_mask[py, px]: | |
| # has_black = True | |
| # break | |
| # history.append(has_black) | |
| # # --- Determine painting condition --- | |
| # if stopped: | |
| # paint = False | |
| # elif has_black: | |
| # paint = True | |
| # elif not any(history): # last 5 all False | |
| # stopped = True | |
| # paint = False | |
| # else: | |
| # paint = True | |
| # # --- Paint or skip --- | |
| # if paint: | |
| # for pt in tracked_points: | |
| # fx, fy = int(pt[0] + x0), int(pt[1] + y0) | |
| # if 0 <= fx < W and 0 <= fy < H: | |
| # cv2.circle(vis_full, (fx, fy), 1, (0,255,0), -1) | |
| # mask_full[fy, fx] = 0 | |
| # else: | |
| # # no painting this frame | |
| # pass | |
| # out_vis.write(cv2.cvtColor(vis_full, cv2.COLOR_RGB2BGR)) | |
| # out_mask.write(mask_full) | |
| # prev_crop_rgb = curr_crop_rgb | |
| # # Optional: progress log | |
| # if frame_idx % 10 == 0: | |
| # print(f"Frame {frame_idx}: {'PAINT' if paint else 'NO-PAINT'} | has_black={has_black} | stopped={stopped}") | |
| # cap.release() | |
| # out_vis.release() | |
| # out_mask.release() | |
| # stabilize_black_regions(OUTPUT_MASK_VIDEO) | |
| # reverse_video_file_inplace(OUTPUT_VIDEO) | |
| # reverse_video_file_inplace(OUTPUT_MASK_VIDEO) | |
| # reverse_video_file_inplace(STABILIZED_MASK) | |
| # return ( | |
| # f"β Tracking complete ({selection_mode}).\n" | |
| # f"Crop {cw}x{ch} @ ({x0},{y0})\n" | |
| # f"Painting stopped={'Yes' if stopped else 'No'} after {frame_idx} frames.\n" | |
| # "Saved outputs reversed back to forward order.", | |
| # OUTPUT_VIDEO, | |
| # OUTPUT_MASK_VIDEO, | |
| # STABILIZED_MASK | |
| # ) | |
| # # ========================================================== | |
| # # === GRADIO APP =========================================== | |
| # # ========================================================== | |
| # def build_app(): | |
| # with gr.Blocks() as demo: | |
| # gr.Markdown("# π― Pixel Tracker ") | |
| # with gr.Row(): | |
| # video_in = gr.Video(label="ποΈ Upload Video") | |
| # frame_num = gr.Number(value=0, visible=False) | |
| # load_btn = gr.Button("πΈ Load Frame for Annotation") | |
| # annot = gr.Image(label="ποΈ Paint ROI Mask", tool="sketch", type="numpy", image_mode="RGBA", height=1000) | |
| # save_btn = gr.Button("πΎ Save Mask") | |
| # log = gr.Textbox(label="Logs", lines=8) | |
| # with gr.Row(): | |
| # pixel_mode = gr.Dropdown(["All Pixels", "Only Black Pixels"], value="All Pixels") | |
| # crop_w = gr.Number(value=400, label="Crop Width") | |
| # crop_h = gr.Number(value=400, label="Crop Height") | |
| # preview_btn = gr.Button("π Preview Crop") | |
| # with gr.Row(): | |
| # preview_frame = gr.Image(label="Preview Frame") | |
| # preview_crop = gr.Image(label="Cropped Region") | |
| # run_btn = gr.Button("π Run Tracking") | |
| # with gr.Row(): | |
| # result_video = gr.Video(label="π¬ Result (Forward)") | |
| # mask_video = gr.Video(label="β¬ Mask (Forward)") | |
| # stabilized_video = gr.Video(label="π§± Stabilized (Forward)") | |
| # # Load reversed frame for painting | |
| # def load_reversed_frame(v, f): | |
| # reversed_path = reverse_video(v.name if hasattr(v, "name") else v, REVERSED_INPUT) | |
| # return extract_frame(reversed_path, int(f)) | |
| # load_btn.click(load_reversed_frame, [video_in, frame_num], annot) | |
| # save_btn.click(save_mask, annot, [gr.State(), log]) | |
| # def preview_crop_fn(v, cw, ch): | |
| # reversed_path = reverse_video(v.name if hasattr(v, "name") else v, REVERSED_INPUT) | |
| # frame0 = extract_frame(reversed_path, 0) | |
| # if frame0 is None or not os.path.exists("user_mask.png"): | |
| # return None, None, "β οΈ Paint and Save Mask first." | |
| # x0,y0,x1,y1 = compute_crop_box_from_mask(cv2.cvtColor(frame0, cv2.COLOR_RGB2BGR), "user_mask.png", int(cw), int(ch)) | |
| # frame_box = draw_crop_preview_on_frame(frame0, (x0,y0,x1,y1)) | |
| # return frame_box, frame0[y0:y1, x0:x1], f"Crop {cw}x{ch} at ({x0},{y0})" | |
| # preview_btn.click(preview_crop_fn, [video_in, crop_w, crop_h], [preview_frame, preview_crop, log]) | |
| # def run_btn_fn(v, m, cw, ch): | |
| # if not os.path.exists("user_mask.png"): | |
| # return "β οΈ Save Mask first.", None, None, None | |
| # return run_tracking(v.name if hasattr(v, "name") else v, "user_mask.png", m, int(cw), int(ch)) | |
| # run_btn.click(run_btn_fn, [video_in, pixel_mode, crop_w, crop_h], | |
| # [log, result_video, mask_video, stabilized_video]) | |
| # return demo | |
| # if __name__ == "__main__": | |
| # app = build_app() | |
| # app.launch(server_name="0.0.0.0", server_port=7861, debug=True) | |
| import os | |
| import sys | |
| import cv2 | |
| import math | |
| import time | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| from tqdm import tqdm | |
| from pathlib import Path | |
| from collections import deque | |
| from argparse import Namespace | |
| from torchvision import transforms | |
| # === RAFT Setup === | |
| sys.path.append("/app/preprocess/RAFT/core") | |
| from raft import RAFT | |
| from utils.utils import InputPadder | |
| # === CONFIG === | |
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| MODEL_PATH = "/app/RAFT/raft-things.pth" | |
| OUTPUT_VIDEO = "/app/full_tracked_output.mp4" | |
| OUTPUT_MASK_VIDEO = "/app/mask_output.mp4" | |
| STABILIZED_MASK = "/app/stabilized_mask_output.mp4" | |
| REVERSED_INPUT = "/app/reversed_input.mp4" | |
| # ========================================================== | |
| # === VIDEO UTILITIES ===================================== | |
| # ========================================================== | |
| def reverse_video(input_path, output_path): | |
| cap = cv2.VideoCapture(input_path) | |
| if not cap.isOpened(): | |
| raise FileNotFoundError(f"β Could not open video: {input_path}") | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
| frames = [] | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frames.append(frame) | |
| cap.release() | |
| for frame in reversed(frames): | |
| out.write(frame) | |
| out.release() | |
| print(f"π Video reversed and saved: {output_path}") | |
| return output_path | |
| def reverse_video_file_inplace(path_in): | |
| tmp_path = path_in.replace(".mp4", "_tmp.mp4") | |
| reverse_video(path_in, tmp_path) | |
| os.replace(tmp_path, path_in) | |
| # ========================================================== | |
| # === RAFT LOADING ========================================= | |
| # ========================================================== | |
| def load_raft_model(model_path): | |
| args = Namespace( | |
| small=False, | |
| mixed_precision=False, | |
| alternate_corr=False, | |
| dropout=0.0, | |
| max_depth=16, | |
| depth_network=False, | |
| depth_residual=False, | |
| depth_scale=1.0 | |
| ) | |
| model = torch.nn.DataParallel(RAFT(args)) | |
| model.load_state_dict(torch.load(model_path, map_location=DEVICE)) | |
| return model.module.to(DEVICE).eval() | |
| def to_tensor(image): | |
| return transforms.ToTensor()(image).unsqueeze(0).to(DEVICE) | |
| def compute_flow(model, img1, img2): | |
| t1, t2 = to_tensor(img1), to_tensor(img2) | |
| padder = InputPadder(t1.shape) | |
| t1, t2 = padder.pad(t1, t2) | |
| _, flow = model(t1, t2, iters=30, test_mode=True) | |
| flow = padder.unpad(flow)[0] | |
| return flow.permute(1, 2, 0).cpu().numpy() | |
| # ========================================================== | |
| # === FRAME / MASK HELPERS ================================ | |
| # ========================================================== | |
| def extract_frame(video_path, frame_number): | |
| cap = cv2.VideoCapture(video_path) | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number) | |
| ret, frame = cap.read() | |
| cap.release() | |
| if not ret: | |
| return None | |
| return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| def save_mask(data): | |
| if data is None: | |
| return None, "β οΈ No mask data received!" | |
| if isinstance(data, dict): | |
| mask = data.get("mask") | |
| else: | |
| mask = data | |
| if mask is None: | |
| return None, "β οΈ Mask missing!" | |
| if mask.ndim == 3: | |
| mask_gray = cv2.cvtColor(mask, cv2.COLOR_RGBA2GRAY) | |
| else: | |
| mask_gray = mask | |
| _, bin_mask = cv2.threshold(mask_gray, 1, 255, cv2.THRESH_BINARY) | |
| mask_path = "user_mask.png" | |
| cv2.imwrite(mask_path, bin_mask) | |
| return mask_path, f"β Saved mask ({np.count_nonzero(bin_mask)} painted pixels)" | |
| # ========================================================== | |
| # === UPDATED DYNAMIC CROP LOGIC =========================== | |
| # ========================================================== | |
| def compute_crop_box_from_mask_dynamic(first_frame_bgr, mask_path, pad=200): | |
| """ | |
| Compute a square crop region based on mask region + padding. | |
| Ensures equal width & height. | |
| """ | |
| mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) | |
| if mask is None: | |
| raise FileNotFoundError(f"Mask not found: {mask_path}") | |
| H, W = mask.shape[:2] | |
| ys, xs = np.where(mask > 0) | |
| # Fallback to center crop if no mask | |
| if len(xs) == 0: | |
| cx, cy = W // 2, H // 2 | |
| size = min(W, H) // 2 | |
| return cx - size // 2, cy - size // 2, cx + size // 2, cy + size // 2 | |
| x_min, x_max = np.min(xs), np.max(xs) | |
| y_min, y_max = np.min(ys), np.max(ys) | |
| # Add padding | |
| x_min = max(0, x_min - pad) | |
| y_min = max(0, y_min - pad) | |
| x_max = min(W, x_max + pad) | |
| y_max = min(H, y_max + pad) | |
| # Make it square | |
| width = x_max - x_min | |
| height = y_max - y_min | |
| side = max(width, height) | |
| cx = (x_min + x_max) // 2 | |
| cy = (y_min + y_max) // 2 | |
| x_min = max(0, cx - side // 2) | |
| y_min = max(0, cy - side // 2) | |
| x_max = min(W, x_min + side) | |
| y_max = min(H, y_min + side) | |
| return int(x_min), int(y_min), int(x_max), int(y_max) | |
| def draw_crop_preview_on_frame(frame_rgb, crop_box, color=(0,255,0), thickness=2): | |
| x0, y0, x1, y1 = crop_box | |
| frame = frame_rgb.copy() | |
| cv2.rectangle(frame, (x0, y0), (x1, y1), color, thickness) | |
| return frame | |
| # ========================================================== | |
| # === STABILIZATION ======================================== | |
| # ========================================================== | |
| def stabilize_black_regions(input_video): | |
| kernel_fill = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)) | |
| kernel_edge = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) | |
| cap = cv2.VideoCapture(input_video) | |
| if not cap.isOpened(): | |
| raise FileNotFoundError(f"β Could not open video: {input_video}") | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
| out = cv2.VideoWriter(STABILIZED_MASK, fourcc, fps, (width, height)) | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) | |
| _, mask = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY) | |
| inv = cv2.bitwise_not(mask) | |
| flood = inv.copy() | |
| h, w = inv.shape | |
| flood_mask = np.zeros((h + 2, w + 2), np.uint8) | |
| cv2.floodFill(flood, flood_mask, (0, 0), 255) | |
| holes = cv2.bitwise_not(flood) | |
| filled = cv2.bitwise_or(inv, holes) | |
| filled = cv2.bitwise_not(filled) | |
| stable = cv2.morphologyEx(filled, cv2.MORPH_CLOSE, kernel_fill, iterations=1) | |
| stable = cv2.morphologyEx(stable, cv2.MORPH_OPEN, kernel_edge, iterations=1) | |
| out.write(cv2.cvtColor(stable, cv2.COLOR_GRAY2BGR)) | |
| cap.release() | |
| out.release() | |
| print(f"β Stabilized mask saved: {STABILIZED_MASK}") | |
| return STABILIZED_MASK | |
| # ========================================================== | |
| # === TRACKING ============================================= | |
| # ========================================================== | |
| def run_tracking(video_path, mask_path, selection_mode="All Pixels"): | |
| BLACK_THRESH = 1 | |
| HISTORY_LEN = 5 | |
| reversed_path = reverse_video(video_path, REVERSED_INPUT) | |
| cap = cv2.VideoCapture(reversed_path) | |
| model = load_raft_model(MODEL_PATH) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| ret, first_frame = cap.read() | |
| if not ret: | |
| return "β Could not read first frame.", None, None, None | |
| H, W = first_frame.shape[:2] | |
| x0, y0, x1, y1 = compute_crop_box_from_mask_dynamic(first_frame, mask_path, pad=200) | |
| cw, ch = x1 - x0, y1 - y0 | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out_vis = cv2.VideoWriter(OUTPUT_VIDEO, fourcc, fps, (W, H)) | |
| out_mask = cv2.VideoWriter(OUTPUT_MASK_VIDEO, fourcc, fps, (W, H), isColor=False) | |
| full_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) | |
| full_mask = cv2.resize(full_mask, (W, H), interpolation=cv2.INTER_NEAREST) | |
| crop_mask = full_mask[y0:y1, x0:x1] | |
| if selection_mode == "All Pixels": | |
| ys, xs = np.where(crop_mask > 0) | |
| else: | |
| gray_first = cv2.cvtColor(first_frame, cv2.COLOR_BGR2GRAY) | |
| black_pixels = (gray_first[y0:y1, x0:x1] < BLACK_THRESH) | |
| combined = (crop_mask > 0) & black_pixels | |
| ys, xs = np.where(combined) | |
| tracked_points = np.vstack((xs, ys)).T.astype(np.float32) | |
| prev_full_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB) | |
| prev_crop_rgb = prev_full_rgb[y0:y1, x0:x1] | |
| history = deque([True]*HISTORY_LEN, maxlen=HISTORY_LEN) | |
| stopped = False | |
| frame_idx = 0 | |
| while True: | |
| ret, curr_frame = cap.read() | |
| if not ret: | |
| break | |
| frame_idx += 1 | |
| curr_full_rgb = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2RGB) | |
| curr_crop_rgb = curr_full_rgb[y0:y1, x0:x1] | |
| gray_crop = cv2.cvtColor(curr_crop_rgb, cv2.COLOR_RGB2GRAY) | |
| flow_crop = compute_flow(model, prev_crop_rgb, curr_crop_rgb) | |
| vis_full = curr_full_rgb.copy() | |
| mask_full = np.full((H, W), 255, dtype=np.uint8) | |
| new_points = [] | |
| for pt in tracked_points: | |
| px, py = int(pt[0]), int(pt[1]) | |
| if 0 <= px < cw and 0 <= py < ch: | |
| dx, dy = flow_crop[py, px] | |
| nx, ny = pt[0] + dx, pt[1] + dy | |
| nx = np.clip(nx, 0, cw-1) | |
| ny = np.clip(ny, 0, ch-1) | |
| new_points.append([nx, ny]) | |
| tracked_points = np.array(new_points, dtype=np.float32) | |
| black_mask = gray_crop < BLACK_THRESH | |
| black_indices = tracked_points.astype(int) | |
| has_black = any( | |
| 0 <= px < cw and 0 <= py < ch and black_mask[py, px] | |
| for px, py in black_indices | |
| ) | |
| history.append(has_black) | |
| if stopped: | |
| paint = False | |
| elif has_black: | |
| paint = True | |
| elif not any(history): | |
| stopped = True | |
| paint = False | |
| else: | |
| paint = True | |
| if paint: | |
| for pt in tracked_points: | |
| fx, fy = int(pt[0] + x0), int(pt[1] + y0) | |
| if 0 <= fx < W and 0 <= fy < H: | |
| cv2.circle(vis_full, (fx, fy), 1, (0,255,0), -1) | |
| mask_full[fy, fx] = 0 | |
| out_vis.write(cv2.cvtColor(vis_full, cv2.COLOR_RGB2BGR)) | |
| out_mask.write(mask_full) | |
| prev_crop_rgb = curr_crop_rgb | |
| if frame_idx % 10 == 0: | |
| print(f"Frame {frame_idx}: {'PAINT' if paint else 'NO-PAINT'} | has_black={has_black} | stopped={stopped}") | |
| cap.release() | |
| out_vis.release() | |
| out_mask.release() | |
| stabilize_black_regions(OUTPUT_MASK_VIDEO) | |
| reverse_video_file_inplace(OUTPUT_VIDEO) | |
| reverse_video_file_inplace(OUTPUT_MASK_VIDEO) | |
| reverse_video_file_inplace(STABILIZED_MASK) | |
| return ( | |
| f"β Tracking complete ({selection_mode}).\n" | |
| f"Square Crop {cw}x{ch} @ ({x0},{y0}) with padding=100\n" | |
| f"Painting stopped={'Yes' if stopped else 'No'} after {frame_idx} frames.\n" | |
| "Saved outputs reversed back to forward order.", | |
| OUTPUT_VIDEO, | |
| OUTPUT_MASK_VIDEO, | |
| STABILIZED_MASK | |
| ) | |
| # ========================================================== | |
| # === GRADIO APP =========================================== | |
| # ========================================================== | |
| def build_app(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# π― Pixel Tracker (Dynamic Square Crop)") | |
| with gr.Row(): | |
| video_in = gr.Video(label="ποΈ Upload Video") | |
| frame_num = gr.Number(value=0, visible=False) | |
| load_btn = gr.Button("πΈ Load Frame for Annotation") | |
| annot = gr.Image(label="ποΈ Paint ROI Mask", tool="sketch", type="numpy", image_mode="RGBA", height=1000) | |
| save_btn = gr.Button("πΎ Save Mask") | |
| log = gr.Textbox(label="Logs", lines=8) | |
| preview_btn = gr.Button("π Preview Crop", visible=False) | |
| with gr.Row(): | |
| preview_frame = gr.Image(label="Preview Frame", visible=False) | |
| preview_crop = gr.Image(label="Cropped Region", visible=False) | |
| run_btn = gr.Button("π Run Tracking") | |
| with gr.Row(): | |
| result_video = gr.Video(label="π¬ Result (Forward)") | |
| mask_video = gr.Video(label="β¬ Mask (Forward)") | |
| stabilized_video = gr.Video(label="π§± Stabilized (Forward)") | |
| def load_reversed_frame(v, f): | |
| reversed_path = reverse_video(v.name if hasattr(v, "name") else v, REVERSED_INPUT) | |
| return extract_frame(reversed_path, int(f)) | |
| load_btn.click(load_reversed_frame, [video_in, frame_num], annot) | |
| save_btn.click(save_mask, annot, [gr.State(), log]) | |
| def preview_crop_fn(v): | |
| reversed_path = reverse_video(v.name if hasattr(v, "name") else v, REVERSED_INPUT) | |
| frame0 = extract_frame(reversed_path, 0) | |
| if frame0 is None or not os.path.exists("user_mask.png"): | |
| return None, None, "β οΈ Paint and Save Mask first." | |
| x0,y0,x1,y1 = compute_crop_box_from_mask_dynamic(cv2.cvtColor(frame0, cv2.COLOR_RGB2BGR), "user_mask.png", pad=200) | |
| frame_box = draw_crop_preview_on_frame(frame0, (x0,y0,x1,y1)) | |
| return frame_box, frame0[y0:y1, x0:x1], f"Square crop {x1-x0}x{y1-y0} at ({x0},{y0})" | |
| preview_btn.click(preview_crop_fn, video_in, [preview_frame, preview_crop, log]) | |
| def run_btn_fn(v, m): | |
| if not os.path.exists("user_mask.png"): | |
| return "β οΈ Save Mask first.", None, None, None | |
| return run_tracking(v.name if hasattr(v, "name") else v, "user_mask.png", m) | |
| run_btn.click(run_btn_fn, [video_in, gr.Dropdown(["All Pixels", "Only Black Pixels"], value="All Pixels")], | |
| [log, result_video, mask_video, stabilized_video]) | |
| return demo | |
| if __name__ == "__main__": | |
| app = build_app() | |
| app.launch(server_name="0.0.0.0", server_port=7861, debug=True) | |