Spaces:
Paused
Paused
| """ | |
| SAM 2 Robust Object Tracker | |
| ============================ | |
| Architecture: | |
| - Grounding DINO → detects WHAT to track (frame 0 + recovery) | |
| - SAM 2 Video → tracks WHERE the object is across frames | |
| - Sliding Window → processes video in chunks so SAM 2 memory stays accurate | |
| - Recovery Loop → if an object disappears, DINO relocalizes it | |
| Output: Bounding Box + Label + ID per frame (no mask overlay, fast + low VRAM) | |
| """ | |
| import os | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from typing import Optional | |
| # ────────────────────────────────────────────────────────── | |
| # Constants | |
| # ────────────────────────────────────────────────────────── | |
| MIN_MASK_AREA = 64 # pixels² — mask smaller than this = "lost" | |
| CHUNK_SIZE_DEFAULT = 120 # frames per SAM-2 sliding window | |
| HOME_DIR = os.path.expanduser("~") | |
| SAM2_CKPT_DEFAULT = os.path.join( | |
| HOME_DIR, ".cache", "torch", "hub", "checkpoints", "sam2.1_hiera_small.pt" | |
| ) | |
| SAM2_CFG_DEFAULT = "configs/sam2.1/sam2.1_hiera_s.yaml" | |
| # ────────────────────────────────────────────────────────── | |
| # TrackedObject — stores per-object state between chunks | |
| # ────────────────────────────────────────────────────────── | |
| class TrackedObject: | |
| """Holds identity and last known bounding box for one tracked object.""" | |
| def __init__(self, obj_id: int, label: str, box: np.ndarray): | |
| self.obj_id = obj_id | |
| self.label = label | |
| self.box = box.astype(np.float32) # [x1, y1, x2, y2] | |
| self.lost = False # True if disappeared last chunk | |
| self.lost_frames = 0 # consecutive frames without mask | |
| def __repr__(self): | |
| return f"TrackedObject(id={self.obj_id}, label='{self.label}', lost={self.lost})" | |
| # ────────────────────────────────────────────────────────── | |
| # VideoFrameStore — thin wrapper around a frames directory | |
| # ────────────────────────────────────────────────────────── | |
| class VideoFrameStore: | |
| """Extracts video frames to disk with optional stabilization, blur filter, resize.""" | |
| def __init__(self, video_path: str, output_dir: str, | |
| target_fps: Optional[float] = None, | |
| max_size: int = 720, | |
| blur_threshold: float = 0.0, | |
| stabilize: bool = False): | |
| self.video_path = video_path | |
| self.output_dir = output_dir | |
| self.target_fps = target_fps | |
| self.max_size = max_size | |
| self.blur_threshold = blur_threshold | |
| self.stabilize = stabilize | |
| self.frame_paths: list[str] = [] # sorted list of extracted frame paths | |
| self.orig_fps = 0.0 | |
| self.width = 0 | |
| self.height = 0 | |
| # ------------------------------------------------------------------ | |
| def extract(self) -> int: | |
| """Run extraction. Returns number of frames saved.""" | |
| import shutil | |
| if os.path.exists(self.output_dir): | |
| shutil.rmtree(self.output_dir) | |
| os.makedirs(self.output_dir) | |
| cap = cv2.VideoCapture(self.video_path) | |
| if not cap.isOpened(): | |
| raise RuntimeError(f"Cannot open video: {self.video_path}") | |
| self.orig_fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 | |
| raw_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| raw_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| # Compute output resolution (keep aspect ratio) | |
| scale = min(1.0, self.max_size / max(raw_w, raw_h)) | |
| self.width = int(raw_w * scale) | |
| self.height = int(raw_h * scale) | |
| # How many original frames to skip between each saved frame | |
| sample_interval = max(1, round(self.orig_fps / self.target_fps)) \ | |
| if self.target_fps and self.target_fps > 0 else 1 | |
| stab_diff = None | |
| if self.stabilize: | |
| stab_diff = self._compute_stabilization(cap, raw_w, raw_h) | |
| cap.release() | |
| cap = cv2.VideoCapture(self.video_path) | |
| saved = 0 | |
| orig_idx = 0 | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| # --- sample at target fps --- | |
| if orig_idx % sample_interval != 0: | |
| orig_idx += 1 | |
| continue | |
| # --- apply stabilization warp --- | |
| if stab_diff is not None and orig_idx < len(stab_diff): | |
| dx, dy, da = stab_diff[orig_idx] | |
| M = np.array([[np.cos(da), -np.sin(da), dx], | |
| [np.sin(da), np.cos(da), dy]], dtype=np.float32) | |
| frame = cv2.warpAffine(frame, M, (raw_w, raw_h)) | |
| # --- resize --- | |
| if scale < 1.0: | |
| frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA) | |
| # --- blur filter --- | |
| if self.blur_threshold > 0: | |
| gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) | |
| if cv2.Laplacian(gray, cv2.CV_64F).var() < self.blur_threshold: | |
| orig_idx += 1 | |
| continue | |
| path = os.path.join(self.output_dir, f"{saved:05d}.jpg") | |
| cv2.imwrite(path, frame, [cv2.IMWRITE_JPEG_QUALITY, 95]) | |
| self.frame_paths.append(path) | |
| saved += 1 | |
| orig_idx += 1 | |
| cap.release() | |
| # Fallback: if blur filter ate everything, save at least 1 frame | |
| if saved == 0: | |
| print("[WARN] All frames were blurry — saving 1 raw frame as fallback.") | |
| cap = cv2.VideoCapture(self.video_path) | |
| ret, frame = cap.read() | |
| cap.release() | |
| if ret: | |
| if scale < 1.0: | |
| frame = cv2.resize(frame, (self.width, self.height)) | |
| path = os.path.join(self.output_dir, "00000.jpg") | |
| cv2.imwrite(path, frame) | |
| self.frame_paths.append(path) | |
| saved = 1 | |
| print(f"[Extract] Saved {saved} frames → {self.output_dir}") | |
| return saved | |
| # ------------------------------------------------------------------ | |
| def _compute_stabilization(self, cap, raw_w, raw_h): | |
| """ORB-based motion estimation → smoothed correction matrix per frame.""" | |
| print("[Stabilize] Computing ORB motion trajectory …") | |
| transforms = [] | |
| prev_gray = None | |
| scale = 480.0 / max(raw_w, raw_h) | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| small = cv2.resize(frame, (int(raw_w * scale), int(raw_h * scale))) | |
| gray = cv2.cvtColor(small, cv2.COLOR_BGR2GRAY) | |
| dx = dy = da = 0.0 | |
| if prev_gray is not None: | |
| orb = cv2.ORB_create(300) | |
| kp1, d1 = orb.detectAndCompute(prev_gray, None) | |
| kp2, d2 = orb.detectAndCompute(gray, None) | |
| if d1 is not None and d2 is not None and len(kp1) > 5 and len(kp2) > 5: | |
| bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True) | |
| matches = sorted(bf.match(d1, d2), key=lambda m: m.distance)[:50] | |
| if len(matches) >= 4: | |
| pts1 = np.float32([kp1[m.queryIdx].pt for m in matches]) | |
| pts2 = np.float32([kp2[m.trainIdx].pt for m in matches]) | |
| M, _ = cv2.estimateAffinePartial2D(pts1, pts2) | |
| if M is not None: | |
| dx = M[0, 2] / scale | |
| dy = M[1, 2] / scale | |
| da = np.arctan2(M[1, 0], M[0, 0]) | |
| transforms.append(np.array([dx, dy, da])) | |
| prev_gray = gray | |
| transforms = np.array(transforms) | |
| traj = np.cumsum(transforms, axis=0) | |
| radius = max(1, min(30, len(traj) // 2)) | |
| smooth = np.copy(traj) | |
| for i in range(len(traj)): | |
| s, e = max(0, i - radius), min(len(traj), i + radius + 1) | |
| smooth[i] = np.mean(traj[s:e], axis=0) | |
| return smooth - traj # correction per frame | |
| # ────────────────────────────────────────────────────────── | |
| # DinoDetector — wraps Grounding DINO for prompt detection | |
| # ────────────────────────────────────────────────────────── | |
| class DinoDetector: | |
| """Loads Grounding DINO and runs chunked prompt detection with NMS.""" | |
| CHUNK_SIZE = 15 # max vocabulary items per DINO call (avoids token overflow) | |
| def __init__(self, device: torch.device): | |
| self.device = device | |
| self.processor = None | |
| self.model = None | |
| def load(self): | |
| from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection | |
| print("[DINO] Loading Grounding DINO Base …") | |
| self.processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base") | |
| self.model = AutoModelForZeroShotObjectDetection.from_pretrained( | |
| "IDEA-Research/grounding-dino-base" | |
| ).to(self.device).eval() | |
| print("[DINO] Loaded.") | |
| # ------------------------------------------------------------------ | |
| def detect(self, image_path: str, prompt: str, | |
| box_threshold: float = 0.30, | |
| text_threshold: float = 0.25, | |
| iou_threshold: float = 0.45 | |
| ) -> tuple[np.ndarray, np.ndarray, list[str]]: | |
| """ | |
| Returns (boxes [N,4], scores [N], labels [N]) in pixel coords. | |
| Prompt can be a multi-line string with # comments and comma-separated items. | |
| """ | |
| image_pil = Image.open(image_path).convert("RGB") | |
| items = self._parse_prompt(prompt) | |
| if not items: | |
| return np.empty((0, 4)), np.array([]), [] | |
| # Split vocabulary into chunks of CHUNK_SIZE | |
| chunks = [items[i:i+self.CHUNK_SIZE] | |
| for i in range(0, len(items), self.CHUNK_SIZE)] | |
| all_boxes, all_scores, all_labels = [], [], [] | |
| for idx, chunk in enumerate(chunks): | |
| chunk_text = " . ".join(chunk) + " ." | |
| print(f" [DINO] chunk {idx+1}/{len(chunks)}: {chunk_text[:80]}…") | |
| inputs = self.processor( | |
| images=image_pil, text=chunk_text, return_tensors="pt" | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| results = self._post_process(outputs, inputs.input_ids, image_pil, | |
| box_threshold, text_threshold) | |
| boxes = results["boxes"].cpu().numpy() | |
| scores = results["scores"].cpu().numpy() | |
| labels = results["labels"] | |
| all_boxes.extend(boxes) | |
| all_scores.extend(scores) | |
| all_labels.extend(labels) | |
| if not all_boxes: | |
| return np.empty((0, 4)), np.array([]), [] | |
| all_boxes = np.array(all_boxes) | |
| all_scores = np.array(all_scores) | |
| keep = self._nms(all_boxes, all_scores, iou_threshold) | |
| return all_boxes[keep], all_scores[keep], [all_labels[k] for k in keep] | |
| # ------------------------------------------------------------------ | |
| def _post_process(self, outputs, input_ids, image_pil, box_thresh, text_thresh): | |
| try: | |
| return self.processor.post_process_grounded_object_detection( | |
| outputs, input_ids, | |
| box_threshold=box_thresh, text_threshold=text_thresh, | |
| target_sizes=[image_pil.size[::-1]] | |
| )[0] | |
| except TypeError: | |
| return self.processor.post_process_grounded_object_detection( | |
| outputs, input_ids, | |
| threshold=box_thresh, text_threshold=text_thresh, | |
| target_sizes=[image_pil.size[::-1]] | |
| )[0] | |
| # ------------------------------------------------------------------ | |
| def _parse_prompt(prompt: str) -> list[str]: | |
| items = [] | |
| for line in prompt.splitlines(): | |
| line = line.strip() | |
| if not line or line.startswith("#"): | |
| continue | |
| for part in line.replace(".", ",").split(","): | |
| p = part.strip() | |
| if p: | |
| items.append(p) | |
| seen, unique = set(), [] | |
| for x in items: | |
| if x not in seen: | |
| seen.add(x) | |
| unique.append(x) | |
| return unique | |
| # ------------------------------------------------------------------ | |
| def _nms(boxes: np.ndarray, scores: np.ndarray, iou_thresh: float) -> list[int]: | |
| if len(boxes) == 0: | |
| return [] | |
| x1, y1, x2, y2 = boxes[:,0], boxes[:,1], boxes[:,2], boxes[:,3] | |
| areas = (x2 - x1) * (y2 - y1) | |
| order = scores.argsort()[::-1] | |
| keep = [] | |
| while order.size > 0: | |
| i = order[0] | |
| keep.append(int(i)) | |
| xx1 = np.maximum(x1[i], x1[order[1:]]) | |
| yy1 = np.maximum(y1[i], y1[order[1:]]) | |
| xx2 = np.minimum(x2[i], x2[order[1:]]) | |
| yy2 = np.minimum(y2[i], y2[order[1:]]) | |
| w = np.maximum(0.0, xx2 - xx1) | |
| h = np.maximum(0.0, yy2 - yy1) | |
| iou = (w * h) / (areas[i] + areas[order[1:]] - w * h + 1e-6) | |
| order = order[np.where(iou <= iou_thresh)[0] + 1] | |
| return keep | |
| # ────────────────────────────────────────────────────────── | |
| # SAM2Tracker — sliding-window SAM 2 tracking engine | |
| # ────────────────────────────────────────────────────────── | |
| class SAM2Tracker: | |
| """ | |
| Proper SAM 2 video tracker with: | |
| 1. Sliding-window propagation — keeps memory bank fresh | |
| 2. Automatic lost-object detection — mask area < MIN_MASK_AREA | |
| 3. DINO re-anchor on lost objects — relocalizes using text prompt | |
| 4. Bbox-only rendering — fast, VRAM-friendly | |
| """ | |
| # Palette — visually distinct colors (BGR) | |
| PALETTE = [ | |
| (255, 80, 80), # blue-ish | |
| ( 80, 220, 80), # green | |
| ( 80, 80, 255), # red | |
| ( 0, 220, 220), # yellow | |
| (220, 0, 220), # magenta | |
| (220, 220, 0), # cyan | |
| (255, 160, 0), # orange | |
| (160, 0, 200), # purple | |
| ( 0, 180, 180), # teal | |
| ( 0, 140, 255), # gold | |
| (180, 255, 0), # lime | |
| (255, 0, 150), # pink | |
| ] | |
| def __init__(self, | |
| sam2_checkpoint: str = SAM2_CKPT_DEFAULT, | |
| sam2_cfg: str = SAM2_CFG_DEFAULT, | |
| device: Optional[torch.device] = None, | |
| chunk_size: int = CHUNK_SIZE_DEFAULT): | |
| self.device = device or torch.device( | |
| "cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| self.sam2_checkpoint = sam2_checkpoint | |
| self.sam2_cfg = sam2_cfg | |
| self.chunk_size = chunk_size | |
| self.predictor = None | |
| # ------------------------------------------------------------------ | |
| def load(self): | |
| from sam2.build_sam import build_sam2_video_predictor | |
| print(f"[SAM2] Loading predictor (device={self.device}) …") | |
| self.predictor = build_sam2_video_predictor( | |
| self.sam2_cfg, self.sam2_checkpoint, device=self.device | |
| ) | |
| print("[SAM2] Loaded.") | |
| # ------------------------------------------------------------------ | |
| def track_video(self, | |
| frame_store: VideoFrameStore, | |
| tracked_objects: list[TrackedObject], | |
| dino: DinoDetector, | |
| prompt: str, | |
| box_threshold: float = 0.30, | |
| text_threshold: float = 0.25, | |
| iou_threshold: float = 0.45, | |
| output_path: str = "output.mp4", | |
| progress_cb=None) -> list[str]: | |
| """ | |
| Main entry — runs sliding-window SAM 2 tracking and writes annotated video. | |
| Returns list of tracked label strings. | |
| """ | |
| frame_paths = frame_store.frame_paths | |
| total = len(frame_paths) | |
| W, H = frame_store.width, frame_store.height | |
| fps = frame_store.target_fps or frame_store.orig_fps | |
| if total == 0: | |
| raise RuntimeError("No frames to track!") | |
| if not tracked_objects: | |
| raise RuntimeError("No objects to track — run DINO detection first.") | |
| # Use browser-friendly H.264 (avc1) codec if possible, fallback to mp4v | |
| fourcc = cv2.VideoWriter_fourcc(*"avc1") | |
| writer = cv2.VideoWriter(output_path, fourcc, fps, (W, H)) | |
| if not writer.isOpened(): | |
| print("[WARN] avc1 codec not opened, falling back to mp4v.") | |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
| writer = cv2.VideoWriter(output_path, fourcc, fps, (W, H)) | |
| # ── sliding window loop ────────────────────────────────────── | |
| # We divide the video into chunks. SAM 2 is initialized fresh | |
| # at the start of each chunk using the last known box of every object. | |
| # This keeps the memory bank small and accurate. | |
| chunk_starts = list(range(0, total, self.chunk_size)) | |
| print(f"\n[Track] {total} frames · {len(chunk_starts)} chunk(s) · " | |
| f"chunk_size={self.chunk_size}") | |
| for c_num, chunk_start in enumerate(chunk_starts): | |
| chunk_end = min(chunk_start + self.chunk_size, total) | |
| chunk_paths = frame_paths[chunk_start:chunk_end] | |
| chunk_len = len(chunk_paths) | |
| print(f"\n[Chunk {c_num+1}/{len(chunk_starts)}] " | |
| f"frames {chunk_start}–{chunk_end-1} ({chunk_len} frames)") | |
| # ── 1. Create chunk frames directory using symlinks ────── | |
| import tempfile, shutil | |
| chunk_dir = os.path.join( | |
| os.path.dirname(frame_store.output_dir), | |
| f"_chunk_{c_num:04d}" | |
| ) | |
| if os.path.exists(chunk_dir): | |
| shutil.rmtree(chunk_dir) | |
| os.makedirs(chunk_dir) | |
| # Use os.symlink for instantaneous setup and disk space saving | |
| for local_i, src in enumerate(chunk_paths): | |
| dst = os.path.join(chunk_dir, f"{local_i:05d}.jpg") | |
| if not os.path.exists(dst): | |
| try: | |
| os.symlink(os.path.abspath(src), dst) | |
| except Exception: | |
| # Fallback to copying if symlink fails | |
| shutil.copy2(src, dst) | |
| # ── 2. Init SAM 2 state for this chunk ────────────────── | |
| autocast = (torch.autocast("cuda", dtype=torch.bfloat16) | |
| if "cuda" in str(self.device) else | |
| torch.autocast("cpu", dtype=torch.float32)) | |
| with torch.inference_mode(), autocast: | |
| # On macOS (MPS), unified memory makes CPU offloading slow. Disable it. | |
| offload_video = True | |
| if "mps" in str(self.device): | |
| offload_video = False | |
| state = self.predictor.init_state( | |
| video_path=chunk_dir, | |
| offload_video_to_cpu=offload_video, | |
| offload_state_to_cpu=False, | |
| ) | |
| self.predictor.reset_state(state) | |
| # ── 3. Register all (non-lost) objects at local frame 0 ── | |
| registered = 0 | |
| for obj in tracked_objects: | |
| if obj.lost: | |
| print(f" [SKIP] id={obj.obj_id} '{obj.label}' is lost, " | |
| f"will try DINO recovery after this chunk.") | |
| continue | |
| self.predictor.add_new_points_or_box( | |
| inference_state=state, | |
| frame_idx=0, # always local frame 0 of chunk | |
| obj_id=obj.obj_id, | |
| box=obj.box, | |
| ) | |
| registered += 1 | |
| print(f" Registered {registered} objects at chunk start.") | |
| # ── 4. Propagate through chunk ────────────────────── | |
| # Collect: for each local frame, map obj_id → mask array | |
| chunk_masks: dict[int, dict[int, np.ndarray]] = {} | |
| # Also track last seen box per object for carry-forward | |
| last_box: dict[int, np.ndarray] = {} | |
| for local_idx, obj_ids, mask_logits in \ | |
| self.predictor.propagate_in_video(state): | |
| frame_masks: dict[int, np.ndarray] = {} | |
| for i, obj_id in enumerate(obj_ids): | |
| mask = (mask_logits[i] > 0.0).cpu().numpy().squeeze() | |
| if mask.ndim == 0: | |
| mask = np.zeros((H, W), dtype=bool) | |
| frame_masks[int(obj_id)] = mask | |
| # Update last known bounding box from mask | |
| if mask.sum() >= MIN_MASK_AREA: | |
| ys, xs = np.where(mask) | |
| new_box = np.array( | |
| [xs.min(), ys.min(), xs.max(), ys.max()], | |
| dtype=np.float32 | |
| ) | |
| last_box[int(obj_id)] = new_box | |
| chunk_masks[local_idx] = frame_masks | |
| if progress_cb: | |
| progress_cb(chunk_start + local_idx + 1, total) | |
| self.predictor.reset_state(state) | |
| # ── 5. Update tracked objects with last seen boxes ─────── | |
| # Set to lost if it didn't have a valid mask in the last frame of the chunk | |
| for obj in tracked_objects: | |
| last_frame_mask = chunk_masks.get(chunk_len - 1, {}).get(obj.obj_id) | |
| if last_frame_mask is not None and last_frame_mask.sum() >= MIN_MASK_AREA and obj.obj_id in last_box: | |
| obj.box = last_box[obj.obj_id] | |
| obj.lost = False | |
| obj.lost_frames = 0 | |
| else: | |
| obj.lost = True | |
| obj.lost_frames += chunk_len | |
| print(f" [LOST] id={obj.obj_id} '{obj.label}' — not visible at chunk end.") | |
| # ── 6. DINO recovery for lost objects ──────────────────── | |
| # Run DINO on the LAST frame of this chunk to relocate them | |
| lost_objects = [o for o in tracked_objects if o.lost] | |
| if lost_objects: | |
| last_chunk_frame = chunk_paths[-1] | |
| print(f" [Recovery] Running DINO on frame {chunk_end-1} " | |
| f"for {len(lost_objects)} lost object(s) …") | |
| boxes, scores, labels = dino.detect( | |
| last_chunk_frame, prompt, | |
| box_threshold, text_threshold, iou_threshold | |
| ) | |
| recovered = self._match_lost_to_dino( | |
| lost_objects, boxes, labels, iou_threshold | |
| ) | |
| for obj_id, new_box in recovered.items(): | |
| for obj in tracked_objects: | |
| if obj.obj_id == obj_id: | |
| obj.box = new_box | |
| obj.lost = False | |
| obj.lost_frames = 0 | |
| print(f" [Recovered] id={obj_id} '{obj.label}' " | |
| f"at chunk boundary.") | |
| # ── 7. Render and write frames for this chunk ──────────── | |
| for local_idx in range(chunk_len): | |
| global_idx = chunk_start + local_idx | |
| frame = cv2.imread(chunk_paths[local_idx]) | |
| masks_here = chunk_masks.get(local_idx, {}) | |
| # --- Save crops of objects before drawing on the frame --- | |
| for obj in tracked_objects: | |
| mask = masks_here.get(obj.obj_id) | |
| if mask is not None and mask.sum() >= MIN_MASK_AREA: | |
| area = mask.sum() | |
| if not hasattr(obj, 'max_mask_area'): | |
| obj.max_mask_area = 0 | |
| obj.best_crop = None | |
| if area > obj.max_mask_area: | |
| obj.max_mask_area = area | |
| ys, xs = np.where(mask) | |
| h_f, w_f = frame.shape[:2] | |
| bx1, bx2 = max(0, int(xs.min())), min(w_f - 1, int(xs.max())) | |
| by1, by2 = max(0, int(ys.min())), min(h_f - 1, int(ys.max())) | |
| if bx2 > bx1 and by2 > by1: | |
| obj.best_crop = frame[by1:by2+1, bx1:bx2+1].copy() | |
| # Draw annotations on the frame and write it directly to the video file | |
| frame = self._draw_frame(frame, masks_here, tracked_objects) | |
| writer.write(frame) | |
| # cleanup temp chunk dir | |
| shutil.rmtree(chunk_dir, ignore_errors=True) | |
| # Clean up memory state and empty PyTorch CUDA/MPS caches to avoid OOM | |
| if 'state' in locals(): | |
| del state | |
| import gc | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| elif hasattr(torch, 'mps') and torch.backends.mps.is_available(): | |
| torch.mps.empty_cache() | |
| if progress_cb: | |
| progress_cb(chunk_end, total) | |
| writer.release() | |
| print(f"[Done] Saved: {os.path.abspath(output_path)}") | |
| return [o.label for o in tracked_objects] | |
| # ------------------------------------------------------------------ | |
| def _draw_frame(self, | |
| frame: np.ndarray, | |
| masks: dict[int, np.ndarray], | |
| tracked_objects: list[TrackedObject]) -> np.ndarray: | |
| """ | |
| Draw bounding box + label + ID. | |
| No pixel mask overlay → fast and VRAM-independent. | |
| """ | |
| if frame is None: | |
| return frame | |
| for obj in tracked_objects: | |
| oid = obj.obj_id | |
| color = self.PALETTE[oid % len(self.PALETTE)] | |
| mask = masks.get(oid) | |
| if mask is None or mask.sum() < MIN_MASK_AREA: | |
| # Object not visible this frame — draw a faded indicator on | |
| # the last known box location | |
| x1, y1, x2, y2 = obj.box.astype(int) | |
| cv2.rectangle(frame, (x1, y1), (x2, y2), color, 1) | |
| self._put_label(frame, f"{obj.label} #{oid} [?]", | |
| x1, y1, color, alpha=0.4) | |
| continue | |
| # Derive tight bounding box from the mask | |
| ys, xs = np.where(mask) | |
| bx1, bx2 = int(xs.min()), int(xs.max()) | |
| by1, by2 = int(ys.min()), int(ys.max()) | |
| # Draw solid bounding box | |
| cv2.rectangle(frame, (bx1, by1), (bx2, by2), color, 2) | |
| self._put_label(frame, f"{obj.label} #{oid}", bx1, by1, color) | |
| return frame | |
| # ------------------------------------------------------------------ | |
| def _put_label(frame: np.ndarray, text: str, | |
| x: int, y: int, color: tuple, | |
| alpha: float = 1.0): | |
| font = cv2.FONT_HERSHEY_SIMPLEX | |
| scale = 0.5 | |
| thickness = 1 | |
| (tw, th), _ = cv2.getTextSize(text, font, scale, thickness) | |
| pad = 4 | |
| bkg_y1 = max(0, y - th - pad * 2) | |
| bkg_y2 = y | |
| bkg_x2 = x + tw + pad * 2 | |
| # Background rectangle | |
| if alpha >= 0.9: | |
| cv2.rectangle(frame, (x, bkg_y1), (bkg_x2, bkg_y2), color, -1) | |
| else: | |
| overlay = frame.copy() | |
| cv2.rectangle(overlay, (x, bkg_y1), (bkg_x2, bkg_y2), color, -1) | |
| cv2.addWeighted(overlay, alpha, frame, 1 - alpha, 0, frame) | |
| # Text | |
| cv2.putText(frame, text, (x + pad, y - pad), | |
| font, scale, (255, 255, 255), thickness, cv2.LINE_AA) | |
| # ------------------------------------------------------------------ | |
| def _match_lost_to_dino(lost_objects: list[TrackedObject], | |
| dino_boxes: np.ndarray, | |
| dino_labels: list[str], | |
| iou_threshold: float = 0.20 | |
| ) -> dict[int, np.ndarray]: | |
| """ | |
| For each lost object, find the best DINO detection that: | |
| (a) has the same label (or close substring match), AND | |
| (b) overlaps reasonably OR is the closest available detection. | |
| Returns {obj_id: new_box}. | |
| """ | |
| recovered = {} | |
| used_dino = set() | |
| for obj in lost_objects: | |
| best_idx = None | |
| best_score = -1.0 | |
| for d_idx, (d_box, d_label) in enumerate(zip(dino_boxes, dino_labels)): | |
| if d_idx in used_dino: | |
| continue | |
| # Label similarity: simple substring match | |
| label_ok = (obj.label.lower() in d_label.lower() or | |
| d_label.lower() in obj.label.lower()) | |
| if not label_ok: | |
| continue | |
| # Prefer boxes overlapping the last known location | |
| x1, y1, x2, y2 = obj.box | |
| dx1,dy1,dx2,dy2 = d_box | |
| ix1 = max(x1, dx1); iy1 = max(y1, dy1) | |
| ix2 = min(x2, dx2); iy2 = min(y2, dy2) | |
| iw = max(0, ix2 - ix1); ih = max(0, iy2 - iy1) | |
| inter = iw * ih | |
| union = ((x2-x1)*(y2-y1) + (dx2-dx1)*(dy2-dy1) - inter + 1e-6) | |
| iou = inter / union | |
| # Score: label match + IoU bonus | |
| score = 0.5 + iou | |
| if score > best_score: | |
| best_score = score | |
| best_idx = d_idx | |
| if best_idx is not None: | |
| recovered[obj.obj_id] = dino_boxes[best_idx].astype(np.float32) | |
| used_dino.add(best_idx) | |
| return recovered | |