v3.2: update nqr_snn/snn/train.py
Browse files- nqr_snn/snn/train.py +19 -5
nqr_snn/snn/train.py
CHANGED
|
@@ -249,13 +249,19 @@ def train_snn(model: nn.Module, train_loader, val_loader, seed: int,
|
|
| 249 |
print(f" Epoch {epoch:3d}: train_loss={train_loss:.4f} train_acc={train_acc:.4f} "
|
| 250 |
f"val_loss={val_loss:.4f} val_acc={val_acc:.4f} lr={current_lr:.2e} ({epoch_time:.1f}s)")
|
| 251 |
|
| 252 |
-
# Early stopping on val_loss
|
| 253 |
-
|
| 254 |
-
|
|
|
|
| 255 |
best_val_acc = val_acc
|
|
|
|
| 256 |
best_epoch = epoch
|
| 257 |
epochs_without_improvement = 0
|
| 258 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
else:
|
| 260 |
epochs_without_improvement += 1
|
| 261 |
if epochs_without_improvement >= patience:
|
|
@@ -266,7 +272,15 @@ def train_snn(model: nn.Module, train_loader, val_loader, seed: int,
|
|
| 266 |
df.to_csv(csv_path, index=False)
|
| 267 |
|
| 268 |
if os.path.exists(ckpt_path):
|
| 269 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
|
| 271 |
print(f" Best: val_acc={best_val_acc:.4f} at epoch {best_epoch}")
|
| 272 |
return best_val_acc, best_epoch
|
|
|
|
| 249 |
print(f" Epoch {epoch:3d}: train_loss={train_loss:.4f} train_acc={train_acc:.4f} "
|
| 250 |
f"val_loss={val_loss:.4f} val_acc={val_acc:.4f} lr={current_lr:.2e} ({epoch_time:.1f}s)")
|
| 251 |
|
| 252 |
+
# v3.2: Early stopping on val_accuracy (was val_loss).
|
| 253 |
+
# At 99%+ accuracy, val_loss fluctuates while accuracy plateaus.
|
| 254 |
+
# Tracking accuracy prevents premature stopping on loss noise.
|
| 255 |
+
if val_acc > best_val_acc:
|
| 256 |
best_val_acc = val_acc
|
| 257 |
+
best_val_loss = val_loss
|
| 258 |
best_epoch = epoch
|
| 259 |
epochs_without_improvement = 0
|
| 260 |
+
# v3.2: Save encoder alongside model to fix train/inference mismatch
|
| 261 |
+
ckpt = {"model_state_dict": model.state_dict()}
|
| 262 |
+
if not skip_encoder and isinstance(encoder, nn.Module):
|
| 263 |
+
ckpt["encoder_state_dict"] = encoder.state_dict()
|
| 264 |
+
torch.save(ckpt, ckpt_path)
|
| 265 |
else:
|
| 266 |
epochs_without_improvement += 1
|
| 267 |
if epochs_without_improvement >= patience:
|
|
|
|
| 272 |
df.to_csv(csv_path, index=False)
|
| 273 |
|
| 274 |
if os.path.exists(ckpt_path):
|
| 275 |
+
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
|
| 276 |
+
# v3.2: Support new dict format and legacy state_dict format
|
| 277 |
+
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
|
| 278 |
+
model.load_state_dict(ckpt["model_state_dict"])
|
| 279 |
+
if not skip_encoder and isinstance(encoder, nn.Module) and "encoder_state_dict" in ckpt:
|
| 280 |
+
encoder.load_state_dict(ckpt["encoder_state_dict"])
|
| 281 |
+
else:
|
| 282 |
+
# Legacy checkpoint: bare state_dict
|
| 283 |
+
model.load_state_dict(ckpt)
|
| 284 |
|
| 285 |
print(f" Best: val_acc={best_val_acc:.4f} at epoch {best_epoch}")
|
| 286 |
return best_val_acc, best_epoch
|