|
|
|
|
|
|
|
|
import itertools |
|
|
import json |
|
|
import logging |
|
|
import os |
|
|
import re |
|
|
import traceback |
|
|
from typing import Any, Callable, Dict, Iterator, List, Optional, cast |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from torch.utils.data import IterableDataset, get_worker_info |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def get_worker_info(): |
|
|
worker_info = torch.utils.data.get_worker_info() |
|
|
if worker_info is None: |
|
|
num_workers = 1 |
|
|
worker_id = 0 |
|
|
else: |
|
|
num_workers = worker_info.num_workers |
|
|
worker_id = worker_info.id |
|
|
|
|
|
return worker_id, num_workers |
|
|
|
|
|
|
|
|
def get_global_rank_info(rank, world_size): |
|
|
worker_id, num_workers = get_worker_info() |
|
|
dataloader_rank = rank * num_workers + worker_id |
|
|
dataloader_world_size = world_size * num_workers |
|
|
return dataloader_rank, dataloader_world_size |
|
|
|
|
|
|
|
|
class JSONLIterator: |
|
|
def __init__( |
|
|
self, |
|
|
fpath: str, |
|
|
world_size: int, |
|
|
world_rank: int, |
|
|
infinite: bool, |
|
|
): |
|
|
assert 0 <= world_rank < world_size, (world_rank, world_size) |
|
|
self.f = open(fpath, "r", encoding="utf-8") |
|
|
self.fpath = fpath |
|
|
self.world_size = world_size |
|
|
self.world_rank = world_rank |
|
|
self.line_num = 0 |
|
|
self.iter = iter(self.gen(infinite)) |
|
|
self.iter_id = 0 |
|
|
|
|
|
def __iter__(self): |
|
|
return self |
|
|
|
|
|
def __next__(self): |
|
|
return next(self.iter) |
|
|
|
|
|
def gen(self, infinite: bool) -> Iterator[Dict]: |
|
|
while True: |
|
|
if self.world_rank == 0: |
|
|
logger.info(f"Starting iteration {self.iter_id} over {self.fpath} ...") |
|
|
self.iter_id += 1 |
|
|
while True: |
|
|
line, self.line_num = self.f.readline(), self.line_num + 1 |
|
|
if not line: |
|
|
break |
|
|
if (self.line_num - 1) % self.world_size == self.world_rank: |
|
|
yield json.loads(line) |
|
|
if not infinite: |
|
|
break |
|
|
self.set_position(None) |
|
|
self.f.close() |
|
|
|
|
|
def set_position(self, position: Optional[int]): |
|
|
logger.warning( |
|
|
f"Setting JSONL position on {self.fpath} " |
|
|
f"({self.world_rank}/{self.world_size}): {position}" |
|
|
) |
|
|
if position is None: |
|
|
self.f.seek(0) |
|
|
self.line_num = 0 |
|
|
else: |
|
|
assert isinstance(position, int) |
|
|
self.f.seek(position) |
|
|
self.line_num = ( |
|
|
self.world_rank + 1 |
|
|
) |
|
|
|
|
|
def get_position(self) -> Optional[int]: |
|
|
file_pos = self.f.tell() |
|
|
if file_pos == 0 and self.line_num == 0: |
|
|
return None |
|
|
assert (self.line_num - 1) % self.world_size == self.world_rank |
|
|
return file_pos |
|
|
|
|
|
def get_example_file(self): |
|
|
""" |
|
|
Return the path to a sample file to infer the content key |
|
|
""" |
|
|
return self.fpath |
|
|
|
|
|
def get_id(self): |
|
|
""" |
|
|
Return an identifier for the dataset this iterator represents |
|
|
""" |
|
|
return self.fpath |
|
|
|
|
|
|
|
|
class JSONLDirectoryIterator: |
|
|
""" |
|
|
The JSONLDirectoryIterator is a data wrapper around a dataset folder, which contains |
|
|
multiple JSONL files. Internally, it reuses the JSONLIterator class to iterate through |
|
|
each individual file, and then wraps onto the next file once the current one is exhausted. |
|
|
|
|
|
Once all files in the directory have been iterated over, we wrap back to the first file |
|
|
( if infinite is true ). |
|
|
|
|
|
This enables us to iterate over a dataset one chunk at a time. |
|
|
|
|
|
Also, note that we open the next chunk file on an ondemand basis, which means that we can |
|
|
modify chunks mid training as well to add more data, fix issues, etc. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dirpath: str, |
|
|
world_size: int, |
|
|
world_rank: int, |
|
|
infinite: bool, |
|
|
): |
|
|
assert 0 <= world_rank < world_size, (world_rank, world_size) |
|
|
self.dirpath = dirpath |
|
|
self.world_size = world_size |
|
|
self.world_rank = world_rank |
|
|
|
|
|
fnames = [ |
|
|
x |
|
|
for x in os.listdir(self.dirpath) |
|
|
if re.fullmatch(r".*chunk\.\d+.*\.jsonl", x) |
|
|
] |
|
|
self.fpaths = [os.path.join(self.dirpath, fname) for fname in sorted(fnames)] |
|
|
assert ( |
|
|
len(self.fpaths) > 0 |
|
|
), f"Specified dataset location {self.dirpath} is empty." |
|
|
|
|
|
|
|
|
if infinite: |
|
|
self.fpaths_generator = cast(Iterator[str], itertools.cycle(self.fpaths)) |
|
|
else: |
|
|
self.fpaths_generator = cast(Iterator[str], iter(self.fpaths)) |
|
|
|
|
|
self.iter = iter(self.gen(infinite)) |
|
|
self.jsonl_iterator: Optional[JSONLIterator] = None |
|
|
|
|
|
def __iter__(self): |
|
|
return self |
|
|
|
|
|
def __next__(self): |
|
|
return next(self.iter) |
|
|
|
|
|
def gen(self, infinite: bool) -> Iterator[Dict]: |
|
|
|
|
|
if self.jsonl_iterator is not None: |
|
|
yield from self.jsonl_iterator |
|
|
|
|
|
for fpath in self.fpaths_generator: |
|
|
|
|
|
self.jsonl_iterator = JSONLIterator( |
|
|
fpath, |
|
|
world_size=self.world_size, |
|
|
world_rank=self.world_rank, |
|
|
infinite=False, |
|
|
) |
|
|
|
|
|
yield from self.jsonl_iterator |
|
|
|
|
|
def set_position(self, state: Dict[str, Any]): |
|
|
logger.warning( |
|
|
f"Setting JSONL position on {self.dirpath} " |
|
|
f"({self.world_rank}/{self.world_size}): {state}" |
|
|
) |
|
|
fpath: Optional[str] = state["fpath"] |
|
|
position: Optional[int] = state["position"] |
|
|
if fpath is None or position is None: |
|
|
return |
|
|
|
|
|
assert isinstance(fpath, str) |
|
|
assert isinstance(position, int) |
|
|
|
|
|
|
|
|
for fpath_candidate in self.fpaths_generator: |
|
|
if fpath_candidate == fpath: |
|
|
break |
|
|
|
|
|
|
|
|
self.jsonl_iterator = JSONLIterator( |
|
|
fpath, |
|
|
world_size=self.world_size, |
|
|
world_rank=self.world_rank, |
|
|
infinite=False, |
|
|
) |
|
|
self.jsonl_iterator.set_position(position) |
|
|
|
|
|
def get_position(self): |
|
|
if self.jsonl_iterator is None: |
|
|
return { |
|
|
"fpath": None, |
|
|
"position": None, |
|
|
} |
|
|
return { |
|
|
"fpath": self.jsonl_iterator.fpath, |
|
|
"position": self.jsonl_iterator.get_position(), |
|
|
} |
|
|
|
|
|
def get_example_file(self): |
|
|
""" |
|
|
Return the path to a sample file to infer the content key |
|
|
""" |
|
|
return self.fpaths[0] |
|
|
|
|
|
def get_id(self): |
|
|
""" |
|
|
Return an identifier for the dataset this iterator represents |
|
|
""" |
|
|
return self.dirpath |
|
|
|
|
|
|
|
|
class IterativeJSONLDataset(IterableDataset): |
|
|
def __init__( |
|
|
self, |
|
|
global_rank: int, |
|
|
world_size: int, |
|
|
dataset_name: str, |
|
|
seed: int = 0, |
|
|
dataset_configs: Dict[str, Any] = {}, |
|
|
): |
|
|
self._dataset_name = dataset_name |
|
|
self._seed = seed |
|
|
self._dataset_conf = dataset_configs[dataset_name] |
|
|
|
|
|
self.global_rank = global_rank |
|
|
self.world_size = world_size |
|
|
self.data_path = self._dataset_conf.annotation |
|
|
|
|
|
def worker_init(self, worker_id, num_workers): |
|
|
dataloader_rank = self.global_rank * num_workers + worker_id |
|
|
dataloader_world_size = self.world_size * num_workers |
|
|
if os.path.isfile(self.data_path): |
|
|
self.jsonl_iterator = JSONLIterator( |
|
|
self.data_path, |
|
|
world_size=dataloader_world_size, |
|
|
world_rank=dataloader_rank, |
|
|
infinite=True, |
|
|
) |
|
|
else: |
|
|
self.jsonl_iterator = JSONLDirectoryIterator( |
|
|
dirpath=self.data_path, |
|
|
world_size=dataloader_world_size, |
|
|
world_rank=dataloader_rank, |
|
|
infinite=True, |
|
|
) |
|
|
if worker_id == 0: |
|
|
logger.info( |
|
|
f"Initializing JSONLDataset {self._dataset_name} on " |
|
|
f"dataloader rank {dataloader_rank} and world size {dataloader_world_size}" |
|
|
) |
|
|
|
|
|
def state_dict(self): |
|
|
pos = self.jsonl_iterator.get_position() |
|
|
if isinstance(pos, Dict): |
|
|
return pos |
|
|
else: |
|
|
return {"single_jsonl_position": pos} |
|
|
|
|
|
def load_state_dict(self, state_dict): |
|
|
if "single_jsonl_position" in state_dict: |
|
|
self.jsonl_iterator.set_position(state_dict["single_jsonl_position"]) |
|
|
else: |
|
|
self.jsonl_iterator.set_position(state_dict) |
|
|
logger.info(f"JSONLDataset {self._dataset_name} resuming from {state_dict}.") |
|
|
|
|
|
def __iter__(self): |
|
|
return self |
|
|
|
|
|
def __next__(self): |
|
|
return next(self.jsonl_iterator) |
|
|
|
|
|
|
|
|
class DatasetMixer(IterableDataset): |
|
|
def __init__( |
|
|
self, |
|
|
mix: str, |
|
|
global_rank: int, |
|
|
world_size: int, |
|
|
seed: int = 0, |
|
|
preprocessors: List[Callable] = [], |
|
|
dataset_configs: Dict[str, Any] = {}, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.dataset_and_preprocessors = [] |
|
|
self.weights = [] |
|
|
self.dataset_names = [] |
|
|
self.totals = [] |
|
|
|
|
|
self.global_rank = global_rank |
|
|
self.world_size = world_size |
|
|
self.seed = seed |
|
|
|
|
|
mix = "".join(mix.split()) |
|
|
|
|
|
for elem in mix.split(","): |
|
|
ds, weight = elem.split(":") |
|
|
|
|
|
if ds not in dataset_configs: |
|
|
raise ValueError(f"Dataset {ds} not found in dataset_configs.") |
|
|
if ds in self.dataset_names: |
|
|
raise ValueError( |
|
|
f"Dataset {ds} already in the mix. Each dataset can only be used once." |
|
|
) |
|
|
|
|
|
dataset = IterativeJSONLDataset( |
|
|
global_rank=global_rank, |
|
|
world_size=world_size, |
|
|
dataset_name=ds, |
|
|
seed=seed, |
|
|
dataset_configs=dataset_configs, |
|
|
) |
|
|
_preprocessors = [ |
|
|
p(dataset_config=dataset_configs[ds]) for p in preprocessors |
|
|
] |
|
|
|
|
|
self.dataset_and_preprocessors.append((dataset, _preprocessors)) |
|
|
self.weights.append(float(weight)) |
|
|
self.dataset_names.append(ds) |
|
|
self.totals.append(0) |
|
|
|
|
|
self.weights = [w / sum(self.weights) for w in self.weights] |
|
|
self.rng = None |
|
|
|
|
|
def state_dict(self): |
|
|
return { |
|
|
"datasets": { |
|
|
ds_name: ds.state_dict() |
|
|
for ds_name, (ds, _) in zip( |
|
|
self.dataset_names, self.dataset_and_preprocessors |
|
|
) |
|
|
}, |
|
|
"totals": { |
|
|
ds_name: total |
|
|
for ds_name, total in zip(self.dataset_names, self.totals) |
|
|
}, |
|
|
"rng": ( |
|
|
[ |
|
|
s.tolist() if isinstance(s, np.ndarray) else s |
|
|
for s in self.rng.get_state() |
|
|
] |
|
|
if self.rng is not None |
|
|
else None |
|
|
), |
|
|
} |
|
|
|
|
|
def load_state_dict(self, state_dict): |
|
|
for ds_name, sd in state_dict["datasets"].items(): |
|
|
if ds_name in self.dataset_names: |
|
|
ds_idx = self.dataset_names.index(ds_name) |
|
|
ds, _ = self.dataset_and_preprocessors[ds_idx] |
|
|
ds.load_state_dict(sd) |
|
|
self.totals[ds_idx] = state_dict["totals"][ds_name] |
|
|
|
|
|
logger.info( |
|
|
f"DatasetMixer with datasets {self.dataset_names} resuming with total samples seen {self.totals} on process {os.getpid()}." |
|
|
) |
|
|
|
|
|
if state_dict["rng"] is not None: |
|
|
self.rng = np.random.RandomState() |
|
|
rng_state = [ |
|
|
np.array(s) if isinstance(s, list) else s for s in state_dict["rng"] |
|
|
] |
|
|
self.rng.set_state(rng_state) |
|
|
|
|
|
def worker_init(self, worker_id): |
|
|
worker_info = torch.utils.data.get_worker_info() |
|
|
for dataset, _ in self.dataset_and_preprocessors: |
|
|
if hasattr(dataset, "worker_init"): |
|
|
dataset.worker_init(worker_id, worker_info.num_workers) |
|
|
|
|
|
def __iter__(self): |
|
|
if self.rng is None: |
|
|
rank, world_size = get_global_rank_info(self.global_rank, self.world_size) |
|
|
self.rng = np.random.RandomState((rank, world_size, self.seed)) |
|
|
|
|
|
while True: |
|
|
try: |
|
|
src_id = self.rng.choice(len(self.weights), p=self.weights) |
|
|
dataset, preprocessors = self.dataset_and_preprocessors[src_id] |
|
|
out = next(dataset) |
|
|
for preprocessor in preprocessors: |
|
|
if out is not None: |
|
|
out = preprocessor(out, self.rng) |
|
|
|
|
|
if out is None: |
|
|
continue |
|
|
|
|
|
self.totals[src_id] += 1 |
|
|
yield out |
|
|
except Exception as e: |
|
|
logger.error( |
|
|
f"Error while iterating over dataset {self.dataset_names[src_id]}: {e}\n" |
|
|
f"Traceback:\n{traceback.format_exc()}" |
|
|
) |
|
|
|
|
|
|
|
|
class PersistentDataLoader: |
|
|
""" |
|
|
A _very_ persistent dataloader. |
|
|
|
|
|
Uses StatefulDataLoader to save dataset state (make sure dataset has a state_dict() and load_state_dict() method). |
|
|
Also keeps the dataloader iterator and the epoch iterator separate, so that the dataloader workers are persistent. |
|
|
|
|
|
Also laughs in the face of torch when it tries to kill the whole job because a worker died. Instead, this dataloader |
|
|
will just gracefully restart the underlying iterator and corresponding workers, while additionally loading the state dict |
|
|
so that it resumes from where it left off. |
|
|
|
|
|
This may or may not be a good idea. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dataset, |
|
|
batch_size, |
|
|
workers, |
|
|
collate_fn=None, |
|
|
positions=None, |
|
|
): |
|
|
from torchdata.stateful_dataloader import StatefulDataLoader |
|
|
|
|
|
self.dataloader = StatefulDataLoader( |
|
|
dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=False, |
|
|
num_workers=workers, |
|
|
|
|
|
multiprocessing_context="fork" if workers > 0 else None, |
|
|
collate_fn=collate_fn, |
|
|
worker_init_fn=( |
|
|
dataset.worker_init if hasattr(dataset, "worker_init") else None |
|
|
), |
|
|
|
|
|
snapshot_every_n_steps=1, |
|
|
) |
|
|
|
|
|
if positions is not None: |
|
|
self.load_state_dict(positions) |
|
|
|
|
|
self._dataloader_iter = iter(self.dataloader) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def state_dict(self): |
|
|
return self.dataloader.state_dict() |
|
|
|
|
|
def load_state_dict(self, state_dict): |
|
|
self.dataloader.load_state_dict(state_dict) |
|
|
|
|
|
def __del__(self): |
|
|
pass |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.dataloader) |
|
|
|
|
|
def __iter__(self): |
|
|
self.iter = self.gen() |
|
|
return self |
|
|
|
|
|
def __next__(self): |
|
|
return next(self.iter) |
|
|
|
|
|
def _refresh_iter(self): |
|
|
|
|
|
self._dataloader_iter = None |
|
|
|
|
|
def _get_next_sample(self): |
|
|
if self._dataloader_iter is None: |
|
|
self.dataloader.load_state_dict(self.dataloader.state_dict()) |
|
|
self._dataloader_iter = iter(self.dataloader) |
|
|
|
|
|
try: |
|
|
return next(self._dataloader_iter) |
|
|
except (KeyboardInterrupt, StopIteration): |
|
|
raise |
|
|
except Exception as e: |
|
|
if self._dataloader_iter is None: |
|
|
|
|
|
return self._get_next_sample() |
|
|
else: |
|
|
raise e |
|
|
|
|
|
def gen(self): |
|
|
while True: |
|
|
try: |
|
|
yield self._get_next_sample() |
|
|
except StopIteration: |
|
|
raise |
|
|
|