Spaces:
Paused
Paused
| import re | |
| import json | |
| import base64 | |
| import logging | |
| from typing import List, Dict, Any, Optional | |
| from utils.schemas import AssessmentStatus | |
| from utils.openai_client import chat_completion, extract_content, get_api_key | |
| logger = logging.getLogger(__name__) | |
| def encode_image(image_path: str) -> str: | |
| with open(image_path, "rb") as image_file: | |
| return base64.b64encode(image_file.read()).decode('utf-8') | |
| def encode_frame_to_b64(frame, quality=None) -> str: | |
| """Encode an OpenCV BGR frame to a base64 JPEG string in memory (no disk I/O). | |
| Args: | |
| frame: OpenCV BGR numpy array. | |
| quality: Optional JPEG quality (1-100). Uses OpenCV default if None. | |
| """ | |
| import cv2 | |
| params = [int(cv2.IMWRITE_JPEG_QUALITY), quality] if quality is not None else None | |
| success, buf = cv2.imencode('.jpg', frame, params) if params else cv2.imencode('.jpg', frame) | |
| if not success: | |
| raise ValueError("Failed to encode frame to JPEG") | |
| return base64.b64encode(buf.tobytes()).decode('utf-8') | |
| _DOMAIN_ROLES = { | |
| "NAVAL": "Naval Intelligence Officer and Maritime Threat Analyst", | |
| "GROUND": "Ground Surveillance Intelligence Officer", | |
| "AERIAL": "Air Surveillance Intelligence Officer", | |
| "URBAN": "Urban Surveillance Intelligence Officer", | |
| "GENERIC": "Tactical Surveillance Analyst", | |
| } | |
| _HUMAN_LABEL_HINTS = frozenset({ | |
| "person", "people", "human", "pedestrian", | |
| "man", "woman", "boy", "girl", "child", | |
| "civilian", "soldier", "infantry", "troop", "trooper", | |
| }) | |
| def _is_human_label(label: str) -> bool: | |
| label_l = (label or "").lower().strip() | |
| if not label_l: | |
| return False | |
| parts = [p for p in re.split(r"[^a-z0-9]+", label_l) if p] | |
| return any(part in _HUMAN_LABEL_HINTS for part in parts) | |
| def _build_status_fallback( | |
| object_ids: List[str], | |
| status: str, | |
| reason: str, | |
| ) -> Dict[str, Dict[str, Any]]: | |
| return { | |
| obj_id: { | |
| "assessment_status": status, | |
| "gpt_reason": reason, | |
| } | |
| for obj_id in object_ids | |
| } | |
| _UNIVERSAL_SCHEMA = ( | |
| "RESPONSE SCHEMA (JSON):\n" | |
| "{\n" | |
| " \"objects\": {\n" | |
| " \"T01\": {\n" | |
| " \"object_type\": \"string (broad category, e.g. Warship, APC, Sedan, Person)\",\n" | |
| " \"size\": \"string (e.g. Large, Medium, Small, ~50m length)\",\n" | |
| " \"visible_weapons\": [\"string\"],\n" | |
| " \"weapon_readiness\": \"string (e.g. Stowed/PEACE, Trained/Aiming, Firing/HOSTILE, Unknown)\",\n" | |
| " \"motion_status\": \"string (e.g. Stationary, Moving Slow, Moving Fast, Hovering)\",\n" | |
| " \"range_estimate\": \"string (e.g. ~500m, ~2NM, ~1km)\",\n" | |
| " \"bearing\": \"string (e.g. 12 o'clock, NNE, 045°)\",\n" | |
| " \"threat_level\": int (1-10, 1=Benign, 10=Imminent Attack),\n" | |
| " \"threat_classification\": \"Friendly\" | \"Neutral\" | \"Suspect\" | \"Hostile\",\n" | |
| " \"tactical_intent\": \"string (e.g. Transit, Patrol, Attack Profile)\",\n" | |
| " \"dynamic_features\": [\n" | |
| " {\"key\": \"string (domain-specific observation name)\", \"value\": \"string\"}\n" | |
| " ] // up to 5 extra observations relevant to the domain\n" | |
| " }\n" | |
| " }\n" | |
| "}\n" | |
| ) | |
| def _parse_range_to_meters(range_text: str) -> Optional[float]: | |
| """Convert a free-text range string to meters. | |
| Supports patterns like '~500m', '~2NM', '~1.5km', '500 meters', '2 nautical miles'. | |
| Returns None if the string cannot be parsed. | |
| """ | |
| if not range_text or range_text == "Unknown": | |
| return None | |
| text = range_text.strip().lstrip("~").strip() | |
| # Try NM / nautical miles | |
| m = re.match(r"([0-9]*\.?[0-9]+)\s*(NM|nm|nautical\s*miles?)", text) | |
| if m: | |
| return float(m.group(1)) * 1852.0 | |
| # Try km / kilometers | |
| m = re.match(r"([0-9]*\.?[0-9]+)\s*(km|kilometers?|kilometres?)", text, re.IGNORECASE) | |
| if m: | |
| return float(m.group(1)) * 1000.0 | |
| # Try meters (default) | |
| m = re.match(r"([0-9]*\.?[0-9]+)\s*(m|meters?|metres?)?$", text, re.IGNORECASE) | |
| if m: | |
| return float(m.group(1)) | |
| return None | |
| def _build_domain_system_prompt(domain: str, mission_spec=None) -> str: | |
| """Build a universal system prompt with domain-appropriate role.""" | |
| # Mission context block (injected regardless of domain) | |
| mission_context = "" | |
| if mission_spec: | |
| mission_context = ( | |
| "\n\nMISSION CONTEXT:\n" | |
| f"- Operator Intent: {mission_spec.mission_intent}\n" | |
| f"- Domain: {mission_spec.domain}\n" | |
| f"- Target Classes: {', '.join(mission_spec.object_classes)}\n" | |
| ) | |
| if mission_spec.context_phrases: | |
| mission_context += f"- Situational Context: {'; '.join(mission_spec.context_phrases)}\n" | |
| if mission_spec.stripped_modifiers: | |
| mission_context += f"- Operator Modifiers (stripped): {', '.join(mission_spec.stripped_modifiers)}\n" | |
| mission_context += ( | |
| "\nUse the mission context to inform your analysis. " | |
| "Focus assessment on the target classes and domain specified." | |
| ) | |
| role = _DOMAIN_ROLES.get(domain, _DOMAIN_ROLES["GENERIC"]) | |
| return ( | |
| f"You are an elite {role}. " | |
| "Your task is to analyze optical surveillance imagery and provide a detailed tactical assessment for every detected object. " | |
| f"You must output a STRICT JSON object that matches the following schema for every object ID provided:\n\n" | |
| f"{_UNIVERSAL_SCHEMA}\n" | |
| "RULES:\n" | |
| "- Use dynamic_features for domain-specific observations (e.g., wake_description, deck_activity, sensor_profile, camouflage, license_plate).\n" | |
| "- Provide up to 5 dynamic_features per object. Choose the most tactically relevant observations.\n" | |
| "- range_estimate should be a human-readable string with units (e.g., '~500m', '~2NM').\n" | |
| "- Visible trained weapons are IMMINENT threat (Score 9-10).\n" | |
| "- Ignore artifacts, focus on the objects." | |
| + mission_context | |
| ) | |
| def estimate_threat_gpt( | |
| image_path: Optional[str] = None, | |
| detections: Optional[List[Dict[str, Any]]] = None, | |
| mission_spec=None, # Optional[MissionSpecification] | |
| image_b64: Optional[str] = None, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Perform Threat Assessment on detected objects using GPT-4o. | |
| Args: | |
| image_path: Path to the image file (mutually exclusive with image_b64). | |
| detections: List of detection dicts (bbox, label, etc.). | |
| mission_spec: Optional MissionSpecification for domain-aware assessment. | |
| image_b64: Pre-encoded base64 JPEG string (avoids disk round-trip). | |
| Returns: | |
| Dict mapping object ID (e.g., T01) to threat assessment dict. | |
| """ | |
| if detections is None: | |
| detections = [] | |
| if not get_api_key(): | |
| logger.error("OPENAI_API_KEY not set. Skipping GPT threat assessment.") | |
| return {} | |
| # 1. Prepare detections summary for prompt. | |
| # Human/person classes are explicitly skipped to avoid refusal paths. | |
| prompt_items = [] | |
| skipped_human_ids: List[str] = [] | |
| for i, det in enumerate(detections): | |
| obj_id = str(det.get("track_id") or det.get("id") or f"T{str(i+1).zfill(2)}") | |
| bbox = det.get("bbox", []) | |
| label = str(det.get("label", "object")) | |
| if _is_human_label(label): | |
| skipped_human_ids.append(obj_id) | |
| continue | |
| prompt_items.append({"obj_id": obj_id, "label": label, "bbox": bbox}) | |
| det_text = "\n".join( | |
| [ | |
| f"- ID: {it['obj_id']}, Classification Hint: {it['label']}, BBox: {it['bbox']}" | |
| for it in prompt_items | |
| ] | |
| ) | |
| if not det_text: | |
| if skipped_human_ids: | |
| logger.warning( | |
| "Skipping GPT threat assessment for %d human/person detections due policy constraints.", | |
| len(skipped_human_ids), | |
| ) | |
| return _build_status_fallback( | |
| skipped_human_ids, | |
| AssessmentStatus.SKIPPED_POLICY, | |
| "Human/person analysis skipped due policy constraints.", | |
| ) | |
| return {} | |
| # 2. Encode image (prefer pre-encoded b64 to avoid disk I/O) | |
| if image_b64: | |
| base64_image = image_b64 | |
| elif image_path: | |
| try: | |
| base64_image = encode_image(image_path) | |
| except Exception as e: | |
| logger.error(f"Failed to encode image for GPT: {e}") | |
| return {} | |
| else: | |
| logger.error("estimate_threat_gpt: no image_path or image_b64 provided") | |
| return {} | |
| # 3. Domain-aware prompt selection (INV-7) | |
| domain = "GENERIC" # default — universal schema works for all domains | |
| if mission_spec: | |
| domain = mission_spec.domain | |
| if mission_spec.domain_source == "INFERRED": | |
| logger.info("GPT assessment using inferred domain=%s (domain_inferred=True)", domain) | |
| system_prompt = _build_domain_system_prompt(domain, mission_spec) | |
| domain_label = domain.lower() if domain != "NAVAL" else "naval" | |
| user_prompt = ( | |
| f"Analyze this {domain_label} surveillance image. The following objects have been detected:\n" | |
| f"{det_text}\n\n" | |
| f"Provide a detailed Threat Assessment for each object based on its visual signatures." | |
| ) | |
| # 4. Call API | |
| payload = { | |
| "model": "gpt-4o", # Use 4o for better vision analysis | |
| "messages": [ | |
| { | |
| "role": "system", | |
| "content": system_prompt | |
| }, | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": user_prompt | |
| }, | |
| { | |
| "type": "image_url", | |
| "image_url": { | |
| "url": f"data:image/jpeg;base64,{base64_image}", | |
| "detail": "low" | |
| } | |
| } | |
| ] | |
| } | |
| ], | |
| "max_tokens": 1500, | |
| "temperature": 0.2, # Low temp for factual consistency | |
| "response_format": { "type": "json_object" } | |
| } | |
| try: | |
| resp_data = chat_completion(payload) | |
| content, refusal = extract_content(resp_data) | |
| if not content: | |
| if refusal: | |
| logger.warning("GPT refused threat assessment: %s", refusal) | |
| else: | |
| logger.warning( | |
| "GPT returned empty content. response_id=%s finish_reason=%s", | |
| resp_data.get("id"), | |
| resp_data.get("choices", [{}])[0].get("finish_reason"), | |
| ) | |
| fallback = _build_status_fallback( | |
| [it["obj_id"] for it in prompt_items], | |
| AssessmentStatus.REFUSED, | |
| refusal or "GPT returned empty content.", | |
| ) | |
| fallback.update( | |
| _build_status_fallback( | |
| skipped_human_ids, | |
| AssessmentStatus.SKIPPED_POLICY, | |
| "Human/person analysis skipped due policy constraints.", | |
| ) | |
| ) | |
| return fallback | |
| result_json = json.loads(content) | |
| objects = result_json.get("objects", {}) | |
| if not isinstance(objects, dict): | |
| logger.warning( | |
| "GPT response 'objects' field is not a dict (got %s); using fallback.", | |
| type(objects).__name__, | |
| ) | |
| objects = {} | |
| # Ensure every requested object receives an explicit assessment state. | |
| for it in prompt_items: | |
| oid = it["obj_id"] | |
| if oid not in objects: | |
| objects[oid] = { | |
| "assessment_status": AssessmentStatus.NO_RESPONSE, | |
| "gpt_reason": "No structured assessment returned for object.", | |
| } | |
| for oid in skipped_human_ids: | |
| objects.setdefault( | |
| oid, | |
| { | |
| "assessment_status": AssessmentStatus.SKIPPED_POLICY, | |
| "gpt_reason": "Human/person analysis skipped due policy constraints.", | |
| }, | |
| ) | |
| # Polyfill legacy fields for frontend compatibility | |
| for obj_id, data in objects.items(): | |
| if not isinstance(data, dict): | |
| data = { | |
| "assessment_status": AssessmentStatus.NO_RESPONSE, | |
| "gpt_reason": "Malformed object payload from GPT.", | |
| } | |
| objects[obj_id] = data | |
| # 1. Distance: parse free-text range_estimate to meters | |
| range_m = _parse_range_to_meters(data.get("range_estimate", "")) | |
| if range_m is not None: | |
| data["distance_m"] = range_m | |
| data["gpt_distance_m"] = range_m | |
| # 2. Direction (legacy alias) | |
| bearing = data.get("bearing", "") | |
| if bearing and bearing != "Unknown": | |
| data["direction"] = bearing | |
| data["gpt_direction"] = bearing | |
| # 3. Description (summary of new fields) | |
| obj_type = data.get("object_type", "Unknown") | |
| threat = data.get("threat_classification", "Unknown") | |
| score = data.get("threat_level", 0) | |
| desc_parts = [obj_type] | |
| desc_parts.append(f"[{threat.upper()} Lvl:{score}]") | |
| data["description"] = " ".join(desc_parts) | |
| data["gpt_description"] = data["description"] | |
| # 4. Legacy threat_level_score alias | |
| data["threat_level_score"] = data.get("threat_level", 0) | |
| return objects | |
| except Exception as e: | |
| logger.error("GPT API call failed: %s", e, exc_info=True) | |
| fallback = _build_status_fallback( | |
| [it["obj_id"] for it in prompt_items], | |
| AssessmentStatus.ERROR, | |
| f"GPT API call failed: {e.__class__.__name__}", | |
| ) | |
| fallback.update( | |
| _build_status_fallback( | |
| skipped_human_ids, | |
| AssessmentStatus.SKIPPED_POLICY, | |
| "Human/person analysis skipped due policy constraints.", | |
| ) | |
| ) | |
| return fallback | |