DermDoctor / app.py
Abed-Négo GNANGUENON
add model
b4fb607 verified
import gradio as gr
import torch
import torchvision.transforms as transforms
from PIL import Image
import pandas as pd
import torch
# Spécifie le chemin du modèle enregistré
model_path = 'chemin/vers/le_modele.pth'
# Si tu utilises un GPU, déclare le dispositif (GPU ou CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Définir le modèle (assure-toi que la classe du modèle est définie)
class MonModele(torch.nn.Module):
# Définition de ton modèle ici (à adapter selon ton architecture)
def __init__(self):
super(MonModele, self).__init__()
# Par exemple : définir les couches du modèle
self.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.fc1 = torch.nn.Linear(64 * 224 * 224, 7) # Exemples d'outputs pour 7 classes
def forward(self, x):
x = self.conv1(x)
x = x.view(x.size(0), -1) # Aplatir les dimensions
x = self.fc1(x)
return x
# Créer une instance du modèle
model = MonModele()
model_path = "skin_cancer_model.pth"
# Charger les poids du modèle
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
model.to(device)
# Exemple de classes (à adapter selon ton cas)
classes = ["MEL", "NV", "BCC", "AKIEC", "BKL", "DF", "VASC"]
# Fonction de prédiction
def predict(image):
# Prétraitement de l'image
transform = transforms.Compose(
[
transforms.Resize(
(224, 224)
), # Redimensionner selon la taille de ton modèle
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
), # Normalisation (modèle pré-entraîné)
]
)
image = transform(image).unsqueeze(0) # Ajouter une dimension pour le batch
# Prédiction
with torch.no_grad():
output = model(image)
_, predicted = torch.max(output, 1) # Récupérer la classe prédite
# Résultat
prediction = classes[predicted.item()]
return prediction
# Actions proposées selon la prédiction
def generate_actions(prediction):
actions = {
"MEL": "Consultez un dermatologue immédiatement.",
"NV": "C'est un grain de beauté normal, continuez à surveiller.",
"BCC": "C'est une forme de cancer de la peau. Consultez un spécialiste.",
"AKIEC": "Consultez un dermatologue pour un diagnostic.",
"BKL": "C'est un grain de beauté bénin, mais surveillez-le.",
"DF": "C'est un dermatofibrome, généralement bénin.",
"VASC": "Cela peut être une lésion vasculaire, consultez un dermatologue.",
}
return actions.get(prediction, "Aucune action recommandée pour cette condition.")
# Interface Gradio
with gr.Blocks() as demo:
gr.HTML(
"""
<h3>DermDoctor</h3>
<h1 style='text-align: center'>Check your skin care with AI</h1>
"""
)
with gr.Row():
image = gr.Image(label="Upload image", type="pil", width="256px")
gr.Markdown("## Analyse summary:")
# Zone de réponse selon la prédiction
with gr.Column() as response_area:
prediction_output = gr.Text(label="Prediction", interactive=False)
actions_output = gr.Text(label="Recommended Actions", interactive=False)
# Fonction pour mettre à jour l'interface après la prédiction
def update_interface(image):
prediction = predict(image) # Prédiction du modèle
actions = generate_actions(prediction) # Actions recommandées
return prediction, actions
# Bouton pour effectuer la prédiction
image.submit(
update_interface, inputs=image, outputs=[prediction_output, actions_output]
)
demo.launch()