Stroke-ia commited on
Commit
f7d71e0
·
verified ·
1 Parent(s): 2e3539a

Create api.py

Browse files
Files changed (1) hide show
  1. api.py +258 -0
api.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException, Security, Depends
2
+ from fastapi.security.api_key import APIKeyHeader
3
+ from fastapi.responses import JSONResponse, StreamingResponse
4
+ import uvicorn
5
+ import logging
6
+ import io
7
+ import os
8
+ from typing import Tuple, Optional
9
+ import time
10
+ import numpy as np
11
+ from PIL import Image
12
+ import cv2
13
+
14
+ # ML
15
+ from ultralytics import YOLO
16
+ import mediapipe as mp
17
+
18
+ # ==========================
19
+ # 🔑 Sécurité : API Key
20
+ # ==========================
21
+ API_KEY = "1234" # ⚠️ change avant de partager
22
+ api_key_header = APIKeyHeader(name="X-API-Key")
23
+
24
+ def verify_api_key(api_key: str = Security(api_key_header)):
25
+ if api_key != API_KEY:
26
+ raise HTTPException(status_code=403, detail="Forbidden")
27
+ return api_key
28
+
29
+ # ==========================
30
+ # 📝 Logger
31
+ # ==========================
32
+ logging.basicConfig(
33
+ level=logging.INFO,
34
+ format="%(asctime)s - %(levelname)s - %(message)s"
35
+ )
36
+ logger = logging.getLogger("stroke-api")
37
+
38
+ # ==========================
39
+ # 🚀 App
40
+ # ==========================
41
+ app = FastAPI(
42
+ title="Stroke Detection API",
43
+ version="1.2.0",
44
+ description="""
45
+ 🚑 Stroke Detection API using YOLOv8 + Face Detection (MediaPipe)
46
+
47
+ ⚠️ **Disclaimer**: Research/demo only — not a medical device.
48
+ """
49
+ )
50
+
51
+ # ==========================
52
+ # 📦 Chargement modèles
53
+ # ==========================
54
+ try:
55
+ model = YOLO("best.pt")
56
+ logger.info("✅ YOLO model loaded.")
57
+ except Exception as e:
58
+ logger.exception("❌ Failed to load YOLO model")
59
+ raise RuntimeError(f"Model loading failed: {e}")
60
+
61
+ mp_face_detection = mp.solutions.face_detection
62
+
63
+ # ==========================
64
+ # 🔧 Utilitaires
65
+ # ==========================
66
+ ALLOWED_EXT = (".png", ".jpg", ".jpeg")
67
+ ALLOWED_MIME = {"image/png", "image/jpeg"}
68
+ MAX_BYTES = 8 * 1024 * 1024 # 8 MB
69
+ CROP_ON_FACE = True # recadrer sur le visage détecté
70
+
71
+ def _validate_file(file: UploadFile, raw: bytes):
72
+ # extension
73
+ if not file.filename.lower().endswith(ALLOWED_EXT):
74
+ raise HTTPException(status_code=400, detail="Invalid file extension. Use .png/.jpg/.jpeg")
75
+ # MIME
76
+ if (file.content_type or "").lower() not in ALLOWED_MIME:
77
+ # On continue si extension OK mais content_type vide côté client
78
+ if file.content_type:
79
+ raise HTTPException(status_code=400, detail="Invalid content-type. Use image/png or image/jpeg")
80
+ # taille
81
+ if len(raw) > MAX_BYTES:
82
+ raise HTTPException(status_code=413, detail=f"Image too large. Max {MAX_BYTES//(1024*1024)} MB")
83
+
84
+ def _read_image_to_numpy(raw: bytes) -> np.ndarray:
85
+ try:
86
+ img = Image.open(io.BytesIO(raw)).convert("RGB")
87
+ return np.array(img)
88
+ except Exception:
89
+ raise HTTPException(status_code=400, detail="Unreadable image file")
90
+
91
+ def _largest_face_bbox(np_img: np.ndarray, min_conf: float = 0.6) -> Optional[Tuple[int,int,int,int]]:
92
+ """
93
+ Retourne (x1,y1,x2,y2) du plus grand visage détecté, ou None.
94
+ """
95
+ h, w = np_img.shape[:2]
96
+ with mp_face_detection.FaceDetection(min_detection_confidence=min_conf) as fd:
97
+ results = fd.process(cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR))
98
+ if not results.detections:
99
+ return None
100
+ boxes = []
101
+ for det in results.detections:
102
+ rel = det.location_data.relative_bounding_box
103
+ x1 = int(max(0, rel.xmin) * w)
104
+ y1 = int(max(0, rel.ymin) * h)
105
+ x2 = int(min(1.0, rel.xmin + rel.width) * w)
106
+ y2 = int(min(1.0, rel.ymin + rel.height) * h)
107
+ boxes.append((x1, y1, x2, y2))
108
+ # choisir le plus grand
109
+ boxes.sort(key=lambda b: (b[2]-b[0])*(b[3]-b[1]), reverse=True)
110
+ return boxes[0] if boxes else None
111
+
112
+ def _crop_to_bbox(np_img: np.ndarray, bbox: Tuple[int,int,int,int], margin: float = 0.15) -> np.ndarray:
113
+ h, w = np_img.shape[:2]
114
+ x1, y1, x2, y2 = bbox
115
+ bw, bh = x2 - x1, y2 - y1
116
+ # marge autour du visage
117
+ dx, dy = int(bw * margin), int(bh * margin)
118
+ X1 = max(0, x1 - dx)
119
+ Y1 = max(0, y1 - dy)
120
+ X2 = min(w, x2 + dx)
121
+ Y2 = min(h, y2 + dy)
122
+ return np_img[Y1:Y2, X1:X2].copy()
123
+
124
+ def _annotate_face_box(np_img: np.ndarray, bbox: Tuple[int,int,int,int]) -> np.ndarray:
125
+ annotated = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR).copy()
126
+ x1, y1, x2, y2 = bbox
127
+ cv2.rectangle(annotated, (x1, y1), (x2, y2), (0, 255, 0), 2) # couleur par défaut
128
+ return annotated
129
+
130
+ # ==========================
131
+ # 🩺 Healthcheck
132
+ # ==========================
133
+ @app.get("/health")
134
+ async def health():
135
+ return {"status": "ok", "model_loaded": True}
136
+
137
+ # ==========================
138
+ # 📦 Endpoint JSON
139
+ # ==========================
140
+ @app.post("/v1/predict/")
141
+ async def predict(
142
+ file: UploadFile = File(...),
143
+ api_key: str = Depends(verify_api_key)
144
+ ):
145
+ raw = await file.read()
146
+ _validate_file(file, raw)
147
+
148
+ try:
149
+ np_img = _read_image_to_numpy(raw)
150
+
151
+ # 1) Détection visage obligatoire
152
+ face_bbox = _largest_face_bbox(np_img)
153
+ if face_bbox is None:
154
+ return JSONResponse(
155
+ status_code=422,
156
+ content={"status": "error", "message": "Aucun visage humain détecté. Veuillez centrer le visage."}
157
+ )
158
+
159
+ # 2) Option : recadrer sur le visage pour améliorer la détection
160
+ input_img = _crop_to_bbox(np_img, face_bbox) if CROP_ON_FACE else np_img
161
+
162
+ # 3) YOLO inference (en mémoire)
163
+ start_time = time.time()
164
+ results = model.predict(source=input_img, verbose=False)
165
+ elapsed = time.time() - start_time
166
+
167
+
168
+ # 4) Format des prédictions
169
+ output = []
170
+ for r in results:
171
+ for box in r.boxes:
172
+ output.append({
173
+ "class": r.names[int(box.cls[0].item())],
174
+ "confidence": float(box.conf[0].item()),
175
+ "bbox": box.xyxy[0].tolist()
176
+ })
177
+
178
+ logger.info(f"/predict {file.filename} -> {len(output)} detections (face ok)")
179
+ return JSONResponse(content={
180
+ "status": "ok",
181
+ "face_detected": True,
182
+ "face_bbox": list(map(int, face_bbox)),
183
+ "predictions": output
184
+ })
185
+
186
+ except HTTPException:
187
+ raise
188
+ except Exception as e:
189
+ logger.exception("Error in /v1/predict")
190
+ raise HTTPException(status_code=500, detail=str(e))
191
+
192
+ # ==========================
193
+ # 🖼️ Endpoint Image (annotée)
194
+ # ==========================
195
+ @app.post("/v1/predict_image/")
196
+ async def predict_image(
197
+ file: UploadFile = File(...),
198
+ api_key: str = Depends(verify_api_key)
199
+ ):
200
+ raw = await file.read()
201
+ _validate_file(file, raw)
202
+
203
+ try:
204
+ np_img = _read_image_to_numpy(raw)
205
+
206
+ # 1) Détection visage
207
+ face_bbox = _largest_face_bbox(np_img)
208
+ if face_bbox is None:
209
+ return JSONResponse(
210
+ status_code=422,
211
+ content={"status": "error", "message": "Aucun visage humain détecté. Veuillez centrer le visage."}
212
+ )
213
+
214
+ # 2) Recadrer sur le visage (optionnel)
215
+ input_img = _crop_to_bbox(np_img, face_bbox) if CROP_ON_FACE else np_img
216
+
217
+ # 3) YOLO
218
+ start_time = time.time()
219
+ results = model.predict(source=input_img, verbose=False)
220
+ elapsed = time.time() - start_time
221
+
222
+ # 4) Annotations YOLO
223
+ yolo_annot = results[0].plot() # BGR
224
+ yolo_annot = cv2.cvtColor(yolo_annot, cv2.COLOR_BGR2RGB)
225
+
226
+ # 5) Si on n’a pas recadré, on dessine aussi le cadre visage sur l’image d’origine
227
+ if not CROP_ON_FACE:
228
+ annotated = _annotate_face_box(np_img, face_bbox)
229
+ # fusion simple : ici on retourne juste l’annot YOLO (non redimensionnée)
230
+ out_rgb = annotated
231
+ else:
232
+ # On retourne l’image annotée sur le crop visage
233
+ out_rgb = yolo_annot
234
+
235
+ # 6) Retour en PNG (stream)
236
+ pil_img = Image.fromarray(out_rgb)
237
+ buf = io.BytesIO()
238
+ pil_img.save(buf, format="PNG")
239
+ buf.seek(0)
240
+
241
+ # 7) Ajouter temps d'inférence dans header
242
+ headers = {"X-Inference-Time": str(round(elapsed,3))}
243
+
244
+ logger.info(f"/predict_image {file.filename} -> face ok + image annotée")
245
+ return StreamingResponse(buf, media_type="image/png")
246
+
247
+ except HTTPException:
248
+ raise
249
+ except Exception as e:
250
+ logger.exception("Error in /v1/predict_image")
251
+ raise HTTPException(status_code=500, detail=str(e))
252
+
253
+ # ==========================
254
+ # 🚀 Lancement local
255
+ # ==========================
256
+ if __name__ == "__main__":
257
+ # Sur HF Spaces, c’est Gradio/Space qui lance; localement :
258
+ uvicorn.run(app, host="0.0.0.0", port=7860)