AbhinavGupta
Update app.py
7a8b57c verified
Raw
History Blame Contribute Delete
24.2 kB
"""
=================================================================
Disaster AI - HuggingFace Spaces API
Final version - all models integrated
=================================================================
"""
import os
import io
import time
import base64
import threading
import traceback
import numpy as np
from PIL import Image
import cv2
import torch
import requests
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from huggingface_hub import hf_hub_download
# ════════════════════════════════
# App Setup
# ════════════════════════════════
app = FastAPI(
title="Disaster AI Inference API",
description="Multi-model disaster scene analysis for Dokai / RoboXavier",
version="3.0.0",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ════════════════════════════════
# Configuration β€” all from secrets
# ════════════════════════════════
HF_TOKEN = os.getenv("HF_TOKEN", None)
HF_VICTIM_MODEL_REPO = os.getenv("HF_VICTIM_MODEL_REPO", "EgoisticCoderX/dokai-victim-detection")
HF_XVIEW2_MODEL_REPO = os.getenv("HF_XVIEW2_MODEL_REPO", "EgoisticCoderX/dokai-xview2-damage")
ROBOFLOW_API_KEY = os.getenv("ROBOFLOW_API_KEY", "rltTa8UANpettqj6aHJG")
MODEL_CACHE_DIR = "/tmp/model_cache"
os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
# ── Victim detection class map ──
TARGET_CLASSES = {
0: "injured_civilian",
1: "trapped_civilian",
2: "safe_civilian",
3: "rescue_personnel",
}
CLASS_PRIORITY = {
"injured_civilian": 1.0,
"trapped_civilian": 0.95,
"safe_civilian": 0.3,
"rescue_personnel": 0.0,
}
# ── xView2 damage severity map ──
DAMAGE_SEVERITY_ORDER = {
"destroyed": 0,
"major_damage": 1,
"minor_damage": 2,
"no_damage": 3,
}
# ════════════════════════════════
# Model Registry
# ════════════════════════════════
class ModelRegistry:
def __init__(self):
self._models = {}
self._errors = {}
self._lock = threading.Lock()
def get(self, name):
return self._models.get(name)
def register(self, name, model):
with self._lock:
self._models[name] = model
print(f"βœ… Model registered: {name}")
def set_error(self, name, error):
with self._lock:
self._errors[name] = str(error)
print(f"❌ Model error [{name}]: {error}")
def is_loaded(self, name):
return name in self._models
def get_error(self, name):
return self._errors.get(name, "Unknown error")
def status(self):
return {
"loaded": list(self._models.keys()),
"errored": {k: v for k, v in self._errors.items()},
}
registry = ModelRegistry()
# ════════════════════════════════
# Model Loaders
# ════════════════════════════════
def load_ladi_model():
"""Load LADI-v2 scene classifier from HuggingFace Hub."""
if registry.is_loaded("ladi"):
return registry.get("ladi")
try:
from transformers import AutoImageProcessor, AutoModelForImageClassification
print("⬇️ Loading MITLL/LADI-v2-classifier-small ...")
processor = AutoImageProcessor.from_pretrained(
"MITLL/LADI-v2-classifier-small",
cache_dir=MODEL_CACHE_DIR,
)
model = AutoModelForImageClassification.from_pretrained(
"MITLL/LADI-v2-classifier-small",
cache_dir=MODEL_CACHE_DIR,
trust_remote_code=True,
ignore_mismatched_sizes=True,
)
model.eval()
registry.register("ladi", {"model": model, "processor": processor})
print("βœ… LADI-v2 ready")
return registry.get("ladi")
except Exception as e:
print(f"❌ LADI-v2 load failed:\n{traceback.format_exc()}")
registry.set_error("ladi", e)
return None
def load_victim_model():
"""Load YOLOv8 victim detection model from HuggingFace Hub."""
if registry.is_loaded("victim"):
return registry.get("victim")
if not HF_VICTIM_MODEL_REPO:
registry.set_error("victim", "HF_VICTIM_MODEL_REPO secret not set")
return None
try:
from ultralytics import YOLO
print(f"⬇️ Loading victim model from {HF_VICTIM_MODEL_REPO} ...")
model_path = hf_hub_download(
repo_id=HF_VICTIM_MODEL_REPO,
filename="best.pt",
cache_dir=MODEL_CACHE_DIR,
token=HF_TOKEN,
)
model = YOLO(model_path)
registry.register("victim", model)
print("βœ… Victim detection model ready")
return model
except Exception as e:
print(f"❌ Victim model load failed:\n{traceback.format_exc()}")
registry.set_error("victim", e)
return None
def load_xview2_model():
"""Load xView2 building damage YOLOv8 model from HuggingFace Hub."""
if registry.is_loaded("xview2"):
return registry.get("xview2")
if not HF_XVIEW2_MODEL_REPO:
registry.set_error("xview2", "HF_XVIEW2_MODEL_REPO secret not set")
return None
try:
from ultralytics import YOLO
print(f"⬇️ Loading xView2 model from {HF_XVIEW2_MODEL_REPO} ...")
model_path = hf_hub_download(
repo_id=HF_XVIEW2_MODEL_REPO,
filename="best.pt",
cache_dir=MODEL_CACHE_DIR,
token=HF_TOKEN,
)
model = YOLO(model_path)
registry.register("xview2", model)
print("βœ… xView2 damage model ready")
return model
except Exception as e:
print(f"❌ xView2 model load failed:\n{traceback.format_exc()}")
registry.set_error("xview2", e)
return None
# ════════════════════════════════
# Startup
# ════════════════════════════════
@app.on_event("startup")
async def startup_event():
print("\n" + "=" * 55)
print("πŸš€ Disaster AI API v3.0 starting up...")
print("=" * 55)
# LADI always loads β€” public model
load_ladi_model()
# Victim model β€” needs secret
if HF_VICTIM_MODEL_REPO:
load_victim_model()
else:
print("⚠️ Victim model skipped β€” HF_VICTIM_MODEL_REPO not set")
# xView2 model β€” needs secret
if HF_XVIEW2_MODEL_REPO:
load_xview2_model()
else:
print("⚠️ xView2 model skipped β€” HF_XVIEW2_MODEL_REPO not set")
print("=" * 55)
print(f"πŸ“Š Registry: {registry.status()}")
print("=" * 55 + "\n")
# ════════════════════════════════
# Utilities
# ════════════════════════════════
def read_image(file_bytes: bytes) -> np.ndarray:
nparr = np.frombuffer(file_bytes, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=400, detail="Invalid image β€” cannot decode")
return img
def call_roboflow(image: np.ndarray, model_id: str, confidence: int = 40) -> list:
if not ROBOFLOW_API_KEY:
return []
try:
_, buffer = cv2.imencode('.jpg', image, [cv2.IMWRITE_JPEG_QUALITY, 80])
img_b64 = base64.b64encode(buffer)
url = f"https://detect.roboflow.com/{model_id}?api_key={ROBOFLOW_API_KEY}&confidence={confidence}"
res = requests.post(
url,
data=img_b64,
headers={"Content-Type": "application/x-www-form-urlencoded"},
timeout=8,
)
res.raise_for_status()
preds = res.json().get("predictions", [])
return [
{
"class": p["class"],
"confidence": round(p["confidence"], 4),
"box": {
"xmin": int(p["x"] - p["width"] / 2),
"ymin": int(p["y"] - p["height"] / 2),
"xmax": int(p["x"] + p["width"] / 2),
"ymax": int(p["y"] + p["height"] / 2),
},
}
for p in preds
]
except Exception as e:
print(f"Roboflow error ({model_id}): {e}")
return []
def compute_triage(detections: list) -> dict:
if not detections:
return {
"total": 0, "critical": 0, "high": 0,
"moderate": 0, "low": 0,
"highest_score": 0.0,
"action": "No victims detected",
"ranked_victims": [],
}
scored = []
for d in detections:
cls_name = d.get("class", "")
conf = d.get("confidence", 0.5)
weight = CLASS_PRIORITY.get(cls_name, 0.5)
score = round(conf * weight, 4)
rank = (
"CRITICAL" if score >= 0.7 else
"HIGH" if score >= 0.4 else
"MODERATE" if score >= 0.2 else
"LOW"
)
scored.append({**d, "priority_score": score, "priority_rank": rank})
scored.sort(key=lambda x: x["priority_score"], reverse=True)
critical = sum(1 for d in scored if d["priority_rank"] == "CRITICAL")
high = sum(1 for d in scored if d["priority_rank"] == "HIGH")
moderate = sum(1 for d in scored if d["priority_rank"] == "MODERATE")
low = sum(1 for d in scored if d["priority_rank"] == "LOW")
action = (
"IMMEDIATE RESCUE - Critical victims present" if critical else
"Deploy rescue team - High priority victims" if high else
"Assess and triage - Moderate victims present" if moderate else
"Low priority - Monitor the area"
)
return {
"total": len(scored),
"critical": critical,
"high": high,
"moderate": moderate,
"low": low,
"highest_score": scored[0]["priority_score"] if scored else 0.0,
"action": action,
"ranked_victims": scored,
}
def compute_zone_color(triage_data: dict, damage_counts: dict, top_class: str) -> str:
"""
Unified zone color logic combining victim triage + building damage + scene class.
red > orange > yellow > green
"""
critical = triage_data.get("critical", 0)
high = triage_data.get("high", 0)
destroyed = damage_counts.get("destroyed", 0)
major_damage = damage_counts.get("major_damage", 0)
minor_damage = damage_counts.get("minor_damage", 0)
victim_total = triage_data.get("total", 0)
scene_critical = any(w in top_class for w in ["destroy", "collapse", "major"])
scene_moderate = "minor_damage" in top_class
if critical > 0 or destroyed > 0 or scene_critical:
return "red"
elif high > 0 or major_damage > 0 or scene_moderate:
return "orange"
elif victim_total > 0 or minor_damage > 0:
return "yellow"
else:
return "green"
# ════════════════════════════════
# Routes
# ════════════════════════════════
@app.get("/")
def root():
return {
"service": "Disaster AI Inference API",
"version": "3.0.0",
"status": registry.status(),
"endpoints": {
"GET /health": "Health check + model status",
"POST /classify": "LADI-v2 disaster scene classification",
"POST /detect/victims": "Victim detection + triage priority",
"POST /detect/vehicles": "Emergency vehicle detection (Roboflow)",
"POST /detect/damage": "xView2 building damage assessment",
"POST /analyze/full": "All models in one call (main endpoint)",
},
}
@app.get("/health")
def health():
return {
"status": "ok",
"registry": registry.status(),
"gpu_available": torch.cuda.is_available(),
"secrets_set": {
"HF_TOKEN": HF_TOKEN is not None,
"HF_VICTIM_MODEL_REPO": bool(HF_VICTIM_MODEL_REPO),
"HF_XVIEW2_MODEL_REPO": bool(HF_XVIEW2_MODEL_REPO),
"ROBOFLOW_API_KEY": bool(ROBOFLOW_API_KEY),
},
"timestamp": time.time(),
}
# ─────────────────────────────────────────────
# 1. LADI-v2 Scene Classification
# ─────────────────────────────────────────────
@app.post("/classify")
async def classify_scene(
file: UploadFile = File(...),
top_k: int = 5,
):
"""
Classify disaster scene using LADI-v2 (aerial damage categories).
Returns top-k predicted labels with confidence scores.
"""
ladi = load_ladi_model()
if ladi is None:
raise HTTPException(
status_code=503,
detail=f"LADI-v2 unavailable: {registry.get_error('ladi')}"
)
contents = await file.read()
try:
img_pil = Image.open(io.BytesIO(contents)).convert("RGB")
except Exception:
raise HTTPException(status_code=400, detail="Invalid image")
model = ladi["model"]
processor = ladi["processor"]
t0 = time.time()
try:
inputs = processor(images=img_pil, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0]
except Exception as e:
raise HTTPException(status_code=500, detail=f"Inference failed: {e}")
elapsed = round((time.time() - t0) * 1000, 2)
id2label = model.config.id2label
all_scores = sorted(
[
{
"class": id2label[i].lower().replace(" ", "_"),
"confidence": round(float(probs[i]), 4),
}
for i in range(len(probs))
],
key=lambda x: x["confidence"],
reverse=True,
)
relevant = [
s for s in all_scores
if not any(w in s["class"] for w in ["water", "flood"])
]
return {
"top_predictions": all_scores[:top_k],
"relevant_only": relevant[:top_k],
"all_scores": all_scores,
"inference_time_ms": elapsed,
}
# ─────────────────────────────────────────────
# 2. Victim Detection + Triage
# ─────────────────────────────────────────────
@app.post("/detect/victims")
async def detect_victims(
file: UploadFile = File(...),
confidence: float = 0.35,
):
"""
Detect victims and rank by triage priority.
Classes: injured_civilian, trapped_civilian, safe_civilian, rescue_personnel.
Priority ranks: CRITICAL / HIGH / MODERATE / LOW
"""
model = load_victim_model()
if model is None:
raise HTTPException(
status_code=503,
detail=f"Victim model unavailable: {registry.get_error('victim')}"
)
contents = await file.read()
img = read_image(contents)
t0 = time.time()
try:
results = model.predict(source=img, conf=confidence, verbose=False)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Inference failed: {e}")
elapsed = round((time.time() - t0) * 1000, 2)
raw = []
for r in results:
for box in r.boxes:
x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
conf_val = float(box.conf[0])
cls_id = int(box.cls[0])
raw.append({
"class": TARGET_CLASSES.get(cls_id, "unknown"),
"class_id": cls_id,
"confidence": round(conf_val, 4),
"box": {"xmin": x1, "ymin": y1, "xmax": x2, "ymax": y2},
})
triage = compute_triage(raw)
victims = triage.pop("ranked_victims", raw)
return {
"detections": victims,
"triage_summary": triage,
"inference_time_ms": elapsed,
}
# ─────────────────────────────────────────────
# 3. Emergency Vehicle Detection (Roboflow)
# ─────────────────────────────────────────────
@app.post("/detect/vehicles")
async def detect_vehicles(file: UploadFile = File(...)):
"""
Detect emergency vehicles via Roboflow hosted model.
Returns ambulance / fire truck detections and rescue_arrived flag.
"""
if not ROBOFLOW_API_KEY:
raise HTTPException(status_code=503, detail="ROBOFLOW_API_KEY secret not set")
contents = await file.read()
img = read_image(contents)
t0 = time.time()
detections = call_roboflow(img, "ambulance-4bova/1", confidence=40)
elapsed = round((time.time() - t0) * 1000, 2)
has_ambulance = any("ambulance" in d["class"].lower() for d in detections)
has_fire_truck = any("fire" in d["class"].lower() for d in detections)
return {
"detections": detections,
"emergency_vehicles": {
"ambulance_detected": has_ambulance,
"fire_truck_detected": has_fire_truck,
"rescue_arrived": has_ambulance or has_fire_truck,
},
"inference_time_ms": elapsed,
}
# ─────────────────────────────────────────────
# 4. xView2 Building Damage Assessment
# ─────────────────────────────────────────────
@app.post("/detect/damage")
async def detect_building_damage(
file: UploadFile = File(...),
confidence: float = 0.30,
):
"""
Assess building damage using xView2-trained YOLOv8.
Classes: destroyed, major_damage, minor_damage, no_damage.
Returns per-building detections, counts, and zone color.
"""
model = load_xview2_model()
if model is None:
raise HTTPException(
status_code=503,
detail=f"xView2 model unavailable: {registry.get_error('xview2')}"
)
contents = await file.read()
img = read_image(contents)
t0 = time.time()
try:
results = model.predict(source=img, conf=confidence, verbose=False)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Inference failed: {e}")
elapsed = round((time.time() - t0) * 1000, 2)
detections = []
counts = {"destroyed": 0, "major_damage": 0, "minor_damage": 0, "no_damage": 0}
for r in results:
for box in r.boxes:
x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
conf_val = float(box.conf[0])
cls_id = int(box.cls[0])
class_name = model.names[cls_id].lower().replace(" ", "_")
# Map raw class name to standard severity key
matched_key = next(
(k for k in counts if k in class_name),
"no_damage"
)
counts[matched_key] += 1
detections.append({
"class": class_name,
"confidence": round(conf_val, 4),
"severity": matched_key,
"box": {"xmin": x1, "ymin": y1, "xmax": x2, "ymax": y2},
})
# Sort: destroyed first, no_damage last
detections.sort(key=lambda d: DAMAGE_SEVERITY_ORDER.get(d["severity"], 9))
if counts["destroyed"] > 0:
zone_color = "red"
elif counts["major_damage"] > 0:
zone_color = "orange"
elif counts["minor_damage"] > 0:
zone_color = "yellow"
else:
zone_color = "green"
return {
"detections": detections,
"summary": counts,
"total_buildings": sum(counts.values()),
"zone_color": zone_color,
"inference_time_ms": elapsed,
}
# ─────────────────────────────────────────────
# 5. Full Analysis β€” all models in one call
# ─────────────────────────────────────────────
@app.post("/analyze/full")
async def full_analysis(
file: UploadFile = File(...),
run_classify: bool = True,
run_victims: bool = True,
run_vehicles: bool = True,
run_damage: bool = True,
):
"""
Run all available models on one image.
Main endpoint for the RoboXavier rover Flask app.
Returns unified zone_color, all detections, and triage/damage summaries.
"""
contents = await file.read()
t_total = time.time()
output = {}
# ── LADI scene classification ──
if run_classify:
try:
output["classification"] = await classify_scene(
UploadFile(filename="f.jpg", file=io.BytesIO(contents))
)
except HTTPException as e:
output["classification"] = {"error": e.detail}
except Exception as e:
output["classification"] = {"error": str(e)}
# ── Victim detection ──
if run_victims:
try:
output["victims"] = await detect_victims(
UploadFile(filename="f.jpg", file=io.BytesIO(contents))
)
except HTTPException as e:
output["victims"] = {"error": e.detail}
except Exception as e:
output["victims"] = {"error": str(e)}
# ── Emergency vehicle detection ──
if run_vehicles:
try:
output["vehicles"] = await detect_vehicles(
UploadFile(filename="f.jpg", file=io.BytesIO(contents))
)
except HTTPException as e:
output["vehicles"] = {"error": e.detail}
except Exception as e:
output["vehicles"] = {"error": str(e)}
# ── xView2 building damage ──
if run_damage:
try:
output["building_damage"] = await detect_building_damage(
UploadFile(filename="f.jpg", file=io.BytesIO(contents))
)
except HTTPException as e:
output["building_damage"] = {"error": e.detail}
except Exception as e:
output["building_damage"] = {"error": str(e)}
# ── Unified zone color (all signals combined) ──
triage_data = output.get("victims", {}).get("triage_summary", {})
damage_counts = output.get("building_damage", {}).get("summary", {})
classify_top = output.get("classification", {}).get("top_predictions", [{}])
top_class = classify_top[0].get("class", "") if classify_top else ""
zone_color = compute_zone_color(triage_data, damage_counts, top_class)
return {
"zone_color": zone_color,
"results": output,
"total_time_ms": round((time.time() - t_total) * 1000, 2),
"timestamp": time.time(),
}
# ════════════════════════════════
# Entry Point
# ════════════════════════════════
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)