functionNormally commited on
Commit
cdc317a
·
1 Parent(s): 15c5bd9

Restructurer l'app : backbone préentraîné + ML classique + FC head + CNN de zéro

Browse files

- Ajout backbone_utils.py : chargement du backbone ResNet18 depuis HF, extraction
de features 512-dim avec cache mémoire
- Ajout classical_ml_utils.py : SVM / LogReg / k-NN / RF / LDA sur les features
extraites (pipeline sklearn avec StandardScaler + joblib)
- Refactorisation train_utils.py : train_fc_head (tête FC seule, ~200Ko sauvegardés)
et train_cnn (SimpleCNN de zéro) ; evaluate_saved_model unifié pour tous les types
- Mise à jour model.py : BackboneWithFC (backbone gelé + tête FC) + SimpleCNN conservé
- Mise à jour predict_utils.py : dispatch automatique selon model_type
- Mise à jour app.py : 4 onglets (dataset / ML classique / neuronaux / test-prédiction)
- Ajout config.py : HF_BACKBONE_REPO, CLASSICAL_MODEL_TYPES
- Ajout .gitignore : exclut data/, backbone/, saved_models/, __pycache__/
- Ajout finetune_backbone.py : script local pour entraîner le backbone sur les données

Files changed (10) hide show
  1. .gitignore +8 -0
  2. app.py +308 -295
  3. backbone_utils.py +81 -0
  4. classical_ml_utils.py +140 -0
  5. config.py +4 -4
  6. data_utils.py +1 -1
  7. finetune_backbone.py +245 -0
  8. model.py +11 -26
  9. predict_utils.py +57 -38
  10. train_utils.py +291 -170
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ data/
2
+ backbone/
3
+ saved_models/
4
+ saved_models_meta/
5
+ saved_figures/
6
+ __pycache__/
7
+ *.pyc
8
+ .DS_Store
app.py CHANGED
@@ -3,90 +3,145 @@ import json
3
  import gradio as gr
4
  import spaces
5
 
6
- from data_utils import (
7
- dataset_overview,
8
- get_class_names,
9
- get_images_for_gallery,
10
- )
11
  from train_utils import (
12
- train_model,
13
  list_saved_models,
14
  model_meta_path,
15
- evaluate_saved_model,
16
- )
17
- from predict_utils import (
18
- predict_uploaded_image,
19
- test_random_sample,
20
  )
21
 
 
 
 
22
 
23
- def load_dataset_overview_callback():
24
  try:
25
  summary, distribution_df = dataset_overview()
26
  class_names = ["Toutes les classes"] + get_class_names()
 
 
 
27
 
28
- return (
29
- summary,
30
- distribution_df,
31
- gr.update(choices=class_names, value="Toutes les classes"),
32
- )
33
 
 
 
 
34
  except Exception as e:
35
- return (
36
- {"Erreur": str(e)},
37
- None,
38
- gr.update(),
39
- )
40
 
41
 
42
- def refresh_gallery_callback(split_name, class_name, max_images):
 
 
 
 
43
  try:
44
- gallery = get_images_for_gallery(
45
- split_name=split_name,
46
- class_name=class_name,
47
- max_images=int(max_images),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  )
49
- return gallery
50
  except Exception as e:
51
- return [(None, f"Erreur : {str(e)}")]
52
 
53
 
54
- def on_model_type_change(model_type):
55
- is_cnn = (model_type == "CNN simple")
56
- default_lr = 0.001 if is_cnn else 0.0001
 
 
 
 
57
  return gr.update(visible=is_cnn), gr.update(value=default_lr)
58
 
59
 
60
- @spaces.GPU(duration=200)
61
- def train_callback(
62
  model_type,
63
- num_conv_blocks,
64
- base_filters,
65
- kernel_size,
66
- use_batchnorm,
67
- dropout,
68
- fc_dim,
69
- learning_rate,
70
- weight_decay,
71
- batch_size,
72
- epochs,
73
  model_tag,
74
  ):
75
  try:
76
- result = train_model(
77
- model_type="cnn" if model_type == "CNN simple" else "resnet18",
78
- num_conv_blocks=int(num_conv_blocks),
79
- base_filters=int(base_filters),
80
- kernel_size=int(kernel_size),
81
- use_batchnorm=bool(use_batchnorm),
82
- dropout=float(dropout),
83
- fc_dim=int(fc_dim),
84
- learning_rate=float(learning_rate),
85
- weight_decay=float(weight_decay),
86
- batch_size=int(batch_size),
87
- epochs=int(epochs),
88
- model_tag=model_tag,
89
- )
 
 
 
 
 
 
 
 
 
 
90
 
91
  models = list_saved_models()
92
  selected = result["model_name"] if result["model_name"] in models else None
@@ -100,21 +155,31 @@ def train_callback(
100
  result["confusion_matrix_path"],
101
  gr.update(choices=models, value=selected),
102
  )
103
-
104
  except Exception as e:
105
- return (
106
- f"Échec de l’entraînement :\n{str(e)}",
107
- None,
108
- None,
109
- None,
110
- None,
111
- None,
112
- gr.update(),
113
- )
 
 
 
 
 
 
 
 
 
 
 
114
 
115
 
116
  @spaces.GPU(duration=120)
117
- def evaluate_saved_model_callback(model_name):
118
  try:
119
  summary, report_df, cm_df, cm_path = evaluate_saved_model(model_name)
120
  return summary, report_df, cm_df, cm_path
@@ -123,269 +188,219 @@ def evaluate_saved_model_callback(model_name):
123
 
124
 
125
  @spaces.GPU(duration=60)
126
- def predict_uploaded_image_callback(model_name, image):
127
  try:
128
  return predict_uploaded_image(model_name, image)
129
  except Exception as e:
130
- return f"Échec de la prédiction :\n{str(e)}", None
131
 
132
 
133
  @spaces.GPU(duration=60)
134
- def test_random_sample_callback(model_name):
135
  try:
136
  return test_random_sample(model_name)
137
  except Exception as e:
138
- return None, f"Échec du test aléatoire :\n{str(e)}", None
139
-
140
-
141
- def refresh_models_dropdown():
142
- models = list_saved_models()
143
- return gr.update(choices=models, value=models[0] if models else None)
144
 
145
 
146
- def get_model_info(model_name: str):
147
- if not model_name:
148
- return {"message": "Aucun modèle sélectionné."}
149
-
150
- meta_file = model_meta_path(model_name)
151
-
152
- try:
153
- with open(meta_file, "r", encoding="utf-8") as f:
154
- return json.load(f)
155
- except FileNotFoundError:
156
- return {"message": "Métadonnées introuvables."}
157
-
158
 
159
  initial_models = list_saved_models()
160
 
161
-
162
- with gr.Blocks(title="Classification dimages microscopiques") as demo:
163
- gr.Markdown("# Classification d’images microscopiques de charbons de bois")
164
  gr.Markdown(
165
- "Application pédagogique pour explorer un jeu de données d’images microscopiques, "
166
- "entraîner un modèle de classification et analyser ses performances."
 
167
  )
168
 
169
  with gr.Tabs():
170
 
 
 
 
171
  with gr.Tab("1. Explorer le jeu de données"):
172
- gr.Markdown("## Comprendre le jeu de données avant lentraînement")
173
-
174
- load_dataset_btn = gr.Button(
175
- "Charger les informations du dataset",
176
- variant="primary",
177
- )
178
-
179
- dataset_summary = gr.JSON(label="Résumé général du dataset")
180
 
 
 
181
  class_distribution = gr.Dataframe(
182
- label="Distribution des images par split et par classe",
183
- interactive=False,
184
  )
185
 
186
  gr.Markdown("## Visualisation des images")
187
-
188
  with gr.Row():
189
  split_selector = gr.Dropdown(
190
- choices=["train", "validation", "test"],
191
- value="train",
192
- label="Split",
193
  )
194
  class_selector = gr.Dropdown(
195
- choices=["Toutes les classes"],
196
- value="Toutes les classes",
197
- label="Classe",
198
- )
199
- max_images = gr.Slider(
200
- minimum=4,
201
- maximum=48,
202
- value=24,
203
- step=4,
204
- label="Nombre d’images à afficher",
205
  )
 
206
 
207
  refresh_gallery_btn = gr.Button("Afficher des exemples")
208
-
209
- image_gallery = gr.Gallery(
210
- label="Exemples d’images",
211
- columns=4,
212
- height=600,
 
 
 
 
 
 
213
  )
214
 
215
- with gr.Tab("2. Entraîner un modèle"):
216
- gr.Markdown("## Choix du modèle et entraînement")
 
 
217
 
218
  with gr.Row():
219
  with gr.Column():
220
- model_type = gr.Radio(
221
- choices=["CNN simple", "ResNet18"],
222
- value="CNN simple",
223
- label="Architecture",
224
- info=(
225
- "CNN simple : entraîné de zéro, paramètres configurables. "
226
- "ResNet18 : pré-entraîné ImageNet, fine-tuning layer4 + classifieur."
227
- ),
228
  )
229
 
230
- with gr.Column(visible=True) as cnn_params_col:
231
- gr.Markdown("#### Paramètres CNN")
232
- num_conv_blocks = gr.Slider(
233
- minimum=2,
234
- maximum=5,
235
- value=3,
236
- step=1,
237
- label="Nombre de blocs convolutionnels",
238
- info="Chaque bloc enchaîne Conv2d → (BN) → ReLU → MaxPool2d.",
 
 
 
 
 
 
 
239
  )
240
 
241
- base_filters = gr.Dropdown(
242
- choices=[16, 32, 64, 128],
243
- value=32,
244
- label="Filtres du premier bloc (doublent à chaque bloc)",
245
- )
246
-
247
- kernel_size = gr.Dropdown(
248
- choices=[3, 5],
249
- value=3,
250
- label="Taille du noyau de convolution",
251
- )
252
 
253
- use_batchnorm = gr.Checkbox(
254
- value=True,
255
- label="Normalisation par lots (BatchNorm)",
256
- )
257
 
258
- gr.Markdown("#### Hyperparamètres d’entraînement")
 
259
 
260
- dropout = gr.Slider(
261
- minimum=0.0,
262
- maximum=0.8,
263
- value=0.4,
264
- step=0.05,
265
- label="Dropout",
266
- )
267
-
268
- fc_dim = gr.Dropdown(
269
- choices=[64, 128, 256, 512],
270
- value=256,
271
- label="Dimension de la couche cachée (classifieur)",
272
- )
273
 
274
- learning_rate = gr.Number(
275
- value=0.001,
276
- label="Taux d’apprentissage",
277
- )
278
 
279
- weight_decay = gr.Number(
280
- value=0.0001,
281
- label="Weight decay",
282
- )
 
283
 
284
- batch_size = gr.Dropdown(
285
- choices=[8, 16, 32, 64],
286
- value=16,
287
- label="Taille du batch",
 
 
 
 
 
 
288
  )
289
 
290
- epochs = gr.Slider(
291
- minimum=1,
292
- maximum=50,
293
- value=30,
294
- step=1,
295
- label="Nombre d’époques",
296
- )
297
 
298
- model_tag = gr.Textbox(
299
- label="Nom court du modèle",
300
- placeholder="ex. cnn_3blocs ou resnet18_ft",
301
- )
 
 
 
 
302
 
303
- train_btn = gr.Button("Lancer lentraînement", variant="primary")
304
 
305
  with gr.Column():
306
- train_status = gr.Textbox(
307
- label="Journal d’entraînement",
308
- lines=18,
309
- )
310
- train_history = gr.JSON(label="Historique d’entraînement")
311
- train_summary = gr.JSON(label="Résumé final")
312
 
313
  gr.Markdown("## Résultats sur le test set")
314
-
315
- train_report = gr.Dataframe(
316
- label="Rapport de classification",
317
- interactive=False,
318
- )
319
-
320
- train_confusion_matrix = gr.Dataframe(
321
- label="Matrice de confusion",
322
- interactive=False,
323
- )
324
-
325
- train_confusion_matrix_image = gr.Image(
326
- label="Matrice de confusion - figure",
327
- type="filepath",
328
- )
329
-
330
- with gr.Tab("3. Tester et analyser un modèle"):
331
  gr.Markdown("## Sélectionner un modèle sauvegardé")
 
 
 
332
 
333
  with gr.Row():
334
  with gr.Column():
335
  model_selector = gr.Dropdown(
336
  choices=initial_models,
337
  value=initial_models[0] if initial_models else None,
338
- label="Modèle sauvegardé",
339
  )
