functionNormally
Restaurer les paramètres CNN qui fonctionnaient + epoch max à 50
e8074db
import json
import gradio as gr
import spaces
from data_utils import (
dataset_overview,
get_class_names,
get_images_for_gallery,
)
from train_utils import (
train_model,
list_saved_models,
model_meta_path,
evaluate_saved_model,
)
from predict_utils import (
predict_uploaded_image,
test_random_sample,
)
def load_dataset_overview_callback():
try:
summary, distribution_df = dataset_overview()
class_names = ["Toutes les classes"] + get_class_names()
return (
summary,
distribution_df,
gr.update(choices=class_names, value="Toutes les classes"),
)
except Exception as e:
return (
{"Erreur": str(e)},
None,
gr.update(),
)
def refresh_gallery_callback(split_name, class_name, max_images):
try:
gallery = get_images_for_gallery(
split_name=split_name,
class_name=class_name,
max_images=int(max_images),
)
return gallery
except Exception as e:
return [(None, f"Erreur : {str(e)}")]
def on_model_type_change(model_type):
is_cnn = (model_type == "CNN simple")
default_lr = 0.001 if is_cnn else 0.0001
return gr.update(visible=is_cnn), gr.update(value=default_lr)
@spaces.GPU(duration=200)
def train_callback(
model_type,
num_conv_blocks,
base_filters,
kernel_size,
use_batchnorm,
dropout,
fc_dim,
learning_rate,
weight_decay,
batch_size,
epochs,
model_tag,
):
try:
result = train_model(
model_type="cnn" if model_type == "CNN simple" else "resnet18",
num_conv_blocks=int(num_conv_blocks),
base_filters=int(base_filters),
kernel_size=int(kernel_size),
use_batchnorm=bool(use_batchnorm),
dropout=float(dropout),
fc_dim=int(fc_dim),
learning_rate=float(learning_rate),
weight_decay=float(weight_decay),
batch_size=int(batch_size),
epochs=int(epochs),
model_tag=model_tag,
)
models = list_saved_models()
selected = result["model_name"] if result["model_name"] in models else None
return (
result["logs"],
result["history"],
result["summary"],
result["classification_report"],
result["confusion_matrix"],
result["confusion_matrix_path"],
gr.update(choices=models, value=selected),
)
except Exception as e:
return (
f"Échec de l’entraînement :\n{str(e)}",
None,
None,
None,
None,
None,
gr.update(),
)
@spaces.GPU(duration=120)
def evaluate_saved_model_callback(model_name):
try:
summary, report_df, cm_df, cm_path = evaluate_saved_model(model_name)
return summary, report_df, cm_df, cm_path
except Exception as e:
return {"Erreur": str(e)}, None, None, None
@spaces.GPU(duration=60)
def predict_uploaded_image_callback(model_name, image):
try:
return predict_uploaded_image(model_name, image)
except Exception as e:
return f"Échec de la prédiction :\n{str(e)}", None
@spaces.GPU(duration=60)
def test_random_sample_callback(model_name):
try:
return test_random_sample(model_name)
except Exception as e:
return None, f"Échec du test aléatoire :\n{str(e)}", None
def refresh_models_dropdown():
models = list_saved_models()
return gr.update(choices=models, value=models[0] if models else None)
def get_model_info(model_name: str):
if not model_name:
return {"message": "Aucun modèle sélectionné."}
meta_file = model_meta_path(model_name)
try:
with open(meta_file, "r", encoding="utf-8") as f:
return json.load(f)
except FileNotFoundError:
return {"message": "Métadonnées introuvables."}
initial_models = list_saved_models()
with gr.Blocks(title="Classification d’images microscopiques") as demo:
gr.Markdown("# Classification d’images microscopiques de charbons de bois")
gr.Markdown(
"Application pédagogique pour explorer un jeu de données d’images microscopiques, "
"entraîner un modèle de classification et analyser ses performances."
)
with gr.Tabs():
with gr.Tab("1. Explorer le jeu de données"):
gr.Markdown("## Comprendre le jeu de données avant l’entraînement")
load_dataset_btn = gr.Button(
"Charger les informations du dataset",
variant="primary",
)
dataset_summary = gr.JSON(label="Résumé général du dataset")
class_distribution = gr.Dataframe(
label="Distribution des images par split et par classe",
interactive=False,
)
gr.Markdown("## Visualisation des images")
with gr.Row():
split_selector = gr.Dropdown(
choices=["train", "validation", "test"],
value="train",
label="Split",
)
class_selector = gr.Dropdown(
choices=["Toutes les classes"],
value="Toutes les classes",
label="Classe",
)
max_images = gr.Slider(
minimum=4,
maximum=48,
value=24,
step=4,
label="Nombre d’images à afficher",
)
refresh_gallery_btn = gr.Button("Afficher des exemples")
image_gallery = gr.Gallery(
label="Exemples d’images",
columns=4,
height=600,
)
with gr.Tab("2. Entraîner un modèle"):
gr.Markdown("## Choix du modèle et entraînement")
with gr.Row():
with gr.Column():
model_type = gr.Radio(
choices=["CNN simple", "ResNet18"],
value="CNN simple",
label="Architecture",
info=(
"CNN simple : entraîné de zéro, paramètres configurables. "
"ResNet18 : pré-entraîné ImageNet, fine-tuning layer4 + classifieur."
),
)
with gr.Column(visible=True) as cnn_params_col:
gr.Markdown("#### Paramètres CNN")
num_conv_blocks = gr.Slider(
minimum=2,
maximum=5,
value=3,
step=1,
label="Nombre de blocs convolutionnels",
info="Chaque bloc enchaîne Conv2d → (BN) → ReLU → MaxPool2d.",
)
base_filters = gr.Dropdown(
choices=[16, 32, 64, 128],
value=32,
label="Filtres du premier bloc (doublent à chaque bloc)",
)
kernel_size = gr.Dropdown(
choices=[3, 5],
value=3,
label="Taille du noyau de convolution",
)
use_batchnorm = gr.Checkbox(
value=True,
label="Normalisation par lots (BatchNorm)",
)
gr.Markdown("#### Hyperparamètres d’entraînement")
dropout = gr.Slider(
minimum=0.0,
maximum=0.8,
value=0.4,
step=0.05,
label="Dropout",
)
fc_dim = gr.Dropdown(
choices=[64, 128, 256, 512],
value=256,
label="Dimension de la couche cachée (classifieur)",
)
learning_rate = gr.Number(
value=0.001,
label="Taux d’apprentissage",
)
weight_decay = gr.Number(
value=0.0001,
label="Weight decay",
)
batch_size = gr.Dropdown(
choices=[8, 16, 32, 64],
value=16,
label="Taille du batch",
)
epochs = gr.Slider(
minimum=1,
maximum=50,
value=30,
step=1,
label="Nombre d’époques",
)
model_tag = gr.Textbox(
label="Nom court du modèle",
placeholder="ex. cnn_3blocs ou resnet18_ft",
)
train_btn = gr.Button("Lancer l’entraînement", variant="primary")
with gr.Column():
train_status = gr.Textbox(
label="Journal d’entraînement",
lines=18,
)
train_history = gr.JSON(label="Historique d’entraînement")
train_summary = gr.JSON(label="Résumé final")
gr.Markdown("## Résultats sur le test set")
train_report = gr.Dataframe(
label="Rapport de classification",
interactive=False,
)
train_confusion_matrix = gr.Dataframe(
label="Matrice de confusion",
interactive=False,
)
train_confusion_matrix_image = gr.Image(
label="Matrice de confusion - figure",
type="filepath",
)
with gr.Tab("3. Tester et analyser un modèle"):
gr.Markdown("## Sélectionner un modèle sauvegardé")
with gr.Row():
with gr.Column():
model_selector = gr.Dropdown(
choices=initial_models,
value=initial_models[0] if initial_models else None,
label="Modèle sauvegardé",
)
refresh_btn = gr.Button("Actualiser la liste des modèles")
load_info_btn = gr.Button("Afficher les informations du modèle")
model_info = gr.JSON(label="Métadonnées du modèle")
with gr.Column():
evaluate_btn = gr.Button(
"Évaluer le modèle sur le test set",
variant="primary",
)
eval_summary = gr.JSON(label="Résumé des métriques")
eval_report = gr.Dataframe(
label="Rapport de classification",
interactive=False,
)
eval_confusion_matrix = gr.Dataframe(
label="Matrice de confusion",
interactive=False,
)
eval_confusion_matrix_image = gr.Image(
label="Matrice de confusion - figure",
type="filepath",
)
gr.Markdown("## Prédiction sur une image importée")
with gr.Row():
with gr.Column():
upload_image = gr.Image(type="pil", label="Importer une image")
predict_btn = gr.Button("Prédire la classe", variant="primary")
with gr.Column():
predict_text = gr.Textbox(label="Résultat de la prédiction", lines=7)
predict_probs = gr.Label(label="Probabilités par classe")
gr.Markdown("## Test sur un échantillon aléatoire du test set")
random_test_btn = gr.Button("Tester un échantillon aléatoire")
with gr.Row():
random_sample_image = gr.Image(type="pil", label="Image test aléatoire")
random_sample_text = gr.Textbox(label="Résultat sur l’échantillon", lines=7)
random_sample_probs = gr.Label(label="Probabilités par classe")
load_dataset_btn.click(
fn=load_dataset_overview_callback,
inputs=None,
outputs=[dataset_summary, class_distribution, class_selector],
)
refresh_gallery_btn.click(
fn=refresh_gallery_callback,
inputs=[split_selector, class_selector, max_images],
outputs=image_gallery,
)
model_type.change(
fn=on_model_type_change,
inputs=model_type,
outputs=[cnn_params_col, learning_rate],
)
train_btn.click(
fn=train_callback,
inputs=[
model_type,
num_conv_blocks,
base_filters,
kernel_size,
use_batchnorm,
dropout,
fc_dim,
learning_rate,
weight_decay,
batch_size,
epochs,
model_tag,
],
outputs=[
train_status,
train_history,
train_summary,
train_report,
train_confusion_matrix,
train_confusion_matrix_image,
model_selector,
],
)
refresh_btn.click(
fn=refresh_models_dropdown,
inputs=None,
outputs=model_selector,
)
load_info_btn.click(
fn=get_model_info,
inputs=model_selector,
outputs=model_info,
)
evaluate_btn.click(
fn=evaluate_saved_model_callback,
inputs=model_selector,
outputs=[
eval_summary,
eval_report,
eval_confusion_matrix,
eval_confusion_matrix_image,
],
)
predict_btn.click(
fn=predict_uploaded_image_callback,
inputs=[model_selector, upload_image],
outputs=[predict_text, predict_probs],
)
random_test_btn.click(
fn=test_random_sample_callback,
inputs=model_selector,
outputs=[random_sample_image, random_sample_text, random_sample_probs],
)
if __name__ == "__main__":
demo.launch(ssr_mode=False)