Spaces:
Sleeping
Sleeping
| # Adapted from https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_clm.py | |
| from itertools import chain | |
| from pathlib import Path | |
| import pickle | |
| from typing import Any, List, Union | |
| import subprocess | |
| import mmap | |
| from multiprocessing.shared_memory import SharedMemory | |
| import numpy as np | |
| import torch | |
| from torch.utils.data.dataloader import DataLoader, Dataset | |
| from transformers import AutoTokenizer | |
| from datasets import load_dataset | |
| from pytorch_lightning import LightningDataModule | |
| from src.datamodules.datasets.lm_dataset import LMDataset | |
| from src.datamodules.fault_tolerant_sampler import RandomFaultTolerantSampler | |
| from src.datamodules.fault_tolerant_sampler import FaultTolerantDistributedSampler | |
| from src.datamodules.datasets.detokenizer import DATASET_TOKENIZATION_REGISTRY | |
| from src.utils.utils import get_logger | |
| logger = get_logger() | |
| # https://github.com/numpy/numpy/issues/18294 | |
| class SHMArray(np.ndarray): #copied from https://numpy.org/doc/stable/user/basics.subclassing.html#slightly-more-realistic-example-attribute-added-to-existing-array | |
| def __new__(cls, input_array, shm=None): | |
| obj = np.asarray(input_array).view(cls) | |
| obj.shm = shm | |
| return obj | |
| def __array_finalize__(self, obj): | |
| if obj is None: return | |
| self.shm = getattr(obj, 'shm', None) | |
| class LMDataModule(LightningDataModule): | |
| def __init__(self, dataset_name, tokenizer_name, dataset_config_name=None, max_length=1024, | |
| cache_dir=None, val_ratio=0.0005, val_split_seed=2357, add_eos=True, | |
| detokenize=False, val_only=False, batch_size=32, batch_size_eval=None, num_workers=1, | |
| shuffle=False, pin_memory=False, drop_last=False, fault_tolerant=False, ddp=False, | |
| fast_forward_epochs=None, fast_forward_batches=None, | |
| use_shmem=True): | |
| super().__init__() | |
| self.dataset_name = dataset_name | |
| self.dataset_config_name = dataset_config_name | |
| self.tokenizer_name = tokenizer_name | |
| self.cache_dir = None if cache_dir is None else Path(cache_dir).expanduser() | |
| self.max_length = max_length | |
| self.val_ratio = val_ratio | |
| self.val_split_seed = val_split_seed | |
| self.val_only = val_only | |
| self.add_eos = add_eos | |
| self.detokenize = detokenize | |
| self.batch_size = batch_size | |
| self.batch_size_eval = batch_size_eval if batch_size_eval is not None else self.batch_size | |
| self.num_workers = num_workers | |
| self.shuffle = shuffle | |
| self.pin_memory = pin_memory | |
| self.drop_last = drop_last | |
| if fault_tolerant: | |
| assert self.shuffle | |
| self.fault_tolerant = fault_tolerant | |
| if ddp: | |
| assert fault_tolerant | |
| self.ddp = ddp | |
| self.fast_forward_epochs = fast_forward_epochs | |
| self.fast_forward_batches = fast_forward_batches | |
| if self.fast_forward_epochs is not None or self.fast_forward_batches is not None: | |
| assert ddp and fault_tolerant | |
| self.use_shmem = use_shmem | |
| if self.use_shmem: | |
| assert cache_dir is not None | |
| def prepare_data(self): | |
| if self.cache_dir is None: # Just download the dataset | |
| load_dataset(self.dataset_name, self.dataset_config_name) | |
| else: # Process the dataset and save it | |
| self.process_dataset() | |
| def setup(self, stage=None): | |
| if stage == 'test' and hasattr(self, 'dataset_test'): | |
| return | |
| concat_ids, self.tokenizer = self.process_dataset() | |
| self.vocab_size = len(self.tokenizer) | |
| # Create all splits | |
| self.dataset_train, self.dataset_val, self.dataset_test = [ | |
| LMDataset(concat_ids[split], seq_len=self.max_length) | |
| for split in ['train', 'validation', 'test'] | |
| ] | |
| def process_dataset(self): | |
| cache_dir = None if self.cache_dir is None else self.cache_dir / self._cache_dir_name | |
| if cache_dir is not None: | |
| if cache_dir.is_dir(): | |
| return self._load_from_cache(cache_dir) | |
| raw_datasets = load_dataset(self.dataset_name, self.dataset_config_name) | |
| # https://github.com/stanford-crfm/mistral/blob/main/src/corpora/auto.py | |
| if 'validation' not in raw_datasets: | |
| assert "train" in raw_datasets, "You must have train in raw_datasets to make a validation raw_datasets" | |
| raw_datasets = raw_datasets["train"].train_test_split( | |
| test_size=self.val_ratio, seed=self.val_split_seed, | |
| shuffle=True # Otherwise test will be at the end of the dataset | |
| ) | |
| raw_datasets['validation'] = raw_datasets['test'] | |
| if self.val_only: # Should only be used for evaluation, not for training | |
| raw_datasets['train'] = raw_datasets['validation'] | |
| # [2021-12-25] TD: Running the detokenizer on wikitext-103 makes ppl worse | |
| # (GPT2-small val ppl after 10 epochs ~22 -> ~25) | |
| # However, it's useful for zero-shot transfer from Openwebtext, | |
| # as after detokenization it's closer to Openwebtext's format. | |
| # https://github.com/stanford-crfm/mistral/issues/12 | |
| if self.detokenize: | |
| if self.dataset_name in DATASET_TOKENIZATION_REGISTRY: | |
| detokenizer = DATASET_TOKENIZATION_REGISTRY[self.dataset_name] | |
| raw_datasets = raw_datasets.map( | |
| lambda example: {'text': detokenizer(example['text'])}, | |
| num_proc=max(self.num_workers, 1), | |
| desc='Running detokenizer on dataset' | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=True) | |
| # Preprocessing the datasets. | |
| # First we tokenize all the texts. | |
| column_names = raw_datasets["train"].column_names | |
| text_column_name = "text" if "text" in column_names else column_names[0] | |
| # [2021-12-25] TD: For wikitext, don't need to add the EOS since each example already ends | |
| # with '\n', and there are no other '\n' in the examples. | |
| # assert all([t.count('\n') == 1 for t in raw_datasets['train']['text'] if t]) | |
| # Add EOS token to the end of the text if the text is not empty | |
| # https://github.com/stanford-crfm/mistral/issues/91 | |
| # https://github.com/stanford-crfm/mistral/pull/98 | |
| if self.add_eos: | |
| add_eos = lambda seq: (seq + tokenizer.eos_token) if seq else seq | |
| add_eos_batched = lambda seqs: [add_eos(seq) for seq in seqs] | |
| tokenize = lambda example: tokenizer(add_eos_batched(example[text_column_name])) | |
| else: | |
| tokenize = lambda example: tokenizer(example[text_column_name]) | |
| # tokenized_datasets = raw_datasets.map( | |
| # tokenize, | |
| # batched=True, | |
| # num_proc=max(self.num_workers, 1), | |
| # remove_columns=column_names, | |
| # desc="Running tokenizer on dataset", | |
| # ) | |
| dtype = np.uint16 if tokenizer.vocab_size < 64 * 1024 else np.int32 | |
| def tokenize_concat(examples): | |
| # We just need 'input_ids', not 'attention_mask' (since it's all 1) | |
| input_ids = np.fromiter(chain(*tokenize(examples)['input_ids']), dtype=dtype) | |
| # Need to return a list since we're doing batched processing | |
| return {'input_ids': [input_ids], 'len': [len(input_ids)]} | |
| tokenized_datasets = raw_datasets.map( | |
| tokenize_concat, | |
| batched=True, | |
| num_proc=max(self.num_workers, 1), | |
| remove_columns=column_names, | |
| desc="Running tokenizer on dataset", | |
| ) | |
| if self.use_shmem: | |
| # Concatenate all input_ids into an array in shared memory | |
| def write_ids_to_shm(example, shm_name, array_len): | |
| shm = SharedMemory(name=shm_name) | |
| shm_arr = np.ndarray((array_len,), dtype=dtype, buffer=shm.buf) | |
| start_idx = example['len_offset'] - len(example['input_ids']) | |
| shm_arr[start_idx:example['len_offset']] = example['input_ids'] | |
| shm.close() | |
| concat_ids = {} | |
| for name, ds in tokenized_datasets.items(): | |
| tokenized_datasets[name] = ds.add_column('len_offset', np.cumsum(ds['len'])) | |
| array_len = tokenized_datasets[name][-1]['len_offset'] | |
| shm = SharedMemory(create=True, size=array_len * np.dtype(dtype).itemsize) | |
| shm_name = shm.name | |
| tokenized_datasets[name].map( | |
| write_ids_to_shm, | |
| fn_kwargs={'shm_name': shm_name, 'array_len': array_len}, | |
| batched=False, | |
| num_proc=max(self.num_workers, 1), | |
| desc="Concatenating examples", | |
| ) | |
| shm_arr = np.ndarray((array_len,), dtype=dtype, buffer=shm.buf) | |
| # We need to keep a reference to the shared memory, otherwise it gets garbage-collected | |
| # when it goes out of scope, and that memory is gone. | |
| # https://github.com/numpy/numpy/issues/18294 | |
| concat_ids[name] = SHMArray(shm_arr, shm=shm) | |
| else: | |
| # Use disk | |
| concat_ids = {} | |
| assert cache_dir is not None | |
| cache_dir.mkdir(parents=True, exist_ok=True) | |
| def write_ids_to_disk(example, filename): | |
| with open(filename, 'r+b') as f: | |
| mm = mmap.mmap(f.fileno(), 0) | |
| start_idx = example['len_offset'] - len(example['input_ids']) | |
| array_len = len(example['input_ids']) | |
| arr = np.ndarray((array_len,), dtype=dtype, buffer=mm, | |
| offset=np.dtype(dtype).itemsize * start_idx) | |
| arr[:] = example['input_ids'] | |
| mm.flush() | |
| for name, ds in tokenized_datasets.items(): | |
| tokenized_datasets[name] = ds.add_column('len_offset', np.cumsum(ds['len'])) | |
| array_len = tokenized_datasets[name][-1]['len_offset'] | |
| filename = cache_dir / f'{name}.bin' | |
| # Need to create the file with this specific size first | |
| # https://ostechnix.com/create-files-certain-size-linux/ | |
| subprocess.run(['truncate', '-s', str(array_len * np.dtype(dtype).itemsize), | |
| str(filename)], check=True) | |
| tokenized_datasets[name].map( | |
| write_ids_to_disk, | |
| fn_kwargs={'filename': filename}, | |
| batched=False, | |
| num_proc=max(self.num_workers, 1), | |
| desc="Concatenating examples", | |
| ) | |
| concat_ids[name] = np.memmap(filename, dtype=dtype, mode='r', shape=(array_len,)) | |
| if cache_dir is not None: | |
| self._save_to_cache(concat_ids, tokenizer, cache_dir) | |
| if not self.use_shmem: | |
| for name in concat_ids: | |
| Path(cache_dir / f'{name}.bin').unlink() | |
| return concat_ids, tokenizer | |
| def _save_to_cache(self, concat_ids, tokenizer, cache_dir): | |
| cache_dir.mkdir(parents=True, exist_ok=True) | |
| logger.info(f'Saving to cache at {str(cache_dir)}') | |
| for k, v in concat_ids.items(): | |
| np.save(cache_dir / f'{k}.npy', v) | |
| with open(cache_dir / 'tokenizer.pkl', 'wb') as f: | |
| pickle.dump(tokenizer, f) | |
| def _load_from_cache(self, cache_dir): | |
| assert cache_dir.is_dir() | |
| logger.info(f'Load from cache at {str(cache_dir)}') | |
| concat_ids = {split: np.load(cache_dir / f'{split}.npy', mmap_mode='r') | |
| for split in ['train', 'validation', 'test']} | |
| with open(cache_dir / 'tokenizer.pkl', 'rb') as f: | |
| tokenizer = pickle.load(f) | |
| return concat_ids, tokenizer | |
| def _cache_dir_name(self): | |
| return f'tokenizer_name-{self.tokenizer_name}-val_ratio-{self.val_ratio}-val_split_seed-{self.val_split_seed}-add_eos-{self.add_eos}-detokenize-{self.detokenize}' | |
| def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader: | |
| """ The train dataloader """ | |
| if self.shuffle and self.fault_tolerant: | |
| shuffle = False | |
| sampler = (FaultTolerantDistributedSampler(self.dataset_train) if self.ddp | |
| else RandomFaultTolerantSampler(self.dataset_train)) | |
| # TD [2022-08-06]: Only the DDP sampler supports fast-forwarding for now | |
| # We assume that it's being resumed with the same number of GPUs | |
| if self.ddp and self.fast_forward_epochs is not None and self.fast_forward_batches is not None: | |
| sampler.load_state_dict({ | |
| 'epoch': self.fast_forward_epochs, | |
| 'counter': self.fast_forward_batches * self.batch_size | |
| }) | |
| else: | |
| shuffle = self.shuffle | |
| sampler = None | |
| return self._data_loader(self.dataset_train, batch_size=self.batch_size, | |
| shuffle=shuffle, sampler=sampler) | |
| def val_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]: | |
| """ The val dataloader """ | |
| return self._data_loader(self.dataset_val, batch_size=self.batch_size_eval) | |
| def test_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]: | |
| """ The test dataloader """ | |
| return self._data_loader(self.dataset_test, batch_size=self.batch_size_eval) | |
| def _data_loader(self, dataset: Dataset, batch_size: int, shuffle: bool = False, | |
| sampler=None) -> DataLoader: | |
| return DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| num_workers=1, # Data is already in memory, we don't need many workers | |
| shuffle=shuffle, | |
| sampler=sampler, | |
| drop_last=self.drop_last, | |
| pin_memory=self.pin_memory, | |
| # persistent_workers=True | |
| ) | |
| def load_state_dict(self, checkpoint): | |
| if self.fault_tolerant: | |
| self.fast_forward_epochs = checkpoint['loops']['fit_loop']['epoch_progress']['current']['completed'] | |
| # TD [2022-08-07] ['epoch_loop.batch_progress']['total']['completed'] is 1 iteration | |
| # behind, so we're using the optimizer's progress. This is set correctly in seq.py. | |
| self.fast_forward_batches = checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed'] | |
| # At this point the train loader hasn't been constructed yet | |