CircleStar's picture
Update app.py
733869b verified
raw
history blame
6.96 kB
import json
import spaces
import gradio as gr
from train_utils import train_model, list_saved_models, model_meta_path
from predict_utils import predict_uploaded_image, test_random_sample
@spaces.GPU(duration=120)
def train_callback(
conv1_channels,
conv2_channels,
kernel_size,
dropout,
fc_dim,
learning_rate,
batch_size,
epochs,
model_tag,
):
try:
logs, history, summary, model_name = train_model(
int(conv1_channels),
int(conv2_channels),
int(kernel_size),
float(dropout),
int(fc_dim),
float(learning_rate),
int(batch_size),
int(epochs),
model_tag,
)
models = list_saved_models()
selected = model_name if model_name in models else (models[0] if models else None)
return logs, history, summary, gr.update(choices=models, value=selected)
except Exception as e:
return f"Échec de l’entraînement :\n{str(e)}", None, None, gr.update()
@spaces.GPU(duration=60)
def predict_uploaded_image_callback(model_name, image):
try:
return predict_uploaded_image(model_name, image)
except Exception as e:
return f"Échec de la prédiction :\n{str(e)}", None
@spaces.GPU(duration=60)
def test_random_sample_callback(model_name):
try:
return test_random_sample(model_name)
except Exception as e:
return None, f"Échec du test aléatoire :\n{str(e)}", None
def refresh_models_dropdown():
models = list_saved_models()
return gr.update(choices=models, value=models[0] if models else None)
def get_model_info(model_name: str):
if not model_name:
return {"message": "Aucun modèle sélectionné."}
meta_file = model_meta_path(model_name)
try:
with open(meta_file, "r", encoding="utf-8") as f:
return json.load(f)
except FileNotFoundError:
return {"message": "Métadonnées introuvables."}
initial_models = list_saved_models()
with gr.Blocks(title="Classification d’images microscopiques") as demo:
gr.Markdown("# Classification d’images microscopiques de charbons de bois")
gr.Markdown(
"Cette application permet d’entraîner un réseau de neurones convolutif simple "
"sur un jeu de données privé Hugging Face, puis de tester les modèles sauvegardés "
"sur une image importée ou sur un échantillon aléatoire."
)
with gr.Tabs():
with gr.Tab("Entraîner"):
with gr.Row():
with gr.Column():
gr.Markdown("### Paramètres d’entraînement")
conv1_channels = gr.Slider(
8, 64, value=16, step=8, label="Nombre de canaux - couche convolutionnelle 1"
)
conv2_channels = gr.Slider(
16, 128, value=32, step=16, label="Nombre de canaux - couche convolutionnelle 2"
)
kernel_size = gr.Dropdown(
choices=[3, 5], value=3, label="Taille du noyau"
)
dropout = gr.Slider(
0.0, 0.7, value=0.2, step=0.05, label="Dropout"
)
fc_dim = gr.Slider(
32, 256, value=128, step=32, label="Dimension de la couche cachée fully-connected"
)
learning_rate = gr.Number(
value=0.001, label="Taux d’apprentissage"
)
batch_size = gr.Dropdown(
choices=[16, 32, 64, 128], value=32, label="Taille du batch"
)
epochs = gr.Slider(
1, 20, value=5, step=1, label="Nombre d’époques"
)
model_tag = gr.Textbox(
label="Nom court du modèle",
placeholder="ex. charbon_cnn_test"
)
train_btn = gr.Button("Lancer l’entraînement", variant="primary")
with gr.Column():
train_status = gr.Textbox(label="Journal d’entraînement", lines=18)
train_history = gr.JSON(label="Historique d’entraînement")
train_summary = gr.JSON(label="Résumé d’entraînement")
with gr.Tab("Tester"):
with gr.Row():
with gr.Column():
gr.Markdown("### Modèle sauvegardé")
model_selector = gr.Dropdown(
choices=initial_models,
value=initial_models[0] if initial_models else None,
label="Sélectionner un modèle",
)
refresh_btn = gr.Button("Actualiser la liste des modèles")
load_info_btn = gr.Button("Afficher les informations du modèle")
model_info = gr.JSON(label="Métadonnées du modèle")
with gr.Column():
gr.Markdown("### Prédiction sur une image importée")
upload_image = gr.Image(type="pil", label="Importer une image")
predict_btn = gr.Button("Prédire la classe", variant="primary")
predict_text = gr.Textbox(label="Résultat de la prédiction", lines=7)
predict_probs = gr.Label(label="Probabilités par classe")
with gr.Row():
random_test_btn = gr.Button("Tester un échantillon aléatoire")
with gr.Row():
random_sample_image = gr.Image(type="pil", label="Image test aléatoire")
random_sample_text = gr.Textbox(label="Résultat sur l’échantillon", lines=7)
random_sample_probs = gr.Label(label="Probabilités par classe")
train_btn.click(
fn=train_callback,
inputs=[
conv1_channels,
conv2_channels,
kernel_size,
dropout,
fc_dim,
learning_rate,
batch_size,
epochs,
model_tag,
],
outputs=[train_status, train_history, train_summary, model_selector],
)
refresh_btn.click(
fn=refresh_models_dropdown,
inputs=None,
outputs=model_selector,
)
load_info_btn.click(
fn=get_model_info,
inputs=model_selector,
outputs=model_info,
)
predict_btn.click(
fn=predict_uploaded_image_callback,
inputs=[model_selector, upload_image],
outputs=[predict_text, predict_probs],
)
random_test_btn.click(
fn=test_random_sample_callback,
inputs=[model_selector],
outputs=[random_sample_image, random_sample_text, random_sample_probs],
)
if __name__ == "__main__":
demo.launch()