Spaces:
Sleeping
Sleeping
| import cv2 | |
| import time | |
| import threading | |
| import numpy as np | |
| from datetime import datetime | |
| from fastapi import FastAPI, UploadFile, File | |
| from fastapi.staticfiles import StaticFiles | |
| from ultralytics import YOLO | |
| from PIL import Image | |
| import os | |
| # ----------------------------- | |
| # 1. Config & Model | |
| # ----------------------------- | |
| MODEL_STROKE_PATH = "stroke.pt" | |
| OUTPUT_DIR = "/tmp/outputs" | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| # Charger YOLO une seule fois | |
| model_stroke = YOLO(MODEL_STROKE_PATH) | |
| BASE_URL = "https://stroke-ia-avc-detect.hf.space" # ⚠️ à adapter selon ton déploiement | |
| # Mapping des classes vers un rapport médical | |
| CLASS_LABELS = { | |
| 0: "Hémorragie intracrânienne", | |
| 1: "Suspicion de zone ischémique", | |
| 2: "Normale Brain", # 👉 adapte en fonction des classes de ton modèle | |
| } | |
| # ----------------------------- | |
| # 2. Génération de rapport | |
| # ----------------------------- | |
| def generate_report(results) -> str: | |
| boxes = results[0].boxes | |
| if len(boxes) == 0: | |
| return "=== RAPPORT AUTOMATIQUE ===\n\nAucune anomalie détectée.\n" | |
| rapport = "=== RAPPORT AUTOMATIQUE AVC ===\n\n" | |
| rapport += f"Nombre de lésions détectées : {len(boxes)}\n\n" | |
| detected_classes = boxes.cls.cpu().numpy().astype(int) | |
| for i, cls_id in enumerate(detected_classes, 1): | |
| label = CLASS_LABELS.get(cls_id, f"Classe inconnue {cls_id}") | |
| rapport += f"- Lésion {i}: {label}\n" | |
| rapport += "\nRecommandations :\n" | |
| rapport += "- Vérifier la concordance clinique.\n" | |
| rapport += "- Considérer un suivi neurologique urgent.\n" | |
| return rapport | |
| # ----------------------------- | |
| # 3. FastAPI | |
| # ----------------------------- | |
| app = FastAPI(title="Stroke Detection API") | |
| app.mount("/files", StaticFiles(directory=OUTPUT_DIR), name="files") | |
| async def predict_stroke(image_file: UploadFile = File(...), conf: float = 0.8): | |
| """ | |
| Endpoint qui reçoit une image IRM et renvoie une image annotée + rapport texte | |
| """ | |
| # Sauvegarde temporaire | |
| tmp_path = f"/tmp/{image_file.filename}" | |
| with open(tmp_path, "wb") as f: | |
| f.write(await image_file.read()) | |
| # Charger image | |
| image = Image.open(tmp_path).convert("RGB") | |
| np_img = np.array(image) | |
| # Conversion en BGR pour OpenCV | |
| np_img = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR) | |
| # Prédiction | |
| results = model_stroke.predict(source=np_img, conf=conf, verbose=False) | |
| if len(results[0].boxes) == 0: | |
| os.remove(tmp_path) | |
| return {"message": "⚠️ Aucun AVC détecté."} | |
| # Annoter l’image | |
| annotated_image = results[0].plot(labels=True) | |
| # Sauvegarder sortie image | |
| timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') | |
| out_img_name = f"stroke_result_{timestamp}.png" | |
| out_img_path = os.path.join(OUTPUT_DIR, out_img_name) | |
| cv2.imwrite(out_img_path, annotated_image) | |
| # Sauvegarder rapport | |
| rapport_text = generate_report(results) | |
| out_txt_name = f"rapport_{timestamp}.txt" | |
| out_txt_path = os.path.join(OUTPUT_DIR, out_txt_name) | |
| with open(out_txt_path, "w", encoding="utf-8") as f: | |
| f.write(rapport_text) | |
| # Nettoyage input | |
| os.remove(tmp_path) | |
| return { | |
| "annotated_result_url": f"{BASE_URL}/files/{out_img_name}", | |
| "rapport_url": f"{BASE_URL}/files/{out_txt_name}", | |
| "message": "✅ Prédiction réussie avec rapport" | |
| } | |
| # ----------------------------- | |
| # 4. Auto-cleanup toutes les 10 min | |
| # ----------------------------- | |
| def auto_cleanup(interval_minutes=10): | |
| while True: | |
| time.sleep(interval_minutes * 60) | |
| for filename in os.listdir(OUTPUT_DIR): | |
| file_path = os.path.join(OUTPUT_DIR, filename) | |
| try: | |
| if os.path.isfile(file_path): | |
| os.remove(file_path) | |
| print(f"[CLEANUP] Fichier supprimé : {file_path}") | |
| except Exception as e: | |
| print(f"[CLEANUP] Erreur suppression {file_path} : {e}") | |
| threading.Thread(target=auto_cleanup, args=(10,), daemon=True).start() | |