Stroke-ia commited on
Commit
796bc65
·
verified ·
1 Parent(s): 4c7f6d0

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +163 -65
api.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import os
2
  import cv2
3
  import time
@@ -12,100 +14,196 @@ from PIL import Image
12
  # -----------------------------
13
  # 1. Config & Model
14
  # -----------------------------
15
- MODEL_STROKE_PATH = "stroke.pt"
16
  OUTPUT_DIR = "/tmp/outputs"
17
  os.makedirs(OUTPUT_DIR, exist_ok=True)
18
 
19
- # Charger YOLO une seule fois
20
- model_stroke = YOLO(MODEL_STROKE_PATH)
21
 
22
- BASE_URL = "https://stroke-ia-avc-detect.hf.space" # ⚠️ à adapter selon ton déploiement
 
23
 
24
- # Mapping des classes vers un rapport médical
25
- CLASS_LABELS = {
26
- 0: "Hémorragie intracrânienne",
27
- 1: "Suspicion de zone ischémique",
28
- 2: "Normale Brain",
29
- # 👉 adapte en fonction des classes de ton modèle
30
  }
31
 
32
  # -----------------------------
33
  # 2. Génération de rapport
34
  # -----------------------------
35
- def generate_report(results) -> str:
36
- boxes = results[0].boxes
37
- if len(boxes) == 0:
38
- return "=== RAPPORT AUTOMATIQUE ===\n\nAucune anomalie détectée.\n"
39
-
40
- rapport = "=== RAPPORT AUTOMATIQUE AVC ===\n\n"
41
- rapport += f"Nombre de lésions détectées : {len(boxes)}\n\n"
 
 
42
 
43
- detected_classes = boxes.cls.cpu().numpy().astype(int)
 
 
 
44
 
45
- for i, cls_id in enumerate(detected_classes, 1):
46
- label = CLASS_LABELS.get(cls_id, f"Classe inconnue {cls_id}")
47
- rapport += f"- Lésion {i}: {label}\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  rapport += "\nRecommandations :\n"
50
- rapport += "- Vérifier la concordance clinique.\n"
51
- rapport += "- Considérer un suivi neurologique urgent.\n"
52
 
53
  return rapport
54
 
55
  # -----------------------------
56
  # 3. FastAPI
57
  # -----------------------------
58
- app = FastAPI(title="Stroke Detection API")
59
  app.mount("/files", StaticFiles(directory=OUTPUT_DIR), name="files")
60
 
61
- @app.post("/predict/")
62
- async def predict_stroke(image_file: UploadFile = File(...), conf: float = 0.8):
63
  """
64
- Endpoint qui reçoit une image IRM et renvoie une image annotée + rapport texte
 
 
65
  """
66
  # Sauvegarde temporaire
67
- tmp_path = f"/tmp/{image_file.filename}"
68
  with open(tmp_path, "wb") as f:
69
  f.write(await image_file.read())
70
 
71
- # Charger image
72
- image = Image.open(tmp_path).convert("RGB")
73
- np_img = np.array(image)
74
-
75
- # Conversion en BGR pour OpenCV
76
- np_img = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR)
77
-
78
- # Prédiction
79
- results = model_stroke.predict(source=np_img, conf=conf, verbose=False)
80
-
81
- if len(results[0].boxes) == 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  os.remove(tmp_path)
83
- return {"message": "⚠️ Aucun AVC détecté."}
84
-
85
- # Annoter l’image
86
- annotated_image = results[0].plot(labels=True)
87
-
88
- # Sauvegarder sortie image
89
- timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
90
- out_img_name = f"stroke_result_{timestamp}.png"
91
- out_img_path = os.path.join(OUTPUT_DIR, out_img_name)
92
- cv2.imwrite(out_img_path, annotated_image)
93
-
94
- # Sauvegarder rapport
95
- rapport_text = generate_report(results)
96
- out_txt_name = f"rapport_{timestamp}.txt"
97
- out_txt_path = os.path.join(OUTPUT_DIR, out_txt_name)
98
- with open(out_txt_path, "w", encoding="utf-8") as f:
99
- f.write(rapport_text)
100
-
101
- # Nettoyage input
102
- os.remove(tmp_path)
103
-
104
- return {
105
- "annotated_result_url": f"{BASE_URL}/files/{out_img_name}",
106
- "rapport_url": f"{BASE_URL}/files/{out_txt_name}",
107
- "message": "✅ Prédiction réussie avec rapport"
108
- }
109
 
110
  # -----------------------------
111
  # 4. Auto-cleanup toutes les 10 min
 
1
+
2
+
3
  import os
