CircleStar commited on
Commit
5fce1fe
·
verified ·
1 Parent(s): e8577ab

Update train_utils.py

Browse files
Files changed (1) hide show
  1. train_utils.py +23 -14
train_utils.py CHANGED
@@ -55,6 +55,7 @@ def load_model(model_name: str, device: torch.device) -> Tuple[nn.Module, dict]:
55
 
56
  if not os.path.exists(meta_file):
57
  raise FileNotFoundError(f"Métadonnées introuvables pour le modèle : {model_name}")
 
58
  if not os.path.exists(weight_file):
59
  raise FileNotFoundError(f"Poids introuvables pour le modèle : {model_name}")
60
 
@@ -65,9 +66,9 @@ def load_model(model_name: str, device: torch.device) -> Tuple[nn.Module, dict]:
65
 
66
  model = ResNet18Classifier(
67
  num_classes=cfg["num_classes"],
68
- dropout=cfg["dropout"],
69
- fc_dim=cfg["fc_dim"],
70
- freeze_backbone=cfg.get("freeze_backbone", True),
71
  )
72
 
73
  state_dict = torch.load(weight_file, map_location="cpu")
@@ -113,6 +114,7 @@ def collect_predictions(model, loader, device):
113
  with torch.no_grad():
114
  for images, labels in loader:
115
  images = images.to(device)
 
116
  outputs = model(images)
117
  preds = outputs.argmax(dim=1).detach().cpu().tolist()
118
 
@@ -123,14 +125,14 @@ def collect_predictions(model, loader, device):
123
 
124
 
125
  def train_model(
126
- dropout: float,
127
- fc_dim: int,
128
- learning_rate: float,
129
- weight_decay: float,
130
- batch_size: int,
131
- epochs: int,
132
- freeze_backbone: bool,
133
- model_tag: str,
134
  ):
135
  device = get_runtime_device()
136
 
@@ -141,13 +143,14 @@ def train_model(
141
  num_classes=num_classes,
142
  dropout=dropout,
143
  fc_dim=fc_dim,
144
- freeze_backbone=freeze_backbone,
145
  ).to(device)
146
 
147
  trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
148
  total_params = sum(p.numel() for p in model.parameters())
149
 
150
  criterion = nn.CrossEntropyLoss()
 
