Spaces:
Sleeping
Restructurer l'app : backbone préentraîné + ML classique + FC head + CNN de zéro
Browse files- Ajout backbone_utils.py : chargement du backbone ResNet18 depuis HF, extraction
de features 512-dim avec cache mémoire
- Ajout classical_ml_utils.py : SVM / LogReg / k-NN / RF / LDA sur les features
extraites (pipeline sklearn avec StandardScaler + joblib)
- Refactorisation train_utils.py : train_fc_head (tête FC seule, ~200Ko sauvegardés)
et train_cnn (SimpleCNN de zéro) ; evaluate_saved_model unifié pour tous les types
- Mise à jour model.py : BackboneWithFC (backbone gelé + tête FC) + SimpleCNN conservé
- Mise à jour predict_utils.py : dispatch automatique selon model_type
- Mise à jour app.py : 4 onglets (dataset / ML classique / neuronaux / test-prédiction)
- Ajout config.py : HF_BACKBONE_REPO, CLASSICAL_MODEL_TYPES
- Ajout .gitignore : exclut data/, backbone/, saved_models/, __pycache__/
- Ajout finetune_backbone.py : script local pour entraîner le backbone sur les données
- .gitignore +8 -0
- app.py +308 -295
- backbone_utils.py +81 -0
- classical_ml_utils.py +140 -0
- config.py +4 -4
- data_utils.py +1 -1
- finetune_backbone.py +245 -0
- model.py +11 -26
- predict_utils.py +57 -38
- train_utils.py +291 -170
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data/
|
| 2 |
+
backbone/
|
| 3 |
+
saved_models/
|
| 4 |
+
saved_models_meta/
|
| 5 |
+
saved_figures/
|
| 6 |
+
__pycache__/
|
| 7 |
+
*.pyc
|
| 8 |
+
.DS_Store
|
|
@@ -3,90 +3,145 @@ import json
|
|
| 3 |
import gradio as gr
|
| 4 |
import spaces
|
| 5 |
|
| 6 |
-
from
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
)
|
| 11 |
from train_utils import (
|
| 12 |
-
|
| 13 |
list_saved_models,
|
| 14 |
model_meta_path,
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
from predict_utils import (
|
| 18 |
-
predict_uploaded_image,
|
| 19 |
-
test_random_sample,
|
| 20 |
)
|
| 21 |
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
-
def
|
| 24 |
try:
|
| 25 |
summary, distribution_df = dataset_overview()
|
| 26 |
class_names = ["Toutes les classes"] + get_class_names()
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
-
return (
|
| 29 |
-
summary,
|
| 30 |
-
distribution_df,
|
| 31 |
-
gr.update(choices=class_names, value="Toutes les classes"),
|
| 32 |
-
)
|
| 33 |
|
|
|
|
|
|
|
|
|
|
| 34 |
except Exception as e:
|
| 35 |
-
return (
|
| 36 |
-
{"Erreur": str(e)},
|
| 37 |
-
None,
|
| 38 |
-
gr.update(),
|
| 39 |
-
)
|
| 40 |
|
| 41 |
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
try:
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
)
|
| 49 |
-
return gallery
|
| 50 |
except Exception as e:
|
| 51 |
-
return
|
| 52 |
|
| 53 |
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
return gr.update(visible=is_cnn), gr.update(value=default_lr)
|
| 58 |
|
| 59 |
|
| 60 |
-
@spaces.GPU(duration=
|
| 61 |
-
def
|
| 62 |
model_type,
|
| 63 |
-
num_conv_blocks,
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
use_batchnorm,
|
| 67 |
-
dropout,
|
| 68 |
-
fc_dim,
|
| 69 |
-
learning_rate,
|
| 70 |
-
weight_decay,
|
| 71 |
-
batch_size,
|
| 72 |
-
epochs,
|
| 73 |
model_tag,
|
| 74 |
):
|
| 75 |
try:
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
models = list_saved_models()
|
| 92 |
selected = result["model_name"] if result["model_name"] in models else None
|
|
@@ -100,21 +155,31 @@ def train_callback(
|
|
| 100 |
result["confusion_matrix_path"],
|
| 101 |
gr.update(choices=models, value=selected),
|
| 102 |
)
|
| 103 |
-
|
| 104 |
except Exception as e:
|
| 105 |
-
return (
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
|
| 116 |
@spaces.GPU(duration=120)
|
| 117 |
-
def
|
| 118 |
try:
|
| 119 |
summary, report_df, cm_df, cm_path = evaluate_saved_model(model_name)
|
| 120 |
return summary, report_df, cm_df, cm_path
|
|
@@ -123,269 +188,219 @@ def evaluate_saved_model_callback(model_name):
|
|
| 123 |
|
| 124 |
|
| 125 |
@spaces.GPU(duration=60)
|
| 126 |
-
def
|
| 127 |
try:
|
| 128 |
return predict_uploaded_image(model_name, image)
|
| 129 |
except Exception as e:
|
| 130 |
-
return f"Échec
|
| 131 |
|
| 132 |
|
| 133 |
@spaces.GPU(duration=60)
|
| 134 |
-
def
|
| 135 |
try:
|
| 136 |
return test_random_sample(model_name)
|
| 137 |
except Exception as e:
|
| 138 |
-
return None, f"Échec
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
def refresh_models_dropdown():
|
| 142 |
-
models = list_saved_models()
|
| 143 |
-
return gr.update(choices=models, value=models[0] if models else None)
|
| 144 |
|
| 145 |
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
meta_file = model_meta_path(model_name)
|
| 151 |
-
|
| 152 |
-
try:
|
| 153 |
-
with open(meta_file, "r", encoding="utf-8") as f:
|
| 154 |
-
return json.load(f)
|
| 155 |
-
except FileNotFoundError:
|
| 156 |
-
return {"message": "Métadonnées introuvables."}
|
| 157 |
-
|
| 158 |
|
| 159 |
initial_models = list_saved_models()
|
| 160 |
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
gr.Markdown("# Classification d’images microscopiques de charbons de bois")
|
| 164 |
gr.Markdown(
|
| 165 |
-
"Application pédagogique
|
| 166 |
-
"
|
|
|
|
| 167 |
)
|
| 168 |
|
| 169 |
with gr.Tabs():
|
| 170 |
|
|
|
|
|
|
|
|
|
|
| 171 |
with gr.Tab("1. Explorer le jeu de données"):
|
| 172 |
-
gr.Markdown("## Comprendre le jeu de données avant l
|
| 173 |
-
|
| 174 |
-
load_dataset_btn = gr.Button(
|
| 175 |
-
"Charger les informations du dataset",
|
| 176 |
-
variant="primary",
|
| 177 |
-
)
|
| 178 |
-
|
| 179 |
-
dataset_summary = gr.JSON(label="Résumé général du dataset")
|
| 180 |
|
|
|
|
|
|
|
| 181 |
class_distribution = gr.Dataframe(
|
| 182 |
-
label="Distribution
|
| 183 |
-
interactive=False,
|
| 184 |
)
|
| 185 |
|
| 186 |
gr.Markdown("## Visualisation des images")
|
| 187 |
-
|
| 188 |
with gr.Row():
|
| 189 |
split_selector = gr.Dropdown(
|
| 190 |
-
choices=["train", "validation", "test"],
|
| 191 |
-
value="train",
|
| 192 |
-
label="Split",
|
| 193 |
)
|
| 194 |
class_selector = gr.Dropdown(
|
| 195 |
-
choices=["Toutes les classes"],
|
| 196 |
-
value="Toutes les classes",
|
| 197 |
-
label="Classe",
|
| 198 |
-
)
|
| 199 |
-
max_images = gr.Slider(
|
| 200 |
-
minimum=4,
|
| 201 |
-
maximum=48,
|
| 202 |
-
value=24,
|
| 203 |
-
step=4,
|
| 204 |
-
label="Nombre d’images à afficher",
|
| 205 |
)
|
|
|
|
| 206 |
|
| 207 |
refresh_gallery_btn = gr.Button("Afficher des exemples")
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
)
|
| 214 |
|
| 215 |
-
|
| 216 |
-
gr.
|
|
|
|
|
|
|
| 217 |
|
| 218 |
with gr.Row():
|
| 219 |
with gr.Column():
|
| 220 |
-
|
| 221 |
-
choices=["
|
| 222 |
-
value="
|
| 223 |
-
label="
|
| 224 |
-
info=(
|
| 225 |
-
"CNN simple : entraîné de zéro, paramètres configurables. "
|
| 226 |
-
"ResNet18 : pré-entraîné ImageNet, fine-tuning layer4 + classifieur."
|
| 227 |
-
),
|
| 228 |
)
|
| 229 |
|
| 230 |
-
with gr.Column(visible=True) as
|
| 231 |
-
gr.Markdown("#### Paramètres
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
)
|
| 240 |
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
)
|
| 246 |
-
|
| 247 |
-
kernel_size = gr.Dropdown(
|
| 248 |
-
choices=[3, 5],
|
| 249 |
-
value=3,
|
| 250 |
-
label="Taille du noyau de convolution",
|
| 251 |
-
)
|
| 252 |
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
)
|
| 257 |
|
| 258 |
-
gr.
|
|
|
|
| 259 |
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
maximum=0.8,
|
| 263 |
-
value=0.4,
|
| 264 |
-
step=0.05,
|
| 265 |
-
label="Dropout",
|
| 266 |
-
)
|
| 267 |
-
|
| 268 |
-
fc_dim = gr.Dropdown(
|
| 269 |
-
choices=[64, 128, 256, 512],
|
| 270 |
-
value=256,
|
| 271 |
-
label="Dimension de la couche cachée (classifieur)",
|
| 272 |
-
)
|
| 273 |
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
)
|
| 278 |
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
|
|
|
| 283 |
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
)
|
| 289 |
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
maximum=
|
| 293 |
-
value=
|
| 294 |
-
|
| 295 |
-
label="
|
| 296 |
-
)
|
| 297 |
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
|
| 303 |
-
|
| 304 |
|
| 305 |
with gr.Column():
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
)
|
| 310 |
-
train_history = gr.JSON(label="Historique d’entraînement")
|
| 311 |
-
train_summary = gr.JSON(label="Résumé final")
|
| 312 |
|
| 313 |
gr.Markdown("## Résultats sur le test set")
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
interactive=False,
|
| 323 |
-
)
|
| 324 |
-
|
| 325 |
-
train_confusion_matrix_image = gr.Image(
|
| 326 |
-
label="Matrice de confusion - figure",
|
| 327 |
-
type="filepath",
|
| 328 |
-
)
|
| 329 |
-
|
| 330 |
-
with gr.Tab("3. Tester et analyser un modèle"):
|
| 331 |
gr.Markdown("## Sélectionner un modèle sauvegardé")
|
|
|
|
|
|
|
|
|
|
| 332 |
|
| 333 |
with gr.Row():
|
| 334 |
with gr.Column():
|
| 335 |
model_selector = gr.Dropdown(
|
| 336 |
choices=initial_models,
|
| 337 |
value=initial_models[0] if initial_models else None,
|
| 338 |
-
label="Modèle
|
| 339 |
)
|
| 340 |
-
|
| 341 |
-
refresh_btn = gr.Button("Actualiser la liste des modèles")
|
| 342 |
load_info_btn = gr.Button("Afficher les informations du modèle")
|
| 343 |
-
model_info = gr.JSON(label="Métadonnées
|
| 344 |
|
| 345 |
with gr.Column():
|
| 346 |
-
evaluate_btn = gr.Button(
|
| 347 |
-
"Évaluer le modèle sur le test set",
|
| 348 |
-
variant="primary",
|
| 349 |
-
)
|
| 350 |
eval_summary = gr.JSON(label="Résumé des métriques")
|
| 351 |
|
| 352 |
-
eval_report = gr.Dataframe(
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
)
|
| 356 |
-
|
| 357 |
-
eval_confusion_matrix = gr.Dataframe(
|
| 358 |
-
label="Matrice de confusion",
|
| 359 |
-
interactive=False,
|
| 360 |
-
)
|
| 361 |
-
|
| 362 |
-
eval_confusion_matrix_image = gr.Image(
|
| 363 |
-
label="Matrice de confusion - figure",
|
| 364 |
-
type="filepath",
|
| 365 |
-
)
|
| 366 |
|
| 367 |
gr.Markdown("## Prédiction sur une image importée")
|
| 368 |
-
|
| 369 |
with gr.Row():
|
| 370 |
with gr.Column():
|
| 371 |
upload_image = gr.Image(type="pil", label="Importer une image")
|
| 372 |
predict_btn = gr.Button("Prédire la classe", variant="primary")
|
| 373 |
-
|
| 374 |
with gr.Column():
|
| 375 |
-
predict_text = gr.Textbox(label="Résultat
|
| 376 |
predict_probs = gr.Label(label="Probabilités par classe")
|
| 377 |
|
| 378 |
gr.Markdown("## Test sur un échantillon aléatoire du test set")
|
| 379 |
-
|
| 380 |
random_test_btn = gr.Button("Tester un échantillon aléatoire")
|
| 381 |
-
|
| 382 |
with gr.Row():
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
|
| 387 |
load_dataset_btn.click(
|
| 388 |
-
fn=
|
| 389 |
inputs=None,
|
| 390 |
outputs=[dataset_summary, class_distribution, class_selector],
|
| 391 |
)
|
|
@@ -396,74 +411,72 @@ with gr.Blocks(title="Classification d’images microscopiques") as demo:
|
|
| 396 |
outputs=image_gallery,
|
| 397 |
)
|
| 398 |
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
|
|
|
|
|
|
| 403 |
)
|
| 404 |
|
| 405 |
-
|
| 406 |
-
fn=
|
| 407 |
inputs=[
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
],
|
| 421 |
outputs=[
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
train_summary,
|
| 425 |
-
train_report,
|
| 426 |
-
train_confusion_matrix,
|
| 427 |
-
train_confusion_matrix_image,
|
| 428 |
model_selector,
|
| 429 |
],
|
| 430 |
)
|
| 431 |
|
| 432 |
-
refresh_btn.click(
|
| 433 |
-
fn=refresh_models_dropdown,
|
| 434 |
-
inputs=None,
|
| 435 |
-
outputs=model_selector,
|
| 436 |
-
)
|
| 437 |
|
| 438 |
-
load_info_btn.click(
|
| 439 |
-
fn=get_model_info,
|
| 440 |
-
inputs=model_selector,
|
| 441 |
-
outputs=model_info,
|
| 442 |
-
)
|
| 443 |
|
| 444 |
evaluate_btn.click(
|
| 445 |
-
fn=
|
| 446 |
inputs=model_selector,
|
| 447 |
-
outputs=[
|
| 448 |
-
eval_summary,
|
| 449 |
-
eval_report,
|
| 450 |
-
eval_confusion_matrix,
|
| 451 |
-
eval_confusion_matrix_image,
|
| 452 |
-
],
|
| 453 |
)
|
| 454 |
|
| 455 |
predict_btn.click(
|
| 456 |
-
fn=
|
| 457 |
inputs=[model_selector, upload_image],
|
| 458 |
outputs=[predict_text, predict_probs],
|
| 459 |
)
|
| 460 |
|
| 461 |
random_test_btn.click(
|
| 462 |
-
fn=
|
| 463 |
inputs=model_selector,
|
| 464 |
-
outputs=[
|
| 465 |
)
|
| 466 |
|
| 467 |
|
| 468 |
if __name__ == "__main__":
|
| 469 |
-
demo.launch(ssr_mode=False)
|
|
|
|
| 3 |
import gradio as gr
|
| 4 |
import spaces
|
| 5 |
|
| 6 |
+
from backbone_utils import extract_all_features, get_cached_features
|
| 7 |
+
from classical_ml_utils import train_classical_model
|
| 8 |
+
from data_utils import dataset_overview, get_class_names, get_images_for_gallery
|
| 9 |
+
from predict_utils import predict_uploaded_image, test_random_sample
|
|
|
|
| 10 |
from train_utils import (
|
| 11 |
+
evaluate_saved_model,
|
| 12 |
list_saved_models,
|
| 13 |
model_meta_path,
|
| 14 |
+
train_cnn,
|
| 15 |
+
train_fc_head,
|
|
|
|
|
|
|
|
|
|
| 16 |
)
|
| 17 |
|
| 18 |
+
# ---------------------------------------------------------------------------
|
| 19 |
+
# Tab 1 — Dataset
|
| 20 |
+
# ---------------------------------------------------------------------------
|
| 21 |
|
| 22 |
+
def load_dataset_callback():
|
| 23 |
try:
|
| 24 |
summary, distribution_df = dataset_overview()
|
| 25 |
class_names = ["Toutes les classes"] + get_class_names()
|
| 26 |
+
return summary, distribution_df, gr.update(choices=class_names, value="Toutes les classes")
|
| 27 |
+
except Exception as e:
|
| 28 |
+
return {"Erreur": str(e)}, None, gr.update()
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
+
def refresh_gallery_callback(split_name, class_name, max_images):
|
| 32 |
+
try:
|
| 33 |
+
return get_images_for_gallery(split_name, class_name, int(max_images))
|
| 34 |
except Exception as e:
|
| 35 |
+
return [(None, f"Erreur : {e}")]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
|
| 38 |
+
# ---------------------------------------------------------------------------
|
| 39 |
+
# Tab 2 — ML classique
|
| 40 |
+
# ---------------------------------------------------------------------------
|
| 41 |
+
|
| 42 |
+
def extract_features_callback():
|
| 43 |
try:
|
| 44 |
+
_, class_names, counts = extract_all_features()
|
| 45 |
+
lines = [f"Extraction terminée ({len(class_names)} classes)"]
|
| 46 |
+
for split, n in counts.items():
|
| 47 |
+
lines.append(f" {split} : {n} images")
|
| 48 |
+
return "\n".join(lines)
|
| 49 |
+
except Exception as e:
|
| 50 |
+
return f"Erreur lors de l'extraction :\n{e}"
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def on_clf_type_change(clf_type):
|
| 54 |
+
show = lambda t: gr.update(visible=(clf_type == t))
|
| 55 |
+
return show("SVM"), show("Régression logistique"), show("k-NN"), show("Forêt aléatoire"), show("LDA")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def train_classical_callback(
|
| 59 |
+
clf_type,
|
| 60 |
+
svm_c, svm_kernel, svm_gamma,
|
| 61 |
+
logreg_c, logreg_max_iter,
|
| 62 |
+
knn_k, knn_metric,
|
| 63 |
+
rf_n_estimators, rf_max_depth,
|
| 64 |
+
lda_solver,
|
| 65 |
+
model_tag,
|
| 66 |
+
):
|
| 67 |
+
try:
|
| 68 |
+
features_cache = get_cached_features()
|
| 69 |
+
if features_cache is None:
|
| 70 |
+
return {"Erreur": "Veuillez d'abord extraire les caractéristiques (bouton ci-dessus)."}, None, None, None, gr.update()
|
| 71 |
+
|
| 72 |
+
params = {}
|
| 73 |
+
if clf_type == "SVM":
|
| 74 |
+
params = {"C": float(svm_c), "kernel": svm_kernel, "gamma": svm_gamma}
|
| 75 |
+
elif clf_type == "Régression logistique":
|
| 76 |
+
params = {"C": float(logreg_c), "max_iter": int(logreg_max_iter)}
|
| 77 |
+
elif clf_type == "k-NN":
|
| 78 |
+
params = {"n_neighbors": int(knn_k), "metric": knn_metric}
|
| 79 |
+
elif clf_type == "Forêt aléatoire":
|
| 80 |
+
depth = int(rf_max_depth) if rf_max_depth and int(rf_max_depth) > 0 else None
|
| 81 |
+
params = {"n_estimators": int(rf_n_estimators), "max_depth": depth}
|
| 82 |
+
elif clf_type == "LDA":
|
| 83 |
+
params = {"solver": lda_solver}
|
| 84 |
+
|
| 85 |
+
class_names = get_class_names()
|
| 86 |
+
result = train_classical_model(clf_type, features_cache, class_names, model_tag, **params)
|
| 87 |
+
|
| 88 |
+
models = list_saved_models()
|
| 89 |
+
selected = result["model_name"] if result["model_name"] in models else None
|
| 90 |
+
|
| 91 |
+
return (
|
| 92 |
+
result["summary"],
|
| 93 |
+
result["classification_report"],
|
| 94 |
+
result["confusion_matrix"],
|
| 95 |
+
result["confusion_matrix_path"],
|
| 96 |
+
gr.update(choices=models, value=selected),
|
| 97 |
)
|
|
|
|
| 98 |
except Exception as e:
|
| 99 |
+
return {"Erreur": str(e)}, None, None, None, gr.update()
|
| 100 |
|
| 101 |
|
| 102 |
+
# ---------------------------------------------------------------------------
|
| 103 |
+
# Tab 3 — Modèles neuronaux
|
| 104 |
+
# ---------------------------------------------------------------------------
|
| 105 |
+
|
| 106 |
+
def on_neural_type_change(model_type):
|
| 107 |
+
is_cnn = (model_type == "CNN de zéro")
|
| 108 |
+
default_lr = 1e-3 if is_cnn else 1e-4
|
| 109 |
return gr.update(visible=is_cnn), gr.update(value=default_lr)
|
| 110 |
|
| 111 |
|
| 112 |
+
@spaces.GPU(duration=300)
|
| 113 |
+
def train_neural_callback(
|
| 114 |
model_type,
|
| 115 |
+
num_conv_blocks, base_filters, kernel_size, use_batchnorm,
|
| 116 |
+
dropout, fc_dim,
|
| 117 |
+
learning_rate, weight_decay, batch_size, epochs,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
model_tag,
|
| 119 |
):
|
| 120 |
try:
|
| 121 |
+
if model_type == "FC sur backbone préentraîné":
|
| 122 |
+
result = train_fc_head(
|
| 123 |
+
dropout=float(dropout),
|
| 124 |
+
fc_dim=int(fc_dim),
|
| 125 |
+
learning_rate=float(learning_rate),
|
| 126 |
+
weight_decay=float(weight_decay),
|
| 127 |
+
batch_size=int(batch_size),
|
| 128 |
+
epochs=int(epochs),
|
| 129 |
+
model_tag=model_tag,
|
| 130 |
+
)
|
| 131 |
+
else:
|
| 132 |
+
result = train_cnn(
|
| 133 |
+
num_conv_blocks=int(num_conv_blocks),
|
| 134 |
+
base_filters=int(base_filters),
|
| 135 |
+
kernel_size=int(kernel_size),
|
| 136 |
+
use_batchnorm=bool(use_batchnorm),
|
| 137 |
+
dropout=float(dropout),
|
| 138 |
+
fc_dim=int(fc_dim),
|
| 139 |
+
learning_rate=float(learning_rate),
|
| 140 |
+
weight_decay=float(weight_decay),
|
| 141 |
+
batch_size=int(batch_size),
|
| 142 |
+
epochs=int(epochs),
|
| 143 |
+
model_tag=model_tag,
|
| 144 |
+
)
|
| 145 |
|
| 146 |
models = list_saved_models()
|
| 147 |
selected = result["model_name"] if result["model_name"] in models else None
|
|
|
|
| 155 |
result["confusion_matrix_path"],
|
| 156 |
gr.update(choices=models, value=selected),
|
| 157 |
)
|
|
|
|
| 158 |
except Exception as e:
|
| 159 |
+
return f"Échec de l'entraînement :\n{e}", None, None, None, None, None, gr.update()
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
# ---------------------------------------------------------------------------
|
| 163 |
+
# Tab 4 — Tester et prédire
|
| 164 |
+
# ---------------------------------------------------------------------------
|
| 165 |
+
|
| 166 |
+
def refresh_models_callback():
|
| 167 |
+
models = list_saved_models()
|
| 168 |
+
return gr.update(choices=models, value=models[0] if models else None)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def get_model_info_callback(model_name):
|
| 172 |
+
if not model_name:
|
| 173 |
+
return {"message": "Aucun modèle sélectionné."}
|
| 174 |
+
try:
|
| 175 |
+
with open(model_meta_path(model_name), "r", encoding="utf-8") as f:
|
| 176 |
+
return json.load(f)
|
| 177 |
+
except FileNotFoundError:
|
| 178 |
+
return {"message": "Métadonnées introuvables."}
|
| 179 |
|
| 180 |
|
| 181 |
@spaces.GPU(duration=120)
|
| 182 |
+
def evaluate_callback(model_name):
|
| 183 |
try:
|
| 184 |
summary, report_df, cm_df, cm_path = evaluate_saved_model(model_name)
|
| 185 |
return summary, report_df, cm_df, cm_path
|
|
|
|
| 188 |
|
| 189 |
|
| 190 |
@spaces.GPU(duration=60)
|
| 191 |
+
def predict_callback(model_name, image):
|
| 192 |
try:
|
| 193 |
return predict_uploaded_image(model_name, image)
|
| 194 |
except Exception as e:
|
| 195 |
+
return f"Échec :\n{e}", None
|
| 196 |
|
| 197 |
|
| 198 |
@spaces.GPU(duration=60)
|
| 199 |
+
def random_test_callback(model_name):
|
| 200 |
try:
|
| 201 |
return test_random_sample(model_name)
|
| 202 |
except Exception as e:
|
| 203 |
+
return None, f"Échec :\n{e}", None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
|
| 205 |
|
| 206 |
+
# ---------------------------------------------------------------------------
|
| 207 |
+
# UI
|
| 208 |
+
# ---------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
|
| 210 |
initial_models = list_saved_models()
|
| 211 |
|
| 212 |
+
with gr.Blocks(title="Classification d'images microscopiques") as demo:
|
| 213 |
+
gr.Markdown("# Classification d'images microscopiques de charbons de bois")
|
|
|
|
| 214 |
gr.Markdown(
|
| 215 |
+
"Application pédagogique : explorez le jeu de données, entraînez des classifieurs "
|
| 216 |
+
"traditionnels ou neuronaux sur les caractéristiques extraites par un backbone "
|
| 217 |
+
"ResNet18 préentraîné, puis analysez et comparez les résultats."
|
| 218 |
)
|
| 219 |
|
| 220 |
with gr.Tabs():
|
| 221 |
|
| 222 |
+
# ------------------------------------------------------------------ #
|
| 223 |
+
# Tab 1
|
| 224 |
+
# ------------------------------------------------------------------ #
|
| 225 |
with gr.Tab("1. Explorer le jeu de données"):
|
| 226 |
+
gr.Markdown("## Comprendre le jeu de données avant l'entraînement")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
|
| 228 |
+
load_dataset_btn = gr.Button("Charger les informations du dataset", variant="primary")
|
| 229 |
+
dataset_summary = gr.JSON(label="Résumé général")
|
| 230 |
class_distribution = gr.Dataframe(
|
| 231 |
+
label="Distribution par split et par classe", interactive=False
|
|
|
|
| 232 |
)
|
| 233 |
|
| 234 |
gr.Markdown("## Visualisation des images")
|
|
|
|
| 235 |
with gr.Row():
|
| 236 |
split_selector = gr.Dropdown(
|
| 237 |
+
choices=["train", "validation", "test"], value="train", label="Split"
|
|
|
|
|
|
|
| 238 |
)
|
| 239 |
class_selector = gr.Dropdown(
|
| 240 |
+
choices=["Toutes les classes"], value="Toutes les classes", label="Classe"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
)
|
| 242 |
+
max_images = gr.Slider(minimum=4, maximum=48, value=24, step=4, label="Nombre d'images")
|
| 243 |
|
| 244 |
refresh_gallery_btn = gr.Button("Afficher des exemples")
|
| 245 |
+
image_gallery = gr.Gallery(label="Exemples d'images", columns=4, height=600)
|
| 246 |
+
|
| 247 |
+
# ------------------------------------------------------------------ #
|
| 248 |
+
# Tab 2
|
| 249 |
+
# ------------------------------------------------------------------ #
|
| 250 |
+
with gr.Tab("2. ML classique sur caractéristiques"):
|
| 251 |
+
gr.Markdown(
|
| 252 |
+
"## Étape 1 — Extraction des caractéristiques\n"
|
| 253 |
+
"Le backbone ResNet18 préentraîné sur les charbons extrait un vecteur de "
|
| 254 |
+
"512 dimensions par image. Cette étape s'exécute sur CPU et ne nécessite "
|
| 255 |
+
"aucun GPU."
|
| 256 |
)
|
| 257 |
|
| 258 |
+
extract_btn = gr.Button("Extraire les caractéristiques (backbone gelé)", variant="primary")
|
| 259 |
+
extract_status = gr.Textbox(label="Statut de l'extraction", lines=4, interactive=False)
|
| 260 |
+
|
| 261 |
+
gr.Markdown("## Étape 2 — Entraîner un classifieur")
|
| 262 |
|
| 263 |
with gr.Row():
|
| 264 |
with gr.Column():
|
| 265 |
+
clf_type = gr.Radio(
|
| 266 |
+
choices=["SVM", "Régression logistique", "k-NN", "Forêt aléatoire", "LDA"],
|
| 267 |
+
value="SVM",
|
| 268 |
+
label="Algorithme",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
)
|
| 270 |
|
| 271 |
+
with gr.Column(visible=True) as svm_col:
|
| 272 |
+
gr.Markdown("#### Paramètres SVM")
|
| 273 |
+
svm_c = gr.Number(value=1.0, label="C (régularisation)")
|
| 274 |
+
svm_kernel = gr.Dropdown(choices=["rbf", "linear", "poly"], value="rbf", label="Noyau")
|
| 275 |
+
svm_gamma = gr.Dropdown(choices=["scale", "auto"], value="scale", label="Gamma")
|
| 276 |
+
|
| 277 |
+
with gr.Column(visible=False) as logreg_col:
|
| 278 |
+
gr.Markdown("#### Paramètres Régression logistique")
|
| 279 |
+
logreg_c = gr.Number(value=1.0, label="C (régularisation)")
|
| 280 |
+
logreg_max_iter = gr.Number(value=1000, label="Itérations max")
|
| 281 |
+
|
| 282 |
+
with gr.Column(visible=False) as knn_col:
|
| 283 |
+
gr.Markdown("#### Paramètres k-NN")
|
| 284 |
+
knn_k = gr.Slider(minimum=1, maximum=20, value=5, step=1, label="k (voisins)")
|
| 285 |
+
knn_metric = gr.Dropdown(
|
| 286 |
+
choices=["euclidean", "cosine", "manhattan"], value="euclidean", label="Métrique"
|
| 287 |
)
|
| 288 |
|
| 289 |
+
with gr.Column(visible=False) as rf_col:
|
| 290 |
+
gr.Markdown("#### Paramètres Forêt aléatoire")
|
| 291 |
+
rf_n_estimators = gr.Slider(minimum=10, maximum=500, value=100, step=10, label="Nombre d'arbres")
|
| 292 |
+
rf_max_depth = gr.Number(value=0, label="Profondeur max (0 = illimitée)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
|
| 294 |
+
with gr.Column(visible=False) as lda_col:
|
| 295 |
+
gr.Markdown("#### Paramètres LDA")
|
| 296 |
+
lda_solver = gr.Dropdown(choices=["svd", "lsqr", "eigen"], value="svd", label="Solveur")
|
|
|
|
| 297 |
|
| 298 |
+
ml_model_tag = gr.Textbox(label="Nom court du modèle", placeholder="ex. svm_rbf")
|
| 299 |
+
train_classical_btn = gr.Button("Entraîner le classifieur", variant="primary")
|
| 300 |
|
| 301 |
+
with gr.Column():
|
| 302 |
+
ml_summary = gr.JSON(label="Résumé des métriques")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
|
| 304 |
+
ml_report = gr.Dataframe(label="Rapport de classification", interactive=False)
|
| 305 |
+
ml_cm = gr.Dataframe(label="Matrice de confusion", interactive=False)
|
| 306 |
+
ml_cm_img = gr.Image(label="Matrice de confusion — figure", type="filepath")
|
|
|
|
| 307 |
|
| 308 |
+
# ------------------------------------------------------------------ #
|
| 309 |
+
# Tab 3
|
| 310 |
+
# ------------------------------------------------------------------ #
|
| 311 |
+
with gr.Tab("3. Modèles neuronaux"):
|
| 312 |
+
gr.Markdown("## Architecture")
|
| 313 |
|
| 314 |
+
with gr.Row():
|
| 315 |
+
with gr.Column():
|
| 316 |
+
neural_type = gr.Radio(
|
| 317 |
+
choices=["FC sur backbone préentraîné", "CNN de zéro"],
|
| 318 |
+
value="FC sur backbone préentraîné",
|
| 319 |
+
label="Type de modèle",
|
| 320 |
+
info=(
|
| 321 |
+
"FC sur backbone : backbone gelé, seule la tête FC est entraînée — rapide, peu de GPU. "
|
| 322 |
+
"CNN de zéro : réseau convolutif entraîné entièrement depuis rien — référence sans transfert."
|
| 323 |
+
),
|
| 324 |
)
|
| 325 |
|
| 326 |
+
with gr.Column(visible=False) as cnn_arch_col:
|
| 327 |
+
gr.Markdown("#### Architecture CNN")
|
| 328 |
+
num_conv_blocks = gr.Slider(minimum=2, maximum=5, value=3, step=1, label="Blocs convolutionnels")
|
| 329 |
+
base_filters = gr.Dropdown(choices=[16, 32, 64, 128], value=32, label="Filtres du premier bloc")
|
| 330 |
+
kernel_size = gr.Dropdown(choices=[3, 5], value=3, label="Taille du noyau")
|
| 331 |
+
use_batchnorm = gr.Checkbox(value=True, label="BatchNorm")
|
|
|
|
| 332 |
|
| 333 |
+
gr.Markdown("#### Hyperparamètres d'entraînement")
|
| 334 |
+
n_dropout = gr.Slider(minimum=0.0, maximum=0.8, value=0.4, step=0.05, label="Dropout")
|
| 335 |
+
n_fc_dim = gr.Dropdown(choices=[64, 128, 256, 512], value=256, label="Dimension couche cachée")
|
| 336 |
+
n_lr = gr.Number(value=1e-4, label="Taux d'apprentissage")
|
| 337 |
+
n_wd = gr.Number(value=1e-4, label="Weight decay")
|
| 338 |
+
n_bs = gr.Dropdown(choices=[8, 16, 32, 64], value=16, label="Taille du batch")
|
| 339 |
+
n_epochs = gr.Slider(minimum=1, maximum=50, value=20, step=1, label="Époques")
|
| 340 |
+
n_tag = gr.Textbox(label="Nom court du modèle", placeholder="ex. fc_head_v1")
|
| 341 |
|
| 342 |
+
train_neural_btn = gr.Button("Lancer l'entraînement", variant="primary")
|
| 343 |
|
| 344 |
with gr.Column():
|
| 345 |
+
neural_logs = gr.Textbox(label="Journal d'entraînement", lines=20)
|
| 346 |
+
neural_history = gr.JSON(label="Historique")
|
| 347 |
+
neural_summary = gr.JSON(label="Résumé final")
|
|
|
|
|
|
|
|
|
|
| 348 |
|
| 349 |
gr.Markdown("## Résultats sur le test set")
|
| 350 |
+
neural_report = gr.Dataframe(label="Rapport de classification", interactive=False)
|
| 351 |
+
neural_cm = gr.Dataframe(label="Matrice de confusion", interactive=False)
|
| 352 |
+
neural_cm_img = gr.Image(label="Matrice de confusion — figure", type="filepath")
|
| 353 |
+
|
| 354 |
+
# ------------------------------------------------------------------ #
|
| 355 |
+
# Tab 4
|
| 356 |
+
# ------------------------------------------------------------------ #
|
| 357 |
+
with gr.Tab("4. Tester et analyser"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
gr.Markdown("## Sélectionner un modèle sauvegardé")
|
| 359 |
+
gr.Markdown(
|
| 360 |
+
"_Tous les types de modèles apparaissent ici : classifieurs ML, têtes FC et CNN._"
|
| 361 |
+
)
|
| 362 |
|
| 363 |
with gr.Row():
|
| 364 |
with gr.Column():
|
| 365 |
model_selector = gr.Dropdown(
|
| 366 |
choices=initial_models,
|
| 367 |
value=initial_models[0] if initial_models else None,
|
| 368 |
+
label="Modèle",
|
| 369 |
)
|
| 370 |
+
refresh_btn = gr.Button("Actualiser la liste")
|
|
|
|
| 371 |
load_info_btn = gr.Button("Afficher les informations du modèle")
|
| 372 |
+
model_info = gr.JSON(label="Métadonnées")
|
| 373 |
|
| 374 |
with gr.Column():
|
| 375 |
+
evaluate_btn = gr.Button("Évaluer sur le test set", variant="primary")
|
|
|
|
|
|
|
|
|
|
| 376 |
eval_summary = gr.JSON(label="Résumé des métriques")
|
| 377 |
|
| 378 |
+
eval_report = gr.Dataframe(label="Rapport de classification", interactive=False)
|
| 379 |
+
eval_cm = gr.Dataframe(label="Matrice de confusion", interactive=False)
|
| 380 |
+
eval_cm_img = gr.Image(label="Matrice de confusion — figure", type="filepath")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
|
| 382 |
gr.Markdown("## Prédiction sur une image importée")
|
|
|
|
| 383 |
with gr.Row():
|
| 384 |
with gr.Column():
|
| 385 |
upload_image = gr.Image(type="pil", label="Importer une image")
|
| 386 |
predict_btn = gr.Button("Prédire la classe", variant="primary")
|
|
|
|
| 387 |
with gr.Column():
|
| 388 |
+
predict_text = gr.Textbox(label="Résultat", lines=7)
|
| 389 |
predict_probs = gr.Label(label="Probabilités par classe")
|
| 390 |
|
| 391 |
gr.Markdown("## Test sur un échantillon aléatoire du test set")
|
|
|
|
| 392 |
random_test_btn = gr.Button("Tester un échantillon aléatoire")
|
|
|
|
| 393 |
with gr.Row():
|
| 394 |
+
random_img = gr.Image(type="pil", label="Image test")
|
| 395 |
+
random_text = gr.Textbox(label="Résultat", lines=7)
|
| 396 |
+
random_probs = gr.Label(label="Probabilités par classe")
|
| 397 |
+
|
| 398 |
+
# ---------------------------------------------------------------------- #
|
| 399 |
+
# Event wiring
|
| 400 |
+
# ---------------------------------------------------------------------- #
|
| 401 |
|
| 402 |
load_dataset_btn.click(
|
| 403 |
+
fn=load_dataset_callback,
|
| 404 |
inputs=None,
|
| 405 |
outputs=[dataset_summary, class_distribution, class_selector],
|
| 406 |
)
|
|
|
|
| 411 |
outputs=image_gallery,
|
| 412 |
)
|
| 413 |
|
| 414 |
+
extract_btn.click(fn=extract_features_callback, inputs=None, outputs=extract_status)
|
| 415 |
+
|
| 416 |
+
clf_type.change(
|
| 417 |
+
fn=on_clf_type_change,
|
| 418 |
+
inputs=clf_type,
|
| 419 |
+
outputs=[svm_col, logreg_col, knn_col, rf_col, lda_col],
|
| 420 |
)
|
| 421 |
|
| 422 |
+
train_classical_btn.click(
|
| 423 |
+
fn=train_classical_callback,
|
| 424 |
inputs=[
|
| 425 |
+
clf_type,
|
| 426 |
+
svm_c, svm_kernel, svm_gamma,
|
| 427 |
+
logreg_c, logreg_max_iter,
|
| 428 |
+
knn_k, knn_metric,
|
| 429 |
+
rf_n_estimators, rf_max_depth,
|
| 430 |
+
lda_solver,
|
| 431 |
+
ml_model_tag,
|
| 432 |
+
],
|
| 433 |
+
outputs=[ml_summary, ml_report, ml_cm, ml_cm_img, model_selector],
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
neural_type.change(
|
| 437 |
+
fn=on_neural_type_change,
|
| 438 |
+
inputs=neural_type,
|
| 439 |
+
outputs=[cnn_arch_col, n_lr],
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
train_neural_btn.click(
|
| 443 |
+
fn=train_neural_callback,
|
| 444 |
+
inputs=[
|
| 445 |
+
neural_type,
|
| 446 |
+
num_conv_blocks, base_filters, kernel_size, use_batchnorm,
|
| 447 |
+
n_dropout, n_fc_dim,
|
| 448 |
+
n_lr, n_wd, n_bs, n_epochs,
|
| 449 |
+
n_tag,
|
| 450 |
],
|
| 451 |
outputs=[
|
| 452 |
+
neural_logs, neural_history, neural_summary,
|
| 453 |
+
neural_report, neural_cm, neural_cm_img,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
model_selector,
|
| 455 |
],
|
| 456 |
)
|
| 457 |
|
| 458 |
+
refresh_btn.click(fn=refresh_models_callback, inputs=None, outputs=model_selector)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 459 |
|
| 460 |
+
load_info_btn.click(fn=get_model_info_callback, inputs=model_selector, outputs=model_info)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 461 |
|
| 462 |
evaluate_btn.click(
|
| 463 |
+
fn=evaluate_callback,
|
| 464 |
inputs=model_selector,
|
| 465 |
+
outputs=[eval_summary, eval_report, eval_cm, eval_cm_img],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 466 |
)
|
| 467 |
|
| 468 |
predict_btn.click(
|
| 469 |
+
fn=predict_callback,
|
| 470 |
inputs=[model_selector, upload_image],
|
| 471 |
outputs=[predict_text, predict_probs],
|
| 472 |
)
|
| 473 |
|
| 474 |
random_test_btn.click(
|
| 475 |
+
fn=random_test_callback,
|
| 476 |
inputs=model_selector,
|
| 477 |
+
outputs=[random_img, random_text, random_probs],
|
| 478 |
)
|
| 479 |
|
| 480 |
|
| 481 |
if __name__ == "__main__":
|
| 482 |
+
demo.launch(ssr_mode=False)
|
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from huggingface_hub import hf_hub_download
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from torchvision import models
|
| 7 |
+
|
| 8 |
+
from config import HF_BACKBONE_REPO, HF_TOKEN
|
| 9 |
+
|
| 10 |
+
_BACKBONE = None
|
| 11 |
+
_FEATURES_CACHE = None
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def load_backbone(device: torch.device) -> nn.Module:
|
| 15 |
+
global _BACKBONE
|
| 16 |
+
|
| 17 |
+
if _BACKBONE is not None:
|
| 18 |
+
return _BACKBONE.to(device)
|
| 19 |
+
|
| 20 |
+
if not HF_BACKBONE_REPO:
|
| 21 |
+
raise RuntimeError(
|
| 22 |
+
"HF_BACKBONE_REPO n'est pas configuré. "
|
| 23 |
+
"Ajoutez-le dans les Secrets du Space Hugging Face."
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
pt_path = hf_hub_download(
|
| 27 |
+
repo_id=HF_BACKBONE_REPO,
|
| 28 |
+
filename="resnet18_charcoal_backbone.pt",
|
| 29 |
+
token=HF_TOKEN,
|
| 30 |
+
repo_type="model",
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
backbone = models.resnet18()
|
| 34 |
+
backbone.fc = nn.Identity()
|
| 35 |
+
backbone.load_state_dict(torch.load(pt_path, map_location="cpu"))
|
| 36 |
+
|
| 37 |
+
for p in backbone.parameters():
|
| 38 |
+
p.requires_grad = False
|
| 39 |
+
|
| 40 |
+
_BACKBONE = backbone
|
| 41 |
+
return _BACKBONE.to(device)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def extract_all_features(batch_size: int = 64):
|
| 45 |
+
global _FEATURES_CACHE
|
| 46 |
+
|
| 47 |
+
from data_utils import prepare_splits, get_class_names, HFDatasetWrapper, get_eval_transform
|
| 48 |
+
|
| 49 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 50 |
+
backbone = load_backbone(device)
|
| 51 |
+
backbone.eval()
|
| 52 |
+
|
| 53 |
+
splits = prepare_splits()
|
| 54 |
+
class_names = get_class_names()
|
| 55 |
+
|
| 56 |
+
cache = {}
|
| 57 |
+
counts = {}
|
| 58 |
+
|
| 59 |
+
for split_name, split_data in splits.items():
|
| 60 |
+
dataset = HFDatasetWrapper(split_data, get_eval_transform())
|
| 61 |
+
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
| 62 |
+
|
| 63 |
+
X_parts, y_parts = [], []
|
| 64 |
+
with torch.no_grad():
|
| 65 |
+
for images, labels in loader:
|
| 66 |
+
features = backbone(images.to(device))
|
| 67 |
+
X_parts.append(features.cpu().numpy())
|
| 68 |
+
y_parts.append(labels.numpy())
|
| 69 |
+
|
| 70 |
+
cache[split_name] = {
|
| 71 |
+
"X": np.concatenate(X_parts, axis=0),
|
| 72 |
+
"y": np.concatenate(y_parts, axis=0),
|
| 73 |
+
}
|
| 74 |
+
counts[split_name] = len(cache[split_name]["y"])
|
| 75 |
+
|
| 76 |
+
_FEATURES_CACHE = cache
|
| 77 |
+
return cache, class_names, counts
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def get_cached_features():
|
| 81 |
+
return _FEATURES_CACHE
|
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
import joblib
|
| 7 |
+
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
|
| 8 |
+
from sklearn.ensemble import RandomForestClassifier
|
| 9 |
+
from sklearn.linear_model import LogisticRegression
|
| 10 |
+
from sklearn.neighbors import KNeighborsClassifier
|
| 11 |
+
from sklearn.pipeline import Pipeline
|
| 12 |
+
from sklearn.preprocessing import StandardScaler
|
| 13 |
+
from sklearn.svm import SVC
|
| 14 |
+
|
| 15 |
+
from config import MODEL_DIR, META_DIR
|
| 16 |
+
from metrics_utils import compute_classification_metrics, save_confusion_matrix_figure
|
| 17 |
+
|
| 18 |
+
CLF_TYPE_MAP = {
|
| 19 |
+
"SVM": "svm",
|
| 20 |
+
"Régression logistique": "logreg",
|
| 21 |
+
"k-NN": "knn",
|
| 22 |
+
"Forêt aléatoire": "rf",
|
| 23 |
+
"LDA": "lda",
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def classifier_path(model_name: str) -> str:
|
| 28 |
+
return os.path.join(MODEL_DIR, f"{model_name}.joblib")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def meta_path(model_name: str) -> str:
|
| 32 |
+
return os.path.join(META_DIR, f"{model_name}.json")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def build_pipeline(clf_type: str, **params) -> Pipeline:
|
| 36 |
+
key = CLF_TYPE_MAP.get(clf_type, clf_type)
|
| 37 |
+
|
| 38 |
+
if key == "svm":
|
| 39 |
+
clf = SVC(
|
| 40 |
+
C=params.get("C", 1.0),
|
| 41 |
+
kernel=params.get("kernel", "rbf"),
|
| 42 |
+
gamma=params.get("gamma", "scale"),
|
| 43 |
+
probability=True,
|
| 44 |
+
random_state=42,
|
| 45 |
+
)
|
| 46 |
+
elif key == "logreg":
|
| 47 |
+
clf = LogisticRegression(
|
| 48 |
+
C=params.get("C", 1.0),
|
| 49 |
+
max_iter=params.get("max_iter", 1000),
|
| 50 |
+
random_state=42,
|
| 51 |
+
)
|
| 52 |
+
elif key == "knn":
|
| 53 |
+
clf = KNeighborsClassifier(
|
| 54 |
+
n_neighbors=params.get("n_neighbors", 5),
|
| 55 |
+
metric=params.get("metric", "euclidean"),
|
| 56 |
+
)
|
| 57 |
+
elif key == "rf":
|
| 58 |
+
max_depth = params.get("max_depth") or None
|
| 59 |
+
clf = RandomForestClassifier(
|
| 60 |
+
n_estimators=params.get("n_estimators", 100),
|
| 61 |
+
max_depth=max_depth,
|
| 62 |
+
random_state=42,
|
| 63 |
+
n_jobs=-1,
|
| 64 |
+
)
|
| 65 |
+
elif key == "lda":
|
| 66 |
+
clf = LinearDiscriminantAnalysis(solver=params.get("solver", "svd"))
|
| 67 |
+
else:
|
| 68 |
+
raise ValueError(f"Classifieur inconnu : {clf_type}")
|
| 69 |
+
|
| 70 |
+
return Pipeline([("scaler", StandardScaler()), ("clf", clf)])
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def train_classical_model(
|
| 74 |
+
clf_type: str,
|
| 75 |
+
features_cache: dict,
|
| 76 |
+
class_names: List[str],
|
| 77 |
+
model_tag: str = "",
|
| 78 |
+
**params,
|
| 79 |
+
):
|
| 80 |
+
X_train = features_cache["train"]["X"]
|
| 81 |
+
y_train = features_cache["train"]["y"]
|
| 82 |
+
X_test = features_cache["test"]["X"]
|
| 83 |
+
y_test = features_cache["test"]["y"]
|
| 84 |
+
|
| 85 |
+
pipeline = build_pipeline(clf_type, **params)
|
| 86 |
+
pipeline.fit(X_train, y_train)
|
| 87 |
+
|
| 88 |
+
y_pred = pipeline.predict(X_test)
|
| 89 |
+
metrics = compute_classification_metrics(y_test.tolist(), y_pred.tolist(), class_names)
|
| 90 |
+
|
| 91 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 92 |
+
safe_tag = model_tag.strip().replace(" ", "_") if model_tag.strip() else CLF_TYPE_MAP.get(clf_type, "clf")
|
| 93 |
+
model_name = f"{safe_tag}_{timestamp}"
|
| 94 |
+
|
| 95 |
+
joblib.dump(pipeline, classifier_path(model_name))
|
| 96 |
+
cm_path = save_confusion_matrix_figure(metrics["confusion_matrix"], model_name)
|
| 97 |
+
|
| 98 |
+
config_dict = {
|
| 99 |
+
"model_type": CLF_TYPE_MAP.get(clf_type, clf_type),
|
| 100 |
+
"clf_type_label": clf_type,
|
| 101 |
+
"class_names": class_names,
|
| 102 |
+
"num_classes": len(class_names),
|
| 103 |
+
**{k: v for k, v in params.items() if v is not None},
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
training_summary = {
|
| 107 |
+
"test_accuracy": metrics["accuracy"],
|
| 108 |
+
"test_f1_macro": metrics["f1_macro"],
|
| 109 |
+
"test_f1_weighted": metrics["f1_weighted"],
|
| 110 |
+
"train_samples": int(len(X_train)),
|
| 111 |
+
"test_samples": int(len(X_test)),
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
with open(meta_path(model_name), "w", encoding="utf-8") as f:
|
| 115 |
+
json.dump(
|
| 116 |
+
{
|
| 117 |
+
"model_name": model_name,
|
| 118 |
+
"config": config_dict,
|
| 119 |
+
"training_summary": training_summary,
|
| 120 |
+
"created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
| 121 |
+
},
|
| 122 |
+
f,
|
| 123 |
+
indent=2,
|
| 124 |
+
ensure_ascii=False,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
return {
|
| 128 |
+
"model_name": model_name,
|
| 129 |
+
"summary": training_summary,
|
| 130 |
+
"classification_report": metrics["classification_report"],
|
| 131 |
+
"confusion_matrix": metrics["confusion_matrix"],
|
| 132 |
+
"confusion_matrix_path": cm_path,
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def load_classical_pipeline(model_name: str) -> Pipeline:
|
| 137 |
+
path = classifier_path(model_name)
|
| 138 |
+
if not os.path.exists(path):
|
| 139 |
+
raise FileNotFoundError(f"Classifieur introuvable : {model_name}")
|
| 140 |
+
return joblib.load(path)
|
|
@@ -10,13 +10,13 @@ os.makedirs(MODEL_DIR, exist_ok=True)
|
|
| 10 |
os.makedirs(META_DIR, exist_ok=True)
|
| 11 |
os.makedirs(FIGURE_DIR, exist_ok=True)
|
| 12 |
|
| 13 |
-
# Replace this with your real private dataset repo
|
| 14 |
HF_DATASET_REPO = os.environ.get("HF_DATASET_REPO", "CircleStar/charcoal-microscopy")
|
| 15 |
-
|
| 16 |
-
# Must be added in Hugging Face Space Settings → Secrets
|
| 17 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 18 |
|
| 19 |
IMAGE_SIZE = 224
|
| 20 |
RANDOM_SEED = 42
|
| 21 |
|
| 22 |
-
DATASET_DISPLAY_NAME = "Images microscopiques de charbons de bois"
|
|
|
|
|
|
|
|
|
| 10 |
os.makedirs(META_DIR, exist_ok=True)
|
| 11 |
os.makedirs(FIGURE_DIR, exist_ok=True)
|
| 12 |
|
|
|
|
| 13 |
HF_DATASET_REPO = os.environ.get("HF_DATASET_REPO", "CircleStar/charcoal-microscopy")
|
| 14 |
+
HF_BACKBONE_REPO = os.environ.get("HF_BACKBONE_REPO", "")
|
|
|
|
| 15 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 16 |
|
| 17 |
IMAGE_SIZE = 224
|
| 18 |
RANDOM_SEED = 42
|
| 19 |
|
| 20 |
+
DATASET_DISPLAY_NAME = "Images microscopiques de charbons de bois"
|
| 21 |
+
|
| 22 |
+
CLASSICAL_MODEL_TYPES = frozenset({"svm", "logreg", "knn", "rf", "lda"})
|
|
@@ -24,7 +24,7 @@ class HFDatasetWrapper(Dataset):
|
|
| 24 |
|
| 25 |
def __len__(self):
|
| 26 |
return len(self.dataset)
|
| 27 |
-
|
| 28 |
def __getitem__(self, idx):
|
| 29 |
item = self.dataset[idx]
|
| 30 |
|
|
|
|
| 24 |
|
| 25 |
def __len__(self):
|
| 26 |
return len(self.dataset)
|
| 27 |
+
·
|
| 28 |
def __getitem__(self, idx):
|
| 29 |
item = self.dataset[idx]
|
| 30 |
|
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
finetune_backbone.py
|
| 3 |
+
|
| 4 |
+
Fine-tune ResNet18 (ImageNet) on the local charcoal microscopy dataset.
|
| 5 |
+
Goal: produce a domain-adapted backbone for students to use as a frozen
|
| 6 |
+
feature extractor. The full dataset is used intentionally — this is a
|
| 7 |
+
teaching artifact, not a research model with a held-out test split.
|
| 8 |
+
|
| 9 |
+
Output (in backbone/):
|
| 10 |
+
resnet18_charcoal_backbone.pt — backbone weights, FC replaced by Identity
|
| 11 |
+
backbone_meta.json — class names, feature dim, training info
|
| 12 |
+
|
| 13 |
+
Usage:
|
| 14 |
+
python finetune_backbone.py
|
| 15 |
+
python finetune_backbone.py --epochs 40 --batch-size 16
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import json
|
| 20 |
+
import time
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import torch.nn as nn
|
| 25 |
+
import torch.optim as optim
|
| 26 |
+
from PIL import Image
|
| 27 |
+
from torch.utils.data import DataLoader, Dataset
|
| 28 |
+
from torchvision import models, transforms
|
| 29 |
+
|
| 30 |
+
# ---------------------------------------------------------------------------
|
| 31 |
+
# Paths
|
| 32 |
+
# ---------------------------------------------------------------------------
|
| 33 |
+
ROOT = Path(__file__).parent
|
| 34 |
+
DATA_DIR = ROOT / "data"
|
| 35 |
+
OUTPUT_DIR = ROOT / "backbone"
|
| 36 |
+
OUTPUT_DIR.mkdir(exist_ok=True)
|
| 37 |
+
|
| 38 |
+
# ---------------------------------------------------------------------------
|
| 39 |
+
# Defaults
|
| 40 |
+
# ---------------------------------------------------------------------------
|
| 41 |
+
IMAGE_SIZE = 224
|
| 42 |
+
SEED = 42
|
| 43 |
+
|
| 44 |
+
WARMUP_EPOCHS = 10 # backbone frozen, only FC trained
|
| 45 |
+
WARMUP_LR = 1e-3
|
| 46 |
+
|
| 47 |
+
FINETUNE_EPOCHS = 40 # all layers unfrozen, small LR
|
| 48 |
+
FINETUNE_LR = 5e-5
|
| 49 |
+
WEIGHT_DECAY = 1e-4
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# ---------------------------------------------------------------------------
|
| 53 |
+
# Dataset
|
| 54 |
+
# ---------------------------------------------------------------------------
|
| 55 |
+
class CharcoalDataset(Dataset):
|
| 56 |
+
"""Flat ImageFolder-style dataset that handles .tif files."""
|
| 57 |
+
|
| 58 |
+
EXTENSIONS = {".tif", ".tiff", ".jpg", ".jpeg", ".png"}
|
| 59 |
+
|
| 60 |
+
def __init__(self, root: Path, transform=None):
|
| 61 |
+
self.transform = transform
|
| 62 |
+
self.classes = sorted(
|
| 63 |
+
d.name for d in root.iterdir()
|
| 64 |
+
if d.is_dir() and not d.name.startswith(".")
|
| 65 |
+
)
|
| 66 |
+
self.class_to_idx = {c: i for i, c in enumerate(self.classes)}
|
| 67 |
+
|
| 68 |
+
self.samples = []
|
| 69 |
+
for cls in self.classes:
|
| 70 |
+
for p in sorted((root / cls).iterdir()):
|
| 71 |
+
if p.suffix.lower() in self.EXTENSIONS:
|
| 72 |
+
self.samples.append((p, self.class_to_idx[cls]))
|
| 73 |
+
|
| 74 |
+
def __len__(self):
|
| 75 |
+
return len(self.samples)
|
| 76 |
+
|
| 77 |
+
def __getitem__(self, idx):
|
| 78 |
+
path, label = self.samples[idx]
|
| 79 |
+
image = Image.open(path).convert("RGB")
|
| 80 |
+
if self.transform:
|
| 81 |
+
image = self.transform(image)
|
| 82 |
+
return image, label
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def make_transform():
|
| 86 |
+
# Aggressive augmentation: microscopy images have no canonical orientation
|
| 87 |
+
# and vary in staining intensity.
|
| 88 |
+
return transforms.Compose([
|
| 89 |
+
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
|
| 90 |
+
transforms.RandomHorizontalFlip(),
|
| 91 |
+
transforms.RandomVerticalFlip(),
|
| 92 |
+
transforms.RandomRotation(180),
|
| 93 |
+
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2),
|
| 94 |
+
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.85, 1.15)),
|
| 95 |
+
transforms.ToTensor(),
|
| 96 |
+
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
| 97 |
+
])
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# ---------------------------------------------------------------------------
|
| 101 |
+
# Training helpers
|
| 102 |
+
# ---------------------------------------------------------------------------
|
| 103 |
+
def run_epoch(model, loader, criterion, optimizer, device):
|
| 104 |
+
model.train()
|
| 105 |
+
total_loss, correct, total = 0.0, 0, 0
|
| 106 |
+
|
| 107 |
+
for images, labels in loader:
|
| 108 |
+
images, labels = images.to(device), labels.to(device)
|
| 109 |
+
|
| 110 |
+
optimizer.zero_grad()
|
| 111 |
+
outputs = model(images)
|
| 112 |
+
loss = criterion(outputs, labels)
|
| 113 |
+
loss.backward()
|
| 114 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
| 115 |
+
optimizer.step()
|
| 116 |
+
|
| 117 |
+
total_loss += loss.item() * images.size(0)
|
| 118 |
+
correct += (outputs.argmax(1) == labels).sum().item()
|
| 119 |
+
total += labels.size(0)
|
| 120 |
+
|
| 121 |
+
return total_loss / total, correct / total
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
# ---------------------------------------------------------------------------
|
| 125 |
+
# Main
|
| 126 |
+
# ---------------------------------------------------------------------------
|
| 127 |
+
def main():
|
| 128 |
+
parser = argparse.ArgumentParser()
|
| 129 |
+
parser.add_argument("--warmup-epochs", type=int, default=WARMUP_EPOCHS)
|
| 130 |
+
parser.add_argument("--finetune-epochs", type=int, default=FINETUNE_EPOCHS)
|
| 131 |
+
parser.add_argument("--batch-size", type=int, default=8)
|
| 132 |
+
parser.add_argument("--warmup-lr", type=float, default=WARMUP_LR)
|
| 133 |
+
parser.add_argument("--finetune-lr", type=float, default=FINETUNE_LR)
|
| 134 |
+
args = parser.parse_args()
|
| 135 |
+
|
| 136 |
+
torch.manual_seed(SEED)
|
| 137 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 138 |
+
print(f"Device : {device}")
|
| 139 |
+
|
| 140 |
+
dataset = CharcoalDataset(DATA_DIR, transform=make_transform())
|
| 141 |
+
num_classes = len(dataset.classes)
|
| 142 |
+
print(f"Classes : {num_classes} | Images : {len(dataset)}")
|
| 143 |
+
print(f" {', '.join(dataset.classes)}\n")
|
| 144 |
+
|
| 145 |
+
loader = DataLoader(
|
| 146 |
+
dataset,
|
| 147 |
+
batch_size=args.batch_size,
|
| 148 |
+
shuffle=True,
|
| 149 |
+
num_workers=0, # 0 = safe on Windows
|
| 150 |
+
pin_memory=(device.type == "cuda"),
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
# -----------------------------------------------------------------------
|
| 154 |
+
# Build model
|
| 155 |
+
# -----------------------------------------------------------------------
|
| 156 |
+
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
|
| 157 |
+
model.fc = nn.Linear(model.fc.in_features, num_classes)
|
| 158 |
+
model.to(device)
|
| 159 |
+
|
| 160 |
+
# Label smoothing helps regularise with tiny datasets
|
| 161 |
+
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
|
| 162 |
+
|
| 163 |
+
# -----------------------------------------------------------------------
|
| 164 |
+
# Phase 1 — warm-up: freeze backbone, train FC only
|
| 165 |
+
# -----------------------------------------------------------------------
|
| 166 |
+
print(f"=== Phase 1 : warm-up ({args.warmup_epochs} epochs, backbone frozen) ===")
|
| 167 |
+
for p in model.parameters():
|
| 168 |
+
p.requires_grad = False
|
| 169 |
+
for p in model.fc.parameters():
|
| 170 |
+
p.requires_grad = True
|
| 171 |
+
|
| 172 |
+
optimizer = optim.AdamW(model.fc.parameters(), lr=args.warmup_lr, weight_decay=WEIGHT_DECAY)
|
| 173 |
+
|
| 174 |
+
for epoch in range(1, args.warmup_epochs + 1):
|
| 175 |
+
loss, acc = run_epoch(model, loader, criterion, optimizer, device)
|
| 176 |
+
print(f" [{epoch:>3}/{args.warmup_epochs}] loss={loss:.4f} acc={acc:.4f}")
|
| 177 |
+
|
| 178 |
+
# -----------------------------------------------------------------------
|
| 179 |
+
# Phase 2 — full fine-tune: unfreeze all layers
|
| 180 |
+
# -----------------------------------------------------------------------
|
| 181 |
+
print(f"\n=== Phase 2 : fine-tune ({args.finetune_epochs} epochs, all layers) ===")
|
| 182 |
+
for p in model.parameters():
|
| 183 |
+
p.requires_grad = True
|
| 184 |
+
|
| 185 |
+
optimizer = optim.AdamW(
|
| 186 |
+
model.parameters(), lr=args.finetune_lr, weight_decay=WEIGHT_DECAY
|
| 187 |
+
)
|
| 188 |
+
scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
| 189 |
+
optimizer, T_max=args.finetune_epochs, eta_min=args.finetune_lr * 0.05
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
best_acc = 0.0
|
| 193 |
+
best_state = None
|
| 194 |
+
t0 = time.time()
|
| 195 |
+
|
| 196 |
+
for epoch in range(1, args.finetune_epochs + 1):
|
| 197 |
+
loss, acc = run_epoch(model, loader, criterion, optimizer, device)
|
| 198 |
+
scheduler.step()
|
| 199 |
+
lr = optimizer.param_groups[0]["lr"]
|
| 200 |
+
print(f" [{epoch:>3}/{args.finetune_epochs}] loss={loss:.4f} acc={acc:.4f} lr={lr:.2e}")
|
| 201 |
+
|
| 202 |
+
if acc > best_acc:
|
| 203 |
+
best_acc = acc
|
| 204 |
+
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
|
| 205 |
+
|
| 206 |
+
elapsed = time.time() - t0
|
| 207 |
+
print(f"\nTemps phase 2 : {elapsed:.0f}s | Meilleure accuracy entraînement : {best_acc:.4f}")
|
| 208 |
+
|
| 209 |
+
# -----------------------------------------------------------------------
|
| 210 |
+
# Save backbone (FC replaced by Identity — outputs 512-dim feature vector)
|
| 211 |
+
# -----------------------------------------------------------------------
|
| 212 |
+
model.load_state_dict(best_state)
|
| 213 |
+
|
| 214 |
+
backbone = models.resnet18()
|
| 215 |
+
backbone.fc = nn.Identity()
|
| 216 |
+
|
| 217 |
+
# Transfer all weights except fc (which is now Identity with no parameters)
|
| 218 |
+
backbone_state = {k: v for k, v in best_state.items() if not k.startswith("fc.")}
|
| 219 |
+
backbone.load_state_dict(backbone_state, strict=False)
|
| 220 |
+
|
| 221 |
+
backbone_path = OUTPUT_DIR / "resnet18_charcoal_backbone.pt"
|
| 222 |
+
torch.save(backbone.state_dict(), backbone_path)
|
| 223 |
+
print(f"Backbone sauvegardé : {backbone_path}")
|
| 224 |
+
|
| 225 |
+
# -----------------------------------------------------------------------
|
| 226 |
+
# Save metadata
|
| 227 |
+
# -----------------------------------------------------------------------
|
| 228 |
+
meta = {
|
| 229 |
+
"classes": dataset.classes,
|
| 230 |
+
"num_classes": num_classes,
|
| 231 |
+
"image_size": IMAGE_SIZE,
|
| 232 |
+
"feature_dim": 512,
|
| 233 |
+
"warmup_epochs": args.warmup_epochs,
|
| 234 |
+
"finetune_epochs": args.finetune_epochs,
|
| 235 |
+
"best_train_acc": round(float(best_acc), 4),
|
| 236 |
+
"device": str(device),
|
| 237 |
+
}
|
| 238 |
+
meta_path = OUTPUT_DIR / "backbone_meta.json"
|
| 239 |
+
with open(meta_path, "w", encoding="utf-8") as f:
|
| 240 |
+
json.dump(meta, f, indent=2, ensure_ascii=False)
|
| 241 |
+
print(f"Métadonnées sauvegardées : {meta_path}")
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
if __name__ == "__main__":
|
| 245 |
+
main()
|
|
@@ -1,33 +1,22 @@
|
|
| 1 |
import torch.nn as nn
|
| 2 |
-
from torchvision import models
|
| 3 |
|
| 4 |
|
| 5 |
-
class
|
| 6 |
-
|
| 7 |
-
super().__init__()
|
| 8 |
-
|
| 9 |
-
weights = models.ResNet18_Weights.DEFAULT
|
| 10 |
-
self.backbone = models.resnet18(weights=weights)
|
| 11 |
-
in_features = self.backbone.fc.in_features
|
| 12 |
-
|
| 13 |
-
# Gel de tout le réseau sauf layer4 et classifieur
|
| 14 |
-
for param in self.backbone.parameters():
|
| 15 |
-
param.requires_grad = False
|
| 16 |
-
for param in self.backbone.layer4.parameters():
|
| 17 |
-
param.requires_grad = True
|
| 18 |
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
| 20 |
nn.Dropout(dropout),
|
| 21 |
-
nn.Linear(
|
| 22 |
-
nn.ReLU(),
|
| 23 |
nn.Dropout(dropout),
|
| 24 |
nn.Linear(fc_dim, num_classes),
|
| 25 |
)
|
| 26 |
-
for param in self.backbone.fc.parameters():
|
| 27 |
-
param.requires_grad = True
|
| 28 |
|
| 29 |
def forward(self, x):
|
| 30 |
-
return self.backbone(x)
|
| 31 |
|
| 32 |
|
| 33 |
class SimpleCNN(nn.Module):
|
|
@@ -48,7 +37,6 @@ class SimpleCNN(nn.Module):
|
|
| 48 |
in_channels = 3
|
| 49 |
|
| 50 |
for i in range(num_conv_blocks):
|
| 51 |
-
# Les filtres doublent à chaque bloc, plafonnés à 512
|
| 52 |
out_channels = min(base_filters * (2 ** i), 512)
|
| 53 |
layers.append(nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding))
|
| 54 |
if use_batchnorm:
|
|
@@ -58,7 +46,6 @@ class SimpleCNN(nn.Module):
|
|
| 58 |
in_channels = out_channels
|
| 59 |
|
| 60 |
self.features = nn.Sequential(*layers)
|
| 61 |
-
# Pooling global : indépendant de la taille spatiale d'entrée
|
| 62 |
self.pool = nn.AdaptiveAvgPool2d(1)
|
| 63 |
|
| 64 |
self.classifier = nn.Sequential(
|
|
@@ -70,7 +57,5 @@ class SimpleCNN(nn.Module):
|
|
| 70 |
)
|
| 71 |
|
| 72 |
def forward(self, x):
|
| 73 |
-
x = self.features(x)
|
| 74 |
-
|
| 75 |
-
x = x.flatten(1)
|
| 76 |
-
return self.classifier(x)
|
|
|
|
| 1 |
import torch.nn as nn
|
|
|
|
| 2 |
|
| 3 |
|
| 4 |
+
class BackboneWithFC(nn.Module):
|
| 5 |
+
"""Frozen ResNet18 backbone + trainable FC classifier head."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
+
def __init__(self, backbone: nn.Module, num_classes: int, dropout: float = 0.4, fc_dim: int = 256):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.backbone = backbone
|
| 10 |
+
self.classifier = nn.Sequential(
|
| 11 |
nn.Dropout(dropout),
|
| 12 |
+
nn.Linear(512, fc_dim),
|
| 13 |
+
nn.ReLU(inplace=True),
|
| 14 |
nn.Dropout(dropout),
|
| 15 |
nn.Linear(fc_dim, num_classes),
|
| 16 |
)
|
|
|
|
|
|
|
| 17 |
|
| 18 |
def forward(self, x):
|
| 19 |
+
return self.classifier(self.backbone(x))
|
| 20 |
|
| 21 |
|
| 22 |
class SimpleCNN(nn.Module):
|
|
|
|
| 37 |
in_channels = 3
|
| 38 |
|
| 39 |
for i in range(num_conv_blocks):
|
|
|
|
| 40 |
out_channels = min(base_filters * (2 ** i), 512)
|
| 41 |
layers.append(nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding))
|
| 42 |
if use_batchnorm:
|
|
|
|
| 46 |
in_channels = out_channels
|
| 47 |
|
| 48 |
self.features = nn.Sequential(*layers)
|
|
|
|
| 49 |
self.pool = nn.AdaptiveAvgPool2d(1)
|
| 50 |
|
| 51 |
self.classifier = nn.Sequential(
|
|
|
|
| 57 |
)
|
| 58 |
|
| 59 |
def forward(self, x):
|
| 60 |
+
x = self.pool(self.features(x))
|
| 61 |
+
return self.classifier(x.flatten(1))
|
|
|
|
|
|
|
@@ -1,41 +1,56 @@
|
|
| 1 |
import random
|
| 2 |
|
|
|
|
| 3 |
import torch
|
| 4 |
from PIL import Image
|
| 5 |
|
|
|
|
| 6 |
from data_utils import get_eval_transform, prepare_splits, get_class_names
|
| 7 |
-
from train_utils import load_model, get_runtime_device
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
def predict_uploaded_image(model_name: str, image: Image.Image):
|
| 11 |
if not model_name:
|
| 12 |
return "Veuillez sélectionner un modèle.", None
|
| 13 |
-
|
| 14 |
if image is None:
|
| 15 |
return "Veuillez importer une image.", None
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
class_names = meta["config"]["class_names"]
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
image = image.convert("RGB")
|
| 24 |
-
tensor = transform(image).unsqueeze(0).to(device)
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
result_text = (
|
| 32 |
f"Prédiction : {class_names[pred_idx]}\n"
|
| 33 |
f"Confiance : {max(probs):.4f}\n\n"
|
| 34 |
f"Modèle : {model_name}\n"
|
| 35 |
-
f"
|
| 36 |
-
f"Appareil
|
| 37 |
)
|
| 38 |
-
|
| 39 |
prob_dict = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
|
| 40 |
return result_text, prob_dict
|
| 41 |
|
|
@@ -44,40 +59,44 @@ def test_random_sample(model_name: str):
|
|
| 44 |
if not model_name:
|
| 45 |
return None, "Veuillez sélectionner un modèle.", None
|
| 46 |
|
|
|
|
|
|
|
|
|
|
| 47 |
device = get_runtime_device()
|
| 48 |
-
model, meta = load_model(model_name, device)
|
| 49 |
|
| 50 |
splits = prepare_splits()
|
| 51 |
-
class_names = get_class_names()
|
| 52 |
test_dataset = splits["test"]
|
| 53 |
|
| 54 |
idx = random.randint(0, len(test_dataset) - 1)
|
| 55 |
item = test_dataset[idx]
|
| 56 |
-
|
| 57 |
image = item["image"]
|
| 58 |
if not isinstance(image, Image.Image):
|
| 59 |
image = Image.open(image)
|
| 60 |
-
|
| 61 |
image = image.convert("RGB")
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
result_text = (
|
| 75 |
f"Échantillon test aléatoire\n"
|
| 76 |
f"Vérité terrain : {label_name}\n"
|
| 77 |
-
f"Prédiction
|
| 78 |
-
f"Confiance
|
| 79 |
-
f"
|
|
|
|
| 80 |
)
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
return image, result_text, prob_dict
|
|
|
|
| 1 |
import random
|
| 2 |
|
| 3 |
+
import numpy as np
|
| 4 |
import torch
|
| 5 |
from PIL import Image
|
| 6 |
|
| 7 |
+
from config import CLASSICAL_MODEL_TYPES
|
| 8 |
from data_utils import get_eval_transform, prepare_splits, get_class_names
|
| 9 |
+
from train_utils import load_model, get_runtime_device, _load_meta
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _extract_feature(image: Image.Image, device: torch.device) -> np.ndarray:
|
| 13 |
+
from backbone_utils import load_backbone
|
| 14 |
+
backbone = load_backbone(device)
|
| 15 |
+
backbone.eval()
|
| 16 |
+
tensor = get_eval_transform()(image.convert("RGB")).unsqueeze(0).to(device)
|
| 17 |
+
with torch.no_grad():
|
| 18 |
+
feat = backbone(tensor)
|
| 19 |
+
return feat.cpu().numpy()
|
| 20 |
|
| 21 |
|
| 22 |
def predict_uploaded_image(model_name: str, image: Image.Image):
|
| 23 |
if not model_name:
|
| 24 |
return "Veuillez sélectionner un modèle.", None
|
|
|
|
| 25 |
if image is None:
|
| 26 |
return "Veuillez importer une image.", None
|
| 27 |
|
| 28 |
+
meta = _load_meta(model_name)
|
| 29 |
+
model_type = meta["config"].get("model_type", "cnn")
|
|
|
|
| 30 |
class_names = meta["config"]["class_names"]
|
| 31 |
+
device = get_runtime_device()
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
+
if model_type in CLASSICAL_MODEL_TYPES:
|
| 34 |
+
from classical_ml_utils import load_classical_pipeline
|
| 35 |
+
pipeline = load_classical_pipeline(model_name)
|
| 36 |
+
feat = _extract_feature(image, device)
|
| 37 |
+
probs = pipeline.predict_proba(feat)[0].tolist()
|
| 38 |
+
pred_idx = int(np.argmax(probs))
|
| 39 |
+
else:
|
| 40 |
+
model, _ = load_model(model_name, device)
|
| 41 |
+
tensor = get_eval_transform()(image.convert("RGB")).unsqueeze(0).to(device)
|
| 42 |
+
with torch.no_grad():
|
| 43 |
+
logits = model(tensor)
|
| 44 |
+
probs = torch.softmax(logits, dim=1).squeeze(0).cpu().tolist()
|
| 45 |
+
pred_idx = int(torch.argmax(logits, dim=1).item())
|
| 46 |
|
| 47 |
result_text = (
|
| 48 |
f"Prédiction : {class_names[pred_idx]}\n"
|
| 49 |
f"Confiance : {max(probs):.4f}\n\n"
|
| 50 |
f"Modèle : {model_name}\n"
|
| 51 |
+
f"Type : {model_type}\n"
|
| 52 |
+
f"Appareil : {device}"
|
| 53 |
)
|
|
|
|
| 54 |
prob_dict = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
|
| 55 |
return result_text, prob_dict
|
| 56 |
|
|
|
|
| 59 |
if not model_name:
|
| 60 |
return None, "Veuillez sélectionner un modèle.", None
|
| 61 |
|
| 62 |
+
meta = _load_meta(model_name)
|
| 63 |
+
model_type = meta["config"].get("model_type", "cnn")
|
| 64 |
+
class_names = get_class_names()
|
| 65 |
device = get_runtime_device()
|
|
|
|
| 66 |
|
| 67 |
splits = prepare_splits()
|
|
|
|
| 68 |
test_dataset = splits["test"]
|
| 69 |
|
| 70 |
idx = random.randint(0, len(test_dataset) - 1)
|
| 71 |
item = test_dataset[idx]
|
|
|
|
| 72 |
image = item["image"]
|
| 73 |
if not isinstance(image, Image.Image):
|
| 74 |
image = Image.open(image)
|
|
|
|
| 75 |
image = image.convert("RGB")
|
| 76 |
+
label_name = class_names[int(item["label"])]
|
| 77 |
+
|
| 78 |
+
if model_type in CLASSICAL_MODEL_TYPES:
|
| 79 |
+
from classical_ml_utils import load_classical_pipeline
|
| 80 |
+
pipeline = load_classical_pipeline(model_name)
|
| 81 |
+
feat = _extract_feature(image, device)
|
| 82 |
+
probs = pipeline.predict_proba(feat)[0].tolist()
|
| 83 |
+
pred_idx = int(np.argmax(probs))
|
| 84 |
+
else:
|
| 85 |
+
model, _ = load_model(model_name, device)
|
| 86 |
+
tensor = get_eval_transform()(image).unsqueeze(0).to(device)
|
| 87 |
+
with torch.no_grad():
|
| 88 |
+
logits = model(tensor)
|
| 89 |
+
probs = torch.softmax(logits, dim=1).squeeze(0).cpu().tolist()
|
| 90 |
+
pred_idx = int(torch.argmax(logits, dim=1).item())
|
| 91 |
+
|
| 92 |
+
model_class_names = meta["config"]["class_names"]
|
| 93 |
result_text = (
|
| 94 |
f"Échantillon test aléatoire\n"
|
| 95 |
f"Vérité terrain : {label_name}\n"
|
| 96 |
+
f"Prédiction : {model_class_names[pred_idx]}\n"
|
| 97 |
+
f"Confiance : {max(probs):.4f}\n"
|
| 98 |
+
f"Type modèle : {model_type}\n"
|
| 99 |
+
f"Appareil : {device}"
|
| 100 |
)
|
| 101 |
+
prob_dict = {model_class_names[i]: float(probs[i]) for i in range(len(model_class_names))}
|
| 102 |
+
return image, result_text, prob_dict
|
|
|
|
@@ -8,16 +8,24 @@ import torch
|
|
| 8 |
import torch.nn as nn
|
| 9 |
import torch.optim as optim
|
| 10 |
|
| 11 |
-
from config import MODEL_DIR, META_DIR, DATASET_DISPLAY_NAME
|
| 12 |
from data_utils import make_loaders
|
| 13 |
from metrics_utils import compute_classification_metrics, save_confusion_matrix_figure
|
| 14 |
-
from model import SimpleCNN,
|
| 15 |
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
def model_weight_path(model_name: str) -> str:
|
| 18 |
return os.path.join(MODEL_DIR, f"{model_name}.pt")
|
| 19 |
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
def model_meta_path(model_name: str) -> str:
|
| 22 |
return os.path.join(META_DIR, f"{model_name}.json")
|
| 23 |
|
|
@@ -34,43 +42,54 @@ def get_runtime_device() -> torch.device:
|
|
| 34 |
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 35 |
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
def save_model(model: nn.Module, model_name: str, config: dict, training_summary: dict):
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
| 40 |
|
| 41 |
-
|
| 42 |
-
"model_name": model_name,
|
| 43 |
-
"config": config,
|
| 44 |
-
"training_summary": training_summary,
|
| 45 |
-
"created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
| 46 |
-
}
|
| 47 |
|
| 48 |
with open(model_meta_path(model_name), "w", encoding="utf-8") as f:
|
| 49 |
-
json.dump(
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
-
if not os.path.exists(meta_file):
|
| 57 |
-
raise FileNotFoundError(f"Métadonnées introuvables pour le modèle : {model_name}")
|
| 58 |
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
-
with open(meta_file, "r", encoding="utf-8") as f:
|
| 63 |
-
meta = json.load(f)
|
| 64 |
|
|
|
|
|
|
|
| 65 |
cfg = meta["config"]
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
| 72 |
)
|
| 73 |
-
|
|
|
|
| 74 |
model = SimpleCNN(
|
| 75 |
num_classes=cfg["num_classes"],
|
| 76 |
num_conv_blocks=cfg.get("num_conv_blocks", 3),
|
|
@@ -80,206 +99,264 @@ def load_model(model_name: str, device: torch.device) -> Tuple[nn.Module, dict]:
|
|
| 80 |
dropout=cfg.get("dropout", 0.4),
|
| 81 |
fc_dim=cfg.get("fc_dim", 256),
|
| 82 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
-
state_dict = torch.load(weight_file, map_location="cpu")
|
| 85 |
-
model.load_state_dict(state_dict)
|
| 86 |
model.to(device)
|
| 87 |
model.eval()
|
| 88 |
-
|
| 89 |
return model, meta
|
| 90 |
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
def evaluate_loss_acc(model, loader, criterion, device):
|
| 93 |
model.eval()
|
| 94 |
-
|
| 95 |
-
total_loss = 0.0
|
| 96 |
-
total = 0
|
| 97 |
-
correct = 0
|
| 98 |
|
| 99 |
with torch.no_grad():
|
| 100 |
for images, labels in loader:
|
| 101 |
images, labels = images.to(device), labels.to(device)
|
| 102 |
-
|
| 103 |
outputs = model(images)
|
| 104 |
loss = criterion(outputs, labels)
|
| 105 |
-
|
| 106 |
total_loss += loss.item() * images.size(0)
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
correct += (preds == labels).sum().item()
|
| 110 |
total += labels.size(0)
|
| 111 |
|
| 112 |
-
|
| 113 |
-
acc = correct / total if total else 0.0
|
| 114 |
-
|
| 115 |
-
return avg_loss, acc
|
| 116 |
|
| 117 |
|
| 118 |
def collect_predictions(model, loader, device):
|
| 119 |
model.eval()
|
| 120 |
-
|
| 121 |
-
y_true = []
|
| 122 |
-
y_pred = []
|
| 123 |
|
| 124 |
with torch.no_grad():
|
| 125 |
for images, labels in loader:
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
outputs = model(images)
|
| 129 |
-
preds = outputs.argmax(dim=1).detach().cpu().tolist()
|
| 130 |
-
|
| 131 |
-
y_pred.extend(preds)
|
| 132 |
y_true.extend(labels.tolist())
|
| 133 |
|
| 134 |
return y_true, y_pred
|
| 135 |
|
| 136 |
|
| 137 |
-
def
|
| 138 |
-
model_type: str = "cnn",
|
| 139 |
-
num_conv_blocks: int = 3,
|
| 140 |
-
base_filters: int = 32,
|
| 141 |
-
kernel_size: int = 3,
|
| 142 |
-
use_batchnorm: bool = True,
|
| 143 |
-
dropout: float = 0.4,
|
| 144 |
-
fc_dim: int = 256,
|
| 145 |
-
learning_rate: float = 0.001,
|
| 146 |
-
weight_decay: float = 0.0001,
|
| 147 |
-
batch_size: int = 16,
|
| 148 |
-
epochs: int = 30,
|
| 149 |
-
model_tag: str = "",
|
| 150 |
-
):
|
| 151 |
-
device = get_runtime_device()
|
| 152 |
-
|
| 153 |
-
train_loader, val_loader, test_loader, class_names = make_loaders(batch_size)
|
| 154 |
-
num_classes = len(class_names)
|
| 155 |
-
|
| 156 |
-
if model_type == "resnet18":
|
| 157 |
-
model = ResNet18Classifier(
|
| 158 |
-
num_classes=num_classes,
|
| 159 |
-
dropout=dropout,
|
| 160 |
-
fc_dim=fc_dim,
|
| 161 |
-
).to(device)
|
| 162 |
-
else:
|
| 163 |
-
model = SimpleCNN(
|
| 164 |
-
num_classes=num_classes,
|
| 165 |
-
num_conv_blocks=num_conv_blocks,
|
| 166 |
-
base_filters=base_filters,
|
| 167 |
-
kernel_size=kernel_size,
|
| 168 |
-
use_batchnorm=use_batchnorm,
|
| 169 |
-
dropout=dropout,
|
| 170 |
-
fc_dim=fc_dim,
|
| 171 |
-
).to(device)
|
| 172 |
-
|
| 173 |
-
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 174 |
-
total_params = sum(p.numel() for p in model.parameters())
|
| 175 |
-
|
| 176 |
-
criterion = nn.CrossEntropyLoss()
|
| 177 |
-
|
| 178 |
-
optimizer = optim.AdamW(
|
| 179 |
-
filter(lambda p: p.requires_grad, model.parameters()),
|
| 180 |
-
lr=learning_rate,
|
| 181 |
-
weight_decay=weight_decay,
|
| 182 |
-
)
|
| 183 |
-
|
| 184 |
-
# Réduit le LR de moitié si val_loss ne s'améliore pas pendant 8 époques
|
| 185 |
-
# patience élevée car le val set est très petit (bruit important)
|
| 186 |
-
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
| 187 |
-
optimizer,
|
| 188 |
-
mode="min",
|
| 189 |
-
factor=0.5,
|
| 190 |
-
patience=8,
|
| 191 |
-
min_lr=learning_rate * 0.2,
|
| 192 |
-
)
|
| 193 |
-
|
| 194 |
history = []
|
| 195 |
logs = []
|
| 196 |
-
start_time = time.time()
|
| 197 |
-
|
| 198 |
best_val_loss = float("inf")
|
| 199 |
-
|
| 200 |
|
| 201 |
for epoch in range(1, epochs + 1):
|
| 202 |
model.train()
|
| 203 |
-
|
| 204 |
-
running_loss = 0.0
|
| 205 |
-
total = 0
|
| 206 |
-
correct = 0
|
| 207 |
|
| 208 |
for images, labels in train_loader:
|
| 209 |
images, labels = images.to(device), labels.to(device)
|
| 210 |
-
|
| 211 |
optimizer.zero_grad()
|
| 212 |
outputs = model(images)
|
| 213 |
-
|
| 214 |
loss = criterion(outputs, labels)
|
| 215 |
loss.backward()
|
| 216 |
-
|
| 217 |
-
# Important: prevents unstable fine-tuning / exploding gradients
|
| 218 |
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
| 219 |
-
|
| 220 |
optimizer.step()
|
| 221 |
|
| 222 |
running_loss += loss.item() * images.size(0)
|
| 223 |
-
|
| 224 |
-
preds = outputs.argmax(dim=1)
|
| 225 |
-
correct += (preds == labels).sum().item()
|
| 226 |
total += labels.size(0)
|
| 227 |
|
| 228 |
train_loss = running_loss / total if total else 0.0
|
| 229 |
train_acc = correct / total if total else 0.0
|
| 230 |
-
|
| 231 |
val_loss, val_acc = evaluate_loss_acc(model, val_loader, criterion, device)
|
| 232 |
scheduler.step(val_loss)
|
| 233 |
current_lr = optimizer.param_groups[0]["lr"]
|
| 234 |
|
| 235 |
if val_loss < best_val_loss:
|
| 236 |
best_val_loss = val_loss
|
| 237 |
-
|
| 238 |
-
k: v.detach().cpu().clone()
|
| 239 |
-
for k, v in model.state_dict().items()
|
| 240 |
-
}
|
| 241 |
|
| 242 |
-
|
| 243 |
"epoch": epoch,
|
| 244 |
"train_loss": round(train_loss, 4),
|
| 245 |
"train_acc": round(train_acc, 4),
|
| 246 |
"val_loss": round(val_loss, 4),
|
| 247 |
"val_acc": round(val_acc, 4),
|
| 248 |
-
}
|
| 249 |
-
|
| 250 |
-
history.append(row)
|
| 251 |
-
|
| 252 |
logs.append(
|
| 253 |
f"Époque {epoch}/{epochs} | "
|
| 254 |
-
f"perte
|
| 255 |
-
f"perte
|
| 256 |
-
f"lr={current_lr:.
|
| 257 |
)
|
| 258 |
|
| 259 |
-
|
| 260 |
-
model.load_state_dict(best_state_dict)
|
| 261 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
test_loss, test_acc = evaluate_loss_acc(model, test_loader, criterion, device)
|
| 263 |
y_true, y_pred = collect_predictions(model, test_loader, device)
|
| 264 |
-
|
| 265 |
metrics = compute_classification_metrics(y_true, y_pred, class_names)
|
|
|
|
| 266 |
|
| 267 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
|
| 269 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 270 |
-
safe_tag = model_tag.strip().replace(" ", "_") if model_tag.strip() else "
|
| 271 |
model_name = f"{safe_tag}_{timestamp}"
|
| 272 |
|
| 273 |
cm_path = save_confusion_matrix_figure(metrics["confusion_matrix"], model_name)
|
| 274 |
|
| 275 |
-
|
| 276 |
-
architecture = "ResNet18 pré-entraîné (layer4 + classifieur)"
|
| 277 |
-
else:
|
| 278 |
-
architecture = f"CNN simple ({num_conv_blocks} blocs, filtres={base_filters}, noyau={kernel_size}x{kernel_size})"
|
| 279 |
|
| 280 |
config = {
|
| 281 |
"dataset_name": DATASET_DISPLAY_NAME,
|
| 282 |
-
"model_type":
|
| 283 |
"architecture": architecture,
|
| 284 |
"num_classes": num_classes,
|
| 285 |
"class_names": class_names,
|
|
@@ -313,18 +390,16 @@ def train_model(
|
|
| 313 |
|
| 314 |
save_model(model, model_name, config, training_summary)
|
| 315 |
|
| 316 |
-
logs
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
logs.append(f"F1 pondéré test : {metrics['f1_weighted']:.4f}")
|
| 327 |
-
logs.append(f"Temps écoulé : {elapsed:.1f}s")
|
| 328 |
|
| 329 |
return {
|
| 330 |
"logs": "\n".join(logs),
|
|
@@ -337,10 +412,24 @@ def train_model(
|
|
| 337 |
}
|
| 338 |
|
| 339 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
def evaluate_saved_model(model_name: str):
|
| 341 |
if not model_name:
|
| 342 |
raise ValueError("Aucun modèle sélectionné.")
|
| 343 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
device = get_runtime_device()
|
| 345 |
model, meta = load_model(model_name, device)
|
| 346 |
|
|
@@ -348,19 +437,51 @@ def evaluate_saved_model(model_name: str):
|
|
| 348 |
_, _, test_loader, class_names = make_loaders(batch_size)
|
| 349 |
|
| 350 |
criterion = nn.CrossEntropyLoss()
|
| 351 |
-
|
| 352 |
test_loss, test_acc = evaluate_loss_acc(model, test_loader, criterion, device)
|
| 353 |
y_true, y_pred = collect_predictions(model, test_loader, device)
|
| 354 |
|
| 355 |
metrics = compute_classification_metrics(y_true, y_pred, class_names)
|
| 356 |
cm_path = save_confusion_matrix_figure(metrics["confusion_matrix"], model_name)
|
| 357 |
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 365 |
|
| 366 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
import torch.nn as nn
|
| 9 |
import torch.optim as optim
|
| 10 |
|
| 11 |
+
from config import MODEL_DIR, META_DIR, DATASET_DISPLAY_NAME, CLASSICAL_MODEL_TYPES
|
| 12 |
from data_utils import make_loaders
|
| 13 |
from metrics_utils import compute_classification_metrics, save_confusion_matrix_figure
|
| 14 |
+
from model import SimpleCNN, BackboneWithFC
|
| 15 |
|
| 16 |
|
| 17 |
+
# ---------------------------------------------------------------------------
|
| 18 |
+
# Path helpers
|
| 19 |
+
# ---------------------------------------------------------------------------
|
| 20 |
+
|
| 21 |
def model_weight_path(model_name: str) -> str:
|
| 22 |
return os.path.join(MODEL_DIR, f"{model_name}.pt")
|
| 23 |
|
| 24 |
|
| 25 |
+
def classifier_weight_path(model_name: str) -> str:
|
| 26 |
+
return os.path.join(MODEL_DIR, f"{model_name}.joblib")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
def model_meta_path(model_name: str) -> str:
|
| 30 |
return os.path.join(META_DIR, f"{model_name}.json")
|
| 31 |
|
|
|
|
| 42 |
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 43 |
|
| 44 |
|
| 45 |
+
# ---------------------------------------------------------------------------
|
| 46 |
+
# Save / load
|
| 47 |
+
# ---------------------------------------------------------------------------
|
| 48 |
+
|
| 49 |
def save_model(model: nn.Module, model_name: str, config: dict, training_summary: dict):
|
| 50 |
+
if config["model_type"] == "fc_head":
|
| 51 |
+
state_dict = {k: v.detach().cpu() for k, v in model.classifier.state_dict().items()}
|
| 52 |
+
else:
|
| 53 |
+
state_dict = {k: v.detach().cpu() for k, v in model.state_dict().items()}
|
| 54 |
|
| 55 |
+
torch.save(state_dict, model_weight_path(model_name))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
with open(model_meta_path(model_name), "w", encoding="utf-8") as f:
|
| 58 |
+
json.dump(
|
| 59 |
+
{
|
| 60 |
+
"model_name": model_name,
|
| 61 |
+
"config": config,
|
| 62 |
+
"training_summary": training_summary,
|
| 63 |
+
"created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
| 64 |
+
},
|
| 65 |
+
f,
|
| 66 |
+
indent=2,
|
| 67 |
+
ensure_ascii=False,
|
| 68 |
+
)
|
| 69 |
|
|
|
|
|
|
|
| 70 |
|
| 71 |
+
def _load_meta(model_name: str) -> dict:
|
| 72 |
+
path = model_meta_path(model_name)
|
| 73 |
+
if not os.path.exists(path):
|
| 74 |
+
raise FileNotFoundError(f"Métadonnées introuvables : {model_name}")
|
| 75 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 76 |
+
return json.load(f)
|
| 77 |
|
|
|
|
|
|
|
| 78 |
|
| 79 |
+
def load_model(model_name: str, device: torch.device) -> Tuple[nn.Module, dict]:
|
| 80 |
+
meta = _load_meta(model_name)
|
| 81 |
cfg = meta["config"]
|
| 82 |
+
model_type = cfg.get("model_type", "cnn")
|
| 83 |
+
|
| 84 |
+
if model_type == "fc_head":
|
| 85 |
+
from backbone_utils import load_backbone
|
| 86 |
+
backbone = load_backbone(device)
|
| 87 |
+
model = BackboneWithFC(backbone, cfg["num_classes"], cfg.get("dropout", 0.4), cfg.get("fc_dim", 256))
|
| 88 |
+
model.classifier.load_state_dict(
|
| 89 |
+
torch.load(model_weight_path(model_name), map_location="cpu")
|
| 90 |
)
|
| 91 |
+
|
| 92 |
+
elif model_type == "cnn":
|
| 93 |
model = SimpleCNN(
|
| 94 |
num_classes=cfg["num_classes"],
|
| 95 |
num_conv_blocks=cfg.get("num_conv_blocks", 3),
|
|
|
|
| 99 |
dropout=cfg.get("dropout", 0.4),
|
| 100 |
fc_dim=cfg.get("fc_dim", 256),
|
| 101 |
)
|
| 102 |
+
model.load_state_dict(torch.load(model_weight_path(model_name), map_location="cpu"))
|
| 103 |
+
|
| 104 |
+
else:
|
| 105 |
+
raise ValueError(f"load_model n'accepte pas le type '{model_type}'. Utilisez load_classical_pipeline pour les modèles ML classiques.")
|
| 106 |
|
|
|
|
|
|
|
| 107 |
model.to(device)
|
| 108 |
model.eval()
|
|
|
|
| 109 |
return model, meta
|
| 110 |
|
| 111 |
|
| 112 |
+
# ---------------------------------------------------------------------------
|
| 113 |
+
# Training helpers
|
| 114 |
+
# ---------------------------------------------------------------------------
|
| 115 |
+
|
| 116 |
def evaluate_loss_acc(model, loader, criterion, device):
|
| 117 |
model.eval()
|
| 118 |
+
total_loss, total, correct = 0.0, 0, 0
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
with torch.no_grad():
|
| 121 |
for images, labels in loader:
|
| 122 |
images, labels = images.to(device), labels.to(device)
|
|
|
|
| 123 |
outputs = model(images)
|
| 124 |
loss = criterion(outputs, labels)
|
|
|
|
| 125 |
total_loss += loss.item() * images.size(0)
|
| 126 |
+
correct += (outputs.argmax(1) == labels).sum().item()
|
|
|
|
|
|
|
| 127 |
total += labels.size(0)
|
| 128 |
|
| 129 |
+
return (total_loss / total if total else 0.0), (correct / total if total else 0.0)
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
|
| 132 |
def collect_predictions(model, loader, device):
|
| 133 |
model.eval()
|
| 134 |
+
y_true, y_pred = [], []
|
|
|
|
|
|
|
| 135 |
|
| 136 |
with torch.no_grad():
|
| 137 |
for images, labels in loader:
|
| 138 |
+
outputs = model(images.to(device))
|
| 139 |
+
y_pred.extend(outputs.argmax(1).detach().cpu().tolist())
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
y_true.extend(labels.tolist())
|
| 141 |
|
| 142 |
return y_true, y_pred
|
| 143 |
|
| 144 |
|
| 145 |
+
def _training_loop(model, train_loader, val_loader, criterion, optimizer, scheduler, epochs, device):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
history = []
|
| 147 |
logs = []
|
|
|
|
|
|
|
| 148 |
best_val_loss = float("inf")
|
| 149 |
+
best_state = None
|
| 150 |
|
| 151 |
for epoch in range(1, epochs + 1):
|
| 152 |
model.train()
|
| 153 |
+
running_loss, total, correct = 0.0, 0, 0
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
for images, labels in train_loader:
|
| 156 |
images, labels = images.to(device), labels.to(device)
|
|
|
|
| 157 |
optimizer.zero_grad()
|
| 158 |
outputs = model(images)
|
|
|
|
| 159 |
loss = criterion(outputs, labels)
|
| 160 |
loss.backward()
|
|
|
|
|
|
|
| 161 |
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
|
|
|
| 162 |
optimizer.step()
|
| 163 |
|
| 164 |
running_loss += loss.item() * images.size(0)
|
| 165 |
+
correct += (outputs.argmax(1) == labels).sum().item()
|
|
|
|
|
|
|
| 166 |
total += labels.size(0)
|
| 167 |
|
| 168 |
train_loss = running_loss / total if total else 0.0
|
| 169 |
train_acc = correct / total if total else 0.0
|
|
|
|
| 170 |
val_loss, val_acc = evaluate_loss_acc(model, val_loader, criterion, device)
|
| 171 |
scheduler.step(val_loss)
|
| 172 |
current_lr = optimizer.param_groups[0]["lr"]
|
| 173 |
|
| 174 |
if val_loss < best_val_loss:
|
| 175 |
best_val_loss = val_loss
|
| 176 |
+
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
|
|
|
|
|
|
|
|
|
|
| 177 |
|
| 178 |
+
history.append({
|
| 179 |
"epoch": epoch,
|
| 180 |
"train_loss": round(train_loss, 4),
|
| 181 |
"train_acc": round(train_acc, 4),
|
| 182 |
"val_loss": round(val_loss, 4),
|
| 183 |
"val_acc": round(val_acc, 4),
|
| 184 |
+
})
|
|
|
|
|
|
|
|
|
|
| 185 |
logs.append(
|
| 186 |
f"Époque {epoch}/{epochs} | "
|
| 187 |
+
f"perte train={train_loss:.4f} acc train={train_acc:.4f} | "
|
| 188 |
+
f"perte val={val_loss:.4f} acc val={val_acc:.4f} | "
|
| 189 |
+
f"lr={current_lr:.2e}"
|
| 190 |
)
|
| 191 |
|
| 192 |
+
return history, logs, best_state, best_val_loss
|
|
|
|
| 193 |
|
| 194 |
+
|
| 195 |
+
# ---------------------------------------------------------------------------
|
| 196 |
+
# Train FC head on frozen backbone
|
| 197 |
+
# ---------------------------------------------------------------------------
|
| 198 |
+
|
| 199 |
+
def train_fc_head(
|
| 200 |
+
dropout: float = 0.4,
|
| 201 |
+
fc_dim: int = 256,
|
| 202 |
+
learning_rate: float = 1e-4,
|
| 203 |
+
weight_decay: float = 1e-4,
|
| 204 |
+
batch_size: int = 16,
|
| 205 |
+
epochs: int = 20,
|
| 206 |
+
model_tag: str = "",
|
| 207 |
+
):
|
| 208 |
+
from backbone_utils import load_backbone
|
| 209 |
+
|
| 210 |
+
device = get_runtime_device()
|
| 211 |
+
train_loader, val_loader, test_loader, class_names = make_loaders(batch_size)
|
| 212 |
+
num_classes = len(class_names)
|
| 213 |
+
|
| 214 |
+
backbone = load_backbone(device)
|
| 215 |
+
|
| 216 |
+
model = BackboneWithFC(backbone, num_classes, dropout, fc_dim).to(device)
|
| 217 |
+
|
| 218 |
+
trainable_params = sum(p.numel() for p in model.classifier.parameters())
|
| 219 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 220 |
+
|
| 221 |
+
criterion = nn.CrossEntropyLoss()
|
| 222 |
+
optimizer = optim.AdamW(model.classifier.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
| 223 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
| 224 |
+
optimizer, mode="min", factor=0.5, patience=5, min_lr=learning_rate * 0.1
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
t0 = time.time()
|
| 228 |
+
history, logs, best_state, best_val_loss = _training_loop(
|
| 229 |
+
model, train_loader, val_loader, criterion, optimizer, scheduler, epochs, device
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
model.load_state_dict(best_state)
|
| 233 |
test_loss, test_acc = evaluate_loss_acc(model, test_loader, criterion, device)
|
| 234 |
y_true, y_pred = collect_predictions(model, test_loader, device)
|
|
|
|
| 235 |
metrics = compute_classification_metrics(y_true, y_pred, class_names)
|
| 236 |
+
elapsed = time.time() - t0
|
| 237 |
|
| 238 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 239 |
+
safe_tag = model_tag.strip().replace(" ", "_") if model_tag.strip() else "fc_head"
|
| 240 |
+
model_name = f"{safe_tag}_{timestamp}"
|
| 241 |
+
|
| 242 |
+
cm_path = save_confusion_matrix_figure(metrics["confusion_matrix"], model_name)
|
| 243 |
+
|
| 244 |
+
config = {
|
| 245 |
+
"dataset_name": DATASET_DISPLAY_NAME,
|
| 246 |
+
"model_type": "fc_head",
|
| 247 |
+
"architecture": f"ResNet18 backbone (gelé) + FC({fc_dim})",
|
| 248 |
+
"num_classes": num_classes,
|
| 249 |
+
"class_names": class_names,
|
| 250 |
+
"dropout": dropout,
|
| 251 |
+
"fc_dim": fc_dim,
|
| 252 |
+
"learning_rate": learning_rate,
|
| 253 |
+
"weight_decay": weight_decay,
|
| 254 |
+
"batch_size": batch_size,
|
| 255 |
+
"epochs": epochs,
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
training_summary = {
|
| 259 |
+
"final_train_loss": history[-1]["train_loss"] if history else None,
|
| 260 |
+
"final_train_acc": history[-1]["train_acc"] if history else None,
|
| 261 |
+
"best_val_loss": round(best_val_loss, 4),
|
| 262 |
+
"final_val_loss": history[-1]["val_loss"] if history else None,
|
| 263 |
+
"final_val_acc": history[-1]["val_acc"] if history else None,
|
| 264 |
+
"test_cross_entropy_loss": round(test_loss, 4),
|
| 265 |
+
"test_accuracy": round(test_acc, 4),
|
| 266 |
+
"test_f1_macro": metrics["f1_macro"],
|
| 267 |
+
"test_f1_weighted": metrics["f1_weighted"],
|
| 268 |
+
"elapsed_seconds": round(elapsed, 2),
|
| 269 |
+
"device": str(device),
|
| 270 |
+
"total_params": total_params,
|
| 271 |
+
"trainable_params": trainable_params,
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
save_model(model, model_name, config, training_summary)
|
| 275 |
+
|
| 276 |
+
logs += [
|
| 277 |
+
"",
|
| 278 |
+
"Entraînement terminé.",
|
| 279 |
+
f"Modèle sauvegardé : {model_name}",
|
| 280 |
+
f"Architecture : {config['architecture']}",
|
| 281 |
+
f"Paramètres entraînables : {trainable_params} / {total_params}",
|
| 282 |
+
f"Perte test : {test_loss:.4f} | Accuracy test : {test_acc:.4f}",
|
| 283 |
+
f"F1 macro : {metrics['f1_macro']:.4f} | F1 pondéré : {metrics['f1_weighted']:.4f}",
|
| 284 |
+
f"Temps : {elapsed:.1f}s | Appareil : {device}",
|
| 285 |
+
]
|
| 286 |
+
|
| 287 |
+
return {
|
| 288 |
+
"logs": "\n".join(logs),
|
| 289 |
+
"history": history,
|
| 290 |
+
"summary": training_summary,
|
| 291 |
+
"model_name": model_name,
|
| 292 |
+
"classification_report": metrics["classification_report"],
|
| 293 |
+
"confusion_matrix": metrics["confusion_matrix"],
|
| 294 |
+
"confusion_matrix_path": cm_path,
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
# ---------------------------------------------------------------------------
|
| 299 |
+
# Train SimpleCNN from scratch
|
| 300 |
+
# ---------------------------------------------------------------------------
|
| 301 |
+
|
| 302 |
+
def train_cnn(
|
| 303 |
+
num_conv_blocks: int = 3,
|
| 304 |
+
base_filters: int = 32,
|
| 305 |
+
kernel_size: int = 3,
|
| 306 |
+
use_batchnorm: bool = True,
|
| 307 |
+
dropout: float = 0.4,
|
| 308 |
+
fc_dim: int = 256,
|
| 309 |
+
learning_rate: float = 1e-3,
|
| 310 |
+
weight_decay: float = 1e-4,
|
| 311 |
+
batch_size: int = 16,
|
| 312 |
+
epochs: int = 30,
|
| 313 |
+
model_tag: str = "",
|
| 314 |
+
):
|
| 315 |
+
device = get_runtime_device()
|
| 316 |
+
train_loader, val_loader, test_loader, class_names = make_loaders(batch_size)
|
| 317 |
+
num_classes = len(class_names)
|
| 318 |
+
|
| 319 |
+
model = SimpleCNN(
|
| 320 |
+
num_classes=num_classes,
|
| 321 |
+
num_conv_blocks=num_conv_blocks,
|
| 322 |
+
base_filters=base_filters,
|
| 323 |
+
kernel_size=kernel_size,
|
| 324 |
+
use_batchnorm=use_batchnorm,
|
| 325 |
+
dropout=dropout,
|
| 326 |
+
fc_dim=fc_dim,
|
| 327 |
+
).to(device)
|
| 328 |
+
|
| 329 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 330 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 331 |
+
|
| 332 |
+
criterion = nn.CrossEntropyLoss()
|
| 333 |
+
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
| 334 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
| 335 |
+
optimizer, mode="min", factor=0.5, patience=8, min_lr=learning_rate * 0.2
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
t0 = time.time()
|
| 339 |
+
history, logs, best_state, best_val_loss = _training_loop(
|
| 340 |
+
model, train_loader, val_loader, criterion, optimizer, scheduler, epochs, device
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
model.load_state_dict(best_state)
|
| 344 |
+
test_loss, test_acc = evaluate_loss_acc(model, test_loader, criterion, device)
|
| 345 |
+
y_true, y_pred = collect_predictions(model, test_loader, device)
|
| 346 |
+
metrics = compute_classification_metrics(y_true, y_pred, class_names)
|
| 347 |
+
elapsed = time.time() - t0
|
| 348 |
|
| 349 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 350 |
+
safe_tag = model_tag.strip().replace(" ", "_") if model_tag.strip() else "cnn"
|
| 351 |
model_name = f"{safe_tag}_{timestamp}"
|
| 352 |
|
| 353 |
cm_path = save_confusion_matrix_figure(metrics["confusion_matrix"], model_name)
|
| 354 |
|
| 355 |
+
architecture = f"CNN simple ({num_conv_blocks} blocs, filtres={base_filters}, noyau={kernel_size}×{kernel_size})"
|
|
|
|
|
|
|
|
|
|
| 356 |
|
| 357 |
config = {
|
| 358 |
"dataset_name": DATASET_DISPLAY_NAME,
|
| 359 |
+
"model_type": "cnn",
|
| 360 |
"architecture": architecture,
|
| 361 |
"num_classes": num_classes,
|
| 362 |
"class_names": class_names,
|
|
|
|
| 390 |
|
| 391 |
save_model(model, model_name, config, training_summary)
|
| 392 |
|
| 393 |
+
logs += [
|
| 394 |
+
"",
|
| 395 |
+
"Entraînement terminé.",
|
| 396 |
+
f"Modèle sauvegardé : {model_name}",
|
| 397 |
+
f"Architecture : {architecture}",
|
| 398 |
+
f"Paramètres : {total_params}",
|
| 399 |
+
f"Perte test : {test_loss:.4f} | Accuracy test : {test_acc:.4f}",
|
| 400 |
+
f"F1 macro : {metrics['f1_macro']:.4f} | F1 pondéré : {metrics['f1_weighted']:.4f}",
|
| 401 |
+
f"Temps : {elapsed:.1f}s | Appareil : {device}",
|
| 402 |
+
]
|
|
|
|
|
|
|
| 403 |
|
| 404 |
return {
|
| 405 |
"logs": "\n".join(logs),
|
|
|
|
| 412 |
}
|
| 413 |
|
| 414 |
|
| 415 |
+
# ---------------------------------------------------------------------------
|
| 416 |
+
# Evaluate any saved model
|
| 417 |
+
# ---------------------------------------------------------------------------
|
| 418 |
+
|
| 419 |
def evaluate_saved_model(model_name: str):
|
| 420 |
if not model_name:
|
| 421 |
raise ValueError("Aucun modèle sélectionné.")
|
| 422 |
|
| 423 |
+
meta = _load_meta(model_name)
|
| 424 |
+
model_type = meta["config"].get("model_type", "cnn")
|
| 425 |
+
|
| 426 |
+
if model_type in CLASSICAL_MODEL_TYPES:
|
| 427 |
+
return _evaluate_classical(model_name, meta)
|
| 428 |
+
else:
|
| 429 |
+
return _evaluate_neural(model_name, meta)
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
def _evaluate_neural(model_name: str, meta: dict):
|
| 433 |
device = get_runtime_device()
|
| 434 |
model, meta = load_model(model_name, device)
|
| 435 |
|
|
|
|
| 437 |
_, _, test_loader, class_names = make_loaders(batch_size)
|
| 438 |
|
| 439 |
criterion = nn.CrossEntropyLoss()
|
|
|
|
| 440 |
test_loss, test_acc = evaluate_loss_acc(model, test_loader, criterion, device)
|
| 441 |
y_true, y_pred = collect_predictions(model, test_loader, device)
|
| 442 |
|
| 443 |
metrics = compute_classification_metrics(y_true, y_pred, class_names)
|
| 444 |
cm_path = save_confusion_matrix_figure(metrics["confusion_matrix"], model_name)
|
| 445 |
|
| 446 |
+
return (
|
| 447 |
+
{
|
| 448 |
+
"test_cross_entropy_loss": round(test_loss, 4),
|
| 449 |
+
"test_accuracy": round(test_acc, 4),
|
| 450 |
+
"test_f1_macro": metrics["f1_macro"],
|
| 451 |
+
"test_f1_weighted": metrics["f1_weighted"],
|
| 452 |
+
"device": str(device),
|
| 453 |
+
},
|
| 454 |
+
metrics["classification_report"],
|
| 455 |
+
metrics["confusion_matrix"],
|
| 456 |
+
cm_path,
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
def _evaluate_classical(model_name: str, meta: dict):
|
| 461 |
+
from backbone_utils import get_cached_features, extract_all_features
|
| 462 |
+
from classical_ml_utils import load_classical_pipeline
|
| 463 |
+
|
| 464 |
+
features_cache = get_cached_features()
|
| 465 |
+
if features_cache is None:
|
| 466 |
+
features_cache, _, _ = extract_all_features()
|
| 467 |
|
| 468 |
+
class_names = meta["config"]["class_names"]
|
| 469 |
+
pipeline = load_classical_pipeline(model_name)
|
| 470 |
+
|
| 471 |
+
X_test = features_cache["test"]["X"]
|
| 472 |
+
y_test = features_cache["test"]["y"]
|
| 473 |
+
y_pred = pipeline.predict(X_test)
|
| 474 |
+
|
| 475 |
+
metrics = compute_classification_metrics(y_test.tolist(), y_pred.tolist(), class_names)
|
| 476 |
+
cm_path = save_confusion_matrix_figure(metrics["confusion_matrix"], model_name)
|
| 477 |
+
|
| 478 |
+
return (
|
| 479 |
+
{
|
| 480 |
+
"test_accuracy": metrics["accuracy"],
|
| 481 |
+
"test_f1_macro": metrics["f1_macro"],
|
| 482 |
+
"test_f1_weighted": metrics["f1_weighted"],
|
| 483 |
+
},
|
| 484 |
+
metrics["classification_report"],
|
| 485 |
+
metrics["confusion_matrix"],
|
| 486 |
+
cm_path,
|
| 487 |
+
)
|