Zhen Ye commited on
Commit
aadca27
·
1 Parent(s): d73eff6

Handle GPT refusal paths and preserve assessment status

Browse files
Files changed (3) hide show
  1. app.py +4 -3
  2. inference.py +9 -4
  3. utils/gpt_reasoning.py +122 -11
app.py CHANGED
@@ -147,9 +147,10 @@ async def _enrich_first_frame_gpt(
147
  info = gpt_results[obj_id]
148
  det.update(info)
149
  det["gpt_raw"] = info
150
- # Mark ASSESSED regardless of whether GPT returned data for this object
151
- det["assessment_frame_index"] = 0
152
- det["assessment_status"] = "ASSESSED"
 
153
 
154
  for det in detections:
155
  if "assessment_status" not in det:
 
147
  info = gpt_results[obj_id]
148
  det.update(info)
149
  det["gpt_raw"] = info
150
+ det.setdefault("assessment_frame_index", 0)
151
+ det["assessment_status"] = info.get("assessment_status", "ASSESSED")
152
+ else:
153
+ det.setdefault("assessment_status", "UNASSESSED")
154
 
155
  for det in detections:
156
  if "assessment_status" not in det:
inference.py CHANGED
@@ -1126,10 +1126,13 @@ def run_inference(
1126
  for d in gpt_dets:
1127
  oid = d.get('track_id')
1128
  if oid and oid in gpt_res:
1129
- d.update(gpt_res[oid])
1130
- d["gpt_raw"] = gpt_res[oid]
 
1131
  d["assessment_frame_index"] = frame_idx
1132
- d["assessment_status"] = "ASSESSED"
 
 
1133
 
1134
  # Push GPT data back into tracker's internal STrack objects
1135
  tracker_ref.inject_metadata(gpt_dets)
@@ -2080,7 +2083,9 @@ def run_grounded_sam2_tracking(
2080
  merged = dict(gpt_res[tid])
2081
  merged["gpt_raw"] = gpt_res[tid]
2082
  merged["assessment_frame_index"] = frame_idx
2083
- merged["assessment_status"] = "ASSESSED"
 
 
2084
  with gpt_data_lock:
2085
  gpt_data_by_track[tid] = merged
2086
  logging.info("GSAM2 enrichment: GPT results stored for %d tracks", len(gpt_data_by_track))
 
1126
  for d in gpt_dets:
1127
  oid = d.get('track_id')
1128
  if oid and oid in gpt_res:
1129
+ gpt_payload = gpt_res[oid]
1130
+ d.update(gpt_payload)
1131
+ d["gpt_raw"] = gpt_payload
1132
  d["assessment_frame_index"] = frame_idx
1133
+ d["assessment_status"] = gpt_payload.get(
1134
+ "assessment_status", "ASSESSED"
1135
+ )
1136
 
1137
  # Push GPT data back into tracker's internal STrack objects
1138
  tracker_ref.inject_metadata(gpt_dets)
 
2083
  merged = dict(gpt_res[tid])
2084
  merged["gpt_raw"] = gpt_res[tid]
2085
  merged["assessment_frame_index"] = frame_idx
2086
+ merged["assessment_status"] = merged.get(
2087
+ "assessment_status", "ASSESSED"
2088
+ )
2089
  with gpt_data_lock:
2090
  gpt_data_by_track[tid] = merged
2091
  logging.info("GSAM2 enrichment: GPT results stored for %d tracks", len(gpt_data_by_track))
utils/gpt_reasoning.py CHANGED
@@ -32,6 +32,34 @@ _DOMAIN_ROLES = {
32
  "GENERIC": "Tactical Surveillance Analyst",
33
  }
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  _UNIVERSAL_SCHEMA = (
36
  "RESPONSE SCHEMA (JSON):\n"
37
  "{\n"
@@ -144,17 +172,37 @@ def estimate_threat_gpt(
144
  logger.error("OPENAI_API_KEY not set. Skipping GPT threat assessment.")
145
  return {}
146
 
147
- # 1. Prepare detections summary for prompt
148
- det_summary = []
 
 
149
  for i, det in enumerate(detections):
150
- obj_id = det.get("track_id") or det.get("id") or f"T{str(i+1).zfill(2)}"
151
  bbox = det.get("bbox", [])
152
- label = det.get("label", "object")
153
- det_summary.append(f"- ID: {obj_id}, Classification Hint: {label}, BBox: {bbox}")
154
-
155
- det_text = "\n".join(det_summary)
 
 
 
 
 
 
 
 
156
 
157
  if not det_text:
 
 
 
 
 
 
 
 
 
 
158
  return {}
159
 
160
  # 2. Encode image (prefer pre-encoded b64 to avoid disk I/O)
@@ -231,17 +279,68 @@ def estimate_threat_gpt(
231
  with urllib.request.urlopen(req, timeout=30) as response:
232
  resp_data = json.loads(response.read().decode('utf-8'))
233
 
234
- content = resp_data['choices'][0]['message'].get('content')
 
235
  if not content:
236
- logger.warning("GPT returned empty content. Full response: %s", resp_data)
237
- return {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
  result_json = json.loads(content)
240
 
241
  objects = result_json.get("objects", {})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
  # Polyfill legacy fields for frontend compatibility
244
  for obj_id, data in objects.items():
 
 
 
 
 
 
 
245
  # 1. Distance: parse free-text range_estimate to meters
246
  range_m = _parse_range_to_meters(data.get("range_estimate", ""))
247
  if range_m is not None:
@@ -272,4 +371,16 @@ def estimate_threat_gpt(
272
 
273
  except Exception as e:
274
  logger.error("GPT API call failed: %s", e, exc_info=True)
275
- return {}
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  "GENERIC": "Tactical Surveillance Analyst",
33
  }
34
 
35
+ _HUMAN_LABEL_HINTS = frozenset({
36
+ "person", "people", "human", "pedestrian",
37
+ "man", "woman", "boy", "girl", "child",
38
+ "civilian", "soldier", "infantry", "troop", "trooper",
39
+ })
40
+
41
+
42
+ def _is_human_label(label: str) -> bool:
43
+ label_l = (label or "").lower().strip()
44
+ if not label_l:
45
+ return False
46
+ parts = [p for p in re.split(r"[^a-z0-9]+", label_l) if p]
47
+ return any(part in _HUMAN_LABEL_HINTS for part in parts)
48
+
49
+
50
+ def _build_status_fallback(
51
+ object_ids: List[str],
52
+ status: str,
53
+ reason: str,
54
+ ) -> Dict[str, Dict[str, Any]]:
55
+ return {
56
+ obj_id: {
57
+ "assessment_status": status,
58
+ "gpt_reason": reason,
59
+ }
60
+ for obj_id in object_ids
61
+ }
62
+
63
  _UNIVERSAL_SCHEMA = (
64
  "RESPONSE SCHEMA (JSON):\n"
65
  "{\n"
 
172
  logger.error("OPENAI_API_KEY not set. Skipping GPT threat assessment.")
173
  return {}
174
 
175
+ # 1. Prepare detections summary for prompt.
176
+ # Human/person classes are explicitly skipped to avoid refusal paths.
177
+ prompt_items = []
178
+ skipped_human_ids: List[str] = []
179
  for i, det in enumerate(detections):
180
+ obj_id = str(det.get("track_id") or det.get("id") or f"T{str(i+1).zfill(2)}")
181
  bbox = det.get("bbox", [])
182
+ label = str(det.get("label", "object"))
183
+ if _is_human_label(label):
184
+ skipped_human_ids.append(obj_id)
185
+ continue
186
+ prompt_items.append({"obj_id": obj_id, "label": label, "bbox": bbox})
187
+
188
+ det_text = "\n".join(
189
+ [
190
+ f"- ID: {it['obj_id']}, Classification Hint: {it['label']}, BBox: {it['bbox']}"
191
+ for it in prompt_items
192
+ ]
193
+ )
194
 
195
  if not det_text:
196
+ if skipped_human_ids:
197
+ logger.warning(
198
+ "Skipping GPT threat assessment for %d human/person detections due policy constraints.",
199
+ len(skipped_human_ids),
200
+ )
201
+ return _build_status_fallback(
202
+ skipped_human_ids,
203
+ "SKIPPED_POLICY",
204
+ "Human/person analysis skipped due policy constraints.",
205
+ )
206
  return {}
207
 
208
  # 2. Encode image (prefer pre-encoded b64 to avoid disk I/O)
 
279
  with urllib.request.urlopen(req, timeout=30) as response:
280
  resp_data = json.loads(response.read().decode('utf-8'))
281
 
282
+ choice_msg = resp_data.get("choices", [{}])[0].get("message", {})
283
+ content = choice_msg.get("content")
284
  if not content:
285
+ refusal = choice_msg.get("refusal")
286
+ if refusal:
287
+ logger.warning("GPT refused threat assessment: %s", refusal)
288
+ else:
289
+ logger.warning(
290
+ "GPT returned empty content. response_id=%s finish_reason=%s",
291
+ resp_data.get("id"),
292
+ resp_data.get("choices", [{}])[0].get("finish_reason"),
293
+ )
294
+ fallback = _build_status_fallback(
295
+ [it["obj_id"] for it in prompt_items],
296
+ "REFUSED",
297
+ refusal or "GPT returned empty content.",
298
+ )
299
+ fallback.update(
300
+ _build_status_fallback(
301
+ skipped_human_ids,
302
+ "SKIPPED_POLICY",
303
+ "Human/person analysis skipped due policy constraints.",
304
+ )
305
+ )
306
+ return fallback
307
 
308
  result_json = json.loads(content)
309
 
310
  objects = result_json.get("objects", {})
311
+ if not isinstance(objects, dict):
312
+ logger.warning(
313
+ "GPT response 'objects' field is not a dict (got %s); using fallback.",
314
+ type(objects).__name__,
315
+ )
316
+ objects = {}
317
+
318
+ # Ensure every requested object receives an explicit assessment state.
319
+ for it in prompt_items:
320
+ oid = it["obj_id"]
321
+ if oid not in objects:
322
+ objects[oid] = {
323
+ "assessment_status": "NO_RESPONSE",
324
+ "gpt_reason": "No structured assessment returned for object.",
325
+ }
326
+ for oid in skipped_human_ids:
327
+ objects.setdefault(
328
+ oid,
329
+ {
330
+ "assessment_status": "SKIPPED_POLICY",
331
+ "gpt_reason": "Human/person analysis skipped due policy constraints.",
332
+ },
333
+ )
334
 
335
  # Polyfill legacy fields for frontend compatibility
336
  for obj_id, data in objects.items():
337
+ if not isinstance(data, dict):
338
+ data = {
339
+ "assessment_status": "NO_RESPONSE",
340
+ "gpt_reason": "Malformed object payload from GPT.",
341
+ }
342
+ objects[obj_id] = data
343
+
344
  # 1. Distance: parse free-text range_estimate to meters
345
  range_m = _parse_range_to_meters(data.get("range_estimate", ""))
346
  if range_m is not None:
 
371
 
372
  except Exception as e:
373
  logger.error("GPT API call failed: %s", e, exc_info=True)
374
+ fallback = _build_status_fallback(
375
+ [it["obj_id"] for it in prompt_items],
376
+ "ERROR",
377
+ f"GPT API call failed: {e.__class__.__name__}",
378
+ )
379
+ fallback.update(
380
+ _build_status_fallback(
381
+ skipped_human_ids,
382
+ "SKIPPED_POLICY",
383
+ "Human/person analysis skipped due policy constraints.",
384
+ )
385
+ )
386
+ return fallback