Stroke-ia commited on
Commit
d58ff9a
·
verified ·
1 Parent(s): f7d71e0

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +45 -101
api.py CHANGED
@@ -4,14 +4,10 @@ from fastapi.responses import JSONResponse, StreamingResponse
4
  import uvicorn
5
  import logging
6
  import io
7
- import os
8
- from typing import Tuple, Optional
9
  import time
10
  import numpy as np
11
  from PIL import Image
12
  import cv2
13
-
14
- # ML
15
  from ultralytics import YOLO
16
  import mediapipe as mp
17
 
@@ -29,10 +25,7 @@ def verify_api_key(api_key: str = Security(api_key_header)):
29
  # ==========================
30
  # 📝 Logger
31
  # ==========================
32
- logging.basicConfig(
33
- level=logging.INFO,
34
- format="%(asctime)s - %(levelname)s - %(message)s"
35
- )
36
  logger = logging.getLogger("stroke-api")
37
 
38
  # ==========================
@@ -41,11 +34,7 @@ logger = logging.getLogger("stroke-api")
41
  app = FastAPI(
42
  title="Stroke Detection API",
43
  version="1.2.0",
44
- description="""
45
- 🚑 Stroke Detection API using YOLOv8 + Face Detection (MediaPipe)
46
-
47
- ⚠️ **Disclaimer**: Research/demo only — not a medical device.
48
- """
49
  )
50
 
51
  # ==========================
@@ -66,18 +55,13 @@ mp_face_detection = mp.solutions.face_detection
66
  ALLOWED_EXT = (".png", ".jpg", ".jpeg")
67
  ALLOWED_MIME = {"image/png", "image/jpeg"}
68
  MAX_BYTES = 8 * 1024 * 1024 # 8 MB
69
- CROP_ON_FACE = True # recadrer sur le visage détecté
70
 
71
  def _validate_file(file: UploadFile, raw: bytes):
72
- # extension
73
  if not file.filename.lower().endswith(ALLOWED_EXT):
74
- raise HTTPException(status_code=400, detail="Invalid file extension. Use .png/.jpg/.jpeg")
75
- # MIME
76
- if (file.content_type or "").lower() not in ALLOWED_MIME:
77
- # On continue si extension OK mais content_type vide côté client
78
- if file.content_type:
79
- raise HTTPException(status_code=400, detail="Invalid content-type. Use image/png or image/jpeg")
80
- # taille
81
  if len(raw) > MAX_BYTES:
82
  raise HTTPException(status_code=413, detail=f"Image too large. Max {MAX_BYTES//(1024*1024)} MB")
83
 
@@ -88,10 +72,7 @@ def _read_image_to_numpy(raw: bytes) -> np.ndarray:
88
  except Exception:
89
  raise HTTPException(status_code=400, detail="Unreadable image file")
90
 
91
- def _largest_face_bbox(np_img: np.ndarray, min_conf: float = 0.6) -> Optional[Tuple[int,int,int,int]]:
92
- """
93
- Retourne (x1,y1,x2,y2) du plus grand visage détecté, ou None.
94
- """
95
  h, w = np_img.shape[:2]
96
  with mp_face_detection.FaceDetection(min_detection_confidence=min_conf) as fd:
97
  results = fd.process(cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR))
@@ -105,26 +86,22 @@ def _largest_face_bbox(np_img: np.ndarray, min_conf: float = 0.6) -> Optional[Tu
105
  x2 = int(min(1.0, rel.xmin + rel.width) * w)
106
  y2 = int(min(1.0, rel.ymin + rel.height) * h)
107
  boxes.append((x1, y1, x2, y2))
108
- # choisir le plus grand
109
  boxes.sort(key=lambda b: (b[2]-b[0])*(b[3]-b[1]), reverse=True)
110
  return boxes[0] if boxes else None
111
 
112
- def _crop_to_bbox(np_img: np.ndarray, bbox: Tuple[int,int,int,int], margin: float = 0.15) -> np.ndarray:
113
  h, w = np_img.shape[:2]
114
  x1, y1, x2, y2 = bbox
115
  bw, bh = x2 - x1, y2 - y1
116
- # marge autour du visage
117
- dx, dy = int(bw * margin), int(bh * margin)
118
- X1 = max(0, x1 - dx)
119
- Y1 = max(0, y1 - dy)
120
- X2 = min(w, x2 + dx)
121
- Y2 = min(h, y2 + dy)
122
  return np_img[Y1:Y2, X1:X2].copy()
123
 
124
- def _annotate_face_box(np_img: np.ndarray, bbox: Tuple[int,int,int,int]) -> np.ndarray:
125
  annotated = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR).copy()
