TRANSFERBABONG / app.py
LewisBabong's picture
Rename NTONGA_BABONG_TRANSFER_LEARNING_STREAMLIT_SN.py to app.py
db06b97 verified
import streamlit as st
import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
# Fonction pour charger le modèle (architecture identique à celle utilisée lors de l'entraînement)
def load_model():
model = models.resnet18(pretrained=False)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)
# Charger les poids sauvegardés (assurez-vous que "Ntonga_brain_tumor_model.pth" est dans le même dossier)
model.load_state_dict(torch.load("Ntonga_brain_tumor_model.pth", map_location=torch.device('cpu')))
model.eval()
return model
model = load_model()
# Transformation pour préparer l'image uploadée (utilise PIL pour redimensionner)
transform = transforms.Compose([
transforms.Lambda(lambda img: img.resize((224, 224))),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# Configuration de la page et style global
st.set_page_config(page_title="Diagnostic Tumeur Cérébrale", page_icon="🧠", layout="wide")
st.markdown("""
<style>
body { background-color: #f4f4f9; font-family: 'Segoe UI', sans-serif; }
.sidebar .sidebar-content { background-color: #e8f5e9; }
h1, h2, h3 { color: #283593; text-align: center; }
.stButton>button { background-color: #283593; color: white; border-radius: 8px; }
.uploaded-image { border: 2px solid #283593; border-radius: 8px; }
.header { text-align: center; margin-bottom: 20px; }
</style>
""", unsafe_allow_html=True)
# Sidebar : Options et informations complémentaires
st.sidebar.title("Menu")
st.sidebar.markdown("### Options")
confidence_threshold = st.sidebar.slider("Seuil de confiance", 0.0, 1.0, 0.5, step=0.05)
show_details = st.sidebar.checkbox("Afficher détails graphiques", value=True)
st.sidebar.markdown("### À propos")
st.sidebar.info(
"Cette application utilise un modèle de transfert learning basé sur ResNet18 pour classifier "
"les IRM cérébrales en deux catégories : Healthy et Tumor. "
"Pour un diagnostic complet, consultez un professionnel de santé."
)
# Titre et description de la page principale
st.markdown("<div class='header'><h1>Diagnostic IRM Cérébrale 🧠</h1></div>", unsafe_allow_html=True)
st.subheader("Analyse et Détection de Tumeurs")
st.write("Téléversez une ou plusieurs images d'IRM du cerveau pour détecter la présence d'une tumeur.")
# Zone de téléchargement des images
uploaded_files = st.file_uploader(
"Uploader une image (JPG, PNG, JPEG)",
type=["jpg", "png", "jpeg"],
accept_multiple_files=True
)
# Liste pour stocker les prédictions et images pour une éventuelle comparaison
predictions_history = []
if uploaded_files:
for uploaded_file in uploaded_files:
# Chargement et affichage de l'image
image = Image.open(uploaded_file)
st.image(image, caption=f"Image chargée : {uploaded_file.name}", use_column_width=True)
# Prétraitement de l'image
image_proc = transform(image).unsqueeze(0)
# Prédiction avec animation
with st.spinner(f"Analyse de {uploaded_file.name} en cours..."):
outputs = model(image_proc)
softmax = torch.nn.Softmax(dim=1)
probs = softmax(outputs).detach().numpy()[0]
_, pred = torch.max(outputs, 1)
# Définition des classes et résultat
class_names = ["Healthy", "Tumor"]
prediction = class_names[pred.item()]
predictions_history.append({
'filename': uploaded_file.name,
'prediction': prediction,
'probs': probs,
'image': image
})
# Affichage du résultat pour chaque image
st.success(f"Résultat pour **{uploaded_file.name}** : {prediction}")
st.write(f"**Score Healthy :** {probs[0]:.2f} | **Score Tumor :** {probs[1]:.2f}")
if np.max(probs) < confidence_threshold:
st.warning("La confiance du modèle est faible pour cette image.")
# Création d'onglets pour afficher plusieurs graphiques par image
tab1, tab2, tab3 = st.tabs(["Graphique Barres", "Graphique Camembert", "Détails & Radar"])
with tab1:
fig_bar, ax_bar = plt.subplots(figsize=(6, 4))
colors = ["green", "red"]
ax_bar.bar(class_names, probs, color=colors)
ax_bar.set_ylim(0, 1)
ax_bar.set_ylabel("Probabilité")
ax_bar.set_title("Répartition des Probabilités")
st.pyplot(fig_bar)
with tab2:
fig_pie, ax_pie = plt.subplots(figsize=(6, 4))
ax_pie.pie(probs, labels=class_names, autopct="%1.1f%%", startangle=90, colors=colors)
ax_pie.axis('equal')
ax_pie.set_title("Distribution en Pourcentage")
st.pyplot(fig_pie)
with tab3:
st.write("### Détails de la Prédiction")
st.write(f"Image : **{uploaded_file.name}**")
st.write(f"Prédiction : **{prediction}**")
st.write(f"Score Healthy : **{probs[0]:.2f}**")
st.write(f"Score Tumor : **{probs[1]:.2f}**")
# Graphique radar pour visualiser les scores
angles = np.linspace(0, 2 * np.pi, len(class_names), endpoint=False)
scores = probs.tolist()
scores += scores[:1]
angles = np.concatenate((angles, [angles[0]]))
fig_radar, ax_radar = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True))
ax_radar.plot(angles, scores, 'o-', linewidth=2, label=uploaded_file.name)
ax_radar.fill(angles, scores, alpha=0.25)
ax_radar.set_thetagrids(angles * 180 / np.pi, class_names + [class_names[0]])
ax_radar.set_ylim(0, 1)
ax_radar.set_title("Graphique Radar des Probabilités")
st.pyplot(fig_radar)
st.markdown("---")
# Analyse agrégée si plusieurs images ont été téléversées
if len(predictions_history) > 1:
st.subheader("Analyse Agrégée")
# Tableau récapitulatif
df = pd.DataFrame({
"Nom de l'image": [d["filename"] for d in predictions_history],
"Prédiction": [d["prediction"] for d in predictions_history],
"Score Healthy": [d["probs"][0] for d in predictions_history],
"Score Tumor": [d["probs"][1] for d in predictions_history]
})
st.table(df)
# Histogramme des scores de confiance
fig_hist, ax_hist = plt.subplots(figsize=(8, 4))
all_confidences = [max(d["probs"]) for d in predictions_history]
ax_hist.hist(all_confidences, bins=10, color="purple", alpha=0.7)
ax_hist.set_xlabel("Score de Confiance")
ax_hist.set_ylabel("Fréquence")
ax_hist.set_title("Distribution des Scores de Confiance")
st.pyplot(fig_hist)
# Graphique en lignes pour comparer les scores entre images
fig_line, ax_line = plt.subplots(figsize=(8, 4))
x_labels = [d["filename"] for d in predictions_history]
healthy_scores = [d["probs"][0] for d in predictions_history]
tumor_scores = [d["probs"][1] for d in predictions_history]
ax_line.plot(x_labels, healthy_scores, marker="o", linestyle="--", label="Healthy", color="green")
ax_line.plot(x_labels, tumor_scores, marker="o", linestyle="--", label="Tumor", color="red")
ax_line.set_ylim(0, 1)
ax_line.set_ylabel("Score")
ax_line.set_title("Comparaison des Scores par Image")
ax_line.legend()
plt.xticks(rotation=45)
st.pyplot(fig_line)
# Bouton pour télécharger les résultats en CSV
csv = df.to_csv(index=False).encode('utf-8')
st.download_button(
label="Télécharger les Résultats",
data=csv,
file_name="resultats_predictions.csv",
mime="text/csv"
)
# Pied de page
st.markdown("""
<hr>
<div style='text-align: center;'>
<em>Développé par l'équipe IA - Version améliorée</em>
</div>
""", unsafe_allow_html=True)