340
-
341
- refresh_btn = gr.Button("Actualiser la liste des modèles")
342
  load_info_btn = gr.Button("Afficher les informations du modèle")
343
- model_info = gr.JSON(label="Métadonnées du modèle")
344
 
345
  with gr.Column():
346
- evaluate_btn = gr.Button(
347
- "Évaluer le modèle sur le test set",
348
- variant="primary",
349
- )
350
  eval_summary = gr.JSON(label="Résumé des métriques")
351
 
352
- eval_report = gr.Dataframe(
353
- label="Rapport de classification",
354
- interactive=False,
355
- )
356
-
357
- eval_confusion_matrix = gr.Dataframe(
358
- label="Matrice de confusion",
359
- interactive=False,
360
- )
361
-
362
- eval_confusion_matrix_image = gr.Image(
363
- label="Matrice de confusion - figure",
364
- type="filepath",
365
- )
366
 
367
  gr.Markdown("## Prédiction sur une image importée")
368
-
369
  with gr.Row():
370
  with gr.Column():
371
  upload_image = gr.Image(type="pil", label="Importer une image")
372
  predict_btn = gr.Button("Prédire la classe", variant="primary")
373
-
374
  with gr.Column():
375
- predict_text = gr.Textbox(label="Résultat de la prédiction", lines=7)
376
  predict_probs = gr.Label(label="Probabilités par classe")
377
 
378
  gr.Markdown("## Test sur un échantillon aléatoire du test set")
379
-
380
  random_test_btn = gr.Button("Tester un échantillon aléatoire")
381
-
382
  with gr.Row():
383
- random_sample_image = gr.Image(type="pil", label="Image test aléatoire")
384
- random_sample_text = gr.Textbox(label="Résultat sur l’échantillon", lines=7)
385
- random_sample_probs = gr.Label(label="Probabilités par classe")
 
 
 
 
386
 
387
  load_dataset_btn.click(
388
- fn=load_dataset_overview_callback,
389
  inputs=None,
390
  outputs=[dataset_summary, class_distribution, class_selector],
391
  )
@@ -396,74 +411,72 @@ with gr.Blocks(title="Classification d’images microscopiques") as demo:
396
  outputs=image_gallery,
397
  )
398
 
399
- model_type.change(
400
- fn=on_model_type_change,
401
- inputs=model_type,
402
- outputs=[cnn_params_col, learning_rate],
 
 
403
  )
404
 
405
- train_btn.click(
406
- fn=train_callback,
407
  inputs=[
408
- model_type,
409
- num_conv_blocks,
410
- base_filters,
411
- kernel_size,
412
- use_batchnorm,
413
- dropout,
414
- fc_dim,
415
- learning_rate,
416
- weight_decay,
417
- batch_size,
418
- epochs,
419
- model_tag,
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  ],
421
  outputs=[
422
- train_status,
423
- train_history,
424
- train_summary,
425
- train_report,
426
- train_confusion_matrix,
427
- train_confusion_matrix_image,
428
  model_selector,
429
  ],
430
  )
431
 
432
- refresh_btn.click(
433
- fn=refresh_models_dropdown,
434
- inputs=None,
435
- outputs=model_selector,
436
- )
437
 
438
- load_info_btn.click(
439
- fn=get_model_info,
440
- inputs=model_selector,
441
- outputs=model_info,
442
- )
443
 
444
  evaluate_btn.click(
445
- fn=evaluate_saved_model_callback,
446
  inputs=model_selector,
447
- outputs=[
448
- eval_summary,
449
- eval_report,
450
- eval_confusion_matrix,
451
- eval_confusion_matrix_image,
452
- ],
453
  )
454
 
455
  predict_btn.click(
456
- fn=predict_uploaded_image_callback,
457
  inputs=[model_selector, upload_image],
458
  outputs=[predict_text, predict_probs],
459
  )
460
 
461
  random_test_btn.click(
462
- fn=test_random_sample_callback,
463
  inputs=model_selector,
464
- outputs=[random_sample_image, random_sample_text, random_sample_probs],
465
  )
466
 
467
 
468
  if __name__ == "__main__":
469
- demo.launch(ssr_mode=False)
 
3
  import gradio as gr
4
  import spaces
5
 
6
+ from backbone_utils import extract_all_features, get_cached_features
7
+ from classical_ml_utils import train_classical_model
8
+ from data_utils import dataset_overview, get_class_names, get_images_for_gallery
9
+ from predict_utils import predict_uploaded_image, test_random_sample
 
10
  from train_utils import (
11
+ evaluate_saved_model,
12
  list_saved_models,
13
  model_meta_path,
14
+ train_cnn,
15
+ train_fc_head,
 
 
 
16
  )
17
 
18
+ # ---------------------------------------------------------------------------
19
+ # Tab 1 — Dataset
20
+ # ---------------------------------------------------------------------------
21
 
22
+ def load_dataset_callback():
23
  try:
24
  summary, distribution_df = dataset_overview()
25
  class_names = ["Toutes les classes"] + get_class_names()
26
+ return summary, distribution_df, gr.update(choices=class_names, value="Toutes les classes")
27
+ except Exception as e:
28
+ return {"Erreur": str(e)}, None, gr.update()
29
 
 
 
 
 
 
30
 
31
+ def refresh_gallery_callback(split_name, class_name, max_images):
32
+ try:
33
+ return get_images_for_gallery(split_name, class_name, int(max_images))
34
  except Exception as e:
35
+ return [(None, f"Erreur : {e}")]
 
 
 
 
36
 
37
 
38
+ # ---------------------------------------------------------------------------
39
+ # Tab 2 — ML classique
40
+ # ---------------------------------------------------------------------------
41
+
42
+ def extract_features_callback():
43
  try:
44
+ _, class_names, counts = extract_all_features()
45
+ lines = [f"Extraction terminée ({len(class_names)} classes)"]
46
+ for split, n in counts.items():
47
+ lines.append(f" {split} : {n} images")
48
+ return "\n".join(lines)
49
+ except Exception as e:
50
+ return f"Erreur lors de l'extraction :\n{e}"
51
+
52
+
53
+ def on_clf_type_change(clf_type):
54
+ show = lambda t: gr.update(visible=(clf_type == t))
55
+ return show("SVM"), show("Régression logistique"), show("k-NN"), show("Forêt aléatoire"), show("LDA")
56
+
57
+
58
+ def train_classical_callback(
59
+ clf_type,
60
+ svm_c, svm_kernel, svm_gamma,
61
+ logreg_c, logreg_max_iter,
62
+ knn_k, knn_metric,
63
+ rf_n_estimators, rf_max_depth,
64
+ lda_solver,
65
+ model_tag,
66
+ ):
67
+ try:
68
+ features_cache = get_cached_features()
69
+ if features_cache is None:
70
+ return {"Erreur": "Veuillez d'abord extraire les caractéristiques (bouton ci-dessus)."}, None, None, None, gr.update()
71
+
72
+ params = {}
73
+ if clf_type == "SVM":
74
+ params = {"C": float(svm_c), "kernel": svm_kernel, "gamma": svm_gamma}
75
+ elif clf_type == "Régression logistique":
76
+ params = {"C": float(logreg_c), "max_iter": int(logreg_max_iter)}
77
+ elif clf_type == "k-NN":
78
+ params = {"n_neighbors": int(knn_k), "metric": knn_metric}
79
+ elif clf_type == "Forêt aléatoire":
80
+ depth = int(rf_max_depth) if rf_max_depth and int(rf_max_depth) > 0 else None
81
+ params = {"n_estimators": int(rf_n_estimators), "max_depth": depth}
82
+ elif clf_type == "LDA":
83
+ params = {"solver": lda_solver}
84
+
85
+ class_names = get_class_names()
86
+ result = train_classical_model(clf_type, features_cache, class_names, model_tag, **params)
87
+
88
+ models = list_saved_models()
89
+ selected = result["model_name"] if result["model_name"] in models else None
90
+
91
+ return (
92
+ result["summary"],
93
+ result["classification_report"],
94
+ result["confusion_matrix"],
95
+ result["confusion_matrix_path"],
96
+ gr.update(choices=models, value=selected),
97
  )
 
98
  except Exception as e:
99
+ return {"Erreur": str(e)}, None, None, None, gr.update()
100
 
101
 
102
+ # ---------------------------------------------------------------------------
103
+ # Tab 3 Modèles neuronaux
104
+ # ---------------------------------------------------------------------------
105
+
106
+ def on_neural_type_change(model_type):
107
+ is_cnn = (model_type == "CNN de zéro")
108
+ default_lr = 1e-3 if is_cnn else 1e-4
109
  return gr.update(visible=is_cnn), gr.update(value=default_lr)
110
 
111
 
112
+ @spaces.GPU(duration=300)
113
+ def train_neural_callback(
114
  model_type,
115
+ num_conv_blocks, base_filters, kernel_size, use_batchnorm,
116
+ dropout, fc_dim,
117
+ learning_rate, weight_decay, batch_size, epochs,
 
 
 
 
 
 
 
118
  model_tag,
119
  ):
120
  try:
121
+ if model_type == "FC sur backbone préentraîné":
122
+ result = train_fc_head(
123
+ dropout=float(dropout),
124
+ fc_dim=int(fc_dim),
125
+ learning_rate=float(learning_rate),
126
+ weight_decay=float(weight_decay),
127
+ batch_size=int(batch_size),
128
+ epochs=int(epochs),
129
+ model_tag=model_tag,
130
+ )
131
+ else:
132
+ result = train_cnn(
133
+ num_conv_blocks=int(num_conv_blocks),
134
+ base_filters=int(base_filters),
135
+ kernel_size=int(kernel_size),
136
+ use_batchnorm=bool(use_batchnorm),
137
+ dropout=float(dropout),
138
+ fc_dim=int(fc_dim),
139
+ learning_rate=float(learning_rate),
140
+ weight_decay=float(weight_decay),
141
+ batch_size=int(batch_size),
142
+ epochs=int(epochs),
143
+ model_tag=model_tag,
144
+ )
145
 
146
  models = list_saved_models()
147
  selected = result["model_name"] if result["model_name"] in models else None
 
155
  result["confusion_matrix_path"],
156
  gr.update(choices=models, value=selected),
157
  )
 
158
  except Exception as e:
159
+ return f"Échec de l'entraînement :\n{e}", None, None, None, None, None, gr.update()
160
+
161
+
162
+ # ---------------------------------------------------------------------------
163
+ # Tab 4 — Tester et prédire
164
+ # ---------------------------------------------------------------------------
165
+
166
+ def refresh_models_callback():
167
+ models = list_saved_models()
168
+ return gr.update(choices=models, value=models[0] if models else None)
169
+
170
+
171
+ def get_model_info_callback(model_name):
172
+ if not model_name:
173
+ return {"message": "Aucun modèle sélectionné."}
174
+ try:
175
+ with open(model_meta_path(model_name), "r", encoding="utf-8") as f:
176
+ return json.load(f)
177
+ except FileNotFoundError:
178
+ return {"message": "Métadonnées introuvables."}
179
 
180
 
181
  @spaces.GPU(duration=120)
182
+ def evaluate_callback(model_name):
183
  try:
184
  summary, report_df, cm_df, cm_path = evaluate_saved_model(model_name)
185
  return summary, report_df, cm_df, cm_path
 
188
 
189
 
190
  @spaces.GPU(duration=60)
191
+ def predict_callback(model_name, image):
192
  try:
193
  return predict_uploaded_image(model_name, image)
194
  except Exception as e:
195
+ return f"Échec :\n{e}", None
196
 
197
 
198
  @spaces.GPU(duration=60)
199
+ def random_test_callback(model_name):
200
  try:
201
  return test_random_sample(model_name)
202
  except Exception as e:
203
+ return None, f"Échec :\n{e}", None
 
 
 
 
 
204
 
205
 
206
+ # ---------------------------------------------------------------------------
207
+ # UI
208
+ # ---------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
209
 
210
  initial_models = list_saved_models()
211
 
212
+ with gr.Blocks(title="Classification d'images microscopiques") as demo:
213
+ gr.Markdown("# Classification d'images microscopiques de charbons de bois")
 
214
  gr.Markdown(
215
+ "Application pédagogique : explorez le jeu de données, entraînez des classifieurs "
216
+ "traditionnels ou neuronaux sur les caractéristiques extraites par un backbone "
217
+ "ResNet18 préentraîné, puis analysez et comparez les résultats."
218
  )
