import streamlit as st from PIL import Image from ultralytics import YOLO import cv2, os from datetime import datetime import numpy as np from dotenv import load_dotenv import irm_cancer_module # ---------------- Charger config ---------------- load_dotenv() SAVE_LIMIT_FREE = int(os.getenv("SAVE_LIMIT_FREE", 5)) PREMIUM_KEY = os.getenv("PREMIUM_KEY", "VOTRE_CLE_PREMIUM") # ---------------- Config générale ---------------- MODEL_PATH = "best.pt" MODEL_IRM_PATH = "best_seg.pt" MODEL_STROKE_PATH = "stroke.pt" SAVE_DIR = os.path.join("/tmp", "results") os.makedirs(SAVE_DIR, exist_ok=True) # ---------------- Charger modèles YOLO ---------------- model = YOLO(MODEL_PATH) model_irm = YOLO(MODEL_IRM_PATH) model_stroke = YOLO(MODEL_STROKE_PATH) # ---------------- Etat utilisateur ---------------- for key in ["uploads_count", "uploads_count_irm", "uploads_count_stroke", "premium_access"]: if key not in st.session_state: st.session_state[key] = 0 if "count" in key else False # ---------------- Fonctions utilitaires ---------------- def _largest_face_bbox(np_img): import mediapipe as mp mp_face_detection = mp.solutions.face_detection h, w = np_img.shape[:2] with mp_face_detection.FaceDetection(min_detection_confidence=0.6) as fd: results = fd.process(cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR)) if not results.detections: return None boxes = [] for det in results.detections: rel = det.location_data.relative_bounding_box x1 = int(max(0, rel.xmin) * w) y1 = int(max(0, rel.ymin) * h) x2 = int(min(1.0, rel.xmin + rel.width) * w) y2 = int(min(1.0, rel.ymin + rel.height) * h) boxes.append((x1, y1, x2, y2)) boxes.sort(key=lambda b: (b[2]-b[0])*(b[3]-b[1]), reverse=True) return boxes[0] if boxes else None def check_limit(counter_name="uploads_count"): """Vérifie la limite gratuite.""" if not st.session_state.premium_access and st.session_state[counter_name] >= SAVE_LIMIT_FREE: st.warning(f"⚠️ Limite gratuite atteinte ({SAVE_LIMIT_FREE} uploads). Passez en mode premium pour continuer.") return False return True # ---------------- Prédiction image classique ---------------- def predict_image(image, conf=0.85, show_labels=True): if not check_limit("uploads_count"): return None np_img = np.array(image) face_bbox = _largest_face_bbox(np_img) if face_bbox is None: st.warning("⚠️ Aucun visage humain détecté. Veuillez centrer le visage.") return None if np_img.shape[2] == 4: np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2BGR) else: np_img = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR) results = model.predict(source=np_img, conf=conf, verbose=False) if len(results[0].boxes) == 0: return None annotated_image = results[0].plot(labels=show_labels) out_path = os.path.join(SAVE_DIR, f"image_result_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png") cv2.imwrite(out_path, annotated_image) st.session_state.uploads_count += 1 return out_path # ---------------- Prédiction vidéo ---------------- def predict_video(video_path, conf=0.85, show_labels=True): if not check_limit("uploads_count"): return None cap = cv2.VideoCapture(video_path) fourcc = cv2.VideoWriter_fourcc(*'mp4v') out_path = os.path.join(SAVE_DIR, f"video_result_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp4") fps = cap.get(cv2.CAP_PROP_FPS) or 30 width, height = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) out = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) detections = 0 while cap.isOpened(): ret, frame = cap.read() if not ret: break results = model.predict(frame, conf=conf, verbose=False) if len(results[0].boxes) > 0: detections += 1 annotated = results[0].plot(labels=show_labels) out.write(annotated) cap.release() out.release() if detections == 0: return None st.session_state.uploads_count += 1 return out_path # ---------------- Prédiction IRM ---------------- def predict_image_irm(image, conf=0.8, show_labels=True): if not check_limit("uploads_count_irm"): return None np_img = np.array(image) if np_img.shape[2] == 4: np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2BGR) else: np_img = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR) results = model_irm.predict(source=np_img, conf=conf, verbose=False) if results[0].masks is None or len(results[0].masks.data) == 0: st.warning("⚠️ Aucun masque détecté par le modèle IRM.") return None annotated_image = results[0].plot(labels=show_labels) out_path = os.path.join(SAVE_DIR, f"irm_result_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png") cv2.imwrite(out_path, annotated_image) st.session_state.uploads_count_irm += 1 return out_path # ---------------- Prédiction Stroke IRM ---------------- def predict_image_stroke(image, conf=0.8, show_labels=True): if not check_limit("uploads_count_stroke"): return None np_img = np.array(image) if np_img.shape[2] == 4: np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2BGR) else: np_img = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR) results = model_stroke.predict(source=np_img, conf=conf, verbose=False) if len(results[0].boxes) == 0: st.warning("⚠️ Aucun AVC détecté par le modèle Stroke.") return None annotated_image = results[0].plot(labels=show_labels) out_path = os.path.join(SAVE_DIR, f"stroke_result_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png") cv2.imwrite(out_path, annotated_image) st.session_state.uploads_count_stroke += 1 return out_path # ---------------- Interface Streamlit ---------------- st.title("🧠 Stroke-IA Détection AVC par IA") # ---------------- Sidebar ---------------- st.sidebar.header("⚙️ Paramètres utilisateur") conf_threshold = st.sidebar.slider("Seuil de confiance (images/vidéos)", 0.1, 1.0, 0.85, 0.05, key="conf_slider") conf_threshold_irm = st.sidebar.slider("Seuil de confiance (IRM)", 0.1, 1.0, 0.8, 0.05, key="conf_slider_irm") conf_threshold_stroke = st.sidebar.slider("Seuil de confiance (Stroke IRM)", 0.1, 1.0, 0.8, 0.05, key="conf_slider_stroke") show_labels = st.sidebar.checkbox("Afficher les labels", value=True, key="labels_checkbox") st.sidebar.header("🔑 Premium / Essai") if not st.session_state.premium_access: user_key = st.sidebar.text_input("Entrez votre clé premium :", type="password", key="premium_input") if user_key == PREMIUM_KEY: st.session_state.premium_access = True st.sidebar.success("✅ Mode premium activé ! La limitation est levée.") st.rerun() if not st.session_state.premium_access: st.sidebar.info(f"📊 Utilisation gratuite images/vidéos : {st.session_state.uploads_count}/{SAVE_LIMIT_FREE}") st.sidebar.info(f"📊 Utilisation gratuite IRM : {st.session_state.uploads_count_irm}/{SAVE_LIMIT_FREE}") st.sidebar.info(f"📊 Utilisation gratuite Stroke IRM : {st.session_state.uploads_count_stroke}/{SAVE_LIMIT_FREE}") # ---------------- Upload vidéo ---------------- st.header("🎥 Détection sur vidéo") video_file = st.file_uploader("Uploader une vidéo", type=["mp4", "mov"], key="video_uploader") if video_file and st.button("Analyser la vidéo", key="video_button"): temp_path = os.path.join(SAVE_DIR, "temp_video.mp4") with open(temp_path, "wb") as f: f.write(video_file.read()) result_path = predict_video(temp_path, conf=conf_threshold, show_labels=show_labels) if result_path is None: st.success("✅ Aucun AVC détecté ou limite gratuite atteinte.") else: st.video(result_path) # ---------------- Upload image ---------------- st.header("🖼️ Détection sur image") image_file = st.file_uploader("Uploader une image", type=["jpg", "jpeg", "png"], key="image_uploader") if image_file and st.button("Analyser l'image", key="image_button"): image = Image.open(image_file) result_path = predict_image(image, conf=conf_threshold, show_labels=show_labels) if result_path is None: st.success("✅ Aucun AVC détecté ou limite gratuite atteinte.") else: st.image(result_path, caption="Image annotée", use_container_width=True) # ---------------- Upload IRM ---------------- st.header("🧠 Détection CANCER par IRM") irm_file = st.file_uploader("Uploader une IRM", type=["jpg", "jpeg", "png"], key="irm_uploader") if irm_file and st.button("Analyser l'IRM", key="irm_button"): irm_image = Image.open(irm_file) result_path_irm = predict_image_irm(irm_image, conf=conf_threshold_irm, show_labels=show_labels) if result_path_irm is None: st.success("✅ Aucun résultat détecté ou limite gratuite atteinte.") else: st.image(result_path_irm, caption="IRM annotée", use_container_width=True) # ---------------- Upload IRM Stroke ---------------- st.header("🧠 Détection AVC par IRM") stroke_file = st.file_uploader("Uploader une IRM pour Stroke", type=["jpg", "jpeg", "png"], key="stroke_uploader") if stroke_file and st.button("Analyser l'IRM Stroke", key="stroke_button"): stroke_image = Image.open(stroke_file) result_path_stroke = predict_image_stroke(stroke_image, conf=conf_threshold_stroke, show_labels=show_labels) if result_path_stroke is None: st.success("✅ Aucun résultat détecté ou limite gratuite atteinte.") else: st.image(result_path_stroke, caption="Stroke annotée", use_container_width=True) # Upload IRM 3D (cancer) st.header("🧠 Détection Tumeur (IRM 3D)") irm3d_files = st.file_uploader("Uploader 4 séquences (FLAIR, T1, T1CE, T2)", type=["nii", "nii.gz"], accept_multiple_files=True) if irm3d_files and st.button("Analyser IRM 3D"): if len(irm3d_files) != 4: st.error("⚠️ Merci d’uploader exactement 4 fichiers IRM (FLAIR, T1, T1CE, T2)") else: tmp_paths = [] for f in irm3d_files: path = os.path.join(SAVE_DIR, f.name) with open(path, "wb") as out: out.write(f.read()) tmp_paths.append(path) seg, report_text, (nii_path, report_path, mask_path) = irm_cancer_module.run(tmp_paths) st.subheader("📝 Rapport automatique") st.text(report_text) if mask_path and os.path.exists(mask_path): st.image(mask_path, caption="Segmentation annotée", use_container_width=True) # Disclaimer st.markdown(f""" --- 👨‍💻 **Badsi Djilali** — Ingénieur Deep Learning 🚀 Créateur de **Stroke_IA_Detection** ⚠️ Démo technique, pas un avis médical. © {datetime.now().year} — Badsi Djilali. """)