Zhen Ye commited on
Commit
8e10ddb
·
1 Parent(s): 968c327

Fix GSAM2 GPT writer state safety and background call args

Browse files
Files changed (2) hide show
  1. inference.py +218 -9
  2. jobs/background.py +11 -7
inference.py CHANGED
@@ -1593,8 +1593,13 @@ def _gsam2_render_frame(
1593
  frame_objects: Dict,
1594
  height: int,
1595
  width: int,
 
1596
  ) -> np.ndarray:
1597
- """Render a single GSAM2 tracking frame (masks + boxes). CPU-only."""
 
 
 
 
1598
  from models.segmenters.grounded_sam2 import ObjectInfo
1599
 
1600
  frame_path = os.path.join(frame_dir, frame_names[frame_idx])
@@ -1636,8 +1641,11 @@ def _gsam2_render_frame(
1636
  box_labels.append(label)
1637
 
1638
  if masks_list:
1639
- frame = draw_masks(frame, np.stack(masks_list), labels=mask_labels)
1640
- if boxes_list:
 
 
 
1641
  frame = draw_boxes(frame, np.array(boxes_list), label_names=box_labels)
1642
 
1643
  return frame
@@ -1652,6 +1660,9 @@ def run_grounded_sam2_tracking(
1652
  job_id: Optional[str] = None,
1653
  stream_queue: Optional[Queue] = None,
1654
  step: int = 20,
 
 
 
1655
  ) -> str:
1656
  """Run Grounded-SAM-2 video tracking pipeline.
1657
 
@@ -1957,10 +1968,12 @@ def run_grounded_sam2_tracking(
1957
  frm = _gsam2_render_frame(
1958
  frame_dir, frame_names, fidx, fobjs,
1959
  height, width,
 
1960
  )
 
1961
  while True:
1962
  try:
1963
- render_out.put((fidx, frm), timeout=1.0)
1964
  break
1965
  except Full:
1966
  if render_done:
@@ -1969,7 +1982,7 @@ def run_grounded_sam2_tracking(
1969
  logging.exception("Render failed for frame %d", fidx)
1970
  blank = np.zeros((height, width, 3), dtype=np.uint8)
1971
  try:
1972
- render_out.put((fidx, blank), timeout=5.0)
1973
  except Full:
1974
  pass
1975
 
@@ -1980,10 +1993,105 @@ def run_grounded_sam2_tracking(
1980
  for t in r_workers:
1981
  t.start()
1982
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1983
  def _writer_loop():
1984
  nonlocal render_done
1985
  next_idx = 0
1986
- buf: Dict[int, np.ndarray] = {}
 
 
 
 
 
 
 
 
 
 
 
 
1987
  try:
1988
  with StreamingVideoWriter(
1989
  output_video_path, fps, width, height
@@ -1998,10 +2106,104 @@ def run_grounded_sam2_tracking(
1998
  len(buf), next_idx,
1999
  )
2000
  time.sleep(0.05)
2001
- idx, frm = render_out.get(timeout=1.0)
2002
- buf[idx] = frm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2003
 
2004
- frm = buf.pop(next_idx)
2005
  writer.write(frm)
2006
 
2007
  if stream_queue:
@@ -2034,6 +2236,13 @@ def run_grounded_sam2_tracking(
2034
  continue
2035
  finally:
2036
  render_done = True
 
 
 
 
 
 
 
2037
 
2038
  writer_thread = Thread(target=_writer_loop, daemon=True)
2039
  writer_thread.start()
 
1593
  frame_objects: Dict,
1594
  height: int,
1595
  width: int,
1596
+ masks_only: bool = False,
1597
  ) -> np.ndarray:
1598
+ """Render a single GSAM2 tracking frame (masks + boxes). CPU-only.
1599
+
1600
+ When *masks_only* is True, skip box rendering so the writer thread can
1601
+ draw boxes later with enriched (GPT) labels.
1602
+ """
1603
  from models.segmenters.grounded_sam2 import ObjectInfo
1604
 
1605
  frame_path = os.path.join(frame_dir, frame_names[frame_idx])
 
1641
  box_labels.append(label)
1642
 
1643
  if masks_list:
1644
+ frame = draw_masks(
1645
+ frame, np.stack(masks_list),
1646
+ labels=None if masks_only else mask_labels,
1647
+ )
1648
+ if boxes_list and not masks_only:
1649
  frame = draw_boxes(frame, np.array(boxes_list), label_names=box_labels)
1650
 
1651
  return frame
 
1660
  job_id: Optional[str] = None,
1661
  stream_queue: Optional[Queue] = None,
1662
  step: int = 20,
1663
+ enable_gpt: bool = False,
1664
+ mission_spec=None, # Optional[MissionSpecification]
1665
+ first_frame_gpt_results: Optional[Dict[str, Any]] = None,
1666
  ) -> str:
1667
  """Run Grounded-SAM-2 video tracking pipeline.
1668
 
 
1968
  frm = _gsam2_render_frame(
1969
  frame_dir, frame_names, fidx, fobjs,
1970
  height, width,
1971
+ masks_only=enable_gpt,
1972
  )
1973
+ payload = (fidx, frm, fobjs) if enable_gpt else (fidx, frm, {})
1974
  while True:
1975
  try:
1976
+ render_out.put(payload, timeout=1.0)
1977
  break
1978
  except Full:
1979
  if render_done:
 
1982
  logging.exception("Render failed for frame %d", fidx)
1983
  blank = np.zeros((height, width, 3), dtype=np.uint8)
1984
  try:
1985
+ render_out.put((fidx, blank, {}), timeout=5.0)
1986
  except Full:
1987
  pass
1988
 
 
1993
  for t in r_workers:
1994
  t.start()
1995
 
1996
+ # --- ObjectInfo → detection dict adapter ---
1997
+ def _objectinfo_to_dets(frame_objects_dict):
1998
+ dets = []
1999
+ for obj_id, info in frame_objects_dict.items():
2000
+ dets.append({
2001
+ "label": info.class_name,
2002
+ "bbox": [info.x1, info.y1, info.x2, info.y2],
2003
+ "score": 1.0,
2004
+ "track_id": f"T{obj_id:02d}",
2005
+ "instance_id": obj_id,
2006
+ })
2007
+ return dets
2008
+
2009
+ # --- GPT enrichment thread (when enabled) ---
2010
+ gpt_enrichment_queue: Queue = Queue(maxsize=4)
2011
+ gpt_data_by_track: Dict[str, Dict] = {}
2012
+ gpt_data_lock = RLock()
2013
+ _relevance_refined = [False]
2014
+
2015
+ def _gsam2_enrichment_thread_fn():
2016
+ while True:
2017
+ item = gpt_enrichment_queue.get()
2018
+ if item is None:
2019
+ break
2020
+ frame_idx, frame_data, gpt_dets, ms = item
2021
+ try:
2022
+ # LLM post-filter (LLM_EXTRACTED mode)
2023
+ if ms and ms.parse_mode == "LLM_EXTRACTED":
2024
+ unique_labels = list({
2025
+ d.get("label", "").lower()
2026
+ for d in gpt_dets if d.get("label")
2027
+ })
2028
+ relevant_labels = evaluate_relevance_llm(
2029
+ unique_labels, ms.operator_text
2030
+ )
2031
+ ms.relevance_criteria.required_classes = list(relevant_labels)
2032
+ _relevance_refined[0] = True
2033
+ logging.info(
2034
+ "GSAM2 enrichment: LLM post-filter frame %d: relevant=%s",
2035
+ frame_idx, relevant_labels,
2036
+ )
2037
+ for d in gpt_dets:
2038
+ decision = evaluate_relevance(d, ms.relevance_criteria)
2039
+ d["mission_relevant"] = decision.relevant
2040
+ gpt_dets = [d for d in gpt_dets if d.get("mission_relevant", True)]
2041
+
2042
+ # GPT threat assessment
2043
+ if gpt_dets:
2044
+ cached_gpt = first_frame_gpt_results
2045
+ if not cached_gpt and job_id:
2046
+ try:
2047
+ from jobs.storage import get_job_storage as _gjs
2048
+ _job = _gjs().get(job_id)
2049
+ if _job and _job.first_frame_gpt_results:
2050
+ cached_gpt = _job.first_frame_gpt_results
2051
+ except Exception:
2052
+ pass
2053
+
2054
+ if cached_gpt:
2055
+ logging.info("GSAM2 enrichment: re-using cached GPT for frame %d", frame_idx)
2056
+ gpt_res = cached_gpt
2057
+ else:
2058
+ logging.info("GSAM2 enrichment: running GPT for frame %d...", frame_idx)
2059
+ frame_b64 = encode_frame_to_b64(frame_data)
2060
+ gpt_res = estimate_threat_gpt(
2061
+ detections=gpt_dets, mission_spec=ms,
2062
+ image_b64=frame_b64,
2063
+ )
2064
+
2065
+ for d in gpt_dets:
2066
+ tid = d.get("track_id")
2067
+ if tid and tid in gpt_res:
2068
+ merged = dict(gpt_res[tid])
2069
+ merged["gpt_raw"] = gpt_res[tid]
2070
+ merged["assessment_frame_index"] = frame_idx
2071
+ merged["assessment_status"] = "ASSESSED"
2072
+ with gpt_data_lock:
2073
+ gpt_data_by_track[tid] = merged
2074
+ logging.info("GSAM2 enrichment: GPT results stored for %d tracks", len(gpt_data_by_track))
2075
+
2076
+ except Exception as e:
2077
+ logging.error("GSAM2 enrichment thread failed for frame %d: %s", frame_idx, e)
2078
+
2079
  def _writer_loop():
2080
  nonlocal render_done
2081
  next_idx = 0
2082
+ buf: Dict[int, Tuple] = {}
2083
+
2084
+ # Per-track bbox history (replaces ByteTracker for GSAM2)
2085
+ track_history: Dict[int, List] = {}
2086
+ speed_est = SpeedEstimator(fps=fps) if enable_gpt else None
2087
+ gpt_submitted = False
2088
+
2089
+ # Start enrichment thread when GPT enabled
2090
+ enrich_thread = None
2091
+ if enable_gpt:
2092
+ enrich_thread = Thread(target=_gsam2_enrichment_thread_fn, daemon=True)
2093
+ enrich_thread.start()
2094
+
2095
  try:
2096
  with StreamingVideoWriter(
2097
  output_video_path, fps, width, height
 
2106
  len(buf), next_idx,
2107
  )
2108
  time.sleep(0.05)
2109
+ idx, frm, fobjs = render_out.get(timeout=1.0)
2110
+ buf[idx] = (frm, fobjs)
2111
+
2112
+ frm, fobjs = buf.pop(next_idx)
2113
+
2114
+ # --- GPT enrichment path ---
2115
+ if enable_gpt and fobjs:
2116
+ dets = _objectinfo_to_dets(fobjs)
2117
+
2118
+ # Maintain per-track bbox history (30-frame window)
2119
+ for det in dets:
2120
+ iid = det["instance_id"]
2121
+ track_history.setdefault(iid, []).append(det["bbox"])
2122
+ if len(track_history[iid]) > 30:
2123
+ track_history[iid].pop(0)
2124
+ # Store an immutable per-frame snapshot.
2125
+ det["history"] = list(track_history[iid])
2126
+
2127
+ # Speed estimation
2128
+ if speed_est:
2129
+ speed_est.estimate(dets)
2130
+
2131
+ # Relevance gate
2132
+ if mission_spec:
2133
+ if (mission_spec.parse_mode == "LLM_EXTRACTED"
2134
+ and not _relevance_refined[0]):
2135
+ for d in dets:
2136
+ d["mission_relevant"] = True
2137
+ d["relevance_reason"] = "pending_llm_postfilter"
2138
+ gpt_dets = dets
2139
+ else:
2140
+ for d in dets:
2141
+ decision = evaluate_relevance(d, mission_spec.relevance_criteria)
2142
+ d["mission_relevant"] = decision.relevant
2143
+ d["relevance_reason"] = decision.reason
2144
+ gpt_dets = [d for d in dets if d.get("mission_relevant", True)]
2145
+ else:
2146
+ for d in dets:
2147
+ d["mission_relevant"] = None
2148
+ gpt_dets = dets
2149
+
2150
+ # GPT enrichment (one-shot, first frame with detections)
2151
+ if gpt_dets and not gpt_submitted:
2152
+ for d in gpt_dets:
2153
+ d["assessment_status"] = "PENDING_GPT"
2154
+ try:
2155
+ gpt_enrichment_queue.put(
2156
+ (
2157
+ next_idx,
2158
+ frm.copy(),
2159
+ copy.deepcopy(gpt_dets),
2160
+ mission_spec,
2161
+ ),
2162
+ timeout=1.0,
2163
+ )
2164
+ gpt_submitted = True
2165
+ logging.info("GSAM2 writer: offloaded GPT enrichment for frame %d", next_idx)
2166
+ except Full:
2167
+ logging.warning("GSAM2 GPT enrichment queue full, skipping")
2168
+
2169
+ # Merge persistent GPT data
2170
+ for det in dets:
2171
+ tid = det["track_id"]
2172
+ with gpt_data_lock:
2173
+ gpt_payload = gpt_data_by_track.get(tid)
2174
+ if gpt_payload:
2175
+ det.update(gpt_payload)
2176
+ det["assessment_status"] = "ASSESSED"
2177
+ elif "assessment_status" not in det:
2178
+ det["assessment_status"] = "UNASSESSED"
2179
+
2180
+ # Build enriched display labels
2181
+ display_labels = []
2182
+ for d in dets:
2183
+ lbl = d.get("label", "obj")
2184
+ if "track_id" in d:
2185
+ lbl = f"{d['track_id']} {lbl}"
2186
+ if d.get("gpt_distance_m") is not None:
2187
+ try:
2188
+ lbl = f"{lbl} {int(float(d['gpt_distance_m']))}m"
2189
+ except (TypeError, ValueError):
2190
+ pass
2191
+ display_labels.append(lbl)
2192
+
2193
+ # Draw boxes on mask-rendered frame
2194
+ if dets:
2195
+ boxes = np.array([d["bbox"] for d in dets])
2196
+ frm = draw_boxes(frm, boxes, label_names=display_labels)
2197
+
2198
+ # Store tracks for frontend
2199
+ if job_id:
2200
+ set_track_data(job_id, next_idx, copy.deepcopy(dets))
2201
+
2202
+ elif enable_gpt:
2203
+ # No objects this frame — still store empty track data
2204
+ if job_id:
2205
+ set_track_data(job_id, next_idx, [])
2206
 
 
2207
  writer.write(frm)
2208
 
2209
  if stream_queue:
 
2236
  continue
2237
  finally:
2238
  render_done = True
2239
+ # Shut down enrichment thread
2240
+ if enrich_thread:
2241
+ try:
2242
+ gpt_enrichment_queue.put(None, timeout=5.0)
2243
+ enrich_thread.join(timeout=30)
2244
+ except Exception:
2245
+ logging.warning("GSAM2 enrichment thread shutdown timed out")
2246
 
2247
  writer_thread = Thread(target=_writer_loop, daemon=True)
2248
  writer_thread.start()
jobs/background.py CHANGED
@@ -29,13 +29,17 @@ async def process_video_async(job_id: str) -> None:
29
  if job.mode == "segmentation":
30
  detection_path = await asyncio.to_thread(
31
  run_grounded_sam2_tracking,
32
- job.input_video_path,
33
- job.output_video_path,
34
- job.queries,
35
- None,
36
- job.segmenter_name,
37
- job_id,
38
- stream_queue,
 
 
 
 
39
  )
40
  else:
41
  detections_list = None
 
29
  if job.mode == "segmentation":
30
  detection_path = await asyncio.to_thread(
31
  run_grounded_sam2_tracking,
32
+ input_video_path=job.input_video_path,
33
+ output_video_path=job.output_video_path,
34
+ queries=job.queries,
35
+ max_frames=None,
36
+ segmenter_name=job.segmenter_name,
37
+ job_id=job_id,
38
+ stream_queue=stream_queue,
39
+ step=20,
40
+ enable_gpt=job.enable_gpt,
41
+ mission_spec=job.mission_spec,
42
+ first_frame_gpt_results=job.first_frame_gpt_results,
43
  )
44
  else:
45
  detections_list = None