| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import copy |
| import sys |
| import traceback |
| from collections import deque |
| from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, Iterator, Optional |
|
|
| from ..utils import logging |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
| if TYPE_CHECKING: |
| from .batching_strategy import BaseBatchingStrategy |
|
|
|
|
| class DynamicBatchSizeDataLoader: |
| """Dynamic batch DataLoader. |
| |
| Args: |
| dataloader: torch DataLoader |
| batching_strategy: dynamic batch strategy |
| collate_fn: DataLoader collate_fn, collate data after get data from batching_strategy |
| num_micro_batch: num_micro_batch, if num_micro_batch == 1, return micro_batch for gradient accumulation |
| length: length of dataloader, if length == -1, length = sys.maxsize, default len(dataloader) |
| drop_last: if True, drop last batch if batch size < num_micro_batch |
| |
| """ |
|
|
| def __init__( |
| self, |
| dataloader: Any, |
| batching_strategy: "BaseBatchingStrategy", |
| collate_fn: Optional[Callable] = None, |
| num_micro_batch: int = 1, |
| length: int = 0, |
| drop_last: bool = True, |
| ) -> None: |
| self.batching_strategy = batching_strategy |
| self.num_micro_batch = num_micro_batch |
| self.dataloader_item_buffer = deque() |
| self.item_buffer = deque() |
| self.step = 0 |
| self._collate_fn = collate_fn |
| self._dataloader = dataloader |
| self._drop_last = drop_last |
| self._data_iter: Iterator |
| self._resume = False |
| self._batch_data_iter: Generator |
|
|
| if length > 0: |
| self._length = length |
| elif length == -1: |
| self._length = sys.maxsize |
| else: |
| self._length = len(self._dataloader) |
|
|
| def __len__(self): |
| if self._length: |
| return self._length |
| else: |
| raise RuntimeError("length must set at init. before call len()") |
|
|
| def __iter__(self) -> Iterator: |
| if not self._resume: |
| self.step = 0 |
| self._data_iter = iter(self._dataloader) |
| self._batch_data_iter = self.batch_data_generator() |
| self._resume = False |
| return self |
|
|
| def __next__(self): |
| return next(self._batch_data_iter) |
|
|
| def batch_data_generator(self): |
| batch = [] |
|
|
| while True: |
| if self._length and self.step >= self._length: |
| return |
|
|
| if self.batching_strategy.is_full_filled(): |
| micro_batch = self.batching_strategy.get_micro_batch(self.step) |
| if self._collate_fn: |
| micro_batch = self._collate_fn(micro_batch) |
| batch.append(micro_batch) |
| if len(batch) == self.num_micro_batch: |
| yield batch |
| self.step += 1 |
| batch = [] |
|
|
| try: |
| processing_item = next(self._data_iter) |
| except Exception as e: |
| if isinstance(e, StopIteration): |
| if self.step < self._length: |
| |
| self._data_iter = iter(self._dataloader) |
| processing_item = next(self._data_iter) |
| elif not self._drop_last and not self.batching_strategy.empty(): |
| while not self.batching_strategy.empty(): |
| micro_batch = self.batching_strategy.get_micro_batch(self.step) |
| if self._collate_fn: |
| micro_batch = self._collate_fn(micro_batch) |
| batch.append(micro_batch) |
| if len(batch) == self.num_micro_batch: |
| yield batch |
| self.step += 1 |
| batch = [] |
|
|
| while len(batch) < self.num_micro_batch: |
| padding_batch = copy.deepcopy(micro_batch) |
| padding_batch["padding_flag"] = True |
| batch.append(padding_batch) |
| yield batch |
| self.step += 1 |
| return |
| else: |
| return |
| else: |
| logger.error(f"DynamicBatchDataset iter data exception: {e} \n{traceback.format_exc()}") |
| raise |
|
|
| |
| if isinstance(processing_item, dict): |
| processing_item = [processing_item] |
|
|
| for item in processing_item: |
| self.batching_strategy.put_item(item) |
|
|
| def state_dict(self): |
| |
| state = self.__dict__.copy() |
| |
| for k in list(state.keys()): |
| if k.startswith("_"): |
| del state[k] |
|
|
| |
| if hasattr(self._dataloader, "state_dict"): |
| state["dataloader_state"] = self._dataloader.state_dict() |
| elif hasattr(self._dataloader, "__getstate__"): |
| state["dataloader_state"] = self._dataloader.__getstate__() |
|
|
| if hasattr(self.batching_strategy, "state_dict"): |
| state["batching_strategy_state"] = self.batching_strategy.state_dict() |
| del state["batching_strategy"] |
|
|
| return copy.deepcopy(state) |
|
|
| def load_state_dict(self, state: Dict[str, Any]): |
| if state["num_micro_batch"] != self.num_micro_batch: |
| logger.warning( |
| f"num_micro_batch changed: [ {state['num_micro_batch']} -> {self.num_micro_batch} ], will clear prefetch buffer" |
| ) |
| del state["num_micro_batch"] |
| self.__dict__.update(state) |
| self._resume = True |
|
|
| if hasattr(self._dataloader, "load_state_dict"): |
| self._dataloader.load_state_dict(state["dataloader_state"]) |
| elif hasattr(self._dataloader, "__getstate__"): |
| self._dataloader.__setstate__(state["dataloader_state"]) |
|
|
| if "batching_strategy_state" in state: |
| self.batching_strategy.load_state_dict( |
| state["batching_strategy_state"] |
| ) |
| del state["batching_strategy_state"] |
|
|
| self._data_iter = iter(self._dataloader) |
| self._batch_data_iter = self.batch_data_generator() |
|
|