4
  import cv2
5
  import time
 
14
  # -----------------------------
15
  # 1. Config & Model
16
  # -----------------------------
17
+ MODEL_IRM_PATH = "best_seg.pt" # <- place ton modèle ici
18
  OUTPUT_DIR = "/tmp/outputs"
19
  os.makedirs(OUTPUT_DIR, exist_ok=True)
20
 
21
+ # Charger YOLO (segmentation) une seule fois
22
+ model_irm = YOLO(MODEL_IRM_PATH)
23
 
24
+ # ⚠️ Adapte ce BASE_URL selon ton déploiement (ex : https://tondomaine.tld)
25
+ BASE_URL = "https://mediscan.caba31.com"
26
 
27
+ # Mapping optionnel des classes (si ton modèle prédit des classes).
28
+ # Si ton modèle ne prédit pas de classes, on utilisera "Lésion suspecte".
29
+ CLASS_LABELS_IRM = {
30
+ 1: "Lésion suspecte",
31
+ 0: "Anomalie secondaire",
32
+ # Ajoute/édite selon les classes réelles de ton modèle
33
  }
34
 
35
  # -----------------------------
36
  # 2. Génération de rapport
37
  # -----------------------------
38
+ def generate_report_irm(results, image_shape=None) -> str:
39
+ """
40
+ Génère un rapport texte simple à partir des résultats YOLO segmentation.
41
+ On compte ici le nombre de masques et, si des boxes existent, on donne
42
+ une estimation de la surface relative (en % de l'image).
43
+ """
44
+ # Sécurité si pas de masques
45
+ if results[0].masks is None or len(results[0].masks.data) == 0:
46
+ return "=== RAPPORT AUTOMATIQUE IRM ===\n\nAucun masque détecté.\n"
47
 
48
+ rapport = "=== RAPPORT AUTOMATIQUE IRM ===\n\n"
49
+ # Nombre de masques détectés
50
+ n_masks = len(results[0].masks.data)
51
+ rapport += f"Nombre de masques détectés : {n_masks}\n\n"
52
 
53
+ # Récupérer boxes si présentes
54
+ boxes = results[0].boxes
55
+ has_boxes = len(boxes) > 0
56
+
57
+ # Taille image (hauteur, largeur) si fournie
58
+ img_h, img_w = (None, None)
59
+ if image_shape is not None:
60
+ img_h, img_w = image_shape[0], image_shape[1]
61
+
62
+ # Récupérer classes si présentes
63
+ detected_classes = None
64
+ try:
65
+ if has_boxes:
66
+ detected_classes = boxes.cls.cpu().numpy().astype(int)
67
+ except Exception:
68
+ detected_classes = None
69
+
70
+ for i in range(n_masks):
71
+ # Classe si disponible (sinon texte générique)
72
+ if detected_classes is not None and i < len(detected_classes):
73
+ cls_id = int(detected_classes[i])
74
+ label = CLASS_LABELS_IRM.get(cls_id, f"Classe {cls_id}")
75
+ else:
76
+ label = "Lésion suspecte"
77
+
78
+ rapport += f"- Masque {i+1}: {label}\n"
79
+
80
+ # Si on a une boîte, calculer surface approximative
81
+ if has_boxes and i < len(boxes):
82
+ try:
83
+ xyxy = boxes.xyxyn[i] # valeurs normalisées (0..1) si disponible
84
+ # parfois boxes.xyxyn peut exister ; sinon fallback sur xyxy
85
+ if xyxy is not None and len(xyxy) == 4:
86
+ # xyxyn = (x1_norm, y1_norm, x2_norm, y2_norm)
87
+ x1n, y1n, x2n, y2n = [float(x) for x in xyxy]
88
+ if img_h and img_w:
89
+ width = (x2n - x1n) * img_w
90
+ height = (y2n - y1n) * img_h
91
+ area = max(width * height, 0.0)
92
+ percent = (area / (img_w * img_h)) * 100 if (img_w and img_h) else None
93
+ if percent is not None:
94
+ rapport += f" - Surface estimée: {percent:.2f}% de l'image\n"
95
+ else:
96
+ # fallback using absolute xyxy
97
+ xyxy_abs = boxes.xyxy[i].cpu().numpy()
98
+ x1, y1, x2, y2 = xyxy_abs
99
+ width = max(x2 - x1, 0.0)
100
+ height = max(y2 - y1, 0.0)
101
+ if img_h and img_w:
102
+ area = width * height
103
+ percent = (area / (img_w * img_h)) * 100
104
+ rapport += f" - Surface estimée: {percent:.2f}% de l'image\n"
105
+ except Exception:
106
+ # si erreur, on ignore l'estimation de surface
107
+ pass
108
 
