Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from PIL import Image | |
| from ultralytics import YOLO | |
| import cv2, os | |
| from datetime import datetime | |
| import numpy as np | |
| from dotenv import load_dotenv | |
| # ---------------- Charger config ---------------- | |
| load_dotenv() | |
| SAVE_LIMIT_FREE = int(os.getenv("SAVE_LIMIT_FREE", 5)) | |
| PREMIUM_KEY = os.getenv("PREMIUM_KEY", "VOTRE_CLE_PREMIUM") | |
| # ---------------- Config générale ---------------- | |
| MODEL_PATH = "best.pt" | |
| MODEL_IRM_PATH = "best_seg.pt" | |
| SAVE_DIR = os.path.join("/tmp", "results") | |
| os.makedirs(SAVE_DIR, exist_ok=True) | |
| # Charger les modèles YOLO | |
| model = YOLO(MODEL_PATH) | |
| model_irm = YOLO(MODEL_IRM_PATH) | |
| # ---------------- Etat utilisateur ---------------- | |
| if "uploads_count" not in st.session_state: | |
| st.session_state.uploads_count = 0 | |
| if "uploads_count_irm" not in st.session_state: | |
| st.session_state.uploads_count_irm = 0 | |
| if "premium_access" not in st.session_state: | |
| st.session_state.premium_access = False | |
| # ---------------- Fonctions utilitaires ---------------- | |
| def _largest_face_bbox(np_img): | |
| import mediapipe as mp | |
| mp_face_detection = mp.solutions.face_detection | |
| h, w = np_img.shape[:2] | |
| with mp_face_detection.FaceDetection(min_detection_confidence=0.6) as fd: | |
| results = fd.process(cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR)) | |
| if not results.detections: | |
| return None | |
| boxes = [] | |
| for det in results.detections: | |
| rel = det.location_data.relative_bounding_box | |
| x1 = int(max(0, rel.xmin) * w) | |
| y1 = int(max(0, rel.ymin) * h) | |
| x2 = int(min(1.0, rel.xmin + rel.width) * w) | |
| y2 = int(min(1.0, rel.ymin + rel.height) * h) | |
| boxes.append((x1, y1, x2, y2)) | |
| boxes.sort(key=lambda b: (b[2]-b[0])*(b[3]-b[1]), reverse=True) | |
| return boxes[0] if boxes else None | |
| def check_limit(counter_name="uploads_count"): | |
| """Vérifie la limite gratuite.""" | |
| if not st.session_state.premium_access and st.session_state[counter_name] >= SAVE_LIMIT_FREE: | |
| st.warning(f"⚠️ Limite gratuite atteinte ({SAVE_LIMIT_FREE} uploads). Passez en mode premium pour continuer.") | |
| return False | |
| return True | |
| # ---------------- Prédiction image classique ---------------- | |
| def predict_image(image, conf=0.85, show_labels=True): | |
| if not check_limit("uploads_count"): | |
| return None | |
| np_img = np.array(image) | |
| face_bbox = _largest_face_bbox(np_img) | |
| if face_bbox is None: | |
| st.warning("⚠️ Aucun visage humain détecté. Veuillez centrer le visage.") | |
| return None | |
| if np_img.shape[2] == 4: | |
| np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2BGR) | |
| else: | |
| np_img = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR) | |
| results = model.predict(source=np_img, conf=conf, verbose=False) | |
| if len(results[0].boxes) == 0: | |
| return None | |
| annotated_image = results[0].plot(labels=show_labels) | |
| out_path = os.path.join(SAVE_DIR, f"image_result_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png") | |
| cv2.imwrite(out_path, annotated_image) | |
| st.session_state.uploads_count += 1 | |
| return out_path | |
| # ---------------- Prédiction vidéo ---------------- | |
| def predict_video(video_path, conf=0.85, show_labels=True): | |
| if not check_limit("uploads_count"): | |
| return None | |
| cap = cv2.VideoCapture(video_path) | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out_path = os.path.join(SAVE_DIR, f"video_result_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp4") | |
| fps = cap.get(cv2.CAP_PROP_FPS) or 30 | |
| width, height = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| out = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) | |
| detections = 0 | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| results = model.predict(frame, conf=conf, verbose=False) | |
| if len(results[0].boxes) > 0: | |
| detections += 1 | |
| annotated = results[0].plot(labels=show_labels) | |
| out.write(annotated) | |
| cap.release() | |
| out.release() | |
| if detections == 0: | |
| return None | |
| st.session_state.uploads_count += 1 | |
| return out_path | |
| # ---------------- Prédiction IRM ---------------- | |
| def predict_image_irm(image, conf=0.8, show_labels=True): | |
| if not check_limit("uploads_count_irm"): | |
| return None | |
| np_img = np.array(image) | |
| if np_img.shape[2] == 4: | |
| np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2BGR) | |
| else: | |
| np_img = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR) | |
| results = model_irm.predict(source=np_img, conf=conf, verbose=False) | |
| if results[0].masks is None or len(results[0].masks.data) == 0: | |
| st.warning("⚠️ Aucun masque détecté par le modèle IRM.") | |
| return None | |
| annotated_image = results[0].plot(labels=show_labels) | |
| out_path = os.path.join(SAVE_DIR, f"irm_result_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png") | |
| cv2.imwrite(out_path, annotated_image) | |
| st.session_state.uploads_count_irm += 1 | |
| return out_path | |
| # ---------------- Interface Streamlit ---------------- | |
| st.title("🧠 Stroke-IA – Détection AVC par IA") | |
| # ---------------- Sidebar ---------------- | |
| st.sidebar.header("⚙️ Paramètres utilisateur") | |
| conf_threshold = st.sidebar.slider("Seuil de confiance (images/vidéos)", 0.1, 1.0, 0.85, 0.05, key="conf_slider") | |
| conf_threshold_irm = st.sidebar.slider("Seuil de confiance (IRM)", 0.1, 1.0, 0.8, 0.05, key="conf_slider_irm") | |
| show_labels = st.sidebar.checkbox("Afficher les labels", value=True, key="labels_checkbox") | |
| st.sidebar.header("🔑 Premium / Essai") | |
| if not st.session_state.premium_access: | |
| user_key = st.sidebar.text_input("Entrez votre clé premium :", type="password", key="premium_input") | |
| if user_key == PREMIUM_KEY: | |
| st.session_state.premium_access = True | |
| st.sidebar.success("✅ Mode premium activé ! La limitation est levée.") | |
| st.rerun() | |
| if not st.session_state.premium_access: | |
| st.sidebar.info(f"📊 Utilisation gratuite images/vidéos : {st.session_state.uploads_count}/{SAVE_LIMIT_FREE}") | |
| st.sidebar.info(f"📊 Utilisation gratuite IRM : {st.session_state.uploads_count_irm}/{SAVE_LIMIT_FREE}") | |
| # ---------------- Upload vidéo ---------------- | |
| st.header("🎥 Détection sur vidéo") | |
| video_file = st.file_uploader("Uploader une vidéo", type=["mp4", "mov"], key="video_uploader") | |
| if video_file and st.button("Analyser la vidéo", key="video_button"): | |
| temp_path = os.path.join(SAVE_DIR, "temp_video.mp4") | |
| with open(temp_path, "wb") as f: | |
| f.write(video_file.read()) | |
| result_path = predict_video(temp_path, conf=conf_threshold, show_labels=show_labels) | |
| if result_path is None: | |
| st.success("✅ Aucun AVC détecté ou limite gratuite atteinte.") | |
| else: | |
| st.video(result_path) | |
| # ---------------- Upload image ---------------- | |
| st.header("🖼️ Détection sur image") | |
| image_file = st.file_uploader("Uploader une image", type=["jpg", "jpeg", "png"], key="image_uploader") | |
| if image_file and st.button("Analyser l'image", key="image_button"): | |
| image = Image.open(image_file) | |
| result_path = predict_image(image, conf=conf_threshold, show_labels=show_labels) | |
| if result_path is None: | |
| st.success("✅ Aucun AVC détecté ou limite gratuite atteinte.") | |
| else: | |
| st.image(result_path, caption="Image annotée", use_container_width=True) | |
| # ---------------- Upload IRM ---------------- | |
| st.header("🧠 Détection sur IRM") | |
| irm_file = st.file_uploader("Uploader une IRM", type=["jpg", "jpeg", "png"], key="irm_uploader") | |
| if irm_file and st.button("Analyser l'IRM", key="irm_button"): | |
| irm_image = Image.open(irm_file) | |
| result_path_irm = predict_image_irm(irm_image, conf=conf_threshold_irm, show_labels=show_labels) | |
| if result_path_irm is None: | |
| st.success("✅ Aucun résultat détecté ou limite gratuite atteinte.") | |
| else: | |
| st.image(result_path_irm, caption="IRM annotée", use_container_width=True) | |
| # ---------------- Disclaimer ---------------- | |
| st.markdown(f""" | |
| --- | |
| 👨💻 **Badsi Djilali** — Ingénieur Deep Learning | |
| 🚀 Créateur de **Stroke_IA_Detection** | |
| 🧠 (Détection d'asymétrie faciale & AVC par IA) | |
| ⚠️ **Disclaimer :** Stroke-IA est une démo technique, pas un avis médical. | |
| © {datetime.now().year} — Badsi Djilali. | |
| """) | |