Stroke-ia commited on
Commit
74015ea
·
verified ·
1 Parent(s): ff67d18

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -0
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File
2
+ from fastapi.responses import JSONResponse, StreamingResponse
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from ultralytics import YOLO
5
+ import numpy as np
6
+ from PIL import Image
7
+ import io
8
+ import cv2
9
+
10
+ # Load YOLO model
11
+ model = YOLO("best.pt")
12
+
13
+ # Class labels
14
+ CLASS_NAMES = [
15
+ "normalEye",
16
+ "normalMouth",
17
+ "strokeEyeMid",
18
+ "strokeEyeSevere",
19
+ "strokeEyeWeak",
20
+ "strokeMouthMid",
21
+ "strokeMouthSevere",
22
+ "strokeMouthWeak"
23
+ ]
24
+
25
+ # Initialize FastAPI app
26
+ app = FastAPI(
27
+ title="Stroke-IA Detection API",
28
+ description="REST API for stroke sign detection (tech demo, not medical advice).",
29
+ version="1.0"
30
+ )
31
+
32
+ # ✅ Enable CORS (to avoid fetch issues in Swagger UI or front-end)
33
+ app.add_middleware(
34
+ CORSMiddleware,
35
+ allow_origins=["*"],
36
+ allow_credentials=True,
37
+ allow_methods=["*"],
38
+ allow_headers=["*"],
39
+ )
40
+
41
+ @app.get("/")
42
+ async def root():
43
+ return {"message": "Stroke-IA API is running. Use /predict/ or /predict_image/."}
44
+
45
+ @app.post("/predict/")
46
+ async def predict(file: UploadFile = File(...)):
47
+ """
48
+ Returns JSON with detections (no image).
49
+ """
50
+ try:
51
+ contents = await file.read()
52
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
53
+ np_image = np.array(image)
54
+
55
+ results = model.predict(source=np_image, conf=0.85, verbose=False)
56
+
57
+ if len(results[0].boxes) == 0:
58
+ return {
59
+ "message": "✅ No stroke signs detected (confidence ≥ 85%)",
60
+ "detections": [],
61
+ "summary": "Healthy face detected with no significant asymmetry."
62
+ }
63
+
64
+ detections = []
65
+ for box, score, cls in zip(results[0].boxes.xyxy.tolist(),
66
+ results[0].boxes.conf.tolist(),
67
+ results[0].boxes.cls.tolist()):
68
+ label = CLASS_NAMES[int(cls)]
69
+ detections.append({
70
+ "box": box,
71
+ "score": float(score),
72
+ "class": int(cls),
73
+ "label": label
74
+ })
75
+
76
+ best_det = max(detections, key=lambda x: x["score"])
77
+ summary = f"⚠️ {best_det['label']} detected with {best_det['score']*100:.1f}% confidence."
78
+
79
+ return {
80
+ "message": "⚠️ Possible stroke signs detected",
81
+ "detections": detections,
82
+ "summary": summary
83
+ }
84
+
85
+ except Exception as e:
86
+ return JSONResponse({"error": str(e)}, status_code=500)
87
+
88
+ @app.post("/predict_image/")
89
+ async def predict_image(file: UploadFile = File(...)):
90
+ """
91
+ Returns the annotated image only (PNG).
92
+ """
93
+ try:
94
+ contents = await file.read()
95
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
96
+ np_image = np.array(image)
97
+
98
+ results = model.predict(source=np_image, conf=0.85, verbose=False)
99
+
100
+ annotated = results[0].plot()
101
+ annotated_pil = Image.fromarray(cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB))
102
+ img_byte_arr = io.BytesIO()
103
+ annotated_pil.save(img_byte_arr, format="PNG")
104
+ img_byte_arr.seek(0)
105
+
106
+ return StreamingResponse(img_byte_arr, media_type="image/png")
107
+
108
+ except Exception as e:
109
+ return JSONResponse({"error": str(e)}, status_code=500)