Vjeong Claude Opus 4.6 commited on
Commit
af13727
Β·
1 Parent(s): 2a50172

Fix LR warmup ordering and align adam_eps with Meta LLaMA

Browse files

Move scheduler.set_lr() before optimizer.step() so the first training
step uses the correct warmup LR instead of the full peak LR, which
could perturb pretrained weights during CPT. Change adam_eps from 1e-8
to 1e-5 to match Meta LLaMA's value for better bf16 numerical stability.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

llm_lab/config/train_config.py CHANGED
@@ -37,7 +37,9 @@ class TrainConfig:
37
  """Adam momentum coefficients. Ξ²2=0.95 is more stable than Ξ²2=0.999 for LLM training.
38
  With large batches and long training, a Ξ²2 that is too large slows adaptation."""
39
 
40
- adam_eps: float = 1e-8
 
 
41
  grad_clip: float = 1.0
42
  """Gradient Clipping: rescales gradients when their norm exceeds 1.0.
43
  Prevents gradient spikes that occur during early training or with noisy data."""
@@ -112,8 +114,7 @@ class TrainConfig:
112
 
113
  @property
114
  def tokens_per_step(self) -> int:
115
- """Number of tokens processed per optimizer step."""
116
- # max_seq_len is injected externally (see ModelConfig)
117
  return self.effective_batch_size * 2048
118
 
119
  @property
 
37
  """Adam momentum coefficients. Ξ²2=0.95 is more stable than Ξ²2=0.999 for LLM training.
38
  With large batches and long training, a Ξ²2 that is too large slows adaptation."""
39
 
40
+ adam_eps: float = 1e-5
41
+ """Adam epsilon. LLaMA uses 1e-5 (not PyTorch default 1e-8) for
42
+ numerical stability with bf16, which has fewer mantissa bits than fp32."""
43
  grad_clip: float = 1.0
44
  """Gradient Clipping: rescales gradients when their norm exceeds 1.0.
45
  Prevents gradient spikes that occur during early training or with noisy data."""
 
114
 
115
  @property
116
  def tokens_per_step(self) -> int:
117
+ """Number of tokens processed per optimizer step (assumes max_seq_len=2048)."""
 
118
  return self.effective_batch_size * 2048
119
 
120
  @property
llm_lab/training/trainer.py CHANGED
@@ -172,12 +172,14 @@ class Trainer:
172
  max_norm=self.config.grad_clip,
173
  ).item()
174
 
 
 
 
 
 
175
  # ── Optimizer Step ──
176
  self.optimizer.step()
177
 
178
- # ── LR Update ──
179
- self.scheduler.set_lr(self.optimizer, self.global_step)
180
-
181
  avg_loss = total_loss / self.config.gradient_accumulation_steps
182
  return avg_loss, grad_norm
183
 
 
172
  max_norm=self.config.grad_clip,
173
  ).item()
174
 
175
+ # ── LR Update (before optimizer step) ──
176
+ # Must set LR before step() so the very first step uses warmup LR (not peak LR).
177
+ # Otherwise step 0 would use the full peak LR, perturbing pretrained weights.
178
+ self.scheduler.set_lr(self.optimizer, self.global_step)
179
+
180
  # ── Optimizer Step ──
181
  self.optimizer.step()
182
 
 
 
 
183
  avg_loss = total_loss / self.config.gradient_accumulation_steps
184
  return avg_loss, grad_norm
185