AbhinavGupta commited on
Commit
325633e
Β·
verified Β·
1 Parent(s): 7f92c1d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +487 -0
app.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ =================================================================
3
+ DISASTER AI β€” HuggingFace Spaces API (Permanent Free Server)
4
+ =================================================================
5
+ Deploy this to HuggingFace Spaces for a permanent, always-on
6
+ API that your friends can call without knowing any AI.
7
+
8
+ Setup:
9
+ 1. Go to https://huggingface.co/spaces
10
+ 2. Create New Space β†’ SDK: Docker (or Gradio)
11
+ 3. Repo name: EgoisticCoderX/dokai-inference-api
12
+ 4. Upload this file as app.py
13
+ 5. Upload requirements.txt
14
+ 6. Add secrets in Space Settings:
15
+ - HF_VICTIM_MODEL_REPO = EgoisticCoderX/dokai-victim-detection
16
+ - ROBOFLOW_API_KEY = your_key
17
+
18
+ Your API will be at:
19
+ https://egoisticcoderx-dokai-inference-api.hf.space/
20
+
21
+ ZeroGPU Note: HF Spaces has ZeroGPU (A10G, free) for inference.
22
+ Add @spaces.GPU decorator for GPU-accelerated endpoints.
23
+ =================================================================
24
+ """
25
+
26
+ import os
27
+ import io
28
+ import json
29
+ import time
30
+ import base64
31
+ import threading
32
+ import numpy as np
33
+ from pathlib import Path
34
+ from PIL import Image
35
+ import cv2
36
+ import torch
37
+ import requests
38
+
39
+ from fastapi import FastAPI, File, UploadFile, HTTPException, Depends
40
+ from fastapi.middleware.cors import CORSMiddleware
41
+ from fastapi.responses import JSONResponse
42
+ from huggingface_hub import hf_hub_download
43
+
44
+ # ── ZeroGPU support (HuggingFace Spaces) ──
45
+ try:
46
+ import spaces
47
+ HAS_ZERO_GPU = True
48
+ print("βœ… ZeroGPU available")
49
+ except ImportError:
50
+ HAS_ZERO_GPU = False
51
+ print("ℹ️ ZeroGPU not available (running locally or non-GPU space)")
52
+
53
+
54
+ # ════════════════════════════════
55
+ # App Setup
56
+ # ════════════════════════════════
57
+ app = FastAPI(
58
+ title="Disaster AI Inference API",
59
+ description="Multi-model disaster scene analysis API for Dokai/RoboXavier",
60
+ version="1.0.0",
61
+ )
62
+
63
+ app.add_middleware(
64
+ CORSMiddleware,
65
+ allow_origins=["*"],
66
+ allow_methods=["*"],
67
+ allow_headers=["*"],
68
+ )
69
+
70
+ # ════════════════════════════════
71
+ # Configuration
72
+ # ════════════════════════════════
73
+ HF_VICTIM_MODEL_REPO = os.getenv("HF_VICTIM_MODEL_REPO", "EgoisticCoderX/dokai-victim-detection")
74
+ ROBOFLOW_API_KEY = os.getenv("ROBOFLOW_API_KEY", "")
75
+ MODEL_CACHE_DIR = "/tmp/model_cache"
76
+ os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
77
+
78
+ TARGET_CLASSES = {
79
+ 0: "injured_civilian",
80
+ 1: "trapped_civilian",
81
+ 2: "safe_civilian",
82
+ 3: "rescue_personnel",
83
+ }
84
+
85
+ CLASS_PRIORITY = {
86
+ "injured_civilian": 1.0,
87
+ "trapped_civilian": 0.95,
88
+ "safe_civilian": 0.3,
89
+ "rescue_personnel": 0.0,
90
+ }
91
+
92
+
93
+ # ════════════════════════════════
94
+ # Model Registry β€” lazy loading
95
+ # ════════════════════════════════
96
+ class ModelRegistry:
97
+ """
98
+ Lazy-loads models on first request.
99
+ Prevents OOM on Space startup.
100
+ """
101
+ def __init__(self):
102
+ self._models = {}
103
+ self._lock = threading.Lock()
104
+ self._loading = set()
105
+
106
+ def get(self, name: str):
107
+ if name in self._models:
108
+ return self._models[name]
109
+ return None
110
+
111
+ def register(self, name: str, model):
112
+ with self._lock:
113
+ self._models[name] = model
114
+ print(f"βœ… Model registered: {name}")
115
+
116
+ def is_loaded(self, name: str) -> bool:
117
+ return name in self._models
118
+
119
+
120
+ registry = ModelRegistry()
121
+
122
+
123
+ def load_victim_model():
124
+ """Download + load YOLOv8 victim detection model from HuggingFace."""
125
+ if registry.is_loaded("victim"):
126
+ return registry.get("victim")
127
+
128
+ try:
129
+ from ultralytics import YOLO
130
+ model_path = hf_hub_download(
131
+ repo_id=HF_VICTIM_MODEL_REPO,
132
+ filename="best.pt",
133
+ cache_dir=MODEL_CACHE_DIR,
134
+ )
135
+ model = YOLO(model_path)
136
+ registry.register("victim", model)
137
+ return model
138
+ except Exception as e:
139
+ print(f"❌ Failed to load victim model: {e}")
140
+ return None
141
+
142
+
143
+ def load_ladi_model():
144
+ """Load LADI-v2 classifier from HuggingFace."""
145
+ if registry.is_loaded("ladi"):
146
+ return registry.get("ladi")
147
+
148
+ try:
149
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
150
+ processor = AutoImageProcessor.from_pretrained(
151
+ "MITLL/LADI-v2-classifier-small",
152
+ cache_dir=MODEL_CACHE_DIR,
153
+ )
154
+ model = AutoModelForImageClassification.from_pretrained(
155
+ "MITLL/LADI-v2-classifier-small",
156
+ cache_dir=MODEL_CACHE_DIR,
157
+ )
158
+ model.eval()
159
+ registry.register("ladi", {"model": model, "processor": processor})
160
+ return registry.get("ladi")
161
+ except Exception as e:
162
+ print(f"❌ Failed to load LADI model: {e}")
163
+ return None
164
+
165
+
166
+ # ════════════════════════════════
167
+ # Utility Functions
168
+ # ════════════════════════════════
169
+ def read_image_from_upload(file_bytes: bytes) -> np.ndarray:
170
+ nparr = np.frombuffer(file_bytes, np.uint8)
171
+ img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
172
+ if img is None:
173
+ raise HTTPException(status_code=400, detail="Invalid image β€” cannot decode")
174
+ return img
175
+
176
+
177
+ def call_roboflow(image: np.ndarray, model_id: str, confidence: int = 40) -> list:
178
+ """Call Roboflow hosted model (ambulance, vest detection)."""
179
+ if not ROBOFLOW_API_KEY:
180
+ return []
181
+ try:
182
+ _, buffer = cv2.imencode('.jpg', image, [cv2.IMWRITE_JPEG_QUALITY, 80])
183
+ img_b64 = base64.b64encode(buffer)
184
+ url = f"https://detect.roboflow.com/{model_id}?api_key={ROBOFLOW_API_KEY}&confidence={confidence}"
185
+ res = requests.post(
186
+ url,
187
+ data=img_b64,
188
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
189
+ timeout=8,
190
+ )
191
+ res.raise_for_status()
192
+ preds = res.json().get("predictions", [])
193
+ return [
194
+ {
195
+ "class": p["class"],
196
+ "confidence": round(p["confidence"], 4),
197
+ "box": {
198
+ "xmin": int(p["x"] - p["width"] / 2),
199
+ "ymin": int(p["y"] - p["height"] / 2),
200
+ "xmax": int(p["x"] + p["width"] / 2),
201
+ "ymax": int(p["y"] + p["height"] / 2),
202
+ },
203
+ }
204
+ for p in preds
205
+ ]
206
+ except Exception as e:
207
+ print(f"Roboflow error ({model_id}): {e}")
208
+ return []
209
+
210
+
211
+ def compute_triage(detections: list) -> dict:
212
+ """Compute triage priority summary from victim detections."""
213
+ if not detections:
214
+ return {
215
+ "total": 0,
216
+ "critical": 0,
217
+ "high": 0,
218
+ "moderate": 0,
219
+ "low": 0,
220
+ "highest_score": 0.0,
221
+ "action": "βœ… No victims detected",
222
+ }
223
+
224
+ scored = []
225
+ for d in detections:
226
+ cls_name = d.get("class", "")
227
+ conf = d.get("confidence", 0.5)
228
+ weight = CLASS_PRIORITY.get(cls_name, 0.5)
229
+ score = conf * weight
230
+ rank = (
231
+ "CRITICAL" if score >= 0.7
232
+ else "HIGH" if score >= 0.4
233
+ else "MODERATE" if score >= 0.2
234
+ else "LOW"
235
+ )
236
+ scored.append({**d, "priority_score": round(score, 4), "priority_rank": rank})
237
+
238
+ scored.sort(key=lambda x: x["priority_score"], reverse=True)
239
+
240
+ critical = sum(1 for d in scored if d["priority_rank"] == "CRITICAL")
241
+ high = sum(1 for d in scored if d["priority_rank"] == "HIGH")
242
+ moderate = sum(1 for d in scored if d["priority_rank"] == "MODERATE")
243
+ low = sum(1 for d in scored if d["priority_rank"] == "LOW")
244
+
245
+ action = (
246
+ "⚠️ IMMEDIATE RESCUE β€” Critical victims present" if critical
247
+ else "πŸ”΄ Deploy rescue team β€” High priority victims" if high
248
+ else "🟑 Assess and triage β€” Moderate victims present" if moderate
249
+ else "🟒 Low priority β€” Monitor the area"
250
+ )
251
+
252
+ return {
253
+ "total": len(scored),
254
+ "critical": critical,
255
+ "high": high,
256
+ "moderate": moderate,
257
+ "low": low,
258
+ "highest_score": scored[0]["priority_score"] if scored else 0.0,
259
+ "action": action,
260
+ "ranked_victims": scored,
261
+ }
262
+
263
+
264
+ # ════════════════════════════════
265
+ # Routes
266
+ # ════════════════════════════════
267
+
268
+ @app.get("/")
269
+ def root():
270
+ return {
271
+ "service": "Disaster AI Inference API",
272
+ "version": "1.0.0",
273
+ "endpoints": {
274
+ "/health": "GET β€” Service health check",
275
+ "/detect/victims": "POST β€” Victim detection + triage",
276
+ "/detect/hazards": "POST β€” Fire, smoke, building damage",
277
+ "/detect/vehicles": "POST β€” Emergency vehicle detection",
278
+ "/classify": "POST β€” LADI-v2 scene classification",
279
+ "/analyze/full": "POST β€” All models in parallel (full analysis)",
280
+ }
281
+ }
282
+
283
+
284
+ @app.get("/health")
285
+ def health():
286
+ return {
287
+ "status": "ok",
288
+ "models_loaded": {
289
+ "victim_detection": registry.is_loaded("victim"),
290
+ "ladi_classifier": registry.is_loaded("ladi"),
291
+ },
292
+ "gpu_available": torch.cuda.is_available(),
293
+ "timestamp": time.time(),
294
+ }
295
+
296
+
297
+ @app.post("/detect/victims")
298
+ async def detect_victims(
299
+ file: UploadFile = File(...),
300
+ confidence: float = 0.35,
301
+ ):
302
+ """
303
+ Detect victims and classify by triage priority.
304
+ Returns detections with priority scores and recommended action.
305
+ """
306
+ contents = await file.read()
307
+ img = read_image_from_upload(contents)
308
+
309
+ model = load_victim_model()
310
+ if model is None:
311
+ raise HTTPException(status_code=503, detail="Victim detection model unavailable")
312
+
313
+ t0 = time.time()
314
+ results = model.predict(source=img, conf=confidence, verbose=False)
315
+ elapsed = round((time.time() - t0) * 1000, 2)
316
+
317
+ raw_detections = []
318
+ for r in results:
319
+ for box in r.boxes:
320
+ x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
321
+ conf_val = float(box.conf[0])
322
+ cls_id = int(box.cls[0])
323
+ raw_detections.append({
324
+ "class": TARGET_CLASSES.get(cls_id, "unknown"),
325
+ "class_id": cls_id,
326
+ "confidence": round(conf_val, 4),
327
+ "box": {"xmin": x1, "ymin": y1, "xmax": x2, "ymax": y2},
328
+ })
329
+
330
+ triage = compute_triage(raw_detections)
331
+
332
+ return {
333
+ "detections": triage.pop("ranked_victims", raw_detections),
334
+ "triage_summary": triage,
335
+ "inference_time_ms": elapsed,
336
+ }
337
+
338
+
339
+ @app.post("/detect/vehicles")
340
+ async def detect_vehicles(file: UploadFile = File(...)):
341
+ """Detect emergency vehicles using Roboflow model."""
342
+ contents = await file.read()
343
+ img = read_image_from_upload(contents)
344
+
345
+ t0 = time.time()
346
+ detections = call_roboflow(img, "ambulance-4bova/1", confidence=40)
347
+ elapsed = round((time.time() - t0) * 1000, 2)
348
+
349
+ has_ambulance = any("ambulance" in d["class"].lower() for d in detections)
350
+ has_fire_truck = any("fire" in d["class"].lower() for d in detections)
351
+
352
+ return {
353
+ "detections": detections,
354
+ "emergency_vehicles": {
355
+ "ambulance_detected": has_ambulance,
356
+ "fire_truck_detected": has_fire_truck,
357
+ "rescue_arrived": has_ambulance or has_fire_truck,
358
+ },
359
+ "inference_time_ms": elapsed,
360
+ }
361
+
362
+
363
+ @app.post("/classify")
364
+ async def classify_scene(file: UploadFile = File(...), top_k: int = 5):
365
+ """Classify disaster scene using LADI-v2."""
366
+ contents = await file.read()
367
+ img_pil = Image.open(io.BytesIO(contents)).convert("RGB")
368
+
369
+ ladi = load_ladi_model()
370
+ if ladi is None:
371
+ raise HTTPException(status_code=503, detail="LADI-v2 model unavailable")
372
+
373
+ model = ladi["model"]
374
+ processor = ladi["processor"]
375
+
376
+ t0 = time.time()
377
+ inputs = processor(images=img_pil, return_tensors="pt")
378
+ with torch.no_grad():
379
+ outputs = model(**inputs)
380
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0]
381
+ elapsed = round((time.time() - t0) * 1000, 2)
382
+
383
+ id2label = model.config.id2label
384
+ all_scores = sorted(
385
+ [
386
+ {
387
+ "class": id2label[i].lower().replace(" ", "_"),
388
+ "confidence": round(float(probs[i]), 4),
389
+ }
390
+ for i in range(len(probs))
391
+ ],
392
+ key=lambda x: x["confidence"],
393
+ reverse=True,
394
+ )
395
+
396
+ return {
397
+ "top_predictions": all_scores[:top_k],
398
+ "all_scores": all_scores,
399
+ "inference_time_ms": elapsed,
400
+ }
401
+
402
+
403
+ @app.post("/analyze/full")
404
+ async def full_analysis(
405
+ file: UploadFile = File(...),
406
+ run_victims: bool = True,
407
+ run_vehicles: bool = True,
408
+ run_classify: bool = True,
409
+ ):
410
+ """
411
+ Run all available models in parallel on one image.
412
+ This is what your rover Flask app should call for full scene analysis.
413
+
414
+ Returns a unified JSON with all detections + zone scoring.
415
+ """
416
+ import asyncio
417
+ from concurrent.futures import ThreadPoolExecutor
418
+
419
+ contents = await file.read()
420
+ t_total = time.time()
421
+
422
+ results = {}
423
+
424
+ with ThreadPoolExecutor(max_workers=3) as executor:
425
+ futures = {}
426
+
427
+ if run_victims:
428
+ async def _victims():
429
+ f = UploadFile(filename="frame.jpg", file=io.BytesIO(contents))
430
+ return await detect_victims(f)
431
+ futures["victims"] = asyncio.ensure_future(_victims())
432
+
433
+ if run_vehicles:
434
+ async def _vehicles():
435
+ f = UploadFile(filename="frame.jpg", file=io.BytesIO(contents))
436
+ return await detect_vehicles(f)
437
+ futures["vehicles"] = asyncio.ensure_future(_vehicles())
438
+
439
+ if run_classify:
440
+ async def _classify():
441
+ f = UploadFile(filename="frame.jpg", file=io.BytesIO(contents))
442
+ return await classify_scene(f)
443
+ futures["classification"] = asyncio.ensure_future(_classify())
444
+
445
+ for key, fut in futures.items():
446
+ try:
447
+ results[key] = await fut
448
+ except Exception as e:
449
+ results[key] = {"error": str(e)}
450
+
451
+ total_ms = round((time.time() - t_total) * 1000, 2)
452
+
453
+ # Determine overall zone color
454
+ victim_data = results.get("victims", {})
455
+ triage = victim_data.get("triage_summary", {})
456
+ classify_data = results.get("classification", {})
457
+
458
+ critical = triage.get("critical", 0)
459
+ high = triage.get("high", 0)
460
+
461
+ top_class = ""
462
+ if classify_data.get("top_predictions"):
463
+ top_class = classify_data["top_predictions"][0].get("class", "")
464
+
465
+ if critical > 0 or "destroy" in top_class or "collapse" in top_class:
466
+ zone_color = "red"
467
+ elif high > 0 or "major_damage" in top_class:
468
+ zone_color = "orange"
469
+ elif triage.get("total", 0) > 0:
470
+ zone_color = "yellow"
471
+ else:
472
+ zone_color = "green"
473
+
474
+ return {
475
+ "zone_color": zone_color,
476
+ "results": results,
477
+ "total_time_ms": total_ms,
478
+ "timestamp": time.time(),
479
+ }
480
+
481
+
482
+ # ════════════════════════════════
483
+ # Entry Point (local testing)
484
+ # ════════════════════════════════
485
+ if __name__ == "__main__":
486
+ import uvicorn
487
+ uvicorn.run(app, host="0.0.0.0", port=7860) # HF Spaces default port