AlekseyCalvin commited on
Commit
7e9ee68
·
verified ·
1 Parent(s): c752780

Upload 4 files

Browse files
special_tokens_map.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<bos>",
3
+ "eos_token": "<bos>",
4
+ "pad_token": "<pad>"
5
+ }
tokenization_bolmo.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from functools import lru_cache
3
+ from typing import Optional, Union
4
+ from transformers import AutoTokenizer
5
+ from transformers.tokenization_utils import PreTrainedTokenizer
6
+
7
+ # Source: https://github.com/openai/gpt-2/blob/master/src/encoder.py#L9
8
+ # Also implemented in https://docs.rs/tokenizers/latest/src/tokenizers/pre_tokenizers/byte_level.rs.html#13-39
9
+ _CHARS_TO_BYTES = {
10
+ "Ā": 0, "ā": 1, "Ă": 2, "ă": 3, "Ą": 4, "ą": 5, "Ć": 6, "ć": 7, "Ĉ": 8,
11
+ "ĉ": 9, "Ċ": 10, "ċ": 11, "Č": 12, "č": 13, "Ď": 14, "ď": 15, "Đ": 16,
12
+ "đ": 17, "Ē": 18, "ē": 19, "Ĕ": 20, "ĕ": 21, "Ė": 22, "ė": 23, "Ę": 24,
13
+ "ę": 25, "Ě": 26, "ě": 27, "Ĝ": 28, "ĝ": 29, "Ğ": 30, "ğ": 31, "Ġ": 32,
14
+ "!": 33, '"': 34, "#": 35, "$": 36, "%": 37, "&": 38, "'": 39, "(": 40,
15
+ ")": 41, "*": 42, "+": 43, ",": 44, "-": 45, ".": 46, "/": 47, "0": 48,
16
+ "1": 49, "2": 50, "3": 51, "4": 52, "5": 53, "6": 54, "7": 55, "8": 56,
17
+ "9": 57, ":": 58, ";": 59, "<": 60, "=": 61, ">": 62, "?": 63, "@": 64,
18
+ "A": 65, "B": 66, "C": 67, "D": 68, "E": 69, "F": 70, "G": 71, "H": 72,
19
+ "I": 73, "J": 74, "K": 75, "L": 76, "M": 77, "N": 78, "O": 79, "P": 80,
20
+ "Q": 81, "R": 82, "S": 83, "T": 84, "U": 85, "V": 86, "W": 87, "X": 88,
21
+ "Y": 89, "Z": 90, "[": 91, "\\": 92, "]": 93, "^": 94, "_": 95, "`": 96,
22
+ "a": 97, "b": 98, "c": 99, "d": 100, "e": 101, "f": 102, "g": 103,
23
+ "h": 104, "i": 105, "j": 106, "k": 107, "l": 108, "m": 109, "n": 110,
24
+ "o": 111, "p": 112, "q": 113, "r": 114, "s": 115, "t": 116, "u": 117,
25
+ "v": 118, "w": 119, "x": 120, "y": 121, "z": 122, "{": 123, "|": 124,
26
+ "}": 125, "~": 126, "ġ": 127, "Ģ": 128, "ģ": 129, "Ĥ": 130, "ĥ": 131,
27
+ "Ħ": 132, "ħ": 133, "Ĩ": 134, "ĩ": 135, "Ī": 136, "ī": 137, "Ĭ": 138,
28
+ "ĭ": 139, "Į": 140, "į": 141, "İ": 142, "ı": 143, "IJ": 144, "ij": 145,
29
+ "Ĵ": 146, "ĵ": 147, "Ķ": 148, "ķ": 149, "ĸ": 150, "Ĺ": 151, "ĺ": 152,
30
+ "Ļ": 153, "ļ": 154, "Ľ": 155, "ľ": 156, "Ŀ": 157, "ŀ": 158, "Ł": 159,
31
+ "ł": 160, "¡": 161, "¢": 162, "£": 163, "¤": 164, "¥": 165, "¦": 166,
32
+ "§": 167, "¨": 168, "©": 169, "ª": 170, "«": 171, "¬": 172, "Ń": 173,
33
+ "®": 174, "¯": 175, "°": 176, "±": 177, "²": 178, "³": 179, "´": 180,
34
+ "µ": 181, "¶": 182, "·": 183, "¸": 184, "¹": 185, "º": 186, "»": 187,
35
+ "¼": 188, "½": 189, "¾": 190, "¿": 191, "À": 192, "Á": 193, "Â": 194,
36
+ "Ã": 195, "Ä": 196, "Å": 197, "Æ": 198, "Ç": 199, "È": 200, "É": 201,
37
+ "Ê": 202, "Ë": 203, "Ì": 204, "Í": 205, "Î": 206, "Ï": 207, "Ð": 208,
38
+ "Ñ": 209, "Ò": 210, "Ó": 211, "Ô": 212, "Õ": 213, "Ö": 214, "×": 215,
39
+ "Ø": 216, "Ù": 217, "Ú": 218, "Û": 219, "Ü": 220, "Ý": 221, "Þ": 222,
40
+ "ß": 223, "à": 224, "á": 225, "â": 226, "ã": 227, "ä": 228, "å": 229,
41
+ "æ": 230, "ç": 231, "è": 232, "é": 233, "ê": 234, "ë": 235, "ì": 236,
42
+ "í": 237, "î": 238, "ï": 239, "ð": 240, "ñ": 241, "ò": 242, "ó": 243,
43
+ "ô": 244, "õ": 245, "ö": 246, "÷": 247, "ø": 248, "ù": 249, "ú": 250,
44
+ "û": 251, "ü": 252, "ý": 253, "þ": 254, "ÿ": 255,
45
+ }
46
+ _BYTES_TO_CHARS = {v: k for k, v in _CHARS_TO_BYTES.items()}
47
+
48
+ def _bytes_to_chars(byte_sequence: bytes) -> str:
49
+ return "".join(_BYTES_TO_CHARS[byte] for byte in byte_sequence)
50
+
51
+ def _chars_to_bytes(char_sequence: str) -> list:
52
+ return list(bytes(_CHARS_TO_BYTES[char] for char in char_sequence))
53
+
54
+ @dataclass
55
+ class BolmoTokenizerConfig:
56
+ vocab_size: int
57
+ bos_token_id: int
58
+ pad_token_id: int
59
+ eos_token_id: int
60
+ bpe_token_end_id: int
61
+ special_tokens: list[str] = field(default_factory=lambda: [])
62
+ special_tokens_first: bool = True
63
+ original_identifier: Optional[str] = None
64
+
65
+
66
+ @classmethod
67
+ def bolmo(cls) -> "BolmoTokenizerConfig":
68
+ special_tokens = [
69
+ "<pad>",
70
+ "<bos>",
71
+ "<eos>",
72
+ "<bpe_token_end>",
73
+ ]
74
+
75
+ return cls(
76
+ # *2 to accomodate fused boundary tokens
77
+ vocab_size=(len(special_tokens) + 256) * 2,
78
+ special_tokens=special_tokens,
79
+ bos_token_id=special_tokens.index("<bos>"),
80
+ pad_token_id=special_tokens.index("<pad>"),
81
+ eos_token_id=special_tokens.index("<bos>"),
82
+ bpe_token_end_id=special_tokens.index("<bpe_token_end>"),
83
+ original_identifier="allenai/dolma2-tokenizer",
84
+ )
85
+
86
+ def build(self):
87
+ return BolmoTokenizer(tokenizer_config=self)
88
+
89
+
90
+ class BolmoTokenizer(PreTrainedTokenizer):
91
+ TOKEN_ID_KEY = -1
92
+
93
+ def __init__(self, **kwargs):
94
+ tokenizer_config = kwargs.pop("tokenizer_config", BolmoTokenizerConfig.bolmo())
95
+
96
+ self.config = tokenizer_config
97
+ self.hf_tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.original_identifier)
98
+ if self.config.special_tokens_first:
99
+ self.offset = len(tokenizer_config.special_tokens)
100
+ self.special_tokens_offset = 0
101
+ else:
102
+ self.offset = 0
103
+ self.special_tokens_offset = self.config.vocab_size - len(tokenizer_config.special_tokens)
104
+
105
+ self.byte_sequences = {}
106
+
107
+ for key, value in self.hf_tokenizer.get_vocab().items():
108
+ if key in self.config.special_tokens:
109
+ byte_sequence = [self.special_tokens_offset + self.config.special_tokens.index(key)]
110
+ elif value == self.hf_tokenizer.eos_token_id and self.eos_token_id is not None:
111
+ byte_sequence = [self.eos_token_id]
112
+ elif value == self.hf_tokenizer.bos_token_id and self.bos_token_id is not None:
113
+ byte_sequence = [self.bos_token_id]
114
+ elif value == self.hf_tokenizer.pad_token_id and self.pad_token_id is not None:
115
+ byte_sequence = [self.pad_token_id]
116
+ else:
117
+ byte_sequence = [self.offset + i for i in _chars_to_bytes(key)]
118
+
119
+ assert self.byte_sequences.get(value) is None
120
+ self.byte_sequences[value] = byte_sequence
121
+
122
+ self.byte_trie = {}
123
+
124
+ for token_id, byte_sequence in self.byte_sequences.items():
125
+ current_dict = self.byte_trie
126
+ for byte in byte_sequence[::-1]: # retrieved from the back so store in reverse order
127
+ if byte not in current_dict:
128
+ current_dict[byte] = {}
129
+ current_dict = current_dict[byte]
130
+ current_dict[BolmoTokenizer.TOKEN_ID_KEY] = token_id
131
+
132
+ self.add_bos_token = True
133
+ self.add_eos_token = False
134
+ self.padding_side = "left" # for generate
135
+
136
+ super().__init__(
137
+ bos_token=self.config.special_tokens[self.config.bos_token_id],
138
+ eos_token=self.config.special_tokens[self.config.eos_token_id],
139
+ pad_token=self.config.special_tokens[self.config.pad_token_id],
140
+ extra_ids=0,
141
+ )
142
+
143
+ @property
144
+ def bos_token_id(self):
145
+ return self.config.bos_token_id
146
+
147
+ @property
148
+ def eos_token_id(self):
149
+ return self.config.eos_token_id
150
+
151
+ @property
152
+ def pad_token_id(self):
153
+ return self.config.pad_token_id
154
+
155
+ @property
156
+ def bpe_token_end_id(self):
157
+ return self.config.bpe_token_end_id
158
+
159
+ @property
160
+ def vocab_size(self):
161
+ return self.config.vocab_size
162
+
163
+ def _convert_id_to_token(self, index):
164
+ if index < self.offset:
165
+ return self.config.special_tokens[index - self.special_tokens_offset]
166
+
167
+ if index >= self.offset + 256 and index < self.offset * 2 + 256:
168
+ # special token with fused boundary
169
+ return self.config.special_tokens[index - self.offset - 256] + "b"
170
+
171
+ return _BYTES_TO_CHARS[index - self.offset - 256 - self.offset] + "b" if index >= self.offset + 256 else _BYTES_TO_CHARS[index - self.offset]
172
+
173
+ def _convert_token_to_id(self, token):
174
+ if token in self.config.special_tokens:
175
+ return self.config.special_tokens.index(token)
176
+
177
+ if token in [x + "b" for x in self.config.special_tokens]:
178
+ # special token with fused boundary
179
+ return 256 + self.config.special_tokens.index(token[:-1])
180
+
181
+ if len(token) > 1 and token[-1] == "b":
182
+ return self.offset + 256 + _CHARS_TO_BYTES[token[0]]
183
+ else:
184
+ return self.offset + _CHARS_TO_BYTES[token]
185
+
186
+ def get_vocab(self):
187
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
188
+ return vocab
189
+
190
+ def expand_byte_ids(self, byte_ids: list[int], n_last: Optional[int] = None) -> list[int]:
191
+ # search in the byte tree for the longest matching token at every byte position
192
+ expanded_ids = []
193
+ for i in range(len(byte_ids)):
194
+ if n_last is not None and i < len(byte_ids) - n_last:
195
+ continue
196
+
197
+ current_dict = self.byte_trie
198
+ current_expansion = None
199
+
200
+ for i in range(i, -1, -1):
201
+ byte = byte_ids[i]
202
+
203
+ if byte == self.bpe_token_end_id:
204
+ # skip bpe token end markers, needed for generation
205
+ continue
206
+
207
+ if byte >= self.offset + 256:
208
+ # ignore fused boundary
209
+ byte -= self.offset + 256
210
+
211
+ try:
212
+ current_dict = current_dict[byte]
213
+ if BolmoTokenizer.TOKEN_ID_KEY in current_dict:
214
+ current_expansion = current_dict[BolmoTokenizer.TOKEN_ID_KEY]
215
+ except KeyError:
216
+ assert current_expansion is not None
217
+ break
218
+
219
+ expanded_ids.append(current_expansion)
220
+
221
+ return expanded_ids
222
+
223
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens
224
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
225
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
226
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
227
+
228
+ output = bos_token_id + token_ids_0 + eos_token_id
229
+
230
+ if token_ids_1 is not None:
231
+ output = output + bos_token_id + token_ids_1 + eos_token_id
232
+
233
+ return output
234
+
235
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_special_tokens_mask
236
+ def get_special_tokens_mask(
237
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
238
+ ) -> list[int]:
239
+ """
240
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
241
+ special tokens using the tokenizer `prepare_for_model` method.
242
+ Args:
243
+ token_ids_0 (`List[int]`):
244
+ List of IDs.
245
+ token_ids_1 (`List[int]`, *optional*):
246
+ Optional second list of IDs for sequence pairs.
247
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
248
+ Whether or not the token list is already formatted with special tokens for the model.
249
+ Returns:
250
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
251
+ """
252
+ if already_has_special_tokens:
253
+ return super().get_special_tokens_mask(
254
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
255
+ )
256
+
257
+ bos_token_id = [1] if self.add_bos_token else []
258
+ eos_token_id = [1] if self.add_eos_token else []
259
+
260
+ if token_ids_1 is None:
261
+ return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
262
+ return (
263
+ bos_token_id
264
+ + ([0] * len(token_ids_0))
265
+ + eos_token_id
266
+ + bos_token_id
267
+ + ([0] * len(token_ids_1))
268
+ + eos_token_id
269
+ )
270
+
271
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.create_token_type_ids_from_sequences
272
+ def create_token_type_ids_from_sequences(
273
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
274
+ ) -> list[int]:
275
+ """
276
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
277
+ sequence pair mask has the following format:
278
+ ```
279
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
280
+ | first sequence | second sequence |
281
+ ```
282
+ if token_ids_1 is None, only returns the first portion of the mask (0s).
283
+ Args:
284
+ token_ids_0 (`List[int]`):
285
+ List of ids.
286
+ token_ids_1 (`List[int]`, *optional*):
287
+ Optional second list of IDs for sequence pairs.
288
+ Returns:
289
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
290
+ """
291
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
292
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
293
+
294
+ output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
295
+
296
+ if token_ids_1 is not None:
297
+ output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
298
+
299
+ return output
300
+
301
+ def _tokenize(self, text: str, **kwargs) -> list[str]:
302
+ """Take as input a string and return a list of strings (tokens) for words/sub-words"""
303
+ tokens = self.convert_ids_to_tokens(self._bolmo_encode(text))
304
+ return tokens
305
+
306
+ def _patch_ids_to_byte_ids(self, input_ids: list[int]):
307
+ return [byte_token_id for token_id in input_ids for byte_token_id in self.byte_sequences[token_id]]
308
+
309
+ def _bolmo_encode(self, string: str, add_special_tokens=False):
310
+ input_ids = self.hf_tokenizer.encode(string, add_special_tokens=add_special_tokens)
311
+ return self._patch_ids_to_byte_ids(input_ids)
312
+
313
+ def _bolmo_decode(self, tokens: list[int], skip_special_tokens: bool = False) -> str:
314
+ return self._decode_to_bytes(tokens, skip_special_tokens=skip_special_tokens).decode("utf-8", errors="replace")
315
+
316
+ def _decode_to_bytes(self, tokens: list[int], skip_special_tokens: bool = False) -> bytes:
317
+ tokens_without_boundary = []
318
+ for token in tokens:
319
+ if token >= (self.offset + 256):
320
+ token -= self.offset + 256
321
+
322
+ tokens_without_boundary.append(token)
323
+
324
+ utf8_bytes = []
325
+
326
+ for token in tokens_without_boundary:
327
+ if token < self.offset:
328
+ if skip_special_tokens:
329
+ continue
330
+ else:
331
+ utf8_bytes.extend(self.config.special_tokens[token].encode("utf-8"))
332
+ else:
333
+ utf8_bytes.append(min(token - self.offset, 255))
334
+
335
+ return bytes(utf8_bytes)
336
+
337
+ def get_tokens_and_patch_lengths(self, original_input_ids: list[int], add_bos=False, strip_pad=False, skip_last=False):
338
+ if add_bos and self.bos_token_id is not None:
339
+ byte_tokens = [self.bos_token_id]
340
+ patch_lengths = [1]
341
+ else:
342
+ byte_tokens = []
343
+ patch_lengths = []
344
+
345
+ for idx, token in enumerate(original_input_ids):
346
+ # optionally skip last token to keep the length the same if add_bos=True
347
+ if skip_last and idx == len(original_input_ids) - 1:
348
+ break
349
+
350
+ token_byte_tokens = self._patch_ids_to_byte_ids([int(token)])
351
+
352
+ if strip_pad and all(t == self.pad_token_id for t in token_byte_tokens):
353
+ # skip padding tokens
354
+ continue
355
+
356
+ patch_lengths.append(len(token_byte_tokens))
357
+ byte_tokens.extend(token_byte_tokens)
358
+
359
+ return byte_tokens, patch_lengths
360
+
361
+ def convert_tokens_to_string(self, tokens: list[str]) -> str:
362
+ return self._bolmo_decode(self.convert_tokens_to_ids(tokens), skip_special_tokens=False) # type: ignore
363
+
364
+ def _decode(
365
+ self,
366
+ token_ids: Union[int, list[int]],
367
+ skip_special_tokens: bool = False,
368
+ clean_up_tokenization_spaces: Optional[bool] = None,
369
+ spaces_between_special_tokens: bool = True,
370
+ **kwargs,
371
+ ) -> str:
372
+ if isinstance(token_ids, int):
373
+ token_ids = [token_ids]
374
+
375
+ return self._bolmo_decode(token_ids, skip_special_tokens=skip_special_tokens)
376
+
377
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
378
+ return () # type: ignore
tokenizer_config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<pad>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<bos>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ }
19
+ },
20
+ "auto_map": {
21
+ "AutoTokenizer": [
22
+ "tokenization_bolmo.BolmoTokenizer",
23
+ null
24
+ ]
25
+ },
26
+ "bos_token": "<bos>",
27
+ "clean_up_tokenization_spaces": false,
28
+ "eos_token": "<bos>",
29
+ "extra_ids": 0,
30
+ "extra_special_tokens": {},
31
+ "model_max_length": 1000000000000000019884624838656,
32
+ "pad_token": "<pad>",
33
+ "tokenizer_class": "BolmoTokenizer"
34
+ }
utils_bolmo.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ def compute_boundary_mask(boundary_logprobs: torch.Tensor, boundary_threshold: str) -> torch.Tensor:
8
+ if boundary_threshold.startswith("sample:"):
9
+ _, temperature = boundary_threshold.split(":")
10
+ temperature = float(temperature)
11
+
12
+ if temperature == 0:
13
+ return (boundary_logprobs > math.log(0.5))
14
+ elif temperature == 1:
15
+ return torch.bernoulli(torch.exp(boundary_logprobs)).to(torch.bool)
16
+ else:
17
+ raise NotImplementedError("Temperatures outside {0,1} are not implemented yet.")
18
+ elif boundary_threshold.startswith("topk:"):
19
+ _, topk = boundary_threshold.split(":")
20
+ topk = int(topk)
21
+ thresholds = torch.quantile(boundary_logprobs, dim=1, q=1 - (topk / boundary_logprobs.shape[1]))
22
+ return (boundary_logprobs >= thresholds.unsqueeze(-1))
23
+ elif boundary_threshold.startswith("topk_percent:"):
24
+ _, topk_percent = boundary_threshold.split(":")
25
+ topk_percent = float(topk_percent)
26
+ assert 0 <= topk_percent <= 1
27
+ thresholds = torch.quantile(boundary_logprobs, dim=1, q=1 - topk_percent)
28
+ return (boundary_logprobs >= thresholds.unsqueeze(-1))
29
+ else:
30
+ raise ValueError(f"Unknown boundary threshold: {boundary_threshold}")
31
+
32
+
33
+ def _pad(tensors: list[torch.Tensor], multiple_of: int, direction: str, value):
34
+ max_len = max(t.size(0) for t in tensors)
35
+ if multiple_of > 1:
36
+ # Round up max_len to the nearest multiple_of
37
+ max_len = ((max_len + multiple_of - 1) // multiple_of) * multiple_of
38
+ padded = []
39
+ for t in tensors:
40
+ if direction == "left":
41
+ pad_shape = (max_len - t.size(0), 0)
42
+ elif direction == "right":
43
+ pad_shape = (0, max_len - t.size(0))
44
+ else:
45
+ raise ValueError(f"Unknown direction: {direction}. Must be 'left' or 'right'.")
46
+ padded.append(F.pad(t, pad_shape, value=value))
47
+ return torch.stack(padded, dim=0)
48
+
49
+ def pad_right(
50
+ tensors: list[torch.Tensor],
51
+ multiple_of: int = 128,
52
+ value=0,
53
+ ):
54
+ return _pad(tensors, multiple_of, direction="right", value=value)
55
+
56
+ def pad_left(
57
+ tensors: list[torch.Tensor],
58
+ multiple_of: int = 128,
59
+ value=0,
60
+ ):
61
+ return _pad(tensors, multiple_of, direction="left", value=value)
62
+
63
+ class MaskState:
64
+ def __init__(self, mask):
65
+ self.cpu_mask = mask.cpu()
66
+
67
+ self.mask = mask
68
+ self.inv_mask = ~mask
69
+ self._all = self.cpu_mask.all().item()
70
+ self._any = self.cpu_mask.any().item()
71
+
72
+ def any(self):
73
+ return self._any
74
+
75
+ def all(self):
76
+ return self._all
77
+
78
+ def selective_get(self, x, inv=False):
79
+ # try to avoid sync through nonzero on index
80
+ if inv:
81
+ if self.all():
82
+ return x[[]]
83
+ elif not self.any():
84
+ return x
85
+ else:
86
+ return x[self.inv_mask]
87
+ else:
88
+ if self.all():
89
+ return x
90
+ elif not self.any():
91
+ return x[[]]
92
+ else:
93
+ return x[self.mask]
94
+
95
+ def selective_put(self, x, out, inv=False):
96
+ # try to avoid sync through nonzero on index
97
+ if inv:
98
+ if self.all():
99
+ return
100
+ elif not self.any():
101
+ out.copy_(x)
102
+ else:
103
+ out[self.inv_mask] = x
104
+ else:
105
+ if self.all():
106
+ out.copy_(x)
107
+ elif not self.any():
108
+ return
109
+ else:
110
+ out[self.mask] = x
111
+
112
+ def selective_add(self, x, out, inv=False):
113
+ # try to avoid sync through nonzero on index
114
+ if inv:
115
+ if self.all():
116
+ return
117
+ elif not self.any():
118
+ out.add_(x)
119
+ else:
120
+ out[self.inv_mask] += x
121
+ else:
122
+ if self.all():
123
+ out.add_(x)
124
+ elif not self.any():
125
+ return
126
+ else:
127
+ out[self.mask] += x