| | |
| |
|
| | from __future__ import annotations |
| |
|
| | from copy import deepcopy |
| | from dataclasses import dataclass |
| | from typing import Any, Dict, Iterable, List, Union |
| |
|
| | import numpy as np |
| | import torch |
| | from datasets import Dataset, IterableDataset |
| | from flame.logging import get_logger |
| | from transformers import PreTrainedTokenizer |
| |
|
| | logger = get_logger(__name__) |
| |
|
| |
|
| | class HuggingfaceDataset(IterableDataset): |
| |
|
| | def __init__( |
| | self, |
| | dataset: Dataset, |
| | tokenizer: PreTrainedTokenizer, |
| | context_len: int = 2048, |
| | rank: int = 0, |
| | world_size: int = 1, |
| | buffer_size: int = 1024 |
| | ) -> HuggingfaceDataset: |
| |
|
| | self.dataset = dataset |
| | self.tokenizer = tokenizer |
| |
|
| | self.data = dataset.shard(world_size, rank) |
| | self.context_len = context_len |
| | self.rank = rank |
| | self.world_size = world_size |
| | self.buffer_size = buffer_size |
| |
|
| | if tokenizer.vocab_size < torch.iinfo(torch.int16).max: |
| | self.dtype = torch.int16 |
| | elif tokenizer.vocab_size < torch.iinfo(torch.int32).max: |
| | self.dtype = torch.int32 |
| | else: |
| | self.dtype = torch.int64 |
| | self.states = None |
| | self.buffer = torch.tensor([], dtype=self.dtype) |
| | self.tokens = [] |
| | self.rand_id = 0 |
| | self.token_id = 0 |
| | self.rng_state = None |
| | self._epoch = 0 |
| |
|
| | def __iter__(self): |
| | g = torch.Generator() |
| | g.manual_seed(self._epoch + self.rank) |
| | if self.rng_state is not None: |
| | g.set_state(self.rng_state) |
| |
|
| | rand_it = self.randint(0, self.buffer_size, g=g) |
| | if self.states is not None: |
| | self.data.load_state_dict(self.states) |
| |
|
| | |
| | n_tokens = self.buffer_size * self.context_len |
| |
|
| | while True: |
| | for sample in self.tokenize(self.data): |
| | |
| | self.tokens += sample |
| | |
| | |
| | if len(self.buffer) == 0 and len(self.tokens) >= n_tokens: |
| | self.buffer = torch.tensor(self.tokens[:n_tokens], dtype=self.dtype).view(self.buffer_size, -1) |
| | self.tokens = self.tokens[n_tokens:] |
| | if len(self.buffer) == self.buffer_size: |
| | yield from self.sample(rand_it) |
| |
|
| | n_chunks = len(self.tokens) // self.context_len |
| | |
| | if n_chunks > 0: |
| | n_tokens = n_chunks * self.context_len |
| | indices = torch.randperm(n_chunks, generator=g).tolist() |
| | self.buffer = torch.tensor(self.tokens[:n_tokens], dtype=torch.long).view(n_chunks, -1) |
| | self.tokens = self.tokens[n_tokens:] |
| | for i in indices: |
| | yield {'input_ids': self.buffer[i]} |
| |
|
| | def tokenize(self, data, batch_size: int = 64): |
| | texts, states = [], [] |
| | for sample in data: |
| | texts.append(sample['text']) |
| | states.append(self.data.state_dict()) |
| | if len(texts) == batch_size: |
| | for s, tokenized in zip(states, self.tokenizer(texts, return_attention_mask=False)['input_ids']): |
| | self.states = s |
| | yield tokenized |
| | texts, states = [], [] |
| | if len(texts) > 0: |
| | for s, tokenized in zip(states, self.tokenizer(texts, return_attention_mask=False)['input_ids']): |
| | self.states = s |
| | yield tokenized |
| |
|
| | def sample(self, indices): |
| | n_tokens = (len(self.tokens) // self.context_len) * self.context_len |
| | while self.token_id < n_tokens: |
| | i = next(indices) |
| | start, end = self.token_id, self.token_id + self.context_len |
| | self.token_id += self.context_len |
| | yield {'input_ids': self.buffer[i].to(torch.long)} |
| | self.buffer[i] = torch.tensor(self.tokens[start:end], dtype=self.dtype) |
| | self.token_id = 0 |
| | self.tokens = self.tokens[n_tokens:] |
| |
|
| | def randint( |
| | self, |
| | low: int, |
| | high: int, |
| | batch_size: int = 1024, |
| | g: torch.Generator = torch.Generator() |
| | ) -> Iterable[int]: |
| | indices = torch.empty(batch_size, dtype=torch.long) |
| | while True: |
| | |
| | self.rng_state = g.get_state() |
| | indices = torch.randint(low, high, (batch_size,), out=indices, generator=g) |
| | for i in indices[self.rand_id:].tolist(): |
| | self.rand_id += 1 |
| | yield i |
| | self.rand_id = 0 |
| |
|
| | def set_epoch(self, epoch): |
| | self._epoch = epoch |
| | if hasattr(self.dataset, "set_epoch"): |
| | self.dataset.set_epoch(epoch) |
| |
|
| | def state_dict(self): |
| | return { |
| | 'states': self.states, |
| | 'buffer': self.buffer.clone(), |
| | 'tokens': deepcopy(self.tokens), |
| | 'rand_id': self.rand_id, |
| | 'token_id': self.token_id, |
| | 'rng_state': self.rng_state, |
| | 'epoch': self._epoch |
| | } |
| |
|
| | def load_state_dict(self, state_dict): |
| | self.states = state_dict['states'] |
| | self.buffer = state_dict['buffer'].clone() |
| | self.tokens = deepcopy(state_dict['tokens']) |
| | self.rand_id = state_dict['rand_id'] |
| | self.token_id = state_dict['token_id'] |
| | self.rng_state = state_dict['rng_state'].clone() if state_dict['rng_state'] is not None else None |
| | self._epoch = state_dict['epoch'] |
| |
|
| |
|
| | @dataclass |
| | class DataCollatorForLanguageModeling: |
| | """ |
| | Data collator used for language modeling. |
| | |
| | Args: |
| | tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]): |
| | The tokenizer used for encoding the data. |
| | varlen (`bool`): |
| | Whether to return sequences with variable lengths. |
| | If `True`, the offsets indicating the start and end of each sequence will be returned. |
| | For example, if the sequence lengths are `[4, 8, 12]`, |
| | the returned `input_ids` will be a long flattened tensor of shape `[1, 24]`, with `offsets` being `[0, 4, 12, 24]`. |
| | If `False`, the `input_ids` with shape `[batch_size, seq_len]` will be returned directly. |
| | return_tensors (`str`): |
| | The type of Tensor to return. Allowable values are "pt". |
| | """ |
| |
|
| | tokenizer: PreTrainedTokenizer |
| | varlen: bool = False |
| | return_tensors: str = "pt" |
| |
|
| | def __call__( |
| | self, |
| | examples: List[Union[List[int], Dict[str, Any]]] |
| | ) -> Dict[str, Any]: |
| | if not isinstance(examples[0], Dict): |
| | examples = [{'input_ids': example} for example in examples] |
| |
|
| | def tensorize(example: Dict[str, Any]) -> Dict[str, Any]: |
| | tensorized = {} |
| | for key in ['input_ids', 'offsets']: |
| | if key not in example: |
| | continue |
| | if isinstance(example[key], List): |
| | tensorized[key] = torch.tensor(example[key], dtype=torch.long) |
| | elif isinstance(example[key], np.ndarray): |
| | tensorized[key] = torch.from_numpy(example[key]) |
| | else: |
| | tensorized[key] = example[key] |
| | return tensorized |
| |
|
| | examples = list(map(tensorize, examples)) |
| |
|
| | if not self.varlen: |
| | length_of_first = examples[0]['input_ids'].size(0) |
| | |
| | if all(example['input_ids'].size(0) == length_of_first for example in examples): |
| | batch = { |
| | 'input_ids': torch.stack([example['input_ids'] for example in examples], dim=0), |
| | } |
| | else: |
| | |
| | if self.tokenizer._pad_token is None: |
| | raise ValueError( |
| | f"You are attempting to pad samples but the tokenizer you are using " |
| | f"({self.tokenizer.__class__.__name__}) does not have a pad token." |
| | ) |
| | batch = self.tokenizer.pad(examples, return_tensors=self.return_tensors, return_attention_mask=False) |
| | else: |
| | if len(examples) > 1: |
| | raise ValueError("The batch size must be 1 for variable length inputs.") |
| | batch = { |
| | 'input_ids': torch.cat([example['input_ids'] for example in examples], dim=0).unsqueeze(0) |
| | } |
| | if 'offsets' in examples[0]: |
| | batch['offsets'] = torch.cat([example['offsets'] for example in examples], dim=0).unsqueeze(0) |
| | else: |
| | |
| | if self.tokenizer.add_bos_token: |
| | offsets = [] |
| | if batch['input_ids'][0, 0] != self.tokenizer.bos_token_id: |
| | offsets.append(torch.tensor([0], dtype=torch.long)) |
| | offsets.append(torch.where(batch['input_ids'].eq(self.tokenizer.bos_token_id))[1]) |
| | offsets.append(torch.tensor([len(batch['input_ids'][0])], dtype=torch.long)) |
| | batch['offsets'] = torch.cat(offsets, dim=0) |
| | elif self.tokenizer.add_eos_token: |
| | offsets = [torch.tensor([0], dtype=torch.long)] |
| | offsets.append(torch.where(batch['input_ids'].eq(self.tokenizer.eos_token_id))[1] + 1) |
| | if batch['input_ids'][0, -1] != self.tokenizer.eos_token_id: |
| | offsets.append(torch.tensor([len(batch['input_ids'][0])], dtype=torch.long)) |
| | batch['offsets'] = torch.cat(offsets, dim=0) |
| | else: |
| | raise ValueError("You must allow the tokenizer to add either a bos or eos token as separators.") |
| |
|
| | labels = batch['input_ids'].clone() |
| | if self.tokenizer.pad_token_id is not None: |
| | labels[labels == self.tokenizer.pad_token_id] = -100 |
| | batch["labels"] = labels |
| | return batch |
| |
|