Spaces:
Running on Zero
Running on Zero
| import json | |
| import gradio as gr | |
| import spaces | |
| from data_utils import ( | |
| dataset_overview, | |
| get_class_names, | |
| get_images_for_gallery, | |
| ) | |
| from train_utils import ( | |
| train_model, | |
| list_saved_models, | |
| model_meta_path, | |
| evaluate_saved_model, | |
| ) | |
| from predict_utils import ( | |
| predict_uploaded_image, | |
| test_random_sample, | |
| ) | |
| def load_dataset_overview_callback(): | |
| try: | |
| summary, distribution_df = dataset_overview() | |
| class_names = ["Toutes les classes"] + get_class_names() | |
| return ( | |
| summary, | |
| distribution_df, | |
| gr.update(choices=class_names, value="Toutes les classes"), | |
| ) | |
| except Exception as e: | |
| return ( | |
| {"Erreur": str(e)}, | |
| None, | |
| gr.update(), | |
| ) | |
| def refresh_gallery_callback(split_name, class_name, max_images): | |
| try: | |
| gallery = get_images_for_gallery( | |
| split_name=split_name, | |
| class_name=class_name, | |
| max_images=int(max_images), | |
| ) | |
| return gallery | |
| except Exception as e: | |
| return [(None, f"Erreur : {str(e)}")] | |
| def train_callback( | |
| dropout, | |
| fc_dim, | |
| learning_rate, | |
| weight_decay, | |
| batch_size, | |
| epochs, | |
| fine_tune_mode, | |
| model_tag, | |
| ): | |
| try: | |
| result = train_model( | |
| dropout=float(dropout), | |
| fc_dim=int(fc_dim), | |
| learning_rate=float(learning_rate), | |
| weight_decay=float(weight_decay), | |
| batch_size=int(batch_size), | |
| epochs=int(epochs), | |
| fine_tune_mode=str(fine_tune_mode), | |
| model_tag=model_tag, | |
| ) | |
| models = list_saved_models() | |
| selected = result["model_name"] if result["model_name"] in models else None | |
| return ( | |
| result["logs"], | |
| result["history"], | |
| result["summary"], | |
| result["classification_report"], | |
| result["confusion_matrix"], | |
| result["confusion_matrix_path"], | |
| gr.update(choices=models, value=selected), | |
| ) | |
| except Exception as e: | |
| return ( | |
| f"Échec de l’entraînement :\n{str(e)}", | |
| None, | |
| None, | |
| None, | |
| None, | |
| None, | |
| gr.update(), | |
| ) | |
| def evaluate_saved_model_callback(model_name): | |
| try: | |
| summary, report_df, cm_df, cm_path = evaluate_saved_model(model_name) | |
| return summary, report_df, cm_df, cm_path | |
| except Exception as e: | |
| return {"Erreur": str(e)}, None, None, None | |
| def predict_uploaded_image_callback(model_name, image): | |
| try: | |
| return predict_uploaded_image(model_name, image) | |
| except Exception as e: | |
| return f"Échec de la prédiction :\n{str(e)}", None | |
| def test_random_sample_callback(model_name): | |
| try: | |
| return test_random_sample(model_name) | |
| except Exception as e: | |
| return None, f"Échec du test aléatoire :\n{str(e)}", None | |
| def refresh_models_dropdown(): | |
| models = list_saved_models() | |
| return gr.update(choices=models, value=models[0] if models else None) | |
| def get_model_info(model_name: str): | |
| if not model_name: | |
| return {"message": "Aucun modèle sélectionné."} | |
| meta_file = model_meta_path(model_name) | |
| try: | |
| with open(meta_file, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| except FileNotFoundError: | |
| return {"message": "Métadonnées introuvables."} | |
| initial_models = list_saved_models() | |
| with gr.Blocks(title="Classification d’images microscopiques") as demo: | |
| gr.Markdown("# Classification d’images microscopiques de charbons de bois") | |
| gr.Markdown( | |
| "Application pédagogique pour explorer un jeu de données d’images microscopiques, " | |
| "entraîner un modèle de classification et analyser ses performances." | |
| ) | |
| with gr.Tabs(): | |
| with gr.Tab("1. Explorer le jeu de données"): | |
| gr.Markdown("## Comprendre le jeu de données avant l’entraînement") | |
| load_dataset_btn = gr.Button( | |
| "Charger les informations du dataset", | |
| variant="primary", | |
| ) | |
| dataset_summary = gr.JSON(label="Résumé général du dataset") | |
| class_distribution = gr.Dataframe( | |
| label="Distribution des images par split et par classe", | |
| interactive=False, | |
| ) | |
| gr.Markdown("## Visualisation des images") | |
| with gr.Row(): | |
| split_selector = gr.Dropdown( | |
| choices=["train", "validation", "test"], | |
| value="train", | |
| label="Split", | |
| ) | |
| class_selector = gr.Dropdown( | |
| choices=["Toutes les classes"], | |
| value="Toutes les classes", | |
| label="Classe", | |
| ) | |
| max_images = gr.Slider( | |
| minimum=4, | |
| maximum=48, | |
| value=24, | |
| step=4, | |
| label="Nombre d’images à afficher", | |
| ) | |
| refresh_gallery_btn = gr.Button("Afficher des exemples") | |
| image_gallery = gr.Gallery( | |
| label="Exemples d’images", | |
| columns=4, | |
| height=600, | |
| ) | |
| with gr.Tab("2. Entraîner un modèle"): | |
| gr.Markdown("## Entraînement avec ResNet18 pré-entraîné") | |
| gr.Markdown( | |
| "Paramètres par défaut recommandés : fine-tuning de la dernière couche convolutionnelle " | |
| "du ResNet18, faible taux d’apprentissage, augmentation légère des données." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| dropout = gr.Slider( | |
| minimum=0.0, | |
| maximum=0.8, | |
| value=0.4, | |
| step=0.05, | |
| label="Dropout", | |
| ) | |
| fc_dim = gr.Dropdown( | |
| choices=[64, 128, 256, 512], | |
| value=256, | |
| label="Dimension de la couche cachée", | |
| ) | |
| learning_rate = gr.Number( | |
| value=0.00001, | |
| label="Taux d’apprentissage", | |
| ) | |
| weight_decay = gr.Number( | |
| value=0.0001, | |
| label="Weight decay", | |
| ) | |
| batch_size = gr.Dropdown( | |
| choices=[8, 16, 32, 64], | |
| value=16, | |
| label="Taille du batch", | |
| ) | |
| epochs = gr.Slider( | |
| minimum=1, | |
| maximum=80, | |
| value=30, | |
| step=1, | |
| label="Nombre d’époques", | |
| ) | |
| fine_tune_mode = gr.Dropdown( | |
| choices=["frozen", "layer4", "full"], | |
| value="layer4", | |
| label="Mode de fine-tuning", | |
| info=( | |
| "frozen = seul le classifieur est entraîné ; " | |
| "layer4 = dernière partie du ResNet18 + classifieur ; " | |
| "full = tout le réseau est ajusté." | |
| ), | |
| ) | |
| model_tag = gr.Textbox( | |
| label="Nom court du modèle", | |
| placeholder="ex. charbon_resnet18_layer4", | |
| ) | |
| train_btn = gr.Button("Lancer l’entraînement", variant="primary") | |
| with gr.Column(): | |
| train_status = gr.Textbox( | |
| label="Journal d’entraînement", | |
| lines=18, | |
| ) | |
| train_history = gr.JSON(label="Historique d’entraînement") | |
| train_summary = gr.JSON(label="Résumé final") | |
| gr.Markdown("## Résultats sur le test set") | |
| train_report = gr.Dataframe( | |
| label="Rapport de classification", | |
| interactive=False, | |
| ) | |
| train_confusion_matrix = gr.Dataframe( | |
| label="Matrice de confusion", | |
| interactive=False, | |
| ) | |
| train_confusion_matrix_image = gr.Image( | |
| label="Matrice de confusion - figure", | |
| type="filepath", | |
| ) | |
| with gr.Tab("3. Tester et analyser un modèle"): | |
| gr.Markdown("## Sélectionner un modèle sauvegardé") | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_selector = gr.Dropdown( | |
| choices=initial_models, | |
| value=initial_models[0] if initial_models else None, | |
| label="Modèle sauvegardé", | |
| ) | |
| refresh_btn = gr.Button("Actualiser la liste des modèles") | |
| load_info_btn = gr.Button("Afficher les informations du modèle") | |
| model_info = gr.JSON(label="Métadonnées du modèle") | |
| with gr.Column(): | |
| evaluate_btn = gr.Button( | |
| "Évaluer le modèle sur le test set", | |
| variant="primary", | |
| ) | |
| eval_summary = gr.JSON(label="Résumé des métriques") | |
| eval_report = gr.Dataframe( | |
| label="Rapport de classification", | |
| interactive=False, | |
| ) | |
| eval_confusion_matrix = gr.Dataframe( | |
| label="Matrice de confusion", | |
| interactive=False, | |
| ) | |
| eval_confusion_matrix_image = gr.Image( | |
| label="Matrice de confusion - figure", | |
| type="filepath", | |
| ) | |
| gr.Markdown("## Prédiction sur une image importée") | |
| with gr.Row(): | |
| with gr.Column(): | |
| upload_image = gr.Image(type="pil", label="Importer une image") | |
| predict_btn = gr.Button("Prédire la classe", variant="primary") | |
| with gr.Column(): | |
| predict_text = gr.Textbox(label="Résultat de la prédiction", lines=7) | |
| predict_probs = gr.Label(label="Probabilités par classe") | |
| gr.Markdown("## Test sur un échantillon aléatoire du test set") | |
| random_test_btn = gr.Button("Tester un échantillon aléatoire") | |
| with gr.Row(): | |
| random_sample_image = gr.Image(type="pil", label="Image test aléatoire") | |
| random_sample_text = gr.Textbox(label="Résultat sur l’échantillon", lines=7) | |
| random_sample_probs = gr.Label(label="Probabilités par classe") | |
| load_dataset_btn.click( | |
| fn=load_dataset_overview_callback, | |
| inputs=None, | |
| outputs=[dataset_summary, class_distribution, class_selector], | |
| ) | |
| refresh_gallery_btn.click( | |
| fn=refresh_gallery_callback, | |
| inputs=[split_selector, class_selector, max_images], | |
| outputs=image_gallery, | |
| ) | |
| train_btn.click( | |
| fn=train_callback, | |
| inputs=[ | |
| dropout, | |
| fc_dim, | |
| learning_rate, | |
| weight_decay, | |
| batch_size, | |
| epochs, | |
| fine_tune_mode, | |
| model_tag, | |
| ], | |
| outputs=[ | |
| train_status, | |
| train_history, | |
| train_summary, | |
| train_report, | |
| train_confusion_matrix, | |
| train_confusion_matrix_image, | |
| model_selector, | |
| ], | |
| ) | |
| refresh_btn.click( | |
| fn=refresh_models_dropdown, | |
| inputs=None, | |
| outputs=model_selector, | |
| ) | |
| load_info_btn.click( | |
| fn=get_model_info, | |
| inputs=model_selector, | |
| outputs=model_info, | |
| ) | |
| evaluate_btn.click( | |
| fn=evaluate_saved_model_callback, | |
| inputs=model_selector, | |
| outputs=[ | |
| eval_summary, | |
| eval_report, | |
| eval_confusion_matrix, | |
| eval_confusion_matrix_image, | |
| ], | |
| ) | |
| predict_btn.click( | |
| fn=predict_uploaded_image_callback, | |
| inputs=[model_selector, upload_image], | |
| outputs=[predict_text, predict_probs], | |
| ) | |
| random_test_btn.click( | |
| fn=test_random_sample_callback, | |
| inputs=model_selector, | |
| outputs=[random_sample_image, random_sample_text, random_sample_probs], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(ssr_mode=False) |