Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| import pandas as pd | |
| import torch | |
| # Spécifie le chemin du modèle enregistré | |
| model_path = 'chemin/vers/le_modele.pth' | |
| # Si tu utilises un GPU, déclare le dispositif (GPU ou CPU) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Définir le modèle (assure-toi que la classe du modèle est définie) | |
| class MonModele(torch.nn.Module): | |
| # Définition de ton modèle ici (à adapter selon ton architecture) | |
| def __init__(self): | |
| super(MonModele, self).__init__() | |
| # Par exemple : définir les couches du modèle | |
| self.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1) | |
| self.fc1 = torch.nn.Linear(64 * 224 * 224, 7) # Exemples d'outputs pour 7 classes | |
| def forward(self, x): | |
| x = self.conv1(x) | |
| x = x.view(x.size(0), -1) # Aplatir les dimensions | |
| x = self.fc1(x) | |
| return x | |
| # Créer une instance du modèle | |
| model = MonModele() | |
| model_path = "skin_cancer_model.pth" | |
| # Charger les poids du modèle | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| model.eval() | |
| model.to(device) | |
| # Exemple de classes (à adapter selon ton cas) | |
| classes = ["MEL", "NV", "BCC", "AKIEC", "BKL", "DF", "VASC"] | |
| # Fonction de prédiction | |
| def predict(image): | |
| # Prétraitement de l'image | |
| transform = transforms.Compose( | |
| [ | |
| transforms.Resize( | |
| (224, 224) | |
| ), # Redimensionner selon la taille de ton modèle | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | |
| ), # Normalisation (modèle pré-entraîné) | |
| ] | |
| ) | |
| image = transform(image).unsqueeze(0) # Ajouter une dimension pour le batch | |
| # Prédiction | |
| with torch.no_grad(): | |
| output = model(image) | |
| _, predicted = torch.max(output, 1) # Récupérer la classe prédite | |
| # Résultat | |
| prediction = classes[predicted.item()] | |
| return prediction | |
| # Actions proposées selon la prédiction | |
| def generate_actions(prediction): | |
| actions = { | |
| "MEL": "Consultez un dermatologue immédiatement.", | |
| "NV": "C'est un grain de beauté normal, continuez à surveiller.", | |
| "BCC": "C'est une forme de cancer de la peau. Consultez un spécialiste.", | |
| "AKIEC": "Consultez un dermatologue pour un diagnostic.", | |
| "BKL": "C'est un grain de beauté bénin, mais surveillez-le.", | |
| "DF": "C'est un dermatofibrome, généralement bénin.", | |
| "VASC": "Cela peut être une lésion vasculaire, consultez un dermatologue.", | |
| } | |
| return actions.get(prediction, "Aucune action recommandée pour cette condition.") | |
| # Interface Gradio | |
| with gr.Blocks() as demo: | |
| gr.HTML( | |
| """ | |
| <h3>DermDoctor</h3> | |
| <h1 style='text-align: center'>Check your skin care with AI</h1> | |
| """ | |
| ) | |
| with gr.Row(): | |
| image = gr.Image(label="Upload image", type="pil", width="256px") | |
| gr.Markdown("## Analyse summary:") | |
| # Zone de réponse selon la prédiction | |
| with gr.Column() as response_area: | |
| prediction_output = gr.Text(label="Prediction", interactive=False) | |
| actions_output = gr.Text(label="Recommended Actions", interactive=False) | |
| # Fonction pour mettre à jour l'interface après la prédiction | |
| def update_interface(image): | |
| prediction = predict(image) # Prédiction du modèle | |
| actions = generate_actions(prediction) # Actions recommandées | |
| return prediction, actions | |
| # Bouton pour effectuer la prédiction | |
| image.submit( | |
| update_interface, inputs=image, outputs=[prediction_output, actions_output] | |
| ) | |
| demo.launch() | |