Claude commited on
Commit
9aa33d8
Β·
unverified Β·
1 Parent(s): 88a545a

fix: forensic code trace fixes across all inspection modules

Browse files

Critical:
- masks.py: fix RLE encode double leading zero that corrupted masks
starting with foreground pixels; vectorize RLE loop with numpy
- router.py: add _parse_track_id() to prevent unhandled ValueError
crashes on malformed track IDs (7 locations)

High:
- attention.py: remove double inference in YOLO saliency, remove
dead code (PIL/torchvision imports), rename GradCAMExtractor to
ActivationSaliencyExtractor (no gradients were ever computed),
remove unused backward hook, add query parameter for Grounding
DINO, switch forward pass to torch.no_grad()
- router.py: add _find_track() helper fixing instance_id=0 being
skipped due to falsy `or` fallback

Medium:
- depth.py, attention.py, superres.py, pointcloud.py: convert all
caches from dict to OrderedDict with move_to_end for LRU eviction
- depth.py, superres.py, sam2_mask.py: store (model, lock) tuples
instead of monkey-patching .lock attribute onto model instances
- pointcloud.py: add depth_map/color_image shape validation, bbox
validation, and efficient bbox-scoped meshgrid allocation
- router.py: add format validation for mask endpoint, sam2_size
validation, type checks on POST body numeric fields, fix mutable
default body={} to body=None, deduplicate TRACK_COLORS to
module-level constant

Low:
- frames.py: change FileNotFoundError to ValueError for corrupt
videos, add bbox validation in crop_frame
- superres.py: enable tiling (tile=256) to prevent OOM on large crops

https://claude.ai/code/session_01XQ1edVcrdcMErbKF53r1aF

inspection/attention.py CHANGED
@@ -1,19 +1,18 @@
1
- """GradCAM-style attention heatmap generation for detector models.
2
 
3
  Produces per-object attention maps showing which regions of the input
4
  image the detector model focused on when detecting a particular object.
5
 
6
- For Transformers-based detectors (DETR, Grounding DINO) we use true
7
- GradCAM by hooking the backbone's last feature layer. For Ultralytics
8
- YOLO models we generate an activation-based saliency map from the
9
- model's internal feature maps (no gradient needed since YOLO doesn't
10
- easily support GradCAM due to its anchor-free detection head).
11
 
12
  Model instances are cached per-device for multi-GPU round-robin,
13
  matching the pattern used in inference.py.
14
  """
15
 
16
  import base64
 
17
  import logging
18
  import threading
19
  from typing import Dict, Optional, Tuple
@@ -26,7 +25,7 @@ logger = logging.getLogger(__name__)
26
 
27
  # ── In-memory attention cache ────────────────────────────────────
28
  # Key: (job_id, frame_idx, track_id_str) Value: heatmap (HxW float32 0-1)
29
- _attention_cache: Dict[Tuple[str, int, str], np.ndarray] = {}
30
  _cache_lock = threading.RLock()
31
  _MAX_CACHE_ENTRIES = 200
32
 
@@ -34,9 +33,13 @@ _MAX_CACHE_ENTRIES = 200
34
  def get_cached_attention(
35
  job_id: str, frame_idx: int, track_id: str
36
  ) -> Optional[np.ndarray]:
37
- """Return cached attention heatmap or None."""
38
  with _cache_lock:
39
- return _attention_cache.get((job_id, frame_idx, track_id))
 
 
 
 
40
 
41
 
