lingbot-vla / lingbotvla /data /dynamic_batching.py
bazaar-research's picture
Upload folder using huggingface_hub
fb11af9 verified
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import sys
import traceback
from collections import deque
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, Iterator, Optional
from ..utils import logging
logger = logging.get_logger(__name__)
if TYPE_CHECKING:
from .batching_strategy import BaseBatchingStrategy
class DynamicBatchSizeDataLoader:
"""Dynamic batch DataLoader.
Args:
dataloader: torch DataLoader
batching_strategy: dynamic batch strategy
collate_fn: DataLoader collate_fn, collate data after get data from batching_strategy
num_micro_batch: num_micro_batch, if num_micro_batch == 1, return micro_batch for gradient accumulation
length: length of dataloader, if length == -1, length = sys.maxsize, default len(dataloader)
drop_last: if True, drop last batch if batch size < num_micro_batch
"""
def __init__(
self,
dataloader: Any,
batching_strategy: "BaseBatchingStrategy",
collate_fn: Optional[Callable] = None,
num_micro_batch: int = 1,
length: int = 0,
drop_last: bool = True,
) -> None:
self.batching_strategy = batching_strategy
self.num_micro_batch = num_micro_batch
self.dataloader_item_buffer = deque()
self.item_buffer = deque()
self.step = 0
self._collate_fn = collate_fn
self._dataloader = dataloader
self._drop_last = drop_last
self._data_iter: Iterator
self._resume = False
self._batch_data_iter: Generator
if length > 0:
self._length = length
elif length == -1:
self._length = sys.maxsize
else:
self._length = len(self._dataloader)
def __len__(self):
if self._length:
return self._length
else:
raise RuntimeError("length must set at init. before call len()")
def __iter__(self) -> Iterator:
if not self._resume:
self.step = 0
self._data_iter = iter(self._dataloader)
self._batch_data_iter = self.batch_data_generator()
self._resume = False
return self
def __next__(self):
return next(self._batch_data_iter)
def batch_data_generator(self):
batch = []
while True:
if self._length and self.step >= self._length:
return
if self.batching_strategy.is_full_filled():
micro_batch = self.batching_strategy.get_micro_batch(self.step)
if self._collate_fn:
micro_batch = self._collate_fn(micro_batch)
batch.append(micro_batch)
if len(batch) == self.num_micro_batch:
yield batch
self.step += 1
batch = []
try:
processing_item = next(self._data_iter)
except Exception as e:
if isinstance(e, StopIteration):
if self.step < self._length:
# call iter until reach length
self._data_iter = iter(self._dataloader)
processing_item = next(self._data_iter)
elif not self._drop_last and not self.batching_strategy.empty():
while not self.batching_strategy.empty():
micro_batch = self.batching_strategy.get_micro_batch(self.step)
if self._collate_fn:
micro_batch = self._collate_fn(micro_batch)
batch.append(micro_batch)
if len(batch) == self.num_micro_batch:
yield batch
self.step += 1
batch = []
while len(batch) < self.num_micro_batch:
padding_batch = copy.deepcopy(micro_batch)
padding_batch["padding_flag"] = True
batch.append(padding_batch)
yield batch
self.step += 1
return
else:
return
else:
logger.error(f"DynamicBatchDataset iter data exception: {e} \n{traceback.format_exc()}")
raise
# put processing_item to buffer
if isinstance(processing_item, dict):
processing_item = [processing_item]
for item in processing_item:
self.batching_strategy.put_item(item)
def state_dict(self):
# save state
state = self.__dict__.copy()
# remove internal fields
for k in list(state.keys()):
if k.startswith("_"):
del state[k]
# save dataloader state
if hasattr(self._dataloader, "state_dict"):
state["dataloader_state"] = self._dataloader.state_dict()
elif hasattr(self._dataloader, "__getstate__"):
state["dataloader_state"] = self._dataloader.__getstate__()
if hasattr(self.batching_strategy, "state_dict"):
state["batching_strategy_state"] = self.batching_strategy.state_dict() # type: ignore
del state["batching_strategy"]
return copy.deepcopy(state)
def load_state_dict(self, state: Dict[str, Any]):
if state["num_micro_batch"] != self.num_micro_batch:
logger.warning(
f"num_micro_batch changed: [ {state['num_micro_batch']} -> {self.num_micro_batch} ], will clear prefetch buffer"
)
del state["num_micro_batch"]
self.__dict__.update(state)
self._resume = True
if hasattr(self._dataloader, "load_state_dict"):
self._dataloader.load_state_dict(state["dataloader_state"])
elif hasattr(self._dataloader, "__getstate__"):
self._dataloader.__setstate__(state["dataloader_state"])
if "batching_strategy_state" in state:
self.batching_strategy.load_state_dict( # type: ignore
state["batching_strategy_state"]
)
del state["batching_strategy_state"]
self._data_iter = iter(self._dataloader)
self._batch_data_iter = self.batch_data_generator()