Zhen Ye Claude Opus 4.6 (1M context) commited on
Commit
29c2d5f
Β·
1 Parent(s): 880e261

feat: add GET /inspect/explain endpoint for multi-LLM interpretability

Browse files

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Files changed (1) hide show
  1. inspection/router.py +129 -1
inspection/router.py CHANGED
@@ -4,9 +4,10 @@ All endpoints are on-demand β€” they do not affect the main inference pipeline.
4
  Endpoints are mounted at /inspect in app.py.
5
  """
6
 
 
7
  import logging
8
  from pathlib import Path
9
- from typing import Optional
10
 
11
  from fastapi import APIRouter, HTTPException, Query
12
  from fastapi.responses import JSONResponse, Response
@@ -764,3 +765,130 @@ async def get_pointcloud(
764
  )
765
 
766
  return JSONResponse(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  Endpoints are mounted at /inspect in app.py.
5
  """
6
 
7
+ import asyncio
8
  import logging
9
  from pathlib import Path
10
+ from typing import Dict, Optional
11
 
12
  from fastapi import APIRouter, HTTPException, Query
13
  from fastapi.responses import JSONResponse, Response
 
765
  )
766
 
767
  return JSONResponse(result)
768
+
769
+
770
+ # ── Explainability (Multi-LLM) ───────────────────────────────────
771
+
772
+ # Per-(job_id, track_id) locks to prevent duplicate concurrent LLM calls
773
+ _explain_locks: Dict[tuple, asyncio.Lock] = {}
774
+
775
+
776
+ def _get_explain_lock(job_id: str, track_id: str) -> asyncio.Lock:
777
+ """Get or create an asyncio lock for a (job_id, track_id) pair."""
778
+ key = (job_id, track_id)
779
+ if key not in _explain_locks:
780
+ _explain_locks[key] = asyncio.Lock()
781
+ return _explain_locks[key]
782
+
783
+
784
+ @router.get("/explain/{job_id}/{track_id}")
785
+ async def explain_track(job_id: str, track_id: str):
786
+ """Generate a multi-LLM interpretability tree for a tracked object.
787
+
788
+ Calls GPT-4o (primary) to generate a hierarchical feature tree,
789
+ then Claude + Gemini (validators) in parallel to validate each feature.
790
+ Results are cached per (job_id, track_id).
791
+ """
792
+ from jobs.storage import get_explanation, set_explanation
793
+ from models.isr.utils import crop_and_encode, encode_frame
794
+ from inspection.frames import extract_frame
795
+
796
+ job = _get_job_or_404(job_id)
797
+
798
+ # Check cache first
799
+ cached = get_explanation(job_id, track_id)
800
+ if cached:
801
+ return JSONResponse(cached)
802
+
803
+ # Acquire per-track lock to prevent duplicate LLM calls
804
+ lock = _get_explain_lock(job_id, track_id)
805
+ async with lock:
806
+ # Re-check cache after acquiring lock
807
+ cached = get_explanation(job_id, track_id)
808
+ if cached:
809
+ return JSONResponse(cached)
810
+
811
+ # Validate OpenAI key is available
812
+ import os
813
+ if not os.environ.get("OPENAI_API_KEY"):
814
+ raise HTTPException(status_code=503, detail="OpenAI API key not configured")
815
+
816
+ storage = get_job_storage()
817
+
818
+ # Parse track_id
819
+ instance_id = _parse_track_id(track_id)
820
+
821
+ # Find the best frame for this track (largest bbox area)
822
+ best_frame_idx = None
823
+ best_area = 0
824
+ best_track = None
825
+
826
+ with storage._lock:
827
+ frames = storage._tracks.get(job_id, {})
828
+ for fidx in sorted(frames.keys(), reverse=True):
829
+ for det in frames[fidx]:
830
+ tid = det.get("instance_id")
831
+ tid_str = det.get("track_id")
832
+ if (tid is not None and tid == instance_id) or tid_str == track_id:
833
+ bbox = det.get("bbox")
834
+ if bbox:
835
+ area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
836
+ if area > best_area:
837
+ best_area = area
838
+ best_frame_idx = fidx
839
+ best_track = dict(det)
840
+
841
+ if best_frame_idx is None or best_track is None:
842
+ raise HTTPException(status_code=404, detail=f"Track {track_id} not found in any frame.")
843
+
844
+ # Extract frame
845
+ input_path = job.input_video_path
846
+ if not input_path or not Path(input_path).exists():
847
+ raise HTTPException(status_code=404, detail="Input video not found on disk.")
848
+
849
+ frame = await asyncio.to_thread(extract_frame, input_path, best_frame_idx)
850
+
851
+ # Encode images
852
+ crop_b64 = crop_and_encode(frame, best_track["bbox"], max_dim=512, quality=80)
853
+ if not crop_b64:
854
+ raise HTTPException(status_code=422, detail="Failed to crop track from frame.")
855
+
856
+ frame_b64 = encode_frame(frame, max_dim=1024, quality=70)
857
+ if not frame_b64:
858
+ raise HTTPException(status_code=422, detail="Failed to encode frame.")
859
+
860
+ # Get mission query (job.queries is List[str])
861
+ mission = ", ".join(job.queries) if job.queries else "general surveillance"
862
+
863
+ # Build metadata
864
+ metadata = {
865
+ "label": best_track.get("label", "unknown"),
866
+ "score": best_track.get("score", 0),
867
+ "speed_kph": best_track.get("speed_kph", 0),
868
+ "direction_clock": best_track.get("direction_clock", "unknown"),
869
+ "depth_rel": best_track.get("depth_rel"),
870
+ "depth_est_m": best_track.get("depth_est_m"),
871
+ "angle_deg": best_track.get("angle_deg"),
872
+ "bbox": best_track.get("bbox"),
873
+ }
874
+
875
+ # Run explainer
876
+ from models.isr.explainer import ISRExplainer
877
+
878
+ explainer = ISRExplainer()
879
+ try:
880
+ result = await asyncio.wait_for(
881
+ explainer.explain(crop_b64, frame_b64, metadata, mission),
882
+ timeout=30.0,
883
+ )
884
+ except asyncio.TimeoutError:
885
+ raise HTTPException(status_code=504, detail="Explanation timed out (30s)")
886
+ except ValueError as e:
887
+ raise HTTPException(status_code=502, detail=str(e))
888
+
889
+ # Add track_id to result
890
+ result["track_id"] = track_id
891
+
892
+ # Cache and return
893
+ set_explanation(job_id, track_id, result)
894
+ return JSONResponse(result)