Spaces:
Build error
Build error
| import streamlit as st | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms, models | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import numpy as np | |
| import pandas as pd | |
| # Fonction pour charger le modèle (architecture identique à celle utilisée lors de l'entraînement) | |
| def load_model(): | |
| model = models.resnet18(pretrained=False) | |
| num_ftrs = model.fc.in_features | |
| model.fc = nn.Linear(num_ftrs, 2) | |
| # Charger les poids sauvegardés (assurez-vous que "Ntonga_brain_tumor_model.pth" est dans le même dossier) | |
| model.load_state_dict(torch.load("Ntonga_brain_tumor_model.pth", map_location=torch.device('cpu'))) | |
| model.eval() | |
| return model | |
| model = load_model() | |
| # Transformation pour préparer l'image uploadée (utilise PIL pour redimensionner) | |
| transform = transforms.Compose([ | |
| transforms.Lambda(lambda img: img.resize((224, 224))), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| # Configuration de la page et style global | |
| st.set_page_config(page_title="Diagnostic Tumeur Cérébrale", page_icon="🧠", layout="wide") | |
| st.markdown(""" | |
| <style> | |
| body { background-color: #f4f4f9; font-family: 'Segoe UI', sans-serif; } | |
| .sidebar .sidebar-content { background-color: #e8f5e9; } | |
| h1, h2, h3 { color: #283593; text-align: center; } | |
| .stButton>button { background-color: #283593; color: white; border-radius: 8px; } | |
| .uploaded-image { border: 2px solid #283593; border-radius: 8px; } | |
| .header { text-align: center; margin-bottom: 20px; } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Sidebar : Options et informations complémentaires | |
| st.sidebar.title("Menu") | |
| st.sidebar.markdown("### Options") | |
| confidence_threshold = st.sidebar.slider("Seuil de confiance", 0.0, 1.0, 0.5, step=0.05) | |
| show_details = st.sidebar.checkbox("Afficher détails graphiques", value=True) | |
| st.sidebar.markdown("### À propos") | |
| st.sidebar.info( | |
| "Cette application utilise un modèle de transfert learning basé sur ResNet18 pour classifier " | |
| "les IRM cérébrales en deux catégories : Healthy et Tumor. " | |
| "Pour un diagnostic complet, consultez un professionnel de santé." | |
| ) | |
| # Titre et description de la page principale | |
| st.markdown("<div class='header'><h1>Diagnostic IRM Cérébrale 🧠</h1></div>", unsafe_allow_html=True) | |
| st.subheader("Analyse et Détection de Tumeurs") | |
| st.write("Téléversez une ou plusieurs images d'IRM du cerveau pour détecter la présence d'une tumeur.") | |
| # Zone de téléchargement des images | |
| uploaded_files = st.file_uploader( | |
| "Uploader une image (JPG, PNG, JPEG)", | |
| type=["jpg", "png", "jpeg"], | |
| accept_multiple_files=True | |
| ) | |
| # Liste pour stocker les prédictions et images pour une éventuelle comparaison | |
| predictions_history = [] | |
| if uploaded_files: | |
| for uploaded_file in uploaded_files: | |
| # Chargement et affichage de l'image | |
| image = Image.open(uploaded_file) | |
| st.image(image, caption=f"Image chargée : {uploaded_file.name}", use_column_width=True) | |
| # Prétraitement de l'image | |
| image_proc = transform(image).unsqueeze(0) | |
| # Prédiction avec animation | |
| with st.spinner(f"Analyse de {uploaded_file.name} en cours..."): | |
| outputs = model(image_proc) | |
| softmax = torch.nn.Softmax(dim=1) | |
| probs = softmax(outputs).detach().numpy()[0] | |
| _, pred = torch.max(outputs, 1) | |
| # Définition des classes et résultat | |
| class_names = ["Healthy", "Tumor"] | |
| prediction = class_names[pred.item()] | |
| predictions_history.append({ | |
| 'filename': uploaded_file.name, | |
| 'prediction': prediction, | |
| 'probs': probs, | |
| 'image': image | |
| }) | |
| # Affichage du résultat pour chaque image | |
| st.success(f"Résultat pour **{uploaded_file.name}** : {prediction}") | |
| st.write(f"**Score Healthy :** {probs[0]:.2f} | **Score Tumor :** {probs[1]:.2f}") | |
| if np.max(probs) < confidence_threshold: | |
| st.warning("La confiance du modèle est faible pour cette image.") | |
| # Création d'onglets pour afficher plusieurs graphiques par image | |
| tab1, tab2, tab3 = st.tabs(["Graphique Barres", "Graphique Camembert", "Détails & Radar"]) | |
| with tab1: | |
| fig_bar, ax_bar = plt.subplots(figsize=(6, 4)) | |
| colors = ["green", "red"] | |
| ax_bar.bar(class_names, probs, color=colors) | |
| ax_bar.set_ylim(0, 1) | |
| ax_bar.set_ylabel("Probabilité") | |
| ax_bar.set_title("Répartition des Probabilités") | |
| st.pyplot(fig_bar) | |
| with tab2: | |
| fig_pie, ax_pie = plt.subplots(figsize=(6, 4)) | |
| ax_pie.pie(probs, labels=class_names, autopct="%1.1f%%", startangle=90, colors=colors) | |
| ax_pie.axis('equal') | |
| ax_pie.set_title("Distribution en Pourcentage") | |
| st.pyplot(fig_pie) | |
| with tab3: | |
| st.write("### Détails de la Prédiction") | |
| st.write(f"Image : **{uploaded_file.name}**") | |
| st.write(f"Prédiction : **{prediction}**") | |
| st.write(f"Score Healthy : **{probs[0]:.2f}**") | |
| st.write(f"Score Tumor : **{probs[1]:.2f}**") | |
| # Graphique radar pour visualiser les scores | |
| angles = np.linspace(0, 2 * np.pi, len(class_names), endpoint=False) | |
| scores = probs.tolist() | |
| scores += scores[:1] | |
| angles = np.concatenate((angles, [angles[0]])) | |
| fig_radar, ax_radar = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True)) | |
| ax_radar.plot(angles, scores, 'o-', linewidth=2, label=uploaded_file.name) | |
| ax_radar.fill(angles, scores, alpha=0.25) | |
| ax_radar.set_thetagrids(angles * 180 / np.pi, class_names + [class_names[0]]) | |
| ax_radar.set_ylim(0, 1) | |
| ax_radar.set_title("Graphique Radar des Probabilités") | |
| st.pyplot(fig_radar) | |
| st.markdown("---") | |
| # Analyse agrégée si plusieurs images ont été téléversées | |
| if len(predictions_history) > 1: | |
| st.subheader("Analyse Agrégée") | |
| # Tableau récapitulatif | |
| df = pd.DataFrame({ | |
| "Nom de l'image": [d["filename"] for d in predictions_history], | |
| "Prédiction": [d["prediction"] for d in predictions_history], | |
| "Score Healthy": [d["probs"][0] for d in predictions_history], | |
| "Score Tumor": [d["probs"][1] for d in predictions_history] | |
| }) | |
| st.table(df) | |
| # Histogramme des scores de confiance | |
| fig_hist, ax_hist = plt.subplots(figsize=(8, 4)) | |
| all_confidences = [max(d["probs"]) for d in predictions_history] | |
| ax_hist.hist(all_confidences, bins=10, color="purple", alpha=0.7) | |
| ax_hist.set_xlabel("Score de Confiance") | |
| ax_hist.set_ylabel("Fréquence") | |
| ax_hist.set_title("Distribution des Scores de Confiance") | |
| st.pyplot(fig_hist) | |
| # Graphique en lignes pour comparer les scores entre images | |
| fig_line, ax_line = plt.subplots(figsize=(8, 4)) | |
| x_labels = [d["filename"] for d in predictions_history] | |
| healthy_scores = [d["probs"][0] for d in predictions_history] | |
| tumor_scores = [d["probs"][1] for d in predictions_history] | |
| ax_line.plot(x_labels, healthy_scores, marker="o", linestyle="--", label="Healthy", color="green") | |
| ax_line.plot(x_labels, tumor_scores, marker="o", linestyle="--", label="Tumor", color="red") | |
| ax_line.set_ylim(0, 1) | |
| ax_line.set_ylabel("Score") | |
| ax_line.set_title("Comparaison des Scores par Image") | |
| ax_line.legend() | |
| plt.xticks(rotation=45) | |
| st.pyplot(fig_line) | |
| # Bouton pour télécharger les résultats en CSV | |
| csv = df.to_csv(index=False).encode('utf-8') | |
| st.download_button( | |
| label="Télécharger les Résultats", | |
| data=csv, | |
| file_name="resultats_predictions.csv", | |
| mime="text/csv" | |
| ) | |
| # Pied de page | |
| st.markdown(""" | |
| <hr> | |
| <div style='text-align: center;'> | |
| <em>Développé par l'équipe IA - Version améliorée</em> | |
| </div> | |
| """, unsafe_allow_html=True) | |