CircleStar commited on
Commit
6da55fb
·
verified ·
1 Parent(s): 5fce1fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -60
app.py CHANGED
@@ -59,7 +59,7 @@ def train_callback(
59
  weight_decay,
60
  batch_size,
61
  epochs,
62
- freeze_backbone,
63
  model_tag,
64
  ):
65
  try:
@@ -70,7 +70,7 @@ def train_callback(
70
  weight_decay=float(weight_decay),
71
  batch_size=int(batch_size),
72
  epochs=int(epochs),
73
- freeze_backbone=bool(freeze_backbone),
74
  model_tag=model_tag,
75
  )
76
 
@@ -157,17 +157,17 @@ with gr.Blocks(title="Classification d’images microscopiques") as demo:
157
  with gr.Tab("1. Explorer le jeu de données"):
158
  gr.Markdown("## Comprendre le jeu de données avant l’entraînement")
159
 
160
- with gr.Row():
161
- load_dataset_btn = gr.Button("Charger les informations du dataset", variant="primary")
 
 
162
 
163
- with gr.Row():
164
- dataset_summary = gr.JSON(label="Résumé général du dataset")
165
 
166
- with gr.Row():
167
- class_distribution = gr.Dataframe(
168
- label="Distribution des images par split et par classe",
169
- interactive=False,
170
- )
171
 
172
  gr.Markdown("## Visualisation des images")
173
 
@@ -189,7 +189,8 @@ with gr.Blocks(title="Classification d’images microscopiques") as demo:
189
  step=4,
190
  label="Nombre d’images à afficher",
191
  )
192
- refresh_gallery_btn = gr.Button("Afficher des exemples")
 
193
 
194
  image_gallery = gr.Gallery(
195
  label="Exemples d’images",
@@ -200,52 +201,64 @@ with gr.Blocks(title="Classification d’images microscopiques") as demo:
200
  with gr.Tab("2. Entraîner un modèle"):
201
  gr.Markdown("## Entraînement avec ResNet18 pré-entraîné")
202
  gr.Markdown(
203
- "Le modèle utilise un backbone ResNet18 pré-entraîné sur ImageNet. "
204
- "Pour limiter le surapprentissage sur un petit dataset, il est recommandé de commencer "
205
- "avec le backbone gelé."
206
  )
207
 
208
  with gr.Row():
209
  with gr.Column():
210
  dropout = gr.Slider(
211
- 0.0,
212
- 0.8,
213
- value=0.5,
214
  step=0.05,
215
  label="Dropout",
216
  )
 
217
  fc_dim = gr.Dropdown(
218
  choices=[64, 128, 256, 512],
219
  value=256,
220
  label="Dimension de la couche cachée",
221
  )
 
222
  learning_rate = gr.Number(
223
- value=0.0001,
224
  label="Taux d’apprentissage",
225
  )
 
226
  weight_decay = gr.Number(
227
  value=0.0001,
228
  label="Weight decay",
229
  )
 
230
  batch_size = gr.Dropdown(
231
  choices=[8, 16, 32, 64],
232
  value=16,
233
  label="Taille du batch",
234
  )
 
235
  epochs = gr.Slider(
236
- 1,
237
- 50,
238
- value=10,
239
  step=1,
240
  label="Nombre d’époques",
241
  )
242
- freeze_backbone = gr.Checkbox(
243
- value=True,
244
- label="Geler le backbone ResNet18",
 
 
 
 
 
 
 
245
  )
 
246
  model_tag = gr.Textbox(
247
  label="Nom court du modèle",
248
- placeholder="ex. charbon_resnet18_test",
249
  )
250
 
251
  train_btn = gr.Button("Lancer l’entraînement", variant="primary")
@@ -260,23 +273,20 @@ with gr.Blocks(title="Classification d’images microscopiques") as demo:
260
 
261
  gr.Markdown("## Résultats sur le test set")
262
 
263
- with gr.Row():
264
- train_report = gr.Dataframe(
265
- label="Rapport de classification",
266
- interactive=False,
267
- )
268
 
269
- with gr.Row():
270
- train_confusion_matrix = gr.Dataframe(
271
- label="Matrice de confusion",
272
- interactive=False,
273
- )
274
 
275
- with gr.Row():
276
- train_confusion_matrix_image = gr.Image(
277
- label="Matrice de confusion - figure",
278
- type="filepath",
279
- )
280
 
281
  with gr.Tab("3. Tester et analyser un modèle"):
282
  gr.Markdown("## Sélectionner un modèle sauvegardé")
@@ -288,29 +298,32 @@ with gr.Blocks(title="Classification d’images microscopiques") as demo:
288
  value=initial_models[0] if initial_models else None,
289
  label="Modèle sauvegardé",
290
  )
 
