| |
| import multiprocessing as mp |
| import time |
| from typing import Any, Callable, Dict, Optional, Union |
|
|
| import numpy as np |
| import torch.distributed as dist |
| from datasets import Dataset as HfDataset |
| from torch.utils.data import Dataset, IterableDataset |
| from tqdm import tqdm |
|
|
| from swift.utils import get_logger, is_dist, is_master |
| from ..template import MaxLengthError |
| from .preprocessor import RowPreprocessor |
|
|
| logger = get_logger() |
|
|
|
|
| def sample_dataset( |
| dataset: HfDataset, |
| dataset_sample: Optional[int], |
| shuffle: bool = True, |
| random_state: Optional[np.random.RandomState] = None, |
| ) -> HfDataset: |
| """Sample dataset by a dataset_sample number |
| Args: |
| dataset: The dataset instance, iterable dataset is not supported |
| dataset_sample: The sample number |
| shuffle: Whether to perform random sampling on non-streaming datasets |
| random_state: The random state |
| Returns: |
| The sampled dataset |
| """ |
| if dataset_sample is None: |
| return dataset |
|
|
| n_repeat_sample = dataset_sample // len(dataset) |
| n_remain_sample = dataset_sample % len(dataset) |
| if n_repeat_sample >= 1 and n_remain_sample >= 1: |
| logger.warning(f'dataset_sample:{dataset_sample} is greater than len(dataset):{len(dataset)}, ' |
| 'repeated sampling will be performed.') |
| idx = np.tile(range(len(dataset)), n_repeat_sample) |
| if n_remain_sample >= 1: |
| if shuffle: |
| if random_state is None: |
| random_state = np.random.RandomState() |
| idx_remain = random_state.permutation(len(dataset))[:n_remain_sample] |
| else: |
| idx_remain = np.arange(n_remain_sample) |
| idx = np.concatenate([idx, idx_remain]) |
| dataset = dataset.select(idx) |
| return dataset |
|
|
|
|
| class LazyLLMDataset(Dataset): |
| """This class if used to lazy tokenize the dataset, and skips bad ones when training""" |
|
|
| def __init__(self, |
| dataset: HfDataset, |
| encode_func: Callable[[Dict[str, Any]], Dict[str, Any]], |
| *, |
| n_try_fetch: int = 10, |
| strict: bool = False, |
| random_state: Union[np.random.RandomState, int, None] = None, |
| traceback_limit: int = 10) -> None: |
| self.dataset = dataset |
| self.encode_func = encode_func |
|
|
| n_try_fetch = 1 if strict else min(n_try_fetch, len(self.dataset)) |
| assert n_try_fetch >= 1 |
| self.strict = strict |
| self.n_try_fetch = n_try_fetch |
|
|
| if not isinstance(random_state, np.random.RandomState): |
| random_state = np.random.RandomState(random_state) |
| self.random_state = random_state |
|
|
| self.traceback_limit = traceback_limit |
| self._traceback_counter = 0 |
| self._idx = 0 |
| self._idx_list = self.random_state.permutation(len(self.dataset)).tolist() |
|
|
| def __getitem__(self, idx: int) -> Dict[str, Any]: |
| for i in range(self.n_try_fetch): |
| n_try = i |
| if i == 0: |
| i = idx |
| else: |
| i = self._idx_list[self._idx] |
| self._idx = (self._idx + 1) % len(self.dataset) |
| data = self.dataset[i] |
| try: |
| return self.encode_func(data) |
| except Exception: |
| if n_try == self.n_try_fetch - 1: |
| if self.strict: |
| logger.warning('To avoid errors, you can pass `strict=False`.') |
| raise |
| if self.traceback_limit is not None and self._traceback_counter < self.traceback_limit: |
| import traceback |
| logger.info(traceback.format_exc()) |
| logger.warning('👆👆👆There are errors in the template.encode, ' |
| 'and another piece of data will be randomly selected.') |
| self._traceback_counter += 1 |
|
|
| raise ValueError('Failed to retrieve the dataset. You can avoid this issue by increasing `max_length` or ' |
| 'modifying the `truncation_strategy`.') |
|
|
| def __len__(self) -> int: |
| return len(self.dataset) |
|
|
|
|
| class BasePackingDataset: |
|
|
| def __init__(self, template, dataset, num_proc: int = 1, *, packing_interval: int = 128, strict: bool = False): |
| template._packing = True |
| self.template = template |
| self.dataset = dataset |
| self.num_proc = num_proc |
| self.packing_interval = packing_interval |
| self.strict = strict |
| assert num_proc >= 1, f'num_proc: {num_proc}' |
| self.workers = [] |
|
|
| @staticmethod |
| def calculate_matched_group(template, sequences, is_finished: bool = True): |
| if len(sequences) == 0: |
| return [], [] |
| |
| import binpacking |
| sequences = binpacking.to_constant_volume(sequences, template.max_length, weight_pos=1) |
| res = [] |
| if sequences and not is_finished: |
| sequences, ret_sequences = sequences[:-1], sequences[-1] |
| else: |
| ret_sequences = [] |
| for row in sequences: |
| packed = template.packing_row(row) |
| res.append(packed) |
| return res, ret_sequences |
|
|
| def _encode_data(self, data) -> Dict[str, Any]: |
| res = None |
| try: |
| res = self.template.encode(data) |
| except Exception as e: |
| if self.strict and not isinstance(e, MaxLengthError): |
| raise |
| return res or {} |
|
|
|
|
| class PackingDataset(BasePackingDataset, Dataset): |
|
|
| def __init__(self, template, dataset, num_proc: int = 1, *, packing_interval: int = 128, strict: bool = False): |
| num_proc = min(len(dataset), num_proc) |
| super().__init__(template, dataset, num_proc, packing_interval=packing_interval, strict=strict) |
| self.prog_bar = tqdm(total=len(dataset), dynamic_ncols=True, desc=f'Packing (num_proc={num_proc})') |
| self._queue = mp.Queue() |
| self._terminated_workers = 0 |
| if is_master(): |
| for i in range(self.num_proc): |
| shard_dataset = self.dataset.shard(self.num_proc, i) |
| worker = mp.Process(target=self._producer, args=(shard_dataset, ), daemon=True) |
| worker.start() |
| self.workers.append(worker) |
|
|
| self.packed_dataset = self.get_packed_dataset() |
| self.prog_bar.close() |
| for worker in self.workers: |
| worker.terminate() |
| if is_dist(): |
| obj_list = [self.packed_dataset] |
| dist.broadcast_object_list(obj_list) |
| self.packed_dataset = obj_list[0] |
| elif is_dist(): |
| obj_list = [None] |
| dist.broadcast_object_list(obj_list) |
| self.packed_dataset = obj_list[0] |
|
|
| def fetch_packing_data(self, res: Optional[list] = None): |
| res = res or [] |
| for _ in range(self.packing_interval): |
| data = self._queue.get() |
| if data is None: |
| self._terminated_workers += 1 |
| if self._terminated_workers == self.num_proc: |
| break |
| continue |
| self.prog_bar.update(1) |
| if data: |
| res.append((data, len(data['input_ids']))) |
| return res |
|
|
| def get_packed_dataset(self): |
| data = [] |
| result = [] |
| while True: |
| data = self.fetch_packing_data(data) |
| is_finished = self._terminated_workers == self.num_proc |
| res, data = self.calculate_matched_group(self.template, data, is_finished=is_finished) |
| result += res |
| if is_finished: |
| break |
| return result |
|
|
| def _producer(self, shard_dataset): |
| for data in shard_dataset: |
| encoded_data = self._encode_data(data) |
| self._queue.put(encoded_data) |
| self._queue.put(None) |
| while True: |
| |
| time.sleep(0.1) |
|
|
| def __getitem__(self, index): |
| return self.packed_dataset[index].copy() |
|
|
| def __len__(self): |
| return len(self.packed_dataset) |
|
|
|
|
| class IterablePackingDataset(BasePackingDataset, IterableDataset): |
|
|
| def __init__(self, |
| template, |
| dataset, |
| num_proc: int = 1, |
| *, |
| packing_interval: int = 128, |
| strict: bool = False, |
| cyclic: bool = False): |
| super().__init__(template, dataset, num_proc, packing_interval=packing_interval, strict=strict) |
| self._in_queue = mp.Queue() |
| self._out_queue = mp.Queue() |
| self.workers = [] |
| self.cyclic = cyclic |
| for _ in range(self.num_proc): |
| worker = mp.Process(target=self._processor, daemon=True) |
| worker.start() |
| self.workers.append(worker) |
|
|
| def _processor(self): |
| while True: |
| i, data = self._in_queue.get() |
| if data is None: |
| encoded_data = None |
| else: |
| encoded_data = self._encode_data(data) |
| self._out_queue.put((i, encoded_data)) |
|
|
| def _put_data_in_queue(self, iterator): |
| for i in range(self.packing_interval): |
| try: |
| data = next(iterator) |
| except StopIteration: |
| self._in_queue.put((i, None)) |
| return True |
| self._in_queue.put((i, data)) |
| return False |
|
|
| def _fetch_data_out_queue(self, res): |
| res = [None] * self.packing_interval |
| for _ in range(self.packing_interval): |
| i, data = self._out_queue.get() |
| if data is None: |
| break |
| elif not data: |
| continue |
| res[i] = (data, len(data['input_ids'])) |
| res = [data for data in res if data] |
| return res |
|
|
| @staticmethod |
| def cyclic_iter(iterable): |
| while True: |
| for x in iterable: |
| yield x |
|
|
| def __iter__(self): |
| try: |
| next(iter(self.dataset)) |
| except StopIteration: |
| return |
|
|
| if self.cyclic: |
| iterator = self.cyclic_iter(self.dataset) |
| else: |
| iterator = iter(self.dataset) |
| data = [] |
| while True: |
| finished = self._put_data_in_queue(iterator) |
| data = self._fetch_data_out_queue(data) |
| res, data = self.calculate_matched_group(self.template, data, is_finished=finished) |
| yield from res |
| if finished: |
| break |
|
|
|
|
| class EncodePreprocessor(RowPreprocessor): |
|
|
| def __init__(self, template: 'Template'): |
| super().__init__() |
| self.template = template |
|
|
| def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: |
| return self.template.encode(row) |
|
|
|
|
| class GetLengthPreprocessor(RowPreprocessor): |
|
|
| def preprocess(self, row): |
| length = max([len(row[k]) for k in row.keys() if k.endswith('input_ids')]) |
| return {'length': length} |
|
|