File size: 11,141 Bytes
cb2428f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 | # Copyright (c) Alibaba, Inc. and its affiliates.
import multiprocessing as mp
import time
from typing import Any, Callable, Dict, Optional, Union
import numpy as np
import torch.distributed as dist
from datasets import Dataset as HfDataset
from torch.utils.data import Dataset, IterableDataset
from tqdm import tqdm
from swift.utils import get_logger, is_dist, is_master
from ..template import MaxLengthError
from .preprocessor import RowPreprocessor
logger = get_logger()
def sample_dataset(
dataset: HfDataset,
dataset_sample: Optional[int],
shuffle: bool = True,
random_state: Optional[np.random.RandomState] = None,
) -> HfDataset:
"""Sample dataset by a dataset_sample number
Args:
dataset: The dataset instance, iterable dataset is not supported
dataset_sample: The sample number
shuffle: Whether to perform random sampling on non-streaming datasets
random_state: The random state
Returns:
The sampled dataset
"""
if dataset_sample is None:
return dataset
n_repeat_sample = dataset_sample // len(dataset)
n_remain_sample = dataset_sample % len(dataset)
if n_repeat_sample >= 1 and n_remain_sample >= 1:
logger.warning(f'dataset_sample:{dataset_sample} is greater than len(dataset):{len(dataset)}, '
'repeated sampling will be performed.')
idx = np.tile(range(len(dataset)), n_repeat_sample)
if n_remain_sample >= 1:
if shuffle:
if random_state is None:
random_state = np.random.RandomState()
idx_remain = random_state.permutation(len(dataset))[:n_remain_sample]
else:
idx_remain = np.arange(n_remain_sample)
idx = np.concatenate([idx, idx_remain])
dataset = dataset.select(idx)
return dataset
class LazyLLMDataset(Dataset):
"""This class if used to lazy tokenize the dataset, and skips bad ones when training"""
def __init__(self,
dataset: HfDataset,
encode_func: Callable[[Dict[str, Any]], Dict[str, Any]],
*,
n_try_fetch: int = 10,
strict: bool = False,
random_state: Union[np.random.RandomState, int, None] = None,
traceback_limit: int = 10) -> None:
self.dataset = dataset
self.encode_func = encode_func
n_try_fetch = 1 if strict else min(n_try_fetch, len(self.dataset))
assert n_try_fetch >= 1
self.strict = strict
self.n_try_fetch = n_try_fetch
if not isinstance(random_state, np.random.RandomState):
random_state = np.random.RandomState(random_state)
self.random_state = random_state
self.traceback_limit = traceback_limit
self._traceback_counter = 0
self._idx = 0
self._idx_list = self.random_state.permutation(len(self.dataset)).tolist()
def __getitem__(self, idx: int) -> Dict[str, Any]:
for i in range(self.n_try_fetch):
n_try = i
if i == 0:
i = idx
else:
i = self._idx_list[self._idx]
self._idx = (self._idx + 1) % len(self.dataset)
data = self.dataset[i]
try:
return self.encode_func(data)
except Exception:
if n_try == self.n_try_fetch - 1:
if self.strict:
logger.warning('To avoid errors, you can pass `strict=False`.')
raise
if self.traceback_limit is not None and self._traceback_counter < self.traceback_limit:
import traceback
logger.info(traceback.format_exc())
logger.warning('👆👆👆There are errors in the template.encode, '
'and another piece of data will be randomly selected.')
self._traceback_counter += 1
raise ValueError('Failed to retrieve the dataset. You can avoid this issue by increasing `max_length` or '
'modifying the `truncation_strategy`.')
def __len__(self) -> int:
return len(self.dataset)
class BasePackingDataset:
def __init__(self, template, dataset, num_proc: int = 1, *, packing_interval: int = 128, strict: bool = False):
template._packing = True
self.template = template
self.dataset = dataset
self.num_proc = num_proc
self.packing_interval = packing_interval
self.strict = strict
assert num_proc >= 1, f'num_proc: {num_proc}'
self.workers = []
@staticmethod
def calculate_matched_group(template, sequences, is_finished: bool = True):
if len(sequences) == 0:
return [], []
# https://arxiv.org/pdf/2404.10830
import binpacking
sequences = binpacking.to_constant_volume(sequences, template.max_length, weight_pos=1)
res = []
if sequences and not is_finished:
sequences, ret_sequences = sequences[:-1], sequences[-1]
else:
ret_sequences = []
for row in sequences:
packed = template.packing_row(row)
res.append(packed)
return res, ret_sequences
def _encode_data(self, data) -> Dict[str, Any]:
res = None
try:
res = self.template.encode(data)
except Exception as e:
if self.strict and not isinstance(e, MaxLengthError):
raise
return res or {}
class PackingDataset(BasePackingDataset, Dataset):
def __init__(self, template, dataset, num_proc: int = 1, *, packing_interval: int = 128, strict: bool = False):
num_proc = min(len(dataset), num_proc)
super().__init__(template, dataset, num_proc, packing_interval=packing_interval, strict=strict)
self.prog_bar = tqdm(total=len(dataset), dynamic_ncols=True, desc=f'Packing (num_proc={num_proc})')
self._queue = mp.Queue()
self._terminated_workers = 0
if is_master():
for i in range(self.num_proc):
shard_dataset = self.dataset.shard(self.num_proc, i)
worker = mp.Process(target=self._producer, args=(shard_dataset, ), daemon=True)
worker.start()
self.workers.append(worker)
self.packed_dataset = self.get_packed_dataset()
self.prog_bar.close()
for worker in self.workers:
worker.terminate()
if is_dist():
obj_list = [self.packed_dataset]
dist.broadcast_object_list(obj_list)
self.packed_dataset = obj_list[0]
elif is_dist():
obj_list = [None]
dist.broadcast_object_list(obj_list)
self.packed_dataset = obj_list[0]
def fetch_packing_data(self, res: Optional[list] = None):
res = res or []
for _ in range(self.packing_interval):
data = self._queue.get()
if data is None:
self._terminated_workers += 1
if self._terminated_workers == self.num_proc:
break
continue
self.prog_bar.update(1)
if data:
res.append((data, len(data['input_ids'])))
return res
def get_packed_dataset(self):
data = []
result = []
while True:
data = self.fetch_packing_data(data)
is_finished = self._terminated_workers == self.num_proc
res, data = self.calculate_matched_group(self.template, data, is_finished=is_finished)
result += res
if is_finished:
break
return result
def _producer(self, shard_dataset):
for data in shard_dataset:
encoded_data = self._encode_data(data) # ignore
self._queue.put(encoded_data)
self._queue.put(None)
while True:
# Wait for the main process to terminate to avoid fd anomalies.
time.sleep(0.1)
def __getitem__(self, index):
return self.packed_dataset[index].copy()
def __len__(self):
return len(self.packed_dataset)
class IterablePackingDataset(BasePackingDataset, IterableDataset):
def __init__(self,
template,
dataset,
num_proc: int = 1,
*,
packing_interval: int = 128,
strict: bool = False,
cyclic: bool = False):
super().__init__(template, dataset, num_proc, packing_interval=packing_interval, strict=strict)
self._in_queue = mp.Queue()
self._out_queue = mp.Queue()
self.workers = []
self.cyclic = cyclic
for _ in range(self.num_proc):
worker = mp.Process(target=self._processor, daemon=True)
worker.start()
self.workers.append(worker)
def _processor(self):
while True:
i, data = self._in_queue.get()
if data is None:
encoded_data = None
else:
encoded_data = self._encode_data(data)
self._out_queue.put((i, encoded_data))
def _put_data_in_queue(self, iterator):
for i in range(self.packing_interval):
try:
data = next(iterator)
except StopIteration:
self._in_queue.put((i, None))
return True
self._in_queue.put((i, data))
return False
def _fetch_data_out_queue(self, res):
res = [None] * self.packing_interval
for _ in range(self.packing_interval):
i, data = self._out_queue.get()
if data is None:
break
elif not data:
continue
res[i] = (data, len(data['input_ids']))
res = [data for data in res if data]
return res
@staticmethod
def cyclic_iter(iterable):
while True:
for x in iterable:
yield x
def __iter__(self):
try:
next(iter(self.dataset))
except StopIteration:
return
if self.cyclic:
iterator = self.cyclic_iter(self.dataset)
else:
iterator = iter(self.dataset)
data = []
while True:
finished = self._put_data_in_queue(iterator)
data = self._fetch_data_out_queue(data)
res, data = self.calculate_matched_group(self.template, data, is_finished=finished)
yield from res
if finished:
break
class EncodePreprocessor(RowPreprocessor):
def __init__(self, template: 'Template'):
super().__init__()
self.template = template
def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]:
return self.template.encode(row)
class GetLengthPreprocessor(RowPreprocessor):
def preprocess(self, row):
length = max([len(row[k]) for k in row.keys() if k.endswith('input_ids')])
return {'length': length}
|