Spaces:
Sleeping
Sleeping
| import json | |
| import gradio as gr | |
| import spaces | |
| from data_utils import ( | |
| dataset_overview, | |
| get_class_names, | |
| get_images_for_gallery, | |
| ) | |
| from train_utils import ( | |
| train_model, | |
| list_saved_models, | |
| model_meta_path, | |
| evaluate_saved_model, | |
| ) | |
| from predict_utils import ( | |
| predict_uploaded_image, | |
| test_random_sample, | |
| ) | |
| def load_dataset_overview_callback(): | |
| try: | |
| summary, distribution_df = dataset_overview() | |
| class_names = ["Toutes les classes"] + get_class_names() | |
| return ( | |
| summary, | |
| distribution_df, | |
| gr.update(choices=class_names, value="Toutes les classes"), | |
| ) | |
| except Exception as e: | |
| return ( | |
| {"Erreur": str(e)}, | |
| None, | |
| gr.update(), | |
| ) | |
| def refresh_gallery_callback(split_name, class_name, max_images): | |
| try: | |
| gallery = get_images_for_gallery( | |
| split_name=split_name, | |
| class_name=class_name, | |
| max_images=int(max_images), | |
| ) | |
| return gallery | |
| except Exception as e: | |
| return [(None, f"Erreur : {str(e)}")] | |
| def on_model_type_change(model_type): | |
| is_cnn = (model_type == "CNN simple") | |
| default_lr = 0.001 if is_cnn else 0.0001 | |
| return gr.update(visible=is_cnn), gr.update(value=default_lr) | |
| def train_callback( | |
| model_type, | |
| num_conv_blocks, | |
| base_filters, | |
| kernel_size, | |
| use_batchnorm, | |
| dropout, | |
| fc_dim, | |
| learning_rate, | |
| weight_decay, | |
| batch_size, | |
| epochs, | |
| model_tag, | |
| ): | |
| try: | |
| result = train_model( | |
| model_type="cnn" if model_type == "CNN simple" else "resnet18", | |
| num_conv_blocks=int(num_conv_blocks), | |
| base_filters=int(base_filters), | |
| kernel_size=int(kernel_size), | |
| use_batchnorm=bool(use_batchnorm), | |
| dropout=float(dropout), | |
| fc_dim=int(fc_dim), | |
| learning_rate=float(learning_rate), | |
| weight_decay=float(weight_decay), | |
| batch_size=int(batch_size), | |
| epochs=int(epochs), | |
| model_tag=model_tag, | |
| ) | |
| models = list_saved_models() | |
| selected = result["model_name"] if result["model_name"] in models else None | |
| return ( | |
| result["logs"], | |
| result["history"], | |
| result["summary"], | |
| result["classification_report"], | |
| result["confusion_matrix"], | |
| result["confusion_matrix_path"], | |
| gr.update(choices=models, value=selected), | |
| ) | |
| except Exception as e: | |
| return ( | |
| f"Échec de l’entraînement :\n{str(e)}", | |
| None, | |
| None, | |
| None, | |
| None, | |
| None, | |
| gr.update(), | |
| ) | |
| def evaluate_saved_model_callback(model_name): | |
| try: | |
| summary, report_df, cm_df, cm_path = evaluate_saved_model(model_name) | |
| return summary, report_df, cm_df, cm_path | |
| except Exception as e: | |
| return {"Erreur": str(e)}, None, None, None | |
| def predict_uploaded_image_callback(model_name, image): | |
| try: | |
| return predict_uploaded_image(model_name, image) | |
| except Exception as e: | |
| return f"Échec de la prédiction :\n{str(e)}", None | |
| def test_random_sample_callback(model_name): | |
| try: | |
| return test_random_sample(model_name) | |
| except Exception as e: | |
| return None, f"Échec du test aléatoire :\n{str(e)}", None | |
| def refresh_models_dropdown(): | |
| models = list_saved_models() | |
| return gr.update(choices=models, value=models[0] if models else None) | |
| def get_model_info(model_name: str): | |
| if not model_name: | |
| return {"message": "Aucun modèle sélectionné."} | |
| meta_file = model_meta_path(model_name) | |
| try: | |
| with open(meta_file, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| except FileNotFoundError: | |
| return {"message": "Métadonnées introuvables."} | |
| initial_models = list_saved_models() | |
| with gr.Blocks(title="Classification d’images microscopiques") as demo: | |
| gr.Markdown("# Classification d’images microscopiques de charbons de bois") | |
| gr.Markdown( | |
| "Application pédagogique pour explorer un jeu de données d’images microscopiques, " | |
| "entraîner un modèle de classification et analyser ses performances." | |
| ) | |
| with gr.Tabs(): | |
| with gr.Tab("1. Explorer le jeu de données"): | |
| gr.Markdown("## Comprendre le jeu de données avant l’entraînement") | |
| load_dataset_btn = gr.Button( | |
| "Charger les informations du dataset", | |
| variant="primary", | |
| ) | |
| dataset_summary = gr.JSON(label="Résumé général du dataset") | |
| class_distribution = gr.Dataframe( | |
| label="Distribution des images par split et par classe", | |
| interactive=False, | |
| ) | |
| gr.Markdown("## Visualisation des images") | |
| with gr.Row(): | |
| split_selector = gr.Dropdown( | |
| choices=["train", "validation", "test"], | |
| value="train", | |
| label="Split", | |
| ) | |
| class_selector = gr.Dropdown( | |
| choices=["Toutes les classes"], | |
| value="Toutes les classes", | |
| label="Classe", | |
| ) | |
| max_images = gr.Slider( | |
| minimum=4, | |
| maximum=48, | |
| value=24, | |
| step=4, | |
| label="Nombre d’images à afficher", | |
| ) | |
| refresh_gallery_btn = gr.Button("Afficher des exemples") | |
| image_gallery = gr.Gallery( | |
| label="Exemples d’images", | |
| columns=4, | |
| height=600, | |
| ) | |
| with gr.Tab("2. Entraîner un modèle"): | |
| gr.Markdown("## Choix du modèle et entraînement") | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_type = gr.Radio( | |
| choices=["CNN simple", "ResNet18"], | |
| value="CNN simple", | |
| label="Architecture", | |
| info=( | |
| "CNN simple : entraîné de zéro, paramètres configurables. " | |
| "ResNet18 : pré-entraîné ImageNet, fine-tuning layer4 + classifieur." | |
| ), | |
| ) | |
| with gr.Column(visible=True) as cnn_params_col: | |
| gr.Markdown("#### Paramètres CNN") | |
| num_conv_blocks = gr.Slider( | |
| minimum=2, | |
| maximum=5, | |
| value=3, | |
| step=1, | |
| label="Nombre de blocs convolutionnels", | |
| info="Chaque bloc enchaîne Conv2d → (BN) → ReLU → MaxPool2d.", | |
| ) | |
| base_filters = gr.Dropdown( | |
| choices=[16, 32, 64, 128], | |
| value=32, | |
| label="Filtres du premier bloc (doublent à chaque bloc)", | |
| ) | |
| kernel_size = gr.Dropdown( | |
| choices=[3, 5], | |
| value=3, | |
| label="Taille du noyau de convolution", | |
| ) | |
| use_batchnorm = gr.Checkbox( | |
| value=True, | |
| label="Normalisation par lots (BatchNorm)", | |
| ) | |
| gr.Markdown("#### Hyperparamètres d’entraînement") | |
| dropout = gr.Slider( | |
| minimum=0.0, | |
| maximum=0.8, | |
| value=0.4, | |
| step=0.05, | |
| label="Dropout", | |
| ) | |
| fc_dim = gr.Dropdown( | |
| choices=[64, 128, 256, 512], | |
| value=256, | |
| label="Dimension de la couche cachée (classifieur)", | |
| ) | |
| learning_rate = gr.Number( | |
| value=0.001, | |
| label="Taux d’apprentissage", | |
| ) | |
| weight_decay = gr.Number( | |
| value=0.0001, | |
| label="Weight decay", | |
| ) | |
| batch_size = gr.Dropdown( | |
| choices=[8, 16, 32, 64], | |
| value=16, | |
| label="Taille du batch", | |
| ) | |
| epochs = gr.Slider( | |
| minimum=1, | |
| maximum=50, | |
| value=30, | |
| step=1, | |
| label="Nombre d’époques", | |
| ) | |
| model_tag = gr.Textbox( | |
| label="Nom court du modèle", | |
| placeholder="ex. cnn_3blocs ou resnet18_ft", | |
| ) | |
| train_btn = gr.Button("Lancer l’entraînement", variant="primary") | |
| with gr.Column(): | |
| train_status = gr.Textbox( | |
| label="Journal d’entraînement", | |
| lines=18, | |
| ) | |
| train_history = gr.JSON(label="Historique d’entraînement") | |
| train_summary = gr.JSON(label="Résumé final") | |
| gr.Markdown("## Résultats sur le test set") | |
| train_report = gr.Dataframe( | |
| label="Rapport de classification", | |
| interactive=False, | |
| ) | |
| train_confusion_matrix = gr.Dataframe( | |
| label="Matrice de confusion", | |
| interactive=False, | |
| ) | |
| train_confusion_matrix_image = gr.Image( | |
| label="Matrice de confusion - figure", | |
| type="filepath", | |
| ) | |
| with gr.Tab("3. Tester et analyser un modèle"): | |
| gr.Markdown("## Sélectionner un modèle sauvegardé") | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_selector = gr.Dropdown( | |
| choices=initial_models, | |
| value=initial_models[0] if initial_models else None, | |
| label="Modèle sauvegardé", | |
| ) | |
| refresh_btn = gr.Button("Actualiser la liste des modèles") | |
| load_info_btn = gr.Button("Afficher les informations du modèle") | |
| model_info = gr.JSON(label="Métadonnées du modèle") | |
| with gr.Column(): | |
| evaluate_btn = gr.Button( | |
| "Évaluer le modèle sur le test set", | |
| variant="primary", | |
| ) | |
| eval_summary = gr.JSON(label="Résumé des métriques") | |
| eval_report = gr.Dataframe( | |
| label="Rapport de classification", | |
| interactive=False, | |
| ) | |
| eval_confusion_matrix = gr.Dataframe( | |
| label="Matrice de confusion", | |
| interactive=False, | |
| ) | |
| eval_confusion_matrix_image = gr.Image( | |
| label="Matrice de confusion - figure", | |
| type="filepath", | |
| ) | |
| gr.Markdown("## Prédiction sur une image importée") | |
| with gr.Row(): | |
| with gr.Column(): | |
| upload_image = gr.Image(type="pil", label="Importer une image") | |
| predict_btn = gr.Button("Prédire la classe", variant="primary") | |
| with gr.Column(): | |
| predict_text = gr.Textbox(label="Résultat de la prédiction", lines=7) | |
| predict_probs = gr.Label(label="Probabilités par classe") | |
| gr.Markdown("## Test sur un échantillon aléatoire du test set") | |
| random_test_btn = gr.Button("Tester un échantillon aléatoire") | |
| with gr.Row(): | |
| random_sample_image = gr.Image(type="pil", label="Image test aléatoire") | |
| random_sample_text = gr.Textbox(label="Résultat sur l’échantillon", lines=7) | |
| random_sample_probs = gr.Label(label="Probabilités par classe") | |
| load_dataset_btn.click( | |
| fn=load_dataset_overview_callback, | |
| inputs=None, | |
| outputs=[dataset_summary, class_distribution, class_selector], | |
| ) | |
| refresh_gallery_btn.click( | |
| fn=refresh_gallery_callback, | |
| inputs=[split_selector, class_selector, max_images], | |
| outputs=image_gallery, | |
| ) | |
| model_type.change( | |
| fn=on_model_type_change, | |
| inputs=model_type, | |
| outputs=[cnn_params_col, learning_rate], | |
| ) | |
| train_btn.click( | |
| fn=train_callback, | |
| inputs=[ | |
| model_type, | |
| num_conv_blocks, | |
| base_filters, | |
| kernel_size, | |
| use_batchnorm, | |
| dropout, | |
| fc_dim, | |
| learning_rate, | |
| weight_decay, | |
| batch_size, | |
| epochs, | |
| model_tag, | |
| ], | |
| outputs=[ | |
| train_status, | |
| train_history, | |
| train_summary, | |
| train_report, | |
| train_confusion_matrix, | |
| train_confusion_matrix_image, | |
| model_selector, | |
| ], | |
| ) | |
| refresh_btn.click( | |
| fn=refresh_models_dropdown, | |
| inputs=None, | |
| outputs=model_selector, | |
| ) | |
| load_info_btn.click( | |
| fn=get_model_info, | |
| inputs=model_selector, | |
| outputs=model_info, | |
| ) | |
| evaluate_btn.click( | |
| fn=evaluate_saved_model_callback, | |
| inputs=model_selector, | |
| outputs=[ | |
| eval_summary, | |
| eval_report, | |
| eval_confusion_matrix, | |
| eval_confusion_matrix_image, | |
| ], | |
| ) | |
| predict_btn.click( | |
| fn=predict_uploaded_image_callback, | |
| inputs=[model_selector, upload_image], | |
| outputs=[predict_text, predict_probs], | |
| ) | |
| random_test_btn.click( | |
| fn=test_random_sample_callback, | |
| inputs=model_selector, | |
| outputs=[random_sample_image, random_sample_text, random_sample_probs], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(ssr_mode=False) |