import json import spaces import gradio as gr from train_utils import train_model, list_saved_models, model_meta_path from predict_utils import predict_uploaded_image, test_random_sample @spaces.GPU(duration=120) def train_callback( conv1_channels, conv2_channels, kernel_size, dropout, fc_dim, learning_rate, batch_size, epochs, model_tag, ): try: logs, history, summary, model_name = train_model( int(conv1_channels), int(conv2_channels), int(kernel_size), float(dropout), int(fc_dim), float(learning_rate), int(batch_size), int(epochs), model_tag, ) models = list_saved_models() selected = model_name if model_name in models else (models[0] if models else None) return logs, history, summary, gr.update(choices=models, value=selected) except Exception as e: return f"Échec de l’entraînement :\n{str(e)}", None, None, gr.update() @spaces.GPU(duration=60) def predict_uploaded_image_callback(model_name, image): try: return predict_uploaded_image(model_name, image) except Exception as e: return f"Échec de la prédiction :\n{str(e)}", None @spaces.GPU(duration=60) def test_random_sample_callback(model_name): try: return test_random_sample(model_name) except Exception as e: return None, f"Échec du test aléatoire :\n{str(e)}", None def refresh_models_dropdown(): models = list_saved_models() return gr.update(choices=models, value=models[0] if models else None) def get_model_info(model_name: str): if not model_name: return {"message": "Aucun modèle sélectionné."} meta_file = model_meta_path(model_name) try: with open(meta_file, "r", encoding="utf-8") as f: return json.load(f) except FileNotFoundError: return {"message": "Métadonnées introuvables."} initial_models = list_saved_models() with gr.Blocks(title="Classification d’images microscopiques") as demo: gr.Markdown("# Classification d’images microscopiques de charbons de bois") gr.Markdown( "Cette application permet d’entraîner un réseau de neurones convolutif simple " "sur un jeu de données privé Hugging Face, puis de tester les modèles sauvegardés " "sur une image importée ou sur un échantillon aléatoire." ) with gr.Tabs(): with gr.Tab("Entraîner"): with gr.Row(): with gr.Column(): gr.Markdown("### Paramètres d’entraînement") conv1_channels = gr.Slider( 8, 64, value=16, step=8, label="Nombre de canaux - couche convolutionnelle 1" ) conv2_channels = gr.Slider( 16, 128, value=32, step=16, label="Nombre de canaux - couche convolutionnelle 2" ) kernel_size = gr.Dropdown( choices=[3, 5], value=3, label="Taille du noyau" ) dropout = gr.Slider( 0.0, 0.7, value=0.2, step=0.05, label="Dropout" ) fc_dim = gr.Slider( 32, 256, value=128, step=32, label="Dimension de la couche cachée fully-connected" ) learning_rate = gr.Number( value=0.001, label="Taux d’apprentissage" ) batch_size = gr.Dropdown( choices=[16, 32, 64, 128], value=32, label="Taille du batch" ) epochs = gr.Slider( 1, 20, value=5, step=1, label="Nombre d’époques" ) model_tag = gr.Textbox( label="Nom court du modèle", placeholder="ex. charbon_cnn_test" ) train_btn = gr.Button("Lancer l’entraînement", variant="primary") with gr.Column(): train_status = gr.Textbox(label="Journal d’entraînement", lines=18) train_history = gr.JSON(label="Historique d’entraînement") train_summary = gr.JSON(label="Résumé d’entraînement") with gr.Tab("Tester"): with gr.Row(): with gr.Column(): gr.Markdown("### Modèle sauvegardé") model_selector = gr.Dropdown( choices=initial_models, value=initial_models[0] if initial_models else None, label="Sélectionner un modèle", ) refresh_btn = gr.Button("Actualiser la liste des modèles") load_info_btn = gr.Button("Afficher les informations du modèle") model_info = gr.JSON(label="Métadonnées du modèle") with gr.Column(): gr.Markdown("### Prédiction sur une image importée") upload_image = gr.Image(type="pil", label="Importer une image") predict_btn = gr.Button("Prédire la classe", variant="primary") predict_text = gr.Textbox(label="Résultat de la prédiction", lines=7) predict_probs = gr.Label(label="Probabilités par classe") with gr.Row(): random_test_btn = gr.Button("Tester un échantillon aléatoire") with gr.Row(): random_sample_image = gr.Image(type="pil", label="Image test aléatoire") random_sample_text = gr.Textbox(label="Résultat sur l’échantillon", lines=7) random_sample_probs = gr.Label(label="Probabilités par classe") train_btn.click( fn=train_callback, inputs=[ conv1_channels, conv2_channels, kernel_size, dropout, fc_dim, learning_rate, batch_size, epochs, model_tag, ], outputs=[train_status, train_history, train_summary, model_selector], ) refresh_btn.click( fn=refresh_models_dropdown, inputs=None, outputs=model_selector, ) load_info_btn.click( fn=get_model_info, inputs=model_selector, outputs=model_info, ) predict_btn.click( fn=predict_uploaded_image_callback, inputs=[model_selector, upload_image], outputs=[predict_text, predict_probs], ) random_test_btn.click( fn=test_random_sample_callback, inputs=[model_selector], outputs=[random_sample_image, random_sample_text, random_sample_probs], ) if __name__ == "__main__": demo.launch()