Stroke-ia commited on
Commit
50e76fd
·
verified ·
1 Parent(s): 43b917b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -105
app.py CHANGED
@@ -4,90 +4,44 @@ from ultralytics import YOLO
4
  import cv2, os
5
  from datetime import datetime
6
  import numpy as np
7
- import mediapipe as mp
8
- from dotenv import load_dotenv
9
-
10
- # ---------------- Charger config ----------------
11
- load_dotenv()
12
- SAVE_LIMIT_FREE = int(os.getenv("SAVE_LIMIT_FREE", 5))
13
- PREMIUM_KEY = os.getenv("PREMIUM_KEY", "VOTRE_CLE_PREMIUM")
14
 
15
  # ---------------- Config générale ----------------
16
  MODEL_PATH = "best.pt"
17
  SAVE_DIR = os.path.join("/tmp", "results")
18
  os.makedirs(SAVE_DIR, exist_ok=True)
19
 
20
- # Charger le modèle YOLO une seule fois
21
  @st.cache_resource
22
  def load_model():
23
  return YOLO(MODEL_PATH)
24
 
25
  model = load_model()
26
 
27
- # ---------------- MediaPipe Face Detection (chargé 1 seule fois) ----------------
28
- mp_face_detection = mp.solutions.face_detection.FaceDetection(min_detection_confidence=0.6)
29
-
30
- def _largest_face_bbox(np_img):
31
- h, w = np_img.shape[:2]
32
- results = mp_face_detection.process(cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR))
33
- if not results.detections:
34
- return None
35
- boxes = []
36
- for det in results.detections:
37
- rel = det.location_data.relative_bounding_box
38
- x1 = int(max(0, rel.xmin) * w)
39
- y1 = int(max(0, rel.ymin) * h)
40
- x2 = int(min(1.0, rel.xmin + rel.width) * w)
41
- y2 = int(min(1.0, rel.ymin + rel.height) * h)
42
- boxes.append((x1, y1, x2, y2))
43
- boxes.sort(key=lambda b: (b[2]-b[0])*(b[3]-b[1]), reverse=True)
44
- return boxes[0] if boxes else None
45
-
46
- # ---------------- Etat utilisateur ----------------
47
- if "uploads_count" not in st.session_state:
48
- st.session_state.uploads_count = 0
49
- if "premium_access" not in st.session_state:
50
- st.session_state.premium_access = False
51
 
52
  # ---------------- Fonctions utilitaires ----------------
53
- def check_limit():
54
- """Vérifie la limite gratuite."""
55
- if not st.session_state.premium_access and st.session_state.uploads_count >= SAVE_LIMIT_FREE:
56
- st.warning(f"⚠️ Limite gratuite atteinte ({SAVE_LIMIT_FREE} uploads). Passez en mode premium pour continuer.")
57
- return False
58
- return True
59
-
60
- def predict_image(image, conf=0.85, show_labels=True):
61
- if not check_limit():
62
- return None
63
- np_img = np.array(image)
64
-
65
- # Détection visage obligatoire
66
- face_bbox = _largest_face_bbox(np_img)
67
- if face_bbox is None:
68
- st.warning("⚠️ Aucun visage humain détecté. Veuillez centrer le visage.")
69
- return None
70
-
71
- if np_img.shape[2] == 4:
72
- np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2BGR)
73
  else:
74
- np_img = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR)
75
-
76
- results = model.predict(source=np_img, conf=conf, verbose=False)
77
- if len(results[0].boxes) == 0:
78
- return None
79
 
 
80
  annotated_image = results[0].plot(labels=show_labels)
 
81
  out_path = os.path.join(SAVE_DIR, f"image_result_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png")
82
  cv2.imwrite(out_path, annotated_image)
83
-
84
- st.session_state.uploads_count += 1
85
  return out_path
86
 
87
- def predict_video(video_path, conf=0.85, show_labels=True):
88
- if not check_limit():
89
- return None
90
-
91
  cap = cv2.VideoCapture(video_path)
92
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
93
  out_path = os.path.join(SAVE_DIR, f"video_result_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp4")
@@ -95,76 +49,85 @@ def predict_video(video_path, conf=0.85, show_labels=True):
95
  fps = cap.get(cv2.CAP_PROP_FPS) or 30
96
  width, height = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
97
  out = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
98
- detections = 0
99
 
100
  while cap.isOpened():
101
  ret, frame = cap.read()
102
  if not ret:
103
  break
104
  results = model.predict(frame, conf=conf, verbose=False)
105
- if len(results[0].boxes) > 0:
106
- detections += 1
107
  annotated = results[0].plot(labels=show_labels)
108
  out.write(annotated)
109
 
