API / api-v0.py
Stroke-ia's picture
Rename api.py to api-v0.py
19d9980 verified
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from ultralytics import YOLO
import numpy as np
from PIL import Image
import io
import cv2
import requests # ✅ ajouté
# Load YOLO model
model = YOLO("best.pt")
# Class labels
CLASS_NAMES = [
"normalEye",
"normalMouth",
"strokeEyeMid",
"strokeEyeSevere",
"strokeEyeWeak",
"strokeMouthMid",
"strokeMouthSevere",
"strokeMouthWeak"
]
# Initialize FastAPI app
app = FastAPI(
title="Stroke-IA Detection API",
description="REST API for stroke sign detection (tech demo, not medical advice).",
version="1.0"
)
# ✅ Enable CORS (to avoid fetch issues in Swagger UI or front-end)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def root():
return {"message": "Stroke-IA API is running. Use /predict/ or /predict_image/."}
@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
np_image = np.array(image)
results = model.predict(source=np_image, conf=0.85, verbose=False)
if len(results[0].boxes) == 0:
return {
"message": "✅ No stroke signs detected (confidence ≥ 85%)",
"detections": [],
"summary": "Healthy face detected with no significant asymmetry."
}
detections = []
for box, score, cls in zip(results[0].boxes.xyxy.tolist(),
results[0].boxes.conf.tolist(),
results[0].boxes.cls.tolist()):
label = CLASS_NAMES[int(cls)]
detections.append({
"box": box,
"score": float(score),
"class": int(cls),
"label": label
})
best_det = max(detections, key=lambda x: x["score"])
summary = f"⚠️ {best_det['label']} detected with {best_det['score']*100:.1f}% confidence."
return {
"message": "⚠️ Possible stroke signs detected",
"detections": detections,
"summary": summary
}
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)
@app.post("/predict_image/")
async def predict_image(file: UploadFile = File(...)):
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
np_image = np.array(image)
results = model.predict(source=np_image, conf=0.85, verbose=False)
annotated = results[0].plot()
annotated_pil = Image.fromarray(cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB))
img_byte_arr = io.BytesIO()
annotated_pil.save(img_byte_arr, format="PNG")
img_byte_arr.seek(0)
return StreamingResponse(img_byte_arr, media_type="image/png")
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)
# ✅ Nouvelle route de test sur Hugging Face
@app.get("/test_request/")
async def test_request():
"""
Teste l'API déployée sur Hugging Face : envoie une image vers /predict et /predict_image,
puis sauvegarde les résultats.
"""
try:
file_path = "test.jpg" # mets ton image dans ton Space
base_url = "https://Stroke-ia.hf.space" # ⚠️ adapte au nom exact de ton Space
# 1) Test JSON (détection)
url_predict = f"{base_url}/predict/"
files = {"file": open(file_path, "rb")}
response = requests.post(url_predict, files=files)
json_result = response.json()
# 2) Test image annotée
url_img = f"{base_url}/predict_image/"
files = {"file": open(file_path, "rb")}
response_img = requests.post(url_img, files=files)
with open("result.png", "wb") as f:
f.write(response_img.content)
return {
"message": "✅ Test request executed on Hugging Face API. Results saved.",
"json_result": json_result,
"saved_image": "result.png"
}
except Exception as e:
return {"error": str(e)}