151
  optimizer = optim.AdamW(
152
  filter(lambda p: p.requires_grad, model.parameters()),
153
  lr=learning_rate,
@@ -176,6 +179,10 @@ def train_model(
176
 
177
  loss = criterion(outputs, labels)
178
  loss.backward()
 
 
 
 
179
  optimizer.step()
180
 
181
  running_loss += loss.item() * images.size(0)
@@ -239,13 +246,14 @@ def train_model(
239
  "weight_decay": weight_decay,
240
  "batch_size": batch_size,
241
  "epochs": epochs,
242
- "freeze_backbone": freeze_backbone,
243
  }
244
 
245
  training_summary = {
246
  "final_train_loss": history[-1]["train_loss"] if history else None,
247
  "final_train_acc": history[-1]["train_acc"] if history else None,
248
  "best_val_loss": round(best_val_loss, 4),
 
249
  "final_val_acc": history[-1]["val_acc"] if history else None,
250
  "test_cross_entropy_loss": round(test_loss, 4),
251
  "test_accuracy": round(test_acc, 4),
@@ -263,6 +271,7 @@ def train_model(
263
  logs.append("Entraînement terminé.")
264
  logs.append(f"Modèle sauvegardé : {model_name}")
265
  logs.append(f"Appareil utilisé : {device}")
 
266
  logs.append(f"Nombre total de paramètres : {total_params}")
267
  logs.append(f"Paramètres entraînables : {trainable_params}")
268
  logs.append(f"Perte test cross-entropy : {test_loss:.4f}")
@@ -289,7 +298,7 @@ def evaluate_saved_model(model_name: str):
289
  device = get_runtime_device()
290
  model, meta = load_model(model_name, device)
291
 
292
- batch_size = int(meta["config"].get("batch_size", 32))
293
  _, _, test_loader, class_names = make_loaders(batch_size)
294
 
295
  criterion = nn.CrossEntropyLoss()
 
55
 
56
  if not os.path.exists(meta_file):
57
  raise FileNotFoundError(f"Métadonnées introuvables pour le modèle : {model_name}")
58
+
59
  if not os.path.exists(weight_file):
60
  raise FileNotFoundError(f"Poids introuvables pour le modèle : {model_name}")
61
 
 
66
 
67
  model = ResNet18Classifier(
68
  num_classes=cfg["num_classes"],
69
+ dropout=cfg.get("dropout", 0.4),
70
+ fc_dim=cfg.get("fc_dim", 256),
71
+ fine_tune_mode=cfg.get("fine_tune_mode", "layer4"),
72
  )
73
 
74
  state_dict = torch.load(weight_file, map_location="cpu")
 
114
  with torch.no_grad():
115
  for images, labels in loader:
116
  images = images.to(device)
117
+
118
  outputs = model(images)
119
  preds = outputs.argmax(dim=1).detach().cpu().tolist()
120
 
 
125
 
126
 
127
  def train_model(
128
+ dropout: float = 0.4,
129
+ fc_dim: int = 256,
130
+ learning_rate: float = 0.00001,
131
+ weight_decay: float = 0.0001,
132
+ batch_size: int = 16,
133
+ epochs: int = 30,
134
+ fine_tune_mode: str = "layer4",
135
+ model_tag: str = "",
136
  ):
137
  device = get_runtime_device()
138
 
 
143
  num_classes=num_classes,
144
  dropout=dropout,
145
  fc_dim=fc_dim,
146
+ fine_tune_mode=fine_tune_mode,
147
  ).to(device)
148
 
149
  trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
150
  total_params = sum(p.numel() for p in model.parameters())
151
 
152
  criterion = nn.CrossEntropyLoss()
153
+
154
  optimizer = optim.AdamW(
155
  filter(lambda p: p.requires_grad, model.parameters()),
156
  lr=learning_rate,
 
179
 
180
  loss = criterion(outputs, labels)
181
  loss.backward()
182
+
183
+ # Important: prevents unstable fine-tuning / exploding gradients
184
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
185
+
186
  optimizer.step()
187
 
188
  running_loss += loss.item() * images.size(0)
 
246
  "weight_decay": weight_decay,
247
  "batch_size": batch_size,
248
  "epochs": epochs,
249
+ "fine_tune_mode": fine_tune_mode,
250
  }
251
 
252
  training_summary = {
253
  "final_train_loss": history[-1]["train_loss"] if history else None,
254
  "final_train_acc": history[-1]["train_acc"] if history else None,
255
  "best_val_loss": round(best_val_loss, 4),
256
+ "final_val_loss": history[-1]["val_loss"] if history else None,
257
  "final_val_acc": history[-1]["val_acc"] if history else None,
258
  "test_cross_entropy_loss": round(test_loss, 4),
259
  "test_accuracy": round(test_acc, 4),
 
271
  logs.append("Entraînement terminé.")
272
  logs.append(f"Modèle sauvegardé : {model_name}")
273
  logs.append(f"Appareil utilisé : {device}")
274
+ logs.append(f"Mode de fine-tuning : {fine_tune_mode}")
275
  logs.append(f"Nombre total de paramètres : {total_params}")
276
  logs.append(f"Paramètres entraînables : {trainable_params}")
277
  logs.append(f"Perte test cross-entropy : {test_loss:.4f}")
 
298
  device = get_runtime_device()
299
  model, meta = load_model(model_name, device)
300
 
301
+ batch_size = int(meta["config"].get("batch_size", 16))
302
  _, _, test_loader, class_names = make_loaders(batch_size)
303
 
304
  criterion = nn.CrossEntropyLoss()