219
 
220
  with gr.Tabs():
221
 
222
+ # ------------------------------------------------------------------ #
223
+ # Tab 1
224
+ # ------------------------------------------------------------------ #
225
  with gr.Tab("1. Explorer le jeu de données"):
226
+ gr.Markdown("## Comprendre le jeu de données avant l'entraînement")
 
 
 
 
 
 
 
227
 
228
+ load_dataset_btn = gr.Button("Charger les informations du dataset", variant="primary")
229
+ dataset_summary = gr.JSON(label="Résumé général")
230
  class_distribution = gr.Dataframe(
231
+ label="Distribution par split et par classe", interactive=False
 
232
  )
233
 
234
  gr.Markdown("## Visualisation des images")
 
235
  with gr.Row():
236
  split_selector = gr.Dropdown(
237
+ choices=["train", "validation", "test"], value="train", label="Split"
 
 
238
  )
239
  class_selector = gr.Dropdown(
240
+ choices=["Toutes les classes"], value="Toutes les classes", label="Classe"
 
 
 
 
 
 
 
 
 
241
  )
242
+ max_images = gr.Slider(minimum=4, maximum=48, value=24, step=4, label="Nombre d'images")
243
 
244
  refresh_gallery_btn = gr.Button("Afficher des exemples")
245
+ image_gallery = gr.Gallery(label="Exemples d'images", columns=4, height=600)
246
+
247
+ # ------------------------------------------------------------------ #
248
+ # Tab 2
249
+ # ------------------------------------------------------------------ #
250
+ with gr.Tab("2. ML classique sur caractéristiques"):
251
+ gr.Markdown(
252
+ "## Étape 1 — Extraction des caractéristiques\n"
253
+ "Le backbone ResNet18 préentraîné sur les charbons extrait un vecteur de "
254
+ "512 dimensions par image. Cette étape s'exécute sur CPU et ne nécessite "
255
+ "aucun GPU."
256
  )
257
 
258
+ extract_btn = gr.Button("Extraire les caractéristiques (backbone gelé)", variant="primary")
259
+ extract_status = gr.Textbox(label="Statut de l'extraction", lines=4, interactive=False)
260
+
261
+ gr.Markdown("## Étape 2 — Entraîner un classifieur")
262
 
263
  with gr.Row():
264
  with gr.Column():
265
+ clf_type = gr.Radio(
266
+ choices=["SVM", "Régression logistique", "k-NN", "Forêt aléatoire", "LDA"],
267
+ value="SVM",
268
+ label="Algorithme",
 
 
 
 
269
  )
270
 
271
+ with gr.Column(visible=True) as svm_col:
272
+ gr.Markdown("#### Paramètres SVM")
273
+ svm_c = gr.Number(value=1.0, label="C (régularisation)")
274
+ svm_kernel = gr.Dropdown(choices=["rbf", "linear", "poly"], value="rbf", label="Noyau")
275
+ svm_gamma = gr.Dropdown(choices=["scale", "auto"], value="scale", label="Gamma")
276
+
277
+ with gr.Column(visible=False) as logreg_col:
278
+ gr.Markdown("#### Paramètres Régression logistique")
279
+ logreg_c = gr.Number(value=1.0, label="C (régularisation)")
280
+ logreg_max_iter = gr.Number(value=1000, label="Itérations max")
281
+
282
+ with gr.Column(visible=False) as knn_col:
283
+ gr.Markdown("#### Paramètres k-NN")
284
+ knn_k = gr.Slider(minimum=1, maximum=20, value=5, step=1, label="k (voisins)")
285
+ knn_metric = gr.Dropdown(
286
+ choices=["euclidean", "cosine", "manhattan"], value="euclidean", label="Métrique"
287
  )
288
 
289
+ with gr.Column(visible=False) as rf_col:
290
+ gr.Markdown("#### Paramètres Forêt aléatoire")
291
+ rf_n_estimators = gr.Slider(minimum=10, maximum=500, value=100, step=10, label="Nombre d'arbres")
292
+ rf_max_depth = gr.Number(value=0, label="Profondeur max (0 = illimitée)")
 
 
 
 
 
 
 
293
 
294
+ with gr.Column(visible=False) as lda_col:
295
+ gr.Markdown("#### Paramètres LDA")
296
+ lda_solver = gr.Dropdown(choices=["svd", "lsqr", "eigen"], value="svd", label="Solveur")
 
297
 
298
+ ml_model_tag = gr.Textbox(label="Nom court du modèle", placeholder="ex. svm_rbf")
299
+ train_classical_btn = gr.Button("Entraîner le classifieur", variant="primary")
300
 
301
+ with gr.Column():
302
+ ml_summary = gr.JSON(label="Résumé des métriques")
 
 
 
 
 
 
 
 
 
 
 
303
 
304
+ ml_report = gr.Dataframe(label="Rapport de classification", interactive=False)
305
+ ml_cm = gr.Dataframe(label="Matrice de confusion", interactive=False)
306
+ ml_cm_img = gr.Image(label="Matrice de confusion — figure", type="filepath")
 
307
 
308
+ # ------------------------------------------------------------------ #
309
+ # Tab 3
310
+ # ------------------------------------------------------------------ #
311
+ with gr.Tab("3. Modèles neuronaux"):
312
+ gr.Markdown("## Architecture")
313
 
314
+ with gr.Row():
315
+ with gr.Column():
316
+ neural_type = gr.Radio(
317
+ choices=["FC sur backbone préentraîné", "CNN de zéro"],
318
+ value="FC sur backbone préentraîné",
319
+ label="Type de modèle",
320
+ info=(
321
+ "FC sur backbone : backbone gelé, seule la tête FC est entraînée — rapide, peu de GPU. "
322
+ "CNN de zéro : réseau convolutif entraîné entièrement depuis rien — référence sans transfert."
323
+ ),
324
  )
325
 
326
+ with gr.Column(visible=False) as cnn_arch_col:
327
+ gr.Markdown("#### Architecture CNN")
328
+ num_conv_blocks = gr.Slider(minimum=2, maximum=5, value=3, step=1, label="Blocs convolutionnels")
329
+ base_filters = gr.Dropdown(choices=[16, 32, 64, 128], value=32, label="Filtres du premier bloc")
330
+ kernel_size = gr.Dropdown(choices=[3, 5], value=3, label="Taille du noyau")
331
+ use_batchnorm = gr.Checkbox(value=True, label="BatchNorm")
 
332
 
333
+ gr.Markdown("#### Hyperparamètres d'entraînement")
334
+ n_dropout = gr.Slider(minimum=0.0, maximum=0.8, value=0.4, step=0.05, label="Dropout")
335
+ n_fc_dim = gr.Dropdown(choices=[64, 128, 256, 512], value=256, label="Dimension couche cachée")
336
+ n_lr = gr.Number(value=1e-4, label="Taux d'apprentissage")
337
+ n_wd = gr.Number(value=1e-4, label="Weight decay")
338
+ n_bs = gr.Dropdown(choices=[8, 16, 32, 64], value=16, label="Taille du batch")
339
+ n_epochs = gr.Slider(minimum=1, maximum=50, value=20, step=1, label="Époques")
340
+ n_tag = gr.Textbox(label="Nom court du modèle", placeholder="ex. fc_head_v1")
341
 
342
+ train_neural_btn = gr.Button("Lancer l'entraînement", variant="primary")
343
 
344
  with gr.Column():
345
+ neural_logs = gr.Textbox(label="Journal d'entraînement", lines=20)
346
+ neural_history = gr.JSON(label="Historique")
347
+ neural_summary = gr.JSON(label="Résumé final")
 
 
 
348
 
349
  gr.Markdown("## Résultats sur le test set")
350
+ neural_report = gr.Dataframe(label="Rapport de classification", interactive=False)
351
+ neural_cm = gr.Dataframe(label="Matrice de confusion", interactive=False)
352
+ neural_cm_img = gr.Image(label="Matrice de confusion — figure", type="filepath")
353
+
354
+ # ------------------------------------------------------------------ #
355
+ # Tab 4
356
+ # ------------------------------------------------------------------ #
357
+ with gr.Tab("4. Tester et analyser"):
 
 
 
 
 
 
 
 
 
358
  gr.Markdown("## Sélectionner un modèle sauvegardé")
359
+ gr.Markdown(
360
+ "_Tous les types de modèles apparaissent ici : classifieurs ML, têtes FC et CNN._"
361
+ )
362
 
363
  with gr.Row():
364
  with gr.Column():
365
  model_selector = gr.Dropdown(
366
  choices=initial_models,
367
  value=initial_models[0] if initial_models else None,
368
+ label="Modèle",
369
  )
370
+ refresh_btn = gr.Button("Actualiser la liste")
 
371
  load_info_btn = gr.Button("Afficher les informations du modèle")
372
+ model_info = gr.JSON(label="Métadonnées")
373
 
374
  with gr.Column():
375
+ evaluate_btn = gr.Button("Évaluer sur le test set", variant="primary")
 
 
 
376
  eval_summary = gr.JSON(label="Résumé des métriques")
377
 
378
+ eval_report = gr.Dataframe(label="Rapport de classification", interactive=False)
379
+ eval_cm = gr.Dataframe(label="Matrice de confusion", interactive=False)
380
+ eval_cm_img = gr.Image(label="Matrice de confusion — figure", type="filepath")
 
 
 
 
 
 
 
 
 
 
 
381
 
382
  gr.Markdown("## Prédiction sur une image importée")
 
383
  with gr.Row():
384
  with gr.Column():
385
  upload_image = gr.Image(type="pil", label="Importer une image")
386
  predict_btn = gr.Button("Prédire la classe", variant="primary")
 
387
  with gr.Column():
388
+ predict_text = gr.Textbox(label="Résultat", lines=7)
389
  predict_probs = gr.Label(label="Probabilités par classe")
390
 
391
  gr.Markdown("## Test sur un échantillon aléatoire du test set")
 
392
  random_test_btn = gr.Button("Tester un échantillon aléatoire")
 
393
  with gr.Row():
394
+ random_img = gr.Image(type="pil", label="Image test")
395
+ random_text = gr.Textbox(label="Résultat", lines=7)
396
+ random_probs = gr.Label(label="Probabilités par classe")
397
+
398
+ # ---------------------------------------------------------------------- #
399
+ # Event wiring
400
+ # ---------------------------------------------------------------------- #
401
 
402
  load_dataset_btn.click(
403
+ fn=load_dataset_callback,
404
  inputs=None,
405
  outputs=[dataset_summary, class_distribution, class_selector],
406
  )
 
411
  outputs=image_gallery,
412
  )
413
 
414
+ extract_btn.click(fn=extract_features_callback, inputs=None, outputs=extract_status)
415
+
416
+ clf_type.change(
417
+ fn=on_clf_type_change,
418
+ inputs=clf_type,
419
+ outputs=[svm_col, logreg_col, knn_col, rf_col, lda_col],
420
  )
421
 
422
+ train_classical_btn.click(
423
+ fn=train_classical_callback,
424
  inputs=[
425
+ clf_type,
426
+ svm_c, svm_kernel, svm_gamma,
427
+ logreg_c, logreg_max_iter,
428
+ knn_k, knn_metric,
429
+ rf_n_estimators, rf_max_depth,
430
+ lda_solver,
431
+ ml_model_tag,
432
+ ],
433
+ outputs=[ml_summary, ml_report, ml_cm, ml_cm_img, model_selector],
434
+ )
435
+
436
+ neural_type.change(
437
+ fn=on_neural_type_change,
438
+ inputs=neural_type,
439
+ outputs=[cnn_arch_col, n_lr],
440
+ )
441
+
442
+ train_neural_btn.click(
443
+ fn=train_neural_callback,
444
+ inputs=[
445
+ neural_type,
446
+ num_conv_blocks, base_filters, kernel_size, use_batchnorm,
447
+ n_dropout, n_fc_dim,
448
+ n_lr, n_wd, n_bs, n_epochs,
449
+ n_tag,
450
  ],
451
  outputs=[
452
+ neural_logs, neural_history, neural_summary,
453
+ neural_report, neural_cm, neural_cm_img,
 
 
 
 
454
  model_selector,
455
  ],
456
  )
457
 
458
+ refresh_btn.click(fn=refresh_models_callback, inputs=None, outputs=model_selector)
 
 
 
 
459
 
460
+ load_info_btn.click(fn=get_model_info_callback, inputs=model_selector, outputs=model_info)
 
 
 
 
461
 
