Fix LR warmup ordering and align adam_eps with Meta LLaMA
Browse filesMove scheduler.set_lr() before optimizer.step() so the first training
step uses the correct warmup LR instead of the full peak LR, which
could perturb pretrained weights during CPT. Change adam_eps from 1e-8
to 1e-5 to match Meta LLaMA's value for better bf16 numerical stability.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
llm_lab/config/train_config.py
CHANGED
|
@@ -37,7 +37,9 @@ class TrainConfig:
|
|
| 37 |
"""Adam momentum coefficients. Ξ²2=0.95 is more stable than Ξ²2=0.999 for LLM training.
|
| 38 |
With large batches and long training, a Ξ²2 that is too large slows adaptation."""
|
| 39 |
|
| 40 |
-
adam_eps: float = 1e-
|
|
|
|
|
|
|
| 41 |
grad_clip: float = 1.0
|
| 42 |
"""Gradient Clipping: rescales gradients when their norm exceeds 1.0.
|
| 43 |
Prevents gradient spikes that occur during early training or with noisy data."""
|
|
@@ -112,8 +114,7 @@ class TrainConfig:
|
|
| 112 |
|
| 113 |
@property
|
| 114 |
def tokens_per_step(self) -> int:
|
| 115 |
-
"""Number of tokens processed per optimizer step."""
|
| 116 |
-
# max_seq_len is injected externally (see ModelConfig)
|
| 117 |
return self.effective_batch_size * 2048
|
| 118 |
|
| 119 |
@property
|
|
|
|
| 37 |
"""Adam momentum coefficients. Ξ²2=0.95 is more stable than Ξ²2=0.999 for LLM training.
|
| 38 |
With large batches and long training, a Ξ²2 that is too large slows adaptation."""
|
| 39 |
|
| 40 |
+
adam_eps: float = 1e-5
|
| 41 |
+
"""Adam epsilon. LLaMA uses 1e-5 (not PyTorch default 1e-8) for
|
| 42 |
+
numerical stability with bf16, which has fewer mantissa bits than fp32."""
|
| 43 |
grad_clip: float = 1.0
|
| 44 |
"""Gradient Clipping: rescales gradients when their norm exceeds 1.0.
|
| 45 |
Prevents gradient spikes that occur during early training or with noisy data."""
|
|
|
|
| 114 |
|
| 115 |
@property
|
| 116 |
def tokens_per_step(self) -> int:
|
| 117 |
+
"""Number of tokens processed per optimizer step (assumes max_seq_len=2048)."""
|
|
|
|
| 118 |
return self.effective_batch_size * 2048
|
| 119 |
|
| 120 |
@property
|
llm_lab/training/trainer.py
CHANGED
|
@@ -172,12 +172,14 @@ class Trainer:
|
|
| 172 |
max_norm=self.config.grad_clip,
|
| 173 |
).item()
|
| 174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
# ββ Optimizer Step ββ
|
| 176 |
self.optimizer.step()
|
| 177 |
|
| 178 |
-
# ββ LR Update ββ
|
| 179 |
-
self.scheduler.set_lr(self.optimizer, self.global_step)
|
| 180 |
-
|
| 181 |
avg_loss = total_loss / self.config.gradient_accumulation_steps
|
| 182 |
return avg_loss, grad_norm
|
| 183 |
|
|
|
|
| 172 |
max_norm=self.config.grad_clip,
|
| 173 |
).item()
|
| 174 |
|
| 175 |
+
# ββ LR Update (before optimizer step) ββ
|
| 176 |
+
# Must set LR before step() so the very first step uses warmup LR (not peak LR).
|
| 177 |
+
# Otherwise step 0 would use the full peak LR, perturbing pretrained weights.
|
| 178 |
+
self.scheduler.set_lr(self.optimizer, self.global_step)
|
| 179 |
+
|
| 180 |
# ββ Optimizer Step ββ
|
| 181 |
self.optimizer.step()
|
| 182 |
|
|
|
|
|
|
|
|
|
|
| 183 |
avg_loss = total_loss / self.config.gradient_accumulation_steps
|
| 184 |
return avg_loss, grad_norm
|
| 185 |
|