functionNormally Claude Sonnet 4.6 commited on
Commit
7ceea37
·
1 Parent(s): 948c799

Ajouter ResNet18 (layer4 + classifieur) comme option de modèle

Browse files

- model.py : classe ResNet18Classifier restaurée (layer4 + tête FC)
- train_utils.py : paramètre model_type, instanciation et config selon le choix
- app.py : radio CNN simple / ResNet18, panneau CNN masquable,
lr ajusté automatiquement selon le modèle sélectionné

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (3) hide show
  1. app.py +54 -29
  2. model.py +29 -0
  3. train_utils.py +42 -21
app.py CHANGED
@@ -51,8 +51,15 @@ def refresh_gallery_callback(split_name, class_name, max_images):
51
  return [(None, f"Erreur : {str(e)}")]
52
 
53
 
 
 
 
 
 
 
54
  @spaces.GPU(duration=200)
55
  def train_callback(
 
56
  num_conv_blocks,
57
  base_filters,
58
  kernel_size,
@@ -67,6 +74,7 @@ def train_callback(
67
  ):
68
  try:
69
  result = train_model(
 
70
  num_conv_blocks=int(num_conv_blocks),
71
  base_filters=int(base_filters),
72
  kernel_size=int(kernel_size),
@@ -205,39 +213,49 @@ with gr.Blocks(title="Classification d’images microscopiques") as demo:
205
  )
206
 
207
  with gr.Tab("2. Entraîner un modèle"):
208
- gr.Markdown("## Entraînement d’un CNN simple (entraîné de zéro)")
209
- gr.Markdown(
210
- "Configurez librement l’architecture du CNN : nombre de blocs convolutionnels, "
211
- "nombre de filtres, taille du noyau, etc. Tous les paramètres sont entraînables."
212
- )
213
 
214
  with gr.Row():
215
  with gr.Column():
216
- num_conv_blocks = gr.Slider(
217
- minimum=2,
218
- maximum=5,
219
- value=3,
220
- step=1,
221
- label="Nombre de blocs convolutionnels",
222
- info="Chaque bloc enchaîne Conv2d (BN) ReLU → MaxPool2d.",
223
- )
224
-
225
- base_filters = gr.Dropdown(
226
- choices=[16, 32, 64, 128],
227
- value=32,
228
- label="Filtres du premier bloc (doublent à chaque bloc)",
229
- )
230
-
231
- kernel_size = gr.Dropdown(
232
- choices=[3, 5],
233
- value=3,
234
- label="Taille du noyau de convolution",
235
  )
236
 
237
- use_batchnorm = gr.Checkbox(
238
- value=True,
239
- label="Normalisation par lots (BatchNorm)",
240
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
 
242
  dropout = gr.Slider(
243
  minimum=0.0,
@@ -279,7 +297,7 @@ with gr.Blocks(title="Classification d’images microscopiques") as demo:
279
 
280
  model_tag = gr.Textbox(
281
  label="Nom court du modèle",
282
- placeholder="ex. cnn_3blocs_32filtres",
283
  )
284
 
285
  train_btn = gr.Button("Lancer l’entraînement", variant="primary")
@@ -378,9 +396,16 @@ with gr.Blocks(title="Classification d’images microscopiques") as demo:
378
  outputs=image_gallery,
379
  )
380
 
 
 
 
 
 
 
381
  train_btn.click(
382
  fn=train_callback,
383
  inputs=[
 
384
  num_conv_blocks,
385
  base_filters,
386
  kernel_size,
 
51
  return [(None, f"Erreur : {str(e)}")]
52
 
53
 
54
+ def on_model_type_change(model_type):
55
+ is_cnn = (model_type == "CNN simple")
56
+ default_lr = 0.001 if is_cnn else 0.0001
57
+ return gr.update(visible=is_cnn), gr.update(value=default_lr)
58
+
59
+
60
  @spaces.GPU(duration=200)
61
  def train_callback(
62
+ model_type,
63
  num_conv_blocks,
64
  base_filters,
65
  kernel_size,
 
74
  ):
75
  try:
76
  result = train_model(
77
+ model_type="cnn" if model_type == "CNN simple" else "resnet18",
78
  num_conv_blocks=int(num_conv_blocks),
79
  base_filters=int(base_filters),
80
  kernel_size=int(kernel_size),
 
213
  )
214
 
215
  with gr.Tab("2. Entraîner un modèle"):
216
+ gr.Markdown("## Choix du modèle et entraînement")
 
 
 
 
217
 
218
  with gr.Row():
219
  with gr.Column():
220
+ model_type = gr.Radio(
221
+ choices=["CNN simple", "ResNet18"],
222
+ value="CNN simple",
223
+ label="Architecture",
224
+ info=(
225
+ "CNN simple : entraîné de zéro, paramètres configurables. "
226
+ "ResNet18 : pré-entraîné ImageNet, fine-tuning layer4 + classifieur."
227
+ ),
 
 
 
 
 
 
 
 
 
 
 
228
  )
229
 
230
+ with gr.Column(visible=True) as cnn_params_col:
231
+ gr.Markdown("#### Paramètres CNN")
232
+ num_conv_blocks = gr.Slider(
233
+ minimum=2,
234
+ maximum=5,
235
+ value=3,
236
+ step=1,
237
+ label="Nombre de blocs convolutionnels",
238
+ info="Chaque bloc enchaîne Conv2d → (BN) → ReLU → MaxPool2d.",
239
+ )
240
+
241
+ base_filters = gr.Dropdown(
242
+ choices=[16, 32, 64, 128],
243
+ value=32,
244
+ label="Filtres du premier bloc (doublent à chaque bloc)",
245
+ )
246
+
247
+ kernel_size = gr.Dropdown(
248
+ choices=[3, 5],
249
+ value=3,
250
+ label="Taille du noyau de convolution",
251
+ )
252
+
253
+ use_batchnorm = gr.Checkbox(
254
+ value=True,
255
+ label="Normalisation par lots (BatchNorm)",
256
+ )
257
+
258
+ gr.Markdown("#### Hyperparamètres d’entraînement")
259
 
260
  dropout = gr.Slider(
261
  minimum=0.0,
 
297
 
298
  model_tag = gr.Textbox(
299
  label="Nom court du modèle",
300
+ placeholder="ex. cnn_3blocs ou resnet18_ft",
301
  )
302
 
303
  train_btn = gr.Button("Lancer l’entraînement", variant="primary")
 
396
  outputs=image_gallery,
397
  )
398
 
399
+ model_type.change(
400
+ fn=on_model_type_change,
401
+ inputs=model_type,
402
+ outputs=[cnn_params_col, learning_rate],
403
+ )
404
+
405
  train_btn.click(
406
  fn=train_callback,
407
  inputs=[
408
+ model_type,
409
  num_conv_blocks,
410
  base_filters,
411
  kernel_size,
model.py CHANGED
@@ -1,4 +1,33 @@
1
  import torch.nn as nn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
 
4
  class SimpleCNN(nn.Module):
 
1
  import torch.nn as nn
2
+ from torchvision import models
3
+
4
+
5
+ class ResNet18Classifier(nn.Module):
6
+ def __init__(self, num_classes: int, dropout: float = 0.4, fc_dim: int = 256):
7
+ super().__init__()
8
+
9
+ weights = models.ResNet18_Weights.DEFAULT
10
+ self.backbone = models.resnet18(weights=weights)
11
+ in_features = self.backbone.fc.in_features
12
+
13
+ # Gel de tout le réseau sauf layer4 et classifieur
14
+ for param in self.backbone.parameters():
15
+ param.requires_grad = False
16
+ for param in self.backbone.layer4.parameters():
17
+ param.requires_grad = True
18
+
19
+ self.backbone.fc = nn.Sequential(
20
+ nn.Dropout(dropout),
21
+ nn.Linear(in_features, fc_dim),
22
+ nn.ReLU(),
23
+ nn.Dropout(dropout),
24
+ nn.Linear(fc_dim, num_classes),
25
+ )
26
+ for param in self.backbone.fc.parameters():
27
+ param.requires_grad = True
28
+
29
+ def forward(self, x):
30
+ return self.backbone(x)
31
 
32
 
33
  class SimpleCNN(nn.Module):
train_utils.py CHANGED
@@ -11,7 +11,7 @@ import torch.optim as optim
11
  from config import MODEL_DIR, META_DIR, DATASET_DISPLAY_NAME
12
  from data_utils import make_loaders
13
  from metrics_utils import compute_classification_metrics, save_confusion_matrix_figure
14
- from model import SimpleCNN
15
 
16
 
17
  def model_weight_path(model_name: str) -> str:
@@ -64,15 +64,22 @@ def load_model(model_name: str, device: torch.device) -> Tuple[nn.Module, dict]:
64
 
65
  cfg = meta["config"]
66
 
67
- model = SimpleCNN(
68
- num_classes=cfg["num_classes"],
69
- num_conv_blocks=cfg.get("num_conv_blocks", 3),
70
- base_filters=cfg.get("base_filters", 32),
71
- kernel_size=cfg.get("kernel_size", 3),
72
- use_batchnorm=cfg.get("use_batchnorm", True),
73
- dropout=cfg.get("dropout", 0.4),
74
- fc_dim=cfg.get("fc_dim", 256),
75
- )
 
 
 
 
 
 
 
76
 
77
  state_dict = torch.load(weight_file, map_location="cpu")
78
  model.load_state_dict(state_dict)
@@ -128,6 +135,7 @@ def collect_predictions(model, loader, device):
128
 
129
 
130
  def train_model(
 
131
  num_conv_blocks: int = 3,
132
  base_filters: int = 32,
133
  kernel_size: int = 3,
@@ -145,15 +153,22 @@ def train_model(
145
  train_loader, val_loader, test_loader, class_names = make_loaders(batch_size)
146
  num_classes = len(class_names)
147
 
148
- model = SimpleCNN(
149
- num_classes=num_classes,
150
- num_conv_blocks=num_conv_blocks,
151
- base_filters=base_filters,
152
- kernel_size=kernel_size,
153
- use_batchnorm=use_batchnorm,
154
- dropout=dropout,
155
- fc_dim=fc_dim,
156
- ).to(device)
 
 
 
 
 
 
 
157
 
158
  trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
159
  total_params = sum(p.numel() for p in model.parameters())
@@ -257,9 +272,15 @@ def train_model(
257
 
258
  cm_path = save_confusion_matrix_figure(metrics["confusion_matrix"], model_name)
259
 
 
 
 
 
 
260
  config = {
261
  "dataset_name": DATASET_DISPLAY_NAME,
262
- "architecture": "CNN simple entraîné de zéro",
 
263
  "num_classes": num_classes,
264
  "class_names": class_names,
265
  "num_conv_blocks": num_conv_blocks,
@@ -296,7 +317,7 @@ def train_model(
296
  logs.append("Entraînement terminé.")
297
  logs.append(f"Modèle sauvegardé : {model_name}")
298
  logs.append(f"Appareil utilisé : {device}")
299
- logs.append(f"Architecture : {num_conv_blocks} blocs conv, filtres de base={base_filters}, noyau={kernel_size}x{kernel_size}, BatchNorm={use_batchnorm}")
300
  logs.append(f"Nombre total de paramètres : {total_params}")
301
  logs.append(f"Paramètres entraînables : {trainable_params}")
302
  logs.append(f"Perte test cross-entropy : {test_loss:.4f}")
 
11
  from config import MODEL_DIR, META_DIR, DATASET_DISPLAY_NAME
12
  from data_utils import make_loaders
13
  from metrics_utils import compute_classification_metrics, save_confusion_matrix_figure
14
+ from model import SimpleCNN, ResNet18Classifier
15
 
16
 
17
  def model_weight_path(model_name: str) -> str:
 
64
 
65
  cfg = meta["config"]
66
 
67
+ if cfg.get("model_type", "cnn") == "resnet18":
68
+ model = ResNet18Classifier(
69
+ num_classes=cfg["num_classes"],
70
+ dropout=cfg.get("dropout", 0.4),
71
+ fc_dim=cfg.get("fc_dim", 256),
72
+ )
73
+ else:
74
+ model = SimpleCNN(
75
+ num_classes=cfg["num_classes"],
76
+ num_conv_blocks=cfg.get("num_conv_blocks", 3),
77
+ base_filters=cfg.get("base_filters", 32),
78
+ kernel_size=cfg.get("kernel_size", 3),
79
+ use_batchnorm=cfg.get("use_batchnorm", True),
80
+ dropout=cfg.get("dropout", 0.4),
81
+ fc_dim=cfg.get("fc_dim", 256),
82
+ )
83
 
84
  state_dict = torch.load(weight_file, map_location="cpu")
85
  model.load_state_dict(state_dict)
 
135
 
136
 
137
  def train_model(
138
+ model_type: str = "cnn",
139
  num_conv_blocks: int = 3,
140
  base_filters: int = 32,
141
  kernel_size: int = 3,
 
153
  train_loader, val_loader, test_loader, class_names = make_loaders(batch_size)
154
  num_classes = len(class_names)
155
 
156
+ if model_type == "resnet18":
157
+ model = ResNet18Classifier(
158
+ num_classes=num_classes,
159
+ dropout=dropout,
160
+ fc_dim=fc_dim,
161
+ ).to(device)
162
+ else:
163
+ model = SimpleCNN(
164
+ num_classes=num_classes,
165
+ num_conv_blocks=num_conv_blocks,
166
+ base_filters=base_filters,
167
+ kernel_size=kernel_size,
168
+ use_batchnorm=use_batchnorm,
169
+ dropout=dropout,
170
+ fc_dim=fc_dim,
171
+ ).to(device)
172
 
173
  trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
174
  total_params = sum(p.numel() for p in model.parameters())
 
272
 
273
  cm_path = save_confusion_matrix_figure(metrics["confusion_matrix"], model_name)
274
 
275
+ if model_type == "resnet18":
276
+ architecture = "ResNet18 pré-entraîné (layer4 + classifieur)"
277
+ else:
278
+ architecture = f"CNN simple ({num_conv_blocks} blocs, filtres={base_filters}, noyau={kernel_size}x{kernel_size})"
279
+
280
  config = {
281
  "dataset_name": DATASET_DISPLAY_NAME,
282
+ "model_type": model_type,
283
+ "architecture": architecture,
284
  "num_classes": num_classes,
285
  "class_names": class_names,
286
  "num_conv_blocks": num_conv_blocks,
 
317
  logs.append("Entraînement terminé.")
318
  logs.append(f"Modèle sauvegardé : {model_name}")
319
  logs.append(f"Appareil utilisé : {device}")
320
+ logs.append(f"Architecture : {architecture}")
321
  logs.append(f"Nombre total de paramètres : {total_params}")
322
  logs.append(f"Paramètres entraînables : {trainable_params}")
323
  logs.append(f"Perte test cross-entropy : {test_loss:.4f}")