Stroke-ia commited on
Commit
da29b06
·
verified ·
1 Parent(s): 19d9980

Upload api.py

Browse files
Files changed (1) hide show
  1. api.py +104 -0
api.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException, Security, Depends
2
+ from fastapi.security.api_key import APIKeyHeader
3
+ from fastapi.responses import FileResponse, JSONResponse
4
+ import uvicorn
5
+ import shutil
6
+ import os
7
+ from ultralytics import YOLO
8
+
9
+ # ==========================
10
+ # 🔑 Sécurité : API Key
11
+ # ==========================
12
+ API_KEY = "super-secret-key" # <-- Change ici avant de partager
13
+ api_key_header = APIKeyHeader(name="X-API-Key")
14
+
15
+ def verify_api_key(api_key: str = Security(api_key_header)):
16
+ if api_key != API_KEY:
17
+ raise HTTPException(status_code=403, detail="Forbidden")
18
+ return api_key
19
+
20
+ # ==========================
21
+ # 🚀 Application
22
+ # ==========================
23
+ app = FastAPI(
24
+ title="Stroke Detection API",
25
+ version="1.0.0",
26
+ description="""
27
+ 🚑 Stroke Detection API using YOLOv8
28
+
29
+ ⚠️ **Disclaimer**: This API is for **research/demo purposes only**.
30
+ It is **not a certified medical tool**. Do not use for medical decisions.
31
+ """
32
+ )
33
+
34
+ # Charger ton modèle YOLOv8
35
+ model = YOLO("best.pt")
36
+
37
+ # ==========================
38
+ # 📦 Endpoint JSON
39
+ # ==========================
40
+ @app.post("/v1/predict/")
41
+ async def predict(
42
+ file: UploadFile = File(...),
43
+ api_key: str = Depends(verify_api_key)
44
+ ):
45
+ try:
46
+ # Sauvegarde temporaire
47
+ temp_file = f"temp_{file.filename}"
48
+ with open(temp_file, "wb") as buffer:
49
+ shutil.copyfileobj(file.file, buffer)
50
+
51
+ # Prédiction YOLO
52
+ results = model.predict(temp_file)
53
+
54
+ output = []
55
+ for r in results:
56
+ for box in r.boxes:
57
+ output.append({
58
+ "class": r.names[int(box.cls[0].item())],
59
+ "confidence": float(box.conf[0].item()),
60
+ "bbox": box.xyxy[0].tolist() # [x1, y1, x2, y2]
61
+ })
62
+
63
+ return JSONResponse(content={"predictions": output})
64
+
65
+ except Exception as e:
66
+ raise HTTPException(status_code=500, detail=str(e))
67
+ finally:
68
+ if os.path.exists(temp_file):
69
+ os.remove(temp_file)
70
+
71
+ # ==========================
72
+ # 📦 Endpoint Image
73
+ # ==========================
74
+ @app.post("/v1/predict_image/")
75
+ async def predict_image(
76
+ file: UploadFile = File(...),
77
+ api_key: str = Depends(verify_api_key)
78
+ ):
79
+ try:
80
+ # Sauvegarde temporaire
81
+ temp_file = f"temp_{file.filename}"
82
+ with open(temp_file, "wb") as buffer:
83
+ shutil.copyfileobj(file.file, buffer)
84
+
85
+ # Prédiction YOLO avec sauvegarde image annotée
86
+ results = model.predict(temp_file, save=True, project="runs", name="detect", exist_ok=True)
87
+
88
+ # Récupérer la dernière image annotée
89
+ output_dir = results[0].save_dir
90
+ output_path = os.path.join(output_dir, os.path.basename(temp_file))
91
+
92
+ return FileResponse(output_path, media_type="image/png")
93
+
94
+ except Exception as e:
95
+ raise HTTPException(status_code=500, detail=str(e))
96
+ finally:
97
+ if os.path.exists(temp_file):
98
+ os.remove(temp_file)
99
+
100
+ # ==========================
101
+ # 🚀 Lancement local
102
+ # ==========================
103
+ if __name__ == "__main__":
104
+ uvicorn.run(app, host="0.0.0.0", port=7860)