Stroke-ia commited on
Commit
32b647a
·
verified ·
1 Parent(s): 05bf2ce

Create api.py

Browse files
Files changed (1) hide show
  1. api.py +191 -0
api.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import time
3
+ import threading
4
+ import numpy as np
5
+ from datetime import datetime
6
+ from fastapi import FastAPI, UploadFile, File
7
+ from fastapi.staticfiles import StaticFiles
8
+ from ultralytics import YOLO
9
+ from PIL import Image
10
+ import os
11
+
12
+ # -----------------------------
13
+ # 1. Config & Model
14
+ # -----------------------------
15
+ MODEL_STROKE_PATH = "stroke.pt"
16
+ OUTPUT_DIR = "/tmp/outputs"
17
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
18
+
19
+ # Charger YOLO une seule fois
20
+ model_stroke = YOLO(MODEL_STROKE_PATH)
21
+
22
+ BASE_URL = "https://stroke-ia-avc-detect.hf.space" # ⚠️ à adapter selon ton déploiement
23
+
24
+ # Mapping des classes vers un rapport médical
25
+ CLASS_LABELS = {
26
+ 0: "Hémorragie intracrânienne",
27
+ 1: "Suspicion de zone ischémique",
28
+ 2: "Normale Brain", # 👉 adapte en fonction des classes de ton modèle
29
+ }
30
+ # -----------------------------
31
+ # DEMO MODE CONFIG (AJOUT)
32
+ # -----------------------------
33
+ DEMO_DIR = "demo_images"
34
+
35
+ DEMO_CASES = {
36
+ "avc_ischemic": {
37
+ "file": "avc_ischemic.png",
38
+ "label": "AVC ischémique (démo)"
39
+ },
40
+ "avc_hemorrhage": {
41
+ "file": "avc_hemorrhage.png",
42
+ "label": "AVC hémorragique (démo)"
43
+ },
44
+ "normal": {
45
+ "file": "normal.png",
46
+ "label": "IRM normale (démo)"
47
+ }
48
+ }
49
+ # -----------------------------
50
+ # 2. Génération de rapport
51
+ # -----------------------------
52
+ def generate_report(results) -> str:
53
+ boxes = results[0].boxes
54
+ if len(boxes) == 0:
55
+ return "=== RAPPORT AUTOMATIQUE ===\n\nAucune anomalie détectée.\n"
56
+
57
+ rapport = "=== RAPPORT AUTOMATIQUE AVC ===\n\n"
58
+ rapport += f"Nombre de lésions détectées : {len(boxes)}\n\n"
59
+
60
+ detected_classes = boxes.cls.cpu().numpy().astype(int)
61
+ for i, cls_id in enumerate(detected_classes, 1):
62
+ label = CLASS_LABELS.get(cls_id, f"Classe inconnue {cls_id}")
63
+ rapport += f"- Lésion {i}: {label}\n"
64
+
65
+ rapport += "\nRecommandations :\n"
66
+ rapport += "- Vérifier la concordance clinique.\n"
67
+ rapport += "- Considérer un suivi neurologique urgent.\n"
68
+
69
+ return rapport
70
+
71
+ # -----------------------------
72
+ # 3. FastAPI
73
+ # -----------------------------
74
+ app = FastAPI(title="Stroke Detection API")
75
+ app.mount("/files", StaticFiles(directory=OUTPUT_DIR), name="files")
76
+ # -----------------------------
77
+ # DEMO – Liste des cas (AJOUT)
78
+ # -----------------------------
79
+ @app.get("/demo/cases")
80
+ def demo_cases():
81
+ return {
82
+ "mode": "demo",
83
+ "cases": DEMO_CASES,
84
+ "warning": "Cas anonymisés – démonstration uniquement"
85
+ }
86
+
87
+ @app.post("/predict/")
88
+ async def predict_stroke(image_file: UploadFile = File(...), conf: float = 0.8):
89
+ """
90
+ Endpoint qui reçoit une image IRM et renvoie une image annotée + rapport texte
91
+ """
92
+ # Sauvegarde temporaire
93
+ tmp_path = f"/tmp/{image_file.filename}"
94
+ with open(tmp_path, "wb") as f:
95
+ f.write(await image_file.read())
96
+
97
+ # Charger image
98
+ image = Image.open(tmp_path).convert("RGB")
99
+ np_img = np.array(image)
100
+
101
+ # Conversion en BGR pour OpenCV
102
+ np_img = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR)
103
+
104
+ # Prédiction
105
+ results = model_stroke.predict(source=np_img, conf=conf, verbose=False)
106
+
107
+ if len(results[0].boxes) == 0:
108
+ os.remove(tmp_path)
109
+ return {"message": "⚠️ Aucun AVC détecté."}
110
+
111
+ # Annoter l’image
112
+ annotated_image = results[0].plot(labels=True)
113
+
114
+ # Sauvegarder sortie image
115
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
116
+ out_img_name = f"stroke_result_{timestamp}.png"
117
+ out_img_path = os.path.join(OUTPUT_DIR, out_img_name)
118
+ cv2.imwrite(out_img_path, annotated_image)
119
+
120
+ # Sauvegarder rapport
121
+ rapport_text = generate_report(results)
122
+ out_txt_name = f"rapport_{timestamp}.txt"
123
+ out_txt_path = os.path.join(OUTPUT_DIR, out_txt_name)
124
+ with open(out_txt_path, "w", encoding="utf-8") as f:
125
+ f.write(rapport_text)
126
+
127
+ # Nettoyage input
128
+ os.remove(tmp_path)
129
+
130
+ return {
131
+ "annotated_result_url": f"{BASE_URL}/files/{out_img_name}",
132
+ "rapport_url": f"{BASE_URL}/files/{out_txt_name}",
133
+ "message": "✅ Prédiction réussie avec rapport"
134
+ }
135
+ # -----------------------------
136
+ # DEMO – Prédiction sans upload (AJOUT)
137
+ # -----------------------------
138
+ @app.post("/demo/predict/{case_id}")
139
+ def demo_predict(case_id: str, conf: float = 0.8):
140
+
141
+ if case_id not in DEMO_CASES:
142
+ return {"error": "Cas démonstratif invalide"}
143
+
144
+ img_path = os.path.join(DEMO_DIR, DEMO_CASES[case_id]["file"])
145
+
146
+ image = Image.open(img_path).convert("RGB")
147
+ np_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
148
+
149
+ results = model_stroke.predict(source=np_img, conf=conf, verbose=False)
150
+
151
+ annotated_image = results[0].plot(labels=True)
152
+
153
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
154
+ out_img_name = f"demo_{case_id}_{timestamp}.png"
155
+ out_img_path = os.path.join(OUTPUT_DIR, out_img_name)
156
+ cv2.imwrite(out_img_path, annotated_image)
157
+
158
+ rapport_text = generate_report(results)
159
+ rapport_text = (
160
+ "⚠️ MODE DÉMONSTRATION – PAS D’USAGE CLINIQUE ⚠️\n\n"
161
+ + rapport_text
162
+ )
163
+
164
+ out_txt_name = f"demo_rapport_{timestamp}.txt"
165
+ out_txt_path = os.path.join(OUTPUT_DIR, out_txt_name)
166
+ with open(out_txt_path, "w", encoding="utf-8") as f:
167
+ f.write(rapport_text)
168
+
169
+ return {
170
+ "mode": "demo",
171
+ "case": DEMO_CASES[case_id]["label"],
172
+ "annotated_result_url": f"{BASE_URL}/files/{out_img_name}",
173
+ "rapport_url": f"{BASE_URL}/files/{out_txt_name}",
174
+ "disclaimer": "Résultat IA à des fins de démonstration uniquement"
175
+ }
176
+ # -----------------------------
177
+ # 4. Auto-cleanup toutes les 10 min
178
+ # -----------------------------
179
+ def auto_cleanup(interval_minutes=10):
180
+ while True:
181
+ time.sleep(interval_minutes * 60)
182
+ for filename in os.listdir(OUTPUT_DIR):
183
+ file_path = os.path.join(OUTPUT_DIR, filename)
184
+ try:
185
+ if os.path.isfile(file_path):
186
+ os.remove(file_path)
187
+ print(f"[CLEANUP] Fichier supprimé : {file_path}")
188
+ except Exception as e:
189
+ print(f"[CLEANUP] Erreur suppression {file_path} : {e}")
190
+
191
+ threading.Thread(target=auto_cleanup, args=(10,), daemon=True).start()