File size: 10,231 Bytes
50776d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
# -*- coding: utf-8 -*-

from __future__ import annotations

from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Union

import numpy as np
import torch
from datasets import Dataset, IterableDataset
from flame.logging import get_logger
from transformers import PreTrainedTokenizer

logger = get_logger(__name__)


class HuggingfaceDataset(IterableDataset):

    def __init__(
        self,
        dataset: Dataset,
        tokenizer: PreTrainedTokenizer,
        context_len: int = 2048,
        rank: int = 0,
        world_size: int = 1,
        buffer_size: int = 1024
    ) -> HuggingfaceDataset:

        self.dataset = dataset
        self.tokenizer = tokenizer

        self.data = dataset.shard(world_size, rank)
        self.context_len = context_len
        self.rank = rank
        self.world_size = world_size
        self.buffer_size = buffer_size

        if tokenizer.vocab_size < torch.iinfo(torch.int16).max:
            self.dtype = torch.int16
        elif tokenizer.vocab_size < torch.iinfo(torch.int32).max:
            self.dtype = torch.int32
        else:
            self.dtype = torch.int64
        self.states = None
        self.buffer = torch.tensor([], dtype=self.dtype)
        self.tokens = []
        self.rand_id = 0
        self.token_id = 0
        self.rng_state = None
        self._epoch = 0

    def __iter__(self):
        g = torch.Generator()
        g.manual_seed(self._epoch + self.rank)
        if self.rng_state is not None:
            g.set_state(self.rng_state)

        rand_it = self.randint(0, self.buffer_size, g=g)
        if self.states is not None:
            self.data.load_state_dict(self.states)

        # max number of tokens allowed in the chunk buffer
        n_tokens = self.buffer_size * self.context_len

        while True:
            for sample in self.tokenize(self.data):
                # keep appending the samples to the token buffer
                self.tokens += sample
                # if the token buffer is full, start sampling
                # NOTE: we first convert the token ids to a tensor of shape [n_chunks, context_len] for efficiency
                if len(self.buffer) == 0 and len(self.tokens) >= n_tokens:
                    self.buffer = torch.tensor(self.tokens[:n_tokens], dtype=self.dtype).view(self.buffer_size, -1)
                    self.tokens = self.tokens[n_tokens:]
                if len(self.buffer) == self.buffer_size:
                    yield from self.sample(rand_it)

            n_chunks = len(self.tokens) // self.context_len
            # handle the left tokens in the buffer
            if n_chunks > 0:
                n_tokens = n_chunks * self.context_len
                indices = torch.randperm(n_chunks, generator=g).tolist()
                self.buffer = torch.tensor(self.tokens[:n_tokens], dtype=torch.long).view(n_chunks, -1)
                self.tokens = self.tokens[n_tokens:]
                for i in indices:
                    yield {'input_ids': self.buffer[i]}

    def tokenize(self, data, batch_size: int = 64):
        texts, states = [], []
        for sample in data:
            texts.append(sample['text'])
            states.append(self.data.state_dict())
            if len(texts) == batch_size:
                for s, tokenized in zip(states, self.tokenizer(texts, return_attention_mask=False)['input_ids']):
                    self.states = s
                    yield tokenized
                texts, states = [], []
        if len(texts) > 0:
            for s, tokenized in zip(states, self.tokenizer(texts, return_attention_mask=False)['input_ids']):
                self.states = s
                yield tokenized

    def sample(self, indices):
        n_tokens = (len(self.tokens) // self.context_len) * self.context_len
        while self.token_id < n_tokens:
            i = next(indices)
            start, end = self.token_id, self.token_id + self.context_len
            self.token_id += self.context_len
            yield {'input_ids': self.buffer[i].to(torch.long)}
            self.buffer[i] = torch.tensor(self.tokens[start:end], dtype=self.dtype)
        self.token_id = 0
        self.tokens = self.tokens[n_tokens:]

    def randint(
        self,
        low: int,
        high: int,
        batch_size: int = 1024,
        g: torch.Generator = torch.Generator()
    ) -> Iterable[int]:
        indices = torch.empty(batch_size, dtype=torch.long)
        while True:
            # record the generator states before sampling
            self.rng_state = g.get_state()
            indices = torch.randint(low, high, (batch_size,), out=indices, generator=g)
            for i in indices[self.rand_id:].tolist():
                self.rand_id += 1
                yield i
            self.rand_id = 0

    def set_epoch(self, epoch):
        self._epoch = epoch
        if hasattr(self.dataset, "set_epoch"):
            self.dataset.set_epoch(epoch)

    def state_dict(self):
        return {
            'states': self.states,
            'buffer': self.buffer.clone(),
            'tokens': deepcopy(self.tokens),
            'rand_id': self.rand_id,
            'token_id': self.token_id,
            'rng_state': self.rng_state,
            'epoch': self._epoch
        }

    def load_state_dict(self, state_dict):
        self.states = state_dict['states']
        self.buffer = state_dict['buffer'].clone()
        self.tokens = deepcopy(state_dict['tokens'])
        self.rand_id = state_dict['rand_id']
        self.token_id = state_dict['token_id']
        self.rng_state = state_dict['rng_state'].clone() if state_dict['rng_state'] is not None else None
        self._epoch = state_dict['epoch']


@dataclass
class DataCollatorForLanguageModeling:
    """
    Data collator used for language modeling.

    Args:
        tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
            The tokenizer used for encoding the data.
        varlen (`bool`):
            Whether to return sequences with variable lengths.
            If `True`, the offsets indicating the start and end of each sequence will be returned.
            For example, if the sequence lengths are `[4, 8, 12]`,
            the returned `input_ids` will be a long flattened tensor of shape `[1, 24]`, with `offsets` being `[0, 4, 12, 24]`.
            If `False`, the `input_ids` with shape `[batch_size, seq_len]` will be returned directly.
        return_tensors (`str`):
            The type of Tensor to return. Allowable values are "pt".
    """

    tokenizer: PreTrainedTokenizer
    varlen: bool = False
    return_tensors: str = "pt"

    def __call__(
        self,
        examples: List[Union[List[int], Dict[str, Any]]]
    ) -> Dict[str, Any]:
        if not isinstance(examples[0], Dict):
            examples = [{'input_ids': example} for example in examples]

        def tensorize(example: Dict[str, Any]) -> Dict[str, Any]:
            tensorized = {}
            for key in ['input_ids', 'offsets']:
                if key not in example:
                    continue
                if isinstance(example[key], List):
                    tensorized[key] = torch.tensor(example[key], dtype=torch.long)
                elif isinstance(example[key], np.ndarray):
                    tensorized[key] = torch.from_numpy(example[key])
                else:
                    tensorized[key] = example[key]
            return tensorized

        examples = list(map(tensorize, examples))

        if not self.varlen:
            length_of_first = examples[0]['input_ids'].size(0)
            # Check if padding is necessary.
            if all(example['input_ids'].size(0) == length_of_first for example in examples):
                batch = {
                    'input_ids': torch.stack([example['input_ids'] for example in examples], dim=0),
                }
            else:
                # If yes, check if we have a `pad_token`.
                if self.tokenizer._pad_token is None:
                    raise ValueError(
                        f"You are attempting to pad samples but the tokenizer you are using "
                        f"({self.tokenizer.__class__.__name__}) does not have a pad token."
                    )
                batch = self.tokenizer.pad(examples, return_tensors=self.return_tensors, return_attention_mask=False)
        else:
            if len(examples) > 1:
                raise ValueError("The batch size must be 1 for variable length inputs.")
            batch = {
                'input_ids': torch.cat([example['input_ids'] for example in examples], dim=0).unsqueeze(0)
            }
            if 'offsets' in examples[0]:
                batch['offsets'] = torch.cat([example['offsets'] for example in examples], dim=0).unsqueeze(0)
            else:
                # determine boundaries by bos/eos positions
                if self.tokenizer.add_bos_token:
                    offsets = []
                    if batch['input_ids'][0, 0] != self.tokenizer.bos_token_id:
                        offsets.append(torch.tensor([0], dtype=torch.long))
                    offsets.append(torch.where(batch['input_ids'].eq(self.tokenizer.bos_token_id))[1])
                    offsets.append(torch.tensor([len(batch['input_ids'][0])], dtype=torch.long))
                    batch['offsets'] = torch.cat(offsets, dim=0)
                elif self.tokenizer.add_eos_token:
                    offsets = [torch.tensor([0], dtype=torch.long)]
                    offsets.append(torch.where(batch['input_ids'].eq(self.tokenizer.eos_token_id))[1] + 1)
                    if batch['input_ids'][0, -1] != self.tokenizer.eos_token_id:
                        offsets.append(torch.tensor([len(batch['input_ids'][0])], dtype=torch.long))
                    batch['offsets'] = torch.cat(offsets, dim=0)
                else:
                    raise ValueError("You must allow the tokenizer to add either a bos or eos token as separators.")

        labels = batch['input_ids'].clone()
        if self.tokenizer.pad_token_id is not None:
            labels[labels == self.tokenizer.pad_token_id] = -100
        batch["labels"] = labels
        return batch