Delta-Vector commited on
Commit
be991b1
·
verified ·
1 Parent(s): eb5278f

add micro_batch_size config key + per-micro inner loop in train step (fixes OOM for fp32+seq2048)

Browse files
configs/base.toml CHANGED
@@ -25,6 +25,7 @@ grad_clip = 1.0
25
  betas = [0.9, 0.95]
26
  eps = 1.0e-8
27
  samples_per_step = 4
 
28
  max_steps = 5
29
  grad_checkpointing = true
30
  attn_implementation = "flash_attention_2"
 
25
  betas = [0.9, 0.95]
26
  eps = 1.0e-8
27
  samples_per_step = 4
28
+ micro_batch_size = 4
29
  max_steps = 5
30
  grad_checkpointing = true
31
  attn_implementation = "flash_attention_2"
configs/grow40_simple.toml CHANGED
@@ -26,6 +26,7 @@ grad_clip = 1.0
26
  betas = [0.9, 0.95]
27
  eps = 1.0e-8
28
  samples_per_step = 8
 
29
  max_steps = 2000
30
  grad_checkpointing = true
31
  attn_implementation = "flash_attention_2"
 
26
  betas = [0.9, 0.95]
27
  eps = 1.0e-8
28
  samples_per_step = 8
29
+ micro_batch_size = 8
30
  max_steps = 2000
31
  grad_checkpointing = true
32
  attn_implementation = "flash_attention_2"
configs/grow40_winning.toml CHANGED
@@ -26,6 +26,7 @@ grad_clip = 1.0
26
  betas = [0.9, 0.999]
27
  eps = 1.0e-3
28
  samples_per_step = 4
 
29
  max_steps = 2000
30
  grad_checkpointing = true
31
  attn_implementation = "flash_attention_2"
 
26
  betas = [0.9, 0.999]
27
  eps = 1.0e-3
28
  samples_per_step = 4
29
+ micro_batch_size = 1
30
  max_steps = 2000
31
  grad_checkpointing = true
32
  attn_implementation = "flash_attention_2"
configs/replicate_zero4.toml CHANGED
@@ -25,6 +25,7 @@ grad_clip = 1.0
25
  betas = [0.9, 0.999]
26
  eps = 1.0e-3
27
  samples_per_step = 4
 
28
  max_steps = 2000
29
  grad_checkpointing = true
30
  attn_implementation = "flash_attention_2"
 
25
  betas = [0.9, 0.999]
26
  eps = 1.0e-3
27
  samples_per_step = 4
28
+ micro_batch_size = 1
29
  max_steps = 2000
30
  grad_checkpointing = true
31
  attn_implementation = "flash_attention_2"
configs/zero_14_17.toml CHANGED
@@ -26,6 +26,7 @@ grad_clip = 1.0
26
  betas = [0.9, 0.95]
27
  eps = 1.0e-8
28
  samples_per_step = 8
 
29
  max_steps = 2000
30
  grad_checkpointing = true
31
  attn_implementation = "flash_attention_2"
 
26
  betas = [0.9, 0.95]
27
  eps = 1.0e-8
28
  samples_per_step = 8
29
+ micro_batch_size = 8
30
  max_steps = 2000
31
  grad_checkpointing = true
32
  attn_implementation = "flash_attention_2"
distill.py CHANGED
@@ -71,6 +71,7 @@ REQUIRED_KEYS = {
71
  "teacher_dtype",
72
  "mixed_precision",
73
  "kl_chunk_size",
 
74
  ),
75
  "eval": ("every_steps", "samples", "seed"),
76
  "log": ("wandb", "wandb_project", "wandb_run", "log_every", "output_dir"),
@@ -580,6 +581,7 @@ def main():
580
 
581
  # ---- Train loop
582
  samples_per_step = cfg["train"]["samples_per_step"]
 
583
  grad_clip = cfg["train"]["grad_clip"]
584
  kl_start_pos = cfg["data"]["kl_start_pos"]
585
  kl_chunk_size = cfg["train"]["kl_chunk_size"]
@@ -590,7 +592,8 @@ def main():
590
  if accelerator.is_main_process:
591
  log.info(
592
  f"=== Training: max_steps={max_steps}, samples_per_step={samples_per_step} "
593
- f"(per rank), effective batch={samples_per_step * accelerator.num_processes}"
 
594
  )
