Zhen Ye Claude Opus 4.6 (1M context) commited on
Commit
61b921a
·
1 Parent(s): 3727802

feat: mission-relevance coloring — red for relevant objects, gray for non-relevant

Browse files

Backend-rendered video now uses GPT assessment verdicts to color bboxes
(detection mode) and masks (segmentation mode) red for mission-relevant
objects and gray for non-relevant. Unassessed tracks keep default colors.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Files changed (3) hide show
  1. inference.py +58 -8
  2. jobs/storage.py +14 -0
  3. models/isr/loop.py +7 -0
inference.py CHANGED
@@ -21,7 +21,7 @@ from models.model_loader import load_detector, load_detector_on_device
21
  from models.segmenters.model_loader import load_segmenter, load_segmenter_on_device
22
  from models.depth_estimators.model_loader import load_depth_estimator, load_depth_estimator_on_device
23
  from utils.video import StreamingVideoWriter
24
- from jobs.storage import set_track_data, store_latest_frame
25
  from inspection.masks import rle_encode
26
  import tempfile
27
  import json as json_module
@@ -136,12 +136,37 @@ def _color_for_label(label: str) -> Tuple[int, int, int]:
136
  return (blue, green, red)
137
 
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  def draw_boxes(
140
  frame: np.ndarray,
141
  boxes: np.ndarray,
142
  labels: Optional[Sequence[int]] = None,
143
  queries: Optional[Sequence[str]] = None,
144
  label_names: Optional[Sequence[str]] = None,
 
145
  ) -> np.ndarray:
146
  output = frame.copy()
147
  if boxes is None:
