STREAMLITE / app_lent_v3.py
Stroke-ia's picture
Rename app.py to app_lent_v3.py
dacdfa8 verified
import streamlit as st
from PIL import Image
from ultralytics import YOLO
import cv2, os
from datetime import datetime
import numpy as np
# ---------------- Config générale ----------------
MODEL_PATH = "best.pt"
SAVE_DIR = os.path.join("/tmp", "results")
os.makedirs(SAVE_DIR, exist_ok=True)
# ---------------- Charger le modèle (1 seule fois) ----------------
@st.cache_resource
def load_model():
return YOLO(MODEL_PATH)
model = load_model()
# ---------------- Limitation compte gratuit ----------------
MAX_IMAGES = 3
MAX_VIDEOS = 1
if "image_count" not in st.session_state:
st.session_state.image_count = 0
if "video_count" not in st.session_state:
st.session_state.video_count = 0
# ---------------- Fonctions utilitaires ----------------
def predict_image(image, conf=0.25, show_labels=True):
image = np.array(image)
if image.shape[2] == 4:
image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR)
else:
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
results = model.predict(source=image, conf=conf, verbose=False)
annotated_image = results[0].plot(labels=show_labels)
out_path = os.path.join(SAVE_DIR, f"image_result_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png")
cv2.imwrite(out_path, annotated_image)
return out_path
def predict_video(video_path, conf=0.25, show_labels=True):
cap = cv2.VideoCapture(video_path)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out_path = os.path.join(SAVE_DIR, f"video_result_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp4")
fps = cap.get(cv2.CAP_PROP_FPS) or 30
width, height = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
out = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
results = model.predict(frame, conf=conf, verbose=False)
annotated = results[0].plot(labels=show_labels)
out.write(annotated)
cap.release()
out.release()
return out_path
# ---------------- Interface Streamlit ----------------
st.title("🧠 Stroke-IA – Détection AVC par IA")
# Sidebar (paramètres utilisateur)
st.sidebar.header("⚙️ Paramètres")
conf_threshold = st.sidebar.slider("Seuil de confiance", 0.1, 1.0, 0.25, 0.05)
show_labels = st.sidebar.checkbox("Afficher les labels", value=True)
# Sidebar quota global
st.sidebar.header("📊 Utilisation gratuite")
st.sidebar.write(f"🖼️ Images utilisées : **{st.session_state.image_count}/{MAX_IMAGES}**")
st.sidebar.write(f"🎥 Vidéos utilisées : **{st.session_state.video_count}/{MAX_VIDEOS}**")
st.sidebar.header("📂 Exemples rapides")
if st.sidebar.button("Tester une image exemple"):
if os.path.exists("example.jpg"):
img = Image.open("example.jpg")
path = predict_image(img, conf=conf_threshold, show_labels=show_labels)
st.image(path, caption="Exemple annoté", use_container_width=True)
else:
st.warning("⚠️ Aucun fichier example.jpg trouvé.")
if st.sidebar.button("Tester une vidéo exemple"):
if os.path.exists("example.mp4"):
path = predict_video("example.mp4", conf=conf_threshold, show_labels=show_labels)
st.video(path)
else:
st.warning("⚠️ Aucun fichier example.mp4 trouvé.")
# Section vidéo upload
st.header("🎥 Détection sur vidéo")
remaining_videos = MAX_VIDEOS - st.session_state.video_count
st.info(f"🎬 Il vous reste **{remaining_videos} vidéo(s)** gratuite(s).")
if st.session_state.video_count >= MAX_VIDEOS:
st.error("⚠️ Limite vidéo atteinte. Passez en premium pour continuer.")
else:
video_file = st.file_uploader("Uploader une vidéo (mp4, mov, etc.)", type=["mp4", "mov"], key="video")
if video_file and st.button("Analyser la vidéo"):
st.session_state.video_count += 1
temp_path = os.path.join(SAVE_DIR, "temp_video.mp4")
with open(temp_path, "wb") as f:
f.write(video_file.read())
result_path = predict_video(temp_path, conf=conf_threshold, show_labels=show_labels)
st.video(result_path)
# Section image upload
st.header("🖼️ Détection sur image")
remaining_images = MAX_IMAGES - st.session_state.image_count
st.info(f"🖼️ Il vous reste **{remaining_images} image(s)** gratuite(s).")
if st.session_state.image_count >= MAX_IMAGES:
st.error("⚠️ Limite images atteinte. Passez en premium pour continuer.")
else:
image_file = st.file_uploader("Uploader une image", type=["jpg", "jpeg", "png"], key="image")
if image_file and st.button("Analyser l'image"):
st.session_state.image_count += 1
image = Image.open(image_file)
result_path = predict_image(image, conf=conf_threshold, show_labels=show_labels)
st.image(result_path, caption="Image annotée", use_container_width=True)
# Disclaimer
st.markdown(f"""
---
⚠️ **Disclaimer :** Stroke-IA est une démo technique, pas un avis médical.
© {datetime.now().year} — Badsi Djilali.
""")