Spaces:
Sleeping
Sleeping
| """ | |
| ================================================================= | |
| Disaster AI - HuggingFace Spaces API | |
| Final version - all models integrated | |
| ================================================================= | |
| """ | |
| import os | |
| import io | |
| import time | |
| import base64 | |
| import threading | |
| import traceback | |
| import numpy as np | |
| from PIL import Image | |
| import cv2 | |
| import torch | |
| import requests | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from huggingface_hub import hf_hub_download | |
| # ββββββββββββββββββββββββββββββββ | |
| # App Setup | |
| # ββββββββββββββββββββββββββββββββ | |
| app = FastAPI( | |
| title="Disaster AI Inference API", | |
| description="Multi-model disaster scene analysis for Dokai / RoboXavier", | |
| version="3.0.0", | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ββββββββββββββββββββββββββββββββ | |
| # Configuration β all from secrets | |
| # ββββββββββββββββββββββββββββββββ | |
| HF_TOKEN = os.getenv("HF_TOKEN", None) | |
| HF_VICTIM_MODEL_REPO = os.getenv("HF_VICTIM_MODEL_REPO", "EgoisticCoderX/dokai-victim-detection") | |
| HF_XVIEW2_MODEL_REPO = os.getenv("HF_XVIEW2_MODEL_REPO", "EgoisticCoderX/dokai-xview2-damage") | |
| ROBOFLOW_API_KEY = os.getenv("ROBOFLOW_API_KEY", "rltTa8UANpettqj6aHJG") | |
| MODEL_CACHE_DIR = "/tmp/model_cache" | |
| os.makedirs(MODEL_CACHE_DIR, exist_ok=True) | |
| # ββ Victim detection class map ββ | |
| TARGET_CLASSES = { | |
| 0: "injured_civilian", | |
| 1: "trapped_civilian", | |
| 2: "safe_civilian", | |
| 3: "rescue_personnel", | |
| } | |
| CLASS_PRIORITY = { | |
| "injured_civilian": 1.0, | |
| "trapped_civilian": 0.95, | |
| "safe_civilian": 0.3, | |
| "rescue_personnel": 0.0, | |
| } | |
| # ββ xView2 damage severity map ββ | |
| DAMAGE_SEVERITY_ORDER = { | |
| "destroyed": 0, | |
| "major_damage": 1, | |
| "minor_damage": 2, | |
| "no_damage": 3, | |
| } | |
| # ββββββββββββββββββββββββββββββββ | |
| # Model Registry | |
| # ββββββββββββββββββββββββββββββββ | |
| class ModelRegistry: | |
| def __init__(self): | |
| self._models = {} | |
| self._errors = {} | |
| self._lock = threading.Lock() | |
| def get(self, name): | |
| return self._models.get(name) | |
| def register(self, name, model): | |
| with self._lock: | |
| self._models[name] = model | |
| print(f"β Model registered: {name}") | |
| def set_error(self, name, error): | |
| with self._lock: | |
| self._errors[name] = str(error) | |
| print(f"β Model error [{name}]: {error}") | |
| def is_loaded(self, name): | |
| return name in self._models | |
| def get_error(self, name): | |
| return self._errors.get(name, "Unknown error") | |
| def status(self): | |
| return { | |
| "loaded": list(self._models.keys()), | |
| "errored": {k: v for k, v in self._errors.items()}, | |
| } | |
| registry = ModelRegistry() | |
| # ββββββββββββββββββββββββββββββββ | |
| # Model Loaders | |
| # ββββββββββββββββββββββββββββββββ | |
| def load_ladi_model(): | |
| """Load LADI-v2 scene classifier from HuggingFace Hub.""" | |
| if registry.is_loaded("ladi"): | |
| return registry.get("ladi") | |
| try: | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| print("β¬οΈ Loading MITLL/LADI-v2-classifier-small ...") | |
| processor = AutoImageProcessor.from_pretrained( | |
| "MITLL/LADI-v2-classifier-small", | |
| cache_dir=MODEL_CACHE_DIR, | |
| ) | |
| model = AutoModelForImageClassification.from_pretrained( | |
| "MITLL/LADI-v2-classifier-small", | |
| cache_dir=MODEL_CACHE_DIR, | |
| trust_remote_code=True, | |
| ignore_mismatched_sizes=True, | |
| ) | |
| model.eval() | |
| registry.register("ladi", {"model": model, "processor": processor}) | |
| print("β LADI-v2 ready") | |
| return registry.get("ladi") | |
| except Exception as e: | |
| print(f"β LADI-v2 load failed:\n{traceback.format_exc()}") | |
| registry.set_error("ladi", e) | |
| return None | |
| def load_victim_model(): | |
| """Load YOLOv8 victim detection model from HuggingFace Hub.""" | |
| if registry.is_loaded("victim"): | |
| return registry.get("victim") | |
| if not HF_VICTIM_MODEL_REPO: | |
| registry.set_error("victim", "HF_VICTIM_MODEL_REPO secret not set") | |
| return None | |
| try: | |
| from ultralytics import YOLO | |
| print(f"β¬οΈ Loading victim model from {HF_VICTIM_MODEL_REPO} ...") | |
| model_path = hf_hub_download( | |
| repo_id=HF_VICTIM_MODEL_REPO, | |
| filename="best.pt", | |
| cache_dir=MODEL_CACHE_DIR, | |
| token=HF_TOKEN, | |
| ) | |
| model = YOLO(model_path) | |
| registry.register("victim", model) | |
| print("β Victim detection model ready") | |
| return model | |
| except Exception as e: | |
| print(f"β Victim model load failed:\n{traceback.format_exc()}") | |
| registry.set_error("victim", e) | |
| return None | |
| def load_xview2_model(): | |
| """Load xView2 building damage YOLOv8 model from HuggingFace Hub.""" | |
| if registry.is_loaded("xview2"): | |
| return registry.get("xview2") | |
| if not HF_XVIEW2_MODEL_REPO: | |
| registry.set_error("xview2", "HF_XVIEW2_MODEL_REPO secret not set") | |
| return None | |
| try: | |
| from ultralytics import YOLO | |
| print(f"β¬οΈ Loading xView2 model from {HF_XVIEW2_MODEL_REPO} ...") | |
| model_path = hf_hub_download( | |
| repo_id=HF_XVIEW2_MODEL_REPO, | |
| filename="best.pt", | |
| cache_dir=MODEL_CACHE_DIR, | |
| token=HF_TOKEN, | |
| ) | |
| model = YOLO(model_path) | |
| registry.register("xview2", model) | |
| print("β xView2 damage model ready") | |
| return model | |
| except Exception as e: | |
| print(f"β xView2 model load failed:\n{traceback.format_exc()}") | |
| registry.set_error("xview2", e) | |
| return None | |
| # ββββββββββββββββββββββββββββββββ | |
| # Startup | |
| # ββββββββββββββββββββββββββββββββ | |
| async def startup_event(): | |
| print("\n" + "=" * 55) | |
| print("π Disaster AI API v3.0 starting up...") | |
| print("=" * 55) | |
| # LADI always loads β public model | |
| load_ladi_model() | |
| # Victim model β needs secret | |
| if HF_VICTIM_MODEL_REPO: | |
| load_victim_model() | |
| else: | |
| print("β οΈ Victim model skipped β HF_VICTIM_MODEL_REPO not set") | |
| # xView2 model β needs secret | |
| if HF_XVIEW2_MODEL_REPO: | |
| load_xview2_model() | |
| else: | |
| print("β οΈ xView2 model skipped β HF_XVIEW2_MODEL_REPO not set") | |
| print("=" * 55) | |
| print(f"π Registry: {registry.status()}") | |
| print("=" * 55 + "\n") | |
| # ββββββββββββββββββββββββββββββββ | |
| # Utilities | |
| # ββββββββββββββββββββββββββββββββ | |
| def read_image(file_bytes: bytes) -> np.ndarray: | |
| nparr = np.frombuffer(file_bytes, np.uint8) | |
| img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| if img is None: | |
| raise HTTPException(status_code=400, detail="Invalid image β cannot decode") | |
| return img | |
| def call_roboflow(image: np.ndarray, model_id: str, confidence: int = 40) -> list: | |
| if not ROBOFLOW_API_KEY: | |
| return [] | |
| try: | |
| _, buffer = cv2.imencode('.jpg', image, [cv2.IMWRITE_JPEG_QUALITY, 80]) | |
| img_b64 = base64.b64encode(buffer) | |
| url = f"https://detect.roboflow.com/{model_id}?api_key={ROBOFLOW_API_KEY}&confidence={confidence}" | |
| res = requests.post( | |
| url, | |
| data=img_b64, | |
| headers={"Content-Type": "application/x-www-form-urlencoded"}, | |
| timeout=8, | |
| ) | |
| res.raise_for_status() | |
| preds = res.json().get("predictions", []) | |
| return [ | |
| { | |
| "class": p["class"], | |
| "confidence": round(p["confidence"], 4), | |
| "box": { | |
| "xmin": int(p["x"] - p["width"] / 2), | |
| "ymin": int(p["y"] - p["height"] / 2), | |
| "xmax": int(p["x"] + p["width"] / 2), | |
| "ymax": int(p["y"] + p["height"] / 2), | |
| }, | |
| } | |
| for p in preds | |
| ] | |
| except Exception as e: | |
| print(f"Roboflow error ({model_id}): {e}") | |
| return [] | |
| def compute_triage(detections: list) -> dict: | |
| if not detections: | |
| return { | |
| "total": 0, "critical": 0, "high": 0, | |
| "moderate": 0, "low": 0, | |
| "highest_score": 0.0, | |
| "action": "No victims detected", | |
| "ranked_victims": [], | |
| } | |
| scored = [] | |
| for d in detections: | |
| cls_name = d.get("class", "") | |
| conf = d.get("confidence", 0.5) | |
| weight = CLASS_PRIORITY.get(cls_name, 0.5) | |
| score = round(conf * weight, 4) | |
| rank = ( | |
| "CRITICAL" if score >= 0.7 else | |
| "HIGH" if score >= 0.4 else | |
| "MODERATE" if score >= 0.2 else | |
| "LOW" | |
| ) | |
| scored.append({**d, "priority_score": score, "priority_rank": rank}) | |
| scored.sort(key=lambda x: x["priority_score"], reverse=True) | |
| critical = sum(1 for d in scored if d["priority_rank"] == "CRITICAL") | |
| high = sum(1 for d in scored if d["priority_rank"] == "HIGH") | |
| moderate = sum(1 for d in scored if d["priority_rank"] == "MODERATE") | |
| low = sum(1 for d in scored if d["priority_rank"] == "LOW") | |
| action = ( | |
| "IMMEDIATE RESCUE - Critical victims present" if critical else | |
| "Deploy rescue team - High priority victims" if high else | |
| "Assess and triage - Moderate victims present" if moderate else | |
| "Low priority - Monitor the area" | |
| ) | |
| return { | |
| "total": len(scored), | |
| "critical": critical, | |
| "high": high, | |
| "moderate": moderate, | |
| "low": low, | |
| "highest_score": scored[0]["priority_score"] if scored else 0.0, | |
| "action": action, | |
| "ranked_victims": scored, | |
| } | |
| def compute_zone_color(triage_data: dict, damage_counts: dict, top_class: str) -> str: | |
| """ | |
| Unified zone color logic combining victim triage + building damage + scene class. | |
| red > orange > yellow > green | |
| """ | |
| critical = triage_data.get("critical", 0) | |
| high = triage_data.get("high", 0) | |
| destroyed = damage_counts.get("destroyed", 0) | |
| major_damage = damage_counts.get("major_damage", 0) | |
| minor_damage = damage_counts.get("minor_damage", 0) | |
| victim_total = triage_data.get("total", 0) | |
| scene_critical = any(w in top_class for w in ["destroy", "collapse", "major"]) | |
| scene_moderate = "minor_damage" in top_class | |
| if critical > 0 or destroyed > 0 or scene_critical: | |
| return "red" | |
| elif high > 0 or major_damage > 0 or scene_moderate: | |
| return "orange" | |
| elif victim_total > 0 or minor_damage > 0: | |
| return "yellow" | |
| else: | |
| return "green" | |
| # ββββββββββββββββββββββββββββββββ | |
| # Routes | |
| # ββββββββββββββββββββββββββββββββ | |
| def root(): | |
| return { | |
| "service": "Disaster AI Inference API", | |
| "version": "3.0.0", | |
| "status": registry.status(), | |
| "endpoints": { | |
| "GET /health": "Health check + model status", | |
| "POST /classify": "LADI-v2 disaster scene classification", | |
| "POST /detect/victims": "Victim detection + triage priority", | |
| "POST /detect/vehicles": "Emergency vehicle detection (Roboflow)", | |
| "POST /detect/damage": "xView2 building damage assessment", | |
| "POST /analyze/full": "All models in one call (main endpoint)", | |
| }, | |
| } | |
| def health(): | |
| return { | |
| "status": "ok", | |
| "registry": registry.status(), | |
| "gpu_available": torch.cuda.is_available(), | |
| "secrets_set": { | |
| "HF_TOKEN": HF_TOKEN is not None, | |
| "HF_VICTIM_MODEL_REPO": bool(HF_VICTIM_MODEL_REPO), | |
| "HF_XVIEW2_MODEL_REPO": bool(HF_XVIEW2_MODEL_REPO), | |
| "ROBOFLOW_API_KEY": bool(ROBOFLOW_API_KEY), | |
| }, | |
| "timestamp": time.time(), | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # 1. LADI-v2 Scene Classification | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| async def classify_scene( | |
| file: UploadFile = File(...), | |
| top_k: int = 5, | |
| ): | |
| """ | |
| Classify disaster scene using LADI-v2 (aerial damage categories). | |
| Returns top-k predicted labels with confidence scores. | |
| """ | |
| ladi = load_ladi_model() | |
| if ladi is None: | |
| raise HTTPException( | |
| status_code=503, | |
| detail=f"LADI-v2 unavailable: {registry.get_error('ladi')}" | |
| ) | |
| contents = await file.read() | |
| try: | |
| img_pil = Image.open(io.BytesIO(contents)).convert("RGB") | |
| except Exception: | |
| raise HTTPException(status_code=400, detail="Invalid image") | |
| model = ladi["model"] | |
| processor = ladi["processor"] | |
| t0 = time.time() | |
| try: | |
| inputs = processor(images=img_pil, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0] | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Inference failed: {e}") | |
| elapsed = round((time.time() - t0) * 1000, 2) | |
| id2label = model.config.id2label | |
| all_scores = sorted( | |
| [ | |
| { | |
| "class": id2label[i].lower().replace(" ", "_"), | |
| "confidence": round(float(probs[i]), 4), | |
| } | |
| for i in range(len(probs)) | |
| ], | |
| key=lambda x: x["confidence"], | |
| reverse=True, | |
| ) | |
| relevant = [ | |
| s for s in all_scores | |
| if not any(w in s["class"] for w in ["water", "flood"]) | |
| ] | |
| return { | |
| "top_predictions": all_scores[:top_k], | |
| "relevant_only": relevant[:top_k], | |
| "all_scores": all_scores, | |
| "inference_time_ms": elapsed, | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # 2. Victim Detection + Triage | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| async def detect_victims( | |
| file: UploadFile = File(...), | |
| confidence: float = 0.35, | |
| ): | |
| """ | |
| Detect victims and rank by triage priority. | |
| Classes: injured_civilian, trapped_civilian, safe_civilian, rescue_personnel. | |
| Priority ranks: CRITICAL / HIGH / MODERATE / LOW | |
| """ | |
| model = load_victim_model() | |
| if model is None: | |
| raise HTTPException( | |
| status_code=503, | |
| detail=f"Victim model unavailable: {registry.get_error('victim')}" | |
| ) | |
| contents = await file.read() | |
| img = read_image(contents) | |
| t0 = time.time() | |
| try: | |
| results = model.predict(source=img, conf=confidence, verbose=False) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Inference failed: {e}") | |
| elapsed = round((time.time() - t0) * 1000, 2) | |
| raw = [] | |
| for r in results: | |
| for box in r.boxes: | |
| x1, y1, x2, y2 = map(int, box.xyxy[0].tolist()) | |
| conf_val = float(box.conf[0]) | |
| cls_id = int(box.cls[0]) | |
| raw.append({ | |
| "class": TARGET_CLASSES.get(cls_id, "unknown"), | |
| "class_id": cls_id, | |
| "confidence": round(conf_val, 4), | |
| "box": {"xmin": x1, "ymin": y1, "xmax": x2, "ymax": y2}, | |
| }) | |
| triage = compute_triage(raw) | |
| victims = triage.pop("ranked_victims", raw) | |
| return { | |
| "detections": victims, | |
| "triage_summary": triage, | |
| "inference_time_ms": elapsed, | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # 3. Emergency Vehicle Detection (Roboflow) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| async def detect_vehicles(file: UploadFile = File(...)): | |
| """ | |
| Detect emergency vehicles via Roboflow hosted model. | |
| Returns ambulance / fire truck detections and rescue_arrived flag. | |
| """ | |
| if not ROBOFLOW_API_KEY: | |
| raise HTTPException(status_code=503, detail="ROBOFLOW_API_KEY secret not set") | |
| contents = await file.read() | |
| img = read_image(contents) | |
| t0 = time.time() | |
| detections = call_roboflow(img, "ambulance-4bova/1", confidence=40) | |
| elapsed = round((time.time() - t0) * 1000, 2) | |
| has_ambulance = any("ambulance" in d["class"].lower() for d in detections) | |
| has_fire_truck = any("fire" in d["class"].lower() for d in detections) | |
| return { | |
| "detections": detections, | |
| "emergency_vehicles": { | |
| "ambulance_detected": has_ambulance, | |
| "fire_truck_detected": has_fire_truck, | |
| "rescue_arrived": has_ambulance or has_fire_truck, | |
| }, | |
| "inference_time_ms": elapsed, | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # 4. xView2 Building Damage Assessment | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| async def detect_building_damage( | |
| file: UploadFile = File(...), | |
| confidence: float = 0.30, | |
| ): | |
| """ | |
| Assess building damage using xView2-trained YOLOv8. | |
| Classes: destroyed, major_damage, minor_damage, no_damage. | |
| Returns per-building detections, counts, and zone color. | |
| """ | |
| model = load_xview2_model() | |
| if model is None: | |
| raise HTTPException( | |
| status_code=503, | |
| detail=f"xView2 model unavailable: {registry.get_error('xview2')}" | |
| ) | |
| contents = await file.read() | |
| img = read_image(contents) | |
| t0 = time.time() | |
| try: | |
| results = model.predict(source=img, conf=confidence, verbose=False) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Inference failed: {e}") | |
| elapsed = round((time.time() - t0) * 1000, 2) | |
| detections = [] | |
| counts = {"destroyed": 0, "major_damage": 0, "minor_damage": 0, "no_damage": 0} | |
| for r in results: | |
| for box in r.boxes: | |
| x1, y1, x2, y2 = map(int, box.xyxy[0].tolist()) | |
| conf_val = float(box.conf[0]) | |
| cls_id = int(box.cls[0]) | |
| class_name = model.names[cls_id].lower().replace(" ", "_") | |
| # Map raw class name to standard severity key | |
| matched_key = next( | |
| (k for k in counts if k in class_name), | |
| "no_damage" | |
| ) | |
| counts[matched_key] += 1 | |
| detections.append({ | |
| "class": class_name, | |
| "confidence": round(conf_val, 4), | |
| "severity": matched_key, | |
| "box": {"xmin": x1, "ymin": y1, "xmax": x2, "ymax": y2}, | |
| }) | |
| # Sort: destroyed first, no_damage last | |
| detections.sort(key=lambda d: DAMAGE_SEVERITY_ORDER.get(d["severity"], 9)) | |
| if counts["destroyed"] > 0: | |
| zone_color = "red" | |
| elif counts["major_damage"] > 0: | |
| zone_color = "orange" | |
| elif counts["minor_damage"] > 0: | |
| zone_color = "yellow" | |
| else: | |
| zone_color = "green" | |
| return { | |
| "detections": detections, | |
| "summary": counts, | |
| "total_buildings": sum(counts.values()), | |
| "zone_color": zone_color, | |
| "inference_time_ms": elapsed, | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # 5. Full Analysis β all models in one call | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| async def full_analysis( | |
| file: UploadFile = File(...), | |
| run_classify: bool = True, | |
| run_victims: bool = True, | |
| run_vehicles: bool = True, | |
| run_damage: bool = True, | |
| ): | |
| """ | |
| Run all available models on one image. | |
| Main endpoint for the RoboXavier rover Flask app. | |
| Returns unified zone_color, all detections, and triage/damage summaries. | |
| """ | |
| contents = await file.read() | |
| t_total = time.time() | |
| output = {} | |
| # ββ LADI scene classification ββ | |
| if run_classify: | |
| try: | |
| output["classification"] = await classify_scene( | |
| UploadFile(filename="f.jpg", file=io.BytesIO(contents)) | |
| ) | |
| except HTTPException as e: | |
| output["classification"] = {"error": e.detail} | |
| except Exception as e: | |
| output["classification"] = {"error": str(e)} | |
| # ββ Victim detection ββ | |
| if run_victims: | |
| try: | |
| output["victims"] = await detect_victims( | |
| UploadFile(filename="f.jpg", file=io.BytesIO(contents)) | |
| ) | |
| except HTTPException as e: | |
| output["victims"] = {"error": e.detail} | |
| except Exception as e: | |
| output["victims"] = {"error": str(e)} | |
| # ββ Emergency vehicle detection ββ | |
| if run_vehicles: | |
| try: | |
| output["vehicles"] = await detect_vehicles( | |
| UploadFile(filename="f.jpg", file=io.BytesIO(contents)) | |
| ) | |
| except HTTPException as e: | |
| output["vehicles"] = {"error": e.detail} | |
| except Exception as e: | |
| output["vehicles"] = {"error": str(e)} | |
| # ββ xView2 building damage ββ | |
| if run_damage: | |
| try: | |
| output["building_damage"] = await detect_building_damage( | |
| UploadFile(filename="f.jpg", file=io.BytesIO(contents)) | |
| ) | |
| except HTTPException as e: | |
| output["building_damage"] = {"error": e.detail} | |
| except Exception as e: | |
| output["building_damage"] = {"error": str(e)} | |
| # ββ Unified zone color (all signals combined) ββ | |
| triage_data = output.get("victims", {}).get("triage_summary", {}) | |
| damage_counts = output.get("building_damage", {}).get("summary", {}) | |
| classify_top = output.get("classification", {}).get("top_predictions", [{}]) | |
| top_class = classify_top[0].get("class", "") if classify_top else "" | |
| zone_color = compute_zone_color(triage_data, damage_counts, top_class) | |
| return { | |
| "zone_color": zone_color, | |
| "results": output, | |
| "total_time_ms": round((time.time() - t_total) * 1000, 2), | |
| "timestamp": time.time(), | |
| } | |
| # ββββββββββββββββββββββββββββββββ | |
| # Entry Point | |
| # ββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |