lingbot-vla / lingbotvla /data /data_loader.py
bazaar-research's picture
Upload folder using huggingface_hub
fb11af9 verified
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Callable, List, Optional, Union
from torch.utils.data import IterableDataset
from torchdata.stateful_dataloader import StatefulDataLoader
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
from ..distributed.parallel_state import get_parallel_state
from ..utils import logging
from .batching_strategy import TextBatchingStrategy
from .data_collator import (
CollatePipeline,
DataCollatorWithPacking,
DataCollatorWithPadding,
DataCollatorWithPositionIDs,
MakeMicroBatchCollator,
TextSequenceShardCollator,
UnpackDataCollator,
)
from .dynamic_batching import DynamicBatchSizeDataLoader
if TYPE_CHECKING:
from torch.utils.data import Dataset
logger = logging.get_logger(__name__)
class DistributedDataloader(StatefulDataLoader):
dataset: "Dataset"
sampler: "StatefulDistributedSampler"
def set_epoch(self, epoch: int) -> None:
if self.sampler is not None and hasattr(self.sampler, "set_epoch"):
self.sampler.set_epoch(epoch)
elif hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(epoch)
def build_dataloader(
dataset: "Dataset",
micro_batch_size: int,
global_batch_size: int,
dataloader_batch_size: int,
max_seq_len: int,
train_steps: int,
rmpad: bool = True,
rmpad_with_pos_ids: bool = False,
bsz_warmup_ratio: float = 0.02,
bsz_warmup_init_mbtoken: int = 200,
dyn_bsz_buffer_size: int = 500,
dyn_bsz_margin: int = 0,
collate_fn: Optional[Union[Callable, List[Callable]]] = None,
num_workers: int = 8,
drop_last: bool = True,
pin_memory: bool = True,
prefetch_factor: Optional[int] = 2,
seed: int = 0,
) -> "DistributedDataloader":
parallel_state = get_parallel_state()
token_micro_bsz = micro_batch_size * max_seq_len
num_micro_batch = global_batch_size // (
micro_batch_size * parallel_state.dp_size
) # num_micro_batch = num accumulation steps
bsz_warmup_steps = int(train_steps * bsz_warmup_ratio)
use_rmpad = rmpad or rmpad_with_pos_ids
logger.info_rank0(
f"train_steps: {train_steps}, max_seq_len: {max_seq_len}, use_rmpad: {use_rmpad}, "
f"bsz_warmup_steps: {bsz_warmup_steps}, bsz_warmup_init_mbtoken: {bsz_warmup_init_mbtoken}, "
f"token_micro_bsz: {token_micro_bsz}, num_micro_batch: {num_micro_batch}, "
f"micro_batch_size: {micro_batch_size}, global_batch_size: {global_batch_size}, "
f"dp_size: {parallel_state.dp_size}, sp_size: {parallel_state.sp_size}."
)
if collate_fn is None:
collate_fn_list = []
if rmpad_with_pos_ids:
collate_fn_list.append(DataCollatorWithPositionIDs())
elif rmpad:
collate_fn_list.append(DataCollatorWithPacking())
else:
collate_fn_list.append(DataCollatorWithPadding())
if parallel_state.sp_enabled:
collate_fn_list.append(TextSequenceShardCollator(rmpad=rmpad, rmpad_with_pos_ids=rmpad_with_pos_ids))
collate_fn = CollatePipeline(collate_fn_list)
if isinstance(collate_fn, list):
collate_fn = CollatePipeline(collate_fn)
if use_rmpad:
batching_strategy = TextBatchingStrategy(
token_micro_bsz=token_micro_bsz - dyn_bsz_margin * max_seq_len,
buffer_size=dyn_bsz_buffer_size,
bsz_warmup_steps=bsz_warmup_steps if bsz_warmup_steps else -1,
bsz_warmup_init_mbtoken=bsz_warmup_init_mbtoken,
)
dyn_bsz_collate_fn = collate_fn
collate_fn = UnpackDataCollator()
else:
collate_fn = MakeMicroBatchCollator(num_micro_batch=num_micro_batch, internal_data_collator=collate_fn)
sampler = None
if not isinstance(dataset, IterableDataset):
sampler = StatefulDistributedSampler(
dataset,
num_replicas=parallel_state.dp_size,
rank=parallel_state.dp_rank,
shuffle=True,
seed=seed,
)
dataloader = DistributedDataloader(
dataset,
batch_size=dataloader_batch_size,
sampler=sampler,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=drop_last,
prefetch_factor=prefetch_factor,
)
if use_rmpad:
dataloader = DynamicBatchSizeDataLoader(
dataloader,
batching_strategy=batching_strategy,
collate_fn=dyn_bsz_collate_fn,
num_micro_batch=num_micro_batch,
length=train_steps,
drop_last=drop_last,
)
return dataloader