Stroke-ia commited on
Commit
56632ca
·
verified ·
1 Parent(s): 71f9742

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +110 -109
api.py CHANGED
@@ -1,109 +1,110 @@
1
- from fastapi import FastAPI, UploadFile, File
2
- from fastapi.responses import JSONResponse, StreamingResponse
3
- from fastapi.middleware.cors import CORSMiddleware
4
- from ultralytics import YOLO
5
- import numpy as np
6
- from PIL import Image
7
- import io
8
- import cv2
9
-
10
- # Load YOLO model
11
- model = YOLO("best.pt")
12
-
13
- # Class labels
14
- CLASS_NAMES = [
15
- "normalEye",
16
- "normalMouth",
17
- "strokeEyeMid",
18
- "strokeEyeSevere",
19
- "strokeEyeWeak",
20
- "strokeMouthMid",
21
- "strokeMouthSevere",
22
- "strokeMouthWeak"
23
- ]
24
-
25
- # Initialize FastAPI app
26
- app = FastAPI(
27
- title="Stroke-IA Detection API",
28
- description="REST API for stroke sign detection (tech demo, not medical advice).",
29
- version="1.0"
30
- )
31
-
32
- # ✅ Enable CORS (to avoid fetch issues in Swagger UI or front-end)
33
- app.add_middleware(
34
- CORSMiddleware,
35
- allow_origins=["*"],
36
- allow_credentials=True,
37
- allow_methods=["*"],
38
- allow_headers=["*"],
39
- )
40
-
41
- @app.get("/")
42
- async def root():
43
- return {"message": "Stroke-IA API is running. Use /predict/ or /predict_image/."}
44
-
45
- @app.post("/predict/")
46
- async def predict(file: UploadFile = File(...)):
47
- """
48
- Returns JSON with detections (no image).
49
- """
50
- try:
51
- contents = await file.read()
52
- image = Image.open(io.BytesIO(contents)).convert("RGB")
53
- np_image = np.array(image)
54
-
55
- results = model.predict(source=np_image, conf=0.85, verbose=False)
56
-
57
- if len(results[0].boxes) == 0:
58
- return {
59
- "message": "✅ No stroke signs detected (confidence ≥ 85%)",
60
- "detections": [],
61
- "summary": "Healthy face detected with no significant asymmetry."
62
- }
63
-
64
- detections = []
65
- for box, score, cls in zip(results[0].boxes.xyxy.tolist(),
66
- results[0].boxes.conf.tolist(),
67
- results[0].boxes.cls.tolist()):
68
- label = CLASS_NAMES[int(cls)]
69
- detections.append({
70
- "box": box,
71
- "score": float(score),
72
- "class": int(cls),
73
- "label": label
74
- })
75
-
76
- best_det = max(detections, key=lambda x: x["score"])
77
- summary = f"⚠️ {best_det['label']} detected with {best_det['score']*100:.1f}% confidence."
78
-
79
- return {
80
- "message": "⚠️ Possible stroke signs detected",
81
- "detections": detections,
82
- "summary": summary
83
- }
84
-
85
- except Exception as e:
86
- return JSONResponse({"error": str(e)}, status_code=500)
87
-
88
- @app.post("/predict_image/")
89
- async def predict_image(file: UploadFile = File(...)):
90
- """
91
- Returns the annotated image only (PNG).
92
- """
93
- try:
94
- contents = await file.read()
95
- image = Image.open(io.BytesIO(contents)).convert("RGB")
96
- np_image = np.array(image)
97
-
98
- results = model.predict(source=np_image, conf=0.85, verbose=False)
99
-
100
- annotated = results[0].plot()
101
- annotated_pil = Image.fromarray(cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB))
102
- img_byte_arr = io.BytesIO()
103
- annotated_pil.save(img_byte_arr, format="PNG")
104
- img_byte_arr.seek(0)
105
-
106
- return StreamingResponse(img_byte_arr, media_type="image/png")
107
-
108
- except Exception as e:
109
- return JSONResponse({"error": str(e)}, status_code=500)
 
 
1
+ from fastapi import FastAPI, UploadFile, File
2
+ from fastapi.responses import JSONResponse, StreamingResponse
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from ultralytics import YOLO
5
+ import numpy as np
6
+ from PIL import Image
7
+ import io
8
+ import cv2
9
+
10
+ # Load YOLO model
11
+ MODEL_PATH = os.path.join(os.path.dirname(__file__), "best.pt")
12
+ model = YOLO(MODEL_PATH)
13
+
14
+ # Class labels
15
+ CLASS_NAMES = [
16
+ "normalEye",
17
+ "normalMouth",
18
+ "strokeEyeMid",
19
+ "strokeEyeSevere",
20
+ "strokeEyeWeak",
21
+ "strokeMouthMid",
22
+ "strokeMouthSevere",
23
+ "strokeMouthWeak"
24
+ ]
25
+
26
+ # Initialize FastAPI app
27
+ app = FastAPI(
28
+ title="Stroke-IA Detection API",
29
+ description="REST API for stroke sign detection (tech demo, not medical advice).",
30
+ version="1.0"
31
+ )
32
+
33
+ # ✅ Enable CORS (to avoid fetch issues in Swagger UI or front-end)
34
+ app.add_middleware(
35
+ CORSMiddleware,
36
+ allow_origins=["*"],
37
+ allow_credentials=True,
38
+ allow_methods=["*"],
39
+ allow_headers=["*"],
40
+ )
41
+
42
+ @app.get("/")
43
+ async def root():
44
+ return {"message": "Stroke-IA API is running. Use /predict/ or /predict_image/."}
45
+
46
+ @app.post("/predict/")
47
+ async def predict(file: UploadFile = File(...)):
48
+ """
49
+ Returns JSON with detections (no image).
50
+ """
51
+ try:
52
+ contents = await file.read()
53
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
54
+ np_image = np.array(image)
55
+
56
+ results = model.predict(source=np_image, conf=0.85, verbose=False)
57
+
58
+ if len(results[0].boxes) == 0:
59
+ return {
60
+ "message": "✅ No stroke signs detected (confidence ≥ 85%)",
61
+ "detections": [],
62
+ "summary": "Healthy face detected with no significant asymmetry."
63
+ }
64
+
65
+ detections = []
66
+ for box, score, cls in zip(results[0].boxes.xyxy.tolist(),
67
+ results[0].boxes.conf.tolist(),
68
+ results[0].boxes.cls.tolist()):
69
+ label = CLASS_NAMES[int(cls)]
70
+ detections.append({
71
+ "box": box,
72
+ "score": float(score),
73
+ "class": int(cls),
74
+ "label": label
75
+ })
76
+
77
+ best_det = max(detections, key=lambda x: x["score"])
78
+ summary = f"⚠️ {best_det['label']} detected with {best_det['score']*100:.1f}% confidence."
79
+
80
+ return {
81
+ "message": "⚠️ Possible stroke signs detected",
82
+ "detections": detections,
83
+ "summary": summary
84
+ }
85
+
86
+ except Exception as e:
87
+ return JSONResponse({"error": str(e)}, status_code=500)
88
+
89
+ @app.post("/predict_image/")
90
+ async def predict_image(file: UploadFile = File(...)):
91
+ """
92
+ Returns the annotated image only (PNG).
93
+ """
94
+ try:
95
+ contents = await file.read()
96
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
97
+ np_image = np.array(image)
98
+
99
+ results = model.predict(source=np_image, conf=0.85, verbose=False)
100
+
101
+ annotated = results[0].plot()
102
+ annotated_pil = Image.fromarray(cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB))
103
+ img_byte_arr = io.BytesIO()
104
+ annotated_pil.save(img_byte_arr, format="PNG")
105
+ img_byte_arr.seek(0)
106
+
107
+ return StreamingResponse(img_byte_arr, media_type="image/png")
108
+
109
+ except Exception as e:
110
+ return JSONResponse({"error": str(e)}, status_code=500)