Spaces:
Runtime error
Runtime error
| """Multi-LLM Explainability Pipeline. | |
| Orchestrates GPT-4o (primary analyzer) + Claude & Gemini (validators) | |
| to produce a hierarchical feature tree explaining why an object was | |
| classified as mission-relevant. | |
| """ | |
| import asyncio | |
| import json | |
| import logging | |
| import os | |
| from typing import Optional | |
| from models.isr.utils import crop_and_encode, encode_frame, parse_llm_json | |
| logger = logging.getLogger(__name__) | |
| # Category color map (synced with frontend) | |
| _CATEGORY_COLORS = { | |
| "Structure": "#3b82f6", | |
| "Function": "#06b6d4", | |
| "Material": "#f59e0b", | |
| "Color": "#ef4444", | |
| "Size": "#10b981", | |
| "Type": "#8b5cf6", | |
| "Motion": "#ec4899", | |
| "Context": "#64748b", | |
| "Shape": "#f97316", | |
| "Markings": "#a855f7", | |
| } | |
| _PRIMARY_SYSTEM_PROMPT = """You are an ISR (Intelligence, Surveillance, Reconnaissance) analyst explaining WHY a detected object matches or does not match a mission objective. | |
| You will receive: | |
| - A cropped image of the detected object | |
| - The full frame showing spatial context | |
| - Detection metadata (label, confidence, speed, depth, direction) | |
| - The mission objective | |
| Analyze the object and produce a HIERARCHICAL FEATURE TREE explaining the key visual and functional features that led to the classification. | |
| Return ONLY a JSON object (no markdown, no explanation) with this exact structure: | |
| { | |
| "object": "<detected class label>", | |
| "satisfies": true/false/null, | |
| "confidence": 0.0-1.0, | |
| "reasoning_summary": "<1-2 sentence summary>", | |
| "categories": [ | |
| { | |
| "name": "<category name>", | |
| "features": [ | |
| { | |
| "name": "<feature name>", | |
| "value": true/false, | |
| "reasoning": "<1 sentence explaining this observation>" | |
| } | |
| ] | |
| } | |
| ] | |
| } | |
| Rules: | |
| - Pick 3-6 categories most relevant to THIS SPECIFIC object from: Structure, Function, Material, Color, Size, Type, Motion, Context, Shape, Markings | |
| - Each category should have 1-4 features (total 5-20 features across all categories) | |
| - Features must be VISUAL OBSERVATIONS from the image, not assumptions | |
| - Be specific and expert-level (a program manager should find this insightful) | |
| - confidence reflects how certain you are about the overall assessment""" | |
| _VALIDATOR_SYSTEM_PROMPT = """You are an ISR analyst reviewing another analyst's feature assessment of a detected object. | |
| You will receive: | |
| - The same cropped image and full frame | |
| - Detection metadata | |
| - The primary analyst's hierarchical feature tree | |
| Your job: independently validate each feature by examining the images yourself. | |
| Return ONLY a JSON object (no markdown) with this structure: | |
| { | |
| "agreement": true/false, | |
| "confidence": 0.0-1.0, | |
| "feature_validations": { | |
| "CategoryName/FeatureName": { | |
| "agree": true/false, | |
| "note": "<brief observation>" | |
| } | |
| } | |
| } | |
| Rules: | |
| - Validate EVERY feature in the tree | |
| - Use the key format "CategoryName/FeatureName" exactly | |
| - Be honest — disagree when the image doesn't support the claim | |
| - Keep notes to 1 sentence""" | |
| class ISRExplainer: | |
| """Orchestrates multi-LLM explanation pipeline for a single track.""" | |
| def __init__(self): | |
| self._openai_client = None | |
| self._anthropic_client = None | |
| def _get_openai(self): | |
| if self._openai_client is None: | |
| import openai | |
| key = os.environ.get("OPENAI_API_KEY") | |
| if not key: | |
| raise ValueError("OPENAI_API_KEY not set") | |
| self._openai_client = openai.OpenAI(api_key=key) | |
| return self._openai_client | |
| def _get_anthropic(self): | |
| if self._anthropic_client is None: | |
| import anthropic | |
| key = os.environ.get("ANTHROPIC_API_KEY") | |
| if not key: | |
| return None | |
| self._anthropic_client = anthropic.Anthropic(api_key=key) | |
| return self._anthropic_client | |
| async def explain( | |
| self, | |
| crop_b64: str, | |
| frame_b64: str, | |
| metadata: dict, | |
| mission: str, | |
| ) -> dict: | |
| """Run the full 3-LLM explanation pipeline. | |
| Args: | |
| crop_b64: Base64-encoded JPEG of the cropped ROI. | |
| frame_b64: Base64-encoded JPEG of the full frame. | |
| metadata: Detection metadata dict (label, score, speed_kph, etc.). | |
| mission: Mission objective string. | |
| Returns: | |
| Merged explanation tree with consensus data. | |
| """ | |
| # Step 1: GPT-4o primary analysis | |
| primary_tree = await self._call_gpt(crop_b64, frame_b64, metadata, mission) | |
| if primary_tree is None: | |
| raise ValueError("Primary GPT-4o analysis failed") | |
| # Step 2: Claude + Gemini validation in parallel | |
| claude_result, gemini_result = await asyncio.gather( | |
| self._call_claude(crop_b64, frame_b64, metadata, mission, primary_tree), | |
| self._call_gemini(crop_b64, frame_b64, metadata, mission, primary_tree), | |
| return_exceptions=True, | |
| ) | |
| # Handle exceptions from validators | |
| if isinstance(claude_result, Exception): | |
| logger.warning("Claude validation failed: %s", claude_result) | |
| claude_result = None | |
| if isinstance(gemini_result, Exception): | |
| logger.warning("Gemini validation failed: %s", gemini_result) | |
| gemini_result = None | |
| # Step 3: Merge into consensus tree | |
| return self._merge(primary_tree, claude_result, gemini_result) | |
| async def _call_gpt(self, crop_b64: str, frame_b64: str, metadata: dict, mission: str) -> Optional[dict]: | |
| """Call GPT-4o to generate the primary feature tree.""" | |
| try: | |
| client = self._get_openai() | |
| user_text = self._build_metadata_text(metadata, mission) | |
| response = await asyncio.to_thread( | |
| client.chat.completions.create, | |
| model="gpt-4o", | |
| messages=[ | |
| {"role": "system", "content": _PRIMARY_SYSTEM_PROMPT}, | |
| {"role": "user", "content": [ | |
| {"type": "text", "text": user_text}, | |
| {"type": "text", "text": "\n[Cropped object]:"}, | |
| {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{crop_b64}", "detail": "high"}}, | |
| {"type": "text", "text": "\n[Full frame context]:"}, | |
| {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{frame_b64}", "detail": "low"}}, | |
| ]}, | |
| ], | |
| max_tokens=2048, | |
| temperature=0.3, | |
| ) | |
| raw = response.choices[0].message.content | |
| return parse_llm_json(raw) | |
| except Exception: | |
| logger.exception("GPT-4o primary analysis failed") | |
| return None | |
| async def _call_claude(self, crop_b64: str, frame_b64: str, metadata: dict, mission: str, tree: dict) -> Optional[dict]: | |
| """Call Claude to validate the primary tree.""" | |
| client = self._get_anthropic() | |
| if client is None: | |
| logger.info("Skipping Claude validation — ANTHROPIC_API_KEY not set") | |
| return None | |
| try: | |
| user_text = self._build_metadata_text(metadata, mission) | |
| user_text += f"\n\nPrimary analyst's feature tree:\n```json\n{json.dumps(tree, indent=2)}\n```" | |
| response = await asyncio.to_thread( | |
| client.messages.create, | |
| model="claude-sonnet-4-20250514", | |
| max_tokens=1024, | |
| system=_VALIDATOR_SYSTEM_PROMPT, | |
| messages=[{ | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": user_text}, | |
| {"type": "text", "text": "\n[Cropped object]:"}, | |
| {"type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": crop_b64}}, | |
| {"type": "text", "text": "\n[Full frame context]:"}, | |
| {"type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": frame_b64}}, | |
| ], | |
| }], | |
| ) | |
| raw = response.content[0].text | |
| logger.info("Claude raw response: %s", raw[:200] if raw else "empty") | |
| return parse_llm_json(raw) | |
| except Exception: | |
| logger.exception("Claude validation failed") | |
| return None | |
| async def _call_gemini(self, crop_b64: str, frame_b64: str, metadata: dict, mission: str, tree: dict) -> Optional[dict]: | |
| """Call Gemini to validate the primary tree.""" | |
| api_key = os.environ.get("GEMINI_API_KEY") | |
| if not api_key: | |
| logger.info("Skipping Gemini validation — GEMINI_API_KEY not set") | |
| return None | |
| try: | |
| import base64 | |
| from google import genai | |
| from google.genai import types | |
| client = genai.Client(api_key=api_key) | |
| user_text = self._build_metadata_text(metadata, mission) | |
| user_text += f"\n\nPrimary analyst's feature tree:\n```json\n{json.dumps(tree, indent=2)}\n```" | |
| # Decode images for Gemini | |
| crop_bytes = base64.b64decode(crop_b64) | |
| frame_bytes = base64.b64decode(frame_b64) | |
| response = await asyncio.to_thread( | |
| client.models.generate_content, | |
| model="gemini-2.0-flash", | |
| contents=[ | |
| types.Content(role="user", parts=[ | |
| types.Part.from_text(_VALIDATOR_SYSTEM_PROMPT + "\n\n" + user_text), | |
| types.Part.from_bytes(data=crop_bytes, mime_type="image/jpeg"), | |
| types.Part.from_text("\n[Full frame context]:"), | |
| types.Part.from_bytes(data=frame_bytes, mime_type="image/jpeg"), | |
| ]), | |
| ], | |
| config=types.GenerateContentConfig( | |
| max_output_tokens=1024, | |
| temperature=0.3, | |
| ), | |
| ) | |
| raw = response.text | |
| logger.info("Gemini raw response: %s", raw[:200] if raw else "empty") | |
| return parse_llm_json(raw) | |
| except Exception: | |
| logger.exception("Gemini validation failed") | |
| return None | |
| def _build_metadata_text(self, metadata: dict, mission: str) -> str: | |
| """Build the text portion describing the detection.""" | |
| lines = [ | |
| f'Mission: "{mission}"', | |
| "", | |
| "Detection metadata:", | |
| f"- Label: {metadata.get('label', 'unknown')}", | |
| f"- Confidence: {metadata.get('score', 0):.2f}", | |
| f"- Speed: {metadata.get('speed_kph', 0):.1f} kph", | |
| f"- Direction: {metadata.get('direction_clock', 'unknown')}", | |
| f"- Angle: {metadata.get('angle_deg', 'N/A')}°", | |
| ] | |
| bbox = metadata.get("bbox") | |
| if bbox: | |
| bw = bbox[2] - bbox[0] | |
| bh = bbox[3] - bbox[1] | |
| lines.append(f"- Bounding box size: {bw}x{bh} px") | |
| return "\n".join(lines) | |
| def _merge(self, tree: dict, claude: Optional[dict], gemini: Optional[dict]) -> dict: | |
| """Merge primary tree with validator results into consensus output.""" | |
| validators_available = sum(1 for v in [claude, gemini] if v is not None) | |
| total_features = 0 | |
| agreed = 0 | |
| for cat in tree.get("categories", []): | |
| cat_name = cat.get("name", "") | |
| cat["color"] = _CATEGORY_COLORS.get(cat_name, "#64748b") | |
| for feat in cat.get("features", []): | |
| total_features += 1 | |
| feat_key = f"{cat_name}/{feat['name']}" | |
| validators = {} | |
| feat_agreed = 0 | |
| if claude and "feature_validations" in claude: | |
| cv = self._find_validation(claude["feature_validations"], feat_key, feat["name"]) | |
| if cv: | |
| validators["claude"] = cv | |
| if cv.get("agree"): | |
| feat_agreed += 1 | |
| if gemini and "feature_validations" in gemini: | |
| gv = self._find_validation(gemini["feature_validations"], feat_key, feat["name"]) | |
| if gv: | |
| validators["gemini"] = gv | |
| if gv.get("agree"): | |
| feat_agreed += 1 | |
| feat["validators"] = validators | |
| feat["consensus"] = feat_agreed | |
| if validators_available > 0 and feat_agreed == validators_available: | |
| agreed += 1 | |
| tree["consensus_bar"] = { | |
| "total_features": total_features, | |
| "agreed": agreed, | |
| "disagreed": total_features - agreed, | |
| "validators_available": validators_available, | |
| } | |
| return tree | |
| def _find_validation(validations: dict, exact_key: str, feat_name: str) -> Optional[dict]: | |
| """Find validation by exact key first, then fuzzy match on feature name.""" | |
| # Exact match | |
| val = validations.get(exact_key) | |
| if val: | |
| return val | |
| # Fuzzy: try case-insensitive exact key | |
| lower_key = exact_key.lower() | |
| for k, v in validations.items(): | |
| if k.lower() == lower_key: | |
| return v | |
| # Fuzzy: match by feature name alone (validator may omit category) | |
| lower_name = feat_name.lower() | |
| for k, v in validations.items(): | |
| parts = k.split("/") | |
| candidate = parts[-1].lower() if parts else k.lower() | |
| if candidate == lower_name: | |
| return v | |
| return None | |