Zhen Ye Claude Opus 4.6 (1M context) commited on
Commit
d6c4e26
·
1 Parent(s): 8d09cca

fix: simplify review — fix mask NMS bug, remove dead code, hoist imports

Browse files

- _mask_nms: fix greedy suppression bug (j<=i compared raw indices,
not sort positions — masks escaped suppression). Use order[pos+1:]
- _mask_nms: batch GPU computation (single matmul + single CPU sync
instead of O(n^2) CUDA syncs per mask pair)
- _apply_nms: hoist torch import (already module-level) and
batched_nms import to module scope
- real-backend.js: remove dead isNotRelevant (always false after
mission_relevant filter)
- explain.js: clear _explainCache on job change to prevent stale
data across different videos

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

demo/js/explain.js CHANGED
@@ -16,12 +16,19 @@ const LIGHTEN_MAP = {
16
  };
17
 
18
  let _explainAbort = null;
19
- const _explainCache = {};
 
20
 
21
  async function loadExplainability(jobId, trackId) {
22
  const panel = document.getElementById('explainPanel');
23
  if (!panel) return;
24
 
 
 
 
 
 
 
25
  // Check cache
26
  if (_explainCache[trackId]) {
27
  renderExplainGraph(_explainCache[trackId], panel);
 
16
  };
17
 
18
  let _explainAbort = null;
19
+ let _explainCache = {};
20
+ let _explainCacheJobId = null;
21
 
22
  async function loadExplainability(jobId, trackId) {
23
  const panel = document.getElementById('explainPanel');
24
  if (!panel) return;
25
 
26
+ // Clear cache on job change
27
+ if (_explainCacheJobId !== jobId) {
28
+ _explainCache = {};
29
+ _explainCacheJobId = jobId;
30
+ }
31
+
32
  // Check cache
33
  if (_explainCache[trackId]) {
34
  renderExplainGraph(_explainCache[trackId], panel);
demo/js/real-backend.js CHANGED
@@ -681,10 +681,9 @@ function renderTrackListFromData(tracks) {
681
 
682
  filtered.forEach((t, idx) => {
683
  const isSelected = ISR.STATE.selectedTrackId === t.track_id;
684
- const isNotRelevant = t.mission_relevant === false;
685
 
686
  const card = document.createElement('div');
687
- card.className = 'track-card' + (isSelected ? ' active' : '') + (isNotRelevant ? ' not-relevant' : '');
688
  card.dataset.trackId = t.track_id;
689
  card.style.animation = `track-card-enter 0.3s ease ${idx * 30}ms both`;
690
 
 
681
 
682
  filtered.forEach((t, idx) => {
683
  const isSelected = ISR.STATE.selectedTrackId === t.track_id;
 
684
 
685
  const card = document.createElement('div');
686
+ card.className = 'track-card' + (isSelected ? ' active' : '');
687
  card.dataset.trackId = t.track_id;
688
  card.style.animation = `track-card-enter 0.3s ease ${idx * 30}ms both`;
689
 
inference.py CHANGED
@@ -316,6 +316,7 @@ def _build_detection_records(
316
 
317
 
318
  from utils.tracker import ByteTracker
 
319
 
320
 
321
  def _apply_nms(detections: List[Dict[str, Any]], iou_threshold: float = 0.45) -> List[Dict[str, Any]]:
@@ -330,13 +331,9 @@ def _apply_nms(detections: List[Dict[str, Any]], iou_threshold: float = 0.45) ->
330
  if len(detections) <= 1:
331
  return detections
332
 
333
- import torch
334
- from utils.tiling import batched_nms
335
-
336
  boxes = torch.tensor([d["bbox"] for d in detections], dtype=torch.float32)
337
  scores = torch.tensor([d["score"] for d in detections], dtype=torch.float32)
338
- # Map labels to integer class IDs for batched NMS
339
- label_to_id = {}
340
  label_ids = []
341
  for d in detections:
342
  lbl = d["label"]
@@ -345,7 +342,7 @@ def _apply_nms(detections: List[Dict[str, Any]], iou_threshold: float = 0.45) ->
345
  label_ids.append(label_to_id[lbl])
346
  labels = torch.tensor(label_ids, dtype=torch.int64)
347
 
348
- keep = batched_nms(boxes, scores, labels, iou_threshold)
349
  return [detections[i] for i in keep.tolist()]
350
 
351
 
 
316
 
317
 
318
  from utils.tracker import ByteTracker
319
+ from utils.tiling import batched_nms as _batched_nms
320
 
321
 
322
  def _apply_nms(detections: List[Dict[str, Any]], iou_threshold: float = 0.45) -> List[Dict[str, Any]]:
 
331
  if len(detections) <= 1:
332
  return detections
333
 
 
 
 
334
  boxes = torch.tensor([d["bbox"] for d in detections], dtype=torch.float32)
335
  scores = torch.tensor([d["score"] for d in detections], dtype=torch.float32)
336
+ label_to_id: Dict[str, int] = {}
 
337
  label_ids = []
338
  for d in detections:
339
  lbl = d["label"]
 
342
  label_ids.append(label_to_id[lbl])
343
  labels = torch.tensor(label_ids, dtype=torch.int64)
344
 
345
+ keep = _batched_nms(boxes, scores, labels, iou_threshold)
346
  return [detections[i] for i in keep.tolist()]
347
 
348
 
models/segmenters/grounded_sam2.py CHANGED
@@ -95,28 +95,28 @@ class MaskDictionary:
95
  if n <= 1:
96
  return list(range(n))
97
 
98
- # Compute mask areas
99
- areas = [int(masks[i].sum()) for i in range(n)]
 
 
 
 
 
 
 
 
100
  suppressed = [False] * n
101
-
102
- # Sort by area descending (keep larger masks)
103
- order = sorted(range(n), key=lambda i: areas[i], reverse=True)
104
-
105
  keep = []
106
- for i in order:
107
  if suppressed[i]:
108
  continue
109
  keep.append(i)
110
- for j in order:
111
- if j <= i or suppressed[j]:
112
  continue
113
- # Only suppress same-label masks
114
  if labels[i] != labels[j]:
115
  continue
116
- # Compute mask IoU
117
- inter = int((masks[i] & masks[j]).sum())
118
- union = areas[i] + areas[j] - inter
119
- if union > 0 and inter / union > iou_threshold:
120
  suppressed[j] = True
121
 
122
  return sorted(keep)
 
95
  if n <= 1:
96
  return list(range(n))
97
 
98
+ # Batch-compute areas and pairwise IoU on GPU (single sync)
99
+ flat = masks.view(n, -1).float()
100
+ areas = flat.sum(dim=1) # (n,)
101
+ inter_matrix = flat @ flat.T # (n, n)
102
+ union_matrix = areas.unsqueeze(1) + areas.unsqueeze(0) - inter_matrix
103
+ iou_matrix = (inter_matrix / union_matrix.clamp(min=1)).cpu().numpy()
104
+ areas_cpu = areas.cpu().numpy()
105
+
106
+ # Greedy suppression: sort by area descending, suppress smaller overlapping same-label masks
107
+ order = sorted(range(n), key=lambda i: areas_cpu[i], reverse=True)
108
  suppressed = [False] * n
 
 
 
 
109
  keep = []
110
+ for pos, i in enumerate(order):
111
  if suppressed[i]:
112
  continue
113
  keep.append(i)
114
+ for j in order[pos + 1:]:
115
+ if suppressed[j]:
116
  continue
 
117
  if labels[i] != labels[j]:
118
  continue
119
+ if iou_matrix[i, j] > iou_threshold:
 
 
 
120
  suppressed[j] = True
121
 
122
  return sorted(keep)