Spaces:
Paused
Paused
Zhen Ye
commited on
Commit
·
8e10ddb
1
Parent(s):
968c327
Fix GSAM2 GPT writer state safety and background call args
Browse files- inference.py +218 -9
- jobs/background.py +11 -7
inference.py
CHANGED
|
@@ -1593,8 +1593,13 @@ def _gsam2_render_frame(
|
|
| 1593 |
frame_objects: Dict,
|
| 1594 |
height: int,
|
| 1595 |
width: int,
|
|
|
|
| 1596 |
) -> np.ndarray:
|
| 1597 |
-
"""Render a single GSAM2 tracking frame (masks + boxes). CPU-only.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1598 |
from models.segmenters.grounded_sam2 import ObjectInfo
|
| 1599 |
|
| 1600 |
frame_path = os.path.join(frame_dir, frame_names[frame_idx])
|
|
@@ -1636,8 +1641,11 @@ def _gsam2_render_frame(
|
|
| 1636 |
box_labels.append(label)
|
| 1637 |
|
| 1638 |
if masks_list:
|
| 1639 |
-
frame = draw_masks(
|
| 1640 |
-
|
|
|
|
|
|
|
|
|
|
| 1641 |
frame = draw_boxes(frame, np.array(boxes_list), label_names=box_labels)
|
| 1642 |
|
| 1643 |
return frame
|
|
@@ -1652,6 +1660,9 @@ def run_grounded_sam2_tracking(
|
|
| 1652 |
job_id: Optional[str] = None,
|
| 1653 |
stream_queue: Optional[Queue] = None,
|
| 1654 |
step: int = 20,
|
|
|
|
|
|
|
|
|
|
| 1655 |
) -> str:
|
| 1656 |
"""Run Grounded-SAM-2 video tracking pipeline.
|
| 1657 |
|
|
@@ -1957,10 +1968,12 @@ def run_grounded_sam2_tracking(
|
|
| 1957 |
frm = _gsam2_render_frame(
|
| 1958 |
frame_dir, frame_names, fidx, fobjs,
|
| 1959 |
height, width,
|
|
|
|
| 1960 |
)
|
|
|
|
| 1961 |
while True:
|
| 1962 |
try:
|
| 1963 |
-
render_out.put(
|
| 1964 |
break
|
| 1965 |
except Full:
|
| 1966 |
if render_done:
|
|
@@ -1969,7 +1982,7 @@ def run_grounded_sam2_tracking(
|
|
| 1969 |
logging.exception("Render failed for frame %d", fidx)
|
| 1970 |
blank = np.zeros((height, width, 3), dtype=np.uint8)
|
| 1971 |
try:
|
| 1972 |
-
render_out.put((fidx, blank), timeout=5.0)
|
| 1973 |
except Full:
|
| 1974 |
pass
|
| 1975 |
|
|
@@ -1980,10 +1993,105 @@ def run_grounded_sam2_tracking(
|
|
| 1980 |
for t in r_workers:
|
| 1981 |
t.start()
|
| 1982 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1983 |
def _writer_loop():
|
| 1984 |
nonlocal render_done
|
| 1985 |
next_idx = 0
|
| 1986 |
-
buf: Dict[int,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1987 |
try:
|
| 1988 |
with StreamingVideoWriter(
|
| 1989 |
output_video_path, fps, width, height
|
|
@@ -1998,10 +2106,104 @@ def run_grounded_sam2_tracking(
|
|
| 1998 |
len(buf), next_idx,
|
| 1999 |
)
|
| 2000 |
time.sleep(0.05)
|
| 2001 |
-
idx, frm = render_out.get(timeout=1.0)
|
| 2002 |
-
buf[idx] = frm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2003 |
|
| 2004 |
-
frm = buf.pop(next_idx)
|
| 2005 |
writer.write(frm)
|
| 2006 |
|
| 2007 |
if stream_queue:
|
|
@@ -2034,6 +2236,13 @@ def run_grounded_sam2_tracking(
|
|
| 2034 |
continue
|
| 2035 |
finally:
|
| 2036 |
render_done = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2037 |
|
| 2038 |
writer_thread = Thread(target=_writer_loop, daemon=True)
|
| 2039 |
writer_thread.start()
|
|
|
|
| 1593 |
frame_objects: Dict,
|
| 1594 |
height: int,
|
| 1595 |
width: int,
|
| 1596 |
+
masks_only: bool = False,
|
| 1597 |
) -> np.ndarray:
|
| 1598 |
+
"""Render a single GSAM2 tracking frame (masks + boxes). CPU-only.
|
| 1599 |
+
|
| 1600 |
+
When *masks_only* is True, skip box rendering so the writer thread can
|
| 1601 |
+
draw boxes later with enriched (GPT) labels.
|
| 1602 |
+
"""
|
| 1603 |
from models.segmenters.grounded_sam2 import ObjectInfo
|
| 1604 |
|
| 1605 |
frame_path = os.path.join(frame_dir, frame_names[frame_idx])
|
|
|
|
| 1641 |
box_labels.append(label)
|
| 1642 |
|
| 1643 |
if masks_list:
|
| 1644 |
+
frame = draw_masks(
|
| 1645 |
+
frame, np.stack(masks_list),
|
| 1646 |
+
labels=None if masks_only else mask_labels,
|
| 1647 |
+
)
|
| 1648 |
+
if boxes_list and not masks_only:
|
| 1649 |
frame = draw_boxes(frame, np.array(boxes_list), label_names=box_labels)
|
| 1650 |
|
| 1651 |
return frame
|
|
|
|
| 1660 |
job_id: Optional[str] = None,
|
| 1661 |
stream_queue: Optional[Queue] = None,
|
| 1662 |
step: int = 20,
|
| 1663 |
+
enable_gpt: bool = False,
|
| 1664 |
+
mission_spec=None, # Optional[MissionSpecification]
|
| 1665 |
+
first_frame_gpt_results: Optional[Dict[str, Any]] = None,
|
| 1666 |
) -> str:
|
| 1667 |
"""Run Grounded-SAM-2 video tracking pipeline.
|
| 1668 |
|
|
|
|
| 1968 |
frm = _gsam2_render_frame(
|
| 1969 |
frame_dir, frame_names, fidx, fobjs,
|
| 1970 |
height, width,
|
| 1971 |
+
masks_only=enable_gpt,
|
| 1972 |
)
|
| 1973 |
+
payload = (fidx, frm, fobjs) if enable_gpt else (fidx, frm, {})
|
| 1974 |
while True:
|
| 1975 |
try:
|
| 1976 |
+
render_out.put(payload, timeout=1.0)
|
| 1977 |
break
|
| 1978 |
except Full:
|
| 1979 |
if render_done:
|
|
|
|
| 1982 |
logging.exception("Render failed for frame %d", fidx)
|
| 1983 |
blank = np.zeros((height, width, 3), dtype=np.uint8)
|
| 1984 |
try:
|
| 1985 |
+
render_out.put((fidx, blank, {}), timeout=5.0)
|
| 1986 |
except Full:
|
| 1987 |
pass
|
| 1988 |
|
|
|
|
| 1993 |
for t in r_workers:
|
| 1994 |
t.start()
|
| 1995 |
|
| 1996 |
+
# --- ObjectInfo → detection dict adapter ---
|
| 1997 |
+
def _objectinfo_to_dets(frame_objects_dict):
|
| 1998 |
+
dets = []
|
| 1999 |
+
for obj_id, info in frame_objects_dict.items():
|
| 2000 |
+
dets.append({
|
| 2001 |
+
"label": info.class_name,
|
| 2002 |
+
"bbox": [info.x1, info.y1, info.x2, info.y2],
|
| 2003 |
+
"score": 1.0,
|
| 2004 |
+
"track_id": f"T{obj_id:02d}",
|
| 2005 |
+
"instance_id": obj_id,
|
| 2006 |
+
})
|
| 2007 |
+
return dets
|
| 2008 |
+
|
| 2009 |
+
# --- GPT enrichment thread (when enabled) ---
|
| 2010 |
+
gpt_enrichment_queue: Queue = Queue(maxsize=4)
|
| 2011 |
+
gpt_data_by_track: Dict[str, Dict] = {}
|
| 2012 |
+
gpt_data_lock = RLock()
|
| 2013 |
+
_relevance_refined = [False]
|
| 2014 |
+
|
| 2015 |
+
def _gsam2_enrichment_thread_fn():
|
| 2016 |
+
while True:
|
| 2017 |
+
item = gpt_enrichment_queue.get()
|
| 2018 |
+
if item is None:
|
| 2019 |
+
break
|
| 2020 |
+
frame_idx, frame_data, gpt_dets, ms = item
|
| 2021 |
+
try:
|
| 2022 |
+
# LLM post-filter (LLM_EXTRACTED mode)
|
| 2023 |
+
if ms and ms.parse_mode == "LLM_EXTRACTED":
|
| 2024 |
+
unique_labels = list({
|
| 2025 |
+
d.get("label", "").lower()
|
| 2026 |
+
for d in gpt_dets if d.get("label")
|
| 2027 |
+
})
|
| 2028 |
+
relevant_labels = evaluate_relevance_llm(
|
| 2029 |
+
unique_labels, ms.operator_text
|
| 2030 |
+
)
|
| 2031 |
+
ms.relevance_criteria.required_classes = list(relevant_labels)
|
| 2032 |
+
_relevance_refined[0] = True
|
| 2033 |
+
logging.info(
|
| 2034 |
+
"GSAM2 enrichment: LLM post-filter frame %d: relevant=%s",
|
| 2035 |
+
frame_idx, relevant_labels,
|
| 2036 |
+
)
|
| 2037 |
+
for d in gpt_dets:
|
| 2038 |
+
decision = evaluate_relevance(d, ms.relevance_criteria)
|
| 2039 |
+
d["mission_relevant"] = decision.relevant
|
| 2040 |
+
gpt_dets = [d for d in gpt_dets if d.get("mission_relevant", True)]
|
| 2041 |
+
|
| 2042 |
+
# GPT threat assessment
|
| 2043 |
+
if gpt_dets:
|
| 2044 |
+
cached_gpt = first_frame_gpt_results
|
| 2045 |
+
if not cached_gpt and job_id:
|
| 2046 |
+
try:
|
| 2047 |
+
from jobs.storage import get_job_storage as _gjs
|
| 2048 |
+
_job = _gjs().get(job_id)
|
| 2049 |
+
if _job and _job.first_frame_gpt_results:
|
| 2050 |
+
cached_gpt = _job.first_frame_gpt_results
|
| 2051 |
+
except Exception:
|
| 2052 |
+
pass
|
| 2053 |
+
|
| 2054 |
+
if cached_gpt:
|
| 2055 |
+
logging.info("GSAM2 enrichment: re-using cached GPT for frame %d", frame_idx)
|
| 2056 |
+
gpt_res = cached_gpt
|
| 2057 |
+
else:
|
| 2058 |
+
logging.info("GSAM2 enrichment: running GPT for frame %d...", frame_idx)
|
| 2059 |
+
frame_b64 = encode_frame_to_b64(frame_data)
|
| 2060 |
+
gpt_res = estimate_threat_gpt(
|
| 2061 |
+
detections=gpt_dets, mission_spec=ms,
|
| 2062 |
+
image_b64=frame_b64,
|
| 2063 |
+
)
|
| 2064 |
+
|
| 2065 |
+
for d in gpt_dets:
|
| 2066 |
+
tid = d.get("track_id")
|
| 2067 |
+
if tid and tid in gpt_res:
|
| 2068 |
+
merged = dict(gpt_res[tid])
|
| 2069 |
+
merged["gpt_raw"] = gpt_res[tid]
|
| 2070 |
+
merged["assessment_frame_index"] = frame_idx
|
| 2071 |
+
merged["assessment_status"] = "ASSESSED"
|
| 2072 |
+
with gpt_data_lock:
|
| 2073 |
+
gpt_data_by_track[tid] = merged
|
| 2074 |
+
logging.info("GSAM2 enrichment: GPT results stored for %d tracks", len(gpt_data_by_track))
|
| 2075 |
+
|
| 2076 |
+
except Exception as e:
|
| 2077 |
+
logging.error("GSAM2 enrichment thread failed for frame %d: %s", frame_idx, e)
|
| 2078 |
+
|
| 2079 |
def _writer_loop():
|
| 2080 |
nonlocal render_done
|
| 2081 |
next_idx = 0
|
| 2082 |
+
buf: Dict[int, Tuple] = {}
|
| 2083 |
+
|
| 2084 |
+
# Per-track bbox history (replaces ByteTracker for GSAM2)
|
| 2085 |
+
track_history: Dict[int, List] = {}
|
| 2086 |
+
speed_est = SpeedEstimator(fps=fps) if enable_gpt else None
|
| 2087 |
+
gpt_submitted = False
|
| 2088 |
+
|
| 2089 |
+
# Start enrichment thread when GPT enabled
|
| 2090 |
+
enrich_thread = None
|
| 2091 |
+
if enable_gpt:
|
| 2092 |
+
enrich_thread = Thread(target=_gsam2_enrichment_thread_fn, daemon=True)
|
| 2093 |
+
enrich_thread.start()
|
| 2094 |
+
|
| 2095 |
try:
|
| 2096 |
with StreamingVideoWriter(
|
| 2097 |
output_video_path, fps, width, height
|
|
|
|
| 2106 |
len(buf), next_idx,
|
| 2107 |
)
|
| 2108 |
time.sleep(0.05)
|
| 2109 |
+
idx, frm, fobjs = render_out.get(timeout=1.0)
|
| 2110 |
+
buf[idx] = (frm, fobjs)
|
| 2111 |
+
|
| 2112 |
+
frm, fobjs = buf.pop(next_idx)
|
| 2113 |
+
|
| 2114 |
+
# --- GPT enrichment path ---
|
| 2115 |
+
if enable_gpt and fobjs:
|
| 2116 |
+
dets = _objectinfo_to_dets(fobjs)
|
| 2117 |
+
|
| 2118 |
+
# Maintain per-track bbox history (30-frame window)
|
| 2119 |
+
for det in dets:
|
| 2120 |
+
iid = det["instance_id"]
|
| 2121 |
+
track_history.setdefault(iid, []).append(det["bbox"])
|
| 2122 |
+
if len(track_history[iid]) > 30:
|
| 2123 |
+
track_history[iid].pop(0)
|
| 2124 |
+
# Store an immutable per-frame snapshot.
|
| 2125 |
+
det["history"] = list(track_history[iid])
|
| 2126 |
+
|
| 2127 |
+
# Speed estimation
|
| 2128 |
+
if speed_est:
|
| 2129 |
+
speed_est.estimate(dets)
|
| 2130 |
+
|
| 2131 |
+
# Relevance gate
|
| 2132 |
+
if mission_spec:
|
| 2133 |
+
if (mission_spec.parse_mode == "LLM_EXTRACTED"
|
| 2134 |
+
and not _relevance_refined[0]):
|
| 2135 |
+
for d in dets:
|
| 2136 |
+
d["mission_relevant"] = True
|
| 2137 |
+
d["relevance_reason"] = "pending_llm_postfilter"
|
| 2138 |
+
gpt_dets = dets
|
| 2139 |
+
else:
|
| 2140 |
+
for d in dets:
|
| 2141 |
+
decision = evaluate_relevance(d, mission_spec.relevance_criteria)
|
| 2142 |
+
d["mission_relevant"] = decision.relevant
|
| 2143 |
+
d["relevance_reason"] = decision.reason
|
| 2144 |
+
gpt_dets = [d for d in dets if d.get("mission_relevant", True)]
|
| 2145 |
+
else:
|
| 2146 |
+
for d in dets:
|
| 2147 |
+
d["mission_relevant"] = None
|
| 2148 |
+
gpt_dets = dets
|
| 2149 |
+
|
| 2150 |
+
# GPT enrichment (one-shot, first frame with detections)
|
| 2151 |
+
if gpt_dets and not gpt_submitted:
|
| 2152 |
+
for d in gpt_dets:
|
| 2153 |
+
d["assessment_status"] = "PENDING_GPT"
|
| 2154 |
+
try:
|
| 2155 |
+
gpt_enrichment_queue.put(
|
| 2156 |
+
(
|
| 2157 |
+
next_idx,
|
| 2158 |
+
frm.copy(),
|
| 2159 |
+
copy.deepcopy(gpt_dets),
|
| 2160 |
+
mission_spec,
|
| 2161 |
+
),
|
| 2162 |
+
timeout=1.0,
|
| 2163 |
+
)
|
| 2164 |
+
gpt_submitted = True
|
| 2165 |
+
logging.info("GSAM2 writer: offloaded GPT enrichment for frame %d", next_idx)
|
| 2166 |
+
except Full:
|
| 2167 |
+
logging.warning("GSAM2 GPT enrichment queue full, skipping")
|
| 2168 |
+
|
| 2169 |
+
# Merge persistent GPT data
|
| 2170 |
+
for det in dets:
|
| 2171 |
+
tid = det["track_id"]
|
| 2172 |
+
with gpt_data_lock:
|
| 2173 |
+
gpt_payload = gpt_data_by_track.get(tid)
|
| 2174 |
+
if gpt_payload:
|
| 2175 |
+
det.update(gpt_payload)
|
| 2176 |
+
det["assessment_status"] = "ASSESSED"
|
| 2177 |
+
elif "assessment_status" not in det:
|
| 2178 |
+
det["assessment_status"] = "UNASSESSED"
|
| 2179 |
+
|
| 2180 |
+
# Build enriched display labels
|
| 2181 |
+
display_labels = []
|
| 2182 |
+
for d in dets:
|
| 2183 |
+
lbl = d.get("label", "obj")
|
| 2184 |
+
if "track_id" in d:
|
| 2185 |
+
lbl = f"{d['track_id']} {lbl}"
|
| 2186 |
+
if d.get("gpt_distance_m") is not None:
|
| 2187 |
+
try:
|
| 2188 |
+
lbl = f"{lbl} {int(float(d['gpt_distance_m']))}m"
|
| 2189 |
+
except (TypeError, ValueError):
|
| 2190 |
+
pass
|
| 2191 |
+
display_labels.append(lbl)
|
| 2192 |
+
|
| 2193 |
+
# Draw boxes on mask-rendered frame
|
| 2194 |
+
if dets:
|
| 2195 |
+
boxes = np.array([d["bbox"] for d in dets])
|
| 2196 |
+
frm = draw_boxes(frm, boxes, label_names=display_labels)
|
| 2197 |
+
|
| 2198 |
+
# Store tracks for frontend
|
| 2199 |
+
if job_id:
|
| 2200 |
+
set_track_data(job_id, next_idx, copy.deepcopy(dets))
|
| 2201 |
+
|
| 2202 |
+
elif enable_gpt:
|
| 2203 |
+
# No objects this frame — still store empty track data
|
| 2204 |
+
if job_id:
|
| 2205 |
+
set_track_data(job_id, next_idx, [])
|
| 2206 |
|
|
|
|
| 2207 |
writer.write(frm)
|
| 2208 |
|
| 2209 |
if stream_queue:
|
|
|
|
| 2236 |
continue
|
| 2237 |
finally:
|
| 2238 |
render_done = True
|
| 2239 |
+
# Shut down enrichment thread
|
| 2240 |
+
if enrich_thread:
|
| 2241 |
+
try:
|
| 2242 |
+
gpt_enrichment_queue.put(None, timeout=5.0)
|
| 2243 |
+
enrich_thread.join(timeout=30)
|
| 2244 |
+
except Exception:
|
| 2245 |
+
logging.warning("GSAM2 enrichment thread shutdown timed out")
|
| 2246 |
|
| 2247 |
writer_thread = Thread(target=_writer_loop, daemon=True)
|
| 2248 |
writer_thread.start()
|
jobs/background.py
CHANGED
|
@@ -29,13 +29,17 @@ async def process_video_async(job_id: str) -> None:
|
|
| 29 |
if job.mode == "segmentation":
|
| 30 |
detection_path = await asyncio.to_thread(
|
| 31 |
run_grounded_sam2_tracking,
|
| 32 |
-
job.input_video_path,
|
| 33 |
-
job.output_video_path,
|
| 34 |
-
job.queries,
|
| 35 |
-
None,
|
| 36 |
-
job.segmenter_name,
|
| 37 |
-
job_id,
|
| 38 |
-
stream_queue,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
)
|
| 40 |
else:
|
| 41 |
detections_list = None
|
|
|
|
| 29 |
if job.mode == "segmentation":
|
| 30 |
detection_path = await asyncio.to_thread(
|
| 31 |
run_grounded_sam2_tracking,
|
| 32 |
+
input_video_path=job.input_video_path,
|
| 33 |
+
output_video_path=job.output_video_path,
|
| 34 |
+
queries=job.queries,
|
| 35 |
+
max_frames=None,
|
| 36 |
+
segmenter_name=job.segmenter_name,
|
| 37 |
+
job_id=job_id,
|
| 38 |
+
stream_queue=stream_queue,
|
| 39 |
+
step=20,
|
| 40 |
+
enable_gpt=job.enable_gpt,
|
| 41 |
+
mission_spec=job.mission_spec,
|
| 42 |
+
first_frame_gpt_results=job.first_frame_gpt_results,
|
| 43 |
)
|
| 44 |
else:
|
| 45 |
detections_list = None
|