Spaces:
Running
on
Zero
Running
on
Zero
Fix multiprocessing dataloader checkpointing and use it in the train script (#50)
Browse files
bytelatent/args.py
CHANGED
|
@@ -1,10 +1,8 @@
|
|
| 1 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
import json
|
| 3 |
import logging
|
| 4 |
import os
|
| 5 |
from typing import Any
|
| 6 |
|
| 7 |
-
import fsspec
|
| 8 |
import numpy as np
|
| 9 |
import yaml
|
| 10 |
from omegaconf import OmegaConf
|
|
|
|
| 1 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
|
|
| 2 |
import logging
|
| 3 |
import os
|
| 4 |
from typing import Any
|
| 5 |
|
|
|
|
| 6 |
import numpy as np
|
| 7 |
import yaml
|
| 8 |
from omegaconf import OmegaConf
|
bytelatent/data/iterators/abstract_iterator.py
CHANGED
|
@@ -21,3 +21,13 @@ class IteratorState(Generic[C]):
|
|
| 21 |
@abc.abstractmethod
|
| 22 |
def build(self) -> StatefulIterator[T, C]:
|
| 23 |
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
@abc.abstractmethod
|
| 22 |
def build(self) -> StatefulIterator[T, C]:
|
| 23 |
pass
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_state_and_refresh(iterator: StatefulIterator):
|
| 27 |
+
# Re-init dataloader and iterator is necessary since get_state()
|
| 28 |
+
# on mp iterator shuts down MP to correctly persist state and it needs
|
| 29 |
+
# to be restarted.
|
| 30 |
+
state = iterator.get_state()
|
| 31 |
+
data_loader = state.build()
|
| 32 |
+
py_iterator = data_loader.create_iter()
|
| 33 |
+
return state, data_loader, py_iterator
|
bytelatent/data/iterators/arrow_iterator.py
CHANGED
|
@@ -60,6 +60,13 @@ def shard_sort_key(file: str):
|
|
| 60 |
return shard_number
|
| 61 |
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
class ArrowFileIterator(StatefulIterator):
|
| 64 |
def __init__(
|
| 65 |
self,
|
|
@@ -235,9 +242,8 @@ class ArrowFileIterator(StatefulIterator):
|
|
| 235 |
yield out
|
| 236 |
|
| 237 |
def _set_row_num(self, target_row_num: int):
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
)
|
| 241 |
if target_row_num is None or target_row_num == 0:
|
| 242 |
self.row_num = 0
|
| 243 |
self.dataset = None
|
|
@@ -285,6 +291,7 @@ class ArrowFileIterator(StatefulIterator):
|
|
| 285 |
else:
|
| 286 |
curr_remaining -= len(batch)
|
| 287 |
self.row_num = target_row_num
|
|
|
|
| 288 |
logger.info(
|
| 289 |
-
f"Finished setting arrow position to {target_row_num} for {
|
| 290 |
)
|
|
|
|
| 60 |
return shard_number
|
| 61 |
|
| 62 |
|
| 63 |
+
def maybe_truncate_string(text: str, max_length: int):
|
| 64 |
+
if len(text) <= max_length:
|
| 65 |
+
return text
|
| 66 |
+
else:
|
| 67 |
+
return text[:max_length] + "..."
|
| 68 |
+
|
| 69 |
+
|
| 70 |
class ArrowFileIterator(StatefulIterator):
|
| 71 |
def __init__(
|
| 72 |
self,
|
|
|
|
| 242 |
yield out
|
| 243 |
|
| 244 |
def _set_row_num(self, target_row_num: int):
|
| 245 |
+
data_str = maybe_truncate_string(str(self.dataset_files), 200)
|
| 246 |
+
logger.info(f"Setting arrow position to {target_row_num} for {data_str}")
|
|
|
|
| 247 |
if target_row_num is None or target_row_num == 0:
|
| 248 |
self.row_num = 0
|
| 249 |
self.dataset = None
|
|
|
|
| 291 |
else:
|
| 292 |
curr_remaining -= len(batch)
|
| 293 |
self.row_num = target_row_num
|
| 294 |
+
data_str = maybe_truncate_string(str(self.dataset_files), 200)
|
| 295 |
logger.info(
|
| 296 |
+
f"Finished setting arrow position to {target_row_num} for {data_str}"
|
| 297 |
)
|
bytelatent/data/iterators/multiprocess_iterator.py
CHANGED
|
@@ -54,9 +54,10 @@ def start_work_from_state(
|
|
| 54 |
if stop_event.is_set():
|
| 55 |
# Signal the end of output, this ensures that even if the queue takes a while to
|
| 56 |
# buffer, that the main thread receives everything (and tosses this fake batch)
|
| 57 |
-
logging.
|
| 58 |
"Worker thread: Stop event detected, outputting is_final=True batch"
|
| 59 |
)
|
|
|
|
| 60 |
batch_queue.put(
|
| 61 |
Batch(
|
| 62 |
x=np.zeros((1, 1)),
|
|
@@ -67,14 +68,17 @@ def start_work_from_state(
|
|
| 67 |
ngram_ids=None,
|
| 68 |
)
|
| 69 |
)
|
|
|
|
|
|
|
|
|
|
| 70 |
break
|
| 71 |
|
| 72 |
try:
|
| 73 |
-
logging.
|
| 74 |
-
state_queue.put(
|
| 75 |
-
logging.
|
| 76 |
state_dumped_event.set()
|
| 77 |
-
logging.
|
| 78 |
except Full:
|
| 79 |
raise ValueError(
|
| 80 |
"Attempted to dump state into the state queue, but it was full"
|
|
@@ -156,16 +160,20 @@ class MultiprocessIterator(StatefulIterator):
|
|
| 156 |
serialized_prefetch_buffer=serialized_prefetch_buffer,
|
| 157 |
)
|
| 158 |
else:
|
| 159 |
-
logging.
|
| 160 |
self.stop_iterating_event.set()
|
| 161 |
-
logging.
|
| 162 |
-
|
|
|
|
| 163 |
self.prefetch_buffer = []
|
| 164 |
final_batch_received = False
|
| 165 |
while True:
|
| 166 |
try:
|
| 167 |
batch = self.batch_queue.get(timeout=1)
|
| 168 |
if batch.is_final:
|
|
|
|
|
|
|
|
|
|
| 169 |
final_batch_received = True
|
| 170 |
break
|
| 171 |
self.prefetch_buffer.append(batch)
|
|
@@ -173,6 +181,9 @@ class MultiprocessIterator(StatefulIterator):
|
|
| 173 |
logging.warning("Main thread: batch_queue is abnormally empty")
|
| 174 |
assert final_batch_received
|
| 175 |
|
|
|
|
|
|
|
|
|
|
| 176 |
try:
|
| 177 |
base_iterator_state = self.state_queue.get(timeout=1)
|
| 178 |
assert isinstance(base_iterator_state, IteratorState)
|
|
|
|
| 54 |
if stop_event.is_set():
|
| 55 |
# Signal the end of output, this ensures that even if the queue takes a while to
|
| 56 |
# buffer, that the main thread receives everything (and tosses this fake batch)
|
| 57 |
+
logging.debug(
|
| 58 |
"Worker thread: Stop event detected, outputting is_final=True batch"
|
| 59 |
)
|
| 60 |
+
logging.debug("Worker thread: batch_queue full=%s", batch_queue.full())
|
| 61 |
batch_queue.put(
|
| 62 |
Batch(
|
| 63 |
x=np.zeros((1, 1)),
|
|
|
|
| 68 |
ngram_ids=None,
|
| 69 |
)
|
| 70 |
)
|
| 71 |
+
logging.debug(
|
| 72 |
+
"Worker thread: is_final=True batch put in queue, breaking from loop."
|
| 73 |
+
)
|
| 74 |
break
|
| 75 |
|
| 76 |
try:
|
| 77 |
+
logging.debug("Worker thread: outputting state")
|
| 78 |
+
state_queue.put(stateful_iterator.get_state(), timeout=1)
|
| 79 |
+
logging.debug("Worker thread: state dump complete")
|
| 80 |
state_dumped_event.set()
|
| 81 |
+
logging.debug("Worker thread: set state_dump_event")
|
| 82 |
except Full:
|
| 83 |
raise ValueError(
|
| 84 |
"Attempted to dump state into the state queue, but it was full"
|
|
|
|
| 160 |
serialized_prefetch_buffer=serialized_prefetch_buffer,
|
| 161 |
)
|
| 162 |
else:
|
| 163 |
+
logging.debug("Main thread: Sending stop iteration event")
|
| 164 |
self.stop_iterating_event.set()
|
| 165 |
+
logging.debug(
|
| 166 |
+
"Main thread: Emptying the batch_queue until batch.is_final=True is found."
|
| 167 |
+
)
|
| 168 |
self.prefetch_buffer = []
|
| 169 |
final_batch_received = False
|
| 170 |
while True:
|
| 171 |
try:
|
| 172 |
batch = self.batch_queue.get(timeout=1)
|
| 173 |
if batch.is_final:
|
| 174 |
+
logging.debug(
|
| 175 |
+
"Main thread: is_final=True batch found, stopping fetch from batch_queue"
|
| 176 |
+
)
|
| 177 |
final_batch_received = True
|
| 178 |
break
|
| 179 |
self.prefetch_buffer.append(batch)
|
|
|
|
| 181 |
logging.warning("Main thread: batch_queue is abnormally empty")
|
| 182 |
assert final_batch_received
|
| 183 |
|
| 184 |
+
logging.debug("Main thread: Waiting for state_dumped event")
|
| 185 |
+
self.state_dumped_event.wait()
|
| 186 |
+
|
| 187 |
try:
|
| 188 |
base_iterator_state = self.state_queue.get(timeout=1)
|
| 189 |
assert isinstance(base_iterator_state, IteratorState)
|
bytelatent/train.py
CHANGED
|
@@ -26,6 +26,7 @@ from torch.optim import lr_scheduler
|
|
| 26 |
from bytelatent.args import TrainArgs, parse_args
|
| 27 |
from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint
|
| 28 |
from bytelatent.data.file_util import get_fs
|
|
|
|
| 29 |
from bytelatent.data.iterators.multiprocess_iterator import (
|
| 30 |
MultiprocessIterator,
|
| 31 |
MultiprocessIteratorState,
|
|
@@ -35,7 +36,6 @@ from bytelatent.distributed import (
|
|
| 35 |
check_model_value_range,
|
| 36 |
clean_env,
|
| 37 |
dist_mean,
|
| 38 |
-
dist_mean_dict,
|
| 39 |
dist_sum,
|
| 40 |
get_device_mesh,
|
| 41 |
get_is_master,
|
|
@@ -88,6 +88,13 @@ def get_iterator_state_name(iterator_state):
|
|
| 88 |
raise ValueError(f"Unsupported iterator to get name from: {iterator_state}")
|
| 89 |
|
| 90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
# TODO: Make this pydantic based instead of data class based
|
| 92 |
# TODO: Generalize this to any iterator state
|
| 93 |
@dataclass
|
|
@@ -603,20 +610,20 @@ def train(args: TrainArgs):
|
|
| 603 |
# step: Metric at a step
|
| 604 |
# interval: Metric averaged/summed across all steps since the last log interval.
|
| 605 |
# Typically, this is 10
|
| 606 |
-
step_loss_per_gpu = loss
|
| 607 |
-
step_loss_across_gpus = dist_mean(step_loss_per_gpu)
|
| 608 |
-
interval_loss_per_gpu = np.mean(step_losses)
|
| 609 |
-
interval_loss_across_gpus = dist_mean(interval_loss_per_gpu)
|
| 610 |
|
| 611 |
stacked_tok_loss = torch.cat(step_tok_losses, dim=0)
|
| 612 |
-
interval_total_tok_loss_per_gpu = stacked_tok_loss.sum()
|
| 613 |
interval_total_tok_loss_across_gpus = dist_sum(
|
| 614 |
interval_total_tok_loss_per_gpu, reduce_dtype=torch.bfloat16
|
| 615 |
-
)
|
| 616 |
-
interval_total_n_bytes_per_gpu = n_bytes
|
| 617 |
interval_total_n_bytes_across_gpus = dist_sum(
|
| 618 |
n_bytes, reduce_dtype=torch.bfloat16
|
| 619 |
-
)
|
| 620 |
|
| 621 |
interval_bpb_per_gpu = (
|
| 622 |
interval_total_tok_loss_per_gpu
|
|
@@ -645,18 +652,20 @@ def train(args: TrainArgs):
|
|
| 645 |
},
|
| 646 |
"memory": gpu_mem_stats._asdict(),
|
| 647 |
"loss": {
|
| 648 |
-
"step_per_gpu": step_loss_per_gpu,
|
| 649 |
-
"step_across_gpu": step_loss_across_gpus,
|
| 650 |
-
"interval_per_gpu": interval_loss_per_gpu,
|
| 651 |
-
"interval_across_gpu": interval_loss_across_gpus,
|
| 652 |
},
|
| 653 |
"bpb": {
|
| 654 |
-
"interval_per_gpu": interval_bpb_per_gpu,
|
| 655 |
-
"interval_across_gpus": interval_bpb_across_gpus,
|
| 656 |
},
|
| 657 |
"n_bytes": {
|
| 658 |
-
"interval_per_gpu": interval_total_n_bytes_per_gpu,
|
| 659 |
-
"interval_across_gpus":
|
|
|
|
|
|
|
| 660 |
},
|
| 661 |
}
|
| 662 |
|
|
@@ -676,8 +685,8 @@ def train(args: TrainArgs):
|
|
| 676 |
logger.info(
|
| 677 |
f"step: {train_state.step}"
|
| 678 |
f" acc: {train_state.acc_step}"
|
| 679 |
-
f" loss_gpu: {round(interval_loss_per_gpu, 4):>7}"
|
| 680 |
-
f" loss_avg: {round(interval_loss_across_gpus, 4):>7}"
|
| 681 |
f" bpb_gpu: {interval_bpb_per_gpu:3f}"
|
| 682 |
f" bpb_avg: {interval_bpb_across_gpus:3f}"
|
| 683 |
f" grad: {grad_norm:.2e}"
|
|
@@ -702,6 +711,9 @@ def train(args: TrainArgs):
|
|
| 702 |
if every_n_steps(
|
| 703 |
train_state, args.checkpoint.dump.every, acc_step=0
|
| 704 |
) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
|
|
|
|
|
|
|
|
|
|
| 705 |
saved = checkpoint.save(
|
| 706 |
model,
|
| 707 |
optimizer,
|
|
@@ -743,6 +755,9 @@ def train(args: TrainArgs):
|
|
| 743 |
|
| 744 |
if preemption_flag["flag"]:
|
| 745 |
if not saved:
|
|
|
|
|
|
|
|
|
|
| 746 |
checkpoint.save(
|
| 747 |
model,
|
| 748 |
optimizer,
|
|
@@ -754,6 +769,9 @@ def train(args: TrainArgs):
|
|
| 754 |
sys.exit(0)
|
| 755 |
|
| 756 |
if not saved:
|
|
|
|
|
|
|
|
|
|
| 757 |
checkpoint.save(
|
| 758 |
model,
|
| 759 |
optimizer,
|
|
|
|
| 26 |
from bytelatent.args import TrainArgs, parse_args
|
| 27 |
from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint
|
| 28 |
from bytelatent.data.file_util import get_fs
|
| 29 |
+
from bytelatent.data.iterators.abstract_iterator import get_state_and_refresh
|
| 30 |
from bytelatent.data.iterators.multiprocess_iterator import (
|
| 31 |
MultiprocessIterator,
|
| 32 |
MultiprocessIteratorState,
|
|
|
|
| 36 |
check_model_value_range,
|
| 37 |
clean_env,
|
| 38 |
dist_mean,
|
|
|
|
| 39 |
dist_sum,
|
| 40 |
get_device_mesh,
|
| 41 |
get_is_master,
|
|
|
|
| 88 |
raise ValueError(f"Unsupported iterator to get name from: {iterator_state}")
|
| 89 |
|
| 90 |
|
| 91 |
+
def to_py_num(num: int | float | torch.Tensor | np.ndarray) -> int | float:
|
| 92 |
+
if isinstance(num, (torch.Tensor, np.ndarray)):
|
| 93 |
+
return num.item()
|
| 94 |
+
else:
|
| 95 |
+
return num
|
| 96 |
+
|
| 97 |
+
|
| 98 |
# TODO: Make this pydantic based instead of data class based
|
| 99 |
# TODO: Generalize this to any iterator state
|
| 100 |
@dataclass
|
|
|
|
| 610 |
# step: Metric at a step
|
| 611 |
# interval: Metric averaged/summed across all steps since the last log interval.
|
| 612 |
# Typically, this is 10
|
| 613 |
+
step_loss_per_gpu = loss
|
| 614 |
+
step_loss_across_gpus = dist_mean(step_loss_per_gpu)
|
| 615 |
+
interval_loss_per_gpu = np.mean(step_losses)
|
| 616 |
+
interval_loss_across_gpus = dist_mean(interval_loss_per_gpu)
|
| 617 |
|
| 618 |
stacked_tok_loss = torch.cat(step_tok_losses, dim=0)
|
| 619 |
+
interval_total_tok_loss_per_gpu = stacked_tok_loss.sum()
|
| 620 |
interval_total_tok_loss_across_gpus = dist_sum(
|
| 621 |
interval_total_tok_loss_per_gpu, reduce_dtype=torch.bfloat16
|
| 622 |
+
)
|
| 623 |
+
interval_total_n_bytes_per_gpu = n_bytes
|
| 624 |
interval_total_n_bytes_across_gpus = dist_sum(
|
| 625 |
n_bytes, reduce_dtype=torch.bfloat16
|
| 626 |
+
)
|
| 627 |
|
| 628 |
interval_bpb_per_gpu = (
|
| 629 |
interval_total_tok_loss_per_gpu
|
|
|
|
| 652 |
},
|
| 653 |
"memory": gpu_mem_stats._asdict(),
|
| 654 |
"loss": {
|
| 655 |
+
"step_per_gpu": to_py_num(step_loss_per_gpu),
|
| 656 |
+
"step_across_gpu": to_py_num(step_loss_across_gpus),
|
| 657 |
+
"interval_per_gpu": to_py_num(interval_loss_per_gpu),
|
| 658 |
+
"interval_across_gpu": to_py_num(interval_loss_across_gpus),
|
| 659 |
},
|
| 660 |
"bpb": {
|
| 661 |
+
"interval_per_gpu": to_py_num(interval_bpb_per_gpu),
|
| 662 |
+
"interval_across_gpus": to_py_num(interval_bpb_across_gpus),
|
| 663 |
},
|
| 664 |
"n_bytes": {
|
| 665 |
+
"interval_per_gpu": to_py_num(interval_total_n_bytes_per_gpu),
|
| 666 |
+
"interval_across_gpus": to_py_num(
|
| 667 |
+
interval_total_n_bytes_across_gpus
|
| 668 |
+
),
|
| 669 |
},
|
| 670 |
}
|
| 671 |
|
|
|
|
| 685 |
logger.info(
|
| 686 |
f"step: {train_state.step}"
|
| 687 |
f" acc: {train_state.acc_step}"
|
| 688 |
+
f" loss_gpu: {round(to_py_num(interval_loss_per_gpu), 4):>7}"
|
| 689 |
+
f" loss_avg: {round(to_py_num(interval_loss_across_gpus), 4):>7}"
|
| 690 |
f" bpb_gpu: {interval_bpb_per_gpu:3f}"
|
| 691 |
f" bpb_avg: {interval_bpb_across_gpus:3f}"
|
| 692 |
f" grad: {grad_norm:.2e}"
|
|
|
|
| 711 |
if every_n_steps(
|
| 712 |
train_state, args.checkpoint.dump.every, acc_step=0
|
| 713 |
) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
|
| 714 |
+
train_state.data_loader_state, data_loader, batch_iterator = (
|
| 715 |
+
get_state_and_refresh(data_loader)
|
| 716 |
+
)
|
| 717 |
saved = checkpoint.save(
|
| 718 |
model,
|
| 719 |
optimizer,
|
|
|
|
| 755 |
|
| 756 |
if preemption_flag["flag"]:
|
| 757 |
if not saved:
|
| 758 |
+
train_state.data_loader_state, data_loader, batch_iterator = (
|
| 759 |
+
get_state_and_refresh(data_loader)
|
| 760 |
+
)
|
| 761 |
checkpoint.save(
|
| 762 |
model,
|
| 763 |
optimizer,
|
|
|
|
| 769 |
sys.exit(0)
|
| 770 |
|
| 771 |
if not saved:
|
| 772 |
+
train_state.data_loader_state, data_loader, batch_iterator = (
|
| 773 |
+
get_state_and_refresh(data_loader)
|
| 774 |
+
)
|
| 775 |
checkpoint.save(
|
| 776 |
model,
|
| 777 |
optimizer,
|