| |
|
|
| from __future__ import annotations |
|
|
| from copy import deepcopy |
| from dataclasses import dataclass |
| from typing import Any, Dict, Iterable, List, Union |
|
|
| import numpy as np |
| import torch |
| from datasets import Dataset, IterableDataset |
| from flame.logging import get_logger |
| from transformers import PreTrainedTokenizer |
|
|
| logger = get_logger(__name__) |
|
|
|
|
| class HuggingfaceDataset(IterableDataset): |
|
|
| def __init__( |
| self, |
| dataset: Dataset, |
| tokenizer: PreTrainedTokenizer, |
| context_len: int = 2048, |
| rank: int = 0, |
| world_size: int = 1, |
| buffer_size: int = 1024 |
| ) -> HuggingfaceDataset: |
|
|
| self.dataset = dataset |
| self.tokenizer = tokenizer |
|
|
| self.data = dataset.shard(world_size, rank) |
| self.context_len = context_len |
| self.rank = rank |
| self.world_size = world_size |
| self.buffer_size = buffer_size |
|
|
| if tokenizer.vocab_size < torch.iinfo(torch.int16).max: |
| self.dtype = torch.int16 |
| elif tokenizer.vocab_size < torch.iinfo(torch.int32).max: |
| self.dtype = torch.int32 |
| else: |
| self.dtype = torch.int64 |
| self.states = None |
| self.buffer = torch.tensor([], dtype=self.dtype) |
| self.tokens = [] |
| self.rand_id = 0 |
| self.token_id = 0 |
| self.rng_state = None |
| self._epoch = 0 |
|
|
| def __iter__(self): |
| g = torch.Generator() |
| g.manual_seed(self._epoch + self.rank) |
| if self.rng_state is not None: |
| g.set_state(self.rng_state) |
|
|
| rand_it = self.randint(0, self.buffer_size, g=g) |
| if self.states is not None: |
| self.data.load_state_dict(self.states) |
|
|
| |
| n_tokens = self.buffer_size * self.context_len |
|
|
| while True: |
| for sample in self.tokenize(self.data): |
| |
| self.tokens += sample |
| |
| |
| if len(self.buffer) == 0 and len(self.tokens) >= n_tokens: |
| self.buffer = torch.tensor(self.tokens[:n_tokens], dtype=self.dtype).view(self.buffer_size, -1) |
| self.tokens = self.tokens[n_tokens:] |
| if len(self.buffer) == self.buffer_size: |
| yield from self.sample(rand_it) |
|
|
| n_chunks = len(self.tokens) // self.context_len |
| |
| if n_chunks > 0: |
| n_tokens = n_chunks * self.context_len |
| indices = torch.randperm(n_chunks, generator=g).tolist() |
| self.buffer = torch.tensor(self.tokens[:n_tokens], dtype=torch.long).view(n_chunks, -1) |
| self.tokens = self.tokens[n_tokens:] |
| for i in indices: |
| yield {'input_ids': self.buffer[i]} |
|
|
| def tokenize(self, data, batch_size: int = 64): |
| texts, states = [], [] |
| for sample in data: |
| texts.append(sample['text']) |
| states.append(self.data.state_dict()) |
| if len(texts) == batch_size: |
| for s, tokenized in zip(states, self.tokenizer(texts, return_attention_mask=False)['input_ids']): |
| self.states = s |
| yield tokenized |
| texts, states = [], [] |
| if len(texts) > 0: |
| for s, tokenized in zip(states, self.tokenizer(texts, return_attention_mask=False)['input_ids']): |
| self.states = s |
| yield tokenized |
|
|
| def sample(self, indices): |
| n_tokens = (len(self.tokens) // self.context_len) * self.context_len |
| while self.token_id < n_tokens: |
| i = next(indices) |
| start, end = self.token_id, self.token_id + self.context_len |
| self.token_id += self.context_len |
| yield {'input_ids': self.buffer[i].to(torch.long)} |
| self.buffer[i] = torch.tensor(self.tokens[start:end], dtype=self.dtype) |
| self.token_id = 0 |
| self.tokens = self.tokens[n_tokens:] |
|
|
| def randint( |
| self, |
| low: int, |
| high: int, |
| batch_size: int = 1024, |
| g: torch.Generator = torch.Generator() |
| ) -> Iterable[int]: |
| indices = torch.empty(batch_size, dtype=torch.long) |
| while True: |
| |
| self.rng_state = g.get_state() |
| indices = torch.randint(low, high, (batch_size,), out=indices, generator=g) |
| for i in indices[self.rand_id:].tolist(): |
| self.rand_id += 1 |
| yield i |
| self.rand_id = 0 |
|
|
| def set_epoch(self, epoch): |
| self._epoch = epoch |
| if hasattr(self.dataset, "set_epoch"): |
| self.dataset.set_epoch(epoch) |
|
|
| def state_dict(self): |
| return { |
| 'states': self.states, |
| 'buffer': self.buffer.clone(), |
| 'tokens': deepcopy(self.tokens), |
| 'rand_id': self.rand_id, |
| 'token_id': self.token_id, |
| 'rng_state': self.rng_state, |
| 'epoch': self._epoch |
| } |
|
|
| def load_state_dict(self, state_dict): |
| self.states = state_dict['states'] |
| self.buffer = state_dict['buffer'].clone() |
| self.tokens = deepcopy(state_dict['tokens']) |
| self.rand_id = state_dict['rand_id'] |
| self.token_id = state_dict['token_id'] |
| self.rng_state = state_dict['rng_state'].clone() if state_dict['rng_state'] is not None else None |
| self._epoch = state_dict['epoch'] |
|
|
|
|
| @dataclass |
| class DataCollatorForLanguageModeling: |
| """ |
| Data collator used for language modeling. |
| |
| Args: |
| tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]): |
| The tokenizer used for encoding the data. |
| varlen (`bool`): |
| Whether to return sequences with variable lengths. |
| If `True`, the offsets indicating the start and end of each sequence will be returned. |
| For example, if the sequence lengths are `[4, 8, 12]`, |
| the returned `input_ids` will be a long flattened tensor of shape `[1, 24]`, with `offsets` being `[0, 4, 12, 24]`. |
| If `False`, the `input_ids` with shape `[batch_size, seq_len]` will be returned directly. |
| return_tensors (`str`): |
| The type of Tensor to return. Allowable values are "pt". |
| """ |
|
|
| tokenizer: PreTrainedTokenizer |
| varlen: bool = False |
| return_tensors: str = "pt" |
|
|
| def __call__( |
| self, |
| examples: List[Union[List[int], Dict[str, Any]]] |
| ) -> Dict[str, Any]: |
| if not isinstance(examples[0], Dict): |
| examples = [{'input_ids': example} for example in examples] |
|
|
| def tensorize(example: Dict[str, Any]) -> Dict[str, Any]: |
| tensorized = {} |
| for key in ['input_ids', 'offsets']: |
| if key not in example: |
| continue |
| if isinstance(example[key], List): |
| tensorized[key] = torch.tensor(example[key], dtype=torch.long) |
| elif isinstance(example[key], np.ndarray): |
| tensorized[key] = torch.from_numpy(example[key]) |
| else: |
| tensorized[key] = example[key] |
| return tensorized |
|
|
| examples = list(map(tensorize, examples)) |
|
|
| if not self.varlen: |
| length_of_first = examples[0]['input_ids'].size(0) |
| |
| if all(example['input_ids'].size(0) == length_of_first for example in examples): |
| batch = { |
| 'input_ids': torch.stack([example['input_ids'] for example in examples], dim=0), |
| } |
| else: |
| |
| if self.tokenizer._pad_token is None: |
| raise ValueError( |
| f"You are attempting to pad samples but the tokenizer you are using " |
| f"({self.tokenizer.__class__.__name__}) does not have a pad token." |
| ) |
| batch = self.tokenizer.pad(examples, return_tensors=self.return_tensors, return_attention_mask=False) |
| else: |
| if len(examples) > 1: |
| raise ValueError("The batch size must be 1 for variable length inputs.") |
| batch = { |
| 'input_ids': torch.cat([example['input_ids'] for example in examples], dim=0).unsqueeze(0) |
| } |
| if 'offsets' in examples[0]: |
| batch['offsets'] = torch.cat([example['offsets'] for example in examples], dim=0).unsqueeze(0) |
| else: |
| |
| if self.tokenizer.add_bos_token: |
| offsets = [] |
| if batch['input_ids'][0, 0] != self.tokenizer.bos_token_id: |
| offsets.append(torch.tensor([0], dtype=torch.long)) |
| offsets.append(torch.where(batch['input_ids'].eq(self.tokenizer.bos_token_id))[1]) |
| offsets.append(torch.tensor([len(batch['input_ids'][0])], dtype=torch.long)) |
| batch['offsets'] = torch.cat(offsets, dim=0) |
| elif self.tokenizer.add_eos_token: |
| offsets = [torch.tensor([0], dtype=torch.long)] |
| offsets.append(torch.where(batch['input_ids'].eq(self.tokenizer.eos_token_id))[1] + 1) |
| if batch['input_ids'][0, -1] != self.tokenizer.eos_token_id: |
| offsets.append(torch.tensor([len(batch['input_ids'][0])], dtype=torch.long)) |
| batch['offsets'] = torch.cat(offsets, dim=0) |
| else: |
| raise ValueError("You must allow the tokenizer to add either a bos or eos token as separators.") |
|
|
| labels = batch['input_ids'].clone() |
| if self.tokenizer.pad_token_id is not None: |
| labels[labels == self.tokenizer.pad_token_id] = -100 |
| batch["labels"] = labels |
| return batch |
|
|