110
  cap.release()
111
  out.release()
112
- if detections == 0:
113
- return None
114
-
115
- st.session_state.uploads_count += 1
116
  return out_path
117
 
118
  # ---------------- Interface Streamlit ----------------
119
  st.title("🧠 Stroke-IA – Détection AVC par IA")
120
 
121
- # ---------------- Sidebar ----------------
122
- st.sidebar.header("⚙️ Paramètres utilisateur")
123
- conf_threshold = st.sidebar.slider("Seuil de confiance", 0.1, 1.0, 0.85, 0.05)
124
  show_labels = st.sidebar.checkbox("Afficher les labels", value=True)
125
 
126
- st.sidebar.header("🔑 Premium / Essai")
127
- if not st.session_state.premium_access:
128
- user_key = st.sidebar.text_input("Entrez votre clé premium :", type="password")
129
- if user_key == PREMIUM_KEY:
130
- st.session_state.premium_access = True
131
- st.sidebar.success(" Mode premium activé ! La limitation est levée.")
 
 
 
 
 
 
 
132
 
133
- # Afficher compteur
134
- if not st.session_state.premium_access:
135
- st.sidebar.info(f"📊 Utilisation gratuite : {st.session_state.uploads_count}/{SAVE_LIMIT_FREE}")
 
 
 
136
 
137
- # ---------------- Upload vidéo ----------------
138
  st.header("🎥 Détection sur vidéo")
139
- video_file = st.file_uploader("Uploader une vidéo", type=["mp4", "mov"])
140
- if video_file and st.button("Analyser la vidéo"):
141
- temp_path = os.path.join(SAVE_DIR, "temp_video.mp4")
142
- with open(temp_path, "wb") as f:
143
- f.write(video_file.read())
144
- result_path = predict_video(temp_path, conf=conf_threshold, show_labels=show_labels)
145
- if result_path is None:
146
- st.success(f" Aucun AVC détecté ou limite gratuite atteinte.")
147
- else:
 
 
 
 
 
148
  st.video(result_path)
149
 
150
- # ---------------- Upload image ----------------
151
  st.header("🖼️ Détection sur image")
152
- image_file = st.file_uploader("Uploader une image", type=["jpg", "jpeg", "png"])
153
- if image_file and st.button("Analyser l'image"):
154
- image = Image.open(image_file)
155
- result_path = predict_image(image, conf=conf_threshold, show_labels=show_labels)
156
- if result_path is None:
157
- st.success(f" Aucun AVC détecté ou limite gratuite atteinte.")
158
- else:
 
 
 
 
 
159
  st.image(result_path, caption="Image annotée", use_container_width=True)
160
 
161
- # ---------------- Disclaimer ----------------
162
  st.markdown(f"""
163
  ---
164
- 👨‍💻 **Badsi Djilali** — Ingénieur Deep Learning
165
- 🚀 Créateur de **Stroke_IA_Detection**
166
- 🧠 (Détection d'asymétrie faciale & AVC par IA)
167
-
168
  ⚠️ **Disclaimer :** Stroke-IA est une démo technique, pas un avis médical.
169
  © {datetime.now().year} — Badsi Djilali.
170
  """)
 
4
  import cv2, os
5
  from datetime import datetime
6
  import numpy as np
 
 
 
 
 
 
 
7
 
8
  # ---------------- Config générale ----------------
9
  MODEL_PATH = "best.pt"
10
  SAVE_DIR = os.path.join("/tmp", "results")
11
  os.makedirs(SAVE_DIR, exist_ok=True)
12
 
13
+ # ---------------- Charger le modèle (1 seule fois) ----------------
14
  @st.cache_resource
15
  def load_model():
16
  return YOLO(MODEL_PATH)
17
 
18
  model = load_model()
19
 
20
+ # ---------------- Limitation compte gratuit ----------------
21
+ MAX_IMAGES = 3
22
+ MAX_VIDEOS = 1
23
+
24
+ if "image_count" not in st.session_state:
25
+ st.session_state.image_count = 0
26
+ if "video_count" not in st.session_state:
27
+ st.session_state.video_count = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  # ---------------- Fonctions utilitaires ----------------
30
+ def predict_image(image, conf=0.25, show_labels=True):
31
+ image = np.array(image)
32
+ if image.shape[2] == 4:
33
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  else:
35
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
 
 
 
 
36
 
37
+ results = model.predict(source=image, conf=conf, verbose=False)
38
  annotated_image = results[0].plot(labels=show_labels)
39
+
40
  out_path = os.path.join(SAVE_DIR, f"image_result_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png")
41
  cv2.imwrite(out_path, annotated_image)
 
 
42
  return out_path
43
 
