KD099 commited on
Commit
d2e04af
·
verified ·
1 Parent(s): c142c1b

v3.2: update nqr_snn/snn/train.py

Browse files
Files changed (1) hide show
  1. nqr_snn/snn/train.py +19 -5
nqr_snn/snn/train.py CHANGED
@@ -249,13 +249,19 @@ def train_snn(model: nn.Module, train_loader, val_loader, seed: int,
249
  print(f" Epoch {epoch:3d}: train_loss={train_loss:.4f} train_acc={train_acc:.4f} "
250
  f"val_loss={val_loss:.4f} val_acc={val_acc:.4f} lr={current_lr:.2e} ({epoch_time:.1f}s)")
251
 
252
- # Early stopping on val_loss
253
- if val_loss < best_val_loss:
254
- best_val_loss = val_loss
 
255
  best_val_acc = val_acc
 
256
  best_epoch = epoch
257
  epochs_without_improvement = 0
258
- torch.save(model.state_dict(), ckpt_path)
 
 
 
 
259
  else:
260
  epochs_without_improvement += 1
261
  if epochs_without_improvement >= patience:
@@ -266,7 +272,15 @@ def train_snn(model: nn.Module, train_loader, val_loader, seed: int,
266
  df.to_csv(csv_path, index=False)
267
 
268
  if os.path.exists(ckpt_path):
269
- model.load_state_dict(torch.load(ckpt_path, map_location=device, weights_only=True))
 
 
 
 
 
 
 
 
270
 
271
  print(f" Best: val_acc={best_val_acc:.4f} at epoch {best_epoch}")
272
  return best_val_acc, best_epoch
 
249
  print(f" Epoch {epoch:3d}: train_loss={train_loss:.4f} train_acc={train_acc:.4f} "
250
  f"val_loss={val_loss:.4f} val_acc={val_acc:.4f} lr={current_lr:.2e} ({epoch_time:.1f}s)")
251
 
252
+ # v3.2: Early stopping on val_accuracy (was val_loss).
253
+ # At 99%+ accuracy, val_loss fluctuates while accuracy plateaus.
254
+ # Tracking accuracy prevents premature stopping on loss noise.
255
+ if val_acc > best_val_acc:
256
  best_val_acc = val_acc
257
+ best_val_loss = val_loss
258
  best_epoch = epoch
259
  epochs_without_improvement = 0
260
+ # v3.2: Save encoder alongside model to fix train/inference mismatch
261
+ ckpt = {"model_state_dict": model.state_dict()}
262
+ if not skip_encoder and isinstance(encoder, nn.Module):
263
+ ckpt["encoder_state_dict"] = encoder.state_dict()
264
+ torch.save(ckpt, ckpt_path)
265
  else:
266
  epochs_without_improvement += 1
267
  if epochs_without_improvement >= patience:
 
272
  df.to_csv(csv_path, index=False)
273
 
274
  if os.path.exists(ckpt_path):
275
+ ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
276
+ # v3.2: Support new dict format and legacy state_dict format
277
+ if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
278
+ model.load_state_dict(ckpt["model_state_dict"])
279
+ if not skip_encoder and isinstance(encoder, nn.Module) and "encoder_state_dict" in ckpt:
280
+ encoder.load_state_dict(ckpt["encoder_state_dict"])
281
+ else:
282
+ # Legacy checkpoint: bare state_dict
283
+ model.load_state_dict(ckpt)
284
 
285
  print(f" Best: val_acc={best_val_acc:.4f} at epoch {best_epoch}")
286
  return best_val_acc, best_epoch