count-model / tracker.py
RehamAAhmed's picture
Upload 8 files
67a0341 verified
Raw
History Blame Contribute Delete
31.6 kB
"""
SAM 2 Robust Object Tracker
============================
Architecture:
- Grounding DINO → detects WHAT to track (frame 0 + recovery)
- SAM 2 Video → tracks WHERE the object is across frames
- Sliding Window → processes video in chunks so SAM 2 memory stays accurate
- Recovery Loop → if an object disappears, DINO relocalizes it
Output: Bounding Box + Label + ID per frame (no mask overlay, fast + low VRAM)
"""
import os
import cv2
import numpy as np
import torch
from PIL import Image
from typing import Optional
# ──────────────────────────────────────────────────────────
# Constants
# ──────────────────────────────────────────────────────────
MIN_MASK_AREA = 64 # pixels² — mask smaller than this = "lost"
CHUNK_SIZE_DEFAULT = 120 # frames per SAM-2 sliding window
HOME_DIR = os.path.expanduser("~")
SAM2_CKPT_DEFAULT = os.path.join(
HOME_DIR, ".cache", "torch", "hub", "checkpoints", "sam2.1_hiera_small.pt"
)
SAM2_CFG_DEFAULT = "configs/sam2.1/sam2.1_hiera_s.yaml"
# ──────────────────────────────────────────────────────────
# TrackedObject — stores per-object state between chunks
# ──────────────────────────────────────────────────────────
class TrackedObject:
"""Holds identity and last known bounding box for one tracked object."""
def __init__(self, obj_id: int, label: str, box: np.ndarray):
self.obj_id = obj_id
self.label = label
self.box = box.astype(np.float32) # [x1, y1, x2, y2]
self.lost = False # True if disappeared last chunk
self.lost_frames = 0 # consecutive frames without mask
def __repr__(self):
return f"TrackedObject(id={self.obj_id}, label='{self.label}', lost={self.lost})"
# ──────────────────────────────────────────────────────────
# VideoFrameStore — thin wrapper around a frames directory
# ──────────────────────────────────────────────────────────
class VideoFrameStore:
"""Extracts video frames to disk with optional stabilization, blur filter, resize."""
def __init__(self, video_path: str, output_dir: str,
target_fps: Optional[float] = None,
max_size: int = 720,
blur_threshold: float = 0.0,
stabilize: bool = False):
self.video_path = video_path
self.output_dir = output_dir
self.target_fps = target_fps
self.max_size = max_size
self.blur_threshold = blur_threshold
self.stabilize = stabilize
self.frame_paths: list[str] = [] # sorted list of extracted frame paths
self.orig_fps = 0.0
self.width = 0
self.height = 0
# ------------------------------------------------------------------
def extract(self) -> int:
"""Run extraction. Returns number of frames saved."""
import shutil
if os.path.exists(self.output_dir):
shutil.rmtree(self.output_dir)
os.makedirs(self.output_dir)
cap = cv2.VideoCapture(self.video_path)
if not cap.isOpened():
raise RuntimeError(f"Cannot open video: {self.video_path}")
self.orig_fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
raw_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
raw_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# Compute output resolution (keep aspect ratio)
scale = min(1.0, self.max_size / max(raw_w, raw_h))
self.width = int(raw_w * scale)
self.height = int(raw_h * scale)
# How many original frames to skip between each saved frame
sample_interval = max(1, round(self.orig_fps / self.target_fps)) \
if self.target_fps and self.target_fps > 0 else 1
stab_diff = None
if self.stabilize:
stab_diff = self._compute_stabilization(cap, raw_w, raw_h)
cap.release()
cap = cv2.VideoCapture(self.video_path)
saved = 0
orig_idx = 0
while True:
ret, frame = cap.read()
if not ret:
break
# --- sample at target fps ---
if orig_idx % sample_interval != 0:
orig_idx += 1
continue
# --- apply stabilization warp ---
if stab_diff is not None and orig_idx < len(stab_diff):
dx, dy, da = stab_diff[orig_idx]
M = np.array([[np.cos(da), -np.sin(da), dx],
[np.sin(da), np.cos(da), dy]], dtype=np.float32)
frame = cv2.warpAffine(frame, M, (raw_w, raw_h))
# --- resize ---
if scale < 1.0:
frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
# --- blur filter ---
if self.blur_threshold > 0:
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
if cv2.Laplacian(gray, cv2.CV_64F).var() < self.blur_threshold:
orig_idx += 1
continue
path = os.path.join(self.output_dir, f"{saved:05d}.jpg")
cv2.imwrite(path, frame, [cv2.IMWRITE_JPEG_QUALITY, 95])
self.frame_paths.append(path)
saved += 1
orig_idx += 1
cap.release()
# Fallback: if blur filter ate everything, save at least 1 frame
if saved == 0:
print("[WARN] All frames were blurry — saving 1 raw frame as fallback.")
cap = cv2.VideoCapture(self.video_path)
ret, frame = cap.read()
cap.release()
if ret:
if scale < 1.0:
frame = cv2.resize(frame, (self.width, self.height))
path = os.path.join(self.output_dir, "00000.jpg")
cv2.imwrite(path, frame)
self.frame_paths.append(path)
saved = 1
print(f"[Extract] Saved {saved} frames → {self.output_dir}")
return saved
# ------------------------------------------------------------------
def _compute_stabilization(self, cap, raw_w, raw_h):
"""ORB-based motion estimation → smoothed correction matrix per frame."""
print("[Stabilize] Computing ORB motion trajectory …")
transforms = []
prev_gray = None
scale = 480.0 / max(raw_w, raw_h)
while True:
ret, frame = cap.read()
if not ret:
break
small = cv2.resize(frame, (int(raw_w * scale), int(raw_h * scale)))
gray = cv2.cvtColor(small, cv2.COLOR_BGR2GRAY)
dx = dy = da = 0.0
if prev_gray is not None:
orb = cv2.ORB_create(300)
kp1, d1 = orb.detectAndCompute(prev_gray, None)
kp2, d2 = orb.detectAndCompute(gray, None)
if d1 is not None and d2 is not None and len(kp1) > 5 and len(kp2) > 5:
bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
matches = sorted(bf.match(d1, d2), key=lambda m: m.distance)[:50]
if len(matches) >= 4:
pts1 = np.float32([kp1[m.queryIdx].pt for m in matches])
pts2 = np.float32([kp2[m.trainIdx].pt for m in matches])
M, _ = cv2.estimateAffinePartial2D(pts1, pts2)
if M is not None:
dx = M[0, 2] / scale
dy = M[1, 2] / scale
da = np.arctan2(M[1, 0], M[0, 0])
transforms.append(np.array([dx, dy, da]))
prev_gray = gray
transforms = np.array(transforms)
traj = np.cumsum(transforms, axis=0)
radius = max(1, min(30, len(traj) // 2))
smooth = np.copy(traj)
for i in range(len(traj)):
s, e = max(0, i - radius), min(len(traj), i + radius + 1)
smooth[i] = np.mean(traj[s:e], axis=0)
return smooth - traj # correction per frame
# ──────────────────────────────────────────────────────────
# DinoDetector — wraps Grounding DINO for prompt detection
# ──────────────────────────────────────────────────────────
class DinoDetector:
"""Loads Grounding DINO and runs chunked prompt detection with NMS."""
CHUNK_SIZE = 15 # max vocabulary items per DINO call (avoids token overflow)
def __init__(self, device: torch.device):
self.device = device
self.processor = None
self.model = None
def load(self):
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
print("[DINO] Loading Grounding DINO Base …")
self.processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base")
self.model = AutoModelForZeroShotObjectDetection.from_pretrained(
"IDEA-Research/grounding-dino-base"
).to(self.device).eval()
print("[DINO] Loaded.")
# ------------------------------------------------------------------
def detect(self, image_path: str, prompt: str,
box_threshold: float = 0.30,
text_threshold: float = 0.25,
iou_threshold: float = 0.45
) -> tuple[np.ndarray, np.ndarray, list[str]]:
"""
Returns (boxes [N,4], scores [N], labels [N]) in pixel coords.
Prompt can be a multi-line string with # comments and comma-separated items.
"""
image_pil = Image.open(image_path).convert("RGB")
items = self._parse_prompt(prompt)
if not items:
return np.empty((0, 4)), np.array([]), []
# Split vocabulary into chunks of CHUNK_SIZE
chunks = [items[i:i+self.CHUNK_SIZE]
for i in range(0, len(items), self.CHUNK_SIZE)]
all_boxes, all_scores, all_labels = [], [], []
for idx, chunk in enumerate(chunks):
chunk_text = " . ".join(chunk) + " ."
print(f" [DINO] chunk {idx+1}/{len(chunks)}: {chunk_text[:80]}…")
inputs = self.processor(
images=image_pil, text=chunk_text, return_tensors="pt"
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
results = self._post_process(outputs, inputs.input_ids, image_pil,
box_threshold, text_threshold)
boxes = results["boxes"].cpu().numpy()
scores = results["scores"].cpu().numpy()
labels = results["labels"]
all_boxes.extend(boxes)
all_scores.extend(scores)
all_labels.extend(labels)
if not all_boxes:
return np.empty((0, 4)), np.array([]), []
all_boxes = np.array(all_boxes)
all_scores = np.array(all_scores)
keep = self._nms(all_boxes, all_scores, iou_threshold)
return all_boxes[keep], all_scores[keep], [all_labels[k] for k in keep]
# ------------------------------------------------------------------
def _post_process(self, outputs, input_ids, image_pil, box_thresh, text_thresh):
try:
return self.processor.post_process_grounded_object_detection(
outputs, input_ids,
box_threshold=box_thresh, text_threshold=text_thresh,
target_sizes=[image_pil.size[::-1]]
)[0]
except TypeError:
return self.processor.post_process_grounded_object_detection(
outputs, input_ids,
threshold=box_thresh, text_threshold=text_thresh,
target_sizes=[image_pil.size[::-1]]
)[0]
# ------------------------------------------------------------------
@staticmethod
def _parse_prompt(prompt: str) -> list[str]:
items = []
for line in prompt.splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
for part in line.replace(".", ",").split(","):
p = part.strip()
if p:
items.append(p)
seen, unique = set(), []
for x in items:
if x not in seen:
seen.add(x)
unique.append(x)
return unique
# ------------------------------------------------------------------
@staticmethod
def _nms(boxes: np.ndarray, scores: np.ndarray, iou_thresh: float) -> list[int]:
if len(boxes) == 0:
return []
x1, y1, x2, y2 = boxes[:,0], boxes[:,1], boxes[:,2], boxes[:,3]
areas = (x2 - x1) * (y2 - y1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(int(i))
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1)
h = np.maximum(0.0, yy2 - yy1)
iou = (w * h) / (areas[i] + areas[order[1:]] - w * h + 1e-6)
order = order[np.where(iou <= iou_thresh)[0] + 1]
return keep
# ──────────────────────────────────────────────────────────
# SAM2Tracker — sliding-window SAM 2 tracking engine
# ──────────────────────────────────────────────────────────
class SAM2Tracker:
"""
Proper SAM 2 video tracker with:
1. Sliding-window propagation — keeps memory bank fresh
2. Automatic lost-object detection — mask area < MIN_MASK_AREA
3. DINO re-anchor on lost objects — relocalizes using text prompt
4. Bbox-only rendering — fast, VRAM-friendly
"""
# Palette — visually distinct colors (BGR)
PALETTE = [
(255, 80, 80), # blue-ish
( 80, 220, 80), # green
( 80, 80, 255), # red
( 0, 220, 220), # yellow
(220, 0, 220), # magenta
(220, 220, 0), # cyan
(255, 160, 0), # orange
(160, 0, 200), # purple
( 0, 180, 180), # teal
( 0, 140, 255), # gold
(180, 255, 0), # lime
(255, 0, 150), # pink
]
def __init__(self,
sam2_checkpoint: str = SAM2_CKPT_DEFAULT,
sam2_cfg: str = SAM2_CFG_DEFAULT,
device: Optional[torch.device] = None,
chunk_size: int = CHUNK_SIZE_DEFAULT):
self.device = device or torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
self.sam2_checkpoint = sam2_checkpoint
self.sam2_cfg = sam2_cfg
self.chunk_size = chunk_size
self.predictor = None
# ------------------------------------------------------------------
def load(self):
from sam2.build_sam import build_sam2_video_predictor
print(f"[SAM2] Loading predictor (device={self.device}) …")
self.predictor = build_sam2_video_predictor(
self.sam2_cfg, self.sam2_checkpoint, device=self.device
)
print("[SAM2] Loaded.")
# ------------------------------------------------------------------
def track_video(self,
frame_store: VideoFrameStore,
tracked_objects: list[TrackedObject],
dino: DinoDetector,
prompt: str,
box_threshold: float = 0.30,
text_threshold: float = 0.25,
iou_threshold: float = 0.45,
output_path: str = "output.mp4",
progress_cb=None) -> list[str]:
"""
Main entry — runs sliding-window SAM 2 tracking and writes annotated video.
Returns list of tracked label strings.
"""
frame_paths = frame_store.frame_paths
total = len(frame_paths)
W, H = frame_store.width, frame_store.height
fps = frame_store.target_fps or frame_store.orig_fps
if total == 0:
raise RuntimeError("No frames to track!")
if not tracked_objects:
raise RuntimeError("No objects to track — run DINO detection first.")
# Use browser-friendly H.264 (avc1) codec if possible, fallback to mp4v
fourcc = cv2.VideoWriter_fourcc(*"avc1")
writer = cv2.VideoWriter(output_path, fourcc, fps, (W, H))
if not writer.isOpened():
print("[WARN] avc1 codec not opened, falling back to mp4v.")
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
writer = cv2.VideoWriter(output_path, fourcc, fps, (W, H))
# ── sliding window loop ──────────────────────────────────────
# We divide the video into chunks. SAM 2 is initialized fresh
# at the start of each chunk using the last known box of every object.
# This keeps the memory bank small and accurate.
chunk_starts = list(range(0, total, self.chunk_size))
print(f"\n[Track] {total} frames · {len(chunk_starts)} chunk(s) · "
f"chunk_size={self.chunk_size}")
for c_num, chunk_start in enumerate(chunk_starts):
chunk_end = min(chunk_start + self.chunk_size, total)
chunk_paths = frame_paths[chunk_start:chunk_end]
chunk_len = len(chunk_paths)
print(f"\n[Chunk {c_num+1}/{len(chunk_starts)}] "
f"frames {chunk_start}{chunk_end-1} ({chunk_len} frames)")
# ── 1. Create chunk frames directory using symlinks ──────
import tempfile, shutil
chunk_dir = os.path.join(
os.path.dirname(frame_store.output_dir),
f"_chunk_{c_num:04d}"
)
if os.path.exists(chunk_dir):
shutil.rmtree(chunk_dir)
os.makedirs(chunk_dir)
# Use os.symlink for instantaneous setup and disk space saving
for local_i, src in enumerate(chunk_paths):
dst = os.path.join(chunk_dir, f"{local_i:05d}.jpg")
if not os.path.exists(dst):
try:
os.symlink(os.path.abspath(src), dst)
except Exception:
# Fallback to copying if symlink fails
shutil.copy2(src, dst)
# ── 2. Init SAM 2 state for this chunk ──────────────────
autocast = (torch.autocast("cuda", dtype=torch.bfloat16)
if "cuda" in str(self.device) else
torch.autocast("cpu", dtype=torch.float32))
with torch.inference_mode(), autocast:
# On macOS (MPS), unified memory makes CPU offloading slow. Disable it.
offload_video = True
if "mps" in str(self.device):
offload_video = False
state = self.predictor.init_state(
video_path=chunk_dir,
offload_video_to_cpu=offload_video,
offload_state_to_cpu=False,
)
self.predictor.reset_state(state)
# ── 3. Register all (non-lost) objects at local frame 0 ──
registered = 0
for obj in tracked_objects:
if obj.lost:
print(f" [SKIP] id={obj.obj_id} '{obj.label}' is lost, "
f"will try DINO recovery after this chunk.")
continue
self.predictor.add_new_points_or_box(
inference_state=state,
frame_idx=0, # always local frame 0 of chunk
obj_id=obj.obj_id,
box=obj.box,
)
registered += 1
print(f" Registered {registered} objects at chunk start.")
# ── 4. Propagate through chunk ──────────────────────
# Collect: for each local frame, map obj_id → mask array
chunk_masks: dict[int, dict[int, np.ndarray]] = {}
# Also track last seen box per object for carry-forward
last_box: dict[int, np.ndarray] = {}
for local_idx, obj_ids, mask_logits in \
self.predictor.propagate_in_video(state):
frame_masks: dict[int, np.ndarray] = {}
for i, obj_id in enumerate(obj_ids):
mask = (mask_logits[i] > 0.0).cpu().numpy().squeeze()
if mask.ndim == 0:
mask = np.zeros((H, W), dtype=bool)
frame_masks[int(obj_id)] = mask
# Update last known bounding box from mask
if mask.sum() >= MIN_MASK_AREA:
ys, xs = np.where(mask)
new_box = np.array(
[xs.min(), ys.min(), xs.max(), ys.max()],
dtype=np.float32
)
last_box[int(obj_id)] = new_box
chunk_masks[local_idx] = frame_masks
if progress_cb:
progress_cb(chunk_start + local_idx + 1, total)
self.predictor.reset_state(state)
# ── 5. Update tracked objects with last seen boxes ───────
# Set to lost if it didn't have a valid mask in the last frame of the chunk
for obj in tracked_objects:
last_frame_mask = chunk_masks.get(chunk_len - 1, {}).get(obj.obj_id)
if last_frame_mask is not None and last_frame_mask.sum() >= MIN_MASK_AREA and obj.obj_id in last_box:
obj.box = last_box[obj.obj_id]
obj.lost = False
obj.lost_frames = 0
else:
obj.lost = True
obj.lost_frames += chunk_len
print(f" [LOST] id={obj.obj_id} '{obj.label}' — not visible at chunk end.")
# ── 6. DINO recovery for lost objects ────────────────────
# Run DINO on the LAST frame of this chunk to relocate them
lost_objects = [o for o in tracked_objects if o.lost]
if lost_objects:
last_chunk_frame = chunk_paths[-1]
print(f" [Recovery] Running DINO on frame {chunk_end-1} "
f"for {len(lost_objects)} lost object(s) …")
boxes, scores, labels = dino.detect(
last_chunk_frame, prompt,
box_threshold, text_threshold, iou_threshold
)
recovered = self._match_lost_to_dino(
lost_objects, boxes, labels, iou_threshold
)
for obj_id, new_box in recovered.items():
for obj in tracked_objects:
if obj.obj_id == obj_id:
obj.box = new_box
obj.lost = False
obj.lost_frames = 0
print(f" [Recovered] id={obj_id} '{obj.label}' "
f"at chunk boundary.")
# ── 7. Render and write frames for this chunk ────────────
for local_idx in range(chunk_len):
global_idx = chunk_start + local_idx
frame = cv2.imread(chunk_paths[local_idx])
masks_here = chunk_masks.get(local_idx, {})
# --- Save crops of objects before drawing on the frame ---
for obj in tracked_objects:
mask = masks_here.get(obj.obj_id)
if mask is not None and mask.sum() >= MIN_MASK_AREA:
area = mask.sum()
if not hasattr(obj, 'max_mask_area'):
obj.max_mask_area = 0
obj.best_crop = None
if area > obj.max_mask_area:
obj.max_mask_area = area
ys, xs = np.where(mask)
h_f, w_f = frame.shape[:2]
bx1, bx2 = max(0, int(xs.min())), min(w_f - 1, int(xs.max()))
by1, by2 = max(0, int(ys.min())), min(h_f - 1, int(ys.max()))
if bx2 > bx1 and by2 > by1:
obj.best_crop = frame[by1:by2+1, bx1:bx2+1].copy()
# Draw annotations on the frame and write it directly to the video file
frame = self._draw_frame(frame, masks_here, tracked_objects)
writer.write(frame)
# cleanup temp chunk dir
shutil.rmtree(chunk_dir, ignore_errors=True)
# Clean up memory state and empty PyTorch CUDA/MPS caches to avoid OOM
if 'state' in locals():
del state
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif hasattr(torch, 'mps') and torch.backends.mps.is_available():
torch.mps.empty_cache()
if progress_cb:
progress_cb(chunk_end, total)
writer.release()
print(f"[Done] Saved: {os.path.abspath(output_path)}")
return [o.label for o in tracked_objects]
# ------------------------------------------------------------------
def _draw_frame(self,
frame: np.ndarray,
masks: dict[int, np.ndarray],
tracked_objects: list[TrackedObject]) -> np.ndarray:
"""
Draw bounding box + label + ID.
No pixel mask overlay → fast and VRAM-independent.
"""
if frame is None:
return frame
for obj in tracked_objects:
oid = obj.obj_id
color = self.PALETTE[oid % len(self.PALETTE)]
mask = masks.get(oid)
if mask is None or mask.sum() < MIN_MASK_AREA:
# Object not visible this frame — draw a faded indicator on
# the last known box location
x1, y1, x2, y2 = obj.box.astype(int)
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 1)
self._put_label(frame, f"{obj.label} #{oid} [?]",
x1, y1, color, alpha=0.4)
continue
# Derive tight bounding box from the mask
ys, xs = np.where(mask)
bx1, bx2 = int(xs.min()), int(xs.max())
by1, by2 = int(ys.min()), int(ys.max())
# Draw solid bounding box
cv2.rectangle(frame, (bx1, by1), (bx2, by2), color, 2)
self._put_label(frame, f"{obj.label} #{oid}", bx1, by1, color)
return frame
# ------------------------------------------------------------------
@staticmethod
def _put_label(frame: np.ndarray, text: str,
x: int, y: int, color: tuple,
alpha: float = 1.0):
font = cv2.FONT_HERSHEY_SIMPLEX
scale = 0.5
thickness = 1
(tw, th), _ = cv2.getTextSize(text, font, scale, thickness)
pad = 4
bkg_y1 = max(0, y - th - pad * 2)
bkg_y2 = y
bkg_x2 = x + tw + pad * 2
# Background rectangle
if alpha >= 0.9:
cv2.rectangle(frame, (x, bkg_y1), (bkg_x2, bkg_y2), color, -1)
else:
overlay = frame.copy()
cv2.rectangle(overlay, (x, bkg_y1), (bkg_x2, bkg_y2), color, -1)
cv2.addWeighted(overlay, alpha, frame, 1 - alpha, 0, frame)
# Text
cv2.putText(frame, text, (x + pad, y - pad),
font, scale, (255, 255, 255), thickness, cv2.LINE_AA)
# ------------------------------------------------------------------
@staticmethod
def _match_lost_to_dino(lost_objects: list[TrackedObject],
dino_boxes: np.ndarray,
dino_labels: list[str],
iou_threshold: float = 0.20
) -> dict[int, np.ndarray]:
"""
For each lost object, find the best DINO detection that:
(a) has the same label (or close substring match), AND
(b) overlaps reasonably OR is the closest available detection.
Returns {obj_id: new_box}.
"""
recovered = {}
used_dino = set()
for obj in lost_objects:
best_idx = None
best_score = -1.0
for d_idx, (d_box, d_label) in enumerate(zip(dino_boxes, dino_labels)):
if d_idx in used_dino:
continue
# Label similarity: simple substring match
label_ok = (obj.label.lower() in d_label.lower() or
d_label.lower() in obj.label.lower())
if not label_ok:
continue
# Prefer boxes overlapping the last known location
x1, y1, x2, y2 = obj.box
dx1,dy1,dx2,dy2 = d_box
ix1 = max(x1, dx1); iy1 = max(y1, dy1)
ix2 = min(x2, dx2); iy2 = min(y2, dy2)
iw = max(0, ix2 - ix1); ih = max(0, iy2 - iy1)
inter = iw * ih
union = ((x2-x1)*(y2-y1) + (dx2-dx1)*(dy2-dy1) - inter + 1e-6)
iou = inter / union
# Score: label match + IoU bonus
score = 0.5 + iou
if score > best_score:
best_score = score
best_idx = d_idx
if best_idx is not None:
recovered[obj.obj_id] = dino_boxes[best_idx].astype(np.float32)
used_dino.add(best_idx)
return recovered