""" ================================================================= Disaster AI - HuggingFace Spaces API Final version - all models integrated ================================================================= """ import os import io import time import base64 import threading import traceback import numpy as np from PIL import Image import cv2 import torch import requests from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.middleware.cors import CORSMiddleware from huggingface_hub import hf_hub_download # ════════════════════════════════ # App Setup # ════════════════════════════════ app = FastAPI( title="Disaster AI Inference API", description="Multi-model disaster scene analysis for Dokai / RoboXavier", version="3.0.0", ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # ════════════════════════════════ # Configuration — all from secrets # ════════════════════════════════ HF_TOKEN = os.getenv("HF_TOKEN", None) HF_VICTIM_MODEL_REPO = os.getenv("HF_VICTIM_MODEL_REPO", "EgoisticCoderX/dokai-victim-detection") HF_XVIEW2_MODEL_REPO = os.getenv("HF_XVIEW2_MODEL_REPO", "EgoisticCoderX/dokai-xview2-damage") ROBOFLOW_API_KEY = os.getenv("ROBOFLOW_API_KEY", "rltTa8UANpettqj6aHJG") MODEL_CACHE_DIR = "/tmp/model_cache" os.makedirs(MODEL_CACHE_DIR, exist_ok=True) # ── Victim detection class map ── TARGET_CLASSES = { 0: "injured_civilian", 1: "trapped_civilian", 2: "safe_civilian", 3: "rescue_personnel", } CLASS_PRIORITY = { "injured_civilian": 1.0, "trapped_civilian": 0.95, "safe_civilian": 0.3, "rescue_personnel": 0.0, } # ── xView2 damage severity map ── DAMAGE_SEVERITY_ORDER = { "destroyed": 0, "major_damage": 1, "minor_damage": 2, "no_damage": 3, } # ════════════════════════════════ # Model Registry # ════════════════════════════════ class ModelRegistry: def __init__(self): self._models = {} self._errors = {} self._lock = threading.Lock() def get(self, name): return self._models.get(name) def register(self, name, model): with self._lock: self._models[name] = model print(f"✅ Model registered: {name}") def set_error(self, name, error): with self._lock: self._errors[name] = str(error) print(f"❌ Model error [{name}]: {error}") def is_loaded(self, name): return name in self._models def get_error(self, name): return self._errors.get(name, "Unknown error") def status(self): return { "loaded": list(self._models.keys()), "errored": {k: v for k, v in self._errors.items()}, } registry = ModelRegistry() # ════════════════════════════════ # Model Loaders # ════════════════════════════════ def load_ladi_model(): """Load LADI-v2 scene classifier from HuggingFace Hub.""" if registry.is_loaded("ladi"): return registry.get("ladi") try: from transformers import AutoImageProcessor, AutoModelForImageClassification print("⬇️ Loading MITLL/LADI-v2-classifier-small ...") processor = AutoImageProcessor.from_pretrained( "MITLL/LADI-v2-classifier-small", cache_dir=MODEL_CACHE_DIR, ) model = AutoModelForImageClassification.from_pretrained( "MITLL/LADI-v2-classifier-small", cache_dir=MODEL_CACHE_DIR, trust_remote_code=True, ignore_mismatched_sizes=True, ) model.eval() registry.register("ladi", {"model": model, "processor": processor}) print("✅ LADI-v2 ready") return registry.get("ladi") except Exception as e: print(f"❌ LADI-v2 load failed:\n{traceback.format_exc()}") registry.set_error("ladi", e) return None def load_victim_model(): """Load YOLOv8 victim detection model from HuggingFace Hub.""" if registry.is_loaded("victim"): return registry.get("victim") if not HF_VICTIM_MODEL_REPO: registry.set_error("victim", "HF_VICTIM_MODEL_REPO secret not set") return None try: from ultralytics import YOLO print(f"⬇️ Loading victim model from {HF_VICTIM_MODEL_REPO} ...") model_path = hf_hub_download( repo_id=HF_VICTIM_MODEL_REPO, filename="best.pt", cache_dir=MODEL_CACHE_DIR, token=HF_TOKEN, ) model = YOLO(model_path) registry.register("victim", model) print("✅ Victim detection model ready") return model except Exception as e: print(f"❌ Victim model load failed:\n{traceback.format_exc()}") registry.set_error("victim", e) return None def load_xview2_model(): """Load xView2 building damage YOLOv8 model from HuggingFace Hub.""" if registry.is_loaded("xview2"): return registry.get("xview2") if not HF_XVIEW2_MODEL_REPO: registry.set_error("xview2", "HF_XVIEW2_MODEL_REPO secret not set") return None try: from ultralytics import YOLO print(f"⬇️ Loading xView2 model from {HF_XVIEW2_MODEL_REPO} ...") model_path = hf_hub_download( repo_id=HF_XVIEW2_MODEL_REPO, filename="best.pt", cache_dir=MODEL_CACHE_DIR, token=HF_TOKEN, ) model = YOLO(model_path) registry.register("xview2", model) print("✅ xView2 damage model ready") return model except Exception as e: print(f"❌ xView2 model load failed:\n{traceback.format_exc()}") registry.set_error("xview2", e) return None # ════════════════════════════════ # Startup # ════════════════════════════════ @app.on_event("startup") async def startup_event(): print("\n" + "=" * 55) print("🚀 Disaster AI API v3.0 starting up...") print("=" * 55) # LADI always loads — public model load_ladi_model() # Victim model — needs secret if HF_VICTIM_MODEL_REPO: load_victim_model() else: print("⚠️ Victim model skipped — HF_VICTIM_MODEL_REPO not set") # xView2 model — needs secret if HF_XVIEW2_MODEL_REPO: load_xview2_model() else: print("⚠️ xView2 model skipped — HF_XVIEW2_MODEL_REPO not set") print("=" * 55) print(f"📊 Registry: {registry.status()}") print("=" * 55 + "\n") # ════════════════════════════════ # Utilities # ════════════════════════════════ def read_image(file_bytes: bytes) -> np.ndarray: nparr = np.frombuffer(file_bytes, np.uint8) img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) if img is None: raise HTTPException(status_code=400, detail="Invalid image — cannot decode") return img def call_roboflow(image: np.ndarray, model_id: str, confidence: int = 40) -> list: if not ROBOFLOW_API_KEY: return [] try: _, buffer = cv2.imencode('.jpg', image, [cv2.IMWRITE_JPEG_QUALITY, 80]) img_b64 = base64.b64encode(buffer) url = f"https://detect.roboflow.com/{model_id}?api_key={ROBOFLOW_API_KEY}&confidence={confidence}" res = requests.post( url, data=img_b64, headers={"Content-Type": "application/x-www-form-urlencoded"}, timeout=8, ) res.raise_for_status() preds = res.json().get("predictions", []) return [ { "class": p["class"], "confidence": round(p["confidence"], 4), "box": { "xmin": int(p["x"] - p["width"] / 2), "ymin": int(p["y"] - p["height"] / 2), "xmax": int(p["x"] + p["width"] / 2), "ymax": int(p["y"] + p["height"] / 2), }, } for p in preds ] except Exception as e: print(f"Roboflow error ({model_id}): {e}") return [] def compute_triage(detections: list) -> dict: if not detections: return { "total": 0, "critical": 0, "high": 0, "moderate": 0, "low": 0, "highest_score": 0.0, "action": "No victims detected", "ranked_victims": [], } scored = [] for d in detections: cls_name = d.get("class", "") conf = d.get("confidence", 0.5) weight = CLASS_PRIORITY.get(cls_name, 0.5) score = round(conf * weight, 4) rank = ( "CRITICAL" if score >= 0.7 else "HIGH" if score >= 0.4 else "MODERATE" if score >= 0.2 else "LOW" ) scored.append({**d, "priority_score": score, "priority_rank": rank}) scored.sort(key=lambda x: x["priority_score"], reverse=True) critical = sum(1 for d in scored if d["priority_rank"] == "CRITICAL") high = sum(1 for d in scored if d["priority_rank"] == "HIGH") moderate = sum(1 for d in scored if d["priority_rank"] == "MODERATE") low = sum(1 for d in scored if d["priority_rank"] == "LOW") action = ( "IMMEDIATE RESCUE - Critical victims present" if critical else "Deploy rescue team - High priority victims" if high else "Assess and triage - Moderate victims present" if moderate else "Low priority - Monitor the area" ) return { "total": len(scored), "critical": critical, "high": high, "moderate": moderate, "low": low, "highest_score": scored[0]["priority_score"] if scored else 0.0, "action": action, "ranked_victims": scored, } def compute_zone_color(triage_data: dict, damage_counts: dict, top_class: str) -> str: """ Unified zone color logic combining victim triage + building damage + scene class. red > orange > yellow > green """ critical = triage_data.get("critical", 0) high = triage_data.get("high", 0) destroyed = damage_counts.get("destroyed", 0) major_damage = damage_counts.get("major_damage", 0) minor_damage = damage_counts.get("minor_damage", 0) victim_total = triage_data.get("total", 0) scene_critical = any(w in top_class for w in ["destroy", "collapse", "major"]) scene_moderate = "minor_damage" in top_class if critical > 0 or destroyed > 0 or scene_critical: return "red" elif high > 0 or major_damage > 0 or scene_moderate: return "orange" elif victim_total > 0 or minor_damage > 0: return "yellow" else: return "green" # ════════════════════════════════ # Routes # ════════════════════════════════ @app.get("/") def root(): return { "service": "Disaster AI Inference API", "version": "3.0.0", "status": registry.status(), "endpoints": { "GET /health": "Health check + model status", "POST /classify": "LADI-v2 disaster scene classification", "POST /detect/victims": "Victim detection + triage priority", "POST /detect/vehicles": "Emergency vehicle detection (Roboflow)", "POST /detect/damage": "xView2 building damage assessment", "POST /analyze/full": "All models in one call (main endpoint)", }, } @app.get("/health") def health(): return { "status": "ok", "registry": registry.status(), "gpu_available": torch.cuda.is_available(), "secrets_set": { "HF_TOKEN": HF_TOKEN is not None, "HF_VICTIM_MODEL_REPO": bool(HF_VICTIM_MODEL_REPO), "HF_XVIEW2_MODEL_REPO": bool(HF_XVIEW2_MODEL_REPO), "ROBOFLOW_API_KEY": bool(ROBOFLOW_API_KEY), }, "timestamp": time.time(), } # ───────────────────────────────────────────── # 1. LADI-v2 Scene Classification # ───────────────────────────────────────────── @app.post("/classify") async def classify_scene( file: UploadFile = File(...), top_k: int = 5, ): """ Classify disaster scene using LADI-v2 (aerial damage categories). Returns top-k predicted labels with confidence scores. """ ladi = load_ladi_model() if ladi is None: raise HTTPException( status_code=503, detail=f"LADI-v2 unavailable: {registry.get_error('ladi')}" ) contents = await file.read() try: img_pil = Image.open(io.BytesIO(contents)).convert("RGB") except Exception: raise HTTPException(status_code=400, detail="Invalid image") model = ladi["model"] processor = ladi["processor"] t0 = time.time() try: inputs = processor(images=img_pil, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0] except Exception as e: raise HTTPException(status_code=500, detail=f"Inference failed: {e}") elapsed = round((time.time() - t0) * 1000, 2) id2label = model.config.id2label all_scores = sorted( [ { "class": id2label[i].lower().replace(" ", "_"), "confidence": round(float(probs[i]), 4), } for i in range(len(probs)) ], key=lambda x: x["confidence"], reverse=True, ) relevant = [ s for s in all_scores if not any(w in s["class"] for w in ["water", "flood"]) ] return { "top_predictions": all_scores[:top_k], "relevant_only": relevant[:top_k], "all_scores": all_scores, "inference_time_ms": elapsed, } # ───────────────────────────────────────────── # 2. Victim Detection + Triage # ───────────────────────────────────────────── @app.post("/detect/victims") async def detect_victims( file: UploadFile = File(...), confidence: float = 0.35, ): """ Detect victims and rank by triage priority. Classes: injured_civilian, trapped_civilian, safe_civilian, rescue_personnel. Priority ranks: CRITICAL / HIGH / MODERATE / LOW """ model = load_victim_model() if model is None: raise HTTPException( status_code=503, detail=f"Victim model unavailable: {registry.get_error('victim')}" ) contents = await file.read() img = read_image(contents) t0 = time.time() try: results = model.predict(source=img, conf=confidence, verbose=False) except Exception as e: raise HTTPException(status_code=500, detail=f"Inference failed: {e}") elapsed = round((time.time() - t0) * 1000, 2) raw = [] for r in results: for box in r.boxes: x1, y1, x2, y2 = map(int, box.xyxy[0].tolist()) conf_val = float(box.conf[0]) cls_id = int(box.cls[0]) raw.append({ "class": TARGET_CLASSES.get(cls_id, "unknown"), "class_id": cls_id, "confidence": round(conf_val, 4), "box": {"xmin": x1, "ymin": y1, "xmax": x2, "ymax": y2}, }) triage = compute_triage(raw) victims = triage.pop("ranked_victims", raw) return { "detections": victims, "triage_summary": triage, "inference_time_ms": elapsed, } # ───────────────────────────────────────────── # 3. Emergency Vehicle Detection (Roboflow) # ───────────────────────────────────────────── @app.post("/detect/vehicles") async def detect_vehicles(file: UploadFile = File(...)): """ Detect emergency vehicles via Roboflow hosted model. Returns ambulance / fire truck detections and rescue_arrived flag. """ if not ROBOFLOW_API_KEY: raise HTTPException(status_code=503, detail="ROBOFLOW_API_KEY secret not set") contents = await file.read() img = read_image(contents) t0 = time.time() detections = call_roboflow(img, "ambulance-4bova/1", confidence=40) elapsed = round((time.time() - t0) * 1000, 2) has_ambulance = any("ambulance" in d["class"].lower() for d in detections) has_fire_truck = any("fire" in d["class"].lower() for d in detections) return { "detections": detections, "emergency_vehicles": { "ambulance_detected": has_ambulance, "fire_truck_detected": has_fire_truck, "rescue_arrived": has_ambulance or has_fire_truck, }, "inference_time_ms": elapsed, } # ───────────────────────────────────────────── # 4. xView2 Building Damage Assessment # ───────────────────────────────────────────── @app.post("/detect/damage") async def detect_building_damage( file: UploadFile = File(...), confidence: float = 0.30, ): """ Assess building damage using xView2-trained YOLOv8. Classes: destroyed, major_damage, minor_damage, no_damage. Returns per-building detections, counts, and zone color. """ model = load_xview2_model() if model is None: raise HTTPException( status_code=503, detail=f"xView2 model unavailable: {registry.get_error('xview2')}" ) contents = await file.read() img = read_image(contents) t0 = time.time() try: results = model.predict(source=img, conf=confidence, verbose=False) except Exception as e: raise HTTPException(status_code=500, detail=f"Inference failed: {e}") elapsed = round((time.time() - t0) * 1000, 2) detections = [] counts = {"destroyed": 0, "major_damage": 0, "minor_damage": 0, "no_damage": 0} for r in results: for box in r.boxes: x1, y1, x2, y2 = map(int, box.xyxy[0].tolist()) conf_val = float(box.conf[0]) cls_id = int(box.cls[0]) class_name = model.names[cls_id].lower().replace(" ", "_") # Map raw class name to standard severity key matched_key = next( (k for k in counts if k in class_name), "no_damage" ) counts[matched_key] += 1 detections.append({ "class": class_name, "confidence": round(conf_val, 4), "severity": matched_key, "box": {"xmin": x1, "ymin": y1, "xmax": x2, "ymax": y2}, }) # Sort: destroyed first, no_damage last detections.sort(key=lambda d: DAMAGE_SEVERITY_ORDER.get(d["severity"], 9)) if counts["destroyed"] > 0: zone_color = "red" elif counts["major_damage"] > 0: zone_color = "orange" elif counts["minor_damage"] > 0: zone_color = "yellow" else: zone_color = "green" return { "detections": detections, "summary": counts, "total_buildings": sum(counts.values()), "zone_color": zone_color, "inference_time_ms": elapsed, } # ───────────────────────────────────────────── # 5. Full Analysis — all models in one call # ───────────────────────────────────────────── @app.post("/analyze/full") async def full_analysis( file: UploadFile = File(...), run_classify: bool = True, run_victims: bool = True, run_vehicles: bool = True, run_damage: bool = True, ): """ Run all available models on one image. Main endpoint for the RoboXavier rover Flask app. Returns unified zone_color, all detections, and triage/damage summaries. """ contents = await file.read() t_total = time.time() output = {} # ── LADI scene classification ── if run_classify: try: output["classification"] = await classify_scene( UploadFile(filename="f.jpg", file=io.BytesIO(contents)) ) except HTTPException as e: output["classification"] = {"error": e.detail} except Exception as e: output["classification"] = {"error": str(e)} # ── Victim detection ── if run_victims: try: output["victims"] = await detect_victims( UploadFile(filename="f.jpg", file=io.BytesIO(contents)) ) except HTTPException as e: output["victims"] = {"error": e.detail} except Exception as e: output["victims"] = {"error": str(e)} # ── Emergency vehicle detection ── if run_vehicles: try: output["vehicles"] = await detect_vehicles( UploadFile(filename="f.jpg", file=io.BytesIO(contents)) ) except HTTPException as e: output["vehicles"] = {"error": e.detail} except Exception as e: output["vehicles"] = {"error": str(e)} # ── xView2 building damage ── if run_damage: try: output["building_damage"] = await detect_building_damage( UploadFile(filename="f.jpg", file=io.BytesIO(contents)) ) except HTTPException as e: output["building_damage"] = {"error": e.detail} except Exception as e: output["building_damage"] = {"error": str(e)} # ── Unified zone color (all signals combined) ── triage_data = output.get("victims", {}).get("triage_summary", {}) damage_counts = output.get("building_damage", {}).get("summary", {}) classify_top = output.get("classification", {}).get("top_predictions", [{}]) top_class = classify_top[0].get("class", "") if classify_top else "" zone_color = compute_zone_color(triage_data, damage_counts, top_class) return { "zone_color": zone_color, "results": output, "total_time_ms": round((time.time() - t_total) * 1000, 2), "timestamp": time.time(), } # ════════════════════════════════ # Entry Point # ════════════════════════════════ if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)