126
  x1, y1, x2, y2 = bbox
127
- cv2.rectangle(annotated, (x1, y1), (x2, y2), (0, 255, 0), 2) # couleur par défaut
128
  return annotated
129
 
130
  # ==========================
@@ -138,121 +115,88 @@ async def health():
138
  # 📦 Endpoint JSON
139
  # ==========================
140
  @app.post("/v1/predict/")
141
- async def predict(
142
- file: UploadFile = File(...),
143
- api_key: str = Depends(verify_api_key)
144
- ):
145
  raw = await file.read()
146
  _validate_file(file, raw)
147
-
148
  try:
149
  np_img = _read_image_to_numpy(raw)
150
-
151
- # 1) Détection visage obligatoire
152
  face_bbox = _largest_face_bbox(np_img)
153
  if face_bbox is None:
154
- return JSONResponse(
155
- status_code=422,
156
- content={"status": "error", "message": "Aucun visage humain détecté. Veuillez centrer le visage."}
157
- )
158
-
159
- # 2) Option : recadrer sur le visage pour améliorer la détection
160
  input_img = _crop_to_bbox(np_img, face_bbox) if CROP_ON_FACE else np_img
161
 
162
- # 3) YOLO inference (en mémoire)
163
  start_time = time.time()
164
  results = model.predict(source=input_img, verbose=False)
165
  elapsed = time.time() - start_time
166
-
167
-
168
- # 4) Format des prédictions
169
- output = []
170
- for r in results:
171
- for box in r.boxes:
172
- output.append({
173
- "class": r.names[int(box.cls[0].item())],
174
- "confidence": float(box.conf[0].item()),
175
- "bbox": box.xyxy[0].tolist()
176
- })
177
 
178
- logger.info(f"/predict {file.filename} -> {len(output)} detections (face ok)")
 
 
 
 
179
  return JSONResponse(content={
180
  "status": "ok",
181
  "face_detected": True,
182
  "face_bbox": list(map(int, face_bbox)),
183
- "predictions": output
 
184
  })
185
 
186
- except HTTPException:
187
- raise
188
  except Exception as e:
189
  logger.exception("Error in /v1/predict")
190
  raise HTTPException(status_code=500, detail=str(e))
191
 
192
  # ==========================
193
- # 🖼️ Endpoint Image (annotée)
194
  # ==========================
195
  @app.post("/v1/predict_image/")
196
- async def predict_image(
197
- file: UploadFile = File(...),
198
- api_key: str = Depends(verify_api_key)
199
- ):
200
  raw = await file.read()
201
  _validate_file(file, raw)
202
-
203
  try:
204
  np_img = _read_image_to_numpy(raw)
205
-
206
- # 1) Détection visage
207
  face_bbox = _largest_face_bbox(np_img)
208
  if face_bbox is None:
209
- return JSONResponse(
210
- status_code=422,
211
- content={"status": "error", "message": "Aucun visage humain détecté. Veuillez centrer le visage."}
212
- )
213
-
214
- # 2) Recadrer sur le visage (optionnel)
215
  input_img = _crop_to_bbox(np_img, face_bbox) if CROP_ON_FACE else np_img
216
 
217
- # 3) YOLO
218
  start_time = time.time()
219
  results = model.predict(source=input_img, verbose=False)
220
  elapsed = time.time() - start_time
221
 
222
- # 4) Annotations YOLO
223
- yolo_annot = results[0].plot() # BGR
224
- yolo_annot = cv2.cvtColor(yolo_annot, cv2.COLOR_BGR2RGB)
225
 
226
- # 5) Si on n’a pas recadré, on dessine aussi le cadre visage sur l’image d’origine
227
- if not CROP_ON_FACE:
228
- annotated = _annotate_face_box(np_img, face_bbox)
229
- # fusion simple : ici on retourne juste l’annot YOLO (non redimensionnée)
230
- out_rgb = annotated
231
- else:
232
- # On retourne l’image annotée sur le crop visage
233
- out_rgb = yolo_annot
234
-
235
- # 6) Retour en PNG (stream)
236
  pil_img = Image.fromarray(out_rgb)
