# stage1_approx.py import numpy as np import torch from utils.box_utils import interpolate_boxes # ── Option A: Pure Linear Interpolation ───────────────────────────── # Best for: static camera or very slow camera movement # Worst for: fast pans, zoom, handheld footage def stage1_linear( keyboxes: dict, num_frames: int ) -> np.ndarray: """ Simplest possible Stage 1 substitute. keyboxes: {frame_idx: [x1, y1, x2, y2]} Returns: [T, 4] box sequence """ return interpolate_boxes(keyboxes, num_frames, method="linear") # ── Option B: DA-v3 Depth Warping ─────────────────────────────────── # Better for: moderate camera motion # From Table 7: IoU=0.79, mAP=0.73 (vs TRACE 0.80, 0.91) # Requires: DepthAnything-v3 + MegaSAM or RAFT optical flow def stage1_depth_warp( frames: np.ndarray, # [T, H, W, 3] keyboxes: dict, depth_model, flow_model=None ) -> np.ndarray: """ Project first-frame boxes to subsequent frames using depth + flow. """ T, H, W, _ = frames.shape first_frame = frames[0] # Get depth for all frames depths = [] for frame in frames: d = depth_model.infer(frame) # [H, W] depth map depths.append(d) depths = np.stack(depths) # [T, H, W] # Get first-frame depth at box center result_boxes = np.zeros((T, 4)) for frame_idx, box in keyboxes.items(): result_boxes[frame_idx] = box # For each unspecified frame, warp from nearest keybox keyframe_ids = sorted(keyboxes.keys()) for t in range(T): if t in keyboxes: continue # Find nearest keyframe nearest_key = min(keyframe_ids, key=lambda k: abs(k - t)) ref_box = keyboxes[nearest_key] ref_depth = depths[nearest_key] tgt_depth = depths[t] # Get depth at box center in reference frame cx_ref = (ref_box[0] + ref_box[2]) / 2 cy_ref = (ref_box[1] + ref_box[3]) / 2 cx_ref_i, cy_ref_i = int(cx_ref), int(cy_ref) d_ref = ref_depth[cy_ref_i, cx_ref_i] # Use optical flow if available for center displacement if flow_model is not None: flow = flow_model.compute( frames[nearest_key], frames[t] ) # [H, W, 2] dx = flow[cy_ref_i, cx_ref_i, 0] dy = flow[cy_ref_i, cx_ref_i, 1] else: dx, dy = 0, 0 # Warp center cx_tgt = cx_ref + dx cy_tgt = cy_ref + dy # Scale box size by depth ratio d_tgt = tgt_depth[int(cy_tgt), int(cx_tgt)] scale = d_ref / (d_tgt + 1e-6) bw = (ref_box[2] - ref_box[0]) * scale bh = (ref_box[3] - ref_box[1]) * scale result_boxes[t] = [ cx_tgt - bw/2, cy_tgt - bh/2, cx_tgt + bw/2, cy_tgt + bh/2 ] # Fill any remaining gaps with interpolation specified = {i: result_boxes[i] for i in keyframe_ids} return interpolate_boxes(specified, T, method="linear") # ── Option C: CoTracker-Assisted Warping ──────────────────────────── # Best for: fast camera, most accurate without training # Uses background point tracks to estimate camera motion def stage1_cotracker( frames: np.ndarray, # [T, H, W, 3] keyboxes: dict, cotracker_model ) -> np.ndarray: """ Use CoTracker point tracks to estimate camera motion, then warp keyboxes accordingly. """ import torch T, H, W, _ = frames.shape # Build grid of background query points (avoid object region) first_box = list(keyboxes.values())[0] # Sample 100 background points (outside object box) bg_points = _sample_background_points( H, W, first_box, n_points=100 ) # [100, 2] (x, y) # Track them across all frames video_tensor = torch.from_numpy(frames).float() video_tensor = video_tensor.permute(0, 3, 1, 2).unsqueeze(0) # [1, T, 3, H, W] queries = torch.zeros(1, len(bg_points), 3) queries[0, :, 0] = 0 # query at frame 0 queries[0, :, 1] = torch.from_numpy(bg_points[:, 0]) # x queries[0, :, 2] = torch.from_numpy(bg_points[:, 1]) # y with torch.no_grad(): tracks, visibility = cotracker_model( video_tensor, queries=queries ) # tracks: [1, T, N_points, 2] tracks = tracks[0].numpy() # [T, N, 2] # Estimate per-frame homography from background tracks result_boxes = np.zeros((T, 4)) ref_points = tracks[0] # [N, 2] at frame 0 for t in range(T): if t in keyboxes: result_boxes[t] = keyboxes[t] continue # Find nearest keyframe nearest_key = min(keyboxes.keys(), key=lambda k: abs(k-t)) ref_box = keyboxes[nearest_key] # Estimate transformation from nearest keyframe to frame t src_pts = tracks[nearest_key] # [N, 2] dst_pts = tracks[t] # [N, 2] import cv2 H_mat, mask = cv2.findHomography( src_pts, dst_pts, cv2.RANSAC, 5.0 ) if H_mat is None: result_boxes[t] = ref_box continue # Warp box corners through homography corners = np.array([ [ref_box[0], ref_box[1]], [ref_box[2], ref_box[1]], [ref_box[2], ref_box[3]], [ref_box[0], ref_box[3]] ], dtype=np.float32).reshape(-1, 1, 2) warped = cv2.perspectiveTransform(corners, H_mat) warped = warped.reshape(-1, 2) result_boxes[t] = [ warped[:, 0].min(), warped[:, 1].min(), warped[:, 0].max(), warped[:, 1].max() ] return result_boxes def _sample_background_points(H, W, object_box, n_points=100): """Sample points outside the object bounding box""" x1, y1, x2, y2 = object_box points = [] attempts = 0 while len(points) < n_points and attempts < n_points * 10: x = np.random.randint(0, W) y = np.random.randint(0, H) if not (x1 <= x <= x2 and y1 <= y <= y2): points.append([x, y]) attempts += 1 return np.array(points, dtype=np.float32)