add micro_batch_size config key + per-micro inner loop in train step (fixes OOM for fp32+seq2048)
Browse files- configs/base.toml +1 -0
- configs/grow40_simple.toml +1 -0
- configs/grow40_winning.toml +1 -0
- configs/replicate_zero4.toml +1 -0
- configs/zero_14_17.toml +1 -0
- distill.py +28 -16
configs/base.toml
CHANGED
|
@@ -25,6 +25,7 @@ grad_clip = 1.0
|
|
| 25 |
betas = [0.9, 0.95]
|
| 26 |
eps = 1.0e-8
|
| 27 |
samples_per_step = 4
|
|
|
|
| 28 |
max_steps = 5
|
| 29 |
grad_checkpointing = true
|
| 30 |
attn_implementation = "flash_attention_2"
|
|
|
|
| 25 |
betas = [0.9, 0.95]
|
| 26 |
eps = 1.0e-8
|
| 27 |
samples_per_step = 4
|
| 28 |
+
micro_batch_size = 4
|
| 29 |
max_steps = 5
|
| 30 |
grad_checkpointing = true
|
| 31 |
attn_implementation = "flash_attention_2"
|
configs/grow40_simple.toml
CHANGED
|
@@ -26,6 +26,7 @@ grad_clip = 1.0
|
|
| 26 |
betas = [0.9, 0.95]
|
| 27 |
eps = 1.0e-8
|
| 28 |
samples_per_step = 8
|
|
|
|
| 29 |
max_steps = 2000
|
| 30 |
grad_checkpointing = true
|
| 31 |
attn_implementation = "flash_attention_2"
|
|
|
|
| 26 |
betas = [0.9, 0.95]
|
| 27 |
eps = 1.0e-8
|
| 28 |
samples_per_step = 8
|
| 29 |
+
micro_batch_size = 8
|
| 30 |
max_steps = 2000
|
| 31 |
grad_checkpointing = true
|
| 32 |
attn_implementation = "flash_attention_2"
|
configs/grow40_winning.toml
CHANGED
|
@@ -26,6 +26,7 @@ grad_clip = 1.0
|
|
| 26 |
betas = [0.9, 0.999]
|
| 27 |
eps = 1.0e-3
|
| 28 |
samples_per_step = 4
|
|
|
|
| 29 |
max_steps = 2000
|
| 30 |
grad_checkpointing = true
|
| 31 |
attn_implementation = "flash_attention_2"
|
|
|
|
| 26 |
betas = [0.9, 0.999]
|
| 27 |
eps = 1.0e-3
|
| 28 |
samples_per_step = 4
|
| 29 |
+
micro_batch_size = 1
|
| 30 |
max_steps = 2000
|
| 31 |
grad_checkpointing = true
|
| 32 |
attn_implementation = "flash_attention_2"
|
configs/replicate_zero4.toml
CHANGED
|
@@ -25,6 +25,7 @@ grad_clip = 1.0
|
|
| 25 |
betas = [0.9, 0.999]
|
| 26 |
eps = 1.0e-3
|
| 27 |
samples_per_step = 4
|
|
|
|
| 28 |
max_steps = 2000
|
| 29 |
grad_checkpointing = true
|
| 30 |
attn_implementation = "flash_attention_2"
|
|
|
|
| 25 |
betas = [0.9, 0.999]
|
| 26 |
eps = 1.0e-3
|
| 27 |
samples_per_step = 4
|
| 28 |
+
micro_batch_size = 1
|
| 29 |
max_steps = 2000
|
| 30 |
grad_checkpointing = true
|
| 31 |
attn_implementation = "flash_attention_2"
|
configs/zero_14_17.toml
CHANGED
|
@@ -26,6 +26,7 @@ grad_clip = 1.0
|
|
| 26 |
betas = [0.9, 0.95]
|
| 27 |
eps = 1.0e-8
|
| 28 |
samples_per_step = 8
|
|
|
|
| 29 |
max_steps = 2000
|
| 30 |
grad_checkpointing = true
|
| 31 |
attn_implementation = "flash_attention_2"
|
|
|
|
| 26 |
betas = [0.9, 0.95]
|
| 27 |
eps = 1.0e-8
|
| 28 |
samples_per_step = 8
|
| 29 |
+
micro_batch_size = 8
|
| 30 |
max_steps = 2000
|
| 31 |
grad_checkpointing = true
|
| 32 |
attn_implementation = "flash_attention_2"
|
distill.py
CHANGED
|
@@ -71,6 +71,7 @@ REQUIRED_KEYS = {
|
|
| 71 |
"teacher_dtype",
|
| 72 |
"mixed_precision",
|
| 73 |
"kl_chunk_size",
|
|
|
|
| 74 |
),
|
| 75 |
"eval": ("every_steps", "samples", "seed"),
|
| 76 |
"log": ("wandb", "wandb_project", "wandb_run", "log_every", "output_dir"),
|
|
@@ -580,6 +581,7 @@ def main():
|
|
| 580 |
|
| 581 |
# ---- Train loop
|
| 582 |
samples_per_step = cfg["train"]["samples_per_step"]
|
|
|
|
| 583 |
grad_clip = cfg["train"]["grad_clip"]
|
| 584 |
kl_start_pos = cfg["data"]["kl_start_pos"]
|
| 585 |
kl_chunk_size = cfg["train"]["kl_chunk_size"]
|
|
@@ -590,7 +592,8 @@ def main():
|
|
| 590 |
if accelerator.is_main_process:
|
| 591 |
log.info(
|
| 592 |
f"=== Training: max_steps={max_steps}, samples_per_step={samples_per_step} "
|
| 593 |
-
f"(per rank
|
|
|
|
| 594 |
)
|
| 595 |
|
| 596 |
student.train()
|
|
@@ -604,20 +607,29 @@ def main():
|
|
| 604 |
log.warning(f"rank {accelerator.process_index}: data exhausted")
|
| 605 |
break
|
| 606 |
|
| 607 |
-
ids, mask = collate_pad(batch, pad_id)
|
| 608 |
-
ids = ids.to(accelerator.device)
|
| 609 |
-
mask = mask.to(accelerator.device)
|
| 610 |
-
|
| 611 |
-
with torch.no_grad():
|
| 612 |
-
t_logits = teacher_forward(teacher, ids, mask)
|
| 613 |
-
s_logits = student(input_ids=ids, attention_mask=mask).logits
|
| 614 |
-
loss = kl_loss_masked(
|
| 615 |
-
s_logits, t_logits, mask,
|
| 616 |
-
start_pos=kl_start_pos, chunk_size=kl_chunk_size,
|
| 617 |
-
)
|
| 618 |
-
|
| 619 |
optimizer.zero_grad()
|
| 620 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 621 |
if grad_clip > 0:
|
| 622 |
accelerator.clip_grad_norm_(student.parameters(), grad_clip)
|
| 623 |
optimizer.step()
|
|
@@ -625,9 +637,9 @@ def main():
|
|
| 625 |
global_step += 1
|
| 626 |
|
| 627 |
elapsed = time.time() - t0
|
| 628 |
-
kl_local =
|
| 629 |
kl_avg = accelerator.gather(kl_local.unsqueeze(0)).mean().item()
|
| 630 |
-
del
|
| 631 |
|
| 632 |
if accelerator.is_main_process and global_step % log_every == 0:
|
| 633 |
lr_now = scheduler.get_last_lr()[0]
|
|
|
|
| 71 |
"teacher_dtype",
|
| 72 |
"mixed_precision",
|
| 73 |
"kl_chunk_size",
|
| 74 |
+
"micro_batch_size",
|
| 75 |
),
|
| 76 |
"eval": ("every_steps", "samples", "seed"),
|
| 77 |
"log": ("wandb", "wandb_project", "wandb_run", "log_every", "output_dir"),
|
|
|
|
| 581 |
|
| 582 |
# ---- Train loop
|
| 583 |
samples_per_step = cfg["train"]["samples_per_step"]
|
| 584 |
+
micro_batch_size = cfg["train"]["micro_batch_size"]
|
| 585 |
grad_clip = cfg["train"]["grad_clip"]
|
| 586 |
kl_start_pos = cfg["data"]["kl_start_pos"]
|
| 587 |
kl_chunk_size = cfg["train"]["kl_chunk_size"]
|
|
|
|
| 592 |
if accelerator.is_main_process:
|
| 593 |
log.info(
|
| 594 |
f"=== Training: max_steps={max_steps}, samples_per_step={samples_per_step} "
|
| 595 |
+
f"(per rank, micro={micro_batch_size}), "
|
| 596 |
+
f"effective batch={samples_per_step * accelerator.num_processes}"
|
| 597 |
)
|
| 598 |
|
| 599 |
student.train()
|
|
|
|
| 607 |
log.warning(f"rank {accelerator.process_index}: data exhausted")
|
| 608 |
break
|
| 609 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 610 |
optimizer.zero_grad()
|
| 611 |
+
batch_n = len(batch)
|
| 612 |
+
kl_sum = 0.0
|
| 613 |
+
for mb_start in range(0, batch_n, micro_batch_size):
|
| 614 |
+
micro = batch[mb_start : mb_start + micro_batch_size]
|
| 615 |
+
mb_n = len(micro)
|
| 616 |
+
ids, mask = collate_pad(micro, pad_id)
|
| 617 |
+
ids = ids.to(accelerator.device)
|
| 618 |
+
mask = mask.to(accelerator.device)
|
| 619 |
+
|
| 620 |
+
with torch.no_grad():
|
| 621 |
+
t_logits = teacher_forward(teacher, ids, mask)
|
| 622 |
+
s_logits = student(input_ids=ids, attention_mask=mask).logits
|
| 623 |
+
loss = kl_loss_masked(
|
| 624 |
+
s_logits, t_logits, mask,
|
| 625 |
+
start_pos=kl_start_pos, chunk_size=kl_chunk_size,
|
| 626 |
+
)
|
| 627 |
+
# Weight by micro size so summing micros gives the batch mean
|
| 628 |
+
scaled = loss * (mb_n / batch_n)
|
| 629 |
+
accelerator.backward(scaled)
|
| 630 |
+
kl_sum += loss.item() * mb_n
|
| 631 |
+
del t_logits, s_logits, loss, scaled
|
| 632 |
+
|
| 633 |
if grad_clip > 0:
|
| 634 |
accelerator.clip_grad_norm_(student.parameters(), grad_clip)
|
| 635 |
optimizer.step()
|
|
|
|
| 637 |
global_step += 1
|
| 638 |
|
| 639 |
elapsed = time.time() - t0
|
| 640 |
+
kl_local = torch.tensor(kl_sum / batch_n, device=accelerator.device)
|
| 641 |
kl_avg = accelerator.gather(kl_local.unsqueeze(0)).mean().item()
|
| 642 |
+
del kl_local
|
| 643 |
|
| 644 |
if accelerator.is_main_process and global_step % log_every == 0:
|
| 645 |
lr_now = scheduler.get_last_lr()[0]
|