Zhen Ye Claude Opus 4.6 commited on
Commit
624478a
Β·
1 Parent(s): a2ca6f9

feat: Add LLM post-processing relevance gate for broad detection

Browse files

Move LLM's role from pre-processing (choosing detector classes) to
post-processing (filtering detected objects against mission intent).
Fast-path (comma-separated labels) is unchanged. LLM-path now detects
broadly (all COCO classes or domain-expanded queries), then asks GPT
once on frame 0 which labels are mission-relevant, caching the result
for deterministic filtering on all subsequent frames.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (5) hide show
  1. app.py +6 -5
  2. inference.py +81 -28
  3. utils/mission_parser.py +46 -0
  4. utils/relevance.py +88 -3
  5. utils/schemas.py +7 -0
app.py CHANGED
@@ -57,7 +57,7 @@ from jobs.storage import (
57
  )
58
  from utils.gpt_reasoning import estimate_threat_gpt
59
  from utils.threat_chat import chat_about_threats
60
- from utils.mission_parser import parse_mission_text, MissionParseError
61
 
62
  logging.basicConfig(level=logging.INFO)
63
 
@@ -274,7 +274,7 @@ async def detect_endpoint(
274
  if queries.strip():
275
  try:
276
  mission_spec = parse_mission_text(queries.strip(), detector_name)
277
- query_list = mission_spec.object_classes
278
  except MissionParseError as e:
279
  raise HTTPException(status_code=422, detail=str(e))
280
  else:
@@ -370,11 +370,12 @@ async def detect_async_endpoint(
370
  if queries.strip():
371
  try:
372
  mission_spec = parse_mission_text(queries.strip(), detector_name)
373
- query_list = mission_spec.object_classes
374
  mission_mode = "MISSION"
375
  logging.info(
376
- "Mission parsed: mode=%s classes=%s domain=%s(%s)",
377
- mission_mode, query_list, mission_spec.domain, mission_spec.domain_source,
 
378
  )
379
  except MissionParseError as e:
380
  raise HTTPException(
 
57
  )
58
  from utils.gpt_reasoning import estimate_threat_gpt
59
  from utils.threat_chat import chat_about_threats
60
+ from utils.mission_parser import parse_mission_text, build_broad_queries, MissionParseError
61
 
62
  logging.basicConfig(level=logging.INFO)
63
 
 
274
  if queries.strip():
275
  try:
276
  mission_spec = parse_mission_text(queries.strip(), detector_name)
277
+ query_list = build_broad_queries(detector_name, mission_spec)
278
  except MissionParseError as e:
279
  raise HTTPException(status_code=422, detail=str(e))
280
  else:
 
370
  if queries.strip():
371
  try:
372
  mission_spec = parse_mission_text(queries.strip(), detector_name)
373
+ query_list = build_broad_queries(detector_name, mission_spec)
374
  mission_mode = "MISSION"
375
  logging.info(
376
+ "Mission parsed: mode=%s classes=%s broad_queries=%s domain=%s(%s)",
377
+ mission_mode, mission_spec.object_classes, query_list,
378
+ mission_spec.domain, mission_spec.domain_source,
379
  )
380
  except MissionParseError as e:
381
  raise HTTPException(
inference.py CHANGED
@@ -23,7 +23,7 @@ from models.depth_estimators.model_loader import load_depth_estimator, load_dept
23
  from models.depth_estimators.base import DepthEstimator
24
  from utils.video import extract_frames, write_video, VideoReader, VideoWriter, AsyncVideoReader
25
  from utils.gpt_reasoning import estimate_threat_gpt
26
- from utils.relevance import evaluate_relevance
27
  from jobs.storage import set_track_data
28
  import tempfile
29
  import json as json_module
@@ -732,25 +732,57 @@ def process_first_frame(
732
 
733
  # --- RELEVANCE GATE (between detection and GPT) ---
734
  if mission_spec:
735
- relevant_dets = []
736
- for det in detections:
737
- decision = evaluate_relevance(det, mission_spec.relevance_criteria)
738
- det["mission_relevant"] = decision.relevant
739
- det["relevance_reason"] = decision.reason
740
- if decision.relevant:
741
- relevant_dets.append(det)
742
- else:
743
- logging.info(
744
- json_module.dumps({
745
- "event": "relevance_decision",
746
- "label": det.get("label"),
747
- "relevant": False,
748
- "reason": decision.reason,
749
- "required_classes": mission_spec.relevance_criteria.required_classes,
750
- "frame": 0,
751
- })
752
- )
753
- gpt_input_dets = relevant_dets
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
754
  else:
755
  # LEGACY mode: all detections pass, tagged as unresolved
756
  for det in detections:
@@ -1117,28 +1149,29 @@ def run_inference(
1117
  # Initialize Tracker & Speed Estimator
1118
  tracker = ByteTracker(frame_rate=fps)
1119
  speed_est = SpeedEstimator(fps=fps)
1120
-
 
1121
  try:
1122
  with VideoWriter(output_video_path, fps, width, height) as writer:
1123
  while next_idx < total_frames:
1124
  # Fetch from queue
1125
  try:
1126
  while next_idx not in buffer:
1127
- # Backpressure: If buffer gets too big due to out-of-order frames,
1128
- # we might want to warn or just hope for the best.
1129
  # But here we are just consuming.
1130
-
1131
  # However, if 'buffer' grows too large (because we are missing next_idx),
1132
  # we are effectively unbounded again if queue_out fills up with future frames.
1133
  # So we should monitor buffer size.
1134
  if len(buffer) > 200 and len(buffer) % 50 == 0:
1135
  logging.warning("Writer buffer large (%d items), waiting for frame %d (GPT Latency?)...", len(buffer), next_idx)
1136
-
1137
  item = queue_out.get(timeout=1.0) # wait
1138
-
1139
  idx, p_frame, dets = item
1140
  buffer[idx] = (p_frame, dets)
1141
-
1142
  # Write next_idx
1143
  p_frame, dets = buffer.pop(next_idx)
1144
 
@@ -1147,7 +1180,27 @@ def run_inference(
1147
  dets = tracker.update(dets)
1148
  speed_est.estimate(dets)
1149
 
1150
- # --- RELEVANCE GATE ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1151
  if mission_spec:
1152
  for d in dets:
1153
  decision = evaluate_relevance(d, mission_spec.relevance_criteria)
 
23
  from models.depth_estimators.base import DepthEstimator
24
  from utils.video import extract_frames, write_video, VideoReader, VideoWriter, AsyncVideoReader
25
  from utils.gpt_reasoning import estimate_threat_gpt
26
+ from utils.relevance import evaluate_relevance, evaluate_relevance_llm
27
  from jobs.storage import set_track_data
28
  import tempfile
29
  import json as json_module
 
732
 
733
  # --- RELEVANCE GATE (between detection and GPT) ---
734
  if mission_spec:
735
+ if mission_spec.parse_mode == "FAST_PATH":
736
+ # Deterministic gate (unchanged)
737
+ relevant_dets = []
738
+ for det in detections:
739
+ decision = evaluate_relevance(det, mission_spec.relevance_criteria)
740
+ det["mission_relevant"] = decision.relevant
741
+ det["relevance_reason"] = decision.reason
742
+ if decision.relevant:
743
+ relevant_dets.append(det)
744
+ else:
745
+ logging.info(
746
+ json_module.dumps({
747
+ "event": "relevance_decision",
748
+ "label": det.get("label"),
749
+ "relevant": False,
750
+ "reason": decision.reason,
751
+ "required_classes": mission_spec.relevance_criteria.required_classes,
752
+ "frame": 0,
753
+ })
754
+ )
755
+ gpt_input_dets = relevant_dets
756
+ else:
757
+ # LLM_EXTRACTED: post-filter with GPT on frame 0
758
+ unique_labels = list({
759
+ d.get("label", "").lower()
760
+ for d in detections if d.get("label")
761
+ })
762
+ relevant_labels = evaluate_relevance_llm(
763
+ unique_labels, mission_spec.operator_text
764
+ )
765
+
766
+ # Cache GPT-approved labels into relevance_criteria for subsequent frames
767
+ mission_spec.relevance_criteria.required_classes = list(relevant_labels)
768
+
769
+ for det in detections:
770
+ label = (det.get("label") or "").lower()
771
+ is_relevant = label in relevant_labels
772
+ det["mission_relevant"] = is_relevant
773
+ det["relevance_reason"] = "ok" if is_relevant else "llm_excluded"
774
+ if not is_relevant:
775
+ logging.info(
776
+ json_module.dumps({
777
+ "event": "relevance_decision",
778
+ "label": det.get("label"),
779
+ "relevant": False,
780
+ "reason": "llm_excluded",
781
+ "relevant_labels": list(relevant_labels),
782
+ "frame": 0,
783
+ })
784
+ )
785
+ gpt_input_dets = [d for d in detections if d.get("mission_relevant")]
786
  else:
787
  # LEGACY mode: all detections pass, tagged as unresolved
788
  for det in detections:
 
1149
  # Initialize Tracker & Speed Estimator
1150
  tracker = ByteTracker(frame_rate=fps)
1151
  speed_est = SpeedEstimator(fps=fps)
1152
+ llm_filtered = False # LLM post-filter runs once on frame 0
1153
+
1154
  try:
1155
  with VideoWriter(output_video_path, fps, width, height) as writer:
1156
  while next_idx < total_frames:
1157
  # Fetch from queue
1158
  try:
1159
  while next_idx not in buffer:
1160
+ # Backpressure: If buffer gets too big due to out-of-order frames,
1161
+ # we might want to warn or just hope for the best.
1162
  # But here we are just consuming.
1163
+
1164
  # However, if 'buffer' grows too large (because we are missing next_idx),
1165
  # we are effectively unbounded again if queue_out fills up with future frames.
1166
  # So we should monitor buffer size.
1167
  if len(buffer) > 200 and len(buffer) % 50 == 0:
1168
  logging.warning("Writer buffer large (%d items), waiting for frame %d (GPT Latency?)...", len(buffer), next_idx)
1169
+
1170
  item = queue_out.get(timeout=1.0) # wait
1171
+
1172
  idx, p_frame, dets = item
1173
  buffer[idx] = (p_frame, dets)
1174
+
1175
  # Write next_idx
1176
  p_frame, dets = buffer.pop(next_idx)
1177
 
 
1180
  dets = tracker.update(dets)
1181
  speed_est.estimate(dets)
1182
 
1183
+ # --- LLM POST-FILTER (frame 0 only, LLM_EXTRACTED mode) ---
1184
+ if (mission_spec
1185
+ and mission_spec.parse_mode == "LLM_EXTRACTED"
1186
+ and not llm_filtered
1187
+ and next_idx == 0):
1188
+ unique_labels = list({
1189
+ d.get("label", "").lower()
1190
+ for d in dets if d.get("label")
1191
+ })
1192
+ relevant_labels = evaluate_relevance_llm(
1193
+ unique_labels, mission_spec.operator_text
1194
+ )
1195
+ # Cache into relevance_criteria for all subsequent frames
1196
+ mission_spec.relevance_criteria.required_classes = list(relevant_labels)
1197
+ llm_filtered = True
1198
+ logging.info(
1199
+ "LLM post-filter applied on frame 0: relevant=%s",
1200
+ relevant_labels,
1201
+ )
1202
+
1203
+ # --- RELEVANCE GATE (deterministic, uses updated criteria) ---
1204
  if mission_spec:
1205
  for d in dets:
1206
  decision = evaluate_relevance(d, mission_spec.relevance_criteria)
utils/mission_parser.py CHANGED
@@ -114,6 +114,7 @@ def _build_fast_path_spec(
114
  context_phrases=[],
115
  stripped_modifiers=[],
116
  operator_text=raw_text,
 
117
  parse_confidence="HIGH",
118
  parse_warnings=warnings,
119
  )
@@ -328,11 +329,56 @@ def _validate_and_build(
328
  context_phrases=context_phrases,
329
  stripped_modifiers=stripped_modifiers,
330
  operator_text=raw_text,
 
331
  parse_confidence=parse_confidence,
332
  parse_warnings=parse_warnings,
333
  )
334
 
335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  def parse_mission_text(
337
  raw_text: str,
338
  detector_key: str,
 
114
  context_phrases=[],
115
  stripped_modifiers=[],
116
  operator_text=raw_text,
117
+ parse_mode="FAST_PATH",
118
  parse_confidence="HIGH",
119
  parse_warnings=warnings,
120
  )
 
329
  context_phrases=context_phrases,
330
  stripped_modifiers=stripped_modifiers,
331
  operator_text=raw_text,
332
+ parse_mode="LLM_EXTRACTED",
333
  parse_confidence=parse_confidence,
334
  parse_warnings=parse_warnings,
335
  )
336
 
337
 
338
+ _DOMAIN_BROAD_CATEGORIES: dict[str, List[str]] = {
339
+ "NAVAL": ["vessel", "ship", "boat", "buoy", "person"],
340
+ "AERIAL": ["aircraft", "helicopter", "drone", "airplane"],
341
+ "GROUND": ["vehicle", "car", "truck", "person", "building"],
342
+ "URBAN": ["person", "vehicle", "car", "bicycle"],
343
+ "GENERIC": ["object"],
344
+ }
345
+
346
+
347
+ def build_broad_queries(
348
+ detector_key: str, mission_spec: MissionSpecification
349
+ ) -> List[str]:
350
+ """Build broad detector queries for LLM post-filter mode.
351
+
352
+ For FAST_PATH: return object_classes directly (unchanged behavior).
353
+ For COCO detectors (LLM_EXTRACTED): return ALL 80 COCO classes.
354
+ For open-vocab detectors (LLM_EXTRACTED): return LLM-extracted classes
355
+ PLUS broad domain categories to maximize recall.
356
+ """
357
+ if mission_spec.parse_mode == "FAST_PATH":
358
+ return mission_spec.object_classes
359
+
360
+ # LLM_EXTRACTED path: detect broadly
361
+ if _is_coco_only(detector_key):
362
+ # COCO detectors ignore queries anyway (DETR detects all 80;
363
+ # YOLOv8 falls back to all if no matches). Send everything.
364
+ return list(COCO_CLASSES)
365
+
366
+ # Open-vocab detector (e.g. Grounding DINO):
367
+ # Combine LLM-extracted classes with domain-specific broad categories
368
+ broad = list(mission_spec.object_classes)
369
+ domain_extras = _DOMAIN_BROAD_CATEGORIES.get(
370
+ mission_spec.domain, _DOMAIN_BROAD_CATEGORIES["GENERIC"]
371
+ )
372
+ seen = {c.lower() for c in broad}
373
+ for cat in domain_extras:
374
+ if cat.lower() not in seen:
375
+ broad.append(cat)
376
+ seen.add(cat.lower())
377
+
378
+ logger.info("Broad queries for %s: %s", detector_key, broad)
379
+ return broad
380
+
381
+
382
  def parse_mission_text(
383
  raw_text: str,
384
  detector_key: str,
utils/relevance.py CHANGED
@@ -1,15 +1,21 @@
1
  """
2
  Object relevance evaluation β€” deterministic gate between detection and GPT assessment.
3
 
4
- Single public function: evaluate_relevance(detection, criteria) -> RelevanceDecision
 
 
5
 
6
- INVARIANT INV-13 enforcement: This function accepts RelevanceCriteria, NOT
7
  MissionSpecification. It cannot see context_phrases, stripped_modifiers, or any
8
  LLM-derived field. This is structural, not by convention.
9
  """
10
 
 
11
  import logging
12
- from typing import Any, Dict, NamedTuple
 
 
 
13
 
14
  from coco_classes import canonicalize_coco_name
15
  from utils.schemas import RelevanceCriteria
@@ -69,3 +75,82 @@ def evaluate_relevance(
69
  return RelevanceDecision(True, "ok")
70
 
71
  return RelevanceDecision(False, "label_not_in_required_classes")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
  Object relevance evaluation β€” deterministic gate between detection and GPT assessment.
3
 
4
+ Public functions:
5
+ evaluate_relevance(detection, criteria) -> RelevanceDecision (deterministic)
6
+ evaluate_relevance_llm(detected_labels, mission_text) -> set[str] (LLM post-filter)
7
 
8
+ INVARIANT INV-13 enforcement: evaluate_relevance() accepts RelevanceCriteria, NOT
9
  MissionSpecification. It cannot see context_phrases, stripped_modifiers, or any
10
  LLM-derived field. This is structural, not by convention.
11
  """
12
 
13
+ import json
14
  import logging
15
+ import os
16
+ import urllib.request
17
+ import urllib.error
18
+ from typing import Any, Dict, List, NamedTuple, Set
19
 
20
  from coco_classes import canonicalize_coco_name
21
  from utils.schemas import RelevanceCriteria
 
75
  return RelevanceDecision(True, "ok")
76
 
77
  return RelevanceDecision(False, "label_not_in_required_classes")
78
+
79
+
80
+ def evaluate_relevance_llm(
81
+ detected_labels: List[str],
82
+ mission_text: str,
83
+ ) -> Set[str]:
84
+ """Ask GPT which detected labels are relevant to the mission.
85
+
86
+ Called ONCE on frame 0 with the unique labels found by the detector.
87
+ Returns a set of relevant label strings (lowercased).
88
+
89
+ On API failure, falls back to accepting all labels (fail-open, logged).
90
+ """
91
+ if not detected_labels:
92
+ return set()
93
+
94
+ api_key = os.environ.get("OPENAI_API_KEY")
95
+ if not api_key:
96
+ logger.warning(
97
+ "OPENAI_API_KEY not set β€” LLM relevance filter falling back to accept-all"
98
+ )
99
+ return set(detected_labels)
100
+
101
+ prompt = (
102
+ f"Given this mission: \"{mission_text}\"\n\n"
103
+ f"Which of these detected object classes are relevant to the mission?\n"
104
+ f"{json.dumps(detected_labels)}\n\n"
105
+ "Return JSON: {\"relevant_labels\": [...]}\n"
106
+ "Only include labels from the provided list that are relevant to "
107
+ "accomplishing the mission. Be inclusive β€” if in doubt, include it."
108
+ )
109
+
110
+ payload = {
111
+ "model": "gpt-4o",
112
+ "temperature": 0.0,
113
+ "max_tokens": 200,
114
+ "response_format": {"type": "json_object"},
115
+ "messages": [
116
+ {"role": "system", "content": "You are a mission relevance filter. Return only JSON."},
117
+ {"role": "user", "content": prompt},
118
+ ],
119
+ }
120
+
121
+ headers = {
122
+ "Content-Type": "application/json",
123
+ "Authorization": f"Bearer {api_key}",
124
+ }
125
+
126
+ try:
127
+ req = urllib.request.Request(
128
+ "https://api.openai.com/v1/chat/completions",
129
+ data=json.dumps(payload).encode("utf-8"),
130
+ headers=headers,
131
+ method="POST",
132
+ )
133
+ with urllib.request.urlopen(req, timeout=30) as response:
134
+ resp_data = json.loads(response.read().decode("utf-8"))
135
+
136
+ content = resp_data["choices"][0]["message"].get("content")
137
+ if not content:
138
+ logger.warning("GPT returned empty content for relevance filter β€” accept-all")
139
+ return set(detected_labels)
140
+
141
+ result = json.loads(content)
142
+ relevant = result.get("relevant_labels", detected_labels)
143
+ relevant_set = {label.lower() for label in relevant}
144
+
145
+ logger.info(
146
+ "LLM relevance filter: mission=%r detected=%s relevant=%s",
147
+ mission_text, detected_labels, relevant_set,
148
+ )
149
+ return relevant_set
150
+
151
+ except (urllib.error.HTTPError, urllib.error.URLError) as e:
152
+ logger.warning("LLM relevance API call failed: %s β€” accept-all fallback", e)
153
+ return set(detected_labels)
154
+ except (json.JSONDecodeError, KeyError, TypeError) as e:
155
+ logger.warning("LLM relevance response parse failed: %s β€” accept-all fallback", e)
156
+ return set(detected_labels)
utils/schemas.py CHANGED
@@ -124,6 +124,13 @@ class MissionSpecification(BaseModel):
124
  description="Original unmodified mission text from the operator. Preserved for audit."
125
  )
126
 
 
 
 
 
 
 
 
127
  # --- LLM self-assessment ---
128
  parse_confidence: Literal["HIGH", "MEDIUM", "LOW"] = Field(
129
  ...,
 
124
  description="Original unmodified mission text from the operator. Preserved for audit."
125
  )
126
 
127
+ # --- Parse mode ---
128
+ parse_mode: Literal["FAST_PATH", "LLM_EXTRACTED"] = Field(
129
+ default="FAST_PATH",
130
+ description="How this spec was created. FAST_PATH = comma-separated labels, "
131
+ "LLM_EXTRACTED = natural language parsed by GPT."
132
+ )
133
+
134
  # --- LLM self-assessment ---
135
  parse_confidence: Literal["HIGH", "MEDIUM", "LOW"] = Field(
136
  ...,