Stroke-ia commited on
Commit
069a73c
·
verified ·
1 Parent(s): 6c73fc9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +188 -0
app.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ from ultralytics import YOLO
4
+ import cv2, os
5
+ from datetime import datetime
6
+ import numpy as np
7
+ from dotenv import load_dotenv
8
+ import nibabel as nib # Pour traitement NIfTI IRM
9
+
10
+ # ---------------- Charger config ----------------
11
+ load_dotenv()
12
+ SAVE_LIMIT_FREE = int(os.getenv("SAVE_LIMIT_FREE", 5))
13
+ PREMIUM_KEY = os.getenv("PREMIUM_KEY", "VOTRE_CLE_PREMIUM")
14
+
15
+ # ---------------- Config générale ----------------
16
+ SAVE_DIR = os.path.join("/tmp", "results")
17
+ os.makedirs(SAVE_DIR, exist_ok=True)
18
+
19
+ # ---------------- Modèles ----------------
20
+ MODEL_AVC_PATH = "best.pt"
21
+ model_avc = YOLO(MODEL_AVC_PATH)
22
+
23
+ MODEL_SEG_PATH = "best_seg.pt"
24
+ model_seg = YOLO(MODEL_SEG_PATH)
25
+
26
+ # ---------------- Etat utilisateur ----------------
27
+ if "uploads_count" not in st.session_state:
28
+ st.session_state.uploads_count = 0
29
+ if "premium_access" not in st.session_state:
30
+ st.session_state.premium_access = False
31
+
32
+ # ---------------- Fonctions utilitaires ----------------
33
+ def _largest_face_bbox(np_img):
34
+ import mediapipe as mp
35
+ mp_face_detection = mp.solutions.face_detection
36
+ h, w = np_img.shape[:2]
37
+ with mp_face_detection.FaceDetection(min_detection_confidence=0.6) as fd:
38
+ results = fd.process(cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR))
39
+ if not results.detections:
40
+ return None
41
+ boxes = []
42
+ for det in results.detections:
43
+ rel = det.location_data.relative_bounding_box
44
+ x1 = int(max(0, rel.xmin) * w)
45
+ y1 = int(max(0, rel.ymin) * h)
46
+ x2 = int(min(1.0, rel.xmin + rel.width) * w)
47
+ y2 = int(min(1.0, rel.ymin + rel.height) * h)
48
+ boxes.append((x1, y1, x2, y2))
49
+ boxes.sort(key=lambda b: (b[2]-b[0])*(b[3]-b[1]), reverse=True)
50
+ return boxes[0] if boxes else None
51
+
52
+ def check_limit():
53
+ if not st.session_state.premium_access and st.session_state.uploads_count >= SAVE_LIMIT_FREE:
54
+ st.warning(f"⚠️ Limite gratuite atteinte ({SAVE_LIMIT_FREE} uploads). Passez en mode premium pour continuer.")
55
+ return False
56
+ return True
57
+
58
+ def predict_image(image, conf=0.85, show_labels=True):
59
+ if not check_limit():
60
+ return None
61
+ np_img = np.array(image)
62
+ face_bbox = _largest_face_bbox(np_img)
63
+ if face_bbox is None:
64
+ st.warning("⚠️ Aucun visage humain détecté. Veuillez centrer le visage.")
65
+ return None
66
+ if np_img.shape[2] == 4:
67
+ np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2BGR)
68
+ else:
69
+ np_img = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR)
70
+
71
+ results = model_avc.predict(source=np_img, conf=conf, verbose=False)
72
+ if len(results[0].boxes) == 0:
73
+ return None
74
+ annotated_image = results[0].plot(labels=show_labels)
75
+ out_path = os.path.join(SAVE_DIR, f"image_result_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png")
76
+ cv2.imwrite(out_path, annotated_image)
77
+ st.session_state.uploads_count += 1
78
+ return out_path
79
+
80
+ def predict_video(video_path, conf=0.85, show_labels=True):
81
+ if not check_limit():
82
+ return None
83
+ cap = cv2.VideoCapture(video_path)
84
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
85
+ out_path = os.path.join(SAVE_DIR, f"video_result_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp4")
86
+ fps = cap.get(cv2.CAP_PROP_FPS) or 30
87
+ width, height = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
88
+ out = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
89
+ detections = 0
90
+ while cap.isOpened():
91
+ ret, frame = cap.read()
92
+ if not ret:
93
+ break
94
+ results = model_avc.predict(frame, conf=conf, verbose=False)
95
+ if len(results[0].boxes) > 0:
96
+ detections += 1
97
+ annotated = results[0].plot(labels=show_labels)
98
+ out.write(annotated)
99
+ cap.release()
100
+ out.release()
101
+ if detections == 0:
102
+ return None
103
+ st.session_state.uploads_count += 1
104
+ return out_path
105
+
106
+ def load_nii_as_image(file):
107
+ nii = nib.load(file)
108
+ data = nii.get_fdata()
109
+ slice_idx = data.shape[2] // 2 # coupe centrale
110
+ slice_img = data[:, :, slice_idx]
111
+ slice_img = ((slice_img - slice_img.min()) / (slice_img.max() - slice_img.min()) * 255).astype(np.uint8)
112
+ return Image.fromarray(slice_img)
113
+
114
+ # ---------------- Interface Streamlit ----------------
115
+ st.title("🧠 Stroke-IA – Détection AVC par IA")
116
+
117
+ # ---------------- Sidebar ----------------
118
+ st.sidebar.header("⚙️ Paramètres utilisateur")
119
+ conf_threshold = st.sidebar.slider("Seuil de confiance AVC", 0.1, 1.0, 0.85, 0.05, key="conf_slider")
120
+ show_labels = st.sidebar.checkbox("Afficher les labels", value=True, key="labels_checkbox")
121
+
122
+ st.sidebar.header("🧠 Paramètres IRM")
123
+ conf_threshold_irm = st.sidebar.slider("Seuil de confiance IRM", 0.1, 1.0, 0.8, 0.05, key="conf_slider_irm")
124
+
125
+ st.sidebar.header("🔑 Premium / Essai")
126
+ if not st.session_state.premium_access:
127
+ user_key = st.sidebar.text_input("Entrez votre clé premium :", type="password", key="premium_input")
128
+ if user_key == PREMIUM_KEY:
129
+ st.session_state.premium_access = True
130
+ st.sidebar.success("✅ Mode premium activé ! La limitation est levée.")
131
+ st.rerun()
132
+
133
+ if not st.session_state.premium_access:
134
+ st.sidebar.info(f"📊 Utilisation gratuite : {st.session_state.uploads_count}/{SAVE_LIMIT_FREE}")
135
+
136
+ # ---------------- Upload vidéo ----------------
137
+ st.header("🎥 Détection sur vidéo")
138
+ video_file = st.file_uploader("Uploader une vidéo", type=["mp4", "mov"], key="video_uploader")
139
+ if video_file and st.button("Analyser la vidéo", key="video_button"):
140
+ temp_path = os.path.join(SAVE_DIR, "temp_video.mp4")
141
+ with open(temp_path, "wb") as f:
142
+ f.write(video_file.read())
143
+ result_path = predict_video(temp_path, conf=conf_threshold, show_labels=show_labels)
144
+ if result_path is None:
145
+ st.success(f"✅ Aucun AVC détecté ou limite gratuite atteinte.")
146
+ else:
147
+ st.video(result_path)
148
+
149
+ # ---------------- Upload image ----------------
150
+ st.header("🖼️ Détection sur image")
151
+ image_file = st.file_uploader("Uploader une image", type=["jpg", "jpeg", "png"], key="image_uploader")
152
+ if image_file and st.button("Analyser l'image", key="image_button"):
153
+ image = Image.open(image_file)
154
+ result_path = predict_image(image, conf=conf_threshold, show_labels=show_labels)
155
+ if result_path is None:
156
+ st.success(f"✅ Aucun AVC détecté ou limite gratuite atteinte.")
157
+ else:
158
+ st.image(result_path, caption="Image annotée", use_container_width=True)
159
+
160
+ # ---------------- Upload IRM ----------------
161
+ st.header("🧪 Détection sur IRM")
162
+ irm_file = st.file_uploader("Uploader une IRM", type=["jpg", "jpeg", "png", "nii"], key="irm_uploader")
163
+ if irm_file and st.button("Analyser l'IRM", key="irm_button"):
164
+ image = Image.open(irm_file) if not irm_file.name.endswith(".nii") else load_nii_as_image(irm_file)
165
+ np_img = np.array(image)
166
+ if np_img.shape[2] == 4:
167
+ np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2BGR)
168
+ else:
169
+ np_img = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR)
170
+ results = model_seg.predict(np_img, conf=conf_threshold_irm, verbose=False)
171
+ if len(results[0].masks.data) == 0:
172
+ st.warning("⚠️ Aucune anomalie détectée sur cette IRM.")
173
+ else:
174
+ annotated_img = results[0].plot()
175
+ out_path = os.path.join(SAVE_DIR, f"irm_result_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png")
176
+ cv2.imwrite(out_path, annotated_img)
177
+ st.image(out_path, caption="IRM annotée", use_container_width=True)
178
+
179
+ # ---------------- Disclaimer ----------------
180
+ st.markdown(f"""
181
+ ---
182
+ 👨‍💻 **Badsi Djilali** — Ingénieur Deep Learning
183
+ 🚀 Créateur de **Stroke_IA_Detection**
184
+ 🧠 (Détection d'asymétrie faciale & AVC par IA + anomalies IRM)
185
+
186
+ ⚠️ **Disclaimer :** Stroke-IA est une démo technique, pas un avis médical.
187
+ © {datetime.now().year} — Badsi Djilali.
188
+ """)