42
  def set_cached_attention(
@@ -91,11 +94,11 @@ def _get_detector(detector_name: str, device: str):
91
  return detector
92
 
93
 
94
- # ── GradCAM for HF Transformers models (DETR, Grounding DINO) ───
95
 
96
 
97
  def _find_target_layer(model: torch.nn.Module) -> Optional[torch.nn.Module]:
98
- """Find the last convolutional or attention layer suitable for GradCAM.
99
 
100
  Tries several strategies in order:
101
  1. DETR ResNet backbone: model.model.backbone.conv_encoder.model.layer4
@@ -136,11 +139,17 @@ def _find_target_layer(model: torch.nn.Module) -> Optional[torch.nn.Module]:
136
  return last_conv
137
 
138
 
139
- class GradCAMExtractor:
140
- """Extract GradCAM heatmaps from a PyTorch model.
 
 
 
 
 
 
141
 
142
  Usage:
143
- extractor = GradCAMExtractor(model, target_layer)
144
  heatmap = extractor.generate(input_tensor, target_bbox)
145
  extractor.release() # remove hooks
146
  """
@@ -149,11 +158,9 @@ class GradCAMExtractor:
149
  self.model = model
150
  self.target_layer = target_layer
151
  self._activations: Optional[torch.Tensor] = None
152
- self._gradients: Optional[torch.Tensor] = None
153
 
154
- # Register hooks
155
  self._fwd_hook = target_layer.register_forward_hook(self._save_activation)
156
- self._bwd_hook = target_layer.register_full_backward_hook(self._save_gradient)
157
 
158
  def _save_activation(self, module, input, output):
159
  if isinstance(output, torch.Tensor):
@@ -161,12 +168,6 @@ class GradCAMExtractor:
161
  elif isinstance(output, (tuple, list)) and len(output) > 0:
162
  self._activations = output[0].detach()
163
 
164
- def _save_gradient(self, module, grad_input, grad_output):
165
- if isinstance(grad_output, (tuple, list)) and len(grad_output) > 0:
166
- self._gradients = grad_output[0].detach()
167
- elif isinstance(grad_output, torch.Tensor):
168
- self._gradients = grad_output.detach()
169
-
170
  def generate(
171
  self,
172
  input_tensor: torch.Tensor,
@@ -174,10 +175,14 @@ class GradCAMExtractor:
174
  frame_h: int,
175
  frame_w: int,
176
  ) -> np.ndarray:
177
- """Generate a GradCAM heatmap for a target bounding box.
 
 
 
 
178
 
179
  Args:
180
- input_tensor: Preprocessed model input tensor.
181
  target_bbox: [x1, y1, x2, y2] in original frame pixel coords.
182
  frame_h: Original frame height.
183
  frame_w: Original frame width.
@@ -186,20 +191,17 @@ class GradCAMExtractor:
186
  HxW float32 array normalized to [0, 1], at the model's
187
  feature map resolution (upscaled to frame size).
188
  """
189
- self.model.zero_grad()
190
  self._activations = None
191
- self._gradients = None
192
 
193
- # Enable gradients temporarily
194
  was_training = self.model.training
195
  self.model.eval()
196
 
197
- # Forward pass with gradients enabled on input
198
- with torch.enable_grad():
199
  outputs = self.model(**{k: v for k, v in input_tensor.items()})
200
 
201
  if self._activations is None:
202
- logger.warning("GradCAM: no activations captured; returning uniform map")
203
  return np.ones((frame_h, frame_w), dtype=np.float32) * 0.5
204
 
205
  # Use the activation map directly as a saliency proxy when
@@ -281,7 +283,6 @@ class GradCAMExtractor:
281
  def release(self):
282
  """Remove hooks from the model."""
283
  self._fwd_hook.remove()
284
- self._bwd_hook.remove()
285
 
286
 
287
  # ── YOLO saliency (activation-based, no gradients) ──────────────
@@ -295,7 +296,7 @@ def _yolo_saliency(
295
  """Generate an activation-based saliency map from a YOLO model.
296
 
297
  Uses the model's internal feature pyramid activations as a proxy
298
- for attention. This avoids the complexity of GradCAM with YOLO's
299
  anchor-free heads.
300
 
301
  Args:
@@ -308,17 +309,7 @@ def _yolo_saliency(
308
  """
309
  frame_h, frame_w = frame.shape[:2]
310
 
311
- # Run inference to get internal features
312
- results = yolo_model.predict(
313
- source=frame,
314
- device=yolo_model.device if hasattr(yolo_model, 'device') else None,
315
- conf=0.1,
316
- imgsz=640,
317
- verbose=False,
318
- )
319
-
320
- # Try to extract feature maps from the model internals
321
- # Ultralytics stores intermediate outputs during forward pass
322
  cam = None
323
 
324
  try:
@@ -339,15 +330,10 @@ def _yolo_saliency(
339
  def hook_fn(module, inp, out, store=activation):
340
  store["out"] = out.detach()
341
 
 
342
  handle = layer.register_forward_hook(hook_fn)
343
 
344
- # Re-run forward pass to capture activations
345
- rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
346
- from PIL import Image
347
- import torchvision.transforms as T
348
-
349
- img = Image.fromarray(rgb)
350
- # Use the same preprocessing as YOLO
351
  yolo_model.predict(
352
  source=frame,
353
  device=yolo_model.device if hasattr(yolo_model, 'device') else None,
@@ -428,6 +414,7 @@ def generate_attention_map(
428
  frame_idx: int,
429
  track_id: str,
430
  device: str = None,
 
431
  ) -> np.ndarray:
432
  """Generate an attention heatmap for a detected object.
433
 
@@ -443,6 +430,8 @@ def generate_attention_map(
443
  track_id: Track ID string (for caching).
444
  device: GPU device string (e.g. 'cuda:0'). If None, uses
445
  round-robin selection via next_device().
 
 
446
 
447
  Returns:
448
  HxW float32 heatmap normalized to [0, 1].
@@ -470,7 +459,7 @@ def generate_attention_map(
470
  logger.warning("YOLO saliency generation failed: %s", e)
471
 
472
  elif detector_name in ("detr_resnet50", "grounding_dino"):
473
- # Transformers models β€” use GradCAM on backbone
474
  try:
475
  detector = _get_detector(detector_name, device)
476
  with detector.lock:
@@ -478,14 +467,14 @@ def generate_attention_map(
478
  target_layer = _find_target_layer(model)
479
 
480
  if target_layer is not None:
481
- extractor = GradCAMExtractor(model, target_layer)
482
  try:
483
  # Prepare input
484
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
485
  processor = detector.processor
486
  if detector_name == "grounding_dino":
487
  inputs = processor(
488
- images=frame_rgb, text="object.", return_tensors="pt"
489
  )
490
  else:
491
  inputs = processor(images=frame_rgb, return_tensors="pt")
@@ -500,7 +489,7 @@ def generate_attention_map(
500
  "No suitable target layer found for %s", detector_name
501
  )
502
  except Exception as e:
503
- logger.warning("GradCAM generation failed for %s: %s", detector_name, e)
504
 
505
  # Fallback: Gaussian heatmap centered on bbox
506
  if heatmap is None:
@@ -518,8 +507,9 @@ def generate_attention_map(
518
 
519
 
520
  def heatmap_to_base64(heatmap: np.ndarray) -> str:
521
- """Encode heatmap as base64 float32 bytes."""
522
- raw = heatmap.astype(np.float32).tobytes()
 
523
  return base64.b64encode(raw).decode("ascii")
524
 
525
 
 
1
+ """Activation-based saliency heatmap generation for detector models.
2
 
3
  Produces per-object attention maps showing which regions of the input
4
  image the detector model focused on when detecting a particular object.
5
 
6
+ For all detector architectures we compute activation L2 norms from a
7
+ hooked backbone layer as a spatial saliency proxy. No gradients are
8
+ computed.
 
 
9
 
10
  Model instances are cached per-device for multi-GPU round-robin,
11
  matching the pattern used in inference.py.
12
  """
13
 
14
  import base64
15
+ import collections
16
  import logging
17
  import threading
18
  from typing import Dict, Optional, Tuple
 
25
 
26
  # ── In-memory attention cache ────────────────────────────────────
27
  # Key: (job_id, frame_idx, track_id_str) Value: heatmap (HxW float32 0-1)
28
+ _attention_cache: collections.OrderedDict[Tuple[str, int, str], np.ndarray] = collections.OrderedDict()
29
  _cache_lock = threading.RLock()
30
  _MAX_CACHE_ENTRIES = 200
31
 
 
33
  def get_cached_attention(
34
  job_id: str, frame_idx: int, track_id: str
35
  ) -> Optional[np.ndarray]:
36
+ """Return cached attention heatmap or None (LRU: moves hit to end)."""
37
  with _cache_lock:
38
+ key = (job_id, frame_idx, track_id)
39
+ val = _attention_cache.get(key)
40
+ if val is not None:
41
+ _attention_cache.move_to_end(key) # LRU behavior
42
+ return val
43
 
44
 
45
  def set_cached_attention(
 
94
  return detector
95
 
96
 
97
+ # ── Activation saliency for HF Transformers models (DETR, Grounding DINO) ──
98
 
99
 
100
  def _find_target_layer(model: torch.nn.Module) -> Optional[torch.nn.Module]:
101
+ """Find the last convolutional or attention layer suitable for saliency extraction.
102
 
103
  Tries several strategies in order:
104
  1. DETR ResNet backbone: model.model.backbone.conv_encoder.model.layer4
 
139
  return last_conv
140
 
141
 
142
+ class ActivationSaliencyExtractor:
143
+ """Extract activation-based saliency heatmaps from a PyTorch model.
144
+
145
+ Computes channel-wise L2 norm of the target layer's activations as
146
+ a saliency proxy. No gradients are computed β€” this is purely
147
+ activation-based. The approach works well for object detection
148
+ architectures where gradient-based targeting is unreliable due to
149
+ complex target matching in the loss function.
150
 
151
  Usage:
152
+ extractor = ActivationSaliencyExtractor(model, target_layer)
153
  heatmap = extractor.generate(input_tensor, target_bbox)
154
  extractor.release() # remove hooks
155
  """
 
158
  self.model = model
159
  self.target_layer = target_layer
160
  self._activations: Optional[torch.Tensor] = None
 
161
 
162
+ # Register forward hook to capture activations
163
  self._fwd_hook = target_layer.register_forward_hook(self._save_activation)
 
164
 
165
  def _save_activation(self, module, input, output):
166
  if isinstance(output, torch.Tensor):
 
168
  elif isinstance(output, (tuple, list)) and len(output) > 0:
169
  self._activations = output[0].detach()
170
 
 
 
 
 
 
 
171
  def generate(
172
  self,
173
  input_tensor: torch.Tensor,
 
175
  frame_h: int,
176
  frame_w: int,
177
  ) -> np.ndarray:
178
+ """Generate an activation-norm saliency map for a target bounding box.
179
+
180
+ Runs a forward pass through the model and uses the L2 norm of
181
+ the captured activations (channel dimension) as a spatial saliency
182
+ map. No gradients are computed.
183
 
184
  Args:
185
+ input_tensor: Preprocessed model input dict (from processor).
186
  target_bbox: [x1, y1, x2, y2] in original frame pixel coords.
187
  frame_h: Original frame height.
188
  frame_w: Original frame width.
 
191
  HxW float32 array normalized to [0, 1], at the model's
192
  feature map resolution (upscaled to frame size).
193
  """
 
194
  self._activations = None
 
195
 
 
196
  was_training = self.model.training
197
  self.model.eval()
198
 
199
+ # Forward pass (no gradients needed)
200
+ with torch.no_grad():
201
  outputs = self.model(**{k: v for k, v in input_tensor.items()})
202
 
203
  if self._activations is None:
204
+ logger.warning("Saliency: no activations captured; returning uniform map")
205
  return np.ones((frame_h, frame_w), dtype=np.float32) * 0.5
206
 
207
  # Use the activation map directly as a saliency proxy when
 
283
  def release(self):
284
  """Remove hooks from the model."""
285
  self._fwd_hook.remove()
 
286
 
287
 
288
  # ── YOLO saliency (activation-based, no gradients) ──────────────
 
296
  """Generate an activation-based saliency map from a YOLO model.
297
 
298
  Uses the model's internal feature pyramid activations as a proxy
299
+ for attention. This avoids the complexity of gradient-based methods with YOLO's
300
  anchor-free heads.
301
 
302
  Args:
 
309
  """
310
  frame_h, frame_w = frame.shape[:2]
311
 
312
+ # Extract feature maps via a forward hook on the model internals
 
 
 
 
 
 
 
 
 
 
313
  cam = None
314
 
315
  try:
 
330
  def hook_fn(module, inp, out, store=activation):
331
  store["out"] = out.detach()
332
 
333
+ # Register hook BEFORE the single predict call
334
  handle = layer.register_forward_hook(hook_fn)
335
 
336
+ # Run predict once to capture activations
 
 
 
 
 
 
337
  yolo_model.predict(
338
  source=frame,
339
  device=yolo_model.device if hasattr(yolo_model, 'device') else None,
 
414
  frame_idx: int,
415
  track_id: str,
416
  device: str = None,
417
+ query: str = "object.",
418
  ) -> np.ndarray:
419
  """Generate an attention heatmap for a detected object.
420
 
 
430
  track_id: Track ID string (for caching).
431
  device: GPU device string (e.g. 'cuda:0'). If None, uses
432
  round-robin selection via next_device().
433
+ query: Text query for open-vocabulary detectors (Grounding DINO).
434
+ Defaults to "object." for backward compatibility.
435
 
436
  Returns:
437
  HxW float32 heatmap normalized to [0, 1].
 
459
  logger.warning("YOLO saliency generation failed: %s", e)
460
 
461
  elif detector_name in ("detr_resnet50", "grounding_dino"):
462
+ # Transformers models β€” use activation saliency on backbone
463
  try:
464
  detector = _get_detector(detector_name, device)
465
  with detector.lock:
 
467
  target_layer = _find_target_layer(model)
468
 
469
  if target_layer is not None:
470
+ extractor = ActivationSaliencyExtractor(model, target_layer)
471
  try:
472
  # Prepare input
473
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
474
  processor = detector.processor
475
  if detector_name == "grounding_dino":
476
  inputs = processor(
477
+ images=frame_rgb, text=query, return_tensors="pt"
478
  )
479
  else:
480
  inputs = processor(images=frame_rgb, return_tensors="pt")
 
489
  "No suitable target layer found for %s", detector_name
490
  )
491
  except Exception as e:
492
+ logger.warning("Activation saliency failed for %s: %s", detector_name, e)
493
 
494
  # Fallback: Gaussian heatmap centered on bbox
495
  if heatmap is None:
 
507
 
508
 
509
  def heatmap_to_base64(heatmap: np.ndarray) -> str:
510
+ """Encode heatmap as base64 uint8 bytes (quantized from float32 [0,1])."""
511
+ quantized = (heatmap.clip(0, 1) * 255).astype(np.uint8)
512
+ raw = quantized.tobytes()
513
  return base64.b64encode(raw).decode("ascii")
514
 
515
 
inspection/depth.py CHANGED
@@ -9,9 +9,10 @@ matching the pattern used in inference.py.
9
  """
10
 
11
  import base64
 
12
  import logging
13
  import threading
14
- from typing import Dict, Optional, Tuple
15
 
16
  import cv2
17
  import numpy as np
@@ -20,7 +21,7 @@ logger = logging.getLogger(__name__)
20
 
21
  # ── In-memory depth cache ────────────────────────────────────────
22
  # Key: (job_id, frame_idx) Value: depth_map (HxW float32)
23
- _depth_cache: Dict[Tuple[str, int], np.ndarray] = {}
24
  _cache_lock = threading.RLock()
25
 
26
  # Limit cache size to avoid OOM
@@ -34,7 +35,11 @@ def _cache_key(job_id: str, frame_idx: int) -> Tuple[str, int]:
34
  def get_cached_depth(job_id: str, frame_idx: int) -> Optional[np.ndarray]:
35
  """Return cached depth map or None."""
36
  with _cache_lock:
37
- return _depth_cache.get(_cache_key(job_id, frame_idx))
 
 
 
 
38
 
39
 
40
  def set_cached_depth(job_id: str, frame_idx: int, depth_map: np.ndarray) -> None:
@@ -60,7 +65,7 @@ def clear_depth_cache(job_id: Optional[str] = None) -> None:
60
 
61
  # ── Per-device model cache ───────────────────────────────────────
62
 
63
- _estimators: Dict[str, object] = {}
64
  _load_lock = threading.Lock()
65
 
66
 
@@ -81,10 +86,9 @@ def _get_estimator(device: str):
81
  from models.depth_estimators.model_loader import load_depth_estimator_on_device
82
 
83
  estimator = load_depth_estimator_on_device("depth", device)
84
- estimator.lock = threading.RLock()
85
- _estimators[device] = estimator
86
  logger.info("Depth estimator loaded on %s", device)
87
- return estimator
88
 
89
 
90
  # ── Core inference ────────────────────────────────────────────────
@@ -115,8 +119,8 @@ def run_depth_on_frame(
115
  from inspection.gpu import next_device
116
  device = next_device()
117
 
118
- estimator = _get_estimator(device)
119
- with estimator.lock:
120
  result = estimator.predict(frame)
121
  depth_map = result.depth_map # HxW float32
122
 
 
9
  """
10
 
11
  import base64
12
+ import collections
13
  import logging
14
  import threading
15
+ from typing import Optional, Tuple
16
 
17
  import cv2
18
  import numpy as np
 
21
 
22
  # ── In-memory depth cache ────────────────────────────────────────
23
  # Key: (job_id, frame_idx) Value: depth_map (HxW float32)
24
+ _depth_cache: collections.OrderedDict = collections.OrderedDict()
25
  _cache_lock = threading.RLock()
26
 
27
  # Limit cache size to avoid OOM
 
35
  def get_cached_depth(job_id: str, frame_idx: int) -> Optional[np.ndarray]:
36
  """Return cached depth map or None."""
37
  with _cache_lock:
38
+ key = _cache_key(job_id, frame_idx)
39
+ value = _depth_cache.get(key)
40
+ if value is not None:
41
+ _depth_cache.move_to_end(key)
42
+ return value
43
 
44
 
45
  def set_cached_depth(job_id: str, frame_idx: int, depth_map: np.ndarray) -> None:
 
65
 
66
  # ── Per-device model cache ───────────────────────────────────────
67
 
68
+ _estimators: dict = {}
69
  _load_lock = threading.Lock()
70
 
71
 
 
86
  from models.depth_estimators.model_loader import load_depth_estimator_on_device
87
 
88
  estimator = load_depth_estimator_on_device("depth", device)
89
+ _estimators[device] = (estimator, threading.RLock())
 
90
  logger.info("Depth estimator loaded on %s", device)
91
+ return _estimators[device]
92
 
93
 
94
  # ── Core inference ────────────────────────────────────────────────
 
119
  from inspection.gpu import next_device
120
  device = next_device()
121
 
122
+ estimator, lock = _get_estimator(device)
123
+ with lock:
124
  result = estimator.predict(frame)
125
  depth_map = result.depth_map # HxW float32
126
 
inspection/frames.py CHANGED
@@ -29,7 +29,7 @@ def extract_frame(video_path: str, frame_idx: int) -> np.ndarray:
29
  """
30
  cap = cv2.VideoCapture(video_path)
31
  if not cap.isOpened():
32
- raise FileNotFoundError(f"Cannot open video: {video_path}")
33
 
34
  try:
35
  total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
@@ -50,7 +50,7 @@ def get_video_info(video_path: str) -> dict:
50
  """Return video metadata (total_frames, fps, width, height)."""
51
  cap = cv2.VideoCapture(video_path)
52
  if not cap.isOpened():
53
- raise FileNotFoundError(f"Cannot open video: {video_path}")
54
  try:
55
  return {
56
  "total_frames": int(cap.get(cv2.CAP_PROP_FRAME_COUNT)),
@@ -77,8 +77,12 @@ def crop_frame(
77
  Returns:
78
  Cropped HxWx3 BGR numpy array.
79
  """
80
- h, w = frame.shape[:2]
81
  x1, y1, x2, y2 = bbox
 
 
 
 
 
82
 
83
  bw = x2 - x1
84
  bh = y2 - y1
@@ -103,6 +107,8 @@ def frame_to_jpeg(frame: np.ndarray, quality: int = 90) -> bytes:
103
  Returns:
104
  JPEG bytes.
105
  """
 
 
106
  encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
107
  success, buffer = cv2.imencode(".jpg", frame, encode_param)
108
  if not success:
 
29
  """
30
  cap = cv2.VideoCapture(video_path)
31
  if not cap.isOpened():
32
+ raise ValueError(f"Cannot open video file: {video_path}")
33
 
34
  try:
35
  total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
 
50
  """Return video metadata (total_frames, fps, width, height)."""
51
  cap = cv2.VideoCapture(video_path)
52
  if not cap.isOpened():
53
+ raise ValueError(f"Cannot open video file: {video_path}")
54
  try:
55
  return {
56
  "total_frames": int(cap.get(cv2.CAP_PROP_FRAME_COUNT)),
 
77
  Returns:
78
  Cropped HxWx3 BGR numpy array.
79
  """
 
80
  x1, y1, x2, y2 = bbox
81
+ if x2 <= x1 or y2 <= y1:
82
+ raise ValueError(
83
+ f"Invalid bbox: [{x1}, {y1}, {x2}, {y2}] β€” must have x2 > x1 and y2 > y1"
84
+ )
85
+ h, w = frame.shape[:2]
86
 
87
  bw = x2 - x1
88
  bh = y2 - y1
 
107
  Returns:
108
  JPEG bytes.
109
  """
110
+ if frame.dtype != np.uint8:
111
+ frame = frame.astype(np.uint8)
112
  encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
113
  success, buffer = cv2.imencode(".jpg", frame, encode_param)
114
  if not success:
inspection/masks.py CHANGED
@@ -27,21 +27,14 @@ def rle_encode(mask: np.ndarray) -> Dict:
27
  # Flatten in column-major (Fortran) order per COCO convention
28
  flat = mask.astype(np.uint8).ravel(order="F")
29
 
30
- # Compute run lengths
31
- counts: List[int] = []
32
- prev = 0
33
- run = 0
34
- for val in flat:
35
- if val == prev:
36
- run += 1
37
- else:
38
- counts.append(run)
39
- run = 1
40
- prev = val
41
- counts.append(run)
42
 
43
  # Ensure counts starts with a run of 0s (COCO convention)
44
- if len(counts) > 0 and flat[0] == 1:
45
  counts.insert(0, 0)
46
 
47
  return {"counts": counts, "size": [h, w]}
 
27
  # Flatten in column-major (Fortran) order per COCO convention
28
  flat = mask.astype(np.uint8).ravel(order="F")
29
 
30
+ # Compute run lengths using vectorized numpy operations
31
+ changes = np.diff(flat)
32
+ change_indices = np.where(changes != 0)[0] + 1
33
+ boundaries = np.concatenate(([0], change_indices, [len(flat)]))
34
+ counts: List[int] = np.diff(boundaries).tolist()
 
 
 
 
 
 
 
35
 
36
  # Ensure counts starts with a run of 0s (COCO convention)
37
+ if flat[0] == 1:
38
  counts.insert(0, 0)
39
 
40
  return {"counts": counts, "size": [h, w]}
inspection/pointcloud.py CHANGED
@@ -7,9 +7,10 @@ efficient frontend consumption.
7
  """
8
 
9
  import base64
 
10
  import logging
11
  import threading
12
- from typing import Dict, Optional, Tuple
13
 
14
  import cv2
15
  import numpy as np
@@ -19,7 +20,7 @@ logger = logging.getLogger(__name__)
19
  # ── In-memory point cloud cache ──────────────────────────────────
20
  # Key: (job_id, frame_idx, track_id_str, max_points)
21
  # Value: dict with positions, colors, num_points, bbox_3d
22
- _pointcloud_cache: Dict[Tuple[str, int, str, int], dict] = {}
23
  _cache_lock = threading.RLock()
24
  _MAX_CACHE_ENTRIES = 100
25
 
@@ -29,7 +30,11 @@ def get_cached_pointcloud(
29
  ) -> Optional[dict]:
30
  """Return cached point cloud data or None."""
31
  with _cache_lock:
32
- return _pointcloud_cache.get((job_id, frame_idx, track_id, max_points))
 
 
 
 
33
 
34
 
35
  def set_cached_pointcloud(
@@ -104,8 +109,20 @@ def depth_to_pointcloud(
104
  - positions: Nx3 float32 array of XYZ coordinates
105
  - colors: Nx3 uint8 array of RGB colors
106
  """
 
 
 
 
 
107
  h, w = depth_map.shape[:2]
108
 
 
 
 
 
 
 
 
109
  if focal_length is None:
110
  focal_length = estimate_focal_length(w, h)
111
 
@@ -113,32 +130,36 @@ def depth_to_pointcloud(
113
  cx = w / 2.0
114
  cy = h / 2.0
115
 
116
- # Create pixel coordinate grids
117
- u_coords, v_coords = np.meshgrid(np.arange(w), np.arange(h))
118
-
119
- # Determine which pixels to include
120
- valid = np.ones((h, w), dtype=bool)
121
-
122
  if mask is not None:
123
- valid &= mask.astype(bool)
 
 
 
 
 
 
 
124
  elif bbox is not None:
125
- x1, y1, x2, y2 = bbox
126
- x1 = max(0, int(x1))
127
- y1 = max(0, int(y1))
128
- x2 = min(w, int(x2))
129
- y2 = min(h, int(y2))
130
- bbox_mask = np.zeros((h, w), dtype=bool)
131
- bbox_mask[y1:y2, x1:x2] = True
132
- valid &= bbox_mask
133
-
134
- # Exclude zero/NaN depth
135
- valid &= depth_map > 0
136
- valid &= np.isfinite(depth_map)
137
-
138
- # Extract valid pixel coordinates and depth values
139
- v_valid = v_coords[valid]
140
- u_valid = u_coords[valid]
141
- z_valid = depth_map[valid].astype(np.float32)
 
 
 
142
 
143
  if len(z_valid) == 0:
144
  return np.zeros((0, 3), dtype=np.float32), np.zeros((0, 3), dtype=np.uint8)
 
7
  """
8
 
9
  import base64
10
+ import collections
11
  import logging
12
  import threading
13
+ from typing import Optional, Tuple
14
 
15
  import cv2
16
  import numpy as np
 
20
  # ── In-memory point cloud cache ──────────────────────────────────
21
  # Key: (job_id, frame_idx, track_id_str, max_points)
22
  # Value: dict with positions, colors, num_points, bbox_3d
23
+ _pointcloud_cache: collections.OrderedDict = collections.OrderedDict()
24
  _cache_lock = threading.RLock()
25
  _MAX_CACHE_ENTRIES = 100
26
 
 
30
  ) -> Optional[dict]:
31
  """Return cached point cloud data or None."""
32
  with _cache_lock:
33
+ key = (job_id, frame_idx, track_id, max_points)
34
+ value = _pointcloud_cache.get(key)
35
+ if value is not None:
36
+ _pointcloud_cache.move_to_end(key)
37
+ return value
38
 
39
 
40
  def set_cached_pointcloud(
 
109
  - positions: Nx3 float32 array of XYZ coordinates
110
  - colors: Nx3 uint8 array of RGB colors
111
  """
112
+ if depth_map.shape[:2] != color_image.shape[:2]:
113
+ raise ValueError(
114
+ f"Shape mismatch: depth_map {depth_map.shape[:2]} vs color_image {color_image.shape[:2]}"
115
+ )
116
+
117
  h, w = depth_map.shape[:2]
118
 
119
+ if bbox is not None:
120
+ x1_raw, y1_raw, x2_raw, y2_raw = bbox
121
+ if x2_raw <= x1_raw or y2_raw <= y1_raw:
122
+ raise ValueError(
123
+ f"Invalid bbox: must have x2 > x1 and y2 > y1, got ({x1_raw}, {y1_raw}, {x2_raw}, {y2_raw})"
124
+ )
125
+
126
  if focal_length is None:
127
  focal_length = estimate_focal_length(w, h)
128
 
 
130
  cx = w / 2.0
131
  cy = h / 2.0
132
 
 
 
 
 
 
 
133
  if mask is not None:
134
+ # Full-frame meshgrid needed for arbitrary mask shapes
135
+ u_coords, v_coords = np.meshgrid(np.arange(w), np.arange(h))
136
+ valid = mask.astype(bool)
137
+ valid &= depth_map > 0
138
+ valid &= np.isfinite(depth_map)
139
+ v_valid = v_coords[valid]
140
+ u_valid = u_coords[valid]
141
+ z_valid = depth_map[valid].astype(np.float32)
142
  elif bbox is not None:
143
+ # Efficient bbox-scoped meshgrid: only allocate for the bbox region
144
+ x1 = max(0, int(bbox[0]))
145
+ y1 = max(0, int(bbox[1]))
146
+ x2 = min(w, int(bbox[2]))
147
+ y2 = min(h, int(bbox[3]))
148
+ u_coords_1d = np.arange(x1, x2)
149
+ v_coords_1d = np.arange(y1, y2)
150
+ u_grid, v_grid = np.meshgrid(u_coords_1d, v_coords_1d)
151
+ depth_region = depth_map[y1:y2, x1:x2]
152
+ valid_region = (depth_region > 0) & np.isfinite(depth_region)
153
+ v_valid = v_grid[valid_region]
154
+ u_valid = u_grid[valid_region]
155
+ z_valid = depth_region[valid_region].astype(np.float32)
156
+ else:
157
+ # Full-frame: no mask or bbox
158
+ u_coords, v_coords = np.meshgrid(np.arange(w), np.arange(h))
159
+ valid = (depth_map > 0) & np.isfinite(depth_map)
160
+ v_valid = v_coords[valid]
161
+ u_valid = u_coords[valid]
162
+ z_valid = depth_map[valid].astype(np.float32)
163
 
164
  if len(z_valid) == 0:
165
  return np.zeros((0, 3), dtype=np.float32), np.zeros((0, 3), dtype=np.uint8)
inspection/router.py CHANGED
@@ -18,6 +18,33 @@ logger = logging.getLogger(__name__)
18
 
19
  router = APIRouter(prefix="/inspect", tags=["inspection"])
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  def _get_job_or_404(job_id: str):
23
  """Retrieve a job from storage or raise 404."""
@@ -79,14 +106,9 @@ async def get_frame(
79
  from jobs.storage import get_track_data
80
 
81
  tracks = get_track_data(job_id, frame_idx)
82
- target = None
83
  # Parse "T01" -> 1 for instance_id matching
84
- instance_id = int(track_id.replace("T", "")) if track_id.startswith("T") else int(track_id)
85
- for t in tracks:
86
- tid = t.get("instance_id") or t.get("track_id")
87
- if tid == instance_id or tid == track_id:
88
- target = t
89
- break
90
  if target and "bbox" in target:
91
  frame = crop_frame(frame, target["bbox"], padding=padding)
92
  else:
@@ -128,6 +150,9 @@ async def get_mask(
128
  from jobs.storage import get_mask_data, get_track_data
129
  from inspection.masks import mask_area, rle_decode, mask_to_png_bytes
130
 
 
 
 
131
  job = _get_job_or_404(job_id)
132
  if job.mode != "segmentation":
133
  raise HTTPException(
@@ -136,7 +161,7 @@ async def get_mask(
136
  )
137
 
138
  # Parse track_id: accept "T01" or "1", store as int internally
139
- instance_id = int(track_id.replace("T", "")) if isinstance(track_id, str) and track_id.startswith("T") else int(track_id)
140
 
141
  rle = get_mask_data(job_id, frame_idx, instance_id)
142
  if rle is None:
@@ -163,13 +188,7 @@ async def get_mask(
163
 
164
  h, w = rle["size"]
165
 
166
- # Deterministic color per track ID
167
- TRACK_COLORS = [
168
- [255, 0, 128], [0, 255, 128], [128, 0, 255], [255, 128, 0],
169
- [0, 128, 255], [128, 255, 0], [255, 0, 0], [0, 255, 0],
170
- [0, 0, 255], [255, 255, 0], [255, 0, 255], [0, 255, 255],
171
- ]
172
- color = TRACK_COLORS[instance_id % len(TRACK_COLORS)]
173
 
174
  return JSONResponse({
175
  "track_id": track_id,
@@ -236,7 +255,7 @@ async def generate_mask(
236
  job_id: str,
237
  frame_idx: int,
238
  track_id: str,
239
- body: dict = {},
240
  ):
241
  """Generate a segmentation mask on-demand using SAM2 with bbox prompt.
242
 
@@ -250,6 +269,9 @@ async def generate_mask(
250
  from inspection.masks import rle_encode, mask_area
251
  from jobs.storage import get_track_data, set_mask_data, get_mask_data
252
 
 
 
 
253
  job = _get_job_or_404(job_id)
254
  input_path = job.input_video_path
255
  if not input_path or not Path(input_path).exists():
@@ -258,28 +280,25 @@ async def generate_mask(
258
  _validate_frame_idx(input_path, frame_idx)
259
 
260
  # Parse track_id
261
- instance_id = int(track_id.replace("T", "")) if track_id.startswith("T") else int(track_id)
 
 
 
 
 
 
262
 
263
  # Check if mask already exists (cached)
264
  existing = get_mask_data(job_id, frame_idx, instance_id)
265
  if existing:
266
  # Return cached mask
267
  h, w = existing["size"]
268
- TRACK_COLORS = [
269
- [255, 0, 128], [0, 255, 128], [128, 0, 255], [255, 128, 0],
270
- [0, 128, 255], [128, 255, 0], [255, 0, 0], [0, 255, 0],
271
- [0, 0, 255], [255, 255, 0], [255, 0, 255], [0, 255, 255],
272
- ]
273
- color = TRACK_COLORS[instance_id % len(TRACK_COLORS)]
274
 
275
  tracks = get_track_data(job_id, frame_idx)
276
- label = ""
277
- bbox = None
278
- for t in tracks:
279
- if t.get("instance_id") == instance_id or t.get("track_id") == track_id:
280
- label = t.get("label", "")
281
- bbox = t.get("bbox")
282
- break
283
 
284
  return JSONResponse({
285
  "track_id": track_id,
@@ -297,17 +316,11 @@ async def generate_mask(
297
 
298
  # Get track bbox
299
  tracks = get_track_data(job_id, frame_idx)
300
- target = None
301
- for t in tracks:
302
- tid = t.get("instance_id") or t.get("track_id")
303
- if tid == instance_id or tid == track_id:
304
- target = t
305
- break
306
  if not target or "bbox" not in target:
307
  raise HTTPException(status_code=404, detail=f"Track {track_id} not found in frame {frame_idx}.")
308
 
309
  bbox = target["bbox"]
310
- sam2_size = body.get("sam2_size", "large")
311
 
312
  # Extract frame and run SAM2 (in thread pool β€” GPU work)
313
  device = next_device()
@@ -319,12 +332,7 @@ async def generate_mask(
319
  set_mask_data(job_id, frame_idx, instance_id, rle)
320
 
321
  h, w = rle["size"]
322
- TRACK_COLORS = [
323
- [255, 0, 128], [0, 255, 128], [128, 0, 255], [255, 128, 0],
324
- [0, 128, 255], [128, 255, 0], [255, 0, 0], [0, 255, 0],
325
- [0, 0, 255], [255, 255, 0], [255, 0, 255], [0, 255, 255],
326
- ]
327
- color = TRACK_COLORS[instance_id % len(TRACK_COLORS)]
328
 
329
  return JSONResponse({
330
  "track_id": track_id,
@@ -392,13 +400,8 @@ async def get_depth(
392
  from jobs.storage import get_track_data
393
 
394
  tracks = get_track_data(job_id, frame_idx)
395
- instance_id = int(track_id.replace("T", "")) if track_id.startswith("T") else int(track_id)
396
- target = None
397
- for t in tracks:
398
- tid = t.get("instance_id") or t.get("track_id")
399
- if tid == instance_id or tid == track_id:
400
- target = t
401
- break
402
  if target and "bbox" in target:
403
  depth_map = crop_depth_to_bbox(depth_map, target["bbox"])
404
  else:
@@ -493,13 +496,8 @@ async def get_attention(
493
  from jobs.storage import get_track_data
494
 
495
  tracks = get_track_data(job_id, frame_idx)
496
- instance_id = int(track_id.replace("T", "")) if track_id.startswith("T") else int(track_id)
497
- target = None
498
- for t in tracks:
499
- tid = t.get("instance_id") or t.get("track_id")
500
- if tid == instance_id or tid == track_id:
501
- target = t
502
- break
503
 
504
  if not target or "bbox" not in target:
505
  raise HTTPException(
@@ -532,7 +530,7 @@ async def get_attention(
532
  "width": w,
533
  "height": h,
534
  "data_b64": data_b64,
535
- "format": "float32",
536
  })
537
 
538
  # format == "overlay"
@@ -548,7 +546,7 @@ async def get_attention(
548
  async def super_resolve(
549
  job_id: str,
550
  frame_idx: int,
551
- body: dict = {},
552
  ):
553
  """Super-resolve a track's cropped region using Real-ESRGAN (or Lanczos4 fallback).
554
 
@@ -565,15 +563,22 @@ async def super_resolve(
565
  from inspection.frames import extract_frame
566
  from inspection.superres import superresolve_crop, image_to_png
567
 
 
 
 
568
  track_id = body.get("track_id")
569
  if not track_id:
570
  raise HTTPException(status_code=400, detail="track_id is required in request body.")
571
 
572
  scale = body.get("scale", 4)
 
 
573
  if scale not in (2, 4):
574
  raise HTTPException(status_code=400, detail="scale must be 2 or 4.")
575
 
576
  padding = body.get("padding", 0.15)
 
 
577
  if not (0.0 <= padding <= 2.0):
578
  raise HTTPException(status_code=400, detail="padding must be between 0.0 and 2.0.")
579
 
@@ -588,13 +593,8 @@ async def super_resolve(
588
  from jobs.storage import get_track_data
589
 
590
  tracks = get_track_data(job_id, frame_idx)
591
- instance_id = int(track_id.replace("T", "")) if track_id.startswith("T") else int(track_id)
592
- target = None
593
- for t in tracks:
594
- tid = t.get("instance_id") or t.get("track_id")
595
- if tid == instance_id or tid == track_id:
596
- target = t
597
- break
598
 
599
  if not target or "bbox" not in target:
600
  raise HTTPException(
@@ -640,7 +640,7 @@ async def super_resolve(
640
  async def get_pointcloud(
641
  job_id: str,
642
  frame_idx: int,
643
- body: dict = {},
644
  ):
645
  """Generate a 3D point cloud for a tracked object.
646
 
@@ -661,11 +661,16 @@ async def get_pointcloud(
661
  from inspection.depth import run_depth_on_frame
662
  from inspection.pointcloud import generate_pointcloud
663
 
 
 
 
664
  track_id = body.get("track_id")
665
  if not track_id:
666
  raise HTTPException(status_code=400, detail="track_id is required in request body.")
667
 
668
  max_points = body.get("max_points", 50000)
 
 
669
  if max_points < 1 or max_points > 500000:
670
  raise HTTPException(status_code=400, detail="max_points must be between 1 and 500000.")
671
 
@@ -680,13 +685,8 @@ async def get_pointcloud(
680
  from jobs.storage import get_track_data, get_mask_data
681
 
682
  tracks = get_track_data(job_id, frame_idx)
683
- instance_id = int(track_id.replace("T", "")) if track_id.startswith("T") else int(track_id)
684
- target = None
685
- for t in tracks:
686
- tid = t.get("instance_id") or t.get("track_id")
687
- if tid == instance_id or tid == track_id:
688
- target = t
689
- break
690
 
691
  if not target or "bbox" not in target:
692
  raise HTTPException(
 
18
 
19
  router = APIRouter(prefix="/inspect", tags=["inspection"])
20
 
21
+ # Deterministic color palette for track visualization
22
+ _TRACK_COLORS = [
23
+ [255, 0, 128], [0, 255, 128], [128, 0, 255], [255, 128, 0],
24
+ [0, 128, 255], [128, 255, 0], [255, 0, 0], [0, 255, 0],
25
+ [0, 0, 255], [255, 255, 0], [255, 0, 255], [0, 255, 255],
26
+ ]
27
+
28
+
29
+ def _parse_track_id(track_id: str) -> int:
30
+ """Parse track ID string (e.g. 'T03' or '3') to integer instance_id."""
31
+ raw = track_id.lstrip("T") if track_id.startswith("T") else track_id
32
+ try:
33
+ return int(raw)
34
+ except ValueError:
35
+ raise HTTPException(status_code=400, detail=f"Invalid track_id '{track_id}'. Expected format: 'T01' or '1'.")
36
+
37
+
38
+ def _find_track(tracks: list, instance_id: int, track_id: str):
39
+ """Find a track by instance_id or track_id string."""
40
+ for t in tracks:
41
+ tid = t.get("instance_id")
42
+ if tid is not None and tid == instance_id:
43
+ return t
44
+ if tid is None and t.get("track_id") == track_id:
45
+ return t
46
+ return None
47
+
48
 
49
  def _get_job_or_404(job_id: str):
50
  """Retrieve a job from storage or raise 404."""
 
106
  from jobs.storage import get_track_data
107
 
108
  tracks = get_track_data(job_id, frame_idx)
 
109
  # Parse "T01" -> 1 for instance_id matching
110
+ instance_id = _parse_track_id(track_id)
111
+ target = _find_track(tracks, instance_id, track_id)
 
 
 
 
112
  if target and "bbox" in target:
113
  frame = crop_frame(frame, target["bbox"], padding=padding)
114
  else:
 
150
  from jobs.storage import get_mask_data, get_track_data
151
  from inspection.masks import mask_area, rle_decode, mask_to_png_bytes
152
 
153
+ if format not in ("json", "png"):
154
+ raise HTTPException(status_code=400, detail=f"Invalid format '{format}'. Must be 'json' or 'png'.")
155
+
156
  job = _get_job_or_404(job_id)
157
  if job.mode != "segmentation":
158
  raise HTTPException(
 
161
  )
162
 
163
  # Parse track_id: accept "T01" or "1", store as int internally
164
+ instance_id = _parse_track_id(track_id)
165
 
166
  rle = get_mask_data(job_id, frame_idx, instance_id)
167
  if rle is None:
 
188
 
189
  h, w = rle["size"]
190
 
191
+ color = _TRACK_COLORS[instance_id % len(_TRACK_COLORS)]
 
 
 
 
 
 
192
 
193
  return JSONResponse({
194
  "track_id": track_id,
 
255
  job_id: str,
256
  frame_idx: int,
257
  track_id: str,
258
+ body: Optional[dict] = None,
259
  ):
260
  """Generate a segmentation mask on-demand using SAM2 with bbox prompt.
261
 
 
269
  from inspection.masks import rle_encode, mask_area
270
  from jobs.storage import get_track_data, set_mask_data, get_mask_data
271
 
272
+ if body is None:
273
+ body = {}
274
+
275
  job = _get_job_or_404(job_id)
276
  input_path = job.input_video_path
277
  if not input_path or not Path(input_path).exists():
 
280
  _validate_frame_idx(input_path, frame_idx)
281
 
282
  # Parse track_id
283
+ instance_id = _parse_track_id(track_id)
284
+
285
+ # Validate sam2_size early
286
+ sam2_size = body.get("sam2_size", "large")
287
+ valid_sizes = ("small", "base", "large")
288
+ if sam2_size not in valid_sizes:
289
+ raise HTTPException(status_code=400, detail=f"Invalid sam2_size '{sam2_size}'. Must be one of: {valid_sizes}")
290
 
291
  # Check if mask already exists (cached)
292
  existing = get_mask_data(job_id, frame_idx, instance_id)
293
  if existing:
294
  # Return cached mask
295
  h, w = existing["size"]
296
+ color = _TRACK_COLORS[instance_id % len(_TRACK_COLORS)]
 
 
 
 
 
297
 
298
  tracks = get_track_data(job_id, frame_idx)
299
+ target = _find_track(tracks, instance_id, track_id)
300
+ label = target.get("label", "") if target else ""
301
+ bbox = target.get("bbox") if target else None
 
 
 
 
302
 
303
  return JSONResponse({
304
  "track_id": track_id,
 
316
 
317
  # Get track bbox
318
  tracks = get_track_data(job_id, frame_idx)
319
+ target = _find_track(tracks, instance_id, track_id)
 
 
 
 
 
320
  if not target or "bbox" not in target:
321
  raise HTTPException(status_code=404, detail=f"Track {track_id} not found in frame {frame_idx}.")
322
 
323
  bbox = target["bbox"]
 
324
 
325
  # Extract frame and run SAM2 (in thread pool β€” GPU work)
326
  device = next_device()
 
332
  set_mask_data(job_id, frame_idx, instance_id, rle)
333
 
334
  h, w = rle["size"]
335
+ color = _TRACK_COLORS[instance_id % len(_TRACK_COLORS)]
 
 
 
 
 
336
 
337
  return JSONResponse({
338
  "track_id": track_id,
 
400
  from jobs.storage import get_track_data
401
 
402
  tracks = get_track_data(job_id, frame_idx)
403
+ instance_id = _parse_track_id(track_id)
404
+ target = _find_track(tracks, instance_id, track_id)
 
 
 
 
 
405
  if target and "bbox" in target:
406
  depth_map = crop_depth_to_bbox(depth_map, target["bbox"])
407
  else:
 
496
  from jobs.storage import get_track_data
497
 
498
  tracks = get_track_data(job_id, frame_idx)
499
+ instance_id = _parse_track_id(track_id)
500
+ target = _find_track(tracks, instance_id, track_id)
 
 
 
 
 
501
 
502
  if not target or "bbox" not in target:
503
  raise HTTPException(
 
530
  "width": w,
531
  "height": h,
532
  "data_b64": data_b64,
533
+ "format": "uint8",
534
  })
535
 
536
  # format == "overlay"
 
546
  async def super_resolve(
547
  job_id: str,
548
  frame_idx: int,
549
+ body: Optional[dict] = None,
550
  ):
551
  """Super-resolve a track's cropped region using Real-ESRGAN (or Lanczos4 fallback).
552
 
 
563
  from inspection.frames import extract_frame
564
  from inspection.superres import superresolve_crop, image_to_png
565
 
566
+ if body is None:
567
+ body = {}
568
+
569
  track_id = body.get("track_id")
570
  if not track_id:
571
  raise HTTPException(status_code=400, detail="track_id is required in request body.")
572
 
573
  scale = body.get("scale", 4)
574
+ if not isinstance(scale, int):
575
+ raise HTTPException(status_code=400, detail="scale must be an integer.")
576
  if scale not in (2, 4):
577
  raise HTTPException(status_code=400, detail="scale must be 2 or 4.")
578
 
579
  padding = body.get("padding", 0.15)
580
+ if not isinstance(padding, (int, float)):
581
+ raise HTTPException(status_code=400, detail="padding must be a number.")
582
  if not (0.0 <= padding <= 2.0):
583
  raise HTTPException(status_code=400, detail="padding must be between 0.0 and 2.0.")
584
 
 
593
  from jobs.storage import get_track_data
594
 
595
  tracks = get_track_data(job_id, frame_idx)
596
+ instance_id = _parse_track_id(track_id)
597
+ target = _find_track(tracks, instance_id, track_id)
 
 
 
 
 
598
 
599
  if not target or "bbox" not in target:
600
  raise HTTPException(
 
640
  async def get_pointcloud(
641
  job_id: str,
642
  frame_idx: int,
643
+ body: Optional[dict] = None,
644
  ):
645
  """Generate a 3D point cloud for a tracked object.
646
 
 
661
  from inspection.depth import run_depth_on_frame
662
  from inspection.pointcloud import generate_pointcloud
663
 
664
+ if body is None:
665
+ body = {}
666
+
667
  track_id = body.get("track_id")
668
  if not track_id:
669
  raise HTTPException(status_code=400, detail="track_id is required in request body.")
670
 
671
  max_points = body.get("max_points", 50000)
672
+ if not isinstance(max_points, int):
673
+ raise HTTPException(status_code=400, detail="max_points must be an integer.")
674
  if max_points < 1 or max_points > 500000:
675
  raise HTTPException(status_code=400, detail="max_points must be between 1 and 500000.")
676
 
 
685
  from jobs.storage import get_track_data, get_mask_data
686
 
687
  tracks = get_track_data(job_id, frame_idx)
688
+ instance_id = _parse_track_id(track_id)
689
+ target = _find_track(tracks, instance_id, track_id)
 
 
 
 
 
690
 
691
  if not target or "bbox" not in target:
692
  raise HTTPException(
inspection/sam2_mask.py CHANGED
@@ -17,7 +17,7 @@ import torch
17
  logger = logging.getLogger(__name__)
18
 
19
  # ── Per-device SAM2 predictor cache ──────────────────────────────
20
- # Key: (sam2_size, device) Value: SAM2ImagePredictor with RLock
21
  _predictor_cache: Dict[Tuple[str, str], object] = {}
22
  _pred_load_lock = threading.Lock()
23
 
@@ -53,10 +53,9 @@ def _get_predictor(sam2_size: str = "large", device: str = None):
53
 
54
  sam2_model = build_sam2(cfg, ckpt, device=device)
55
  predictor = SAM2ImagePredictor(sam2_model)
56
- predictor.lock = threading.RLock()
57
- _predictor_cache[key] = predictor
58
  logger.info("SAM2 (%s) predictor loaded on %s", sam2_size, device)
59
- return predictor
60
 
61
 
62
  def generate_mask_from_bbox(
@@ -85,9 +84,9 @@ def generate_mask_from_bbox(
85
 
86
  # SAM2 expects RGB
87
  rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
88
- predictor = _get_predictor(sam2_size, device)
89
 
90
- with predictor.lock:
91
  with torch.inference_mode():
92
  predictor.set_image(rgb)
93
  input_box = np.array(bbox)
 
17
  logger = logging.getLogger(__name__)
18
 
19
  # ── Per-device SAM2 predictor cache ──────────────────────────────
20
+ # Key: (sam2_size, device) Value: (SAM2ImagePredictor, RLock) tuple
21
  _predictor_cache: Dict[Tuple[str, str], object] = {}
22
  _pred_load_lock = threading.Lock()
23
 
 
53
 
54
  sam2_model = build_sam2(cfg, ckpt, device=device)
55
  predictor = SAM2ImagePredictor(sam2_model)
56
+ _predictor_cache[key] = (predictor, threading.RLock())
 
57
  logger.info("SAM2 (%s) predictor loaded on %s", sam2_size, device)
58
+ return _predictor_cache[key]
59
 
60
 
61
  def generate_mask_from_bbox(
 
84
 
85
  # SAM2 expects RGB
86
  rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
87
+ predictor, lock = _get_predictor(sam2_size, device)
88
 
89
+ with lock:
90
  with torch.inference_mode():
91
  predictor.set_image(rgb)
92
  input_box = np.array(bbox)
inspection/superres.py CHANGED
@@ -10,9 +10,10 @@ Model instances are cached per-device for multi-GPU round-robin,
10
  matching the pattern used in inference.py.
11
  """
12
 
 
13
  import logging
14
  import threading
15
- from typing import Dict, Optional, Tuple
16
 
17
  import cv2
18
  import numpy as np
@@ -21,7 +22,7 @@ logger = logging.getLogger(__name__)
21
 
22
  # ── In-memory super-resolution cache ─────────────────────────────
23
  # Key: (job_id, frame_idx, track_id_str, scale) Value: upscaled BGR uint8 ndarray
24
- _superres_cache: Dict[Tuple[str, int, str, int], np.ndarray] = {}
25
  _cache_lock = threading.RLock()
26
  _MAX_CACHE_ENTRIES = 100
27
 
@@ -31,7 +32,11 @@ def get_cached_superres(
31
  ) -> Optional[np.ndarray]:
32
  """Return cached super-resolved image or None."""
33
  with _cache_lock:
34
- return _superres_cache.get((job_id, frame_idx, track_id, scale))
 
 
 
 
35
 
36
 
37
  def set_cached_superres(
@@ -58,7 +63,7 @@ def clear_superres_cache(job_id: Optional[str] = None) -> None:
58
 
59
  # ── Per-device Real-ESRGAN model cache ───────────────────────────
60
 
61
- _realesrgan_models: Dict[str, object] = {}
62
  _realesrgan_load_lock = threading.Lock()
63
  _realesrgan_available: Optional[bool] = None
64
 
@@ -118,16 +123,15 @@ def _get_realesrgan_model(device: str):
118
  scale=4,
119
  model_path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
120
  model=rrdb_model,
121
- tile=0, # No tiling for small crops
122
  tile_pad=10,
123
  pre_pad=0,
124
  half=device.startswith("cuda"),
125
  device=device,
126
  )
127
- model.lock = threading.RLock()
128
- _realesrgan_models[device] = model
129
  logger.info("Real-ESRGAN x4plus loaded on %s", device)
130
- return model
131
  except Exception as e:
132
  logger.warning("Failed to load Real-ESRGAN on %s: %s", device, e)
133
  return None
@@ -158,11 +162,12 @@ def upscale_image(
158
  from inspection.gpu import next_device
159
  device = next_device()
160
 
161
- model = _get_realesrgan_model(device)
162
- if model is not None:
163
  try:
 
164
  # Real-ESRGAN expects BGR uint8 input
165
- with model.lock:
166
  output, _ = model.enhance(image, outscale=scale)
167
  return output, "realesrgan"
168
  except Exception as e:
 
10
  matching the pattern used in inference.py.
11
  """
12
 
13
+ import collections
14
  import logging
15
  import threading
16
+ from typing import Optional, Tuple
17
 
18
  import cv2
19
  import numpy as np
 
22
 
23
  # ── In-memory super-resolution cache ─────────────────────────────
24
  # Key: (job_id, frame_idx, track_id_str, scale) Value: upscaled BGR uint8 ndarray
25
+ _superres_cache: collections.OrderedDict = collections.OrderedDict()
26
  _cache_lock = threading.RLock()
27
  _MAX_CACHE_ENTRIES = 100
28
 
 
32
  ) -> Optional[np.ndarray]:
33
  """Return cached super-resolved image or None."""
34
  with _cache_lock:
35
+ key = (job_id, frame_idx, track_id, scale)
36
+ value = _superres_cache.get(key)
37
+ if value is not None:
38
+ _superres_cache.move_to_end(key)
39
+ return value
40
 
41
 
42
  def set_cached_superres(
 
63
 
64
  # ── Per-device Real-ESRGAN model cache ───────────────────────────
65
 
66
+ _realesrgan_models: dict = {}
67
  _realesrgan_load_lock = threading.Lock()
68
  _realesrgan_available: Optional[bool] = None
69
 
 
123
  scale=4,
124
  model_path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
125
  model=rrdb_model,
126
+ tile=256, # Enable tiling to prevent OOM on large crops
127
  tile_pad=10,
128
  pre_pad=0,
129
  half=device.startswith("cuda"),
130
  device=device,
131
  )
132
+ _realesrgan_models[device] = (model, threading.RLock())
 
133
  logger.info("Real-ESRGAN x4plus loaded on %s", device)
134
+ return _realesrgan_models[device]
135
  except Exception as e:
136
  logger.warning("Failed to load Real-ESRGAN on %s: %s", device, e)
137
  return None
 
162
  from inspection.gpu import next_device
163
  device = next_device()
164
 
165
+ model_tuple = _get_realesrgan_model(device)
166
+ if model_tuple is not None:
167
  try:
168
+ model, lock = model_tuple
169
  # Real-ESRGAN expects BGR uint8 input
170
+ with lock:
171
  output, _ = model.enhance(image, outscale=scale)
172
  return output, "realesrgan"
173
  except Exception as e: