File size: 4,594 Bytes
dc5057c
 
a55aeb4
dc5057c
a55aeb4
 
 
 
dc5057c
a55aeb4
 
dc5057c
 
 
 
a55aeb4
dc5057c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a55aeb4
 
 
 
dc5057c
 
a55aeb4
dc5057c
 
 
 
 
 
 
a55aeb4
dc5057c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a55aeb4
 
 
 
dc5057c
a55aeb4
 
 
dc5057c
a55aeb4
 
 
 
 
dc5057c
a55aeb4
dc5057c
 
 
a55aeb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc5057c
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from fastapi import FastAPI, UploadFile, File, HTTPException, Security, Depends
from fastapi.security.api_key import APIKeyHeader
from fastapi.responses import JSONResponse, StreamingResponse
import uvicorn
import io
import numpy as np
from PIL import Image
import cv2
from ultralytics import YOLO
import requests
import os

# ==========================
# 🔑 Sécurité : API Key
# ==========================
API_KEY = "1234"  # <-- Change ici avant de partager
api_key_header = APIKeyHeader(name="X-API-Key")

def verify_api_key(api_key: str = Security(api_key_header)):
    if api_key != API_KEY:
        raise HTTPException(status_code=403, detail="Forbidden")
    return api_key

# ==========================
# 🚀 Application
# ==========================
app = FastAPI(
    title="Stroke Detection API",
    version="1.0.0",
    description="""
    🚑 Stroke Detection API using YOLOv8

    ⚠️ **Disclaimer**: This API is for **research/demo purposes only**.  
    It is **not a certified medical tool**. Do not use for medical decisions.
    """
)

# Charger ton modèle YOLOv8
model = YOLO("best.pt")

# ==========================
# 📦 Endpoint JSON
# ==========================
@app.post("/v1/predict/")
async def predict(
    file: UploadFile = File(...),
    api_key: str = Depends(verify_api_key)
):
    try:
        # Lire directement en mémoire
        contents = await file.read()
        image = Image.open(io.BytesIO(contents)).convert("RGB")
        np_image = np.array(image)

        # Prédiction YOLO
        results = model.predict(np_image, conf=0.5, verbose=False)

        output = []
        for r in results:
            for box in r.boxes:
                output.append({
                    "class": r.names[int(box.cls[0].item())],
                    "confidence": float(box.conf[0].item()),
                    "bbox": box.xyxy[0].tolist()
                })

        return JSONResponse(content={"predictions": output})

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

# ==========================
# 📦 Endpoint Image
# ==========================
@app.post("/v1/predict_image/")
async def predict_image(
    file: UploadFile = File(...),
    api_key: str = Depends(verify_api_key)
):
    try:
        # Lire directement en mémoire
        contents = await file.read()
        image = Image.open(io.BytesIO(contents)).convert("RGB")
        np_image = np.array(image)

        # Prédiction YOLO + image annotée
        results = model.predict(np_image, conf=0.5, verbose=False)
        annotated = results[0].plot()

        # Convertir en bytes pour StreamingResponse
        annotated_pil = Image.fromarray(cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB))
        img_byte_arr = io.BytesIO()
        annotated_pil.save(img_byte_arr, format="PNG")
        img_byte_arr.seek(0)

        return StreamingResponse(img_byte_arr, media_type="image/png")

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

# ==========================
# 🧪 Endpoint Test interne
# ==========================
@app.get("/test_request/")
async def test_request():
    """
    Test interne de l'API déployée sur Hugging Face.
    Utilise une image locale 'test.jpg' (⚠️ à placer dans ton repo Space).
    """
    try:
        file_path = "test.jpg"  # ⚠️ Mets une image dans ton Space
        base_url = "https://stroke-ia-api.hf.space"  # ⚠️ adapte au nom exact de ton Space

        if not os.path.exists(file_path):
            return {"error": f"{file_path} introuvable dans le Space."}

        # Test JSON
        url_predict = f"{base_url}/v1/predict/"
        files = {"file": open(file_path, "rb")}
        headers = {"X-API-Key": API_KEY}
        response = requests.post(url_predict, files=files, headers=headers)
        json_result = response.json()

        # Test image annotée
        url_img = f"{base_url}/v1/predict_image/"
        files = {"file": open(file_path, "rb")}
        response_img = requests.post(url_img, files=files, headers=headers)

        with open("result.png", "wb") as f:
            f.write(response_img.content)

        return {
            "message": "✅ Test request exécuté sur Hugging Face API. Résultats sauvegardés.",
            "json_result": json_result,
            "saved_image": "result.png"
        }

    except Exception as e:
        return {"error": str(e)}

# ==========================
# 🚀 Lancement local
# ==========================
if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)