"""Multi-LLM Explainability Pipeline. Orchestrates GPT-4o (primary analyzer) + Claude & Gemini (validators) to produce a hierarchical feature tree explaining why an object was classified as mission-relevant. """ import asyncio import json import logging import os from typing import Optional from models.isr.utils import crop_and_encode, encode_frame, parse_llm_json logger = logging.getLogger(__name__) # Category color map (synced with frontend) _CATEGORY_COLORS = { "Structure": "#3b82f6", "Function": "#06b6d4", "Material": "#f59e0b", "Color": "#ef4444", "Size": "#10b981", "Type": "#8b5cf6", "Motion": "#ec4899", "Context": "#64748b", "Shape": "#f97316", "Markings": "#a855f7", } _PRIMARY_SYSTEM_PROMPT = """You are an ISR (Intelligence, Surveillance, Reconnaissance) analyst explaining WHY a detected object matches or does not match a mission objective. You will receive: - A cropped image of the detected object - The full frame showing spatial context - Detection metadata (label, confidence, speed, depth, direction) - The mission objective Analyze the object and produce a HIERARCHICAL FEATURE TREE explaining the key visual and functional features that led to the classification. Return ONLY a JSON object (no markdown, no explanation) with this exact structure: { "object": "", "satisfies": true/false/null, "confidence": 0.0-1.0, "reasoning_summary": "<1-2 sentence summary>", "categories": [ { "name": "", "features": [ { "name": "", "value": true/false, "reasoning": "<1 sentence explaining this observation>" } ] } ] } Rules: - Pick 3-6 categories most relevant to THIS SPECIFIC object from: Structure, Function, Material, Color, Size, Type, Motion, Context, Shape, Markings - Each category should have 1-4 features (total 5-20 features across all categories) - Features must be VISUAL OBSERVATIONS from the image, not assumptions - Be specific and expert-level (a program manager should find this insightful) - confidence reflects how certain you are about the overall assessment""" _VALIDATOR_SYSTEM_PROMPT = """You are an ISR analyst reviewing another analyst's feature assessment of a detected object. You will receive: - The same cropped image and full frame - Detection metadata - The primary analyst's hierarchical feature tree Your job: independently validate each feature by examining the images yourself. Return ONLY a JSON object (no markdown) with this structure: { "agreement": true/false, "confidence": 0.0-1.0, "feature_validations": { "CategoryName/FeatureName": { "agree": true/false, "note": "" } } } Rules: - Validate EVERY feature in the tree - Use the key format "CategoryName/FeatureName" exactly - Be honest — disagree when the image doesn't support the claim - Keep notes to 1 sentence""" class ISRExplainer: """Orchestrates multi-LLM explanation pipeline for a single track.""" def __init__(self): self._openai_client = None self._anthropic_client = None def _get_openai(self): if self._openai_client is None: import openai key = os.environ.get("OPENAI_API_KEY") if not key: raise ValueError("OPENAI_API_KEY not set") self._openai_client = openai.OpenAI(api_key=key) return self._openai_client def _get_anthropic(self): if self._anthropic_client is None: import anthropic key = os.environ.get("ANTHROPIC_API_KEY") if not key: return None self._anthropic_client = anthropic.Anthropic(api_key=key) return self._anthropic_client async def explain( self, crop_b64: str, frame_b64: str, metadata: dict, mission: str, ) -> dict: """Run the full 3-LLM explanation pipeline. Args: crop_b64: Base64-encoded JPEG of the cropped ROI. frame_b64: Base64-encoded JPEG of the full frame. metadata: Detection metadata dict (label, score, speed_kph, etc.). mission: Mission objective string. Returns: Merged explanation tree with consensus data. """ # Step 1: GPT-4o primary analysis primary_tree = await self._call_gpt(crop_b64, frame_b64, metadata, mission) if primary_tree is None: raise ValueError("Primary GPT-4o analysis failed") # Step 2: Claude + Gemini validation in parallel claude_result, gemini_result = await asyncio.gather( self._call_claude(crop_b64, frame_b64, metadata, mission, primary_tree), self._call_gemini(crop_b64, frame_b64, metadata, mission, primary_tree), return_exceptions=True, ) # Handle exceptions from validators if isinstance(claude_result, Exception): logger.warning("Claude validation failed: %s", claude_result) claude_result = None if isinstance(gemini_result, Exception): logger.warning("Gemini validation failed: %s", gemini_result) gemini_result = None # Step 3: Merge into consensus tree return self._merge(primary_tree, claude_result, gemini_result) async def _call_gpt(self, crop_b64: str, frame_b64: str, metadata: dict, mission: str) -> Optional[dict]: """Call GPT-4o to generate the primary feature tree.""" try: client = self._get_openai() user_text = self._build_metadata_text(metadata, mission) response = await asyncio.to_thread( client.chat.completions.create, model="gpt-4o", messages=[ {"role": "system", "content": _PRIMARY_SYSTEM_PROMPT}, {"role": "user", "content": [ {"type": "text", "text": user_text}, {"type": "text", "text": "\n[Cropped object]:"}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{crop_b64}", "detail": "high"}}, {"type": "text", "text": "\n[Full frame context]:"}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{frame_b64}", "detail": "low"}}, ]}, ], max_tokens=2048, temperature=0.3, ) raw = response.choices[0].message.content return parse_llm_json(raw) except Exception: logger.exception("GPT-4o primary analysis failed") return None async def _call_claude(self, crop_b64: str, frame_b64: str, metadata: dict, mission: str, tree: dict) -> Optional[dict]: """Call Claude to validate the primary tree.""" client = self._get_anthropic() if client is None: logger.info("Skipping Claude validation — ANTHROPIC_API_KEY not set") return None try: user_text = self._build_metadata_text(metadata, mission) user_text += f"\n\nPrimary analyst's feature tree:\n```json\n{json.dumps(tree, indent=2)}\n```" response = await asyncio.to_thread( client.messages.create, model="claude-sonnet-4-20250514", max_tokens=1024, system=_VALIDATOR_SYSTEM_PROMPT, messages=[{ "role": "user", "content": [ {"type": "text", "text": user_text}, {"type": "text", "text": "\n[Cropped object]:"}, {"type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": crop_b64}}, {"type": "text", "text": "\n[Full frame context]:"}, {"type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": frame_b64}}, ], }], ) raw = response.content[0].text logger.info("Claude raw response: %s", raw[:200] if raw else "empty") return parse_llm_json(raw) except Exception: logger.exception("Claude validation failed") return None async def _call_gemini(self, crop_b64: str, frame_b64: str, metadata: dict, mission: str, tree: dict) -> Optional[dict]: """Call Gemini to validate the primary tree.""" api_key = os.environ.get("GEMINI_API_KEY") if not api_key: logger.info("Skipping Gemini validation — GEMINI_API_KEY not set") return None try: import base64 from google import genai from google.genai import types client = genai.Client(api_key=api_key) user_text = self._build_metadata_text(metadata, mission) user_text += f"\n\nPrimary analyst's feature tree:\n```json\n{json.dumps(tree, indent=2)}\n```" # Decode images for Gemini crop_bytes = base64.b64decode(crop_b64) frame_bytes = base64.b64decode(frame_b64) response = await asyncio.to_thread( client.models.generate_content, model="gemini-2.0-flash", contents=[ types.Content(role="user", parts=[ types.Part.from_text(_VALIDATOR_SYSTEM_PROMPT + "\n\n" + user_text), types.Part.from_bytes(data=crop_bytes, mime_type="image/jpeg"), types.Part.from_text("\n[Full frame context]:"), types.Part.from_bytes(data=frame_bytes, mime_type="image/jpeg"), ]), ], config=types.GenerateContentConfig( max_output_tokens=1024, temperature=0.3, ), ) raw = response.text logger.info("Gemini raw response: %s", raw[:200] if raw else "empty") return parse_llm_json(raw) except Exception: logger.exception("Gemini validation failed") return None def _build_metadata_text(self, metadata: dict, mission: str) -> str: """Build the text portion describing the detection.""" lines = [ f'Mission: "{mission}"', "", "Detection metadata:", f"- Label: {metadata.get('label', 'unknown')}", f"- Confidence: {metadata.get('score', 0):.2f}", f"- Speed: {metadata.get('speed_kph', 0):.1f} kph", f"- Direction: {metadata.get('direction_clock', 'unknown')}", f"- Angle: {metadata.get('angle_deg', 'N/A')}°", ] bbox = metadata.get("bbox") if bbox: bw = bbox[2] - bbox[0] bh = bbox[3] - bbox[1] lines.append(f"- Bounding box size: {bw}x{bh} px") return "\n".join(lines) def _merge(self, tree: dict, claude: Optional[dict], gemini: Optional[dict]) -> dict: """Merge primary tree with validator results into consensus output.""" validators_available = sum(1 for v in [claude, gemini] if v is not None) total_features = 0 agreed = 0 for cat in tree.get("categories", []): cat_name = cat.get("name", "") cat["color"] = _CATEGORY_COLORS.get(cat_name, "#64748b") for feat in cat.get("features", []): total_features += 1 feat_key = f"{cat_name}/{feat['name']}" validators = {} feat_agreed = 0 if claude and "feature_validations" in claude: cv = self._find_validation(claude["feature_validations"], feat_key, feat["name"]) if cv: validators["claude"] = cv if cv.get("agree"): feat_agreed += 1 if gemini and "feature_validations" in gemini: gv = self._find_validation(gemini["feature_validations"], feat_key, feat["name"]) if gv: validators["gemini"] = gv if gv.get("agree"): feat_agreed += 1 feat["validators"] = validators feat["consensus"] = feat_agreed if validators_available > 0 and feat_agreed == validators_available: agreed += 1 tree["consensus_bar"] = { "total_features": total_features, "agreed": agreed, "disagreed": total_features - agreed, "validators_available": validators_available, } return tree @staticmethod def _find_validation(validations: dict, exact_key: str, feat_name: str) -> Optional[dict]: """Find validation by exact key first, then fuzzy match on feature name.""" # Exact match val = validations.get(exact_key) if val: return val # Fuzzy: try case-insensitive exact key lower_key = exact_key.lower() for k, v in validations.items(): if k.lower() == lower_key: return v # Fuzzy: match by feature name alone (validator may omit category) lower_name = feat_name.lower() for k, v in validations.items(): parts = k.split("/") candidate = parts[-1].lower() if parts else k.lower() if candidate == lower_name: return v return None