Spaces:
Runtime error
Runtime error
Zhen Ye Claude Opus 4.6 (1M context) commited on
Commit Β·
29c2d5f
1
Parent(s): 880e261
feat: add GET /inspect/explain endpoint for multi-LLM interpretability
Browse filesCo-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- inspection/router.py +129 -1
inspection/router.py
CHANGED
|
@@ -4,9 +4,10 @@ All endpoints are on-demand β they do not affect the main inference pipeline.
|
|
| 4 |
Endpoints are mounted at /inspect in app.py.
|
| 5 |
"""
|
| 6 |
|
|
|
|
| 7 |
import logging
|
| 8 |
from pathlib import Path
|
| 9 |
-
from typing import Optional
|
| 10 |
|
| 11 |
from fastapi import APIRouter, HTTPException, Query
|
| 12 |
from fastapi.responses import JSONResponse, Response
|
|
@@ -764,3 +765,130 @@ async def get_pointcloud(
|
|
| 764 |
)
|
| 765 |
|
| 766 |
return JSONResponse(result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
Endpoints are mounted at /inspect in app.py.
|
| 5 |
"""
|
| 6 |
|
| 7 |
+
import asyncio
|
| 8 |
import logging
|
| 9 |
from pathlib import Path
|
| 10 |
+
from typing import Dict, Optional
|
| 11 |
|
| 12 |
from fastapi import APIRouter, HTTPException, Query
|
| 13 |
from fastapi.responses import JSONResponse, Response
|
|
|
|
| 765 |
)
|
| 766 |
|
| 767 |
return JSONResponse(result)
|
| 768 |
+
|
| 769 |
+
|
| 770 |
+
# ββ Explainability (Multi-LLM) βββββββββββββββββββββββββββββββββββ
|
| 771 |
+
|
| 772 |
+
# Per-(job_id, track_id) locks to prevent duplicate concurrent LLM calls
|
| 773 |
+
_explain_locks: Dict[tuple, asyncio.Lock] = {}
|
| 774 |
+
|
| 775 |
+
|
| 776 |
+
def _get_explain_lock(job_id: str, track_id: str) -> asyncio.Lock:
|
| 777 |
+
"""Get or create an asyncio lock for a (job_id, track_id) pair."""
|
| 778 |
+
key = (job_id, track_id)
|
| 779 |
+
if key not in _explain_locks:
|
| 780 |
+
_explain_locks[key] = asyncio.Lock()
|
| 781 |
+
return _explain_locks[key]
|
| 782 |
+
|
| 783 |
+
|
| 784 |
+
@router.get("/explain/{job_id}/{track_id}")
|
| 785 |
+
async def explain_track(job_id: str, track_id: str):
|
| 786 |
+
"""Generate a multi-LLM interpretability tree for a tracked object.
|
| 787 |
+
|
| 788 |
+
Calls GPT-4o (primary) to generate a hierarchical feature tree,
|
| 789 |
+
then Claude + Gemini (validators) in parallel to validate each feature.
|
| 790 |
+
Results are cached per (job_id, track_id).
|
| 791 |
+
"""
|
| 792 |
+
from jobs.storage import get_explanation, set_explanation
|
| 793 |
+
from models.isr.utils import crop_and_encode, encode_frame
|
| 794 |
+
from inspection.frames import extract_frame
|
| 795 |
+
|
| 796 |
+
job = _get_job_or_404(job_id)
|
| 797 |
+
|
| 798 |
+
# Check cache first
|
| 799 |
+
cached = get_explanation(job_id, track_id)
|
| 800 |
+
if cached:
|
| 801 |
+
return JSONResponse(cached)
|
| 802 |
+
|
| 803 |
+
# Acquire per-track lock to prevent duplicate LLM calls
|
| 804 |
+
lock = _get_explain_lock(job_id, track_id)
|
| 805 |
+
async with lock:
|
| 806 |
+
# Re-check cache after acquiring lock
|
| 807 |
+
cached = get_explanation(job_id, track_id)
|
| 808 |
+
if cached:
|
| 809 |
+
return JSONResponse(cached)
|
| 810 |
+
|
| 811 |
+
# Validate OpenAI key is available
|
| 812 |
+
import os
|
| 813 |
+
if not os.environ.get("OPENAI_API_KEY"):
|
| 814 |
+
raise HTTPException(status_code=503, detail="OpenAI API key not configured")
|
| 815 |
+
|
| 816 |
+
storage = get_job_storage()
|
| 817 |
+
|
| 818 |
+
# Parse track_id
|
| 819 |
+
instance_id = _parse_track_id(track_id)
|
| 820 |
+
|
| 821 |
+
# Find the best frame for this track (largest bbox area)
|
| 822 |
+
best_frame_idx = None
|
| 823 |
+
best_area = 0
|
| 824 |
+
best_track = None
|
| 825 |
+
|
| 826 |
+
with storage._lock:
|
| 827 |
+
frames = storage._tracks.get(job_id, {})
|
| 828 |
+
for fidx in sorted(frames.keys(), reverse=True):
|
| 829 |
+
for det in frames[fidx]:
|
| 830 |
+
tid = det.get("instance_id")
|
| 831 |
+
tid_str = det.get("track_id")
|
| 832 |
+
if (tid is not None and tid == instance_id) or tid_str == track_id:
|
| 833 |
+
bbox = det.get("bbox")
|
| 834 |
+
if bbox:
|
| 835 |
+
area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
|
| 836 |
+
if area > best_area:
|
| 837 |
+
best_area = area
|
| 838 |
+
best_frame_idx = fidx
|
| 839 |
+
best_track = dict(det)
|
| 840 |
+
|
| 841 |
+
if best_frame_idx is None or best_track is None:
|
| 842 |
+
raise HTTPException(status_code=404, detail=f"Track {track_id} not found in any frame.")
|
| 843 |
+
|
| 844 |
+
# Extract frame
|
| 845 |
+
input_path = job.input_video_path
|
| 846 |
+
if not input_path or not Path(input_path).exists():
|
| 847 |
+
raise HTTPException(status_code=404, detail="Input video not found on disk.")
|
| 848 |
+
|
| 849 |
+
frame = await asyncio.to_thread(extract_frame, input_path, best_frame_idx)
|
| 850 |
+
|
| 851 |
+
# Encode images
|
| 852 |
+
crop_b64 = crop_and_encode(frame, best_track["bbox"], max_dim=512, quality=80)
|
| 853 |
+
if not crop_b64:
|
| 854 |
+
raise HTTPException(status_code=422, detail="Failed to crop track from frame.")
|
| 855 |
+
|
| 856 |
+
frame_b64 = encode_frame(frame, max_dim=1024, quality=70)
|
| 857 |
+
if not frame_b64:
|
| 858 |
+
raise HTTPException(status_code=422, detail="Failed to encode frame.")
|
| 859 |
+
|
| 860 |
+
# Get mission query (job.queries is List[str])
|
| 861 |
+
mission = ", ".join(job.queries) if job.queries else "general surveillance"
|
| 862 |
+
|
| 863 |
+
# Build metadata
|
| 864 |
+
metadata = {
|
| 865 |
+
"label": best_track.get("label", "unknown"),
|
| 866 |
+
"score": best_track.get("score", 0),
|
| 867 |
+
"speed_kph": best_track.get("speed_kph", 0),
|
| 868 |
+
"direction_clock": best_track.get("direction_clock", "unknown"),
|
| 869 |
+
"depth_rel": best_track.get("depth_rel"),
|
| 870 |
+
"depth_est_m": best_track.get("depth_est_m"),
|
| 871 |
+
"angle_deg": best_track.get("angle_deg"),
|
| 872 |
+
"bbox": best_track.get("bbox"),
|
| 873 |
+
}
|
| 874 |
+
|
| 875 |
+
# Run explainer
|
| 876 |
+
from models.isr.explainer import ISRExplainer
|
| 877 |
+
|
| 878 |
+
explainer = ISRExplainer()
|
| 879 |
+
try:
|
| 880 |
+
result = await asyncio.wait_for(
|
| 881 |
+
explainer.explain(crop_b64, frame_b64, metadata, mission),
|
| 882 |
+
timeout=30.0,
|
| 883 |
+
)
|
| 884 |
+
except asyncio.TimeoutError:
|
| 885 |
+
raise HTTPException(status_code=504, detail="Explanation timed out (30s)")
|
| 886 |
+
except ValueError as e:
|
| 887 |
+
raise HTTPException(status_code=502, detail=str(e))
|
| 888 |
+
|
| 889 |
+
# Add track_id to result
|
| 890 |
+
result["track_id"] = track_id
|
| 891 |
+
|
| 892 |
+
# Cache and return
|
| 893 |
+
set_explanation(job_id, track_id, result)
|
| 894 |
+
return JSONResponse(result)
|