| | import dataclasses |
| | from typing import Any, Iterator, List, Optional |
| |
|
| | import numpy as np |
| | from mistral_common.tokens.tokenizers.sentencepiece import InstructTokenizerBase |
| |
|
| | from .args import DataArgs |
| | from .dataset import build_dataset |
| |
|
| |
|
| | @dataclasses.dataclass |
| | class Batch: |
| | x: np.ndarray |
| | y: np.ndarray |
| | sizes: List[int] |
| | y_mask: Optional[np.ndarray] = None |
| | is_pad_only: bool = False |
| |
|
| | def __post_init__(self): |
| | assert self.x.ndim == 1 |
| | assert self.x.shape == self.y.shape |
| | assert self.x.dtype == np.int64 |
| | assert self.y.dtype == np.int64 |
| | assert isinstance(self.sizes, list) |
| | assert sum(self.sizes) == self.x.size == self.y.size |
| |
|
| | if self.y_mask is not None: |
| | assert self.y_mask.size == self.y.size, (self.y_mask.shape, self.y.shape) |
| | assert self.y_mask.dtype == bool |
| | assert sum(self.sizes) == self.y_mask.size |
| | assert not self.y_mask.all() |
| | assert self.y_mask.any() |
| |
|
| | if self.is_pad_only: |
| | assert np.sum(np.abs(self.y)) == 0 |
| | assert np.sum(np.abs(self.x)) == 0 |
| | assert self.y_mask is None |
| | |
| | self.y_mask = np.zeros_like(self.x) |
| |
|
| |
|
| |
|
| |
|
| | @dataclasses.dataclass |
| | class BatchList: |
| | x: List[List[int]] = dataclasses.field(default_factory=list) |
| | y: List[List[int]] = dataclasses.field(default_factory=list) |
| | sizes: List[List[int]] = dataclasses.field(default_factory=list) |
| | y_mask: List[Optional[List[int]]] = dataclasses.field(default_factory=list) |
| |
|
| | def __post_init__(self): |
| | assert self.x == [], "`BatchList` has to be empty at init." |
| | assert self.y == [], "`BatchList` has to be empty at init." |
| | assert self.sizes == [], "`BatchList` has to be empty at init." |
| | assert self.y_mask == [], "`BatchList` has to be empty at init." |
| |
|
| | def __len__(self) -> int: |
| | return len(self.x) |
| |
|
| | def add(self, x: List[int], y: List[int], sizes: List[int], y_mask: Optional[List[int]] = None): |
| | self.x.append(x) |
| | self.y.append(y) |
| | self.sizes.append(sizes) |
| | self.y_mask.append(y_mask) |
| |
|
| | def empty(self): |
| | self.x = [] |
| | self.y = [] |
| | self.sizes = [] |
| | self.y_mask = [] |
| |
|
| | @staticmethod |
| | def flatten_to_numpy(list_of_lists: List[List[Any]], dtype: np.dtype) -> np.array: |
| | return np.array([el for sublist in list_of_lists for el in sublist], dtype=dtype) |
| |
|
| | def create_batch(self) -> Batch: |
| | x_np: np.array = self.flatten_to_numpy(self.x, dtype=np.int64) |
| | y_np: np.array = self.flatten_to_numpy(self.y, dtype=np.int64) |
| | sizes = sum(self.sizes, []) |
| |
|
| | y_mask_np: Optional[np.array] = self.flatten_to_numpy(self.y_mask, dtype=bool) |
| | y_mask_np = None if y_mask_np.all() else y_mask_np |
| |
|
| | return Batch(x_np, y_np, sizes, y_mask_np) |
| |
|
| |
|
| |
|
| |
|
| | def build_data_loader( |
| | instruct_tokenizer: InstructTokenizerBase, |
| | args: DataArgs, |
| | batch_size: int, |
| | seq_len: int, |
| | seed: Optional[int], |
| | rank: int, |
| | world_size: int, |
| | is_eval: bool, |
| | ) -> Iterator[Batch]: |
| | pretrain_data = args.data if not is_eval else "" |
| | instruct_data = args.instruct_data if not is_eval else args.eval_instruct_data |
| |
|
| | dataset = build_dataset( |
| | pretrain_data=pretrain_data, |
| | instruct_data=instruct_data, |
| | instruct_args=args.instruct, |
| | instruct_tokenizer=instruct_tokenizer, |
| | seq_len=seq_len, |
| | seed=seed, |
| | rank=rank, |
| | world_size=world_size, |
| | is_eval=is_eval, |
| | shuffle_pretrain=args.shuffle, |
| | ) |
| |
|
| | batch_list = BatchList() |
| | for sample in dataset: |
| | assert all(s >= 0 for s in sample.sizes) |
| |
|
| | batch_list.add(sample.x, sample.y, sample.sizes, sample.mask) |
| |
|
| | if len(batch_list) == batch_size: |
| | batch: Batch = batch_list.create_batch() |
| | yield batch |
| |
|
| | batch_list.empty() |
| |
|
| |
|