| import os |
| from unittest.mock import MagicMock |
| import requests |
|
|
| from torch.utils.data import IterableDataset |
|
|
|
|
| def train_tokenizer(destination_path): |
| destination_path.mkdir(parents=True, exist_ok=True) |
|
|
| |
| input_file_path = destination_path / "input.txt" |
| if not input_file_path.exists(): |
| data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" |
| with open(input_file_path, "w") as f: |
| f.write(requests.get(data_url).text) |
|
|
| from lit_llama import Tokenizer |
| Tokenizer.train( |
| input=input_file_path, |
| destination=destination_path, |
| vocab_size=100, |
| ) |
|
|
| return destination_path / "tokenizer.model" |
|
|
|
|
| def test_packed_dataset(tmp_path): |
| tokenizer_path = train_tokenizer(tmp_path) |
|
|
| from lit_llama import Tokenizer |
| tokenizer = Tokenizer(tokenizer_path) |
|
|
| texts = [ |
| "The moment of truth is upon us.", |
| "Time to open the fridge." |
| ] |
|
|
| from lit_llama.packed_dataset import PackedDatasetBuilder, PackedDataset, HDR_SIZE |
|
|
| block_size = 10 |
| n_blocks = 2 |
| chunk_size = block_size * n_blocks |
|
|
| builder = PackedDatasetBuilder( |
| outdir=tmp_path, |
| prefix="packed_dataset", |
| chunk_size=chunk_size, |
| sep_token=tokenizer.bos_id, |
| dtype="auto", |
| vocab_size=100, |
| ) |
|
|
| text_ids = [] |
|
|
| for text in texts: |
| text_ids = tokenizer.encode(text) |
| assert text_ids[0] == tokenizer.bos_id |
| builder.add_array(text_ids) |
|
|
| filenames = builder.filenames |
|
|
| assert len(filenames) == 2 |
| assert os.path.basename(filenames[0]) == "packed_dataset_0000000000.bin" |
| assert os.path.basename(filenames[1]) == "packed_dataset_0000000001.bin" |
|
|
| import numpy as np |
|
|
| ex_tokenized = [ |
| tokenizer.encode(text).numpy().astype(builder.dtype) |
| for text in texts |
| ] |
| ex_tokenized = np.concatenate(ex_tokenized) |
| ex_tokenized = ex_tokenized[:2 * chunk_size] |
|
|
| for filename, el in zip(filenames, np.array_split(ex_tokenized, 2)): |
| mmap = np.memmap(filename, mode="r", order="C", offset=HDR_SIZE) |
| count = len(mmap) // np.dtype(builder.dtype).itemsize |
| arr = np.frombuffer( |
| mmap, dtype=builder.dtype, count=count, offset=0 |
| ) |
| where_bos = np.where(arr == tokenizer.bos_id) |
| |
| assert len(where_bos) == 1 |
| assert np.array_equal(arr, el) |
|
|
| dataset = PackedDataset(filenames=filenames, n_chunks=2, block_size=block_size, shuffle=False) |
|
|
| ex_split = np.array_split(ex_tokenized, ex_tokenized.shape[0] // block_size) |
|
|
| for item, el in zip(dataset, ex_split): |
| assert np.array_equal(item, el) |
|
|
| dataset = PackedDataset(filenames=filenames, n_chunks=2, block_size=block_size, seed=12345) |
|
|
| for i, item in enumerate(dataset): |
| block_idxs = iter(dataset)._block_idxs |
| assert np.array_equal(item, ex_split[block_idxs[i]]) |
|
|
| dataset = PackedDataset(filenames=filenames, n_chunks=2, block_size=block_size, seed=12345, wrap=True) |
|
|
| for i, item in enumerate(dataset): |
| if i > 24: |
| break |
|
|
| dataset = PackedDataset(filenames=filenames, n_chunks=1, block_size=block_size, seed=12345) |
|
|
| for i, item in enumerate(dataset): |
| block_idxs = iter(dataset)._block_idxs |
| chunk_idx = i // n_blocks * n_blocks |
| assert np.array_equal(item, ex_split[chunk_idx + block_idxs[i % n_blocks]]) |
|
|
| block_size_ = block_size // 2 |
| ex_split = np.array_split(ex_tokenized, ex_tokenized.shape[0] // block_size_) |
| dataset = PackedDataset(filenames=filenames, n_chunks=2, block_size=block_size_, seed=12345) |
|
|
| for i, item in enumerate(dataset): |
| block_idxs = iter(dataset)._block_idxs |
| assert np.array_equal(item, ex_split[block_idxs[i]]) |
|
|
| block_size_ = block_size // 3 |
| n_chunks = 2 |
| ex_chunks = np.split(ex_tokenized, n_chunks) |
| n_splits = ex_tokenized.shape[0] // n_chunks // block_size_ |
| ex_splits = [np.split(el[:n_splits * block_size_], n_splits) for el in ex_chunks] |
| ex_split = sum(ex_splits, []) |
|
|
| dataset = PackedDataset(filenames=filenames, n_chunks=n_chunks, block_size=block_size_, seed=12345) |
|
|
| for i, item in enumerate(dataset): |
| block_idxs = iter(dataset)._block_idxs |
| assert np.array_equal(item, ex_split[block_idxs[i]]) |
|
|
|
|
| class SimpleDataset(IterableDataset): |
| def __init__(self, start, end): |
| super().__init__() |
| self._start = start |
| self._end = end |
|
|
| def __iter__(self): |
| return iter(range(self._start, self._end)) |
| |
|
|
| def test_combined_dataset(tmp_path): |
| from lit_llama.packed_dataset import CombinedDataset |
|
|
| dataset1 = SimpleDataset(0, 10) |
| dataset2 = SimpleDataset(10, 20) |
| dataset = CombinedDataset(datasets=[dataset1, dataset2], weights=[1.0, 0.0], seed=12345) |
|
|
| res = [el for el in dataset] |
| assert res == list(range(0, 10)) |
|
|
| dataset1 = SimpleDataset(0, 10) |
| dataset2 = SimpleDataset(10, 20) |
| dataset = CombinedDataset(datasets=[dataset1, dataset2], weights=[0.0, 1.0], seed=12345) |
|
|
| res = [el for el in dataset] |
| assert res == list(range(10, 20)) |
|
|
| dataset1 = SimpleDataset(0, 10) |
| dataset2 = SimpleDataset(10, 20) |
| dataset = CombinedDataset(datasets=[dataset1, dataset2], weights=[0.5, 0.5], seed=12345) |
|
|
| res = [el for el in dataset] |
| assert 9 in res or 19 in res |
| if len(res) > 10: |
| assert 0 in res and 10 in res |
|
|
|
|
| def test_sharded_packed_dataset(monkeypatch): |
| import lit_llama.packed_dataset |
| from lit_llama.packed_dataset import PackedDataset |
|
|
| dataset_iterator_mock = MagicMock() |
| monkeypatch.setattr(lit_llama.packed_dataset, "PackedDatasetIterator", dataset_iterator_mock) |
| filenames = [str(i) for i in range(10)] |
|
|
| |
| iter(PackedDataset(filenames=filenames, n_chunks=2, block_size=2)) |
| assert dataset_iterator_mock.call_args[1]["filenames"] == filenames |
| dataset_iterator_mock.reset_mock() |
| |
| iter(PackedDataset(filenames=filenames, n_chunks=2, block_size=2, num_processes=2, process_rank=0)) |
| assert dataset_iterator_mock.call_args[1]["filenames"] == ["0", "2", "4", "6", "8"] |
| dataset_iterator_mock.reset_mock() |
| |
| iter(PackedDataset(filenames=filenames, n_chunks=2, block_size=2, num_processes=2, process_rank=1)) |
| assert dataset_iterator_mock.call_args[1]["filenames"] == ["1", "3", "5", "7", "9"] |
| dataset_iterator_mock.reset_mock() |
| |
| |
| iter(PackedDataset(filenames=filenames, n_chunks=2, block_size=2, num_processes=3, process_rank=0)) |
| assert dataset_iterator_mock.call_args[1]["filenames"] == ["0", "3", "6"] |
| dataset_iterator_mock.reset_mock() |
| |
| iter(PackedDataset(filenames=filenames, n_chunks=2, block_size=2, num_processes=3, process_rank=1)) |
| assert dataset_iterator_mock.call_args[1]["filenames"] == ["1", "4", "7"] |
| dataset_iterator_mock.reset_mock() |
| |
| iter(PackedDataset(filenames=filenames, n_chunks=2, block_size=2, num_processes=3, process_rank=2)) |
| assert dataset_iterator_mock.call_args[1]["filenames"] == ["2", "5", "8"] |
|
|