ObjectInsertion / stage1_approx.py
Leema Krishna Murali
Initial commit
f3d0a26
# stage1_approx.py
import numpy as np
import torch
from utils.box_utils import interpolate_boxes
# ── Option A: Pure Linear Interpolation ─────────────────────────────
# Best for: static camera or very slow camera movement
# Worst for: fast pans, zoom, handheld footage
def stage1_linear(
keyboxes: dict,
num_frames: int
) -> np.ndarray:
"""
Simplest possible Stage 1 substitute.
keyboxes: {frame_idx: [x1, y1, x2, y2]}
Returns: [T, 4] box sequence
"""
return interpolate_boxes(keyboxes, num_frames, method="linear")
# ── Option B: DA-v3 Depth Warping ───────────────────────────────────
# Better for: moderate camera motion
# From Table 7: IoU=0.79, mAP=0.73 (vs TRACE 0.80, 0.91)
# Requires: DepthAnything-v3 + MegaSAM or RAFT optical flow
def stage1_depth_warp(
frames: np.ndarray, # [T, H, W, 3]
keyboxes: dict,
depth_model,
flow_model=None
) -> np.ndarray:
"""
Project first-frame boxes to subsequent frames using depth + flow.
"""
T, H, W, _ = frames.shape
first_frame = frames[0]
# Get depth for all frames
depths = []
for frame in frames:
d = depth_model.infer(frame) # [H, W] depth map
depths.append(d)
depths = np.stack(depths) # [T, H, W]
# Get first-frame depth at box center
result_boxes = np.zeros((T, 4))
for frame_idx, box in keyboxes.items():
result_boxes[frame_idx] = box
# For each unspecified frame, warp from nearest keybox
keyframe_ids = sorted(keyboxes.keys())
for t in range(T):
if t in keyboxes:
continue
# Find nearest keyframe
nearest_key = min(keyframe_ids, key=lambda k: abs(k - t))
ref_box = keyboxes[nearest_key]
ref_depth = depths[nearest_key]
tgt_depth = depths[t]
# Get depth at box center in reference frame
cx_ref = (ref_box[0] + ref_box[2]) / 2
cy_ref = (ref_box[1] + ref_box[3]) / 2
cx_ref_i, cy_ref_i = int(cx_ref), int(cy_ref)
d_ref = ref_depth[cy_ref_i, cx_ref_i]
# Use optical flow if available for center displacement
if flow_model is not None:
flow = flow_model.compute(
frames[nearest_key], frames[t]
) # [H, W, 2]
dx = flow[cy_ref_i, cx_ref_i, 0]
dy = flow[cy_ref_i, cx_ref_i, 1]
else:
dx, dy = 0, 0
# Warp center
cx_tgt = cx_ref + dx
cy_tgt = cy_ref + dy
# Scale box size by depth ratio
d_tgt = tgt_depth[int(cy_tgt), int(cx_tgt)]
scale = d_ref / (d_tgt + 1e-6)
bw = (ref_box[2] - ref_box[0]) * scale
bh = (ref_box[3] - ref_box[1]) * scale
result_boxes[t] = [
cx_tgt - bw/2, cy_tgt - bh/2,
cx_tgt + bw/2, cy_tgt + bh/2
]
# Fill any remaining gaps with interpolation
specified = {i: result_boxes[i] for i in keyframe_ids}
return interpolate_boxes(specified, T, method="linear")
# ── Option C: CoTracker-Assisted Warping ────────────────────────────
# Best for: fast camera, most accurate without training
# Uses background point tracks to estimate camera motion
def stage1_cotracker(
frames: np.ndarray, # [T, H, W, 3]
keyboxes: dict,
cotracker_model
) -> np.ndarray:
"""
Use CoTracker point tracks to estimate camera motion,
then warp keyboxes accordingly.
"""
import torch
T, H, W, _ = frames.shape
# Build grid of background query points (avoid object region)
first_box = list(keyboxes.values())[0]
# Sample 100 background points (outside object box)
bg_points = _sample_background_points(
H, W, first_box, n_points=100
) # [100, 2] (x, y)
# Track them across all frames
video_tensor = torch.from_numpy(frames).float()
video_tensor = video_tensor.permute(0, 3, 1, 2).unsqueeze(0)
# [1, T, 3, H, W]
queries = torch.zeros(1, len(bg_points), 3)
queries[0, :, 0] = 0 # query at frame 0
queries[0, :, 1] = torch.from_numpy(bg_points[:, 0]) # x
queries[0, :, 2] = torch.from_numpy(bg_points[:, 1]) # y
with torch.no_grad():
tracks, visibility = cotracker_model(
video_tensor, queries=queries
)
# tracks: [1, T, N_points, 2]
tracks = tracks[0].numpy() # [T, N, 2]
# Estimate per-frame homography from background tracks
result_boxes = np.zeros((T, 4))
ref_points = tracks[0] # [N, 2] at frame 0
for t in range(T):
if t in keyboxes:
result_boxes[t] = keyboxes[t]
continue
# Find nearest keyframe
nearest_key = min(keyboxes.keys(), key=lambda k: abs(k-t))
ref_box = keyboxes[nearest_key]
# Estimate transformation from nearest keyframe to frame t
src_pts = tracks[nearest_key] # [N, 2]
dst_pts = tracks[t] # [N, 2]
import cv2
H_mat, mask = cv2.findHomography(
src_pts, dst_pts, cv2.RANSAC, 5.0
)
if H_mat is None:
result_boxes[t] = ref_box
continue
# Warp box corners through homography
corners = np.array([
[ref_box[0], ref_box[1]],
[ref_box[2], ref_box[1]],
[ref_box[2], ref_box[3]],
[ref_box[0], ref_box[3]]
], dtype=np.float32).reshape(-1, 1, 2)
warped = cv2.perspectiveTransform(corners, H_mat)
warped = warped.reshape(-1, 2)
result_boxes[t] = [
warped[:, 0].min(), warped[:, 1].min(),
warped[:, 0].max(), warped[:, 1].max()
]
return result_boxes
def _sample_background_points(H, W, object_box, n_points=100):
"""Sample points outside the object bounding box"""
x1, y1, x2, y2 = object_box
points = []
attempts = 0
while len(points) < n_points and attempts < n_points * 10:
x = np.random.randint(0, W)
y = np.random.randint(0, H)
if not (x1 <= x <= x2 and y1 <= y <= y2):
points.append([x, y])
attempts += 1
return np.array(points, dtype=np.float32)