import json import gradio as gr import spaces from data_utils import ( dataset_overview, get_class_names, get_images_for_gallery, ) from train_utils import ( train_model, list_saved_models, model_meta_path, evaluate_saved_model, ) from predict_utils import ( predict_uploaded_image, test_random_sample, ) def load_dataset_overview_callback(): try: summary, distribution_df = dataset_overview() class_names = ["Toutes les classes"] + get_class_names() return ( summary, distribution_df, gr.update(choices=class_names, value="Toutes les classes"), ) except Exception as e: return ( {"Erreur": str(e)}, None, gr.update(), ) def refresh_gallery_callback(split_name, class_name, max_images): try: gallery = get_images_for_gallery( split_name=split_name, class_name=class_name, max_images=int(max_images), ) return gallery except Exception as e: return [(None, f"Erreur : {str(e)}")] def on_model_type_change(model_type): is_cnn = (model_type == "CNN simple") default_lr = 0.001 if is_cnn else 0.0001 return gr.update(visible=is_cnn), gr.update(value=default_lr) @spaces.GPU(duration=200) def train_callback( model_type, num_conv_blocks, base_filters, kernel_size, use_batchnorm, dropout, fc_dim, learning_rate, weight_decay, batch_size, epochs, model_tag, ): try: result = train_model( model_type="cnn" if model_type == "CNN simple" else "resnet18", num_conv_blocks=int(num_conv_blocks), base_filters=int(base_filters), kernel_size=int(kernel_size), use_batchnorm=bool(use_batchnorm), dropout=float(dropout), fc_dim=int(fc_dim), learning_rate=float(learning_rate), weight_decay=float(weight_decay), batch_size=int(batch_size), epochs=int(epochs), model_tag=model_tag, ) models = list_saved_models() selected = result["model_name"] if result["model_name"] in models else None return ( result["logs"], result["history"], result["summary"], result["classification_report"], result["confusion_matrix"], result["confusion_matrix_path"], gr.update(choices=models, value=selected), ) except Exception as e: return ( f"Échec de l’entraînement :\n{str(e)}", None, None, None, None, None, gr.update(), ) @spaces.GPU(duration=120) def evaluate_saved_model_callback(model_name): try: summary, report_df, cm_df, cm_path = evaluate_saved_model(model_name) return summary, report_df, cm_df, cm_path except Exception as e: return {"Erreur": str(e)}, None, None, None @spaces.GPU(duration=60) def predict_uploaded_image_callback(model_name, image): try: return predict_uploaded_image(model_name, image) except Exception as e: return f"Échec de la prédiction :\n{str(e)}", None @spaces.GPU(duration=60) def test_random_sample_callback(model_name): try: return test_random_sample(model_name) except Exception as e: return None, f"Échec du test aléatoire :\n{str(e)}", None def refresh_models_dropdown(): models = list_saved_models() return gr.update(choices=models, value=models[0] if models else None) def get_model_info(model_name: str): if not model_name: return {"message": "Aucun modèle sélectionné."} meta_file = model_meta_path(model_name) try: with open(meta_file, "r", encoding="utf-8") as f: return json.load(f) except FileNotFoundError: return {"message": "Métadonnées introuvables."} initial_models = list_saved_models() with gr.Blocks(title="Classification d’images microscopiques") as demo: gr.Markdown("# Classification d’images microscopiques de charbons de bois") gr.Markdown( "Application pédagogique pour explorer un jeu de données d’images microscopiques, " "entraîner un modèle de classification et analyser ses performances." ) with gr.Tabs(): with gr.Tab("1. Explorer le jeu de données"): gr.Markdown("## Comprendre le jeu de données avant l’entraînement") load_dataset_btn = gr.Button( "Charger les informations du dataset", variant="primary", ) dataset_summary = gr.JSON(label="Résumé général du dataset") class_distribution = gr.Dataframe( label="Distribution des images par split et par classe", interactive=False, ) gr.Markdown("## Visualisation des images") with gr.Row(): split_selector = gr.Dropdown( choices=["train", "validation", "test"], value="train", label="Split", ) class_selector = gr.Dropdown( choices=["Toutes les classes"], value="Toutes les classes", label="Classe", ) max_images = gr.Slider( minimum=4, maximum=48, value=24, step=4, label="Nombre d’images à afficher", ) refresh_gallery_btn = gr.Button("Afficher des exemples") image_gallery = gr.Gallery( label="Exemples d’images", columns=4, height=600, ) with gr.Tab("2. Entraîner un modèle"): gr.Markdown("## Choix du modèle et entraînement") with gr.Row(): with gr.Column(): model_type = gr.Radio( choices=["CNN simple", "ResNet18"], value="CNN simple", label="Architecture", info=( "CNN simple : entraîné de zéro, paramètres configurables. " "ResNet18 : pré-entraîné ImageNet, fine-tuning layer4 + classifieur." ), ) with gr.Column(visible=True) as cnn_params_col: gr.Markdown("#### Paramètres CNN") num_conv_blocks = gr.Slider( minimum=2, maximum=5, value=3, step=1, label="Nombre de blocs convolutionnels", info="Chaque bloc enchaîne Conv2d → (BN) → ReLU → MaxPool2d.", ) base_filters = gr.Dropdown( choices=[16, 32, 64, 128], value=32, label="Filtres du premier bloc (doublent à chaque bloc)", ) kernel_size = gr.Dropdown( choices=[3, 5], value=3, label="Taille du noyau de convolution", ) use_batchnorm = gr.Checkbox( value=True, label="Normalisation par lots (BatchNorm)", ) gr.Markdown("#### Hyperparamètres d’entraînement") dropout = gr.Slider( minimum=0.0, maximum=0.8, value=0.4, step=0.05, label="Dropout", ) fc_dim = gr.Dropdown( choices=[64, 128, 256, 512], value=256, label="Dimension de la couche cachée (classifieur)", ) learning_rate = gr.Number( value=0.001, label="Taux d’apprentissage", ) weight_decay = gr.Number( value=0.0001, label="Weight decay", ) batch_size = gr.Dropdown( choices=[8, 16, 32, 64], value=16, label="Taille du batch", ) epochs = gr.Slider( minimum=1, maximum=50, value=30, step=1, label="Nombre d’époques", ) model_tag = gr.Textbox( label="Nom court du modèle", placeholder="ex. cnn_3blocs ou resnet18_ft", ) train_btn = gr.Button("Lancer l’entraînement", variant="primary") with gr.Column(): train_status = gr.Textbox( label="Journal d’entraînement", lines=18, ) train_history = gr.JSON(label="Historique d’entraînement") train_summary = gr.JSON(label="Résumé final") gr.Markdown("## Résultats sur le test set") train_report = gr.Dataframe( label="Rapport de classification", interactive=False, ) train_confusion_matrix = gr.Dataframe( label="Matrice de confusion", interactive=False, ) train_confusion_matrix_image = gr.Image( label="Matrice de confusion - figure", type="filepath", ) with gr.Tab("3. Tester et analyser un modèle"): gr.Markdown("## Sélectionner un modèle sauvegardé") with gr.Row(): with gr.Column(): model_selector = gr.Dropdown( choices=initial_models, value=initial_models[0] if initial_models else None, label="Modèle sauvegardé", ) refresh_btn = gr.Button("Actualiser la liste des modèles") load_info_btn = gr.Button("Afficher les informations du modèle") model_info = gr.JSON(label="Métadonnées du modèle") with gr.Column(): evaluate_btn = gr.Button( "Évaluer le modèle sur le test set", variant="primary", ) eval_summary = gr.JSON(label="Résumé des métriques") eval_report = gr.Dataframe( label="Rapport de classification", interactive=False, ) eval_confusion_matrix = gr.Dataframe( label="Matrice de confusion", interactive=False, ) eval_confusion_matrix_image = gr.Image( label="Matrice de confusion - figure", type="filepath", ) gr.Markdown("## Prédiction sur une image importée") with gr.Row(): with gr.Column(): upload_image = gr.Image(type="pil", label="Importer une image") predict_btn = gr.Button("Prédire la classe", variant="primary") with gr.Column(): predict_text = gr.Textbox(label="Résultat de la prédiction", lines=7) predict_probs = gr.Label(label="Probabilités par classe") gr.Markdown("## Test sur un échantillon aléatoire du test set") random_test_btn = gr.Button("Tester un échantillon aléatoire") with gr.Row(): random_sample_image = gr.Image(type="pil", label="Image test aléatoire") random_sample_text = gr.Textbox(label="Résultat sur l’échantillon", lines=7) random_sample_probs = gr.Label(label="Probabilités par classe") load_dataset_btn.click( fn=load_dataset_overview_callback, inputs=None, outputs=[dataset_summary, class_distribution, class_selector], ) refresh_gallery_btn.click( fn=refresh_gallery_callback, inputs=[split_selector, class_selector, max_images], outputs=image_gallery, ) model_type.change( fn=on_model_type_change, inputs=model_type, outputs=[cnn_params_col, learning_rate], ) train_btn.click( fn=train_callback, inputs=[ model_type, num_conv_blocks, base_filters, kernel_size, use_batchnorm, dropout, fc_dim, learning_rate, weight_decay, batch_size, epochs, model_tag, ], outputs=[ train_status, train_history, train_summary, train_report, train_confusion_matrix, train_confusion_matrix_image, model_selector, ], ) refresh_btn.click( fn=refresh_models_dropdown, inputs=None, outputs=model_selector, ) load_info_btn.click( fn=get_model_info, inputs=model_selector, outputs=model_info, ) evaluate_btn.click( fn=evaluate_saved_model_callback, inputs=model_selector, outputs=[ eval_summary, eval_report, eval_confusion_matrix, eval_confusion_matrix_image, ], ) predict_btn.click( fn=predict_uploaded_image_callback, inputs=[model_selector, upload_image], outputs=[predict_text, predict_probs], ) random_test_btn.click( fn=test_random_sample_callback, inputs=model_selector, outputs=[random_sample_image, random_sample_text, random_sample_probs], ) if __name__ == "__main__": demo.launch(ssr_mode=False)