Spaces:
Runtime error
Runtime error
Merge branch 'worktree-agent-afd6bcf7'
Browse files- inspection/router.py +129 -1
- jobs/storage.py +20 -0
- models/isr/explainer.py +331 -0
- requirements.txt +2 -0
inspection/router.py
CHANGED
|
@@ -4,9 +4,10 @@ All endpoints are on-demand — they do not affect the main inference pipeline.
|
|
| 4 |
Endpoints are mounted at /inspect in app.py.
|
| 5 |
"""
|
| 6 |
|
|
|
|
| 7 |
import logging
|
| 8 |
from pathlib import Path
|
| 9 |
-
from typing import Optional
|
| 10 |
|
| 11 |
from fastapi import APIRouter, HTTPException, Query
|
| 12 |
from fastapi.responses import JSONResponse, Response
|
|
@@ -764,3 +765,130 @@ async def get_pointcloud(
|
|
| 764 |
)
|
| 765 |
|
| 766 |
return JSONResponse(result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
Endpoints are mounted at /inspect in app.py.
|
| 5 |
"""
|
| 6 |
|
| 7 |
+
import asyncio
|
| 8 |
import logging
|
| 9 |
from pathlib import Path
|
| 10 |
+
from typing import Dict, Optional
|
| 11 |
|
| 12 |
from fastapi import APIRouter, HTTPException, Query
|
| 13 |
from fastapi.responses import JSONResponse, Response
|
|
|
|
| 765 |
)
|
| 766 |
|
| 767 |
return JSONResponse(result)
|
| 768 |
+
|
| 769 |
+
|
| 770 |
+
# ── Explainability (Multi-LLM) ───────────────────────────────────
|
| 771 |
+
|
| 772 |
+
# Per-(job_id, track_id) locks to prevent duplicate concurrent LLM calls
|
| 773 |
+
_explain_locks: Dict[tuple, asyncio.Lock] = {}
|
| 774 |
+
|
| 775 |
+
|
| 776 |
+
def _get_explain_lock(job_id: str, track_id: str) -> asyncio.Lock:
|
| 777 |
+
"""Get or create an asyncio lock for a (job_id, track_id) pair."""
|
| 778 |
+
key = (job_id, track_id)
|
| 779 |
+
if key not in _explain_locks:
|
| 780 |
+
_explain_locks[key] = asyncio.Lock()
|
| 781 |
+
return _explain_locks[key]
|
| 782 |
+
|
| 783 |
+
|
| 784 |
+
@router.get("/explain/{job_id}/{track_id}")
|
| 785 |
+
async def explain_track(job_id: str, track_id: str):
|
| 786 |
+
"""Generate a multi-LLM interpretability tree for a tracked object.
|
| 787 |
+
|
| 788 |
+
Calls GPT-4o (primary) to generate a hierarchical feature tree,
|
| 789 |
+
then Claude + Gemini (validators) in parallel to validate each feature.
|
| 790 |
+
Results are cached per (job_id, track_id).
|
| 791 |
+
"""
|
| 792 |
+
from jobs.storage import get_explanation, set_explanation
|
| 793 |
+
from models.isr.utils import crop_and_encode, encode_frame
|
| 794 |
+
from inspection.frames import extract_frame
|
| 795 |
+
|
| 796 |
+
job = _get_job_or_404(job_id)
|
| 797 |
+
|
| 798 |
+
# Check cache first
|
| 799 |
+
cached = get_explanation(job_id, track_id)
|
| 800 |
+
if cached:
|
| 801 |
+
return JSONResponse(cached)
|
| 802 |
+
|
| 803 |
+
# Acquire per-track lock to prevent duplicate LLM calls
|
| 804 |
+
lock = _get_explain_lock(job_id, track_id)
|
| 805 |
+
async with lock:
|
| 806 |
+
# Re-check cache after acquiring lock
|
| 807 |
+
cached = get_explanation(job_id, track_id)
|
| 808 |
+
if cached:
|
| 809 |
+
return JSONResponse(cached)
|
| 810 |
+
|
| 811 |
+
# Validate OpenAI key is available
|
| 812 |
+
import os
|
| 813 |
+
if not os.environ.get("OPENAI_API_KEY"):
|
| 814 |
+
raise HTTPException(status_code=503, detail="OpenAI API key not configured")
|
| 815 |
+
|
| 816 |
+
storage = get_job_storage()
|
| 817 |
+
|
| 818 |
+
# Parse track_id
|
| 819 |
+
instance_id = _parse_track_id(track_id)
|
| 820 |
+
|
| 821 |
+
# Find the best frame for this track (largest bbox area)
|
| 822 |
+
best_frame_idx = None
|
| 823 |
+
best_area = 0
|
| 824 |
+
best_track = None
|
| 825 |
+
|
| 826 |
+
with storage._lock:
|
| 827 |
+
frames = storage._tracks.get(job_id, {})
|
| 828 |
+
for fidx in sorted(frames.keys(), reverse=True):
|
| 829 |
+
for det in frames[fidx]:
|
| 830 |
+
tid = det.get("instance_id")
|
| 831 |
+
tid_str = det.get("track_id")
|
| 832 |
+
if (tid is not None and tid == instance_id) or tid_str == track_id:
|
| 833 |
+
bbox = det.get("bbox")
|
| 834 |
+
if bbox:
|
| 835 |
+
area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
|
| 836 |
+
if area > best_area:
|
| 837 |
+
best_area = area
|
| 838 |
+
best_frame_idx = fidx
|
| 839 |
+
best_track = dict(det)
|
| 840 |
+
|
| 841 |
+
if best_frame_idx is None or best_track is None:
|
| 842 |
+
raise HTTPException(status_code=404, detail=f"Track {track_id} not found in any frame.")
|
| 843 |
+
|
| 844 |
+
# Extract frame
|
| 845 |
+
input_path = job.input_video_path
|
| 846 |
+
if not input_path or not Path(input_path).exists():
|
| 847 |
+
raise HTTPException(status_code=404, detail="Input video not found on disk.")
|
| 848 |
+
|
| 849 |
+
frame = await asyncio.to_thread(extract_frame, input_path, best_frame_idx)
|
| 850 |
+
|
| 851 |
+
# Encode images
|
| 852 |
+
crop_b64 = crop_and_encode(frame, best_track["bbox"], max_dim=512, quality=80)
|
| 853 |
+
if not crop_b64:
|
| 854 |
+
raise HTTPException(status_code=422, detail="Failed to crop track from frame.")
|
| 855 |
+
|
| 856 |
+
frame_b64 = encode_frame(frame, max_dim=1024, quality=70)
|
| 857 |
+
if not frame_b64:
|
| 858 |
+
raise HTTPException(status_code=422, detail="Failed to encode frame.")
|
| 859 |
+
|
| 860 |
+
# Get mission query (job.queries is List[str])
|
| 861 |
+
mission = ", ".join(job.queries) if job.queries else "general surveillance"
|
| 862 |
+
|
| 863 |
+
# Build metadata
|
| 864 |
+
metadata = {
|
| 865 |
+
"label": best_track.get("label", "unknown"),
|
| 866 |
+
"score": best_track.get("score", 0),
|
| 867 |
+
"speed_kph": best_track.get("speed_kph", 0),
|
| 868 |
+
"direction_clock": best_track.get("direction_clock", "unknown"),
|
| 869 |
+
"depth_rel": best_track.get("depth_rel"),
|
| 870 |
+
"depth_est_m": best_track.get("depth_est_m"),
|
| 871 |
+
"angle_deg": best_track.get("angle_deg"),
|
| 872 |
+
"bbox": best_track.get("bbox"),
|
| 873 |
+
}
|
| 874 |
+
|
| 875 |
+
# Run explainer
|
| 876 |
+
from models.isr.explainer import ISRExplainer
|
| 877 |
+
|
| 878 |
+
explainer = ISRExplainer()
|
| 879 |
+
try:
|
| 880 |
+
result = await asyncio.wait_for(
|
| 881 |
+
explainer.explain(crop_b64, frame_b64, metadata, mission),
|
| 882 |
+
timeout=30.0,
|
| 883 |
+
)
|
| 884 |
+
except asyncio.TimeoutError:
|
| 885 |
+
raise HTTPException(status_code=504, detail="Explanation timed out (30s)")
|
| 886 |
+
except ValueError as e:
|
| 887 |
+
raise HTTPException(status_code=502, detail=str(e))
|
| 888 |
+
|
| 889 |
+
# Add track_id to result
|
| 890 |
+
result["track_id"] = track_id
|
| 891 |
+
|
| 892 |
+
# Cache and return
|
| 893 |
+
set_explanation(job_id, track_id, result)
|
| 894 |
+
return JSONResponse(result)
|
jobs/storage.py
CHANGED
|
@@ -42,6 +42,7 @@ class JobStorage:
|
|
| 42 |
self._latest_frames: Dict[str, any] = {} # job_id -> np.ndarray
|
| 43 |
self._mask_data: Dict[str, Dict[str, any]] = {} # job_id -> {f"{frame_idx}:{track_id}" -> rle_dict}
|
| 44 |
self._mission_verdicts: Dict[str, Dict[str, bool]] = {} # job_id -> {track_id -> mission_relevant}
|
|
|
|
| 45 |
self._lock = RLock()
|
| 46 |
|
| 47 |
def create(self, job: JobInfo) -> None:
|
|
@@ -106,6 +107,18 @@ class JobStorage:
|
|
| 106 |
with self._lock:
|
| 107 |
return dict(self._mission_verdicts.get(job_id, {}))
|
| 108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
def get_all_masks_for_frame(self, job_id: str, frame_idx: int) -> dict:
|
| 110 |
"""Return {track_id: rle_dict} for all objects in a frame."""
|
| 111 |
with self._lock:
|
|
@@ -136,6 +149,7 @@ class JobStorage:
|
|
| 136 |
self._latest_frames.pop(job_id, None)
|
| 137 |
self._mask_data.pop(job_id, None)
|
| 138 |
self._mission_verdicts.pop(job_id, None)
|
|
|
|
| 139 |
shutil.rmtree(get_job_directory(job_id), ignore_errors=True)
|
| 140 |
|
| 141 |
def cleanup_expired(self, max_age: timedelta) -> None:
|
|
@@ -181,3 +195,9 @@ def get_mask_data(job_id: str, frame_idx: int, track_id: int) -> Optional[dict]:
|
|
| 181 |
|
| 182 |
def get_all_masks_for_frame(job_id: str, frame_idx: int) -> dict:
|
| 183 |
return get_job_storage().get_all_masks_for_frame(job_id, frame_idx)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
self._latest_frames: Dict[str, any] = {} # job_id -> np.ndarray
|
| 43 |
self._mask_data: Dict[str, Dict[str, any]] = {} # job_id -> {f"{frame_idx}:{track_id}" -> rle_dict}
|
| 44 |
self._mission_verdicts: Dict[str, Dict[str, bool]] = {} # job_id -> {track_id -> mission_relevant}
|
| 45 |
+
self._explanations: Dict[str, Dict[str, dict]] = {} # job_id -> {track_id -> explanation}
|
| 46 |
self._lock = RLock()
|
| 47 |
|
| 48 |
def create(self, job: JobInfo) -> None:
|
|
|
|
| 107 |
with self._lock:
|
| 108 |
return dict(self._mission_verdicts.get(job_id, {}))
|
| 109 |
|
| 110 |
+
def set_explanation(self, job_id: str, track_id: str, data: dict) -> None:
|
| 111 |
+
"""Cache an explanation result for a track."""
|
| 112 |
+
with self._lock:
|
| 113 |
+
if job_id not in self._explanations:
|
| 114 |
+
self._explanations[job_id] = {}
|
| 115 |
+
self._explanations[job_id][track_id] = data
|
| 116 |
+
|
| 117 |
+
def get_explanation(self, job_id: str, track_id: str) -> Optional[dict]:
|
| 118 |
+
"""Retrieve cached explanation for a track."""
|
| 119 |
+
with self._lock:
|
| 120 |
+
return self._explanations.get(job_id, {}).get(track_id)
|
| 121 |
+
|
| 122 |
def get_all_masks_for_frame(self, job_id: str, frame_idx: int) -> dict:
|
| 123 |
"""Return {track_id: rle_dict} for all objects in a frame."""
|
| 124 |
with self._lock:
|
|
|
|
| 149 |
self._latest_frames.pop(job_id, None)
|
| 150 |
self._mask_data.pop(job_id, None)
|
| 151 |
self._mission_verdicts.pop(job_id, None)
|
| 152 |
+
self._explanations.pop(job_id, None)
|
| 153 |
shutil.rmtree(get_job_directory(job_id), ignore_errors=True)
|
| 154 |
|
| 155 |
def cleanup_expired(self, max_age: timedelta) -> None:
|
|
|
|
| 195 |
|
| 196 |
def get_all_masks_for_frame(job_id: str, frame_idx: int) -> dict:
|
| 197 |
return get_job_storage().get_all_masks_for_frame(job_id, frame_idx)
|
| 198 |
+
|
| 199 |
+
def get_explanation(job_id: str, track_id: str) -> Optional[dict]:
|
| 200 |
+
return get_job_storage().get_explanation(job_id, track_id)
|
| 201 |
+
|
| 202 |
+
def set_explanation(job_id: str, track_id: str, data: dict) -> None:
|
| 203 |
+
get_job_storage().set_explanation(job_id, track_id, data)
|
models/isr/explainer.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Multi-LLM Explainability Pipeline.
|
| 2 |
+
|
| 3 |
+
Orchestrates GPT-4o (primary analyzer) + Claude & Gemini (validators)
|
| 4 |
+
to produce a hierarchical feature tree explaining why an object was
|
| 5 |
+
classified as mission-relevant.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import asyncio
|
| 9 |
+
import json
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
from typing import Optional
|
| 13 |
+
|
| 14 |
+
from models.isr.utils import crop_and_encode, encode_frame, parse_llm_json
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
# Category color map (synced with frontend)
|
| 19 |
+
_CATEGORY_COLORS = {
|
| 20 |
+
"Structure": "#3b82f6",
|
| 21 |
+
"Function": "#06b6d4",
|
| 22 |
+
"Material": "#f59e0b",
|
| 23 |
+
"Color": "#ef4444",
|
| 24 |
+
"Size": "#10b981",
|
| 25 |
+
"Type": "#8b5cf6",
|
| 26 |
+
"Motion": "#ec4899",
|
| 27 |
+
"Context": "#64748b",
|
| 28 |
+
"Shape": "#f97316",
|
| 29 |
+
"Markings": "#a855f7",
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
_PRIMARY_SYSTEM_PROMPT = """You are an ISR (Intelligence, Surveillance, Reconnaissance) analyst explaining WHY a detected object matches or does not match a mission objective.
|
| 33 |
+
|
| 34 |
+
You will receive:
|
| 35 |
+
- A cropped image of the detected object
|
| 36 |
+
- The full frame showing spatial context
|
| 37 |
+
- Detection metadata (label, confidence, speed, depth, direction)
|
| 38 |
+
- The mission objective
|
| 39 |
+
|
| 40 |
+
Analyze the object and produce a HIERARCHICAL FEATURE TREE explaining the key visual and functional features that led to the classification.
|
| 41 |
+
|
| 42 |
+
Return ONLY a JSON object (no markdown, no explanation) with this exact structure:
|
| 43 |
+
{
|
| 44 |
+
"object": "<detected class label>",
|
| 45 |
+
"satisfies": true/false/null,
|
| 46 |
+
"confidence": 0.0-1.0,
|
| 47 |
+
"reasoning_summary": "<1-2 sentence summary>",
|
| 48 |
+
"categories": [
|
| 49 |
+
{
|
| 50 |
+
"name": "<category name>",
|
| 51 |
+
"features": [
|
| 52 |
+
{
|
| 53 |
+
"name": "<feature name>",
|
| 54 |
+
"value": true/false,
|
| 55 |
+
"reasoning": "<1 sentence explaining this observation>"
|
| 56 |
+
}
|
| 57 |
+
]
|
| 58 |
+
}
|
| 59 |
+
]
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
Rules:
|
| 63 |
+
- Pick 3-6 categories most relevant to THIS SPECIFIC object from: Structure, Function, Material, Color, Size, Type, Motion, Context, Shape, Markings
|
| 64 |
+
- Each category should have 1-4 features (total 5-20 features across all categories)
|
| 65 |
+
- Features must be VISUAL OBSERVATIONS from the image, not assumptions
|
| 66 |
+
- Be specific and expert-level (a program manager should find this insightful)
|
| 67 |
+
- confidence reflects how certain you are about the overall assessment"""
|
| 68 |
+
|
| 69 |
+
_VALIDATOR_SYSTEM_PROMPT = """You are an ISR analyst reviewing another analyst's feature assessment of a detected object.
|
| 70 |
+
|
| 71 |
+
You will receive:
|
| 72 |
+
- The same cropped image and full frame
|
| 73 |
+
- Detection metadata
|
| 74 |
+
- The primary analyst's hierarchical feature tree
|
| 75 |
+
|
| 76 |
+
Your job: independently validate each feature by examining the images yourself.
|
| 77 |
+
|
| 78 |
+
Return ONLY a JSON object (no markdown) with this structure:
|
| 79 |
+
{
|
| 80 |
+
"agreement": true/false,
|
| 81 |
+
"confidence": 0.0-1.0,
|
| 82 |
+
"feature_validations": {
|
| 83 |
+
"CategoryName/FeatureName": {
|
| 84 |
+
"agree": true/false,
|
| 85 |
+
"note": "<brief observation>"
|
| 86 |
+
}
|
| 87 |
+
}
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
Rules:
|
| 91 |
+
- Validate EVERY feature in the tree
|
| 92 |
+
- Use the key format "CategoryName/FeatureName" exactly
|
| 93 |
+
- Be honest — disagree when the image doesn't support the claim
|
| 94 |
+
- Keep notes to 1 sentence"""
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class ISRExplainer:
|
| 98 |
+
"""Orchestrates multi-LLM explanation pipeline for a single track."""
|
| 99 |
+
|
| 100 |
+
def __init__(self):
|
| 101 |
+
self._openai_client = None
|
| 102 |
+
self._anthropic_client = None
|
| 103 |
+
|
| 104 |
+
def _get_openai(self):
|
| 105 |
+
if self._openai_client is None:
|
| 106 |
+
import openai
|
| 107 |
+
key = os.environ.get("OPENAI_API_KEY")
|
| 108 |
+
if not key:
|
| 109 |
+
raise ValueError("OPENAI_API_KEY not set")
|
| 110 |
+
self._openai_client = openai.OpenAI(api_key=key)
|
| 111 |
+
return self._openai_client
|
| 112 |
+
|
| 113 |
+
def _get_anthropic(self):
|
| 114 |
+
if self._anthropic_client is None:
|
| 115 |
+
import anthropic
|
| 116 |
+
key = os.environ.get("ANTHROPIC_API_KEY")
|
| 117 |
+
if not key:
|
| 118 |
+
return None
|
| 119 |
+
self._anthropic_client = anthropic.Anthropic(api_key=key)
|
| 120 |
+
return self._anthropic_client
|
| 121 |
+
|
| 122 |
+
async def explain(
|
| 123 |
+
self,
|
| 124 |
+
crop_b64: str,
|
| 125 |
+
frame_b64: str,
|
| 126 |
+
metadata: dict,
|
| 127 |
+
mission: str,
|
| 128 |
+
) -> dict:
|
| 129 |
+
"""Run the full 3-LLM explanation pipeline.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
crop_b64: Base64-encoded JPEG of the cropped ROI.
|
| 133 |
+
frame_b64: Base64-encoded JPEG of the full frame.
|
| 134 |
+
metadata: Detection metadata dict (label, score, speed_kph, etc.).
|
| 135 |
+
mission: Mission objective string.
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
Merged explanation tree with consensus data.
|
| 139 |
+
"""
|
| 140 |
+
# Step 1: GPT-4o primary analysis
|
| 141 |
+
primary_tree = await self._call_gpt(crop_b64, frame_b64, metadata, mission)
|
| 142 |
+
if primary_tree is None:
|
| 143 |
+
raise ValueError("Primary GPT-4o analysis failed")
|
| 144 |
+
|
| 145 |
+
# Step 2: Claude + Gemini validation in parallel
|
| 146 |
+
claude_result, gemini_result = await asyncio.gather(
|
| 147 |
+
self._call_claude(crop_b64, frame_b64, metadata, mission, primary_tree),
|
| 148 |
+
self._call_gemini(crop_b64, frame_b64, metadata, mission, primary_tree),
|
| 149 |
+
return_exceptions=True,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# Handle exceptions from validators
|
| 153 |
+
if isinstance(claude_result, Exception):
|
| 154 |
+
logger.warning("Claude validation failed: %s", claude_result)
|
| 155 |
+
claude_result = None
|
| 156 |
+
if isinstance(gemini_result, Exception):
|
| 157 |
+
logger.warning("Gemini validation failed: %s", gemini_result)
|
| 158 |
+
gemini_result = None
|
| 159 |
+
|
| 160 |
+
# Step 3: Merge into consensus tree
|
| 161 |
+
return self._merge(primary_tree, claude_result, gemini_result)
|
| 162 |
+
|
| 163 |
+
async def _call_gpt(self, crop_b64: str, frame_b64: str, metadata: dict, mission: str) -> Optional[dict]:
|
| 164 |
+
"""Call GPT-4o to generate the primary feature tree."""
|
| 165 |
+
try:
|
| 166 |
+
client = self._get_openai()
|
| 167 |
+
user_text = self._build_metadata_text(metadata, mission)
|
| 168 |
+
|
| 169 |
+
response = await asyncio.to_thread(
|
| 170 |
+
client.chat.completions.create,
|
| 171 |
+
model="gpt-4o",
|
| 172 |
+
messages=[
|
| 173 |
+
{"role": "system", "content": _PRIMARY_SYSTEM_PROMPT},
|
| 174 |
+
{"role": "user", "content": [
|
| 175 |
+
{"type": "text", "text": user_text},
|
| 176 |
+
{"type": "text", "text": "\n[Cropped object]:"},
|
| 177 |
+
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{crop_b64}", "detail": "high"}},
|
| 178 |
+
{"type": "text", "text": "\n[Full frame context]:"},
|
| 179 |
+
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{frame_b64}", "detail": "low"}},
|
| 180 |
+
]},
|
| 181 |
+
],
|
| 182 |
+
max_tokens=2048,
|
| 183 |
+
temperature=0.3,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
raw = response.choices[0].message.content
|
| 187 |
+
return parse_llm_json(raw)
|
| 188 |
+
except Exception:
|
| 189 |
+
logger.exception("GPT-4o primary analysis failed")
|
| 190 |
+
return None
|
| 191 |
+
|
| 192 |
+
async def _call_claude(self, crop_b64: str, frame_b64: str, metadata: dict, mission: str, tree: dict) -> Optional[dict]:
|
| 193 |
+
"""Call Claude to validate the primary tree."""
|
| 194 |
+
client = self._get_anthropic()
|
| 195 |
+
if client is None:
|
| 196 |
+
logger.info("Skipping Claude validation — ANTHROPIC_API_KEY not set")
|
| 197 |
+
return None
|
| 198 |
+
|
| 199 |
+
try:
|
| 200 |
+
user_text = self._build_metadata_text(metadata, mission)
|
| 201 |
+
user_text += f"\n\nPrimary analyst's feature tree:\n```json\n{json.dumps(tree, indent=2)}\n```"
|
| 202 |
+
|
| 203 |
+
response = await asyncio.to_thread(
|
| 204 |
+
client.messages.create,
|
| 205 |
+
model="claude-sonnet-4-20250514",
|
| 206 |
+
max_tokens=1024,
|
| 207 |
+
system=_VALIDATOR_SYSTEM_PROMPT,
|
| 208 |
+
messages=[{
|
| 209 |
+
"role": "user",
|
| 210 |
+
"content": [
|
| 211 |
+
{"type": "text", "text": user_text},
|
| 212 |
+
{"type": "text", "text": "\n[Cropped object]:"},
|
| 213 |
+
{"type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": crop_b64}},
|
| 214 |
+
{"type": "text", "text": "\n[Full frame context]:"},
|
| 215 |
+
{"type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": frame_b64}},
|
| 216 |
+
],
|
| 217 |
+
}],
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
raw = response.content[0].text
|
| 221 |
+
return parse_llm_json(raw)
|
| 222 |
+
except Exception:
|
| 223 |
+
logger.exception("Claude validation failed")
|
| 224 |
+
return None
|
| 225 |
+
|
| 226 |
+
async def _call_gemini(self, crop_b64: str, frame_b64: str, metadata: dict, mission: str, tree: dict) -> Optional[dict]:
|
| 227 |
+
"""Call Gemini to validate the primary tree."""
|
| 228 |
+
api_key = os.environ.get("GEMINI_API_KEY")
|
| 229 |
+
if not api_key:
|
| 230 |
+
logger.info("Skipping Gemini validation — GEMINI_API_KEY not set")
|
| 231 |
+
return None
|
| 232 |
+
|
| 233 |
+
try:
|
| 234 |
+
import base64
|
| 235 |
+
import google.generativeai as genai
|
| 236 |
+
|
| 237 |
+
genai.configure(api_key=api_key)
|
| 238 |
+
model = genai.GenerativeModel("gemini-2.0-flash")
|
| 239 |
+
|
| 240 |
+
user_text = self._build_metadata_text(metadata, mission)
|
| 241 |
+
user_text += f"\n\nPrimary analyst's feature tree:\n```json\n{json.dumps(tree, indent=2)}\n```"
|
| 242 |
+
|
| 243 |
+
# Decode images for Gemini
|
| 244 |
+
crop_bytes = base64.b64decode(crop_b64)
|
| 245 |
+
frame_bytes = base64.b64decode(frame_b64)
|
| 246 |
+
|
| 247 |
+
response = await asyncio.to_thread(
|
| 248 |
+
model.generate_content,
|
| 249 |
+
[
|
| 250 |
+
_VALIDATOR_SYSTEM_PROMPT + "\n\n" + user_text,
|
| 251 |
+
{"mime_type": "image/jpeg", "data": crop_bytes},
|
| 252 |
+
"\n[Full frame context]:",
|
| 253 |
+
{"mime_type": "image/jpeg", "data": frame_bytes},
|
| 254 |
+
],
|
| 255 |
+
generation_config=genai.GenerationConfig(
|
| 256 |
+
max_output_tokens=1024,
|
| 257 |
+
temperature=0.3,
|
| 258 |
+
),
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
raw = response.text
|
| 262 |
+
return parse_llm_json(raw)
|
| 263 |
+
except Exception:
|
| 264 |
+
logger.exception("Gemini validation failed")
|
| 265 |
+
return None
|
| 266 |
+
|
| 267 |
+
def _build_metadata_text(self, metadata: dict, mission: str) -> str:
|
| 268 |
+
"""Build the text portion describing the detection."""
|
| 269 |
+
lines = [
|
| 270 |
+
f'Mission: "{mission}"',
|
| 271 |
+
"",
|
| 272 |
+
"Detection metadata:",
|
| 273 |
+
f"- Label: {metadata.get('label', 'unknown')}",
|
| 274 |
+
f"- Confidence: {metadata.get('score', 0):.2f}",
|
| 275 |
+
f"- Speed: {metadata.get('speed_kph', 0):.1f} kph",
|
| 276 |
+
f"- Direction: {metadata.get('direction_clock', 'unknown')}",
|
| 277 |
+
f"- Depth (relative): {metadata.get('depth_rel', 'N/A')}",
|
| 278 |
+
f"- Depth (estimated): {metadata.get('depth_est_m', 'N/A')}m",
|
| 279 |
+
f"- Angle: {metadata.get('angle_deg', 'N/A')}°",
|
| 280 |
+
]
|
| 281 |
+
bbox = metadata.get("bbox")
|
| 282 |
+
if bbox:
|
| 283 |
+
bw = bbox[2] - bbox[0]
|
| 284 |
+
bh = bbox[3] - bbox[1]
|
| 285 |
+
lines.append(f"- Bounding box size: {bw}x{bh} px")
|
| 286 |
+
return "\n".join(lines)
|
| 287 |
+
|
| 288 |
+
def _merge(self, tree: dict, claude: Optional[dict], gemini: Optional[dict]) -> dict:
|
| 289 |
+
"""Merge primary tree with validator results into consensus output."""
|
| 290 |
+
validators_available = sum(1 for v in [claude, gemini] if v is not None)
|
| 291 |
+
total_features = 0
|
| 292 |
+
agreed = 0
|
| 293 |
+
|
| 294 |
+
for cat in tree.get("categories", []):
|
| 295 |
+
cat_name = cat.get("name", "")
|
| 296 |
+
cat["color"] = _CATEGORY_COLORS.get(cat_name, "#64748b")
|
| 297 |
+
|
| 298 |
+
for feat in cat.get("features", []):
|
| 299 |
+
total_features += 1
|
| 300 |
+
feat_key = f"{cat_name}/{feat['name']}"
|
| 301 |
+
validators = {}
|
| 302 |
+
feat_agreed = 0
|
| 303 |
+
|
| 304 |
+
if claude and "feature_validations" in claude:
|
| 305 |
+
cv = claude["feature_validations"].get(feat_key)
|
| 306 |
+
if cv:
|
| 307 |
+
validators["claude"] = cv
|
| 308 |
+
if cv.get("agree"):
|
| 309 |
+
feat_agreed += 1
|
| 310 |
+
|
| 311 |
+
if gemini and "feature_validations" in gemini:
|
| 312 |
+
gv = gemini["feature_validations"].get(feat_key)
|
| 313 |
+
if gv:
|
| 314 |
+
validators["gemini"] = gv
|
| 315 |
+
if gv.get("agree"):
|
| 316 |
+
feat_agreed += 1
|
| 317 |
+
|
| 318 |
+
feat["validators"] = validators
|
| 319 |
+
feat["consensus"] = feat_agreed
|
| 320 |
+
|
| 321 |
+
if validators_available > 0 and feat_agreed == validators_available:
|
| 322 |
+
agreed += 1
|
| 323 |
+
|
| 324 |
+
tree["consensus_bar"] = {
|
| 325 |
+
"total_features": total_features,
|
| 326 |
+
"agreed": agreed,
|
| 327 |
+
"disagreed": total_features - agreed,
|
| 328 |
+
"validators_available": validators_available,
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
return tree
|
requirements.txt
CHANGED
|
@@ -16,3 +16,5 @@ iopath>=0.1.10
|
|
| 16 |
psutil
|
| 17 |
dill
|
| 18 |
openai>=1.0.0
|
|
|
|
|
|
|
|
|
| 16 |
psutil
|
| 17 |
dill
|
| 18 |
openai>=1.0.0
|
| 19 |
+
anthropic>=0.40.0
|
| 20 |
+
google-generativeai>=0.8.0
|