44
+ def predict_video(video_path, conf=0.25, show_labels=True):
 
 
 
45
  cap = cv2.VideoCapture(video_path)
46
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
47
  out_path = os.path.join(SAVE_DIR, f"video_result_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp4")
 
49
  fps = cap.get(cv2.CAP_PROP_FPS) or 30
50
  width, height = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
51
  out = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
 
52
 
53
  while cap.isOpened():
54
  ret, frame = cap.read()
55
  if not ret:
56
  break
57
  results = model.predict(frame, conf=conf, verbose=False)
 
 
58
  annotated = results[0].plot(labels=show_labels)
59
  out.write(annotated)
60
 
61
  cap.release()
62
  out.release()
 
 
 
 
63
  return out_path
64
 
65
  # ---------------- Interface Streamlit ----------------
66
  st.title("🧠 Stroke-IA – Détection AVC par IA")
67
 
68
+ # Sidebar (paramètres utilisateur)
69
+ st.sidebar.header("⚙️ Paramètres")
70
+ conf_threshold = st.sidebar.slider("Seuil de confiance", 0.1, 1.0, 0.25, 0.05)
71
  show_labels = st.sidebar.checkbox("Afficher les labels", value=True)
72
 
73
+ # Sidebar quota global
74
+ st.sidebar.header("📊 Utilisation gratuite")
75
+ st.sidebar.write(f"🖼️ Images utilisées : **{st.session_state.image_count}/{MAX_IMAGES}**")
76
+ st.sidebar.write(f"🎥 Vidéos utilisées : **{st.session_state.video_count}/{MAX_VIDEOS}**")
77
+
78
+ st.sidebar.header("📂 Exemples rapides")
79
+ if st.sidebar.button("Tester une image exemple"):
80
+ if os.path.exists("example.jpg"):
81
+ img = Image.open("example.jpg")
82
+ path = predict_image(img, conf=conf_threshold, show_labels=show_labels)
83
+ st.image(path, caption="Exemple annoté", use_container_width=True)
84
+ else:
85
+ st.warning("⚠️ Aucun fichier example.jpg trouvé.")
86
 
87
+ if st.sidebar.button("Tester une vidéo exemple"):
88
+ if os.path.exists("example.mp4"):
89
+ path = predict_video("example.mp4", conf=conf_threshold, show_labels=show_labels)
90
+ st.video(path)
91
+ else:
92
+ st.warning("⚠️ Aucun fichier example.mp4 trouvé.")
93
 
94
+ # Section vidéo upload
95
  st.header("🎥 Détection sur vidéo")
96
+
97
+ remaining_videos = MAX_VIDEOS - st.session_state.video_count
98
+ st.info(f"🎬 Il vous reste **{remaining_videos} vidéo(s)** gratuite(s).")
99
+
100
+ if st.session_state.video_count >= MAX_VIDEOS:
101
+ st.error("⚠️ Limite vidéo atteinte. Passez en premium pour continuer.")
102
+ else:
103
+ video_file = st.file_uploader("Uploader une vidéo (mp4, mov, etc.)", type=["mp4", "mov"], key="video")
104
+ if video_file and st.button("Analyser la vidéo"):
105
+ st.session_state.video_count += 1
106
+ temp_path = os.path.join(SAVE_DIR, "temp_video.mp4")
107
+ with open(temp_path, "wb") as f:
108
+ f.write(video_file.read())
109
+ result_path = predict_video(temp_path, conf=conf_threshold, show_labels=show_labels)
110
  st.video(result_path)
111
 
112
+ # Section image upload
113
  st.header("🖼️ Détection sur image")
114
+
115
+ remaining_images = MAX_IMAGES - st.session_state.image_count
116
+ st.info(f"🖼️ Il vous reste **{remaining_images} image(s)** gratuite(s).")
117
+
118
+ if st.session_state.image_count >= MAX_IMAGES:
119
+ st.error("⚠️ Limite images atteinte. Passez en premium pour continuer.")
120
+ else:
121
+ image_file = st.file_uploader("Uploader une image", type=["jpg", "jpeg", "png"], key="image")
122
+ if image_file and st.button("Analyser l'image"):
123
+ st.session_state.image_count += 1
124
+ image = Image.open(image_file)
125
+ result_path = predict_image(image, conf=conf_threshold, show_labels=show_labels)
126
  st.image(result_path, caption="Image annotée", use_container_width=True)
127
 
128
+ # Disclaimer
129
  st.markdown(f"""
130
  ---
 
 
 
 
131
  ⚠️ **Disclaimer :** Stroke-IA est une démo technique, pas un avis médical.
132
  © {datetime.now().year} — Badsi Djilali.
133
  """)