@@ -158,7 +183,10 @@ def draw_boxes(
158
  label = f"label_{label_idx}"
159
  else:
160
  label = f"label_{idx}"
161
- color = (128, 128, 128) if not label else _color_for_label(label)
 
 
 
162
  cv2.rectangle(output, (x1, y1), (x2, y2), color, thickness=2)
163
  if label:
164
  font = cv2.FONT_HERSHEY_SIMPLEX
@@ -190,6 +218,7 @@ def draw_masks(
190
  masks: np.ndarray,
191
  alpha: float = 0.65,
192
  labels: Optional[Sequence[str]] = None,
 
193
  ) -> np.ndarray:
194
  output = frame.copy()
195
  if masks is None or len(masks) == 0:
@@ -206,9 +235,12 @@ def draw_masks(
206
  label = None
207
  if labels and idx < len(labels):
208
  label = labels[idx]
209
- # Use a fallback key for consistent color even when no label text
210
- color_key = label if label else f"object_{idx}"
211
- color = _color_for_label(color_key)
 
 
 
212
  overlay[mask_bool] = color
213
  output = cv2.addWeighted(output, 1.0, overlay, alpha, 0)
214
  contours, _ = cv2.findContours(
@@ -928,7 +960,12 @@ def run_inference(
928
  if dets:
929
  display_boxes = np.array([d['bbox'] for d in dets])
930
  display_labels = [d.get('label', 'obj') for d in dets]
931
- p_frame = draw_boxes(p_frame, display_boxes, label_names=display_labels)
 
 
 
 
 
932
 
933
  writer.write(p_frame)
934
 
@@ -1054,6 +1091,7 @@ def _gsam2_render_frame(
1054
  height: int,
1055
  width: int,
1056
  frame_store=None,
 
1057
  ) -> np.ndarray:
1058
  """Render a single GSAM2 tracking frame (masks only, no bboxes). CPU-only."""
1059
  if frame_store is not None:
@@ -1069,8 +1107,9 @@ def _gsam2_render_frame(
1069
 
1070
  masks_list: List[np.ndarray] = []
1071
  mask_labels: List[str] = []
 
1072
 
1073
- for _obj_id, obj_info in frame_objects.items():
1074
  mask = obj_info.mask
1075
  label = obj_info.class_name
1076
  if mask is not None:
@@ -1086,10 +1125,16 @@ def _gsam2_render_frame(
1086
  ).astype(bool)
1087
  masks_list.append(mask_np)
1088
  mask_labels.append(label)
 
 
 
 
 
 
1089
 
1090
  if masks_list:
1091
  # Draw masks with labels — no bboxes for segmentation mode
1092
- frame = draw_masks(frame, np.stack(masks_list), labels=mask_labels)
1093
 
1094
  return frame
1095
 
@@ -1199,10 +1244,15 @@ def run_grounded_sam2_tracking(
1199
  if _perf_metrics is not None:
1200
  _t_r = time.perf_counter()
1201
 
 
 
 
 
1202
  frm = _gsam2_render_frame(
1203
  frame_dir, frame_names, fidx, fobjs,
1204
  height, width,
1205
  frame_store=frame_store,
 
1206
  )
1207
 
1208
  if _perf_metrics is not None:
 
21
  from models.segmenters.model_loader import load_segmenter, load_segmenter_on_device
22
  from models.depth_estimators.model_loader import load_depth_estimator, load_depth_estimator_on_device
23
  from utils.video import StreamingVideoWriter
24
+ from jobs.storage import set_track_data, store_latest_frame, get_job_storage
25
  from inspection.masks import rle_encode
26
  import tempfile
27
  import json as json_module
 
136
  return (blue, green, red)
137
 
138
 
139
+ # Mission-relevance colors (BGR)
140
+ _COLOR_MISSION_RELEVANT = (0, 0, 255) # red
141
+ _COLOR_NOT_RELEVANT = (128, 128, 128) # gray
142
+
143
+
144
+ def _mission_colors_for_dets(
145
+ dets: list,
146
+ verdicts: Dict[str, bool],
147
+ ) -> Optional[List[Tuple[int, int, int]]]:
148
+ """Return per-detection BGR colors based on mission verdicts, or None if no verdicts."""
149
+ if not verdicts:
150
+ return None
151
+ colors = []
152
+ for d in dets:
153
+ tid = d.get("track_id")
154
+ if tid and tid in verdicts:
155
+ colors.append(_COLOR_MISSION_RELEVANT if verdicts[tid] else _COLOR_NOT_RELEVANT)
156
+ else:
157
+ # Unassessed — use default hash color
158
+ label = d.get("label", "obj")
159
+ colors.append(_color_for_label(label))
160
+ return colors
161
+
162
+
163
  def draw_boxes(
164
  frame: np.ndarray,
165
  boxes: np.ndarray,
166
  labels: Optional[Sequence[int]] = None,
167
  queries: Optional[Sequence[str]] = None,
168
  label_names: Optional[Sequence[str]] = None,
169
+ colors: Optional[Sequence[Tuple[int, int, int]]] = None,
170
  ) -> np.ndarray:
171
  output = frame.copy()
172
  if boxes is None:
 
183
  label = f"label_{label_idx}"
184
  else:
185
  label = f"label_{idx}"
186
+ if colors is not None and idx < len(colors):
187
+ color = colors[idx]
188
+ else:
189
+ color = (128, 128, 128) if not label else _color_for_label(label)
190
  cv2.rectangle(output, (x1, y1), (x2, y2), color, thickness=2)
191
  if label:
192
  font = cv2.FONT_HERSHEY_SIMPLEX
 
218
  masks: np.ndarray,
219
  alpha: float = 0.65,
220
  labels: Optional[Sequence[str]] = None,
221
+ colors: Optional[Sequence[Tuple[int, int, int]]] = None,
222
  ) -> np.ndarray:
223
  output = frame.copy()
224
  if masks is None or len(masks) == 0:
 
235
  label = None
236
  if labels and idx < len(labels):
237
  label = labels[idx]
238
+ if colors is not None and idx < len(colors):
239
+ color = colors[idx]
240
+ else:
241
+ # Use a fallback key for consistent color even when no label text
242
+ color_key = label if label else f"object_{idx}"
243
+ color = _color_for_label(color_key)
244
  overlay[mask_bool] = color
245
  output = cv2.addWeighted(output, 1.0, overlay, alpha, 0)
246
  contours, _ = cv2.findContours(
 
960
  if dets:
961
  display_boxes = np.array([d['bbox'] for d in dets])
962
  display_labels = [d.get('label', 'obj') for d in dets]
963
+ # Apply mission-relevance colors (red/gray) if verdicts available
964
+ det_colors = None
965
+ if job_id:
966
+ verdicts = get_job_storage().get_mission_verdicts(job_id)
967
+ det_colors = _mission_colors_for_dets(dets, verdicts)
968
+ p_frame = draw_boxes(p_frame, display_boxes, label_names=display_labels, colors=det_colors)
969
 
970
  writer.write(p_frame)
971
 
 
1091
  height: int,
1092
  width: int,
1093
  frame_store=None,
1094
+ mission_verdicts: Optional[Dict[str, bool]] = None,
1095
  ) -> np.ndarray:
1096
  """Render a single GSAM2 tracking frame (masks only, no bboxes). CPU-only."""
1097
  if frame_store is not None:
 
1107
 
1108
  masks_list: List[np.ndarray] = []
1109
  mask_labels: List[str] = []
1110
+ mask_colors: Optional[List[Tuple[int, int, int]]] = None if not mission_verdicts else []
1111
 
1112
+ for obj_id, obj_info in frame_objects.items():
1113
  mask = obj_info.mask
1114
  label = obj_info.class_name
1115
  if mask is not None:
 
1125
  ).astype(bool)
1126
  masks_list.append(mask_np)
1127
  mask_labels.append(label)
1128
+ if mask_colors is not None:
1129
+ tid = str(obj_id)
1130
+ if tid in mission_verdicts:
1131
+ mask_colors.append(_COLOR_MISSION_RELEVANT if mission_verdicts[tid] else _COLOR_NOT_RELEVANT)
1132
+ else:
1133
+ mask_colors.append(_color_for_label(label or f"object_{obj_id}"))
1134
 
1135
  if masks_list:
1136
  # Draw masks with labels — no bboxes for segmentation mode
1137
+ frame = draw_masks(frame, np.stack(masks_list), labels=mask_labels, colors=mask_colors)
1138
 
1139
  return frame
1140
 
 
1244
  if _perf_metrics is not None:
1245
  _t_r = time.perf_counter()
1246
 
1247
+ # Fetch mission verdicts for coloring
1248
+ seg_verdicts = None
1249
+ if job_id:
1250
+ seg_verdicts = get_job_storage().get_mission_verdicts(job_id) or None
1251
  frm = _gsam2_render_frame(
1252
  frame_dir, frame_names, fidx, fobjs,
1253
  height, width,
1254
  frame_store=frame_store,
1255
+ mission_verdicts=seg_verdicts,
1256
  )
1257
 
1258
  if _perf_metrics is not None:
jobs/storage.py CHANGED
@@ -41,6 +41,7 @@ class JobStorage:
41
  self._tracks: Dict[str, Dict[int, list]] = {} # job_id -> {frame_idx -> tracks}
42
  self._latest_frames: Dict[str, any] = {} # job_id -> np.ndarray
43
  self._mask_data: Dict[str, Dict[str, any]] = {} # job_id -> {f"{frame_idx}:{track_id}" -> rle_dict}
 
44
  self._lock = RLock()
45
 
46
  def create(self, job: JobInfo) -> None:
@@ -93,6 +94,18 @@ class JobStorage:
93
  key = f"{frame_idx}:{track_id}"
94
  return self._mask_data.get(job_id, {}).get(key)
95
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  def get_all_masks_for_frame(self, job_id: str, frame_idx: int) -> dict:
97
  """Return {track_id: rle_dict} for all objects in a frame."""
98
  with self._lock:
@@ -122,6 +135,7 @@ class JobStorage:
122
  self._tracks.pop(job_id, None)
123
  self._latest_frames.pop(job_id, None)
124
  self._mask_data.pop(job_id, None)
 
125
  shutil.rmtree(get_job_directory(job_id), ignore_errors=True)
126
 
127
  def cleanup_expired(self, max_age: timedelta) -> None:
 
41
  self._tracks: Dict[str, Dict[int, list]] = {} # job_id -> {frame_idx -> tracks}
42
  self._latest_frames: Dict[str, any] = {} # job_id -> np.ndarray
43
  self._mask_data: Dict[str, Dict[str, any]] = {} # job_id -> {f"{frame_idx}:{track_id}" -> rle_dict}
44
+ self._mission_verdicts: Dict[str, Dict[str, bool]] = {} # job_id -> {track_id -> mission_relevant}
45
  self._lock = RLock()
46
 
47
  def create(self, job: JobInfo) -> None:
 
94
  key = f"{frame_idx}:{track_id}"
95
  return self._mask_data.get(job_id, {}).get(key)
96
 
97
+ def set_mission_verdict(self, job_id: str, track_id: str, relevant: bool) -> None:
98
+ """Cache a mission-relevance verdict for a track."""
99
+ with self._lock:
100
+ if job_id not in self._mission_verdicts:
101
+ self._mission_verdicts[job_id] = {}
102
+ self._mission_verdicts[job_id][track_id] = relevant
103
+
104
+ def get_mission_verdicts(self, job_id: str) -> Dict[str, bool]:
105
+ """Return {track_id: mission_relevant} for all assessed tracks."""
106
+ with self._lock:
107
+ return dict(self._mission_verdicts.get(job_id, {}))
108
+
109
  def get_all_masks_for_frame(self, job_id: str, frame_idx: int) -> dict:
110
  """Return {track_id: rle_dict} for all objects in a frame."""
111
  with self._lock:
 
135
  self._tracks.pop(job_id, None)
136
  self._latest_frames.pop(job_id, None)
137
  self._mask_data.pop(job_id, None)
138
+ self._mission_verdicts.pop(job_id, None)
139
  shutil.rmtree(get_job_directory(job_id), ignore_errors=True)
140
 
141
  def cleanup_expired(self, max_age: timedelta) -> None:
models/isr/loop.py CHANGED
@@ -117,6 +117,13 @@ async def run_isr_assessor_loop(
117
  def _merge_verdicts(storage, job_id: str, verdicts: dict, assessment_frame_idx: int) -> None:
118
  """Merge verdict data into all stored frames for matching track_ids."""
119
  with storage._lock:
 
 
 
 
 
 
 
120
  frames = storage._tracks.get(job_id, {})
121
  for frame_idx, frame_tracks in frames.items():
122
  for det in frame_tracks:
 
117
  def _merge_verdicts(storage, job_id: str, verdicts: dict, assessment_frame_idx: int) -> None:
118
  """Merge verdict data into all stored frames for matching track_ids."""
119
  with storage._lock:
120
+ # Update mission verdict cache for backend rendering
121
+ for tid, v in verdicts.items():
122
+ relevant = v.get("mission_relevant", True)
123
+ if job_id not in storage._mission_verdicts:
124
+ storage._mission_verdicts[job_id] = {}
125
+ storage._mission_verdicts[job_id][tid] = relevant
126
+
127
  frames = storage._tracks.get(job_id, {})
128
  for frame_idx, frame_tracks in frames.items():
129
  for det in frame_tracks: