Zhen Ye commited on
Commit
374a0ef
·
1 Parent(s): 1a9b396

Fix GPT thread safety, improve relevance logic, and add caching for COCO matching

Browse files
Files changed (3) hide show
  1. app.py +9 -10
  2. coco_classes.py +2 -0
  3. inference.py +37 -22
app.py CHANGED
@@ -90,16 +90,15 @@ async def _enrich_first_frame_gpt(
90
  # LLM relevance filter (if LLM_EXTRACTED mode)
91
  gpt_dets = detections
92
  if mission_spec and mission_spec.parse_mode == "LLM_EXTRACTED":
93
- if not mission_spec.relevance_criteria.required_classes:
94
- unique_labels = list({
95
- d.get("label", "").lower()
96
- for d in detections if d.get("label")
97
- })
98
- relevant_labels = await asyncio.to_thread(
99
- evaluate_relevance_llm, unique_labels, mission_spec.operator_text
100
- )
101
- mission_spec.relevance_criteria.required_classes = list(relevant_labels)
102
- # Apply deterministic filter
103
  for d in detections:
104
  decision = evaluate_relevance(d, mission_spec.relevance_criteria)
105
  d["mission_relevant"] = decision.relevant
 
90
  # LLM relevance filter (if LLM_EXTRACTED mode)
91
  gpt_dets = detections
92
  if mission_spec and mission_spec.parse_mode == "LLM_EXTRACTED":
93
+ unique_labels = list({
94
+ d.get("label", "").lower()
95
+ for d in detections if d.get("label")
96
+ })
97
+ relevant_labels = await asyncio.to_thread(
98
+ evaluate_relevance_llm, unique_labels, mission_spec.operator_text
99
+ )
100
+ mission_spec.relevance_criteria.required_classes = list(relevant_labels)
101
+ # Apply deterministic filter with refined classes
 
102
  for d in detections:
103
  decision = evaluate_relevance(d, mission_spec.relevance_criteria)
104
  d["mission_relevant"] = decision.relevant
coco_classes.py CHANGED
@@ -1,6 +1,7 @@
1
  from __future__ import annotations
2
 
3
  import difflib
 
4
  import logging
5
  import re
6
  from typing import Dict, Optional, Tuple
@@ -219,6 +220,7 @@ def _semantic_coco_match(value: str) -> Optional[str]:
219
  return None
220
 
221
 
 
222
  def canonicalize_coco_name(value: str | None) -> str | None:
223
  """Map an arbitrary string to the closest COCO class name if possible.
224
 
 
1
  from __future__ import annotations
2
 
3
  import difflib
4
+ import functools
5
  import logging
6
  import re
7
  from typing import Dict, Optional, Tuple
 
220
  return None
221
 
222
 
223
+ @functools.lru_cache(maxsize=512)
224
  def canonicalize_coco_name(value: str | None) -> str | None:
225
  """Map an arbitrary string to the closest COCO class name if possible.
226
 
inference.py CHANGED
@@ -1046,6 +1046,7 @@ def run_inference(
1046
  # --- GPT Enrichment Thread (non-blocking) ---
1047
  # Runs LLM relevance + GPT threat assessment off the writer's critical path.
1048
  gpt_enrichment_queue = Queue(maxsize=4)
 
1049
 
1050
  def enrichment_thread_fn(tracker_ref):
1051
  """Dedicated thread for GPT/LLM calls. Receives work from writer, injects results via tracker."""
@@ -1055,9 +1056,8 @@ def run_inference(
1055
  break # Sentinel — shutdown
1056
  frame_idx, frame_data, gpt_dets, ms = item
1057
  try:
1058
- # LLM post-filter (LLM_EXTRACTED mode, frame 0 only)
1059
- if (ms and ms.parse_mode == "LLM_EXTRACTED"
1060
- and not ms.relevance_criteria.required_classes):
1061
  unique_labels = list({
1062
  d.get("label", "").lower()
1063
  for d in gpt_dets if d.get("label")
@@ -1066,10 +1066,16 @@ def run_inference(
1066
  unique_labels, ms.operator_text
1067
  )
1068
  ms.relevance_criteria.required_classes = list(relevant_labels)
 
1069
  logging.info(
1070
  "Enrichment: LLM post-filter applied on frame %d: relevant=%s",
1071
  frame_idx, relevant_labels,
1072
  )
 
 
 
 
 
1073
 
1074
  # GPT threat assessment
1075
  if gpt_dets:
@@ -1151,30 +1157,39 @@ def run_inference(
1151
 
1152
  # --- RELEVANCE GATE (deterministic, fast — stays in writer) ---
1153
  if mission_spec:
1154
- for d in dets:
1155
- decision = evaluate_relevance(d, mission_spec.relevance_criteria)
1156
- d["mission_relevant"] = decision.relevant
1157
- d["relevance_reason"] = decision.reason
1158
- if not decision.relevant:
1159
- logging.info(
1160
- json_module.dumps({
1161
- "event": "relevance_decision",
1162
- "track_id": d.get("track_id"),
1163
- "label": d.get("label"),
1164
- "relevant": False,
1165
- "reason": decision.reason,
1166
- "required_classes": mission_spec.relevance_criteria.required_classes,
1167
- "frame": next_idx,
1168
- })
1169
- )
1170
- gpt_dets = [d for d in dets if d.get("mission_relevant", True)]
 
 
 
 
 
 
 
 
 
1171
  else:
1172
  for d in dets:
1173
  d["mission_relevant"] = None
1174
  gpt_dets = dets
1175
 
1176
  # --- GPT ENRICHMENT (non-blocking, offloaded to enrichment thread) ---
1177
- if next_idx == 0 and enable_gpt and gpt_dets and not gpt_submitted:
1178
  # Tag as pending — enrichment thread will update to ASSESSED later
1179
  for d in gpt_dets:
1180
  d["assessment_status"] = "PENDING_GPT"
@@ -1184,7 +1199,7 @@ def run_inference(
1184
  timeout=1.0,
1185
  )
1186
  gpt_submitted = True
1187
- logging.info("Writer: offloaded GPT enrichment for frame 0")
1188
  except Full:
1189
  logging.warning("GPT enrichment queue full, skipping frame 0 GPT")
1190
 
 
1046
  # --- GPT Enrichment Thread (non-blocking) ---
1047
  # Runs LLM relevance + GPT threat assessment off the writer's critical path.
1048
  gpt_enrichment_queue = Queue(maxsize=4)
1049
+ _relevance_refined = [False] # mutable container for thread-safe sharing
1050
 
1051
  def enrichment_thread_fn(tracker_ref):
1052
  """Dedicated thread for GPT/LLM calls. Receives work from writer, injects results via tracker."""
 
1056
  break # Sentinel — shutdown
1057
  frame_idx, frame_data, gpt_dets, ms = item
1058
  try:
1059
+ # LLM post-filter (LLM_EXTRACTED mode)
1060
+ if ms and ms.parse_mode == "LLM_EXTRACTED":
 
1061
  unique_labels = list({
1062
  d.get("label", "").lower()
1063
  for d in gpt_dets if d.get("label")
 
1066
  unique_labels, ms.operator_text
1067
  )
1068
  ms.relevance_criteria.required_classes = list(relevant_labels)
1069
+ _relevance_refined[0] = True # signal writer_loop to switch to deterministic gate
1070
  logging.info(
1071
  "Enrichment: LLM post-filter applied on frame %d: relevant=%s",
1072
  frame_idx, relevant_labels,
1073
  )
1074
+ # Re-filter with refined classes
1075
+ for d in gpt_dets:
1076
+ decision = evaluate_relevance(d, ms.relevance_criteria)
1077
+ d["mission_relevant"] = decision.relevant
1078
+ gpt_dets = [d for d in gpt_dets if d.get("mission_relevant", True)]
1079
 
1080
  # GPT threat assessment
1081
  if gpt_dets:
 
1157
 
1158
  # --- RELEVANCE GATE (deterministic, fast — stays in writer) ---
1159
  if mission_spec:
1160
+ if (mission_spec.parse_mode == "LLM_EXTRACTED"
1161
+ and not _relevance_refined[0]):
1162
+ # LLM post-filter hasn't run yet — pass all through
1163
+ for d in dets:
1164
+ d["mission_relevant"] = True
1165
+ d["relevance_reason"] = "pending_llm_postfilter"
1166
+ gpt_dets = dets
1167
+ else:
1168
+ # Normal deterministic gate (with refined or FAST_PATH classes)
1169
+ for d in dets:
1170
+ decision = evaluate_relevance(d, mission_spec.relevance_criteria)
1171
+ d["mission_relevant"] = decision.relevant
1172
+ d["relevance_reason"] = decision.reason
1173
+ if not decision.relevant:
1174
+ logging.info(
1175
+ json_module.dumps({
1176
+ "event": "relevance_decision",
1177
+ "track_id": d.get("track_id"),
1178
+ "label": d.get("label"),
1179
+ "relevant": False,
1180
+ "reason": decision.reason,
1181
+ "required_classes": mission_spec.relevance_criteria.required_classes,
1182
+ "frame": next_idx,
1183
+ })
1184
+ )
1185
+ gpt_dets = [d for d in dets if d.get("mission_relevant", True)]
1186
  else:
1187
  for d in dets:
1188
  d["mission_relevant"] = None
1189
  gpt_dets = dets
1190
 
1191
  # --- GPT ENRICHMENT (non-blocking, offloaded to enrichment thread) ---
1192
+ if enable_gpt and gpt_dets and not gpt_submitted:
1193
  # Tag as pending — enrichment thread will update to ASSESSED later
1194
  for d in gpt_dets:
1195
  d["assessment_status"] = "PENDING_GPT"
 
1199
  timeout=1.0,
1200
  )
1201
  gpt_submitted = True
1202
+ logging.info("Writer: offloaded GPT enrichment for frame %d", next_idx)
1203
  except Full:
1204
  logging.warning("GPT enrichment queue full, skipping frame 0 GPT")
1205