ISR / models /isr /explainer.py
zye0616's picture
refactor: remove depth estimation pipeline entirely
1fad3ed
"""Multi-LLM Explainability Pipeline.
Orchestrates GPT-4o (primary analyzer) + Claude & Gemini (validators)
to produce a hierarchical feature tree explaining why an object was
classified as mission-relevant.
"""
import asyncio
import json
import logging
import os
from typing import Optional
from models.isr.utils import crop_and_encode, encode_frame, parse_llm_json
logger = logging.getLogger(__name__)
# Category color map (synced with frontend)
_CATEGORY_COLORS = {
"Structure": "#3b82f6",
"Function": "#06b6d4",
"Material": "#f59e0b",
"Color": "#ef4444",
"Size": "#10b981",
"Type": "#8b5cf6",
"Motion": "#ec4899",
"Context": "#64748b",
"Shape": "#f97316",
"Markings": "#a855f7",
}
_PRIMARY_SYSTEM_PROMPT = """You are an ISR (Intelligence, Surveillance, Reconnaissance) analyst explaining WHY a detected object matches or does not match a mission objective.
You will receive:
- A cropped image of the detected object
- The full frame showing spatial context
- Detection metadata (label, confidence, speed, depth, direction)
- The mission objective
Analyze the object and produce a HIERARCHICAL FEATURE TREE explaining the key visual and functional features that led to the classification.
Return ONLY a JSON object (no markdown, no explanation) with this exact structure:
{
"object": "<detected class label>",
"satisfies": true/false/null,
"confidence": 0.0-1.0,
"reasoning_summary": "<1-2 sentence summary>",
"categories": [
{
"name": "<category name>",
"features": [
{
"name": "<feature name>",
"value": true/false,
"reasoning": "<1 sentence explaining this observation>"
}
]
}
]
}
Rules:
- Pick 3-6 categories most relevant to THIS SPECIFIC object from: Structure, Function, Material, Color, Size, Type, Motion, Context, Shape, Markings
- Each category should have 1-4 features (total 5-20 features across all categories)
- Features must be VISUAL OBSERVATIONS from the image, not assumptions
- Be specific and expert-level (a program manager should find this insightful)
- confidence reflects how certain you are about the overall assessment"""
_VALIDATOR_SYSTEM_PROMPT = """You are an ISR analyst reviewing another analyst's feature assessment of a detected object.
You will receive:
- The same cropped image and full frame
- Detection metadata
- The primary analyst's hierarchical feature tree
Your job: independently validate each feature by examining the images yourself.
Return ONLY a JSON object (no markdown) with this structure:
{
"agreement": true/false,
"confidence": 0.0-1.0,
"feature_validations": {
"CategoryName/FeatureName": {
"agree": true/false,
"note": "<brief observation>"
}
}
}
Rules:
- Validate EVERY feature in the tree
- Use the key format "CategoryName/FeatureName" exactly
- Be honest — disagree when the image doesn't support the claim
- Keep notes to 1 sentence"""
class ISRExplainer:
"""Orchestrates multi-LLM explanation pipeline for a single track."""
def __init__(self):
self._openai_client = None
self._anthropic_client = None
def _get_openai(self):
if self._openai_client is None:
import openai
key = os.environ.get("OPENAI_API_KEY")
if not key:
raise ValueError("OPENAI_API_KEY not set")
self._openai_client = openai.OpenAI(api_key=key)
return self._openai_client
def _get_anthropic(self):
if self._anthropic_client is None:
import anthropic
key = os.environ.get("ANTHROPIC_API_KEY")
if not key:
return None
self._anthropic_client = anthropic.Anthropic(api_key=key)
return self._anthropic_client
async def explain(
self,
crop_b64: str,
frame_b64: str,
metadata: dict,
mission: str,
) -> dict:
"""Run the full 3-LLM explanation pipeline.
Args:
crop_b64: Base64-encoded JPEG of the cropped ROI.
frame_b64: Base64-encoded JPEG of the full frame.
metadata: Detection metadata dict (label, score, speed_kph, etc.).
mission: Mission objective string.
Returns:
Merged explanation tree with consensus data.
"""
# Step 1: GPT-4o primary analysis
primary_tree = await self._call_gpt(crop_b64, frame_b64, metadata, mission)
if primary_tree is None:
raise ValueError("Primary GPT-4o analysis failed")
# Step 2: Claude + Gemini validation in parallel
claude_result, gemini_result = await asyncio.gather(
self._call_claude(crop_b64, frame_b64, metadata, mission, primary_tree),
self._call_gemini(crop_b64, frame_b64, metadata, mission, primary_tree),
return_exceptions=True,
)
# Handle exceptions from validators
if isinstance(claude_result, Exception):
logger.warning("Claude validation failed: %s", claude_result)
claude_result = None
if isinstance(gemini_result, Exception):
logger.warning("Gemini validation failed: %s", gemini_result)
gemini_result = None
# Step 3: Merge into consensus tree
return self._merge(primary_tree, claude_result, gemini_result)
async def _call_gpt(self, crop_b64: str, frame_b64: str, metadata: dict, mission: str) -> Optional[dict]:
"""Call GPT-4o to generate the primary feature tree."""
try:
client = self._get_openai()
user_text = self._build_metadata_text(metadata, mission)
response = await asyncio.to_thread(
client.chat.completions.create,
model="gpt-4o",
messages=[
{"role": "system", "content": _PRIMARY_SYSTEM_PROMPT},
{"role": "user", "content": [
{"type": "text", "text": user_text},
{"type": "text", "text": "\n[Cropped object]:"},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{crop_b64}", "detail": "high"}},
{"type": "text", "text": "\n[Full frame context]:"},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{frame_b64}", "detail": "low"}},
]},
],
max_tokens=2048,
temperature=0.3,
)
raw = response.choices[0].message.content
return parse_llm_json(raw)
except Exception:
logger.exception("GPT-4o primary analysis failed")
return None
async def _call_claude(self, crop_b64: str, frame_b64: str, metadata: dict, mission: str, tree: dict) -> Optional[dict]:
"""Call Claude to validate the primary tree."""
client = self._get_anthropic()
if client is None:
logger.info("Skipping Claude validation — ANTHROPIC_API_KEY not set")
return None
try:
user_text = self._build_metadata_text(metadata, mission)
user_text += f"\n\nPrimary analyst's feature tree:\n```json\n{json.dumps(tree, indent=2)}\n```"
response = await asyncio.to_thread(
client.messages.create,
model="claude-sonnet-4-20250514",
max_tokens=1024,
system=_VALIDATOR_SYSTEM_PROMPT,
messages=[{
"role": "user",
"content": [
{"type": "text", "text": user_text},
{"type": "text", "text": "\n[Cropped object]:"},
{"type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": crop_b64}},
{"type": "text", "text": "\n[Full frame context]:"},
{"type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": frame_b64}},
],
}],
)
raw = response.content[0].text
logger.info("Claude raw response: %s", raw[:200] if raw else "empty")
return parse_llm_json(raw)
except Exception:
logger.exception("Claude validation failed")
return None
async def _call_gemini(self, crop_b64: str, frame_b64: str, metadata: dict, mission: str, tree: dict) -> Optional[dict]:
"""Call Gemini to validate the primary tree."""
api_key = os.environ.get("GEMINI_API_KEY")
if not api_key:
logger.info("Skipping Gemini validation — GEMINI_API_KEY not set")
return None
try:
import base64
from google import genai
from google.genai import types
client = genai.Client(api_key=api_key)
user_text = self._build_metadata_text(metadata, mission)
user_text += f"\n\nPrimary analyst's feature tree:\n```json\n{json.dumps(tree, indent=2)}\n```"
# Decode images for Gemini
crop_bytes = base64.b64decode(crop_b64)
frame_bytes = base64.b64decode(frame_b64)
response = await asyncio.to_thread(
client.models.generate_content,
model="gemini-2.0-flash",
contents=[
types.Content(role="user", parts=[
types.Part.from_text(_VALIDATOR_SYSTEM_PROMPT + "\n\n" + user_text),
types.Part.from_bytes(data=crop_bytes, mime_type="image/jpeg"),
types.Part.from_text("\n[Full frame context]:"),
types.Part.from_bytes(data=frame_bytes, mime_type="image/jpeg"),
]),
],
config=types.GenerateContentConfig(
max_output_tokens=1024,
temperature=0.3,
),
)
raw = response.text
logger.info("Gemini raw response: %s", raw[:200] if raw else "empty")
return parse_llm_json(raw)
except Exception:
logger.exception("Gemini validation failed")
return None
def _build_metadata_text(self, metadata: dict, mission: str) -> str:
"""Build the text portion describing the detection."""
lines = [
f'Mission: "{mission}"',
"",
"Detection metadata:",
f"- Label: {metadata.get('label', 'unknown')}",
f"- Confidence: {metadata.get('score', 0):.2f}",
f"- Speed: {metadata.get('speed_kph', 0):.1f} kph",
f"- Direction: {metadata.get('direction_clock', 'unknown')}",
f"- Angle: {metadata.get('angle_deg', 'N/A')}°",
]
bbox = metadata.get("bbox")
if bbox:
bw = bbox[2] - bbox[0]
bh = bbox[3] - bbox[1]
lines.append(f"- Bounding box size: {bw}x{bh} px")
return "\n".join(lines)
def _merge(self, tree: dict, claude: Optional[dict], gemini: Optional[dict]) -> dict:
"""Merge primary tree with validator results into consensus output."""
validators_available = sum(1 for v in [claude, gemini] if v is not None)
total_features = 0
agreed = 0
for cat in tree.get("categories", []):
cat_name = cat.get("name", "")
cat["color"] = _CATEGORY_COLORS.get(cat_name, "#64748b")
for feat in cat.get("features", []):
total_features += 1
feat_key = f"{cat_name}/{feat['name']}"
validators = {}
feat_agreed = 0
if claude and "feature_validations" in claude:
cv = self._find_validation(claude["feature_validations"], feat_key, feat["name"])
if cv:
validators["claude"] = cv
if cv.get("agree"):
feat_agreed += 1
if gemini and "feature_validations" in gemini:
gv = self._find_validation(gemini["feature_validations"], feat_key, feat["name"])
if gv:
validators["gemini"] = gv
if gv.get("agree"):
feat_agreed += 1
feat["validators"] = validators
feat["consensus"] = feat_agreed
if validators_available > 0 and feat_agreed == validators_available:
agreed += 1
tree["consensus_bar"] = {
"total_features": total_features,
"agreed": agreed,
"disagreed": total_features - agreed,
"validators_available": validators_available,
}
return tree
@staticmethod
def _find_validation(validations: dict, exact_key: str, feat_name: str) -> Optional[dict]:
"""Find validation by exact key first, then fuzzy match on feature name."""
# Exact match
val = validations.get(exact_key)
if val:
return val
# Fuzzy: try case-insensitive exact key
lower_key = exact_key.lower()
for k, v in validations.items():
if k.lower() == lower_key:
return v
# Fuzzy: match by feature name alone (validator may omit category)
lower_name = feat_name.lower()
for k, v in validations.items():
parts = k.split("/")
candidate = parts[-1].lower() if parts else k.lower()
if candidate == lower_name:
return v
return None