291
  refresh_btn = gr.Button("Actualiser la liste des modèles")
292
  load_info_btn = gr.Button("Afficher les informations du modèle")
293
  model_info = gr.JSON(label="Métadonnées du modèle")
294
 
295
  with gr.Column():
296
- evaluate_btn = gr.Button("Évaluer le modèle sur le test set", variant="primary")
297
- eval_summary = gr.JSON(label="Résumé des métriques")
298
- eval_report = gr.Dataframe(
299
- label="Rapport de classification",
300
- interactive=False,
301
  )
 
302
 
303
- with gr.Row():
304
- eval_confusion_matrix = gr.Dataframe(
305
- label="Matrice de confusion",
306
- interactive=False,
307
- )
308
 
309
- with gr.Row():
310
- eval_confusion_matrix_image = gr.Image(
311
- label="Matrice de confusion - figure",
312
- type="filepath",
313
- )
 
 
 
 
314
 
315
  gr.Markdown("## Prédiction sur une image importée")
316
 
@@ -318,14 +331,14 @@ with gr.Blocks(title="Classification d’images microscopiques") as demo:
318
  with gr.Column():
319
  upload_image = gr.Image(type="pil", label="Importer une image")
320
  predict_btn = gr.Button("Prédire la classe", variant="primary")
 
321
  with gr.Column():
322
  predict_text = gr.Textbox(label="Résultat de la prédiction", lines=7)
323
  predict_probs = gr.Label(label="Probabilités par classe")
324
 
325
  gr.Markdown("## Test sur un échantillon aléatoire du test set")
326
 
327
- with gr.Row():
328
- random_test_btn = gr.Button("Tester un échantillon aléatoire")
329
 
330
  with gr.Row():
331
  random_sample_image = gr.Image(type="pil", label="Image test aléatoire")
@@ -353,7 +366,7 @@ with gr.Blocks(title="Classification d’images microscopiques") as demo:
353
  weight_decay,
354
  batch_size,
355
  epochs,
356
- freeze_backbone,
357
  model_tag,
358
  ],
359
  outputs=[
@@ -398,7 +411,7 @@ with gr.Blocks(title="Classification d’images microscopiques") as demo:
398
 
399
  random_test_btn.click(
400
  fn=test_random_sample_callback,
401
- inputs=[model_selector],
402
  outputs=[random_sample_image, random_sample_text, random_sample_probs],
403
  )
404
 
 
59
  weight_decay,
60
  batch_size,
61
  epochs,
62
+ fine_tune_mode,
63
  model_tag,
64
  ):
65
  try:
 
70
  weight_decay=float(weight_decay),
71
  batch_size=int(batch_size),
72
  epochs=int(epochs),
73
+ fine_tune_mode=str(fine_tune_mode),
74
  model_tag=model_tag,
75
  )
76
 
 
157
  with gr.Tab("1. Explorer le jeu de données"):
158
  gr.Markdown("## Comprendre le jeu de données avant l’entraînement")
159
 
160
+ load_dataset_btn = gr.Button(
161
+ "Charger les informations du dataset",
162
+ variant="primary",
163
+ )
164
 
165
+ dataset_summary = gr.JSON(label="Résumé général du dataset")
 
166
 
167
+ class_distribution = gr.Dataframe(
168
+ label="Distribution des images par split et par classe",
169
+ interactive=False,
170
+ )
 
171
 
172
  gr.Markdown("## Visualisation des images")
173
 
 
189
  step=4,
190
  label="Nombre d’images à afficher",
191
  )
192
+
193
+ refresh_gallery_btn = gr.Button("Afficher des exemples")
194
 
