File size: 6,299 Bytes
f3d0a26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
# stage1_approx.py
import numpy as np
import torch
from utils.box_utils import interpolate_boxes

# ── Option A: Pure Linear Interpolation ─────────────────────────────
# Best for: static camera or very slow camera movement
# Worst for: fast pans, zoom, handheld footage

def stage1_linear(
    keyboxes: dict,
    num_frames: int
) -> np.ndarray:
    """
    Simplest possible Stage 1 substitute.
    keyboxes: {frame_idx: [x1, y1, x2, y2]}
    Returns:  [T, 4] box sequence
    """
    return interpolate_boxes(keyboxes, num_frames, method="linear")


# ── Option B: DA-v3 Depth Warping ───────────────────────────────────
# Better for: moderate camera motion
# From Table 7: IoU=0.79, mAP=0.73 (vs TRACE 0.80, 0.91)
# Requires: DepthAnything-v3 + MegaSAM or RAFT optical flow

def stage1_depth_warp(
    frames: np.ndarray,   # [T, H, W, 3]
    keyboxes: dict,
    depth_model,
    flow_model=None
) -> np.ndarray:
    """
    Project first-frame boxes to subsequent frames using depth + flow.
    """
    T, H, W, _ = frames.shape
    first_frame = frames[0]

    # Get depth for all frames
    depths = []
    for frame in frames:
        d = depth_model.infer(frame)  # [H, W] depth map
        depths.append(d)
    depths = np.stack(depths)  # [T, H, W]

    # Get first-frame depth at box center
    result_boxes = np.zeros((T, 4))
    for frame_idx, box in keyboxes.items():
        result_boxes[frame_idx] = box

    # For each unspecified frame, warp from nearest keybox
    keyframe_ids = sorted(keyboxes.keys())

    for t in range(T):
        if t in keyboxes:
            continue

        # Find nearest keyframe
        nearest_key = min(keyframe_ids, key=lambda k: abs(k - t))
        ref_box = keyboxes[nearest_key]
        ref_depth = depths[nearest_key]
        tgt_depth = depths[t]

        # Get depth at box center in reference frame
        cx_ref = (ref_box[0] + ref_box[2]) / 2
        cy_ref = (ref_box[1] + ref_box[3]) / 2
        cx_ref_i, cy_ref_i = int(cx_ref), int(cy_ref)
        d_ref = ref_depth[cy_ref_i, cx_ref_i]

        # Use optical flow if available for center displacement
        if flow_model is not None:
            flow = flow_model.compute(
                frames[nearest_key], frames[t]
            )  # [H, W, 2]
            dx = flow[cy_ref_i, cx_ref_i, 0]
            dy = flow[cy_ref_i, cx_ref_i, 1]
        else:
            dx, dy = 0, 0

        # Warp center
        cx_tgt = cx_ref + dx
        cy_tgt = cy_ref + dy

        # Scale box size by depth ratio
        d_tgt = tgt_depth[int(cy_tgt), int(cx_tgt)]
        scale = d_ref / (d_tgt + 1e-6)
        bw = (ref_box[2] - ref_box[0]) * scale
        bh = (ref_box[3] - ref_box[1]) * scale

        result_boxes[t] = [
            cx_tgt - bw/2, cy_tgt - bh/2,
            cx_tgt + bw/2, cy_tgt + bh/2
        ]

    # Fill any remaining gaps with interpolation
    specified = {i: result_boxes[i] for i in keyframe_ids}
    return interpolate_boxes(specified, T, method="linear")


# ── Option C: CoTracker-Assisted Warping ────────────────────────────
# Best for: fast camera, most accurate without training
# Uses background point tracks to estimate camera motion

def stage1_cotracker(
    frames: np.ndarray,   # [T, H, W, 3]
    keyboxes: dict,
    cotracker_model
) -> np.ndarray:
    """
    Use CoTracker point tracks to estimate camera motion,
    then warp keyboxes accordingly.
    """
    import torch
    T, H, W, _ = frames.shape

    # Build grid of background query points (avoid object region)
    first_box = list(keyboxes.values())[0]

    # Sample 100 background points (outside object box)
    bg_points = _sample_background_points(
        H, W, first_box, n_points=100
    )  # [100, 2] (x, y)

    # Track them across all frames
    video_tensor = torch.from_numpy(frames).float()
    video_tensor = video_tensor.permute(0, 3, 1, 2).unsqueeze(0)
    # [1, T, 3, H, W]

    queries = torch.zeros(1, len(bg_points), 3)
    queries[0, :, 0] = 0  # query at frame 0
    queries[0, :, 1] = torch.from_numpy(bg_points[:, 0])  # x
    queries[0, :, 2] = torch.from_numpy(bg_points[:, 1])  # y

    with torch.no_grad():
        tracks, visibility = cotracker_model(
            video_tensor, queries=queries
        )
    # tracks: [1, T, N_points, 2]
    tracks = tracks[0].numpy()  # [T, N, 2]

    # Estimate per-frame homography from background tracks
    result_boxes = np.zeros((T, 4))
    ref_points = tracks[0]  # [N, 2] at frame 0

    for t in range(T):
        if t in keyboxes:
            result_boxes[t] = keyboxes[t]
            continue

        # Find nearest keyframe
        nearest_key = min(keyboxes.keys(), key=lambda k: abs(k-t))
        ref_box = keyboxes[nearest_key]

        # Estimate transformation from nearest keyframe to frame t
        src_pts = tracks[nearest_key]  # [N, 2]
        dst_pts = tracks[t]           # [N, 2]

        import cv2
        H_mat, mask = cv2.findHomography(
            src_pts, dst_pts, cv2.RANSAC, 5.0
        )

        if H_mat is None:
            result_boxes[t] = ref_box
            continue

        # Warp box corners through homography
        corners = np.array([
            [ref_box[0], ref_box[1]],
            [ref_box[2], ref_box[1]],
            [ref_box[2], ref_box[3]],
            [ref_box[0], ref_box[3]]
        ], dtype=np.float32).reshape(-1, 1, 2)

        warped = cv2.perspectiveTransform(corners, H_mat)
        warped = warped.reshape(-1, 2)

        result_boxes[t] = [
            warped[:, 0].min(), warped[:, 1].min(),
            warped[:, 0].max(), warped[:, 1].max()
        ]

    return result_boxes


def _sample_background_points(H, W, object_box, n_points=100):
    """Sample points outside the object bounding box"""
    x1, y1, x2, y2 = object_box
    points = []
    attempts = 0
    while len(points) < n_points and attempts < n_points * 10:
        x = np.random.randint(0, W)
        y = np.random.randint(0, H)
        if not (x1 <= x <= x2 and y1 <= y <= y2):
            points.append([x, y])
        attempts += 1
    return np.array(points, dtype=np.float32)