| |
|
|
| import itertools |
| import json |
| import logging |
| import os |
| import re |
| import traceback |
| from typing import Any, Callable, Dict, Iterator, List, Optional, cast |
|
|
| import numpy as np |
| import torch |
| from torch.utils.data import IterableDataset, get_worker_info |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def get_worker_info(): |
| worker_info = torch.utils.data.get_worker_info() |
| if worker_info is None: |
| num_workers = 1 |
| worker_id = 0 |
| else: |
| num_workers = worker_info.num_workers |
| worker_id = worker_info.id |
|
|
| return worker_id, num_workers |
|
|
|
|
| def get_global_rank_info(rank, world_size): |
| worker_id, num_workers = get_worker_info() |
| dataloader_rank = rank * num_workers + worker_id |
| dataloader_world_size = world_size * num_workers |
| return dataloader_rank, dataloader_world_size |
|
|
|
|
| class JSONLIterator: |
| def __init__( |
| self, |
| fpath: str, |
| world_size: int, |
| world_rank: int, |
| infinite: bool, |
| ): |
| assert 0 <= world_rank < world_size, (world_rank, world_size) |
| self.f = open(fpath, "r", encoding="utf-8") |
| self.fpath = fpath |
| self.world_size = world_size |
| self.world_rank = world_rank |
| self.line_num = 0 |
| self.iter = iter(self.gen(infinite)) |
| self.iter_id = 0 |
|
|
| def __iter__(self): |
| return self |
|
|
| def __next__(self): |
| return next(self.iter) |
|
|
| def gen(self, infinite: bool) -> Iterator[Dict]: |
| while True: |
| if self.world_rank == 0: |
| logger.info(f"Starting iteration {self.iter_id} over {self.fpath} ...") |
| self.iter_id += 1 |
| while True: |
| line, self.line_num = self.f.readline(), self.line_num + 1 |
| if not line: |
| break |
| if (self.line_num - 1) % self.world_size == self.world_rank: |
| yield json.loads(line) |
| if not infinite: |
| break |
| self.set_position(None) |
| self.f.close() |
|
|
| def set_position(self, position: Optional[int]): |
| logger.warning( |
| f"Setting JSONL position on {self.fpath} " |
| f"({self.world_rank}/{self.world_size}): {position}" |
| ) |
| if position is None: |
| self.f.seek(0) |
| self.line_num = 0 |
| else: |
| assert isinstance(position, int) |
| self.f.seek(position) |
| self.line_num = ( |
| self.world_rank + 1 |
| ) |
|
|
| def get_position(self) -> Optional[int]: |
| file_pos = self.f.tell() |
| if file_pos == 0 and self.line_num == 0: |
| return None |
| assert (self.line_num - 1) % self.world_size == self.world_rank |
| return file_pos |
|
|
| def get_example_file(self): |
| """ |
| Return the path to a sample file to infer the content key |
| """ |
| return self.fpath |
|
|
| def get_id(self): |
| """ |
| Return an identifier for the dataset this iterator represents |
| """ |
| return self.fpath |
|
|
|
|
| class JSONLDirectoryIterator: |
| """ |
| The JSONLDirectoryIterator is a data wrapper around a dataset folder, which contains |
| multiple JSONL files. Internally, it reuses the JSONLIterator class to iterate through |
| each individual file, and then wraps onto the next file once the current one is exhausted. |
| |
| Once all files in the directory have been iterated over, we wrap back to the first file |
| ( if infinite is true ). |
| |
| This enables us to iterate over a dataset one chunk at a time. |
| |
| Also, note that we open the next chunk file on an ondemand basis, which means that we can |
| modify chunks mid training as well to add more data, fix issues, etc. |
| """ |
|
|
| def __init__( |
| self, |
| dirpath: str, |
| world_size: int, |
| world_rank: int, |
| infinite: bool, |
| ): |
| assert 0 <= world_rank < world_size, (world_rank, world_size) |
| self.dirpath = dirpath |
| self.world_size = world_size |
| self.world_rank = world_rank |
|
|
| fnames = [ |
| x |
| for x in os.listdir(self.dirpath) |
| if re.fullmatch(r".*chunk\.\d+.*\.jsonl", x) |
| ] |
| self.fpaths = [os.path.join(self.dirpath, fname) for fname in sorted(fnames)] |
| assert ( |
| len(self.fpaths) > 0 |
| ), f"Specified dataset location {self.dirpath} is empty." |
|
|
| |
| if infinite: |
| self.fpaths_generator = cast(Iterator[str], itertools.cycle(self.fpaths)) |
| else: |
| self.fpaths_generator = cast(Iterator[str], iter(self.fpaths)) |
|
|
| self.iter = iter(self.gen(infinite)) |
| self.jsonl_iterator: Optional[JSONLIterator] = None |
|
|
| def __iter__(self): |
| return self |
|
|
| def __next__(self): |
| return next(self.iter) |
|
|
| def gen(self, infinite: bool) -> Iterator[Dict]: |
| |
| if self.jsonl_iterator is not None: |
| yield from self.jsonl_iterator |
|
|
| for fpath in self.fpaths_generator: |
| |
| self.jsonl_iterator = JSONLIterator( |
| fpath, |
| world_size=self.world_size, |
| world_rank=self.world_rank, |
| infinite=False, |
| ) |
|
|
| yield from self.jsonl_iterator |
|
|
| def set_position(self, state: Dict[str, Any]): |
| logger.warning( |
| f"Setting JSONL position on {self.dirpath} " |
| f"({self.world_rank}/{self.world_size}): {state}" |
| ) |
| fpath: Optional[str] = state["fpath"] |
| position: Optional[int] = state["position"] |
| if fpath is None or position is None: |
| return |
|
|
| assert isinstance(fpath, str) |
| assert isinstance(position, int) |
|
|
| |
| for fpath_candidate in self.fpaths_generator: |
| if fpath_candidate == fpath: |
| break |
|
|
| |
| self.jsonl_iterator = JSONLIterator( |
| fpath, |
| world_size=self.world_size, |
| world_rank=self.world_rank, |
| infinite=False, |
| ) |
| self.jsonl_iterator.set_position(position) |
|
|
| def get_position(self): |
| if self.jsonl_iterator is None: |
| return { |
| "fpath": None, |
| "position": None, |
| } |
| return { |
| "fpath": self.jsonl_iterator.fpath, |
| "position": self.jsonl_iterator.get_position(), |
| } |
|
|
| def get_example_file(self): |
| """ |
| Return the path to a sample file to infer the content key |
| """ |
| return self.fpaths[0] |
|
|
| def get_id(self): |
| """ |
| Return an identifier for the dataset this iterator represents |
| """ |
| return self.dirpath |
|
|
|
|
| class IterativeJSONLDataset(IterableDataset): |
| def __init__( |
| self, |
| global_rank: int, |
| world_size: int, |
| dataset_name: str, |
| seed: int = 0, |
| dataset_configs: Dict[str, Any] = {}, |
| ): |
| self._dataset_name = dataset_name |
| self._seed = seed |
| self._dataset_conf = dataset_configs[dataset_name] |
|
|
| self.global_rank = global_rank |
| self.world_size = world_size |
| self.data_path = self._dataset_conf.annotation |
|
|
| def worker_init(self, worker_id, num_workers): |
| dataloader_rank = self.global_rank * num_workers + worker_id |
| dataloader_world_size = self.world_size * num_workers |
| if os.path.isfile(self.data_path): |
| self.jsonl_iterator = JSONLIterator( |
| self.data_path, |
| world_size=dataloader_world_size, |
| world_rank=dataloader_rank, |
| infinite=True, |
| ) |
| else: |
| self.jsonl_iterator = JSONLDirectoryIterator( |
| dirpath=self.data_path, |
| world_size=dataloader_world_size, |
| world_rank=dataloader_rank, |
| infinite=True, |
| ) |
| if worker_id == 0: |
| logger.info( |
| f"Initializing JSONLDataset {self._dataset_name} on " |
| f"dataloader rank {dataloader_rank} and world size {dataloader_world_size}" |
| ) |
|
|
| def state_dict(self): |
| pos = self.jsonl_iterator.get_position() |
| if isinstance(pos, Dict): |
| return pos |
| else: |
| return {"single_jsonl_position": pos} |
|
|
| def load_state_dict(self, state_dict): |
| if "single_jsonl_position" in state_dict: |
| self.jsonl_iterator.set_position(state_dict["single_jsonl_position"]) |
| else: |
| self.jsonl_iterator.set_position(state_dict) |
| logger.info(f"JSONLDataset {self._dataset_name} resuming from {state_dict}.") |
|
|
| def __iter__(self): |
| return self |
|
|
| def __next__(self): |
| return next(self.jsonl_iterator) |
|
|
|
|
| class DatasetMixer(IterableDataset): |
| def __init__( |
| self, |
| mix: str, |
| global_rank: int, |
| world_size: int, |
| seed: int = 0, |
| preprocessors: List[Callable] = [], |
| dataset_configs: Dict[str, Any] = {}, |
| ): |
| super().__init__() |
|
|
| self.dataset_and_preprocessors = [] |
| self.weights = [] |
| self.dataset_names = [] |
| self.totals = [] |
|
|
| self.global_rank = global_rank |
| self.world_size = world_size |
| self.seed = seed |
|
|
| mix = "".join(mix.split()) |
|
|
| for elem in mix.split(","): |
| ds, weight = elem.split(":") |
|
|
| if ds not in dataset_configs: |
| raise ValueError(f"Dataset {ds} not found in dataset_configs.") |
| if ds in self.dataset_names: |
| raise ValueError( |
| f"Dataset {ds} already in the mix. Each dataset can only be used once." |
| ) |
|
|
| dataset = IterativeJSONLDataset( |
| global_rank=global_rank, |
| world_size=world_size, |
| dataset_name=ds, |
| seed=seed, |
| dataset_configs=dataset_configs, |
| ) |
| _preprocessors = [ |
| p(dataset_config=dataset_configs[ds]) for p in preprocessors |
| ] |
|
|
| self.dataset_and_preprocessors.append((dataset, _preprocessors)) |
| self.weights.append(float(weight)) |
| self.dataset_names.append(ds) |
| self.totals.append(0) |
|
|
| self.weights = [w / sum(self.weights) for w in self.weights] |
| self.rng = None |
|
|
| def state_dict(self): |
| return { |
| "datasets": { |
| ds_name: ds.state_dict() |
| for ds_name, (ds, _) in zip( |
| self.dataset_names, self.dataset_and_preprocessors |
| ) |
| }, |
| "totals": { |
| ds_name: total |
| for ds_name, total in zip(self.dataset_names, self.totals) |
| }, |
| "rng": ( |
| [ |
| s.tolist() if isinstance(s, np.ndarray) else s |
| for s in self.rng.get_state() |
| ] |
| if self.rng is not None |
| else None |
| ), |
| } |
|
|
| def load_state_dict(self, state_dict): |
| for ds_name, sd in state_dict["datasets"].items(): |
| if ds_name in self.dataset_names: |
| ds_idx = self.dataset_names.index(ds_name) |
| ds, _ = self.dataset_and_preprocessors[ds_idx] |
| ds.load_state_dict(sd) |
| self.totals[ds_idx] = state_dict["totals"][ds_name] |
|
|
| logger.info( |
| f"DatasetMixer with datasets {self.dataset_names} resuming with total samples seen {self.totals} on process {os.getpid()}." |
| ) |
|
|
| if state_dict["rng"] is not None: |
| self.rng = np.random.RandomState() |
| rng_state = [ |
| np.array(s) if isinstance(s, list) else s for s in state_dict["rng"] |
| ] |
| self.rng.set_state(rng_state) |
|
|
| def worker_init(self, worker_id): |
| worker_info = torch.utils.data.get_worker_info() |
| for dataset, _ in self.dataset_and_preprocessors: |
| if hasattr(dataset, "worker_init"): |
| dataset.worker_init(worker_id, worker_info.num_workers) |
|
|
| def __iter__(self): |
| if self.rng is None: |
| rank, world_size = get_global_rank_info(self.global_rank, self.world_size) |
| self.rng = np.random.RandomState((rank, world_size, self.seed)) |
|
|
| while True: |
| try: |
| src_id = self.rng.choice(len(self.weights), p=self.weights) |
| dataset, preprocessors = self.dataset_and_preprocessors[src_id] |
| out = next(dataset) |
| for preprocessor in preprocessors: |
| if out is not None: |
| out = preprocessor(out, self.rng) |
|
|
| if out is None: |
| continue |
|
|
| self.totals[src_id] += 1 |
| yield out |
| except Exception as e: |
| logger.error( |
| f"Error while iterating over dataset {self.dataset_names[src_id]}: {e}\n" |
| f"Traceback:\n{traceback.format_exc()}" |
| ) |
|
|
|
|
| class PersistentDataLoader: |
| """ |
| A _very_ persistent dataloader. |
| |
| Uses StatefulDataLoader to save dataset state (make sure dataset has a state_dict() and load_state_dict() method). |
| Also keeps the dataloader iterator and the epoch iterator separate, so that the dataloader workers are persistent. |
| |
| Also laughs in the face of torch when it tries to kill the whole job because a worker died. Instead, this dataloader |
| will just gracefully restart the underlying iterator and corresponding workers, while additionally loading the state dict |
| so that it resumes from where it left off. |
| |
| This may or may not be a good idea. |
| """ |
|
|
| def __init__( |
| self, |
| dataset, |
| batch_size, |
| workers, |
| collate_fn=None, |
| positions=None, |
| ): |
| from torchdata.stateful_dataloader import StatefulDataLoader |
|
|
| self.dataloader = StatefulDataLoader( |
| dataset, |
| batch_size=batch_size, |
| shuffle=False, |
| num_workers=workers, |
| |
| multiprocessing_context="fork" if workers > 0 else None, |
| collate_fn=collate_fn, |
| worker_init_fn=( |
| dataset.worker_init if hasattr(dataset, "worker_init") else None |
| ), |
| |
| snapshot_every_n_steps=1, |
| ) |
|
|
| if positions is not None: |
| self.load_state_dict(positions) |
|
|
| self._dataloader_iter = iter(self.dataloader) |
|
|
| |
| |
|
|
| def state_dict(self): |
| return self.dataloader.state_dict() |
|
|
| def load_state_dict(self, state_dict): |
| self.dataloader.load_state_dict(state_dict) |
|
|
| def __del__(self): |
| pass |
|
|
| def __len__(self): |
| return len(self.dataloader) |
|
|
| def __iter__(self): |
| self.iter = self.gen() |
| return self |
|
|
| def __next__(self): |
| return next(self.iter) |
|
|
| def _refresh_iter(self): |
| |
| self._dataloader_iter = None |
|
|
| def _get_next_sample(self): |
| if self._dataloader_iter is None: |
| self.dataloader.load_state_dict(self.dataloader.state_dict()) |
| self._dataloader_iter = iter(self.dataloader) |
|
|
| try: |
| return next(self._dataloader_iter) |
| except (KeyboardInterrupt, StopIteration): |
| raise |
| except Exception as e: |
| if self._dataloader_iter is None: |
| |
| return self._get_next_sample() |
| else: |
| raise e |
|
|
| def gen(self): |
| while True: |
| try: |
| yield self._get_next_sample() |
| except StopIteration: |
| raise |
|
|