Spaces:
Paused
Paused
Zhen Ye Claude Opus 4.6 (1M context) commited on
Commit ·
d6c4e26
1
Parent(s): 8d09cca
fix: simplify review — fix mask NMS bug, remove dead code, hoist imports
Browse files- _mask_nms: fix greedy suppression bug (j<=i compared raw indices,
not sort positions — masks escaped suppression). Use order[pos+1:]
- _mask_nms: batch GPU computation (single matmul + single CPU sync
instead of O(n^2) CUDA syncs per mask pair)
- _apply_nms: hoist torch import (already module-level) and
batched_nms import to module scope
- real-backend.js: remove dead isNotRelevant (always false after
mission_relevant filter)
- explain.js: clear _explainCache on job change to prevent stale
data across different videos
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- demo/js/explain.js +8 -1
- demo/js/real-backend.js +1 -2
- inference.py +3 -6
- models/segmenters/grounded_sam2.py +14 -14
demo/js/explain.js
CHANGED
|
@@ -16,12 +16,19 @@ const LIGHTEN_MAP = {
|
|
| 16 |
};
|
| 17 |
|
| 18 |
let _explainAbort = null;
|
| 19 |
-
|
|
|
|
| 20 |
|
| 21 |
async function loadExplainability(jobId, trackId) {
|
| 22 |
const panel = document.getElementById('explainPanel');
|
| 23 |
if (!panel) return;
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
// Check cache
|
| 26 |
if (_explainCache[trackId]) {
|
| 27 |
renderExplainGraph(_explainCache[trackId], panel);
|
|
|
|
| 16 |
};
|
| 17 |
|
| 18 |
let _explainAbort = null;
|
| 19 |
+
let _explainCache = {};
|
| 20 |
+
let _explainCacheJobId = null;
|
| 21 |
|
| 22 |
async function loadExplainability(jobId, trackId) {
|
| 23 |
const panel = document.getElementById('explainPanel');
|
| 24 |
if (!panel) return;
|
| 25 |
|
| 26 |
+
// Clear cache on job change
|
| 27 |
+
if (_explainCacheJobId !== jobId) {
|
| 28 |
+
_explainCache = {};
|
| 29 |
+
_explainCacheJobId = jobId;
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
// Check cache
|
| 33 |
if (_explainCache[trackId]) {
|
| 34 |
renderExplainGraph(_explainCache[trackId], panel);
|
demo/js/real-backend.js
CHANGED
|
@@ -681,10 +681,9 @@ function renderTrackListFromData(tracks) {
|
|
| 681 |
|
| 682 |
filtered.forEach((t, idx) => {
|
| 683 |
const isSelected = ISR.STATE.selectedTrackId === t.track_id;
|
| 684 |
-
const isNotRelevant = t.mission_relevant === false;
|
| 685 |
|
| 686 |
const card = document.createElement('div');
|
| 687 |
-
card.className = 'track-card' + (isSelected ? ' active' : '')
|
| 688 |
card.dataset.trackId = t.track_id;
|
| 689 |
card.style.animation = `track-card-enter 0.3s ease ${idx * 30}ms both`;
|
| 690 |
|
|
|
|
| 681 |
|
| 682 |
filtered.forEach((t, idx) => {
|
| 683 |
const isSelected = ISR.STATE.selectedTrackId === t.track_id;
|
|
|
|
| 684 |
|
| 685 |
const card = document.createElement('div');
|
| 686 |
+
card.className = 'track-card' + (isSelected ? ' active' : '');
|
| 687 |
card.dataset.trackId = t.track_id;
|
| 688 |
card.style.animation = `track-card-enter 0.3s ease ${idx * 30}ms both`;
|
| 689 |
|
inference.py
CHANGED
|
@@ -316,6 +316,7 @@ def _build_detection_records(
|
|
| 316 |
|
| 317 |
|
| 318 |
from utils.tracker import ByteTracker
|
|
|
|
| 319 |
|
| 320 |
|
| 321 |
def _apply_nms(detections: List[Dict[str, Any]], iou_threshold: float = 0.45) -> List[Dict[str, Any]]:
|
|
@@ -330,13 +331,9 @@ def _apply_nms(detections: List[Dict[str, Any]], iou_threshold: float = 0.45) ->
|
|
| 330 |
if len(detections) <= 1:
|
| 331 |
return detections
|
| 332 |
|
| 333 |
-
import torch
|
| 334 |
-
from utils.tiling import batched_nms
|
| 335 |
-
|
| 336 |
boxes = torch.tensor([d["bbox"] for d in detections], dtype=torch.float32)
|
| 337 |
scores = torch.tensor([d["score"] for d in detections], dtype=torch.float32)
|
| 338 |
-
|
| 339 |
-
label_to_id = {}
|
| 340 |
label_ids = []
|
| 341 |
for d in detections:
|
| 342 |
lbl = d["label"]
|
|
@@ -345,7 +342,7 @@ def _apply_nms(detections: List[Dict[str, Any]], iou_threshold: float = 0.45) ->
|
|
| 345 |
label_ids.append(label_to_id[lbl])
|
| 346 |
labels = torch.tensor(label_ids, dtype=torch.int64)
|
| 347 |
|
| 348 |
-
keep =
|
| 349 |
return [detections[i] for i in keep.tolist()]
|
| 350 |
|
| 351 |
|
|
|
|
| 316 |
|
| 317 |
|
| 318 |
from utils.tracker import ByteTracker
|
| 319 |
+
from utils.tiling import batched_nms as _batched_nms
|
| 320 |
|
| 321 |
|
| 322 |
def _apply_nms(detections: List[Dict[str, Any]], iou_threshold: float = 0.45) -> List[Dict[str, Any]]:
|
|
|
|
| 331 |
if len(detections) <= 1:
|
| 332 |
return detections
|
| 333 |
|
|
|
|
|
|
|
|
|
|
| 334 |
boxes = torch.tensor([d["bbox"] for d in detections], dtype=torch.float32)
|
| 335 |
scores = torch.tensor([d["score"] for d in detections], dtype=torch.float32)
|
| 336 |
+
label_to_id: Dict[str, int] = {}
|
|
|
|
| 337 |
label_ids = []
|
| 338 |
for d in detections:
|
| 339 |
lbl = d["label"]
|
|
|
|
| 342 |
label_ids.append(label_to_id[lbl])
|
| 343 |
labels = torch.tensor(label_ids, dtype=torch.int64)
|
| 344 |
|
| 345 |
+
keep = _batched_nms(boxes, scores, labels, iou_threshold)
|
| 346 |
return [detections[i] for i in keep.tolist()]
|
| 347 |
|
| 348 |
|
models/segmenters/grounded_sam2.py
CHANGED
|
@@ -95,28 +95,28 @@ class MaskDictionary:
|
|
| 95 |
if n <= 1:
|
| 96 |
return list(range(n))
|
| 97 |
|
| 98 |
-
#
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
suppressed = [False] * n
|
| 101 |
-
|
| 102 |
-
# Sort by area descending (keep larger masks)
|
| 103 |
-
order = sorted(range(n), key=lambda i: areas[i], reverse=True)
|
| 104 |
-
|
| 105 |
keep = []
|
| 106 |
-
for i in order:
|
| 107 |
if suppressed[i]:
|
| 108 |
continue
|
| 109 |
keep.append(i)
|
| 110 |
-
for j in order:
|
| 111 |
-
if
|
| 112 |
continue
|
| 113 |
-
# Only suppress same-label masks
|
| 114 |
if labels[i] != labels[j]:
|
| 115 |
continue
|
| 116 |
-
|
| 117 |
-
inter = int((masks[i] & masks[j]).sum())
|
| 118 |
-
union = areas[i] + areas[j] - inter
|
| 119 |
-
if union > 0 and inter / union > iou_threshold:
|
| 120 |
suppressed[j] = True
|
| 121 |
|
| 122 |
return sorted(keep)
|
|
|
|
| 95 |
if n <= 1:
|
| 96 |
return list(range(n))
|
| 97 |
|
| 98 |
+
# Batch-compute areas and pairwise IoU on GPU (single sync)
|
| 99 |
+
flat = masks.view(n, -1).float()
|
| 100 |
+
areas = flat.sum(dim=1) # (n,)
|
| 101 |
+
inter_matrix = flat @ flat.T # (n, n)
|
| 102 |
+
union_matrix = areas.unsqueeze(1) + areas.unsqueeze(0) - inter_matrix
|
| 103 |
+
iou_matrix = (inter_matrix / union_matrix.clamp(min=1)).cpu().numpy()
|
| 104 |
+
areas_cpu = areas.cpu().numpy()
|
| 105 |
+
|
| 106 |
+
# Greedy suppression: sort by area descending, suppress smaller overlapping same-label masks
|
| 107 |
+
order = sorted(range(n), key=lambda i: areas_cpu[i], reverse=True)
|
| 108 |
suppressed = [False] * n
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
keep = []
|
| 110 |
+
for pos, i in enumerate(order):
|
| 111 |
if suppressed[i]:
|
| 112 |
continue
|
| 113 |
keep.append(i)
|
| 114 |
+
for j in order[pos + 1:]:
|
| 115 |
+
if suppressed[j]:
|
| 116 |
continue
|
|
|
|
| 117 |
if labels[i] != labels[j]:
|
| 118 |
continue
|
| 119 |
+
if iou_matrix[i, j] > iou_threshold:
|
|
|
|
|
|
|
|
|
|
| 120 |
suppressed[j] = True
|
| 121 |
|
| 122 |
return sorted(keep)
|