""" SAM 2 Robust Object Tracker ============================ Architecture: - Grounding DINO → detects WHAT to track (frame 0 + recovery) - SAM 2 Video → tracks WHERE the object is across frames - Sliding Window → processes video in chunks so SAM 2 memory stays accurate - Recovery Loop → if an object disappears, DINO relocalizes it Output: Bounding Box + Label + ID per frame (no mask overlay, fast + low VRAM) """ import os import cv2 import numpy as np import torch from PIL import Image from typing import Optional # ────────────────────────────────────────────────────────── # Constants # ────────────────────────────────────────────────────────── MIN_MASK_AREA = 64 # pixels² — mask smaller than this = "lost" CHUNK_SIZE_DEFAULT = 120 # frames per SAM-2 sliding window HOME_DIR = os.path.expanduser("~") SAM2_CKPT_DEFAULT = os.path.join( HOME_DIR, ".cache", "torch", "hub", "checkpoints", "sam2.1_hiera_small.pt" ) SAM2_CFG_DEFAULT = "configs/sam2.1/sam2.1_hiera_s.yaml" # ────────────────────────────────────────────────────────── # TrackedObject — stores per-object state between chunks # ────────────────────────────────────────────────────────── class TrackedObject: """Holds identity and last known bounding box for one tracked object.""" def __init__(self, obj_id: int, label: str, box: np.ndarray): self.obj_id = obj_id self.label = label self.box = box.astype(np.float32) # [x1, y1, x2, y2] self.lost = False # True if disappeared last chunk self.lost_frames = 0 # consecutive frames without mask def __repr__(self): return f"TrackedObject(id={self.obj_id}, label='{self.label}', lost={self.lost})" # ────────────────────────────────────────────────────────── # VideoFrameStore — thin wrapper around a frames directory # ────────────────────────────────────────────────────────── class VideoFrameStore: """Extracts video frames to disk with optional stabilization, blur filter, resize.""" def __init__(self, video_path: str, output_dir: str, target_fps: Optional[float] = None, max_size: int = 720, blur_threshold: float = 0.0, stabilize: bool = False): self.video_path = video_path self.output_dir = output_dir self.target_fps = target_fps self.max_size = max_size self.blur_threshold = blur_threshold self.stabilize = stabilize self.frame_paths: list[str] = [] # sorted list of extracted frame paths self.orig_fps = 0.0 self.width = 0 self.height = 0 # ------------------------------------------------------------------ def extract(self) -> int: """Run extraction. Returns number of frames saved.""" import shutil if os.path.exists(self.output_dir): shutil.rmtree(self.output_dir) os.makedirs(self.output_dir) cap = cv2.VideoCapture(self.video_path) if not cap.isOpened(): raise RuntimeError(f"Cannot open video: {self.video_path}") self.orig_fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 raw_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) raw_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # Compute output resolution (keep aspect ratio) scale = min(1.0, self.max_size / max(raw_w, raw_h)) self.width = int(raw_w * scale) self.height = int(raw_h * scale) # How many original frames to skip between each saved frame sample_interval = max(1, round(self.orig_fps / self.target_fps)) \ if self.target_fps and self.target_fps > 0 else 1 stab_diff = None if self.stabilize: stab_diff = self._compute_stabilization(cap, raw_w, raw_h) cap.release() cap = cv2.VideoCapture(self.video_path) saved = 0 orig_idx = 0 while True: ret, frame = cap.read() if not ret: break # --- sample at target fps --- if orig_idx % sample_interval != 0: orig_idx += 1 continue # --- apply stabilization warp --- if stab_diff is not None and orig_idx < len(stab_diff): dx, dy, da = stab_diff[orig_idx] M = np.array([[np.cos(da), -np.sin(da), dx], [np.sin(da), np.cos(da), dy]], dtype=np.float32) frame = cv2.warpAffine(frame, M, (raw_w, raw_h)) # --- resize --- if scale < 1.0: frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA) # --- blur filter --- if self.blur_threshold > 0: gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) if cv2.Laplacian(gray, cv2.CV_64F).var() < self.blur_threshold: orig_idx += 1 continue path = os.path.join(self.output_dir, f"{saved:05d}.jpg") cv2.imwrite(path, frame, [cv2.IMWRITE_JPEG_QUALITY, 95]) self.frame_paths.append(path) saved += 1 orig_idx += 1 cap.release() # Fallback: if blur filter ate everything, save at least 1 frame if saved == 0: print("[WARN] All frames were blurry — saving 1 raw frame as fallback.") cap = cv2.VideoCapture(self.video_path) ret, frame = cap.read() cap.release() if ret: if scale < 1.0: frame = cv2.resize(frame, (self.width, self.height)) path = os.path.join(self.output_dir, "00000.jpg") cv2.imwrite(path, frame) self.frame_paths.append(path) saved = 1 print(f"[Extract] Saved {saved} frames → {self.output_dir}") return saved # ------------------------------------------------------------------ def _compute_stabilization(self, cap, raw_w, raw_h): """ORB-based motion estimation → smoothed correction matrix per frame.""" print("[Stabilize] Computing ORB motion trajectory …") transforms = [] prev_gray = None scale = 480.0 / max(raw_w, raw_h) while True: ret, frame = cap.read() if not ret: break small = cv2.resize(frame, (int(raw_w * scale), int(raw_h * scale))) gray = cv2.cvtColor(small, cv2.COLOR_BGR2GRAY) dx = dy = da = 0.0 if prev_gray is not None: orb = cv2.ORB_create(300) kp1, d1 = orb.detectAndCompute(prev_gray, None) kp2, d2 = orb.detectAndCompute(gray, None) if d1 is not None and d2 is not None and len(kp1) > 5 and len(kp2) > 5: bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True) matches = sorted(bf.match(d1, d2), key=lambda m: m.distance)[:50] if len(matches) >= 4: pts1 = np.float32([kp1[m.queryIdx].pt for m in matches]) pts2 = np.float32([kp2[m.trainIdx].pt for m in matches]) M, _ = cv2.estimateAffinePartial2D(pts1, pts2) if M is not None: dx = M[0, 2] / scale dy = M[1, 2] / scale da = np.arctan2(M[1, 0], M[0, 0]) transforms.append(np.array([dx, dy, da])) prev_gray = gray transforms = np.array(transforms) traj = np.cumsum(transforms, axis=0) radius = max(1, min(30, len(traj) // 2)) smooth = np.copy(traj) for i in range(len(traj)): s, e = max(0, i - radius), min(len(traj), i + radius + 1) smooth[i] = np.mean(traj[s:e], axis=0) return smooth - traj # correction per frame # ────────────────────────────────────────────────────────── # DinoDetector — wraps Grounding DINO for prompt detection # ────────────────────────────────────────────────────────── class DinoDetector: """Loads Grounding DINO and runs chunked prompt detection with NMS.""" CHUNK_SIZE = 15 # max vocabulary items per DINO call (avoids token overflow) def __init__(self, device: torch.device): self.device = device self.processor = None self.model = None def load(self): from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection print("[DINO] Loading Grounding DINO Base …") self.processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base") self.model = AutoModelForZeroShotObjectDetection.from_pretrained( "IDEA-Research/grounding-dino-base" ).to(self.device).eval() print("[DINO] Loaded.") # ------------------------------------------------------------------ def detect(self, image_path: str, prompt: str, box_threshold: float = 0.30, text_threshold: float = 0.25, iou_threshold: float = 0.45 ) -> tuple[np.ndarray, np.ndarray, list[str]]: """ Returns (boxes [N,4], scores [N], labels [N]) in pixel coords. Prompt can be a multi-line string with # comments and comma-separated items. """ image_pil = Image.open(image_path).convert("RGB") items = self._parse_prompt(prompt) if not items: return np.empty((0, 4)), np.array([]), [] # Split vocabulary into chunks of CHUNK_SIZE chunks = [items[i:i+self.CHUNK_SIZE] for i in range(0, len(items), self.CHUNK_SIZE)] all_boxes, all_scores, all_labels = [], [], [] for idx, chunk in enumerate(chunks): chunk_text = " . ".join(chunk) + " ." print(f" [DINO] chunk {idx+1}/{len(chunks)}: {chunk_text[:80]}…") inputs = self.processor( images=image_pil, text=chunk_text, return_tensors="pt" ).to(self.device) with torch.no_grad(): outputs = self.model(**inputs) results = self._post_process(outputs, inputs.input_ids, image_pil, box_threshold, text_threshold) boxes = results["boxes"].cpu().numpy() scores = results["scores"].cpu().numpy() labels = results["labels"] all_boxes.extend(boxes) all_scores.extend(scores) all_labels.extend(labels) if not all_boxes: return np.empty((0, 4)), np.array([]), [] all_boxes = np.array(all_boxes) all_scores = np.array(all_scores) keep = self._nms(all_boxes, all_scores, iou_threshold) return all_boxes[keep], all_scores[keep], [all_labels[k] for k in keep] # ------------------------------------------------------------------ def _post_process(self, outputs, input_ids, image_pil, box_thresh, text_thresh): try: return self.processor.post_process_grounded_object_detection( outputs, input_ids, box_threshold=box_thresh, text_threshold=text_thresh, target_sizes=[image_pil.size[::-1]] )[0] except TypeError: return self.processor.post_process_grounded_object_detection( outputs, input_ids, threshold=box_thresh, text_threshold=text_thresh, target_sizes=[image_pil.size[::-1]] )[0] # ------------------------------------------------------------------ @staticmethod def _parse_prompt(prompt: str) -> list[str]: items = [] for line in prompt.splitlines(): line = line.strip() if not line or line.startswith("#"): continue for part in line.replace(".", ",").split(","): p = part.strip() if p: items.append(p) seen, unique = set(), [] for x in items: if x not in seen: seen.add(x) unique.append(x) return unique # ------------------------------------------------------------------ @staticmethod def _nms(boxes: np.ndarray, scores: np.ndarray, iou_thresh: float) -> list[int]: if len(boxes) == 0: return [] x1, y1, x2, y2 = boxes[:,0], boxes[:,1], boxes[:,2], boxes[:,3] areas = (x2 - x1) * (y2 - y1) order = scores.argsort()[::-1] keep = [] while order.size > 0: i = order[0] keep.append(int(i)) xx1 = np.maximum(x1[i], x1[order[1:]]) yy1 = np.maximum(y1[i], y1[order[1:]]) xx2 = np.minimum(x2[i], x2[order[1:]]) yy2 = np.minimum(y2[i], y2[order[1:]]) w = np.maximum(0.0, xx2 - xx1) h = np.maximum(0.0, yy2 - yy1) iou = (w * h) / (areas[i] + areas[order[1:]] - w * h + 1e-6) order = order[np.where(iou <= iou_thresh)[0] + 1] return keep # ────────────────────────────────────────────────────────── # SAM2Tracker — sliding-window SAM 2 tracking engine # ────────────────────────────────────────────────────────── class SAM2Tracker: """ Proper SAM 2 video tracker with: 1. Sliding-window propagation — keeps memory bank fresh 2. Automatic lost-object detection — mask area < MIN_MASK_AREA 3. DINO re-anchor on lost objects — relocalizes using text prompt 4. Bbox-only rendering — fast, VRAM-friendly """ # Palette — visually distinct colors (BGR) PALETTE = [ (255, 80, 80), # blue-ish ( 80, 220, 80), # green ( 80, 80, 255), # red ( 0, 220, 220), # yellow (220, 0, 220), # magenta (220, 220, 0), # cyan (255, 160, 0), # orange (160, 0, 200), # purple ( 0, 180, 180), # teal ( 0, 140, 255), # gold (180, 255, 0), # lime (255, 0, 150), # pink ] def __init__(self, sam2_checkpoint: str = SAM2_CKPT_DEFAULT, sam2_cfg: str = SAM2_CFG_DEFAULT, device: Optional[torch.device] = None, chunk_size: int = CHUNK_SIZE_DEFAULT): self.device = device or torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) self.sam2_checkpoint = sam2_checkpoint self.sam2_cfg = sam2_cfg self.chunk_size = chunk_size self.predictor = None # ------------------------------------------------------------------ def load(self): from sam2.build_sam import build_sam2_video_predictor print(f"[SAM2] Loading predictor (device={self.device}) …") self.predictor = build_sam2_video_predictor( self.sam2_cfg, self.sam2_checkpoint, device=self.device ) print("[SAM2] Loaded.") # ------------------------------------------------------------------ def track_video(self, frame_store: VideoFrameStore, tracked_objects: list[TrackedObject], dino: DinoDetector, prompt: str, box_threshold: float = 0.30, text_threshold: float = 0.25, iou_threshold: float = 0.45, output_path: str = "output.mp4", progress_cb=None) -> list[str]: """ Main entry — runs sliding-window SAM 2 tracking and writes annotated video. Returns list of tracked label strings. """ frame_paths = frame_store.frame_paths total = len(frame_paths) W, H = frame_store.width, frame_store.height fps = frame_store.target_fps or frame_store.orig_fps if total == 0: raise RuntimeError("No frames to track!") if not tracked_objects: raise RuntimeError("No objects to track — run DINO detection first.") # Use browser-friendly H.264 (avc1) codec if possible, fallback to mp4v fourcc = cv2.VideoWriter_fourcc(*"avc1") writer = cv2.VideoWriter(output_path, fourcc, fps, (W, H)) if not writer.isOpened(): print("[WARN] avc1 codec not opened, falling back to mp4v.") fourcc = cv2.VideoWriter_fourcc(*"mp4v") writer = cv2.VideoWriter(output_path, fourcc, fps, (W, H)) # ── sliding window loop ────────────────────────────────────── # We divide the video into chunks. SAM 2 is initialized fresh # at the start of each chunk using the last known box of every object. # This keeps the memory bank small and accurate. chunk_starts = list(range(0, total, self.chunk_size)) print(f"\n[Track] {total} frames · {len(chunk_starts)} chunk(s) · " f"chunk_size={self.chunk_size}") for c_num, chunk_start in enumerate(chunk_starts): chunk_end = min(chunk_start + self.chunk_size, total) chunk_paths = frame_paths[chunk_start:chunk_end] chunk_len = len(chunk_paths) print(f"\n[Chunk {c_num+1}/{len(chunk_starts)}] " f"frames {chunk_start}–{chunk_end-1} ({chunk_len} frames)") # ── 1. Create chunk frames directory using symlinks ────── import tempfile, shutil chunk_dir = os.path.join( os.path.dirname(frame_store.output_dir), f"_chunk_{c_num:04d}" ) if os.path.exists(chunk_dir): shutil.rmtree(chunk_dir) os.makedirs(chunk_dir) # Use os.symlink for instantaneous setup and disk space saving for local_i, src in enumerate(chunk_paths): dst = os.path.join(chunk_dir, f"{local_i:05d}.jpg") if not os.path.exists(dst): try: os.symlink(os.path.abspath(src), dst) except Exception: # Fallback to copying if symlink fails shutil.copy2(src, dst) # ── 2. Init SAM 2 state for this chunk ────────────────── autocast = (torch.autocast("cuda", dtype=torch.bfloat16) if "cuda" in str(self.device) else torch.autocast("cpu", dtype=torch.float32)) with torch.inference_mode(), autocast: # On macOS (MPS), unified memory makes CPU offloading slow. Disable it. offload_video = True if "mps" in str(self.device): offload_video = False state = self.predictor.init_state( video_path=chunk_dir, offload_video_to_cpu=offload_video, offload_state_to_cpu=False, ) self.predictor.reset_state(state) # ── 3. Register all (non-lost) objects at local frame 0 ── registered = 0 for obj in tracked_objects: if obj.lost: print(f" [SKIP] id={obj.obj_id} '{obj.label}' is lost, " f"will try DINO recovery after this chunk.") continue self.predictor.add_new_points_or_box( inference_state=state, frame_idx=0, # always local frame 0 of chunk obj_id=obj.obj_id, box=obj.box, ) registered += 1 print(f" Registered {registered} objects at chunk start.") # ── 4. Propagate through chunk ────────────────────── # Collect: for each local frame, map obj_id → mask array chunk_masks: dict[int, dict[int, np.ndarray]] = {} # Also track last seen box per object for carry-forward last_box: dict[int, np.ndarray] = {} for local_idx, obj_ids, mask_logits in \ self.predictor.propagate_in_video(state): frame_masks: dict[int, np.ndarray] = {} for i, obj_id in enumerate(obj_ids): mask = (mask_logits[i] > 0.0).cpu().numpy().squeeze() if mask.ndim == 0: mask = np.zeros((H, W), dtype=bool) frame_masks[int(obj_id)] = mask # Update last known bounding box from mask if mask.sum() >= MIN_MASK_AREA: ys, xs = np.where(mask) new_box = np.array( [xs.min(), ys.min(), xs.max(), ys.max()], dtype=np.float32 ) last_box[int(obj_id)] = new_box chunk_masks[local_idx] = frame_masks if progress_cb: progress_cb(chunk_start + local_idx + 1, total) self.predictor.reset_state(state) # ── 5. Update tracked objects with last seen boxes ─────── # Set to lost if it didn't have a valid mask in the last frame of the chunk for obj in tracked_objects: last_frame_mask = chunk_masks.get(chunk_len - 1, {}).get(obj.obj_id) if last_frame_mask is not None and last_frame_mask.sum() >= MIN_MASK_AREA and obj.obj_id in last_box: obj.box = last_box[obj.obj_id] obj.lost = False obj.lost_frames = 0 else: obj.lost = True obj.lost_frames += chunk_len print(f" [LOST] id={obj.obj_id} '{obj.label}' — not visible at chunk end.") # ── 6. DINO recovery for lost objects ──────────────────── # Run DINO on the LAST frame of this chunk to relocate them lost_objects = [o for o in tracked_objects if o.lost] if lost_objects: last_chunk_frame = chunk_paths[-1] print(f" [Recovery] Running DINO on frame {chunk_end-1} " f"for {len(lost_objects)} lost object(s) …") boxes, scores, labels = dino.detect( last_chunk_frame, prompt, box_threshold, text_threshold, iou_threshold ) recovered = self._match_lost_to_dino( lost_objects, boxes, labels, iou_threshold ) for obj_id, new_box in recovered.items(): for obj in tracked_objects: if obj.obj_id == obj_id: obj.box = new_box obj.lost = False obj.lost_frames = 0 print(f" [Recovered] id={obj_id} '{obj.label}' " f"at chunk boundary.") # ── 7. Render and write frames for this chunk ──────────── for local_idx in range(chunk_len): global_idx = chunk_start + local_idx frame = cv2.imread(chunk_paths[local_idx]) masks_here = chunk_masks.get(local_idx, {}) # --- Save crops of objects before drawing on the frame --- for obj in tracked_objects: mask = masks_here.get(obj.obj_id) if mask is not None and mask.sum() >= MIN_MASK_AREA: area = mask.sum() if not hasattr(obj, 'max_mask_area'): obj.max_mask_area = 0 obj.best_crop = None if area > obj.max_mask_area: obj.max_mask_area = area ys, xs = np.where(mask) h_f, w_f = frame.shape[:2] bx1, bx2 = max(0, int(xs.min())), min(w_f - 1, int(xs.max())) by1, by2 = max(0, int(ys.min())), min(h_f - 1, int(ys.max())) if bx2 > bx1 and by2 > by1: obj.best_crop = frame[by1:by2+1, bx1:bx2+1].copy() # Draw annotations on the frame and write it directly to the video file frame = self._draw_frame(frame, masks_here, tracked_objects) writer.write(frame) # cleanup temp chunk dir shutil.rmtree(chunk_dir, ignore_errors=True) # Clean up memory state and empty PyTorch CUDA/MPS caches to avoid OOM if 'state' in locals(): del state import gc gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() elif hasattr(torch, 'mps') and torch.backends.mps.is_available(): torch.mps.empty_cache() if progress_cb: progress_cb(chunk_end, total) writer.release() print(f"[Done] Saved: {os.path.abspath(output_path)}") return [o.label for o in tracked_objects] # ------------------------------------------------------------------ def _draw_frame(self, frame: np.ndarray, masks: dict[int, np.ndarray], tracked_objects: list[TrackedObject]) -> np.ndarray: """ Draw bounding box + label + ID. No pixel mask overlay → fast and VRAM-independent. """ if frame is None: return frame for obj in tracked_objects: oid = obj.obj_id color = self.PALETTE[oid % len(self.PALETTE)] mask = masks.get(oid) if mask is None or mask.sum() < MIN_MASK_AREA: # Object not visible this frame — draw a faded indicator on # the last known box location x1, y1, x2, y2 = obj.box.astype(int) cv2.rectangle(frame, (x1, y1), (x2, y2), color, 1) self._put_label(frame, f"{obj.label} #{oid} [?]", x1, y1, color, alpha=0.4) continue # Derive tight bounding box from the mask ys, xs = np.where(mask) bx1, bx2 = int(xs.min()), int(xs.max()) by1, by2 = int(ys.min()), int(ys.max()) # Draw solid bounding box cv2.rectangle(frame, (bx1, by1), (bx2, by2), color, 2) self._put_label(frame, f"{obj.label} #{oid}", bx1, by1, color) return frame # ------------------------------------------------------------------ @staticmethod def _put_label(frame: np.ndarray, text: str, x: int, y: int, color: tuple, alpha: float = 1.0): font = cv2.FONT_HERSHEY_SIMPLEX scale = 0.5 thickness = 1 (tw, th), _ = cv2.getTextSize(text, font, scale, thickness) pad = 4 bkg_y1 = max(0, y - th - pad * 2) bkg_y2 = y bkg_x2 = x + tw + pad * 2 # Background rectangle if alpha >= 0.9: cv2.rectangle(frame, (x, bkg_y1), (bkg_x2, bkg_y2), color, -1) else: overlay = frame.copy() cv2.rectangle(overlay, (x, bkg_y1), (bkg_x2, bkg_y2), color, -1) cv2.addWeighted(overlay, alpha, frame, 1 - alpha, 0, frame) # Text cv2.putText(frame, text, (x + pad, y - pad), font, scale, (255, 255, 255), thickness, cv2.LINE_AA) # ------------------------------------------------------------------ @staticmethod def _match_lost_to_dino(lost_objects: list[TrackedObject], dino_boxes: np.ndarray, dino_labels: list[str], iou_threshold: float = 0.20 ) -> dict[int, np.ndarray]: """ For each lost object, find the best DINO detection that: (a) has the same label (or close substring match), AND (b) overlaps reasonably OR is the closest available detection. Returns {obj_id: new_box}. """ recovered = {} used_dino = set() for obj in lost_objects: best_idx = None best_score = -1.0 for d_idx, (d_box, d_label) in enumerate(zip(dino_boxes, dino_labels)): if d_idx in used_dino: continue # Label similarity: simple substring match label_ok = (obj.label.lower() in d_label.lower() or d_label.lower() in obj.label.lower()) if not label_ok: continue # Prefer boxes overlapping the last known location x1, y1, x2, y2 = obj.box dx1,dy1,dx2,dy2 = d_box ix1 = max(x1, dx1); iy1 = max(y1, dy1) ix2 = min(x2, dx2); iy2 = min(y2, dy2) iw = max(0, ix2 - ix1); ih = max(0, iy2 - iy1) inter = iw * ih union = ((x2-x1)*(y2-y1) + (dx2-dx1)*(dy2-dy1) - inter + 1e-6) iou = inter / union # Score: label match + IoU bonus score = 0.5 + iou if score > best_score: best_score = score best_idx = d_idx if best_idx is not None: recovered[obj.obj_id] = dino_boxes[best_idx].astype(np.float32) used_dino.add(best_idx) return recovered