237
  buf = io.BytesIO()
238
  pil_img.save(buf, format="PNG")
239
  buf.seek(0)
240
-
241
- # 7) Ajouter temps d'inférence dans header
242
- headers = {"X-Inference-Time": str(round(elapsed,3))}
243
 
244
- logger.info(f"/predict_image {file.filename} -> face ok + image annotée")
245
- return StreamingResponse(buf, media_type="image/png")
246
 
247
- except HTTPException:
248
- raise
249
  except Exception as e:
250
  logger.exception("Error in /v1/predict_image")
251
  raise HTTPException(status_code=500, detail=str(e))
252
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  # ==========================
254
  # 🚀 Lancement local
255
  # ==========================
256
  if __name__ == "__main__":
257
- # Sur HF Spaces, c’est Gradio/Space qui lance; localement :
258
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
4
  import uvicorn
5
  import logging
6
  import io
 
 
7
  import time
8
  import numpy as np
9
  from PIL import Image
10
  import cv2
 
 
11
  from ultralytics import YOLO
12
  import mediapipe as mp
13
 
 
25
  # ==========================
26
  # 📝 Logger
27
  # ==========================
28
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
 
 
 
29
  logger = logging.getLogger("stroke-api")
30
 
31
  # ==========================
 
34
  app = FastAPI(
35
  title="Stroke Detection API",
36
  version="1.2.0",
37
+ description="🚑 Stroke Detection API using YOLOv8 + Face Detection (MediaPipe). Research/demo only."
 
 
 
 
38
  )
39
 
40
  # ==========================
 
55
  ALLOWED_EXT = (".png", ".jpg", ".jpeg")
56
  ALLOWED_MIME = {"image/png", "image/jpeg"}
57
  MAX_BYTES = 8 * 1024 * 1024 # 8 MB
58
+ CROP_ON_FACE = True
59
 
60
  def _validate_file(file: UploadFile, raw: bytes):
 
61
  if not file.filename.lower().endswith(ALLOWED_EXT):
62
+ raise HTTPException(status_code=400, detail="Invalid file extension")
63
+ if (file.content_type or "").lower() not in ALLOWED_MIME and file.content_type:
64
+ raise HTTPException(status_code=400, detail="Invalid content-type")
 
 
 
 
65
  if len(raw) > MAX_BYTES:
66
  raise HTTPException(status_code=413, detail=f"Image too large. Max {MAX_BYTES//(1024*1024)} MB")
67
 
 
72
  except Exception:
73
  raise HTTPException(status_code=400, detail="Unreadable image file")
74
 
75
+ def _largest_face_bbox(np_img: np.ndarray, min_conf: float = 0.6):
 
 
 
76
  h, w = np_img.shape[:2]
77
  with mp_face_detection.FaceDetection(min_detection_confidence=min_conf) as fd:
78
  results = fd.process(cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR))
 
86
  x2 = int(min(1.0, rel.xmin + rel.width) * w)
87
  y2 = int(min(1.0, rel.ymin + rel.height) * h)
88
  boxes.append((x1, y1, x2, y2))
 
89
  boxes.sort(key=lambda b: (b[2]-b[0])*(b[3]-b[1]), reverse=True)
90
  return boxes[0] if boxes else None
91
 
92
+ def _crop_to_bbox(np_img: np.ndarray, bbox, margin: float = 0.15) -> np.ndarray:
93
  h, w = np_img.shape[:2]
94
  x1, y1, x2, y2 = bbox
95
  bw, bh = x2 - x1, y2 - y1
96
+ dx, dy = int(bw*margin), int(bh*margin)
97
+ X1, Y1 = max(0,x1-dx), max(0,y1-dy)
98
+ X2, Y2 = min(w,x2+dx), min(h,y2+dy)
 
 
 
99
  return np_img[Y1:Y2, X1:X2].copy()
100
 
101
+ def _annotate_face_box(np_img: np.ndarray, bbox) -> np.ndarray:
102
  annotated = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR).copy()
103
  x1, y1, x2, y2 = bbox
104
+ cv2.rectangle(annotated, (x1, y1), (x2, y2), (0,255,0), 2)
105
  return annotated