595
 
596
  student.train()
@@ -604,20 +607,29 @@ def main():
604
  log.warning(f"rank {accelerator.process_index}: data exhausted")
605
  break
606
 
607
- ids, mask = collate_pad(batch, pad_id)
608
- ids = ids.to(accelerator.device)
609
- mask = mask.to(accelerator.device)
610
-
611
- with torch.no_grad():
612
- t_logits = teacher_forward(teacher, ids, mask)
613
- s_logits = student(input_ids=ids, attention_mask=mask).logits
614
- loss = kl_loss_masked(
615
- s_logits, t_logits, mask,
616
- start_pos=kl_start_pos, chunk_size=kl_chunk_size,
617
- )
618
-
619
  optimizer.zero_grad()
620
- accelerator.backward(loss)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
621
  if grad_clip > 0:
622
  accelerator.clip_grad_norm_(student.parameters(), grad_clip)
623
  optimizer.step()
@@ -625,9 +637,9 @@ def main():
625
  global_step += 1
626
 
627
  elapsed = time.time() - t0
628
- kl_local = loss.detach()
629
  kl_avg = accelerator.gather(kl_local.unsqueeze(0)).mean().item()
630
- del t_logits, s_logits, loss, kl_local
631
 
632
  if accelerator.is_main_process and global_step % log_every == 0:
633
  lr_now = scheduler.get_last_lr()[0]
 
71
  "teacher_dtype",
72
  "mixed_precision",
73
  "kl_chunk_size",
74
+ "micro_batch_size",
75
  ),
76
  "eval": ("every_steps", "samples", "seed"),
77
  "log": ("wandb", "wandb_project", "wandb_run", "log_every", "output_dir"),
 
581
 
582
  # ---- Train loop
583
  samples_per_step = cfg["train"]["samples_per_step"]
584
+ micro_batch_size = cfg["train"]["micro_batch_size"]
585
  grad_clip = cfg["train"]["grad_clip"]
586
  kl_start_pos = cfg["data"]["kl_start_pos"]
587
  kl_chunk_size = cfg["train"]["kl_chunk_size"]
 
592
  if accelerator.is_main_process:
593
  log.info(
594
  f"=== Training: max_steps={max_steps}, samples_per_step={samples_per_step} "
595
+ f"(per rank, micro={micro_batch_size}), "
596
+ f"effective batch={samples_per_step * accelerator.num_processes}"
597
  )
598
 
599
  student.train()
 
607
  log.warning(f"rank {accelerator.process_index}: data exhausted")
608
  break
609
 
 
 
 
 
 
 
 
 
 
 
 
 
610
  optimizer.zero_grad()
611
+ batch_n = len(batch)
612
+ kl_sum = 0.0
613
+ for mb_start in range(0, batch_n, micro_batch_size):
614
+ micro = batch[mb_start : mb_start + micro_batch_size]
615
+ mb_n = len(micro)
616
+ ids, mask = collate_pad(micro, pad_id)
617
+ ids = ids.to(accelerator.device)
618
+ mask = mask.to(accelerator.device)
619
+
620
+ with torch.no_grad():
621
+ t_logits = teacher_forward(teacher, ids, mask)
622
+ s_logits = student(input_ids=ids, attention_mask=mask).logits
623
+ loss = kl_loss_masked(
624
+ s_logits, t_logits, mask,
625
+ start_pos=kl_start_pos, chunk_size=kl_chunk_size,
626
+ )
627
+ # Weight by micro size so summing micros gives the batch mean
628
+ scaled = loss * (mb_n / batch_n)
629
+ accelerator.backward(scaled)
630
+ kl_sum += loss.item() * mb_n
631
+ del t_logits, s_logits, loss, scaled
632
+
633
  if grad_clip > 0:
634
  accelerator.clip_grad_norm_(student.parameters(), grad_clip)
635
  optimizer.step()
 
637
  global_step += 1
638
 
639
  elapsed = time.time() - t0
640
+ kl_local = torch.tensor(kl_sum / batch_n, device=accelerator.device)
641
  kl_avg = accelerator.gather(kl_local.unsqueeze(0)).mean().item()
642
+ del kl_local
643
 
644
  if accelerator.is_main_process and global_step % log_every == 0:
645
  lr_now = scheduler.get_last_lr()[0]