DavidNgoue's picture
Update app.py
836b6d3 verified
import streamlit as st
from PIL import Image
import numpy as np
import torch
from torchvision import transforms
import torchvision.models as models
from torchvision.models import VGG16_Weights, ResNet50_Weights, MobileNet_V2_Weights
import os
from fpdf import FPDF
import io
# Taille d'entrée pour les modèles
INPUT_SIZE = 224
# Chemins des modèles sauvegardés (adapté pour PyTorch)
MODEL_PYTORCH_PATH = "VGG16_best_accuracy_0.9660.pth" # Exemple
# Chargement des classes
class_names = ['Healthy', 'Tumor']
# Chargement des modèles PyTorch (avec gestion des poids)
models_dict = {
"VGG16": models.vgg16(weights=VGG16_Weights.DEFAULT),
"ResNet50": models.resnet50(weights=ResNet50_Weights.DEFAULT),
"MobileNetV2": models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT)
}
def predict_with_pytorch(model_path, image, class_names):
try:
checkpoint = torch.load(model_path, map_location=torch.device('cpu')) # Charge sur CPU si GPU pas dispo
model_name = checkpoint['model_name']
model = models_dict[model_name]
# Gestion des incompatibilités de taille de la dernière couche (si nécessaire)
try:
model.load_state_dict(checkpoint['model_state_dict'])
except RuntimeError as e:
if "size mismatch for classifier.6" in str(e): # Exemple pour VGG16
num_features = model.classifier[6].in_features
model.classifier[6] = torch.nn.Linear(num_features, len(class_names))
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
else:
raise
model.eval()
data_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image_tensor = data_transforms(image).unsqueeze(0)
with torch.no_grad():
output = model(image_tensor)
probabilities = torch.nn.functional.softmax(output, dim=1)[0].tolist()
_, predicted_class = torch.max(output, 1)
predicted_label = class_names[predicted_class.item()]
return predicted_label, probabilities
except FileNotFoundError:
return "Erreur: Fichier de modèle non trouvé.", None
except Exception as e:
return f"Erreur lors de la prédiction: {e}", None
# Fonction pour générer un PDF avec les prédictions et recommandations
def generate_pdf(prediction, probabilities, image, model_name):
pdf = FPDF()
pdf.set_auto_page_break(auto=True, margin=15)
pdf.add_page()
pdf.set_font("Arial", style="B", size=16)
pdf.cell(200, 10, "Rapport de Prédiction", ln=True, align='C')
pdf.ln(10)
pdf.set_font("Arial", size=12)
pdf.cell(200, 10, f"Modèle utilisé: {model_name}", ln=True)
pdf.cell(200, 10, f"Prédiction: {prediction}", ln=True)
pdf.cell(200, 10, "Probabilités:", ln=True)
for i, prob in enumerate(probabilities):
pdf.cell(200, 10, f"{class_names[i]}: {prob:.2%}", ln=True)
pdf.ln(10)
pdf.cell(200, 10, "Recommandations:", ln=True)
recommendations = {
"Cyst": "Consultez un urologue pour une évaluation approfondie.",
"Normal": "Aucune intervention n'est requise.",
"Stone": "Hydratation et suivi avec un spécialiste recommandés.",
"Tumor": "Consultez un oncologue pour un diagnostic précis."
}
pdf.multi_cell(0, 10, recommendations.get(prediction, "Aucune recommandation disponible."))
# Sauvegarde et insertion de l'image
img_path = "temp_image.jpg"
image.save(img_path)
pdf.image(img_path, x=60, w=90)
# Générer le PDF en mémoire
pdf_output = io.BytesIO()
pdf_bytes = pdf.output(dest='S').encode('latin1') # Générer en mémoire
pdf_output.write(pdf_bytes)
pdf_output.seek(0)
return pdf_output
# Fonction JS pour animer les ballons
def show_balloons():
st.markdown("""
<script type="text/javascript">
function createBalloon(x, y) {
var balloon = document.createElement('div');
balloon.style.position = 'absolute';
balloon.style.left = x + 'px';
balloon.style.top = y + 'px';
balloon.style.width = '50px';
balloon.style.height = '50px';
balloon.style.background = 'url(https://example.com/balloon.png) no-repeat center center';
balloon.style.backgroundSize = 'contain';
balloon.style.animation = 'floatBalloon 6s ease-in-out infinite';
document.body.appendChild(balloon);
}
for (let i = 0; i < 5; i++) {
var x = Math.random() * window.innerWidth;
var y = Math.random() * window.innerHeight;
createBalloon(x, y);
}
</script>
""", unsafe_allow_html=True)
# CSS pour une interface moderne avec dégradés et icônes
st.markdown("""
<style>
body {
background: linear-gradient(135deg, #6e7fef, #f0c6d1);
font-family: 'Arial', sans-serif;
color: #333;
}
.title {
text-align: center;
font-size: 36px;
font-weight: bold;
color: #ff85a2;
text-shadow: 2px 2px 5px rgba(0, 0, 0, 0.3);
margin-top: 20px;
}
.upload-section {
text-align: center;
margin: 20px;
}
.btn-primary {
background-color: #6e7fef;
border-color: #6e7fef;
color: white;
font-weight: bold;
}
.btn-primary:hover {
background-color: #5573d7;
}
.result {
text-align: center;
font-size: 24px;
color: #444;
}
.about-section {
margin-top: 40px;
font-size: 16px;
line-height: 1.6;
color: #555;
}
.legend {
font-size: 14px;
color: #555;
}
.sidebar .sidebar-content {
background: linear-gradient(135deg, #ff85a2, #ffeb3b);
color: #333;
padding-top: 20px;
}
.sidebar .sidebar-content .block-container {
padding-left: 20px;
}
.sidebar .sidebar-content .block {
background-color: rgba(255, 255, 255, 0.7);
padding: 10px;
border-radius: 8px;
margin-bottom: 10px;
}
.stSidebar {
background: linear-gradient(45deg, #EE4C2C, #FF8126, #FFC32C);
color: white;
font-size: 16px;
}
.logo {
max-width: 200px;
margin: 20px auto;
}
.logo-description {
text-align: center;
font-size: 18px;
color: white;
margin-bottom: 30px;
}
.sidebar select {
width: 100%;
padding: 10px;
background-color: #f3f3f3;
border-radius: 8px;
font-size: 16px;
border: 1px solid #ccc;
}
@keyframes floatBalloon {
0% { transform: translateY(0); opacity: 1; }
50% { transform: translateY(-200px); opacity: 0.5; }
100% { transform: translateY(0); opacity: 1; }
}
</style>
""", unsafe_allow_html=True)
# Identifiants d'accès
USERNAME = "abdouramandalil"
PASSWORD = "transferlearning"
def check_login():
""" Vérifie l'authentification de l'utilisateur """
st.sidebar.header("Connexion")
username = st.sidebar.text_input("Nom d'utilisateur")
password = st.sidebar.text_input("Mot de passe", type="password")
if st.sidebar.button("Se connecter"):
if username == USERNAME and password == PASSWORD:
st.session_state["authenticated"] = True
st.sidebar.success("Connexion réussie !")
st.rerun()
else:
st.sidebar.error("Identifiants incorrects")
def logout():
st.session_state["authenticated"] = False
st.rerun()
if "authenticated" in st.session_state and st.session_state["authenticated"]:
if st.sidebar.button("Se déconnecter"):
logout()
# Vérification de la connexion
if "authenticated" not in st.session_state:
st.session_state["authenticated"] = False
if not st.session_state["authenticated"]:
check_login()
st.stop()
# Application Streamlit
st.markdown('<div class="title">Application de Classification d\'Images</div>', unsafe_allow_html=True)
menu = st.sidebar.selectbox("Menu", ["🧠 Accueil", "🧠 Classification de l'état cérébral avec VGG16", "Comparatif entre modèles", "👨‍💻À propos"])
if menu == "🧠 Accueil":
st.write("Bienvenue dans notre application de classification d'images. Cette application a été développée pour la classification d'images dans le contexte du **Brain Tumor MRI images**, un dataset médical contenant des images de tomodensitométrie (CT) de reins, avec des catégories représentant des reins normaux, des kystes, des tumeurs et des calculs rénaux.")
st.write("""
Cette application a pour objectif de classer ces images en différentes catégories (**Healthy** Aucun problème: , **Tumor**: Tumeur cérébrale) à l'aide de modèles de machine learning.
Nous avons effectué du Transfer Learning sur trois modèles pré entrainés et avons fait un comparatif des trois avant de choisir **VGG16** comme celui avec les meilleurs caractéristiques.
Ces modèles ont été entraînés sur le dataset Brain Tumor MRI images et peuvent être utilisés pour prédire la catégorie d'une image donnée, en détectant des anomalies ou en validant l'état du rein à partir des images CT.
""")
# Création des onglets pour chaque modèle
tab_resnet50, tab_vgg16, tab_mobilenetv2 = st.tabs(["ResNet50", "VGG16", "MobileNetV2"])
with tab_resnet50:
st.image("resnet50_image.webp", caption="ResNet50", width=700)
st.write("""
**ResNet50** est un réseau de neurones convolutif profond (CNN) très populaire pour la classification d'images, introduit dans l'article "Deep Residual Learning for Image Recognition".
### Avantages :
- Utilise des **connexions résiduelles** pour résoudre les problèmes de dégradation des performances lors de l'augmentation de la profondeur du réseau.
- Excellente précision sur de grandes bases de données d'images comme ImageNet.
- Convient pour le **fine-tuning** sur des données spécifiques grâce à son architecture pré-entraînée.
### Utilisation :
ResNet50 est largement utilisé dans des tâches comme :
- La reconnaissance d'objets.
- La segmentation d'images.
- La détection de maladies en imagerie médicale.
""")
with tab_vgg16:
st.image("vgg16_image.jpg", caption="VGG16", width=700)
st.write("""
**VGG16** est un modèle de CNN développé par l'équipe de recherche Visual Geometry Group (VGG). Il est connu pour sa simplicité et son efficacité dans la classification d'images.
### Avantages :
- Architecture simple avec des couches convolutives empilées suivies de couches entièrement connectées.
- Bonne généralisation, même pour des données en dehors de son domaine d'origine.
- Facilement extensible pour des tâches comme la segmentation et la détection.
### Utilisation :
VGG16 est utilisé pour :
- La classification d'images dans des bases de données variées.
- L'extraction de caractéristiques pour des modèles personnalisés.
- Les applications médicales nécessitant des modèles interprétables.
""")
with tab_mobilenetv2:
st.image("mobilenetv2.webp", caption="MobileNetV2", width=700)
st.write("""
**MobileNetV2** est un modèle léger optimisé pour les appareils mobiles et embarqués. Il repose sur des blocs convolutifs de profondeur et des connexions résiduelles.
### Avantages :
- Très efficace en termes de calcul avec un compromis optimal entre précision et vitesse.
- Convient aux appareils à faible puissance (comme les smartphones).
- Supporte le déploiement facile via TensorFlow Lite ou PyTorch Mobile.
### Utilisation :
MobileNetV2 est utilisé pour :
- La reconnaissance d'images en temps réel sur des appareils mobiles.
- Les applications embarquées nécessitant des modèles compacts.
- Les tâches de vision par ordinateur sur des données limitées en ressources.
""")
st.write("Ces modèles pré-entraînés sont tous des choix puissants, adaptés à divers scénarios. Le choix dépend des besoins en performances, en taille de modèle et en capacité d'adaptation aux appareils cibles.")
elif menu == "🧠 Classification de l'état cérébral avec VGG16":
# Ajout d'un sous-titre explicatif pour informer sur la fonctionnalité
st.subheader("Classification avec VGG16 (PyTorch)")
# Présentation de la fonctionnalité pour l'utilisateur
st.markdown("""
Cette section vous permet de **classer une image d'état des reins** en fonction de son apparence.
Le modèle utilise **VGG16**, une architecture d'apprentissage profond optimisée pour analyser les images.
Voici ce que vous devez faire :
1. Téléchargez une image au format `jpg`, `jpeg` ou `png`.
2. Cliquez sur le bouton **Classifier** pour lancer l'analyse.
3. Obtenez le résultat du diagnostic (Normal ou Anomalie) accompagné d'un graphique des probabilités.
""")
# Étape 1 : L'utilisateur télécharge une image
uploaded_file = st.file_uploader("Téléchargez une image des reins (format jpg, jpeg ou png)", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
# Affiche l'image téléchargée
image = Image.open(uploaded_file)
st.image(image, caption="Image téléchargée avec succès", use_container_width=True)
# Explication pour l'étape suivante
st.markdown("""
Cliquez sur le bouton **Classifier** pour que le modèle analyse l'image et détermine si l'état des reins est **Normal** ou présente une **Anomalie**.
""")
# Étape 2 : L'utilisateur clique pour classifier l'image
if st.button("Classifier"):
# Appel de la fonction de prédiction avec le modèle ResNet50
label, probabilities = predict_with_pytorch(MODEL_PYTORCH_PATH, image, class_names)
pdf_file = generate_pdf(label, probabilities, image, MODEL_PYTORCH_PATH)
st.download_button("Télécharger le rapport PDF", pdf_file, "rapport_prediction.pdf", "application/pdf")
# Affichage du résultat
st.markdown(f"""
### Résultat de la classification :
**Classe prédite :** {label}
""", unsafe_allow_html=True)
# Affichage des probabilités sous forme de graphique
st.markdown("""
#### Confiance du modèle dans chaque catégorie :
Le graphique ci-dessous montre la probabilité associée à chaque classe. Une probabilité élevée indique la classe la plus probable.
""")
st.bar_chart(probabilities)
# Affichage d'un message visuel en fonction de la classe prédite
if label == "Healthy":
st.balloons() # Animation festive si le résultat est "Normal"
st.success("Félicitations ! L'image a été classée comme **Normale**.")
else:
st.error("Désolé, le modèle indique une **Anomalie** dans l'image téléchargée.")
st.markdown("""
#### Que faire en cas d'anomalie ?
Si une anomalie est détectée, il est recommandé de :
- Vérifier l'image téléchargée pour s'assurer qu'elle est correcte.
- Contacter un professionnel de santé pour une analyse approfondie.
""")
elif menu == "Comparatif entre modèles":
st.write("""
## **Analyse et Conclusions**
### **Pourquoi VGG16 est un meilleur choix ?**
Dans ce projet, nous avons utilisé **VGG16** pour la classification d'images en nous basant sur le **transfert d'apprentissage** sous **PyTorch**. Après avoir comparé plusieurs modèles, **VGG16 s'est révélé être le plus adapté** aux tâches de classification d'images grâce aux éléments suivants :
---
### **1. Simplicité et efficacité de l’architecture**
L’architecture de **VGG16** repose sur des couches convolutives empilées de manière séquentielle, suivies de couches entièrement connectées. Cette approche simplifiée permet :
- Une **meilleure généralisation** sur des tâches de classification avec peu de données.
- Une extraction de **caractéristiques précises et hiérarchisées**, facilitant l’adaptation du modèle via le transfert learning.
👉 **Avantage** : Contrairement à ResNet50, VGG16 ne dépend pas de connexions résiduelles complexes, ce qui facilite sa mise en œuvre et son ajustement en fonction des besoins spécifiques du projet.
---
### **2. Robustesse et fiabilité pour des images complexes**
VGG16 a été utilisé dans de nombreux projets de **vision par ordinateur** et s’est avéré efficace dans divers contextes. Son architecture permet de capturer des **détails fins** dans les images, ce qui est un atout pour :
- La **reconnaissance d’objets** et de **textures complexes**.
- La **classification médicale**, où la précision est cruciale, ce qui est le cas de notre jeu de données.
👉 **Avantage** : Sa capacité à apprendre des caractéristiques riches **sans nécessiter une optimisation trop poussée** en fait un modèle fiable et éprouvé.
---
### **3. Équilibre entre performance et coût computationnel**
Bien que **ResNet50** possède moins de paramètres (~25M contre ~138M pour VGG16), cela ne signifie pas nécessairement une meilleure performance sur des datasets limités. En effet, **VGG16 tire parti de sa structure régulière** pour apprendre efficacement même lorsque les données sont en quantité modérée.
- **VGG16** fonctionne bien **sans nécessiter de grandes ressources GPU**, ce qui le rend **idéal pour des environnements à puissance limitée**.
- Son absence de connexions résiduelles permet d’éviter des erreurs d’instabilité lors de l’entraînement.
👉 **Avantage** : **Facile à entraîner** et adapté aux **configurations matérielles classiques**.
---
### **4. Meilleures performances en transfert learning**
L'utilisation de **VGG16 pré-entraîné sur ImageNet** sous PyTorch nous a permis d'obtenir :
- Une **précision élevée** sur notre ensemble de validation.
- Une **stabilité des performances**, sans risque de dégradation des gradients.
👉 **Avantage** : En **gelant les couches de convolution** et en **réentraînant uniquement la tête de classification**, VGG16 s’adapte rapidement à de nouvelles tâches.
---
### **5. Comparaison avec d’autres modèles**
| Critère | VGG16 | ResNet50 | MobileNetV2 |
|-----------------------|------------------------|--------------------------|-------------------------|
| **Nombre de paramètres** | ~138M | ~25M | ~3.4M |
| **Profondeur** | 16 couches | 50 couches | Profondeur dynamique |
| **Performance** | Haute précision | Très bonne mais plus complexe à optimiser | Légèrement inférieure |
| **Taille et vitesse** | Modérée | Plus rapide mais parfois instable | Très légère et rapide |
- **ResNet50** : Son architecture résiduelle est puissante, mais **nécessite un ajustement plus fin des hyperparamètres**, ce qui peut poser problème sur des datasets limités.
- **MobileNetV2** : Très léger, mais **moins performant** pour des tâches nécessitant une haute précision.
👉 **Pourquoi VGG16 ?** Sa simplicité et son efficacité en font un choix sûr et performant pour **la classification d’images sous PyTorch**.
---
### **6. Résultats observés dans ce projet**
En utilisant **VGG16 avec du transfert learning**, nous avons obtenu :
✅ Une **précision élevée**, même avec un dataset limité.
✅ Une **convergence rapide**, sans nécessiter d'ajustements complexes.
✅ Une **bonne stabilité des résultats**, sans risque de sur-apprentissage excessif.
---
## **Conclusion**
En résumé, **VGG16** s'est avéré être **le meilleur choix pour ce projet** en raison de :
- **Son architecture simple et efficace**, qui facilite son implémentation sous PyTorch.
- **Son excellente capacité d’apprentissage** pour des tâches de classification d’images complexes.
- **Son efficacité en transfert learning**, permettant une adaptation rapide sans nécessiter d’énormes ressources GPU.
👉 **Recommandation** : Pour des tâches de **classification d’images nécessitant une précision élevée et une mise en œuvre rapide**, **VGG16 est le modèle idéal sous PyTorch**. 🚀
""")
elif menu == "👨‍💻À propos":
st.header("À propos de moi")
st.markdown(
"""
<div style="text-align:center; font-family: Arial; margin: 20px 0;">
<h2>Mon Parcours</h2>
<p>Je suis un passionné de l'intelligence artificielle et de la donnée. Actuellement en Master 2 en IA et Big Data, je travaille sur des solutions innovantes dans le domaine de l'Intelligence Artificielle appliquée à la santé.</p>
</div>
<div style="display: flex; justify-content: center; gap: 30px; flex-wrap: wrap;">
<div style="text-align: center;">
<img src="https://avatars.githubusercontent.com/u/1234567" alt="Ngoue David" style="width: 150px; height: 150px; border-radius: 50%;">
<h4>Ngoue David</h4>
<p>🎓 Master 2 IA & Big Data</p>
<p>📧 <a href="mailto:ngouedavidrogeryannick@gmail.com">ngouedavidrogeryannick@gmail.com</a></p>
<p>🌐 <a href="https://github.com/TheBeyonder237" target="_blank">Profil GitHub</a></p>
</div>
</div>
""", unsafe_allow_html=True)