106
 
107
  # ==========================
 
115
  # 📦 Endpoint JSON
116
  # ==========================
117
  @app.post("/v1/predict/")
118
+ async def predict(file: UploadFile = File(...), api_key: str = Depends(verify_api_key)):
 
 
 
119
  raw = await file.read()
120
  _validate_file(file, raw)
 
121
  try:
122
  np_img = _read_image_to_numpy(raw)
 
 
123
  face_bbox = _largest_face_bbox(np_img)
124
  if face_bbox is None:
125
+ return JSONResponse(status_code=422, content={"status":"error","message":"Aucun visage détecté"})
 
 
 
 
 
126
  input_img = _crop_to_bbox(np_img, face_bbox) if CROP_ON_FACE else np_img
127
 
 
128
  start_time = time.time()
129
  results = model.predict(source=input_img, verbose=False)
130
  elapsed = time.time() - start_time
 
 
 
 
 
 
 
 
 
 
 
131
 
132
+ output = [{"class": r.names[int(box.cls[0].item())],
133
+ "confidence": float(box.conf[0].item()),
134
+ "bbox": box.xyxy[0].tolist()}
135
+ for r in results for box in r.boxes]
136
+
137
  return JSONResponse(content={
138
  "status": "ok",
139
  "face_detected": True,
140
  "face_bbox": list(map(int, face_bbox)),
141
+ "predictions": output,
142
+ "inference_time_sec": round(elapsed,3)
143
  })
144
 
 
 
145
  except Exception as e:
146
  logger.exception("Error in /v1/predict")
147
  raise HTTPException(status_code=500, detail=str(e))
148
 
149
  # ==========================
150
+ # 🖼️ Endpoint Image annotée
151
  # ==========================
152
  @app.post("/v1/predict_image/")
153
+ async def predict_image(file: UploadFile = File(...), api_key: str = Depends(verify_api_key)):
 
 
 
154
  raw = await file.read()
155
  _validate_file(file, raw)
 
156
  try:
157
  np_img = _read_image_to_numpy(raw)
 
 
158
  face_bbox = _largest_face_bbox(np_img)
159
  if face_bbox is None:
160
+ return JSONResponse(status_code=422, content={"status":"error","message":"Aucun visage détecté"})
 
 
 
 
 
161
  input_img = _crop_to_bbox(np_img, face_bbox) if CROP_ON_FACE else np_img
162
 
 
163
  start_time = time.time()
164
  results = model.predict(source=input_img, verbose=False)
165
  elapsed = time.time() - start_time
166
 
167
+ yolo_annot = cv2.cvtColor(results[0].plot(), cv2.COLOR_BGR2RGB)
168
+ out_rgb = yolo_annot if CROP_ON_FACE else _annotate_face_box(np_img, face_bbox)
 
169
 
 
 
 
 
 
 
 
 
 
 
170
  pil_img = Image.fromarray(out_rgb)
171
  buf = io.BytesIO()
172
  pil_img.save(buf, format="PNG")
173
  buf.seek(0)
 
 
 
174
 
175
+ headers = {"X-Inference-Time": str(round(elapsed,3))}
176
+ return StreamingResponse(buf, media_type="image/png", headers=headers)
177
 
 
 
178
  except Exception as e:
179
  logger.exception("Error in /v1/predict_image")
180
  raise HTTPException(status_code=500, detail=str(e))
181
 
182
+ # ==========================
183
+ # 🧪 Test automatique
184
+ # ==========================
185
+ @app.get("/test_upload/")
186
+ async def test_upload():
187
+ try:
188
+ file_path = "test.jpg"
189
+ np_img = _read_image_to_numpy(open(file_path,"rb").read())
190
+ face_bbox = _largest_face_bbox(np_img)
191
+ if not face_bbox:
192
+ return {"status":"error","message":"Aucun visage détecté"}
193
+ results = model.predict(source=np_img, verbose=False)
194
+ return {"status":"ok","face_detected":True,"num_detections":len(results[0].boxes)}
195
+ except Exception as e:
196
+ return {"status":"error","message": str(e)}
197
+
198
  # ==========================
199
  # 🚀 Lancement local
200
  # ==========================
201
  if __name__ == "__main__":
 
202
  uvicorn.run(app, host="0.0.0.0", port=7860)