462
  evaluate_btn.click(
463
+ fn=evaluate_callback,
464
  inputs=model_selector,
465
+ outputs=[eval_summary, eval_report, eval_cm, eval_cm_img],
 
 
 
 
 
466
  )
467
 
468
  predict_btn.click(
469
+ fn=predict_callback,
470
  inputs=[model_selector, upload_image],
471
  outputs=[predict_text, predict_probs],
472
  )
473
 
474
  random_test_btn.click(
475
+ fn=random_test_callback,
476
  inputs=model_selector,
477
+ outputs=[random_img, random_text, random_probs],
478
  )
479
 
480
 
481
  if __name__ == "__main__":
482
+ demo.launch(ssr_mode=False)
backbone_utils.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ from huggingface_hub import hf_hub_download
5
+ from torch.utils.data import DataLoader
6
+ from torchvision import models
7
+
8
+ from config import HF_BACKBONE_REPO, HF_TOKEN
9
+
10
+ _BACKBONE = None
11
+ _FEATURES_CACHE = None
12
+
13
+
14
+ def load_backbone(device: torch.device) -> nn.Module:
15
+ global _BACKBONE
16
+
17
+ if _BACKBONE is not None:
18
+ return _BACKBONE.to(device)
19
+
20
+ if not HF_BACKBONE_REPO:
21
+ raise RuntimeError(
22
+ "HF_BACKBONE_REPO n'est pas configuré. "
23
+ "Ajoutez-le dans les Secrets du Space Hugging Face."
24
+ )
25
+
26
+ pt_path = hf_hub_download(
27
+ repo_id=HF_BACKBONE_REPO,
28
+ filename="resnet18_charcoal_backbone.pt",
29
+ token=HF_TOKEN,
30
+ repo_type="model",
31
+ )
32
+
33
+ backbone = models.resnet18()
34
+ backbone.fc = nn.Identity()
35
+ backbone.load_state_dict(torch.load(pt_path, map_location="cpu"))
36
+
37
+ for p in backbone.parameters():
38
+ p.requires_grad = False
39
+
40
+ _BACKBONE = backbone
41
+ return _BACKBONE.to(device)
42
+
43
+
44
+ def extract_all_features(batch_size: int = 64):
45
+ global _FEATURES_CACHE
46
+
47
+ from data_utils import prepare_splits, get_class_names, HFDatasetWrapper, get_eval_transform
48
+
49
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
+ backbone = load_backbone(device)
51
+ backbone.eval()
52
+
53
+ splits = prepare_splits()
54
+ class_names = get_class_names()
55
+
56
+ cache = {}
57
+ counts = {}
58
+
59
+ for split_name, split_data in splits.items():
60
+ dataset = HFDatasetWrapper(split_data, get_eval_transform())
61
+ loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
62
+
63
+ X_parts, y_parts = [], []
64
+ with torch.no_grad():
65
+ for images, labels in loader:
66
+ features = backbone(images.to(device))
67
+ X_parts.append(features.cpu().numpy())
68
+ y_parts.append(labels.numpy())
69
+
70
+ cache[split_name] = {
71
+ "X": np.concatenate(X_parts, axis=0),
72
+ "y": np.concatenate(y_parts, axis=0),
73
+ }
74
+ counts[split_name] = len(cache[split_name]["y"])
75
+
76
+ _FEATURES_CACHE = cache
77
+ return cache, class_names, counts
78
+
79
+
80
+ def get_cached_features():
81
+ return _FEATURES_CACHE
classical_ml_utils.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from datetime import datetime
4
+ from typing import List
5
+
6
+ import joblib
7
+ from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
8
+ from sklearn.ensemble import RandomForestClassifier
9
+ from sklearn.linear_model import LogisticRegression
10
+ from sklearn.neighbors import KNeighborsClassifier
11
+ from sklearn.pipeline import Pipeline
12
+ from sklearn.preprocessing import StandardScaler
13
+ from sklearn.svm import SVC
14
+
15
+ from config import MODEL_DIR, META_DIR
16
+ from metrics_utils import compute_classification_metrics, save_confusion_matrix_figure
17
+
18
+ CLF_TYPE_MAP = {
19
+ "SVM": "svm",
20
+ "Régression logistique": "logreg",
21
+ "k-NN": "knn",
22
+ "Forêt aléatoire": "rf",
23
+ "LDA": "lda",
24
+ }
25
+
26
+
27
+ def classifier_path(model_name: str) -> str:
28
+ return os.path.join(MODEL_DIR, f"{model_name}.joblib")
29
+
30
+
31
+ def meta_path(model_name: str) -> str:
32
+ return os.path.join(META_DIR, f"{model_name}.json")
33
+
34
+
35
+ def build_pipeline(clf_type: str, **params) -> Pipeline:
36
+ key = CLF_TYPE_MAP.get(clf_type, clf_type)
37
+
38
+ if key == "svm":
39
+ clf = SVC(
40
+ C=params.get("C", 1.0),
41
+ kernel=params.get("kernel", "rbf"),
42
+ gamma=params.get("gamma", "scale"),
43
+ probability=True,
44
+ random_state=42,
45
+ )
46
+ elif key == "logreg":
47
+ clf = LogisticRegression(
48
+ C=params.get("C", 1.0),
49
+ max_iter=params.get("max_iter", 1000),
50
+ random_state=42,
51
+ )
52
+ elif key == "knn":
53
+ clf = KNeighborsClassifier(
54
+ n_neighbors=params.get("n_neighbors", 5),
55
+ metric=params.get("metric", "euclidean"),
56
+ )
57
+ elif key == "rf":
58
+ max_depth = params.get("max_depth") or None
59
+ clf = RandomForestClassifier(
60
+ n_estimators=params.get("n_estimators", 100),
61
+ max_depth=max_depth,
62
+ random_state=42,
63
+ n_jobs=-1,
64
+ )
65
+ elif key == "lda":
66
+ clf = LinearDiscriminantAnalysis(solver=params.get("solver", "svd"))
67
+ else:
68
+ raise ValueError(f"Classifieur inconnu : {clf_type}")
69
+
70
+ return Pipeline([("scaler", StandardScaler()), ("clf", clf)])
71
+
72
+
73
+ def train_classical_model(
74
+ clf_type: str,
75
+ features_cache: dict,
76
+ class_names: List[str],
77
+ model_tag: str = "",
78
+ **params,
79
+ ):
80
+ X_train = features_cache["train"]["X"]
81
+ y_train = features_cache["train"]["y"]
82
+ X_test = features_cache["test"]["X"]
83
+ y_test = features_cache["test"]["y"]
84
+
85
+ pipeline = build_pipeline(clf_type, **params)
86
+ pipeline.fit(X_train, y_train)
87
+
88
+ y_pred = pipeline.predict(X_test)
89
+ metrics = compute_classification_metrics(y_test.tolist(), y_pred.tolist(), class_names)
90
+
91
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
92
+ safe_tag = model_tag.strip().replace(" ", "_") if model_tag.strip() else CLF_TYPE_MAP.get(clf_type, "clf")
93
+ model_name = f"{safe_tag}_{timestamp}"
94
+
95
+ joblib.dump(pipeline, classifier_path(model_name))
96
+ cm_path = save_confusion_matrix_figure(metrics["confusion_matrix"], model_name)
97
+
98
+ config_dict = {
99
+ "model_type": CLF_TYPE_MAP.get(clf_type, clf_type),
100
+ "clf_type_label": clf_type,
101
+ "class_names": class_names,
102
+ "num_classes": len(class_names),
103
+ **{k: v for k, v in params.items() if v is not None},
104
+ }
105
+
106
+ training_summary = {
107
+ "test_accuracy": metrics["accuracy"],
108
+ "test_f1_macro": metrics["f1_macro"],
109
+ "test_f1_weighted": metrics["f1_weighted"],
110
+ "train_samples": int(len(X_train)),
111
+ "test_samples": int(len(X_test)),
112
+ }
113
+
114
+ with open(meta_path(model_name), "w", encoding="utf-8") as f:
115
+ json.dump(
116
+ {
117
+ "model_name": model_name,
118
+ "config": config_dict,
119
+ "training_summary": training_summary,
120
+ "created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
121
+ },
122
+ f,
123
+ indent=2,
124
+ ensure_ascii=False,
125
+ )
126
+
127
+ return {
128
+ "model_name": model_name,
129
+ "summary": training_summary,
130
+ "classification_report": metrics["classification_report"],
131
+ "confusion_matrix": metrics["confusion_matrix"],
132
+ "confusion_matrix_path": cm_path,
133
+ }
134
+
135
+
136
+ def load_classical_pipeline(model_name: str) -> Pipeline:
137
+ path = classifier_path(model_name)
138
+ if not os.path.exists(path):
139
+ raise FileNotFoundError(f"Classifieur introuvable : {model_name}")
140
+ return joblib.load(path)
config.py CHANGED
@@ -10,13 +10,13 @@ os.makedirs(MODEL_DIR, exist_ok=True)
10
  os.makedirs(META_DIR, exist_ok=True)
11
  os.makedirs(FIGURE_DIR, exist_ok=True)
12
 
13
- # Replace this with your real private dataset repo
14
  HF_DATASET_REPO = os.environ.get("HF_DATASET_REPO", "CircleStar/charcoal-microscopy")
15
-
16
- # Must be added in Hugging Face Space Settings → Secrets
17
  HF_TOKEN = os.environ.get("HF_TOKEN")
18
 
19
  IMAGE_SIZE = 224
20
  RANDOM_SEED = 42
21
 
22
- DATASET_DISPLAY_NAME = "Images microscopiques de charbons de bois"
 
 
 
10
  os.makedirs(META_DIR, exist_ok=True)
11
  os.makedirs(FIGURE_DIR, exist_ok=True)
12
 
 
13
  HF_DATASET_REPO = os.environ.get("HF_DATASET_REPO", "CircleStar/charcoal-microscopy")
14
+ HF_BACKBONE_REPO = os.environ.get("HF_BACKBONE_REPO", "")
 
15
  HF_TOKEN = os.environ.get("HF_TOKEN")
16
 
17
  IMAGE_SIZE = 224
18
  RANDOM_SEED = 42
19
 
20
+ DATASET_DISPLAY_NAME = "Images microscopiques de charbons de bois"
21
+
22
+ CLASSICAL_MODEL_TYPES = frozenset({"svm", "logreg", "knn", "rf", "lda"})
data_utils.py CHANGED
@@ -24,7 +24,7 @@ class HFDatasetWrapper(Dataset):
24
 
25
  def __len__(self):
26
  return len(self.dataset)
27
-
28
  def __getitem__(self, idx):
29
  item = self.dataset[idx]
30
 
 
24
 
25
  def __len__(self):
26
  return len(self.dataset)
27
+ ·
28
  def __getitem__(self, idx):
29
  item = self.dataset[idx]
30
 
