Spaces:
Running on Zero
Running on Zero
| import json | |
| import spaces | |
| import gradio as gr | |
| from train_utils import train_model, list_saved_models, model_meta_path | |
| from predict_utils import predict_uploaded_image, test_random_sample | |
| def train_callback( | |
| conv1_channels, | |
| conv2_channels, | |
| kernel_size, | |
| dropout, | |
| fc_dim, | |
| learning_rate, | |
| batch_size, | |
| epochs, | |
| model_tag, | |
| ): | |
| try: | |
| logs, history, summary, model_name = train_model( | |
| int(conv1_channels), | |
| int(conv2_channels), | |
| int(kernel_size), | |
| float(dropout), | |
| int(fc_dim), | |
| float(learning_rate), | |
| int(batch_size), | |
| int(epochs), | |
| model_tag, | |
| ) | |
| models = list_saved_models() | |
| selected = model_name if model_name in models else (models[0] if models else None) | |
| return logs, history, summary, gr.update(choices=models, value=selected) | |
| except Exception as e: | |
| return f"Échec de l’entraînement :\n{str(e)}", None, None, gr.update() | |
| def predict_uploaded_image_callback(model_name, image): | |
| try: | |
| return predict_uploaded_image(model_name, image) | |
| except Exception as e: | |
| return f"Échec de la prédiction :\n{str(e)}", None | |
| def test_random_sample_callback(model_name): | |
| try: | |
| return test_random_sample(model_name) | |
| except Exception as e: | |
| return None, f"Échec du test aléatoire :\n{str(e)}", None | |
| def refresh_models_dropdown(): | |
| models = list_saved_models() | |
| return gr.update(choices=models, value=models[0] if models else None) | |
| def get_model_info(model_name: str): | |
| if not model_name: | |
| return {"message": "Aucun modèle sélectionné."} | |
| meta_file = model_meta_path(model_name) | |
| try: | |
| with open(meta_file, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| except FileNotFoundError: | |
| return {"message": "Métadonnées introuvables."} | |
| initial_models = list_saved_models() | |
| with gr.Blocks(title="Classification d’images microscopiques") as demo: | |
| gr.Markdown("# Classification d’images microscopiques de charbons de bois") | |
| gr.Markdown( | |
| "Cette application permet d’entraîner un réseau de neurones convolutif simple " | |
| "sur un jeu de données privé Hugging Face, puis de tester les modèles sauvegardés " | |
| "sur une image importée ou sur un échantillon aléatoire." | |
| ) | |
| with gr.Tabs(): | |
| with gr.Tab("Entraîner"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Paramètres d’entraînement") | |
| conv1_channels = gr.Slider( | |
| 8, 64, value=16, step=8, label="Nombre de canaux - couche convolutionnelle 1" | |
| ) | |
| conv2_channels = gr.Slider( | |
| 16, 128, value=32, step=16, label="Nombre de canaux - couche convolutionnelle 2" | |
| ) | |
| kernel_size = gr.Dropdown( | |
| choices=[3, 5], value=3, label="Taille du noyau" | |
| ) | |
| dropout = gr.Slider( | |
| 0.0, 0.7, value=0.2, step=0.05, label="Dropout" | |
| ) | |
| fc_dim = gr.Slider( | |
| 32, 256, value=128, step=32, label="Dimension de la couche cachée fully-connected" | |
| ) | |
| learning_rate = gr.Number( | |
| value=0.001, label="Taux d’apprentissage" | |
| ) | |
| batch_size = gr.Dropdown( | |
| choices=[16, 32, 64, 128], value=32, label="Taille du batch" | |
| ) | |
| epochs = gr.Slider( | |
| 1, 20, value=5, step=1, label="Nombre d’époques" | |
| ) | |
| model_tag = gr.Textbox( | |
| label="Nom court du modèle", | |
| placeholder="ex. charbon_cnn_test" | |
| ) | |
| train_btn = gr.Button("Lancer l’entraînement", variant="primary") | |
| with gr.Column(): | |
| train_status = gr.Textbox(label="Journal d’entraînement", lines=18) | |
| train_history = gr.JSON(label="Historique d’entraînement") | |
| train_summary = gr.JSON(label="Résumé d’entraînement") | |
| with gr.Tab("Tester"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Modèle sauvegardé") | |
| model_selector = gr.Dropdown( | |
| choices=initial_models, | |
| value=initial_models[0] if initial_models else None, | |
| label="Sélectionner un modèle", | |
| ) | |
| refresh_btn = gr.Button("Actualiser la liste des modèles") | |
| load_info_btn = gr.Button("Afficher les informations du modèle") | |
| model_info = gr.JSON(label="Métadonnées du modèle") | |
| with gr.Column(): | |
| gr.Markdown("### Prédiction sur une image importée") | |
| upload_image = gr.Image(type="pil", label="Importer une image") | |
| predict_btn = gr.Button("Prédire la classe", variant="primary") | |
| predict_text = gr.Textbox(label="Résultat de la prédiction", lines=7) | |
| predict_probs = gr.Label(label="Probabilités par classe") | |
| with gr.Row(): | |
| random_test_btn = gr.Button("Tester un échantillon aléatoire") | |
| with gr.Row(): | |
| random_sample_image = gr.Image(type="pil", label="Image test aléatoire") | |
| random_sample_text = gr.Textbox(label="Résultat sur l’échantillon", lines=7) | |
| random_sample_probs = gr.Label(label="Probabilités par classe") | |
| train_btn.click( | |
| fn=train_callback, | |
| inputs=[ | |
| conv1_channels, | |
| conv2_channels, | |
| kernel_size, | |
| dropout, | |
| fc_dim, | |
| learning_rate, | |
| batch_size, | |
| epochs, | |
| model_tag, | |
| ], | |
| outputs=[train_status, train_history, train_summary, model_selector], | |
| ) | |
| refresh_btn.click( | |
| fn=refresh_models_dropdown, | |
| inputs=None, | |
| outputs=model_selector, | |
| ) | |
| load_info_btn.click( | |
| fn=get_model_info, | |
| inputs=model_selector, | |
| outputs=model_info, | |
| ) | |
| predict_btn.click( | |
| fn=predict_uploaded_image_callback, | |
| inputs=[model_selector, upload_image], | |
| outputs=[predict_text, predict_probs], | |
| ) | |
| random_test_btn.click( | |
| fn=test_random_sample_callback, | |
| inputs=[model_selector], | |
| outputs=[random_sample_image, random_sample_text, random_sample_probs], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |