File size: 4,362 Bytes
56632ca c1a2296 56632ca c1a2296 56632ca 4a6733a 20376fb 4a6733a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from ultralytics import YOLO
import numpy as np
from PIL import Image
import io
import cv2
import requests # ✅ ajouté
# Load YOLO model
model = YOLO("best.pt")
# Class labels
CLASS_NAMES = [
"normalEye",
"normalMouth",
"strokeEyeMid",
"strokeEyeSevere",
"strokeEyeWeak",
"strokeMouthMid",
"strokeMouthSevere",
"strokeMouthWeak"
]
# Initialize FastAPI app
app = FastAPI(
title="Stroke-IA Detection API",
description="REST API for stroke sign detection (tech demo, not medical advice).",
version="1.0"
)
# ✅ Enable CORS (to avoid fetch issues in Swagger UI or front-end)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def root():
return {"message": "Stroke-IA API is running. Use /predict/ or /predict_image/."}
@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
np_image = np.array(image)
results = model.predict(source=np_image, conf=0.85, verbose=False)
if len(results[0].boxes) == 0:
return {
"message": "✅ No stroke signs detected (confidence ≥ 85%)",
"detections": [],
"summary": "Healthy face detected with no significant asymmetry."
}
detections = []
for box, score, cls in zip(results[0].boxes.xyxy.tolist(),
results[0].boxes.conf.tolist(),
results[0].boxes.cls.tolist()):
label = CLASS_NAMES[int(cls)]
detections.append({
"box": box,
"score": float(score),
"class": int(cls),
"label": label
})
best_det = max(detections, key=lambda x: x["score"])
summary = f"⚠️ {best_det['label']} detected with {best_det['score']*100:.1f}% confidence."
return {
"message": "⚠️ Possible stroke signs detected",
"detections": detections,
"summary": summary
}
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)
@app.post("/predict_image/")
async def predict_image(file: UploadFile = File(...)):
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
np_image = np.array(image)
results = model.predict(source=np_image, conf=0.85, verbose=False)
annotated = results[0].plot()
annotated_pil = Image.fromarray(cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB))
img_byte_arr = io.BytesIO()
annotated_pil.save(img_byte_arr, format="PNG")
img_byte_arr.seek(0)
return StreamingResponse(img_byte_arr, media_type="image/png")
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)
# ✅ Nouvelle route de test sur Hugging Face
@app.get("/test_request/")
async def test_request():
"""
Teste l'API déployée sur Hugging Face : envoie une image vers /predict et /predict_image,
puis sauvegarde les résultats.
"""
try:
file_path = "test.jpg" # mets ton image dans ton Space
base_url = "https://Stroke-ia.hf.space" # ⚠️ adapte au nom exact de ton Space
# 1) Test JSON (détection)
url_predict = f"{base_url}/predict/"
files = {"file": open(file_path, "rb")}
response = requests.post(url_predict, files=files)
json_result = response.json()
# 2) Test image annotée
url_img = f"{base_url}/predict_image/"
files = {"file": open(file_path, "rb")}
response_img = requests.post(url_img, files=files)
with open("result.png", "wb") as f:
f.write(response_img.content)
return {
"message": "✅ Test request executed on Hugging Face API. Results saved.",
"json_result": json_result,
"saved_image": "result.png"
}
except Exception as e:
return {"error": str(e)}
|