Spaces:
Sleeping
Sleeping
Update train_utils.py
Browse files- train_utils.py +23 -14
train_utils.py
CHANGED
|
@@ -55,6 +55,7 @@ def load_model(model_name: str, device: torch.device) -> Tuple[nn.Module, dict]:
|
|
| 55 |
|
| 56 |
if not os.path.exists(meta_file):
|
| 57 |
raise FileNotFoundError(f"Métadonnées introuvables pour le modèle : {model_name}")
|
|
|
|
| 58 |
if not os.path.exists(weight_file):
|
| 59 |
raise FileNotFoundError(f"Poids introuvables pour le modèle : {model_name}")
|
| 60 |
|
|
@@ -65,9 +66,9 @@ def load_model(model_name: str, device: torch.device) -> Tuple[nn.Module, dict]:
|
|
| 65 |
|
| 66 |
model = ResNet18Classifier(
|
| 67 |
num_classes=cfg["num_classes"],
|
| 68 |
-
dropout=cfg
|
| 69 |
-
fc_dim=cfg
|
| 70 |
-
|
| 71 |
)
|
| 72 |
|
| 73 |
state_dict = torch.load(weight_file, map_location="cpu")
|
|
@@ -113,6 +114,7 @@ def collect_predictions(model, loader, device):
|
|
| 113 |
with torch.no_grad():
|
| 114 |
for images, labels in loader:
|
| 115 |
images = images.to(device)
|
|
|
|
| 116 |
outputs = model(images)
|
| 117 |
preds = outputs.argmax(dim=1).detach().cpu().tolist()
|
| 118 |
|
|
@@ -123,14 +125,14 @@ def collect_predictions(model, loader, device):
|
|
| 123 |
|
| 124 |
|
| 125 |
def train_model(
|
| 126 |
-
dropout: float,
|
| 127 |
-
fc_dim: int,
|
| 128 |
-
learning_rate: float,
|
| 129 |
-
weight_decay: float,
|
| 130 |
-
batch_size: int,
|
| 131 |
-
epochs: int,
|
| 132 |
-
|
| 133 |
-
model_tag: str,
|
| 134 |
):
|
| 135 |
device = get_runtime_device()
|
| 136 |
|
|
@@ -141,13 +143,14 @@ def train_model(
|
|
| 141 |
num_classes=num_classes,
|
| 142 |
dropout=dropout,
|
| 143 |
fc_dim=fc_dim,
|
| 144 |
-
|
| 145 |
).to(device)
|
| 146 |
|
| 147 |
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 148 |
total_params = sum(p.numel() for p in model.parameters())
|
| 149 |
|
| 150 |
criterion = nn.CrossEntropyLoss()
|
|
|
|
| 151 |
optimizer = optim.AdamW(
|
| 152 |
filter(lambda p: p.requires_grad, model.parameters()),
|
| 153 |
lr=learning_rate,
|
|
@@ -176,6 +179,10 @@ def train_model(
|
|
| 176 |
|
| 177 |
loss = criterion(outputs, labels)
|
| 178 |
loss.backward()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
optimizer.step()
|
| 180 |
|
| 181 |
running_loss += loss.item() * images.size(0)
|
|
@@ -239,13 +246,14 @@ def train_model(
|
|
| 239 |
"weight_decay": weight_decay,
|
| 240 |
"batch_size": batch_size,
|
| 241 |
"epochs": epochs,
|
| 242 |
-
"
|
| 243 |
}
|
| 244 |
|
| 245 |
training_summary = {
|
| 246 |
"final_train_loss": history[-1]["train_loss"] if history else None,
|
| 247 |
"final_train_acc": history[-1]["train_acc"] if history else None,
|
| 248 |
"best_val_loss": round(best_val_loss, 4),
|
|
|
|
| 249 |
"final_val_acc": history[-1]["val_acc"] if history else None,
|
| 250 |
"test_cross_entropy_loss": round(test_loss, 4),
|
| 251 |
"test_accuracy": round(test_acc, 4),
|
|
@@ -263,6 +271,7 @@ def train_model(
|
|
| 263 |
logs.append("Entraînement terminé.")
|
| 264 |
logs.append(f"Modèle sauvegardé : {model_name}")
|
| 265 |
logs.append(f"Appareil utilisé : {device}")
|
|
|
|
| 266 |
logs.append(f"Nombre total de paramètres : {total_params}")
|
| 267 |
logs.append(f"Paramètres entraînables : {trainable_params}")
|
| 268 |
logs.append(f"Perte test cross-entropy : {test_loss:.4f}")
|
|
@@ -289,7 +298,7 @@ def evaluate_saved_model(model_name: str):
|
|
| 289 |
device = get_runtime_device()
|
| 290 |
model, meta = load_model(model_name, device)
|
| 291 |
|
| 292 |
-
batch_size = int(meta["config"].get("batch_size",
|
| 293 |
_, _, test_loader, class_names = make_loaders(batch_size)
|
| 294 |
|
| 295 |
criterion = nn.CrossEntropyLoss()
|
|
|
|
| 55 |
|
| 56 |
if not os.path.exists(meta_file):
|
| 57 |
raise FileNotFoundError(f"Métadonnées introuvables pour le modèle : {model_name}")
|
| 58 |
+
|
| 59 |
if not os.path.exists(weight_file):
|
| 60 |
raise FileNotFoundError(f"Poids introuvables pour le modèle : {model_name}")
|
| 61 |
|
|
|
|
| 66 |
|
| 67 |
model = ResNet18Classifier(
|
| 68 |
num_classes=cfg["num_classes"],
|
| 69 |
+
dropout=cfg.get("dropout", 0.4),
|
| 70 |
+
fc_dim=cfg.get("fc_dim", 256),
|
| 71 |
+
fine_tune_mode=cfg.get("fine_tune_mode", "layer4"),
|
| 72 |
)
|
| 73 |
|
| 74 |
state_dict = torch.load(weight_file, map_location="cpu")
|
|
|
|
| 114 |
with torch.no_grad():
|
| 115 |
for images, labels in loader:
|
| 116 |
images = images.to(device)
|
| 117 |
+
|
| 118 |
outputs = model(images)
|
| 119 |
preds = outputs.argmax(dim=1).detach().cpu().tolist()
|
| 120 |
|
|
|
|
| 125 |
|
| 126 |
|
| 127 |
def train_model(
|
| 128 |
+
dropout: float = 0.4,
|
| 129 |
+
fc_dim: int = 256,
|
| 130 |
+
learning_rate: float = 0.00001,
|
| 131 |
+
weight_decay: float = 0.0001,
|
| 132 |
+
batch_size: int = 16,
|
| 133 |
+
epochs: int = 30,
|
| 134 |
+
fine_tune_mode: str = "layer4",
|
| 135 |
+
model_tag: str = "",
|
| 136 |
):
|
| 137 |
device = get_runtime_device()
|
| 138 |
|
|
|
|
| 143 |
num_classes=num_classes,
|
| 144 |
dropout=dropout,
|
| 145 |
fc_dim=fc_dim,
|
| 146 |
+
fine_tune_mode=fine_tune_mode,
|
| 147 |
).to(device)
|
| 148 |
|
| 149 |
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 150 |
total_params = sum(p.numel() for p in model.parameters())
|
| 151 |
|
| 152 |
criterion = nn.CrossEntropyLoss()
|
| 153 |
+
|
| 154 |
optimizer = optim.AdamW(
|
| 155 |
filter(lambda p: p.requires_grad, model.parameters()),
|
| 156 |
lr=learning_rate,
|
|
|
|
| 179 |
|
| 180 |
loss = criterion(outputs, labels)
|
| 181 |
loss.backward()
|
| 182 |
+
|
| 183 |
+
# Important: prevents unstable fine-tuning / exploding gradients
|
| 184 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
| 185 |
+
|
| 186 |
optimizer.step()
|
| 187 |
|
| 188 |
running_loss += loss.item() * images.size(0)
|
|
|
|
| 246 |
"weight_decay": weight_decay,
|
| 247 |
"batch_size": batch_size,
|
| 248 |
"epochs": epochs,
|
| 249 |
+
"fine_tune_mode": fine_tune_mode,
|
| 250 |
}
|
| 251 |
|
| 252 |
training_summary = {
|
| 253 |
"final_train_loss": history[-1]["train_loss"] if history else None,
|
| 254 |
"final_train_acc": history[-1]["train_acc"] if history else None,
|
| 255 |
"best_val_loss": round(best_val_loss, 4),
|
| 256 |
+
"final_val_loss": history[-1]["val_loss"] if history else None,
|
| 257 |
"final_val_acc": history[-1]["val_acc"] if history else None,
|
| 258 |
"test_cross_entropy_loss": round(test_loss, 4),
|
| 259 |
"test_accuracy": round(test_acc, 4),
|
|
|
|
| 271 |
logs.append("Entraînement terminé.")
|
| 272 |
logs.append(f"Modèle sauvegardé : {model_name}")
|
| 273 |
logs.append(f"Appareil utilisé : {device}")
|
| 274 |
+
logs.append(f"Mode de fine-tuning : {fine_tune_mode}")
|
| 275 |
logs.append(f"Nombre total de paramètres : {total_params}")
|
| 276 |
logs.append(f"Paramètres entraînables : {trainable_params}")
|
| 277 |
logs.append(f"Perte test cross-entropy : {test_loss:.4f}")
|
|
|
|
| 298 |
device = get_runtime_device()
|
| 299 |
model, meta = load_model(model_name, device)
|
| 300 |
|
| 301 |
+
batch_size = int(meta["config"].get("batch_size", 16))
|
| 302 |
_, _, test_loader, class_names = make_loaders(batch_size)
|
| 303 |
|
| 304 |
criterion = nn.CrossEntropyLoss()
|