109
  rapport += "\nRecommandations :\n"
110
+ rapport += "- Corréler ces résultats avec le tableau clinique et un radiologue.\n"
111
+ rapport += "- Si la lésion est significative, envisager un suivi/consultation spécialisée.\n"
112
 
113
  return rapport
114
 
115
  # -----------------------------
116
  # 3. FastAPI
117
  # -----------------------------
118
+ app = FastAPI(title="IRM Segmentation API")
119
  app.mount("/files", StaticFiles(directory=OUTPUT_DIR), name="files")
120
 
121
+ @app.post("/predict_irm/")
122
+ async def predict_irm(image_file: UploadFile = File(...), conf: float = 0.8, show_labels: bool = True):
123
  """
124
+ Endpoint qui reçoit une image IRM (upload) et renvoie :
125
+ - une image annotée (masques/boxes) accessible via URL
126
+ - un fichier texte de rapport accessible via URL
127
  """
128
  # Sauvegarde temporaire
129
+ tmp_path = f"/tmp/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{image_file.filename}"
130
  with open(tmp_path, "wb") as f:
131
  f.write(await image_file.read())
132
 
133
+ try:
134
+ # Charger image avec PIL -> RGB
135
+ image = Image.open(tmp_path).convert("RGB")
136
+ np_img = np.array(image)
137
+
138
+ # Conversion en BGR pour OpenCV (et pour ultralytics si nécessaire)
139
+ if np_img.shape[2] == 4:
140
+ np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2BGR)
141
+ else:
142
+ np_img = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR)
143
+
144
+ # Prédiction
145
+ results = model_irm.predict(source=np_img, conf=conf, verbose=False)
146
+
147
+ # Vérifier présence de masques
148
+ if results[0].masks is None or len(results[0].masks.data) == 0:
149
+ # Nettoyage input
150
+ os.remove(tmp_path)
151
+ return {"message": "⚠️ Aucun masque détecté par le modèle IRM."}
152
+
153
+ # Annoter image (affiche masques + boxes)
154
+ annotated_image = results[0].plot(labels=show_labels)
155
+ # Définir timestamp une seule fois
156
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
157
+
158
+ # 2. Mask image
159
+ out_mask_name = f"mask_{timestamp}.png"
160
+ out_mask_path = os.path.join(OUTPUT_DIR, out_mask_name)
161
+ mask = np.zeros(np_img.shape[:2], dtype=np.uint8)
162
+ cv2.rectangle(mask, (50, 50), (200, 200), 255, -1)
163
+ cv2.imwrite(out_mask_path, mask)
164
+
165
+ # Sauvegarder sortie image
166
+
167
+ out_img_name = f"irm_result_{timestamp}.png"
168
+ out_img_path = os.path.join(OUTPUT_DIR, out_img_name)
169
+
170
+
171
+ # Si annotated_image est PIL Image, convertir ; sinon sauver directement si numpy
172
+ if hasattr(annotated_image, "save"):
173
+ # PIL Image
174
+ annotated_image.save(out_img_path)
175
+ else:
176
+ # numpy array (probablement BGR)
177
+ cv2.imwrite(out_img_path, annotated_image)
178
+ cv2.imwrite(out_mask_path, mask)
179
+ # Générer & sauvegarder rapport
180
+ img_shape = np_img.shape[:2] # (h, w)
181
+
182
+ rapport_text = generate_report_irm(results)
183
+ out_txt_name = f"rapport_irm_{timestamp}.txt"
184
+ out_txt_path = os.path.join(OUTPUT_DIR, out_txt_name)
185
+ with open(out_txt_path, "w", encoding="utf-8") as f:
186
+ f.write(rapport_text)
187
+
188
+ # Nettoyage input
189
  os.remove(tmp_path)
190
+
191
+ return {
192
+ "annotated_result_url": f"{BASE_URL}/files/{out_img_name}",
193
+ "rapport_url": f"{BASE_URL}/files/{out_txt_name}",
194
+ "message": "✅ Prédiction réussie avec rapport"
195
+
196
+ "message": "✅ Prédiction réussie avec rapport"
197
+ }
198
+
199
+ except Exception as e:
200
+ # Nettoyage fichier temporaire si existant
201
+ try:
202
+ if os.path.exists(tmp_path):
203
+ os.remove(tmp_path)
204
+ except Exception:
205
+ pass
206
+ return {"error": str(e)}
 
 
 
 
 
 
 
 
 
207
 
208
  # -----------------------------
209
  # 4. Auto-cleanup toutes les 10 min