195
  image_gallery = gr.Gallery(
196
  label="Exemples d’images",
 
201
  with gr.Tab("2. Entraîner un modèle"):
202
  gr.Markdown("## Entraînement avec ResNet18 pré-entraîné")
203
  gr.Markdown(
204
+ "Paramètres par défaut recommandés : fine-tuning de la dernière couche convolutionnelle "
205
+ "du ResNet18, faible taux d’apprentissage, augmentation légère des données."
 
206
  )
207
 
208
  with gr.Row():
209
  with gr.Column():
210
  dropout = gr.Slider(
211
+ minimum=0.0,
212
+ maximum=0.8,
213
+ value=0.4,
214
  step=0.05,
215
  label="Dropout",
216
  )
217
+
218
  fc_dim = gr.Dropdown(
219
  choices=[64, 128, 256, 512],
220
  value=256,
221
  label="Dimension de la couche cachée",
222
  )
223
+
224
  learning_rate = gr.Number(
225
+ value=0.00001,
226
  label="Taux d’apprentissage",
227
  )
228
+
229
  weight_decay = gr.Number(
230
  value=0.0001,
231
  label="Weight decay",
232
  )
233
+
234
  batch_size = gr.Dropdown(
235
  choices=[8, 16, 32, 64],
236
  value=16,
237
  label="Taille du batch",
238
  )
239
+
240
  epochs = gr.Slider(
241
+ minimum=1,
242
+ maximum=80,
243
+ value=30,
244
  step=1,
245
  label="Nombre d’époques",
246
  )
247
+
248
+ fine_tune_mode = gr.Dropdown(
249
+ choices=["frozen", "layer4", "full"],
250
+ value="layer4",
251
+ label="Mode de fine-tuning",
252
+ info=(
253
+ "frozen = seul le classifieur est entraîné ; "
254
+ "layer4 = dernière partie du ResNet18 + classifieur ; "
255
+ "full = tout le réseau est ajusté."
256
+ ),
257
  )
258
+
259
  model_tag = gr.Textbox(
260
  label="Nom court du modèle",
261
+ placeholder="ex. charbon_resnet18_layer4",
262
  )
263
 
264
  train_btn = gr.Button("Lancer l’entraînement", variant="primary")
 
273
 
274
  gr.Markdown("## Résultats sur le test set")
275
 
276
+ train_report = gr.Dataframe(
277
+ label="Rapport de classification",
278
+ interactive=False,
279
+ )
 
280
 
281
+ train_confusion_matrix = gr.Dataframe(
282
+ label="Matrice de confusion",
283
+ interactive=False,
284
+ )
 
285
 
286
+ train_confusion_matrix_image = gr.Image(
287
+ label="Matrice de confusion - figure",
288
+ type="filepath",
289
+ )
 
290
 
291
  with gr.Tab("3. Tester et analyser un modèle"):
292
  gr.Markdown("## Sélectionner un modèle sauvegardé")
 
298
  value=initial_models[0] if initial_models else None,
299
  label="Modèle sauvegardé",
300
  )
301
+
302
  refresh_btn = gr.Button("Actualiser la liste des modèles")
303
  load_info_btn = gr.Button("Afficher les informations du modèle")
304
  model_info = gr.JSON(label="Métadonnées du modèle")
305
 
306
  with gr.Column():
307
+ evaluate_btn = gr.Button(
308
+ "Évaluer le modèle sur le test set",
309
+ variant="primary",
 
 
310
  )
311
+ eval_summary = gr.JSON(label="Résumé des métriques")
312
 
313
+ eval_report = gr.Dataframe(
314
+ label="Rapport de classification",
315
+ interactive=False,
316
+ )
 
317
 
318
+ eval_confusion_matrix = gr.Dataframe(
319
+ label="Matrice de confusion",
320
+ interactive=False,
321
+ )
322
+
323
+ eval_confusion_matrix_image = gr.Image(
324
+ label="Matrice de confusion - figure",
325
+ type="filepath",
326
+ )
327
 
328
  gr.Markdown("## Prédiction sur une image importée")
329
 
 
331
  with gr.Column():
332
  upload_image = gr.Image(type="pil", label="Importer une image")
333
  predict_btn = gr.Button("Prédire la classe", variant="primary")
334
+
335
  with gr.Column():
336
  predict_text = gr.Textbox(label="Résultat de la prédiction", lines=7)
337
  predict_probs = gr.Label(label="Probabilités par classe")
338
 
339
  gr.Markdown("## Test sur un échantillon aléatoire du test set")
340
 
341
+ random_test_btn = gr.Button("Tester un échantillon aléatoire")
 
342
 
343
  with gr.Row():
344
  random_sample_image = gr.Image(type="pil", label="Image test aléatoire")
 
366
  weight_decay,
367
  batch_size,
368
  epochs,
369
+ fine_tune_mode,
370
  model_tag,
371
  ],
372
  outputs=[
 
411
 
412
  random_test_btn.click(
413
  fn=test_random_sample_callback,
414
+ inputs=model_selector,
415
  outputs=[random_sample_image, random_sample_text, random_sample_probs],
416
  )
417