Spaces:
Running
Running
Zhen Ye
commited on
Commit
·
bb6e650
1
Parent(s):
b3371b1
feat(backend): enhance inference pipeline with GLM logic and structured outputs
Browse files- inference.py +33 -117
- utils/enrichment.py +122 -0
- utils/gpt_reasoning.py +22 -33
- utils/mission_parser.py +8 -29
- utils/openai_client.py +80 -0
- utils/relevance.py +6 -21
- utils/schemas.py +12 -0
- utils/tracker.py +6 -4
inference.py
CHANGED
|
@@ -7,7 +7,7 @@ import os
|
|
| 7 |
|
| 8 |
import logging
|
| 9 |
import time
|
| 10 |
-
from threading import RLock, Thread
|
| 11 |
from queue import Queue, PriorityQueue, Full, Empty
|
| 12 |
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
| 13 |
|
|
@@ -15,15 +15,15 @@ import cv2
|
|
| 15 |
import numpy as np
|
| 16 |
import torch
|
| 17 |
from concurrent.futures import ThreadPoolExecutor
|
| 18 |
-
from threading import RLock
|
| 19 |
from models.detectors.base import ObjectDetector
|
| 20 |
from models.model_loader import load_detector, load_detector_on_device
|
| 21 |
from models.segmenters.model_loader import load_segmenter, load_segmenter_on_device
|
| 22 |
from models.depth_estimators.model_loader import load_depth_estimator, load_depth_estimator_on_device
|
| 23 |
from models.depth_estimators.base import DepthEstimator
|
| 24 |
-
from utils.video import extract_frames, write_video, VideoReader, VideoWriter,
|
| 25 |
-
from utils.
|
| 26 |
-
from utils.
|
|
|
|
| 27 |
from jobs.storage import set_track_data
|
| 28 |
import tempfile
|
| 29 |
import json as json_module
|
|
@@ -781,7 +781,7 @@ def process_first_frame(
|
|
| 781 |
"bbox": [int(c) for c in box],
|
| 782 |
"score": float(seg_result.scores[idx]) if seg_result.scores is not None and idx < len(seg_result.scores) else 1.0,
|
| 783 |
"track_id": f"T{idx + 1:02d}",
|
| 784 |
-
"assessment_status":
|
| 785 |
})
|
| 786 |
return processed, detections
|
| 787 |
|
|
@@ -791,7 +791,7 @@ def process_first_frame(
|
|
| 791 |
|
| 792 |
# Tag all detections as unassessed — GPT runs later in enrichment thread
|
| 793 |
for det in detections:
|
| 794 |
-
det["assessment_status"] =
|
| 795 |
|
| 796 |
return processed, detections
|
| 797 |
|
|
@@ -1067,7 +1067,7 @@ def run_inference(
|
|
| 1067 |
# --- GPT Enrichment Thread (non-blocking) ---
|
| 1068 |
# Runs LLM relevance + GPT threat assessment off the writer's critical path.
|
| 1069 |
gpt_enrichment_queue = Queue(maxsize=4)
|
| 1070 |
-
_relevance_refined =
|
| 1071 |
|
| 1072 |
def enrichment_thread_fn(tracker_ref):
|
| 1073 |
"""Dedicated thread for GPT/LLM calls. Receives work from writer, injects results via tracker."""
|
|
@@ -1077,64 +1077,13 @@ def run_inference(
|
|
| 1077 |
break # Sentinel — shutdown
|
| 1078 |
frame_idx, frame_data, gpt_dets, ms = item
|
| 1079 |
try:
|
| 1080 |
-
|
| 1081 |
-
|
| 1082 |
-
|
| 1083 |
-
|
| 1084 |
-
|
| 1085 |
-
|
| 1086 |
-
|
| 1087 |
-
unique_labels, ms.operator_text
|
| 1088 |
-
)
|
| 1089 |
-
ms.relevance_criteria.required_classes = list(relevant_labels)
|
| 1090 |
-
_relevance_refined[0] = True # signal writer_loop to switch to deterministic gate
|
| 1091 |
-
logging.info(
|
| 1092 |
-
"Enrichment: LLM post-filter applied on frame %d: relevant=%s",
|
| 1093 |
-
frame_idx, relevant_labels,
|
| 1094 |
-
)
|
| 1095 |
-
# Re-filter with refined classes
|
| 1096 |
-
for d in gpt_dets:
|
| 1097 |
-
decision = evaluate_relevance(d, ms.relevance_criteria)
|
| 1098 |
-
d["mission_relevant"] = decision.relevant
|
| 1099 |
-
gpt_dets = [d for d in gpt_dets if d.get("mission_relevant", True)]
|
| 1100 |
-
|
| 1101 |
-
# GPT threat assessment
|
| 1102 |
-
if gpt_dets:
|
| 1103 |
-
# Check for cached results: passed directly or from app-level background task
|
| 1104 |
-
cached_gpt = first_frame_gpt_results
|
| 1105 |
-
if not cached_gpt and job_id:
|
| 1106 |
-
try:
|
| 1107 |
-
from jobs.storage import get_job_storage as _gjs
|
| 1108 |
-
_job = _gjs().get(job_id)
|
| 1109 |
-
if _job and _job.first_frame_gpt_results:
|
| 1110 |
-
cached_gpt = _job.first_frame_gpt_results
|
| 1111 |
-
except Exception:
|
| 1112 |
-
pass
|
| 1113 |
-
|
| 1114 |
-
if cached_gpt:
|
| 1115 |
-
logging.info("Enrichment: re-using cached GPT results for frame %d", frame_idx)
|
| 1116 |
-
gpt_res = cached_gpt
|
| 1117 |
-
else:
|
| 1118 |
-
logging.info("Enrichment: running GPT estimation for frame %d...", frame_idx)
|
| 1119 |
-
frame_b64 = encode_frame_to_b64(frame_data)
|
| 1120 |
-
gpt_res = estimate_threat_gpt(
|
| 1121 |
-
detections=gpt_dets, mission_spec=ms,
|
| 1122 |
-
image_b64=frame_b64,
|
| 1123 |
-
)
|
| 1124 |
-
|
| 1125 |
-
# Merge using real track_id assigned by ByteTracker
|
| 1126 |
-
for d in gpt_dets:
|
| 1127 |
-
oid = d.get('track_id')
|
| 1128 |
-
if oid and oid in gpt_res:
|
| 1129 |
-
gpt_payload = gpt_res[oid]
|
| 1130 |
-
d.update(gpt_payload)
|
| 1131 |
-
d["gpt_raw"] = gpt_payload
|
| 1132 |
-
d["assessment_frame_index"] = frame_idx
|
| 1133 |
-
d["assessment_status"] = gpt_payload.get(
|
| 1134 |
-
"assessment_status", "ASSESSED"
|
| 1135 |
-
)
|
| 1136 |
-
|
| 1137 |
-
# Push GPT data back into tracker's internal STrack objects
|
| 1138 |
tracker_ref.inject_metadata(gpt_dets)
|
| 1139 |
logging.info("Enrichment: GPT results injected into tracker for frame %d", frame_idx)
|
| 1140 |
|
|
@@ -1186,7 +1135,7 @@ def run_inference(
|
|
| 1186 |
# --- RELEVANCE GATE (deterministic, fast — stays in writer) ---
|
| 1187 |
if mission_spec:
|
| 1188 |
if (mission_spec.parse_mode == "LLM_EXTRACTED"
|
| 1189 |
-
and not _relevance_refined
|
| 1190 |
# LLM post-filter hasn't run yet — pass all through
|
| 1191 |
for d in dets:
|
| 1192 |
d["mission_relevant"] = True
|
|
@@ -1220,7 +1169,7 @@ def run_inference(
|
|
| 1220 |
if enable_gpt and gpt_dets and not gpt_submitted:
|
| 1221 |
# Tag as pending — enrichment thread will update to ASSESSED later
|
| 1222 |
for d in gpt_dets:
|
| 1223 |
-
d["assessment_status"] =
|
| 1224 |
try:
|
| 1225 |
gpt_enrichment_queue.put(
|
| 1226 |
(next_idx, p_frame.copy(), gpt_dets, mission_spec),
|
|
@@ -1234,7 +1183,7 @@ def run_inference(
|
|
| 1234 |
# Tag unassessed detections (INV-6)
|
| 1235 |
for d in dets:
|
| 1236 |
if "assessment_status" not in d:
|
| 1237 |
-
d["assessment_status"] =
|
| 1238 |
|
| 1239 |
# --- RENDER BOXES & OVERLAYS ---
|
| 1240 |
if dets:
|
|
@@ -2027,7 +1976,7 @@ def run_grounded_sam2_tracking(
|
|
| 2027 |
gpt_enrichment_queue: Queue = Queue(maxsize=4)
|
| 2028 |
gpt_data_by_track: Dict[str, Dict] = {}
|
| 2029 |
gpt_data_lock = RLock()
|
| 2030 |
-
_relevance_refined =
|
| 2031 |
|
| 2032 |
def _gsam2_enrichment_thread_fn():
|
| 2033 |
while True:
|
|
@@ -2036,49 +1985,15 @@ def run_grounded_sam2_tracking(
|
|
| 2036 |
break
|
| 2037 |
frame_idx, frame_data, gpt_dets, ms = item
|
| 2038 |
try:
|
| 2039 |
-
|
| 2040 |
-
|
| 2041 |
-
|
| 2042 |
-
|
| 2043 |
-
|
| 2044 |
-
|
| 2045 |
-
relevant_labels = evaluate_relevance_llm(
|
| 2046 |
-
unique_labels, ms.operator_text
|
| 2047 |
-
)
|
| 2048 |
-
ms.relevance_criteria.required_classes = list(relevant_labels)
|
| 2049 |
-
_relevance_refined[0] = True
|
| 2050 |
-
logging.info(
|
| 2051 |
-
"GSAM2 enrichment: LLM post-filter frame %d: relevant=%s",
|
| 2052 |
-
frame_idx, relevant_labels,
|
| 2053 |
-
)
|
| 2054 |
-
for d in gpt_dets:
|
| 2055 |
-
decision = evaluate_relevance(d, ms.relevance_criteria)
|
| 2056 |
-
d["mission_relevant"] = decision.relevant
|
| 2057 |
-
gpt_dets = [d for d in gpt_dets if d.get("mission_relevant", True)]
|
| 2058 |
-
|
| 2059 |
-
# GPT threat assessment
|
| 2060 |
-
if gpt_dets:
|
| 2061 |
-
cached_gpt = first_frame_gpt_results
|
| 2062 |
-
if not cached_gpt and job_id:
|
| 2063 |
-
try:
|
| 2064 |
-
from jobs.storage import get_job_storage as _gjs
|
| 2065 |
-
_job = _gjs().get(job_id)
|
| 2066 |
-
if _job and _job.first_frame_gpt_results:
|
| 2067 |
-
cached_gpt = _job.first_frame_gpt_results
|
| 2068 |
-
except Exception:
|
| 2069 |
-
pass
|
| 2070 |
-
|
| 2071 |
-
if cached_gpt:
|
| 2072 |
-
logging.info("GSAM2 enrichment: re-using cached GPT for frame %d", frame_idx)
|
| 2073 |
-
gpt_res = cached_gpt
|
| 2074 |
-
else:
|
| 2075 |
-
logging.info("GSAM2 enrichment: running GPT for frame %d...", frame_idx)
|
| 2076 |
-
frame_b64 = encode_frame_to_b64(frame_data)
|
| 2077 |
-
gpt_res = estimate_threat_gpt(
|
| 2078 |
-
detections=gpt_dets, mission_spec=ms,
|
| 2079 |
-
image_b64=frame_b64,
|
| 2080 |
-
)
|
| 2081 |
|
|
|
|
|
|
|
| 2082 |
for d in gpt_dets:
|
| 2083 |
tid = d.get("track_id")
|
| 2084 |
if tid and tid in gpt_res:
|
|
@@ -2086,7 +2001,7 @@ def run_grounded_sam2_tracking(
|
|
| 2086 |
merged["gpt_raw"] = gpt_res[tid]
|
| 2087 |
merged["assessment_frame_index"] = frame_idx
|
| 2088 |
merged["assessment_status"] = merged.get(
|
| 2089 |
-
"assessment_status",
|
| 2090 |
)
|
| 2091 |
with gpt_data_lock:
|
| 2092 |
gpt_data_by_track[tid] = merged
|
|
@@ -2096,6 +2011,7 @@ def run_grounded_sam2_tracking(
|
|
| 2096 |
# frontend polling (/detect/status) picks them up.
|
| 2097 |
if job_id:
|
| 2098 |
try:
|
|
|
|
| 2099 |
_st = _gjs().get(job_id)
|
| 2100 |
if _st and _st.first_frame_detections:
|
| 2101 |
for det in _st.first_frame_detections:
|
|
@@ -2184,7 +2100,7 @@ def run_grounded_sam2_tracking(
|
|
| 2184 |
# Relevance gate
|
| 2185 |
if mission_spec:
|
| 2186 |
if (mission_spec.parse_mode == "LLM_EXTRACTED"
|
| 2187 |
-
and not _relevance_refined
|
| 2188 |
for d in dets:
|
| 2189 |
d["mission_relevant"] = True
|
| 2190 |
d["relevance_reason"] = "pending_llm_postfilter"
|
|
@@ -2203,7 +2119,7 @@ def run_grounded_sam2_tracking(
|
|
| 2203 |
# GPT enrichment (one-shot, first frame with detections)
|
| 2204 |
if gpt_dets and not gpt_submitted:
|
| 2205 |
for d in gpt_dets:
|
| 2206 |
-
d["assessment_status"] =
|
| 2207 |
try:
|
| 2208 |
gpt_enrichment_queue.put(
|
| 2209 |
(
|
|
@@ -2226,9 +2142,9 @@ def run_grounded_sam2_tracking(
|
|
| 2226 |
gpt_payload = gpt_data_by_track.get(tid)
|
| 2227 |
if gpt_payload:
|
| 2228 |
det.update(gpt_payload)
|
| 2229 |
-
det["assessment_status"] =
|
| 2230 |
elif "assessment_status" not in det:
|
| 2231 |
-
det["assessment_status"] =
|
| 2232 |
|
| 2233 |
# Build enriched display labels
|
| 2234 |
display_labels = []
|
|
|
|
| 7 |
|
| 8 |
import logging
|
| 9 |
import time
|
| 10 |
+
from threading import Event, RLock, Thread
|
| 11 |
from queue import Queue, PriorityQueue, Full, Empty
|
| 12 |
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
| 13 |
|
|
|
|
| 15 |
import numpy as np
|
| 16 |
import torch
|
| 17 |
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
| 18 |
from models.detectors.base import ObjectDetector
|
| 19 |
from models.model_loader import load_detector, load_detector_on_device
|
| 20 |
from models.segmenters.model_loader import load_segmenter, load_segmenter_on_device
|
| 21 |
from models.depth_estimators.model_loader import load_depth_estimator, load_depth_estimator_on_device
|
| 22 |
from models.depth_estimators.base import DepthEstimator
|
| 23 |
+
from utils.video import extract_frames, write_video, VideoReader, VideoWriter, StreamingVideoWriter
|
| 24 |
+
from utils.relevance import evaluate_relevance
|
| 25 |
+
from utils.enrichment import run_enrichment
|
| 26 |
+
from utils.schemas import AssessmentStatus
|
| 27 |
from jobs.storage import set_track_data
|
| 28 |
import tempfile
|
| 29 |
import json as json_module
|
|
|
|
| 781 |
"bbox": [int(c) for c in box],
|
| 782 |
"score": float(seg_result.scores[idx]) if seg_result.scores is not None and idx < len(seg_result.scores) else 1.0,
|
| 783 |
"track_id": f"T{idx + 1:02d}",
|
| 784 |
+
"assessment_status": AssessmentStatus.UNASSESSED,
|
| 785 |
})
|
| 786 |
return processed, detections
|
| 787 |
|
|
|
|
| 791 |
|
| 792 |
# Tag all detections as unassessed — GPT runs later in enrichment thread
|
| 793 |
for det in detections:
|
| 794 |
+
det["assessment_status"] = AssessmentStatus.UNASSESSED
|
| 795 |
|
| 796 |
return processed, detections
|
| 797 |
|
|
|
|
| 1067 |
# --- GPT Enrichment Thread (non-blocking) ---
|
| 1068 |
# Runs LLM relevance + GPT threat assessment off the writer's critical path.
|
| 1069 |
gpt_enrichment_queue = Queue(maxsize=4)
|
| 1070 |
+
_relevance_refined = Event()
|
| 1071 |
|
| 1072 |
def enrichment_thread_fn(tracker_ref):
|
| 1073 |
"""Dedicated thread for GPT/LLM calls. Receives work from writer, injects results via tracker."""
|
|
|
|
| 1077 |
break # Sentinel — shutdown
|
| 1078 |
frame_idx, frame_data, gpt_dets, ms = item
|
| 1079 |
try:
|
| 1080 |
+
gpt_res = run_enrichment(
|
| 1081 |
+
frame_idx, frame_data, gpt_dets, ms,
|
| 1082 |
+
first_frame_gpt_results=first_frame_gpt_results,
|
| 1083 |
+
job_id=job_id,
|
| 1084 |
+
relevance_refined_event=_relevance_refined,
|
| 1085 |
+
)
|
| 1086 |
+
if gpt_res:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1087 |
tracker_ref.inject_metadata(gpt_dets)
|
| 1088 |
logging.info("Enrichment: GPT results injected into tracker for frame %d", frame_idx)
|
| 1089 |
|
|
|
|
| 1135 |
# --- RELEVANCE GATE (deterministic, fast — stays in writer) ---
|
| 1136 |
if mission_spec:
|
| 1137 |
if (mission_spec.parse_mode == "LLM_EXTRACTED"
|
| 1138 |
+
and not _relevance_refined.is_set()):
|
| 1139 |
# LLM post-filter hasn't run yet — pass all through
|
| 1140 |
for d in dets:
|
| 1141 |
d["mission_relevant"] = True
|
|
|
|
| 1169 |
if enable_gpt and gpt_dets and not gpt_submitted:
|
| 1170 |
# Tag as pending — enrichment thread will update to ASSESSED later
|
| 1171 |
for d in gpt_dets:
|
| 1172 |
+
d["assessment_status"] = AssessmentStatus.PENDING_GPT
|
| 1173 |
try:
|
| 1174 |
gpt_enrichment_queue.put(
|
| 1175 |
(next_idx, p_frame.copy(), gpt_dets, mission_spec),
|
|
|
|
| 1183 |
# Tag unassessed detections (INV-6)
|
| 1184 |
for d in dets:
|
| 1185 |
if "assessment_status" not in d:
|
| 1186 |
+
d["assessment_status"] = AssessmentStatus.UNASSESSED
|
| 1187 |
|
| 1188 |
# --- RENDER BOXES & OVERLAYS ---
|
| 1189 |
if dets:
|
|
|
|
| 1976 |
gpt_enrichment_queue: Queue = Queue(maxsize=4)
|
| 1977 |
gpt_data_by_track: Dict[str, Dict] = {}
|
| 1978 |
gpt_data_lock = RLock()
|
| 1979 |
+
_relevance_refined = Event()
|
| 1980 |
|
| 1981 |
def _gsam2_enrichment_thread_fn():
|
| 1982 |
while True:
|
|
|
|
| 1985 |
break
|
| 1986 |
frame_idx, frame_data, gpt_dets, ms = item
|
| 1987 |
try:
|
| 1988 |
+
gpt_res = run_enrichment(
|
| 1989 |
+
frame_idx, frame_data, gpt_dets, ms,
|
| 1990 |
+
first_frame_gpt_results=first_frame_gpt_results,
|
| 1991 |
+
job_id=job_id,
|
| 1992 |
+
relevance_refined_event=_relevance_refined,
|
| 1993 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1994 |
|
| 1995 |
+
# GSAM2-specific: store results in per-track dict and persist to job storage
|
| 1996 |
+
if gpt_res:
|
| 1997 |
for d in gpt_dets:
|
| 1998 |
tid = d.get("track_id")
|
| 1999 |
if tid and tid in gpt_res:
|
|
|
|
| 2001 |
merged["gpt_raw"] = gpt_res[tid]
|
| 2002 |
merged["assessment_frame_index"] = frame_idx
|
| 2003 |
merged["assessment_status"] = merged.get(
|
| 2004 |
+
"assessment_status", AssessmentStatus.ASSESSED
|
| 2005 |
)
|
| 2006 |
with gpt_data_lock:
|
| 2007 |
gpt_data_by_track[tid] = merged
|
|
|
|
| 2011 |
# frontend polling (/detect/status) picks them up.
|
| 2012 |
if job_id:
|
| 2013 |
try:
|
| 2014 |
+
from jobs.storage import get_job_storage as _gjs
|
| 2015 |
_st = _gjs().get(job_id)
|
| 2016 |
if _st and _st.first_frame_detections:
|
| 2017 |
for det in _st.first_frame_detections:
|
|
|
|
| 2100 |
# Relevance gate
|
| 2101 |
if mission_spec:
|
| 2102 |
if (mission_spec.parse_mode == "LLM_EXTRACTED"
|
| 2103 |
+
and not _relevance_refined.is_set()):
|
| 2104 |
for d in dets:
|
| 2105 |
d["mission_relevant"] = True
|
| 2106 |
d["relevance_reason"] = "pending_llm_postfilter"
|
|
|
|
| 2119 |
# GPT enrichment (one-shot, first frame with detections)
|
| 2120 |
if gpt_dets and not gpt_submitted:
|
| 2121 |
for d in gpt_dets:
|
| 2122 |
+
d["assessment_status"] = AssessmentStatus.PENDING_GPT
|
| 2123 |
try:
|
| 2124 |
gpt_enrichment_queue.put(
|
| 2125 |
(
|
|
|
|
| 2142 |
gpt_payload = gpt_data_by_track.get(tid)
|
| 2143 |
if gpt_payload:
|
| 2144 |
det.update(gpt_payload)
|
| 2145 |
+
det["assessment_status"] = AssessmentStatus.ASSESSED
|
| 2146 |
elif "assessment_status" not in det:
|
| 2147 |
+
det["assessment_status"] = AssessmentStatus.UNASSESSED
|
| 2148 |
|
| 2149 |
# Build enriched display labels
|
| 2150 |
display_labels = []
|
utils/enrichment.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Shared enrichment workflow — single implementation of the 5-step GPT enrichment
|
| 3 |
+
pipeline used by inference.py (detection + GSAM2) and app.py (first-frame).
|
| 4 |
+
|
| 5 |
+
Consolidates duplicated logic from:
|
| 6 |
+
- inference.py enrichment_thread_fn
|
| 7 |
+
- inference.py _gsam2_enrichment_thread_fn
|
| 8 |
+
- app.py _enrich_first_frame_gpt
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
from threading import Event
|
| 13 |
+
from typing import Any, Dict, List, Optional
|
| 14 |
+
|
| 15 |
+
from utils.gpt_reasoning import estimate_threat_gpt, encode_frame_to_b64
|
| 16 |
+
from utils.relevance import evaluate_relevance, evaluate_relevance_llm
|
| 17 |
+
from utils.schemas import AssessmentStatus
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def run_enrichment(
|
| 23 |
+
frame_idx: int,
|
| 24 |
+
frame_data,
|
| 25 |
+
detections: List[Dict[str, Any]],
|
| 26 |
+
mission_spec,
|
| 27 |
+
*,
|
| 28 |
+
first_frame_gpt_results: Optional[Dict] = None,
|
| 29 |
+
job_id: Optional[str] = None,
|
| 30 |
+
relevance_refined_event: Optional[Event] = None,
|
| 31 |
+
) -> Optional[Dict[str, Any]]:
|
| 32 |
+
"""Run the shared enrichment workflow (LLM post-filter + GPT threat assessment).
|
| 33 |
+
|
| 34 |
+
Steps:
|
| 35 |
+
1. LLM post-filter via evaluate_relevance_llm() (if LLM_EXTRACTED mode)
|
| 36 |
+
2. Signal relevance_refined_event (if provided)
|
| 37 |
+
3. Check cached GPT results (parameter or JobStorage fallback)
|
| 38 |
+
4. Call estimate_threat_gpt() if no cache
|
| 39 |
+
5. Merge results into detections by track_id
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
frame_idx: Index of the frame being enriched.
|
| 43 |
+
frame_data: OpenCV BGR frame (numpy array).
|
| 44 |
+
detections: Mutable list of detection dicts to enrich in-place.
|
| 45 |
+
mission_spec: Optional MissionSpecification.
|
| 46 |
+
first_frame_gpt_results: Pre-computed GPT results (cache hit).
|
| 47 |
+
job_id: Job identifier for JobStorage fallback cache lookup.
|
| 48 |
+
relevance_refined_event: threading.Event to signal when LLM post-filter completes.
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
GPT results dict (object_id -> assessment), or None if all detections
|
| 52 |
+
were filtered out.
|
| 53 |
+
"""
|
| 54 |
+
gpt_dets = detections
|
| 55 |
+
|
| 56 |
+
# --- Step 1: LLM post-filter (LLM_EXTRACTED mode) ---
|
| 57 |
+
if mission_spec and mission_spec.parse_mode == "LLM_EXTRACTED":
|
| 58 |
+
unique_labels = list({
|
| 59 |
+
d.get("label", "").lower()
|
| 60 |
+
for d in gpt_dets if d.get("label")
|
| 61 |
+
})
|
| 62 |
+
relevant_labels = evaluate_relevance_llm(
|
| 63 |
+
unique_labels, mission_spec.operator_text
|
| 64 |
+
)
|
| 65 |
+
mission_spec.relevance_criteria.required_classes = list(relevant_labels)
|
| 66 |
+
|
| 67 |
+
# --- Step 2: Signal writer loop ---
|
| 68 |
+
if relevance_refined_event is not None:
|
| 69 |
+
relevance_refined_event.set()
|
| 70 |
+
|
| 71 |
+
logger.info(
|
| 72 |
+
"Enrichment: LLM post-filter applied on frame %d: relevant=%s",
|
| 73 |
+
frame_idx, relevant_labels,
|
| 74 |
+
)
|
| 75 |
+
# Re-filter with refined classes
|
| 76 |
+
for d in gpt_dets:
|
| 77 |
+
decision = evaluate_relevance(d, mission_spec.relevance_criteria)
|
| 78 |
+
d["mission_relevant"] = decision.relevant
|
| 79 |
+
gpt_dets = [d for d in gpt_dets if d.get("mission_relevant", True)]
|
| 80 |
+
elif relevance_refined_event is not None:
|
| 81 |
+
# Non-LLM mode: signal immediately so writer doesn't block
|
| 82 |
+
relevance_refined_event.set()
|
| 83 |
+
|
| 84 |
+
if not gpt_dets:
|
| 85 |
+
return None
|
| 86 |
+
|
| 87 |
+
# --- Step 3: Check cached GPT results ---
|
| 88 |
+
cached_gpt = first_frame_gpt_results
|
| 89 |
+
if not cached_gpt and job_id:
|
| 90 |
+
try:
|
| 91 |
+
from jobs.storage import get_job_storage as _gjs
|
| 92 |
+
_job = _gjs().get(job_id)
|
| 93 |
+
if _job and _job.first_frame_gpt_results:
|
| 94 |
+
cached_gpt = _job.first_frame_gpt_results
|
| 95 |
+
except Exception:
|
| 96 |
+
pass
|
| 97 |
+
|
| 98 |
+
# --- Step 4: Call GPT if no cache ---
|
| 99 |
+
if cached_gpt:
|
| 100 |
+
logger.info("Enrichment: re-using cached GPT results for frame %d", frame_idx)
|
| 101 |
+
gpt_res = cached_gpt
|
| 102 |
+
else:
|
| 103 |
+
logger.info("Enrichment: running GPT estimation for frame %d...", frame_idx)
|
| 104 |
+
frame_b64 = encode_frame_to_b64(frame_data)
|
| 105 |
+
gpt_res = estimate_threat_gpt(
|
| 106 |
+
detections=gpt_dets, mission_spec=mission_spec,
|
| 107 |
+
image_b64=frame_b64,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# --- Step 5: Merge results into detections by track_id ---
|
| 111 |
+
for d in gpt_dets:
|
| 112 |
+
oid = d.get("track_id")
|
| 113 |
+
if oid and oid in gpt_res:
|
| 114 |
+
gpt_payload = gpt_res[oid]
|
| 115 |
+
d.update(gpt_payload)
|
| 116 |
+
d["gpt_raw"] = gpt_payload
|
| 117 |
+
d["assessment_frame_index"] = frame_idx
|
| 118 |
+
d["assessment_status"] = gpt_payload.get(
|
| 119 |
+
"assessment_status", AssessmentStatus.ASSESSED
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
return gpt_res
|
utils/gpt_reasoning.py
CHANGED
|
@@ -4,9 +4,8 @@ import json
|
|
| 4 |
import base64
|
| 5 |
import logging
|
| 6 |
from typing import List, Dict, Any, Optional
|
| 7 |
-
import
|
| 8 |
-
import
|
| 9 |
-
from utils.schemas import FrameThreatAnalysis
|
| 10 |
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
|
@@ -15,10 +14,16 @@ def encode_image(image_path: str) -> str:
|
|
| 15 |
return base64.b64encode(image_file.read()).decode('utf-8')
|
| 16 |
|
| 17 |
|
| 18 |
-
def encode_frame_to_b64(frame) -> str:
|
| 19 |
-
"""Encode an OpenCV BGR frame to a base64 JPEG string in memory (no disk I/O).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
import cv2
|
| 21 |
-
|
|
|
|
| 22 |
if not success:
|
| 23 |
raise ValueError("Failed to encode frame to JPEG")
|
| 24 |
return base64.b64encode(buf.tobytes()).decode('utf-8')
|
|
@@ -167,8 +172,7 @@ def estimate_threat_gpt(
|
|
| 167 |
if detections is None:
|
| 168 |
detections = []
|
| 169 |
|
| 170 |
-
|
| 171 |
-
if not api_key:
|
| 172 |
logger.error("OPENAI_API_KEY not set. Skipping GPT threat assessment.")
|
| 173 |
return {}
|
| 174 |
|
|
@@ -200,7 +204,7 @@ def estimate_threat_gpt(
|
|
| 200 |
)
|
| 201 |
return _build_status_fallback(
|
| 202 |
skipped_human_ids,
|
| 203 |
-
|
| 204 |
"Human/person analysis skipped due policy constraints.",
|
| 205 |
)
|
| 206 |
return {}
|
|
@@ -264,25 +268,10 @@ def estimate_threat_gpt(
|
|
| 264 |
"response_format": { "type": "json_object" }
|
| 265 |
}
|
| 266 |
|
| 267 |
-
headers = {
|
| 268 |
-
"Content-Type": "application/json",
|
| 269 |
-
"Authorization": f"Bearer {api_key}"
|
| 270 |
-
}
|
| 271 |
-
|
| 272 |
try:
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
data=json.dumps(payload).encode('utf-8'),
|
| 276 |
-
headers=headers,
|
| 277 |
-
method="POST"
|
| 278 |
-
)
|
| 279 |
-
with urllib.request.urlopen(req, timeout=30) as response:
|
| 280 |
-
resp_data = json.loads(response.read().decode('utf-8'))
|
| 281 |
-
|
| 282 |
-
choice_msg = resp_data.get("choices", [{}])[0].get("message", {})
|
| 283 |
-
content = choice_msg.get("content")
|
| 284 |
if not content:
|
| 285 |
-
refusal = choice_msg.get("refusal")
|
| 286 |
if refusal:
|
| 287 |
logger.warning("GPT refused threat assessment: %s", refusal)
|
| 288 |
else:
|
|
@@ -293,13 +282,13 @@ def estimate_threat_gpt(
|
|
| 293 |
)
|
| 294 |
fallback = _build_status_fallback(
|
| 295 |
[it["obj_id"] for it in prompt_items],
|
| 296 |
-
|
| 297 |
refusal or "GPT returned empty content.",
|
| 298 |
)
|
| 299 |
fallback.update(
|
| 300 |
_build_status_fallback(
|
| 301 |
skipped_human_ids,
|
| 302 |
-
|
| 303 |
"Human/person analysis skipped due policy constraints.",
|
| 304 |
)
|
| 305 |
)
|
|
@@ -320,14 +309,14 @@ def estimate_threat_gpt(
|
|
| 320 |
oid = it["obj_id"]
|
| 321 |
if oid not in objects:
|
| 322 |
objects[oid] = {
|
| 323 |
-
"assessment_status":
|
| 324 |
"gpt_reason": "No structured assessment returned for object.",
|
| 325 |
}
|
| 326 |
for oid in skipped_human_ids:
|
| 327 |
objects.setdefault(
|
| 328 |
oid,
|
| 329 |
{
|
| 330 |
-
"assessment_status":
|
| 331 |
"gpt_reason": "Human/person analysis skipped due policy constraints.",
|
| 332 |
},
|
| 333 |
)
|
|
@@ -336,7 +325,7 @@ def estimate_threat_gpt(
|
|
| 336 |
for obj_id, data in objects.items():
|
| 337 |
if not isinstance(data, dict):
|
| 338 |
data = {
|
| 339 |
-
"assessment_status":
|
| 340 |
"gpt_reason": "Malformed object payload from GPT.",
|
| 341 |
}
|
| 342 |
objects[obj_id] = data
|
|
@@ -373,13 +362,13 @@ def estimate_threat_gpt(
|
|
| 373 |
logger.error("GPT API call failed: %s", e, exc_info=True)
|
| 374 |
fallback = _build_status_fallback(
|
| 375 |
[it["obj_id"] for it in prompt_items],
|
| 376 |
-
|
| 377 |
f"GPT API call failed: {e.__class__.__name__}",
|
| 378 |
)
|
| 379 |
fallback.update(
|
| 380 |
_build_status_fallback(
|
| 381 |
skipped_human_ids,
|
| 382 |
-
|
| 383 |
"Human/person analysis skipped due policy constraints.",
|
| 384 |
)
|
| 385 |
)
|
|
|
|
| 4 |
import base64
|
| 5 |
import logging
|
| 6 |
from typing import List, Dict, Any, Optional
|
| 7 |
+
from utils.schemas import FrameThreatAnalysis, AssessmentStatus
|
| 8 |
+
from utils.openai_client import chat_completion, extract_content, get_api_key, OpenAIAPIError
|
|
|
|
| 9 |
|
| 10 |
logger = logging.getLogger(__name__)
|
| 11 |
|
|
|
|
| 14 |
return base64.b64encode(image_file.read()).decode('utf-8')
|
| 15 |
|
| 16 |
|
| 17 |
+
def encode_frame_to_b64(frame, quality=None) -> str:
|
| 18 |
+
"""Encode an OpenCV BGR frame to a base64 JPEG string in memory (no disk I/O).
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
frame: OpenCV BGR numpy array.
|
| 22 |
+
quality: Optional JPEG quality (1-100). Uses OpenCV default if None.
|
| 23 |
+
"""
|
| 24 |
import cv2
|
| 25 |
+
params = [int(cv2.IMWRITE_JPEG_QUALITY), quality] if quality is not None else None
|
| 26 |
+
success, buf = cv2.imencode('.jpg', frame, params) if params else cv2.imencode('.jpg', frame)
|
| 27 |
if not success:
|
| 28 |
raise ValueError("Failed to encode frame to JPEG")
|
| 29 |
return base64.b64encode(buf.tobytes()).decode('utf-8')
|
|
|
|
| 172 |
if detections is None:
|
| 173 |
detections = []
|
| 174 |
|
| 175 |
+
if not get_api_key():
|
|
|
|
| 176 |
logger.error("OPENAI_API_KEY not set. Skipping GPT threat assessment.")
|
| 177 |
return {}
|
| 178 |
|
|
|
|
| 204 |
)
|
| 205 |
return _build_status_fallback(
|
| 206 |
skipped_human_ids,
|
| 207 |
+
AssessmentStatus.SKIPPED_POLICY,
|
| 208 |
"Human/person analysis skipped due policy constraints.",
|
| 209 |
)
|
| 210 |
return {}
|
|
|
|
| 268 |
"response_format": { "type": "json_object" }
|
| 269 |
}
|
| 270 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
try:
|
| 272 |
+
resp_data = chat_completion(payload)
|
| 273 |
+
content, refusal = extract_content(resp_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
if not content:
|
|
|
|
| 275 |
if refusal:
|
| 276 |
logger.warning("GPT refused threat assessment: %s", refusal)
|
| 277 |
else:
|
|
|
|
| 282 |
)
|
| 283 |
fallback = _build_status_fallback(
|
| 284 |
[it["obj_id"] for it in prompt_items],
|
| 285 |
+
AssessmentStatus.REFUSED,
|
| 286 |
refusal or "GPT returned empty content.",
|
| 287 |
)
|
| 288 |
fallback.update(
|
| 289 |
_build_status_fallback(
|
| 290 |
skipped_human_ids,
|
| 291 |
+
AssessmentStatus.SKIPPED_POLICY,
|
| 292 |
"Human/person analysis skipped due policy constraints.",
|
| 293 |
)
|
| 294 |
)
|
|
|
|
| 309 |
oid = it["obj_id"]
|
| 310 |
if oid not in objects:
|
| 311 |
objects[oid] = {
|
| 312 |
+
"assessment_status": AssessmentStatus.NO_RESPONSE,
|
| 313 |
"gpt_reason": "No structured assessment returned for object.",
|
| 314 |
}
|
| 315 |
for oid in skipped_human_ids:
|
| 316 |
objects.setdefault(
|
| 317 |
oid,
|
| 318 |
{
|
| 319 |
+
"assessment_status": AssessmentStatus.SKIPPED_POLICY,
|
| 320 |
"gpt_reason": "Human/person analysis skipped due policy constraints.",
|
| 321 |
},
|
| 322 |
)
|
|
|
|
| 325 |
for obj_id, data in objects.items():
|
| 326 |
if not isinstance(data, dict):
|
| 327 |
data = {
|
| 328 |
+
"assessment_status": AssessmentStatus.NO_RESPONSE,
|
| 329 |
"gpt_reason": "Malformed object payload from GPT.",
|
| 330 |
}
|
| 331 |
objects[obj_id] = data
|
|
|
|
| 362 |
logger.error("GPT API call failed: %s", e, exc_info=True)
|
| 363 |
fallback = _build_status_fallback(
|
| 364 |
[it["obj_id"] for it in prompt_items],
|
| 365 |
+
AssessmentStatus.ERROR,
|
| 366 |
f"GPT API call failed: {e.__class__.__name__}",
|
| 367 |
)
|
| 368 |
fallback.update(
|
| 369 |
_build_status_fallback(
|
| 370 |
skipped_human_ids,
|
| 371 |
+
AssessmentStatus.SKIPPED_POLICY,
|
| 372 |
"Human/person analysis skipped due policy constraints.",
|
| 373 |
)
|
| 374 |
)
|
utils/mission_parser.py
CHANGED
|
@@ -12,15 +12,13 @@ Internal flow:
|
|
| 12 |
6. Return validated MissionSpecification or raise MissionParseError
|
| 13 |
"""
|
| 14 |
|
| 15 |
-
import base64
|
| 16 |
import json
|
| 17 |
import logging
|
| 18 |
-
import os
|
| 19 |
import re
|
| 20 |
-
import urllib.request
|
| 21 |
-
import urllib.error
|
| 22 |
from typing import List, Optional
|
| 23 |
|
|
|
|
|
|
|
| 24 |
from coco_classes import COCO_CLASSES, canonicalize_coco_name, coco_class_catalog
|
| 25 |
from utils.schemas import MissionSpecification, RelevanceCriteria
|
| 26 |
|
|
@@ -209,16 +207,11 @@ def _extract_and_encode_first_frame(video_path: Optional[str]) -> Optional[str]:
|
|
| 209 |
if not video_path:
|
| 210 |
return None
|
| 211 |
try:
|
| 212 |
-
import cv2
|
| 213 |
from inference import extract_first_frame
|
|
|
|
| 214 |
|
| 215 |
frame, _fps, _w, _h = extract_first_frame(video_path)
|
| 216 |
-
|
| 217 |
-
".jpg", frame, [int(cv2.IMWRITE_JPEG_QUALITY), 85]
|
| 218 |
-
)
|
| 219 |
-
if not success:
|
| 220 |
-
return None
|
| 221 |
-
return base64.b64encode(buf).decode("ascii")
|
| 222 |
except Exception:
|
| 223 |
logger.warning("Failed to extract/encode first frame for vision grounding", exc_info=True)
|
| 224 |
return None
|
|
@@ -226,8 +219,7 @@ def _extract_and_encode_first_frame(video_path: Optional[str]) -> Optional[str]:
|
|
| 226 |
|
| 227 |
def _call_extraction_llm(raw_text: str, detector_key: str, first_frame_b64: Optional[str] = None) -> dict:
|
| 228 |
"""Call GPT-4o to extract structured mission fields from natural language."""
|
| 229 |
-
|
| 230 |
-
if not api_key:
|
| 231 |
raise MissionParseError(
|
| 232 |
"OPENAI_API_KEY not set. Cannot parse natural language mission text. "
|
| 233 |
"Use comma-separated class labels instead (e.g., 'person, car, boat')."
|
|
@@ -278,28 +270,15 @@ def _call_extraction_llm(raw_text: str, detector_key: str, first_frame_b64: Opti
|
|
| 278 |
],
|
| 279 |
}
|
| 280 |
|
| 281 |
-
headers = {
|
| 282 |
-
"Content-Type": "application/json",
|
| 283 |
-
"Authorization": f"Bearer {api_key}",
|
| 284 |
-
}
|
| 285 |
-
|
| 286 |
try:
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
data=json.dumps(payload).encode("utf-8"),
|
| 290 |
-
headers=headers,
|
| 291 |
-
method="POST",
|
| 292 |
-
)
|
| 293 |
-
with urllib.request.urlopen(req, timeout=timeout_s) as response:
|
| 294 |
-
resp_data = json.loads(response.read().decode("utf-8"))
|
| 295 |
-
|
| 296 |
-
content = resp_data["choices"][0]["message"].get("content")
|
| 297 |
if not content:
|
| 298 |
raise MissionParseError("GPT returned empty content during mission parsing.")
|
| 299 |
|
| 300 |
return json.loads(content)
|
| 301 |
|
| 302 |
-
except
|
| 303 |
raise MissionParseError(f"Mission parsing API call failed: {e}")
|
| 304 |
except json.JSONDecodeError:
|
| 305 |
raise MissionParseError(
|
|
|
|
| 12 |
6. Return validated MissionSpecification or raise MissionParseError
|
| 13 |
"""
|
| 14 |
|
|
|
|
| 15 |
import json
|
| 16 |
import logging
|
|
|
|
| 17 |
import re
|
|
|
|
|
|
|
| 18 |
from typing import List, Optional
|
| 19 |
|
| 20 |
+
from utils.openai_client import chat_completion, extract_content, get_api_key, OpenAIAPIError
|
| 21 |
+
|
| 22 |
from coco_classes import COCO_CLASSES, canonicalize_coco_name, coco_class_catalog
|
| 23 |
from utils.schemas import MissionSpecification, RelevanceCriteria
|
| 24 |
|
|
|
|
| 207 |
if not video_path:
|
| 208 |
return None
|
| 209 |
try:
|
|
|
|
| 210 |
from inference import extract_first_frame
|
| 211 |
+
from utils.gpt_reasoning import encode_frame_to_b64
|
| 212 |
|
| 213 |
frame, _fps, _w, _h = extract_first_frame(video_path)
|
| 214 |
+
return encode_frame_to_b64(frame, quality=85)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
except Exception:
|
| 216 |
logger.warning("Failed to extract/encode first frame for vision grounding", exc_info=True)
|
| 217 |
return None
|
|
|
|
| 219 |
|
| 220 |
def _call_extraction_llm(raw_text: str, detector_key: str, first_frame_b64: Optional[str] = None) -> dict:
|
| 221 |
"""Call GPT-4o to extract structured mission fields from natural language."""
|
| 222 |
+
if not get_api_key():
|
|
|
|
| 223 |
raise MissionParseError(
|
| 224 |
"OPENAI_API_KEY not set. Cannot parse natural language mission text. "
|
| 225 |
"Use comma-separated class labels instead (e.g., 'person, car, boat')."
|
|
|
|
| 270 |
],
|
| 271 |
}
|
| 272 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
try:
|
| 274 |
+
resp_data = chat_completion(payload, timeout=timeout_s)
|
| 275 |
+
content, _refusal = extract_content(resp_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
if not content:
|
| 277 |
raise MissionParseError("GPT returned empty content during mission parsing.")
|
| 278 |
|
| 279 |
return json.loads(content)
|
| 280 |
|
| 281 |
+
except OpenAIAPIError as e:
|
| 282 |
raise MissionParseError(f"Mission parsing API call failed: {e}")
|
| 283 |
except json.JSONDecodeError:
|
| 284 |
raise MissionParseError(
|
utils/openai_client.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Shared OpenAI HTTP client — single implementation of the chat-completions call.
|
| 3 |
+
|
| 4 |
+
Replaces duplicated urllib boilerplate in gpt_reasoning, relevance,
|
| 5 |
+
mission_parser, and threat_chat.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import logging
|
| 10 |
+
import os
|
| 11 |
+
import urllib.request
|
| 12 |
+
import urllib.error
|
| 13 |
+
from typing import Dict, Optional, Tuple
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
_API_URL = "https://api.openai.com/v1/chat/completions"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class OpenAIAPIError(Exception):
|
| 21 |
+
"""Raised when the OpenAI API call fails (HTTP or network error)."""
|
| 22 |
+
|
| 23 |
+
def __init__(self, message: str, status_code: Optional[int] = None):
|
| 24 |
+
self.status_code = status_code
|
| 25 |
+
super().__init__(message)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def get_api_key() -> Optional[str]:
|
| 29 |
+
"""Return the OpenAI API key from the environment, or None."""
|
| 30 |
+
return os.environ.get("OPENAI_API_KEY")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def chat_completion(payload: Dict, *, timeout: int = 30) -> Dict:
|
| 34 |
+
"""Send a chat-completion request and return the parsed JSON response.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
payload: Full request body (model, messages, etc.).
|
| 38 |
+
timeout: HTTP timeout in seconds.
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
Parsed response dict.
|
| 42 |
+
|
| 43 |
+
Raises:
|
| 44 |
+
OpenAIAPIError: On HTTP or network failure.
|
| 45 |
+
"""
|
| 46 |
+
api_key = get_api_key()
|
| 47 |
+
if not api_key:
|
| 48 |
+
raise OpenAIAPIError("OPENAI_API_KEY not set")
|
| 49 |
+
|
| 50 |
+
headers = {
|
| 51 |
+
"Content-Type": "application/json",
|
| 52 |
+
"Authorization": f"Bearer {api_key}",
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
try:
|
| 56 |
+
req = urllib.request.Request(
|
| 57 |
+
_API_URL,
|
| 58 |
+
data=json.dumps(payload).encode("utf-8"),
|
| 59 |
+
headers=headers,
|
| 60 |
+
method="POST",
|
| 61 |
+
)
|
| 62 |
+
with urllib.request.urlopen(req, timeout=timeout) as response:
|
| 63 |
+
return json.loads(response.read().decode("utf-8"))
|
| 64 |
+
except urllib.error.HTTPError as e:
|
| 65 |
+
raise OpenAIAPIError(
|
| 66 |
+
f"HTTP {e.code}: {e.reason}", status_code=e.code
|
| 67 |
+
) from e
|
| 68 |
+
except urllib.error.URLError as e:
|
| 69 |
+
raise OpenAIAPIError(f"URL error: {e.reason}") from e
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def extract_content(resp_data: Dict) -> Tuple[Optional[str], Optional[str]]:
|
| 73 |
+
"""Safely extract content and refusal from a chat-completion response.
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
(content, refusal) — either may be None.
|
| 77 |
+
"""
|
| 78 |
+
choice = resp_data.get("choices", [{}])[0]
|
| 79 |
+
message = choice.get("message", {})
|
| 80 |
+
return message.get("content"), message.get("refusal")
|
utils/relevance.py
CHANGED
|
@@ -12,11 +12,10 @@ LLM-derived field. This is structural, not by convention.
|
|
| 12 |
|
| 13 |
import json
|
| 14 |
import logging
|
| 15 |
-
import os
|
| 16 |
-
import urllib.request
|
| 17 |
-
import urllib.error
|
| 18 |
from typing import Any, Dict, List, NamedTuple, Set
|
| 19 |
|
|
|
|
|
|
|
| 20 |
from coco_classes import canonicalize_coco_name
|
| 21 |
from utils.schemas import RelevanceCriteria
|
| 22 |
|
|
@@ -91,8 +90,7 @@ def evaluate_relevance_llm(
|
|
| 91 |
if not detected_labels:
|
| 92 |
return set()
|
| 93 |
|
| 94 |
-
|
| 95 |
-
if not api_key:
|
| 96 |
logger.warning(
|
| 97 |
"OPENAI_API_KEY not set — LLM relevance filter falling back to accept-all"
|
| 98 |
)
|
|
@@ -118,22 +116,9 @@ def evaluate_relevance_llm(
|
|
| 118 |
],
|
| 119 |
}
|
| 120 |
|
| 121 |
-
headers = {
|
| 122 |
-
"Content-Type": "application/json",
|
| 123 |
-
"Authorization": f"Bearer {api_key}",
|
| 124 |
-
}
|
| 125 |
-
|
| 126 |
try:
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
data=json.dumps(payload).encode("utf-8"),
|
| 130 |
-
headers=headers,
|
| 131 |
-
method="POST",
|
| 132 |
-
)
|
| 133 |
-
with urllib.request.urlopen(req, timeout=30) as response:
|
| 134 |
-
resp_data = json.loads(response.read().decode("utf-8"))
|
| 135 |
-
|
| 136 |
-
content = resp_data["choices"][0]["message"].get("content")
|
| 137 |
if not content:
|
| 138 |
logger.warning("GPT returned empty content for relevance filter — accept-all")
|
| 139 |
return set(detected_labels)
|
|
@@ -148,7 +133,7 @@ def evaluate_relevance_llm(
|
|
| 148 |
)
|
| 149 |
return relevant_set
|
| 150 |
|
| 151 |
-
except
|
| 152 |
logger.warning("LLM relevance API call failed: %s — accept-all fallback", e)
|
| 153 |
return set(detected_labels)
|
| 154 |
except (json.JSONDecodeError, KeyError, TypeError) as e:
|
|
|
|
| 12 |
|
| 13 |
import json
|
| 14 |
import logging
|
|
|
|
|
|
|
|
|
|
| 15 |
from typing import Any, Dict, List, NamedTuple, Set
|
| 16 |
|
| 17 |
+
from utils.openai_client import chat_completion, extract_content, get_api_key, OpenAIAPIError
|
| 18 |
+
|
| 19 |
from coco_classes import canonicalize_coco_name
|
| 20 |
from utils.schemas import RelevanceCriteria
|
| 21 |
|
|
|
|
| 90 |
if not detected_labels:
|
| 91 |
return set()
|
| 92 |
|
| 93 |
+
if not get_api_key():
|
|
|
|
| 94 |
logger.warning(
|
| 95 |
"OPENAI_API_KEY not set — LLM relevance filter falling back to accept-all"
|
| 96 |
)
|
|
|
|
| 116 |
],
|
| 117 |
}
|
| 118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
try:
|
| 120 |
+
resp_data = chat_completion(payload)
|
| 121 |
+
content, _refusal = extract_content(resp_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
if not content:
|
| 123 |
logger.warning("GPT returned empty content for relevance filter — accept-all")
|
| 124 |
return set(detected_labels)
|
|
|
|
| 133 |
)
|
| 134 |
return relevant_set
|
| 135 |
|
| 136 |
+
except OpenAIAPIError as e:
|
| 137 |
logger.warning("LLM relevance API call failed: %s — accept-all fallback", e)
|
| 138 |
return set(detected_labels)
|
| 139 |
except (json.JSONDecodeError, KeyError, TypeError) as e:
|
utils/schemas.py
CHANGED
|
@@ -146,3 +146,15 @@ class MissionSpecification(BaseModel):
|
|
| 146 |
"E.g., 'term \"threat\" is not a visual class, stripped'."
|
| 147 |
)
|
| 148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
"E.g., 'term \"threat\" is not a visual class, stripped'."
|
| 147 |
)
|
| 148 |
|
| 149 |
+
|
| 150 |
+
class AssessmentStatus:
|
| 151 |
+
"""Canonical string constants for detection assessment lifecycle."""
|
| 152 |
+
ASSESSED = "ASSESSED"
|
| 153 |
+
UNASSESSED = "UNASSESSED"
|
| 154 |
+
PENDING_GPT = "PENDING_GPT"
|
| 155 |
+
SKIPPED_POLICY = "SKIPPED_POLICY"
|
| 156 |
+
REFUSED = "REFUSED"
|
| 157 |
+
ERROR = "ERROR"
|
| 158 |
+
NO_RESPONSE = "NO_RESPONSE"
|
| 159 |
+
STALE = "STALE"
|
| 160 |
+
|
utils/tracker.py
CHANGED
|
@@ -3,6 +3,8 @@ import numpy as np
|
|
| 3 |
from scipy.optimize import linear_sum_assignment
|
| 4 |
import scipy.linalg
|
| 5 |
|
|
|
|
|
|
|
| 6 |
|
| 7 |
class KalmanFilter:
|
| 8 |
"""
|
|
@@ -574,11 +576,11 @@ class ByteTracker:
|
|
| 574 |
if assessment_frame is not None:
|
| 575 |
frames_since = self.frame_id - assessment_frame
|
| 576 |
if frames_since > MAX_STALE_FRAMES:
|
| 577 |
-
d_out['assessment_status'] =
|
| 578 |
d_out['assessment_age_frames'] = frames_since
|
| 579 |
-
elif d_out.get('assessment_status') !=
|
| 580 |
# INV-6: Unassessed objects get explicit UNASSESSED status
|
| 581 |
-
d_out['assessment_status'] =
|
| 582 |
|
| 583 |
# Update history
|
| 584 |
if 'history' not in track.gpt_data:
|
|
@@ -634,7 +636,7 @@ class ByteTracker:
|
|
| 634 |
k in meta for k in ("threat_level_score", "gpt_raw", "object_type")
|
| 635 |
):
|
| 636 |
meta["assessment_frame_index"] = self.frame_id
|
| 637 |
-
meta["assessment_status"] =
|
| 638 |
meta_by_tid[tid] = meta
|
| 639 |
for track in self.tracked_stracks:
|
| 640 |
tid_str = f"T{str(track.track_id).zfill(2)}"
|
|
|
|
| 3 |
from scipy.optimize import linear_sum_assignment
|
| 4 |
import scipy.linalg
|
| 5 |
|
| 6 |
+
from utils.schemas import AssessmentStatus
|
| 7 |
+
|
| 8 |
|
| 9 |
class KalmanFilter:
|
| 10 |
"""
|
|
|
|
| 576 |
if assessment_frame is not None:
|
| 577 |
frames_since = self.frame_id - assessment_frame
|
| 578 |
if frames_since > MAX_STALE_FRAMES:
|
| 579 |
+
d_out['assessment_status'] = AssessmentStatus.STALE
|
| 580 |
d_out['assessment_age_frames'] = frames_since
|
| 581 |
+
elif d_out.get('assessment_status') != AssessmentStatus.ASSESSED:
|
| 582 |
# INV-6: Unassessed objects get explicit UNASSESSED status
|
| 583 |
+
d_out['assessment_status'] = AssessmentStatus.UNASSESSED
|
| 584 |
|
| 585 |
# Update history
|
| 586 |
if 'history' not in track.gpt_data:
|
|
|
|
| 636 |
k in meta for k in ("threat_level_score", "gpt_raw", "object_type")
|
| 637 |
):
|
| 638 |
meta["assessment_frame_index"] = self.frame_id
|
| 639 |
+
meta["assessment_status"] = AssessmentStatus.ASSESSED
|
| 640 |
meta_by_tid[tid] = meta
|
| 641 |
for track in self.tracked_stracks:
|
| 642 |
tid_str = f"T{str(track.track_id).zfill(2)}"
|