finetune_backbone.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ finetune_backbone.py
3
+
4
+ Fine-tune ResNet18 (ImageNet) on the local charcoal microscopy dataset.
5
+ Goal: produce a domain-adapted backbone for students to use as a frozen
6
+ feature extractor. The full dataset is used intentionally — this is a
7
+ teaching artifact, not a research model with a held-out test split.
8
+
9
+ Output (in backbone/):
10
+ resnet18_charcoal_backbone.pt — backbone weights, FC replaced by Identity
11
+ backbone_meta.json — class names, feature dim, training info
12
+
13
+ Usage:
14
+ python finetune_backbone.py
15
+ python finetune_backbone.py --epochs 40 --batch-size 16
16
+ """
17
+
18
+ import argparse
19
+ import json
20
+ import time
21
+ from pathlib import Path
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.optim as optim
26
+ from PIL import Image
27
+ from torch.utils.data import DataLoader, Dataset
28
+ from torchvision import models, transforms
29
+
30
+ # ---------------------------------------------------------------------------
31
+ # Paths
32
+ # ---------------------------------------------------------------------------
33
+ ROOT = Path(__file__).parent
34
+ DATA_DIR = ROOT / "data"
35
+ OUTPUT_DIR = ROOT / "backbone"
36
+ OUTPUT_DIR.mkdir(exist_ok=True)
37
+
38
+ # ---------------------------------------------------------------------------
39
+ # Defaults
40
+ # ---------------------------------------------------------------------------
41
+ IMAGE_SIZE = 224
42
+ SEED = 42
43
+
44
+ WARMUP_EPOCHS = 10 # backbone frozen, only FC trained
45
+ WARMUP_LR = 1e-3
46
+
47
+ FINETUNE_EPOCHS = 40 # all layers unfrozen, small LR
48
+ FINETUNE_LR = 5e-5
49
+ WEIGHT_DECAY = 1e-4
50
+
51
+
52
+ # ---------------------------------------------------------------------------
53
+ # Dataset
54
+ # ---------------------------------------------------------------------------
55
+ class CharcoalDataset(Dataset):
56
+ """Flat ImageFolder-style dataset that handles .tif files."""
57
+
58
+ EXTENSIONS = {".tif", ".tiff", ".jpg", ".jpeg", ".png"}
59
+
60
+ def __init__(self, root: Path, transform=None):
61
+ self.transform = transform
62
+ self.classes = sorted(
63
+ d.name for d in root.iterdir()
64
+ if d.is_dir() and not d.name.startswith(".")
65
+ )
66
+ self.class_to_idx = {c: i for i, c in enumerate(self.classes)}
67
+
68
+ self.samples = []
69
+ for cls in self.classes:
70
+ for p in sorted((root / cls).iterdir()):
71
+ if p.suffix.lower() in self.EXTENSIONS:
72
+ self.samples.append((p, self.class_to_idx[cls]))
73
+
74
+ def __len__(self):
75
+ return len(self.samples)
76
+
77
+ def __getitem__(self, idx):
78
+ path, label = self.samples[idx]
79
+ image = Image.open(path).convert("RGB")
80
+ if self.transform:
81
+ image = self.transform(image)
82
+ return image, label
83
+
84
+
85
+ def make_transform():
86
+ # Aggressive augmentation: microscopy images have no canonical orientation
87
+ # and vary in staining intensity.
88
+ return transforms.Compose([
89
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
90
+ transforms.RandomHorizontalFlip(),
91
+ transforms.RandomVerticalFlip(),
92
+ transforms.RandomRotation(180),
93
+ transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2),
94
+ transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.85, 1.15)),
95
+ transforms.ToTensor(),
96
+ transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
97
+ ])
98
+
99
+
100
+ # ---------------------------------------------------------------------------
101
+ # Training helpers
102
+ # ---------------------------------------------------------------------------
103
+ def run_epoch(model, loader, criterion, optimizer, device):
104
+ model.train()
105
+ total_loss, correct, total = 0.0, 0, 0
106
+
107
+ for images, labels in loader:
108
+ images, labels = images.to(device), labels.to(device)
109
+
110
+ optimizer.zero_grad()
111
+ outputs = model(images)
112
+ loss = criterion(outputs, labels)
113
+ loss.backward()
114
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
115
+ optimizer.step()
116
+
117
+ total_loss += loss.item() * images.size(0)
118
+ correct += (outputs.argmax(1) == labels).sum().item()
119
+ total += labels.size(0)
120
+
121
+ return total_loss / total, correct / total
122
+
123
+
124
+ # ---------------------------------------------------------------------------
125
+ # Main
126
+ # ---------------------------------------------------------------------------
127
+ def main():
128
+ parser = argparse.ArgumentParser()
129
+ parser.add_argument("--warmup-epochs", type=int, default=WARMUP_EPOCHS)
130
+ parser.add_argument("--finetune-epochs", type=int, default=FINETUNE_EPOCHS)
131
+ parser.add_argument("--batch-size", type=int, default=8)
132
+ parser.add_argument("--warmup-lr", type=float, default=WARMUP_LR)
133
+ parser.add_argument("--finetune-lr", type=float, default=FINETUNE_LR)
134
+ args = parser.parse_args()
135
+
136
+ torch.manual_seed(SEED)
137
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
138
+ print(f"Device : {device}")
139
+
140
+ dataset = CharcoalDataset(DATA_DIR, transform=make_transform())
141
+ num_classes = len(dataset.classes)
142
+ print(f"Classes : {num_classes} | Images : {len(dataset)}")
143
+ print(f" {', '.join(dataset.classes)}\n")
144
+
145
+ loader = DataLoader(
146
+ dataset,
147
+ batch_size=args.batch_size,
148
+ shuffle=True,
149
+ num_workers=0, # 0 = safe on Windows
150
+ pin_memory=(device.type == "cuda"),
151
+ )
152
+
153
+ # -----------------------------------------------------------------------
154
+ # Build model
155
+ # -----------------------------------------------------------------------
156
+ model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
157
+ model.fc = nn.Linear(model.fc.in_features, num_classes)
158
+ model.to(device)
159
+
160
+ # Label smoothing helps regularise with tiny datasets
161
+ criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
162
+
163
+ # -----------------------------------------------------------------------
164
+ # Phase 1 — warm-up: freeze backbone, train FC only
165
+ # -----------------------------------------------------------------------
166
+ print(f"=== Phase 1 : warm-up ({args.warmup_epochs} epochs, backbone frozen) ===")
167
+ for p in model.parameters():
168
+ p.requires_grad = False
169
+ for p in model.fc.parameters():
170
+ p.requires_grad = True
171
+
172
+ optimizer = optim.AdamW(model.fc.parameters(), lr=args.warmup_lr, weight_decay=WEIGHT_DECAY)
173
+
174
+ for epoch in range(1, args.warmup_epochs + 1):
175
+ loss, acc = run_epoch(model, loader, criterion, optimizer, device)
176
+ print(f" [{epoch:>3}/{args.warmup_epochs}] loss={loss:.4f} acc={acc:.4f}")
177
+
178
+ # -----------------------------------------------------------------------
179
+ # Phase 2 — full fine-tune: unfreeze all layers
180
+ # -----------------------------------------------------------------------
181
+ print(f"\n=== Phase 2 : fine-tune ({args.finetune_epochs} epochs, all layers) ===")
182
+ for p in model.parameters():
183
+ p.requires_grad = True
184
+
185
+ optimizer = optim.AdamW(
186
+ model.parameters(), lr=args.finetune_lr, weight_decay=WEIGHT_DECAY
187
+ )
188
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(
189
+ optimizer, T_max=args.finetune_epochs, eta_min=args.finetune_lr * 0.05
190
+ )
191
+
192
+ best_acc = 0.0
193
+ best_state = None
194
+ t0 = time.time()
195
+
196
+ for epoch in range(1, args.finetune_epochs + 1):
197
+ loss, acc = run_epoch(model, loader, criterion, optimizer, device)
198
+ scheduler.step()
199
+ lr = optimizer.param_groups[0]["lr"]
200
+ print(f" [{epoch:>3}/{args.finetune_epochs}] loss={loss:.4f} acc={acc:.4f} lr={lr:.2e}")
201
+
202
+ if acc > best_acc:
203
+ best_acc = acc
204
+ best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
205
+
206
+ elapsed = time.time() - t0
207
+ print(f"\nTemps phase 2 : {elapsed:.0f}s | Meilleure accuracy entraînement : {best_acc:.4f}")
208
+
209
+ # -----------------------------------------------------------------------
210
+ # Save backbone (FC replaced by Identity — outputs 512-dim feature vector)
211
+ # -----------------------------------------------------------------------
212
+ model.load_state_dict(best_state)
213
+
214
+ backbone = models.resnet18()
215
+ backbone.fc = nn.Identity()
216
+
217
+ # Transfer all weights except fc (which is now Identity with no parameters)
218
+ backbone_state = {k: v for k, v in best_state.items() if not k.startswith("fc.")}
219
+ backbone.load_state_dict(backbone_state, strict=False)
220
+
221
+ backbone_path = OUTPUT_DIR / "resnet18_charcoal_backbone.pt"
222
+ torch.save(backbone.state_dict(), backbone_path)
223
+ print(f"Backbone sauvegardé : {backbone_path}")
224
+
225
+ # -----------------------------------------------------------------------
226
+ # Save metadata
227
+ # -----------------------------------------------------------------------
228
+ meta = {
229
+ "classes": dataset.classes,
230
+ "num_classes": num_classes,
231
+ "image_size": IMAGE_SIZE,
232
+ "feature_dim": 512,
233
+ "warmup_epochs": args.warmup_epochs,
234
+ "finetune_epochs": args.finetune_epochs,
235
+ "best_train_acc": round(float(best_acc), 4),
236
+ "device": str(device),
237
+ }
238
+ meta_path = OUTPUT_DIR / "backbone_meta.json"
239
+ with open(meta_path, "w", encoding="utf-8") as f:
240
+ json.dump(meta, f, indent=2, ensure_ascii=False)
241
+ print(f"Métadonnées sauvegardées : {meta_path}")
242
+
243
+
244
+ if __name__ == "__main__":
245
+ main()
model.py CHANGED
@@ -1,33 +1,22 @@
1
  import torch.nn as nn
2
- from torchvision import models
3
 
4
 
5
- class ResNet18Classifier(nn.Module):
6
- def __init__(self, num_classes: int, dropout: float = 0.4, fc_dim: int = 256):
7
- super().__init__()
8
-
9
- weights = models.ResNet18_Weights.DEFAULT
10
- self.backbone = models.resnet18(weights=weights)
11
- in_features = self.backbone.fc.in_features
12
-
13
- # Gel de tout le réseau sauf layer4 et classifieur
14
- for param in self.backbone.parameters():
15
- param.requires_grad = False
16
- for param in self.backbone.layer4.parameters():
17
- param.requires_grad = True
18
 
19
- self.backbone.fc = nn.Sequential(
 
 
 
20
  nn.Dropout(dropout),
21
- nn.Linear(in_features, fc_dim),
22
- nn.ReLU(),
23
  nn.Dropout(dropout),
24
  nn.Linear(fc_dim, num_classes),
25
  )
26
- for param in self.backbone.fc.parameters():
27
- param.requires_grad = True
28
 
29
  def forward(self, x):
30
- return self.backbone(x)
31
 
32
 
33
  class SimpleCNN(nn.Module):
@@ -48,7 +37,6 @@ class SimpleCNN(nn.Module):
48
  in_channels = 3
49
 
50
  for i in range(num_conv_blocks):
51
- # Les filtres doublent à chaque bloc, plafonnés à 512
52
  out_channels = min(base_filters * (2 ** i), 512)
53
  layers.append(nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding))
54
  if use_batchnorm:
@@ -58,7 +46,6 @@ class SimpleCNN(nn.Module):
58
  in_channels = out_channels
59
 
60
  self.features = nn.Sequential(*layers)
61
- # Pooling global : indépendant de la taille spatiale d'entrée
62
  self.pool = nn.AdaptiveAvgPool2d(1)
63
 
64
  self.classifier = nn.Sequential(
@@ -70,7 +57,5 @@ class SimpleCNN(nn.Module):
70
  )
71
 
72
  def forward(self, x):
73
- x = self.features(x)
74
- x = self.pool(x)
75
- x = x.flatten(1)
76
- return self.classifier(x)
 
1
  import torch.nn as nn
 
2
 
3
 
4
+ class BackboneWithFC(nn.Module):
5
+ """Frozen ResNet18 backbone + trainable FC classifier head."""
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ def __init__(self, backbone: nn.Module, num_classes: int, dropout: float = 0.4, fc_dim: int = 256):
8
+ super().__init__()
9
+ self.backbone = backbone
10
+ self.classifier = nn.Sequential(
11
  nn.Dropout(dropout),
12
+ nn.Linear(512, fc_dim),
13
+ nn.ReLU(inplace=True),
14
  nn.Dropout(dropout),
15
  nn.Linear(fc_dim, num_classes),
16
  )
 
 
17
 
18
  def forward(self, x):
19
+ return self.classifier(self.backbone(x))
20
 
21
 
22
  class SimpleCNN(nn.Module):
 
37
  in_channels = 3
38
 
39
  for i in range(num_conv_blocks):
 
40
  out_channels = min(base_filters * (2 ** i), 512)
41
  layers.append(nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding))
42
  if use_batchnorm:
 
46
  in_channels = out_channels
47
 
48
  self.features = nn.Sequential(*layers)
 
49
  self.pool = nn.AdaptiveAvgPool2d(1)
50
 
51
  self.classifier = nn.Sequential(
 
57
  )
58
 
59
  def forward(self, x):
60
+ x = self.pool(self.features(x))
61
+ return self.classifier(x.flatten(1))
 
 
predict_utils.py CHANGED
@@ -1,41 +1,56 @@
1
  import random
2
 
 
3
  import torch
4
  from PIL import Image
5
 
 
6
  from data_utils import get_eval_transform, prepare_splits, get_class_names
7
- from train_utils import load_model, get_runtime_device
 
 
 
 
 
 
 
 
 
 
8
 
9
 
10
  def predict_uploaded_image(model_name: str, image: Image.Image):
11
  if not model_name:
12
  return "Veuillez sélectionner un modèle.", None
13
-
14
  if image is None:
15
  return "Veuillez importer une image.", None
16
 
17
- device = get_runtime_device()
18
- model, meta = load_model(model_name, device)
19
-
20
  class_names = meta["config"]["class_names"]
21
- transform = get_eval_transform()
22
-
23
- image = image.convert("RGB")
24
- tensor = transform(image).unsqueeze(0).to(device)
25
 
26
- with torch.no_grad():
27
- logits = model(tensor)
28
- probs = torch.softmax(logits, dim=1).squeeze(0).cpu().tolist()
29
- pred_idx = int(torch.argmax(logits, dim=1).item())
 
 
 
 
 
 
 
 
 
30
 
31
  result_text = (
32
  f"Prédiction : {class_names[pred_idx]}\n"
33
  f"Confiance : {max(probs):.4f}\n\n"
34
  f"Modèle : {model_name}\n"
35
- f"Jeu de données : {meta['config']['dataset_name']}\n"
36
- f"Appareil utilisé : {device}"
37
  )
38
-
39
  prob_dict = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
40
  return result_text, prob_dict
41
 
@@ -44,40 +59,44 @@ def test_random_sample(model_name: str):
44
  if not model_name:
45
  return None, "Veuillez sélectionner un modèle.", None
46
 
 
 
 
47
  device = get_runtime_device()
48
- model, meta = load_model(model_name, device)
49
 
50
  splits = prepare_splits()
51
- class_names = get_class_names()
52
  test_dataset = splits["test"]
53
 
54
  idx = random.randint(0, len(test_dataset) - 1)
55
  item = test_dataset[idx]
56
-
57
  image = item["image"]
58
  if not isinstance(image, Image.Image):
59
  image = Image.open(image)
60
-
61
  image = image.convert("RGB")
62
-
63
- label = int(item["label"])
64
- label_name = class_names[label]
65
-
66
- transform = get_eval_transform()
67
- tensor = transform(image).unsqueeze(0).to(device)
68
-
69
- with torch.no_grad():
70
- logits = model(tensor)
71
- probs = torch.softmax(logits, dim=1).squeeze(0).cpu().tolist()
72
- pred_idx = int(torch.argmax(logits, dim=1).item())
73
-
 
 
 
 
 
74
  result_text = (
75
  f"Échantillon test aléatoire\n"
76
  f"Vérité terrain : {label_name}\n"
77
- f"Prédiction : {class_names[pred_idx]}\n"
78
- f"Confiance : {max(probs):.4f}\n"
79
- f"Appareil utilisé : {device}"
 
80
  )
81
-
82
- prob_dict = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
83
- return image, result_text, prob_dict
 
1
  import random
2
 
3
+ import numpy as np
4
  import torch
5
  from PIL import Image
6
 
7
+ from config import CLASSICAL_MODEL_TYPES
8
  from data_utils import get_eval_transform, prepare_splits, get_class_names
9
+ from train_utils import load_model, get_runtime_device, _load_meta
10
+
11
+
12
+ def _extract_feature(image: Image.Image, device: torch.device) -> np.ndarray:
13
+ from backbone_utils import load_backbone
14
+ backbone = load_backbone(device)
15
+ backbone.eval()
16
+ tensor = get_eval_transform()(image.convert("RGB")).unsqueeze(0).to(device)
17
+ with torch.no_grad():
18
+ feat = backbone(tensor)
19
+ return feat.cpu().numpy()
20
 
21
 
22
  def predict_uploaded_image(model_name: str, image: Image.Image):
23
  if not model_name:
24
  return "Veuillez sélectionner un modèle.", None
 
25
  if image is None:
26
  return "Veuillez importer une image.", None
27
 
28
+ meta = _load_meta(model_name)
29
+ model_type = meta["config"].get("model_type", "cnn")
 
30
  class_names = meta["config"]["class_names"]
31
+ device = get_runtime_device()
 
 
 
32
 
33
+ if model_type in CLASSICAL_MODEL_TYPES:
34
+ from classical_ml_utils import load_classical_pipeline
35
+ pipeline = load_classical_pipeline(model_name)
36
+ feat = _extract_feature(image, device)
37
+ probs = pipeline.predict_proba(feat)[0].tolist()
38
+ pred_idx = int(np.argmax(probs))
39
+ else:
40
+ model, _ = load_model(model_name, device)
41
+ tensor = get_eval_transform()(image.convert("RGB")).unsqueeze(0).to(device)
42
+ with torch.no_grad():
43
+ logits = model(tensor)
44
+ probs = torch.softmax(logits, dim=1).squeeze(0).cpu().tolist()
45
+ pred_idx = int(torch.argmax(logits, dim=1).item())
46
 
47
  result_text = (
48
  f"Prédiction : {class_names[pred_idx]}\n"
49
  f"Confiance : {max(probs):.4f}\n\n"
50
  f"Modèle : {model_name}\n"
51
+ f"Type : {model_type}\n"
52
+ f"Appareil : {device}"
53
  )
 
54
  prob_dict = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
55
  return result_text, prob_dict
56
 
 
59
  if not model_name:
60
  return None, "Veuillez sélectionner un modèle.", None
61
 
62
+ meta = _load_meta(model_name)
63
+ model_type = meta["config"].get("model_type", "cnn")
64
+ class_names = get_class_names()
65
  device = get_runtime_device()
 
66
 
67
  splits = prepare_splits()
 
68
  test_dataset = splits["test"]
69
 
70
  idx = random.randint(0, len(test_dataset) - 1)
71
  item = test_dataset[idx]
 
72
  image = item["image"]
73
  if not isinstance(image, Image.Image):
74
  image = Image.open(image)
 
75
  image = image.convert("RGB")
76
+ label_name = class_names[int(item["label"])]
77
+
78
+ if model_type in CLASSICAL_MODEL_TYPES:
79
+ from classical_ml_utils import load_classical_pipeline
80
+ pipeline = load_classical_pipeline(model_name)
81
+ feat = _extract_feature(image, device)
82
+ probs = pipeline.predict_proba(feat)[0].tolist()
83
+ pred_idx = int(np.argmax(probs))
84
+ else:
85
+ model, _ = load_model(model_name, device)
86
+ tensor = get_eval_transform()(image).unsqueeze(0).to(device)
87
+ with torch.no_grad():
88
+ logits = model(tensor)
89
+ probs = torch.softmax(logits, dim=1).squeeze(0).cpu().tolist()
90
+ pred_idx = int(torch.argmax(logits, dim=1).item())
91
+
92
+ model_class_names = meta["config"]["class_names"]
93
  result_text = (
94
  f"Échantillon test aléatoire\n"
95
  f"Vérité terrain : {label_name}\n"
96
+ f"Prédiction : {model_class_names[pred_idx]}\n"
97
+ f"Confiance : {max(probs):.4f}\n"
98
+ f"Type modèle : {model_type}\n"
99
+ f"Appareil : {device}"
100
  )
101
+ prob_dict = {model_class_names[i]: float(probs[i]) for i in range(len(model_class_names))}
102
+ return image, result_text, prob_dict
 
train_utils.py CHANGED
@@ -8,16 +8,24 @@ import torch
8
  import torch.nn as nn
9
  import torch.optim as optim
10
 
11
- from config import MODEL_DIR, META_DIR, DATASET_DISPLAY_NAME
12
  from data_utils import make_loaders
13
  from metrics_utils import compute_classification_metrics, save_confusion_matrix_figure
14
- from model import SimpleCNN, ResNet18Classifier
15
 
16
 
 
 
 
 
17
  def model_weight_path(model_name: str) -> str:
18
  return os.path.join(MODEL_DIR, f"{model_name}.pt")
19
 
20
 
 
 
 
 
21
  def model_meta_path(model_name: str) -> str:
22
  return os.path.join(META_DIR, f"{model_name}.json")
23
 
@@ -34,43 +42,54 @@ def get_runtime_device() -> torch.device:
34
  return torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
 
36
 
 
 
 
 
37
  def save_model(model: nn.Module, model_name: str, config: dict, training_summary: dict):
38
- cpu_state_dict = {k: v.detach().cpu() for k, v in model.state_dict().items()}
39
- torch.save(cpu_state_dict, model_weight_path(model_name))
 
 
40
 
41
- payload = {
42
- "model_name": model_name,
43
- "config": config,
44
- "training_summary": training_summary,
45
- "created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
46
- }
47
 
48
  with open(model_meta_path(model_name), "w", encoding="utf-8") as f:
49
- json.dump(payload, f, indent=2, ensure_ascii=False)
50
-
51
-
52
- def load_model(model_name: str, device: torch.device) -> Tuple[nn.Module, dict]:
53
- meta_file = model_meta_path(model_name)
54
- weight_file = model_weight_path(model_name)
 
 
 
 
 
55
 
56
- if not os.path.exists(meta_file):
57
- raise FileNotFoundError(f"Métadonnées introuvables pour le modèle : {model_name}")
58
 
59
- if not os.path.exists(weight_file):
60
- raise FileNotFoundError(f"Poids introuvables pour le modèle : {model_name}")
 
 
 
 
61
 
62
- with open(meta_file, "r", encoding="utf-8") as f:
63
- meta = json.load(f)
64
 
 
 
65
  cfg = meta["config"]
66
-
67
- if cfg.get("model_type", "cnn") == "resnet18":
68
- model = ResNet18Classifier(
69
- num_classes=cfg["num_classes"],
70
- dropout=cfg.get("dropout", 0.4),
71
- fc_dim=cfg.get("fc_dim", 256),
 
 
72
  )
73
- else:
 
74
  model = SimpleCNN(
75
  num_classes=cfg["num_classes"],
76
  num_conv_blocks=cfg.get("num_conv_blocks", 3),
@@ -80,206 +99,264 @@ def load_model(model_name: str, device: torch.device) -> Tuple[nn.Module, dict]:
80
  dropout=cfg.get("dropout", 0.4),
81
  fc_dim=cfg.get("fc_dim", 256),
82
  )
 
 
 
 
83
 
84
- state_dict = torch.load(weight_file, map_location="cpu")
85
- model.load_state_dict(state_dict)
86
  model.to(device)
87
  model.eval()
88
-
89
  return model, meta
90
 
91
 
 
 
 
 
92
  def evaluate_loss_acc(model, loader, criterion, device):
93
  model.eval()
94
-
95
- total_loss = 0.0
96
- total = 0
97
- correct = 0
98
 
99
  with torch.no_grad():
100
  for images, labels in loader:
101
  images, labels = images.to(device), labels.to(device)
102
-
103
  outputs = model(images)
104
  loss = criterion(outputs, labels)
105
-
106
  total_loss += loss.item() * images.size(0)
107
- preds = outputs.argmax(dim=1)
108
-
109
- correct += (preds == labels).sum().item()
110
  total += labels.size(0)
111
 
112
- avg_loss = total_loss / total if total else 0.0
113
- acc = correct / total if total else 0.0
114
-
115
- return avg_loss, acc
116
 
117
 
118
  def collect_predictions(model, loader, device):
119
  model.eval()
120
-
121
- y_true = []
122
- y_pred = []
123
 
124
  with torch.no_grad():
125
  for images, labels in loader:
126
- images = images.to(device)
127
-
128
- outputs = model(images)
129
- preds = outputs.argmax(dim=1).detach().cpu().tolist()
130
-
131
- y_pred.extend(preds)
132
  y_true.extend(labels.tolist())
133
 
134
  return y_true, y_pred
135
 
136
 
137
- def train_model(
138
- model_type: str = "cnn",
139
- num_conv_blocks: int = 3,
140
- base_filters: int = 32,
141
- kernel_size: int = 3,
142
- use_batchnorm: bool = True,
143
- dropout: float = 0.4,
144
- fc_dim: int = 256,
145
- learning_rate: float = 0.001,
146
- weight_decay: float = 0.0001,
147
- batch_size: int = 16,
148
- epochs: int = 30,
149
- model_tag: str = "",
150
- ):
151
- device = get_runtime_device()
152
-
153
- train_loader, val_loader, test_loader, class_names = make_loaders(batch_size)
154
- num_classes = len(class_names)
155
-
156
- if model_type == "resnet18":
157
- model = ResNet18Classifier(
158
- num_classes=num_classes,
159
- dropout=dropout,
160
- fc_dim=fc_dim,
161
- ).to(device)
162
- else:
163
- model = SimpleCNN(
164
- num_classes=num_classes,
165
- num_conv_blocks=num_conv_blocks,
166
- base_filters=base_filters,
167
- kernel_size=kernel_size,
168
- use_batchnorm=use_batchnorm,
169
- dropout=dropout,
170
- fc_dim=fc_dim,
171
- ).to(device)
172
-
173
- trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
174
- total_params = sum(p.numel() for p in model.parameters())
175
-
176
- criterion = nn.CrossEntropyLoss()
177
-
178
- optimizer = optim.AdamW(
179
- filter(lambda p: p.requires_grad, model.parameters()),
180
- lr=learning_rate,
181
- weight_decay=weight_decay,
182
- )
183
-
184
- # Réduit le LR de moitié si val_loss ne s'améliore pas pendant 8 époques
185
- # patience élevée car le val set est très petit (bruit important)
186
- scheduler = optim.lr_scheduler.ReduceLROnPlateau(
187
- optimizer,
188
- mode="min",
189
- factor=0.5,
190
- patience=8,
191
- min_lr=learning_rate * 0.2,
192
- )
193
-
194
  history = []
195
  logs = []
196
- start_time = time.time()
197
-
198
  best_val_loss = float("inf")
199
- best_state_dict = None
200
 
201
  for epoch in range(1, epochs + 1):
202
  model.train()
203
-
204
- running_loss = 0.0
205
- total = 0
206
- correct = 0
207
 
208
  for images, labels in train_loader:
209
  images, labels = images.to(device), labels.to(device)
210
-
211
  optimizer.zero_grad()
212
  outputs = model(images)
213
-
214
  loss = criterion(outputs, labels)
215
  loss.backward()
216
-
217
- # Important: prevents unstable fine-tuning / exploding gradients
218
  torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
219
-
220
  optimizer.step()
221
 
222
  running_loss += loss.item() * images.size(0)
223
-
224
- preds = outputs.argmax(dim=1)
225
- correct += (preds == labels).sum().item()
226
  total += labels.size(0)
227
 
228
  train_loss = running_loss / total if total else 0.0
229
  train_acc = correct / total if total else 0.0
230
-
231
  val_loss, val_acc = evaluate_loss_acc(model, val_loader, criterion, device)
232
  scheduler.step(val_loss)
233
  current_lr = optimizer.param_groups[0]["lr"]
234
 
235
  if val_loss < best_val_loss:
236
  best_val_loss = val_loss
237
- best_state_dict = {
238
- k: v.detach().cpu().clone()
239
- for k, v in model.state_dict().items()
240
- }
241
 
242
- row = {
243
  "epoch": epoch,
244
  "train_loss": round(train_loss, 4),
245
  "train_acc": round(train_acc, 4),
246
  "val_loss": round(val_loss, 4),
247
  "val_acc": round(val_acc, 4),
248
- }
249
-
250
- history.append(row)
251
-
252
  logs.append(
253
  f"Époque {epoch}/{epochs} | "
254
- f"perte entraînement={train_loss:.4f}, précision entraînement={train_acc:.4f}, "
255
- f"perte validation={val_loss:.4f}, précision validation={val_acc:.4f}, "
256
- f"lr={current_lr:.6f}"
257
  )
258
 
259
- if best_state_dict is not None:
260
- model.load_state_dict(best_state_dict)
261
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  test_loss, test_acc = evaluate_loss_acc(model, test_loader, criterion, device)
263
  y_true, y_pred = collect_predictions(model, test_loader, device)
264
-
265
  metrics = compute_classification_metrics(y_true, y_pred, class_names)
 
266
 
267
- elapsed = time.time() - start_time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
 
269
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
270
- safe_tag = model_tag.strip().replace(" ", "_") if model_tag.strip() else "charcoal_resnet18"
271
  model_name = f"{safe_tag}_{timestamp}"
272
 
273
  cm_path = save_confusion_matrix_figure(metrics["confusion_matrix"], model_name)
274
 
275
- if model_type == "resnet18":
276
- architecture = "ResNet18 pré-entraîné (layer4 + classifieur)"
277
- else:
278
- architecture = f"CNN simple ({num_conv_blocks} blocs, filtres={base_filters}, noyau={kernel_size}x{kernel_size})"
279
 
280
  config = {
281
  "dataset_name": DATASET_DISPLAY_NAME,
282
- "model_type": model_type,
283
  "architecture": architecture,
284
  "num_classes": num_classes,
285
  "class_names": class_names,
@@ -313,18 +390,16 @@ def train_model(
313
 
314
  save_model(model, model_name, config, training_summary)
315
 
316
- logs.append("")
317
- logs.append("Entraînement terminé.")
318
- logs.append(f"Modèle sauvegardé : {model_name}")
319
- logs.append(f"Appareil utilisé : {device}")
320
- logs.append(f"Architecture : {architecture}")
321
- logs.append(f"Nombre total de paramètres : {total_params}")
322
- logs.append(f"Paramètres entraînables : {trainable_params}")
323
- logs.append(f"Perte test cross-entropy : {test_loss:.4f}")
324
- logs.append(f"Accuracy test : {test_acc:.4f}")
325
- logs.append(f"F1 macro test : {metrics['f1_macro']:.4f}")
326
- logs.append(f"F1 pondéré test : {metrics['f1_weighted']:.4f}")
327
- logs.append(f"Temps écoulé : {elapsed:.1f}s")
328
 
329
  return {
330
  "logs": "\n".join(logs),
@@ -337,10 +412,24 @@ def train_model(
337
  }
338
 
339
 
 
 
 
 
340
  def evaluate_saved_model(model_name: str):
341
  if not model_name:
342
  raise ValueError("Aucun modèle sélectionné.")
343
 
 
 
 
 
 
 
 
 
 
 
344
  device = get_runtime_device()
345
  model, meta = load_model(model_name, device)
346
 
@@ -348,19 +437,51 @@ def evaluate_saved_model(model_name: str):
348
  _, _, test_loader, class_names = make_loaders(batch_size)
349
 
350
  criterion = nn.CrossEntropyLoss()
351
-
352
  test_loss, test_acc = evaluate_loss_acc(model, test_loader, criterion, device)
353
  y_true, y_pred = collect_predictions(model, test_loader, device)
354
 
355
  metrics = compute_classification_metrics(y_true, y_pred, class_names)
356
  cm_path = save_confusion_matrix_figure(metrics["confusion_matrix"], model_name)
357
 
358
- summary = {
359
- "test_cross_entropy_loss": round(test_loss, 4),
360
- "test_accuracy": round(test_acc, 4),
361
- "test_f1_macro": metrics["f1_macro"],
362
- "test_f1_weighted": metrics["f1_weighted"],
363
- "device": str(device),
364
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
 
366
- return summary, metrics["classification_report"], metrics["confusion_matrix"], cm_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  import torch.nn as nn
9
  import torch.optim as optim
10
 
11
+ from config import MODEL_DIR, META_DIR, DATASET_DISPLAY_NAME, CLASSICAL_MODEL_TYPES
12
  from data_utils import make_loaders
13
  from metrics_utils import compute_classification_metrics, save_confusion_matrix_figure
14
+ from model import SimpleCNN, BackboneWithFC
15
 
16
 
17
+ # ---------------------------------------------------------------------------
18
+ # Path helpers
19
+ # ---------------------------------------------------------------------------
20
+
21
  def model_weight_path(model_name: str) -> str:
22
  return os.path.join(MODEL_DIR, f"{model_name}.pt")
23
 
24
 
25
+ def classifier_weight_path(model_name: str) -> str:
26
+ return os.path.join(MODEL_DIR, f"{model_name}.joblib")
27
+
28
+
29
  def model_meta_path(model_name: str) -> str:
30
  return os.path.join(META_DIR, f"{model_name}.json")
31
 
 
42
  return torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
 
44
 
45
+ # ---------------------------------------------------------------------------
46
+ # Save / load
47
+ # ---------------------------------------------------------------------------
48
+
49
  def save_model(model: nn.Module, model_name: str, config: dict, training_summary: dict):
50
+ if config["model_type"] == "fc_head":
51
+ state_dict = {k: v.detach().cpu() for k, v in model.classifier.state_dict().items()}
52
+ else:
53
+ state_dict = {k: v.detach().cpu() for k, v in model.state_dict().items()}
54
 
55
+ torch.save(state_dict, model_weight_path(model_name))
 
 
 
 
 
56
 
57
  with open(model_meta_path(model_name), "w", encoding="utf-8") as f:
58
+ json.dump(
59
+ {
60
+ "model_name": model_name,
61
+ "config": config,
62
+ "training_summary": training_summary,
63
+ "created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
64
+ },
65
+ f,
66
+ indent=2,
67
+ ensure_ascii=False,
68
+ )
69
 
 
 
70
 
71
+ def _load_meta(model_name: str) -> dict:
72
+ path = model_meta_path(model_name)
73
+ if not os.path.exists(path):
74
+ raise FileNotFoundError(f"Métadonnées introuvables : {model_name}")
75
+ with open(path, "r", encoding="utf-8") as f:
76
+ return json.load(f)
77
 
 
 
78
 
79
+ def load_model(model_name: str, device: torch.device) -> Tuple[nn.Module, dict]:
80
+ meta = _load_meta(model_name)
81
  cfg = meta["config"]
82
+ model_type = cfg.get("model_type", "cnn")
83
+
84
+ if model_type == "fc_head":
85
+ from backbone_utils import load_backbone
86
+ backbone = load_backbone(device)
87
+ model = BackboneWithFC(backbone, cfg["num_classes"], cfg.get("dropout", 0.4), cfg.get("fc_dim", 256))
88
+ model.classifier.load_state_dict(
89
+ torch.load(model_weight_path(model_name), map_location="cpu")
90
  )
91
+
92
+ elif model_type == "cnn":
93
  model = SimpleCNN(
94
  num_classes=cfg["num_classes"],
95
  num_conv_blocks=cfg.get("num_conv_blocks", 3),
 
99
  dropout=cfg.get("dropout", 0.4),
100
  fc_dim=cfg.get("fc_dim", 256),
101
  )
102
+ model.load_state_dict(torch.load(model_weight_path(model_name), map_location="cpu"))
103
+
104
+ else:
105
+ raise ValueError(f"load_model n'accepte pas le type '{model_type}'. Utilisez load_classical_pipeline pour les modèles ML classiques.")
106
 
 
 
107
  model.to(device)
108
  model.eval()
 
109
  return model, meta
110
 
111
 
112
+ # ---------------------------------------------------------------------------
113
+ # Training helpers
114
+ # ---------------------------------------------------------------------------
115
+
116
  def evaluate_loss_acc(model, loader, criterion, device):
117
  model.eval()
118
+ total_loss, total, correct = 0.0, 0, 0
 
 
 
119
 
120
  with torch.no_grad():
121
  for images, labels in loader:
122
  images, labels = images.to(device), labels.to(device)
 
123
  outputs = model(images)
124
  loss = criterion(outputs, labels)
 
125
  total_loss += loss.item() * images.size(0)
126
+ correct += (outputs.argmax(1) == labels).sum().item()
 
 
127
  total += labels.size(0)
128
 
129
+ return (total_loss / total if total else 0.0), (correct / total if total else 0.0)
 
 
 
130
 
131
 
132
  def collect_predictions(model, loader, device):
133
  model.eval()
134
+ y_true, y_pred = [], []
 
 
135
 
136
  with torch.no_grad():
137
  for images, labels in loader:
138
+ outputs = model(images.to(device))
139
+ y_pred.extend(outputs.argmax(1).detach().cpu().tolist())
 
 
 
 
140
  y_true.extend(labels.tolist())
141
 
142
  return y_true, y_pred
143
 
144
 
145
+ def _training_loop(model, train_loader, val_loader, criterion, optimizer, scheduler, epochs, device):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  history = []
147
  logs = []
 
 
148
  best_val_loss = float("inf")
149
+ best_state = None
150
 
151
  for epoch in range(1, epochs + 1):
152
  model.train()
153
+ running_loss, total, correct = 0.0, 0, 0
 
 
 
154
 
155
  for images, labels in train_loader:
156
  images, labels = images.to(device), labels.to(device)
 
157
  optimizer.zero_grad()
158
  outputs = model(images)
 
159
  loss = criterion(outputs, labels)
160
  loss.backward()
 
 
161
  torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
 
162
  optimizer.step()
163
 
164
  running_loss += loss.item() * images.size(0)
165
+ correct += (outputs.argmax(1) == labels).sum().item()
 
 
166
  total += labels.size(0)
167
 
168
  train_loss = running_loss / total if total else 0.0
169
  train_acc = correct / total if total else 0.0
 
170
  val_loss, val_acc = evaluate_loss_acc(model, val_loader, criterion, device)
171
  scheduler.step(val_loss)
172
  current_lr = optimizer.param_groups[0]["lr"]
173
 
174
  if val_loss < best_val_loss:
175
  best_val_loss = val_loss
176
+ best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
 
 
 
177
 
178
+ history.append({
179
  "epoch": epoch,
180
  "train_loss": round(train_loss, 4),
181
  "train_acc": round(train_acc, 4),
182
  "val_loss": round(val_loss, 4),
183
  "val_acc": round(val_acc, 4),
184
+ })
 
 
 
185
  logs.append(
186
  f"Époque {epoch}/{epochs} | "
187
+ f"perte train={train_loss:.4f} acc train={train_acc:.4f} | "
188
+ f"perte val={val_loss:.4f} acc val={val_acc:.4f} | "
189
+ f"lr={current_lr:.2e}"
190
  )
191
 
192
+ return history, logs, best_state, best_val_loss
 
193
 
194
+
195
+ # ---------------------------------------------------------------------------
196
+ # Train FC head on frozen backbone
197
+ # ---------------------------------------------------------------------------
198
+
199
+ def train_fc_head(
200
+ dropout: float = 0.4,
201
+ fc_dim: int = 256,
202
+ learning_rate: float = 1e-4,
203
+ weight_decay: float = 1e-4,
204
+ batch_size: int = 16,
205
+ epochs: int = 20,
206
+ model_tag: str = "",
207
+ ):
208
+ from backbone_utils import load_backbone
209
+
210
+ device = get_runtime_device()
211
+ train_loader, val_loader, test_loader, class_names = make_loaders(batch_size)
212
+ num_classes = len(class_names)
213
+
214
+ backbone = load_backbone(device)
215
+
216
+ model = BackboneWithFC(backbone, num_classes, dropout, fc_dim).to(device)
217
+
218
+ trainable_params = sum(p.numel() for p in model.classifier.parameters())
219
+ total_params = sum(p.numel() for p in model.parameters())
220
+
221
+ criterion = nn.CrossEntropyLoss()
222
+ optimizer = optim.AdamW(model.classifier.parameters(), lr=learning_rate, weight_decay=weight_decay)
223
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
224
+ optimizer, mode="min", factor=0.5, patience=5, min_lr=learning_rate * 0.1
225
+ )
226
+
227
+ t0 = time.time()
228
+ history, logs, best_state, best_val_loss = _training_loop(
229
+ model, train_loader, val_loader, criterion, optimizer, scheduler, epochs, device
230
+ )
231
+
232
+ model.load_state_dict(best_state)
233
  test_loss, test_acc = evaluate_loss_acc(model, test_loader, criterion, device)
234
  y_true, y_pred = collect_predictions(model, test_loader, device)
 
235
  metrics = compute_classification_metrics(y_true, y_pred, class_names)
236
+ elapsed = time.time() - t0
237
 
238
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
239
+ safe_tag = model_tag.strip().replace(" ", "_") if model_tag.strip() else "fc_head"
240
+ model_name = f"{safe_tag}_{timestamp}"
241
+
242
+ cm_path = save_confusion_matrix_figure(metrics["confusion_matrix"], model_name)
243
+
244
+ config = {
245
+ "dataset_name": DATASET_DISPLAY_NAME,
246
+ "model_type": "fc_head",
247
+ "architecture": f"ResNet18 backbone (gelé) + FC({fc_dim})",
248
+ "num_classes": num_classes,
249
+ "class_names": class_names,
250
+ "dropout": dropout,
251
+ "fc_dim": fc_dim,
252
+ "learning_rate": learning_rate,
253
+ "weight_decay": weight_decay,
254
+ "batch_size": batch_size,
255
+ "epochs": epochs,
256
+ }
257
+
258
+ training_summary = {
259
+ "final_train_loss": history[-1]["train_loss"] if history else None,
260
+ "final_train_acc": history[-1]["train_acc"] if history else None,
261
+ "best_val_loss": round(best_val_loss, 4),
262
+ "final_val_loss": history[-1]["val_loss"] if history else None,
263
+ "final_val_acc": history[-1]["val_acc"] if history else None,
264
+ "test_cross_entropy_loss": round(test_loss, 4),
265
+ "test_accuracy": round(test_acc, 4),
266
+ "test_f1_macro": metrics["f1_macro"],
267
+ "test_f1_weighted": metrics["f1_weighted"],
268
+ "elapsed_seconds": round(elapsed, 2),
269
+ "device": str(device),
270
+ "total_params": total_params,
271
+ "trainable_params": trainable_params,
272
+ }
273
+
274
+ save_model(model, model_name, config, training_summary)
275
+
276
+ logs += [
277
+ "",
278
+ "Entraînement terminé.",
279
+ f"Modèle sauvegardé : {model_name}",
280
+ f"Architecture : {config['architecture']}",
281
+ f"Paramètres entraînables : {trainable_params} / {total_params}",
282
+ f"Perte test : {test_loss:.4f} | Accuracy test : {test_acc:.4f}",
283
+ f"F1 macro : {metrics['f1_macro']:.4f} | F1 pondéré : {metrics['f1_weighted']:.4f}",
284
+ f"Temps : {elapsed:.1f}s | Appareil : {device}",
285
+ ]
286
+
287
+ return {
288
+ "logs": "\n".join(logs),
289
+ "history": history,
290
+ "summary": training_summary,
291
+ "model_name": model_name,
292
+ "classification_report": metrics["classification_report"],
293
+ "confusion_matrix": metrics["confusion_matrix"],
294
+ "confusion_matrix_path": cm_path,
295
+ }
296
+
297
+
298
+ # ---------------------------------------------------------------------------
299
+ # Train SimpleCNN from scratch
300
+ # ---------------------------------------------------------------------------
301
+
302
+ def train_cnn(
303
+ num_conv_blocks: int = 3,
304
+ base_filters: int = 32,
305
+ kernel_size: int = 3,
306
+ use_batchnorm: bool = True,
307
+ dropout: float = 0.4,
308
+ fc_dim: int = 256,
309
+ learning_rate: float = 1e-3,
310
+ weight_decay: float = 1e-4,
311
+ batch_size: int = 16,
312
+ epochs: int = 30,
313
+ model_tag: str = "",
314
+ ):
315
+ device = get_runtime_device()
316
+ train_loader, val_loader, test_loader, class_names = make_loaders(batch_size)
317
+ num_classes = len(class_names)
318
+
319
+ model = SimpleCNN(
320
+ num_classes=num_classes,
321
+ num_conv_blocks=num_conv_blocks,
322
+ base_filters=base_filters,
323
+ kernel_size=kernel_size,
324
+ use_batchnorm=use_batchnorm,
325
+ dropout=dropout,
326
+ fc_dim=fc_dim,
327
+ ).to(device)
328
+
329
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
330
+ total_params = sum(p.numel() for p in model.parameters())
331
+
332
+ criterion = nn.CrossEntropyLoss()
333
+ optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
334
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
335
+ optimizer, mode="min", factor=0.5, patience=8, min_lr=learning_rate * 0.2
336
+ )
337
+
338
+ t0 = time.time()
339
+ history, logs, best_state, best_val_loss = _training_loop(
340
+ model, train_loader, val_loader, criterion, optimizer, scheduler, epochs, device
341
+ )
342
+
343
+ model.load_state_dict(best_state)
344
+ test_loss, test_acc = evaluate_loss_acc(model, test_loader, criterion, device)
345
+ y_true, y_pred = collect_predictions(model, test_loader, device)
346
+ metrics = compute_classification_metrics(y_true, y_pred, class_names)
347
+ elapsed = time.time() - t0
348
 
349
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
350
+ safe_tag = model_tag.strip().replace(" ", "_") if model_tag.strip() else "cnn"
351
  model_name = f"{safe_tag}_{timestamp}"
352
 
353
  cm_path = save_confusion_matrix_figure(metrics["confusion_matrix"], model_name)
354
 
355
+ architecture = f"CNN simple ({num_conv_blocks} blocs, filtres={base_filters}, noyau={kernel_size}×{kernel_size})"
 
 
 
356
 
357
  config = {
358
  "dataset_name": DATASET_DISPLAY_NAME,
359
+ "model_type": "cnn",
360
  "architecture": architecture,
361
  "num_classes": num_classes,
362
  "class_names": class_names,
 
390
 
391
  save_model(model, model_name, config, training_summary)
392
 
393
+ logs += [
394
+ "",
395
+ "Entraînement terminé.",
396
+ f"Modèle sauvegardé : {model_name}",
397
+ f"Architecture : {architecture}",
398
+ f"Paramètres : {total_params}",
399
+ f"Perte test : {test_loss:.4f} | Accuracy test : {test_acc:.4f}",
400
+ f"F1 macro : {metrics['f1_macro']:.4f} | F1 pondéré : {metrics['f1_weighted']:.4f}",
401
+ f"Temps : {elapsed:.1f}s | Appareil : {device}",
402
+ ]
 
 
403
 
404
  return {
405
  "logs": "\n".join(logs),
 
412
  }
413
 
414
 
415
+ # ---------------------------------------------------------------------------
416
+ # Evaluate any saved model
417
+ # ---------------------------------------------------------------------------
418
+
419
  def evaluate_saved_model(model_name: str):
420
  if not model_name:
421
  raise ValueError("Aucun modèle sélectionné.")
422
 
423
+ meta = _load_meta(model_name)
424
+ model_type = meta["config"].get("model_type", "cnn")
425
+
426
+ if model_type in CLASSICAL_MODEL_TYPES:
427
+ return _evaluate_classical(model_name, meta)
428
+ else:
429
+ return _evaluate_neural(model_name, meta)
430
+
431
+
432
+ def _evaluate_neural(model_name: str, meta: dict):
433
  device = get_runtime_device()
434
  model, meta = load_model(model_name, device)
435
 
 
437
  _, _, test_loader, class_names = make_loaders(batch_size)
438
 
439
  criterion = nn.CrossEntropyLoss()
 
440
  test_loss, test_acc = evaluate_loss_acc(model, test_loader, criterion, device)
441
  y_true, y_pred = collect_predictions(model, test_loader, device)
442
 
443
  metrics = compute_classification_metrics(y_true, y_pred, class_names)
444
  cm_path = save_confusion_matrix_figure(metrics["confusion_matrix"], model_name)
445
 
446
+ return (
447
+ {
448
+ "test_cross_entropy_loss": round(test_loss, 4),
449
+ "test_accuracy": round(test_acc, 4),
450
+ "test_f1_macro": metrics["f1_macro"],
451
+ "test_f1_weighted": metrics["f1_weighted"],
452
+ "device": str(device),
453
+ },
454
+ metrics["classification_report"],
455
+ metrics["confusion_matrix"],
456
+ cm_path,
457
+ )
458
+
459
+
460
+ def _evaluate_classical(model_name: str, meta: dict):
461
+ from backbone_utils import get_cached_features, extract_all_features
462
+ from classical_ml_utils import load_classical_pipeline
463
+
464
+ features_cache = get_cached_features()
465
+ if features_cache is None:
466
+ features_cache, _, _ = extract_all_features()
467
 
468
+ class_names = meta["config"]["class_names"]
469
+ pipeline = load_classical_pipeline(model_name)
470
+
471
+ X_test = features_cache["test"]["X"]
472
+ y_test = features_cache["test"]["y"]
473
+ y_pred = pipeline.predict(X_test)
474
+
475
+ metrics = compute_classification_metrics(y_test.tolist(), y_pred.tolist(), class_names)
476
+ cm_path = save_confusion_matrix_figure(metrics["confusion_matrix"], model_name)
477
+
478
+ return (
479
+ {
480
+ "test_accuracy": metrics["accuracy"],
481
+ "test_f1_macro": metrics["f1_macro"],
482
+ "test_f1_weighted": metrics["f1_weighted"],
483
+ },
484
+ metrics["classification_report"],
485
+ metrics["confusion_matrix"],
486
+ cm_path,
487
+ )