Student0809 commited on
Commit
326a7fe
·
verified ·
1 Parent(s): c076144

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. docs/transformers/src/transformers/models/myt5/tokenization_myt5.py +380 -0
  2. docs/transformers/src/transformers/models/nemotron/configuration_nemotron.py +156 -0
  3. docs/transformers/src/transformers/models/nemotron/convert_nemotron_nemo_to_hf.py +346 -0
  4. docs/transformers/src/transformers/models/nllb/tokenization_nllb.py +394 -0
  5. docs/transformers/src/transformers/models/nllb_moe/__init__.py +27 -0
  6. docs/transformers/src/transformers/models/nllb_moe/configuration_nllb_moe.py +219 -0
  7. docs/transformers/src/transformers/models/nllb_moe/convert_nllb_moe_sharded_original_checkpoint_to_pytorch.py +161 -0
  8. docs/transformers/src/transformers/models/nllb_moe/modeling_nllb_moe.py +1784 -0
  9. docs/transformers/src/transformers/models/nougat/__init__.py +28 -0
  10. docs/transformers/src/transformers/models/nougat/convert_nougat_to_hf.py +282 -0
  11. docs/transformers/src/transformers/models/nougat/image_processing_nougat.py +525 -0
  12. docs/transformers/src/transformers/models/nougat/processing_nougat.py +163 -0
  13. docs/transformers/src/transformers/models/nougat/tokenization_nougat_fast.py +620 -0
  14. docs/transformers/src/transformers/models/nystromformer/__init__.py +27 -0
  15. docs/transformers/src/transformers/models/nystromformer/configuration_nystromformer.py +132 -0
  16. docs/transformers/src/transformers/models/nystromformer/convert_nystromformer_original_pytorch_checkpoint_to_pytorch.py +111 -0
  17. docs/transformers/src/transformers/models/nystromformer/modeling_nystromformer.py +1124 -0
  18. docs/transformers/src/transformers/models/olmo/__init__.py +27 -0
  19. docs/transformers/src/transformers/models/olmo/configuration_olmo.py +198 -0
  20. docs/transformers/src/transformers/models/olmo/convert_olmo_weights_to_hf.py +248 -0
  21. docs/transformers/src/transformers/models/olmo/modeling_olmo.py +814 -0
  22. docs/transformers/src/transformers/models/olmo/modular_olmo.py +148 -0
  23. docs/transformers/src/transformers/models/olmo2/__init__.py +27 -0
  24. docs/transformers/src/transformers/models/olmo2/configuration_olmo2.py +180 -0
  25. docs/transformers/src/transformers/models/olmo2/convert_olmo2_weights_to_hf.py +306 -0
  26. docs/transformers/src/transformers/models/olmo2/modeling_olmo2.py +820 -0
  27. docs/transformers/src/transformers/models/olmo2/modular_olmo2.py +320 -0
  28. docs/transformers/src/transformers/models/olmoe/__init__.py +27 -0
  29. docs/transformers/src/transformers/models/olmoe/configuration_olmoe.py +182 -0
  30. docs/transformers/src/transformers/models/olmoe/convert_olmoe_weights_to_hf.py +281 -0
  31. docs/transformers/src/transformers/models/olmoe/modeling_olmoe.py +1273 -0
  32. docs/transformers/src/transformers/models/omdet_turbo/__init__.py +28 -0
  33. docs/transformers/src/transformers/models/omdet_turbo/configuration_omdet_turbo.py +293 -0
  34. docs/transformers/src/transformers/models/omdet_turbo/convert_omdet_turbo_to_hf.py +349 -0
  35. docs/transformers/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py +1711 -0
  36. docs/transformers/src/transformers/models/omdet_turbo/processing_omdet_turbo.py +415 -0
  37. docs/transformers/src/transformers/models/oneformer/__init__.py +29 -0
  38. docs/transformers/src/transformers/models/oneformer/configuration_oneformer.py +277 -0
  39. docs/transformers/src/transformers/models/oneformer/convert_to_hf_oneformer.py +1191 -0
  40. docs/transformers/src/transformers/models/oneformer/image_processing_oneformer.py +1356 -0
  41. docs/transformers/src/transformers/models/oneformer/modeling_oneformer.py +0 -0
  42. docs/transformers/src/transformers/models/oneformer/processing_oneformer.py +207 -0
  43. docs/transformers/src/transformers/models/openai/__init__.py +30 -0
  44. docs/transformers/src/transformers/models/openai/configuration_openai.py +156 -0
  45. docs/transformers/src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py +74 -0
  46. docs/transformers/src/transformers/models/openai/modeling_openai.py +967 -0
  47. docs/transformers/src/transformers/models/openai/modeling_tf_openai.py +937 -0
  48. docs/transformers/src/transformers/models/openai/tokenization_openai.py +396 -0
  49. docs/transformers/src/transformers/models/openai/tokenization_openai_fast.py +66 -0
  50. docs/transformers/src/transformers/models/opt/__init__.py +29 -0
docs/transformers/src/transformers/models/myt5/tokenization_myt5.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization class for model MyT5."""
16
+
17
+ import json
18
+ import os
19
+ import warnings
20
+ from collections import defaultdict
21
+ from typing import Dict, List, Optional, Tuple, Union
22
+
23
+ from ...tokenization_utils import AddedToken, PreTrainedTokenizer
24
+ from ...utils import logging
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ VOCAB_FILES_NAMES = {"vocab_file": "byte_maps.json"}
31
+
32
+
33
+ class ByteRewriter:
34
+ """
35
+ Byte rewriter class for MyT5 tokenizer.
36
+ This class is used to rewrite bytes using a hash tree. The hash tree is constructed from a set of rewriting rules.
37
+
38
+ Args:
39
+ rewriting_rules (`str` or `Dict[str, str]`):
40
+ A path to a json file containing the rewriting rules or a dictionary containing the rewriting rules.
41
+
42
+ """
43
+
44
+ LEAF = "[LEAF]"
45
+
46
+ def __init__(self, rewriting_rules: Union[str, Dict[str, str]]):
47
+ if isinstance(rewriting_rules, str):
48
+ with open(rewriting_rules, "r") as f:
49
+ rewriting_rules = json.load(f)
50
+ elif not isinstance(rewriting_rules, dict):
51
+ raise ValueError(
52
+ f"rewriting_rules should be either a path to json file or a dict, got {type(rewriting_rules)}"
53
+ )
54
+
55
+ self.hash_tree = self.construct_hash_tree(rewriting_rules)
56
+ reverse_rewriting_rules = {v: k for k, v in rewriting_rules.items()}
57
+ self.reverse_hash_tree = self.construct_hash_tree(reverse_rewriting_rules)
58
+
59
+ def add_leaf(self, hash_tree: Dict[str, Union[dict, List[str]]], byte_in_sequence: str, byte_out_sequence: str):
60
+ """
61
+ Add a leaf with the output byte sequence to the hash tree.
62
+ """
63
+ byte_in_list = byte_in_sequence.split(" ")
64
+ byte_out_list = byte_out_sequence.split(" ")
65
+
66
+ tree_pointer = hash_tree
67
+ for b in byte_in_list:
68
+ if b not in tree_pointer:
69
+ tree_pointer[b] = {}
70
+ tree_pointer = tree_pointer[b]
71
+
72
+ tree_pointer[self.LEAF] = byte_out_list
73
+
74
+ def construct_hash_tree(self, rewriting_rules: Dict[str, str]) -> Dict[str, Union[dict, List[str]]]:
75
+ """
76
+ Construct a hash tree for rewritten byte sequences.
77
+ """
78
+ hash_tree = defaultdict(dict)
79
+ for b in (f"{x:02x}" for x in range(256)):
80
+ hash_tree[b][self.LEAF] = [b]
81
+
82
+ for in_sequence, out_sequence in rewriting_rules.items():
83
+ self.add_leaf(hash_tree, in_sequence, out_sequence)
84
+
85
+ return hash_tree
86
+
87
+ def search_hash_tree(self, byte_sequence: List[str]) -> Union[None, List[str]]:
88
+ """
89
+ Search the hash tree and return the rewritten byte sequence if found.
90
+ """
91
+ tree_pointer = self.hash_tree
92
+ for b in byte_sequence:
93
+ if b in tree_pointer:
94
+ tree_pointer = tree_pointer[b]
95
+ else:
96
+ return None
97
+
98
+ return tree_pointer[self.LEAF]
99
+
100
+ def rewrite_bytes(self, in_bytes: List[str], reverse=False) -> List[str]:
101
+ """
102
+ Rewrite a sequence of bytes using the hash tree.
103
+
104
+ Args:
105
+ in_bytes (`List[str]`): A list of bytes to be rewritten.
106
+ reverse (`bool`): If True, decoding is performed with the reverse hash tree.
107
+ Returns:
108
+ `List[str]`: The rewritten byte sequence.
109
+ """
110
+ out_bytes = []
111
+ b_start = 0
112
+ b_end = 0
113
+
114
+ while b_start < len(in_bytes):
115
+ tree_pointer = self.hash_tree if not reverse else self.reverse_hash_tree
116
+ for j in range(b_start, len(in_bytes)):
117
+ b = in_bytes[j]
118
+ if b in tree_pointer:
119
+ tree_pointer = tree_pointer[b]
120
+ elif j == b_start:
121
+ cur_leaf = [b]
122
+ b_end = j
123
+ break
124
+ else:
125
+ break
126
+ if self.LEAF in tree_pointer:
127
+ cur_leaf = tree_pointer[self.LEAF]
128
+ b_end = j
129
+ out_bytes.extend(cur_leaf)
130
+ b_start = b_end + 1
131
+
132
+ return out_bytes
133
+
134
+
135
+ class MyT5Tokenizer(PreTrainedTokenizer):
136
+ """
137
+ Construct a MyT5 tokenizer.
138
+
139
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
140
+ this superclass for more information regarding those methods.
141
+
142
+ Args:
143
+ vocab_file (`str`): The file containing the byte rewriting rules.
144
+ eos_token (`str`, *optional*, defaults to `"</s>"`):
145
+ The end of sequence token.
146
+
147
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
148
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
149
+ token instead.
150
+ pad_token (`str`, *optional*, defaults to `"<pad>"`):
151
+ The token used for padding, for example when batching sequences of different lengths.
152
+ extra_ids (`int`, *optional*, defaults to 125):
153
+ Add a number of extra ids added to the end of the vocabulary for use as sentinels. These tokens are
154
+ accessible as "<extra_id_{%d}>" where "{%d}" is a number between 0 and extra_ids-1. Extra tokens are
155
+ indexed from the end of the vocabulary up to beginning ("<extra_id_0>" is the last token in the vocabulary
156
+ like in ByT5 preprocessing see
157
+ [here](https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/t5/data/preprocessors.py#L2117)).
158
+ additional_special_tokens (`List[str]`, *optional*):
159
+ Additional special tokens used by the tokenizer.
160
+ """
161
+
162
+ model_input_names = ["input_ids", "attention_mask"]
163
+ vocab_files_names = VOCAB_FILES_NAMES
164
+
165
+ def __init__(
166
+ self,
167
+ vocab_file,
168
+ eos_token="</s>",
169
+ unk_token="<unk>",
170
+ pad_token="<pad>",
171
+ extra_ids=125,
172
+ additional_special_tokens=None,
173
+ **kwargs,
174
+ ) -> None:
175
+ # Add extra_ids to the special token list
176
+ if extra_ids > 0 and additional_special_tokens is None:
177
+ additional_special_tokens = [f"<extra_id_{i}>" for i in range(extra_ids)]
178
+ elif extra_ids > 0 and additional_special_tokens is not None and len(additional_special_tokens) > 0:
179
+ # Check that we have the right number of extra_id special tokens
180
+ extra_tokens = len(set(filter(lambda x: bool("extra_id" in str(x)), additional_special_tokens)))
181
+ if extra_tokens != extra_ids:
182
+ raise ValueError(
183
+ f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are"
184
+ " provided to MyT5Tokenizer. In this case the additional_special_tokens must include the"
185
+ " extra_ids tokens"
186
+ )
187
+
188
+ pad_token = AddedToken(pad_token, lstrip=True, rstrip=True) if isinstance(pad_token, str) else pad_token
189
+ eos_token = AddedToken(eos_token, lstrip=True, rstrip=True) if isinstance(eos_token, str) else eos_token
190
+ unk_token = AddedToken(unk_token, lstrip=True, rstrip=True) if isinstance(unk_token, str) else unk_token
191
+ # unk token needs to be in the vocab with correct index
192
+ self._added_tokens_decoder = {0: pad_token, 1: eos_token, 2: unk_token}
193
+ self.offset = len(self._added_tokens_decoder)
194
+ self._utf_vocab_size = 2**8 # utf is 8 bits
195
+
196
+ # Load byte maps
197
+ self.byte_maps = json.load(open(vocab_file, "r"))
198
+
199
+ self.decompose_rewriter = ByteRewriter(self.byte_maps["decompose_map"])
200
+ self.merge_rewriter = ByteRewriter(self.byte_maps["merge_map"])
201
+
202
+ super().__init__(
203
+ eos_token=eos_token,
204
+ unk_token=unk_token,
205
+ pad_token=pad_token,
206
+ extra_ids=0,
207
+ additional_special_tokens=additional_special_tokens,
208
+ **kwargs,
209
+ )
210
+
211
+ @property
212
+ def vocab_size(self):
213
+ return self._utf_vocab_size
214
+
215
+ # Copied from transformers.models.byt5.tokenization_byt5.ByT5Tokenizer.get_vocab
216
+ def get_vocab(self):
217
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size + self.offset)}
218
+ vocab.update(self.added_tokens_encoder)
219
+ return vocab
220
+
221
+ # Copied from transformers.models.byt5.tokenization_byt5.ByT5Tokenizer.get_special_tokens_mask
222
+ def get_special_tokens_mask(
223
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
224
+ ) -> List[int]:
225
+ """
226
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
227
+ special tokens using the tokenizer `prepare_for_model` method.
228
+
229
+ Args:
230
+ token_ids_0 (`List[int]`):
231
+ List of IDs.
232
+ token_ids_1 (`List[int]`, *optional*):
233
+ Optional second list of IDs for sequence pairs.
234
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
235
+ Whether or not the token list is already formatted with special tokens for the model.
236
+
237
+ Returns:
238
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
239
+ """
240
+ if already_has_special_tokens:
241
+ return super().get_special_tokens_mask(
242
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
243
+ )
244
+
245
+ # normal case: some special tokens
246
+ if token_ids_1 is None:
247
+ return ([0] * len(token_ids_0)) + [1]
248
+ return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
249
+
250
+ def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]:
251
+ """Do not add eos again if user already added it."""
252
+ if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id:
253
+ warnings.warn(
254
+ f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated"
255
+ " eos tokens being added."
256
+ )
257
+ return token_ids
258
+ else:
259
+ return token_ids + [self.eos_token_id]
260
+
261
+ def create_token_type_ids_from_sequences(
262
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
263
+ ) -> List[int]:
264
+ """
265
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. MyT5 does not
266
+ make use of token type ids, therefore a list of zeros is returned.
267
+
268
+ Args:
269
+ token_ids_0 (`List[int]`):
270
+ List of IDs.
271
+ token_ids_1 (`List[int]`, *optional*):
272
+ Optional second list of IDs for sequence pairs.
273
+
274
+ Returns:
275
+ `List[int]`: List of zeros.
276
+ """
277
+ eos = [self.eos_token_id]
278
+
279
+ if token_ids_1 is None:
280
+ return len(token_ids_0 + eos) * [0]
281
+ return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
282
+
283
+ # Copied from transformers.models.byt5.tokenization_byt5.ByT5Tokenizer.build_inputs_with_special_tokens
284
+ def build_inputs_with_special_tokens(
285
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
286
+ ) -> List[int]:
287
+ """
288
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
289
+ adding special tokens. A sequence has the following format:
290
+
291
+ - single sequence: `X </s>`
292
+ - pair of sequences: `A </s> B </s>`
293
+
294
+ Args:
295
+ token_ids_0 (`List[int]`):
296
+ List of IDs to which the special tokens will be added.
297
+ token_ids_1 (`List[int]`, *optional*):
298
+ Optional second list of IDs for sequence pairs.
299
+
300
+ Returns:
301
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
302
+ """
303
+ token_ids_0 = self._add_eos_if_not_present(token_ids_0)
304
+ if token_ids_1 is None:
305
+ return token_ids_0
306
+ else:
307
+ token_ids_1 = self._add_eos_if_not_present(token_ids_1)
308
+ return token_ids_0 + token_ids_1
309
+
310
+ def _tokenize(self, text: str, **kwargs) -> List[str]:
311
+ """Take as input a string and return a list of strings (tokens) for words/sub-words.
312
+ Represents tokens in two character hex format"""
313
+
314
+ tokens = [f"{i:02x}" for i in text.encode("utf-8")]
315
+ tokens = self.morphological_encode(tokens)
316
+ return tokens
317
+
318
+ def _convert_token_to_id(self, token):
319
+ """Converts a token (str) in an id using the vocab."""
320
+
321
+ if len(token) != 2:
322
+ token_id = None
323
+ else:
324
+ token_id = int(token, 16) + self.offset
325
+
326
+ return token_id
327
+
328
+ def _convert_id_to_token(self, index):
329
+ """Converts an index (integer) in a token (str) using the vocab."""
330
+ token = f"{index - self.offset:02x}"
331
+ return token
332
+
333
+ def morphological_encode(self, indices: List[str]) -> List[str]:
334
+ # Decompose and merge morphological sequences
335
+ indices = self.decompose_rewriter.rewrite_bytes(indices, reverse=False)
336
+ indices = self.merge_rewriter.rewrite_bytes(indices, reverse=False)
337
+ return indices
338
+
339
+ def morphological_decode(self, indices: List[str]) -> List[str]:
340
+ # Demerge and compose morphological sequences
341
+ indices = self.merge_rewriter.rewrite_bytes(indices, reverse=True)
342
+ indices = self.decompose_rewriter.rewrite_bytes(indices, reverse=True)
343
+ return indices
344
+
345
+ def convert_tokens_to_string(self, tokens):
346
+ """Converts a sequence of tokens (string) in a single string."""
347
+ bstring = b""
348
+
349
+ out_tokens = []
350
+ for token in tokens:
351
+ if token in self.added_tokens_decoder:
352
+ out_tokens.append(self.added_tokens_decoder[token])
353
+ elif token in self.added_tokens_encoder:
354
+ out_tokens.append(token)
355
+ else:
356
+ out_tokens.append(token)
357
+
358
+ out_tokens = self.morphological_decode(out_tokens)
359
+ _added_tokens = set(self.added_tokens_decoder.values()) | set(self.added_tokens_encoder)
360
+ for token in out_tokens:
361
+ if token in _added_tokens:
362
+ bstring += bytes(token, "utf-8")
363
+ else:
364
+ bstring += bytes.fromhex(token)
365
+ string = bstring.decode("utf-8", errors="ignore")
366
+ return string
367
+
368
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
369
+ if os.path.isdir(save_directory):
370
+ vocab_file = os.path.join(
371
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
372
+ )
373
+ else:
374
+ vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
375
+ with open(vocab_file, "w", encoding="utf-8") as writer:
376
+ writer.write(json.dumps(self.byte_maps, indent=2, ensure_ascii=False))
377
+ return (vocab_file,)
378
+
379
+
380
+ __all__ = ["MyT5Tokenizer"]
docs/transformers/src/transformers/models/nemotron/configuration_nemotron.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 HuggingFace Inc. team. All rights reserved.
3
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """Nemotron model configuration"""
17
+
18
+ from ...configuration_utils import PretrainedConfig
19
+ from ...modeling_rope_utils import rope_config_validation
20
+ from ...utils import logging
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ class NemotronConfig(PretrainedConfig):
27
+ r"""
28
+ This is the configuration class to store the configuration of a [`NemotronModel`]. It is used to instantiate an Nemotron
29
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
30
+ defaults will yield a similar configuration to that of the Nemotron-8B.
31
+ e.g. [nvidia/nemotron-3-8b-base-4k-hf](https://huggingface.co/nvidia/nemotron-3-8b-base-4k-hf).
32
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
33
+ documentation from [`PretrainedConfig`] for more information.
34
+
35
+
36
+ Args:
37
+ vocab_size (`int`, *optional*, defaults to 256000):
38
+ Vocabulary size of the Nemotron model. Defines the number of different tokens that can be represented by the
39
+ `inputs_ids` passed when calling [`NemotronModel`]
40
+ hidden_size (`int`, *optional*, defaults to 6144):
41
+ Dimension of the hidden representations.
42
+ intermediate_size (`int`, *optional*, defaults to 24576):
43
+ Dimension of the MLP representations.
44
+ num_hidden_layers (`int`, *optional*, defaults to 32):
45
+ Number of hidden layers in the Transformer decoder.
46
+ num_attention_heads (`int`, *optional*, defaults to 48):
47
+ Number of attention heads for each attention layer in the Transformer decoder.
48
+ head_dim (`int`, *optional*):
49
+ Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if None
50
+ num_key_value_heads (`int`, *optional*):
51
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
52
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
53
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
54
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
55
+ by meanpooling all the original heads within that group. For more details checkout [this
56
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
57
+ `num_attention_heads`.
58
+ hidden_act (`str` or `function`, *optional*, defaults to `"relu2"`):
59
+ The non-linear activation function (function or string) in the decoder.
60
+ max_position_embeddings (`int`, *optional*, defaults to 4096):
61
+ The maximum sequence length that this model might ever be used with.
62
+ initializer_range (`float`, *optional*, defaults to 0.0134):
63
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
64
+ norm_eps (`float`, *optional*, defaults to 1e-05):
65
+ The epsilon used by the normalization layers.
66
+ use_cache (`bool`, *optional*, defaults to `True`):
67
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
68
+ relevant if `config.is_decoder=True`.
69
+ pad_token_id (`int`, *optional*):
70
+ Padding token id.
71
+ bos_token_id (`int`, *optional*, defaults to 2):
72
+ Beginning of stream token id.
73
+ eos_token_id (`int`, *optional*, defaults to 3):
74
+ End of stream token id.
75
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
76
+ Whether to tie weight embeddings
77
+ rope_theta (`float`, *optional*, defaults to 10000.0):
78
+ The base period of the RoPE embeddings.
79
+ partial_rotary_factor (`float`, *optional*, defaults to 0.5): Percentage of the query and keys which will have rotary embedding.
80
+ attention_bias (`bool`, *optional*, defaults to `False`):
81
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
82
+ attention_dropout (`float`, *optional*, defaults to 0.0):
83
+ The dropout ratio for the attention probabilities.
84
+ mlp_bias (`bool`, *optional*, defaults to `False`):
85
+ Whether to use a bias in up_proj and down_proj layers in the MLP layers.
86
+
87
+ ```python
88
+ >>> from transformers import NemotronModel, NemotronConfig
89
+
90
+ >>> # Initializing a Nemotron nemotron-15b style configuration
91
+ >>> configuration = NemotronConfig()
92
+
93
+ >>> # Initializing a model from the nemotron-15b style configuration
94
+ >>> model = NemotronModel(configuration)
95
+
96
+ >>> # Accessing the model configuration
97
+ >>> configuration = model.config
98
+ ```"""
99
+
100
+ model_type = "nemotron"
101
+ keys_to_ignore_at_inference = ["past_key_values"]
102
+
103
+ def __init__(
104
+ self,
105
+ vocab_size=256000,
106
+ hidden_size=6144,
107
+ intermediate_size=24576,
108
+ num_hidden_layers=32,
109
+ num_attention_heads=48,
110
+ head_dim=None,
111
+ num_key_value_heads=None,
112
+ hidden_act="relu2",
113
+ max_position_embeddings=4096,
114
+ initializer_range=0.0134,
115
+ norm_eps=1e-5,
116
+ use_cache=True,
117
+ pad_token_id=None,
118
+ bos_token_id=2,
119
+ eos_token_id=3,
120
+ tie_word_embeddings=False,
121
+ rope_theta=10000.0,
122
+ partial_rotary_factor=0.5,
123
+ attention_bias=False,
124
+ attention_dropout=0.0,
125
+ mlp_bias=False,
126
+ **kwargs,
127
+ ):
128
+ self.vocab_size = vocab_size
129
+ self.max_position_embeddings = max_position_embeddings
130
+ self.hidden_size = hidden_size
131
+ self.intermediate_size = intermediate_size
132
+ self.num_hidden_layers = num_hidden_layers
133
+ self.num_attention_heads = num_attention_heads
134
+ self.head_dim = head_dim if head_dim is not None else hidden_size // num_attention_heads
135
+ self.num_key_value_heads = num_key_value_heads
136
+ self.hidden_act = hidden_act
137
+ self.initializer_range = initializer_range
138
+ self.norm_eps = norm_eps
139
+ self.use_cache = use_cache
140
+ self.rope_theta = rope_theta
141
+ self.partial_rotary_factor = partial_rotary_factor
142
+ rope_config_validation(self)
143
+ self.attention_bias = attention_bias
144
+ self.attention_dropout = attention_dropout
145
+ self.mlp_bias = mlp_bias
146
+
147
+ super().__init__(
148
+ pad_token_id=pad_token_id,
149
+ bos_token_id=bos_token_id,
150
+ eos_token_id=eos_token_id,
151
+ tie_word_embeddings=tie_word_embeddings,
152
+ **kwargs,
153
+ )
154
+
155
+
156
+ __all__ = ["NemotronConfig"]
docs/transformers/src/transformers/models/nemotron/convert_nemotron_nemo_to_hf.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import os
17
+ import shutil
18
+ from argparse import ArgumentParser
19
+ from collections import OrderedDict
20
+
21
+ import torch
22
+ from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
23
+ from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
24
+ from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy
25
+ from nemo.utils import logging
26
+ from pytorch_lightning import Trainer
27
+
28
+ from transformers import LlamaTokenizer, PreTrainedTokenizerFast
29
+ from transformers.convert_slow_tokenizer import LlamaConverter
30
+
31
+
32
+ """
33
+ Script to convert a nemotron checkpoint in nemo (mcore path) into a HuggingFace checkpoint.
34
+ This script can be used to 1) generate only the HF weights, or 2) generate an entire HF model folder.
35
+
36
+ 1) Generate only HF weights from a nemo file:
37
+
38
+ python convert_nemotron_nemo_to_hf.py \
39
+ --input_name_or_path /path/to/file.nemo or /path/to/extracted_folder \
40
+ --output_path /path/to/pytorch_model.bin
41
+
42
+ 2) Generate the full HF model folder
43
+
44
+ python convert_nemotron_nemo_to_hf.py \
45
+ --input_name_or_path /path/to/file.nemo or /path/to/extracted_folder \
46
+ --hf_input_path /path/to/input_hf_folder \
47
+ --hf_output_path /path/to/output_hf_folder \
48
+
49
+ Use the --cpu-only flag if the model cannot fit in the GPU (e.g. Nemotron4 340b).
50
+ However this option makes the conversion script significantly slower.
51
+ """
52
+
53
+
54
+ def get_args():
55
+ parser = ArgumentParser()
56
+ parser.add_argument(
57
+ "--input_name_or_path",
58
+ type=str,
59
+ default=None,
60
+ required=True,
61
+ help="Path to .nemo file or extracted folder",
62
+ )
63
+ parser.add_argument("--output_path", type=str, default=None, required=False, help="Path to HF .bin file")
64
+ parser.add_argument(
65
+ "--hf_input_path",
66
+ type=str,
67
+ default=None,
68
+ help="A HF model path, e.g. a folder containing https://huggingface.co/nvidia/Minitron-8B-Base",
69
+ )
70
+ parser.add_argument(
71
+ "--hf_output_path",
72
+ type=str,
73
+ default=None,
74
+ help="Output HF model path, with the same format as above but user's own weights",
75
+ )
76
+ parser.add_argument(
77
+ "--precision",
78
+ type=str,
79
+ default=None,
80
+ help="Precision of output weights."
81
+ "Defaults to precision of the input nemo weights (model.cfg.trainer.precision)",
82
+ )
83
+ parser.add_argument(
84
+ "--cpu-only",
85
+ action="store_true",
86
+ help="Load model in cpu only. Useful if the model cannot fit in GPU memory, "
87
+ "but this option makes the conversion script significantly slower.",
88
+ )
89
+ args = parser.parse_args()
90
+ return args
91
+
92
+
93
+ def convert_hf_config(nemo_config, tokenizer, vocab_size, dtype, hf_output_path, hf_url="nvidia/Minitron-8B-Base"):
94
+ """
95
+ Convert NeMo config to HF config
96
+ """
97
+ NEMO_ACT2HF = {
98
+ "squared-relu": "relu2",
99
+ "fast-swiglu": "silu",
100
+ }
101
+ DTYPE2HF = {
102
+ torch.bfloat16: "bfloat16",
103
+ torch.float16: "float16",
104
+ torch.float32: "float32",
105
+ }
106
+ hf_config = {
107
+ "_name_or_path": hf_url,
108
+ "architectures": ["NemotronForCausalLM"],
109
+ "bos_token_id": tokenizer.bos_id,
110
+ "eos_token_id": tokenizer.eos_id,
111
+ "hidden_act": NEMO_ACT2HF[nemo_config.activation],
112
+ "hidden_size": nemo_config.hidden_size,
113
+ "initializer_range": nemo_config.init_method_std,
114
+ "intermediate_size": nemo_config.ffn_hidden_size,
115
+ "max_position_embeddings": nemo_config.max_position_embeddings,
116
+ "model_type": "nemotron",
117
+ "num_attention_heads": nemo_config.num_attention_heads,
118
+ "num_hidden_layers": nemo_config.num_layers,
119
+ "num_key_value_heads": nemo_config.get("num_query_groups", nemo_config.num_attention_heads),
120
+ "norm_eps": nemo_config.layernorm_epsilon,
121
+ "rope_theta": nemo_config.get("rotary_base", 10000),
122
+ "partial_rotary_factor": nemo_config.get("rotary_percentage", 1.0),
123
+ "tie_word_embeddings": False,
124
+ "torch_dtype": DTYPE2HF[dtype],
125
+ "transformers_version": "4.32.0.dev0", # TODO
126
+ "use_cache": True,
127
+ "vocab_size": vocab_size,
128
+ }
129
+ if nemo_config.kv_channels is not None:
130
+ hf_config["kv_channels"] = nemo_config.kv_channels
131
+ json.dump(hf_config, open(f"{hf_output_path}/config.json", "w"), indent=2)
132
+
133
+
134
+ def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) -> None:
135
+ """
136
+ Convert NeMo weights to HF weights
137
+ """
138
+ dummy_trainer = Trainer(devices=1, accelerator="cpu", strategy=NLPDDPStrategy())
139
+ model_config = MegatronGPTModel.restore_from(input_nemo_file, trainer=dummy_trainer, return_config=True)
140
+ model_config.tensor_model_parallel_size = 1
141
+ model_config.pipeline_model_parallel_size = 1
142
+ model_config.sequence_parallel = False
143
+ model_config.transformer_engine = True
144
+ if cpu_only:
145
+ map_location = torch.device("cpu")
146
+ model_config.use_cpu_initialization = True
147
+ model_config.dist_ckpt_load_on_device = False
148
+ else:
149
+ map_location = None
150
+
151
+ if cpu_only:
152
+ logging.info("******** Loading model on CPU. This will take a significant amount of time.")
153
+
154
+ model = MegatronGPTModel.restore_from(
155
+ input_nemo_file, trainer=dummy_trainer, override_config_path=model_config, map_location=map_location
156
+ )
157
+
158
+ vocab_size = model.padded_vocab_size
159
+
160
+ if precision is None:
161
+ precision = model.cfg.precision
162
+ if precision in [32, "32"]:
163
+ dtype = torch.float32
164
+ elif precision in [16, "16", "16-mixed"]:
165
+ dtype = torch.float16
166
+ elif precision in ["bf16", "bf16-mixed"]:
167
+ dtype = torch.bfloat16
168
+ else:
169
+ logging.warning(f"Precision string {precision} is not recognized, falling back to fp32")
170
+ dtype = torch.float32 # fallback
171
+ logging.info(f"Using precision {dtype}")
172
+
173
+ def param_to_weights(param):
174
+ return param.to(dtype)
175
+
176
+ checkpoint = OrderedDict()
177
+
178
+ hidden_size = model.cfg.hidden_size
179
+ head_num = model.cfg.num_attention_heads
180
+ num_layers = model.cfg.num_layers
181
+ ffn_hidden_size = model.cfg.ffn_hidden_size
182
+ num_query_groups = model.cfg.get("num_query_groups", head_num) # different num_query_groups for 70B
183
+ if num_query_groups is None:
184
+ num_query_groups = head_num
185
+ heads_per_group = head_num // num_query_groups
186
+ qkv_total_dim = head_num + 2 * num_query_groups
187
+
188
+ # Embedding
189
+ embed_weight = model.state_dict()["model.embedding.word_embeddings.weight"]
190
+ embed_weights_base_name = "model.embed_tokens.weight"
191
+ checkpoint[embed_weights_base_name] = param_to_weights(embed_weight)
192
+
193
+ for l in range(int(num_layers)):
194
+ print(f"converting layer {l}")
195
+
196
+ qkv_weights = model.state_dict()[f"model.decoder.layers.{l}.self_attention.linear_qkv.weight"]
197
+ qkv_weights = qkv_weights.reshape([qkv_total_dim, -1, hidden_size])
198
+
199
+ q_slice = torch.cat(
200
+ [
201
+ torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group)
202
+ for i in range(num_query_groups)
203
+ ]
204
+ )
205
+ k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2))
206
+ v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2))
207
+ ## Example of slices
208
+ ## (without GQA): num_query_groups = head_num = 32,
209
+ ## q_slice = [0, 3, 6, 9 , ... 90, 93]
210
+ ## k_slice = [1, 4, 7, 10, ... 91, 94]
211
+ ## v_slice = [2, 5, 8, 11, ... 92, 95]
212
+ ## (with GQA): num_query_groups = 8, head_num = 64
213
+ ## q_slice = [0, 1, .. 6, 7, 10, 11, .. 16, 17, 20, 21, .. 67, 70, ... 76, 77]
214
+ ## k_slice = [8, 18, 28, ... 68, 78]
215
+ ## v_slice = [9, 19, 29, ... 69, 79]
216
+
217
+ q_weights_base_name = f"model.layers.{l}.self_attn.q_proj.weight"
218
+ k_weights_base_name = f"model.layers.{l}.self_attn.k_proj.weight"
219
+ v_weights_base_name = f"model.layers.{l}.self_attn.v_proj.weight"
220
+
221
+ checkpoint[q_weights_base_name] = param_to_weights(qkv_weights[q_slice].reshape(-1, hidden_size))
222
+ checkpoint[k_weights_base_name] = param_to_weights(qkv_weights[k_slice].reshape(-1, hidden_size))
223
+ checkpoint[v_weights_base_name] = param_to_weights(qkv_weights[v_slice].reshape(-1, hidden_size))
224
+
225
+ # attention dense
226
+ o_weight = model.state_dict()[f"model.decoder.layers.{l}.self_attention.linear_proj.weight"]
227
+ o_weight_base_name = f"model.layers.{l}.self_attn.o_proj.weight"
228
+ checkpoint[o_weight_base_name] = param_to_weights(o_weight)
229
+
230
+ # mlp
231
+ mlp_weights = model.state_dict()[f"model.decoder.layers.{l}.mlp.linear_fc1.weight"]
232
+ mlp_up_proj_weight = model.state_dict()[f"model.decoder.layers.{l}.mlp.linear_fc2.weight"]
233
+
234
+ if mlp_weights.shape[0] != mlp_up_proj_weight.shape[1]:
235
+ # Has projection (used for swi-glu)
236
+ logging.warning(
237
+ "Gated projection layers detected in NeMo checkpoint. Currently Nemotron HF does not support gated MLP."
238
+ )
239
+ assert mlp_weights.shape[0] == 2 * mlp_up_proj_weight.shape[1]
240
+
241
+ mlp_down_proj_weight = mlp_weights[:ffn_hidden_size, :]
242
+ mlp_gate_proj_weight = mlp_weights[ffn_hidden_size:, :]
243
+
244
+ mlp_down_proj_base_name = f"model.layers.{l}.mlp.gate_proj.weight"
245
+ mlp_gate_proj_base_name = f"model.layers.{l}.mlp.up_proj.weight"
246
+
247
+ checkpoint[mlp_down_proj_base_name] = param_to_weights(mlp_down_proj_weight)
248
+ checkpoint[mlp_gate_proj_base_name] = param_to_weights(mlp_gate_proj_weight)
249
+ else:
250
+ mlp_down_proj_weight = mlp_weights
251
+ mlp_down_proj_base_name = f"model.layers.{l}.mlp.up_proj.weight"
252
+ checkpoint[mlp_down_proj_base_name] = param_to_weights(mlp_down_proj_weight)
253
+
254
+ mlp_up_proj_base_name = f"model.layers.{l}.mlp.down_proj.weight"
255
+ checkpoint[mlp_up_proj_base_name] = param_to_weights(mlp_up_proj_weight)
256
+
257
+ # layernorm
258
+ input_ln_weight = model.state_dict()[f"model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_weight"]
259
+ input_ln_base_name = f"model.layers.{l}.input_layernorm.weight"
260
+ checkpoint[input_ln_base_name] = param_to_weights(input_ln_weight)
261
+ if (
262
+ model.state_dict().get(f"model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_bias", None)
263
+ is not None
264
+ ):
265
+ input_ln_bias = model.state_dict()[f"model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_bias"]
266
+ input_ln_bias_name = f"model.layers.{l}.input_layernorm.bias"
267
+ checkpoint[input_ln_bias_name] = param_to_weights(input_ln_bias)
268
+
269
+ post_attn_ln_weight = model.state_dict()[f"model.decoder.layers.{l}.mlp.linear_fc1.layer_norm_weight"]
270
+ post_attn_ln_base_name = f"model.layers.{l}.post_attention_layernorm.weight"
271
+ checkpoint[post_attn_ln_base_name] = param_to_weights(post_attn_ln_weight)
272
+ if model.state_dict().get(f"model.decoder.layers.{l}.mlp.linear_fc1.layer_norm_bias", None) is not None:
273
+ post_attn_ln_bias = model.state_dict()[f"model.decoder.layers.{l}.mlp.linear_fc1.layer_norm_bias"]
274
+ post_attn_ln_bias_name = f"model.layers.{l}.post_attention_layernorm.bias"
275
+ checkpoint[post_attn_ln_bias_name] = param_to_weights(post_attn_ln_bias)
276
+
277
+ print(f"done layer {l}")
278
+
279
+ final_ln_weight = model.state_dict()["model.decoder.final_layernorm.weight"]
280
+ final_ln_base_name = "model.norm.weight"
281
+ checkpoint[final_ln_base_name] = param_to_weights(final_ln_weight)
282
+ if model.state_dict().get("model.decoder.final_layernorm.bias", None) is not None:
283
+ final_ln_bias = model.state_dict()["model.decoder.final_layernorm.bias"]
284
+ final_ln_bias_name = "model.norm.bias"
285
+ checkpoint[final_ln_bias_name] = param_to_weights(final_ln_bias)
286
+
287
+ output_layer_weight = model.state_dict()["model.output_layer.weight"]
288
+ output_layer_base_name = "lm_head.weight"
289
+ checkpoint[output_layer_base_name] = param_to_weights(output_layer_weight)
290
+
291
+ os.makedirs(os.path.dirname(output_hf_file), exist_ok=True)
292
+ torch.save(checkpoint, output_hf_file)
293
+ logging.info(f"Weights saved to {output_hf_file}")
294
+
295
+ return model_config, model.tokenizer, dtype, vocab_size
296
+
297
+
298
+ def extract_nemotron_tokenizer(nemo_file, model_config, output_hf_path, nemo_tokenizer):
299
+ tokenizer_cfg = model_config.tokenizer
300
+ if tokenizer_cfg.library == "sentencepiece":
301
+ # For sentencepiece tokenizer, we are wrapping with HF's LlamaTokenizer
302
+ # and convert it to a PreTrainedTokenizerFast
303
+ tokenizer_fn = tokenizer_cfg.model[5:]
304
+ output_tokenizer = f"{output_hf_path}/tokenizer.model"
305
+ if nemo_file.endswith(".nemo"):
306
+ import tarfile
307
+
308
+ archive = tarfile.open(nemo_file, "r")
309
+ tokenizer_filename = "./" + tokenizer_fn # exclude 'nemo:' prefix
310
+ archive.extract(tokenizer_filename, output_hf_path)
311
+ archive.close()
312
+ os.rename(f"{output_hf_path}/{tokenizer_fn}", output_tokenizer)
313
+ elif os.path.isdir(nemo_file):
314
+ shutil.copy(f"{nemo_file}/{tokenizer_fn}", output_tokenizer)
315
+ # We use LlamaTokenizer for sentencepiece based tokenizer
316
+ tokenizer = LlamaTokenizer.from_pretrained(output_hf_path, legacy=False)
317
+ # Convert the LlamaTokenizer to a PreTrainedTokenizerFast instance
318
+ tokenizer = PreTrainedTokenizerFast(
319
+ tokenizer_object=LlamaConverter(tokenizer).converted(), model_input_names=["input_ids", "token_type_ids"]
320
+ )
321
+ tokenizer.save_pretrained(output_hf_path)
322
+ logging.info(f"Setencepiece tokenizer has been saved to {output_tokenizer}")
323
+ elif isinstance(nemo_tokenizer, AutoTokenizer):
324
+ nemo_tokenizer.tokenizer.save_pretrained(output_hf_path)
325
+ logging.info(f"HF AutoTokenizer has been saved to {output_hf_path}")
326
+ else:
327
+ raise ValueError(f"Unsupported tokenizer type: library: {tokenizer_cfg.library}, type: {tokenizer_cfg.type}")
328
+
329
+
330
+ if __name__ == "__main__":
331
+ args = get_args()
332
+ if not args.hf_output_path:
333
+ assert args.output_path is not None, "Need to provide either output_path or hf_output_path"
334
+ else:
335
+ args.output_path = f"{args.hf_output_path}/pytorch_model.bin"
336
+ logging.info(f"weight will be saved to {args.output_path}")
337
+
338
+ nemo_config, nemo_tokenizer, dtype, vocab_size = convert(
339
+ args.input_name_or_path, args.output_path, precision=args.precision, cpu_only=args.cpu_only
340
+ )
341
+ if args.hf_input_path and args.hf_output_path:
342
+ convert_hf_config(nemo_config, nemo_tokenizer, vocab_size, dtype, args.hf_output_path, args.hf_input_path)
343
+ extract_nemotron_tokenizer(args.input_name_or_path, nemo_config, args.hf_output_path, nemo_tokenizer)
344
+ else:
345
+ logging.info("`hf_input_path` and/or `hf_output_path` not provided, not generating full HF model.")
346
+ logging.info(f".bin file is saved to {args.output_path}")
docs/transformers/src/transformers/models/nllb/tokenization_nllb.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ from shutil import copyfile
18
+ from typing import Any, Dict, List, Optional, Tuple
19
+
20
+ import sentencepiece as spm
21
+
22
+ from ...tokenization_utils import AddedToken, BatchEncoding, PreTrainedTokenizer
23
+ from ...utils import logging
24
+ from ...utils.import_utils import requires
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+ SPIECE_UNDERLINE = "▁"
30
+
31
+ VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"}
32
+
33
+
34
+ FAIRSEQ_LANGUAGE_CODES = ['ace_Arab', 'ace_Latn', 'acm_Arab', 'acq_Arab', 'aeb_Arab', 'afr_Latn', 'ajp_Arab', 'aka_Latn', 'amh_Ethi', 'apc_Arab', 'arb_Arab', 'ars_Arab', 'ary_Arab', 'arz_Arab', 'asm_Beng', 'ast_Latn', 'awa_Deva', 'ayr_Latn', 'azb_Arab', 'azj_Latn', 'bak_Cyrl', 'bam_Latn', 'ban_Latn', 'bel_Cyrl', 'bem_Latn', 'ben_Beng', 'bho_Deva', 'bjn_Arab', 'bjn_Latn', 'bod_Tibt', 'bos_Latn', 'bug_Latn', 'bul_Cyrl', 'cat_Latn', 'ceb_Latn', 'ces_Latn', 'cjk_Latn', 'ckb_Arab', 'crh_Latn', 'cym_Latn', 'dan_Latn', 'deu_Latn', 'dik_Latn', 'dyu_Latn', 'dzo_Tibt', 'ell_Grek', 'eng_Latn', 'epo_Latn', 'est_Latn', 'eus_Latn', 'ewe_Latn', 'fao_Latn', 'pes_Arab', 'fij_Latn', 'fin_Latn', 'fon_Latn', 'fra_Latn', 'fur_Latn', 'fuv_Latn', 'gla_Latn', 'gle_Latn', 'glg_Latn', 'grn_Latn', 'guj_Gujr', 'hat_Latn', 'hau_Latn', 'heb_Hebr', 'hin_Deva', 'hne_Deva', 'hrv_Latn', 'hun_Latn', 'hye_Armn', 'ibo_Latn', 'ilo_Latn', 'ind_Latn', 'isl_Latn', 'ita_Latn', 'jav_Latn', 'jpn_Jpan', 'kab_Latn', 'kac_Latn', 'kam_Latn', 'kan_Knda', 'kas_Arab', 'kas_Deva', 'kat_Geor', 'knc_Arab', 'knc_Latn', 'kaz_Cyrl', 'kbp_Latn', 'kea_Latn', 'khm_Khmr', 'kik_Latn', 'kin_Latn', 'kir_Cyrl', 'kmb_Latn', 'kon_Latn', 'kor_Hang', 'kmr_Latn', 'lao_Laoo', 'lvs_Latn', 'lij_Latn', 'lim_Latn', 'lin_Latn', 'lit_Latn', 'lmo_Latn', 'ltg_Latn', 'ltz_Latn', 'lua_Latn', 'lug_Latn', 'luo_Latn', 'lus_Latn', 'mag_Deva', 'mai_Deva', 'mal_Mlym', 'mar_Deva', 'min_Latn', 'mkd_Cyrl', 'plt_Latn', 'mlt_Latn', 'mni_Beng', 'khk_Cyrl', 'mos_Latn', 'mri_Latn', 'zsm_Latn', 'mya_Mymr', 'nld_Latn', 'nno_Latn', 'nob_Latn', 'npi_Deva', 'nso_Latn', 'nus_Latn', 'nya_Latn', 'oci_Latn', 'gaz_Latn', 'ory_Orya', 'pag_Latn', 'pan_Guru', 'pap_Latn', 'pol_Latn', 'por_Latn', 'prs_Arab', 'pbt_Arab', 'quy_Latn', 'ron_Latn', 'run_Latn', 'rus_Cyrl', 'sag_Latn', 'san_Deva', 'sat_Beng', 'scn_Latn', 'shn_Mymr', 'sin_Sinh', 'slk_Latn', 'slv_Latn', 'smo_Latn', 'sna_Latn', 'snd_Arab', 'som_Latn', 'sot_Latn', 'spa_Latn', 'als_Latn', 'srd_Latn', 'srp_Cyrl', 'ssw_Latn', 'sun_Latn', 'swe_Latn', 'swh_Latn', 'szl_Latn', 'tam_Taml', 'tat_Cyrl', 'tel_Telu', 'tgk_Cyrl', 'tgl_Latn', 'tha_Thai', 'tir_Ethi', 'taq_Latn', 'taq_Tfng', 'tpi_Latn', 'tsn_Latn', 'tso_Latn', 'tuk_Latn', 'tum_Latn', 'tur_Latn', 'twi_Latn', 'tzm_Tfng', 'uig_Arab', 'ukr_Cyrl', 'umb_Latn', 'urd_Arab', 'uzn_Latn', 'vec_Latn', 'vie_Latn', 'war_Latn', 'wol_Latn', 'xho_Latn', 'ydd_Hebr', 'yor_Latn', 'yue_Hant', 'zho_Hans', 'zho_Hant', 'zul_Latn'] # fmt: skip
35
+
36
+
37
+ @requires(backends=("sentencepiece",))
38
+ class NllbTokenizer(PreTrainedTokenizer):
39
+ """
40
+ Construct an NLLB tokenizer.
41
+
42
+ Adapted from [`RobertaTokenizer`] and [`XLNetTokenizer`]. Based on
43
+ [SentencePiece](https://github.com/google/sentencepiece).
44
+
45
+ The tokenization method is `<tokens> <eos> <language code>` for source language documents, and `<language code>
46
+ <tokens> <eos>` for target language documents.
47
+
48
+ Examples:
49
+
50
+ ```python
51
+ >>> from transformers import NllbTokenizer
52
+
53
+ >>> tokenizer = NllbTokenizer.from_pretrained(
54
+ ... "facebook/nllb-200-distilled-600M", src_lang="eng_Latn", tgt_lang="fra_Latn"
55
+ ... )
56
+ >>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
57
+ >>> expected_translation_french = "Le chef de l'ONU affirme qu'il n'y a pas de solution militaire en Syrie."
58
+ >>> inputs = tokenizer(example_english_phrase, text_target=expected_translation_french, return_tensors="pt")
59
+ ```
60
+
61
+ Args:
62
+ vocab_file (`str`):
63
+ Path to the vocabulary file.
64
+ bos_token (`str`, *optional*, defaults to `"<s>"`):
65
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
66
+
67
+ <Tip>
68
+
69
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
70
+ sequence. The token used is the `cls_token`.
71
+
72
+ </Tip>
73
+
74
+ eos_token (`str`, *optional*, defaults to `"</s>"`):
75
+ The end of sequence token.
76
+
77
+ <Tip>
78
+
79
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
80
+ The token used is the `sep_token`.
81
+
82
+ </Tip>
83
+
84
+ sep_token (`str`, *optional*, defaults to `"</s>"`):
85
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
86
+ sequence classification or for a text and a question for question answering. It is also used as the last
87
+ token of a sequence built with special tokens.
88
+ cls_token (`str`, *optional*, defaults to `"<s>"`):
89
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
90
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
91
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
92
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
93
+ token instead.
94
+ pad_token (`str`, *optional*, defaults to `"<pad>"`):
95
+ The token used for padding, for example when batching sequences of different lengths.
96
+ mask_token (`str`, *optional*, defaults to `"<mask>"`):
97
+ The token used for masking values. This is the token used when training this model with masked language
98
+ modeling. This is the token which the model will try to predict.
99
+ tokenizer_file (`str`, *optional*):
100
+ The path to a tokenizer file to use instead of the vocab file.
101
+ src_lang (`str`, *optional*):
102
+ The language to use as source language for translation.
103
+ tgt_lang (`str`, *optional*):
104
+ The language to use as target language for translation.
105
+ sp_model_kwargs (`Dict[str, str]`):
106
+ Additional keyword arguments to pass to the model initialization.
107
+ """
108
+
109
+ vocab_files_names = VOCAB_FILES_NAMES
110
+ model_input_names = ["input_ids", "attention_mask"]
111
+
112
+ prefix_tokens: List[int] = []
113
+ suffix_tokens: List[int] = []
114
+
115
+ def __init__(
116
+ self,
117
+ vocab_file,
118
+ bos_token="<s>",
119
+ eos_token="</s>",
120
+ sep_token="</s>",
121
+ cls_token="<s>",
122
+ unk_token="<unk>",
123
+ pad_token="<pad>",
124
+ mask_token="<mask>",
125
+ tokenizer_file=None,
126
+ src_lang=None,
127
+ tgt_lang=None,
128
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
129
+ additional_special_tokens=None,
130
+ legacy_behaviour=False,
131
+ **kwargs,
132
+ ):
133
+ if additional_special_tokens is None:
134
+ additional_special_tokens = FAIRSEQ_LANGUAGE_CODES
135
+ bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
136
+ pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token
137
+ eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
138
+ unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
139
+ # Mask token behave like a normal word, i.e. include the space before it
140
+ mask_token = (
141
+ AddedToken(mask_token, normalized=True, lstrip=True, special=True)
142
+ if isinstance(mask_token, str)
143
+ else mask_token
144
+ )
145
+
146
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
147
+ self.legacy_behaviour = legacy_behaviour
148
+
149
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
150
+ self.sp_model.Load(str(vocab_file))
151
+ self.vocab_file = vocab_file
152
+ # Original fairseq vocab and spm vocab must be "aligned":
153
+ # Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
154
+ # -------- | ------- | ------- | ------ | ------- | ---- | ---- | ---- | ---- | ---- | ----
155
+ # fairseq | '<s>' | '<pad>' | '</s>' | '<unk>' | 'an' | '▁n' | '▁m' | '▁t' | '▁k' | '▁a'
156
+ # spm | '<unk>' | '<s>' | '</s>' | 'an' | '▁n' | '▁m' | '▁t' | '▁k' | '▁a' | '▁s'
157
+
158
+ # unk token needs to be in the vocab with correct index
159
+ self._added_tokens_decoder = {0: bos_token, 1: pad_token, 2: eos_token, 3: unk_token}
160
+ # The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab
161
+ self.fairseq_offset = 1
162
+ self.sp_model_size = len(self.sp_model)
163
+
164
+ super().__init__(
165
+ bos_token=bos_token,
166
+ eos_token=eos_token,
167
+ unk_token=unk_token,
168
+ sep_token=sep_token,
169
+ cls_token=cls_token,
170
+ pad_token=pad_token,
171
+ mask_token=mask_token,
172
+ tokenizer_file=tokenizer_file,
173
+ src_lang=src_lang,
174
+ tgt_lang=tgt_lang,
175
+ additional_special_tokens=additional_special_tokens,
176
+ sp_model_kwargs=self.sp_model_kwargs,
177
+ legacy_behaviour=legacy_behaviour,
178
+ **kwargs,
179
+ )
180
+
181
+ self._src_lang = src_lang if src_lang is not None else "eng_Latn"
182
+ self.cur_lang_code_id = self.convert_tokens_to_ids(self._src_lang)
183
+ self.tgt_lang = tgt_lang
184
+ self.set_src_lang_special_tokens(self._src_lang)
185
+
186
+ def __getstate__(self):
187
+ state = self.__dict__.copy()
188
+ state["sp_model"] = None
189
+ state["sp_model_proto"] = self.sp_model.serialized_model_proto()
190
+ return state
191
+
192
+ def __setstate__(self, d):
193
+ self.__dict__ = d
194
+
195
+ # for backward compatibility
196
+ if not hasattr(self, "sp_model_kwargs"):
197
+ self.sp_model_kwargs = {}
198
+
199
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
200
+ self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
201
+
202
+ @property
203
+ def vocab_size(self):
204
+ return len(self.sp_model) + self.fairseq_offset
205
+
206
+ @property
207
+ def src_lang(self) -> str:
208
+ return self._src_lang
209
+
210
+ @src_lang.setter
211
+ def src_lang(self, new_src_lang: str) -> None:
212
+ self._src_lang = new_src_lang
213
+ self.set_src_lang_special_tokens(self._src_lang)
214
+
215
+ def get_special_tokens_mask(
216
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
217
+ ) -> List[int]:
218
+ """
219
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
220
+ special tokens using the tokenizer `prepare_for_model` method.
221
+
222
+ Args:
223
+ token_ids_0 (`List[int]`):
224
+ List of IDs.
225
+ token_ids_1 (`List[int]`, *optional*):
226
+ Optional second list of IDs for sequence pairs.
227
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
228
+ Whether or not the token list is already formatted with special tokens for the model.
229
+
230
+ Returns:
231
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
232
+ """
233
+
234
+ if already_has_special_tokens:
235
+ return super().get_special_tokens_mask(
236
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
237
+ )
238
+
239
+ prefix_ones = [1] * len(self.prefix_tokens)
240
+ suffix_ones = [1] * len(self.suffix_tokens)
241
+ if token_ids_1 is None:
242
+ return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
243
+ return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
244
+
245
+ def build_inputs_with_special_tokens(
246
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
247
+ ) -> List[int]:
248
+ """
249
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
250
+ adding special tokens. An NLLB sequence has the following format, where `X` represents the sequence:
251
+
252
+ - `input_ids` (for encoder) `X [eos, src_lang_code]`
253
+ - `decoder_input_ids`: (for decoder) `X [eos, tgt_lang_code]`
254
+
255
+ BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
256
+ separator.
257
+
258
+ Args:
259
+ token_ids_0 (`List[int]`):
260
+ List of IDs to which the special tokens will be added.
261
+ token_ids_1 (`List[int]`, *optional*):
262
+ Optional second list of IDs for sequence pairs.
263
+
264
+ Returns:
265
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
266
+ """
267
+ if token_ids_1 is None:
268
+ return self.prefix_tokens + token_ids_0 + self.suffix_tokens
269
+ # We don't expect to process pairs, but leave the pair logic for API consistency
270
+ return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
271
+
272
+ def create_token_type_ids_from_sequences(
273
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
274
+ ) -> List[int]:
275
+ """
276
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. nllb does not
277
+ make use of token type ids, therefore a list of zeros is returned.
278
+
279
+ Args:
280
+ token_ids_0 (`List[int]`):
281
+ List of IDs.
282
+ token_ids_1 (`List[int]`, *optional*):
283
+ Optional second list of IDs for sequence pairs.
284
+
285
+ Returns:
286
+ `List[int]`: List of zeros.
287
+
288
+ """
289
+
290
+ sep = [self.sep_token_id]
291
+ cls = [self.cls_token_id]
292
+
293
+ if token_ids_1 is None:
294
+ return len(cls + token_ids_0 + sep) * [0]
295
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
296
+
297
+ def _build_translation_inputs(
298
+ self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs
299
+ ):
300
+ """Used by translation pipeline, to prepare inputs for the generate function"""
301
+ if src_lang is None or tgt_lang is None:
302
+ raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
303
+ self.src_lang = src_lang
304
+ inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs)
305
+ tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
306
+ inputs["forced_bos_token_id"] = tgt_lang_id
307
+ return inputs
308
+
309
+ def get_vocab(self):
310
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
311
+ vocab.update(self.added_tokens_encoder)
312
+ return vocab
313
+
314
+ def _tokenize(self, text: str) -> List[str]:
315
+ return self.sp_model.encode(text, out_type=str)
316
+
317
+ def _convert_token_to_id(self, token):
318
+ """Converts a token (str) in an id using the vocab."""
319
+ spm_id = self.sp_model.PieceToId(token)
320
+ # Need to return unknown token if the SP model returned 0
321
+ return spm_id + self.fairseq_offset if spm_id else self.unk_token_id
322
+
323
+ def _convert_id_to_token(self, index):
324
+ """Converts an index (integer) in a token (str) using the vocab."""
325
+ return self.sp_model.IdToPiece(index - self.fairseq_offset)
326
+
327
+ def convert_tokens_to_string(self, tokens):
328
+ """Converts a sequence of tokens (strings for sub-words) in a single string."""
329
+ out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
330
+ return out_string
331
+
332
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
333
+ if not os.path.isdir(save_directory):
334
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
335
+ return
336
+ out_vocab_file = os.path.join(
337
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
338
+ )
339
+
340
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
341
+ copyfile(self.vocab_file, out_vocab_file)
342
+ elif not os.path.isfile(self.vocab_file):
343
+ with open(out_vocab_file, "wb") as fi:
344
+ content_spiece_model = self.sp_model.serialized_model_proto()
345
+ fi.write(content_spiece_model)
346
+
347
+ return (out_vocab_file,)
348
+
349
+ def prepare_seq2seq_batch(
350
+ self,
351
+ src_texts: List[str],
352
+ src_lang: str = "eng_Latn",
353
+ tgt_texts: Optional[List[str]] = None,
354
+ tgt_lang: str = "fra_Latn",
355
+ **kwargs,
356
+ ) -> BatchEncoding:
357
+ self.src_lang = src_lang
358
+ self.tgt_lang = tgt_lang
359
+ return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
360
+
361
+ def _switch_to_input_mode(self):
362
+ return self.set_src_lang_special_tokens(self.src_lang)
363
+
364
+ def _switch_to_target_mode(self):
365
+ return self.set_tgt_lang_special_tokens(self.tgt_lang)
366
+
367
+ def set_src_lang_special_tokens(self, src_lang) -> None:
368
+ """Reset the special tokens to the source lang setting.
369
+ - In legacy mode: No prefix and suffix=[eos, src_lang_code].
370
+ - In default mode: Prefix=[src_lang_code], suffix = [eos]
371
+ """
372
+ self.cur_lang_code = self.convert_tokens_to_ids(src_lang)
373
+ if self.legacy_behaviour:
374
+ self.prefix_tokens = []
375
+ self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
376
+ else:
377
+ self.prefix_tokens = [self.cur_lang_code]
378
+ self.suffix_tokens = [self.eos_token_id]
379
+
380
+ def set_tgt_lang_special_tokens(self, lang: str) -> None:
381
+ """Reset the special tokens to the target lang setting.
382
+ - In legacy mode: No prefix and suffix=[eos, tgt_lang_code].
383
+ - In default mode: Prefix=[tgt_lang_code], suffix = [eos]
384
+ """
385
+ self.cur_lang_code = self.convert_tokens_to_ids(lang)
386
+ if self.legacy_behaviour:
387
+ self.prefix_tokens = []
388
+ self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
389
+ else:
390
+ self.prefix_tokens = [self.cur_lang_code]
391
+ self.suffix_tokens = [self.eos_token_id]
392
+
393
+
394
+ __all__ = ["NllbTokenizer"]
docs/transformers/src/transformers/models/nllb_moe/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_nllb_moe import *
22
+ from .modeling_nllb_moe import *
23
+ else:
24
+ import sys
25
+
26
+ _file = globals()["__file__"]
27
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/src/transformers/models/nllb_moe/configuration_nllb_moe.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023, HuggingFace Inc.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """NLLB-MoE model configuration"""
16
+
17
+ from ...configuration_utils import PretrainedConfig
18
+ from ...utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class NllbMoeConfig(PretrainedConfig):
25
+ r"""
26
+ This is the configuration class to store the configuration of a [`NllbMoeModel`]. It is used to instantiate an
27
+ NLLB-MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration
28
+ with the defaults will yield a similar configuration to that of the NLLB-MoE
29
+ [facebook/nllb-moe-54b](https://huggingface.co/facebook/nllb-moe-54b) architecture.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+
35
+ Args:
36
+ vocab_size (`int`, *optional*, defaults to 50265):
37
+ Vocabulary size of the NllbMoe model. Defines the number of different tokens that can be represented by the
38
+ `inputs_ids` passed when calling [`NllbMoeModel`] or
39
+ d_model (`int`, *optional*, defaults to 1024):
40
+ Dimensionality of the layers and the pooler layer.
41
+ encoder_layers (`int`, *optional*, defaults to 12):
42
+ Number of encoder layers.
43
+ decoder_layers (`int`, *optional*, defaults to 12):
44
+ Number of decoder layers.
45
+ encoder_attention_heads (`int`, *optional*, defaults to 16):
46
+ Number of attention heads for each attention layer in the Transformer encoder.
47
+ decoder_attention_heads (`int`, *optional*, defaults to 16):
48
+ Number of attention heads for each attention layer in the Transformer decoder.
49
+ decoder_ffn_dim (`int`, *optional*, defaults to 4096):
50
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
51
+ encoder_ffn_dim (`int`, *optional*, defaults to 4096):
52
+ Dimensionality of the "intermediate" (often named feed-forward) layer in encoder.
53
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
54
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
55
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
56
+ dropout (`float`, *optional*, defaults to 0.1):
57
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
58
+ attention_dropout (`float`, *optional*, defaults to 0.0):
59
+ The dropout ratio for the attention probabilities.
60
+ activation_dropout (`float`, *optional*, defaults to 0.0):
61
+ The dropout ratio for activations inside the fully connected layer.
62
+ classifier_dropout (`float`, *optional*, defaults to 0.0):
63
+ The dropout ratio for classifier.
64
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
65
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
66
+ just in case (e.g., 512 or 1024 or 2048).
67
+ init_std (`float`, *optional*, defaults to 0.02):
68
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
69
+ encoder_layerdrop (`float`, *optional*, defaults to 0.0):
70
+ The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
71
+ for more details.
72
+ decoder_layerdrop (`float`, *optional*, defaults to 0.0):
73
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
74
+ for more details.
75
+ second_expert_policy ( `str`, *optional*, default to `"all"`):
76
+ The policy used for the sampling the probability of being sampled to a second expert for each token.
77
+ normalize_router_prob_before_dropping (`bool`, *optional*, defaults to `True`):
78
+ Whether or not to normalize the router probabilities before applying a mask based on the experts capacity
79
+ (capacity dropping).
80
+ batch_prioritized_routing (`bool`, *optional*, defaults to `True`):
81
+ Whether or not to orders the tokens by their router probabilities before capacity dropping. This means that
82
+ the tokens that have the highest probabilities will be routed before other tokens that might be further in
83
+ the sequence.
84
+ moe_eval_capacity_token_fraction (`float`, *optional*, defaults to 1.0):
85
+ Fraction of tokens as capacity during validation, if set to negative, uses the same as training. Should be
86
+ in range: (0.0, 1.0].
87
+ num_experts (`int`, *optional*, defaults to 128):
88
+ Number of experts for each NllbMoeSparseMlp layer.
89
+ expert_capacity (`int`, *optional*, defaults to 64):
90
+ Number of tokens that can be stored in each expert.
91
+ encoder_sparse_step (`int`, *optional*, defaults to 4):
92
+ Frequency of the sparse layers in the encoder. 4 means that one out of 4 layers will be sparse.
93
+ decoder_sparse_step (`int`, *optional*, defaults to 4):
94
+ Frequency of the sparse layers in the decoder. 4 means that one out of 4 layers will be sparse.
95
+ router_dtype (`str`, *optional*, default to `"float32"`):
96
+ The `dtype` used for the routers. It is preferable to keep the `dtype` to `"float32"` as specified in the
97
+ *selective precision* discussion in [the paper](https://arxiv.org/abs/2101.03961).
98
+ router_ignore_padding_tokens (`bool`, *optional*, defaults to `False`):
99
+ Whether to ignore padding tokens when routing. if `False`, the padding tokens are not routed to any
100
+ experts.
101
+ router_bias (`bool`, *optional*, defaults to `False`):
102
+ Whether or not the classifier of the router should have a bias.
103
+ moe_token_dropout (`float`, *optional*, default to 0.2):
104
+ Masking rate for MoE expert output masking (EOM), which is implemented via a Dropout2d on the expert
105
+ outputs.
106
+ output_router_logits (`bool`, *optional*, defaults to `False`):
107
+ Whether or not to return the router logits. Only set to `True` to get the auxiliary loss when training.
108
+ use_cache (`bool`, *optional*, defaults to `True`):
109
+ Whether or not the model should return the last key/values attentions (not used by all models).
110
+
111
+ Example:
112
+
113
+ ```python
114
+ >>> from transformers import NllbMoeModel, NllbMoeConfig
115
+
116
+ >>> # Initializing a NllbMoe facebook/nllb-moe-54b style configuration
117
+ >>> configuration = NllbMoeConfig()
118
+
119
+ >>> # Initializing a model from the facebook/nllb-moe-54b style configuration
120
+ >>> model = NllbMoeModel(configuration)
121
+
122
+ >>> # Accessing the model configuration
123
+ >>> configuration = model.config
124
+ ```"""
125
+
126
+ model_type = "nllb-moe"
127
+ keys_to_ignore_at_inference = ["past_key_values"]
128
+ attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
129
+
130
+ def __init__(
131
+ self,
132
+ vocab_size=128112,
133
+ max_position_embeddings=1024,
134
+ encoder_layers=12,
135
+ encoder_ffn_dim=4096,
136
+ encoder_attention_heads=16,
137
+ decoder_layers=12,
138
+ decoder_ffn_dim=4096,
139
+ decoder_attention_heads=16,
140
+ encoder_layerdrop=0.05,
141
+ decoder_layerdrop=0.05,
142
+ use_cache=True,
143
+ is_encoder_decoder=True,
144
+ activation_function="relu",
145
+ d_model=1024,
146
+ dropout=0.1,
147
+ attention_dropout=0.1,
148
+ activation_dropout=0.0,
149
+ init_std=0.02,
150
+ decoder_start_token_id=2,
151
+ scale_embedding=True,
152
+ router_bias=False,
153
+ router_dtype="float32",
154
+ router_ignore_padding_tokens=False,
155
+ num_experts=128,
156
+ expert_capacity=64,
157
+ encoder_sparse_step=4,
158
+ decoder_sparse_step=4,
159
+ router_z_loss_coef=0.001,
160
+ router_aux_loss_coef=0.001,
161
+ second_expert_policy="all",
162
+ normalize_router_prob_before_dropping=False,
163
+ batch_prioritized_routing=False,
164
+ moe_eval_capacity_token_fraction=1.0,
165
+ moe_token_dropout=0.2,
166
+ pad_token_id=1,
167
+ bos_token_id=0,
168
+ eos_token_id=2,
169
+ output_router_logits=False,
170
+ **kwargs,
171
+ ):
172
+ self.vocab_size = vocab_size
173
+ self.max_position_embeddings = max_position_embeddings
174
+ self.d_model = d_model
175
+ self.encoder_ffn_dim = encoder_ffn_dim
176
+ self.encoder_layers = encoder_layers
177
+ self.encoder_attention_heads = encoder_attention_heads
178
+ self.decoder_ffn_dim = decoder_ffn_dim
179
+ self.decoder_layers = decoder_layers
180
+ self.decoder_attention_heads = decoder_attention_heads
181
+ self.dropout = dropout
182
+ self.attention_dropout = attention_dropout
183
+ self.activation_dropout = activation_dropout
184
+ self.activation_function = activation_function
185
+ self.init_std = init_std
186
+ self.encoder_layerdrop = encoder_layerdrop
187
+ self.decoder_layerdrop = decoder_layerdrop
188
+ self.use_cache = use_cache
189
+ self.num_hidden_layers = encoder_layers
190
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
191
+ self.router_z_loss_coef = router_z_loss_coef
192
+ self.router_aux_loss_coef = router_aux_loss_coef
193
+ self.decoder_sparse_step = decoder_sparse_step
194
+ self.encoder_sparse_step = encoder_sparse_step
195
+ self.num_experts = num_experts
196
+ self.expert_capacity = expert_capacity
197
+ self.router_bias = router_bias
198
+ if router_dtype not in ["float32", "float16", "bfloat16"]:
199
+ raise ValueError(f"`router_dtype` must be one of 'float32', 'float16' or 'bfloat16', got {router_dtype}")
200
+ self.router_dtype = router_dtype
201
+
202
+ self.router_ignore_padding_tokens = router_ignore_padding_tokens
203
+ self.batch_prioritized_routing = batch_prioritized_routing
204
+ self.second_expert_policy = second_expert_policy
205
+ self.normalize_router_prob_before_dropping = normalize_router_prob_before_dropping
206
+ self.moe_eval_capacity_token_fraction = moe_eval_capacity_token_fraction
207
+ self.moe_token_dropout = moe_token_dropout
208
+ self.output_router_logits = output_router_logits
209
+ super().__init__(
210
+ pad_token_id=pad_token_id,
211
+ bos_token_id=bos_token_id,
212
+ eos_token_id=eos_token_id,
213
+ is_encoder_decoder=is_encoder_decoder,
214
+ decoder_start_token_id=decoder_start_token_id,
215
+ **kwargs,
216
+ )
217
+
218
+
219
+ __all__ = ["NllbMoeConfig"]
docs/transformers/src/transformers/models/nllb_moe/convert_nllb_moe_sharded_original_checkpoint_to_pytorch.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import argparse
15
+ import json
16
+ import os
17
+
18
+ import torch
19
+ from torch import nn
20
+
21
+ from transformers import NllbMoeConfig, NllbMoeModel
22
+ from transformers.utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME
23
+
24
+
25
+ def remove_ignore_keys_(state_dict):
26
+ ignore_keys = [
27
+ "encoder.version",
28
+ "decoder.version",
29
+ "model.encoder.version",
30
+ "model.decoder.version",
31
+ "decoder.output_projection.weight",
32
+ "_float_tensor",
33
+ "encoder.embed_positions._float_tensor",
34
+ "decoder.embed_positions._float_tensor",
35
+ ]
36
+ for k in ignore_keys:
37
+ state_dict.pop(k, None)
38
+
39
+
40
+ def make_linear_from_emb(emb):
41
+ vocab_size, emb_size = emb.weight.shape
42
+ lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
43
+ lin_layer.weight.data = emb.weight.data
44
+ return lin_layer
45
+
46
+
47
+ def rename_fairseq_keys(state_dict, expert_idx=None):
48
+ new_dict = {}
49
+ for old_key in state_dict.keys():
50
+ key = old_key
51
+ if "moe_layer.experts." in key:
52
+ if expert_idx is not None:
53
+ key = key.replace("moe_layer.experts.0", f"ffn.experts.expert_{expert_idx}")
54
+ else:
55
+ key = key.replace("moe_layer.experts.", "ffn.experts.expert_")
56
+ if "gate" in key:
57
+ key = key.replace(".moe_layer.gate.wg", ".ffn.router.classifier")
58
+ if "fc2" and "experts" not in key:
59
+ key = key.replace(".fc2.", ".ffn.fc2.")
60
+ if "fc1" and "experts" not in key:
61
+ key = key.replace(".fc1.", ".ffn.fc1.")
62
+ if ".encoder_attn." in key:
63
+ key = key.replace(".encoder_attn.", ".cross_attention.")
64
+ if "encoder_attn_layer_norm" in key:
65
+ key = key.replace("encoder_attn_layer_norm", "cross_attention_layer_norm")
66
+ if "final_layer_norm" in key:
67
+ key = key.replace("final_layer_norm", "ff_layer_norm")
68
+ new_dict[key] = state_dict[old_key]
69
+ return new_dict
70
+
71
+
72
+ def shard_on_the_fly(switch_checkpoint_path, dump_path, num_experts, dtype, weights_name: str = WEIGHTS_NAME):
73
+ sharded_state_dicts = []
74
+ total_size = 0
75
+ os.makedirs(dump_path, exist_ok=True)
76
+
77
+ for expert in range(num_experts):
78
+ expert_path = switch_checkpoint_path + f"-rank-{expert}.pt"
79
+ if os.path.isfile(expert_path):
80
+ expert_state = torch.load(expert_path, weights_only=True)["model"]
81
+ remove_ignore_keys_(expert_state)
82
+ expert_state = rename_fairseq_keys(expert_state, expert)
83
+ save_path = os.path.join(
84
+ dump_path, weights_name.replace(".bin", f"-{len(sharded_state_dicts) + 1:05d}-of-???.bin")
85
+ )
86
+ torch.save(expert_state, save_path)
87
+ sharded_state_dicts.append(expert_state.keys())
88
+ total_size += sum([value.numel() for key, value in expert_state.items()]) * (
89
+ expert_state[list(expert_state)[0]].element_size()
90
+ )
91
+
92
+ # Add the last block
93
+ save_path = os.path.join(
94
+ dump_path, weights_name.replace(".bin", f"-{len(sharded_state_dicts) + 1:05d}-of-???.bin")
95
+ )
96
+ shared_weights = torch.load(switch_checkpoint_path + "-shared.pt", weights_only=True)["model"]
97
+ remove_ignore_keys_(shared_weights)
98
+ shared_weights = rename_fairseq_keys(shared_weights, None)
99
+ shared_weights["shared.weight"] = shared_weights["decoder.embed_tokens.weight"]
100
+ sharded_state_dicts.append(shared_weights.keys())
101
+
102
+ # If we only have the shared weights (dummy model/experts saved on the same file)
103
+ if len(sharded_state_dicts) == 1:
104
+ save_path = os.path.join(dump_path, weights_name)
105
+ torch.save(shared_weights, save_path)
106
+ return {weights_name: sharded_state_dicts[0]}, None
107
+ else:
108
+ torch.save(shared_weights, save_path)
109
+ # Otherwise, let's build the index
110
+ weight_map = {}
111
+ for idx, shard in enumerate(sharded_state_dicts):
112
+ shard_file = weights_name.replace(".bin", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.bin")
113
+ temp_filename = os.path.join(dump_path, weights_name.replace(".bin", f"-{idx + 1:05d}-of-???.bin"))
114
+ os.rename(temp_filename, os.path.join(dump_path, shard_file))
115
+ for key in shard:
116
+ weight_map[key] = shard_file
117
+
118
+ # Add the metadata
119
+ metadata = {"total_size": total_size}
120
+ index = {"metadata": metadata, "weight_map": weight_map}
121
+
122
+ with open(os.path.join(dump_path, WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f:
123
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
124
+ f.write(content)
125
+
126
+ return metadata, index
127
+
128
+
129
+ if __name__ == "__main__":
130
+ parser = argparse.ArgumentParser()
131
+ # Required parameters
132
+ parser.add_argument(
133
+ "--nllb_moe_checkpoint_path",
134
+ default="/home/arthur_huggingface_co/fairseq/weights/checkpoints/model_moe_54b/checkpoint_2_300000",
135
+ type=str,
136
+ required=False,
137
+ help="Path to a directory containing a folder per layer. Follows the original Google format.",
138
+ )
139
+ parser.add_argument("--dtype", default="float32", type=str, required=False, help="dtype of the saved model")
140
+ parser.add_argument(
141
+ "--pytorch_dump_folder_path",
142
+ default="/home/arthur_huggingface_co/fairseq/weights/checkpoints/hf-converted-moe-54b",
143
+ type=str,
144
+ required=False,
145
+ help="Path to the output pytorch model.",
146
+ )
147
+ args = parser.parse_args()
148
+ metadata, index = shard_on_the_fly(
149
+ args.nllb_moe_checkpoint_path,
150
+ args.pytorch_dump_folder_path,
151
+ 128,
152
+ args.dtype,
153
+ )
154
+
155
+ config = NllbMoeConfig.from_pretrained(
156
+ "facebook/nllb-200-3.3B", encoder_sparse_step=4, decoder_sparse_step=4, num_experts=128
157
+ )
158
+ config.save_pretrained(args.pytorch_dump_folder_path)
159
+ model = NllbMoeModel.from_pretrained(args.pytorch_dump_folder_path)
160
+ print("Done")
161
+ model.save_pretrained(args.pytorch_dump_folder_path)
docs/transformers/src/transformers/models/nllb_moe/modeling_nllb_moe.py ADDED
@@ -0,0 +1,1784 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 NllbMoe Authors and HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch NLLB-MoE model."""
16
+
17
+ import math
18
+ from typing import List, Optional, Tuple, Union
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ from torch.nn import CrossEntropyLoss
23
+
24
+ from ...activations import ACT2FN
25
+ from ...generation import GenerationMixin
26
+ from ...integrations.deepspeed import is_deepspeed_zero3_enabled
27
+ from ...integrations.fsdp import is_fsdp_managed_module
28
+ from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
29
+ from ...modeling_outputs import (
30
+ MoEModelOutput,
31
+ MoEModelOutputWithPastAndCrossAttentions,
32
+ Seq2SeqMoEModelOutput,
33
+ Seq2SeqMoEOutput,
34
+ )
35
+ from ...modeling_utils import PreTrainedModel
36
+ from ...utils import (
37
+ add_end_docstrings,
38
+ add_start_docstrings,
39
+ add_start_docstrings_to_model_forward,
40
+ logging,
41
+ replace_return_docstrings,
42
+ )
43
+ from .configuration_nllb_moe import NllbMoeConfig
44
+
45
+
46
+ logger = logging.get_logger(__name__)
47
+
48
+ _CONFIG_FOR_DOC = "NllbMoeConfig"
49
+ _CHECKPOINT_FOR_DOC = "hf-internal-testing/dummy-nllb-moe-2-experts"
50
+ _REAL_CHECKPOINT_FOR_DOC = "facebook/nllb-moe-54b"
51
+
52
+
53
+ ####################################################
54
+ # This dict contains ids and associated url
55
+ # for the pretrained weights provided with the models
56
+ ####################################################
57
+
58
+
59
+ # Copied from transformers.models.bart.modeling_bart.shift_tokens_right
60
+ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
61
+ """
62
+ Shift input ids one token to the right.
63
+ """
64
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
65
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
66
+ shifted_input_ids[:, 0] = decoder_start_token_id
67
+
68
+ if pad_token_id is None:
69
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
70
+ # replace possible -100 values in labels by `pad_token_id`
71
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
72
+
73
+ return shifted_input_ids
74
+
75
+
76
+ # Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids
77
+ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
78
+ """
79
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
80
+ are ignored. This is modified from fairseq's `utils.make_positions`.
81
+
82
+ Args:
83
+ x: torch.Tensor x:
84
+
85
+ Returns: torch.Tensor
86
+ """
87
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
88
+ mask = input_ids.ne(padding_idx).int()
89
+ incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
90
+ return incremental_indices.long() + padding_idx
91
+
92
+
93
+ def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.Tensor) -> float:
94
+ r"""
95
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
96
+
97
+ See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
98
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
99
+ experts is too unbalanced.
100
+
101
+ Args:
102
+ router_probs (`torch.Tensor`):
103
+ Probability assigned to each expert per token. Shape: [batch_size, seqeunce_length, num_experts].
104
+ expert_indices (`torch.Tensor`):
105
+ Indices tensor of shape [batch_size, seqeunce_length] identifying the selected expert for a given token.
106
+
107
+ Returns:
108
+ The auxiliary loss.
109
+ """
110
+ if router_probs is None:
111
+ return 0
112
+
113
+ num_experts = router_probs.shape[-1]
114
+
115
+ # cast the expert indices to int64, otherwise one-hot encoding will fail
116
+ if expert_indices.dtype != torch.int64:
117
+ expert_indices = expert_indices.to(torch.int64)
118
+
119
+ if len(expert_indices.shape) == 2:
120
+ expert_indices = expert_indices.unsqueeze(2)
121
+
122
+ expert_mask = torch.nn.functional.one_hot(expert_indices, num_experts)
123
+
124
+ # For a given token, determine if it was routed to a given expert.
125
+ expert_mask = torch.max(expert_mask, axis=-2).values
126
+
127
+ # cast to float32 otherwise mean will fail
128
+ expert_mask = expert_mask.to(torch.float32)
129
+ tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2)
130
+
131
+ router_prob_per_group_and_expert = torch.mean(router_probs, axis=-2)
132
+ return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) * (num_experts**2)
133
+
134
+
135
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ScaledWordEmbedding with M2M100->NllbMoe
136
+ class NllbMoeScaledWordEmbedding(nn.Embedding):
137
+ """
138
+ This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
139
+ """
140
+
141
+ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
142
+ super().__init__(num_embeddings, embedding_dim, padding_idx)
143
+ self.embed_scale = embed_scale
144
+
145
+ def forward(self, input_ids: torch.Tensor):
146
+ return super().forward(input_ids) * self.embed_scale
147
+
148
+
149
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding
150
+ class NllbMoeSinusoidalPositionalEmbedding(nn.Module):
151
+ """This module produces sinusoidal positional embeddings of any length."""
152
+
153
+ def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
154
+ super().__init__()
155
+ self.offset = 2
156
+ self.embedding_dim = embedding_dim
157
+ self.padding_idx = padding_idx
158
+ self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
159
+
160
+ def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
161
+ emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
162
+ if hasattr(self, "weights"):
163
+ # in forward put the weights on the correct dtype and device of the param
164
+ emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
165
+
166
+ self.register_buffer("weights", emb_weights, persistent=False)
167
+
168
+ @staticmethod
169
+ def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
170
+ """
171
+ Build sinusoidal embeddings.
172
+
173
+ This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
174
+ "Attention Is All You Need".
175
+ """
176
+ half_dim = embedding_dim // 2
177
+ emb = math.log(10000) / (half_dim - 1)
178
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb)
179
+ emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0)
180
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
181
+ if embedding_dim % 2 == 1:
182
+ # zero pad
183
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
184
+ if padding_idx is not None:
185
+ emb[padding_idx, :] = 0
186
+
187
+ return emb.to(torch.get_default_dtype())
188
+
189
+ @torch.no_grad()
190
+ def forward(
191
+ self,
192
+ input_ids: Optional[torch.Tensor] = None,
193
+ inputs_embeds: Optional[torch.Tensor] = None,
194
+ past_key_values_length: int = 0,
195
+ ):
196
+ if input_ids is not None:
197
+ bsz, seq_len = input_ids.size()
198
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
199
+ position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length).to(
200
+ input_ids.device
201
+ )
202
+ else:
203
+ bsz, seq_len = inputs_embeds.size()[:-1]
204
+ position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds, past_key_values_length)
205
+
206
+ # expand embeddings if needed
207
+ max_pos = self.padding_idx + 1 + seq_len + past_key_values_length
208
+ if max_pos > self.weights.size(0):
209
+ self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx)
210
+
211
+ return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, self.weights.shape[-1]).detach()
212
+
213
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds, past_key_values_length):
214
+ """
215
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
216
+
217
+ Args:
218
+ inputs_embeds: torch.Tensor
219
+
220
+ Returns: torch.Tensor
221
+ """
222
+ input_shape = inputs_embeds.size()[:-1]
223
+ sequence_length = input_shape[1]
224
+
225
+ position_ids = torch.arange(
226
+ self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
227
+ )
228
+ return position_ids.unsqueeze(0).expand(input_shape).contiguous() + past_key_values_length
229
+
230
+
231
+ class NllbMoeTop2Router(nn.Module):
232
+ """
233
+ Router using tokens choose top-2 experts assignment.
234
+
235
+ This router uses the same mechanism as in NLLB-MoE from the fairseq repository. Items are sorted by router_probs
236
+ and then routed to their choice of expert until the expert's expert_capacity is reached. **There is no guarantee
237
+ that each token is processed by an expert**, or that each expert receives at least one token.
238
+
239
+ The router combining weights are also returned to make sure that the states that are not updated will be masked.
240
+
241
+ """
242
+
243
+ def __init__(self, config: NllbMoeConfig):
244
+ super().__init__()
245
+ self.num_experts = config.num_experts
246
+ self.expert_capacity = config.expert_capacity
247
+ self.classifier = nn.Linear(config.hidden_size, self.num_experts, bias=config.router_bias)
248
+ self.router_ignore_padding_tokens = config.router_ignore_padding_tokens
249
+ self.dtype = getattr(torch, config.router_dtype)
250
+
251
+ self.second_expert_policy = config.second_expert_policy
252
+ self.normalize_router_prob_before_dropping = config.normalize_router_prob_before_dropping
253
+ self.batch_prioritized_routing = config.batch_prioritized_routing
254
+ self.moe_eval_capacity_token_fraction = config.moe_eval_capacity_token_fraction
255
+
256
+ def _cast_classifier(self):
257
+ r"""
258
+ `bitsandbytes` `Linear8bitLt` layers does not support manual casting Therefore we need to check if they are an
259
+ instance of the `Linear8bitLt` class by checking special attributes.
260
+ """
261
+ if not (hasattr(self.classifier, "SCB") or hasattr(self.classifier, "CB")):
262
+ self.classifier = self.classifier.to(self.dtype)
263
+
264
+ def normalize_router_probabilities(self, router_probs, top_1_mask, top_2_mask):
265
+ top_1_max_probs = (router_probs * top_1_mask).sum(dim=1)
266
+ top_2_max_probs = (router_probs * top_2_mask).sum(dim=1)
267
+ denom_s = torch.clamp(top_1_max_probs + top_2_max_probs, min=torch.finfo(router_probs.dtype).eps)
268
+ top_1_max_probs = top_1_max_probs / denom_s
269
+ top_2_max_probs = top_2_max_probs / denom_s
270
+ return top_1_max_probs, top_2_max_probs
271
+
272
+ def route_tokens(
273
+ self,
274
+ router_logits: torch.Tensor,
275
+ input_dtype: torch.dtype = torch.float32,
276
+ padding_mask: Optional[torch.LongTensor] = None,
277
+ ) -> Tuple:
278
+ """
279
+ Computes the `dispatch_mask` and the `dispatch_weights` for each experts. The masks are adapted to the expert
280
+ capacity.
281
+ """
282
+ nb_tokens = router_logits.shape[0]
283
+ # Apply Softmax and cast back to the original `dtype`
284
+ router_probs = nn.functional.softmax(router_logits, dim=-1, dtype=self.dtype).to(input_dtype)
285
+ top_1_expert_index = torch.argmax(router_probs, dim=-1)
286
+ top_1_mask = torch.nn.functional.one_hot(top_1_expert_index, num_classes=self.num_experts)
287
+
288
+ if self.second_expert_policy == "sampling":
289
+ gumbel = torch.distributions.gumbel.Gumbel(0, 1).rsample
290
+ router_logits += gumbel(router_logits.shape).to(router_logits.device)
291
+
292
+ # replace top_1_expert_index with min values
293
+ logits_except_top_1 = router_logits.masked_fill(top_1_mask.bool(), float("-inf"))
294
+ top_2_expert_index = torch.argmax(logits_except_top_1, dim=-1)
295
+ top_2_mask = torch.nn.functional.one_hot(top_2_expert_index, num_classes=self.num_experts)
296
+
297
+ if self.normalize_router_prob_before_dropping:
298
+ top_1_max_probs, top_2_max_probs = self.normalize_router_probabilities(
299
+ router_probs, top_1_mask, top_2_mask
300
+ )
301
+
302
+ if self.second_expert_policy == "random":
303
+ top_2_max_probs = (router_probs * top_2_mask).sum(dim=1)
304
+ sampled = (2 * top_2_max_probs) > torch.rand_like(top_2_max_probs.float())
305
+ top_2_mask = top_2_mask * sampled.repeat(self.num_experts, 1).transpose(1, 0)
306
+
307
+ if padding_mask is not None and not self.router_ignore_padding_tokens:
308
+ if len(padding_mask.shape) == 4:
309
+ # only get the last causal mask
310
+ padding_mask = padding_mask[:, :, -1, :].reshape(-1)[-nb_tokens:]
311
+ non_padding = ~padding_mask.bool()
312
+ top_1_mask = top_1_mask * non_padding.unsqueeze(-1).to(top_1_mask.dtype)
313
+ top_2_mask = top_2_mask * non_padding.unsqueeze(-1).to(top_1_mask.dtype)
314
+
315
+ if self.batch_prioritized_routing:
316
+ # sort tokens based on their routing probability
317
+ # to make sure important tokens are routed, first
318
+ importance_scores = -1 * router_probs.max(dim=1)[0]
319
+ sorted_top_1_mask = top_1_mask[importance_scores.argsort(dim=0)]
320
+ sorted_cumsum1 = (torch.cumsum(sorted_top_1_mask, dim=0) - 1) * sorted_top_1_mask
321
+ locations1 = sorted_cumsum1[importance_scores.argsort(dim=0).argsort(dim=0)]
322
+
323
+ sorted_top_2_mask = top_2_mask[importance_scores.argsort(dim=0)]
324
+ sorted_cumsum2 = (torch.cumsum(sorted_top_2_mask, dim=0) - 1) * sorted_top_2_mask
325
+ locations2 = sorted_cumsum2[importance_scores.argsort(dim=0).argsort(dim=0)]
326
+ # Update 2nd's location by accounting for locations of 1st
327
+ locations2 += torch.sum(top_1_mask, dim=0, keepdim=True)
328
+
329
+ else:
330
+ locations1 = torch.cumsum(top_1_mask, dim=0) - 1
331
+ locations2 = torch.cumsum(top_2_mask, dim=0) - 1
332
+ # Update 2nd's location by accounting for locations of 1st
333
+ locations2 += torch.sum(top_1_mask, dim=0, keepdim=True)
334
+
335
+ if not self.training and self.moe_eval_capacity_token_fraction > 0:
336
+ self.expert_capacity = math.ceil(self.moe_eval_capacity_token_fraction * nb_tokens)
337
+ else:
338
+ capacity = 2 * math.ceil(nb_tokens / self.num_experts)
339
+ self.expert_capacity = capacity if self.expert_capacity is None else self.expert_capacity
340
+
341
+ # Remove locations outside capacity from ( cumsum < capacity = False will not be routed)
342
+ top_1_mask = top_1_mask * torch.lt(locations1, self.expert_capacity)
343
+ top_2_mask = top_2_mask * torch.lt(locations2, self.expert_capacity)
344
+
345
+ if not self.normalize_router_prob_before_dropping:
346
+ top_1_max_probs, top_2_max_probs = self.normalize_router_probabilities(
347
+ router_probs, top_1_mask, top_2_mask
348
+ )
349
+
350
+ # Calculate combine_weights and dispatch_mask
351
+ gates1 = top_1_max_probs[:, None] * top_1_mask
352
+ gates2 = top_2_max_probs[:, None] * top_2_mask
353
+ router_probs = gates1 + gates2
354
+
355
+ return top_1_mask, router_probs
356
+
357
+ def forward(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.LongTensor] = None) -> Tuple:
358
+ r"""
359
+ The hidden states are reshaped to simplify the computation of the router probabilities (combining weights for
360
+ each experts.)
361
+
362
+ Args:
363
+ hidden_states (`torch.Tensor`):
364
+ (batch_size, sequence_length, hidden_dim) from which router probabilities are computed.
365
+ Returns:
366
+ top_1_mask (`torch.Tensor` of shape (batch_size, sequence_length)):
367
+ Index tensor of shape [batch_size, sequence_length] corresponding to the expert selected for each token
368
+ using the top1 probabilities of the router.
369
+ router_probabilities (`torch.Tensor` of shape (batch_size, sequence_length, nump_experts)):
370
+ Tensor of shape (batch_size, sequence_length, num_experts) corresponding to the probabilities for each
371
+ token and expert. Used for routing tokens to experts.
372
+ router_logits (`torch.Tensor` of shape (batch_size, sequence_length))):
373
+ Logits tensor of shape (batch_size, sequence_length, num_experts) corresponding to raw router logits.
374
+ This is used later for computing router z-loss.
375
+ """
376
+ self.input_dtype = hidden_states.dtype
377
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
378
+ hidden_states = hidden_states.reshape((batch_size * sequence_length), hidden_dim)
379
+ hidden_states = hidden_states.to(self.dtype)
380
+ self._cast_classifier()
381
+ router_logits = self.classifier(hidden_states)
382
+ top_1_mask, router_probs = self.route_tokens(router_logits, self.input_dtype, padding_mask)
383
+ return top_1_mask, router_probs
384
+
385
+
386
+ class NllbMoeDenseActDense(nn.Module):
387
+ def __init__(self, config: NllbMoeConfig, ffn_dim: int):
388
+ super().__init__()
389
+ self.fc1 = nn.Linear(config.d_model, ffn_dim)
390
+ self.fc2 = nn.Linear(ffn_dim, config.d_model)
391
+ self.dropout = nn.Dropout(config.activation_dropout)
392
+ self.act = ACT2FN[config.activation_function]
393
+
394
+ def forward(self, hidden_states):
395
+ hidden_states = self.fc1(hidden_states)
396
+ hidden_states = self.act(hidden_states)
397
+ hidden_states = self.dropout(hidden_states)
398
+ if (
399
+ isinstance(self.fc2.weight, torch.Tensor)
400
+ and hidden_states.dtype != self.fc2.weight.dtype
401
+ and (self.fc2.weight.dtype != torch.int8 and self.fc2.weight.dtype != torch.uint8)
402
+ ):
403
+ hidden_states = hidden_states.to(self.fc2.weight.dtype)
404
+ hidden_states = self.fc2(hidden_states)
405
+ return hidden_states
406
+
407
+
408
+ class NllbMoeSparseMLP(nn.Module):
409
+ r"""
410
+ Implementation of the NLLB-MoE sparse MLP module.
411
+ """
412
+
413
+ def __init__(self, config: NllbMoeConfig, ffn_dim: int, expert_class: nn.Module = NllbMoeDenseActDense):
414
+ super().__init__()
415
+ self.router = NllbMoeTop2Router(config)
416
+ self.moe_token_dropout = config.moe_token_dropout
417
+ self.token_dropout = nn.Dropout(self.moe_token_dropout)
418
+ self.num_experts = config.num_experts
419
+
420
+ self.experts = nn.ModuleDict()
421
+ for idx in range(self.num_experts):
422
+ self.experts[f"expert_{idx}"] = expert_class(config, ffn_dim)
423
+
424
+ def forward(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.Tensor] = False):
425
+ r"""
426
+ The goal of this forward pass is to have the same number of operation as the equivalent `NllbMoeDenseActDense`
427
+ (mlp) layer. This means that all of the hidden states should be processed at most twice ( since we are using a
428
+ top_2 gating mecanism). This means that we keep the complexity to O(batch_size x sequence_length x hidden_dim)
429
+ instead of O(num_experts x batch_size x sequence_length x hidden_dim).
430
+
431
+ 1- Get the `router_probs` from the `router`. The shape of the `router_mask` is `(batch_size X sequence_length,
432
+ num_expert)` and corresponds to the boolean version of the `router_probs`. The inputs are masked using the
433
+ `router_mask`.
434
+
435
+ 2- Dispatch the hidden_states to its associated experts. The router probabilities are used to weight the
436
+ contribution of each experts when updating the masked hidden states.
437
+
438
+ Args:
439
+ hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_dim)`):
440
+ The hidden states
441
+ padding_mask (`torch.Tensor`, *optional*, defaults to `False`):
442
+ Attention mask. Can be in the causal form or not.
443
+
444
+ Returns:
445
+ hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_dim)`):
446
+ Updated hidden states
447
+ router_logits (`torch.Tensor` of shape `(batch_size, sequence_length, num_experts)`):
448
+ Needed for computing the loss
449
+
450
+ """
451
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
452
+
453
+ top_1_mask, router_probs = self.router(hidden_states, padding_mask)
454
+ router_mask = router_probs.bool()
455
+ hidden_states = hidden_states.reshape((batch_size * sequence_length), hidden_dim)
456
+ masked_hidden_states = torch.einsum("bm,be->ebm", hidden_states, router_mask)
457
+ for idx, expert in enumerate(self.experts.values()):
458
+ token_indices = router_mask[:, idx]
459
+ combining_weights = router_probs[token_indices, idx]
460
+ expert_output = expert(masked_hidden_states[idx, token_indices])
461
+ if self.moe_token_dropout > 0:
462
+ if self.training:
463
+ expert_output = self.token_dropout(expert_output)
464
+ else:
465
+ expert_output *= 1 - self.moe_token_dropout
466
+ masked_hidden_states[idx, token_indices] = torch.einsum("b,be->be", combining_weights, expert_output)
467
+ hidden_states = masked_hidden_states.sum(dim=0).reshape(batch_size, sequence_length, hidden_dim)
468
+
469
+ top_1_expert_index = torch.argmax(top_1_mask, dim=-1)
470
+ return hidden_states, (router_probs, top_1_expert_index)
471
+
472
+
473
+ # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->NllbMoe,key_value_states->encoder_hidden_states
474
+ class NllbMoeAttention(nn.Module):
475
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
476
+
477
+ def __init__(
478
+ self,
479
+ embed_dim: int,
480
+ num_heads: int,
481
+ dropout: float = 0.0,
482
+ is_decoder: bool = False,
483
+ bias: bool = True,
484
+ is_causal: bool = False,
485
+ config: Optional[NllbMoeConfig] = None,
486
+ ):
487
+ super().__init__()
488
+ self.embed_dim = embed_dim
489
+ self.num_heads = num_heads
490
+ self.dropout = dropout
491
+ self.head_dim = embed_dim // num_heads
492
+ self.config = config
493
+
494
+ if (self.head_dim * num_heads) != self.embed_dim:
495
+ raise ValueError(
496
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
497
+ f" and `num_heads`: {num_heads})."
498
+ )
499
+ self.scaling = self.head_dim**-0.5
500
+ self.is_decoder = is_decoder
501
+ self.is_causal = is_causal
502
+
503
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
504
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
505
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
506
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
507
+
508
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
509
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
510
+
511
+ def forward(
512
+ self,
513
+ hidden_states: torch.Tensor,
514
+ encoder_hidden_states: Optional[torch.Tensor] = None,
515
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
516
+ attention_mask: Optional[torch.Tensor] = None,
517
+ layer_head_mask: Optional[torch.Tensor] = None,
518
+ output_attentions: bool = False,
519
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
520
+ """Input shape: Batch x Time x Channel"""
521
+
522
+ # if encoder_hidden_states are provided this layer is used as a cross-attention layer
523
+ # for the decoder
524
+ is_cross_attention = encoder_hidden_states is not None
525
+
526
+ bsz, tgt_len, _ = hidden_states.size()
527
+
528
+ # get query proj
529
+ query_states = self.q_proj(hidden_states) * self.scaling
530
+ # get key, value proj
531
+ # `past_key_value[0].shape[2] == encoder_hidden_states.shape[1]`
532
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
533
+ # the provided `encoder_hidden_states` to support prefix tuning
534
+ if (
535
+ is_cross_attention
536
+ and past_key_value is not None
537
+ and past_key_value[0].shape[2] == encoder_hidden_states.shape[1]
538
+ ):
539
+ # reuse k,v, cross_attentions
540
+ key_states = past_key_value[0]
541
+ value_states = past_key_value[1]
542
+ elif is_cross_attention:
543
+ # cross_attentions
544
+ key_states = self._shape(self.k_proj(encoder_hidden_states), -1, bsz)
545
+ value_states = self._shape(self.v_proj(encoder_hidden_states), -1, bsz)
546
+ elif past_key_value is not None:
547
+ # reuse k, v, self_attention
548
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
549
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
550
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
551
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
552
+ else:
553
+ # self_attention
554
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
555
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
556
+
557
+ if self.is_decoder:
558
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
559
+ # Further calls to cross_attention layer can then reuse all cross-attention
560
+ # key/value_states (first "if" case)
561
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
562
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
563
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
564
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
565
+ past_key_value = (key_states, value_states)
566
+
567
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
568
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
569
+ key_states = key_states.reshape(*proj_shape)
570
+ value_states = value_states.reshape(*proj_shape)
571
+
572
+ src_len = key_states.size(1)
573
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
574
+
575
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
576
+ raise ValueError(
577
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
578
+ f" {attn_weights.size()}"
579
+ )
580
+
581
+ if attention_mask is not None:
582
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
583
+ raise ValueError(
584
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
585
+ )
586
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
587
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
588
+
589
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
590
+
591
+ if layer_head_mask is not None:
592
+ if layer_head_mask.size() != (self.num_heads,):
593
+ raise ValueError(
594
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
595
+ f" {layer_head_mask.size()}"
596
+ )
597
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
598
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
599
+
600
+ if output_attentions:
601
+ # this operation is a bit awkward, but it's required to
602
+ # make sure that attn_weights keeps its gradient.
603
+ # In order to do so, attn_weights have to be reshaped
604
+ # twice and have to be reused in the following
605
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
606
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
607
+ else:
608
+ attn_weights_reshaped = None
609
+
610
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
611
+
612
+ attn_output = torch.bmm(attn_probs, value_states)
613
+
614
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
615
+ raise ValueError(
616
+ f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
617
+ f" {attn_output.size()}"
618
+ )
619
+
620
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
621
+ attn_output = attn_output.transpose(1, 2)
622
+
623
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
624
+ # partitioned across GPUs when using tensor-parallelism.
625
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
626
+
627
+ attn_output = self.out_proj(attn_output)
628
+
629
+ return attn_output, attn_weights_reshaped, past_key_value
630
+
631
+
632
+ class NllbMoeEncoderLayer(nn.Module):
633
+ def __init__(self, config: NllbMoeConfig, is_sparse: bool = False):
634
+ super().__init__()
635
+ self.embed_dim = config.d_model
636
+ self.is_sparse = is_sparse
637
+ self.self_attn = NllbMoeAttention(
638
+ embed_dim=self.embed_dim,
639
+ num_heads=config.encoder_attention_heads,
640
+ dropout=config.attention_dropout,
641
+ )
642
+ self.attn_dropout = nn.Dropout(config.dropout)
643
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
644
+ if not self.is_sparse:
645
+ self.ffn = NllbMoeDenseActDense(config, ffn_dim=config.encoder_ffn_dim)
646
+ else:
647
+ self.ffn = NllbMoeSparseMLP(config, ffn_dim=config.encoder_ffn_dim)
648
+ self.ff_layer_norm = nn.LayerNorm(config.d_model)
649
+ self.ff_dropout = nn.Dropout(config.activation_dropout)
650
+
651
+ def forward(
652
+ self,
653
+ hidden_states: torch.Tensor,
654
+ attention_mask: torch.Tensor,
655
+ layer_head_mask: torch.Tensor,
656
+ output_attentions: bool = False,
657
+ output_router_logits: bool = False,
658
+ ) -> torch.Tensor:
659
+ """
660
+ Args:
661
+ hidden_states (`torch.FloatTensor`):
662
+ input to the layer of shape `(batch, seq_len, embed_dim)`
663
+ attention_mask (`torch.FloatTensor`):
664
+ attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very
665
+ large negative values.
666
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
667
+ `(encoder_attention_heads,)`.
668
+ output_attentions (`bool`, *optional*):
669
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
670
+ returned tensors for more detail.
671
+ """
672
+ residual = hidden_states
673
+ hidden_states = self.self_attn_layer_norm(hidden_states)
674
+ hidden_states, attn_weights, _ = self.self_attn(
675
+ hidden_states=hidden_states,
676
+ attention_mask=attention_mask,
677
+ layer_head_mask=layer_head_mask,
678
+ output_attentions=output_attentions,
679
+ )
680
+ hidden_states = self.attn_dropout(hidden_states)
681
+ hidden_states = residual + hidden_states
682
+
683
+ residual = hidden_states
684
+
685
+ hidden_states = self.ff_layer_norm(hidden_states)
686
+ if self.is_sparse:
687
+ hidden_states, router_states = self.ffn(hidden_states, attention_mask)
688
+ else:
689
+ # router_states set to None to track which layers have None gradients.
690
+ hidden_states, router_states = self.ffn(hidden_states), None
691
+
692
+ hidden_states = self.ff_dropout(hidden_states)
693
+
694
+ hidden_states = residual + hidden_states
695
+
696
+ if hidden_states.dtype == torch.float16 and (
697
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
698
+ ):
699
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
700
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
701
+
702
+ outputs = (hidden_states,)
703
+
704
+ if output_attentions:
705
+ outputs += (attn_weights,)
706
+
707
+ if output_router_logits:
708
+ outputs += (router_states,)
709
+
710
+ return outputs
711
+
712
+
713
+ class NllbMoeDecoderLayer(nn.Module):
714
+ def __init__(self, config: NllbMoeConfig, is_sparse: bool = False):
715
+ super().__init__()
716
+ self.embed_dim = config.d_model
717
+ self.is_sparse = is_sparse
718
+ self.self_attn = NllbMoeAttention(
719
+ embed_dim=self.embed_dim,
720
+ num_heads=config.decoder_attention_heads,
721
+ dropout=config.attention_dropout,
722
+ is_decoder=True,
723
+ )
724
+ self.dropout = config.dropout
725
+ self.activation_fn = ACT2FN[config.activation_function]
726
+ self.attn_dropout = nn.Dropout(config.dropout)
727
+
728
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
729
+ self.cross_attention = NllbMoeAttention(
730
+ self.embed_dim, config.decoder_attention_heads, config.attention_dropout, is_decoder=True
731
+ )
732
+ self.cross_attention_layer_norm = nn.LayerNorm(self.embed_dim)
733
+ if not self.is_sparse:
734
+ self.ffn = NllbMoeDenseActDense(config, ffn_dim=config.decoder_ffn_dim)
735
+ else:
736
+ self.ffn = NllbMoeSparseMLP(config, ffn_dim=config.decoder_ffn_dim)
737
+ self.ff_layer_norm = nn.LayerNorm(config.d_model)
738
+ self.ff_dropout = nn.Dropout(config.activation_dropout)
739
+
740
+ def forward(
741
+ self,
742
+ hidden_states: torch.Tensor,
743
+ attention_mask: Optional[torch.Tensor] = None,
744
+ encoder_hidden_states: Optional[torch.Tensor] = None,
745
+ encoder_attention_mask: Optional[torch.Tensor] = None,
746
+ layer_head_mask: Optional[torch.Tensor] = None,
747
+ cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
748
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
749
+ output_attentions: Optional[bool] = False,
750
+ output_router_logits: Optional[bool] = False,
751
+ use_cache: Optional[bool] = True,
752
+ ) -> torch.Tensor:
753
+ """
754
+ Args:
755
+ hidden_states (`torch.FloatTensor`):
756
+ input to the layer of shape `(batch, seq_len, embed_dim)`
757
+ attention_mask (`torch.FloatTensor`):
758
+ attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very
759
+ large negative values.
760
+ encoder_hidden_states (`torch.FloatTensor`):
761
+ cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
762
+ encoder_attention_mask (`torch.FloatTensor`):
763
+ encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by
764
+ very large negative values.
765
+ layer_head_mask (`torch.FloatTensor`):
766
+ mask for attention heads in a given layer of size `(encoder_attention_heads,)`.
767
+ cross_attn_layer_head_mask (`torch.FloatTensor`):
768
+ mask for cross-attention heads in a given layer of size `(decoder_attention_heads,)`.
769
+ past_key_value (`Tuple(torch.FloatTensor)`):
770
+ cached past key and value projection states
771
+ output_attentions (`bool`, *optional*):
772
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
773
+ returned tensors for more detail.
774
+ """
775
+ residual = hidden_states
776
+ hidden_states = self.self_attn_layer_norm(hidden_states)
777
+
778
+ # Self Attention
779
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
780
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
781
+ # add present self-attn cache to positions 1,2 of present_key_value tuple
782
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
783
+ hidden_states=hidden_states,
784
+ past_key_value=self_attn_past_key_value,
785
+ attention_mask=attention_mask,
786
+ layer_head_mask=layer_head_mask,
787
+ output_attentions=output_attentions,
788
+ )
789
+ hidden_states = self.attn_dropout(hidden_states)
790
+ hidden_states = residual + hidden_states
791
+
792
+ # Cross-Attention Block
793
+ cross_attn_present_key_value = None
794
+ cross_attn_weights = None
795
+ if encoder_hidden_states is not None:
796
+ residual = hidden_states
797
+ hidden_states = self.cross_attention_layer_norm(hidden_states)
798
+
799
+ # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
800
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
801
+ hidden_states, cross_attn_weights, cross_attn_present_key_value = self.cross_attention(
802
+ hidden_states=hidden_states,
803
+ encoder_hidden_states=encoder_hidden_states,
804
+ past_key_value=cross_attn_past_key_value,
805
+ attention_mask=encoder_attention_mask,
806
+ layer_head_mask=cross_attn_layer_head_mask,
807
+ output_attentions=output_attentions,
808
+ )
809
+ hidden_states = self.attn_dropout(hidden_states)
810
+ hidden_states = residual + hidden_states
811
+
812
+ # add cross-attn to positions 3,4 of present_key_value tuple
813
+ present_key_value += cross_attn_present_key_value
814
+
815
+ # Fully Connected
816
+ residual = hidden_states
817
+
818
+ hidden_states = self.ff_layer_norm(hidden_states)
819
+ if self.is_sparse:
820
+ hidden_states, router_states = self.ffn(hidden_states, attention_mask)
821
+ else:
822
+ hidden_states, router_states = self.ffn(hidden_states), None
823
+
824
+ hidden_states = self.ff_dropout(hidden_states)
825
+
826
+ hidden_states = residual + hidden_states
827
+
828
+ # clamp inf values to enable fp16 training
829
+ if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
830
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
831
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
832
+
833
+ outputs = (hidden_states, present_key_value)
834
+
835
+ if output_attentions:
836
+ outputs += (self_attn_weights, cross_attn_weights)
837
+
838
+ if output_router_logits:
839
+ outputs += (router_states,)
840
+
841
+ return outputs
842
+
843
+
844
+ class NllbMoePreTrainedModel(PreTrainedModel):
845
+ config_class = NllbMoeConfig
846
+ base_model_prefix = "model"
847
+ supports_gradient_checkpointing = True
848
+ _no_split_modules = ["NllbMoeEncoderLayer", "NllbMoeDecoderLayer"]
849
+
850
+ def _init_weights(self, module):
851
+ """Initialize the weights"""
852
+ std = self.config.init_std
853
+ if isinstance(module, nn.Linear):
854
+ module.weight.data.normal_(mean=0.0, std=std)
855
+ if module.bias is not None:
856
+ module.bias.data.zero_()
857
+ elif isinstance(module, nn.Embedding):
858
+ module.weight.data.normal_(mean=0.0, std=std)
859
+ if module.padding_idx is not None:
860
+ module.weight.data[module.padding_idx].zero_()
861
+
862
+
863
+ NLLB_MOE_START_DOCSTRING = r"""
864
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
865
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
866
+ etc.)
867
+
868
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
869
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
870
+ and behavior.
871
+
872
+ Parameters:
873
+ config ([`NllbMoeConfig`]):
874
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
875
+ load the weights associated with the model, only the configuration. Check out the
876
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
877
+ """
878
+
879
+ NLLB_MOE_GENERATION_EXAMPLE = r"""
880
+ Translation example:
881
+
882
+ ```python
883
+ >>> from transformers import AutoTokenizer, NllbMoeForConditionalGeneration
884
+
885
+ >>> model = NllbMoeForConditionalGeneration.from_pretrained("facebook/nllb-moe-54b")
886
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-moe-54b")
887
+
888
+ >>> text_to_translate = "Life is like a box of chocolates"
889
+ >>> model_inputs = tokenizer(text_to_translate, return_tensors="pt")
890
+
891
+ >>> # translate to French
892
+ >>> gen_tokens = model.generate(**model_inputs, forced_bos_token_id=tokenizer.get_lang_id("eng_Latn"))
893
+ >>> print(tokenizer.batch_decode(gen_tokens, skip_special_tokens=True))
894
+ ```
895
+ """
896
+
897
+ NLLB_MOE_INPUTS_DOCSTRING = r"""
898
+ Args:
899
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
900
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
901
+ it.
902
+
903
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
904
+ [`PreTrainedTokenizer.__call__`] for details.
905
+
906
+ [What are input IDs?](../glossary#input-ids)
907
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
908
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
909
+
910
+ - 1 for tokens that are **not masked**,
911
+ - 0 for tokens that are **masked**.
912
+
913
+ [What are attention masks?](../glossary#attention-mask)
914
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
915
+ Indices of decoder input sequence tokens in the vocabulary.
916
+
917
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
918
+ [`PreTrainedTokenizer.__call__`] for details.
919
+
920
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
921
+
922
+ NllbMoe uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If
923
+ `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
924
+ `past_key_values`).
925
+ decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
926
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
927
+ be used by default.
928
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
929
+ Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
930
+
931
+ - 1 indicates the head is **not masked**,
932
+ - 0 indicates the head is **masked**.
933
+
934
+ decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
935
+ Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
936
+
937
+ - 1 indicates the head is **not masked**,
938
+ - 0 indicates the head is **masked**.
939
+
940
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
941
+ Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
942
+ 1]`:
943
+
944
+ - 1 indicates the head is **not masked**,
945
+ - 0 indicates the head is **masked**.
946
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
947
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
948
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
949
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
950
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
951
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
952
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
953
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
954
+
955
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
956
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
957
+
958
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
959
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
960
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
961
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
962
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
963
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
964
+ than the model's internal embedding lookup matrix.
965
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
966
+ Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
967
+ representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
968
+ input (see `past_key_values`). This is useful if you want more control over how to convert
969
+ `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
970
+
971
+ If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
972
+ of `inputs_embeds`.
973
+ use_cache (`bool`, *optional*):
974
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
975
+ `past_key_values`).
976
+ output_attentions (`bool`, *optional*):
977
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
978
+ tensors for more detail.
979
+ output_hidden_states (`bool`, *optional*):
980
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
981
+ more detail.
982
+ output_router_logits (`bool`, *optional*):
983
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
984
+ should not be returned during inference.
985
+ return_dict (`bool`, *optional*):
986
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
987
+ """
988
+
989
+
990
+ class NllbMoeEncoder(NllbMoePreTrainedModel):
991
+ """
992
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
993
+ [`NllbMoeEncoderLayer`].
994
+
995
+ Args:
996
+ config:
997
+ NllbMoeConfig
998
+ embed_tokens (nn.Embedding):
999
+ output embedding
1000
+ """
1001
+
1002
+ def __init__(self, config: NllbMoeConfig, embed_tokens: Optional[nn.Embedding] = None):
1003
+ super().__init__(config)
1004
+
1005
+ self.dropout = config.dropout
1006
+ self.layerdrop = config.encoder_layerdrop
1007
+
1008
+ embed_dim = config.d_model
1009
+ self.padding_idx = config.pad_token_id
1010
+ self.max_source_positions = config.max_position_embeddings
1011
+ embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
1012
+
1013
+ self.embed_tokens = NllbMoeScaledWordEmbedding(
1014
+ config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
1015
+ )
1016
+
1017
+ if embed_tokens is not None:
1018
+ self.embed_tokens.weight = embed_tokens.weight
1019
+
1020
+ self.embed_positions = NllbMoeSinusoidalPositionalEmbedding(
1021
+ config.max_position_embeddings,
1022
+ embed_dim,
1023
+ self.padding_idx,
1024
+ )
1025
+ sparse_step = config.encoder_sparse_step
1026
+ self.layers = nn.ModuleList()
1027
+ for i in range(config.encoder_layers):
1028
+ is_sparse = (i + 1) % sparse_step == 0 if sparse_step > 0 else False
1029
+ self.layers.append(NllbMoeEncoderLayer(config, is_sparse))
1030
+
1031
+ self.layer_norm = nn.LayerNorm(config.d_model)
1032
+
1033
+ self.gradient_checkpointing = False
1034
+ # Initialize weights and apply final processing
1035
+ self.post_init()
1036
+
1037
+ def forward(
1038
+ self,
1039
+ input_ids: Optional[torch.Tensor] = None,
1040
+ attention_mask: Optional[torch.Tensor] = None,
1041
+ head_mask: Optional[torch.Tensor] = None,
1042
+ inputs_embeds: Optional[torch.Tensor] = None,
1043
+ output_attentions: Optional[bool] = None,
1044
+ output_hidden_states: Optional[bool] = None,
1045
+ output_router_logits: Optional[bool] = None,
1046
+ return_dict: Optional[bool] = None,
1047
+ ):
1048
+ r"""
1049
+ Args:
1050
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1051
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
1052
+ provide it.
1053
+
1054
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1055
+ [`PreTrainedTokenizer.__call__`] for details.
1056
+
1057
+ [What are input IDs?](../glossary#input-ids)
1058
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1059
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1060
+
1061
+ - 1 for tokens that are **not masked**,
1062
+ - 0 for tokens that are **masked**.
1063
+
1064
+ [What are attention masks?](../glossary#attention-mask)
1065
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
1066
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
1067
+
1068
+ - 1 indicates the head is **not masked**,
1069
+ - 0 indicates the head is **masked**.
1070
+
1071
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1072
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
1073
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
1074
+ than the model's internal embedding lookup matrix.
1075
+ output_attentions (`bool`, *optional*):
1076
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1077
+ returned tensors for more detail.
1078
+ output_hidden_states (`bool`, *optional*):
1079
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1080
+ for more detail.
1081
+ output_router_logits (`bool`, *optional*):
1082
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss,
1083
+ and should not be returned during inference.
1084
+ return_dict (`bool`, *optional*):
1085
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1086
+ """
1087
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1088
+ output_hidden_states = (
1089
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1090
+ )
1091
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1092
+
1093
+ # retrieve input_ids and inputs_embeds
1094
+ if input_ids is not None and inputs_embeds is not None:
1095
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1096
+ elif input_ids is not None:
1097
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
1098
+ input_shape = input_ids.size()
1099
+ input_ids = input_ids.view(-1, input_shape[-1])
1100
+ elif inputs_embeds is not None:
1101
+ input_shape = inputs_embeds.size()[:-1]
1102
+ else:
1103
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1104
+
1105
+ if inputs_embeds is None:
1106
+ inputs_embeds = self.embed_tokens(input_ids)
1107
+
1108
+ embed_pos = self.embed_positions(input_ids, inputs_embeds)
1109
+ embed_pos = embed_pos.to(inputs_embeds.device)
1110
+
1111
+ hidden_states = inputs_embeds + embed_pos
1112
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
1113
+
1114
+ # expand attention_mask
1115
+ if attention_mask is not None:
1116
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1117
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
1118
+
1119
+ encoder_states = () if output_hidden_states else None
1120
+ all_router_probs = () if output_router_logits else None
1121
+ all_attentions = () if output_attentions else None
1122
+
1123
+ # check if head_mask has a correct number of layers specified if desired
1124
+ if head_mask is not None:
1125
+ if head_mask.size()[0] != len(self.layers):
1126
+ raise ValueError(
1127
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
1128
+ f" {head_mask.size()[0]}."
1129
+ )
1130
+
1131
+ for idx, encoder_layer in enumerate(self.layers):
1132
+ if output_hidden_states:
1133
+ encoder_states = encoder_states + (hidden_states,)
1134
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
1135
+ dropout_probability = torch.rand([])
1136
+ if self.training and (dropout_probability < self.layerdrop): # skip the layer
1137
+ layer_outputs = (None, None, None)
1138
+ else:
1139
+ if self.gradient_checkpointing and self.training:
1140
+ layer_outputs = self._gradient_checkpointing_func(
1141
+ encoder_layer.__call__,
1142
+ hidden_states,
1143
+ attention_mask,
1144
+ (head_mask[idx] if head_mask is not None else None),
1145
+ output_attentions,
1146
+ )
1147
+ else:
1148
+ layer_outputs = encoder_layer(
1149
+ hidden_states,
1150
+ attention_mask,
1151
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
1152
+ output_attentions=output_attentions,
1153
+ output_router_logits=output_router_logits,
1154
+ )
1155
+
1156
+ hidden_states = layer_outputs[0]
1157
+
1158
+ if output_attentions:
1159
+ all_attentions += (layer_outputs[1],)
1160
+
1161
+ if output_router_logits:
1162
+ all_router_probs += (layer_outputs[-1],)
1163
+
1164
+ last_hidden_state = self.layer_norm(hidden_states)
1165
+
1166
+ if output_hidden_states:
1167
+ encoder_states += (last_hidden_state,)
1168
+
1169
+ if not return_dict:
1170
+ return tuple(
1171
+ v for v in [last_hidden_state, encoder_states, all_attentions, all_router_probs] if v is not None
1172
+ )
1173
+
1174
+ return MoEModelOutput(
1175
+ last_hidden_state=last_hidden_state,
1176
+ hidden_states=encoder_states,
1177
+ attentions=all_attentions,
1178
+ router_probs=all_router_probs,
1179
+ )
1180
+
1181
+
1182
+ class NllbMoeDecoder(NllbMoePreTrainedModel):
1183
+ """
1184
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`NllbMoeDecoderLayer`]
1185
+
1186
+ Args:
1187
+ config:
1188
+ NllbMoeConfig
1189
+ embed_tokens (nn.Embedding):
1190
+ output embedding
1191
+ """
1192
+
1193
+ def __init__(self, config: NllbMoeConfig, embed_tokens: Optional[nn.Embedding] = None):
1194
+ super().__init__(config)
1195
+ self.dropout = config.dropout
1196
+ self.layerdrop = config.decoder_layerdrop
1197
+ self.padding_idx = config.pad_token_id
1198
+ self.max_target_positions = config.max_position_embeddings
1199
+ embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
1200
+
1201
+ self.embed_tokens = NllbMoeScaledWordEmbedding(
1202
+ config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
1203
+ )
1204
+
1205
+ if embed_tokens is not None:
1206
+ self.embed_tokens.weight = embed_tokens.weight
1207
+
1208
+ self.embed_positions = NllbMoeSinusoidalPositionalEmbedding(
1209
+ config.max_position_embeddings,
1210
+ config.d_model,
1211
+ self.padding_idx,
1212
+ )
1213
+
1214
+ sparse_step = config.decoder_sparse_step
1215
+ self.layers = nn.ModuleList()
1216
+ for i in range(config.decoder_layers):
1217
+ is_sparse = (i + 1) % sparse_step == 0 if sparse_step > 0 else False
1218
+ self.layers.append(NllbMoeDecoderLayer(config, is_sparse))
1219
+
1220
+ self.layer_norm = nn.LayerNorm(config.d_model)
1221
+
1222
+ self.gradient_checkpointing = False
1223
+ # Initialize weights and apply final processing
1224
+ self.post_init()
1225
+
1226
+ def forward(
1227
+ self,
1228
+ input_ids: Optional[torch.Tensor] = None,
1229
+ attention_mask: Optional[torch.Tensor] = None,
1230
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1231
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1232
+ head_mask: Optional[torch.Tensor] = None,
1233
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1234
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1235
+ inputs_embeds: Optional[torch.Tensor] = None,
1236
+ use_cache: Optional[bool] = None,
1237
+ output_attentions: Optional[bool] = None,
1238
+ output_hidden_states: Optional[bool] = None,
1239
+ output_router_logits: Optional[bool] = None,
1240
+ return_dict: Optional[bool] = None,
1241
+ ):
1242
+ r"""
1243
+ Args:
1244
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1245
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
1246
+ provide it.
1247
+
1248
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1249
+ [`PreTrainedTokenizer.__call__`] for details.
1250
+
1251
+ [What are input IDs?](../glossary#input-ids)
1252
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1253
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1254
+
1255
+ - 1 for tokens that are **not masked**,
1256
+ - 0 for tokens that are **masked**.
1257
+
1258
+ [What are attention masks?](../glossary#attention-mask)
1259
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
1260
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
1261
+ of the decoder.
1262
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
1263
+ Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
1264
+ selected in `[0, 1]`:
1265
+
1266
+ - 1 for tokens that are **not masked**,
1267
+ - 0 for tokens that are **masked**.
1268
+
1269
+ [What are attention masks?](../glossary#attention-mask)
1270
+ head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
1271
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
1272
+
1273
+ - 1 indicates the head is **not masked**,
1274
+ - 0 indicates the head is **masked**.
1275
+
1276
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
1277
+ Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
1278
+ cross-attention on hidden heads. Mask values selected in `[0, 1]`:
1279
+
1280
+ - 1 indicates the head is **not masked**,
1281
+ - 0 indicates the head is **masked**.
1282
+
1283
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1284
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1285
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
1286
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
1287
+
1288
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
1289
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
1290
+
1291
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
1292
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
1293
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1294
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1295
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
1296
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
1297
+ than the model's internal embedding lookup matrix.
1298
+ output_attentions (`bool`, *optional*):
1299
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1300
+ returned tensors for more detail.
1301
+ output_hidden_states (`bool`, *optional*):
1302
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1303
+ for more detail.
1304
+ output_router_logits (`bool`, *optional*):
1305
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss,
1306
+ and should not be returned during inference.
1307
+ return_dict (`bool`, *optional*):
1308
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1309
+ """
1310
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1311
+ output_hidden_states = (
1312
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1313
+ )
1314
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1315
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1316
+
1317
+ # retrieve input_ids and inputs_embeds
1318
+ if input_ids is not None and inputs_embeds is not None:
1319
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
1320
+ elif input_ids is not None:
1321
+ input_shape = input_ids.size()
1322
+ input_ids = input_ids.view(-1, input_shape[-1])
1323
+ elif inputs_embeds is not None:
1324
+ input_shape = inputs_embeds.size()[:-1]
1325
+ else:
1326
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
1327
+
1328
+ # past_key_values_length
1329
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
1330
+
1331
+ if inputs_embeds is None:
1332
+ inputs_embeds = self.embed_tokens(input_ids)
1333
+
1334
+ # create causal mask
1335
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1336
+ combined_attention_mask = _prepare_4d_causal_attention_mask(
1337
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
1338
+ )
1339
+
1340
+ # expand encoder attention mask
1341
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
1342
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1343
+ encoder_attention_mask = _prepare_4d_attention_mask(
1344
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1345
+ )
1346
+
1347
+ # embed positions
1348
+ positions = self.embed_positions(input_ids, inputs_embeds, past_key_values_length)
1349
+ positions = positions.to(inputs_embeds.device)
1350
+
1351
+ hidden_states = inputs_embeds + positions
1352
+
1353
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
1354
+
1355
+ if self.gradient_checkpointing and self.training:
1356
+ if use_cache:
1357
+ logger.warning_once(
1358
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1359
+ )
1360
+ use_cache = False
1361
+
1362
+ # decoder layers
1363
+ all_hidden_states = () if output_hidden_states else None
1364
+ all_self_attns = () if output_attentions else None
1365
+ all_router_probs = () if output_router_logits else None
1366
+ all_cross_attentions = () if output_attentions else None
1367
+ present_key_value_states = () if use_cache else None
1368
+
1369
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
1370
+ for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
1371
+ if attn_mask is not None:
1372
+ if attn_mask.size()[0] != len(self.layers):
1373
+ raise ValueError(
1374
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
1375
+ f" {head_mask.size()[0]}."
1376
+ )
1377
+ synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
1378
+
1379
+ for idx, decoder_layer in enumerate(self.layers):
1380
+ if output_hidden_states:
1381
+ all_hidden_states += (hidden_states,)
1382
+
1383
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
1384
+ dropout_probability = torch.rand([])
1385
+
1386
+ skip_the_layer = True if self.training and (dropout_probability < self.layerdrop) else False
1387
+ if not skip_the_layer or synced_gpus:
1388
+ layer_head_mask = head_mask[idx] if head_mask is not None else None
1389
+ cross_attn_layer_head_mask = cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
1390
+
1391
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
1392
+
1393
+ # under fsdp or deepspeed zero3 all gpus must run in sync
1394
+ if self.gradient_checkpointing and self.training:
1395
+ if use_cache:
1396
+ logger.warning_once(
1397
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1398
+ )
1399
+ use_cache = False
1400
+ layer_outputs = self._gradient_checkpointing_func(
1401
+ decoder_layer.forward,
1402
+ hidden_states,
1403
+ combined_attention_mask,
1404
+ encoder_hidden_states,
1405
+ encoder_attention_mask,
1406
+ layer_head_mask,
1407
+ cross_attn_layer_head_mask,
1408
+ None, # past_key_value is always None with gradient checkpointing
1409
+ use_cache,
1410
+ output_attentions,
1411
+ )
1412
+ else:
1413
+ layer_outputs = decoder_layer(
1414
+ hidden_states,
1415
+ attention_mask=combined_attention_mask,
1416
+ encoder_hidden_states=encoder_hidden_states,
1417
+ encoder_attention_mask=encoder_attention_mask,
1418
+ layer_head_mask=layer_head_mask,
1419
+ cross_attn_layer_head_mask=cross_attn_layer_head_mask,
1420
+ past_key_value=past_key_value,
1421
+ use_cache=use_cache,
1422
+ output_attentions=output_attentions,
1423
+ output_router_logits=output_router_logits,
1424
+ )
1425
+
1426
+ hidden_states = layer_outputs[0]
1427
+
1428
+ if skip_the_layer:
1429
+ continue
1430
+
1431
+ if use_cache:
1432
+ present_key_value_states += (layer_outputs[1],)
1433
+
1434
+ if output_attentions:
1435
+ all_self_attns += (layer_outputs[2],)
1436
+ all_cross_attentions += (layer_outputs[3],)
1437
+
1438
+ if output_router_logits:
1439
+ all_router_probs += (layer_outputs[-1],)
1440
+
1441
+ hidden_states = self.layer_norm(hidden_states)
1442
+
1443
+ # Add last layer
1444
+ if output_hidden_states:
1445
+ all_hidden_states += (hidden_states,)
1446
+
1447
+ if not return_dict:
1448
+ return tuple(
1449
+ v
1450
+ for v in [
1451
+ hidden_states,
1452
+ present_key_value_states,
1453
+ all_hidden_states,
1454
+ all_self_attns,
1455
+ all_cross_attentions,
1456
+ all_router_probs,
1457
+ ]
1458
+ if v is not None
1459
+ )
1460
+ return MoEModelOutputWithPastAndCrossAttentions(
1461
+ last_hidden_state=hidden_states,
1462
+ past_key_values=present_key_value_states,
1463
+ hidden_states=all_hidden_states,
1464
+ attentions=all_self_attns,
1465
+ cross_attentions=all_cross_attentions,
1466
+ router_probs=all_router_probs,
1467
+ )
1468
+
1469
+
1470
+ @add_start_docstrings(
1471
+ "The bare NllbMoe Model outputting raw hidden-states without any specific head on top.",
1472
+ NLLB_MOE_START_DOCSTRING,
1473
+ )
1474
+ class NllbMoeModel(NllbMoePreTrainedModel):
1475
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
1476
+
1477
+ def __init__(self, config: NllbMoeConfig):
1478
+ super().__init__(config)
1479
+
1480
+ padding_idx, vocab_size = config.pad_token_id, config.vocab_size
1481
+ embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
1482
+ self.shared = NllbMoeScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)
1483
+
1484
+ self.encoder = NllbMoeEncoder(config, self.shared)
1485
+ self.decoder = NllbMoeDecoder(config, self.shared)
1486
+
1487
+ # Initialize weights and apply final processing
1488
+ self.post_init()
1489
+
1490
+ def get_input_embeddings(self):
1491
+ return self.shared
1492
+
1493
+ def set_input_embeddings(self, value):
1494
+ self.shared = value
1495
+ self.encoder.embed_tokens = self.shared
1496
+ self.decoder.embed_tokens = self.shared
1497
+
1498
+ def _tie_weights(self):
1499
+ if self.config.tie_word_embeddings:
1500
+ self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
1501
+ self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
1502
+
1503
+ def get_encoder(self):
1504
+ return self.encoder
1505
+
1506
+ def get_decoder(self):
1507
+ return self.decoder
1508
+
1509
+ @add_start_docstrings_to_model_forward(NLLB_MOE_INPUTS_DOCSTRING)
1510
+ @add_start_docstrings_to_model_forward(NLLB_MOE_INPUTS_DOCSTRING)
1511
+ @replace_return_docstrings(output_type=Seq2SeqMoEModelOutput, config_class=_CONFIG_FOR_DOC)
1512
+ def forward(
1513
+ self,
1514
+ input_ids: Optional[torch.LongTensor] = None,
1515
+ attention_mask: Optional[torch.Tensor] = None,
1516
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1517
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1518
+ head_mask: Optional[torch.Tensor] = None,
1519
+ decoder_head_mask: Optional[torch.Tensor] = None,
1520
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1521
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1522
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1523
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1524
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1525
+ use_cache: Optional[bool] = None,
1526
+ output_attentions: Optional[bool] = None,
1527
+ output_hidden_states: Optional[bool] = None,
1528
+ output_router_logits: Optional[bool] = None,
1529
+ return_dict: Optional[bool] = None,
1530
+ ) -> Union[Tuple[torch.Tensor], Seq2SeqMoEModelOutput]:
1531
+ r"""
1532
+ Returns:
1533
+
1534
+ Example:
1535
+
1536
+ ```python
1537
+ >>> from transformers import AutoTokenizer, NllbMoeModel
1538
+
1539
+ >>> tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/random-nllb-moe-2-experts")
1540
+ >>> model = SwitchTransformersModel.from_pretrained("hf-internal-testing/random-nllb-moe-2-experts")
1541
+
1542
+ >>> input_ids = tokenizer(
1543
+ ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
1544
+ ... ).input_ids # Batch size 1
1545
+ >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
1546
+
1547
+ >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for NllbMoeModel
1548
+ >>> decoder_input_ids = model._shift_right(decoder_input_ids)
1549
+
1550
+ >>> # forward pass
1551
+ >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
1552
+ >>> last_hidden_states = outputs.last_hidden_state
1553
+ ```"""
1554
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1555
+ if encoder_outputs is None:
1556
+ encoder_outputs = self.encoder(
1557
+ input_ids=input_ids,
1558
+ attention_mask=attention_mask,
1559
+ head_mask=head_mask,
1560
+ inputs_embeds=inputs_embeds,
1561
+ output_attentions=output_attentions,
1562
+ output_hidden_states=output_hidden_states,
1563
+ output_router_logits=output_router_logits,
1564
+ return_dict=return_dict,
1565
+ )
1566
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
1567
+ elif return_dict and not isinstance(encoder_outputs, MoEModelOutput):
1568
+ encoder_outputs = MoEModelOutput(
1569
+ last_hidden_state=encoder_outputs[0],
1570
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1571
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1572
+ router_probs=encoder_outputs[3] if len(encoder_outputs) > 3 else None,
1573
+ )
1574
+
1575
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
1576
+ decoder_outputs = self.decoder(
1577
+ input_ids=decoder_input_ids,
1578
+ attention_mask=decoder_attention_mask,
1579
+ encoder_hidden_states=encoder_outputs[0],
1580
+ encoder_attention_mask=attention_mask,
1581
+ head_mask=decoder_head_mask,
1582
+ cross_attn_head_mask=cross_attn_head_mask,
1583
+ past_key_values=past_key_values,
1584
+ inputs_embeds=decoder_inputs_embeds,
1585
+ use_cache=use_cache,
1586
+ output_attentions=output_attentions,
1587
+ output_hidden_states=output_hidden_states,
1588
+ output_router_logits=output_router_logits,
1589
+ return_dict=return_dict,
1590
+ )
1591
+
1592
+ if not return_dict:
1593
+ return decoder_outputs + encoder_outputs
1594
+
1595
+ return Seq2SeqMoEModelOutput(
1596
+ past_key_values=decoder_outputs.past_key_values,
1597
+ cross_attentions=decoder_outputs.cross_attentions,
1598
+ last_hidden_state=decoder_outputs.last_hidden_state,
1599
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1600
+ encoder_hidden_states=encoder_outputs.hidden_states,
1601
+ decoder_hidden_states=decoder_outputs.hidden_states,
1602
+ encoder_attentions=encoder_outputs.attentions,
1603
+ decoder_attentions=decoder_outputs.attentions,
1604
+ encoder_router_logits=encoder_outputs.router_probs,
1605
+ decoder_router_logits=decoder_outputs.router_probs,
1606
+ )
1607
+
1608
+
1609
+ @add_start_docstrings(
1610
+ "The NllbMoe Model with a language modeling head. Can be used for summarization.", NLLB_MOE_START_DOCSTRING
1611
+ )
1612
+ class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel, GenerationMixin):
1613
+ base_model_prefix = "model"
1614
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
1615
+
1616
+ def __init__(self, config: NllbMoeConfig):
1617
+ super().__init__(config)
1618
+ self.model = NllbMoeModel(config)
1619
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
1620
+
1621
+ self.router_z_loss_coef = config.router_z_loss_coef
1622
+ self.router_aux_loss_coef = config.router_aux_loss_coef
1623
+ # Initialize weights and apply final processing
1624
+ self.post_init()
1625
+
1626
+ def get_encoder(self):
1627
+ return self.model.get_encoder()
1628
+
1629
+ def get_decoder(self):
1630
+ return self.model.get_decoder()
1631
+
1632
+ def get_output_embeddings(self):
1633
+ return self.lm_head
1634
+
1635
+ def set_output_embeddings(self, new_embeddings):
1636
+ self.lm_head = new_embeddings
1637
+
1638
+ @add_start_docstrings_to_model_forward(NLLB_MOE_INPUTS_DOCSTRING)
1639
+ @replace_return_docstrings(output_type=Seq2SeqMoEOutput, config_class=_CONFIG_FOR_DOC)
1640
+ @add_end_docstrings(NLLB_MOE_GENERATION_EXAMPLE)
1641
+ def forward(
1642
+ self,
1643
+ input_ids: Optional[torch.LongTensor] = None,
1644
+ attention_mask: Optional[torch.Tensor] = None,
1645
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1646
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1647
+ head_mask: Optional[torch.Tensor] = None,
1648
+ decoder_head_mask: Optional[torch.Tensor] = None,
1649
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1650
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1651
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1652
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1653
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1654
+ labels: Optional[torch.LongTensor] = None,
1655
+ use_cache: Optional[bool] = None,
1656
+ output_attentions: Optional[bool] = None,
1657
+ output_hidden_states: Optional[bool] = None,
1658
+ output_router_logits: Optional[bool] = None,
1659
+ return_dict: Optional[bool] = None,
1660
+ ) -> Union[Tuple[torch.Tensor], Seq2SeqMoEOutput]:
1661
+ r"""
1662
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1663
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1664
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1665
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1666
+
1667
+ Returns:
1668
+ """
1669
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1670
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1671
+ output_router_logits = (
1672
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
1673
+ )
1674
+ if labels is not None:
1675
+ if decoder_input_ids is None:
1676
+ decoder_input_ids = shift_tokens_right(
1677
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
1678
+ )
1679
+
1680
+ outputs = self.model(
1681
+ input_ids,
1682
+ attention_mask=attention_mask,
1683
+ decoder_input_ids=decoder_input_ids,
1684
+ encoder_outputs=encoder_outputs,
1685
+ decoder_attention_mask=decoder_attention_mask,
1686
+ head_mask=head_mask,
1687
+ decoder_head_mask=decoder_head_mask,
1688
+ cross_attn_head_mask=cross_attn_head_mask,
1689
+ past_key_values=past_key_values,
1690
+ inputs_embeds=inputs_embeds,
1691
+ decoder_inputs_embeds=decoder_inputs_embeds,
1692
+ use_cache=use_cache,
1693
+ output_attentions=output_attentions,
1694
+ output_hidden_states=output_hidden_states,
1695
+ output_router_logits=output_router_logits,
1696
+ return_dict=return_dict,
1697
+ )
1698
+ lm_logits = self.lm_head(outputs[0])
1699
+
1700
+ loss = None
1701
+ encoder_aux_loss = None
1702
+ decoder_aux_loss = None
1703
+
1704
+ if labels is not None:
1705
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
1706
+ # todo check in the config if router loss enables
1707
+
1708
+ if output_router_logits:
1709
+ encoder_router_logits = outputs[-1]
1710
+ decoder_router_logits = outputs[3 if output_attentions else 4]
1711
+
1712
+ # Compute the router loss (z_loss + auxiliary loss) for each router in the encoder and decoder
1713
+ encoder_router_logits, encoder_expert_indexes = self._unpack_router_logits(encoder_router_logits)
1714
+ encoder_aux_loss = load_balancing_loss_func(encoder_router_logits, encoder_expert_indexes)
1715
+
1716
+ decoder_router_logits, decoder_expert_indexes = self._unpack_router_logits(decoder_router_logits)
1717
+ decoder_aux_loss = load_balancing_loss_func(decoder_router_logits, decoder_expert_indexes)
1718
+
1719
+ loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
1720
+
1721
+ if output_router_logits and labels is not None:
1722
+ aux_loss = self.router_aux_loss_coef * (encoder_aux_loss + decoder_aux_loss)
1723
+ loss = loss + aux_loss
1724
+
1725
+ output = (loss,) if loss is not None else ()
1726
+ if not return_dict:
1727
+ output += (lm_logits,)
1728
+ if output_router_logits: # only return the loss if they are not None
1729
+ output += (
1730
+ encoder_aux_loss,
1731
+ decoder_aux_loss,
1732
+ *outputs[1:],
1733
+ )
1734
+ else:
1735
+ output += outputs[1:]
1736
+
1737
+ return output
1738
+
1739
+ return Seq2SeqMoEOutput(
1740
+ loss=loss,
1741
+ logits=lm_logits,
1742
+ past_key_values=outputs.past_key_values,
1743
+ cross_attentions=outputs.cross_attentions,
1744
+ encoder_aux_loss=encoder_aux_loss,
1745
+ decoder_aux_loss=decoder_aux_loss,
1746
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1747
+ encoder_hidden_states=outputs.encoder_hidden_states,
1748
+ decoder_hidden_states=outputs.decoder_hidden_states,
1749
+ encoder_attentions=outputs.encoder_attentions,
1750
+ decoder_attentions=outputs.decoder_attentions,
1751
+ encoder_router_logits=outputs.encoder_router_logits,
1752
+ decoder_router_logits=outputs.decoder_router_logits,
1753
+ )
1754
+
1755
+ def _unpack_router_logits(self, router_outputs):
1756
+ total_router_logits = []
1757
+ total_expert_indexes = []
1758
+ for router_output in router_outputs:
1759
+ if router_output is not None:
1760
+ router_logits, expert_indexes = router_output
1761
+ total_router_logits.append(router_logits)
1762
+ total_expert_indexes.append(expert_indexes)
1763
+
1764
+ total_router_logits = torch.cat(total_router_logits, dim=1) if len(total_router_logits) > 0 else None
1765
+ total_expert_indexes = torch.stack(total_expert_indexes, dim=1) if len(total_expert_indexes) > 0 else None
1766
+ return total_router_logits, total_expert_indexes
1767
+
1768
+ @staticmethod
1769
+ def _reorder_cache(past_key_values, beam_idx):
1770
+ reordered_past = ()
1771
+ for layer_past in past_key_values:
1772
+ reordered_past += (
1773
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1774
+ )
1775
+ return reordered_past
1776
+
1777
+
1778
+ __all__ = [
1779
+ "NllbMoeForConditionalGeneration",
1780
+ "NllbMoeModel",
1781
+ "NllbMoePreTrainedModel",
1782
+ "NllbMoeTop2Router",
1783
+ "NllbMoeSparseMLP",
1784
+ ]
docs/transformers/src/transformers/models/nougat/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .image_processing_nougat import *
22
+ from .processing_nougat import *
23
+ from .tokenization_nougat_fast import *
24
+ else:
25
+ import sys
26
+
27
+ _file = globals()["__file__"]
28
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/src/transformers/models/nougat/convert_nougat_to_hf.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert Nougat checkpoints using the original `nougat` library. URL:
16
+ https://github.com/facebookresearch/nougat/tree/main"""
17
+
18
+ import argparse
19
+
20
+ import torch
21
+ from huggingface_hub import hf_hub_download
22
+ from nougat import NougatModel
23
+ from nougat.dataset.rasterize import rasterize_paper
24
+ from nougat.utils.checkpoint import get_checkpoint
25
+ from PIL import Image
26
+
27
+ from transformers import (
28
+ DonutSwinConfig,
29
+ DonutSwinModel,
30
+ MBartConfig,
31
+ MBartForCausalLM,
32
+ NougatImageProcessor,
33
+ NougatProcessor,
34
+ NougatTokenizerFast,
35
+ VisionEncoderDecoderModel,
36
+ )
37
+
38
+
39
+ def get_configs(model):
40
+ original_config = model.config
41
+
42
+ encoder_config = DonutSwinConfig(
43
+ image_size=original_config.input_size,
44
+ patch_size=4,
45
+ depths=original_config.encoder_layer,
46
+ num_heads=[4, 8, 16, 32],
47
+ window_size=original_config.window_size,
48
+ embed_dim=128,
49
+ )
50
+ decoder_config = MBartConfig(
51
+ is_decoder=True,
52
+ is_encoder_decoder=False,
53
+ add_cross_attention=True,
54
+ decoder_layers=original_config.decoder_layer,
55
+ max_position_embeddings=original_config.max_position_embeddings,
56
+ vocab_size=len(
57
+ model.decoder.tokenizer
58
+ ), # several special tokens are added to the vocab of XLMRobertaTokenizer, see repo on the hub (added_tokens.json)
59
+ scale_embedding=True,
60
+ add_final_layer_norm=True,
61
+ tie_word_embeddings=False,
62
+ )
63
+
64
+ return encoder_config, decoder_config
65
+
66
+
67
+ # Copied from transformers.models.donut.convert_donut_to_pytorch.rename_key
68
+ def rename_key(name):
69
+ if "encoder.model" in name:
70
+ name = name.replace("encoder.model", "encoder")
71
+ if "decoder.model" in name:
72
+ name = name.replace("decoder.model", "decoder")
73
+ if "patch_embed.proj" in name:
74
+ name = name.replace("patch_embed.proj", "embeddings.patch_embeddings.projection")
75
+ if "patch_embed.norm" in name:
76
+ name = name.replace("patch_embed.norm", "embeddings.norm")
77
+ if name.startswith("encoder"):
78
+ if "layers" in name:
79
+ name = "encoder." + name
80
+ if "attn.proj" in name:
81
+ name = name.replace("attn.proj", "attention.output.dense")
82
+ if "attn" in name and "mask" not in name:
83
+ name = name.replace("attn", "attention.self")
84
+ if "norm1" in name:
85
+ name = name.replace("norm1", "layernorm_before")
86
+ if "norm2" in name:
87
+ name = name.replace("norm2", "layernorm_after")
88
+ if "mlp.fc1" in name:
89
+ name = name.replace("mlp.fc1", "intermediate.dense")
90
+ if "mlp.fc2" in name:
91
+ name = name.replace("mlp.fc2", "output.dense")
92
+
93
+ if name == "encoder.norm.weight":
94
+ name = "encoder.layernorm.weight"
95
+ if name == "encoder.norm.bias":
96
+ name = "encoder.layernorm.bias"
97
+
98
+ return name
99
+
100
+
101
+ # Copied from transformers.models.donut.convert_donut_to_pytorch.convert_state_dict
102
+ def convert_state_dict(orig_state_dict, model):
103
+ for key in orig_state_dict.copy().keys():
104
+ val = orig_state_dict.pop(key)
105
+
106
+ if "qkv" in key:
107
+ key_split = key.split(".")
108
+ layer_num = int(key_split[3])
109
+ block_num = int(key_split[5])
110
+ dim = model.encoder.encoder.layers[layer_num].blocks[block_num].attention.self.all_head_size
111
+
112
+ if "weight" in key:
113
+ orig_state_dict[
114
+ f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.query.weight"
115
+ ] = val[:dim, :]
116
+ orig_state_dict[f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.key.weight"] = (
117
+ val[dim : dim * 2, :]
118
+ )
119
+ orig_state_dict[
120
+ f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.value.weight"
121
+ ] = val[-dim:, :]
122
+ else:
123
+ orig_state_dict[f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.query.bias"] = (
124
+ val[:dim]
125
+ )
126
+ orig_state_dict[f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.key.bias"] = (
127
+ val[dim : dim * 2]
128
+ )
129
+ orig_state_dict[f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.value.bias"] = (
130
+ val[-dim:]
131
+ )
132
+ elif "attn_mask" in key or key in ["encoder.model.norm.weight", "encoder.model.norm.bias"]:
133
+ # HuggingFace implementation doesn't use attn_mask buffer
134
+ # and model doesn't use final LayerNorms for the encoder
135
+ pass
136
+ else:
137
+ orig_state_dict[rename_key(key)] = val
138
+
139
+ return orig_state_dict
140
+
141
+
142
+ def convert_nougat_checkpoint(model_tag, pytorch_dump_folder_path=None, push_to_hub=False):
143
+ # load original model
144
+ checkpoint_path = get_checkpoint(None, model_tag)
145
+ original_model = NougatModel.from_pretrained(checkpoint_path)
146
+ original_model.eval()
147
+
148
+ # load HuggingFace model
149
+ encoder_config, decoder_config = get_configs(original_model)
150
+ encoder = DonutSwinModel(encoder_config)
151
+ decoder = MBartForCausalLM(decoder_config)
152
+ model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder)
153
+ model.eval()
154
+
155
+ state_dict = original_model.state_dict()
156
+ new_state_dict = convert_state_dict(state_dict, model)
157
+ model.load_state_dict(new_state_dict)
158
+
159
+ # verify results on PDF
160
+ filepath = hf_hub_download(repo_id="ysharma/nougat", filename="input/nougat.pdf", repo_type="space")
161
+ images = rasterize_paper(pdf=filepath, return_pil=True)
162
+ image = Image.open(images[0])
163
+
164
+ tokenizer_file = checkpoint_path / "tokenizer.json"
165
+ tokenizer = NougatTokenizerFast(tokenizer_file=str(tokenizer_file))
166
+ tokenizer.pad_token = "<pad>"
167
+ tokenizer.bos_token = "<s>"
168
+ tokenizer.eos_token = "</s>"
169
+ tokenizer.unk_token = "<unk>"
170
+ tokenizer.model_max_length = original_model.config.max_length
171
+
172
+ size = {"height": original_model.config.input_size[0], "width": original_model.config.input_size[1]}
173
+ image_processor = NougatImageProcessor(
174
+ do_align_long_axis=original_model.config.align_long_axis,
175
+ size=size,
176
+ )
177
+ processor = NougatProcessor(image_processor=image_processor, tokenizer=tokenizer)
178
+
179
+ # verify pixel_values
180
+ pixel_values = processor(image, return_tensors="pt").pixel_values
181
+ original_pixel_values = original_model.encoder.prepare_input(image).unsqueeze(0)
182
+
183
+ assert torch.allclose(original_pixel_values, pixel_values)
184
+
185
+ # verify patch embeddings
186
+ original_patch_embed = original_model.encoder.model.patch_embed(pixel_values)
187
+ patch_embeddings, _ = model.encoder.embeddings(pixel_values)
188
+ assert torch.allclose(original_patch_embed, patch_embeddings)
189
+
190
+ # verify encoder hidden states
191
+ original_last_hidden_state = original_model.encoder(pixel_values)
192
+ last_hidden_state = model.encoder(pixel_values).last_hidden_state
193
+ assert torch.allclose(original_last_hidden_state, last_hidden_state, atol=1e-2)
194
+
195
+ # NOTE original model does not use tied weights for embeddings of decoder
196
+ original_embeddings = original_model.decoder.model.model.decoder.embed_tokens
197
+ embeddings = model.decoder.model.decoder.embed_tokens
198
+ assert torch.allclose(original_embeddings.weight, embeddings.weight, atol=1e-3)
199
+
200
+ # verify decoder hidden states
201
+ prompt = "hello world"
202
+ decoder_input_ids = original_model.decoder.tokenizer(
203
+ prompt, add_special_tokens=False, return_tensors="pt"
204
+ ).input_ids
205
+ decoder_attention_mask = torch.ones_like(decoder_input_ids)
206
+ original_logits = original_model(
207
+ image_tensors=pixel_values, decoder_input_ids=decoder_input_ids, attention_mask=decoder_attention_mask
208
+ ).logits
209
+ logits = model(
210
+ pixel_values,
211
+ decoder_input_ids=decoder_input_ids[:, :-1],
212
+ decoder_attention_mask=decoder_attention_mask[:, :-1],
213
+ ).logits
214
+ assert torch.allclose(original_logits, logits, atol=1e-3)
215
+
216
+ # verify generation
217
+ outputs = model.generate(
218
+ pixel_values,
219
+ min_length=1,
220
+ max_length=30,
221
+ pad_token_id=tokenizer.pad_token_id,
222
+ eos_token_id=tokenizer.eos_token_id,
223
+ use_cache=True,
224
+ bad_words_ids=[
225
+ [tokenizer.unk_token_id],
226
+ ],
227
+ return_dict_in_generate=True,
228
+ do_sample=False,
229
+ )
230
+ generated = tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)[0]
231
+
232
+ if model_tag == "0.1.0-base":
233
+ expected_generation = "# Nougat: Neural Optical Understanding for Academic Documents\n\nLukas Blecher\n\nCorrespondence to: lblec"
234
+ elif model_tag == "0.1.0-small":
235
+ expected_generation = (
236
+ "# Nougat: Neural Optical Understanding for Academic Documents\n\nLukas Blecher\n\nCorrespondence to: lble"
237
+ )
238
+ else:
239
+ raise ValueError(f"Unexpected model tag: {model_tag}")
240
+
241
+ assert generated == expected_generation
242
+ print("Looks ok!")
243
+
244
+ if pytorch_dump_folder_path is not None:
245
+ print(f"Saving model and processor to {pytorch_dump_folder_path}")
246
+ model.save_pretrained(pytorch_dump_folder_path)
247
+ processor.save_pretrained(pytorch_dump_folder_path)
248
+
249
+ if push_to_hub:
250
+ tag_to_name = {"0.1.0-base": "nougat-base", "0.1.0-small": "nougat-small"}
251
+ model_name = tag_to_name[model_tag]
252
+
253
+ model.push_to_hub(f"facebook/{model_name}")
254
+ processor.push_to_hub(f"facebook/{model_name}")
255
+
256
+
257
+ if __name__ == "__main__":
258
+ parser = argparse.ArgumentParser()
259
+ # Required parameters
260
+ parser.add_argument(
261
+ "--model_tag",
262
+ default="0.1.0-base",
263
+ required=False,
264
+ type=str,
265
+ choices=["0.1.0-base", "0.1.0-small"],
266
+ help="Tag of the original model you'd like to convert.",
267
+ )
268
+ parser.add_argument(
269
+ "--pytorch_dump_folder_path",
270
+ default=None,
271
+ required=False,
272
+ type=str,
273
+ help="Path to the output PyTorch model directory.",
274
+ )
275
+ parser.add_argument(
276
+ "--push_to_hub",
277
+ action="store_true",
278
+ help="Whether or not to push the converted model and processor to the 🤗 hub.",
279
+ )
280
+
281
+ args = parser.parse_args()
282
+ convert_nougat_checkpoint(args.model_tag, args.pytorch_dump_folder_path, args.push_to_hub)
docs/transformers/src/transformers/models/nougat/image_processing_nougat.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Image processor class for Nougat."""
16
+
17
+ from typing import Dict, List, Optional, Union
18
+
19
+ import numpy as np
20
+
21
+ from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
22
+ from ...image_transforms import (
23
+ get_resize_output_image_size,
24
+ pad,
25
+ resize,
26
+ to_channel_dimension_format,
27
+ to_pil_image,
28
+ )
29
+ from ...image_utils import (
30
+ IMAGENET_DEFAULT_MEAN,
31
+ IMAGENET_DEFAULT_STD,
32
+ ChannelDimension,
33
+ ImageInput,
34
+ PILImageResampling,
35
+ get_image_size,
36
+ infer_channel_dimension_format,
37
+ is_scaled_image,
38
+ make_list_of_images,
39
+ to_numpy_array,
40
+ valid_images,
41
+ validate_preprocess_arguments,
42
+ )
43
+ from ...utils import TensorType, filter_out_non_signature_kwargs, logging
44
+ from ...utils.import_utils import is_cv2_available, is_vision_available
45
+
46
+
47
+ logger = logging.get_logger(__name__)
48
+
49
+
50
+ if is_cv2_available():
51
+ pass
52
+
53
+
54
+ if is_vision_available():
55
+ import PIL
56
+
57
+
58
+ class NougatImageProcessor(BaseImageProcessor):
59
+ r"""
60
+ Constructs a Nougat image processor.
61
+
62
+ Args:
63
+ do_crop_margin (`bool`, *optional*, defaults to `True`):
64
+ Whether to crop the image margins.
65
+ do_resize (`bool`, *optional*, defaults to `True`):
66
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
67
+ `do_resize` in the `preprocess` method.
68
+ size (`Dict[str, int]` *optional*, defaults to `{"height": 896, "width": 672}`):
69
+ Size of the image after resizing. Can be overridden by `size` in the `preprocess` method.
70
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
71
+ Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
72
+ do_thumbnail (`bool`, *optional*, defaults to `True`):
73
+ Whether to resize the image using thumbnail method.
74
+ do_align_long_axis (`bool`, *optional*, defaults to `False`):
75
+ Whether to align the long axis of the image with the long axis of `size` by rotating by 90 degrees.
76
+ do_pad (`bool`, *optional*, defaults to `True`):
77
+ Whether to pad the images to the largest image size in the batch.
78
+ do_rescale (`bool`, *optional*, defaults to `True`):
79
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
80
+ parameter in the `preprocess` method.
81
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
82
+ Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
83
+ `preprocess` method.
84
+ do_normalize (`bool`, *optional*, defaults to `True`):
85
+ Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
86
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`):
87
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
88
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
89
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`):
90
+ Image standard deviation.
91
+ """
92
+
93
+ model_input_names = ["pixel_values"]
94
+
95
+ def __init__(
96
+ self,
97
+ do_crop_margin: bool = True,
98
+ do_resize: bool = True,
99
+ size: Dict[str, int] = None,
100
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
101
+ do_thumbnail: bool = True,
102
+ do_align_long_axis: bool = False,
103
+ do_pad: bool = True,
104
+ do_rescale: bool = True,
105
+ rescale_factor: Union[int, float] = 1 / 255,
106
+ do_normalize: bool = True,
107
+ image_mean: Optional[Union[float, List[float]]] = None,
108
+ image_std: Optional[Union[float, List[float]]] = None,
109
+ **kwargs,
110
+ ) -> None:
111
+ super().__init__(**kwargs)
112
+
113
+ size = size if size is not None else {"height": 896, "width": 672}
114
+ size = get_size_dict(size)
115
+
116
+ self.do_crop_margin = do_crop_margin
117
+ self.do_resize = do_resize
118
+ self.size = size
119
+ self.resample = resample
120
+ self.do_thumbnail = do_thumbnail
121
+ self.do_align_long_axis = do_align_long_axis
122
+ self.do_pad = do_pad
123
+ self.do_rescale = do_rescale
124
+ self.rescale_factor = rescale_factor
125
+ self.do_normalize = do_normalize
126
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
127
+ self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
128
+
129
+ def python_find_non_zero(self, image: np.array):
130
+ """This is a reimplementation of a findNonZero function equivalent to cv2."""
131
+ non_zero_indices = np.column_stack(np.nonzero(image))
132
+ idxvec = non_zero_indices[:, [1, 0]]
133
+ idxvec = idxvec.reshape(-1, 1, 2)
134
+ return idxvec
135
+
136
+ def python_bounding_rect(self, coordinates):
137
+ """This is a reimplementation of a BoundingRect function equivalent to cv2."""
138
+ min_values = np.min(coordinates, axis=(0, 1)).astype(int)
139
+ max_values = np.max(coordinates, axis=(0, 1)).astype(int)
140
+ x_min, y_min = min_values[0], min_values[1]
141
+ width = max_values[0] - x_min + 1
142
+ height = max_values[1] - y_min + 1
143
+ return x_min, y_min, width, height
144
+
145
+ def crop_margin(
146
+ self,
147
+ image: np.array,
148
+ gray_threshold: int = 200,
149
+ data_format: Optional[ChannelDimension] = None,
150
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
151
+ ) -> np.array:
152
+ """
153
+ Crops the margin of the image. Gray pixels are considered margin (i.e., pixels with a value below the
154
+ threshold).
155
+
156
+ Args:
157
+ image (`np.array`):
158
+ The image to be cropped.
159
+ gray_threshold (`int`, *optional*, defaults to `200`)
160
+ Value below which pixels are considered to be gray.
161
+ data_format (`ChannelDimension`, *optional*):
162
+ The channel dimension format of the output image. If unset, will use the inferred format from the
163
+ input.
164
+ input_data_format (`ChannelDimension`, *optional*):
165
+ The channel dimension format of the input image. If unset, will use the inferred format from the input.
166
+ """
167
+ if input_data_format is None:
168
+ input_data_format = infer_channel_dimension_format(image)
169
+
170
+ image = to_pil_image(image, input_data_format=input_data_format)
171
+ data = np.array(image.convert("L")).astype(np.uint8)
172
+ max_val = data.max()
173
+ min_val = data.min()
174
+ if max_val == min_val:
175
+ image = np.array(image)
176
+ image = (
177
+ to_channel_dimension_format(image, data_format, input_data_format)
178
+ if data_format is not None
179
+ else image
180
+ )
181
+ return image
182
+ data = (data - min_val) / (max_val - min_val) * 255
183
+ gray = data < gray_threshold
184
+ coords = self.python_find_non_zero(gray)
185
+ x_min, y_min, width, height = self.python_bounding_rect(coords)
186
+ image = image.crop((x_min, y_min, x_min + width, y_min + height))
187
+ image = np.array(image).astype(np.uint8)
188
+ image = to_channel_dimension_format(image, input_data_format, ChannelDimension.LAST)
189
+
190
+ image = (
191
+ to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
192
+ )
193
+
194
+ return image
195
+
196
+ # Copied from transformers.models.donut.image_processing_donut.DonutImageProcessor.align_long_axis
197
+ def align_long_axis(
198
+ self,
199
+ image: np.ndarray,
200
+ size: Dict[str, int],
201
+ data_format: Optional[Union[str, ChannelDimension]] = None,
202
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
203
+ ) -> np.ndarray:
204
+ """
205
+ Align the long axis of the image to the longest axis of the specified size.
206
+
207
+ Args:
208
+ image (`np.ndarray`):
209
+ The image to be aligned.
210
+ size (`Dict[str, int]`):
211
+ The size `{"height": h, "width": w}` to align the long axis to.
212
+ data_format (`str` or `ChannelDimension`, *optional*):
213
+ The data format of the output image. If unset, the same format as the input image is used.
214
+ input_data_format (`ChannelDimension` or `str`, *optional*):
215
+ The channel dimension format of the input image. If not provided, it will be inferred.
216
+
217
+ Returns:
218
+ `np.ndarray`: The aligned image.
219
+ """
220
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
221
+ output_height, output_width = size["height"], size["width"]
222
+
223
+ if input_data_format is None:
224
+ # We assume that all images have the same channel dimension format.
225
+ input_data_format = infer_channel_dimension_format(image)
226
+
227
+ if input_data_format == ChannelDimension.LAST:
228
+ rot_axes = (0, 1)
229
+ elif input_data_format == ChannelDimension.FIRST:
230
+ rot_axes = (1, 2)
231
+ else:
232
+ raise ValueError(f"Unsupported data format: {input_data_format}")
233
+
234
+ if (output_width < output_height and input_width > input_height) or (
235
+ output_width > output_height and input_width < input_height
236
+ ):
237
+ image = np.rot90(image, 3, axes=rot_axes)
238
+
239
+ if data_format is not None:
240
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
241
+
242
+ return image
243
+
244
+ def pad_image(
245
+ self,
246
+ image: np.ndarray,
247
+ size: Dict[str, int],
248
+ data_format: Optional[Union[str, ChannelDimension]] = None,
249
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
250
+ ) -> np.ndarray:
251
+ """
252
+ Pad the image to the specified size at the top, bottom, left and right.
253
+
254
+ Args:
255
+ image (`np.ndarray`):
256
+ The image to be padded.
257
+ size (`Dict[str, int]`):
258
+ The size `{"height": h, "width": w}` to pad the image to.
259
+ data_format (`str` or `ChannelDimension`, *optional*):
260
+ The data format of the output image. If unset, the same format as the input image is used.
261
+ input_data_format (`ChannelDimension` or `str`, *optional*):
262
+ The channel dimension format of the input image. If not provided, it will be inferred.
263
+ """
264
+ output_height, output_width = size["height"], size["width"]
265
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
266
+
267
+ delta_width = output_width - input_width
268
+ delta_height = output_height - input_height
269
+
270
+ pad_top = delta_height // 2
271
+ pad_left = delta_width // 2
272
+
273
+ pad_bottom = delta_height - pad_top
274
+ pad_right = delta_width - pad_left
275
+
276
+ padding = ((pad_top, pad_bottom), (pad_left, pad_right))
277
+ return pad(image, padding, data_format=data_format, input_data_format=input_data_format)
278
+
279
+ # Copied from transformers.models.donut.image_processing_donut.DonutImageProcessor.thumbnail
280
+ def thumbnail(
281
+ self,
282
+ image: np.ndarray,
283
+ size: Dict[str, int],
284
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
285
+ data_format: Optional[Union[str, ChannelDimension]] = None,
286
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
287
+ **kwargs,
288
+ ) -> np.ndarray:
289
+ """
290
+ Resize the image to make a thumbnail. The image is resized so that no dimension is larger than any
291
+ corresponding dimension of the specified size.
292
+
293
+ Args:
294
+ image (`np.ndarray`):
295
+ The image to be resized.
296
+ size (`Dict[str, int]`):
297
+ The size `{"height": h, "width": w}` to resize the image to.
298
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
299
+ The resampling filter to use.
300
+ data_format (`Optional[Union[str, ChannelDimension]]`, *optional*):
301
+ The data format of the output image. If unset, the same format as the input image is used.
302
+ input_data_format (`ChannelDimension` or `str`, *optional*):
303
+ The channel dimension format of the input image. If not provided, it will be inferred.
304
+ """
305
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
306
+ output_height, output_width = size["height"], size["width"]
307
+
308
+ # We always resize to the smallest of either the input or output size.
309
+ height = min(input_height, output_height)
310
+ width = min(input_width, output_width)
311
+
312
+ if height == input_height and width == input_width:
313
+ return image
314
+
315
+ if input_height > input_width:
316
+ width = int(input_width * height / input_height)
317
+ elif input_width > input_height:
318
+ height = int(input_height * width / input_width)
319
+
320
+ return resize(
321
+ image,
322
+ size=(height, width),
323
+ resample=resample,
324
+ reducing_gap=2.0,
325
+ data_format=data_format,
326
+ input_data_format=input_data_format,
327
+ **kwargs,
328
+ )
329
+
330
+ # Copied from transformers.models.donut.image_processing_donut.DonutImageProcessor.resize
331
+ def resize(
332
+ self,
333
+ image: np.ndarray,
334
+ size: Dict[str, int],
335
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
336
+ data_format: Optional[Union[str, ChannelDimension]] = None,
337
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
338
+ **kwargs,
339
+ ) -> np.ndarray:
340
+ """
341
+ Resizes `image` to `(height, width)` specified by `size` using the PIL library.
342
+
343
+ Args:
344
+ image (`np.ndarray`):
345
+ Image to resize.
346
+ size (`Dict[str, int]`):
347
+ Size of the output image.
348
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
349
+ Resampling filter to use when resiizing the image.
350
+ data_format (`str` or `ChannelDimension`, *optional*):
351
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
352
+ input_data_format (`ChannelDimension` or `str`, *optional*):
353
+ The channel dimension format of the input image. If not provided, it will be inferred.
354
+ """
355
+ size = get_size_dict(size)
356
+ shortest_edge = min(size["height"], size["width"])
357
+ output_size = get_resize_output_image_size(
358
+ image, size=shortest_edge, default_to_square=False, input_data_format=input_data_format
359
+ )
360
+ resized_image = resize(
361
+ image,
362
+ size=output_size,
363
+ resample=resample,
364
+ data_format=data_format,
365
+ input_data_format=input_data_format,
366
+ **kwargs,
367
+ )
368
+ return resized_image
369
+
370
+ @filter_out_non_signature_kwargs()
371
+ def preprocess(
372
+ self,
373
+ images: ImageInput,
374
+ do_crop_margin: Optional[bool] = None,
375
+ do_resize: Optional[bool] = None,
376
+ size: Dict[str, int] = None,
377
+ resample: PILImageResampling = None,
378
+ do_thumbnail: Optional[bool] = None,
379
+ do_align_long_axis: Optional[bool] = None,
380
+ do_pad: Optional[bool] = None,
381
+ do_rescale: Optional[bool] = None,
382
+ rescale_factor: Union[int, float] = None,
383
+ do_normalize: Optional[bool] = None,
384
+ image_mean: Optional[Union[float, List[float]]] = None,
385
+ image_std: Optional[Union[float, List[float]]] = None,
386
+ return_tensors: Optional[Union[str, TensorType]] = None,
387
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
388
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
389
+ ) -> PIL.Image.Image:
390
+ """
391
+ Preprocess an image or batch of images.
392
+
393
+ Args:
394
+ images (`ImageInput`):
395
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255.
396
+ do_crop_margin (`bool`, *optional*, defaults to `self.do_crop_margin`):
397
+ Whether to crop the image margins.
398
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
399
+ Whether to resize the image.
400
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
401
+ Size of the image after resizing. Shortest edge of the image is resized to min(size["height"],
402
+ size["width"]) with the longest edge resized to keep the input aspect ratio.
403
+ resample (`int`, *optional*, defaults to `self.resample`):
404
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
405
+ has an effect if `do_resize` is set to `True`.
406
+ do_thumbnail (`bool`, *optional*, defaults to `self.do_thumbnail`):
407
+ Whether to resize the image using thumbnail method.
408
+ do_align_long_axis (`bool`, *optional*, defaults to `self.do_align_long_axis`):
409
+ Whether to align the long axis of the image with the long axis of `size` by rotating by 90 degrees.
410
+ do_pad (`bool`, *optional*, defaults to `self.do_pad`):
411
+ Whether to pad the images to the largest image size in the batch.
412
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
413
+ Whether to rescale the image by the specified scale `rescale_factor`.
414
+ rescale_factor (`int` or `float`, *optional*, defaults to `self.rescale_factor`):
415
+ Scale factor to use if rescaling the image.
416
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
417
+ Whether to normalize the image.
418
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
419
+ Image mean to use for normalization.
420
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
421
+ Image standard deviation to use for normalization.
422
+ return_tensors (`str` or `TensorType`, *optional*):
423
+ The type of tensors to return. Can be one of:
424
+ - Unset: Return a list of `np.ndarray`.
425
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
426
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
427
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
428
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
429
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
430
+ The channel dimension format for the output image. Can be one of:
431
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
432
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
433
+ - Unset: defaults to the channel dimension format of the input image.
434
+ input_data_format (`ChannelDimension` or `str`, *optional*):
435
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
436
+ from the input image. Can be one of:
437
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
438
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
439
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
440
+ """
441
+ do_crop_margin = do_crop_margin if do_crop_margin is not None else self.do_crop_margin
442
+ do_resize = do_resize if do_resize is not None else self.do_resize
443
+ size = size if size is not None else self.size
444
+ resample = resample if resample is not None else self.resample
445
+ do_thumbnail = do_thumbnail if do_thumbnail is not None else self.do_thumbnail
446
+ do_align_long_axis = do_align_long_axis if do_align_long_axis is not None else self.do_align_long_axis
447
+ do_pad = do_pad if do_pad is not None else self.do_pad
448
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
449
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
450
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
451
+ image_mean = image_mean if image_mean is not None else self.image_mean
452
+ image_std = image_std if image_std is not None else self.image_std
453
+
454
+ images = make_list_of_images(images)
455
+
456
+ if not valid_images(images):
457
+ raise ValueError(
458
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
459
+ "torch.Tensor, tf.Tensor or jax.ndarray."
460
+ )
461
+ validate_preprocess_arguments(
462
+ do_rescale=do_rescale,
463
+ rescale_factor=rescale_factor,
464
+ do_normalize=do_normalize,
465
+ image_mean=image_mean,
466
+ image_std=image_std,
467
+ do_pad=do_pad,
468
+ size_divisibility=size, # There is no pad divisibility in this processor, but pad requires the size arg.
469
+ do_resize=do_resize,
470
+ size=size,
471
+ resample=resample,
472
+ )
473
+
474
+ # All transformations expect numpy arrays.
475
+ images = [to_numpy_array(image) for image in images]
476
+
477
+ if do_rescale and is_scaled_image(images[0]):
478
+ logger.warning_once(
479
+ "It looks like you are trying to rescale already rescaled images. If the input"
480
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
481
+ )
482
+
483
+ if input_data_format is None:
484
+ # We assume that all images have the same channel dimension format.
485
+ input_data_format = infer_channel_dimension_format(images[0])
486
+
487
+ if do_crop_margin:
488
+ images = [self.crop_margin(image, input_data_format=input_data_format) for image in images]
489
+
490
+ if do_align_long_axis:
491
+ images = [self.align_long_axis(image, size=size, input_data_format=input_data_format) for image in images]
492
+
493
+ if do_resize:
494
+ images = [
495
+ self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
496
+ for image in images
497
+ ]
498
+
499
+ if do_thumbnail:
500
+ images = [self.thumbnail(image=image, size=size, input_data_format=input_data_format) for image in images]
501
+
502
+ if do_pad:
503
+ images = [self.pad_image(image=image, size=size, input_data_format=input_data_format) for image in images]
504
+
505
+ if do_rescale:
506
+ images = [
507
+ self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
508
+ for image in images
509
+ ]
510
+
511
+ if do_normalize:
512
+ images = [
513
+ self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
514
+ for image in images
515
+ ]
516
+
517
+ images = [
518
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
519
+ ]
520
+
521
+ data = {"pixel_values": images}
522
+ return BatchFeature(data=data, tensor_type=return_tensors)
523
+
524
+
525
+ __all__ = ["NougatImageProcessor"]
docs/transformers/src/transformers/models/nougat/processing_nougat.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Processor class for Nougat.
17
+ """
18
+
19
+ from typing import Dict, List, Optional, Union
20
+
21
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput, TruncationStrategy
22
+
23
+ from ...processing_utils import ProcessorMixin
24
+ from ...utils import PaddingStrategy, TensorType
25
+
26
+
27
+ class NougatProcessor(ProcessorMixin):
28
+ r"""
29
+ Constructs a Nougat processor which wraps a Nougat image processor and a Nougat tokenizer into a single processor.
30
+
31
+ [`NougatProcessor`] offers all the functionalities of [`NougatImageProcessor`] and [`NougatTokenizerFast`]. See the
32
+ [`~NougatProcessor.__call__`] and [`~NougatProcessor.decode`] for more information.
33
+
34
+ Args:
35
+ image_processor ([`NougatImageProcessor`]):
36
+ An instance of [`NougatImageProcessor`]. The image processor is a required input.
37
+ tokenizer ([`NougatTokenizerFast`]):
38
+ An instance of [`NougatTokenizerFast`]. The tokenizer is a required input.
39
+ """
40
+
41
+ attributes = ["image_processor", "tokenizer"]
42
+ image_processor_class = "AutoImageProcessor"
43
+ tokenizer_class = "AutoTokenizer"
44
+
45
+ def __init__(self, image_processor, tokenizer):
46
+ super().__init__(image_processor, tokenizer)
47
+ self.current_processor = self.image_processor
48
+
49
+ def __call__(
50
+ self,
51
+ images=None,
52
+ text=None,
53
+ do_crop_margin: Optional[bool] = None,
54
+ do_resize: Optional[bool] = None,
55
+ size: Dict[str, int] = None,
56
+ resample: "PILImageResampling" = None, # noqa: F821
57
+ do_thumbnail: Optional[bool] = None,
58
+ do_align_long_axis: Optional[bool] = None,
59
+ do_pad: Optional[bool] = None,
60
+ do_rescale: Optional[bool] = None,
61
+ rescale_factor: Union[int, float] = None,
62
+ do_normalize: Optional[bool] = None,
63
+ image_mean: Optional[Union[float, List[float]]] = None,
64
+ image_std: Optional[Union[float, List[float]]] = None,
65
+ data_format: Optional["ChannelDimension"] = "channels_first", # noqa: F821
66
+ input_data_format: Optional[Union[str, "ChannelDimension"]] = None, # noqa: F821
67
+ text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
68
+ text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
69
+ text_pair_target: Optional[
70
+ Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]
71
+ ] = None,
72
+ add_special_tokens: bool = True,
73
+ padding: Union[bool, str, PaddingStrategy] = False,
74
+ truncation: Union[bool, str, TruncationStrategy] = None,
75
+ max_length: Optional[int] = None,
76
+ stride: int = 0,
77
+ is_split_into_words: bool = False,
78
+ pad_to_multiple_of: Optional[int] = None,
79
+ return_tensors: Optional[Union[str, TensorType]] = None,
80
+ return_token_type_ids: Optional[bool] = None,
81
+ return_attention_mask: Optional[bool] = None,
82
+ return_overflowing_tokens: bool = False,
83
+ return_special_tokens_mask: bool = False,
84
+ return_offsets_mapping: bool = False,
85
+ return_length: bool = False,
86
+ verbose: bool = True,
87
+ ):
88
+ if images is None and text is None:
89
+ raise ValueError("You need to specify either an `images` or `text` input to process.")
90
+
91
+ if images is not None:
92
+ inputs = self.image_processor(
93
+ images,
94
+ do_crop_margin=do_crop_margin,
95
+ do_resize=do_resize,
96
+ size=size,
97
+ resample=resample,
98
+ do_thumbnail=do_thumbnail,
99
+ do_align_long_axis=do_align_long_axis,
100
+ do_pad=do_pad,
101
+ do_rescale=do_rescale,
102
+ rescale_factor=rescale_factor,
103
+ do_normalize=do_normalize,
104
+ image_mean=image_mean,
105
+ image_std=image_std,
106
+ return_tensors=return_tensors,
107
+ data_format=data_format,
108
+ input_data_format=input_data_format,
109
+ )
110
+ if text is not None:
111
+ encodings = self.tokenizer(
112
+ text,
113
+ text_pair=text_pair,
114
+ text_target=text_target,
115
+ text_pair_target=text_pair_target,
116
+ add_special_tokens=add_special_tokens,
117
+ padding=padding,
118
+ truncation=truncation,
119
+ max_length=max_length,
120
+ stride=stride,
121
+ is_split_into_words=is_split_into_words,
122
+ pad_to_multiple_of=pad_to_multiple_of,
123
+ return_tensors=return_tensors,
124
+ return_token_type_ids=return_token_type_ids,
125
+ return_attention_mask=return_attention_mask,
126
+ return_overflowing_tokens=return_overflowing_tokens,
127
+ return_special_tokens_mask=return_special_tokens_mask,
128
+ return_offsets_mapping=return_offsets_mapping,
129
+ return_length=return_length,
130
+ verbose=verbose,
131
+ )
132
+
133
+ if text is None:
134
+ return inputs
135
+ elif images is None:
136
+ return encodings
137
+ else:
138
+ inputs["labels"] = encodings["input_ids"]
139
+ return inputs
140
+
141
+ def batch_decode(self, *args, **kwargs):
142
+ """
143
+ This method forwards all its arguments to NougatTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please refer
144
+ to the docstring of this method for more information.
145
+ """
146
+ return self.tokenizer.batch_decode(*args, **kwargs)
147
+
148
+ def decode(self, *args, **kwargs):
149
+ """
150
+ This method forwards all its arguments to NougatTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
151
+ the docstring of this method for more information.
152
+ """
153
+ return self.tokenizer.decode(*args, **kwargs)
154
+
155
+ def post_process_generation(self, *args, **kwargs):
156
+ """
157
+ This method forwards all its arguments to NougatTokenizer's [`~PreTrainedTokenizer.post_process_generation`].
158
+ Please refer to the docstring of this method for more information.
159
+ """
160
+ return self.tokenizer.post_process_generation(*args, **kwargs)
161
+
162
+
163
+ __all__ = ["NougatProcessor"]
docs/transformers/src/transformers/models/nougat/tokenization_nougat_fast.py ADDED
@@ -0,0 +1,620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Fast tokenizer class for Nougat.
17
+ """
18
+
19
+ import re
20
+ from functools import partial
21
+ from multiprocessing import Pool
22
+ from typing import List, Optional, Union
23
+
24
+ import numpy as np
25
+
26
+ from transformers.tokenization_utils_base import INIT_TOKENIZER_DOCSTRING
27
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
28
+ from transformers.utils import add_end_docstrings
29
+
30
+ from ...utils import is_levenshtein_available, is_nltk_available, logging, requires_backends
31
+
32
+
33
+ if is_levenshtein_available():
34
+ from Levenshtein import ratio
35
+
36
+ if is_nltk_available():
37
+ import nltk
38
+
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+
43
+ INIT_TOKENIZER_DOCSTRING += """
44
+ tokenizer_object ([`tokenizers.Tokenizer`]):
45
+ A [`tokenizers.Tokenizer`] object from 🤗 tokenizers to instantiate from. See [Using tokenizers from 🤗
46
+ tokenizers](../fast_tokenizers) for more information.
47
+ tokenizer_file ([`str`]):
48
+ A path to a local JSON file representing a previously serialized [`tokenizers.Tokenizer`] object from 🤗
49
+ tokenizers.
50
+ """
51
+
52
+
53
+ VOCAB_FILES_NAMES = {"tokenizer_file": "tokenizer.json"}
54
+
55
+
56
+ def markdown_compatible(text: str) -> str:
57
+ """
58
+ Make text compatible with Markdown formatting.
59
+
60
+ This function makes various text formatting adjustments to make it compatible with Markdown.
61
+
62
+ Args:
63
+ text (`str`):
64
+ The input text to be made Markdown-compatible.
65
+
66
+ Returns:
67
+ `str`: The Markdown-compatible text.
68
+ """
69
+ # equation tag
70
+ # Replace lines that start with a pattern like (decimal) \[some text\] with \[[some text] \tag{decimal}\].
71
+ text = re.sub(r"^\(([\d.]+[a-zA-Z]?)\) \\\[(.+?)\\\]$", r"\[\2 \\tag{\1}\]", text, flags=re.M)
72
+ # Replace lines that start with a pattern like \[some text\] (decimal) with \[[some text] \tag{decimal}\].
73
+ text = re.sub(r"^\\\[(.+?)\\\] \(([\d.]+[a-zA-Z]?)\)$", r"\[\1 \\tag{\2}\]", text, flags=re.M)
74
+ # Replace lines that start with a pattern like \[some text\] (digits) \[another text\] with \[[some text] \tag{digits}\] [another text].
75
+ text = re.sub(
76
+ r"^\\\[(.+?)\\\] \(([\d.]+[a-zA-Z]?)\) (\\\[.+?\\\])$",
77
+ r"\[\1 \\tag{\2}\] \3",
78
+ text,
79
+ flags=re.M,
80
+ )
81
+ # multi line
82
+ text = text.replace(r"\. ", ". ")
83
+ # bold formatting
84
+ text = text.replace(r"\bm{", r"\mathbf{").replace(r"{\\bm ", r"\mathbf{")
85
+ text = re.sub(r"\\mbox{ ?\\boldmath\$(.*?)\$}", r"\\mathbf{\1}", text)
86
+ # Reformat urls (http, ftp and https only) to markdown [url](url) clickable format
87
+ text = re.sub(
88
+ r"((?:http|ftp|https):\/\/(?:[\w_-]+(?:(?:\.[\w_-]+)+))(?:[\w.,@?^=%&:\/~+#-]*[\w@?^=%&\/~+#-]))",
89
+ r"[\1](\1)",
90
+ text,
91
+ )
92
+ # algorithms
93
+ text = re.sub(r"```\s*(.+?)\s*```", r"```\n\1\n```", text, flags=re.S)
94
+
95
+ return text
96
+
97
+
98
+ def normalize_list_like_lines(generation):
99
+ """
100
+ Normalize lines in the given text that resemble list items. The function looks for lines that start optionally with
101
+ '-' or '*', possibly followed by Roman numerals or digits indicating nesting levels. The function reformats such
102
+ lines to make them more structured.
103
+
104
+ Args:
105
+ generation (str): The input text containing lines that need to be normalized.
106
+
107
+ Returns:
108
+ str: The input text with the list-like lines normalized.
109
+
110
+ Note:
111
+ The function uses regular expressions to identify and reformat the list-like lines. The patterns capture
112
+ optional bullet points, nesting levels indicated by numerals, and the actual list item content. The
113
+ normalization adjusts the bullet point style and nesting levels based on the captured patterns.
114
+ """
115
+
116
+ lines = generation.split("\n")
117
+ output_lines = []
118
+ for line_no, line in enumerate(lines):
119
+ match = re.search(r". ([-*]) ", line)
120
+ if not match or line[0] not in ("-", "*"):
121
+ output_lines.append(line)
122
+ continue # Doesn't fit the pattern we want, no changes
123
+ delim = match.group(1) + " "
124
+ splits = line.split(delim)[1:]
125
+ replacement = ""
126
+ delim1 = line[0] + " "
127
+
128
+ for i, item in enumerate(splits):
129
+ level = 0
130
+ potential_numeral, _, rest = item.strip().partition(" ")
131
+ if not rest:
132
+ continue
133
+ # Infer current nesting level based on detected numbering
134
+ if re.match(r"^[\dixv]+((?:\.[\dixv])?)+$", potential_numeral, flags=re.I | re.M):
135
+ level = potential_numeral.count(".")
136
+
137
+ replacement += (
138
+ ("\n" if i > 0 else "") + ("\t" * level) + (delim if i > 0 or line_no == 0 else delim1) + item.strip()
139
+ )
140
+
141
+ if line_no == len(lines) - 1: # If this is the last line in the generation
142
+ replacement += "\n" # Add an empty line to the end of the generation
143
+
144
+ output_lines.append(replacement)
145
+
146
+ return "\n".join(output_lines)
147
+
148
+
149
+ def find_next_punctuation(text: str, start_idx=0):
150
+ """
151
+ Find the index of the next punctuation mark.
152
+
153
+ Args:
154
+ text (`str`):
155
+ String to examine
156
+ start_idx (`int`, *optional*)
157
+ Index where to start
158
+ """
159
+
160
+ for i in range(start_idx, len(text)):
161
+ if text[i] in [".", "?", "!", "\n"]:
162
+ return i
163
+
164
+ return None
165
+
166
+
167
+ def truncate_repetitions(text: str, min_len: int = 30) -> str:
168
+ """
169
+ Attempt to truncate repeating segments in the input string.
170
+
171
+ This function looks for the longest repeating substring at the end of the input string and truncates it to appear
172
+ only once. To be considered for removal, repetitions need to be continuous.
173
+
174
+ Args:
175
+ text (`str`):
176
+ The input raw prediction to be truncated.
177
+ min_len (int):
178
+ The minimum length of the repeating segment.
179
+
180
+ Returns:
181
+ `str`: The input string with repeated segments truncated.
182
+ """
183
+ text_lower = text.lower()
184
+ text_length = len(text_lower)
185
+
186
+ if text_length < 2 * min_len:
187
+ return text
188
+
189
+ # try to find a length at which the tail is repeating
190
+ max_repetition_length = None
191
+ for repetition_length in range(min_len, int(text_length / 2)):
192
+ # check if there is a repetition at the end
193
+ same = True
194
+ for i in range(0, repetition_length):
195
+ if text_lower[text_length - repetition_length - i - 1] != text_lower[text_length - i - 1]:
196
+ same = False
197
+ break
198
+
199
+ if same:
200
+ max_repetition_length = repetition_length
201
+
202
+ if max_repetition_length is None:
203
+ return text
204
+
205
+ lcs = text_lower[-max_repetition_length:]
206
+
207
+ # remove all but the last repetition
208
+ substituted_text = text
209
+ substituted_text_lower = text_lower
210
+ while substituted_text_lower.endswith(lcs):
211
+ substituted_text = substituted_text[:-max_repetition_length]
212
+ substituted_text_lower = substituted_text_lower[:-max_repetition_length]
213
+
214
+ # this is the tail with the repetitions
215
+ repeating_tail = text_lower[len(substituted_text_lower) :]
216
+
217
+ # add until next punctuation and make sure last sentence is not repeating
218
+ substituted_text_lower_out = substituted_text_lower
219
+ while True:
220
+ sentence_end = find_next_punctuation(text_lower, len(substituted_text_lower_out))
221
+ sentence_start = find_next_punctuation(text_lower[::-1], len(substituted_text_lower_out))
222
+ if sentence_end and sentence_start:
223
+ sentence = text_lower[sentence_start:sentence_end]
224
+ substituted_text_lower_out = text_lower[: sentence_end + 1]
225
+ if sentence in repeating_tail:
226
+ break
227
+ else:
228
+ break
229
+
230
+ text_out = text[: len(substituted_text_lower_out)]
231
+
232
+ return text_out
233
+
234
+
235
+ def remove_numbers(lines):
236
+ def _clean(s):
237
+ return re.sub(r"(?:[\d_]|\*\*)", "", s).strip()
238
+
239
+ if isinstance(lines, str):
240
+ return _clean(lines)
241
+ out = []
242
+ for l in lines:
243
+ out.append(_clean(l))
244
+ return out
245
+
246
+
247
+ def get_slices(lines, clean_lines):
248
+ """
249
+ Get slices of text based on specific criteria within the lines.
250
+
251
+ This function identifies and returns slices of text from the input lines based on certain conditions.
252
+
253
+ These conditions were chosen by the Nougat authors:
254
+ - The slice is less than 200 characters long.
255
+ - The slice is more than 3 characters long.
256
+ - The slice does not start with "[MISSING_PAGE".
257
+ - The slice is either the same as the next slice or the ratio of the two in terms of Levensthein distance is
258
+ greater than 0.9.
259
+
260
+ Args:
261
+ lines (`List[str]`):
262
+ The list of lines containing the text.
263
+ clean_lines (`List[str]`):
264
+ A cleaned version of the text (without numbers).
265
+
266
+ Returns:
267
+ `List[tuple]`: A list of tuples representing the start and end indices of text slices.
268
+ """
269
+ indices = np.zeros(len(lines))
270
+ for i in range(len(lines) - 1):
271
+ j = i + 1
272
+ while not clean_lines[j] and j < len(lines) - 1:
273
+ j += 1
274
+ if (
275
+ len(clean_lines[i]) < 200
276
+ and len(clean_lines[i]) > 3
277
+ and len(clean_lines[j]) < 200
278
+ and len(clean_lines[j]) > 3
279
+ and not clean_lines[i].startswith("[MISSING_PAGE")
280
+ and (clean_lines[i] == clean_lines[j] or ratio(clean_lines[i], clean_lines[j]) > 0.9)
281
+ ):
282
+ indices[i:j] = 1
283
+ ids = np.where(indices)[0]
284
+ slices = []
285
+ if len(ids) == 0:
286
+ return slices
287
+ j0 = 0
288
+ for j, x in enumerate(np.diff(ids) > 3):
289
+ if x:
290
+ slices.append((ids[j0], ids[j] + 2))
291
+ j0 = j + 1
292
+ slices.append((ids[j0], ids[-1] + 2))
293
+ return [sli for sli in slices if sli[1] - sli[0] > 15]
294
+
295
+
296
+ def remove_slice_from_lines(lines, clean_text, slice) -> str:
297
+ """
298
+ Remove a slice of text from the lines based on specific criteria.
299
+
300
+ This function identifies a slice of text within the lines and removes it based on certain conditions.
301
+
302
+ Args:
303
+ lines (list of str): The list of lines containing the text.
304
+ clean_text (list of str): A cleaned version of the text (without numbers).
305
+ slice (tuple): A tuple representing the start and end indices of the slice to be removed.
306
+
307
+ Returns:
308
+ str: The removed slice of text as a single string.
309
+ """
310
+ base = clean_text[slice[0]]
311
+ section = list(slice)
312
+ check_start_flag = False
313
+ # backwards pass, at most 5 lines
314
+ for line_idx in range(max(0, slice[0] - 1), max(0, slice[0] - 5), -1):
315
+ if not lines[line_idx]:
316
+ continue
317
+ if lines[line_idx] == "## References":
318
+ section[0] = line_idx
319
+ break
320
+ elif ratio(base, remove_numbers(lines[line_idx])) < 0.9:
321
+ section[0] = line_idx + 1
322
+ potential_ref = remove_numbers(lines[max(0, line_idx - 1)].partition("* [")[-1])
323
+ if len(potential_ref) >= 0.75 * len(base) and ratio(base, potential_ref) < 0.9:
324
+ section[0] = line_idx
325
+ check_start_flag = True
326
+ break
327
+ # forward pass, at most 5 lines
328
+ for line_idx in range(min(len(lines), slice[1]), min(len(lines), slice[1] + 5)):
329
+ if ratio(base, remove_numbers(lines[line_idx])) < 0.9:
330
+ section[1] = line_idx
331
+ break
332
+ if len(lines) <= section[1]:
333
+ section[1] = len(lines) - 1
334
+ to_delete = "\n".join(lines[section[0] : section[1] + 1])
335
+ # cut off next page content
336
+ itera, iterb = enumerate(lines[section[1] - 1]), enumerate(lines[section[1]])
337
+ while True:
338
+ try:
339
+ (ia, a) = next(itera)
340
+ while a.isnumeric():
341
+ (ia, a) = next(itera)
342
+ (ib, b) = next(iterb)
343
+ while b.isnumeric():
344
+ (ib, b) = next(iterb)
345
+ if a != b:
346
+ break
347
+ except StopIteration:
348
+ break
349
+ if check_start_flag and "* [" in to_delete:
350
+ to_delete = "* [" + to_delete.partition("* [")[-1]
351
+ try:
352
+ delta = len(lines[section[1]]) - ib - 1
353
+ if delta > 0:
354
+ to_delete = to_delete[:-delta]
355
+ except UnboundLocalError:
356
+ pass
357
+
358
+ return to_delete.strip()
359
+
360
+
361
+ @add_end_docstrings(INIT_TOKENIZER_DOCSTRING)
362
+ class NougatTokenizerFast(PreTrainedTokenizerFast):
363
+ """
364
+ Fast tokenizer for Nougat (backed by HuggingFace tokenizers library).
365
+
366
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
367
+ refer to this superclass for more information regarding those methods. This class mainly adds Nougat-specific
368
+ methods for postprocessing the generated text.
369
+
370
+ Args:
371
+ vocab_file (`str`, *optional*):
372
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that
373
+ contains the vocabulary necessary to instantiate a tokenizer.
374
+ tokenizer_file (`str`, *optional*):
375
+ [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
376
+ contains everything needed to load the tokenizer.
377
+
378
+ clean_up_tokenization_spaces (`str`, *optional*, defaults to `False`):
379
+ Wether to cleanup spaces after decoding, cleanup consists in removing potential artifacts like extra
380
+ spaces.
381
+
382
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
383
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
384
+ token instead.
385
+
386
+ bos_token (`str`, *optional*, defaults to `"<s>"`):
387
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
388
+
389
+ eos_token (`str`, *optional*, defaults to `"</s>"`):
390
+ The end of sequence token.
391
+
392
+ pad_token (`str`, *optional*, defaults to `"<pad>"`):
393
+ The token used for padding, for example when batching sequences of different lengths.
394
+ """
395
+
396
+ vocab_files_names = VOCAB_FILES_NAMES
397
+ model_input_names = ["input_ids", "attention_mask"]
398
+ slow_tokenizer_class = None
399
+
400
+ def __init__(
401
+ self,
402
+ vocab_file=None,
403
+ tokenizer_file=None,
404
+ clean_up_tokenization_spaces=False,
405
+ unk_token="<unk>",
406
+ bos_token="<s>",
407
+ eos_token="</s>",
408
+ pad_token="<pad>",
409
+ **kwargs,
410
+ ):
411
+ super().__init__(
412
+ vocab_file=vocab_file,
413
+ tokenizer_file=tokenizer_file,
414
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
415
+ unk_token=unk_token,
416
+ bos_token=bos_token,
417
+ eos_token=eos_token,
418
+ pad_token=pad_token,
419
+ **kwargs,
420
+ )
421
+ self.vocab_file = vocab_file
422
+
423
+ def remove_hallucinated_references(self, text: str) -> str:
424
+ """
425
+ Remove hallucinated or missing references from the text.
426
+
427
+ This function identifies and removes references that are marked as missing or hallucinated from the input text.
428
+
429
+ Args:
430
+ text (`str`):
431
+ The input text containing references.
432
+
433
+ Returns:
434
+ `str`: The text with hallucinated references removed.
435
+ """
436
+ lines = text.split("\n")
437
+ if len(lines) == 0:
438
+ return ""
439
+ clean_lines = remove_numbers(lines)
440
+ slices = get_slices(lines, clean_lines)
441
+ to_delete = []
442
+ for slice in slices:
443
+ to_delete.append(remove_slice_from_lines(lines, clean_lines, slice))
444
+ for to_delete in reversed(to_delete):
445
+ text = text.replace(to_delete, "\n\n[MISSING_PAGE_POST]\n\n")
446
+ text = re.sub(
447
+ r"## References\n+\[MISSING_PAGE_POST(:\d+)?\]",
448
+ "\n\n[MISSING_PAGE_POST\\1]",
449
+ text,
450
+ )
451
+ return text
452
+
453
+ def correct_tables(self, generation: str) -> str:
454
+ """
455
+ Takes a generated string and fixes tables/tabulars to make them match the markdown format needed.
456
+
457
+ Args:
458
+ generation (str): The generated text to be postprocessed.
459
+
460
+ Returns:
461
+ str: The postprocessed text.
462
+
463
+ Example:
464
+
465
+ ```python
466
+ correct_tables("\\begin{table} \\begin{tabular}{l l} & \\ \\end{tabular} \\end{table}")
467
+ "\\begin{table}\n\\begin{tabular}{l l} & \\ \\end{tabular}\n\\end{table}"
468
+ ```
469
+ """
470
+ # remove obvious wrong tables
471
+ for l in generation.split("\n"):
472
+ if l.count("\\begin{tabular}") > 15 or l.count("\\multicolumn") > 60 or l.count("&") > 400:
473
+ generation = generation.replace(l, "")
474
+ # whitespace corrections
475
+
476
+ generation = generation.replace("\\begin{table} \\begin{tabular}", "\\begin{table}\n\\begin{tabular}")
477
+ generation = generation.replace("\\end{tabular} \\end{table}", "\\end{tabular}\n\\end{table}")
478
+ generation = generation.replace("\\end{table} Tab", "\\end{table}\nTab")
479
+
480
+ generation = re.sub(r"(^.+)\\begin{tab", r"\1\n\\begin{tab", generation, flags=re.M)
481
+
482
+ # Remove left-aligned empty LaTeX tabular blocks.
483
+ generation = generation.replace(r"\begin{tabular}{l l} & \\ \end{tabular}", "")
484
+ # Remove tabulars with just 2 newline characters.
485
+ generation = generation.replace("\\begin{tabular}{}\n\n\\end{tabular}", "")
486
+ return generation
487
+
488
+ def post_process_single(self, generation: str, fix_markdown: bool = True) -> str:
489
+ """
490
+ Postprocess a single generated text. Regular expressions used here are taken directly from the Nougat article
491
+ authors. These expressions are commented for clarity and tested end-to-end in most cases.
492
+
493
+ Args:
494
+ generation (str): The generated text to be postprocessed.
495
+ fix_markdown (bool, optional): Whether to perform Markdown formatting fixes. Default is True.
496
+
497
+ Returns:
498
+ str: The postprocessed text.
499
+ """
500
+ generation = re.sub(
501
+ r"(?:\n|^)#+ \d*\W? ?(.{100,})", r"\n\1", generation
502
+ ) # too long section titles probably are none
503
+ generation = generation.strip()
504
+ # Remove LaTeX left margin tag
505
+ generation = generation.replace("\n* [leftmargin=*]\n", "\n")
506
+ # Remove lines with markdown headings starting with #, with numerals,
507
+ # and possibly roman numerals with trailing spaces and newlines
508
+ generation = re.sub(r"^#+ (?:[\d+\.]+|[ixv\.]+)?\s*(?:$|\n\s*)", "", generation, flags=re.M)
509
+ # most likely hallucinated titles
510
+ lines = generation.split("\n")
511
+ if lines[-1].startswith("#") and lines[-1].lstrip("#").startswith(" ") and len(lines) > 1:
512
+ logger.info("Likely hallucinated title at the end of the page: " + lines[-1])
513
+ generation = "\n".join(lines[:-1])
514
+ # obvious repetition detection
515
+ generation = truncate_repetitions(generation)
516
+ # Reference corrections
517
+ generation = self.remove_hallucinated_references(generation)
518
+ # Remove lines starting with asterisks and numbers like "*[1]" and followed by capital letters and periods (ie too long references)
519
+ generation = re.sub(r"^\* \[\d+\](\s?[A-W]\.+\s?){10,}.*$", "", generation, flags=re.M)
520
+ # Remove empty brackets after a reference number in brackets. *[12][]ABC will become *[12]ABC
521
+ generation = re.sub(r"^(\* \[\d+\])\[\](.*)$", r"\1\2", generation, flags=re.M)
522
+ # Remove single characters before or after 2 new lines
523
+ generation = re.sub(r"(^\w\n\n|\n\n\w$)", "", generation)
524
+ # pmc math artifact correction
525
+ generation = re.sub(
526
+ r"([\s.,()])_([a-zA-Z0-9])__([a-zA-Z0-9]){1,3}_([\s.,:()])",
527
+ r"\1\(\2_{\3}\)\4",
528
+ generation,
529
+ )
530
+ generation = re.sub(r"([\s.,\d])_([a-zA-Z0-9])_([\s.,\d;])", r"\1\(\2\)\3", generation)
531
+ # footnote mistakes
532
+ generation = re.sub(
533
+ r"(\nFootnote .*?:) (?:footnotetext|thanks):\W*(.*(?:\n\n|$))",
534
+ r"\1 \2",
535
+ generation,
536
+ )
537
+ # TODO Come up with footnote formatting inside a table
538
+ generation = re.sub(r"\[FOOTNOTE:.+?\](.*?)\[ENDFOOTNOTE\]", "", generation)
539
+ # itemize post processing
540
+ generation = normalize_list_like_lines(generation)
541
+
542
+ if generation.endswith((".", "}")):
543
+ generation += "\n\n"
544
+ if re.match(r"[A-Z0-9,;:]$", generation):
545
+ # add space in case it there is a comma or word ending
546
+ generation += " "
547
+ elif generation.startswith(("#", "**", "\\begin")):
548
+ generation = "\n\n" + generation
549
+ elif generation.split("\n")[-1].startswith(("#", "Figure", "Table")):
550
+ generation = generation + "\n\n"
551
+ else:
552
+ try:
553
+ last_word = generation.split(" ")[-1]
554
+ if last_word in nltk.corpus.words.words():
555
+ generation += " "
556
+ except LookupError:
557
+ # add space just in case. Will split words but better than concatenating them
558
+ generation += " "
559
+
560
+ # table corrections
561
+ generation = self.correct_tables(generation)
562
+ # Remove optional, empty square brackets after begin{array}
563
+ generation = generation.replace("\\begin{array}[]{", "\\begin{array}{")
564
+ # Remove empty or malformed LaTeX tabular blocks with 2 or more columns specified, with spaces and ampersands.
565
+ generation = re.sub(
566
+ r"\\begin{tabular}{([clr ]){2,}}\s*[& ]*\s*(\\\\)? \\end{tabular}",
567
+ "",
568
+ generation,
569
+ )
570
+ # Remove lines containing "S.A.B." one or more times. Was included in Nougat's code.
571
+ generation = re.sub(r"(\*\*S\. A\. B\.\*\*\n+){2,}", "", generation)
572
+ # Remove markdown-style headers that are incomplete or empty on multiple lines.
573
+ generation = re.sub(r"^#+( [\[\d\w])?$", "", generation, flags=re.M)
574
+ # Remove lines with just one period.
575
+ generation = re.sub(r"^\.\s*$", "", generation, flags=re.M)
576
+ # Replace instances of three or more newlines with just two newlines.
577
+ generation = re.sub(r"\n{3,}", "\n\n", generation)
578
+ if fix_markdown:
579
+ return markdown_compatible(generation)
580
+ else:
581
+ return generation
582
+
583
+ def post_process_generation(
584
+ self,
585
+ generation: Union[str, List[str]],
586
+ fix_markdown: bool = True,
587
+ num_workers: Optional[int] = None,
588
+ ) -> Union[str, List[str]]:
589
+ """
590
+ Postprocess a generated text or a list of generated texts.
591
+
592
+ This function can be used to perform postprocessing on generated text, such as fixing Markdown formatting.
593
+
594
+ Postprocessing is quite slow so it is recommended to use multiprocessing to speed up the process.
595
+
596
+ Args:
597
+ generation (Union[str, List[str]]):
598
+ The generated text or a list of generated texts.
599
+ fix_markdown (`bool`, *optional*, defaults to `True`):
600
+ Whether to perform Markdown formatting fixes.
601
+ num_workers (`int`, *optional*):
602
+ Optional number of workers to pass to leverage multiprocessing (postprocessing several texts in
603
+ parallel).
604
+
605
+ Returns:
606
+ Union[str, List[str]]: The postprocessed text or list of postprocessed texts.
607
+ """
608
+ requires_backends(self, ["nltk", "levenshtein"])
609
+
610
+ if isinstance(generation, list):
611
+ if num_workers is not None and isinstance(num_workers, int):
612
+ with Pool(num_workers) as p:
613
+ return p.map(partial(self.post_process_single, fix_markdown=fix_markdown), generation)
614
+ else:
615
+ return [self.post_process_single(s, fix_markdown=fix_markdown) for s in generation]
616
+ else:
617
+ return self.post_process_single(generation, fix_markdown=fix_markdown)
618
+
619
+
620
+ __all__ = ["NougatTokenizerFast"]
docs/transformers/src/transformers/models/nystromformer/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_nystromformer import *
22
+ from .modeling_nystromformer import *
23
+ else:
24
+ import sys
25
+
26
+ _file = globals()["__file__"]
27
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/src/transformers/models/nystromformer/configuration_nystromformer.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 UW-Madison and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Nystromformer model configuration"""
16
+
17
+ from ...configuration_utils import PretrainedConfig
18
+ from ...utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class NystromformerConfig(PretrainedConfig):
25
+ r"""
26
+ This is the configuration class to store the configuration of a [`NystromformerModel`]. It is used to instantiate
27
+ an Nystromformer model according to the specified arguments, defining the model architecture. Instantiating a
28
+ configuration with the defaults will yield a similar configuration to that of the Nystromformer
29
+ [uw-madison/nystromformer-512](https://huggingface.co/uw-madison/nystromformer-512) architecture.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+ Args:
35
+ vocab_size (`int`, *optional*, defaults to 30000):
36
+ Vocabulary size of the Nystromformer model. Defines the number of different tokens that can be represented
37
+ by the `inputs_ids` passed when calling [`NystromformerModel`].
38
+ hidden_size (`int`, *optional*, defaults to 768):
39
+ Dimension of the encoder layers and the pooler layer.
40
+ num_hidden_layers (`int`, *optional*, defaults to 12):
41
+ Number of hidden layers in the Transformer encoder.
42
+ num_attention_heads (`int`, *optional*, defaults to 12):
43
+ Number of attention heads for each attention layer in the Transformer encoder.
44
+ intermediate_size (`int`, *optional*, defaults to 3072):
45
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
46
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
47
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
48
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
49
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
50
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
51
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
52
+ The dropout ratio for the attention probabilities.
53
+ max_position_embeddings (`int`, *optional*, defaults to 512):
54
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
55
+ just in case (e.g., 512 or 1024 or 2048).
56
+ type_vocab_size (`int`, *optional*, defaults to 2):
57
+ The vocabulary size of the `token_type_ids` passed when calling [`NystromformerModel`].
58
+ segment_means_seq_len (`int`, *optional*, defaults to 64):
59
+ Sequence length used in segment-means.
60
+ num_landmarks (`int`, *optional*, defaults to 64):
61
+ The number of landmark (or Nystrom) points to use in Nystrom approximation of the softmax self-attention
62
+ matrix.
63
+ conv_kernel_size (`int`, *optional*, defaults to 65):
64
+ The kernel size of depthwise convolution used in Nystrom approximation.
65
+ inv_coeff_init_option (`bool`, *optional*, defaults to `False`):
66
+ Whether or not to use exact coefficient computation for the initial values for the iterative method of
67
+ calculating the Moore-Penrose inverse of a matrix.
68
+ initializer_range (`float`, *optional*, defaults to 0.02):
69
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
70
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
71
+ The epsilon used by the layer normalization layers.
72
+
73
+ Example:
74
+
75
+ ```python
76
+ >>> from transformers import NystromformerModel, NystromformerConfig
77
+
78
+ >>> # Initializing a Nystromformer uw-madison/nystromformer-512 style configuration
79
+ >>> configuration = NystromformerConfig()
80
+
81
+ >>> # Initializing a model from the uw-madison/nystromformer-512 style configuration
82
+ >>> model = NystromformerModel(configuration)
83
+
84
+ >>> # Accessing the model configuration
85
+ >>> configuration = model.config
86
+ ```"""
87
+
88
+ model_type = "nystromformer"
89
+
90
+ def __init__(
91
+ self,
92
+ vocab_size=30000,
93
+ hidden_size=768,
94
+ num_hidden_layers=12,
95
+ num_attention_heads=12,
96
+ intermediate_size=3072,
97
+ hidden_act="gelu_new",
98
+ hidden_dropout_prob=0.1,
99
+ attention_probs_dropout_prob=0.1,
100
+ max_position_embeddings=510,
101
+ type_vocab_size=2,
102
+ segment_means_seq_len=64,
103
+ num_landmarks=64,
104
+ conv_kernel_size=65,
105
+ inv_coeff_init_option=False,
106
+ initializer_range=0.02,
107
+ layer_norm_eps=1e-5,
108
+ pad_token_id=1,
109
+ bos_token_id=0,
110
+ eos_token_id=2,
111
+ **kwargs,
112
+ ):
113
+ self.vocab_size = vocab_size
114
+ self.max_position_embeddings = max_position_embeddings
115
+ self.hidden_size = hidden_size
116
+ self.num_hidden_layers = num_hidden_layers
117
+ self.num_attention_heads = num_attention_heads
118
+ self.intermediate_size = intermediate_size
119
+ self.hidden_act = hidden_act
120
+ self.hidden_dropout_prob = hidden_dropout_prob
121
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
122
+ self.initializer_range = initializer_range
123
+ self.type_vocab_size = type_vocab_size
124
+ self.segment_means_seq_len = segment_means_seq_len
125
+ self.num_landmarks = num_landmarks
126
+ self.conv_kernel_size = conv_kernel_size
127
+ self.inv_coeff_init_option = inv_coeff_init_option
128
+ self.layer_norm_eps = layer_norm_eps
129
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
130
+
131
+
132
+ __all__ = ["NystromformerConfig"]
docs/transformers/src/transformers/models/nystromformer/convert_nystromformer_original_pytorch_checkpoint_to_pytorch.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Convert Nystromformer checkpoints from the original repository."""
17
+
18
+ import argparse
19
+
20
+ import torch
21
+
22
+ from transformers import NystromformerConfig, NystromformerForMaskedLM
23
+
24
+
25
+ def rename_key(orig_key):
26
+ if "model" in orig_key:
27
+ orig_key = orig_key.replace("model.", "")
28
+ if "norm1" in orig_key:
29
+ orig_key = orig_key.replace("norm1", "attention.output.LayerNorm")
30
+ if "norm2" in orig_key:
31
+ orig_key = orig_key.replace("norm2", "output.LayerNorm")
32
+ if "norm" in orig_key:
33
+ orig_key = orig_key.replace("norm", "LayerNorm")
34
+ if "transformer" in orig_key:
35
+ layer_num = orig_key.split(".")[0].split("_")[-1]
36
+ orig_key = orig_key.replace(f"transformer_{layer_num}", f"encoder.layer.{layer_num}")
37
+ if "mha.attn" in orig_key:
38
+ orig_key = orig_key.replace("mha.attn", "attention.self")
39
+ if "mha" in orig_key:
40
+ orig_key = orig_key.replace("mha", "attention")
41
+ if "W_q" in orig_key:
42
+ orig_key = orig_key.replace("W_q", "self.query")
43
+ if "W_k" in orig_key:
44
+ orig_key = orig_key.replace("W_k", "self.key")
45
+ if "W_v" in orig_key:
46
+ orig_key = orig_key.replace("W_v", "self.value")
47
+ if "ff1" in orig_key:
48
+ orig_key = orig_key.replace("ff1", "intermediate.dense")
49
+ if "ff2" in orig_key:
50
+ orig_key = orig_key.replace("ff2", "output.dense")
51
+ if "ff" in orig_key:
52
+ orig_key = orig_key.replace("ff", "output.dense")
53
+ if "mlm_class" in orig_key:
54
+ orig_key = orig_key.replace("mlm.mlm_class", "cls.predictions.decoder")
55
+ if "mlm" in orig_key:
56
+ orig_key = orig_key.replace("mlm", "cls.predictions.transform")
57
+ if "cls" not in orig_key:
58
+ orig_key = "nystromformer." + orig_key
59
+
60
+ return orig_key
61
+
62
+
63
+ def convert_checkpoint_helper(config, orig_state_dict):
64
+ for key in orig_state_dict.copy().keys():
65
+ val = orig_state_dict.pop(key)
66
+
67
+ if ("pooler" in key) or ("sen_class" in key) or ("conv.bias" in key):
68
+ continue
69
+ else:
70
+ orig_state_dict[rename_key(key)] = val
71
+
72
+ orig_state_dict["cls.predictions.bias"] = orig_state_dict["cls.predictions.decoder.bias"]
73
+ orig_state_dict["nystromformer.embeddings.position_ids"] = (
74
+ torch.arange(config.max_position_embeddings).expand((1, -1)) + 2
75
+ )
76
+
77
+ return orig_state_dict
78
+
79
+
80
+ def convert_nystromformer_checkpoint(checkpoint_path, nystromformer_config_file, pytorch_dump_path):
81
+ orig_state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)["model_state_dict"]
82
+ config = NystromformerConfig.from_json_file(nystromformer_config_file)
83
+ model = NystromformerForMaskedLM(config)
84
+
85
+ new_state_dict = convert_checkpoint_helper(config, orig_state_dict)
86
+
87
+ model.load_state_dict(new_state_dict)
88
+ model.eval()
89
+ model.save_pretrained(pytorch_dump_path)
90
+
91
+ print(f"Checkpoint successfuly converted. Model saved at {pytorch_dump_path}")
92
+
93
+
94
+ if __name__ == "__main__":
95
+ parser = argparse.ArgumentParser()
96
+ # Required parameters
97
+ parser.add_argument(
98
+ "--pytorch_model_path", default=None, type=str, required=True, help="Path to Nystromformer pytorch checkpoint."
99
+ )
100
+ parser.add_argument(
101
+ "--config_file",
102
+ default=None,
103
+ type=str,
104
+ required=True,
105
+ help="The json file for Nystromformer model config.",
106
+ )
107
+ parser.add_argument(
108
+ "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
109
+ )
110
+ args = parser.parse_args()
111
+ convert_nystromformer_checkpoint(args.pytorch_model_path, args.config_file, args.pytorch_dump_path)
docs/transformers/src/transformers/models/nystromformer/modeling_nystromformer.py ADDED
@@ -0,0 +1,1124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 UW-Madison The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch Nystromformer model."""
16
+
17
+ import math
18
+ from typing import Optional, Tuple, Union
19
+
20
+ import torch
21
+ import torch.utils.checkpoint
22
+ from torch import nn
23
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
24
+
25
+ from ...activations import ACT2FN
26
+ from ...modeling_outputs import (
27
+ BaseModelOutputWithPastAndCrossAttentions,
28
+ MaskedLMOutput,
29
+ MultipleChoiceModelOutput,
30
+ QuestionAnsweringModelOutput,
31
+ SequenceClassifierOutput,
32
+ TokenClassifierOutput,
33
+ )
34
+ from ...modeling_utils import PreTrainedModel
35
+ from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
36
+ from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
37
+ from .configuration_nystromformer import NystromformerConfig
38
+
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+ _CHECKPOINT_FOR_DOC = "uw-madison/nystromformer-512"
43
+ _CONFIG_FOR_DOC = "NystromformerConfig"
44
+
45
+
46
+ class NystromformerEmbeddings(nn.Module):
47
+ """Construct the embeddings from word, position and token_type embeddings."""
48
+
49
+ def __init__(self, config):
50
+ super().__init__()
51
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
52
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings + 2, config.hidden_size)
53
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
54
+
55
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
56
+ # any TensorFlow checkpoint file
57
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
58
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
59
+
60
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
61
+ self.register_buffer(
62
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + 2, persistent=False
63
+ )
64
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
65
+ self.register_buffer(
66
+ "token_type_ids",
67
+ torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
68
+ persistent=False,
69
+ )
70
+
71
+ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
72
+ if input_ids is not None:
73
+ input_shape = input_ids.size()
74
+ else:
75
+ input_shape = inputs_embeds.size()[:-1]
76
+
77
+ seq_length = input_shape[1]
78
+
79
+ if position_ids is None:
80
+ position_ids = self.position_ids[:, :seq_length]
81
+
82
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
83
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
84
+ # issue #5664
85
+ if token_type_ids is None:
86
+ if hasattr(self, "token_type_ids"):
87
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
88
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
89
+ token_type_ids = buffered_token_type_ids_expanded
90
+ else:
91
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
92
+
93
+ if inputs_embeds is None:
94
+ inputs_embeds = self.word_embeddings(input_ids)
95
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
96
+
97
+ embeddings = inputs_embeds + token_type_embeddings
98
+ if self.position_embedding_type == "absolute":
99
+ position_embeddings = self.position_embeddings(position_ids)
100
+ embeddings += position_embeddings
101
+ embeddings = self.LayerNorm(embeddings)
102
+ embeddings = self.dropout(embeddings)
103
+ return embeddings
104
+
105
+
106
+ class NystromformerSelfAttention(nn.Module):
107
+ def __init__(self, config, position_embedding_type=None):
108
+ super().__init__()
109
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
110
+ raise ValueError(
111
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
112
+ f"heads ({config.num_attention_heads})"
113
+ )
114
+
115
+ self.num_attention_heads = config.num_attention_heads
116
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
117
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
118
+
119
+ self.num_landmarks = config.num_landmarks
120
+ self.seq_len = config.segment_means_seq_len
121
+ self.conv_kernel_size = config.conv_kernel_size
122
+
123
+ if config.inv_coeff_init_option:
124
+ self.init_option = config["inv_init_coeff_option"]
125
+ else:
126
+ self.init_option = "original"
127
+
128
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
129
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
130
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
131
+
132
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
133
+ self.position_embedding_type = position_embedding_type or getattr(
134
+ config, "position_embedding_type", "absolute"
135
+ )
136
+
137
+ if self.conv_kernel_size is not None:
138
+ self.conv = nn.Conv2d(
139
+ in_channels=self.num_attention_heads,
140
+ out_channels=self.num_attention_heads,
141
+ kernel_size=(self.conv_kernel_size, 1),
142
+ padding=(self.conv_kernel_size // 2, 0),
143
+ bias=False,
144
+ groups=self.num_attention_heads,
145
+ )
146
+
147
+ # Function to approximate Moore-Penrose inverse via the iterative method
148
+ def iterative_inv(self, mat, n_iter=6):
149
+ identity = torch.eye(mat.size(-1), device=mat.device)
150
+ key = mat
151
+
152
+ # The entries of key are positive and ||key||_{\infty} = 1 due to softmax
153
+ if self.init_option == "original":
154
+ # This original implementation is more conservative to compute coefficient of Z_0.
155
+ value = 1 / torch.max(torch.sum(key, dim=-2)) * key.transpose(-1, -2)
156
+ else:
157
+ # This is the exact coefficient computation, 1 / ||key||_1, of initialization of Z_0, leading to faster convergence.
158
+ value = 1 / torch.max(torch.sum(key, dim=-2), dim=-1).values[:, :, None, None] * key.transpose(-1, -2)
159
+
160
+ for _ in range(n_iter):
161
+ key_value = torch.matmul(key, value)
162
+ value = torch.matmul(
163
+ 0.25 * value,
164
+ 13 * identity
165
+ - torch.matmul(key_value, 15 * identity - torch.matmul(key_value, 7 * identity - key_value)),
166
+ )
167
+ return value
168
+
169
+ def transpose_for_scores(self, layer):
170
+ new_layer_shape = layer.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
171
+ layer = layer.view(*new_layer_shape)
172
+ return layer.permute(0, 2, 1, 3)
173
+
174
+ def forward(self, hidden_states, attention_mask=None, output_attentions=False):
175
+ mixed_query_layer = self.query(hidden_states)
176
+
177
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
178
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
179
+ query_layer = self.transpose_for_scores(mixed_query_layer)
180
+
181
+ query_layer = query_layer / math.sqrt(math.sqrt(self.attention_head_size))
182
+ key_layer = key_layer / math.sqrt(math.sqrt(self.attention_head_size))
183
+
184
+ if self.num_landmarks == self.seq_len:
185
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
186
+
187
+ if attention_mask is not None:
188
+ # Apply the attention mask is (precomputed for all layers in NystromformerModel forward() function)
189
+ attention_scores = attention_scores + attention_mask
190
+
191
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
192
+ context_layer = torch.matmul(attention_probs, value_layer)
193
+
194
+ else:
195
+ q_landmarks = query_layer.reshape(
196
+ -1,
197
+ self.num_attention_heads,
198
+ self.num_landmarks,
199
+ self.seq_len // self.num_landmarks,
200
+ self.attention_head_size,
201
+ ).mean(dim=-2)
202
+ k_landmarks = key_layer.reshape(
203
+ -1,
204
+ self.num_attention_heads,
205
+ self.num_landmarks,
206
+ self.seq_len // self.num_landmarks,
207
+ self.attention_head_size,
208
+ ).mean(dim=-2)
209
+
210
+ kernel_1 = torch.nn.functional.softmax(torch.matmul(query_layer, k_landmarks.transpose(-1, -2)), dim=-1)
211
+ kernel_2 = torch.nn.functional.softmax(torch.matmul(q_landmarks, k_landmarks.transpose(-1, -2)), dim=-1)
212
+
213
+ attention_scores = torch.matmul(q_landmarks, key_layer.transpose(-1, -2))
214
+
215
+ if attention_mask is not None:
216
+ # Apply the attention mask is (precomputed for all layers in NystromformerModel forward() function)
217
+ attention_scores = attention_scores + attention_mask
218
+
219
+ kernel_3 = nn.functional.softmax(attention_scores, dim=-1)
220
+ attention_probs = torch.matmul(kernel_1, self.iterative_inv(kernel_2))
221
+ new_value_layer = torch.matmul(kernel_3, value_layer)
222
+ context_layer = torch.matmul(attention_probs, new_value_layer)
223
+
224
+ if self.conv_kernel_size is not None:
225
+ context_layer += self.conv(value_layer)
226
+
227
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
228
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
229
+ context_layer = context_layer.view(*new_context_layer_shape)
230
+
231
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
232
+
233
+ return outputs
234
+
235
+
236
+ # Copied from transformers.models.bert.modeling_bert.BertSelfOutput
237
+ class NystromformerSelfOutput(nn.Module):
238
+ def __init__(self, config):
239
+ super().__init__()
240
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
241
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
242
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
243
+
244
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
245
+ hidden_states = self.dense(hidden_states)
246
+ hidden_states = self.dropout(hidden_states)
247
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
248
+ return hidden_states
249
+
250
+
251
+ class NystromformerAttention(nn.Module):
252
+ def __init__(self, config, position_embedding_type=None):
253
+ super().__init__()
254
+ self.self = NystromformerSelfAttention(config, position_embedding_type=position_embedding_type)
255
+ self.output = NystromformerSelfOutput(config)
256
+ self.pruned_heads = set()
257
+
258
+ def prune_heads(self, heads):
259
+ if len(heads) == 0:
260
+ return
261
+ heads, index = find_pruneable_heads_and_indices(
262
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
263
+ )
264
+
265
+ # Prune linear layers
266
+ self.self.query = prune_linear_layer(self.self.query, index)
267
+ self.self.key = prune_linear_layer(self.self.key, index)
268
+ self.self.value = prune_linear_layer(self.self.value, index)
269
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
270
+
271
+ # Update hyper params and store pruned heads
272
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
273
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
274
+ self.pruned_heads = self.pruned_heads.union(heads)
275
+
276
+ def forward(self, hidden_states, attention_mask=None, output_attentions=False):
277
+ self_outputs = self.self(hidden_states, attention_mask, output_attentions)
278
+ attention_output = self.output(self_outputs[0], hidden_states)
279
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
280
+ return outputs
281
+
282
+
283
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->Nystromformer
284
+ class NystromformerIntermediate(nn.Module):
285
+ def __init__(self, config):
286
+ super().__init__()
287
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
288
+ if isinstance(config.hidden_act, str):
289
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
290
+ else:
291
+ self.intermediate_act_fn = config.hidden_act
292
+
293
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
294
+ hidden_states = self.dense(hidden_states)
295
+ hidden_states = self.intermediate_act_fn(hidden_states)
296
+ return hidden_states
297
+
298
+
299
+ # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->Nystromformer
300
+ class NystromformerOutput(nn.Module):
301
+ def __init__(self, config):
302
+ super().__init__()
303
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
304
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
305
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
306
+
307
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
308
+ hidden_states = self.dense(hidden_states)
309
+ hidden_states = self.dropout(hidden_states)
310
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
311
+ return hidden_states
312
+
313
+
314
+ class NystromformerLayer(nn.Module):
315
+ def __init__(self, config):
316
+ super().__init__()
317
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
318
+ self.seq_len_dim = 1
319
+ self.attention = NystromformerAttention(config)
320
+ self.add_cross_attention = config.add_cross_attention
321
+ self.intermediate = NystromformerIntermediate(config)
322
+ self.output = NystromformerOutput(config)
323
+
324
+ def forward(self, hidden_states, attention_mask=None, output_attentions=False):
325
+ self_attention_outputs = self.attention(hidden_states, attention_mask, output_attentions=output_attentions)
326
+ attention_output = self_attention_outputs[0]
327
+
328
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
329
+
330
+ layer_output = apply_chunking_to_forward(
331
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
332
+ )
333
+ outputs = (layer_output,) + outputs
334
+
335
+ return outputs
336
+
337
+ def feed_forward_chunk(self, attention_output):
338
+ intermediate_output = self.intermediate(attention_output)
339
+ layer_output = self.output(intermediate_output, attention_output)
340
+ return layer_output
341
+
342
+
343
+ class NystromformerEncoder(nn.Module):
344
+ def __init__(self, config):
345
+ super().__init__()
346
+ self.config = config
347
+ self.layer = nn.ModuleList([NystromformerLayer(config) for _ in range(config.num_hidden_layers)])
348
+ self.gradient_checkpointing = False
349
+
350
+ def forward(
351
+ self,
352
+ hidden_states: torch.Tensor,
353
+ attention_mask: Optional[torch.Tensor] = None,
354
+ head_mask: Optional[torch.Tensor] = None,
355
+ output_attentions: bool = False,
356
+ output_hidden_states: bool = False,
357
+ return_dict: bool = True,
358
+ ):
359
+ all_hidden_states = () if output_hidden_states else None
360
+ all_self_attentions = () if output_attentions else None
361
+
362
+ for i, layer_module in enumerate(self.layer):
363
+ if output_hidden_states:
364
+ all_hidden_states = all_hidden_states + (hidden_states,)
365
+
366
+ if self.gradient_checkpointing and self.training:
367
+ layer_outputs = self._gradient_checkpointing_func(
368
+ layer_module.__call__,
369
+ hidden_states,
370
+ attention_mask,
371
+ output_attentions,
372
+ )
373
+ else:
374
+ layer_outputs = layer_module(hidden_states, attention_mask, output_attentions)
375
+
376
+ hidden_states = layer_outputs[0]
377
+ if output_attentions:
378
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
379
+
380
+ if output_hidden_states:
381
+ all_hidden_states = all_hidden_states + (hidden_states,)
382
+
383
+ if not return_dict:
384
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
385
+ return BaseModelOutputWithPastAndCrossAttentions(
386
+ last_hidden_state=hidden_states,
387
+ hidden_states=all_hidden_states,
388
+ attentions=all_self_attentions,
389
+ )
390
+
391
+
392
+ # Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->Nystromformer
393
+ class NystromformerPredictionHeadTransform(nn.Module):
394
+ def __init__(self, config):
395
+ super().__init__()
396
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
397
+ if isinstance(config.hidden_act, str):
398
+ self.transform_act_fn = ACT2FN[config.hidden_act]
399
+ else:
400
+ self.transform_act_fn = config.hidden_act
401
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
402
+
403
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
404
+ hidden_states = self.dense(hidden_states)
405
+ hidden_states = self.transform_act_fn(hidden_states)
406
+ hidden_states = self.LayerNorm(hidden_states)
407
+ return hidden_states
408
+
409
+
410
+ # Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->Nystromformer
411
+ class NystromformerLMPredictionHead(nn.Module):
412
+ def __init__(self, config):
413
+ super().__init__()
414
+ self.transform = NystromformerPredictionHeadTransform(config)
415
+
416
+ # The output weights are the same as the input embeddings, but there is
417
+ # an output-only bias for each token.
418
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
419
+
420
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
421
+
422
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
423
+ self.decoder.bias = self.bias
424
+
425
+ def _tie_weights(self):
426
+ self.decoder.bias = self.bias
427
+
428
+ def forward(self, hidden_states):
429
+ hidden_states = self.transform(hidden_states)
430
+ hidden_states = self.decoder(hidden_states)
431
+ return hidden_states
432
+
433
+
434
+ # Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->Nystromformer
435
+ class NystromformerOnlyMLMHead(nn.Module):
436
+ def __init__(self, config):
437
+ super().__init__()
438
+ self.predictions = NystromformerLMPredictionHead(config)
439
+
440
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
441
+ prediction_scores = self.predictions(sequence_output)
442
+ return prediction_scores
443
+
444
+
445
+ class NystromformerPreTrainedModel(PreTrainedModel):
446
+ """
447
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
448
+ models.
449
+ """
450
+
451
+ config_class = NystromformerConfig
452
+ base_model_prefix = "nystromformer"
453
+ supports_gradient_checkpointing = True
454
+
455
+ def _init_weights(self, module):
456
+ """Initialize the weights"""
457
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
458
+ # Slightly different from the TF version which uses truncated_normal for initialization
459
+ # cf https://github.com/pytorch/pytorch/pull/5617
460
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
461
+ if module.bias is not None:
462
+ module.bias.data.zero_()
463
+ elif isinstance(module, nn.Embedding):
464
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
465
+ if module.padding_idx is not None:
466
+ module.weight.data[module.padding_idx].zero_()
467
+ elif isinstance(module, nn.LayerNorm):
468
+ module.bias.data.zero_()
469
+ module.weight.data.fill_(1.0)
470
+
471
+
472
+ NYSTROMFORMER_START_DOCSTRING = r"""
473
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
474
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
475
+ behavior.
476
+
477
+ Parameters:
478
+ config ([`NystromformerConfig`]): Model configuration class with all the parameters of the model.
479
+ Initializing with a config file does not load the weights associated with the model, only the
480
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
481
+ """
482
+
483
+ NYSTROMFORMER_INPUTS_DOCSTRING = r"""
484
+ Args:
485
+ input_ids (`torch.LongTensor` of shape `({0})`):
486
+ Indices of input sequence tokens in the vocabulary.
487
+
488
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
489
+ [`PreTrainedTokenizer.__call__`] for details.
490
+
491
+ [What are input IDs?](../glossary#input-ids)
492
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
493
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
494
+
495
+ - 1 for tokens that are **not masked**,
496
+ - 0 for tokens that are **masked**.
497
+
498
+ [What are attention masks?](../glossary#attention-mask)
499
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
500
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
501
+ 1]`:
502
+
503
+ - 0 corresponds to a *sentence A* token,
504
+ - 1 corresponds to a *sentence B* token.
505
+
506
+ [What are token type IDs?](../glossary#token-type-ids)
507
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
508
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
509
+ config.max_position_embeddings - 1]`.
510
+
511
+ [What are position IDs?](../glossary#position-ids)
512
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
513
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
514
+
515
+ - 1 indicates the head is **not masked**,
516
+ - 0 indicates the head is **masked**.
517
+
518
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
519
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
520
+ is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
521
+ model's internal embedding lookup matrix.
522
+ output_attentions (`bool`, *optional*):
523
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
524
+ tensors for more detail.
525
+ output_hidden_states (`bool`, *optional*):
526
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
527
+ more detail.
528
+ return_dict (`bool`, *optional*):
529
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
530
+ """
531
+
532
+
533
+ @add_start_docstrings(
534
+ "The bare Nyströmformer Model transformer outputting raw hidden-states without any specific head on top.",
535
+ NYSTROMFORMER_START_DOCSTRING,
536
+ )
537
+ class NystromformerModel(NystromformerPreTrainedModel):
538
+ def __init__(self, config):
539
+ super().__init__(config)
540
+ self.config = config
541
+
542
+ self.embeddings = NystromformerEmbeddings(config)
543
+ self.encoder = NystromformerEncoder(config)
544
+
545
+ # Initialize weights and apply final processing
546
+ self.post_init()
547
+
548
+ def get_input_embeddings(self):
549
+ return self.embeddings.word_embeddings
550
+
551
+ def set_input_embeddings(self, value):
552
+ self.embeddings.word_embeddings = value
553
+
554
+ def _prune_heads(self, heads_to_prune):
555
+ """
556
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
557
+ class PreTrainedModel
558
+ """
559
+ for layer, heads in heads_to_prune.items():
560
+ self.encoder.layer[layer].attention.prune_heads(heads)
561
+
562
+ @add_start_docstrings_to_model_forward(NYSTROMFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
563
+ @add_code_sample_docstrings(
564
+ checkpoint=_CHECKPOINT_FOR_DOC,
565
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
566
+ config_class=_CONFIG_FOR_DOC,
567
+ )
568
+ def forward(
569
+ self,
570
+ input_ids: Optional[torch.LongTensor] = None,
571
+ attention_mask: Optional[torch.FloatTensor] = None,
572
+ token_type_ids: Optional[torch.LongTensor] = None,
573
+ position_ids: Optional[torch.LongTensor] = None,
574
+ head_mask: Optional[torch.FloatTensor] = None,
575
+ inputs_embeds: Optional[torch.FloatTensor] = None,
576
+ output_attentions: Optional[bool] = None,
577
+ output_hidden_states: Optional[bool] = None,
578
+ return_dict: Optional[bool] = None,
579
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
580
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
581
+ output_hidden_states = (
582
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
583
+ )
584
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
585
+
586
+ if input_ids is not None and inputs_embeds is not None:
587
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
588
+ elif input_ids is not None:
589
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
590
+ input_shape = input_ids.size()
591
+ elif inputs_embeds is not None:
592
+ input_shape = inputs_embeds.size()[:-1]
593
+ else:
594
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
595
+
596
+ batch_size, seq_length = input_shape
597
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
598
+
599
+ if attention_mask is None:
600
+ attention_mask = torch.ones(((batch_size, seq_length)), device=device)
601
+
602
+ if token_type_ids is None:
603
+ if hasattr(self.embeddings, "token_type_ids"):
604
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
605
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
606
+ token_type_ids = buffered_token_type_ids_expanded
607
+ else:
608
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
609
+
610
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
611
+ # ourselves in which case we just need to make it broadcastable to all heads.
612
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
613
+
614
+ # Prepare head mask if needed
615
+ # 1.0 in head_mask indicate we keep the head
616
+ # attention_probs has shape bsz x n_heads x N x N
617
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
618
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
619
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
620
+
621
+ embedding_output = self.embeddings(
622
+ input_ids=input_ids,
623
+ position_ids=position_ids,
624
+ token_type_ids=token_type_ids,
625
+ inputs_embeds=inputs_embeds,
626
+ )
627
+ encoder_outputs = self.encoder(
628
+ embedding_output,
629
+ attention_mask=extended_attention_mask,
630
+ head_mask=head_mask,
631
+ output_attentions=output_attentions,
632
+ output_hidden_states=output_hidden_states,
633
+ return_dict=return_dict,
634
+ )
635
+ sequence_output = encoder_outputs[0]
636
+
637
+ if not return_dict:
638
+ return (sequence_output,) + encoder_outputs[1:]
639
+
640
+ return BaseModelOutputWithPastAndCrossAttentions(
641
+ last_hidden_state=sequence_output,
642
+ hidden_states=encoder_outputs.hidden_states,
643
+ attentions=encoder_outputs.attentions,
644
+ cross_attentions=encoder_outputs.cross_attentions,
645
+ )
646
+
647
+
648
+ @add_start_docstrings("""Nyströmformer Model with a `language modeling` head on top.""", NYSTROMFORMER_START_DOCSTRING)
649
+ class NystromformerForMaskedLM(NystromformerPreTrainedModel):
650
+ _tied_weights_keys = ["cls.predictions.decoder"]
651
+
652
+ def __init__(self, config):
653
+ super().__init__(config)
654
+
655
+ self.nystromformer = NystromformerModel(config)
656
+ self.cls = NystromformerOnlyMLMHead(config)
657
+
658
+ # Initialize weights and apply final processing
659
+ self.post_init()
660
+
661
+ def get_output_embeddings(self):
662
+ return self.cls.predictions.decoder
663
+
664
+ def set_output_embeddings(self, new_embeddings):
665
+ self.cls.predictions.decoder = new_embeddings
666
+ self.cls.predictions.bias = new_embeddings.bias
667
+
668
+ @add_start_docstrings_to_model_forward(NYSTROMFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
669
+ @add_code_sample_docstrings(
670
+ checkpoint=_CHECKPOINT_FOR_DOC,
671
+ output_type=MaskedLMOutput,
672
+ config_class=_CONFIG_FOR_DOC,
673
+ )
674
+ def forward(
675
+ self,
676
+ input_ids: Optional[torch.LongTensor] = None,
677
+ attention_mask: Optional[torch.FloatTensor] = None,
678
+ token_type_ids: Optional[torch.LongTensor] = None,
679
+ position_ids: Optional[torch.LongTensor] = None,
680
+ head_mask: Optional[torch.FloatTensor] = None,
681
+ inputs_embeds: Optional[torch.FloatTensor] = None,
682
+ labels: Optional[torch.LongTensor] = None,
683
+ output_attentions: Optional[bool] = None,
684
+ output_hidden_states: Optional[bool] = None,
685
+ return_dict: Optional[bool] = None,
686
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
687
+ r"""
688
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
689
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
690
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
691
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
692
+ """
693
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
694
+
695
+ outputs = self.nystromformer(
696
+ input_ids,
697
+ attention_mask=attention_mask,
698
+ token_type_ids=token_type_ids,
699
+ position_ids=position_ids,
700
+ head_mask=head_mask,
701
+ inputs_embeds=inputs_embeds,
702
+ output_attentions=output_attentions,
703
+ output_hidden_states=output_hidden_states,
704
+ return_dict=return_dict,
705
+ )
706
+
707
+ sequence_output = outputs[0]
708
+ prediction_scores = self.cls(sequence_output)
709
+
710
+ masked_lm_loss = None
711
+ if labels is not None:
712
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
713
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
714
+
715
+ if not return_dict:
716
+ output = (prediction_scores,) + outputs[1:]
717
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
718
+
719
+ return MaskedLMOutput(
720
+ loss=masked_lm_loss,
721
+ logits=prediction_scores,
722
+ hidden_states=outputs.hidden_states,
723
+ attentions=outputs.attentions,
724
+ )
725
+
726
+
727
+ class NystromformerClassificationHead(nn.Module):
728
+ """Head for sentence-level classification tasks."""
729
+
730
+ def __init__(self, config):
731
+ super().__init__()
732
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
733
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
734
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
735
+
736
+ self.config = config
737
+
738
+ def forward(self, features, **kwargs):
739
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
740
+ x = self.dropout(x)
741
+ x = self.dense(x)
742
+ x = ACT2FN[self.config.hidden_act](x)
743
+ x = self.dropout(x)
744
+ x = self.out_proj(x)
745
+ return x
746
+
747
+
748
+ @add_start_docstrings(
749
+ """
750
+ Nyströmformer Model transformer with a sequence classification/regression head on top (a linear layer on top of the
751
+ pooled output) e.g. for GLUE tasks.
752
+ """,
753
+ NYSTROMFORMER_START_DOCSTRING,
754
+ )
755
+ class NystromformerForSequenceClassification(NystromformerPreTrainedModel):
756
+ def __init__(self, config):
757
+ super().__init__(config)
758
+ self.num_labels = config.num_labels
759
+ self.nystromformer = NystromformerModel(config)
760
+ self.classifier = NystromformerClassificationHead(config)
761
+
762
+ # Initialize weights and apply final processing
763
+ self.post_init()
764
+
765
+ @add_start_docstrings_to_model_forward(NYSTROMFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
766
+ @add_code_sample_docstrings(
767
+ checkpoint=_CHECKPOINT_FOR_DOC,
768
+ output_type=SequenceClassifierOutput,
769
+ config_class=_CONFIG_FOR_DOC,
770
+ )
771
+ def forward(
772
+ self,
773
+ input_ids: Optional[torch.LongTensor] = None,
774
+ attention_mask: Optional[torch.FloatTensor] = None,
775
+ token_type_ids: Optional[torch.LongTensor] = None,
776
+ position_ids: Optional[torch.LongTensor] = None,
777
+ head_mask: Optional[torch.FloatTensor] = None,
778
+ inputs_embeds: Optional[torch.FloatTensor] = None,
779
+ labels: Optional[torch.LongTensor] = None,
780
+ output_attentions: Optional[bool] = None,
781
+ output_hidden_states: Optional[bool] = None,
782
+ return_dict: Optional[bool] = None,
783
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
784
+ r"""
785
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
786
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
787
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
788
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
789
+ """
790
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
791
+
792
+ outputs = self.nystromformer(
793
+ input_ids,
794
+ attention_mask=attention_mask,
795
+ token_type_ids=token_type_ids,
796
+ position_ids=position_ids,
797
+ head_mask=head_mask,
798
+ inputs_embeds=inputs_embeds,
799
+ output_attentions=output_attentions,
800
+ output_hidden_states=output_hidden_states,
801
+ return_dict=return_dict,
802
+ )
803
+
804
+ sequence_output = outputs[0]
805
+ logits = self.classifier(sequence_output)
806
+
807
+ loss = None
808
+ if labels is not None:
809
+ if self.config.problem_type is None:
810
+ if self.num_labels == 1:
811
+ self.config.problem_type = "regression"
812
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
813
+ self.config.problem_type = "single_label_classification"
814
+ else:
815
+ self.config.problem_type = "multi_label_classification"
816
+
817
+ if self.config.problem_type == "regression":
818
+ loss_fct = MSELoss()
819
+ if self.num_labels == 1:
820
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
821
+ else:
822
+ loss = loss_fct(logits, labels)
823
+ elif self.config.problem_type == "single_label_classification":
824
+ loss_fct = CrossEntropyLoss()
825
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
826
+ elif self.config.problem_type == "multi_label_classification":
827
+ loss_fct = BCEWithLogitsLoss()
828
+ loss = loss_fct(logits, labels)
829
+ if not return_dict:
830
+ output = (logits,) + outputs[1:]
831
+ return ((loss,) + output) if loss is not None else output
832
+
833
+ return SequenceClassifierOutput(
834
+ loss=loss,
835
+ logits=logits,
836
+ hidden_states=outputs.hidden_states,
837
+ attentions=outputs.attentions,
838
+ )
839
+
840
+
841
+ @add_start_docstrings(
842
+ """
843
+ Nyströmformer Model with a multiple choice classification head on top (a linear layer on top of the pooled output
844
+ and a softmax) e.g. for RocStories/SWAG tasks.
845
+ """,
846
+ NYSTROMFORMER_START_DOCSTRING,
847
+ )
848
+ class NystromformerForMultipleChoice(NystromformerPreTrainedModel):
849
+ def __init__(self, config):
850
+ super().__init__(config)
851
+
852
+ self.nystromformer = NystromformerModel(config)
853
+ self.pre_classifier = nn.Linear(config.hidden_size, config.hidden_size)
854
+ self.classifier = nn.Linear(config.hidden_size, 1)
855
+
856
+ # Initialize weights and apply final processing
857
+ self.post_init()
858
+
859
+ @add_start_docstrings_to_model_forward(
860
+ NYSTROMFORMER_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
861
+ )
862
+ @add_code_sample_docstrings(
863
+ checkpoint=_CHECKPOINT_FOR_DOC,
864
+ output_type=MultipleChoiceModelOutput,
865
+ config_class=_CONFIG_FOR_DOC,
866
+ )
867
+ def forward(
868
+ self,
869
+ input_ids: Optional[torch.LongTensor] = None,
870
+ attention_mask: Optional[torch.FloatTensor] = None,
871
+ token_type_ids: Optional[torch.LongTensor] = None,
872
+ position_ids: Optional[torch.LongTensor] = None,
873
+ head_mask: Optional[torch.FloatTensor] = None,
874
+ inputs_embeds: Optional[torch.FloatTensor] = None,
875
+ labels: Optional[torch.LongTensor] = None,
876
+ output_attentions: Optional[bool] = None,
877
+ output_hidden_states: Optional[bool] = None,
878
+ return_dict: Optional[bool] = None,
879
+ ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
880
+ r"""
881
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
882
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
883
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
884
+ `input_ids` above)
885
+ """
886
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
887
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
888
+
889
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
890
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
891
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
892
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
893
+ inputs_embeds = (
894
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
895
+ if inputs_embeds is not None
896
+ else None
897
+ )
898
+
899
+ outputs = self.nystromformer(
900
+ input_ids,
901
+ attention_mask=attention_mask,
902
+ token_type_ids=token_type_ids,
903
+ position_ids=position_ids,
904
+ head_mask=head_mask,
905
+ inputs_embeds=inputs_embeds,
906
+ output_attentions=output_attentions,
907
+ output_hidden_states=output_hidden_states,
908
+ return_dict=return_dict,
909
+ )
910
+
911
+ hidden_state = outputs[0] # (bs * num_choices, seq_len, dim)
912
+ pooled_output = hidden_state[:, 0] # (bs * num_choices, dim)
913
+ pooled_output = self.pre_classifier(pooled_output) # (bs * num_choices, dim)
914
+ pooled_output = nn.ReLU()(pooled_output) # (bs * num_choices, dim)
915
+ logits = self.classifier(pooled_output)
916
+
917
+ reshaped_logits = logits.view(-1, num_choices)
918
+
919
+ loss = None
920
+ if labels is not None:
921
+ loss_fct = CrossEntropyLoss()
922
+ loss = loss_fct(reshaped_logits, labels)
923
+
924
+ if not return_dict:
925
+ output = (reshaped_logits,) + outputs[1:]
926
+ return ((loss,) + output) if loss is not None else output
927
+
928
+ return MultipleChoiceModelOutput(
929
+ loss=loss,
930
+ logits=reshaped_logits,
931
+ hidden_states=outputs.hidden_states,
932
+ attentions=outputs.attentions,
933
+ )
934
+
935
+
936
+ @add_start_docstrings(
937
+ """
938
+ Nyströmformer Model with a token classification head on top (a linear layer on top of the hidden-states output)
939
+ e.g. for Named-Entity-Recognition (NER) tasks.
940
+ """,
941
+ NYSTROMFORMER_START_DOCSTRING,
942
+ )
943
+ class NystromformerForTokenClassification(NystromformerPreTrainedModel):
944
+ def __init__(self, config):
945
+ super().__init__(config)
946
+ self.num_labels = config.num_labels
947
+
948
+ self.nystromformer = NystromformerModel(config)
949
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
950
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
951
+
952
+ # Initialize weights and apply final processing
953
+ self.post_init()
954
+
955
+ @add_start_docstrings_to_model_forward(NYSTROMFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
956
+ @add_code_sample_docstrings(
957
+ checkpoint=_CHECKPOINT_FOR_DOC,
958
+ output_type=TokenClassifierOutput,
959
+ config_class=_CONFIG_FOR_DOC,
960
+ )
961
+ def forward(
962
+ self,
963
+ input_ids: Optional[torch.LongTensor] = None,
964
+ attention_mask: Optional[torch.FloatTensor] = None,
965
+ token_type_ids: Optional[torch.LongTensor] = None,
966
+ position_ids: Optional[torch.LongTensor] = None,
967
+ head_mask: Optional[torch.FloatTensor] = None,
968
+ inputs_embeds: Optional[torch.FloatTensor] = None,
969
+ labels: Optional[torch.LongTensor] = None,
970
+ output_attentions: Optional[bool] = None,
971
+ output_hidden_states: Optional[bool] = None,
972
+ return_dict: Optional[bool] = None,
973
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
974
+ r"""
975
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
976
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
977
+ """
978
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
979
+
980
+ outputs = self.nystromformer(
981
+ input_ids,
982
+ attention_mask=attention_mask,
983
+ token_type_ids=token_type_ids,
984
+ position_ids=position_ids,
985
+ head_mask=head_mask,
986
+ inputs_embeds=inputs_embeds,
987
+ output_attentions=output_attentions,
988
+ output_hidden_states=output_hidden_states,
989
+ return_dict=return_dict,
990
+ )
991
+
992
+ sequence_output = outputs[0]
993
+
994
+ sequence_output = self.dropout(sequence_output)
995
+ logits = self.classifier(sequence_output)
996
+
997
+ loss = None
998
+ if labels is not None:
999
+ loss_fct = CrossEntropyLoss()
1000
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1001
+
1002
+ if not return_dict:
1003
+ output = (logits,) + outputs[1:]
1004
+ return ((loss,) + output) if loss is not None else output
1005
+
1006
+ return TokenClassifierOutput(
1007
+ loss=loss,
1008
+ logits=logits,
1009
+ hidden_states=outputs.hidden_states,
1010
+ attentions=outputs.attentions,
1011
+ )
1012
+
1013
+
1014
+ @add_start_docstrings(
1015
+ """
1016
+ Nyströmformer Model with a span classification head on top for extractive question-answering tasks like SQuAD (a
1017
+ linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1018
+ """,
1019
+ NYSTROMFORMER_START_DOCSTRING,
1020
+ )
1021
+ class NystromformerForQuestionAnswering(NystromformerPreTrainedModel):
1022
+ def __init__(self, config):
1023
+ super().__init__(config)
1024
+
1025
+ config.num_labels = 2
1026
+ self.num_labels = config.num_labels
1027
+
1028
+ self.nystromformer = NystromformerModel(config)
1029
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1030
+
1031
+ # Initialize weights and apply final processing
1032
+ self.post_init()
1033
+
1034
+ @add_start_docstrings_to_model_forward(NYSTROMFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1035
+ @add_code_sample_docstrings(
1036
+ checkpoint=_CHECKPOINT_FOR_DOC,
1037
+ output_type=QuestionAnsweringModelOutput,
1038
+ config_class=_CONFIG_FOR_DOC,
1039
+ )
1040
+ def forward(
1041
+ self,
1042
+ input_ids: Optional[torch.LongTensor] = None,
1043
+ attention_mask: Optional[torch.FloatTensor] = None,
1044
+ token_type_ids: Optional[torch.LongTensor] = None,
1045
+ position_ids: Optional[torch.LongTensor] = None,
1046
+ head_mask: Optional[torch.FloatTensor] = None,
1047
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1048
+ start_positions: Optional[torch.LongTensor] = None,
1049
+ end_positions: Optional[torch.LongTensor] = None,
1050
+ output_attentions: Optional[bool] = None,
1051
+ output_hidden_states: Optional[bool] = None,
1052
+ return_dict: Optional[bool] = None,
1053
+ ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
1054
+ r"""
1055
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1056
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1057
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1058
+ are not taken into account for computing the loss.
1059
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1060
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1061
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1062
+ are not taken into account for computing the loss.
1063
+ """
1064
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1065
+
1066
+ outputs = self.nystromformer(
1067
+ input_ids,
1068
+ attention_mask=attention_mask,
1069
+ token_type_ids=token_type_ids,
1070
+ position_ids=position_ids,
1071
+ head_mask=head_mask,
1072
+ inputs_embeds=inputs_embeds,
1073
+ output_attentions=output_attentions,
1074
+ output_hidden_states=output_hidden_states,
1075
+ return_dict=return_dict,
1076
+ )
1077
+
1078
+ sequence_output = outputs[0]
1079
+
1080
+ logits = self.qa_outputs(sequence_output)
1081
+ start_logits, end_logits = logits.split(1, dim=-1)
1082
+ start_logits = start_logits.squeeze(-1)
1083
+ end_logits = end_logits.squeeze(-1)
1084
+
1085
+ total_loss = None
1086
+ if start_positions is not None and end_positions is not None:
1087
+ # If we are on multi-GPU, split add a dimension
1088
+ if len(start_positions.size()) > 1:
1089
+ start_positions = start_positions.squeeze(-1)
1090
+ if len(end_positions.size()) > 1:
1091
+ end_positions = end_positions.squeeze(-1)
1092
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1093
+ ignored_index = start_logits.size(1)
1094
+ start_positions = start_positions.clamp(0, ignored_index)
1095
+ end_positions = end_positions.clamp(0, ignored_index)
1096
+
1097
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1098
+ start_loss = loss_fct(start_logits, start_positions)
1099
+ end_loss = loss_fct(end_logits, end_positions)
1100
+ total_loss = (start_loss + end_loss) / 2
1101
+
1102
+ if not return_dict:
1103
+ output = (start_logits, end_logits) + outputs[1:]
1104
+ return ((total_loss,) + output) if total_loss is not None else output
1105
+
1106
+ return QuestionAnsweringModelOutput(
1107
+ loss=total_loss,
1108
+ start_logits=start_logits,
1109
+ end_logits=end_logits,
1110
+ hidden_states=outputs.hidden_states,
1111
+ attentions=outputs.attentions,
1112
+ )
1113
+
1114
+
1115
+ __all__ = [
1116
+ "NystromformerForMaskedLM",
1117
+ "NystromformerForMultipleChoice",
1118
+ "NystromformerForQuestionAnswering",
1119
+ "NystromformerForSequenceClassification",
1120
+ "NystromformerForTokenClassification",
1121
+ "NystromformerLayer",
1122
+ "NystromformerModel",
1123
+ "NystromformerPreTrainedModel",
1124
+ ]
docs/transformers/src/transformers/models/olmo/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 EleutherAI and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_olmo import *
22
+ from .modeling_olmo import *
23
+ else:
24
+ import sys
25
+
26
+ _file = globals()["__file__"]
27
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/src/transformers/models/olmo/configuration_olmo.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """OLMo model configuration"""
21
+
22
+ from ...configuration_utils import PretrainedConfig
23
+ from ...utils import logging
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ class OlmoConfig(PretrainedConfig):
30
+ r"""
31
+ This is the configuration class to store the configuration of a [`OlmoModel`]. It is used to instantiate an OLMo
32
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
33
+ defaults will yield a similar configuration to that of the [allenai/OLMo-7B-hf](https://huggingface.co/allenai/OLMo-7B-hf).
34
+
35
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
+ documentation from [`PretrainedConfig`] for more information.
37
+
38
+
39
+ Args:
40
+ vocab_size (`int`, *optional*, defaults to 50304):
41
+ Vocabulary size of the OLMo model. Defines the number of different tokens that can be represented by the
42
+ `inputs_ids` passed when calling [`OlmoModel`]
43
+ hidden_size (`int`, *optional*, defaults to 4096):
44
+ Dimension of the hidden representations.
45
+ intermediate_size (`int`, *optional*, defaults to 11008):
46
+ Dimension of the MLP representations.
47
+ num_hidden_layers (`int`, *optional*, defaults to 32):
48
+ Number of hidden layers in the Transformer decoder.
49
+ num_attention_heads (`int`, *optional*, defaults to 32):
50
+ Number of attention heads for each attention layer in the Transformer decoder.
51
+ num_key_value_heads (`int`, *optional*):
52
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
53
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
54
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
55
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
56
+ by meanpooling all the original heads within that group. For more details checkout [this
57
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
58
+ `num_attention_heads`.
59
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
60
+ The non-linear activation function (function or string) in the decoder.
61
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
62
+ The maximum sequence length that this model might ever be used with.
63
+ initializer_range (`float`, *optional*, defaults to 0.02):
64
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
65
+ use_cache (`bool`, *optional*, defaults to `True`):
66
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
67
+ relevant if `config.is_decoder=True`.
68
+ pad_token_id (`int`, *optional*, defaults to 1):
69
+ Padding token id.
70
+ bos_token_id (`int`, *optional*):
71
+ Beginning of stream token id.
72
+ eos_token_id (`int`, *optional*, defaults to 50279):
73
+ End of stream token id.
74
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
75
+ Whether to tie weight embeddings
76
+ rope_theta (`float`, *optional*, defaults to 10000.0):
77
+ The base period of the RoPE embeddings.
78
+ rope_scaling (`Dict`, *optional*):
79
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
80
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
81
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
82
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
83
+ these scaling strategies behave:
84
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
85
+ experimental feature, subject to breaking API changes in future versions.
86
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
87
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
88
+ attention_dropout (`float`, *optional*, defaults to 0.0):
89
+ The dropout ratio for the attention probabilities.
90
+ clip_qkv (`float`, *optional*):
91
+ If not `None`, elements of query, key and value attention states are clipped so that their
92
+ absolute value does not exceed this value.
93
+
94
+ ```python
95
+ >>> from transformers import OlmoModel, OlmoConfig
96
+
97
+ >>> # Initializing a OLMo 7B style configuration
98
+ >>> configuration = OlmoConfig()
99
+
100
+ >>> # Initializing a model from the OLMo 7B style configuration
101
+ >>> model = OlmoModel(configuration)
102
+
103
+ >>> # Accessing the model configuration
104
+ >>> configuration = model.config
105
+ ```"""
106
+
107
+ model_type = "olmo"
108
+ keys_to_ignore_at_inference = ["past_key_values"]
109
+ base_model_tp_plan = {
110
+ "layers.*.self_attn.q_proj": "colwise",
111
+ "layers.*.self_attn.k_proj": "colwise",
112
+ "layers.*.self_attn.v_proj": "colwise",
113
+ "layers.*.self_attn.o_proj": "rowwise",
114
+ "layers.*.mlp.gate_proj": "colwise",
115
+ "layers.*.mlp.up_proj": "colwise",
116
+ "layers.*.mlp.down_proj": "rowwise",
117
+ }
118
+ base_model_pp_plan = {
119
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
120
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
121
+ "norm": (["hidden_states"], ["hidden_states"]),
122
+ }
123
+
124
+ def __init__(
125
+ self,
126
+ vocab_size=50304,
127
+ hidden_size=4096,
128
+ intermediate_size=11008,
129
+ num_hidden_layers=32,
130
+ num_attention_heads=32,
131
+ num_key_value_heads=None,
132
+ hidden_act="silu",
133
+ max_position_embeddings=2048,
134
+ initializer_range=0.02,
135
+ use_cache=True,
136
+ pad_token_id=1,
137
+ bos_token_id=None,
138
+ eos_token_id=50279,
139
+ tie_word_embeddings=False,
140
+ rope_theta=10000.0,
141
+ rope_scaling=None,
142
+ attention_bias=False,
143
+ attention_dropout=0.0,
144
+ clip_qkv=None,
145
+ **kwargs,
146
+ ):
147
+ self.vocab_size = vocab_size
148
+ self.max_position_embeddings = max_position_embeddings
149
+ self.hidden_size = hidden_size
150
+ self.intermediate_size = intermediate_size
151
+ self.num_hidden_layers = num_hidden_layers
152
+ self.num_attention_heads = num_attention_heads
153
+
154
+ # for backward compatibility
155
+ if num_key_value_heads is None:
156
+ num_key_value_heads = num_attention_heads
157
+
158
+ self.num_key_value_heads = num_key_value_heads
159
+ self.hidden_act = hidden_act
160
+ self.initializer_range = initializer_range
161
+ self.use_cache = use_cache
162
+ self.rope_theta = rope_theta
163
+ self.rope_scaling = rope_scaling
164
+ self._rope_scaling_validation()
165
+ self.attention_bias = attention_bias
166
+ self.attention_dropout = attention_dropout
167
+ self.clip_qkv = clip_qkv
168
+
169
+ super().__init__(
170
+ pad_token_id=pad_token_id,
171
+ bos_token_id=bos_token_id,
172
+ eos_token_id=eos_token_id,
173
+ tie_word_embeddings=tie_word_embeddings,
174
+ **kwargs,
175
+ )
176
+
177
+ def _rope_scaling_validation(self):
178
+ """
179
+ Validate the `rope_scaling` configuration.
180
+ """
181
+ if self.rope_scaling is None:
182
+ return
183
+
184
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
185
+ raise ValueError(
186
+ f"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, got {self.rope_scaling}"
187
+ )
188
+ rope_scaling_type = self.rope_scaling.get("type", None)
189
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
190
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
191
+ raise ValueError(
192
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
193
+ )
194
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
195
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
196
+
197
+
198
+ __all__ = ["OlmoConfig"]
docs/transformers/src/transformers/models/olmo/convert_olmo_weights_to_hf.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 EleutherAI and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import argparse
15
+ import gc
16
+ import json
17
+ import os
18
+ import shutil
19
+ from pathlib import Path
20
+
21
+ import torch
22
+ import yaml
23
+ from tokenizers import Tokenizer
24
+
25
+ from transformers import OlmoConfig, OlmoForCausalLM
26
+ from transformers.models.gpt_neox.tokenization_gpt_neox_fast import GPTNeoXTokenizerFast
27
+
28
+
29
+ """
30
+ Sample usage:
31
+
32
+ ```
33
+ python src/transformers/models/olmo/convert_olmo_weights_to_hf.py \
34
+ --input_dir /path/to/downloaded/olmo/weights --model_size 7B --output_dir /output/path
35
+ ```
36
+
37
+ Thereafter, models can be loaded via:
38
+
39
+ ```py
40
+ from transformers import OlmoForCausalLM, AutoTokenizer
41
+
42
+ model = OlmoForCausalLM.from_pretrained("/output/path")
43
+ tokenizer = AutoTokenizer.from_pretrained("/output/path")
44
+ ```
45
+
46
+ Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
47
+ come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
48
+ """
49
+
50
+
51
+ def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
52
+ return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)
53
+
54
+
55
+ def read_json(path):
56
+ with open(path, "r") as f:
57
+ return json.load(f)
58
+
59
+
60
+ def write_json(text, path):
61
+ with open(path, "w") as f:
62
+ json.dump(text, f)
63
+
64
+
65
+ def write_model(model_path, input_base_path, tokenizer_path=None, safe_serialization=True, fix_eos_token_id=True):
66
+ os.makedirs(model_path, exist_ok=True)
67
+ tmp_model_path = os.path.join(model_path, "tmp")
68
+ os.makedirs(tmp_model_path, exist_ok=True)
69
+
70
+ config_path = Path(input_base_path) / "config.yaml"
71
+ olmo_config = yaml.safe_load(config_path.read_text())["model"]
72
+
73
+ n_layers = olmo_config["n_layers"]
74
+ n_heads = olmo_config["n_heads"]
75
+ dim = olmo_config["d_model"]
76
+ dims_per_head = dim // n_heads
77
+ base = 10000.0
78
+ inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
79
+ max_position_embeddings = olmo_config["max_sequence_length"]
80
+
81
+ vocab_size = olmo_config.get("embedding_size", olmo_config["vocab_size"])
82
+
83
+ if olmo_config.get("n_kv_heads", None) is not None:
84
+ num_key_value_heads = olmo_config["n_kv_heads"] # for GQA / MQA
85
+ elif olmo_config["multi_query_attention"]: # compatibility with other checkpoints
86
+ num_key_value_heads = 1
87
+ else:
88
+ num_key_value_heads = n_heads
89
+
90
+ print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
91
+
92
+ # Not sharded
93
+ # (The sharded implementation would also work, but this is simpler.)
94
+ loaded = torch.load(os.path.join(input_base_path, "model.pt"), map_location="cpu", weights_only=True)
95
+
96
+ param_count = 0
97
+ index_dict = {"weight_map": {}}
98
+ for layer_i in range(n_layers):
99
+ filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
100
+ # Unsharded
101
+ # TODO: Layernorm stuff
102
+ # TODO: multi query attention
103
+ fused_dims = [dim, dims_per_head * num_key_value_heads, dims_per_head * num_key_value_heads]
104
+ q_proj_weight, k_proj_weight, v_proj_weight = torch.split(
105
+ loaded[f"transformer.blocks.{layer_i}.att_proj.weight"], fused_dims, dim=0
106
+ )
107
+ up_proj_weight, gate_proj_weight = torch.chunk(
108
+ loaded[f"transformer.blocks.{layer_i}.ff_proj.weight"], 2, dim=0
109
+ )
110
+ state_dict = {
111
+ f"model.layers.{layer_i}.self_attn.q_proj.weight": q_proj_weight,
112
+ f"model.layers.{layer_i}.self_attn.k_proj.weight": k_proj_weight,
113
+ f"model.layers.{layer_i}.self_attn.v_proj.weight": v_proj_weight,
114
+ f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"transformer.blocks.{layer_i}.attn_out.weight"],
115
+ f"model.layers.{layer_i}.mlp.gate_proj.weight": gate_proj_weight,
116
+ f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"transformer.blocks.{layer_i}.ff_out.weight"],
117
+ f"model.layers.{layer_i}.mlp.up_proj.weight": up_proj_weight,
118
+ }
119
+
120
+ state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
121
+
122
+ for k, v in state_dict.items():
123
+ index_dict["weight_map"][k] = filename
124
+ param_count += v.numel()
125
+ torch.save(state_dict, os.path.join(tmp_model_path, filename))
126
+
127
+ filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
128
+
129
+ # Unsharded
130
+ # TODO: Deal with weight-tying
131
+ state_dict = {
132
+ "model.embed_tokens.weight": loaded["transformer.wte.weight"],
133
+ "lm_head.weight": loaded["transformer.ff_out.weight"]
134
+ if "transformer.ff_out.weight" in loaded
135
+ else loaded["transformer.wte.weight"],
136
+ }
137
+
138
+ for k, v in state_dict.items():
139
+ index_dict["weight_map"][k] = filename
140
+ param_count += v.numel()
141
+ torch.save(state_dict, os.path.join(tmp_model_path, filename))
142
+
143
+ # Write configs
144
+ index_dict["metadata"] = {"total_size": param_count * 2}
145
+ write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
146
+
147
+ if olmo_config.get("mlp_hidden_size", None) is not None:
148
+ intermediate_size = olmo_config["mlp_hidden_size"] // 2
149
+ else:
150
+ intermediate_size = (dim * olmo_config["mlp_ratio"]) // 2
151
+
152
+ config = OlmoConfig(
153
+ vocab_size=vocab_size,
154
+ hidden_size=dim,
155
+ intermediate_size=intermediate_size,
156
+ num_hidden_layers=n_layers,
157
+ num_attention_heads=n_heads,
158
+ num_key_value_heads=num_key_value_heads,
159
+ max_position_embeddings=max_position_embeddings,
160
+ pad_token_id=olmo_config["pad_token_id"],
161
+ bos_token_id=None,
162
+ eos_token_id=olmo_config["eos_token_id"],
163
+ tie_word_embeddings=olmo_config["weight_tying"],
164
+ rope_theta=base,
165
+ clip_qkv=olmo_config.get("clip_qkv"),
166
+ )
167
+ config.save_pretrained(tmp_model_path)
168
+
169
+ # Make space so we can load the model properly now.
170
+ del state_dict
171
+ del loaded
172
+ gc.collect()
173
+
174
+ if tokenizer_path is not None:
175
+ _write_tokenizer(model_path, config, tokenizer_path, fix_eos_token_id)
176
+
177
+ print("Loading the checkpoint in a OLMo model.")
178
+ model = OlmoForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float32, low_cpu_mem_usage=True)
179
+ # Avoid saving this as part of the config.
180
+ del model.config._name_or_path
181
+ print("Saving in the Transformers format.")
182
+ model.save_pretrained(model_path, safe_serialization=safe_serialization)
183
+ shutil.rmtree(tmp_model_path)
184
+
185
+
186
+ def _write_tokenizer(
187
+ output_path: Path, config: OlmoConfig, input_tokenizer_path: Path, fix_eos_token_id: bool = True
188
+ ) -> None:
189
+ print(f"Saving a {GPTNeoXTokenizerFast.__name__} to {output_path}.")
190
+
191
+ base_tokenizer = Tokenizer.from_file(str(input_tokenizer_path))
192
+
193
+ eos_token_id = config.eos_token_id if config.eos_token_id is not None else base_tokenizer.get_vocab_size() - 1
194
+ pad_token_id = config.pad_token_id if config.pad_token_id is not None else eos_token_id
195
+
196
+ if fix_eos_token_id and eos_token_id == 0:
197
+ # Fixing a bug in OLMo where eos token id was incorrectly set
198
+ print("Changing eos_token_id from 0 to 50279.")
199
+ eos_token_id = 50279
200
+
201
+ tokenizer = GPTNeoXTokenizerFast(
202
+ tokenizer_object=base_tokenizer,
203
+ eos_token=base_tokenizer.decode([eos_token_id], skip_special_tokens=False),
204
+ pad_token=base_tokenizer.decode([pad_token_id], skip_special_tokens=False),
205
+ unk_token=None,
206
+ bos_token=None,
207
+ )
208
+
209
+ tokenizer.save_pretrained(output_path)
210
+
211
+
212
+ def main():
213
+ parser = argparse.ArgumentParser()
214
+ parser.add_argument(
215
+ "--input_dir",
216
+ required=True,
217
+ help="Location of OLMo weights, which contains config.yaml and model.pt.",
218
+ )
219
+ parser.add_argument(
220
+ "--tokenizer_json_path",
221
+ default=None,
222
+ help="Location of OLMo tokenizer json file.",
223
+ )
224
+ parser.add_argument(
225
+ "--output_dir",
226
+ required=True,
227
+ help="Location to write HF model and tokenizer",
228
+ )
229
+ parser.add_argument(
230
+ "--no_fix_eos_token_id",
231
+ action="store_false",
232
+ dest="fix_eos_token_id",
233
+ help="If set, does not change eos token id from 0 to 50279 if it is 0. Changing 0 to 50279 is a bug fix, so use this option with care.",
234
+ )
235
+ parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.")
236
+ # Different OLMo versions used different default values for max_position_embeddings, hence the need to be able to specify which version is being used.
237
+ args = parser.parse_args()
238
+ write_model(
239
+ model_path=args.output_dir,
240
+ input_base_path=args.input_dir,
241
+ safe_serialization=args.safe_serialization,
242
+ tokenizer_path=args.tokenizer_json_path,
243
+ fix_eos_token_id=args.fix_eos_token_id,
244
+ )
245
+
246
+
247
+ if __name__ == "__main__":
248
+ main()
docs/transformers/src/transformers/models/olmo/modeling_olmo.py ADDED
@@ -0,0 +1,814 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/olmo/modular_olmo.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_olmo.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ from typing import Callable, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from ...activations import ACT2FN
14
+ from ...cache_utils import Cache, DynamicCache, StaticCache
15
+ from ...generation import GenerationMixin
16
+ from ...modeling_attn_mask_utils import AttentionMaskConverter
17
+ from ...modeling_flash_attention_utils import FlashAttentionKwargs
18
+ from ...modeling_layers import GradientCheckpointingLayer
19
+ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
20
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
21
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
22
+ from ...processing_utils import Unpack
23
+ from ...utils import (
24
+ LossKwargs,
25
+ add_start_docstrings,
26
+ add_start_docstrings_to_model_forward,
27
+ can_return_tuple,
28
+ is_torch_flex_attn_available,
29
+ logging,
30
+ replace_return_docstrings,
31
+ )
32
+ from .configuration_olmo import OlmoConfig
33
+
34
+
35
+ if is_torch_flex_attn_available():
36
+ from torch.nn.attention.flex_attention import BlockMask
37
+
38
+ from ...integrations.flex_attention import make_flex_block_causal_mask
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+ _CONFIG_FOR_DOC = "OlmoConfig"
43
+
44
+
45
+ class OlmoLayerNorm(nn.Module):
46
+ """LayerNorm but with no learnable weight or bias."""
47
+
48
+ def __init__(self, hidden_size: int) -> None:
49
+ super().__init__()
50
+ self.normalized_shape = (hidden_size,)
51
+
52
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
53
+ orig_dtype = hidden_states.dtype
54
+ return F.layer_norm(hidden_states.to(dtype=torch.float32), self.normalized_shape, None, None, eps=1e-5).to(
55
+ orig_dtype
56
+ )
57
+
58
+
59
+ class OlmoMLP(nn.Module):
60
+ def __init__(self, config):
61
+ super().__init__()
62
+ self.config = config
63
+ self.hidden_size = config.hidden_size
64
+ self.intermediate_size = config.intermediate_size
65
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
66
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
67
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
68
+ self.act_fn = ACT2FN[config.hidden_act]
69
+
70
+ def forward(self, x):
71
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
72
+ return down_proj
73
+
74
+
75
+ def rotate_half(x):
76
+ """Rotates half the hidden dims of the input."""
77
+ x1 = x[..., : x.shape[-1] // 2]
78
+ x2 = x[..., x.shape[-1] // 2 :]
79
+ return torch.cat((-x2, x1), dim=-1)
80
+
81
+
82
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
83
+ """Applies Rotary Position Embedding to the query and key tensors.
84
+
85
+ Args:
86
+ q (`torch.Tensor`): The query tensor.
87
+ k (`torch.Tensor`): The key tensor.
88
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
89
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
90
+ position_ids (`torch.Tensor`, *optional*):
91
+ Deprecated and unused.
92
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
93
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
94
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
95
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
96
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
97
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
98
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
99
+ Returns:
100
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
101
+ """
102
+ cos = cos.unsqueeze(unsqueeze_dim)
103
+ sin = sin.unsqueeze(unsqueeze_dim)
104
+ q_embed = (q * cos) + (rotate_half(q) * sin)
105
+ k_embed = (k * cos) + (rotate_half(k) * sin)
106
+ return q_embed, k_embed
107
+
108
+
109
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
110
+ """
111
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
112
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
113
+ """
114
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
115
+ if n_rep == 1:
116
+ return hidden_states
117
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
118
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
119
+
120
+
121
+ def eager_attention_forward(
122
+ module: nn.Module,
123
+ query: torch.Tensor,
124
+ key: torch.Tensor,
125
+ value: torch.Tensor,
126
+ attention_mask: Optional[torch.Tensor],
127
+ scaling: float,
128
+ dropout: float = 0.0,
129
+ **kwargs,
130
+ ):
131
+ key_states = repeat_kv(key, module.num_key_value_groups)
132
+ value_states = repeat_kv(value, module.num_key_value_groups)
133
+
134
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
135
+ if attention_mask is not None:
136
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
137
+ attn_weights = attn_weights + causal_mask
138
+
139
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
140
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
141
+ attn_output = torch.matmul(attn_weights, value_states)
142
+ attn_output = attn_output.transpose(1, 2).contiguous()
143
+
144
+ return attn_output, attn_weights
145
+
146
+
147
+ class OlmoAttention(nn.Module):
148
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
149
+
150
+ def __init__(self, config: OlmoConfig, layer_idx: int):
151
+ super().__init__()
152
+ self.config = config
153
+ self.layer_idx = layer_idx
154
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
155
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
156
+ self.scaling = self.head_dim**-0.5
157
+ self.attention_dropout = config.attention_dropout
158
+ self.is_causal = True
159
+
160
+ self.q_proj = nn.Linear(
161
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
162
+ )
163
+ self.k_proj = nn.Linear(
164
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
165
+ )
166
+ self.v_proj = nn.Linear(
167
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
168
+ )
169
+ self.o_proj = nn.Linear(
170
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
171
+ )
172
+
173
+ def forward(
174
+ self,
175
+ hidden_states: torch.Tensor,
176
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
177
+ attention_mask: Optional[torch.Tensor],
178
+ past_key_value: Optional[Cache] = None,
179
+ cache_position: Optional[torch.LongTensor] = None,
180
+ **kwargs,
181
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
182
+ input_shape = hidden_states.shape[:-1]
183
+ hidden_shape = (*input_shape, -1, self.head_dim)
184
+
185
+ query_states = self.q_proj(hidden_states)
186
+ key_states = self.k_proj(hidden_states)
187
+ value_states = self.v_proj(hidden_states)
188
+
189
+ if self.config.clip_qkv is not None:
190
+ query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
191
+ key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
192
+ value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
193
+
194
+ query_states = query_states.view(hidden_shape).transpose(1, 2)
195
+ key_states = key_states.view(hidden_shape).transpose(1, 2)
196
+ value_states = value_states.view(hidden_shape).transpose(1, 2)
197
+
198
+ cos, sin = position_embeddings
199
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
200
+
201
+ if past_key_value is not None:
202
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
203
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
204
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
205
+
206
+ attention_interface: Callable = eager_attention_forward
207
+ if self.config._attn_implementation != "eager":
208
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
209
+ logger.warning_once(
210
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
211
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
212
+ )
213
+ else:
214
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
215
+
216
+ attn_output, attn_weights = attention_interface(
217
+ self,
218
+ query_states,
219
+ key_states,
220
+ value_states,
221
+ attention_mask,
222
+ dropout=0.0 if not self.training else self.attention_dropout,
223
+ scaling=self.scaling,
224
+ **kwargs,
225
+ )
226
+
227
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
228
+ attn_output = self.o_proj(attn_output)
229
+ return attn_output, attn_weights
230
+
231
+
232
+ class OlmoDecoderLayer(GradientCheckpointingLayer):
233
+ def __init__(self, config: OlmoConfig, layer_idx: int):
234
+ super().__init__()
235
+ self.hidden_size = config.hidden_size
236
+ self.self_attn = OlmoAttention(config=config, layer_idx=layer_idx)
237
+
238
+ self.mlp = OlmoMLP(config)
239
+ self.input_layernorm = OlmoLayerNorm(config.hidden_size)
240
+ self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size)
241
+
242
+ def forward(
243
+ self,
244
+ hidden_states: torch.Tensor,
245
+ attention_mask: Optional[torch.Tensor] = None,
246
+ position_ids: Optional[torch.LongTensor] = None,
247
+ past_key_value: Optional[Cache] = None,
248
+ output_attentions: Optional[bool] = False,
249
+ use_cache: Optional[bool] = False,
250
+ cache_position: Optional[torch.LongTensor] = None,
251
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
252
+ **kwargs: Unpack[FlashAttentionKwargs],
253
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
254
+ residual = hidden_states
255
+ hidden_states = self.input_layernorm(hidden_states)
256
+
257
+ # Self Attention
258
+ hidden_states, self_attn_weights = self.self_attn(
259
+ hidden_states=hidden_states,
260
+ attention_mask=attention_mask,
261
+ position_ids=position_ids,
262
+ past_key_value=past_key_value,
263
+ output_attentions=output_attentions,
264
+ use_cache=use_cache,
265
+ cache_position=cache_position,
266
+ position_embeddings=position_embeddings,
267
+ **kwargs,
268
+ )
269
+ hidden_states = residual + hidden_states
270
+
271
+ # Fully Connected
272
+ residual = hidden_states
273
+ hidden_states = self.post_attention_layernorm(hidden_states)
274
+ hidden_states = self.mlp(hidden_states)
275
+ hidden_states = residual + hidden_states
276
+
277
+ outputs = (hidden_states,)
278
+ if output_attentions:
279
+ outputs += (self_attn_weights,)
280
+
281
+ return outputs
282
+
283
+
284
+ class OlmoRotaryEmbedding(nn.Module):
285
+ def __init__(self, config: OlmoConfig, device=None):
286
+ super().__init__()
287
+ # BC: "rope_type" was originally "type"
288
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
289
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
290
+ else:
291
+ self.rope_type = "default"
292
+ self.max_seq_len_cached = config.max_position_embeddings
293
+ self.original_max_seq_len = config.max_position_embeddings
294
+
295
+ self.config = config
296
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
297
+
298
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
299
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
300
+ self.original_inv_freq = self.inv_freq
301
+
302
+ @torch.no_grad()
303
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
304
+ def forward(self, x, position_ids):
305
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
306
+ position_ids_expanded = position_ids[:, None, :].float()
307
+
308
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
309
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
310
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
311
+ emb = torch.cat((freqs, freqs), dim=-1)
312
+ cos = emb.cos() * self.attention_scaling
313
+ sin = emb.sin() * self.attention_scaling
314
+
315
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
316
+
317
+
318
+ OLMO_START_DOCSTRING = r"""
319
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
320
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
321
+ etc.)
322
+
323
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
324
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
325
+ and behavior.
326
+
327
+ Parameters:
328
+ config ([`OlmoConfig`]):
329
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
330
+ load the weights associated with the model, only the configuration. Check out the
331
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
332
+ """
333
+
334
+
335
+ @add_start_docstrings(
336
+ "The bare Olmo Model outputting raw hidden-states without any specific head on top.",
337
+ OLMO_START_DOCSTRING,
338
+ )
339
+ class OlmoPreTrainedModel(PreTrainedModel):
340
+ config_class = OlmoConfig
341
+ base_model_prefix = "model"
342
+ supports_gradient_checkpointing = True
343
+ _no_split_modules = ["OlmoDecoderLayer"]
344
+ _skip_keys_device_placement = ["past_key_values"]
345
+ _supports_flash_attn_2 = True
346
+ _supports_sdpa = True
347
+ _supports_flex_attn = True
348
+ _supports_cache_class = True
349
+ _supports_quantized_cache = True
350
+ _supports_static_cache = True
351
+ _supports_attention_backend = True
352
+
353
+ def _init_weights(self, module):
354
+ std = self.config.initializer_range
355
+ if isinstance(module, nn.Linear):
356
+ module.weight.data.normal_(mean=0.0, std=std)
357
+ if module.bias is not None:
358
+ module.bias.data.zero_()
359
+ elif isinstance(module, nn.Embedding):
360
+ module.weight.data.normal_(mean=0.0, std=std)
361
+ if module.padding_idx is not None:
362
+ module.weight.data[module.padding_idx].zero_()
363
+
364
+
365
+ OLMO_INPUTS_DOCSTRING = r"""
366
+ Args:
367
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
368
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
369
+ it.
370
+
371
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
372
+ [`PreTrainedTokenizer.__call__`] for details.
373
+
374
+ [What are input IDs?](../glossary#input-ids)
375
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*):
376
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
377
+
378
+ - 1 for tokens that are **not masked**,
379
+ - 0 for tokens that are **masked**.
380
+
381
+ If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask,
382
+ but you can also pass a `BlockMask` object directly here.
383
+
384
+ [What are attention masks?](../glossary#attention-mask)
385
+
386
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
387
+ [`PreTrainedTokenizer.__call__`] for details.
388
+
389
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
390
+ `past_key_values`).
391
+
392
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
393
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
394
+ information on the default strategy.
395
+
396
+ - 1 indicates the head is **not masked**,
397
+ - 0 indicates the head is **masked**.
398
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
399
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
400
+ config.n_positions - 1]`.
401
+
402
+ [What are position IDs?](../glossary#position-ids)
403
+ past_key_values (`Cache`, *optional*):
404
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
405
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
406
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
407
+
408
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
409
+
410
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
411
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
412
+ of shape `(batch_size, sequence_length)`.
413
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
414
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
415
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
416
+ model's internal embedding lookup matrix.
417
+ use_cache (`bool`, *optional*):
418
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
419
+ `past_key_values`).
420
+ output_attentions (`bool`, *optional*):
421
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
422
+ tensors for more detail.
423
+ output_hidden_states (`bool`, *optional*):
424
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
425
+ more detail.
426
+ return_dict (`bool`, *optional*):
427
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
428
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
429
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
430
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
431
+ the complete sequence length.
432
+ """
433
+
434
+
435
+ @add_start_docstrings(
436
+ "The bare Olmo Model outputting raw hidden-states without any specific head on top.",
437
+ OLMO_START_DOCSTRING,
438
+ )
439
+ class OlmoModel(OlmoPreTrainedModel):
440
+ """
441
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OlmoDecoderLayer`]
442
+
443
+ Args:
444
+ config: OlmoConfig
445
+ """
446
+
447
+ def __init__(self, config: OlmoConfig):
448
+ super().__init__(config)
449
+ self.padding_idx = config.pad_token_id
450
+ self.vocab_size = config.vocab_size
451
+
452
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
453
+ self.layers = nn.ModuleList(
454
+ [OlmoDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
455
+ )
456
+ self.norm = OlmoLayerNorm(config.hidden_size)
457
+ self.rotary_emb = OlmoRotaryEmbedding(config=config)
458
+ self.gradient_checkpointing = False
459
+
460
+ # Initialize weights and apply final processing
461
+ self.post_init()
462
+
463
+ def get_input_embeddings(self):
464
+ return self.embed_tokens
465
+
466
+ def set_input_embeddings(self, value):
467
+ self.embed_tokens = value
468
+
469
+ @can_return_tuple
470
+ @add_start_docstrings_to_model_forward(OLMO_INPUTS_DOCSTRING)
471
+ def forward(
472
+ self,
473
+ input_ids: Optional[torch.LongTensor] = None,
474
+ attention_mask: Optional[torch.Tensor] = None,
475
+ position_ids: Optional[torch.LongTensor] = None,
476
+ past_key_values: Optional[Cache] = None,
477
+ inputs_embeds: Optional[torch.FloatTensor] = None,
478
+ use_cache: Optional[bool] = None,
479
+ output_attentions: Optional[bool] = None,
480
+ output_hidden_states: Optional[bool] = None,
481
+ cache_position: Optional[torch.LongTensor] = None,
482
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
483
+ ) -> BaseModelOutputWithPast:
484
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
485
+ output_hidden_states = (
486
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
487
+ )
488
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
489
+
490
+ if (input_ids is None) ^ (inputs_embeds is not None):
491
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
492
+
493
+ if self.gradient_checkpointing and self.training and use_cache:
494
+ logger.warning_once(
495
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
496
+ )
497
+ use_cache = False
498
+
499
+ # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
500
+ if not isinstance(past_key_values, (type(None), Cache)):
501
+ raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
502
+
503
+ if inputs_embeds is None:
504
+ inputs_embeds = self.embed_tokens(input_ids)
505
+
506
+ if use_cache and past_key_values is None:
507
+ past_key_values = DynamicCache()
508
+
509
+ if cache_position is None:
510
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
511
+ cache_position = torch.arange(
512
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
513
+ )
514
+
515
+ if position_ids is None:
516
+ position_ids = cache_position.unsqueeze(0)
517
+
518
+ causal_mask = self._update_causal_mask(
519
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
520
+ )
521
+
522
+ hidden_states = inputs_embeds
523
+
524
+ # create position embeddings to be shared across the decoder layers
525
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
526
+
527
+ # decoder layers
528
+ all_hidden_states = () if output_hidden_states else None
529
+ all_self_attns = () if output_attentions else None
530
+
531
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
532
+ if output_hidden_states:
533
+ all_hidden_states += (hidden_states,)
534
+
535
+ layer_outputs = decoder_layer(
536
+ hidden_states,
537
+ attention_mask=causal_mask,
538
+ position_ids=position_ids,
539
+ past_key_value=past_key_values,
540
+ output_attentions=output_attentions,
541
+ use_cache=use_cache,
542
+ cache_position=cache_position,
543
+ position_embeddings=position_embeddings,
544
+ **flash_attn_kwargs,
545
+ )
546
+
547
+ hidden_states = layer_outputs[0]
548
+
549
+ if output_attentions:
550
+ all_self_attns += (layer_outputs[1],)
551
+
552
+ hidden_states = self.norm(hidden_states)
553
+
554
+ # add hidden states from the last decoder layer
555
+ if output_hidden_states:
556
+ all_hidden_states += (hidden_states,)
557
+
558
+ return BaseModelOutputWithPast(
559
+ last_hidden_state=hidden_states,
560
+ past_key_values=past_key_values if use_cache else None,
561
+ hidden_states=all_hidden_states,
562
+ attentions=all_self_attns,
563
+ )
564
+
565
+ def _update_causal_mask(
566
+ self,
567
+ attention_mask: Union[torch.Tensor, "BlockMask"],
568
+ input_tensor: torch.Tensor,
569
+ cache_position: torch.Tensor,
570
+ past_key_values: Cache,
571
+ output_attentions: bool = False,
572
+ ):
573
+ if self.config._attn_implementation == "flash_attention_2":
574
+ if attention_mask is not None and (attention_mask == 0.0).any():
575
+ return attention_mask
576
+ return None
577
+ if self.config._attn_implementation == "flex_attention":
578
+ if isinstance(attention_mask, torch.Tensor):
579
+ attention_mask = make_flex_block_causal_mask(attention_mask)
580
+ return attention_mask
581
+
582
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
583
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
584
+ # to infer the attention mask.
585
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
586
+ using_static_cache = isinstance(past_key_values, StaticCache)
587
+
588
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
589
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
590
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
591
+ attention_mask,
592
+ inputs_embeds=input_tensor,
593
+ past_key_values_length=past_seen_tokens,
594
+ is_training=self.training,
595
+ ):
596
+ return None
597
+
598
+ dtype, device = input_tensor.dtype, input_tensor.device
599
+ sequence_length = input_tensor.shape[1]
600
+ if using_static_cache:
601
+ target_length = past_key_values.get_max_cache_shape()
602
+ else:
603
+ target_length = (
604
+ attention_mask.shape[-1]
605
+ if isinstance(attention_mask, torch.Tensor)
606
+ else past_seen_tokens + sequence_length + 1
607
+ )
608
+
609
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
610
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
611
+ attention_mask,
612
+ sequence_length=sequence_length,
613
+ target_length=target_length,
614
+ dtype=dtype,
615
+ device=device,
616
+ cache_position=cache_position,
617
+ batch_size=input_tensor.shape[0],
618
+ )
619
+
620
+ if (
621
+ self.config._attn_implementation == "sdpa"
622
+ and attention_mask is not None
623
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
624
+ and not output_attentions
625
+ ):
626
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
627
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
628
+ # Details: https://github.com/pytorch/pytorch/issues/110213
629
+ min_dtype = torch.finfo(dtype).min
630
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
631
+
632
+ return causal_mask
633
+
634
+ @staticmethod
635
+ def _prepare_4d_causal_attention_mask_with_cache_position(
636
+ attention_mask: torch.Tensor,
637
+ sequence_length: int,
638
+ target_length: int,
639
+ dtype: torch.dtype,
640
+ device: torch.device,
641
+ cache_position: torch.Tensor,
642
+ batch_size: int,
643
+ **kwargs,
644
+ ):
645
+ """
646
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
647
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
648
+
649
+ Args:
650
+ attention_mask (`torch.Tensor`):
651
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
652
+ `(batch_size, 1, query_length, key_value_length)`.
653
+ sequence_length (`int`):
654
+ The sequence length being processed.
655
+ target_length (`int`):
656
+ The target length: when generating with static cache, the mask should be as long as the static cache,
657
+ to account for the 0 padding, the part of the cache that is not filled yet.
658
+ dtype (`torch.dtype`):
659
+ The dtype to use for the 4D attention mask.
660
+ device (`torch.device`):
661
+ The device to place the 4D attention mask on.
662
+ cache_position (`torch.Tensor`):
663
+ Indices depicting the position of the input sequence tokens in the sequence.
664
+ batch_size (`torch.Tensor`):
665
+ Batch size.
666
+ """
667
+ if attention_mask is not None and attention_mask.dim() == 4:
668
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
669
+ causal_mask = attention_mask
670
+ else:
671
+ min_dtype = torch.finfo(dtype).min
672
+ causal_mask = torch.full(
673
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
674
+ )
675
+ if sequence_length != 1:
676
+ causal_mask = torch.triu(causal_mask, diagonal=1)
677
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
678
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
679
+ if attention_mask is not None:
680
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
681
+ mask_length = attention_mask.shape[-1]
682
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
683
+ causal_mask.device
684
+ )
685
+ padding_mask = padding_mask == 0
686
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
687
+ padding_mask, min_dtype
688
+ )
689
+
690
+ return causal_mask
691
+
692
+
693
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
694
+
695
+
696
+ class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin):
697
+ _tied_weights_keys = ["lm_head.weight"]
698
+ _tp_plan = {"lm_head": "colwise_rep"}
699
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
700
+
701
+ def __init__(self, config):
702
+ super().__init__(config)
703
+ self.model = OlmoModel(config)
704
+ self.vocab_size = config.vocab_size
705
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
706
+
707
+ # Initialize weights and apply final processing
708
+ self.post_init()
709
+
710
+ def get_input_embeddings(self):
711
+ return self.model.embed_tokens
712
+
713
+ def set_input_embeddings(self, value):
714
+ self.model.embed_tokens = value
715
+
716
+ def get_output_embeddings(self):
717
+ return self.lm_head
718
+
719
+ def set_output_embeddings(self, new_embeddings):
720
+ self.lm_head = new_embeddings
721
+
722
+ def set_decoder(self, decoder):
723
+ self.model = decoder
724
+
725
+ def get_decoder(self):
726
+ return self.model
727
+
728
+ @can_return_tuple
729
+ @add_start_docstrings_to_model_forward(OLMO_INPUTS_DOCSTRING)
730
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
731
+ def forward(
732
+ self,
733
+ input_ids: Optional[torch.LongTensor] = None,
734
+ attention_mask: Optional[torch.Tensor] = None,
735
+ position_ids: Optional[torch.LongTensor] = None,
736
+ past_key_values: Optional[Cache] = None,
737
+ inputs_embeds: Optional[torch.FloatTensor] = None,
738
+ labels: Optional[torch.LongTensor] = None,
739
+ use_cache: Optional[bool] = None,
740
+ output_attentions: Optional[bool] = None,
741
+ output_hidden_states: Optional[bool] = None,
742
+ cache_position: Optional[torch.LongTensor] = None,
743
+ logits_to_keep: Union[int, torch.Tensor] = 0,
744
+ **kwargs: Unpack[KwargsForCausalLM],
745
+ ) -> CausalLMOutputWithPast:
746
+ r"""
747
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
748
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
749
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
750
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
751
+
752
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
753
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
754
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
755
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
756
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
757
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
758
+
759
+ Returns:
760
+
761
+ Example:
762
+
763
+ ```python
764
+ >>> from transformers import AutoTokenizer, OlmoForCausalLM
765
+
766
+ >>> model = OlmoForCausalLM.from_pretrained("meta-olmo/Olmo-2-7b-hf")
767
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-olmo/Olmo-2-7b-hf")
768
+
769
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
770
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
771
+
772
+ >>> # Generate
773
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
774
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
775
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
776
+ ```"""
777
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
778
+ output_hidden_states = (
779
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
780
+ )
781
+
782
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
783
+ outputs: BaseModelOutputWithPast = self.model(
784
+ input_ids=input_ids,
785
+ attention_mask=attention_mask,
786
+ position_ids=position_ids,
787
+ past_key_values=past_key_values,
788
+ inputs_embeds=inputs_embeds,
789
+ use_cache=use_cache,
790
+ output_attentions=output_attentions,
791
+ output_hidden_states=output_hidden_states,
792
+ cache_position=cache_position,
793
+ **kwargs,
794
+ )
795
+
796
+ hidden_states = outputs.last_hidden_state
797
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
798
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
799
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
800
+
801
+ loss = None
802
+ if labels is not None:
803
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
804
+
805
+ return CausalLMOutputWithPast(
806
+ loss=loss,
807
+ logits=logits,
808
+ past_key_values=outputs.past_key_values,
809
+ hidden_states=outputs.hidden_states,
810
+ attentions=outputs.attentions,
811
+ )
812
+
813
+
814
+ __all__ = ["OlmoForCausalLM", "OlmoModel", "OlmoPreTrainedModel"]
docs/transformers/src/transformers/models/olmo/modular_olmo.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional, Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torch.utils.checkpoint
7
+
8
+ from ...cache_utils import Cache
9
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
10
+ from ...utils import logging
11
+ from ..llama.modeling_llama import (
12
+ LlamaAttention,
13
+ LlamaDecoderLayer,
14
+ LlamaForCausalLM,
15
+ LlamaMLP,
16
+ LlamaModel,
17
+ LlamaPreTrainedModel,
18
+ LlamaRotaryEmbedding,
19
+ apply_rotary_pos_emb,
20
+ eager_attention_forward,
21
+ )
22
+ from .configuration_olmo import OlmoConfig
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ class OlmoLayerNorm(nn.Module):
29
+ """LayerNorm but with no learnable weight or bias."""
30
+
31
+ def __init__(self, hidden_size: int) -> None:
32
+ super().__init__()
33
+ self.normalized_shape = (hidden_size,)
34
+
35
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
36
+ orig_dtype = hidden_states.dtype
37
+ return F.layer_norm(hidden_states.to(dtype=torch.float32), self.normalized_shape, None, None, eps=1e-5).to(
38
+ orig_dtype
39
+ )
40
+
41
+
42
+ class OlmoMLP(LlamaMLP):
43
+ def __init__(self, config):
44
+ super().__init__(config)
45
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
46
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
47
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
48
+
49
+
50
+ class OlmoAttention(LlamaAttention):
51
+ def forward(
52
+ self,
53
+ hidden_states: torch.Tensor,
54
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
55
+ attention_mask: Optional[torch.Tensor],
56
+ past_key_value: Optional[Cache] = None,
57
+ cache_position: Optional[torch.LongTensor] = None,
58
+ **kwargs,
59
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
60
+ input_shape = hidden_states.shape[:-1]
61
+ hidden_shape = (*input_shape, -1, self.head_dim)
62
+
63
+ query_states = self.q_proj(hidden_states)
64
+ key_states = self.k_proj(hidden_states)
65
+ value_states = self.v_proj(hidden_states)
66
+
67
+ if self.config.clip_qkv is not None:
68
+ query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
69
+ key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
70
+ value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
71
+
72
+ query_states = query_states.view(hidden_shape).transpose(1, 2)
73
+ key_states = key_states.view(hidden_shape).transpose(1, 2)
74
+ value_states = value_states.view(hidden_shape).transpose(1, 2)
75
+
76
+ cos, sin = position_embeddings
77
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
78
+
79
+ if past_key_value is not None:
80
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
81
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
82
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
83
+
84
+ attention_interface: Callable = eager_attention_forward
85
+ if self.config._attn_implementation != "eager":
86
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
87
+ logger.warning_once(
88
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
89
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
90
+ )
91
+ else:
92
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
93
+
94
+ attn_output, attn_weights = attention_interface(
95
+ self,
96
+ query_states,
97
+ key_states,
98
+ value_states,
99
+ attention_mask,
100
+ dropout=0.0 if not self.training else self.attention_dropout,
101
+ scaling=self.scaling,
102
+ **kwargs,
103
+ )
104
+
105
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
106
+ attn_output = self.o_proj(attn_output)
107
+ return attn_output, attn_weights
108
+
109
+
110
+ class OlmoDecoderLayer(LlamaDecoderLayer):
111
+ def __init__(self, config: OlmoConfig, layer_idx: int):
112
+ super().__init__(config, layer_idx)
113
+ self.input_layernorm = OlmoLayerNorm(config.hidden_size)
114
+ self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size)
115
+ self.self_attn = OlmoAttention(config=config, layer_idx=layer_idx)
116
+
117
+
118
+ class OlmoRotaryEmbedding(LlamaRotaryEmbedding):
119
+ pass
120
+
121
+
122
+ class OlmoPreTrainedModel(LlamaPreTrainedModel):
123
+ def _init_weights(self, module):
124
+ std = self.config.initializer_range
125
+ if isinstance(module, nn.Linear):
126
+ module.weight.data.normal_(mean=0.0, std=std)
127
+ if module.bias is not None:
128
+ module.bias.data.zero_()
129
+ elif isinstance(module, nn.Embedding):
130
+ module.weight.data.normal_(mean=0.0, std=std)
131
+ if module.padding_idx is not None:
132
+ module.weight.data[module.padding_idx].zero_()
133
+
134
+
135
+ class OlmoModel(LlamaModel):
136
+ def __init__(self, config: OlmoConfig):
137
+ super().__init__(config)
138
+ self.layers = nn.ModuleList(
139
+ [OlmoDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
140
+ )
141
+ self.norm = OlmoLayerNorm(config.hidden_size)
142
+
143
+
144
+ class OlmoForCausalLM(LlamaForCausalLM):
145
+ pass
146
+
147
+
148
+ __all__ = ["OlmoForCausalLM", "OlmoModel", "OlmoPreTrainedModel"]
docs/transformers/src/transformers/models/olmo2/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 EleutherAI and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_olmo2 import *
22
+ from .modeling_olmo2 import *
23
+ else:
24
+ import sys
25
+
26
+ _file = globals()["__file__"]
27
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/src/transformers/models/olmo2/configuration_olmo2.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/olmo2/modular_olmo2.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_olmo2.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+
8
+ from ...configuration_utils import PretrainedConfig
9
+
10
+
11
+ class Olmo2Config(PretrainedConfig):
12
+ r"""
13
+ This is the configuration class to store the configuration of a [`Olmo2Model`]. It is used to instantiate an OLMo2
14
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
15
+ defaults will yield a similar configuration to that of the [allenai/Olmo2-7B-1124-hf](https://huggingface.co/allenai/Olmo2-7B-1124-hf).
16
+
17
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
18
+ documentation from [`PretrainedConfig`] for more information.
19
+
20
+
21
+ Args:
22
+ vocab_size (`int`, *optional*, defaults to 50304):
23
+ Vocabulary size of the Olmo2 model. Defines the number of different tokens that can be represented by the
24
+ `inputs_ids` passed when calling [`Olmo2Model`]
25
+ hidden_size (`int`, *optional*, defaults to 4096):
26
+ Dimension of the hidden representations.
27
+ intermediate_size (`int`, *optional*, defaults to 11008):
28
+ Dimension of the MLP representations.
29
+ num_hidden_layers (`int`, *optional*, defaults to 32):
30
+ Number of hidden layers in the Transformer decoder.
31
+ num_attention_heads (`int`, *optional*, defaults to 32):
32
+ Number of attention heads for each attention layer in the Transformer decoder.
33
+ num_key_value_heads (`int`, *optional*):
34
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
35
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
36
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
37
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
38
+ by meanpooling all the original heads within that group. For more details checkout [this
39
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
40
+ `num_attention_heads`.
41
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
42
+ The non-linear activation function (function or string) in the decoder.
43
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
44
+ The maximum sequence length that this model might ever be used with.
45
+ initializer_range (`float`, *optional*, defaults to 0.02):
46
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
47
+ use_cache (`bool`, *optional*, defaults to `True`):
48
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
49
+ relevant if `config.is_decoder=True`.
50
+ pad_token_id (`int`, *optional*, defaults to 1):
51
+ Padding token id.
52
+ bos_token_id (`int`, *optional*):
53
+ Beginning of stream token id.
54
+ eos_token_id (`int`, *optional*, defaults to 50279):
55
+ End of stream token id.
56
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
57
+ Whether to tie weight embeddings
58
+ rope_theta (`float`, *optional*, defaults to 10000.0):
59
+ The base period of the RoPE embeddings.
60
+ rope_scaling (`Dict`, *optional*):
61
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
62
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
63
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
64
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
65
+ these scaling strategies behave:
66
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
67
+ experimental feature, subject to breaking API changes in future versions.
68
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
69
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
70
+ attention_dropout (`float`, *optional*, defaults to 0.0):
71
+ The dropout ratio for the attention probabilities.
72
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
73
+ The epsilon used by the rms normalization layers.
74
+
75
+ ```python
76
+ >>> from transformers import Olmo2Model, Olmo2Config
77
+
78
+ >>> # Initializing a Olmo2 7B style configuration
79
+ >>> configuration = Olmo2Config()
80
+
81
+ >>> # Initializing a model from the Olmo2 7B style configuration
82
+ >>> model = Olmo2Model(configuration)
83
+
84
+ >>> # Accessing the model configuration
85
+ >>> configuration = model.config
86
+ ```
87
+ """
88
+
89
+ model_type = "olmo2"
90
+ keys_to_ignore_at_inference = ["past_key_values"]
91
+ base_model_tp_plan = {
92
+ "layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
93
+ "layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
94
+ "layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
95
+ "layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
96
+ "layers.*.mlp.gate_proj": "colwise",
97
+ "layers.*.mlp.up_proj": "colwise",
98
+ "layers.*.mlp.down_proj": "rowwise",
99
+ }
100
+ base_model_pp_plan = {
101
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
102
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
103
+ "norm": (["hidden_states"], ["hidden_states"]),
104
+ }
105
+
106
+ def __init__(
107
+ self,
108
+ vocab_size=50304,
109
+ hidden_size=4096,
110
+ intermediate_size=11008,
111
+ num_hidden_layers=32,
112
+ num_attention_heads=32,
113
+ num_key_value_heads=None,
114
+ hidden_act="silu",
115
+ max_position_embeddings=2048,
116
+ initializer_range=0.02,
117
+ use_cache=True,
118
+ pad_token_id=1,
119
+ bos_token_id=None,
120
+ eos_token_id=50279,
121
+ tie_word_embeddings=False,
122
+ rope_theta=10000.0,
123
+ rope_scaling=None,
124
+ attention_bias=False,
125
+ attention_dropout=0.0,
126
+ rms_norm_eps=1e-5,
127
+ **kwargs,
128
+ ):
129
+ super().__init__(
130
+ pad_token_id=pad_token_id,
131
+ bos_token_id=bos_token_id,
132
+ eos_token_id=eos_token_id,
133
+ tie_word_embeddings=tie_word_embeddings,
134
+ **kwargs,
135
+ )
136
+ self.vocab_size = vocab_size
137
+ self.max_position_embeddings = max_position_embeddings
138
+ self.hidden_size = hidden_size
139
+ self.intermediate_size = intermediate_size
140
+ self.num_hidden_layers = num_hidden_layers
141
+ self.num_attention_heads = num_attention_heads
142
+
143
+ # for backward compatibility
144
+ if num_key_value_heads is None:
145
+ num_key_value_heads = num_attention_heads
146
+
147
+ self.num_key_value_heads = num_key_value_heads
148
+ self.hidden_act = hidden_act
149
+ self.initializer_range = initializer_range
150
+ self.use_cache = use_cache
151
+ self.rope_theta = rope_theta
152
+ self.rope_scaling = rope_scaling
153
+ self._rope_scaling_validation()
154
+ self.attention_bias = attention_bias
155
+ self.attention_dropout = attention_dropout
156
+
157
+ self.rms_norm_eps = rms_norm_eps
158
+
159
+ def _rope_scaling_validation(self):
160
+ """
161
+ Validate the `rope_scaling` configuration.
162
+ """
163
+ if self.rope_scaling is None:
164
+ return
165
+
166
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
167
+ raise ValueError(
168
+ f"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, got {self.rope_scaling}"
169
+ )
170
+ rope_scaling_type = self.rope_scaling.get("type", None)
171
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
172
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
173
+ raise ValueError(
174
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
175
+ )
176
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
177
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
178
+
179
+
180
+ __all__ = ["Olmo2Config"]
docs/transformers/src/transformers/models/olmo2/convert_olmo2_weights_to_hf.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 EleutherAI and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from __future__ import annotations
15
+
16
+ import argparse
17
+ import gc
18
+ import json
19
+ import os
20
+ import shutil
21
+ from pathlib import Path
22
+ from typing import Any, Dict
23
+
24
+ import torch
25
+ import yaml
26
+ from tokenizers import Tokenizer
27
+
28
+ from transformers import Olmo2Config, Olmo2ForCausalLM
29
+ from transformers.models.gpt2.tokenization_gpt2_fast import GPT2TokenizerFast
30
+
31
+
32
+ """
33
+ Sample usage:
34
+
35
+ ```
36
+ python src/transformers/models/olmo2/convert_olmo2_weights_to_hf.py \
37
+ --input_dir /path/to/downloaded/olmo2/weights --model_size 7B --output_dir /output/path
38
+ ```
39
+
40
+ Thereafter, models can be loaded via:
41
+
42
+ ```py
43
+ from transformers import Olmo2ForCausalLM, AutoTokenizer
44
+
45
+ model = Olmo2ForCausalLM.from_pretrained("/output/path")
46
+ tokenizer = AutoTokenizer.from_pretrained("/output/path")
47
+ ```
48
+
49
+ Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
50
+ come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
51
+ """
52
+
53
+
54
+ def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
55
+ return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)
56
+
57
+
58
+ def read_json(path):
59
+ with open(path, "r") as f:
60
+ return json.load(f)
61
+
62
+
63
+ def write_json(text, path):
64
+ with open(path, "w") as f:
65
+ json.dump(text, f)
66
+
67
+
68
+ def write_model(
69
+ model_path,
70
+ input_base_path,
71
+ include_tokenizer=True,
72
+ tokenizer_path=None,
73
+ safe_serialization=True,
74
+ fix_eos_token_id=True,
75
+ tmp_cleanup=True,
76
+ ):
77
+ os.makedirs(model_path, exist_ok=True)
78
+ tmp_model_path = os.path.join(model_path, "tmp")
79
+ os.makedirs(tmp_model_path, exist_ok=True)
80
+
81
+ config_path = Path(input_base_path) / "config.yaml"
82
+ olmo2_config = yaml.safe_load(config_path.read_text())["model"]
83
+
84
+ if not olmo2_config.get("attention_layer_norm", False):
85
+ raise RuntimeError("OLMo2 checkpoints must have attention layer norm")
86
+ if not olmo2_config.get("norm_after", False):
87
+ raise RuntimeError("OLMo2 checkpoints must set norm_after to True")
88
+
89
+ n_layers = olmo2_config["n_layers"]
90
+ n_heads = olmo2_config["n_heads"]
91
+ dim = olmo2_config["d_model"]
92
+ dims_per_head = dim // n_heads
93
+ base = olmo2_config["rope_theta"]
94
+ inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
95
+ max_position_embeddings = olmo2_config["max_sequence_length"]
96
+
97
+ vocab_size = olmo2_config.get("embedding_size", olmo2_config["vocab_size"])
98
+
99
+ if olmo2_config.get("n_kv_heads", None) is not None:
100
+ num_key_value_heads = olmo2_config["n_kv_heads"] # for GQA / MQA
101
+ elif olmo2_config["multi_query_attention"]: # compatibility with other checkpoints
102
+ num_key_value_heads = 1
103
+ else:
104
+ num_key_value_heads = n_heads
105
+
106
+ print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
107
+
108
+ # Not sharded
109
+ # (The sharded implementation would also work, but this is simpler.)
110
+ loaded = torch.load(os.path.join(input_base_path, "model.pt"), map_location="cpu", weights_only=True)
111
+
112
+ param_count = 0
113
+ index_dict: Dict[str, Any] = {"weight_map": {}}
114
+ for layer_i in range(n_layers):
115
+ filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
116
+ # Unsharded
117
+ # TODO: Layernorm stuff
118
+ # TODO: multi query attention
119
+ fused_dims = [dim, dims_per_head * num_key_value_heads, dims_per_head * num_key_value_heads]
120
+ q_proj_weight, k_proj_weight, v_proj_weight = torch.split(
121
+ loaded[f"transformer.blocks.{layer_i}.att_proj.weight"], fused_dims, dim=0
122
+ )
123
+ up_proj_weight, gate_proj_weight = torch.chunk(
124
+ loaded[f"transformer.blocks.{layer_i}.ff_proj.weight"], 2, dim=0
125
+ )
126
+ state_dict = {
127
+ f"model.layers.{layer_i}.self_attn.q_proj.weight": q_proj_weight,
128
+ f"model.layers.{layer_i}.self_attn.k_proj.weight": k_proj_weight,
129
+ f"model.layers.{layer_i}.self_attn.v_proj.weight": v_proj_weight,
130
+ f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"transformer.blocks.{layer_i}.attn_out.weight"],
131
+ f"model.layers.{layer_i}.self_attn.q_norm.weight": loaded[f"transformer.blocks.{layer_i}.q_norm.weight"],
132
+ f"model.layers.{layer_i}.self_attn.k_norm.weight": loaded[f"transformer.blocks.{layer_i}.k_norm.weight"],
133
+ f"model.layers.{layer_i}.mlp.gate_proj.weight": gate_proj_weight,
134
+ f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"transformer.blocks.{layer_i}.ff_out.weight"],
135
+ f"model.layers.{layer_i}.mlp.up_proj.weight": up_proj_weight,
136
+ f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[
137
+ f"transformer.blocks.{layer_i}.attn_norm.weight"
138
+ ],
139
+ f"model.layers.{layer_i}.post_feedforward_layernorm.weight": loaded[
140
+ f"transformer.blocks.{layer_i}.ff_norm.weight"
141
+ ],
142
+ }
143
+
144
+ state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
145
+
146
+ for k, v in state_dict.items():
147
+ index_dict["weight_map"][k] = filename
148
+ param_count += v.numel()
149
+ torch.save(state_dict, os.path.join(tmp_model_path, filename))
150
+
151
+ filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
152
+
153
+ # Unsharded
154
+ # TODO: Deal with weight-tying
155
+ state_dict = {
156
+ "model.embed_tokens.weight": loaded["transformer.wte.weight"],
157
+ "model.norm.weight": loaded["transformer.ln_f.weight"],
158
+ "lm_head.weight": loaded["transformer.ff_out.weight"]
159
+ if "transformer.ff_out.weight" in loaded
160
+ else loaded["transformer.wte.weight"],
161
+ }
162
+
163
+ for k, v in state_dict.items():
164
+ index_dict["weight_map"][k] = filename
165
+ param_count += v.numel()
166
+ torch.save(state_dict, os.path.join(tmp_model_path, filename))
167
+
168
+ # Write configs
169
+ index_dict["metadata"] = {"total_size": param_count * 2}
170
+ write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
171
+
172
+ if olmo2_config.get("mlp_hidden_size", None) is not None:
173
+ intermediate_size = olmo2_config["mlp_hidden_size"] // 2
174
+ else:
175
+ intermediate_size = (dim * olmo2_config["mlp_ratio"]) // 2
176
+
177
+ if fix_eos_token_id and olmo2_config["eos_token_id"] == 0:
178
+ # Fixing a bug in OLMo where eos token id was incorrectly set
179
+ print("Changing eos_token_id from 0 to 50279.")
180
+ olmo2_config["eos_token_id"] = 50279
181
+
182
+ config = Olmo2Config(
183
+ vocab_size=vocab_size,
184
+ hidden_size=dim,
185
+ intermediate_size=intermediate_size,
186
+ num_hidden_layers=n_layers,
187
+ num_attention_heads=n_heads,
188
+ num_key_value_heads=num_key_value_heads,
189
+ max_position_embeddings=max_position_embeddings,
190
+ pad_token_id=olmo2_config["pad_token_id"],
191
+ bos_token_id=None,
192
+ eos_token_id=olmo2_config["eos_token_id"],
193
+ tie_word_embeddings=olmo2_config["weight_tying"],
194
+ rms_norm_eps=olmo2_config["layer_norm_eps"],
195
+ rope_theta=base,
196
+ )
197
+ config.save_pretrained(tmp_model_path)
198
+
199
+ # Make space so we can load the model properly now.
200
+ del state_dict
201
+ del loaded
202
+ gc.collect()
203
+
204
+ if include_tokenizer:
205
+ _write_tokenizer(model_path, config, input_base_path, tokenizer_path)
206
+
207
+ print("Loading the checkpoint in a OLMo2 model.")
208
+ model = Olmo2ForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float32, low_cpu_mem_usage=True)
209
+ # Avoid saving this as part of the config.
210
+ del model.config._name_or_path
211
+ print("Saving in the Transformers format.")
212
+ model.save_pretrained(model_path, safe_serialization=safe_serialization)
213
+ if tmp_cleanup:
214
+ # Make cleanup optional; attempting to `rmtree` the `tmp_model_path` causes
215
+ # errors if using NFS.
216
+ shutil.rmtree(tmp_model_path)
217
+
218
+
219
+ def _write_tokenizer(
220
+ output_path: Path,
221
+ config: Olmo2Config,
222
+ checkpoint_dir: str,
223
+ input_tokenizer_path: Path | None,
224
+ ) -> None:
225
+ print(f"Saving a {GPT2TokenizerFast.__name__} to {output_path}.")
226
+
227
+ if input_tokenizer_path is not None:
228
+ base_tokenizer = Tokenizer.from_file(str(input_tokenizer_path))
229
+ else:
230
+ config_path = Path(checkpoint_dir) / "config.yaml"
231
+ tokenizer_config = yaml.safe_load(config_path.read_text())["tokenizer"]
232
+
233
+ # Initialize tokenizer and validate vocab size.
234
+ if Path(tokenizer_config["identifier"]).is_file():
235
+ base_tokenizer = Tokenizer.from_file(tokenizer_config["identifier"])
236
+ else:
237
+ base_tokenizer = Tokenizer.from_pretrained(tokenizer_config["identifier"])
238
+
239
+ eos_token_id = config.eos_token_id if config.eos_token_id is not None else base_tokenizer.get_vocab_size() - 1
240
+ pad_token_id = config.pad_token_id if config.pad_token_id is not None else eos_token_id
241
+
242
+ tokenizer = GPT2TokenizerFast(
243
+ tokenizer_object=base_tokenizer,
244
+ eos_token=base_tokenizer.decode([eos_token_id], skip_special_tokens=False),
245
+ pad_token=base_tokenizer.decode([pad_token_id], skip_special_tokens=False),
246
+ )
247
+
248
+ tokenizer.save_pretrained(output_path)
249
+
250
+
251
+ def main():
252
+ parser = argparse.ArgumentParser()
253
+ parser.add_argument(
254
+ "--input_dir",
255
+ required=True,
256
+ help="Location of OLMo2 weights, which contains config.yaml and model.pt.",
257
+ )
258
+ parser.add_argument(
259
+ "--no_tokenizer",
260
+ action="store_false",
261
+ dest="include_tokenizer",
262
+ help="If set, do not convert OLMo tokenizer to HF tokenizer.",
263
+ )
264
+ parser.add_argument(
265
+ "--tokenizer_json_path",
266
+ type=Path,
267
+ default=None,
268
+ help="Location of OLMo2 tokenizer json file. Defaults to what is set in the config file.",
269
+ )
270
+ parser.add_argument(
271
+ "--output_dir",
272
+ required=True,
273
+ help="Location to write HF model and tokenizer",
274
+ )
275
+ parser.add_argument(
276
+ "--no_fix_eos_token_id",
277
+ action="store_false",
278
+ dest="fix_eos_token_id",
279
+ help="If set, does not change eos token id from 0 to 50279 if it is 0. Changing 0 to 50279 is a bug fix, so use this option with care.",
280
+ )
281
+ parser.add_argument(
282
+ "--no_tmp_cleanup",
283
+ action="store_false",
284
+ dest="tmp_cleanup",
285
+ help="If passed, don't remove temp dir at end of HF conversion.",
286
+ )
287
+ parser.add_argument(
288
+ "--no_safe_serialization",
289
+ action="store_false",
290
+ dest="safe_serialization",
291
+ help="Whether or not to save using `safetensors`.",
292
+ )
293
+ args = parser.parse_args()
294
+ write_model(
295
+ model_path=args.output_dir,
296
+ input_base_path=args.input_dir,
297
+ safe_serialization=args.safe_serialization,
298
+ include_tokenizer=args.include_tokenizer,
299
+ tokenizer_path=args.tokenizer_json_path,
300
+ fix_eos_token_id=args.fix_eos_token_id,
301
+ tmp_cleanup=args.tmp_cleanup,
302
+ )
303
+
304
+
305
+ if __name__ == "__main__":
306
+ main()
docs/transformers/src/transformers/models/olmo2/modeling_olmo2.py ADDED
@@ -0,0 +1,820 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/olmo2/modular_olmo2.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_olmo2.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ from typing import Callable, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from ...activations import ACT2FN
13
+ from ...cache_utils import Cache, DynamicCache, StaticCache
14
+ from ...generation import GenerationMixin
15
+ from ...integrations import use_kernel_forward_from_hub
16
+ from ...modeling_attn_mask_utils import AttentionMaskConverter
17
+ from ...modeling_flash_attention_utils import FlashAttentionKwargs
18
+ from ...modeling_layers import GradientCheckpointingLayer
19
+ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
20
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
21
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
22
+ from ...processing_utils import Unpack
23
+ from ...utils import (
24
+ LossKwargs,
25
+ add_start_docstrings,
26
+ add_start_docstrings_to_model_forward,
27
+ can_return_tuple,
28
+ is_torch_flex_attn_available,
29
+ logging,
30
+ replace_return_docstrings,
31
+ )
32
+ from .configuration_olmo2 import Olmo2Config
33
+
34
+
35
+ if is_torch_flex_attn_available():
36
+ from torch.nn.attention.flex_attention import BlockMask
37
+
38
+ from ...integrations.flex_attention import make_flex_block_causal_mask
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+ _CONFIG_FOR_DOC = "Olmo2Config"
43
+
44
+
45
+ @use_kernel_forward_from_hub("RMSNorm")
46
+ class Olmo2RMSNorm(nn.Module):
47
+ def __init__(self, hidden_size, eps=1e-6):
48
+ """
49
+ Olmo2RMSNorm is equivalent to T5LayerNorm
50
+ """
51
+ super().__init__()
52
+ self.weight = nn.Parameter(torch.ones(hidden_size))
53
+ self.variance_epsilon = eps
54
+
55
+ def forward(self, hidden_states):
56
+ input_dtype = hidden_states.dtype
57
+ hidden_states = hidden_states.to(torch.float32)
58
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
59
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
60
+ return self.weight * hidden_states.to(input_dtype)
61
+
62
+ def extra_repr(self):
63
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
64
+
65
+
66
+ def rotate_half(x):
67
+ """Rotates half the hidden dims of the input."""
68
+ x1 = x[..., : x.shape[-1] // 2]
69
+ x2 = x[..., x.shape[-1] // 2 :]
70
+ return torch.cat((-x2, x1), dim=-1)
71
+
72
+
73
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
74
+ """Applies Rotary Position Embedding to the query and key tensors.
75
+
76
+ Args:
77
+ q (`torch.Tensor`): The query tensor.
78
+ k (`torch.Tensor`): The key tensor.
79
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
80
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
81
+ position_ids (`torch.Tensor`, *optional*):
82
+ Deprecated and unused.
83
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
84
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
85
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
86
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
87
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
88
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
89
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
90
+ Returns:
91
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
92
+ """
93
+ cos = cos.unsqueeze(unsqueeze_dim)
94
+ sin = sin.unsqueeze(unsqueeze_dim)
95
+ q_embed = (q * cos) + (rotate_half(q) * sin)
96
+ k_embed = (k * cos) + (rotate_half(k) * sin)
97
+ return q_embed, k_embed
98
+
99
+
100
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
101
+ """
102
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
103
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
104
+ """
105
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
106
+ if n_rep == 1:
107
+ return hidden_states
108
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
109
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
110
+
111
+
112
+ def eager_attention_forward(
113
+ module: nn.Module,
114
+ query: torch.Tensor,
115
+ key: torch.Tensor,
116
+ value: torch.Tensor,
117
+ attention_mask: Optional[torch.Tensor],
118
+ scaling: float,
119
+ dropout: float = 0.0,
120
+ **kwargs,
121
+ ):
122
+ key_states = repeat_kv(key, module.num_key_value_groups)
123
+ value_states = repeat_kv(value, module.num_key_value_groups)
124
+
125
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
126
+ if attention_mask is not None:
127
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
128
+ attn_weights = attn_weights + causal_mask
129
+
130
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
131
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
132
+ attn_output = torch.matmul(attn_weights, value_states)
133
+ attn_output = attn_output.transpose(1, 2).contiguous()
134
+
135
+ return attn_output, attn_weights
136
+
137
+
138
+ class Olmo2Attention(nn.Module):
139
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
140
+
141
+ def __init__(self, config: Olmo2Config, layer_idx: Optional[int] = None):
142
+ super().__init__()
143
+ self.config = config
144
+ self.layer_idx = layer_idx
145
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
146
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
147
+ self.scaling = self.head_dim**-0.5
148
+ self.attention_dropout = config.attention_dropout
149
+ self.is_causal = True
150
+
151
+ self.q_proj = nn.Linear(
152
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
153
+ )
154
+ self.k_proj = nn.Linear(
155
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
156
+ )
157
+ self.v_proj = nn.Linear(
158
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
159
+ )
160
+ self.o_proj = nn.Linear(
161
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
162
+ )
163
+ self.q_norm = Olmo2RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps)
164
+ self.k_norm = Olmo2RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps)
165
+
166
+ def forward(
167
+ self,
168
+ hidden_states: torch.Tensor,
169
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
170
+ attention_mask: Optional[torch.Tensor],
171
+ past_key_value: Optional[Cache] = None,
172
+ cache_position: Optional[torch.LongTensor] = None,
173
+ **kwargs,
174
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
175
+ input_shape = hidden_states.shape[:-1]
176
+ hidden_shape = (*input_shape, -1, self.head_dim)
177
+
178
+ query_states = self.q_norm(self.q_proj(hidden_states))
179
+ key_states = self.k_norm(self.k_proj(hidden_states))
180
+ value_states = self.v_proj(hidden_states)
181
+
182
+ query_states = query_states.view(hidden_shape).transpose(1, 2)
183
+ key_states = key_states.view(hidden_shape).transpose(1, 2)
184
+ value_states = value_states.view(hidden_shape).transpose(1, 2)
185
+
186
+ cos, sin = position_embeddings
187
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
188
+
189
+ if past_key_value is not None:
190
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
191
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
192
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
193
+
194
+ attention_interface: Callable = eager_attention_forward
195
+ if self.config._attn_implementation != "eager":
196
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
197
+ logger.warning_once(
198
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
199
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
200
+ )
201
+ else:
202
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
203
+
204
+ attn_output, attn_weights = attention_interface(
205
+ self,
206
+ query_states,
207
+ key_states,
208
+ value_states,
209
+ attention_mask,
210
+ dropout=0.0 if not self.training else self.attention_dropout,
211
+ scaling=self.scaling,
212
+ **kwargs,
213
+ )
214
+
215
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
216
+ attn_output = self.o_proj(attn_output)
217
+ return attn_output, attn_weights
218
+
219
+
220
+ class Olmo2MLP(nn.Module):
221
+ def __init__(self, config):
222
+ super().__init__()
223
+ self.config = config
224
+ self.hidden_size = config.hidden_size
225
+ self.intermediate_size = config.intermediate_size
226
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
227
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
228
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
229
+ self.act_fn = ACT2FN[config.hidden_act]
230
+
231
+ def forward(self, x):
232
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
233
+ return down_proj
234
+
235
+
236
+ class Olmo2DecoderLayer(GradientCheckpointingLayer):
237
+ def __init__(self, config: Olmo2Config, layer_idx: int):
238
+ super().__init__()
239
+ self.hidden_size = config.hidden_size
240
+ self.self_attn = Olmo2Attention(config=config, layer_idx=layer_idx)
241
+
242
+ self.mlp = Olmo2MLP(config)
243
+ self.post_attention_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
244
+ self.post_feedforward_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
245
+
246
+ def forward(
247
+ self,
248
+ hidden_states: torch.Tensor,
249
+ attention_mask: Optional[torch.Tensor] = None,
250
+ position_ids: Optional[torch.LongTensor] = None,
251
+ past_key_value: Optional[Cache] = None,
252
+ output_attentions: Optional[bool] = False,
253
+ use_cache: Optional[bool] = False,
254
+ cache_position: Optional[torch.LongTensor] = None,
255
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
256
+ **kwargs,
257
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
258
+ residual = hidden_states
259
+
260
+ # Self Attention
261
+ hidden_states, self_attn_weights = self.self_attn(
262
+ hidden_states=hidden_states,
263
+ attention_mask=attention_mask,
264
+ position_ids=position_ids,
265
+ past_key_value=past_key_value,
266
+ output_attentions=output_attentions,
267
+ use_cache=use_cache,
268
+ cache_position=cache_position,
269
+ position_embeddings=position_embeddings,
270
+ **kwargs,
271
+ )
272
+ hidden_states = self.post_attention_layernorm(hidden_states)
273
+ hidden_states = residual + hidden_states
274
+
275
+ # Fully Connected
276
+ residual = hidden_states
277
+ hidden_states = self.mlp(hidden_states)
278
+ hidden_states = self.post_feedforward_layernorm(hidden_states)
279
+ hidden_states = residual + hidden_states
280
+
281
+ outputs = (hidden_states,)
282
+ if output_attentions:
283
+ outputs += (self_attn_weights,)
284
+
285
+ return outputs
286
+
287
+
288
+ class Olmo2RotaryEmbedding(nn.Module):
289
+ def __init__(self, config: Olmo2Config, device=None):
290
+ super().__init__()
291
+ # BC: "rope_type" was originally "type"
292
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
293
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
294
+ else:
295
+ self.rope_type = "default"
296
+ self.max_seq_len_cached = config.max_position_embeddings
297
+ self.original_max_seq_len = config.max_position_embeddings
298
+
299
+ self.config = config
300
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
301
+
302
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
303
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
304
+ self.original_inv_freq = self.inv_freq
305
+
306
+ @torch.no_grad()
307
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
308
+ def forward(self, x, position_ids):
309
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
310
+ position_ids_expanded = position_ids[:, None, :].float()
311
+
312
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
313
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
314
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
315
+ emb = torch.cat((freqs, freqs), dim=-1)
316
+ cos = emb.cos() * self.attention_scaling
317
+ sin = emb.sin() * self.attention_scaling
318
+
319
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
320
+
321
+
322
+ OLMO2_START_DOCSTRING = r"""
323
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
324
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
325
+ etc.)
326
+
327
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
328
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
329
+ and behavior.
330
+
331
+ Parameters:
332
+ config ([`Olmo2Config`]):
333
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
334
+ load the weights associated with the model, only the configuration. Check out the
335
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
336
+ """
337
+
338
+
339
+ @add_start_docstrings(
340
+ "The bare Olmo2 Model outputting raw hidden-states without any specific head on top.",
341
+ OLMO2_START_DOCSTRING,
342
+ )
343
+ class Olmo2PreTrainedModel(PreTrainedModel):
344
+ config_class = Olmo2Config
345
+ base_model_prefix = "model"
346
+ supports_gradient_checkpointing = True
347
+ _no_split_modules = ["Olmo2DecoderLayer"]
348
+ _skip_keys_device_placement = ["past_key_values"]
349
+ _supports_flash_attn_2 = True
350
+ _supports_sdpa = True
351
+ _supports_flex_attn = True
352
+ _supports_cache_class = True
353
+ _supports_quantized_cache = True
354
+ _supports_static_cache = True
355
+ _supports_attention_backend = True
356
+
357
+ def _init_weights(self, module):
358
+ std = self.config.initializer_range
359
+ if isinstance(module, nn.Linear):
360
+ module.weight.data.normal_(mean=0.0, std=std)
361
+ if module.bias is not None:
362
+ module.bias.data.zero_()
363
+ elif isinstance(module, nn.Embedding):
364
+ module.weight.data.normal_(mean=0.0, std=std)
365
+ if module.padding_idx is not None:
366
+ module.weight.data[module.padding_idx].zero_()
367
+ elif isinstance(module, Olmo2RMSNorm):
368
+ module.weight.data.fill_(1.0)
369
+
370
+
371
+ OLMO2_INPUTS_DOCSTRING = r"""
372
+ Args:
373
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
374
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
375
+ it.
376
+
377
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
378
+ [`PreTrainedTokenizer.__call__`] for details.
379
+
380
+ [What are input IDs?](../glossary#input-ids)
381
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*):
382
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
383
+
384
+ - 1 for tokens that are **not masked**,
385
+ - 0 for tokens that are **masked**.
386
+
387
+ If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask,
388
+ but you can also pass a `BlockMask` object directly here.
389
+
390
+ [What are attention masks?](../glossary#attention-mask)
391
+
392
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
393
+ [`PreTrainedTokenizer.__call__`] for details.
394
+
395
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
396
+ `past_key_values`).
397
+
398
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
399
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
400
+ information on the default strategy.
401
+
402
+ - 1 indicates the head is **not masked**,
403
+ - 0 indicates the head is **masked**.
404
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
405
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
406
+ config.n_positions - 1]`.
407
+
408
+ [What are position IDs?](../glossary#position-ids)
409
+ past_key_values (`Cache`, *optional*):
410
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
411
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
412
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
413
+
414
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
415
+
416
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
417
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
418
+ of shape `(batch_size, sequence_length)`.
419
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
420
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
421
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
422
+ model's internal embedding lookup matrix.
423
+ use_cache (`bool`, *optional*):
424
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
425
+ `past_key_values`).
426
+ output_attentions (`bool`, *optional*):
427
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
428
+ tensors for more detail.
429
+ output_hidden_states (`bool`, *optional*):
430
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
431
+ more detail.
432
+ return_dict (`bool`, *optional*):
433
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
434
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
435
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
436
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
437
+ the complete sequence length.
438
+ """
439
+
440
+
441
+ @add_start_docstrings(
442
+ "The bare Olmo2 Model outputting raw hidden-states without any specific head on top.",
443
+ OLMO2_START_DOCSTRING,
444
+ )
445
+ class Olmo2Model(Olmo2PreTrainedModel):
446
+ """
447
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Olmo2DecoderLayer`]
448
+
449
+ Args:
450
+ config: Olmo2Config
451
+ """
452
+
453
+ def __init__(self, config: Olmo2Config):
454
+ super().__init__(config)
455
+ self.padding_idx = config.pad_token_id
456
+ self.vocab_size = config.vocab_size
457
+
458
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
459
+ self.layers = nn.ModuleList(
460
+ [Olmo2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
461
+ )
462
+ self.norm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
463
+ self.rotary_emb = Olmo2RotaryEmbedding(config=config)
464
+ self.gradient_checkpointing = False
465
+
466
+ # Initialize weights and apply final processing
467
+ self.post_init()
468
+
469
+ def get_input_embeddings(self):
470
+ return self.embed_tokens
471
+
472
+ def set_input_embeddings(self, value):
473
+ self.embed_tokens = value
474
+
475
+ @can_return_tuple
476
+ @add_start_docstrings_to_model_forward(OLMO2_INPUTS_DOCSTRING)
477
+ def forward(
478
+ self,
479
+ input_ids: Optional[torch.LongTensor] = None,
480
+ attention_mask: Optional[torch.Tensor] = None,
481
+ position_ids: Optional[torch.LongTensor] = None,
482
+ past_key_values: Optional[Cache] = None,
483
+ inputs_embeds: Optional[torch.FloatTensor] = None,
484
+ use_cache: Optional[bool] = None,
485
+ output_attentions: Optional[bool] = None,
486
+ output_hidden_states: Optional[bool] = None,
487
+ cache_position: Optional[torch.LongTensor] = None,
488
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
489
+ ) -> BaseModelOutputWithPast:
490
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
491
+ output_hidden_states = (
492
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
493
+ )
494
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
495
+
496
+ if (input_ids is None) ^ (inputs_embeds is not None):
497
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
498
+
499
+ if self.gradient_checkpointing and self.training and use_cache:
500
+ logger.warning_once(
501
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
502
+ )
503
+ use_cache = False
504
+
505
+ # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
506
+ if not isinstance(past_key_values, (type(None), Cache)):
507
+ raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
508
+
509
+ if inputs_embeds is None:
510
+ inputs_embeds = self.embed_tokens(input_ids)
511
+
512
+ if use_cache and past_key_values is None:
513
+ past_key_values = DynamicCache()
514
+
515
+ if cache_position is None:
516
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
517
+ cache_position = torch.arange(
518
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
519
+ )
520
+
521
+ if position_ids is None:
522
+ position_ids = cache_position.unsqueeze(0)
523
+
524
+ causal_mask = self._update_causal_mask(
525
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
526
+ )
527
+
528
+ hidden_states = inputs_embeds
529
+
530
+ # create position embeddings to be shared across the decoder layers
531
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
532
+
533
+ # decoder layers
534
+ all_hidden_states = () if output_hidden_states else None
535
+ all_self_attns = () if output_attentions else None
536
+
537
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
538
+ if output_hidden_states:
539
+ all_hidden_states += (hidden_states,)
540
+
541
+ layer_outputs = decoder_layer(
542
+ hidden_states,
543
+ attention_mask=causal_mask,
544
+ position_ids=position_ids,
545
+ past_key_value=past_key_values,
546
+ output_attentions=output_attentions,
547
+ use_cache=use_cache,
548
+ cache_position=cache_position,
549
+ position_embeddings=position_embeddings,
550
+ **flash_attn_kwargs,
551
+ )
552
+
553
+ hidden_states = layer_outputs[0]
554
+
555
+ if output_attentions:
556
+ all_self_attns += (layer_outputs[1],)
557
+
558
+ hidden_states = self.norm(hidden_states)
559
+
560
+ # add hidden states from the last decoder layer
561
+ if output_hidden_states:
562
+ all_hidden_states += (hidden_states,)
563
+
564
+ return BaseModelOutputWithPast(
565
+ last_hidden_state=hidden_states,
566
+ past_key_values=past_key_values if use_cache else None,
567
+ hidden_states=all_hidden_states,
568
+ attentions=all_self_attns,
569
+ )
570
+
571
+ def _update_causal_mask(
572
+ self,
573
+ attention_mask: Union[torch.Tensor, "BlockMask"],
574
+ input_tensor: torch.Tensor,
575
+ cache_position: torch.Tensor,
576
+ past_key_values: Cache,
577
+ output_attentions: bool = False,
578
+ ):
579
+ if self.config._attn_implementation == "flash_attention_2":
580
+ if attention_mask is not None and (attention_mask == 0.0).any():
581
+ return attention_mask
582
+ return None
583
+ if self.config._attn_implementation == "flex_attention":
584
+ if isinstance(attention_mask, torch.Tensor):
585
+ attention_mask = make_flex_block_causal_mask(attention_mask)
586
+ return attention_mask
587
+
588
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
589
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
590
+ # to infer the attention mask.
591
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
592
+ using_static_cache = isinstance(past_key_values, StaticCache)
593
+
594
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
595
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
596
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
597
+ attention_mask,
598
+ inputs_embeds=input_tensor,
599
+ past_key_values_length=past_seen_tokens,
600
+ is_training=self.training,
601
+ ):
602
+ return None
603
+
604
+ dtype, device = input_tensor.dtype, input_tensor.device
605
+ sequence_length = input_tensor.shape[1]
606
+ if using_static_cache:
607
+ target_length = past_key_values.get_max_cache_shape()
608
+ else:
609
+ target_length = (
610
+ attention_mask.shape[-1]
611
+ if isinstance(attention_mask, torch.Tensor)
612
+ else past_seen_tokens + sequence_length + 1
613
+ )
614
+
615
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
616
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
617
+ attention_mask,
618
+ sequence_length=sequence_length,
619
+ target_length=target_length,
620
+ dtype=dtype,
621
+ device=device,
622
+ cache_position=cache_position,
623
+ batch_size=input_tensor.shape[0],
624
+ )
625
+
626
+ if (
627
+ self.config._attn_implementation == "sdpa"
628
+ and attention_mask is not None
629
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
630
+ and not output_attentions
631
+ ):
632
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
633
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
634
+ # Details: https://github.com/pytorch/pytorch/issues/110213
635
+ min_dtype = torch.finfo(dtype).min
636
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
637
+
638
+ return causal_mask
639
+
640
+ @staticmethod
641
+ def _prepare_4d_causal_attention_mask_with_cache_position(
642
+ attention_mask: torch.Tensor,
643
+ sequence_length: int,
644
+ target_length: int,
645
+ dtype: torch.dtype,
646
+ device: torch.device,
647
+ cache_position: torch.Tensor,
648
+ batch_size: int,
649
+ **kwargs,
650
+ ):
651
+ """
652
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
653
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
654
+
655
+ Args:
656
+ attention_mask (`torch.Tensor`):
657
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
658
+ `(batch_size, 1, query_length, key_value_length)`.
659
+ sequence_length (`int`):
660
+ The sequence length being processed.
661
+ target_length (`int`):
662
+ The target length: when generating with static cache, the mask should be as long as the static cache,
663
+ to account for the 0 padding, the part of the cache that is not filled yet.
664
+ dtype (`torch.dtype`):
665
+ The dtype to use for the 4D attention mask.
666
+ device (`torch.device`):
667
+ The device to place the 4D attention mask on.
668
+ cache_position (`torch.Tensor`):
669
+ Indices depicting the position of the input sequence tokens in the sequence.
670
+ batch_size (`torch.Tensor`):
671
+ Batch size.
672
+ """
673
+ if attention_mask is not None and attention_mask.dim() == 4:
674
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
675
+ causal_mask = attention_mask
676
+ else:
677
+ min_dtype = torch.finfo(dtype).min
678
+ causal_mask = torch.full(
679
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
680
+ )
681
+ if sequence_length != 1:
682
+ causal_mask = torch.triu(causal_mask, diagonal=1)
683
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
684
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
685
+ if attention_mask is not None:
686
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
687
+ mask_length = attention_mask.shape[-1]
688
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
689
+ causal_mask.device
690
+ )
691
+ padding_mask = padding_mask == 0
692
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
693
+ padding_mask, min_dtype
694
+ )
695
+
696
+ return causal_mask
697
+
698
+
699
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
700
+
701
+
702
+ class Olmo2ForCausalLM(Olmo2PreTrainedModel, GenerationMixin):
703
+ _tied_weights_keys = ["lm_head.weight"]
704
+ _tp_plan = {"lm_head": "colwise_rep"}
705
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
706
+
707
+ def __init__(self, config):
708
+ super().__init__(config)
709
+ self.model = Olmo2Model(config)
710
+ self.vocab_size = config.vocab_size
711
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
712
+
713
+ # Initialize weights and apply final processing
714
+ self.post_init()
715
+
716
+ def get_input_embeddings(self):
717
+ return self.model.embed_tokens
718
+
719
+ def set_input_embeddings(self, value):
720
+ self.model.embed_tokens = value
721
+
722
+ def get_output_embeddings(self):
723
+ return self.lm_head
724
+
725
+ def set_output_embeddings(self, new_embeddings):
726
+ self.lm_head = new_embeddings
727
+
728
+ def set_decoder(self, decoder):
729
+ self.model = decoder
730
+
731
+ def get_decoder(self):
732
+ return self.model
733
+
734
+ @can_return_tuple
735
+ @add_start_docstrings_to_model_forward(OLMO2_INPUTS_DOCSTRING)
736
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
737
+ def forward(
738
+ self,
739
+ input_ids: Optional[torch.LongTensor] = None,
740
+ attention_mask: Optional[torch.Tensor] = None,
741
+ position_ids: Optional[torch.LongTensor] = None,
742
+ past_key_values: Optional[Cache] = None,
743
+ inputs_embeds: Optional[torch.FloatTensor] = None,
744
+ labels: Optional[torch.LongTensor] = None,
745
+ use_cache: Optional[bool] = None,
746
+ output_attentions: Optional[bool] = None,
747
+ output_hidden_states: Optional[bool] = None,
748
+ cache_position: Optional[torch.LongTensor] = None,
749
+ logits_to_keep: Union[int, torch.Tensor] = 0,
750
+ **kwargs: Unpack[KwargsForCausalLM],
751
+ ) -> CausalLMOutputWithPast:
752
+ r"""
753
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
754
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
755
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
756
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
757
+
758
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
759
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
760
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
761
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
762
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
763
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
764
+
765
+ Returns:
766
+
767
+ Example:
768
+
769
+ ```python
770
+ >>> from transformers import AutoTokenizer, Olmo2ForCausalLM
771
+
772
+ >>> model = Olmo2ForCausalLM.from_pretrained("meta-olmo2/Olmo2-2-7b-hf")
773
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-olmo2/Olmo2-2-7b-hf")
774
+
775
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
776
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
777
+
778
+ >>> # Generate
779
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
780
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
781
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
782
+ ```"""
783
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
784
+ output_hidden_states = (
785
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
786
+ )
787
+
788
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
789
+ outputs: BaseModelOutputWithPast = self.model(
790
+ input_ids=input_ids,
791
+ attention_mask=attention_mask,
792
+ position_ids=position_ids,
793
+ past_key_values=past_key_values,
794
+ inputs_embeds=inputs_embeds,
795
+ use_cache=use_cache,
796
+ output_attentions=output_attentions,
797
+ output_hidden_states=output_hidden_states,
798
+ cache_position=cache_position,
799
+ **kwargs,
800
+ )
801
+
802
+ hidden_states = outputs.last_hidden_state
803
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
804
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
805
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
806
+
807
+ loss = None
808
+ if labels is not None:
809
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
810
+
811
+ return CausalLMOutputWithPast(
812
+ loss=loss,
813
+ logits=logits,
814
+ past_key_values=outputs.past_key_values,
815
+ hidden_states=outputs.hidden_states,
816
+ attentions=outputs.attentions,
817
+ )
818
+
819
+
820
+ __all__ = ["Olmo2ForCausalLM", "Olmo2Model", "Olmo2PreTrainedModel"]
docs/transformers/src/transformers/models/olmo2/modular_olmo2.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional, Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from ...cache_utils import Cache
7
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
8
+ from ...pytorch_utils import ALL_LAYERNORM_LAYERS
9
+ from ...utils import logging
10
+ from ..llama.modeling_llama import LlamaPreTrainedModel, LlamaRMSNorm, eager_attention_forward
11
+ from ..olmo.configuration_olmo import OlmoConfig
12
+ from ..olmo.modeling_olmo import (
13
+ OlmoAttention,
14
+ OlmoDecoderLayer,
15
+ OlmoForCausalLM,
16
+ OlmoModel,
17
+ OlmoRotaryEmbedding,
18
+ apply_rotary_pos_emb,
19
+ )
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class Olmo2Config(OlmoConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`Olmo2Model`]. It is used to instantiate an OLMo2
28
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
29
+ defaults will yield a similar configuration to that of the [allenai/Olmo2-7B-1124-hf](https://huggingface.co/allenai/Olmo2-7B-1124-hf).
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+
35
+ Args:
36
+ vocab_size (`int`, *optional*, defaults to 50304):
37
+ Vocabulary size of the Olmo2 model. Defines the number of different tokens that can be represented by the
38
+ `inputs_ids` passed when calling [`Olmo2Model`]
39
+ hidden_size (`int`, *optional*, defaults to 4096):
40
+ Dimension of the hidden representations.
41
+ intermediate_size (`int`, *optional*, defaults to 11008):
42
+ Dimension of the MLP representations.
43
+ num_hidden_layers (`int`, *optional*, defaults to 32):
44
+ Number of hidden layers in the Transformer decoder.
45
+ num_attention_heads (`int`, *optional*, defaults to 32):
46
+ Number of attention heads for each attention layer in the Transformer decoder.
47
+ num_key_value_heads (`int`, *optional*):
48
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
49
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
50
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
51
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
52
+ by meanpooling all the original heads within that group. For more details checkout [this
53
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
54
+ `num_attention_heads`.
55
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
56
+ The non-linear activation function (function or string) in the decoder.
57
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
58
+ The maximum sequence length that this model might ever be used with.
59
+ initializer_range (`float`, *optional*, defaults to 0.02):
60
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
61
+ use_cache (`bool`, *optional*, defaults to `True`):
62
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
63
+ relevant if `config.is_decoder=True`.
64
+ pad_token_id (`int`, *optional*, defaults to 1):
65
+ Padding token id.
66
+ bos_token_id (`int`, *optional*):
67
+ Beginning of stream token id.
68
+ eos_token_id (`int`, *optional*, defaults to 50279):
69
+ End of stream token id.
70
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
71
+ Whether to tie weight embeddings
72
+ rope_theta (`float`, *optional*, defaults to 10000.0):
73
+ The base period of the RoPE embeddings.
74
+ rope_scaling (`Dict`, *optional*):
75
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
76
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
77
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
78
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
79
+ these scaling strategies behave:
80
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
81
+ experimental feature, subject to breaking API changes in future versions.
82
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
83
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
84
+ attention_dropout (`float`, *optional*, defaults to 0.0):
85
+ The dropout ratio for the attention probabilities.
86
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
87
+ The epsilon used by the rms normalization layers.
88
+
89
+ ```python
90
+ >>> from transformers import Olmo2Model, Olmo2Config
91
+
92
+ >>> # Initializing a Olmo2 7B style configuration
93
+ >>> configuration = Olmo2Config()
94
+
95
+ >>> # Initializing a model from the Olmo2 7B style configuration
96
+ >>> model = Olmo2Model(configuration)
97
+
98
+ >>> # Accessing the model configuration
99
+ >>> configuration = model.config
100
+ ```
101
+ """
102
+
103
+ model_type = "olmo2"
104
+ base_model_tp_plan = {
105
+ "layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
106
+ "layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
107
+ "layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
108
+ "layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
109
+ "layers.*.mlp.gate_proj": "colwise",
110
+ "layers.*.mlp.up_proj": "colwise",
111
+ "layers.*.mlp.down_proj": "rowwise",
112
+ }
113
+ base_model_pp_plan = {
114
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
115
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
116
+ "norm": (["hidden_states"], ["hidden_states"]),
117
+ }
118
+
119
+ def __init__(
120
+ self,
121
+ vocab_size=50304,
122
+ hidden_size=4096,
123
+ intermediate_size=11008,
124
+ num_hidden_layers=32,
125
+ num_attention_heads=32,
126
+ num_key_value_heads=None,
127
+ hidden_act="silu",
128
+ max_position_embeddings=2048,
129
+ initializer_range=0.02,
130
+ use_cache=True,
131
+ pad_token_id=1,
132
+ bos_token_id=None,
133
+ eos_token_id=50279,
134
+ tie_word_embeddings=False,
135
+ rope_theta=10000.0,
136
+ rope_scaling=None,
137
+ attention_bias=False,
138
+ attention_dropout=0.0,
139
+ rms_norm_eps=1e-5,
140
+ **kwargs,
141
+ ):
142
+ super().__init__(
143
+ vocab_size=vocab_size,
144
+ hidden_size=hidden_size,
145
+ intermediate_size=intermediate_size,
146
+ num_hidden_layers=num_hidden_layers,
147
+ num_attention_heads=num_attention_heads,
148
+ num_key_value_heads=num_key_value_heads,
149
+ hidden_act=hidden_act,
150
+ max_position_embeddings=max_position_embeddings,
151
+ initializer_range=initializer_range,
152
+ use_cache=use_cache,
153
+ pad_token_id=pad_token_id,
154
+ bos_token_id=bos_token_id,
155
+ eos_token_id=eos_token_id,
156
+ tie_word_embeddings=tie_word_embeddings,
157
+ rope_theta=rope_theta,
158
+ rope_scaling=rope_scaling,
159
+ attention_bias=attention_bias,
160
+ attention_dropout=attention_dropout,
161
+ **kwargs,
162
+ )
163
+
164
+ self.rms_norm_eps = rms_norm_eps
165
+ del self.clip_qkv
166
+
167
+
168
+ class Olmo2RMSNorm(LlamaRMSNorm):
169
+ pass
170
+
171
+
172
+ ALL_LAYERNORM_LAYERS.append(Olmo2RMSNorm)
173
+
174
+
175
+ # Olmo2 attention is identical to OLMo attention except:
176
+ # - Norm is applied to attention queries and keys.
177
+ # - No qkv clipping.
178
+ class Olmo2Attention(OlmoAttention):
179
+ def __init__(self, config: Olmo2Config, layer_idx: Optional[int] = None):
180
+ super().__init__(config, layer_idx=layer_idx)
181
+ self.q_norm = Olmo2RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps)
182
+ self.k_norm = Olmo2RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps)
183
+
184
+ def forward(
185
+ self,
186
+ hidden_states: torch.Tensor,
187
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
188
+ attention_mask: Optional[torch.Tensor],
189
+ past_key_value: Optional[Cache] = None,
190
+ cache_position: Optional[torch.LongTensor] = None,
191
+ **kwargs,
192
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
193
+ input_shape = hidden_states.shape[:-1]
194
+ hidden_shape = (*input_shape, -1, self.head_dim)
195
+
196
+ query_states = self.q_norm(self.q_proj(hidden_states))
197
+ key_states = self.k_norm(self.k_proj(hidden_states))
198
+ value_states = self.v_proj(hidden_states)
199
+
200
+ query_states = query_states.view(hidden_shape).transpose(1, 2)
201
+ key_states = key_states.view(hidden_shape).transpose(1, 2)
202
+ value_states = value_states.view(hidden_shape).transpose(1, 2)
203
+
204
+ cos, sin = position_embeddings
205
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
206
+
207
+ if past_key_value is not None:
208
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
209
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
210
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
211
+
212
+ attention_interface: Callable = eager_attention_forward
213
+ if self.config._attn_implementation != "eager":
214
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
215
+ logger.warning_once(
216
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
217
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
218
+ )
219
+ else:
220
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
221
+
222
+ attn_output, attn_weights = attention_interface(
223
+ self,
224
+ query_states,
225
+ key_states,
226
+ value_states,
227
+ attention_mask,
228
+ dropout=0.0 if not self.training else self.attention_dropout,
229
+ scaling=self.scaling,
230
+ **kwargs,
231
+ )
232
+
233
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
234
+ attn_output = self.o_proj(attn_output)
235
+ return attn_output, attn_weights
236
+
237
+
238
+ # The OLMo2 layers are identical to those of the OLMo model except:
239
+ # - RMSNorm is used instead of standard layer norm.
240
+ # - Norm is applied after attention/feedforward rather than before.
241
+ class Olmo2DecoderLayer(OlmoDecoderLayer):
242
+ def __init__(self, config: Olmo2Config, layer_idx: int):
243
+ super().__init__(config, layer_idx=layer_idx)
244
+ self.post_attention_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
245
+ self.post_feedforward_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
246
+ self.self_attn = Olmo2Attention(config=config, layer_idx=layer_idx)
247
+ del self.input_layernorm
248
+
249
+ def forward(
250
+ self,
251
+ hidden_states: torch.Tensor,
252
+ attention_mask: Optional[torch.Tensor] = None,
253
+ position_ids: Optional[torch.LongTensor] = None,
254
+ past_key_value: Optional[Cache] = None,
255
+ output_attentions: Optional[bool] = False,
256
+ use_cache: Optional[bool] = False,
257
+ cache_position: Optional[torch.LongTensor] = None,
258
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
259
+ **kwargs,
260
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
261
+ residual = hidden_states
262
+
263
+ # Self Attention
264
+ hidden_states, self_attn_weights = self.self_attn(
265
+ hidden_states=hidden_states,
266
+ attention_mask=attention_mask,
267
+ position_ids=position_ids,
268
+ past_key_value=past_key_value,
269
+ output_attentions=output_attentions,
270
+ use_cache=use_cache,
271
+ cache_position=cache_position,
272
+ position_embeddings=position_embeddings,
273
+ **kwargs,
274
+ )
275
+ hidden_states = self.post_attention_layernorm(hidden_states)
276
+ hidden_states = residual + hidden_states
277
+
278
+ # Fully Connected
279
+ residual = hidden_states
280
+ hidden_states = self.mlp(hidden_states)
281
+ hidden_states = self.post_feedforward_layernorm(hidden_states)
282
+ hidden_states = residual + hidden_states
283
+
284
+ outputs = (hidden_states,)
285
+ if output_attentions:
286
+ outputs += (self_attn_weights,)
287
+
288
+ return outputs
289
+
290
+
291
+ class Olmo2RotaryEmbedding(OlmoRotaryEmbedding):
292
+ pass
293
+
294
+
295
+ class Olmo2PreTrainedModel(LlamaPreTrainedModel):
296
+ pass
297
+
298
+
299
+ # The OLMo2 model is identical to the OLMo model, except RMSNorm is used instead of
300
+ # standard layer norm for the output norm.
301
+ class Olmo2Model(OlmoModel):
302
+ def __init__(self, config: Olmo2Config):
303
+ super().__init__(config)
304
+ self.norm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
305
+ self.layers = nn.ModuleList(
306
+ [Olmo2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
307
+ )
308
+
309
+
310
+ # The heads now only need to redefine the model inside to the correct `RobertaModel`
311
+ class Olmo2ForCausalLM(OlmoForCausalLM):
312
+ pass
313
+
314
+
315
+ __all__ = [
316
+ "Olmo2Config",
317
+ "Olmo2ForCausalLM",
318
+ "Olmo2Model",
319
+ "Olmo2PreTrainedModel", # noqa: F822
320
+ ]
docs/transformers/src/transformers/models/olmoe/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_olmoe import *
22
+ from .modeling_olmoe import *
23
+ else:
24
+ import sys
25
+
26
+ _file = globals()["__file__"]
27
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/src/transformers/models/olmoe/configuration_olmoe.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+ """OLMoE model configuration"""
13
+
14
+ from ...configuration_utils import PretrainedConfig
15
+ from ...modeling_rope_utils import rope_config_validation
16
+
17
+
18
+ class OlmoeConfig(PretrainedConfig):
19
+ r"""
20
+ This is the configuration class to store the configuration of a [`OlmoeModel`]. It is used to instantiate an OLMoE
21
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
22
+ defaults will yield a similar configuration to that of the [allenai/OLMoE-1B-7B-0924](https://huggingface.co/allenai/OLMoE-1B-7B-0924).
23
+
24
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
25
+ documentation from [`PretrainedConfig`] for more information.
26
+
27
+
28
+ Args:
29
+ vocab_size (`int`, *optional*, defaults to 50304):
30
+ Vocabulary size of the OLMoE model. Defines the number of different tokens that can be represented by the
31
+ `inputs_ids` passed when calling [`OlmoeModel`]
32
+ hidden_size (`int`, *optional*, defaults to 2048):
33
+ Dimension of the hidden representations.
34
+ intermediate_size (`int`, *optional*, defaults to 2048):
35
+ Dimension of the MLP representations.
36
+ num_hidden_layers (`int`, *optional*, defaults to 16):
37
+ Number of hidden layers in the Transformer decoder.
38
+ num_attention_heads (`int`, *optional*, defaults to 16):
39
+ Number of attention heads for each attention layer in the Transformer decoder.
40
+ num_key_value_heads (`int`, *optional*):
41
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
42
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
43
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
44
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
45
+ by meanpooling all the original heads within that group. For more details checkout [this
46
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
47
+ `num_attention_heads`.
48
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
49
+ The non-linear activation function (function or string) in the decoder.
50
+ max_position_embeddings (`int`, *optional*, defaults to 4096):
51
+ The maximum sequence length that this model might ever be used with.
52
+ initializer_range (`float`, *optional*, defaults to 0.02):
53
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
54
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
55
+ The epsilon used by the rms normalization layers.
56
+ use_cache (`bool`, *optional*, defaults to `True`):
57
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
58
+ relevant if `config.is_decoder=True`.
59
+ pad_token_id (`int`, *optional*, defaults to 1):
60
+ Padding token id.
61
+ bos_token_id (`int`, *optional*):
62
+ Beginning of stream token id.
63
+ eos_token_id (`int`, *optional*, defaults to 50279):
64
+ End of stream token id.
65
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
66
+ Whether to tie weight embeddings
67
+ rope_theta (`float`, *optional*, defaults to 10000.0):
68
+ The base period of the RoPE embeddings.
69
+ rope_scaling (`Dict`, *optional*):
70
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
71
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
72
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
73
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
74
+ these scaling strategies behave:
75
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
76
+ experimental feature, subject to breaking API changes in future versions.
77
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
78
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
79
+ attention_dropout (`float`, *optional*, defaults to 0.0):
80
+ The dropout ratio for the attention probabilities.
81
+ clip_qkv (`float`, *optional*):
82
+ If not `None`, elements of query, key and value attention states are clipped so that their
83
+ absolute value does not exceed this value.
84
+ num_experts_per_tok (`int`, *optional*, defaults to 8):
85
+ Number of selected experts.
86
+ num_experts (`int`, *optional*, defaults to 64):
87
+ Number of routed experts.
88
+ output_router_logits (`bool`, *optional*, defaults to `False`):
89
+ Whether or not the router logits should be returned by the model. Enabeling this will also
90
+ allow the model to output the auxiliary loss, including load balancing loss and router z-loss.
91
+ router_aux_loss_coef (`float`, *optional*, defaults to 0.01):
92
+ The aux loss factor for the total loss.
93
+ norm_topk_prob (`bool`, *optional*, defaults to `False`):
94
+ Whether to normalize the topk probabilities.
95
+
96
+ ```python
97
+ >>> from transformers import OlmoeModel, OlmoeConfig
98
+
99
+ >>> # Initializing a OLMoE 7B A1B style configuration
100
+ >>> configuration = OlmoeConfig()
101
+
102
+ >>> # Initializing a model from the OLMoE 7B A1B style configuration
103
+ >>> model = OlmoeModel(configuration)
104
+
105
+ >>> # Accessing the model configuration
106
+ >>> configuration = model.config
107
+ ```"""
108
+
109
+ model_type = "olmoe"
110
+ keys_to_ignore_at_inference = ["past_key_values"]
111
+
112
+ def __init__(
113
+ self,
114
+ vocab_size=50304,
115
+ hidden_size=2048,
116
+ intermediate_size=2048,
117
+ num_hidden_layers=16,
118
+ num_attention_heads=16,
119
+ num_key_value_heads=None,
120
+ hidden_act="silu",
121
+ max_position_embeddings=4096,
122
+ initializer_range=0.02,
123
+ rms_norm_eps=1e-05,
124
+ use_cache=True,
125
+ pad_token_id=1,
126
+ bos_token_id=None,
127
+ eos_token_id=50279,
128
+ tie_word_embeddings=False,
129
+ rope_theta=10000.0,
130
+ rope_scaling=None,
131
+ attention_bias=False,
132
+ attention_dropout=0.0,
133
+ clip_qkv=None,
134
+ num_experts_per_tok=8,
135
+ num_experts=64,
136
+ output_router_logits=False,
137
+ router_aux_loss_coef=0.01,
138
+ norm_topk_prob=False,
139
+ **kwargs,
140
+ ):
141
+ self.vocab_size = vocab_size
142
+ self.max_position_embeddings = max_position_embeddings
143
+ self.hidden_size = hidden_size
144
+ self.intermediate_size = intermediate_size
145
+ self.num_hidden_layers = num_hidden_layers
146
+ self.num_attention_heads = num_attention_heads
147
+
148
+ # for backward compatibility
149
+ if num_key_value_heads is None:
150
+ num_key_value_heads = num_attention_heads
151
+
152
+ self.num_key_value_heads = num_key_value_heads
153
+ self.hidden_act = hidden_act
154
+ self.initializer_range = initializer_range
155
+ self.rms_norm_eps = rms_norm_eps
156
+ self.use_cache = use_cache
157
+ self.rope_theta = rope_theta
158
+ self.rope_scaling = rope_scaling
159
+ self.attention_bias = attention_bias
160
+ self.attention_dropout = attention_dropout
161
+ self.clip_qkv = clip_qkv
162
+ self.num_experts_per_tok = num_experts_per_tok
163
+ self.num_experts = num_experts
164
+ self.output_router_logits = output_router_logits
165
+ self.router_aux_loss_coef = router_aux_loss_coef
166
+ self.norm_topk_prob = norm_topk_prob
167
+ # Validate the correctness of rotary position embeddings parameters
168
+ # BC: if there is a 'type' field, move it to 'rope_type'.
169
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
170
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
171
+ rope_config_validation(self)
172
+
173
+ super().__init__(
174
+ pad_token_id=pad_token_id,
175
+ bos_token_id=bos_token_id,
176
+ eos_token_id=eos_token_id,
177
+ tie_word_embeddings=tie_word_embeddings,
178
+ **kwargs,
179
+ )
180
+
181
+
182
+ __all__ = ["OlmoeConfig"]
docs/transformers/src/transformers/models/olmoe/convert_olmoe_weights_to_hf.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+ """
13
+ Example for running:
14
+ 0. Cp ckpts to local
15
+ aws s3 cp --recursive s3://ai2-llm/checkpoints/OLMoE/olmoe-8x1b-newhp-newds-final-annealFrom1200000/step23842 /data/niklas/llm/checkpoints/olmoe-8x1b-newhp-newds-final-annealFrom1200000_step23842
16
+ 1. Unshard your OLMoE checkpoint using https://github.com/allenai/OLMo/blob/7d63fe09d23cf23714da5aa633a44a90180195da/scripts/unshard.py
17
+ python OLMo/scripts/unshard.py /data/niklas/llm/checkpoints/23485/step954000 /data/niklas/llm/checkpoints/1b-954000-unsharded --model-only
18
+ python OLMo/scripts/unshard.py /data/niklas/llm/checkpoints/23485/step954000 /data/niklas/llm/checkpoints/1b-954000-unsharded --model-only
19
+ python OLMo/scripts/unshard.py /data/niklas/llm/checkpoints/olmoe-8x1b-newhp-newds-final-annealFrom1200000_step23842 /data/niklas/llm/checkpoints/olmoe-8x1b-newhp-newds-final-annealFrom1200000_step23842-unsharded --model-only
20
+ 2. Convert to transformers
21
+ rm -rf olmoe; mkdir olmoe; python /data/niklas/transformers/src/transformers/models/olmoe/convert_olmoe_weights_to_hf.py --input_dir /data/niklas/llm/checkpoints/olmoe-8x1b-newhp-newds-final-annealFrom1200000_step23842-unsharded --tokenizer_json_path /data/niklas/llm/checkpoints/olmoe-step1200000-unsharded/tokenizer.json --output_dir olmoe
22
+ 3. Load model via:
23
+ ```
24
+ from transformers import OlmoeForCausalLM, AutoTokenizer
25
+ import torch
26
+ model = OlmoeForCausalLM.from_pretrained("../transformers/olmoe", torch_dtype=torch.bfloat16).cuda()
27
+ model = OlmoeForCausalLM.from_pretrained("../transformers/olmoe").cuda()
28
+ tokenizer = AutoTokenizer.from_pretrained("../transformers/olmoe")
29
+ inputs = tokenizer("Bitcoin is", return_tensors="pt")
30
+ inputs = {k: v.cuda() for k, v in inputs.items()}
31
+ out = model.generate(**inputs, max_length=64)
32
+ print(tokenizer.decode(out[0]))
33
+ # > # Bitcoin is a digital currency that is created and held electronically. No one controls it. Bitcoins aren’t printed, like dollars or euros – they’re produced by people and businesses running computers all around the world, using software that solves mathematical
34
+ # Or quick sanity check:
35
+ o = model(torch.tensor([[0, 1]]).cuda())
36
+ # If the checkpoint is not converted to BF16 but kept in FP32:
37
+ # > # Bitcoin is a digital currency that is not controlled by any central authority. It is a peer-to-peer payment system that allows users to send and receive payments from anywhere in the world. Bitcoin is also known as a cryptocurrency because it uses cryptography to secure transactions and prevent fraud.
38
+ ```
39
+
40
+ Note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
41
+ come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
42
+
43
+ Compare with OLMo codebase:
44
+ ```
45
+ from olmo.model import OLMo
46
+ import torch
47
+ model = OLMo.from_checkpoint("/data/niklas/llm/checkpoints/olmoe-step1200000-unsharded-pt")
48
+ model = model.cuda()
49
+ model = model.to(torch.bfloat16)
50
+ from transformers import AutoTokenizer
51
+ tokenizer = AutoTokenizer.from_pretrained("../transformers/olmoe")
52
+ inputs = tokenizer("Bitcoin is", return_tensors="pt")
53
+ inputs = {k: v.cuda() for k, v in inputs.items()}
54
+ out = model.generate(**inputs)
55
+ print(tokenizer.decode(out[0][0][0]))
56
+ # Bitcoin is a digital currency that is created and held electronically. No one controls it. Bitcoins aren’t printed, like dollars or euros – they’re produced by people and businesses running computers all around the world, using software that solves mathematical problems. It’s the first example of a growing category of money
57
+ # Or quick sanity check:
58
+ o = model(torch.tensor([[0, 1]]).cuda())
59
+ ```
60
+ """
61
+
62
+ import argparse
63
+ import gc
64
+ import json
65
+ import os
66
+ import shutil
67
+ from pathlib import Path
68
+
69
+ import torch
70
+ import yaml
71
+ from tokenizers import Tokenizer
72
+
73
+ from transformers import OlmoeConfig, OlmoeForCausalLM
74
+ from transformers.models.gpt_neox.tokenization_gpt_neox_fast import GPTNeoXTokenizerFast
75
+
76
+
77
+ def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
78
+ return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)
79
+
80
+
81
+ def read_json(path):
82
+ with open(path, "r") as f:
83
+ return json.load(f)
84
+
85
+
86
+ def write_json(text, path):
87
+ with open(path, "w") as f:
88
+ json.dump(text, f)
89
+
90
+
91
+ def write_model(model_path, input_base_path, tokenizer_path=None, safe_serialization=True, fix_eos_token_id=True):
92
+ os.makedirs(model_path, exist_ok=True)
93
+ tmp_model_path = os.path.join(model_path, "tmp")
94
+ os.makedirs(tmp_model_path, exist_ok=True)
95
+
96
+ config_path = Path(input_base_path) / "config.yaml"
97
+ olmoe_config = yaml.safe_load(config_path.read_text())["model"]
98
+
99
+ if fix_eos_token_id:
100
+ olmoe_config["eos_token_id"] = 50279
101
+
102
+ n_layers = olmoe_config["n_layers"]
103
+ n_heads = olmoe_config["n_heads"]
104
+ dim = olmoe_config["d_model"]
105
+ dims_per_head = dim // n_heads
106
+ base = 10000.0
107
+ inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
108
+ max_position_embeddings = olmoe_config["max_sequence_length"]
109
+
110
+ vocab_size = olmoe_config.get("embedding_size", olmoe_config["vocab_size"])
111
+
112
+ if olmoe_config.get("n_kv_heads", None) is not None:
113
+ num_key_value_heads = olmoe_config["n_kv_heads"] # for GQA / MQA
114
+ elif olmoe_config["multi_query_attention"]: # compatibility with other checkpoints
115
+ num_key_value_heads = 1
116
+ else:
117
+ num_key_value_heads = n_heads
118
+
119
+ print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
120
+
121
+ # Not sharded
122
+ loaded = torch.load(os.path.join(input_base_path, "model.pt"), map_location="cpu", weights_only=True)
123
+
124
+ param_count = 0
125
+ index_dict = {"weight_map": {}}
126
+ for layer_i in range(n_layers):
127
+ filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
128
+ fused_dims = [dim, dims_per_head * num_key_value_heads, dims_per_head * num_key_value_heads]
129
+ q_proj_weight, k_proj_weight, v_proj_weight = torch.split(
130
+ loaded[f"transformer.blocks.{layer_i}.att_proj.weight"], fused_dims, dim=0
131
+ )
132
+ state_dict = {
133
+ f"model.layers.{layer_i}.self_attn.q_proj.weight": q_proj_weight,
134
+ f"model.layers.{layer_i}.self_attn.k_proj.weight": k_proj_weight,
135
+ f"model.layers.{layer_i}.self_attn.v_proj.weight": v_proj_weight,
136
+ f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"transformer.blocks.{layer_i}.attn_out.weight"],
137
+ f"model.layers.{layer_i}.self_attn.q_norm.weight": loaded[f"transformer.blocks.{layer_i}.q_norm.weight"],
138
+ f"model.layers.{layer_i}.self_attn.k_norm.weight": loaded[f"transformer.blocks.{layer_i}.k_norm.weight"],
139
+ f"model.layers.{layer_i}.mlp.gate.weight": loaded[f"transformer.blocks.{layer_i}.ffn.router.layer.weight"],
140
+ f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"transformer.blocks.{layer_i}.attn_norm.weight"],
141
+ f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[
142
+ f"transformer.blocks.{layer_i}.ff_norm.weight"
143
+ ],
144
+ }
145
+
146
+ num_experts = loaded[f"transformer.blocks.{layer_i}.ffn.router.layer.weight"].shape[0]
147
+ dim_per_expert = loaded[f"transformer.blocks.{layer_i}.ffn.experts.mlp.w1"].shape[0] // num_experts
148
+ for expert_i in range(num_experts):
149
+ state_dict[f"model.layers.{layer_i}.mlp.experts.{expert_i}.gate_proj.weight"] = loaded[
150
+ f"transformer.blocks.{layer_i}.ffn.experts.mlp.w1"
151
+ ][dim_per_expert * expert_i : dim_per_expert * (expert_i + 1), :]
152
+ state_dict[f"model.layers.{layer_i}.mlp.experts.{expert_i}.up_proj.weight"] = loaded[
153
+ f"transformer.blocks.{layer_i}.ffn.experts.mlp.v1"
154
+ ][dim_per_expert * expert_i : dim_per_expert * (expert_i + 1), :]
155
+ state_dict[f"model.layers.{layer_i}.mlp.experts.{expert_i}.down_proj.weight"] = loaded[
156
+ f"transformer.blocks.{layer_i}.ffn.experts.mlp.w2"
157
+ ][dim_per_expert * expert_i : dim_per_expert * (expert_i + 1), :].T.contiguous()
158
+
159
+ state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
160
+
161
+ for k, v in state_dict.items():
162
+ index_dict["weight_map"][k] = filename
163
+ param_count += v.numel()
164
+ torch.save(state_dict, os.path.join(tmp_model_path, filename))
165
+
166
+ filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
167
+
168
+ # Unsharded
169
+ state_dict = {
170
+ "model.embed_tokens.weight": loaded["transformer.wte.weight"],
171
+ "lm_head.weight": loaded["transformer.ff_out.weight"],
172
+ "model.norm.weight": loaded["transformer.ln_f.weight"],
173
+ }
174
+
175
+ for k, v in state_dict.items():
176
+ index_dict["weight_map"][k] = filename
177
+ param_count += v.numel()
178
+ torch.save(state_dict, os.path.join(tmp_model_path, filename))
179
+
180
+ # Write configs
181
+ index_dict["metadata"] = {"total_size": param_count * 2}
182
+ write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
183
+
184
+ config = OlmoeConfig(
185
+ vocab_size=vocab_size,
186
+ hidden_size=dim,
187
+ intermediate_size=dim_per_expert,
188
+ num_hidden_layers=n_layers,
189
+ num_attention_heads=n_heads,
190
+ num_key_value_heads=num_key_value_heads,
191
+ max_position_embeddings=max_position_embeddings,
192
+ pad_token_id=olmoe_config["pad_token_id"],
193
+ bos_token_id=None,
194
+ eos_token_id=olmoe_config["eos_token_id"],
195
+ tie_word_embeddings=olmoe_config["weight_tying"],
196
+ rope_theta=base,
197
+ clip_qkv=olmoe_config.get("clip_qkv"),
198
+ )
199
+ config.save_pretrained(tmp_model_path)
200
+
201
+ # Make space so we can load the model properly now.
202
+ del state_dict
203
+ del loaded
204
+ gc.collect()
205
+
206
+ if tokenizer_path is not None:
207
+ _write_tokenizer(model_path, config, tokenizer_path, fix_eos_token_id)
208
+
209
+ print("Loading the checkpoint in a OLMoE model.")
210
+ model = OlmoeForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.bfloat16)
211
+ # Avoid saving this as part of the config.
212
+ del model.config._name_or_path
213
+ print("Saving in the Transformers format.")
214
+ model.save_pretrained(model_path, safe_serialization=safe_serialization)
215
+ shutil.rmtree(tmp_model_path)
216
+
217
+
218
+ def _write_tokenizer(
219
+ output_path: Path, config: OlmoeConfig, input_tokenizer_path: Path, fix_eos_token_id: bool = True
220
+ ) -> None:
221
+ print(f"Saving a {GPTNeoXTokenizerFast.__name__} to {output_path}.")
222
+
223
+ base_tokenizer = Tokenizer.from_file(str(input_tokenizer_path))
224
+
225
+ eos_token_id = config.eos_token_id if config.eos_token_id is not None else base_tokenizer.get_vocab_size() - 1
226
+ pad_token_id = config.pad_token_id if config.pad_token_id is not None else eos_token_id
227
+
228
+ if fix_eos_token_id and eos_token_id == 0:
229
+ # Fixing a bug in OLMo where eos token id was incorrectly set
230
+ print("Changing eos_token_id from 0 to 50279.")
231
+ eos_token_id = 50279
232
+
233
+ tokenizer = GPTNeoXTokenizerFast(
234
+ tokenizer_object=base_tokenizer,
235
+ eos_token=base_tokenizer.decode([eos_token_id], skip_special_tokens=False),
236
+ pad_token=base_tokenizer.decode([pad_token_id], skip_special_tokens=False),
237
+ unk_token=None,
238
+ bos_token=None,
239
+ )
240
+
241
+ tokenizer.save_pretrained(output_path)
242
+
243
+
244
+ def main():
245
+ parser = argparse.ArgumentParser()
246
+ parser.add_argument(
247
+ "--input_dir",
248
+ required=True,
249
+ help="Location of OLMoE weights, which contains config.yaml and model.pt.",
250
+ )
251
+ parser.add_argument(
252
+ "--tokenizer_json_path",
253
+ default=None,
254
+ help="Location of OLMoE tokenizer json file.",
255
+ )
256
+ parser.add_argument(
257
+ "--output_dir",
258
+ required=True,
259
+ help="Location to write HF model and tokenizer",
260
+ )
261
+ parser.add_argument(
262
+ "--no_fix_eos_token_id",
263
+ action="store_false",
264
+ dest="fix_eos_token_id",
265
+ help="If set, does not change eos token id from 0 to 50279 if it is 0. Changing 0 to 50279 is a bug fix, so use this option with care.",
266
+ )
267
+ parser.add_argument(
268
+ "--safe_serialization", type=bool, default=True, help="Whether or not to save using `safetensors`."
269
+ )
270
+ args = parser.parse_args()
271
+ write_model(
272
+ model_path=args.output_dir,
273
+ input_base_path=args.input_dir,
274
+ safe_serialization=args.safe_serialization,
275
+ tokenizer_path=args.tokenizer_json_path,
276
+ fix_eos_token_id=args.fix_eos_token_id,
277
+ )
278
+
279
+
280
+ if __name__ == "__main__":
281
+ main()
docs/transformers/src/transformers/models/olmoe/modeling_olmoe.py ADDED
@@ -0,0 +1,1273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+ """PyTorch OLMoE model."""
13
+
14
+ import math
15
+ from typing import List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ import torch.utils.checkpoint
20
+ from torch import nn
21
+
22
+ from ...activations import ACT2FN
23
+ from ...cache_utils import Cache, DynamicCache, StaticCache
24
+ from ...generation import GenerationMixin
25
+ from ...modeling_attn_mask_utils import AttentionMaskConverter
26
+ from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
27
+ from ...modeling_outputs import (
28
+ MoeCausalLMOutputWithPast,
29
+ MoeModelOutputWithPast,
30
+ )
31
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
32
+ from ...modeling_utils import PreTrainedModel
33
+ from ...pytorch_utils import ALL_LAYERNORM_LAYERS
34
+ from ...utils import (
35
+ add_start_docstrings,
36
+ add_start_docstrings_to_model_forward,
37
+ logging,
38
+ replace_return_docstrings,
39
+ )
40
+ from .configuration_olmoe import OlmoeConfig
41
+
42
+
43
+ if is_flash_attn_available():
44
+ from ...modeling_flash_attention_utils import _flash_attention_forward
45
+
46
+
47
+ logger = logging.get_logger(__name__)
48
+
49
+ _CONFIG_FOR_DOC = "OlmoeConfig"
50
+
51
+
52
+ # Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func
53
+ def load_balancing_loss_func(
54
+ gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None],
55
+ num_experts: Optional[int] = None,
56
+ top_k=2,
57
+ attention_mask: Optional[torch.Tensor] = None,
58
+ ) -> Union[torch.Tensor, int]:
59
+ r"""
60
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
61
+
62
+ See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
63
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
64
+ experts is too unbalanced.
65
+
66
+ Args:
67
+ gate_logits:
68
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
69
+ shape [batch_size X sequence_length, num_experts].
70
+ num_experts:
71
+ Number of experts
72
+ top_k:
73
+ The number of experts to route per-token, can be also interpreted as the `top-k` routing
74
+ parameter.
75
+ attention_mask (`torch.Tensor`, *optional*):
76
+ The attention_mask used in forward function
77
+ shape [batch_size X sequence_length] if not None.
78
+
79
+ Returns:
80
+ The auxiliary loss.
81
+ """
82
+ if gate_logits is None or not isinstance(gate_logits, tuple):
83
+ return 0
84
+
85
+ if isinstance(gate_logits, tuple):
86
+ compute_device = gate_logits[0].device
87
+ concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
88
+
89
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
90
+
91
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
92
+
93
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
94
+
95
+ if attention_mask is None:
96
+ # Compute the percentage of tokens routed to each experts
97
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
98
+
99
+ # Compute the average probability of routing to these experts
100
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
101
+ else:
102
+ batch_size, sequence_length = attention_mask.shape
103
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
104
+
105
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
106
+ expert_attention_mask = (
107
+ attention_mask[None, :, :, None, None]
108
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
109
+ .reshape(-1, top_k, num_experts)
110
+ .to(compute_device)
111
+ )
112
+
113
+ # Compute the percentage of tokens routed to each experts
114
+ tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
115
+ expert_attention_mask, dim=0
116
+ )
117
+
118
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
119
+ router_per_expert_attention_mask = (
120
+ attention_mask[None, :, :, None]
121
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
122
+ .reshape(-1, num_experts)
123
+ .to(compute_device)
124
+ )
125
+
126
+ # Compute the average probability of routing to these experts
127
+ router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
128
+ router_per_expert_attention_mask, dim=0
129
+ )
130
+
131
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
132
+ return overall_loss * num_experts
133
+
134
+
135
+ class OlmoeRMSNorm(nn.Module):
136
+ def __init__(self, hidden_size, eps=1e-5):
137
+ """
138
+ OlmoeRMSNorm is equivalent to T5LayerNorm
139
+ """
140
+ super().__init__()
141
+ self.weight = nn.Parameter(torch.ones(hidden_size))
142
+ self.variance_epsilon = eps
143
+
144
+ def forward(self, hidden_states):
145
+ input_dtype = hidden_states.dtype
146
+ hidden_states = hidden_states.to(torch.float32)
147
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
148
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
149
+ return self.weight * hidden_states.to(input_dtype)
150
+
151
+ def extra_repr(self):
152
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
153
+
154
+
155
+ ALL_LAYERNORM_LAYERS.append(OlmoeRMSNorm)
156
+
157
+
158
+ # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Olmoe
159
+ class OlmoeRotaryEmbedding(nn.Module):
160
+ def __init__(self, config: OlmoeConfig, device=None):
161
+ super().__init__()
162
+ # BC: "rope_type" was originally "type"
163
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
164
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
165
+ else:
166
+ self.rope_type = "default"
167
+ self.max_seq_len_cached = config.max_position_embeddings
168
+ self.original_max_seq_len = config.max_position_embeddings
169
+
170
+ self.config = config
171
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
172
+
173
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
174
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
175
+ self.original_inv_freq = self.inv_freq
176
+
177
+ @torch.no_grad()
178
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
179
+ def forward(self, x, position_ids):
180
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
181
+ position_ids_expanded = position_ids[:, None, :].float()
182
+
183
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
184
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
185
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
186
+ emb = torch.cat((freqs, freqs), dim=-1)
187
+ cos = emb.cos() * self.attention_scaling
188
+ sin = emb.sin() * self.attention_scaling
189
+
190
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
191
+
192
+
193
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
194
+ def rotate_half(x):
195
+ """Rotates half the hidden dims of the input."""
196
+ x1 = x[..., : x.shape[-1] // 2]
197
+ x2 = x[..., x.shape[-1] // 2 :]
198
+ return torch.cat((-x2, x1), dim=-1)
199
+
200
+
201
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
202
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
203
+ """Applies Rotary Position Embedding to the query and key tensors.
204
+
205
+ Args:
206
+ q (`torch.Tensor`): The query tensor.
207
+ k (`torch.Tensor`): The key tensor.
208
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
209
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
210
+ position_ids (`torch.Tensor`, *optional*):
211
+ Deprecated and unused.
212
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
213
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
214
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
215
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
216
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
217
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
218
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
219
+ Returns:
220
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
221
+ """
222
+ cos = cos.unsqueeze(unsqueeze_dim)
223
+ sin = sin.unsqueeze(unsqueeze_dim)
224
+ q_embed = (q * cos) + (rotate_half(q) * sin)
225
+ k_embed = (k * cos) + (rotate_half(k) * sin)
226
+ return q_embed, k_embed
227
+
228
+
229
+ # Copied from transformers.models.olmo.modeling_olmo.OlmoMLP with Olmo->Olmoe
230
+ class OlmoeMLP(nn.Module):
231
+ def __init__(self, config):
232
+ super().__init__()
233
+ self.config = config
234
+ self.hidden_size = config.hidden_size
235
+ self.intermediate_size = config.intermediate_size
236
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
237
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
238
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
239
+ self.act_fn = ACT2FN[config.hidden_act]
240
+
241
+ def forward(self, x):
242
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
243
+ return down_proj
244
+
245
+
246
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
247
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
248
+ """
249
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
250
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
251
+ """
252
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
253
+ if n_rep == 1:
254
+ return hidden_states
255
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
256
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
257
+
258
+
259
+ class OlmoeAttention(nn.Module):
260
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
261
+
262
+ def __init__(self, config: OlmoeConfig, layer_idx: Optional[int] = None):
263
+ super().__init__()
264
+ self.config = config
265
+ self.layer_idx = layer_idx
266
+ if layer_idx is None:
267
+ logger.warning_once(
268
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
269
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
270
+ "when creating this class."
271
+ )
272
+
273
+ self.attention_dropout = config.attention_dropout
274
+ self.hidden_size = config.hidden_size
275
+ self.num_heads = config.num_attention_heads
276
+ self.head_dim = self.hidden_size // self.num_heads
277
+ self.num_key_value_heads = config.num_key_value_heads
278
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
279
+ self.max_position_embeddings = config.max_position_embeddings
280
+ self.rope_theta = config.rope_theta
281
+ self.is_causal = True
282
+
283
+ if (self.head_dim * self.num_heads) != self.hidden_size:
284
+ raise ValueError(
285
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
286
+ f" and `num_heads`: {self.num_heads})."
287
+ )
288
+
289
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
290
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
291
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
292
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
293
+ self.q_norm = OlmoeRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
294
+ self.k_norm = OlmoeRMSNorm(
295
+ (self.hidden_size // self.num_heads) * self.num_key_value_heads, eps=config.rms_norm_eps
296
+ )
297
+
298
+ def forward(
299
+ self,
300
+ hidden_states: torch.Tensor,
301
+ attention_mask: Optional[torch.Tensor] = None,
302
+ position_ids: Optional[torch.LongTensor] = None,
303
+ past_key_value: Optional[Cache] = None,
304
+ output_attentions: bool = False,
305
+ use_cache: bool = False,
306
+ cache_position: Optional[torch.LongTensor] = None,
307
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
308
+ **kwargs,
309
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
310
+ bsz, q_len, _ = hidden_states.size()
311
+
312
+ query_states = self.q_norm(self.q_proj(hidden_states))
313
+ key_states = self.k_norm(self.k_proj(hidden_states))
314
+ value_states = self.v_proj(hidden_states)
315
+
316
+ if self.config.clip_qkv is not None:
317
+ query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
318
+ key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
319
+ value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
320
+
321
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
322
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
323
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
324
+
325
+ cos, sin = position_embeddings
326
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
327
+
328
+ if past_key_value is not None:
329
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
330
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
331
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
332
+
333
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
334
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
335
+
336
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
337
+
338
+ if attention_mask is not None: # no matter the length, we just slice it
339
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
340
+ attn_weights = attn_weights + causal_mask
341
+
342
+ # upcast attention to fp32
343
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
344
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
345
+ attn_output = torch.matmul(attn_weights, value_states)
346
+
347
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
348
+ raise ValueError(
349
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
350
+ f" {attn_output.size()}"
351
+ )
352
+
353
+ attn_output = attn_output.transpose(1, 2).contiguous()
354
+
355
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
356
+
357
+ attn_output = self.o_proj(attn_output)
358
+
359
+ if not output_attentions:
360
+ attn_weights = None
361
+
362
+ return attn_output, attn_weights, past_key_value
363
+
364
+
365
+ class OlmoeFlashAttention2(OlmoeAttention):
366
+ """
367
+ OLMoE flash attention module. This module inherits from `OlmoeAttention` as the weights of the module stays
368
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
369
+ flash attention and deal with padding tokens in case the input contains any of them.
370
+ """
371
+
372
+ def __init__(self, *args, **kwargs):
373
+ super().__init__(*args, **kwargs)
374
+
375
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
376
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
377
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
378
+ self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
379
+
380
+ def forward(
381
+ self,
382
+ hidden_states: torch.Tensor,
383
+ attention_mask: Optional[torch.LongTensor] = None,
384
+ position_ids: Optional[torch.LongTensor] = None,
385
+ past_key_value: Optional[Cache] = None,
386
+ output_attentions: bool = False,
387
+ use_cache: bool = False,
388
+ cache_position: Optional[torch.LongTensor] = None,
389
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
390
+ **kwargs,
391
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
392
+ output_attentions = False
393
+
394
+ bsz, q_len, _ = hidden_states.size()
395
+
396
+ query_states = self.q_norm(self.q_proj(hidden_states))
397
+ key_states = self.k_norm(self.k_proj(hidden_states))
398
+ value_states = self.v_proj(hidden_states)
399
+ if self.config.clip_qkv is not None:
400
+ query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
401
+ key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
402
+ value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
403
+
404
+ # Flash attention requires the input to have the shape
405
+ # batch_size x seq_length x head_dim x hidden_dim
406
+ # therefore we just need to keep the original shape
407
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
408
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
409
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
410
+
411
+ cos, sin = position_embeddings
412
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
413
+
414
+ if past_key_value is not None:
415
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
416
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
417
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
418
+
419
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
420
+ # to be able to avoid many of these transpose/reshape/view.
421
+ query_states = query_states.transpose(1, 2)
422
+ key_states = key_states.transpose(1, 2)
423
+ value_states = value_states.transpose(1, 2)
424
+
425
+ dropout_rate = self.attention_dropout if self.training else 0.0
426
+
427
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
428
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
429
+ # cast them back in the correct dtype just to be sure everything works as expected.
430
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
431
+ # in fp32. (OlmoeRMSNorm handles it correctly)
432
+
433
+ input_dtype = query_states.dtype
434
+ if input_dtype == torch.float32:
435
+ if torch.is_autocast_enabled():
436
+ target_dtype = torch.get_autocast_gpu_dtype()
437
+ # Handle the case where the model is quantized
438
+ elif hasattr(self.config, "_pre_quantization_dtype"):
439
+ target_dtype = self.config._pre_quantization_dtype
440
+ else:
441
+ target_dtype = self.q_proj.weight.dtype
442
+
443
+ logger.warning_once(
444
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
445
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
446
+ f" {target_dtype}."
447
+ )
448
+
449
+ query_states = query_states.to(target_dtype)
450
+ key_states = key_states.to(target_dtype)
451
+ value_states = value_states.to(target_dtype)
452
+
453
+ attn_output = _flash_attention_forward(
454
+ query_states,
455
+ key_states,
456
+ value_states,
457
+ attention_mask,
458
+ q_len,
459
+ dropout=dropout_rate,
460
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
461
+ is_causal=self.is_causal,
462
+ )
463
+
464
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
465
+ attn_output = self.o_proj(attn_output)
466
+
467
+ if not output_attentions:
468
+ attn_weights = None
469
+
470
+ return attn_output, attn_weights, past_key_value
471
+
472
+
473
+ class OlmoeSdpaAttention(OlmoeAttention):
474
+ """
475
+ OLMoE attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
476
+ `OlmoeAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
477
+ SDPA API.
478
+ """
479
+
480
+ # Adapted from OlmoeAttention.forward
481
+ def forward(
482
+ self,
483
+ hidden_states: torch.Tensor,
484
+ attention_mask: Optional[torch.Tensor] = None,
485
+ position_ids: Optional[torch.LongTensor] = None,
486
+ past_key_value: Optional[Cache] = None,
487
+ output_attentions: bool = False,
488
+ use_cache: bool = False,
489
+ cache_position: Optional[torch.LongTensor] = None,
490
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
491
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
492
+ if output_attentions:
493
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
494
+ logger.warning_once(
495
+ "OlmoeModel is using OlmoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
496
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
497
+ )
498
+ return super().forward(
499
+ hidden_states=hidden_states,
500
+ attention_mask=attention_mask,
501
+ position_ids=position_ids,
502
+ past_key_value=past_key_value,
503
+ output_attentions=output_attentions,
504
+ use_cache=use_cache,
505
+ cache_position=cache_position,
506
+ position_embeddings=position_embeddings,
507
+ )
508
+
509
+ bsz, q_len, _ = hidden_states.size()
510
+
511
+ query_states = self.q_norm(self.q_proj(hidden_states))
512
+ key_states = self.k_norm(self.k_proj(hidden_states))
513
+ value_states = self.v_proj(hidden_states)
514
+
515
+ if self.config.clip_qkv is not None:
516
+ query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
517
+ key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
518
+ value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
519
+
520
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
521
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
522
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
523
+
524
+ cos, sin = position_embeddings
525
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
526
+
527
+ if past_key_value is not None:
528
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
529
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
530
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
531
+
532
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
533
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
534
+
535
+ causal_mask = attention_mask
536
+ # if attention_mask is not None and cache_position is not None:
537
+ if attention_mask is not None:
538
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
539
+
540
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
541
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
542
+ if query_states.device.type == "cuda" and causal_mask is not None:
543
+ query_states = query_states.contiguous()
544
+ key_states = key_states.contiguous()
545
+ value_states = value_states.contiguous()
546
+
547
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
548
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
549
+ is_causal = True if causal_mask is None and q_len > 1 else False
550
+
551
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
552
+ query_states,
553
+ key_states,
554
+ value_states,
555
+ attn_mask=causal_mask,
556
+ dropout_p=self.attention_dropout if self.training else 0.0,
557
+ is_causal=is_causal,
558
+ )
559
+
560
+ attn_output = attn_output.transpose(1, 2).contiguous()
561
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
562
+
563
+ attn_output = self.o_proj(attn_output)
564
+
565
+ return attn_output, None, past_key_value
566
+
567
+
568
+ OLMOE_ATTENTION_CLASSES = {
569
+ "eager": OlmoeAttention,
570
+ "flash_attention_2": OlmoeFlashAttention2,
571
+ "sdpa": OlmoeSdpaAttention,
572
+ }
573
+
574
+
575
+ class OlmoeSparseMoeBlock(nn.Module):
576
+ def __init__(self, config):
577
+ super().__init__()
578
+ self.num_experts = config.num_experts
579
+ self.top_k = config.num_experts_per_tok
580
+ self.norm_topk_prob = config.norm_topk_prob
581
+ self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
582
+ self.experts = nn.ModuleList([OlmoeMLP(config) for _ in range(self.num_experts)])
583
+
584
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
585
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
586
+ hidden_states = hidden_states.view(-1, hidden_dim)
587
+ # router_logits: (batch * sequence_length, n_experts)
588
+ router_logits = self.gate(hidden_states)
589
+
590
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
591
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
592
+ if self.norm_topk_prob:
593
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
594
+ # we cast back to the input dtype
595
+ routing_weights = routing_weights.to(hidden_states.dtype)
596
+
597
+ final_hidden_states = torch.zeros(
598
+ (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
599
+ )
600
+
601
+ # One hot encode the selected experts to create an expert mask
602
+ # this will be used to easily index which expert is going to be selected
603
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
604
+
605
+ # Loop over all available experts in the model and perform the computation on each expert
606
+ for expert_idx in range(self.num_experts):
607
+ expert_layer = self.experts[expert_idx]
608
+ idx, top_x = torch.where(expert_mask[expert_idx])
609
+
610
+ # Index the correct hidden states and compute the expert hidden state for
611
+ # the current expert. We need to make sure to multiply the output hidden
612
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
613
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
614
+ current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
615
+
616
+ # However `index_add_` only support torch tensors for indexing so we'll use
617
+ # the `top_x` tensor here.
618
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
619
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
620
+ return final_hidden_states, router_logits
621
+
622
+
623
+ class OlmoeDecoderLayer(nn.Module):
624
+ def __init__(self, config: OlmoeConfig, layer_idx: int):
625
+ super().__init__()
626
+ self.hidden_size = config.hidden_size
627
+
628
+ self.self_attn = OLMOE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
629
+
630
+ self.mlp = OlmoeSparseMoeBlock(config)
631
+ self.input_layernorm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
632
+ self.post_attention_layernorm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
633
+
634
+ def forward(
635
+ self,
636
+ hidden_states: torch.Tensor,
637
+ attention_mask: Optional[torch.Tensor] = None,
638
+ position_ids: Optional[torch.LongTensor] = None,
639
+ past_key_value: Optional[Cache] = None,
640
+ output_attentions: Optional[bool] = False,
641
+ output_router_logits: Optional[bool] = False,
642
+ use_cache: Optional[bool] = False,
643
+ cache_position: Optional[torch.LongTensor] = None,
644
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
645
+ **kwargs,
646
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
647
+ """
648
+ Args:
649
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
650
+ attention_mask (`torch.FloatTensor`, *optional*):
651
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
652
+ query_sequence_length, key_sequence_length)` if default attention is used.
653
+ output_attentions (`bool`, *optional*):
654
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
655
+ returned tensors for more detail.
656
+ output_router_logits (`bool`, *optional*):
657
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss,
658
+ and should not be returned during inference.
659
+ use_cache (`bool`, *optional*):
660
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
661
+ (see `past_key_values`).
662
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
663
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
664
+ Indices depicting the position of the input sequence tokens in the sequence
665
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
666
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
667
+ with `head_dim` being the embedding dimension of each attention head.
668
+ kwargs (`dict`, *optional*):
669
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
670
+ into the model
671
+ """
672
+ residual = hidden_states
673
+
674
+ hidden_states = self.input_layernorm(hidden_states)
675
+
676
+ # Self Attention
677
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
678
+ hidden_states=hidden_states,
679
+ attention_mask=attention_mask,
680
+ position_ids=position_ids,
681
+ past_key_value=past_key_value,
682
+ output_attentions=output_attentions,
683
+ use_cache=use_cache,
684
+ cache_position=cache_position,
685
+ position_embeddings=position_embeddings,
686
+ **kwargs,
687
+ )
688
+ hidden_states = residual + hidden_states
689
+
690
+ # Fully Connected
691
+ residual = hidden_states
692
+ hidden_states = self.post_attention_layernorm(hidden_states)
693
+ hidden_states, router_logits = self.mlp(hidden_states)
694
+ hidden_states = residual + hidden_states
695
+
696
+ outputs = (hidden_states,)
697
+
698
+ if output_attentions:
699
+ outputs += (self_attn_weights,)
700
+
701
+ if use_cache:
702
+ outputs += (present_key_value,)
703
+
704
+ if output_router_logits:
705
+ outputs += (router_logits,)
706
+
707
+ return outputs
708
+
709
+
710
+ OLMOE_START_DOCSTRING = r"""
711
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
712
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
713
+ etc.)
714
+
715
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
716
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
717
+ and behavior.
718
+
719
+ Parameters:
720
+ config ([`OlmoeConfig`]):
721
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
722
+ load the weights associated with the model, only the configuration. Check out the
723
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
724
+ """
725
+
726
+
727
+ @add_start_docstrings(
728
+ "The bare Olmoe Model outputting raw hidden-states without any specific head on top.",
729
+ OLMOE_START_DOCSTRING,
730
+ )
731
+ class OlmoePreTrainedModel(PreTrainedModel):
732
+ config_class = OlmoeConfig
733
+ base_model_prefix = "model"
734
+ supports_gradient_checkpointing = True
735
+ _no_split_modules = ["OlmoeDecoderLayer"]
736
+ _skip_keys_device_placement = ["past_key_values"]
737
+ _supports_flash_attn_2 = True
738
+ _supports_sdpa = True
739
+ _supports_cache_class = True
740
+ _supports_quantized_cache = True
741
+ _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
742
+
743
+ def _init_weights(self, module):
744
+ std = self.config.initializer_range
745
+ if isinstance(module, nn.Linear):
746
+ module.weight.data.normal_(mean=0.0, std=std)
747
+ if module.bias is not None:
748
+ module.bias.data.zero_()
749
+ elif isinstance(module, OlmoeRMSNorm):
750
+ module.weight.data.fill_(1.0)
751
+ elif isinstance(module, nn.Embedding):
752
+ module.weight.data.normal_(mean=0.0, std=std)
753
+ if module.padding_idx is not None:
754
+ module.weight.data[module.padding_idx].zero_()
755
+
756
+
757
+ OLMOE_INPUTS_DOCSTRING = r"""
758
+ Args:
759
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
760
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
761
+ it.
762
+
763
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
764
+ [`PreTrainedTokenizer.__call__`] for details.
765
+
766
+ [What are input IDs?](../glossary#input-ids)
767
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
768
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
769
+
770
+ - 1 for tokens that are **not masked**,
771
+ - 0 for tokens that are **masked**.
772
+
773
+ [What are attention masks?](../glossary#attention-mask)
774
+
775
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
776
+ [`PreTrainedTokenizer.__call__`] for details.
777
+
778
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
779
+ `past_key_values`).
780
+
781
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
782
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
783
+ information on the default strategy.
784
+
785
+ - 1 indicates the head is **not masked**,
786
+ - 0 indicates the head is **masked**.
787
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
788
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
789
+ config.n_positions - 1]`.
790
+
791
+ [What are position IDs?](../glossary#position-ids)
792
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
793
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
794
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
795
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
796
+
797
+ Two formats are allowed:
798
+ - a [`~cache_utils.Cache`] instance;
799
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
800
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
801
+ cache format.
802
+
803
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
804
+ legacy cache format will be returned.
805
+
806
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
807
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
808
+ of shape `(batch_size, sequence_length)`.
809
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
810
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
811
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
812
+ model's internal embedding lookup matrix.
813
+ use_cache (`bool`, *optional*):
814
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
815
+ `past_key_values`).
816
+ output_attentions (`bool`, *optional*):
817
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
818
+ tensors for more detail.
819
+ output_hidden_states (`bool`, *optional*):
820
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
821
+ more detail.
822
+ output_router_logits (`bool`, *optional*):
823
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
824
+ should not be returned during inference.
825
+ return_dict (`bool`, *optional*):
826
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
827
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
828
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
829
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
830
+ the complete sequence length.
831
+ """
832
+
833
+
834
+ @add_start_docstrings(
835
+ "The bare Olmoe Model outputting raw hidden-states without any specific head on top.",
836
+ OLMOE_START_DOCSTRING,
837
+ )
838
+ # TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->Olmoe
839
+ class OlmoeModel(OlmoePreTrainedModel):
840
+ """
841
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OlmoeDecoderLayer`]
842
+
843
+ Args:
844
+ config: OlmoeConfig
845
+ """
846
+
847
+ def __init__(self, config: OlmoeConfig):
848
+ super().__init__(config)
849
+ self.padding_idx = config.pad_token_id
850
+ self.vocab_size = config.vocab_size
851
+
852
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
853
+ self.layers = nn.ModuleList(
854
+ [OlmoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
855
+ )
856
+ self.norm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
857
+ self.rotary_emb = OlmoeRotaryEmbedding(config=config)
858
+ self.gradient_checkpointing = False
859
+
860
+ # Initialize weights and apply final processing
861
+ self.post_init()
862
+
863
+ def get_input_embeddings(self):
864
+ return self.embed_tokens
865
+
866
+ def set_input_embeddings(self, value):
867
+ self.embed_tokens = value
868
+
869
+ @add_start_docstrings_to_model_forward(OLMOE_INPUTS_DOCSTRING)
870
+ # Ignore copy
871
+ def forward(
872
+ self,
873
+ input_ids: Optional[torch.LongTensor] = None,
874
+ attention_mask: Optional[torch.Tensor] = None,
875
+ position_ids: Optional[torch.LongTensor] = None,
876
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
877
+ inputs_embeds: Optional[torch.FloatTensor] = None,
878
+ use_cache: Optional[bool] = None,
879
+ output_attentions: Optional[bool] = None,
880
+ output_hidden_states: Optional[bool] = None,
881
+ output_router_logits: Optional[bool] = None,
882
+ return_dict: Optional[bool] = None,
883
+ cache_position: Optional[torch.LongTensor] = None,
884
+ ) -> Union[Tuple, MoeModelOutputWithPast]:
885
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
886
+ output_router_logits = (
887
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
888
+ )
889
+ output_hidden_states = (
890
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
891
+ )
892
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
893
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
894
+
895
+ if (input_ids is None) ^ (inputs_embeds is not None):
896
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
897
+
898
+ if self.gradient_checkpointing and self.training and use_cache:
899
+ logger.warning_once(
900
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
901
+ )
902
+ use_cache = False
903
+
904
+ if inputs_embeds is None:
905
+ inputs_embeds = self.embed_tokens(input_ids)
906
+
907
+ # kept for BC (non `Cache` `past_key_values` inputs)
908
+ return_legacy_cache = False
909
+ if use_cache and not isinstance(past_key_values, Cache):
910
+ return_legacy_cache = True
911
+ if past_key_values is None:
912
+ past_key_values = DynamicCache()
913
+ else:
914
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
915
+ logger.warning_once(
916
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
917
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
918
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
919
+ )
920
+
921
+ if cache_position is None:
922
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
923
+ cache_position = torch.arange(
924
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
925
+ )
926
+ if position_ids is None:
927
+ position_ids = cache_position.unsqueeze(0)
928
+
929
+ causal_mask = self._update_causal_mask(
930
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
931
+ )
932
+
933
+ # embed positions
934
+ hidden_states = inputs_embeds
935
+
936
+ # create position embeddings to be shared across the decoder layers
937
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
938
+
939
+ # decoder layers
940
+ all_hidden_states = () if output_hidden_states else None
941
+ all_self_attns = () if output_attentions else None
942
+ all_router_logits = () if output_router_logits else None
943
+ next_decoder_cache = None
944
+
945
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
946
+ if output_hidden_states:
947
+ all_hidden_states += (hidden_states,)
948
+
949
+ if self.gradient_checkpointing and self.training:
950
+ layer_outputs = self._gradient_checkpointing_func(
951
+ decoder_layer.__call__,
952
+ hidden_states,
953
+ causal_mask,
954
+ position_ids,
955
+ past_key_values,
956
+ output_attentions,
957
+ output_router_logits,
958
+ use_cache,
959
+ cache_position,
960
+ position_embeddings,
961
+ )
962
+ else:
963
+ layer_outputs = decoder_layer(
964
+ hidden_states,
965
+ attention_mask=causal_mask,
966
+ position_ids=position_ids,
967
+ past_key_value=past_key_values,
968
+ output_attentions=output_attentions,
969
+ output_router_logits=output_router_logits,
970
+ use_cache=use_cache,
971
+ cache_position=cache_position,
972
+ position_embeddings=position_embeddings,
973
+ )
974
+
975
+ hidden_states = layer_outputs[0]
976
+
977
+ if use_cache:
978
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
979
+
980
+ if output_attentions:
981
+ all_self_attns += (layer_outputs[1],)
982
+
983
+ if output_router_logits and layer_outputs[-1] is not None:
984
+ all_router_logits += (layer_outputs[-1],)
985
+
986
+ hidden_states = self.norm(hidden_states)
987
+
988
+ # add hidden states from the last decoder layer
989
+ if output_hidden_states:
990
+ all_hidden_states += (hidden_states,)
991
+
992
+ next_cache = next_decoder_cache if use_cache else None
993
+ if return_legacy_cache:
994
+ next_cache = next_cache.to_legacy_cache()
995
+
996
+ if not return_dict:
997
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
998
+ return MoeModelOutputWithPast(
999
+ last_hidden_state=hidden_states,
1000
+ past_key_values=next_cache,
1001
+ hidden_states=all_hidden_states,
1002
+ attentions=all_self_attns,
1003
+ router_logits=all_router_logits,
1004
+ )
1005
+
1006
+ def _update_causal_mask(
1007
+ self,
1008
+ attention_mask: torch.Tensor,
1009
+ input_tensor: torch.Tensor,
1010
+ cache_position: torch.Tensor,
1011
+ past_key_values: Cache,
1012
+ output_attentions: bool,
1013
+ ):
1014
+ if self.config._attn_implementation == "flash_attention_2":
1015
+ if attention_mask is not None and 0.0 in attention_mask:
1016
+ return attention_mask
1017
+ return None
1018
+
1019
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1020
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1021
+ # to infer the attention mask.
1022
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1023
+ using_static_cache = isinstance(past_key_values, StaticCache)
1024
+
1025
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1026
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
1027
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
1028
+ attention_mask,
1029
+ inputs_embeds=input_tensor,
1030
+ past_key_values_length=past_seen_tokens,
1031
+ is_training=self.training,
1032
+ ):
1033
+ return None
1034
+
1035
+ dtype, device = input_tensor.dtype, input_tensor.device
1036
+ sequence_length = input_tensor.shape[1]
1037
+ if using_static_cache:
1038
+ target_length = past_key_values.get_max_cache_shape()
1039
+ else:
1040
+ target_length = (
1041
+ attention_mask.shape[-1]
1042
+ if isinstance(attention_mask, torch.Tensor)
1043
+ else past_seen_tokens + sequence_length + 1
1044
+ )
1045
+
1046
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1047
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
1048
+ attention_mask,
1049
+ sequence_length=sequence_length,
1050
+ target_length=target_length,
1051
+ dtype=dtype,
1052
+ device=device,
1053
+ cache_position=cache_position,
1054
+ batch_size=input_tensor.shape[0],
1055
+ )
1056
+
1057
+ if (
1058
+ self.config._attn_implementation == "sdpa"
1059
+ and attention_mask is not None
1060
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
1061
+ and not output_attentions
1062
+ ):
1063
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1064
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1065
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1066
+ min_dtype = torch.finfo(dtype).min
1067
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1068
+
1069
+ return causal_mask
1070
+
1071
+ @staticmethod
1072
+ def _prepare_4d_causal_attention_mask_with_cache_position(
1073
+ attention_mask: torch.Tensor,
1074
+ sequence_length: int,
1075
+ target_length: int,
1076
+ dtype: torch.dtype,
1077
+ device: torch.device,
1078
+ cache_position: torch.Tensor,
1079
+ batch_size: int,
1080
+ **kwargs,
1081
+ ):
1082
+ """
1083
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
1084
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
1085
+
1086
+ Args:
1087
+ attention_mask (`torch.Tensor`):
1088
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
1089
+ `(batch_size, 1, query_length, key_value_length)`.
1090
+ sequence_length (`int`):
1091
+ The sequence length being processed.
1092
+ target_length (`int`):
1093
+ The target length: when generating with static cache, the mask should be as long as the static cache,
1094
+ to account for the 0 padding, the part of the cache that is not filled yet.
1095
+ dtype (`torch.dtype`):
1096
+ The dtype to use for the 4D attention mask.
1097
+ device (`torch.device`):
1098
+ The device to place the 4D attention mask on.
1099
+ cache_position (`torch.Tensor`):
1100
+ Indices depicting the position of the input sequence tokens in the sequence.
1101
+ batch_size (`torch.Tensor`):
1102
+ Batch size.
1103
+ """
1104
+ if attention_mask is not None and attention_mask.dim() == 4:
1105
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
1106
+ causal_mask = attention_mask
1107
+ else:
1108
+ min_dtype = torch.finfo(dtype).min
1109
+ causal_mask = torch.full(
1110
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
1111
+ )
1112
+ if sequence_length != 1:
1113
+ causal_mask = torch.triu(causal_mask, diagonal=1)
1114
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1115
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
1116
+ if attention_mask is not None:
1117
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1118
+ mask_length = attention_mask.shape[-1]
1119
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
1120
+ padding_mask = padding_mask == 0
1121
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
1122
+ padding_mask, min_dtype
1123
+ )
1124
+
1125
+ return causal_mask
1126
+
1127
+
1128
+ class OlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin):
1129
+ _tied_weights_keys = ["lm_head.weight"]
1130
+
1131
+ def __init__(self, config):
1132
+ super().__init__(config)
1133
+ self.model = OlmoeModel(config)
1134
+ self.vocab_size = config.vocab_size
1135
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1136
+
1137
+ self.router_aux_loss_coef = config.router_aux_loss_coef
1138
+ self.num_experts = config.num_experts
1139
+ self.num_experts_per_tok = config.num_experts_per_tok
1140
+ # Initialize weights and apply final processing
1141
+ self.post_init()
1142
+
1143
+ def get_input_embeddings(self):
1144
+ return self.model.embed_tokens
1145
+
1146
+ def set_input_embeddings(self, value):
1147
+ self.model.embed_tokens = value
1148
+
1149
+ def get_output_embeddings(self):
1150
+ return self.lm_head
1151
+
1152
+ def set_output_embeddings(self, new_embeddings):
1153
+ self.lm_head = new_embeddings
1154
+
1155
+ def set_decoder(self, decoder):
1156
+ self.model = decoder
1157
+
1158
+ def get_decoder(self):
1159
+ return self.model
1160
+
1161
+ @add_start_docstrings_to_model_forward(OLMOE_INPUTS_DOCSTRING)
1162
+ @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1163
+ def forward(
1164
+ self,
1165
+ input_ids: Optional[torch.LongTensor] = None,
1166
+ attention_mask: Optional[torch.Tensor] = None,
1167
+ position_ids: Optional[torch.LongTensor] = None,
1168
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1169
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1170
+ labels: Optional[torch.LongTensor] = None,
1171
+ use_cache: Optional[bool] = None,
1172
+ output_attentions: Optional[bool] = None,
1173
+ output_hidden_states: Optional[bool] = None,
1174
+ output_router_logits: Optional[bool] = None,
1175
+ return_dict: Optional[bool] = None,
1176
+ cache_position: Optional[torch.LongTensor] = None,
1177
+ logits_to_keep: Union[int, torch.Tensor] = 0,
1178
+ **loss_kwargs,
1179
+ ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
1180
+ r"""
1181
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1182
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1183
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1184
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1185
+
1186
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
1187
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
1188
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
1189
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
1190
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
1191
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
1192
+
1193
+ Returns:
1194
+
1195
+ Example:
1196
+
1197
+ ```python
1198
+ >>> from transformers import AutoTokenizer, OlmoeForCausalLM
1199
+
1200
+ >>> model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924")
1201
+ >>> tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924")
1202
+
1203
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1204
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1205
+
1206
+ >>> # Generate
1207
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1208
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1209
+ 'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m'
1210
+ ```
1211
+ """
1212
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1213
+ output_router_logits = (
1214
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
1215
+ )
1216
+ output_hidden_states = (
1217
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1218
+ )
1219
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1220
+
1221
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1222
+ outputs = self.model(
1223
+ input_ids=input_ids,
1224
+ attention_mask=attention_mask,
1225
+ position_ids=position_ids,
1226
+ past_key_values=past_key_values,
1227
+ inputs_embeds=inputs_embeds,
1228
+ use_cache=use_cache,
1229
+ output_attentions=output_attentions,
1230
+ output_hidden_states=output_hidden_states,
1231
+ output_router_logits=output_router_logits,
1232
+ return_dict=return_dict,
1233
+ cache_position=cache_position,
1234
+ )
1235
+
1236
+ hidden_states = outputs[0]
1237
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1238
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1239
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
1240
+
1241
+ loss = None
1242
+ if labels is not None:
1243
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
1244
+
1245
+ aux_loss = None
1246
+ if output_router_logits:
1247
+ aux_loss = load_balancing_loss_func(
1248
+ outputs.router_logits if return_dict else outputs[-1],
1249
+ self.num_experts,
1250
+ self.num_experts_per_tok,
1251
+ attention_mask,
1252
+ )
1253
+ if labels is not None:
1254
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
1255
+
1256
+ if not return_dict:
1257
+ output = (logits,) + outputs[1:]
1258
+ if output_router_logits:
1259
+ output = (aux_loss,) + output
1260
+ return (loss,) + output if loss is not None else output
1261
+
1262
+ return MoeCausalLMOutputWithPast(
1263
+ loss=loss,
1264
+ aux_loss=aux_loss,
1265
+ logits=logits,
1266
+ past_key_values=outputs.past_key_values,
1267
+ hidden_states=outputs.hidden_states,
1268
+ attentions=outputs.attentions,
1269
+ router_logits=outputs.router_logits,
1270
+ )
1271
+
1272
+
1273
+ __all__ = ["OlmoeForCausalLM", "OlmoeModel", "OlmoePreTrainedModel"]
docs/transformers/src/transformers/models/omdet_turbo/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_omdet_turbo import *
22
+ from .modeling_omdet_turbo import *
23
+ from .processing_omdet_turbo import *
24
+ else:
25
+ import sys
26
+
27
+ _file = globals()["__file__"]
28
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/src/transformers/models/omdet_turbo/configuration_omdet_turbo.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """OmDet-Turbo model configuration"""
16
+
17
+ from ...configuration_utils import PretrainedConfig
18
+ from ...utils import logging
19
+ from ...utils.backbone_utils import verify_backbone_config_arguments
20
+ from ..auto import CONFIG_MAPPING
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ class OmDetTurboConfig(PretrainedConfig):
27
+ r"""
28
+ This is the configuration class to store the configuration of a [`OmDetTurboForObjectDetection`].
29
+ It is used to instantiate a OmDet-Turbo model according to the specified arguments, defining the model architecture
30
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the OmDet-Turbo
31
+ [omlab/omdet-turbo-swin-tiny-hf](https://huggingface.co/omlab/omdet-turbo-swin-tiny-hf) architecture.
32
+
33
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
34
+ documentation from [`PretrainedConfig`] for more information.
35
+
36
+ Args:
37
+ text_config (`PretrainedConfig`, *optional*):
38
+ The configuration of the text backbone.
39
+ backbone_config (`PretrainedConfig`, *optional*):
40
+ The configuration of the vision backbone.
41
+ use_timm_backbone (`bool`, *optional*, defaults to `True`):
42
+ Whether to use the timm for the vision backbone.
43
+ backbone (`str`, *optional*, defaults to `"swin_tiny_patch4_window7_224"`):
44
+ The name of the pretrained vision backbone to use. If `use_pretrained_backbone=False` a randomly initialized
45
+ backbone with the same architecture `backbone` is used.
46
+ backbone_kwargs (`dict`, *optional*):
47
+ Additional kwargs for the vision backbone.
48
+ use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
49
+ Whether to use a pretrained vision backbone.
50
+ apply_layernorm_after_vision_backbone (`bool`, *optional*, defaults to `True`):
51
+ Whether to apply layer normalization on the feature maps of the vision backbone output.
52
+ image_size (`int`, *optional*, defaults to 640):
53
+ The size (resolution) of each image.
54
+ disable_custom_kernels (`bool`, *optional*, defaults to `False`):
55
+ Whether to disable custom kernels.
56
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
57
+ The epsilon value for layer normalization.
58
+ batch_norm_eps (`float`, *optional*, defaults to 1e-05):
59
+ The epsilon value for batch normalization.
60
+ init_std (`float`, *optional*, defaults to 0.02):
61
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
62
+ text_projection_in_dim (`int`, *optional*, defaults to 512):
63
+ The input dimension for the text projection.
64
+ text_projection_out_dim (`int`, *optional*, defaults to 512):
65
+ The output dimension for the text projection.
66
+ task_encoder_hidden_dim (`int`, *optional*, defaults to 1024):
67
+ The feedforward dimension for the task encoder.
68
+ class_embed_dim (`int`, *optional*, defaults to 512):
69
+ The dimension of the classes embeddings.
70
+ class_distance_type (`str`, *optional*, defaults to `"cosine"`):
71
+ The type of of distance to compare predicted classes to projected classes embeddings.
72
+ Can be `"cosine"` or `"dot"`.
73
+ num_queries (`int`, *optional*, defaults to 900):
74
+ The number of queries.
75
+ csp_activation (`str`, *optional*, defaults to `"silu"`):
76
+ The activation function of the Cross Stage Partial (CSP) networks of the encoder.
77
+ conv_norm_activation (`str`, *optional*, defaults to `"gelu"`):
78
+ The activation function of the ConvNormLayer layers of the encoder.
79
+ encoder_feedforward_activation (`str`, *optional*, defaults to `"relu"`):
80
+ The activation function for the feedforward network of the encoder.
81
+ encoder_feedforward_dropout (`float`, *optional*, defaults to 0.0):
82
+ The dropout rate following the activation of the encoder feedforward network.
83
+ encoder_dropout (`float`, *optional*, defaults to 0.0):
84
+ The dropout rate of the encoder multi-head attention module.
85
+ hidden_expansion (`int`, *optional*, defaults to 1):
86
+ The hidden expansion of the CSP networks in the encoder.
87
+ vision_features_channels (`tuple(int)`, *optional*, defaults to `[256, 256, 256]`):
88
+ The projected vision features channels used as inputs for the decoder.
89
+ encoder_hidden_dim (`int`, *optional*, defaults to 256):
90
+ The hidden dimension of the encoder.
91
+ encoder_in_channels (`List(int)`, *optional*, defaults to `[192, 384, 768]`):
92
+ The input channels for the encoder.
93
+ encoder_projection_indices (`List(int)`, *optional*, defaults to `[2]`):
94
+ The indices of the input features projected by each layers.
95
+ encoder_attention_heads (`int`, *optional*, defaults to 8):
96
+ The number of attention heads for the encoder.
97
+ encoder_dim_feedforward (`int`, *optional*, defaults to 2048):
98
+ The feedforward dimension for the encoder.
99
+ encoder_layers (`int`, *optional*, defaults to 1):
100
+ The number of layers in the encoder.
101
+ positional_encoding_temperature (`int`, *optional*, defaults to 10000):
102
+ The positional encoding temperature in the encoder.
103
+ num_feature_levels (`int`, *optional*, defaults to 3):
104
+ The number of feature levels for the multi-scale deformable attention module of the decoder.
105
+ decoder_hidden_dim (`int`, *optional*, defaults to 256):
106
+ The hidden dimension of the decoder.
107
+ decoder_num_heads (`int`, *optional*, defaults to 8):
108
+ The number of heads for the decoder.
109
+ decoder_num_layers (`int`, *optional*, defaults to 6):
110
+ The number of layers for the decoder.
111
+ decoder_activation (`str`, *optional*, defaults to `"relu"`):
112
+ The activation function for the decoder.
113
+ decoder_dim_feedforward (`int`, *optional*, defaults to 2048):
114
+ The feedforward dimension for the decoder.
115
+ decoder_num_points (`int`, *optional*, defaults to 4):
116
+ The number of points sampled in the decoder multi-scale deformable attention module.
117
+ decoder_dropout (`float`, *optional*, defaults to 0.0):
118
+ The dropout rate for the decoder.
119
+ eval_size (`Tuple[int, int]`, *optional*):
120
+ Height and width used to computes the effective height and width of the position embeddings after taking
121
+ into account the stride (see RTDetr).
122
+ learn_initial_query (`bool`, *optional*, defaults to `False`):
123
+ Whether to learn the initial query.
124
+ cache_size (`int`, *optional*, defaults to 100):
125
+ The cache size for the classes and prompts caches.
126
+ is_encoder_decoder (`bool`, *optional*, defaults to `True`):
127
+ Whether the model is used as an encoder-decoder model or not.
128
+ kwargs (`Dict[str, Any]`, *optional*):
129
+ Additional parameters from the architecture. The values in kwargs will be saved as part of the configuration
130
+ and can be used to control the model outputs.
131
+
132
+ Examples:
133
+
134
+ ```python
135
+ >>> from transformers import OmDetTurboConfig, OmDetTurboForObjectDetection
136
+
137
+ >>> # Initializing a OmDet-Turbo omlab/omdet-turbo-swin-tiny-hf style configuration
138
+ >>> configuration = OmDetTurboConfig()
139
+
140
+ >>> # Initializing a model (with random weights) from the omlab/omdet-turbo-swin-tiny-hf style configuration
141
+ >>> model = OmDetTurboForObjectDetection(configuration)
142
+
143
+ >>> # Accessing the model configuration
144
+ >>> configuration = model.config
145
+ ```"""
146
+
147
+ model_type = "omdet-turbo"
148
+ attribute_map = {
149
+ "encoder_hidden_dim": "d_model",
150
+ "num_attention_heads": "encoder_attention_heads",
151
+ }
152
+
153
+ def __init__(
154
+ self,
155
+ text_config=None,
156
+ backbone_config=None,
157
+ use_timm_backbone=True,
158
+ backbone="swin_tiny_patch4_window7_224",
159
+ backbone_kwargs=None,
160
+ use_pretrained_backbone=False,
161
+ apply_layernorm_after_vision_backbone=True,
162
+ image_size=640,
163
+ disable_custom_kernels=False,
164
+ layer_norm_eps=1e-5,
165
+ batch_norm_eps=1e-5,
166
+ init_std=0.02,
167
+ text_projection_in_dim=512,
168
+ text_projection_out_dim=512,
169
+ task_encoder_hidden_dim=1024,
170
+ class_embed_dim=512,
171
+ class_distance_type="cosine",
172
+ num_queries=900,
173
+ csp_activation="silu",
174
+ conv_norm_activation="gelu",
175
+ encoder_feedforward_activation="relu",
176
+ encoder_feedforward_dropout=0.0,
177
+ encoder_dropout=0.0,
178
+ hidden_expansion=1,
179
+ vision_features_channels=[256, 256, 256],
180
+ encoder_hidden_dim=256,
181
+ encoder_in_channels=[192, 384, 768],
182
+ encoder_projection_indices=[2],
183
+ encoder_attention_heads=8,
184
+ encoder_dim_feedforward=2048,
185
+ encoder_layers=1,
186
+ positional_encoding_temperature=10000,
187
+ num_feature_levels=3,
188
+ decoder_hidden_dim=256,
189
+ decoder_num_heads=8,
190
+ decoder_num_layers=6,
191
+ decoder_activation="relu",
192
+ decoder_dim_feedforward=2048,
193
+ decoder_num_points=4,
194
+ decoder_dropout=0.0,
195
+ eval_size=None,
196
+ learn_initial_query=False,
197
+ cache_size=100,
198
+ is_encoder_decoder=True,
199
+ **kwargs,
200
+ ):
201
+ if use_timm_backbone:
202
+ if backbone_config is None:
203
+ backbone_kwargs = {
204
+ "out_indices": [1, 2, 3],
205
+ "img_size": image_size,
206
+ "always_partition": True,
207
+ }
208
+ elif backbone_config is None:
209
+ logger.info("`backbone_config` is `None`. Initializing the config with the default `swin` vision config.")
210
+ backbone_config = CONFIG_MAPPING["swin"](
211
+ window_size=7,
212
+ image_size=image_size,
213
+ embed_dim=96,
214
+ depths=[2, 2, 6, 2],
215
+ num_heads=[3, 6, 12, 24],
216
+ out_indices=[2, 3, 4],
217
+ )
218
+ elif isinstance(backbone_config, dict):
219
+ backbone_model_type = backbone_config.get("model_type")
220
+ config_class = CONFIG_MAPPING[backbone_model_type]
221
+ backbone_config = config_class.from_dict(backbone_config)
222
+
223
+ verify_backbone_config_arguments(
224
+ use_timm_backbone=use_timm_backbone,
225
+ use_pretrained_backbone=use_pretrained_backbone,
226
+ backbone=backbone,
227
+ backbone_config=backbone_config,
228
+ backbone_kwargs=backbone_kwargs,
229
+ )
230
+
231
+ if text_config is None:
232
+ logger.info(
233
+ "`text_config` is `None`. Initializing the config with the default `clip_text_model` text config."
234
+ )
235
+ text_config = CONFIG_MAPPING["clip_text_model"]()
236
+ elif isinstance(text_config, dict):
237
+ text_model_type = text_config.get("model_type")
238
+ text_config = CONFIG_MAPPING[text_model_type](**text_config)
239
+
240
+ if class_distance_type not in ["cosine", "dot"]:
241
+ raise ValueError(
242
+ f"Invalid `class_distance_type`. It should be either `cosine` or `dot`, but got {class_distance_type}."
243
+ )
244
+
245
+ self.text_config = text_config
246
+ self.backbone_config = backbone_config
247
+ self.use_timm_backbone = use_timm_backbone
248
+ self.backbone = backbone
249
+ self.backbone_kwargs = backbone_kwargs
250
+ self.use_pretrained_backbone = use_pretrained_backbone
251
+ self.apply_layernorm_after_vision_backbone = apply_layernorm_after_vision_backbone
252
+ self.image_size = image_size
253
+ self.disable_custom_kernels = disable_custom_kernels
254
+ self.layer_norm_eps = layer_norm_eps
255
+ self.batch_norm_eps = batch_norm_eps
256
+ self.init_std = init_std
257
+ self.text_projection_in_dim = text_projection_in_dim
258
+ self.text_projection_out_dim = text_projection_out_dim
259
+ self.task_encoder_hidden_dim = task_encoder_hidden_dim
260
+ self.class_embed_dim = class_embed_dim
261
+ self.class_distance_type = class_distance_type
262
+ self.num_queries = num_queries
263
+ self.csp_activation = csp_activation
264
+ self.conv_norm_activation = conv_norm_activation
265
+ self.encoder_feedforward_activation = encoder_feedforward_activation
266
+ self.encoder_feedforward_dropout = encoder_feedforward_dropout
267
+ self.encoder_dropout = encoder_dropout
268
+ self.hidden_expansion = hidden_expansion
269
+ self.vision_features_channels = vision_features_channels
270
+ self.encoder_hidden_dim = encoder_hidden_dim
271
+ self.encoder_in_channels = encoder_in_channels
272
+ self.encoder_projection_indices = encoder_projection_indices
273
+ self.encoder_attention_heads = encoder_attention_heads
274
+ self.encoder_dim_feedforward = encoder_dim_feedforward
275
+ self.encoder_layers = encoder_layers
276
+ self.positional_encoding_temperature = positional_encoding_temperature
277
+ self.num_feature_levels = num_feature_levels
278
+ self.decoder_hidden_dim = decoder_hidden_dim
279
+ self.decoder_num_heads = decoder_num_heads
280
+ self.decoder_num_layers = decoder_num_layers
281
+ self.decoder_activation = decoder_activation
282
+ self.decoder_dim_feedforward = decoder_dim_feedforward
283
+ self.decoder_num_points = decoder_num_points
284
+ self.decoder_dropout = decoder_dropout
285
+ self.eval_size = eval_size
286
+ self.learn_initial_query = learn_initial_query
287
+ self.cache_size = cache_size
288
+ self.is_encoder_decoder = is_encoder_decoder
289
+
290
+ super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
291
+
292
+
293
+ __all__ = ["OmDetTurboConfig"]
docs/transformers/src/transformers/models/omdet_turbo/convert_omdet_turbo_to_hf.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert OmDet-Turbo checkpoints from the original repository.
16
+
17
+ URL: https://github.com/om-ai-lab/OmDet"""
18
+
19
+ import argparse
20
+
21
+ import requests
22
+ import torch
23
+ from PIL import Image
24
+
25
+ from transformers import (
26
+ CLIPTokenizer,
27
+ DetrImageProcessor,
28
+ OmDetTurboConfig,
29
+ OmDetTurboForObjectDetection,
30
+ OmDetTurboProcessor,
31
+ )
32
+
33
+
34
+ IMAGE_MEAN = [123.675, 116.28, 103.53]
35
+ IMAGE_STD = [58.395, 57.12, 57.375]
36
+
37
+
38
+ def get_omdet_turbo_config(model_name, use_timm_backbone):
39
+ if "tiny" in model_name:
40
+ window_size = 7
41
+ embed_dim = 96
42
+ depths = (2, 2, 6, 2)
43
+ num_heads = (3, 6, 12, 24)
44
+ image_size = 640
45
+ else:
46
+ raise ValueError("Model not supported, only supports tiny variant.")
47
+
48
+ config = OmDetTurboConfig(
49
+ backbone_window_size=window_size,
50
+ backbone_image_size=image_size,
51
+ backbone_embed_dim=embed_dim,
52
+ backbone_depths=depths,
53
+ backbone_num_heads=num_heads,
54
+ backbone_out_indices=(1, 2, 3),
55
+ text_config={"model_type": "clip_text_model"},
56
+ use_timm_backbone=use_timm_backbone,
57
+ backbone="swin_tiny_patch4_window7_224" if use_timm_backbone else None,
58
+ apply_layernorm_after_vision_backbone=True if use_timm_backbone else False,
59
+ use_pretrained_backbone=False,
60
+ )
61
+
62
+ return config
63
+
64
+
65
+ def create_rename_keys_vision(state_dict, config):
66
+ rename_keys = []
67
+ # fmt: off
68
+ ########################################## VISION BACKBONE - START
69
+ for layer_name in state_dict.keys():
70
+ if layer_name.startswith("backbone") and not layer_name.startswith("backbone.norm"):
71
+ if config.use_timm_backbone:
72
+ layer_name_replace = layer_name.replace("backbone", "vision_backbone.vision_backbone._backbone")
73
+ layer_name_replace = layer_name_replace.replace(".layers.", ".layers_")
74
+ if "downsample" in layer_name:
75
+ # get layer number
76
+ layer_num = int(layer_name.split(".")[2])
77
+ layer_name_replace = layer_name_replace.replace(f"{layer_num}.downsample", f"{layer_num+1}.downsample")
78
+ else:
79
+ layer_name_replace = layer_name.replace("backbone", "vision_backbone.vision_backbone")
80
+ layer_name_replace = layer_name_replace.replace("patch_embed.proj", "embeddings.patch_embeddings.projection")
81
+ layer_name_replace = layer_name_replace.replace("patch_embed.norm", "embeddings.norm")
82
+ if layer_name.startswith("backbone.layers"):
83
+ layer_name_replace = layer_name_replace.replace("norm1", "layernorm_before")
84
+ layer_name_replace = layer_name_replace.replace("norm2", "layernorm_after")
85
+ layer_name_replace = layer_name_replace.replace("attn.proj", "attention.output.dense")
86
+ layer_name_replace = layer_name_replace.replace("mlp.fc1", "intermediate.dense")
87
+ layer_name_replace = layer_name_replace.replace("mlp.fc2", "output.dense")
88
+ layer_name_replace = layer_name_replace.replace(".layers.", ".encoder.layers.")
89
+ layer_name_replace = layer_name_replace.replace(".attn.", ".attention.self.")
90
+ elif layer_name.startswith("backbone.norm"):
91
+ layer_num = int(layer_name.split("norm")[1].split(".")[0])
92
+ if config.use_timm_backbone:
93
+ layer_name_replace = layer_name.replace("backbone", "vision_backbone")
94
+ layer_name_replace = layer_name_replace.replace(f"norm{layer_num}", f"layer_norms.{layer_num-1}")
95
+ else:
96
+ layer_name_replace = layer_name.replace(f"backbone.norm{layer_num}", f"vision_backbone.vision_backbone.hidden_states_norms.stage{layer_num+1}")
97
+ else:
98
+ continue
99
+ rename_keys.append((layer_name, layer_name_replace))
100
+ ########################################## VISION BACKBONE - END
101
+
102
+ ########################################## ENCODER - START
103
+ for layer_name, params in state_dict.items():
104
+ if "neck" in layer_name:
105
+ layer_name_replace = layer_name.replace("neck", "encoder")
106
+ layer_name_replace = layer_name_replace.replace("input_proj", "channel_projection_layers")
107
+ if "fpn_blocks" in layer_name or "pan_blocks" in layer_name or "lateral_convs" in layer_name or "downsample_convs" in layer_name:
108
+ layer_name_replace = layer_name_replace.replace(".m.", ".bottlenecks.")
109
+ layer_name_replace = layer_name_replace.replace(".cv", ".conv")
110
+ layer_name_replace = layer_name_replace.replace(".bn", ".norm")
111
+ if "encoder_layer" in layer_name:
112
+ layer_name_replace = layer_name_replace.replace("encoder_layer", "encoder.0.layers.0")
113
+ layer_name_replace = layer_name_replace.replace(".linear", ".fc")
114
+ layer_name_replace = layer_name_replace.replace("norm1", "self_attn_layer_norm")
115
+ layer_name_replace = layer_name_replace.replace("norm2", "final_layer_norm")
116
+ rename_keys.append((layer_name, layer_name_replace))
117
+ ########################################## ENCODER - END
118
+
119
+ ########################################## DECODER - START
120
+ for layer_name, params in state_dict.items():
121
+ if layer_name.startswith("decoder"):
122
+ layer_name_replace = layer_name.replace("decoder.decoder.layers", "decoder.layers")
123
+ layer_name_replace = layer_name_replace.replace("input_proj", "channel_projection_layers")
124
+ layer_name_replace = layer_name_replace.replace("query_pos_head", "query_position_head")
125
+ layer_name_replace = layer_name_replace.replace("enc_bbox_head", "encoder_bbox_head")
126
+ layer_name_replace = layer_name_replace.replace("enc_output", "encoder_vision_features")
127
+ layer_name_replace = layer_name_replace.replace("dec_score_head", "decoder_class_head")
128
+ layer_name_replace = layer_name_replace.replace("dec_bbox_head", "decoder_bbox_head")
129
+ layer_name_replace = layer_name_replace.replace("enc_score_head", "encoder_class_head")
130
+ rename_keys.append((layer_name, layer_name_replace))
131
+ ########################################## DECODER - END
132
+ # fmt: on
133
+ return rename_keys
134
+
135
+
136
+ def create_rename_keys_language(state_dict):
137
+ rename_keys = []
138
+ # fmt: off
139
+ for layer_name in state_dict.keys():
140
+ if layer_name.startswith("language_backbone") and not layer_name.startswith("language_backbone.text_projection"):
141
+ layer_name_replace = layer_name.replace("language_backbone", "language_backbone.model.text_model")
142
+ layer_name_replace = layer_name_replace.replace("transformer.resblocks", "encoder.layers")
143
+ layer_name_replace = layer_name_replace.replace("token_embedding", "embeddings.token_embedding")
144
+ layer_name_replace = layer_name_replace.replace("positional_embedding", "embeddings.position_embedding.weight")
145
+ layer_name_replace = layer_name_replace.replace(".attn", ".self_attn")
146
+ layer_name_replace = layer_name_replace.replace(".mlp.c_fc", ".mlp.fc1")
147
+ layer_name_replace = layer_name_replace.replace(".mlp.c_proj", ".mlp.fc2")
148
+ layer_name_replace = layer_name_replace.replace("ln_final", "final_layer_norm")
149
+ layer_name_replace = layer_name_replace.replace(".ln_", ".layer_norm")
150
+ rename_keys.append((layer_name, layer_name_replace))
151
+ # fmt: on
152
+ return rename_keys
153
+
154
+
155
+ def rename_key(dct, old, new):
156
+ val = dct.pop(old)
157
+ dct[new] = val
158
+
159
+
160
+ # we split up the matrix of each encoder layer into queries, keys and values
161
+ def read_in_q_k_v_vision(state_dict, config):
162
+ state_dict_keys = list(state_dict.keys())
163
+ for layer_name_vision in state_dict_keys:
164
+ if layer_name_vision.startswith("vision_backbone") and "qkv" in layer_name_vision:
165
+ layer_num = int(layer_name_vision.split(".")[4])
166
+ hidden_size = config.backbone_config.embed_dim * 2**layer_num
167
+ if "weight" in layer_name_vision:
168
+ in_proj_weight = state_dict.pop(layer_name_vision)
169
+ state_dict[layer_name_vision.replace("qkv.weight", "key.weight")] = in_proj_weight[:hidden_size, :]
170
+ state_dict[layer_name_vision.replace("qkv.weight", "query.weight")] = in_proj_weight[
171
+ hidden_size : hidden_size * 2, :
172
+ ]
173
+ state_dict[layer_name_vision.replace("qkv.weight", "value.weight")] = in_proj_weight[-hidden_size:, :]
174
+ elif "bias" in layer_name_vision:
175
+ in_proj_bias = state_dict.pop(layer_name_vision)
176
+ state_dict[layer_name_vision.replace("qkv.bias", "key.bias")] = in_proj_bias[:hidden_size]
177
+ state_dict[layer_name_vision.replace("qkv.bias", "query.bias")] = in_proj_bias[
178
+ hidden_size : hidden_size * 2
179
+ ]
180
+ state_dict[layer_name_vision.replace("qkv.bias", "value.bias")] = in_proj_bias[-hidden_size:]
181
+
182
+
183
+ def read_in_q_k_v_text(state_dict, config):
184
+ state_dict_keys = list(state_dict.keys())
185
+ hidden_size = config.text_config.projection_dim
186
+ for layer_name_text in state_dict_keys:
187
+ if layer_name_text.startswith("language_backbone") and "in_proj" in layer_name_text:
188
+ if "weight" in layer_name_text:
189
+ in_proj_weight = state_dict.pop(layer_name_text)
190
+ state_dict[layer_name_text.replace("in_proj_weight", "q_proj.weight")] = in_proj_weight[
191
+ :hidden_size, :
192
+ ]
193
+ state_dict[layer_name_text.replace("in_proj_weight", "k_proj.weight")] = in_proj_weight[
194
+ hidden_size : hidden_size * 2, :
195
+ ]
196
+ state_dict[layer_name_text.replace("in_proj_weight", "v_proj.weight")] = in_proj_weight[
197
+ -hidden_size:, :
198
+ ]
199
+ elif "bias" in layer_name_text:
200
+ in_proj_bias = state_dict.pop(layer_name_text)
201
+ state_dict[layer_name_text.replace("in_proj_bias", "q_proj.bias")] = in_proj_bias[:hidden_size]
202
+ state_dict[layer_name_text.replace("in_proj_bias", "k_proj.bias")] = in_proj_bias[
203
+ hidden_size : hidden_size * 2
204
+ ]
205
+ state_dict[layer_name_text.replace("in_proj_bias", "v_proj.bias")] = in_proj_bias[-hidden_size:]
206
+
207
+
208
+ def read_in_q_k_v_encoder(state_dict, config):
209
+ embed_dim = config.encoder_hidden_dim
210
+ # read in weights + bias of input projection layer (in original implementation, this is a single matrix + bias)
211
+ in_proj_weight = state_dict.pop("encoder.encoder.0.layers.0.self_attn.in_proj_weight")
212
+ in_proj_bias = state_dict.pop("encoder.encoder.0.layers.0.self_attn.in_proj_bias")
213
+ # next, add query, keys and values (in that order) to the state dict
214
+ state_dict["encoder.encoder.0.layers.0.self_attn.query.weight"] = in_proj_weight[:embed_dim, :]
215
+ state_dict["encoder.encoder.0.layers.0.self_attn.query.bias"] = in_proj_bias[:embed_dim]
216
+ state_dict["encoder.encoder.0.layers.0.self_attn.key.weight"] = in_proj_weight[embed_dim : embed_dim * 2, :]
217
+ state_dict["encoder.encoder.0.layers.0.self_attn.key.bias"] = in_proj_bias[embed_dim : embed_dim * 2]
218
+ state_dict["encoder.encoder.0.layers.0.self_attn.value.weight"] = in_proj_weight[-embed_dim:, :]
219
+ state_dict["encoder.encoder.0.layers.0.self_attn.value.bias"] = in_proj_bias[-embed_dim:]
220
+
221
+
222
+ def read_in_q_k_v_decoder(state_dict, config):
223
+ for layer_num in range(config.decoder_num_layers):
224
+ embed_dim = config.decoder_hidden_dim
225
+ # read in weights + bias of input projection layer (in original implementation, this is a single matrix + bias)
226
+ in_proj_weight = state_dict.pop(f"decoder.layers.{layer_num}.self_attn.in_proj_weight")
227
+ in_proj_bias = state_dict.pop(f"decoder.layers.{layer_num}.self_attn.in_proj_bias")
228
+ # next, add query, keys and values (in that order) to the state dict
229
+ state_dict[f"decoder.layers.{layer_num}.self_attn.query.weight"] = in_proj_weight[:embed_dim, :]
230
+ state_dict[f"decoder.layers.{layer_num}.self_attn.query.bias"] = in_proj_bias[:embed_dim]
231
+ state_dict[f"decoder.layers.{layer_num}.self_attn.key.weight"] = in_proj_weight[embed_dim : embed_dim * 2, :]
232
+ state_dict[f"decoder.layers.{layer_num}.self_attn.key.bias"] = in_proj_bias[embed_dim : embed_dim * 2]
233
+ state_dict[f"decoder.layers.{layer_num}.self_attn.value.weight"] = in_proj_weight[-embed_dim:, :]
234
+ state_dict[f"decoder.layers.{layer_num}.self_attn.value.bias"] = in_proj_bias[-embed_dim:]
235
+
236
+
237
+ def run_test(model, processor):
238
+ # We will verify our results on an image of cute cats
239
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
240
+ image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
241
+
242
+ classes = ["cat", "remote"]
243
+ task = "Detect {}.".format(", ".join(classes))
244
+ inputs = processor(image, text=classes, task=task, return_tensors="pt")
245
+
246
+ # Running forward
247
+ with torch.no_grad():
248
+ outputs = model(**inputs)
249
+
250
+ predicted_slice = outputs[1][0, :3, :3]
251
+ print(predicted_slice)
252
+ expected_slice = torch.tensor([[0.9427, -2.5958], [0.2105, -3.4569], [-2.6364, -4.1610]])
253
+
254
+ assert torch.allclose(predicted_slice, expected_slice, atol=1e-4)
255
+ print("Looks ok!")
256
+
257
+
258
+ @torch.no_grad()
259
+ def convert_omdet_turbo_checkpoint(args):
260
+ model_name = args.model_name
261
+ pytorch_dump_folder_path = args.pytorch_dump_folder_path
262
+ push_to_hub = args.push_to_hub
263
+ use_timm_backbone = args.use_timm_backbone
264
+
265
+ checkpoint_mapping = {
266
+ "omdet-turbo-tiny": [
267
+ "https://huggingface.co/omlab/OmDet-Turbo_tiny_SWIN_T/resolve/main/OmDet-Turbo_tiny_SWIN_T.pth",
268
+ "https://huggingface.co/omlab/OmDet-Turbo_tiny_SWIN_T/resolve/main/ViT-B-16.pt",
269
+ ],
270
+ }
271
+ # Define default OmDetTurbo configuation
272
+ config = get_omdet_turbo_config(model_name, use_timm_backbone)
273
+
274
+ # Load original checkpoint
275
+ checkpoint_url = checkpoint_mapping[model_name]
276
+ original_state_dict_vision = torch.hub.load_state_dict_from_url(checkpoint_url[0], map_location="cpu")["model"]
277
+ original_state_dict_vision = {k.replace("module.", ""): v for k, v in original_state_dict_vision.items()}
278
+
279
+ # Rename keys
280
+ new_state_dict = original_state_dict_vision.copy()
281
+ rename_keys_vision = create_rename_keys_vision(new_state_dict, config)
282
+
283
+ rename_keys_language = create_rename_keys_language(new_state_dict)
284
+
285
+ for src, dest in rename_keys_vision:
286
+ rename_key(new_state_dict, src, dest)
287
+
288
+ for src, dest in rename_keys_language:
289
+ rename_key(new_state_dict, src, dest)
290
+
291
+ if not use_timm_backbone:
292
+ read_in_q_k_v_vision(new_state_dict, config)
293
+ read_in_q_k_v_text(new_state_dict, config)
294
+ read_in_q_k_v_encoder(new_state_dict, config)
295
+ read_in_q_k_v_decoder(new_state_dict, config)
296
+ # add "model" prefix to all keys
297
+ new_state_dict = {f"model.{k}": v for k, v in new_state_dict.items()}
298
+
299
+ # Load HF model
300
+ model = OmDetTurboForObjectDetection(config)
301
+ model.eval()
302
+ missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
303
+ print("Missing keys:", missing_keys)
304
+ print("Unexpected keys:", unexpected_keys)
305
+
306
+ image_processor = DetrImageProcessor(
307
+ size={"height": config.backbone_image_size, "width": config.backbone_image_size},
308
+ do_rescale=False,
309
+ image_mean=IMAGE_MEAN,
310
+ image_std=IMAGE_STD,
311
+ do_pad=False,
312
+ )
313
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
314
+ processor = OmDetTurboProcessor(image_processor=image_processor, tokenizer=tokenizer)
315
+
316
+ # end-to-end consistency test
317
+ run_test(model, processor)
318
+
319
+ if pytorch_dump_folder_path is not None:
320
+ model.save_pretrained(pytorch_dump_folder_path)
321
+ processor.save_pretrained(pytorch_dump_folder_path)
322
+
323
+ if push_to_hub:
324
+ model.push_to_hub(f"omlab/{model_name}")
325
+ processor.push_to_hub(f"omlab/{model_name}")
326
+
327
+
328
+ if __name__ == "__main__":
329
+ parser = argparse.ArgumentParser()
330
+ # Required parameters
331
+ parser.add_argument(
332
+ "--model_name",
333
+ default="omdet-turbo-tiny",
334
+ type=str,
335
+ choices=["omdet-turbo-tiny"],
336
+ help="Name of the OmDetTurbo model you'd like to convert.",
337
+ )
338
+ parser.add_argument(
339
+ "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
340
+ )
341
+ parser.add_argument(
342
+ "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
343
+ )
344
+ parser.add_argument(
345
+ "--use_timm_backbone", action="store_true", help="Whether or not to use timm backbone for vision backbone."
346
+ )
347
+
348
+ args = parser.parse_args()
349
+ convert_omdet_turbo_checkpoint(args)
docs/transformers/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py ADDED
@@ -0,0 +1,1711 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Om Research Lab and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch OmDet-Turbo model."""
16
+
17
+ import math
18
+ import warnings
19
+ from collections import OrderedDict
20
+ from dataclasses import dataclass
21
+ from functools import lru_cache
22
+ from typing import List, Optional, Tuple, Union
23
+
24
+ import torch
25
+ import torch.nn.functional as F
26
+ from torch import Tensor, nn
27
+
28
+ from ...activations import ACT2CLS, ACT2FN
29
+ from ...file_utils import (
30
+ ModelOutput,
31
+ add_start_docstrings,
32
+ add_start_docstrings_to_model_forward,
33
+ replace_return_docstrings,
34
+ )
35
+ from ...integrations import use_kernel_forward_from_hub
36
+ from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
37
+ from ...modeling_utils import PreTrainedModel
38
+ from ...utils import logging
39
+ from ...utils.backbone_utils import load_backbone
40
+ from ..auto import AutoModel
41
+ from .configuration_omdet_turbo import OmDetTurboConfig
42
+
43
+
44
+ logger = logging.get_logger(__name__)
45
+ _CONFIG_FOR_DOC = "OmDetTurboConfig"
46
+
47
+
48
+ @dataclass
49
+ class OmDetTurboEncoderOutput(ModelOutput):
50
+ """
51
+ Base class for outputs of the OmDetTurboHybridEncoder.
52
+
53
+ Args:
54
+ last_hidden_state (`torch.FloatTensor`):
55
+ Last hidden states of the encoder.
56
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
57
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
58
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
59
+
60
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
61
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
62
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
63
+ sequence_length)`.
64
+
65
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
66
+ heads.
67
+ extracted_states (`Tuple[torch.FloatTensor]`):
68
+ The extracted states from the Feature Pyramid Network (FPN) and Path Aggregation Network (PAN) of the encoder.
69
+ """
70
+
71
+ last_hidden_state: Optional[torch.FloatTensor] = None
72
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
73
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
74
+ extracted_states: Tuple[torch.FloatTensor] = None
75
+
76
+
77
+ @dataclass
78
+ class OmDetTurboDecoderOutput(ModelOutput):
79
+ """
80
+ Base class for outputs of the OmDetTurboDecoder.
81
+
82
+ Args:
83
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
84
+ Sequence of hidden-states at the output of the last layer of the decoder.
85
+ decoder_coords (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
86
+ The predicted coordinates of the objects.
87
+ decoder_classes (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes)`):
88
+ The predicted classes of the objects.
89
+ encoder_coord_logits (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
90
+ The predicted coordinates of the objects from the encoder.
91
+ encoder_class_logits (`Tuple[torch.FloatTensor]`) of shape `(batch_size, num_queries, num_classes)`:
92
+ The predicted class of the objects from the encoder.
93
+ init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
94
+ The initial reference points.
95
+ intermediate_reference_points (`Tuple[Tuple[torch.FloatTensor]]`):
96
+ The intermediate reference points.
97
+ hidden_states (`Optional[Tuple[torch.FloatTensor]]`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
98
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape
99
+ `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
100
+ plus the initial embedding outputs.
101
+ attentions (`Optional[Tuple[Tuple[torch.FloatTensor]]]`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
102
+ Tuple of tuples of `torch.FloatTensor` (one for attention for each layer) of shape `(batch_size, num_heads,
103
+ sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
104
+ weighted average in the self-attention, cross-attention and multi-scale deformable attention heads.
105
+ """
106
+
107
+ last_hidden_state: Optional[torch.FloatTensor] = None
108
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
109
+ attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
110
+ decoder_coords: Optional[torch.FloatTensor] = None
111
+ decoder_classes: Optional[torch.FloatTensor] = None
112
+ encoder_coord_logits: Optional[torch.FloatTensor] = None
113
+ encoder_class_logits: Tuple[torch.FloatTensor] = None
114
+ init_reference_points: Optional[torch.FloatTensor] = None
115
+ intermediate_reference_points: Tuple[Tuple[torch.FloatTensor]] = None
116
+
117
+
118
+ @dataclass
119
+ class OmDetTurboObjectDetectionOutput(ModelOutput):
120
+ """
121
+ Output type of [`OmDetTurboObjectDetectionOutput`].
122
+
123
+ Args:
124
+ loss (`torch.FloatTensor`):
125
+ The loss value.
126
+ decoder_coord_logits (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
127
+ The predicted coordinates logits of the objects.
128
+ decoder_class_logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes)`):
129
+ The predicted class of the objects.
130
+ init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
131
+ The initial reference points.
132
+ intermediate_reference_points (`Tuple[Tuple[torch.FloatTensor]]`):
133
+ The intermediate reference points.
134
+ encoder_coord_logits (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
135
+ The predicted coordinates of the objects from the encoder.
136
+ encoder_class_logits (`Tuple[torch.FloatTensor]`):
137
+ The predicted class of the objects from the encoder.
138
+ encoder_extracted_states (`torch.FloatTensor`):
139
+ The extracted states from the Feature Pyramid Network (FPN) and Path Aggregation Network (PAN) of the encoder.
140
+ decoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
141
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape
142
+ `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
143
+ plus the initial embedding outputs.
144
+ decoder_attentions (`Tuple[Tuple[torch.FloatTensor]]`, *optional*):
145
+ Tuple of tuples of `torch.FloatTensor` (one for attention for each layer) of shape `(batch_size, num_heads,
146
+ sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
147
+ weighted average in the self-attention, cross-attention and multi-scale deformable attention heads.
148
+ encoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
149
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape
150
+ `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
151
+ plus the initial embedding outputs.
152
+ encoder_attentions (`Tuple[Tuple[torch.FloatTensor]]`, *optional*):
153
+ Tuple of tuples of `torch.FloatTensor` (one for attention for each layer) of shape `(batch_size, num_heads,
154
+ sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
155
+ weighted average in the self-attention, cross-attention and multi-scale deformable attention heads.
156
+ classes_structure (`torch.LongTensor`, *optional*):
157
+ The number of queried classes for each image.
158
+ """
159
+
160
+ loss: Optional[torch.FloatTensor] = None
161
+ decoder_coord_logits: Optional[torch.FloatTensor] = None
162
+ decoder_class_logits: Optional[torch.FloatTensor] = None
163
+ init_reference_points: Optional[torch.FloatTensor] = None
164
+ intermediate_reference_points: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
165
+ encoder_coord_logits: Optional[torch.FloatTensor] = None
166
+ encoder_class_logits: Tuple[torch.FloatTensor] = None
167
+ encoder_extracted_states: Optional[torch.FloatTensor] = None
168
+ decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
169
+ decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
170
+ encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
171
+ encoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
172
+ classes_structure: Optional[torch.LongTensor] = None
173
+
174
+
175
+ @use_kernel_forward_from_hub("MultiScaleDeformableAttention")
176
+ # Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttention
177
+ class MultiScaleDeformableAttention(nn.Module):
178
+ def forward(
179
+ self,
180
+ value: Tensor,
181
+ value_spatial_shapes: Tensor,
182
+ value_spatial_shapes_list: List[Tuple],
183
+ level_start_index: Tensor,
184
+ sampling_locations: Tensor,
185
+ attention_weights: Tensor,
186
+ im2col_step: int,
187
+ ):
188
+ batch_size, _, num_heads, hidden_dim = value.shape
189
+ _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
190
+ value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1)
191
+ sampling_grids = 2 * sampling_locations - 1
192
+ sampling_value_list = []
193
+ for level_id, (height, width) in enumerate(value_spatial_shapes_list):
194
+ # batch_size, height*width, num_heads, hidden_dim
195
+ # -> batch_size, height*width, num_heads*hidden_dim
196
+ # -> batch_size, num_heads*hidden_dim, height*width
197
+ # -> batch_size*num_heads, hidden_dim, height, width
198
+ value_l_ = (
199
+ value_list[level_id]
200
+ .flatten(2)
201
+ .transpose(1, 2)
202
+ .reshape(batch_size * num_heads, hidden_dim, height, width)
203
+ )
204
+ # batch_size, num_queries, num_heads, num_points, 2
205
+ # -> batch_size, num_heads, num_queries, num_points, 2
206
+ # -> batch_size*num_heads, num_queries, num_points, 2
207
+ sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
208
+ # batch_size*num_heads, hidden_dim, num_queries, num_points
209
+ sampling_value_l_ = nn.functional.grid_sample(
210
+ value_l_,
211
+ sampling_grid_l_,
212
+ mode="bilinear",
213
+ padding_mode="zeros",
214
+ align_corners=False,
215
+ )
216
+ sampling_value_list.append(sampling_value_l_)
217
+ # (batch_size, num_queries, num_heads, num_levels, num_points)
218
+ # -> (batch_size, num_heads, num_queries, num_levels, num_points)
219
+ # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
220
+ attention_weights = attention_weights.transpose(1, 2).reshape(
221
+ batch_size * num_heads, 1, num_queries, num_levels * num_points
222
+ )
223
+ output = (
224
+ (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
225
+ .sum(-1)
226
+ .view(batch_size, num_heads * hidden_dim, num_queries)
227
+ )
228
+ return output.transpose(1, 2).contiguous()
229
+
230
+
231
+ class OmDetTurboLRUCache:
232
+ def __init__(self, capacity: int):
233
+ self.cache = OrderedDict()
234
+ self.capacity = capacity
235
+ self.current_load = 0
236
+
237
+ def has(self, key) -> bool:
238
+ return key in self.cache
239
+
240
+ def get(self, key):
241
+ """
242
+ Get the value of the key if the key exists in the cache, otherwise return None.
243
+ Move the key to the end of the cache to show that it was recently used.
244
+ """
245
+ if key not in self.cache:
246
+ return None
247
+ self.cache.move_to_end(key)
248
+ return self.cache[key]
249
+
250
+ def put(self, key, value) -> None:
251
+ """
252
+ Add the key-value pair to the cache.
253
+ Move the key to the end of the cache to show that it was recently used.
254
+ If the cache is full, remove the first key (least recently used).
255
+ """
256
+ if key not in self.cache:
257
+ self.current_load += 1
258
+ if self.current_load > self.capacity:
259
+ self.cache.popitem(last=False)
260
+ self.current_load -= 1
261
+
262
+ self.cache[key] = value
263
+ self.cache.move_to_end(key)
264
+
265
+
266
+ class OmDetTurboLanguageBackbone(nn.Module):
267
+ def __init__(self, config: OmDetTurboConfig):
268
+ super().__init__()
269
+ self.model = AutoModel.from_config(config.text_config)
270
+ self.text_projection = nn.Parameter(torch.zeros(config.text_projection_in_dim, config.text_projection_out_dim))
271
+
272
+ def forward(self, hidden_states, mask=None, encode_type="task"):
273
+ text_outputs = self.model(hidden_states)
274
+ pooled_output = text_outputs[0]
275
+ if encode_type == "task":
276
+ if mask is None:
277
+ raise ValueError("mask is required for task encoding")
278
+ max_len = (mask != 0).sum(1).max().item()
279
+ truncated_mask = mask[:, :max_len]
280
+ truncated_output = pooled_output[:, :max_len, :]
281
+ return truncated_output.transpose(0, 1), truncated_mask
282
+ elif encode_type == "class":
283
+ max_pooled_output = pooled_output[torch.arange(pooled_output.shape[0]), hidden_states.argmax(dim=-1)]
284
+ projected_output = max_pooled_output @ self.text_projection
285
+ return projected_output
286
+ else:
287
+ raise ValueError(f"encode_type {encode_type} is not supported")
288
+
289
+
290
+ class OmDetTurboVisionBackbone(nn.Module):
291
+ def __init__(self, config: OmDetTurboConfig):
292
+ super().__init__()
293
+ self.apply_layernorm_after_vision_backbone = config.apply_layernorm_after_vision_backbone
294
+ self.vision_backbone = load_backbone(config)
295
+ self.layer_norms = nn.ModuleList(
296
+ [nn.LayerNorm(in_channel_dim, eps=config.layer_norm_eps) for in_channel_dim in config.encoder_in_channels]
297
+ )
298
+
299
+ def forward(self, pixel_values):
300
+ outputs = self.vision_backbone(pixel_values).feature_maps
301
+ if self.apply_layernorm_after_vision_backbone:
302
+ outputs = [
303
+ layer_norm(output).permute(0, 3, 1, 2).contiguous()
304
+ for layer_norm, output in zip(self.layer_norms, outputs)
305
+ ]
306
+
307
+ return outputs
308
+
309
+
310
+ # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention with DeformableDetr->OmDetTurbo, Deformable DETR->OmDet-Turbo
311
+ class OmDetTurboMultiscaleDeformableAttention(nn.Module):
312
+ """
313
+ Multiscale deformable attention as proposed in Deformable DETR.
314
+ """
315
+
316
+ def __init__(self, config: OmDetTurboConfig, num_heads: int, n_points: int):
317
+ super().__init__()
318
+
319
+ self.attn = MultiScaleDeformableAttention()
320
+
321
+ if config.d_model % num_heads != 0:
322
+ raise ValueError(
323
+ f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}"
324
+ )
325
+ dim_per_head = config.d_model // num_heads
326
+ # check if dim_per_head is power of 2
327
+ if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0):
328
+ warnings.warn(
329
+ "You'd better set embed_dim (d_model) in OmDetTurboMultiscaleDeformableAttention to make the"
330
+ " dimension of each attention head a power of 2 which is more efficient in the authors' CUDA"
331
+ " implementation."
332
+ )
333
+
334
+ self.im2col_step = 64
335
+
336
+ self.d_model = config.d_model
337
+ self.n_levels = config.num_feature_levels
338
+ self.n_heads = num_heads
339
+ self.n_points = n_points
340
+
341
+ self.sampling_offsets = nn.Linear(config.d_model, num_heads * self.n_levels * n_points * 2)
342
+ self.attention_weights = nn.Linear(config.d_model, num_heads * self.n_levels * n_points)
343
+ self.value_proj = nn.Linear(config.d_model, config.d_model)
344
+ self.output_proj = nn.Linear(config.d_model, config.d_model)
345
+
346
+ self.disable_custom_kernels = config.disable_custom_kernels
347
+
348
+ def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
349
+ return tensor if position_embeddings is None else tensor + position_embeddings
350
+
351
+ def forward(
352
+ self,
353
+ hidden_states: torch.Tensor,
354
+ attention_mask: Optional[torch.Tensor] = None,
355
+ encoder_hidden_states=None,
356
+ encoder_attention_mask=None,
357
+ position_embeddings: Optional[torch.Tensor] = None,
358
+ reference_points=None,
359
+ spatial_shapes=None,
360
+ spatial_shapes_list=None,
361
+ level_start_index=None,
362
+ output_attentions: bool = False,
363
+ ):
364
+ # add position embeddings to the hidden states before projecting to queries and keys
365
+ if position_embeddings is not None:
366
+ hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
367
+
368
+ batch_size, num_queries, _ = hidden_states.shape
369
+ batch_size, sequence_length, _ = encoder_hidden_states.shape
370
+ # Ignore copy
371
+ total_elements = sum([shape[0] * shape[1] for shape in spatial_shapes_list])
372
+ if total_elements != sequence_length:
373
+ raise ValueError(
374
+ "Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
375
+ )
376
+
377
+ value = self.value_proj(encoder_hidden_states)
378
+ if attention_mask is not None:
379
+ # we invert the attention_mask
380
+ value = value.masked_fill(~attention_mask[..., None], float(0))
381
+ value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads)
382
+ sampling_offsets = self.sampling_offsets(hidden_states).view(
383
+ batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2
384
+ )
385
+ attention_weights = self.attention_weights(hidden_states).view(
386
+ batch_size, num_queries, self.n_heads, self.n_levels * self.n_points
387
+ )
388
+ attention_weights = F.softmax(attention_weights, -1).view(
389
+ batch_size, num_queries, self.n_heads, self.n_levels, self.n_points
390
+ )
391
+ # batch_size, num_queries, n_heads, n_levels, n_points, 2
392
+ num_coordinates = reference_points.shape[-1]
393
+ if num_coordinates == 2:
394
+ offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
395
+ sampling_locations = (
396
+ reference_points[:, :, None, :, None, :]
397
+ + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
398
+ )
399
+ elif num_coordinates == 4:
400
+ sampling_locations = (
401
+ reference_points[:, :, None, :, None, :2]
402
+ + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
403
+ )
404
+ else:
405
+ raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
406
+
407
+ output = self.attn(
408
+ value,
409
+ spatial_shapes,
410
+ spatial_shapes_list,
411
+ level_start_index,
412
+ sampling_locations,
413
+ attention_weights,
414
+ self.im2col_step,
415
+ )
416
+
417
+ output = self.output_proj(output)
418
+
419
+ return output, attention_weights
420
+
421
+
422
+ # Copied from transformers.models.rt_detr.modeling_rt_detr.RTDetrConvNormLayer with RTDetr->OmDetTurbo
423
+ class OmDetTurboConvNormLayer(nn.Module):
424
+ def __init__(self, config, in_channels, out_channels, kernel_size, stride, padding=None, activation=None):
425
+ super().__init__()
426
+ self.conv = nn.Conv2d(
427
+ in_channels,
428
+ out_channels,
429
+ kernel_size,
430
+ stride,
431
+ padding=(kernel_size - 1) // 2 if padding is None else padding,
432
+ bias=False,
433
+ )
434
+ self.norm = nn.BatchNorm2d(out_channels, config.batch_norm_eps)
435
+ self.activation = nn.Identity() if activation is None else ACT2CLS[activation]()
436
+
437
+ def forward(self, hidden_state):
438
+ hidden_state = self.conv(hidden_state)
439
+ hidden_state = self.norm(hidden_state)
440
+ hidden_state = self.activation(hidden_state)
441
+ return hidden_state
442
+
443
+
444
+ # Copied from transformers.models.rt_detr.modeling_rt_detr.RTDetrRepVggBlock with RTDetr->OmDetTurbo, activation_function->csp_activation
445
+ class OmDetTurboRepVggBlock(nn.Module):
446
+ """
447
+ RepVGG architecture block introduced by the work "RepVGG: Making VGG-style ConvNets Great Again".
448
+ """
449
+
450
+ def __init__(self, config: OmDetTurboConfig):
451
+ super().__init__()
452
+
453
+ activation = config.csp_activation
454
+ hidden_channels = int(config.encoder_hidden_dim * config.hidden_expansion)
455
+ self.conv1 = OmDetTurboConvNormLayer(config, hidden_channels, hidden_channels, 3, 1, padding=1)
456
+ self.conv2 = OmDetTurboConvNormLayer(config, hidden_channels, hidden_channels, 1, 1, padding=0)
457
+ self.activation = nn.Identity() if activation is None else ACT2CLS[activation]()
458
+
459
+ def forward(self, x):
460
+ y = self.conv1(x) + self.conv2(x)
461
+ return self.activation(y)
462
+
463
+
464
+ # Copied from transformers.models.rt_detr.modeling_rt_detr.RTDetrCSPRepLayer with RTDetr->OmDetTurbo, activation_function->csp_activation
465
+ class OmDetTurboCSPRepLayer(nn.Module):
466
+ """
467
+ Cross Stage Partial (CSP) network layer with RepVGG blocks.
468
+ """
469
+
470
+ def __init__(self, config: OmDetTurboConfig):
471
+ super().__init__()
472
+
473
+ in_channels = config.encoder_hidden_dim * 2
474
+ out_channels = config.encoder_hidden_dim
475
+ num_blocks = 3
476
+ activation = config.csp_activation
477
+
478
+ hidden_channels = int(out_channels * config.hidden_expansion)
479
+ self.conv1 = OmDetTurboConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation)
480
+ self.conv2 = OmDetTurboConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation)
481
+ self.bottlenecks = nn.Sequential(*[OmDetTurboRepVggBlock(config) for _ in range(num_blocks)])
482
+ if hidden_channels != out_channels:
483
+ self.conv3 = OmDetTurboConvNormLayer(config, hidden_channels, out_channels, 1, 1, activation=activation)
484
+ else:
485
+ self.conv3 = nn.Identity()
486
+
487
+ def forward(self, hidden_state):
488
+ hidden_state_1 = self.conv1(hidden_state)
489
+ hidden_state_1 = self.bottlenecks(hidden_state_1)
490
+ hidden_state_2 = self.conv2(hidden_state)
491
+ return self.conv3(hidden_state_1 + hidden_state_2)
492
+
493
+
494
+ class OmDetTurboMultiheadAttention(nn.Module):
495
+ """Equivalent implementation of nn.MultiheadAttention with `batch_first=True`."""
496
+
497
+ def __init__(self, config, hidden_size, num_attention_heads, dropout):
498
+ super().__init__()
499
+ if hidden_size % num_attention_heads != 0:
500
+ raise ValueError(
501
+ f"The hidden size ({hidden_size}) is not a multiple of the number of attention "
502
+ f"heads ({num_attention_heads})"
503
+ )
504
+ self.num_attention_heads = num_attention_heads
505
+ self.attention_head_size = int(hidden_size / num_attention_heads)
506
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
507
+ self.query = nn.Linear(hidden_size, self.all_head_size)
508
+ self.key = nn.Linear(hidden_size, self.all_head_size)
509
+ self.value = nn.Linear(hidden_size, self.all_head_size)
510
+ self.out_proj = nn.Linear(hidden_size, hidden_size)
511
+ self.dropout = nn.Dropout(dropout)
512
+
513
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
514
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
515
+ x = x.view(new_x_shape)
516
+ return x.permute(0, 2, 1, 3)
517
+
518
+ def forward(
519
+ self,
520
+ queries: torch.Tensor,
521
+ keys: torch.Tensor,
522
+ values: torch.Tensor,
523
+ attention_mask: Optional[torch.FloatTensor] = None,
524
+ output_attentions: Optional[bool] = False,
525
+ ) -> Tuple[torch.Tensor]:
526
+ query_layer = self.transpose_for_scores(self.query(queries))
527
+ key_layer = self.transpose_for_scores(self.key(keys))
528
+ value_layer = self.transpose_for_scores(self.value(values))
529
+
530
+ # Take the dot product between "query" and "key" to get the raw attention scores.
531
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
532
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
533
+
534
+ if attention_mask is not None:
535
+ attention_scores = attention_scores + attention_mask
536
+
537
+ # Normalize the attention scores to probabilities.
538
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
539
+
540
+ # This is actually dropping out entire tokens to attend to, which might
541
+ # seem a bit unusual, but is taken from the original Transformer paper.
542
+ attention_probs = self.dropout(attention_probs)
543
+
544
+ context_layer = torch.matmul(attention_probs, value_layer)
545
+
546
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
547
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
548
+ context_layer = context_layer.view(new_context_layer_shape)
549
+
550
+ context_layer = self.out_proj(context_layer)
551
+
552
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
553
+
554
+ return outputs
555
+
556
+
557
+ class OmDetTurboEncoderLayer(nn.Module):
558
+ def __init__(self, config: OmDetTurboConfig):
559
+ super().__init__()
560
+ self.self_attn = OmDetTurboMultiheadAttention(
561
+ config,
562
+ hidden_size=config.encoder_hidden_dim,
563
+ num_attention_heads=config.num_attention_heads,
564
+ dropout=config.encoder_dropout,
565
+ )
566
+ self.self_attn_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps)
567
+ self.dropout = nn.Dropout(config.encoder_dropout)
568
+ self.activation_fn = ACT2FN[config.encoder_feedforward_activation]
569
+ self.encoder_feedforward_dropout = nn.Dropout(config.encoder_feedforward_dropout)
570
+ self.fc1 = nn.Linear(config.encoder_hidden_dim, config.encoder_dim_feedforward)
571
+ self.fc2 = nn.Linear(config.encoder_dim_feedforward, config.encoder_hidden_dim)
572
+ self.final_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps)
573
+
574
+ @staticmethod
575
+ def with_pos_embed(tensor, pos_embed):
576
+ return tensor if pos_embed is None else tensor + pos_embed
577
+
578
+ def forward(
579
+ self,
580
+ hidden_states: torch.Tensor,
581
+ attention_mask: torch.Tensor,
582
+ position_embeddings: Optional[torch.Tensor] = None,
583
+ output_attentions: bool = False,
584
+ ):
585
+ """
586
+ Args:
587
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
588
+ attention_mask (`torch.FloatTensor`): attention mask of size
589
+ `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
590
+ values.
591
+ position_embeddings (`torch.FloatTensor`, *optional*):
592
+ Object queries (also called content embeddings), to be added to the hidden states.
593
+ output_attentions (`bool`, *optional*, defaults to `False`):
594
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
595
+ returned tensors for more detail.
596
+ """
597
+ residual = hidden_states
598
+ query = key = self.with_pos_embed(hidden_states, position_embeddings)
599
+
600
+ hidden_states = self.self_attn(
601
+ queries=query,
602
+ keys=key,
603
+ values=hidden_states,
604
+ attention_mask=attention_mask,
605
+ output_attentions=output_attentions,
606
+ )
607
+ hidden_states, attentions = hidden_states if output_attentions else (hidden_states[0], None)
608
+ hidden_states = self.dropout(hidden_states)
609
+ hidden_states = residual + hidden_states
610
+ hidden_states = self.self_attn_layer_norm(hidden_states)
611
+ residual = hidden_states
612
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
613
+ hidden_states = self.encoder_feedforward_dropout(hidden_states)
614
+ hidden_states = self.fc2(hidden_states)
615
+ hidden_states = self.dropout(hidden_states)
616
+ hidden_states = residual + hidden_states
617
+ hidden_states = self.final_layer_norm(hidden_states)
618
+ if self.training:
619
+ if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
620
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
621
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
622
+
623
+ if output_attentions:
624
+ return hidden_states, attentions
625
+
626
+ return (hidden_states,)
627
+
628
+
629
+ class OmDetTurboEncoder(nn.Module):
630
+ def __init__(self, config: OmDetTurboConfig):
631
+ super().__init__()
632
+
633
+ self.layers = nn.ModuleList([OmDetTurboEncoderLayer(config) for _ in range(config.encoder_layers)])
634
+
635
+ def forward(
636
+ self, src, src_mask=None, pos_embed=None, output_attentions: bool = False
637
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
638
+ hidden_states = src
639
+ attention = () if output_attentions else None
640
+ for layer in self.layers:
641
+ hidden_states = layer(
642
+ hidden_states,
643
+ attention_mask=src_mask,
644
+ position_embeddings=pos_embed,
645
+ output_attentions=output_attentions,
646
+ )
647
+ if output_attentions:
648
+ attention = attention + (hidden_states[1],)
649
+ hidden_states = hidden_states[0]
650
+
651
+ return hidden_states, attention
652
+
653
+
654
+ class OmDetTurboHybridEncoder(nn.Module):
655
+ """
656
+ Encoder consisting of channel projection layers, a set of `OmDetTurboEncoder`, a top-down Feature Pyramid Network
657
+ (FPN) and a bottom-up Path Aggregation Network (PAN). More details on the paper: https://arxiv.org/abs/2304.08069
658
+
659
+ Args:
660
+ config: OmDetTurboConfig
661
+ """
662
+
663
+ def __init__(self, config: OmDetTurboConfig):
664
+ super().__init__()
665
+ self.config = config
666
+ self.in_channels = config.encoder_in_channels
667
+ self.encoder_hidden_dim = config.encoder_hidden_dim
668
+ self.encoder_projection_indices = config.encoder_projection_indices
669
+ self.positional_encoding_temperature = config.positional_encoding_temperature
670
+ self.eval_size = config.eval_size
671
+ self.out_channels = [self.encoder_hidden_dim for _ in self.in_channels]
672
+
673
+ self.channel_projection_layers = nn.ModuleList()
674
+ for in_channel in self.in_channels:
675
+ self.channel_projection_layers.append(
676
+ nn.Sequential(
677
+ nn.Conv2d(in_channel, self.encoder_hidden_dim, kernel_size=(1, 1), bias=False),
678
+ nn.BatchNorm2d(self.encoder_hidden_dim),
679
+ )
680
+ )
681
+
682
+ # encoder transformer
683
+ self.encoder = nn.ModuleList([OmDetTurboEncoder(config) for _ in range(len(self.encoder_projection_indices))])
684
+ # top-down fpn
685
+ self.lateral_convs = nn.ModuleList()
686
+ self.fpn_blocks = nn.ModuleList()
687
+ for _ in range(len(self.in_channels) - 1, 0, -1):
688
+ self.lateral_convs.append(
689
+ OmDetTurboConvNormLayer(
690
+ config,
691
+ in_channels=self.encoder_hidden_dim,
692
+ out_channels=self.encoder_hidden_dim,
693
+ kernel_size=1,
694
+ stride=1,
695
+ activation=config.conv_norm_activation,
696
+ )
697
+ )
698
+ self.fpn_blocks.append(OmDetTurboCSPRepLayer(config))
699
+
700
+ # bottom-up pan
701
+ self.downsample_convs = nn.ModuleList()
702
+ self.pan_blocks = nn.ModuleList()
703
+ for _ in range(len(self.in_channels) - 1):
704
+ self.downsample_convs.append(
705
+ OmDetTurboConvNormLayer(
706
+ config,
707
+ in_channels=self.encoder_hidden_dim,
708
+ out_channels=self.encoder_hidden_dim,
709
+ kernel_size=3,
710
+ stride=2,
711
+ activation=config.conv_norm_activation,
712
+ )
713
+ )
714
+ self.pan_blocks.append(OmDetTurboCSPRepLayer(config))
715
+
716
+ @staticmethod
717
+ def build_2d_sincos_position_embedding(
718
+ width, height, embed_dim=256, temperature=10000.0, device="cpu", dtype=torch.float32
719
+ ):
720
+ grid_w = torch.arange(int(width), dtype=dtype, device=device)
721
+ grid_h = torch.arange(int(height), dtype=dtype, device=device)
722
+ grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij")
723
+ if embed_dim % 4 != 0:
724
+ raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding")
725
+ pos_dim = embed_dim // 4
726
+ omega = torch.arange(pos_dim, dtype=dtype, device=device) / pos_dim
727
+ omega = 1.0 / (temperature**omega)
728
+
729
+ out_w = grid_w.flatten()[..., None] @ omega[None]
730
+ out_h = grid_h.flatten()[..., None] @ omega[None]
731
+
732
+ return torch.concat([out_w.sin(), out_w.cos(), out_h.sin(), out_h.cos()], dim=1)[None, :, :]
733
+
734
+ def forward(
735
+ self,
736
+ inputs_embeddings=None,
737
+ output_attentions=None,
738
+ output_hidden_states=None,
739
+ return_dict=None,
740
+ ):
741
+ r"""
742
+ Args:
743
+ inputs_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
744
+ Flattened feature map (output of the backbone + projection layers) that is passed to the encoder.
745
+ output_attentions (`bool`, *optional*):
746
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
747
+ returned tensors for more detail.
748
+ output_hidden_states (`bool`, *optional*):
749
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
750
+ for more detail.
751
+ return_dict (`bool`, *optional*):
752
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
753
+ """
754
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
755
+ output_hidden_states = (
756
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
757
+ )
758
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
759
+
760
+ hidden_states = inputs_embeddings
761
+
762
+ encoder_states = () if output_hidden_states else None
763
+ all_attentions = () if output_attentions else None
764
+ # get projection features
765
+ projected_features = [self.channel_projection_layers[i](feature) for i, feature in enumerate(hidden_states)]
766
+ # encoder
767
+ for encoder_layer_index, feature_to_project_index in enumerate(self.encoder_projection_indices):
768
+ if output_hidden_states:
769
+ encoder_states = encoder_states + (projected_features[feature_to_project_index],)
770
+ height, width = projected_features[feature_to_project_index].shape[2:]
771
+ # flatten [batch, channel, height, width] to [batch, height*width, channel]
772
+ src_flatten = projected_features[feature_to_project_index].flatten(2).permute(0, 2, 1)
773
+ if self.training or self.eval_size is None:
774
+ pos_embed = self.build_2d_sincos_position_embedding(
775
+ width,
776
+ height,
777
+ self.encoder_hidden_dim,
778
+ self.positional_encoding_temperature,
779
+ device=src_flatten.device,
780
+ dtype=src_flatten.dtype,
781
+ ).to(src_flatten.device, src_flatten.dtype)
782
+ else:
783
+ pos_embed = None
784
+ layer_outputs = self.encoder[encoder_layer_index](
785
+ src_flatten,
786
+ pos_embed=pos_embed,
787
+ output_attentions=output_attentions,
788
+ )
789
+ projected_features[feature_to_project_index] = (
790
+ layer_outputs[0].permute(0, 2, 1).reshape(-1, self.encoder_hidden_dim, height, width).contiguous()
791
+ )
792
+
793
+ if output_attentions:
794
+ all_attentions = all_attentions + (layer_outputs[1],)
795
+
796
+ if output_hidden_states:
797
+ encoder_states = encoder_states + (projected_features[feature_to_project_index],)
798
+
799
+ # Feature Pyramid Network (FPN)
800
+ fpn_feature_maps = [projected_features[-1]]
801
+ for idx in range(len(self.in_channels) - 1, 0, -1):
802
+ feat_high = fpn_feature_maps[0]
803
+ feat_low = projected_features[idx - 1]
804
+ feat_high = self.lateral_convs[len(self.in_channels) - 1 - idx](feat_high)
805
+ fpn_feature_maps[0] = feat_high
806
+ upsample_feat = F.interpolate(feat_high, scale_factor=2.0, mode="nearest")
807
+ fps_map = self.fpn_blocks[len(self.in_channels) - 1 - idx](torch.concat([upsample_feat, feat_low], dim=1))
808
+ fpn_feature_maps.insert(0, fps_map)
809
+
810
+ # Path Aggregation Network (PAN)
811
+ fpn_states = [fpn_feature_maps[0]]
812
+ for idx in range(len(self.in_channels) - 1):
813
+ feat_low = fpn_states[-1]
814
+ feat_high = fpn_feature_maps[idx + 1]
815
+ downsample_feat = self.downsample_convs[idx](feat_low)
816
+ hidden_states = self.pan_blocks[idx](
817
+ torch.concat([downsample_feat, feat_high.to(downsample_feat.device)], dim=1)
818
+ )
819
+ fpn_states.append(hidden_states)
820
+ if not return_dict:
821
+ return (fpn_states[-1], encoder_states, all_attentions, fpn_states)
822
+ return OmDetTurboEncoderOutput(
823
+ last_hidden_state=fpn_states[-1],
824
+ hidden_states=encoder_states,
825
+ attentions=all_attentions,
826
+ extracted_states=fpn_states,
827
+ )
828
+
829
+
830
+ class OmDetTurboMLPWithDropout(nn.Module):
831
+ def __init__(self, config):
832
+ super().__init__()
833
+ self.linear1 = nn.Linear(config.class_embed_dim, config.task_encoder_hidden_dim)
834
+ self.activation = ACT2FN[config.decoder_activation]
835
+ self.dropout = nn.Dropout(config.decoder_dropout)
836
+ self.linear2 = nn.Linear(config.task_encoder_hidden_dim, config.class_embed_dim)
837
+
838
+ def forward(self, x):
839
+ return self.linear2(self.dropout(self.activation(self.linear1(x))))
840
+
841
+
842
+ class OmDetTurboMLP(nn.Module):
843
+ """Very simple multi-layer perceptron (also called FFN)"""
844
+
845
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
846
+ super().__init__()
847
+ self.num_layers = num_layers
848
+ hidden_layers_dims = [hidden_dim] * (num_layers - 1)
849
+ layers_dims = [input_dim] + hidden_layers_dims + [output_dim]
850
+ self.layers = nn.ModuleList(
851
+ [nn.Linear(in_dim, out_dim) for in_dim, out_dim in zip(layers_dims[:-1], layers_dims[1:])]
852
+ )
853
+
854
+ def forward(self, x):
855
+ for i, layer in enumerate(self.layers):
856
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
857
+ return x
858
+
859
+
860
+ class OmDetTurboResidualLayer(nn.Module):
861
+ """
862
+ A residual connection followed by a layer norm.
863
+ """
864
+
865
+ def __init__(self, config):
866
+ super().__init__()
867
+ self.norm1 = nn.LayerNorm(config.class_embed_dim, eps=config.layer_norm_eps)
868
+ self.dropout = nn.Dropout(config.decoder_dropout)
869
+
870
+ def forward(self, x, y):
871
+ return self.norm1(x + self.dropout(y))
872
+
873
+
874
+ class OmDetTurboTaskEncoder(nn.Module):
875
+ def __init__(self, config):
876
+ super().__init__()
877
+ self.mlp = OmDetTurboMLPWithDropout(config)
878
+ self.res1 = OmDetTurboResidualLayer(config)
879
+
880
+ def forward(self, x):
881
+ mlp_out = self.mlp(x)
882
+ x = self.res1(x, mlp_out)
883
+ return x
884
+
885
+
886
+ class OmDetTurboDeformableTransformerDecoderLayer(nn.Module):
887
+ """
888
+ A single layer of the Deformable Transformer Decoder.
889
+ """
890
+
891
+ def __init__(self, config):
892
+ super().__init__()
893
+ # self attention
894
+ self.self_attn = OmDetTurboMultiheadAttention(
895
+ config,
896
+ hidden_size=config.decoder_hidden_dim,
897
+ num_attention_heads=config.decoder_num_heads,
898
+ dropout=config.decoder_dropout,
899
+ )
900
+ self.dropout1 = nn.Dropout(config.decoder_dropout)
901
+ self.norm1 = nn.LayerNorm(config.decoder_hidden_dim, eps=config.layer_norm_eps)
902
+
903
+ # cross attention
904
+ self.cross_attn = OmDetTurboMultiscaleDeformableAttention(
905
+ config, num_heads=config.decoder_num_heads, n_points=config.decoder_num_points
906
+ )
907
+ self.dropout2 = nn.Dropout(config.decoder_dropout)
908
+ self.norm2 = nn.LayerNorm(config.decoder_hidden_dim, eps=config.layer_norm_eps)
909
+
910
+ # feed forward network
911
+ self.linear1 = nn.Linear(config.decoder_hidden_dim, config.decoder_dim_feedforward)
912
+ self.act = ACT2FN[config.decoder_activation]
913
+ self.dropout3 = nn.Dropout(config.decoder_dropout)
914
+ self.linear2 = nn.Linear(config.decoder_dim_feedforward, config.decoder_hidden_dim)
915
+ self.dropout4 = nn.Dropout(config.decoder_dropout)
916
+ self.norm3 = nn.LayerNorm(config.decoder_hidden_dim, eps=config.layer_norm_eps)
917
+
918
+ self.output_attentions = config.output_attentions
919
+ self.output_hidden_states = config.output_hidden_states
920
+
921
+ @staticmethod
922
+ def with_pos_embed(tensor, pos):
923
+ return tensor if pos is None else tensor + pos
924
+
925
+ def forward(
926
+ self,
927
+ decoder_embeddings,
928
+ task_features,
929
+ reference_points,
930
+ vision_features,
931
+ vision_shapes,
932
+ vision_shapes_list,
933
+ level_start_index=None,
934
+ attention_mask=None,
935
+ padding_mask=None,
936
+ query_position=None,
937
+ output_attentions=None,
938
+ output_hidden_states=None,
939
+ ):
940
+ output_attentions = output_attentions if output_attentions is not None else self.output_attentions
941
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
942
+
943
+ origin_embedding_len = decoder_embeddings.shape[1]
944
+
945
+ # self attention
946
+ query = key = self.with_pos_embed(decoder_embeddings, query_position)
947
+ # combine task_features with query, key, value
948
+ task_features = task_features.transpose(0, 1)
949
+ query = torch.cat((query, task_features), dim=1)
950
+ key = torch.cat((key, task_features), dim=1)
951
+ decoder_embeddings = torch.cat((decoder_embeddings, task_features), dim=1)
952
+
953
+ outputs = self.self_attn(
954
+ query,
955
+ key,
956
+ decoder_embeddings,
957
+ attention_mask=attention_mask,
958
+ output_attentions=output_attentions,
959
+ )
960
+ context, self_attention = outputs if output_attentions else (outputs[0], None)
961
+ decoder_embeddings = decoder_embeddings + self.dropout1(context)
962
+ decoder_embeddings = self.norm1(decoder_embeddings)
963
+
964
+ task_features = decoder_embeddings[:, origin_embedding_len:, :].transpose(0, 1)
965
+ decoder_embeddings = decoder_embeddings[:, :origin_embedding_len, :]
966
+
967
+ # cross attention
968
+ hidden_states = self.with_pos_embed(decoder_embeddings, query_position)
969
+ reference_points = reference_points.unsqueeze(2)
970
+ outputs, cross_attention = self.cross_attn(
971
+ hidden_states=hidden_states,
972
+ attention_mask=padding_mask,
973
+ encoder_hidden_states=vision_features,
974
+ reference_points=reference_points,
975
+ spatial_shapes=vision_shapes,
976
+ spatial_shapes_list=vision_shapes_list,
977
+ level_start_index=level_start_index,
978
+ )
979
+ decoder_embeddings = decoder_embeddings + self.dropout2(outputs)
980
+ residual = self.norm2(decoder_embeddings)
981
+
982
+ # feed forward network
983
+ decoder_embeddings = self.linear2(self.dropout3(self.act(self.linear1(residual))))
984
+ decoder_embeddings = residual + self.dropout4(decoder_embeddings)
985
+ decoder_embeddings = self.norm3(decoder_embeddings)
986
+
987
+ return (
988
+ decoder_embeddings,
989
+ task_features,
990
+ self_attention if output_attentions else None,
991
+ cross_attention if output_attentions else None,
992
+ )
993
+
994
+
995
+ class OmDetTurboPreTrainedModel(PreTrainedModel):
996
+ config_class = OmDetTurboConfig
997
+ base_model_prefix = "model"
998
+ main_input_name = "pixel_values"
999
+
1000
+ def _init_weights(self, module):
1001
+ def linear_init_(module_to_init):
1002
+ bound = 1 / math.sqrt(module_to_init.weight.shape[0])
1003
+ nn.init.uniform_(module_to_init.weight, -bound, bound)
1004
+ if hasattr(module_to_init, "bias") and module_to_init.bias is not None:
1005
+ nn.init.uniform_(module_to_init.bias, -bound, bound)
1006
+
1007
+ if isinstance(module, OmDetTurboEncoderLayer):
1008
+ linear_init_(module.fc1)
1009
+ linear_init_(module.fc2)
1010
+ elif isinstance(module, OmDetTurboDecoder):
1011
+ nn.init.constant_(module.encoder_bbox_head.layers[-1].weight, 0.0)
1012
+ nn.init.constant_(module.encoder_bbox_head.layers[-1].bias, 0.0)
1013
+ for mlp in module.decoder_bbox_head:
1014
+ nn.init.constant_(mlp.layers[-1].weight, 0.0)
1015
+ nn.init.constant_(mlp.layers[-1].bias, 0.0)
1016
+ linear_init_(module.encoder_vision_features[0])
1017
+ nn.init.xavier_uniform_(module.encoder_vision_features[0].weight)
1018
+ if module.learn_initial_query:
1019
+ nn.init.xavier_uniform_(module.tgt_embed.weight)
1020
+ nn.init.xavier_uniform_(module.query_position_head.layers[0].weight)
1021
+ nn.init.xavier_uniform_(module.query_position_head.layers[1].weight)
1022
+ for layer in module.channel_projection_layers:
1023
+ nn.init.xavier_uniform_(layer[0].weight)
1024
+ elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
1025
+ module.weight.data.normal_(mean=0.0, std=self.config.init_std)
1026
+ if module.bias is not None:
1027
+ module.bias.data.zero_()
1028
+
1029
+ def _set_gradient_checkpointing(self, module, value=False):
1030
+ if isinstance(module, OmDetTurboDecoder):
1031
+ module.gradient_checkpointing = value
1032
+
1033
+ @staticmethod
1034
+ def _get_cache_key_at_index(input_ids, attention_mask, index):
1035
+ input_ids = input_ids[index]
1036
+ input_mask = attention_mask[index]
1037
+ cache_key = tuple(input_ids[input_mask != 0].tolist())
1038
+ return cache_key
1039
+
1040
+ def get_cached_class_embeddings(self, classes_input_ids, classes_attention_mask):
1041
+ not_cached_index = []
1042
+ not_cached_classes = []
1043
+ total_embeddings = []
1044
+ for idx, _ in enumerate(classes_input_ids):
1045
+ cache_key = self._get_cache_key_at_index(classes_input_ids, classes_attention_mask, idx)
1046
+ if self.language_cache_class.has(cache_key):
1047
+ total_embeddings.append(self.language_cache_class.get(cache_key))
1048
+ else:
1049
+ total_embeddings.append(None)
1050
+ not_cached_index.append(idx)
1051
+ not_cached_classes.append(cache_key)
1052
+
1053
+ if not_cached_classes:
1054
+ not_cached_classes_ids = torch.stack([classes_input_ids[idx] for idx in not_cached_index])
1055
+ embeddings = self.language_backbone(not_cached_classes_ids, encode_type="class")
1056
+ for idx, emb in enumerate(embeddings):
1057
+ idx_to_put = not_cached_index[idx]
1058
+ total_embeddings[idx_to_put] = emb
1059
+ self.language_cache_class.put(not_cached_classes[idx], emb)
1060
+
1061
+ total_class_embs = torch.stack(total_embeddings).to(self.device)
1062
+ return total_class_embs
1063
+
1064
+ def get_cached_task_embeddings(self, tasks_input_ids, tasks_attention_mask):
1065
+ not_cached_index = []
1066
+ not_cached_tasks = []
1067
+ total_task_features = []
1068
+ total_task_masks = []
1069
+ for idx, _ in enumerate(tasks_input_ids):
1070
+ cache_key = self._get_cache_key_at_index(tasks_input_ids, tasks_attention_mask, idx)
1071
+ if self.language_cache_prompt.has(cache_key):
1072
+ task_feature, task_mask = self.language_cache_prompt.get(cache_key)
1073
+ total_task_features.append(task_feature)
1074
+ total_task_masks.append(task_mask)
1075
+ else:
1076
+ total_task_features.append(None)
1077
+ total_task_masks.append(None)
1078
+ not_cached_index.append(idx)
1079
+ not_cached_tasks.append(cache_key)
1080
+
1081
+ if not_cached_tasks:
1082
+ not_cached_index_ids = torch.stack([tasks_input_ids[idx] for idx in not_cached_index])
1083
+ not_cached_mask = torch.stack([tasks_attention_mask[idx] for idx in not_cached_index])
1084
+ embeddings, masks = self.language_backbone(not_cached_index_ids, mask=not_cached_mask, encode_type="task")
1085
+
1086
+ for idx in range(embeddings.shape[1]):
1087
+ emb = embeddings[:, [idx], :]
1088
+ idx_to_put = not_cached_index[idx]
1089
+ cur_mask = torch.unsqueeze(masks[idx], dim=0).to(self.device)
1090
+ total_task_features[idx_to_put] = emb
1091
+ total_task_masks[idx_to_put] = cur_mask
1092
+ self.language_cache_prompt.put(not_cached_tasks[idx], (emb, cur_mask))
1093
+
1094
+ # pad before concat if needed
1095
+ max_len = max([task.shape[0] for task in total_task_features])
1096
+ for idx, task in enumerate(total_task_features):
1097
+ if task.shape[0] < max_len:
1098
+ pad_size = max_len - task.shape[0]
1099
+ total_task_features[idx] = F.pad(task, (0, 0, 0, 0, 0, pad_size))
1100
+ total_task_masks[idx] = F.pad(total_task_masks[idx], (0, pad_size))
1101
+
1102
+ total_task_features = torch.cat(total_task_features, dim=1).to(self.device)
1103
+ total_task_masks = torch.cat(total_task_masks, dim=0).to(self.device)
1104
+
1105
+ return total_task_features, total_task_masks
1106
+
1107
+ def get_language_embedding(
1108
+ self,
1109
+ classes_input_ids,
1110
+ classes_attention_mask,
1111
+ tasks_input_ids,
1112
+ tasks_attention_mask,
1113
+ classes_structure,
1114
+ ):
1115
+ batched_classes_embeddings = self.get_cached_class_embeddings(classes_input_ids, classes_attention_mask)
1116
+ # regroup class embeddings using saved structure
1117
+ max_class_size = torch.max(classes_structure)
1118
+ class_embeddings_regrouped = []
1119
+ start = 0
1120
+ for size in classes_structure:
1121
+ pad_size = max_class_size - size
1122
+ class_embeddings_regrouped.append(
1123
+ F.pad(batched_classes_embeddings[start : start + size], (0, 0, 0, pad_size)).unsqueeze(1)
1124
+ )
1125
+ start += size
1126
+ class_embeddings = torch.cat(class_embeddings_regrouped, dim=1)
1127
+
1128
+ task_embeddings, task_mask = self.get_cached_task_embeddings(tasks_input_ids, tasks_attention_mask)
1129
+
1130
+ return class_embeddings, task_embeddings, task_mask
1131
+
1132
+
1133
+ OMDET_TURBO_START_DOCSTRING = r"""
1134
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1135
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1136
+ etc.)
1137
+
1138
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1139
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1140
+ and behavior.
1141
+
1142
+ Parameters:
1143
+ config ([`OmDetTurboConfig`]):
1144
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
1145
+ load the weights associated with the model, only the configuration. Check out the
1146
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1147
+ """
1148
+
1149
+ OMDET_TURBO_INPUTS_DOCSTRING = r"""
1150
+ Args:
1151
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
1152
+ Pixel values. Padding will be ignored by default should you provide it.
1153
+
1154
+ Pixel values can be obtained using [`AutoImageProcessor`]. See [`DetrImageProcessor.__call__`] for
1155
+ details.
1156
+
1157
+ classes_input_ids (`torch.LongTensor` of shape `(total_classes (>= batch_size), sequence_length)`):
1158
+ Indices of input classes sequence tokens in the vocabulary of the language model.
1159
+ Several classes can be provided for each tasks, thus the tokenized classes are flattened
1160
+ and the structure of the classes is provided in the `classes_structure` argument.
1161
+
1162
+ Indices can be obtained using [`OmDetTurboProcessor`]. See [`OmDetTurboProcessor.__call__`] for
1163
+ details.
1164
+
1165
+ [What are input IDs?](../glossary#input-ids)
1166
+
1167
+ classes_attention_mask (`torch.BoolTensor` of shape `(total_classes (>= batch_size), num_classes, sequence_length)`):
1168
+ Attention mask for the classes. This is a binary mask that indicates which tokens should be attended to,
1169
+ and which should not.
1170
+
1171
+ tasks_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1172
+ Indices of input tasks sequence tokens in the vocabulary of the language model.
1173
+
1174
+ Indices can be obtained using [`OmDetTurboProcessor`]. See [`OmDetTurboProcessor.__call__`] for
1175
+ details.
1176
+
1177
+ [What are input IDs?](../glossary#input-ids)
1178
+
1179
+ tasks_attention_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
1180
+ Attention mask for the tasks. This is a binary mask that indicates which tokens should be attended to,
1181
+ and which should not.
1182
+
1183
+ classes_structure (torch.LongTensor of shape `(batch_size)`):
1184
+ Structure of the classes. This tensor indicates the number of classes for each task.
1185
+
1186
+ output_attentions (`bool`, *optional*):
1187
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1188
+ tensors for more detail.
1189
+ output_hidden_states (`bool`, *optional*):
1190
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1191
+ more detail.
1192
+ return_dict (`bool`, *optional*):
1193
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1194
+
1195
+ """
1196
+
1197
+
1198
+ def _cosine_similarity_scaled(a, b, logit_scale):
1199
+ a = a / a.norm(dim=2, keepdim=True).clamp_min(1e-12)
1200
+ b = b / b.norm(dim=1, keepdim=True).clamp_min(1e-12)
1201
+ logit_scale = logit_scale.exp()
1202
+ logits_per_image = logit_scale * torch.bmm(a, b)
1203
+ return logits_per_image
1204
+
1205
+
1206
+ def get_class_similarity(class_distance_type, cls_feature, class_proj):
1207
+ logit_scale = torch.tensor(1 / 0.07).log()
1208
+ if class_distance_type == "cosine":
1209
+ class_logits = _cosine_similarity_scaled(cls_feature, class_proj, logit_scale)
1210
+ elif class_distance_type == "dot":
1211
+ class_logits = torch.bmm(cls_feature, class_proj)
1212
+ else:
1213
+ raise Exception("Unknown class_distance_type {}".format(class_distance_type))
1214
+ return class_logits
1215
+
1216
+
1217
+ def _inverse_sigmoid(x, eps=1e-5):
1218
+ x = x.clamp(min=0, max=1)
1219
+ x1 = x.clamp(min=eps)
1220
+ x2 = (1 - x).clamp(min=eps)
1221
+ return torch.log(x1 / x2)
1222
+
1223
+
1224
+ class OmDetTurboDecoder(OmDetTurboPreTrainedModel):
1225
+ def __init__(self, config: OmDetTurboConfig):
1226
+ self.config = config
1227
+ super().__init__(config)
1228
+ self.gradient_checkpointing = False
1229
+
1230
+ hidden_dim = config.decoder_hidden_dim
1231
+ self.num_queries = config.num_queries
1232
+ self.class_distance_type = config.class_distance_type
1233
+ self.learn_initial_query = config.learn_initial_query
1234
+
1235
+ # backbone feature projection
1236
+ self.channel_projection_layers = nn.ModuleList(
1237
+ nn.Sequential(nn.Conv2d(x, hidden_dim, 1, bias=False), nn.BatchNorm2d(hidden_dim))
1238
+ for x in config.vision_features_channels
1239
+ )
1240
+ self.task_encoder = OmDetTurboTaskEncoder(config)
1241
+ if config.class_embed_dim != hidden_dim:
1242
+ self.task_project = nn.Linear(config.class_embed_dim, hidden_dim)
1243
+
1244
+ # Transformer module
1245
+ self.layers = nn.ModuleList(
1246
+ [OmDetTurboDeformableTransformerDecoderLayer(config) for _ in range(config.decoder_num_layers)]
1247
+ )
1248
+ self.decoder_num_layers = config.decoder_num_layers
1249
+ # decoder embedding
1250
+ if self.learn_initial_query:
1251
+ self.tgt_embed = nn.Embedding(self.num_queries, hidden_dim)
1252
+ self.query_position_head = OmDetTurboMLP(
1253
+ input_dim=4, hidden_dim=2 * hidden_dim, output_dim=hidden_dim, num_layers=2
1254
+ )
1255
+
1256
+ # encoder head
1257
+ self.encoder_vision_features = nn.Sequential(
1258
+ nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim, eps=config.layer_norm_eps)
1259
+ )
1260
+ self.encoder_class_head = nn.Linear(config.class_embed_dim, hidden_dim)
1261
+ self.encoder_bbox_head = OmDetTurboMLP(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=4, num_layers=3)
1262
+
1263
+ # decoder head
1264
+ self.decoder_class_head = nn.ModuleList(
1265
+ [nn.Linear(config.class_embed_dim, hidden_dim) for _ in range(config.decoder_num_layers)]
1266
+ )
1267
+ self.decoder_bbox_head = nn.ModuleList(
1268
+ [OmDetTurboMLP(hidden_dim, hidden_dim, 4, num_layers=3) for _ in range(config.decoder_num_layers)]
1269
+ )
1270
+
1271
+ # Initialize weights and apply final processing
1272
+ self.post_init()
1273
+
1274
+ @lru_cache(maxsize=32)
1275
+ def generate_anchors(self, spatial_shapes=None, grid_size=0.05, device="cpu", dtype=torch.float32):
1276
+ # We always generate anchors in float32 to preserve equivalence between
1277
+ # dynamic and static anchor inference
1278
+ # Ignore copy
1279
+ if spatial_shapes is None:
1280
+ raise ValueError("spatial_shapes must be provided")
1281
+
1282
+ anchors = []
1283
+ for level, (height, width) in enumerate(spatial_shapes):
1284
+ grid_y, grid_x = torch.meshgrid(
1285
+ torch.arange(end=height, dtype=dtype, device=device),
1286
+ torch.arange(end=width, dtype=dtype, device=device),
1287
+ indexing="ij",
1288
+ )
1289
+ grid_xy = torch.stack([grid_x, grid_y], -1)
1290
+ valid_wh = torch.tensor([width, height], dtype=dtype, device=device)
1291
+ grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_wh
1292
+ wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0**level)
1293
+ anchors.append(torch.concat([grid_xy, wh], -1).reshape(-1, height * width, 4))
1294
+ # define the valid range for anchor coordinates
1295
+ eps = 1e-2
1296
+ anchors = torch.concat(anchors, 1)
1297
+ valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True)
1298
+ anchors = torch.log(anchors / (1 - anchors))
1299
+ anchors = torch.where(valid_mask, anchors, torch.inf)
1300
+
1301
+ return anchors, valid_mask
1302
+
1303
+ def _get_encoder_input(self, vision_features):
1304
+ # get projection features
1305
+ vision_features = [self.channel_projection_layers[i](feat) for i, feat in enumerate(vision_features)]
1306
+ # get encoder inputs
1307
+ new_vision_features = []
1308
+ new_vision_shapes_list = []
1309
+ for feat in vision_features:
1310
+ height, width = feat.shape[2:]
1311
+ # [batch_size, channels, height, width] -> [batch_size, height*width, channels]
1312
+ new_vision_features.append(feat.flatten(2).permute(0, 2, 1))
1313
+ # [num_feature_levels, 2]
1314
+ new_vision_shapes_list.append((height, width))
1315
+
1316
+ # [batch_size, height*width, channels]
1317
+ new_vision_features = torch.cat(new_vision_features, 1)
1318
+ new_vision_shapes = torch.tensor(new_vision_shapes_list, dtype=torch.int64, device=vision_features[0].device)
1319
+ level_start_index = torch.cat((new_vision_shapes.new_zeros((1,)), new_vision_shapes.prod(1).cumsum(0)[:-1]))
1320
+
1321
+ return new_vision_features, new_vision_shapes, new_vision_shapes_list, level_start_index
1322
+
1323
+ def _get_decoder_input(
1324
+ self, vision_features, vision_shapes, class_features, denoise_embeddings=None, denoise_bboxes=None
1325
+ ):
1326
+ batch_size = len(vision_features)
1327
+ # prepare input for decoder
1328
+ anchors, valid_mask = self.generate_anchors(
1329
+ vision_shapes, device=vision_features.device, dtype=vision_features.dtype
1330
+ )
1331
+ predicted_class_features = self.encoder_vision_features(
1332
+ torch.where(
1333
+ valid_mask,
1334
+ vision_features,
1335
+ torch.tensor(0.0, dtype=vision_features.dtype, device=vision_features.device),
1336
+ )
1337
+ )
1338
+
1339
+ original_class_projected = self.encoder_class_head(class_features).permute(1, 2, 0)
1340
+ encoder_class_similarity = get_class_similarity(
1341
+ self.class_distance_type, predicted_class_features, original_class_projected
1342
+ )
1343
+
1344
+ # dynamic anchors + static content
1345
+ # (batch_size, height*width, 4)
1346
+ encoder_outputs_bboxes = self.encoder_bbox_head(predicted_class_features) + anchors
1347
+
1348
+ # query selection
1349
+ # (batch_size, num_queries)
1350
+ topk_ind = torch.topk(encoder_class_similarity.max(-1).values, self.num_queries, dim=1).indices.view(-1)
1351
+ # (batch_size, num_queries)
1352
+ batch_ind = (
1353
+ torch.arange(end=batch_size, dtype=topk_ind.dtype, device=topk_ind.device)
1354
+ .unsqueeze(-1)
1355
+ .repeat(1, self.num_queries)
1356
+ .view(-1)
1357
+ )
1358
+
1359
+ reference_points = encoder_outputs_bboxes[batch_ind, topk_ind].view(batch_size, self.num_queries, -1)
1360
+ encoder_bboxes = reference_points.sigmoid()
1361
+ if denoise_bboxes is not None:
1362
+ reference_points = torch.cat([denoise_bboxes, reference_points], 1)
1363
+ if self.training:
1364
+ reference_points = reference_points.detach()
1365
+ encoder_class_similarity = encoder_class_similarity[batch_ind, topk_ind].view(batch_size, self.num_queries, -1)
1366
+
1367
+ if self.learn_initial_query:
1368
+ embeddings = self.tgt_embed.weight.unsqueeze(0).repeat(batch_size, 1, 1)
1369
+ else:
1370
+ embeddings = predicted_class_features[batch_ind, topk_ind].view(batch_size, self.num_queries, -1)
1371
+ if self.training:
1372
+ embeddings = embeddings.detach()
1373
+ if denoise_embeddings is not None:
1374
+ embeddings = torch.cat([denoise_embeddings, embeddings], 1)
1375
+
1376
+ return embeddings, reference_points, encoder_bboxes, encoder_class_similarity, anchors
1377
+
1378
+ def forward(
1379
+ self,
1380
+ vision_features,
1381
+ class_features,
1382
+ task_features,
1383
+ task_mask,
1384
+ output_attentions=None,
1385
+ output_hidden_states=None,
1386
+ return_dict=None,
1387
+ ):
1388
+ """
1389
+ Args:
1390
+ vision_features (`torch.FloatTensor`): The sequence of vision features. shape depends on the vision
1391
+ backbone.
1392
+ class_features (`torch.FloatTensor`): The sequence of class features of shape
1393
+ `(class_sequence_length, batch_size, class_embed_dim)`.
1394
+ task_features (`torch.FloatTensor`): The sequence of task features of shape
1395
+ `(task_sequence_length, batch_size, decoder_hidden_dim)`.
1396
+ task_mask (`torch.LongTensor`): The mask for the task features of shape `(batch_size, task_sequence_length)`.
1397
+ output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention
1398
+ layers. See `attentions` under returned tensors for more detail.
1399
+ output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See
1400
+ `hidden_states` under returned tensors for more detail.
1401
+ return_dict (`bool`, *optional*): Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain
1402
+ tuple.
1403
+ """
1404
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1405
+ output_hidden_states = (
1406
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1407
+ )
1408
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1409
+
1410
+ vision_features, vision_shapes, vision_shapes_list, level_start_index = self._get_encoder_input(
1411
+ vision_features
1412
+ )
1413
+
1414
+ # todo add denoising for training
1415
+ denoise_embeddings, denoise_bboxes, key_padding_mask = None, None, None
1416
+ batch_size = task_mask.shape[0]
1417
+
1418
+ # compose attn_mask for vision_emb and task_emb fusion
1419
+ task_features = self.task_encoder(task_features)
1420
+ if self.task_project is not None:
1421
+ task_features = self.task_project(task_features)
1422
+ src_key_mask = (task_mask == 0).detach()
1423
+ attn_mask_len = self.num_queries
1424
+ fusion_size = attn_mask_len + task_features.shape[0]
1425
+ key_padding_mask = torch.zeros([batch_size, fusion_size], dtype=torch.bool).to(task_features.device)
1426
+ key_padding_mask[:, attn_mask_len:] = src_key_mask
1427
+ attention_mask = _prepare_4d_attention_mask(~key_padding_mask, dtype=vision_features.dtype)
1428
+ decoder_embeddings, reference_points, encoder_bboxes, encoder_class_similarity, init_reference_points = (
1429
+ self._get_decoder_input(
1430
+ vision_features, tuple(vision_shapes_list), class_features, denoise_embeddings, denoise_bboxes
1431
+ )
1432
+ )
1433
+
1434
+ all_hidden_states = () if output_hidden_states else None
1435
+ all_attns = () if output_attentions else None
1436
+ all_self_attns = () if output_attentions else None
1437
+ all_cross_attns = () if output_attentions else None
1438
+ predicted_class_features = decoder_embeddings
1439
+
1440
+ if output_hidden_states:
1441
+ all_hidden_states = all_hidden_states + (predicted_class_features,)
1442
+ decoder_bboxes = []
1443
+ decoder_classes = []
1444
+ last_refined_bbox = None
1445
+ reference_points = reference_points.sigmoid()
1446
+ for i, layer in enumerate(self.layers):
1447
+ if self.gradient_checkpointing and self.training:
1448
+ predicted_class_features, task_features, self_attention, cross_attention = (
1449
+ self._gradient_checkpointing_func(
1450
+ layer.__call__,
1451
+ predicted_class_features,
1452
+ task_features,
1453
+ reference_points,
1454
+ vision_features,
1455
+ vision_shapes,
1456
+ vision_shapes_list,
1457
+ level_start_index=level_start_index,
1458
+ attention_mask=attention_mask,
1459
+ query_position=self.query_position_head(reference_points),
1460
+ output_attentions=output_attentions,
1461
+ output_hidden_states=output_hidden_states,
1462
+ )
1463
+ )
1464
+ else:
1465
+ predicted_class_features, task_features, self_attention, cross_attention = layer(
1466
+ predicted_class_features,
1467
+ task_features,
1468
+ reference_points,
1469
+ vision_features,
1470
+ vision_shapes,
1471
+ vision_shapes_list,
1472
+ level_start_index=level_start_index,
1473
+ attention_mask=attention_mask,
1474
+ query_position=self.query_position_head(reference_points),
1475
+ output_attentions=output_attentions,
1476
+ output_hidden_states=output_hidden_states,
1477
+ )
1478
+ if output_attentions:
1479
+ all_self_attns = all_self_attns + (self_attention,)
1480
+ all_cross_attns = all_cross_attns + (cross_attention,)
1481
+ if output_hidden_states:
1482
+ all_hidden_states = all_hidden_states + (predicted_class_features,)
1483
+
1484
+ refined_bbox = torch.sigmoid(
1485
+ self.decoder_bbox_head[i](predicted_class_features) + _inverse_sigmoid(reference_points)
1486
+ )
1487
+ original_class_projected = self.decoder_class_head[i](class_features).permute(1, 2, 0)
1488
+ if self.training:
1489
+ decoder_classes.append(
1490
+ get_class_similarity(
1491
+ class_distance_type=self.class_distance_type,
1492
+ cls_feature=predicted_class_features,
1493
+ class_proj=original_class_projected,
1494
+ )
1495
+ )
1496
+ if i == 0:
1497
+ decoder_bboxes.append(refined_bbox)
1498
+ else:
1499
+ decoder_bboxes.append(
1500
+ torch.sigmoid(
1501
+ self.decoder_bbox_head[i](predicted_class_features) + _inverse_sigmoid(last_refined_bbox)
1502
+ )
1503
+ )
1504
+ elif i == self.decoder_num_layers - 1:
1505
+ decoder_classes.append(
1506
+ get_class_similarity(self.class_distance_type, predicted_class_features, original_class_projected)
1507
+ )
1508
+ decoder_bboxes.append(refined_bbox)
1509
+ break
1510
+ last_refined_bbox = refined_bbox
1511
+ reference_points = refined_bbox.detach() if self.training else refined_bbox
1512
+ if output_attentions:
1513
+ all_attns += (all_self_attns, all_cross_attns)
1514
+
1515
+ last_hidden_state = predicted_class_features
1516
+ decoder_bboxes = torch.stack(decoder_bboxes)
1517
+ decoder_classes = torch.stack(decoder_classes)
1518
+
1519
+ if not return_dict:
1520
+ return (
1521
+ last_hidden_state,
1522
+ all_hidden_states,
1523
+ all_attns,
1524
+ decoder_bboxes,
1525
+ decoder_classes,
1526
+ encoder_bboxes,
1527
+ encoder_class_similarity,
1528
+ init_reference_points,
1529
+ reference_points,
1530
+ )
1531
+
1532
+ return OmDetTurboDecoderOutput(
1533
+ last_hidden_state=last_hidden_state,
1534
+ hidden_states=all_hidden_states,
1535
+ attentions=all_attns,
1536
+ decoder_coords=decoder_bboxes,
1537
+ decoder_classes=decoder_classes,
1538
+ encoder_coord_logits=encoder_bboxes,
1539
+ encoder_class_logits=encoder_class_similarity,
1540
+ init_reference_points=init_reference_points,
1541
+ intermediate_reference_points=reference_points,
1542
+ )
1543
+
1544
+
1545
+ @add_start_docstrings(
1546
+ """
1547
+ OmDetTurbo Model (consisting of a vision and a text backbone, and encoder-decoder architecture) outputting
1548
+ bounding boxes and classes scores for tasks such as COCO detection.
1549
+ """,
1550
+ OMDET_TURBO_START_DOCSTRING,
1551
+ )
1552
+ class OmDetTurboForObjectDetection(OmDetTurboPreTrainedModel):
1553
+ def __init__(self, config: OmDetTurboConfig):
1554
+ super().__init__(config)
1555
+ self.vision_backbone = OmDetTurboVisionBackbone(config)
1556
+ self.language_backbone = OmDetTurboLanguageBackbone(config)
1557
+ self.encoder = OmDetTurboHybridEncoder(config)
1558
+ self.decoder = OmDetTurboDecoder(config)
1559
+ self.num_queries = config.num_queries
1560
+
1561
+ self.language_cache_class = OmDetTurboLRUCache(config.cache_size)
1562
+ self.language_cache_prompt = OmDetTurboLRUCache(config.cache_size)
1563
+ self.vocab_size = config.text_config.vocab_size
1564
+ self.post_init()
1565
+
1566
+ def get_input_embeddings(self):
1567
+ return self.language_backbone.model.get_input_embeddings()
1568
+
1569
+ def set_input_embeddings(self, value):
1570
+ self.language_backbone.model.set_input_embeddings(value)
1571
+
1572
+ def resize_token_embeddings(
1573
+ self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None, mean_resizing: bool = True
1574
+ ) -> nn.Embedding:
1575
+ model_embeds = self.language_backbone.model.resize_token_embeddings(
1576
+ new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of, mean_resizing=mean_resizing
1577
+ )
1578
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
1579
+ self.vocab_size = model_embeds.num_embeddings
1580
+ return model_embeds
1581
+
1582
+ @add_start_docstrings_to_model_forward(OMDET_TURBO_INPUTS_DOCSTRING)
1583
+ @replace_return_docstrings(output_type=OmDetTurboObjectDetectionOutput, config_class=_CONFIG_FOR_DOC)
1584
+ def forward(
1585
+ self,
1586
+ pixel_values: torch.FloatTensor,
1587
+ classes_input_ids: torch.LongTensor,
1588
+ classes_attention_mask: torch.LongTensor,
1589
+ tasks_input_ids: torch.LongTensor,
1590
+ tasks_attention_mask: torch.LongTensor,
1591
+ classes_structure: torch.LongTensor,
1592
+ labels: Optional[torch.LongTensor] = None,
1593
+ output_attentions: Optional[bool] = None,
1594
+ output_hidden_states: Optional[bool] = None,
1595
+ return_dict: Optional[bool] = None,
1596
+ ) -> Union[Tuple[torch.FloatTensor], OmDetTurboObjectDetectionOutput]:
1597
+ r"""
1598
+ Returns:
1599
+
1600
+ Examples:
1601
+
1602
+ ```python
1603
+ >>> import requests
1604
+ >>> from PIL import Image
1605
+
1606
+ >>> from transformers import AutoProcessor, OmDetTurboForObjectDetection
1607
+
1608
+ >>> processor = AutoProcessor.from_pretrained("omlab/omdet-turbo-swin-tiny-hf")
1609
+ >>> model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf")
1610
+
1611
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1612
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1613
+ >>> classes = ["cat", "remote"]
1614
+ >>> task = "Detect {}.".format(", ".join(classes))
1615
+ >>> inputs = processor(image, text=classes, task=task, return_tensors="pt")
1616
+
1617
+ >>> outputs = model(**inputs)
1618
+
1619
+ >>> # convert outputs (bounding boxes and class logits)
1620
+ >>> results = processor.post_process_grounded_object_detection(
1621
+ ... outputs,
1622
+ ... classes=classes,
1623
+ ... target_sizes=[image.size[::-1]],
1624
+ ... score_threshold=0.3,
1625
+ ... nms_threshold=0.3,
1626
+ >>> )[0]
1627
+ >>> for score, class_name, box in zip(results["scores"], results["classes"], results["boxes"]):
1628
+ ... box = [round(i, 1) for i in box.tolist()]
1629
+ ... print(
1630
+ ... f"Detected {class_name} with confidence "
1631
+ ... f"{round(score.item(), 2)} at location {box}"
1632
+ ... )
1633
+ Detected remote with confidence 0.76 at location [39.9, 71.3, 176.5, 117.9]
1634
+ Detected cat with confidence 0.72 at location [345.1, 22.5, 639.7, 371.9]
1635
+ Detected cat with confidence 0.65 at location [12.7, 53.8, 315.5, 475.3]
1636
+ Detected remote with confidence 0.57 at location [333.4, 75.6, 370.7, 187.0]
1637
+ ```"""
1638
+ if labels is not None:
1639
+ raise NotImplementedError("Training is not implemented yet")
1640
+
1641
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1642
+ output_hidden_states = (
1643
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1644
+ )
1645
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1646
+
1647
+ loss = None
1648
+ image_features = self.vision_backbone(pixel_values)
1649
+ encoder_outputs = self.encoder(
1650
+ image_features,
1651
+ output_attentions=output_attentions,
1652
+ output_hidden_states=output_hidden_states,
1653
+ return_dict=return_dict,
1654
+ )
1655
+ class_features, task_features, task_mask = self.get_language_embedding(
1656
+ classes_input_ids,
1657
+ classes_attention_mask,
1658
+ tasks_input_ids,
1659
+ tasks_attention_mask,
1660
+ classes_structure,
1661
+ )
1662
+ encoder_extracted_states = encoder_outputs.extracted_states if return_dict else encoder_outputs[-1]
1663
+ decoder_outputs = self.decoder(
1664
+ encoder_extracted_states,
1665
+ class_features,
1666
+ task_features,
1667
+ task_mask,
1668
+ output_attentions=output_attentions,
1669
+ output_hidden_states=output_hidden_states,
1670
+ return_dict=return_dict,
1671
+ )
1672
+
1673
+ if not return_dict:
1674
+ return tuple(
1675
+ output
1676
+ for output in [
1677
+ loss,
1678
+ decoder_outputs[3][-1],
1679
+ decoder_outputs[4][-1],
1680
+ decoder_outputs[7],
1681
+ decoder_outputs[8],
1682
+ decoder_outputs[5],
1683
+ decoder_outputs[6],
1684
+ encoder_outputs[-1],
1685
+ decoder_outputs[1],
1686
+ decoder_outputs[2],
1687
+ encoder_outputs[1],
1688
+ encoder_outputs[2],
1689
+ classes_structure,
1690
+ ]
1691
+ if output is not None
1692
+ )
1693
+
1694
+ return OmDetTurboObjectDetectionOutput(
1695
+ loss=loss,
1696
+ decoder_coord_logits=decoder_outputs.decoder_coords[-1],
1697
+ decoder_class_logits=decoder_outputs.decoder_classes[-1],
1698
+ init_reference_points=decoder_outputs.init_reference_points,
1699
+ intermediate_reference_points=decoder_outputs.intermediate_reference_points,
1700
+ encoder_coord_logits=decoder_outputs.encoder_coord_logits,
1701
+ encoder_class_logits=decoder_outputs.encoder_class_logits,
1702
+ encoder_extracted_states=encoder_outputs.extracted_states,
1703
+ decoder_hidden_states=decoder_outputs.hidden_states,
1704
+ decoder_attentions=decoder_outputs.attentions,
1705
+ encoder_hidden_states=encoder_outputs.hidden_states,
1706
+ encoder_attentions=encoder_outputs.attentions,
1707
+ classes_structure=classes_structure,
1708
+ )
1709
+
1710
+
1711
+ __all__ = ["OmDetTurboForObjectDetection", "OmDetTurboPreTrainedModel"]
docs/transformers/src/transformers/models/omdet_turbo/processing_omdet_turbo.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Processor class for OmDet-Turbo.
17
+ """
18
+
19
+ import warnings
20
+ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
21
+
22
+ from ...feature_extraction_utils import BatchFeature
23
+ from ...image_transforms import center_to_corners_format
24
+ from ...image_utils import ImageInput
25
+ from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack
26
+ from ...tokenization_utils_base import PreTokenizedInput, TextInput
27
+ from ...utils import (
28
+ TensorType,
29
+ is_torch_available,
30
+ is_torchvision_available,
31
+ )
32
+ from ...utils.deprecation import deprecate_kwarg
33
+ from ...utils.import_utils import requires
34
+
35
+
36
+ if TYPE_CHECKING:
37
+ from .modeling_omdet_turbo import OmDetTurboObjectDetectionOutput
38
+
39
+
40
+ class OmDetTurboTextKwargs(TextKwargs, total=False):
41
+ task: Optional[Union[str, List[str], TextInput, PreTokenizedInput]]
42
+
43
+
44
+ if is_torch_available():
45
+ import torch
46
+
47
+
48
+ if is_torchvision_available():
49
+ from torchvision.ops.boxes import batched_nms
50
+
51
+
52
+ class OmDetTurboProcessorKwargs(ProcessingKwargs, total=False):
53
+ text_kwargs: OmDetTurboTextKwargs
54
+ _defaults = {
55
+ "text_kwargs": {
56
+ "add_special_tokens": True,
57
+ "padding": "max_length",
58
+ "truncation": True,
59
+ "max_length": 77,
60
+ "stride": 0,
61
+ "return_overflowing_tokens": False,
62
+ "return_special_tokens_mask": False,
63
+ "return_offsets_mapping": False,
64
+ "return_token_type_ids": False,
65
+ "return_length": False,
66
+ "verbose": True,
67
+ "task": None,
68
+ },
69
+ "images_kwargs": {},
70
+ }
71
+
72
+
73
+ class DictWithDeprecationWarning(dict):
74
+ message = (
75
+ "The `classes` key is deprecated for `OmDetTurboProcessor.post_process_grounded_object_detection` "
76
+ "output dict and will be removed in a 4.51.0 version. Please use `text_labels` instead."
77
+ )
78
+
79
+ def __getitem__(self, key):
80
+ if key == "classes":
81
+ warnings.warn(self.message, FutureWarning)
82
+ return super().__getitem__("text_labels")
83
+ return super().__getitem__(key)
84
+
85
+ def get(self, key, *args, **kwargs):
86
+ if key == "classes":
87
+ warnings.warn(self.message, FutureWarning)
88
+ return super().get("text_labels", *args, **kwargs)
89
+ return super().get(key, *args, **kwargs)
90
+
91
+
92
+ def clip_boxes(box, box_size: Tuple[int, int]):
93
+ """
94
+ Clip the boxes by limiting x coordinates to the range [0, width]
95
+ and y coordinates to the range [0, height].
96
+
97
+ Args:
98
+ box (Tensor): The box to be clipped.
99
+ box_size (height, width): The clipping box's size.
100
+ """
101
+ assert torch.isfinite(box).all(), "Box tensor contains infinite or NaN!"
102
+ height, width = box_size
103
+ x1 = box[:, 0].clamp(min=0, max=width)
104
+ y1 = box[:, 1].clamp(min=0, max=height)
105
+ x2 = box[:, 2].clamp(min=0, max=width)
106
+ y2 = box[:, 3].clamp(min=0, max=height)
107
+ box = torch.stack((x1, y1, x2, y2), dim=-1)
108
+
109
+ return box
110
+
111
+
112
+ def compute_score(boxes):
113
+ """
114
+ Compute logit scores per class for each box (proposal) and an array of class indices
115
+ corresponding to each proposal, flattened across the proposal_num.
116
+ The indices in `classes` will later be used to filter and match the predicted classes
117
+ with the input class names.
118
+ """
119
+ num_classes = boxes.shape[2]
120
+ proposal_num = boxes.shape[1]
121
+ scores = torch.sigmoid(boxes)
122
+ classes = torch.arange(num_classes, device=boxes.device).unsqueeze(0).repeat(proposal_num, 1).flatten(0, 1)
123
+ return scores, classes
124
+
125
+
126
+ def _post_process_boxes_for_image(
127
+ boxes: "torch.Tensor",
128
+ scores: "torch.Tensor",
129
+ labels: "torch.Tensor",
130
+ image_num_classes: int,
131
+ image_size: Tuple[int, int],
132
+ threshold: float,
133
+ nms_threshold: float,
134
+ max_num_det: Optional[int] = None,
135
+ ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
136
+ """
137
+ Filter predicted results using given thresholds and NMS.
138
+
139
+ Args:
140
+ boxes (`torch.Tensor`):
141
+ A Tensor of predicted class-specific or class-agnostic boxes for the image.
142
+ Shape (num_queries, max_num_classes_in_batch * 4) if doing class-specific regression,
143
+ or (num_queries, 4) if doing class-agnostic regression.
144
+ scores (`torch.Tensor` of shape (num_queries, max_num_classes_in_batch + 1)):
145
+ A Tensor of predicted class scores for the image.
146
+ labels (`torch.Tensor` of shape (num_queries * (max_num_classes_in_batch + 1),)):
147
+ A Tensor of predicted labels for the image.
148
+ image_num_classes (`int`):
149
+ The number of classes queried for detection on the image.
150
+ image_size (`Tuple[int, int]`):
151
+ A tuple of (height, width) for the image.
152
+ threshold (`float`):
153
+ Only return detections with a confidence score exceeding this threshold.
154
+ nms_threshold (`float`):
155
+ The threshold to use for box non-maximum suppression. Value in [0, 1].
156
+ max_num_det (`int`, *optional*):
157
+ The maximum number of detections to return. Default is None.
158
+
159
+ Returns:
160
+ Tuple: A tuple with the following:
161
+ "boxes" (Tensor): A tensor of shape (num_filtered_objects, 4), containing the predicted boxes in (x1, y1, x2, y2) format.
162
+ "scores" (Tensor): A tensor of shape (num_filtered_objects,), containing the predicted confidence scores for each detection.
163
+ "labels" (Tensor): A tensor of ids, where each id is the predicted class id for the corresponding detection
164
+ """
165
+
166
+ # Filter by max number of detections
167
+ proposal_num = len(boxes) if max_num_det is None else max_num_det
168
+ scores_per_image, topk_indices = scores.flatten(0, 1).topk(proposal_num, sorted=False)
169
+ labels_per_image = labels[topk_indices]
170
+ boxes_per_image = boxes.view(-1, 1, 4).repeat(1, scores.shape[1], 1).view(-1, 4)
171
+ boxes_per_image = boxes_per_image[topk_indices]
172
+
173
+ # Convert and scale boxes to original image size
174
+ boxes_per_image = center_to_corners_format(boxes_per_image)
175
+ boxes_per_image = boxes_per_image * torch.tensor(image_size[::-1]).repeat(2).to(boxes_per_image.device)
176
+
177
+ # Filtering by confidence score
178
+ filter_mask = scores_per_image > threshold # R x K
179
+ score_keep = filter_mask.nonzero(as_tuple=False).view(-1)
180
+ boxes_per_image = boxes_per_image[score_keep]
181
+ scores_per_image = scores_per_image[score_keep]
182
+ labels_per_image = labels_per_image[score_keep]
183
+
184
+ # Ensure we did not overflow to non existing classes
185
+ filter_classes_mask = labels_per_image < image_num_classes
186
+ classes_keep = filter_classes_mask.nonzero(as_tuple=False).view(-1)
187
+ boxes_per_image = boxes_per_image[classes_keep]
188
+ scores_per_image = scores_per_image[classes_keep]
189
+ labels_per_image = labels_per_image[classes_keep]
190
+
191
+ # NMS
192
+ keep = batched_nms(boxes_per_image, scores_per_image, labels_per_image, nms_threshold)
193
+ boxes_per_image = boxes_per_image[keep]
194
+ scores_per_image = scores_per_image[keep]
195
+ labels_per_image = labels_per_image[keep]
196
+
197
+ # Clip to image size
198
+ boxes_per_image = clip_boxes(boxes_per_image, image_size)
199
+
200
+ return boxes_per_image, scores_per_image, labels_per_image
201
+
202
+
203
+ @requires(backends=("vision", "torchvision"))
204
+ class OmDetTurboProcessor(ProcessorMixin):
205
+ r"""
206
+ Constructs a OmDet-Turbo processor which wraps a Deformable DETR image processor and an AutoTokenizer into a
207
+ single processor.
208
+
209
+ [`OmDetTurboProcessor`] offers all the functionalities of [`DetrImageProcessor`] and
210
+ [`AutoTokenizer`]. See the docstring of [`~OmDetTurboProcessor.__call__`] and [`~OmDetTurboProcessor.decode`]
211
+ for more information.
212
+
213
+ Args:
214
+ image_processor (`DetrImageProcessor`):
215
+ An instance of [`DetrImageProcessor`]. The image processor is a required input.
216
+ tokenizer (`AutoTokenizer`):
217
+ An instance of ['PreTrainedTokenizer`]. The tokenizer is a required input.
218
+ """
219
+
220
+ attributes = ["image_processor", "tokenizer"]
221
+ image_processor_class = ("DetrImageProcessor", "DetrImageProcessorFast")
222
+ tokenizer_class = "AutoTokenizer"
223
+
224
+ def __init__(self, image_processor, tokenizer):
225
+ super().__init__(image_processor, tokenizer)
226
+
227
+ def __call__(
228
+ self,
229
+ images: ImageInput = None,
230
+ text: Union[List[str], List[List[str]]] = None,
231
+ audio=None,
232
+ videos=None,
233
+ **kwargs: Unpack[OmDetTurboProcessorKwargs],
234
+ ) -> BatchFeature:
235
+ """
236
+ This method uses [*DetrImageProcessor.__call__] method to prepare image(s) for the model, and
237
+ [CLIPTokenizerFast.__call__] to prepare text for the model.
238
+
239
+ Please refer to the docstring of the above two methods for more information.
240
+
241
+ Args:
242
+ images (`ImageInput`):
243
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255.
244
+ text (`Union[str, List[str], List[List[str]]]`):
245
+ The classes used to limit the scope of the open vocabulary detection. Expects a list of strings or a list
246
+ of list of strings. Batched classes can be of different lengths.
247
+ Examples: ["cat", "dog", "bird"], [["cat", "dog", "bird"], ["hat", "person"], ["car"]]
248
+ Kwargs:
249
+ task (`Union[str, List[str], TextInput, PreTokenizedInput]`):
250
+ The grounded text used to guide open vocabulary detection. Expects a single string or a list of strings.
251
+ Examples: "Detect a cat, a dog, and a bird.",[ "Detect everything.", "Detect trees and flowers."]
252
+ When not provided, the default task is "Detect [class1], [class2], [class3]" etc.
253
+ ...
254
+ """
255
+ if images is None or text is None:
256
+ raise ValueError("You have to specify both `images` and `text`")
257
+
258
+ output_kwargs = self._merge_kwargs(
259
+ OmDetTurboProcessorKwargs,
260
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
261
+ **kwargs,
262
+ )
263
+
264
+ if isinstance(text, str):
265
+ text = text.strip(" ").split(",")
266
+
267
+ if not (len(text) and isinstance(text[0], (list, tuple))):
268
+ text = [text]
269
+
270
+ task = output_kwargs["text_kwargs"].pop("task", None)
271
+ if task is None:
272
+ task = ["Detect {}.".format(", ".join(text_single)) for text_single in text]
273
+ elif not isinstance(task, (list, tuple)):
274
+ task = [task]
275
+
276
+ encoding_image_processor = self.image_processor(images, **output_kwargs["images_kwargs"])
277
+ tasks_encoding = self.tokenizer(text=task, **output_kwargs["text_kwargs"])
278
+
279
+ classes = text
280
+
281
+ classes_structure = torch.tensor([len(class_single) for class_single in classes], dtype=torch.long)
282
+ classes_flattened = [class_single for class_batch in classes for class_single in class_batch]
283
+ classes_encoding = self.tokenizer(text=classes_flattened, **output_kwargs["text_kwargs"])
284
+
285
+ encoding = BatchFeature()
286
+ encoding.update({f"tasks_{key}": value for key, value in tasks_encoding.items()})
287
+ encoding.update({f"classes_{key}": value for key, value in classes_encoding.items()})
288
+ encoding.update({"classes_structure": classes_structure})
289
+ encoding.update(encoding_image_processor)
290
+
291
+ return encoding
292
+
293
+ # Copied from transformers.models.blip.processing_blip.BlipProcessor.batch_decode with BertTokenizerFast->PreTrainedTokenizer
294
+ def batch_decode(self, *args, **kwargs):
295
+ """
296
+ This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
297
+ refer to the docstring of this method for more information.
298
+ """
299
+ return self.tokenizer.batch_decode(*args, **kwargs)
300
+
301
+ # Copied from transformers.models.blip.processing_blip.BlipProcessor.decode with BertTokenizerFast->PreTrainedTokenizer
302
+ def decode(self, *args, **kwargs):
303
+ """
304
+ This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
305
+ the docstring of this method for more information.
306
+ """
307
+ return self.tokenizer.decode(*args, **kwargs)
308
+
309
+ def _get_default_image_size(self) -> Tuple[int, int]:
310
+ height = (
311
+ self.image_processor.size["height"]
312
+ if "height" in self.image_processor.size
313
+ else self.image_processor.size["shortest_edge"]
314
+ )
315
+ width = (
316
+ self.image_processor.size["width"]
317
+ if "width" in self.image_processor.size
318
+ else self.image_processor.size["longest_edge"]
319
+ )
320
+ return height, width
321
+
322
+ @deprecate_kwarg("score_threshold", new_name="threshold", version="4.51.0")
323
+ @deprecate_kwarg("classes", new_name="text_labels", version="4.51.0")
324
+ def post_process_grounded_object_detection(
325
+ self,
326
+ outputs: "OmDetTurboObjectDetectionOutput",
327
+ text_labels: Optional[Union[List[str], List[List[str]]]] = None,
328
+ threshold: float = 0.3,
329
+ nms_threshold: float = 0.5,
330
+ target_sizes: Optional[Union[TensorType, List[Tuple]]] = None,
331
+ max_num_det: Optional[int] = None,
332
+ ):
333
+ """
334
+ Converts the raw output of [`OmDetTurboForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
335
+ bottom_right_x, bottom_right_y) format and get the associated text class.
336
+
337
+ Args:
338
+ outputs ([`OmDetTurboObjectDetectionOutput`]):
339
+ Raw outputs of the model.
340
+ text_labels (Union[List[str], List[List[str]]], *optional*):
341
+ The input classes names. If not provided, `text_labels` will be set to `None` in `outputs`.
342
+ threshold (float, defaults to 0.3):
343
+ Only return detections with a confidence score exceeding this threshold.
344
+ nms_threshold (float, defaults to 0.5):
345
+ The threshold to use for box non-maximum suppression. Value in [0, 1].
346
+ target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
347
+ Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
348
+ `(height, width)` of each image in the batch. If unset, predictions will not be resized.
349
+ max_num_det (`int`, *optional*):
350
+ The maximum number of detections to return.
351
+ Returns:
352
+ `List[Dict]`: A list of dictionaries, each dictionary containing the scores, classes and boxes for an image
353
+ in the batch as predicted by the model.
354
+ """
355
+
356
+ batch_size = len(outputs.decoder_coord_logits)
357
+
358
+ # Inputs consistency check for target sizes
359
+ if target_sizes is None:
360
+ height, width = self._get_default_image_size()
361
+ target_sizes = [(height, width)] * batch_size
362
+
363
+ if any(len(image_size) != 2 for image_size in target_sizes):
364
+ raise ValueError(
365
+ "Each element of target_sizes must contain the size (height, width) of each image of the batch"
366
+ )
367
+
368
+ if len(target_sizes) != batch_size:
369
+ raise ValueError("Make sure that you pass in as many target sizes as output sequences")
370
+
371
+ # Inputs consistency check for text labels
372
+ if text_labels is not None and isinstance(text_labels[0], str):
373
+ text_labels = [text_labels]
374
+
375
+ if text_labels is not None and len(text_labels) != batch_size:
376
+ raise ValueError("Make sure that you pass in as many classes group as output sequences")
377
+
378
+ # Convert target_sizes to list for easier handling
379
+ if isinstance(target_sizes, torch.Tensor):
380
+ target_sizes = target_sizes.tolist()
381
+
382
+ batch_boxes = outputs.decoder_coord_logits
383
+ batch_logits = outputs.decoder_class_logits
384
+ batch_num_classes = outputs.classes_structure
385
+
386
+ batch_scores, batch_labels = compute_score(batch_logits)
387
+
388
+ results = []
389
+ for boxes, scores, image_size, image_num_classes in zip(
390
+ batch_boxes, batch_scores, target_sizes, batch_num_classes
391
+ ):
392
+ boxes, scores, labels = _post_process_boxes_for_image(
393
+ boxes=boxes,
394
+ scores=scores,
395
+ labels=batch_labels,
396
+ image_num_classes=image_num_classes,
397
+ image_size=image_size,
398
+ threshold=threshold,
399
+ nms_threshold=nms_threshold,
400
+ max_num_det=max_num_det,
401
+ )
402
+ result = DictWithDeprecationWarning(
403
+ {"boxes": boxes, "scores": scores, "labels": labels, "text_labels": None}
404
+ )
405
+ results.append(result)
406
+
407
+ # Add text labels
408
+ if text_labels is not None:
409
+ for result, image_text_labels in zip(results, text_labels):
410
+ result["text_labels"] = [image_text_labels[idx] for idx in result["labels"]]
411
+
412
+ return results
413
+
414
+
415
+ __all__ = ["OmDetTurboProcessor"]
docs/transformers/src/transformers/models/oneformer/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_oneformer import *
22
+ from .image_processing_oneformer import *
23
+ from .modeling_oneformer import *
24
+ from .processing_oneformer import *
25
+ else:
26
+ import sys
27
+
28
+ _file = globals()["__file__"]
29
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/src/transformers/models/oneformer/configuration_oneformer.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 SHI Labs and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """OneFormer model configuration"""
16
+
17
+ from typing import Dict, Optional
18
+
19
+ from ...configuration_utils import PretrainedConfig
20
+ from ...utils import logging
21
+ from ...utils.backbone_utils import verify_backbone_config_arguments
22
+ from ..auto import CONFIG_MAPPING
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ class OneFormerConfig(PretrainedConfig):
29
+ r"""
30
+ This is the configuration class to store the configuration of a [`OneFormerModel`]. It is used to instantiate a
31
+ OneFormer model according to the specified arguments, defining the model architecture. Instantiating a
32
+ configuration with the defaults will yield a similar configuration to that of the OneFormer
33
+ [shi-labs/oneformer_ade20k_swin_tiny](https://huggingface.co/shi-labs/oneformer_ade20k_swin_tiny) architecture
34
+ trained on [ADE20k-150](https://huggingface.co/datasets/scene_parse_150).
35
+
36
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
37
+ documentation from [`PretrainedConfig`] for more information.
38
+
39
+ Args:
40
+ backbone_config (`PretrainedConfig`, *optional*, defaults to `SwinConfig`):
41
+ The configuration of the backbone model.
42
+ backbone (`str`, *optional*):
43
+ Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
44
+ will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
45
+ is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
46
+ use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
47
+ Whether to use pretrained weights for the backbone.
48
+ use_timm_backbone (`bool`, *optional*, defaults to `False`):
49
+ Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
50
+ library.
51
+ backbone_kwargs (`dict`, *optional*):
52
+ Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
53
+ e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
54
+ ignore_value (`int`, *optional*, defaults to 255):
55
+ Values to be ignored in GT label while calculating loss.
56
+ num_queries (`int`, *optional*, defaults to 150):
57
+ Number of object queries.
58
+ no_object_weight (`float`, *optional*, defaults to 0.1):
59
+ Weight for no-object class predictions.
60
+ class_weight (`float`, *optional*, defaults to 2.0):
61
+ Weight for Classification CE loss.
62
+ mask_weight (`float`, *optional*, defaults to 5.0):
63
+ Weight for binary CE loss.
64
+ dice_weight (`float`, *optional*, defaults to 5.0):
65
+ Weight for dice loss.
66
+ contrastive_weight (`float`, *optional*, defaults to 0.5):
67
+ Weight for contrastive loss.
68
+ contrastive_temperature (`float`, *optional*, defaults to 0.07):
69
+ Initial value for scaling the contrastive logits.
70
+ train_num_points (`int`, *optional*, defaults to 12544):
71
+ Number of points to sample while calculating losses on mask predictions.
72
+ oversample_ratio (`float`, *optional*, defaults to 3.0):
73
+ Ratio to decide how many points to oversample.
74
+ importance_sample_ratio (`float`, *optional*, defaults to 0.75):
75
+ Ratio of points that are sampled via importance sampling.
76
+ init_std (`float`, *optional*, defaults to 0.02):
77
+ Standard deviation for normal intialization.
78
+ init_xavier_std (`float`, *optional*, defaults to 1.0):
79
+ Standard deviation for xavier uniform initialization.
80
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
81
+ Epsilon for layer normalization.
82
+ is_training (`bool`, *optional*, defaults to `False`):
83
+ Whether to run in training or inference mode.
84
+ use_auxiliary_loss (`bool`, *optional*, defaults to `True`):
85
+ Whether to calculate loss using intermediate predictions from transformer decoder.
86
+ output_auxiliary_logits (`bool`, *optional*, defaults to `True`):
87
+ Whether to return intermediate predictions from transformer decoder.
88
+ strides (`list`, *optional*, defaults to `[4, 8, 16, 32]`):
89
+ List containing the strides for feature maps in the encoder.
90
+ task_seq_len (`int`, *optional*, defaults to 77):
91
+ Sequence length for tokenizing text list input.
92
+ text_encoder_width (`int`, *optional*, defaults to 256):
93
+ Hidden size for text encoder.
94
+ text_encoder_context_length (`int`, *optional*, defaults to 77):
95
+ Input sequence length for text encoder.
96
+ text_encoder_num_layers (`int`, *optional*, defaults to 6):
97
+ Number of layers for transformer in text encoder.
98
+ text_encoder_vocab_size (`int`, *optional*, defaults to 49408):
99
+ Vocabulary size for tokenizer.
100
+ text_encoder_proj_layers (`int`, *optional*, defaults to 2):
101
+ Number of layers in MLP for project text queries.
102
+ text_encoder_n_ctx (`int`, *optional*, defaults to 16):
103
+ Number of learnable text context queries.
104
+ conv_dim (`int`, *optional*, defaults to 256):
105
+ Feature map dimension to map outputs from the backbone.
106
+ mask_dim (`int`, *optional*, defaults to 256):
107
+ Dimension for feature maps in pixel decoder.
108
+ hidden_dim (`int`, *optional*, defaults to 256):
109
+ Dimension for hidden states in transformer decoder.
110
+ encoder_feedforward_dim (`int`, *optional*, defaults to 1024):
111
+ Dimension for FFN layer in pixel decoder.
112
+ norm (`str`, *optional*, defaults to `"GN"`):
113
+ Type of normalization.
114
+ encoder_layers (`int`, *optional*, defaults to 6):
115
+ Number of layers in pixel decoder.
116
+ decoder_layers (`int`, *optional*, defaults to 10):
117
+ Number of layers in transformer decoder.
118
+ use_task_norm (`bool`, *optional*, defaults to `True`):
119
+ Whether to normalize the task token.
120
+ num_attention_heads (`int`, *optional*, defaults to 8):
121
+ Number of attention heads in transformer layers in the pixel and transformer decoders.
122
+ dropout (`float`, *optional*, defaults to 0.1):
123
+ Dropout probability for pixel and transformer decoders.
124
+ dim_feedforward (`int`, *optional*, defaults to 2048):
125
+ Dimension for FFN layer in transformer decoder.
126
+ pre_norm (`bool`, *optional*, defaults to `False`):
127
+ Whether to normalize hidden states before attention layers in transformer decoder.
128
+ enforce_input_proj (`bool`, *optional*, defaults to `False`):
129
+ Whether to project hidden states in transformer decoder.
130
+ query_dec_layers (`int`, *optional*, defaults to 2):
131
+ Number of layers in query transformer.
132
+ common_stride (`int`, *optional*, defaults to 4):
133
+ Common stride used for features in pixel decoder.
134
+
135
+ Examples:
136
+ ```python
137
+ >>> from transformers import OneFormerConfig, OneFormerModel
138
+
139
+ >>> # Initializing a OneFormer shi-labs/oneformer_ade20k_swin_tiny configuration
140
+ >>> configuration = OneFormerConfig()
141
+ >>> # Initializing a model (with random weights) from the shi-labs/oneformer_ade20k_swin_tiny style configuration
142
+ >>> model = OneFormerModel(configuration)
143
+ >>> # Accessing the model configuration
144
+ >>> configuration = model.config
145
+ ```
146
+ """
147
+
148
+ model_type = "oneformer"
149
+ attribute_map = {"hidden_size": "hidden_dim"}
150
+
151
+ def __init__(
152
+ self,
153
+ backbone_config: Optional[Dict] = None,
154
+ backbone: Optional[str] = None,
155
+ use_pretrained_backbone: bool = False,
156
+ use_timm_backbone: bool = False,
157
+ backbone_kwargs: Optional[Dict] = None,
158
+ ignore_value: int = 255,
159
+ num_queries: int = 150,
160
+ no_object_weight: int = 0.1,
161
+ class_weight: float = 2.0,
162
+ mask_weight: float = 5.0,
163
+ dice_weight: float = 5.0,
164
+ contrastive_weight: float = 0.5,
165
+ contrastive_temperature: float = 0.07,
166
+ train_num_points: int = 12544,
167
+ oversample_ratio: float = 3.0,
168
+ importance_sample_ratio: float = 0.75,
169
+ init_std: float = 0.02,
170
+ init_xavier_std: float = 1.0,
171
+ layer_norm_eps: float = 1e-05,
172
+ is_training: bool = False,
173
+ use_auxiliary_loss: bool = True,
174
+ output_auxiliary_logits: bool = True,
175
+ strides: Optional[list] = [4, 8, 16, 32],
176
+ task_seq_len: int = 77,
177
+ text_encoder_width: int = 256,
178
+ text_encoder_context_length: int = 77,
179
+ text_encoder_num_layers: int = 6,
180
+ text_encoder_vocab_size: int = 49408,
181
+ text_encoder_proj_layers: int = 2,
182
+ text_encoder_n_ctx: int = 16,
183
+ conv_dim: int = 256,
184
+ mask_dim: int = 256,
185
+ hidden_dim: int = 256,
186
+ encoder_feedforward_dim: int = 1024,
187
+ norm: str = "GN",
188
+ encoder_layers: int = 6,
189
+ decoder_layers: int = 10,
190
+ use_task_norm: bool = True,
191
+ num_attention_heads: int = 8,
192
+ dropout: float = 0.1,
193
+ dim_feedforward: int = 2048,
194
+ pre_norm: bool = False,
195
+ enforce_input_proj: bool = False,
196
+ query_dec_layers: int = 2,
197
+ common_stride: int = 4,
198
+ **kwargs,
199
+ ):
200
+ if backbone_config is None and backbone is None:
201
+ logger.info("`backbone_config` is unset. Initializing the config with the default `Swin` backbone.")
202
+ backbone_config = CONFIG_MAPPING["swin"](
203
+ image_size=224,
204
+ num_channels=3,
205
+ patch_size=4,
206
+ embed_dim=96,
207
+ depths=[2, 2, 6, 2],
208
+ num_heads=[3, 6, 12, 24],
209
+ window_size=7,
210
+ drop_path_rate=0.3,
211
+ use_absolute_embeddings=False,
212
+ out_features=["stage1", "stage2", "stage3", "stage4"],
213
+ )
214
+ elif isinstance(backbone_config, dict):
215
+ backbone_model_type = backbone_config.get("model_type")
216
+ config_class = CONFIG_MAPPING[backbone_model_type]
217
+ backbone_config = config_class.from_dict(backbone_config)
218
+
219
+ verify_backbone_config_arguments(
220
+ use_timm_backbone=use_timm_backbone,
221
+ use_pretrained_backbone=use_pretrained_backbone,
222
+ backbone=backbone,
223
+ backbone_config=backbone_config,
224
+ backbone_kwargs=backbone_kwargs,
225
+ )
226
+
227
+ self.backbone_config = backbone_config
228
+ self.backbone = backbone
229
+ self.use_pretrained_backbone = use_pretrained_backbone
230
+ self.use_timm_backbone = use_timm_backbone
231
+ self.backbone_kwargs = backbone_kwargs
232
+ self.ignore_value = ignore_value
233
+ self.num_queries = num_queries
234
+ self.no_object_weight = no_object_weight
235
+ self.class_weight = class_weight
236
+ self.mask_weight = mask_weight
237
+ self.dice_weight = dice_weight
238
+ self.contrastive_weight = contrastive_weight
239
+ self.contrastive_temperature = contrastive_temperature
240
+ self.train_num_points = train_num_points
241
+ self.oversample_ratio = oversample_ratio
242
+ self.importance_sample_ratio = importance_sample_ratio
243
+ self.init_std = init_std
244
+ self.init_xavier_std = init_xavier_std
245
+ self.layer_norm_eps = layer_norm_eps
246
+ self.is_training = is_training
247
+ self.use_auxiliary_loss = use_auxiliary_loss
248
+ self.output_auxiliary_logits = output_auxiliary_logits
249
+ self.strides = strides
250
+ self.task_seq_len = task_seq_len
251
+ self.text_encoder_width = text_encoder_width
252
+ self.text_encoder_context_length = text_encoder_context_length
253
+ self.text_encoder_num_layers = text_encoder_num_layers
254
+ self.text_encoder_vocab_size = text_encoder_vocab_size
255
+ self.text_encoder_proj_layers = text_encoder_proj_layers
256
+ self.text_encoder_n_ctx = text_encoder_n_ctx
257
+ self.conv_dim = conv_dim
258
+ self.mask_dim = mask_dim
259
+ self.hidden_dim = hidden_dim
260
+ self.encoder_feedforward_dim = encoder_feedforward_dim
261
+ self.norm = norm
262
+ self.encoder_layers = encoder_layers
263
+ self.decoder_layers = decoder_layers
264
+ self.use_task_norm = use_task_norm
265
+ self.num_attention_heads = num_attention_heads
266
+ self.dropout = dropout
267
+ self.dim_feedforward = dim_feedforward
268
+ self.pre_norm = pre_norm
269
+ self.enforce_input_proj = enforce_input_proj
270
+ self.query_dec_layers = query_dec_layers
271
+ self.common_stride = common_stride
272
+ self.num_hidden_layers = decoder_layers
273
+
274
+ super().__init__(**kwargs)
275
+
276
+
277
+ __all__ = ["OneFormerConfig"]
docs/transformers/src/transformers/models/oneformer/convert_to_hf_oneformer.py ADDED
@@ -0,0 +1,1191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 SHI Labs and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Convert OneFormer checkpoints from the original repository. URL: https://github.com/SHI-Labs/OneFormer"""
17
+
18
+ import os
19
+ import sys
20
+ from argparse import ArgumentParser
21
+ from dataclasses import dataclass
22
+ from pathlib import Path
23
+ from pprint import pformat
24
+ from typing import Any, Dict, Iterator, List, Set, Tuple
25
+
26
+ import requests
27
+ import torch
28
+ import torchvision.transforms as T
29
+ from PIL import Image
30
+ from torch import Tensor, nn
31
+
32
+
33
+ try:
34
+ from detectron2.checkpoint import DetectionCheckpointer
35
+ from detectron2.config import get_cfg
36
+ from detectron2.data import MetadataCatalog
37
+ from detectron2.projects.deeplab import add_deeplab_config
38
+ except ImportError:
39
+ pass
40
+ from transformers import CLIPTokenizer, DinatConfig, SwinConfig
41
+ from transformers.models.oneformer.image_processing_oneformer import OneFormerImageProcessor
42
+ from transformers.models.oneformer.modeling_oneformer import (
43
+ OneFormerConfig,
44
+ OneFormerForUniversalSegmentation,
45
+ OneFormerForUniversalSegmentationOutput,
46
+ OneFormerModel,
47
+ OneFormerModelOutput,
48
+ )
49
+ from transformers.models.oneformer.processing_oneformer import OneFormerProcessor
50
+ from transformers.utils import logging
51
+
52
+
53
+ StateDict = Dict[str, Tensor]
54
+
55
+ logging.set_verbosity_info()
56
+ logger = logging.get_logger()
57
+
58
+ torch.manual_seed(0)
59
+
60
+
61
+ class TrackedStateDict:
62
+ def __init__(self, to_track: Dict):
63
+ """This class "tracks" a python dictionary by keeping track of which item is accessed.
64
+
65
+ Args:
66
+ to_track (Dict): The dictionary we wish to track
67
+ """
68
+ self.to_track = to_track
69
+ self._seen: Set[str] = set()
70
+
71
+ def __getitem__(self, key: str) -> Any:
72
+ return self.to_track[key]
73
+
74
+ def __setitem__(self, key: str, item: Any):
75
+ self._seen.add(key)
76
+ self.to_track[key] = item
77
+
78
+ def diff(self) -> List[str]:
79
+ """This method returns a set difference between the keys in the tracked state dict and the one we have access so far.
80
+ This is an effective method to check if we have update all the keys
81
+
82
+ Returns:
83
+ List[str]: List of keys not yet updated
84
+ """
85
+ return set(self.to_track.keys()) - self._seen
86
+
87
+ def copy(self) -> Dict:
88
+ # proxy the call to the internal dictionary
89
+ return self.to_track.copy()
90
+
91
+
92
+ # Image to verify the result
93
+ def prepare_img():
94
+ url = "https://praeclarumjj3.github.io/files/coco.jpeg"
95
+ img_data = requests.get(url, stream=True).raw
96
+ im = Image.open(img_data)
97
+ return im
98
+
99
+
100
+ @dataclass
101
+ class Args:
102
+ """Fake command line arguments needed by oneformer/detectron2 implementation"""
103
+
104
+ config_file: str
105
+
106
+
107
+ def setup_cfg(args: Args):
108
+ # load config from file and command-line arguments
109
+ cfg = get_cfg()
110
+ add_deeplab_config(cfg)
111
+ add_common_config(cfg)
112
+ add_oneformer_config(cfg)
113
+ add_swin_config(cfg)
114
+ add_dinat_config(cfg)
115
+ cfg.merge_from_file(args.config_file)
116
+ cfg.freeze()
117
+ return cfg
118
+
119
+
120
+ class OriginalOneFormerConfigToOursConverter:
121
+ def __call__(self, original_config: object, is_swin: bool) -> OneFormerConfig:
122
+ model = original_config.MODEL
123
+
124
+ dataset_catalog = MetadataCatalog.get(original_config.DATASETS.TEST_PANOPTIC[0])
125
+ id2label = dict(enumerate(dataset_catalog.stuff_classes))
126
+ label2id = {label: idx for idx, label in id2label.items()}
127
+
128
+ if is_swin:
129
+ if model.SWIN.EMBED_DIM == 96:
130
+ backbone_config = SwinConfig.from_pretrained(
131
+ "microsoft/swin-tiny-patch4-window7-224",
132
+ drop_path_rate=model.SWIN.DROP_PATH_RATE,
133
+ out_features=["stage1", "stage2", "stage3", "stage4"],
134
+ )
135
+ elif model.SWIN.EMBED_DIM == 192:
136
+ backbone_config = SwinConfig.from_pretrained(
137
+ "microsoft/swin-large-patch4-window12-384",
138
+ drop_path_rate=model.SWIN.DROP_PATH_RATE,
139
+ out_features=["stage1", "stage2", "stage3", "stage4"],
140
+ )
141
+ else:
142
+ raise ValueError(f"embed dim {model.SWIN.EMBED_DIM} not supported for Swin!")
143
+ else:
144
+ backbone_config = DinatConfig.from_pretrained(
145
+ "shi-labs/dinat-large-11x11-in22k-in1k-384",
146
+ dilations=model.DiNAT.DILATIONS,
147
+ kernel_size=model.DiNAT.KERNEL_SIZE,
148
+ out_features=["stage1", "stage2", "stage3", "stage4"],
149
+ )
150
+
151
+ config: OneFormerConfig = OneFormerConfig(
152
+ backbone_config=backbone_config,
153
+ output_attentions=True,
154
+ output_hidden_states=True,
155
+ return_dict=True,
156
+ ignore_value=model.SEM_SEG_HEAD.IGNORE_VALUE,
157
+ num_classes=model.SEM_SEG_HEAD.NUM_CLASSES,
158
+ num_queries=model.ONE_FORMER.NUM_OBJECT_QUERIES,
159
+ no_object_weight=model.ONE_FORMER.NO_OBJECT_WEIGHT,
160
+ class_weight=model.ONE_FORMER.CLASS_WEIGHT,
161
+ mask_weight=model.ONE_FORMER.MASK_WEIGHT,
162
+ dice_weight=model.ONE_FORMER.DICE_WEIGHT,
163
+ contrastive_weight=model.ONE_FORMER.CONTRASTIVE_WEIGHT,
164
+ contrastive_temperature=model.ONE_FORMER.CONTRASTIVE_TEMPERATURE,
165
+ train_num_points=model.ONE_FORMER.TRAIN_NUM_POINTS,
166
+ oversample_ratio=model.ONE_FORMER.OVERSAMPLE_RATIO,
167
+ importance_sample_ratio=model.ONE_FORMER.IMPORTANCE_SAMPLE_RATIO,
168
+ init_std=0.02,
169
+ init_xavier_std=1.0,
170
+ layer_norm_eps=1e-05,
171
+ is_training=False,
172
+ use_auxiliary_loss=model.ONE_FORMER.DEEP_SUPERVISION,
173
+ output_auxiliary_logits=True,
174
+ strides=[4, 8, 16, 32],
175
+ task_seq_len=original_config.INPUT.TASK_SEQ_LEN,
176
+ max_seq_len=original_config.INPUT.MAX_SEQ_LEN,
177
+ text_encoder_width=model.TEXT_ENCODER.WIDTH,
178
+ text_encoder_context_length=model.TEXT_ENCODER.CONTEXT_LENGTH,
179
+ text_encoder_num_layers=model.TEXT_ENCODER.NUM_LAYERS,
180
+ text_encoder_vocab_size=model.TEXT_ENCODER.VOCAB_SIZE,
181
+ text_encoder_proj_layers=model.TEXT_ENCODER.PROJ_NUM_LAYERS,
182
+ text_encoder_n_ctx=model.TEXT_ENCODER.N_CTX,
183
+ conv_dim=model.SEM_SEG_HEAD.CONVS_DIM,
184
+ mask_dim=model.SEM_SEG_HEAD.MASK_DIM,
185
+ hidden_dim=model.ONE_FORMER.HIDDEN_DIM,
186
+ norm=model.SEM_SEG_HEAD.NORM,
187
+ encoder_layers=model.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS,
188
+ encoder_feedforward_dim=1024,
189
+ decoder_layers=model.ONE_FORMER.DEC_LAYERS,
190
+ use_task_norm=model.ONE_FORMER.USE_TASK_NORM,
191
+ num_attention_heads=model.ONE_FORMER.NHEADS,
192
+ dropout=model.ONE_FORMER.DROPOUT,
193
+ dim_feedforward=model.ONE_FORMER.DIM_FEEDFORWARD,
194
+ pre_norm=model.ONE_FORMER.PRE_NORM,
195
+ enforce_input_proj=model.ONE_FORMER.ENFORCE_INPUT_PROJ,
196
+ query_dec_layers=model.ONE_FORMER.CLASS_DEC_LAYERS,
197
+ common_stride=model.SEM_SEG_HEAD.COMMON_STRIDE,
198
+ id2label=id2label,
199
+ label2id=label2id,
200
+ )
201
+
202
+ return config
203
+
204
+
205
+ class OriginalOneFormerConfigToProcessorConverter:
206
+ def __call__(self, original_config: object, model_repo: str) -> OneFormerProcessor:
207
+ model = original_config.MODEL
208
+ model_input = original_config.INPUT
209
+ dataset_catalog = MetadataCatalog.get(original_config.DATASETS.TEST_PANOPTIC[0])
210
+
211
+ if "ade20k" in model_repo:
212
+ class_info_file = "ade20k_panoptic.json"
213
+ elif "coco" in model_repo:
214
+ class_info_file = "coco_panoptic.json"
215
+ elif "cityscapes" in model_repo:
216
+ class_info_file = "cityscapes_panoptic.json"
217
+ else:
218
+ raise ValueError("Invalid Dataset!")
219
+
220
+ image_processor = OneFormerImageProcessor(
221
+ image_mean=(torch.tensor(model.PIXEL_MEAN) / 255).tolist(),
222
+ image_std=(torch.tensor(model.PIXEL_STD) / 255).tolist(),
223
+ size=model_input.MIN_SIZE_TEST,
224
+ max_size=model_input.MAX_SIZE_TEST,
225
+ num_labels=model.SEM_SEG_HEAD.NUM_CLASSES,
226
+ ignore_index=dataset_catalog.ignore_label,
227
+ class_info_file=class_info_file,
228
+ )
229
+
230
+ tokenizer = CLIPTokenizer.from_pretrained(model_repo)
231
+
232
+ return OneFormerProcessor(
233
+ image_processor=image_processor,
234
+ tokenizer=tokenizer,
235
+ task_seq_length=original_config.INPUT.TASK_SEQ_LEN,
236
+ max_seq_length=original_config.INPUT.MAX_SEQ_LEN,
237
+ )
238
+
239
+
240
+ class OriginalOneFormerCheckpointToOursConverter:
241
+ def __init__(self, original_model: nn.Module, config: OneFormerConfig):
242
+ self.original_model = original_model
243
+ self.config = config
244
+
245
+ def pop_all(self, renamed_keys: List[Tuple[str, str]], dst_state_dict: StateDict, src_state_dict: StateDict):
246
+ for src_key, dst_key in renamed_keys:
247
+ dst_state_dict[dst_key] = src_state_dict.pop(src_key)
248
+
249
+ # Swin Backbone
250
+ def replace_swin_backbone(self, dst_state_dict: StateDict, src_state_dict: StateDict, config: OneFormerConfig):
251
+ dst_prefix: str = "pixel_level_module.encoder"
252
+ src_prefix: str = "backbone"
253
+
254
+ renamed_keys = [
255
+ (
256
+ f"{src_prefix}.patch_embed.proj.weight",
257
+ f"{dst_prefix}.embeddings.patch_embeddings.projection.weight",
258
+ ),
259
+ (f"{src_prefix}.patch_embed.proj.bias", f"{dst_prefix}.embeddings.patch_embeddings.projection.bias"),
260
+ (f"{src_prefix}.patch_embed.norm.weight", f"{dst_prefix}.embeddings.norm.weight"),
261
+ (f"{src_prefix}.patch_embed.norm.bias", f"{dst_prefix}.embeddings.norm.bias"),
262
+ ]
263
+ num_layers = len(config.backbone_config.depths)
264
+ for layer_idx in range(num_layers):
265
+ for block_idx in range(config.backbone_config.depths[layer_idx]):
266
+ renamed_keys.extend(
267
+ [ # src, dst
268
+ (
269
+ f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm1.weight",
270
+ f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_before.weight",
271
+ ),
272
+ (
273
+ f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm1.bias",
274
+ f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_before.bias",
275
+ ),
276
+ (
277
+ f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.relative_position_bias_table",
278
+ f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.relative_position_bias_table",
279
+ ),
280
+ ]
281
+ )
282
+ # now we need to handle the attentions
283
+ # read in weights + bias of input projection layer of cross-attention
284
+
285
+ src_att_weight = src_state_dict[f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.weight"]
286
+ src_att_bias = src_state_dict[f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.bias"]
287
+
288
+ size = src_att_weight.shape[0]
289
+ offset = size // 3
290
+ dst_state_dict[
291
+ f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.query.weight"
292
+ ] = src_att_weight[:offset, :]
293
+ dst_state_dict[
294
+ f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.query.bias"
295
+ ] = src_att_bias[:offset]
296
+
297
+ dst_state_dict[
298
+ f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.key.weight"
299
+ ] = src_att_weight[offset : offset * 2, :]
300
+ dst_state_dict[
301
+ f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.key.bias"
302
+ ] = src_att_bias[offset : offset * 2]
303
+
304
+ dst_state_dict[
305
+ f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.value.weight"
306
+ ] = src_att_weight[-offset:, :]
307
+ dst_state_dict[
308
+ f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.value.bias"
309
+ ] = src_att_bias[-offset:]
310
+
311
+ # let's pop them
312
+ src_state_dict.pop(f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.weight")
313
+ src_state_dict.pop(f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.bias")
314
+ # proj
315
+ renamed_keys.extend(
316
+ [
317
+ (
318
+ f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.proj.weight",
319
+ f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.output.dense.weight",
320
+ ),
321
+ (
322
+ f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.proj.bias",
323
+ f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.output.dense.bias",
324
+ ),
325
+ ]
326
+ )
327
+
328
+ # second norm
329
+ renamed_keys.extend(
330
+ [
331
+ (
332
+ f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm2.weight",
333
+ f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_after.weight",
334
+ ),
335
+ (
336
+ f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm2.bias",
337
+ f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_after.bias",
338
+ ),
339
+ ]
340
+ )
341
+
342
+ # mlp
343
+ renamed_keys.extend(
344
+ [
345
+ (
346
+ f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc1.weight",
347
+ f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.intermediate.dense.weight",
348
+ ),
349
+ (
350
+ f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc1.bias",
351
+ f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.intermediate.dense.bias",
352
+ ),
353
+ (
354
+ f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc2.weight",
355
+ f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.output.dense.weight",
356
+ ),
357
+ (
358
+ f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc2.bias",
359
+ f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.output.dense.bias",
360
+ ),
361
+ ]
362
+ )
363
+
364
+ renamed_keys.extend(
365
+ [
366
+ (
367
+ f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.relative_position_index",
368
+ f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.relative_position_index",
369
+ )
370
+ ]
371
+ )
372
+
373
+ if layer_idx < num_layers - 1:
374
+ # patch merging
375
+ renamed_keys.extend(
376
+ [
377
+ (
378
+ f"{src_prefix}.layers.{layer_idx}.downsample.reduction.weight",
379
+ f"{dst_prefix}.encoder.layers.{layer_idx}.downsample.reduction.weight",
380
+ ),
381
+ (
382
+ f"{src_prefix}.layers.{layer_idx}.downsample.norm.weight",
383
+ f"{dst_prefix}.encoder.layers.{layer_idx}.downsample.norm.weight",
384
+ ),
385
+ (
386
+ f"{src_prefix}.layers.{layer_idx}.downsample.norm.bias",
387
+ f"{dst_prefix}.encoder.layers.{layer_idx}.downsample.norm.bias",
388
+ ),
389
+ ]
390
+ )
391
+
392
+ # hidden states norms
393
+ renamed_keys.extend(
394
+ [
395
+ (
396
+ f"{src_prefix}.norm{layer_idx}.weight",
397
+ f"{dst_prefix}.hidden_states_norms.stage{layer_idx + 1}.weight",
398
+ ),
399
+ (
400
+ f"{src_prefix}.norm{layer_idx}.bias",
401
+ f"{dst_prefix}.hidden_states_norms.stage{layer_idx + 1}.bias",
402
+ ),
403
+ ]
404
+ )
405
+
406
+ self.pop_all(renamed_keys, dst_state_dict, src_state_dict)
407
+
408
+ # Dinat Backbone
409
+ def replace_dinat_backbone(self, dst_state_dict: StateDict, src_state_dict: StateDict, config: OneFormerConfig):
410
+ dst_prefix: str = "pixel_level_module.encoder"
411
+ src_prefix: str = "backbone"
412
+
413
+ def rename_keys_for_weight_bias(src_prefix: str, dst_prefix: str):
414
+ return [
415
+ (f"{src_prefix}.weight", f"{dst_prefix}.weight"),
416
+ (f"{src_prefix}.bias", f"{dst_prefix}.bias"),
417
+ ]
418
+
419
+ renamed_keys = rename_keys_for_weight_bias(f"{src_prefix}.patch_embed.norm", f"{dst_prefix}.embeddings.norm")
420
+
421
+ for i in range(2):
422
+ renamed_keys.extend(
423
+ rename_keys_for_weight_bias(
424
+ f"{src_prefix}.patch_embed.proj.{i}",
425
+ f"{dst_prefix}.embeddings.patch_embeddings.projection.{i}",
426
+ )
427
+ )
428
+
429
+ num_layers = len(config.backbone_config.depths)
430
+ for layer_idx in range(num_layers):
431
+ for block_idx in range(config.backbone_config.depths[layer_idx]):
432
+ renamed_keys.extend(
433
+ rename_keys_for_weight_bias(
434
+ f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.norm1",
435
+ f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.layernorm_before",
436
+ )
437
+ )
438
+
439
+ renamed_keys.extend(
440
+ rename_keys_for_weight_bias(
441
+ f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.norm2",
442
+ f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.layernorm_after",
443
+ )
444
+ )
445
+
446
+ renamed_keys.extend(
447
+ [ # src, dst
448
+ (
449
+ f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.attn.rpb",
450
+ f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.self.rpb",
451
+ ),
452
+ ]
453
+ )
454
+ # now we need to handle the attentions
455
+ # read in weights + bias of input projection layer of cross-attention
456
+
457
+ src_att_weight = src_state_dict[f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.attn.qkv.weight"]
458
+ src_att_bias = src_state_dict[f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.attn.qkv.bias"]
459
+
460
+ size = src_att_weight.shape[0]
461
+ offset = size // 3
462
+ dst_state_dict[
463
+ f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.self.query.weight"
464
+ ] = src_att_weight[:offset, :]
465
+ dst_state_dict[
466
+ f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.self.query.bias"
467
+ ] = src_att_bias[:offset]
468
+
469
+ dst_state_dict[
470
+ f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.self.key.weight"
471
+ ] = src_att_weight[offset : offset * 2, :]
472
+ dst_state_dict[
473
+ f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.self.key.bias"
474
+ ] = src_att_bias[offset : offset * 2]
475
+
476
+ dst_state_dict[
477
+ f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.self.value.weight"
478
+ ] = src_att_weight[-offset:, :]
479
+ dst_state_dict[
480
+ f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.self.value.bias"
481
+ ] = src_att_bias[-offset:]
482
+
483
+ # let's pop them
484
+ src_state_dict.pop(f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.attn.qkv.weight")
485
+ src_state_dict.pop(f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.attn.qkv.bias")
486
+ # proj
487
+
488
+ renamed_keys.extend(
489
+ rename_keys_for_weight_bias(
490
+ f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.attn.proj",
491
+ f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.output.dense",
492
+ )
493
+ )
494
+
495
+ # mlp
496
+ renamed_keys.extend(
497
+ rename_keys_for_weight_bias(
498
+ f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.mlp.fc1",
499
+ f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.intermediate.dense",
500
+ )
501
+ )
502
+
503
+ renamed_keys.extend(
504
+ rename_keys_for_weight_bias(
505
+ f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.mlp.fc2",
506
+ f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.output.dense",
507
+ )
508
+ )
509
+
510
+ if layer_idx < num_layers - 1:
511
+ # patch merging
512
+ renamed_keys.extend(
513
+ [
514
+ (
515
+ f"{src_prefix}.levels.{layer_idx}.downsample.reduction.weight",
516
+ f"{dst_prefix}.encoder.levels.{layer_idx}.downsample.reduction.weight",
517
+ ),
518
+ (
519
+ f"{src_prefix}.levels.{layer_idx}.downsample.norm.weight",
520
+ f"{dst_prefix}.encoder.levels.{layer_idx}.downsample.norm.weight",
521
+ ),
522
+ (
523
+ f"{src_prefix}.levels.{layer_idx}.downsample.norm.bias",
524
+ f"{dst_prefix}.encoder.levels.{layer_idx}.downsample.norm.bias",
525
+ ),
526
+ ]
527
+ )
528
+
529
+ # hidden states norms
530
+ renamed_keys.extend(
531
+ [
532
+ (
533
+ f"{src_prefix}.norm{layer_idx}.weight",
534
+ f"{dst_prefix}.hidden_states_norms.stage{layer_idx + 1}.weight",
535
+ ),
536
+ (
537
+ f"{src_prefix}.norm{layer_idx}.bias",
538
+ f"{dst_prefix}.hidden_states_norms.stage{layer_idx + 1}.bias",
539
+ ),
540
+ ]
541
+ )
542
+
543
+ self.pop_all(renamed_keys, dst_state_dict, src_state_dict)
544
+
545
+ # Backbone + Pixel Decoder
546
+ def replace_pixel_module(self, dst_state_dict: StateDict, src_state_dict: StateDict, is_swin: bool):
547
+ dst_prefix: str = "pixel_level_module.decoder"
548
+ src_prefix: str = "sem_seg_head.pixel_decoder"
549
+
550
+ if is_swin:
551
+ self.replace_swin_backbone(dst_state_dict, src_state_dict, self.config)
552
+ else:
553
+ self.replace_dinat_backbone(dst_state_dict, src_state_dict, self.config)
554
+
555
+ def rename_keys_for_weight_bias(src_prefix: str, dst_prefix: str):
556
+ return [
557
+ (f"{src_prefix}.weight", f"{dst_prefix}.weight"),
558
+ (f"{src_prefix}.bias", f"{dst_prefix}.bias"),
559
+ ]
560
+
561
+ def rename_keys_for_self_attn(src_prefix: str, dst_prefix: str):
562
+ self_attn_keys = []
563
+ self_attn_keys.extend(
564
+ rename_keys_for_weight_bias(f"{src_prefix}.attention_weights", f"{dst_prefix}.attention_weights")
565
+ )
566
+ self_attn_keys.extend(
567
+ rename_keys_for_weight_bias(f"{src_prefix}.output_proj", f"{dst_prefix}.output_proj")
568
+ )
569
+ self_attn_keys.extend(
570
+ rename_keys_for_weight_bias(f"{src_prefix}.sampling_offsets", f"{dst_prefix}.sampling_offsets")
571
+ )
572
+ self_attn_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.value_proj", f"{dst_prefix}.value_proj"))
573
+
574
+ return self_attn_keys
575
+
576
+ def rename_keys_for_encoder_layer(src_prefix: str, dst_prefix: str):
577
+ encoder_keys = []
578
+ encoder_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.linear1", f"{dst_prefix}.fc1"))
579
+ encoder_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.linear2", f"{dst_prefix}.fc2"))
580
+ encoder_keys.extend(
581
+ rename_keys_for_weight_bias(f"{src_prefix}.norm1", f"{dst_prefix}.self_attn_layer_norm")
582
+ )
583
+ encoder_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.norm2", f"{dst_prefix}.final_layer_norm"))
584
+ encoder_keys.extend(rename_keys_for_self_attn(f"{src_prefix}.self_attn", f"{dst_prefix}.self_attn"))
585
+
586
+ return encoder_keys
587
+
588
+ # convolution layer for final features
589
+ renamed_keys = [
590
+ (f"{src_prefix}.adapter_1.weight", f"{dst_prefix}.adapter_1.0.weight"),
591
+ (f"{src_prefix}.adapter_1.norm.weight", f"{dst_prefix}.adapter_1.1.weight"),
592
+ (f"{src_prefix}.adapter_1.norm.bias", f"{dst_prefix}.adapter_1.1.bias"),
593
+ ]
594
+
595
+ renamed_keys.extend(
596
+ [
597
+ (f"{src_prefix}.layer_1.weight", f"{dst_prefix}.layer_1.0.weight"),
598
+ (f"{src_prefix}.layer_1.norm.weight", f"{dst_prefix}.layer_1.1.weight"),
599
+ (f"{src_prefix}.layer_1.norm.bias", f"{dst_prefix}.layer_1.1.bias"),
600
+ ]
601
+ )
602
+
603
+ # proj layers
604
+ for i in range(3):
605
+ for j in range(2):
606
+ renamed_keys.extend(
607
+ [
608
+ (f"{src_prefix}.input_proj.{i}.{j}.weight", f"{dst_prefix}.input_projections.{i}.{j}.weight"),
609
+ (f"{src_prefix}.input_proj.{i}.{j}.bias", f"{dst_prefix}.input_projections.{i}.{j}.bias"),
610
+ ]
611
+ )
612
+
613
+ renamed_keys.extend([(f"{src_prefix}.transformer.level_embed", f"{dst_prefix}.level_embed")])
614
+
615
+ # layers
616
+ for layer_idx in range(self.config.encoder_layers):
617
+ renamed_keys.extend(
618
+ rename_keys_for_encoder_layer(
619
+ f"{src_prefix}.transformer.encoder.layers.{layer_idx}", f"{dst_prefix}.encoder.layers.{layer_idx}"
620
+ )
621
+ )
622
+
623
+ # proj
624
+ renamed_keys.extend(
625
+ [
626
+ (f"{src_prefix}.mask_features.weight", f"{dst_prefix}.mask_projection.weight"),
627
+ (f"{src_prefix}.mask_features.bias", f"{dst_prefix}.mask_projection.bias"),
628
+ ]
629
+ )
630
+
631
+ self.pop_all(renamed_keys, dst_state_dict, src_state_dict)
632
+
633
+ # Transformer Decoder
634
+ def replace_keys_qkv_transformer_decoder(self, dst_state_dict: StateDict, src_state_dict: StateDict):
635
+ dst_prefix: str = "transformer_module.decoder.layers"
636
+ src_prefix: str = "sem_seg_head.predictor"
637
+ for i in range(self.config.decoder_layers - 1):
638
+ # read in weights + bias of input projection layer of self-attention
639
+ in_proj_weight = src_state_dict.pop(
640
+ f"{src_prefix}.transformer_self_attention_layers.{i}.self_attn.in_proj_weight"
641
+ )
642
+ in_proj_bias = src_state_dict.pop(
643
+ f"{src_prefix}.transformer_self_attention_layers.{i}.self_attn.in_proj_bias"
644
+ )
645
+ # next, add query, keys and values (in that order) to the state dict
646
+ dst_state_dict[f"{dst_prefix}.{i}.self_attn.self_attn.q_proj.weight"] = in_proj_weight[:256, :]
647
+ dst_state_dict[f"{dst_prefix}.{i}.self_attn.self_attn.q_proj.bias"] = in_proj_bias[:256]
648
+ dst_state_dict[f"{dst_prefix}.{i}.self_attn.self_attn.k_proj.weight"] = in_proj_weight[256:512, :]
649
+ dst_state_dict[f"{dst_prefix}.{i}.self_attn.self_attn.k_proj.bias"] = in_proj_bias[256:512]
650
+ dst_state_dict[f"{dst_prefix}.{i}.self_attn.self_attn.v_proj.weight"] = in_proj_weight[-256:, :]
651
+ dst_state_dict[f"{dst_prefix}.{i}.self_attn.self_attn.v_proj.bias"] = in_proj_bias[-256:]
652
+
653
+ def replace_transformer_module(self, dst_state_dict: StateDict, src_state_dict: StateDict):
654
+ dst_prefix: str = "transformer_module"
655
+ src_prefix: str = "sem_seg_head.predictor"
656
+
657
+ def rename_keys_for_weight_bias(src_prefix: str, dst_prefix: str):
658
+ return [
659
+ (f"{src_prefix}.weight", f"{dst_prefix}.weight"),
660
+ (f"{src_prefix}.bias", f"{dst_prefix}.bias"),
661
+ ]
662
+
663
+ def rename_keys_for_attn(src_prefix: str, dst_prefix: str):
664
+ attn_keys = [
665
+ (f"{src_prefix}.in_proj_bias", f"{dst_prefix}.in_proj_bias"),
666
+ (f"{src_prefix}.in_proj_weight", f"{dst_prefix}.in_proj_weight"),
667
+ ]
668
+ attn_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.out_proj", f"{dst_prefix}.out_proj"))
669
+
670
+ return attn_keys
671
+
672
+ def rename_keys_for_self_attn(src_prefix: str, dst_prefix: str):
673
+ attn_keys = []
674
+ attn_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.out_proj", f"{dst_prefix}.out_proj"))
675
+
676
+ return attn_keys
677
+
678
+ def rename_keys_for_query_transformer_layer(src_prefix: str, dst_prefix: str):
679
+ query_transformer_layer_keys = []
680
+
681
+ query_transformer_layer_keys.extend(
682
+ rename_keys_for_weight_bias(f"{src_prefix}.linear1", f"{dst_prefix}.linear1")
683
+ )
684
+ query_transformer_layer_keys.extend(
685
+ rename_keys_for_weight_bias(f"{src_prefix}.linear2", f"{dst_prefix}.linear2")
686
+ )
687
+ query_transformer_layer_keys.extend(
688
+ rename_keys_for_weight_bias(f"{src_prefix}.norm1", f"{dst_prefix}.norm1")
689
+ )
690
+ query_transformer_layer_keys.extend(
691
+ rename_keys_for_weight_bias(f"{src_prefix}.norm2", f"{dst_prefix}.norm2")
692
+ )
693
+ query_transformer_layer_keys.extend(
694
+ rename_keys_for_weight_bias(f"{src_prefix}.norm3", f"{dst_prefix}.norm3")
695
+ )
696
+
697
+ query_transformer_layer_keys.extend(
698
+ rename_keys_for_attn(f"{src_prefix}.self_attn", f"{dst_prefix}.self_attn")
699
+ )
700
+
701
+ query_transformer_layer_keys.extend(
702
+ rename_keys_for_attn(f"{src_prefix}.multihead_attn", f"{dst_prefix}.multihead_attn")
703
+ )
704
+
705
+ return query_transformer_layer_keys
706
+
707
+ def rename_keys_for_cross_attn_layer(src_prefix: str, dst_prefix: str):
708
+ cross_attn_layer_keys = []
709
+
710
+ cross_attn_layer_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.norm", f"{dst_prefix}.norm"))
711
+ cross_attn_layer_keys.extend(
712
+ rename_keys_for_attn(f"{src_prefix}.multihead_attn", f"{dst_prefix}.multihead_attn")
713
+ )
714
+
715
+ return cross_attn_layer_keys
716
+
717
+ def rename_keys_for_self_attn_layer(src_prefix: str, dst_prefix: str):
718
+ self_attn_layer_keys = []
719
+
720
+ self_attn_layer_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.norm", f"{dst_prefix}.norm"))
721
+ self_attn_layer_keys.extend(
722
+ rename_keys_for_self_attn(f"{src_prefix}.self_attn", f"{dst_prefix}.self_attn")
723
+ )
724
+
725
+ return self_attn_layer_keys
726
+
727
+ def rename_keys_for_ffn_layer(src_prefix: str, dst_prefix: str):
728
+ ffn_layer_keys = []
729
+
730
+ ffn_layer_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.linear1", f"{dst_prefix}.linear1"))
731
+ ffn_layer_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.linear2", f"{dst_prefix}.linear2"))
732
+ ffn_layer_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.norm", f"{dst_prefix}.norm"))
733
+
734
+ return ffn_layer_keys
735
+
736
+ def rename_keys_for_transformer_decoder_layer(src_prefix: str, dst_prefix: str, idx: int):
737
+ transformer_decoder_layer_keys = []
738
+
739
+ transformer_decoder_layer_keys.extend(
740
+ rename_keys_for_cross_attn_layer(
741
+ f"{src_prefix}.transformer_cross_attention_layers.{idx}", f"{dst_prefix}.{idx}.cross_attn"
742
+ )
743
+ )
744
+
745
+ transformer_decoder_layer_keys.extend(
746
+ rename_keys_for_self_attn_layer(
747
+ f"{src_prefix}.transformer_self_attention_layers.{idx}", f"{dst_prefix}.{idx}.self_attn"
748
+ )
749
+ )
750
+
751
+ transformer_decoder_layer_keys.extend(
752
+ rename_keys_for_ffn_layer(f"{src_prefix}.transformer_ffn_layers.{idx}", f"{dst_prefix}.{idx}.ffn")
753
+ )
754
+
755
+ return transformer_decoder_layer_keys
756
+
757
+ # positional embedding for object queries
758
+ renamed_keys = [
759
+ (f"{src_prefix}.query_embed.weight", f"{dst_prefix}.queries_embedder.weight"),
760
+ (f"{src_prefix}.level_embed.weight", f"{dst_prefix}.level_embed.weight"),
761
+ ]
762
+
763
+ # norm
764
+ renamed_keys.extend(
765
+ rename_keys_for_weight_bias(f"{src_prefix}.decoder_norm", f"{dst_prefix}.decoder.decoder_norm")
766
+ )
767
+
768
+ # proj
769
+ renamed_keys.extend(
770
+ rename_keys_for_weight_bias(
771
+ f"{src_prefix}.class_input_proj", f"{dst_prefix}.decoder.query_input_projection"
772
+ )
773
+ )
774
+
775
+ renamed_keys.extend(
776
+ rename_keys_for_weight_bias(f"{src_prefix}.class_embed", f"{dst_prefix}.decoder.class_embed")
777
+ )
778
+
779
+ for i in range(3):
780
+ renamed_keys.extend(
781
+ rename_keys_for_weight_bias(
782
+ f"{src_prefix}.mask_embed.layers.{i}", f"{dst_prefix}.decoder.mask_embed.layers.{i}.0"
783
+ )
784
+ )
785
+
786
+ # norm
787
+ renamed_keys.extend(
788
+ rename_keys_for_weight_bias(
789
+ f"{src_prefix}.class_transformer.decoder.norm", f"{dst_prefix}.decoder.query_transformer.decoder.norm"
790
+ )
791
+ )
792
+
793
+ # transformer to update queries with task tokens
794
+ for i in range(self.config.query_dec_layers):
795
+ renamed_keys.extend(
796
+ rename_keys_for_query_transformer_layer(
797
+ f"{src_prefix}.class_transformer.decoder.layers.{i}",
798
+ f"{dst_prefix}.decoder.query_transformer.decoder.layers.{i}",
799
+ )
800
+ )
801
+
802
+ # decoder layers
803
+ for i in range(self.config.decoder_layers - 1):
804
+ renamed_keys.extend(
805
+ rename_keys_for_transformer_decoder_layer(
806
+ f"{src_prefix}",
807
+ f"{dst_prefix}.decoder.layers",
808
+ i,
809
+ )
810
+ )
811
+
812
+ self.pop_all(renamed_keys, dst_state_dict, src_state_dict)
813
+ self.replace_keys_qkv_transformer_decoder(dst_state_dict, src_state_dict)
814
+
815
+ def replace_task_mlp(self, dst_state_dict: StateDict, src_state_dict: StateDict):
816
+ dst_prefix: str = "task_encoder"
817
+ src_prefix: str = "task_mlp"
818
+
819
+ def rename_keys_for_weight_bias(src_prefix: str, dst_prefix: str):
820
+ return [
821
+ (f"{src_prefix}.weight", f"{dst_prefix}.weight"),
822
+ (f"{src_prefix}.bias", f"{dst_prefix}.bias"),
823
+ ]
824
+
825
+ renamed_keys = []
826
+
827
+ for i in range(2):
828
+ renamed_keys.extend(
829
+ rename_keys_for_weight_bias(f"{src_prefix}.layers.{i}", f"{dst_prefix}.task_mlp.layers.{i}.0")
830
+ )
831
+
832
+ self.pop_all(renamed_keys, dst_state_dict, src_state_dict)
833
+
834
+ def replace_text_projector(self, dst_state_dict: StateDict, src_state_dict: StateDict):
835
+ dst_prefix: str = "text_mapper.text_projector"
836
+ src_prefix: str = "text_projector"
837
+
838
+ def rename_keys_for_weight_bias(src_prefix: str, dst_prefix: str):
839
+ return [
840
+ (f"{src_prefix}.weight", f"{dst_prefix}.weight"),
841
+ (f"{src_prefix}.bias", f"{dst_prefix}.bias"),
842
+ ]
843
+
844
+ renamed_keys = []
845
+
846
+ for i in range(self.config.text_encoder_config["text_encoder_proj_layers"]):
847
+ renamed_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.layers.{i}", f"{dst_prefix}.{i}.0"))
848
+
849
+ self.pop_all(renamed_keys, dst_state_dict, src_state_dict)
850
+
851
+ def replace_text_mapper(self, dst_state_dict: StateDict, src_state_dict: StateDict):
852
+ dst_prefix: str = "text_mapper.text_encoder"
853
+ src_prefix: str = "text_encoder"
854
+
855
+ self.replace_text_projector(dst_state_dict, src_state_dict)
856
+
857
+ def rename_keys_for_weight_bias(src_prefix: str, dst_prefix: str):
858
+ return [
859
+ (f"{src_prefix}.weight", f"{dst_prefix}.weight"),
860
+ (f"{src_prefix}.bias", f"{dst_prefix}.bias"),
861
+ ]
862
+
863
+ def rename_keys_for_attn(src_prefix: str, dst_prefix: str):
864
+ attn_keys = [
865
+ (f"{src_prefix}.in_proj_bias", f"{dst_prefix}.in_proj_bias"),
866
+ (f"{src_prefix}.in_proj_weight", f"{dst_prefix}.in_proj_weight"),
867
+ ]
868
+ attn_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.out_proj", f"{dst_prefix}.out_proj"))
869
+
870
+ return attn_keys
871
+
872
+ def rename_keys_for_layer(src_prefix: str, dst_prefix: str):
873
+ resblock_keys = []
874
+
875
+ resblock_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.mlp.c_fc", f"{dst_prefix}.mlp.fc1"))
876
+ resblock_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.mlp.c_proj", f"{dst_prefix}.mlp.fc2"))
877
+ resblock_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.ln_1", f"{dst_prefix}.layer_norm1"))
878
+ resblock_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.ln_2", f"{dst_prefix}.layer_norm2"))
879
+ resblock_keys.extend(rename_keys_for_attn(f"{src_prefix}.attn", f"{dst_prefix}.self_attn"))
880
+
881
+ return resblock_keys
882
+
883
+ renamed_keys = [
884
+ ("prompt_ctx.weight", "text_mapper.prompt_ctx.weight"),
885
+ ]
886
+
887
+ renamed_keys.extend(
888
+ [
889
+ (f"{src_prefix}.positional_embedding", f"{dst_prefix}.positional_embedding"),
890
+ (f"{src_prefix}.token_embedding.weight", f"{dst_prefix}.token_embedding.weight"),
891
+ ]
892
+ )
893
+
894
+ renamed_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.ln_final", f"{dst_prefix}.ln_final"))
895
+
896
+ for i in range(self.config.text_encoder_config["text_encoder_num_layers"]):
897
+ renamed_keys.extend(
898
+ rename_keys_for_layer(
899
+ f"{src_prefix}.transformer.resblocks.{i}", f"{dst_prefix}.transformer.layers.{i}"
900
+ )
901
+ )
902
+
903
+ self.pop_all(renamed_keys, dst_state_dict, src_state_dict)
904
+
905
+ def convert(self, oneformer: OneFormerModel, is_swin: bool) -> OneFormerModel:
906
+ dst_state_dict = TrackedStateDict(oneformer.state_dict())
907
+ src_state_dict = self.original_model.state_dict()
908
+
909
+ self.replace_pixel_module(dst_state_dict, src_state_dict, is_swin)
910
+ self.replace_transformer_module(dst_state_dict, src_state_dict)
911
+ self.replace_task_mlp(dst_state_dict, src_state_dict)
912
+ if self.config.is_training:
913
+ self.replace_text_mapper(dst_state_dict, src_state_dict)
914
+
915
+ logger.info(f"Missed keys are {pformat(dst_state_dict.diff())}")
916
+ logger.info(f"Not copied keys are {pformat(src_state_dict.keys())}")
917
+ logger.info("🙌 Done")
918
+
919
+ oneformer.load_state_dict(dst_state_dict)
920
+
921
+ return oneformer
922
+
923
+ @staticmethod
924
+ def using_dirs(checkpoints_dir: Path, config_dir: Path) -> Iterator[Tuple[object, Path, Path]]:
925
+ checkpoints: List[Path] = checkpoints_dir.glob("**/*.pth")
926
+
927
+ for checkpoint in checkpoints:
928
+ logger.info(f"💪 Converting {checkpoint.stem}")
929
+ # find associated config file
930
+ config: Path = config_dir / f"{checkpoint.stem}.yaml"
931
+
932
+ yield config, checkpoint
933
+
934
+
935
+ def post_process_sem_seg_output(outputs: OneFormerForUniversalSegmentationOutput, target_size: Tuple[int, int]):
936
+ # class_queries_logits has shape [BATCH, QUERIES, CLASSES + 1]
937
+ class_queries_logits = outputs.class_queries_logits
938
+ # masks_queries_logits has shape [BATCH, QUERIES, HEIGHT, WIDTH]
939
+ masks_queries_logits = outputs.masks_queries_logits
940
+ if target_size is not None:
941
+ masks_queries_logits = torch.nn.functional.interpolate(
942
+ masks_queries_logits,
943
+ size=target_size,
944
+ mode="bilinear",
945
+ align_corners=False,
946
+ )
947
+ # remove the null class `[..., :-1]`
948
+ masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1]
949
+ # mask probs has shape [BATCH, QUERIES, HEIGHT, WIDTH]
950
+ masks_probs = masks_queries_logits.sigmoid()
951
+ # now we want to sum over the queries,
952
+ # $ out_{c,h,w} = \sum_q p_{q,c} * m_{q,h,w} $
953
+ # where $ softmax(p) \in R^{q, c} $ is the mask classes
954
+ # and $ sigmoid(m) \in R^{q, h, w}$ is the mask probabilities
955
+ # b(atch)q(uery)c(lasses), b(atch)q(uery)h(eight)w(idth)
956
+ segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
957
+
958
+ return segmentation
959
+
960
+
961
+ def test(
962
+ original_model,
963
+ our_model: OneFormerForUniversalSegmentation,
964
+ processor: OneFormerProcessor,
965
+ model_repo: str,
966
+ ):
967
+ def _preprocess_text(text_list=None, max_length=77):
968
+ if text_list is None:
969
+ raise ValueError("tokens cannot be None.")
970
+
971
+ tokens = tokenizer(text_list, padding="max_length", max_length=max_length, truncation=True)
972
+
973
+ attention_masks, input_ids = tokens["attention_mask"], tokens["input_ids"]
974
+
975
+ token_inputs = []
976
+ for attn_mask, input_id in zip(attention_masks, input_ids):
977
+ token = torch.tensor(attn_mask) * torch.tensor(input_id)
978
+ token_inputs.append(token.unsqueeze(0))
979
+
980
+ token_inputs = torch.cat(token_inputs, dim=0)
981
+ return token_inputs
982
+
983
+ with torch.no_grad():
984
+ tokenizer = CLIPTokenizer.from_pretrained(model_repo)
985
+ original_model = original_model.eval()
986
+ our_model = our_model.eval()
987
+
988
+ im = prepare_img()
989
+
990
+ tr = T.Compose(
991
+ [
992
+ T.Resize((640, 640)),
993
+ T.ToTensor(),
994
+ T.Normalize(
995
+ mean=torch.tensor([123.675, 116.280, 103.530]) / 255.0,
996
+ std=torch.tensor([58.395, 57.120, 57.375]) / 255.0,
997
+ ),
998
+ ],
999
+ )
1000
+
1001
+ x = tr(im).unsqueeze(0)
1002
+
1003
+ task_input = ["the task is semantic"]
1004
+ task_token = _preprocess_text(task_input, max_length=processor.task_seq_length)
1005
+
1006
+ original_model_backbone_features = original_model.backbone(x.clone())
1007
+
1008
+ our_model_output: OneFormerModelOutput = our_model.model(x.clone(), task_token, output_hidden_states=True)
1009
+
1010
+ for original_model_feature, our_model_feature in zip(
1011
+ original_model_backbone_features.values(), our_model_output.encoder_hidden_states
1012
+ ):
1013
+ assert torch.allclose(original_model_feature, our_model_feature, atol=3e-3), (
1014
+ "The backbone features are not the same."
1015
+ )
1016
+ mask_features, _, multi_scale_features, _, _ = original_model.sem_seg_head.pixel_decoder.forward_features(
1017
+ original_model_backbone_features
1018
+ )
1019
+
1020
+ original_pixel_decoder_features = []
1021
+ original_pixel_decoder_features.append(mask_features)
1022
+ for i in range(len(multi_scale_features)):
1023
+ original_pixel_decoder_features.append(multi_scale_features[i])
1024
+
1025
+ for original_model_feature, our_model_feature in zip(
1026
+ original_pixel_decoder_features, our_model_output.pixel_decoder_hidden_states
1027
+ ):
1028
+ assert torch.allclose(original_model_feature, our_model_feature, atol=3e-4), (
1029
+ "The pixel decoder feature are not the same"
1030
+ )
1031
+
1032
+ tr_complete = T.Compose(
1033
+ [
1034
+ T.Resize((640, 640)),
1035
+ T.ToTensor(),
1036
+ ],
1037
+ )
1038
+
1039
+ y = (tr_complete(im) * 255.0).to(torch.int).float()
1040
+
1041
+ # let's test the full model
1042
+ original_model_out = original_model([{"image": y.clone(), "task": "The task is semantic"}])
1043
+
1044
+ original_segmentation = original_model_out[0]["sem_seg"]
1045
+
1046
+ our_model_out: OneFormerForUniversalSegmentationOutput = our_model(
1047
+ x.clone(), task_token, output_hidden_states=True
1048
+ )
1049
+
1050
+ our_segmentation = post_process_sem_seg_output(our_model_out, target_size=(640, 640))[0]
1051
+
1052
+ assert torch.allclose(original_segmentation, our_segmentation, atol=1e-3), (
1053
+ "The segmentation image is not the same."
1054
+ )
1055
+
1056
+ logger.info("✅ Test passed!")
1057
+
1058
+
1059
+ def get_name(checkpoint_file: Path):
1060
+ model_name_raw: str = checkpoint_file.stem
1061
+
1062
+ backbone = "swin" if "swin" in model_name_raw else "dinat"
1063
+ dataset = ""
1064
+ if "coco" in model_name_raw:
1065
+ dataset = "coco"
1066
+ elif "ade20k" in model_name_raw:
1067
+ dataset = "ade20k"
1068
+ elif "cityscapes" in model_name_raw:
1069
+ dataset = "cityscapes"
1070
+ else:
1071
+ raise ValueError(
1072
+ f"{model_name_raw} must be wrong since we didn't find 'coco' or 'ade20k' or 'cityscapes' in it "
1073
+ )
1074
+
1075
+ backbone_types = ["tiny", "large"]
1076
+
1077
+ backbone_type = list(filter(lambda x: x in model_name_raw, backbone_types))[0]
1078
+
1079
+ model_name = f"oneformer_{dataset}_{backbone}_{backbone_type}"
1080
+
1081
+ return model_name
1082
+
1083
+
1084
+ if __name__ == "__main__":
1085
+ parser = ArgumentParser(
1086
+ description=(
1087
+ "Command line to convert the original oneformer models (with swin backbone) to transformers"
1088
+ " implementation."
1089
+ )
1090
+ )
1091
+
1092
+ parser.add_argument(
1093
+ "--checkpoints_dir",
1094
+ type=Path,
1095
+ help=(
1096
+ "A directory containing the model's checkpoints. The directory has to have the following structure:"
1097
+ " structure: <DIR_NAME>/<DATASET_NAME>/<CONFIG_NAME>.pth; where <CONFIG_NAME> name must follow the"
1098
+ " following nomenclature nomenclature: oneformer_<DATASET_NAME>_<BACKBONE>_<BACKBONE_TYPE>"
1099
+ ),
1100
+ )
1101
+ parser.add_argument(
1102
+ "--configs_dir",
1103
+ type=Path,
1104
+ help=(
1105
+ "A directory containing the model's configs, see detectron2 doc. The directory has to have the following"
1106
+ " structure: <DIR_NAME>/<DATASET_NAME>/<CONFIG_NAME>.yaml; where <CONFIG_NAME> name must follow the"
1107
+ " following nomenclature nomenclature: oneformer_<DATASET_NAME>_<BACKBONE>_<BACKBONE_TYPE>"
1108
+ ),
1109
+ )
1110
+ parser.add_argument(
1111
+ "--pytorch_dump_folder_path",
1112
+ required=True,
1113
+ type=Path,
1114
+ help="Path to the folder to output PyTorch models.",
1115
+ )
1116
+ parser.add_argument(
1117
+ "--oneformer_dir",
1118
+ required=True,
1119
+ type=Path,
1120
+ help=(
1121
+ "A path to OneFormer's original implementation directory. You can download from here: "
1122
+ "https://github.com/SHI-Labs/OneFormer"
1123
+ ),
1124
+ )
1125
+
1126
+ args = parser.parse_args()
1127
+
1128
+ checkpoints_dir: Path = args.checkpoints_dir
1129
+ config_dir: Path = args.configs_dir
1130
+ save_directory: Path = args.pytorch_dump_folder_path
1131
+ oneformer_dir: Path = args.oneformer_dir
1132
+ # append the path to the parents to oneformer dir
1133
+ sys.path.append(str(oneformer_dir.parent))
1134
+ # and import what's needed
1135
+ from OneFormer.oneformer import add_common_config, add_dinat_config, add_oneformer_config, add_swin_config
1136
+ from OneFormer.oneformer.oneformer_model import OneFormer as OriginalOneFormer
1137
+
1138
+ if not save_directory.exists():
1139
+ save_directory.mkdir(parents=True)
1140
+
1141
+ for config_file, checkpoint_file in OriginalOneFormerCheckpointToOursConverter.using_dirs(
1142
+ checkpoints_dir, config_dir
1143
+ ):
1144
+ processor = OriginalOneFormerConfigToProcessorConverter()(
1145
+ setup_cfg(Args(config_file=config_file)), os.path.join("shi-labs", config_file.stem)
1146
+ )
1147
+
1148
+ original_config = setup_cfg(Args(config_file=config_file))
1149
+ oneformer_kwargs = OriginalOneFormer.from_config(original_config)
1150
+
1151
+ original_model = OriginalOneFormer(**oneformer_kwargs).eval()
1152
+
1153
+ DetectionCheckpointer(original_model).load(str(checkpoint_file))
1154
+
1155
+ is_swin = "swin" in config_file.stem
1156
+
1157
+ config: OneFormerConfig = OriginalOneFormerConfigToOursConverter()(original_config, is_swin)
1158
+
1159
+ oneformer = OneFormerModel(config=config).eval()
1160
+
1161
+ converter = OriginalOneFormerCheckpointToOursConverter(original_model, config)
1162
+
1163
+ oneformer = converter.convert(oneformer, is_swin)
1164
+
1165
+ oneformer_for_universal_segmentation = OneFormerForUniversalSegmentation(config=config).eval()
1166
+
1167
+ oneformer_for_universal_segmentation.model = oneformer
1168
+
1169
+ test(
1170
+ original_model,
1171
+ oneformer_for_universal_segmentation,
1172
+ processor,
1173
+ os.path.join("shi-labs", config_file.stem),
1174
+ )
1175
+
1176
+ model_name = get_name(checkpoint_file)
1177
+ logger.info(f"🪄 Saving {model_name}")
1178
+
1179
+ processor.save_pretrained(save_directory / model_name)
1180
+ oneformer_for_universal_segmentation.save_pretrained(save_directory / model_name)
1181
+
1182
+ processor.push_to_hub(
1183
+ repo_id=os.path.join("shi-labs", config_file.stem),
1184
+ commit_message="Add configs",
1185
+ use_temp_dir=True,
1186
+ )
1187
+ oneformer_for_universal_segmentation.push_to_hub(
1188
+ repo_id=os.path.join("shi-labs", config_file.stem),
1189
+ commit_message="Add model",
1190
+ use_temp_dir=True,
1191
+ )
docs/transformers/src/transformers/models/oneformer/image_processing_oneformer.py ADDED
@@ -0,0 +1,1356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 SHI Labs and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Image processor class for OneFormer."""
16
+
17
+ import json
18
+ import os
19
+ from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
20
+
21
+ import numpy as np
22
+ from huggingface_hub import hf_hub_download
23
+ from huggingface_hub.utils import RepositoryNotFoundError
24
+
25
+ from ...image_processing_utils import INIT_SERVICE_KWARGS, BaseImageProcessor, BatchFeature, get_size_dict
26
+ from ...image_transforms import (
27
+ PaddingMode,
28
+ get_resize_output_image_size,
29
+ pad,
30
+ rescale,
31
+ resize,
32
+ to_channel_dimension_format,
33
+ )
34
+ from ...image_utils import (
35
+ ChannelDimension,
36
+ ImageInput,
37
+ PILImageResampling,
38
+ get_image_size,
39
+ infer_channel_dimension_format,
40
+ is_scaled_image,
41
+ make_list_of_images,
42
+ to_numpy_array,
43
+ valid_images,
44
+ validate_preprocess_arguments,
45
+ )
46
+ from ...utils import (
47
+ IMAGENET_DEFAULT_MEAN,
48
+ IMAGENET_DEFAULT_STD,
49
+ TensorType,
50
+ filter_out_non_signature_kwargs,
51
+ is_torch_available,
52
+ is_torch_tensor,
53
+ logging,
54
+ )
55
+ from ...utils.deprecation import deprecate_kwarg
56
+
57
+
58
+ logger = logging.get_logger(__name__)
59
+
60
+
61
+ if is_torch_available():
62
+ import torch
63
+ from torch import nn
64
+
65
+
66
+ # Copied from transformers.models.detr.image_processing_detr.max_across_indices
67
+ def max_across_indices(values: Iterable[Any]) -> List[Any]:
68
+ """
69
+ Return the maximum value across all indices of an iterable of values.
70
+ """
71
+ return [max(values_i) for values_i in zip(*values)]
72
+
73
+
74
+ # Copied from transformers.models.detr.image_processing_detr.get_max_height_width
75
+ def get_max_height_width(
76
+ images: List[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
77
+ ) -> List[int]:
78
+ """
79
+ Get the maximum height and width across all images in a batch.
80
+ """
81
+ if input_data_format is None:
82
+ input_data_format = infer_channel_dimension_format(images[0])
83
+
84
+ if input_data_format == ChannelDimension.FIRST:
85
+ _, max_height, max_width = max_across_indices([img.shape for img in images])
86
+ elif input_data_format == ChannelDimension.LAST:
87
+ max_height, max_width, _ = max_across_indices([img.shape for img in images])
88
+ else:
89
+ raise ValueError(f"Invalid channel dimension format: {input_data_format}")
90
+ return (max_height, max_width)
91
+
92
+
93
+ # Copied from transformers.models.detr.image_processing_detr.make_pixel_mask
94
+ def make_pixel_mask(
95
+ image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
96
+ ) -> np.ndarray:
97
+ """
98
+ Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
99
+
100
+ Args:
101
+ image (`np.ndarray`):
102
+ Image to make the pixel mask for.
103
+ output_size (`Tuple[int, int]`):
104
+ Output size of the mask.
105
+ """
106
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
107
+ mask = np.zeros(output_size, dtype=np.int64)
108
+ mask[:input_height, :input_width] = 1
109
+ return mask
110
+
111
+
112
+ # Copied from transformers.models.detr.image_processing_detr.binary_mask_to_rle
113
+ def binary_mask_to_rle(mask):
114
+ """
115
+ Converts given binary mask of shape `(height, width)` to the run-length encoding (RLE) format.
116
+
117
+ Args:
118
+ mask (`torch.Tensor` or `numpy.array`):
119
+ A binary mask tensor of shape `(height, width)` where 0 denotes background and 1 denotes the target
120
+ segment_id or class_id.
121
+ Returns:
122
+ `List`: Run-length encoded list of the binary mask. Refer to COCO API for more information about the RLE
123
+ format.
124
+ """
125
+ if is_torch_tensor(mask):
126
+ mask = mask.numpy()
127
+
128
+ pixels = mask.flatten()
129
+ pixels = np.concatenate([[0], pixels, [0]])
130
+ runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
131
+ runs[1::2] -= runs[::2]
132
+ return list(runs)
133
+
134
+
135
+ # Copied from transformers.models.detr.image_processing_detr.convert_segmentation_to_rle
136
+ def convert_segmentation_to_rle(segmentation):
137
+ """
138
+ Converts given segmentation map of shape `(height, width)` to the run-length encoding (RLE) format.
139
+
140
+ Args:
141
+ segmentation (`torch.Tensor` or `numpy.array`):
142
+ A segmentation map of shape `(height, width)` where each value denotes a segment or class id.
143
+ Returns:
144
+ `List[List]`: A list of lists, where each list is the run-length encoding of a segment / class id.
145
+ """
146
+ segment_ids = torch.unique(segmentation)
147
+
148
+ run_length_encodings = []
149
+ for idx in segment_ids:
150
+ mask = torch.where(segmentation == idx, 1, 0)
151
+ rle = binary_mask_to_rle(mask)
152
+ run_length_encodings.append(rle)
153
+
154
+ return run_length_encodings
155
+
156
+
157
+ # Copied from transformers.models.detr.image_processing_detr.remove_low_and_no_objects
158
+ def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_labels):
159
+ """
160
+ Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores` and
161
+ `labels`.
162
+
163
+ Args:
164
+ masks (`torch.Tensor`):
165
+ A tensor of shape `(num_queries, height, width)`.
166
+ scores (`torch.Tensor`):
167
+ A tensor of shape `(num_queries)`.
168
+ labels (`torch.Tensor`):
169
+ A tensor of shape `(num_queries)`.
170
+ object_mask_threshold (`float`):
171
+ A number between 0 and 1 used to binarize the masks.
172
+ Raises:
173
+ `ValueError`: Raised when the first dimension doesn't match in all input tensors.
174
+ Returns:
175
+ `Tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`]`: The `masks`, `scores` and `labels` without the region
176
+ < `object_mask_threshold`.
177
+ """
178
+ if not (masks.shape[0] == scores.shape[0] == labels.shape[0]):
179
+ raise ValueError("mask, scores and labels must have the same shape!")
180
+
181
+ to_keep = labels.ne(num_labels) & (scores > object_mask_threshold)
182
+
183
+ return masks[to_keep], scores[to_keep], labels[to_keep]
184
+
185
+
186
+ # Copied from transformers.models.detr.image_processing_detr.check_segment_validity
187
+ def check_segment_validity(mask_labels, mask_probs, k, mask_threshold=0.5, overlap_mask_area_threshold=0.8):
188
+ # Get the mask associated with the k class
189
+ mask_k = mask_labels == k
190
+ mask_k_area = mask_k.sum()
191
+
192
+ # Compute the area of all the stuff in query k
193
+ original_area = (mask_probs[k] >= mask_threshold).sum()
194
+ mask_exists = mask_k_area > 0 and original_area > 0
195
+
196
+ # Eliminate disconnected tiny segments
197
+ if mask_exists:
198
+ area_ratio = mask_k_area / original_area
199
+ if not area_ratio.item() > overlap_mask_area_threshold:
200
+ mask_exists = False
201
+
202
+ return mask_exists, mask_k
203
+
204
+
205
+ # Copied from transformers.models.detr.image_processing_detr.compute_segments
206
+ def compute_segments(
207
+ mask_probs,
208
+ pred_scores,
209
+ pred_labels,
210
+ mask_threshold: float = 0.5,
211
+ overlap_mask_area_threshold: float = 0.8,
212
+ label_ids_to_fuse: Optional[Set[int]] = None,
213
+ target_size: Tuple[int, int] = None,
214
+ ):
215
+ height = mask_probs.shape[1] if target_size is None else target_size[0]
216
+ width = mask_probs.shape[2] if target_size is None else target_size[1]
217
+
218
+ segmentation = torch.zeros((height, width), dtype=torch.int32, device=mask_probs.device)
219
+ segments: List[Dict] = []
220
+
221
+ if target_size is not None:
222
+ mask_probs = nn.functional.interpolate(
223
+ mask_probs.unsqueeze(0), size=target_size, mode="bilinear", align_corners=False
224
+ )[0]
225
+
226
+ current_segment_id = 0
227
+
228
+ # Weigh each mask by its prediction score
229
+ mask_probs *= pred_scores.view(-1, 1, 1)
230
+ mask_labels = mask_probs.argmax(0) # [height, width]
231
+
232
+ # Keep track of instances of each class
233
+ stuff_memory_list: Dict[str, int] = {}
234
+ for k in range(pred_labels.shape[0]):
235
+ pred_class = pred_labels[k].item()
236
+ should_fuse = pred_class in label_ids_to_fuse
237
+
238
+ # Check if mask exists and large enough to be a segment
239
+ mask_exists, mask_k = check_segment_validity(
240
+ mask_labels, mask_probs, k, mask_threshold, overlap_mask_area_threshold
241
+ )
242
+
243
+ if mask_exists:
244
+ if pred_class in stuff_memory_list:
245
+ current_segment_id = stuff_memory_list[pred_class]
246
+ else:
247
+ current_segment_id += 1
248
+
249
+ # Add current object segment to final segmentation map
250
+ segmentation[mask_k] = current_segment_id
251
+ segment_score = round(pred_scores[k].item(), 6)
252
+ segments.append(
253
+ {
254
+ "id": current_segment_id,
255
+ "label_id": pred_class,
256
+ "was_fused": should_fuse,
257
+ "score": segment_score,
258
+ }
259
+ )
260
+ if should_fuse:
261
+ stuff_memory_list[pred_class] = current_segment_id
262
+
263
+ return segmentation, segments
264
+
265
+
266
+ # Copied from transformers.models.maskformer.image_processing_maskformer.convert_segmentation_map_to_binary_masks
267
+ def convert_segmentation_map_to_binary_masks(
268
+ segmentation_map: "np.ndarray",
269
+ instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
270
+ ignore_index: Optional[int] = None,
271
+ do_reduce_labels: bool = False,
272
+ ):
273
+ if do_reduce_labels and ignore_index is None:
274
+ raise ValueError("If `do_reduce_labels` is True, `ignore_index` must be provided.")
275
+
276
+ if do_reduce_labels:
277
+ segmentation_map = np.where(segmentation_map == 0, ignore_index, segmentation_map - 1)
278
+
279
+ # Get unique ids (class or instance ids based on input)
280
+ all_labels = np.unique(segmentation_map)
281
+
282
+ # Drop background label if applicable
283
+ if ignore_index is not None:
284
+ all_labels = all_labels[all_labels != ignore_index]
285
+
286
+ # Generate a binary mask for each object instance
287
+ binary_masks = [(segmentation_map == i) for i in all_labels]
288
+
289
+ # Stack the binary masks
290
+ if binary_masks:
291
+ binary_masks = np.stack(binary_masks, axis=0)
292
+ else:
293
+ binary_masks = np.zeros((0, *segmentation_map.shape))
294
+
295
+ # Convert instance ids to class ids
296
+ if instance_id_to_semantic_id is not None:
297
+ labels = np.zeros(all_labels.shape[0])
298
+
299
+ for label in all_labels:
300
+ class_id = instance_id_to_semantic_id[label + 1 if do_reduce_labels else label]
301
+ labels[all_labels == label] = class_id - 1 if do_reduce_labels else class_id
302
+ else:
303
+ labels = all_labels
304
+
305
+ return binary_masks.astype(np.float32), labels.astype(np.int64)
306
+
307
+
308
+ def get_oneformer_resize_output_image_size(
309
+ image: np.ndarray,
310
+ size: Union[int, Tuple[int, int], List[int], Tuple[int]],
311
+ max_size: Optional[int] = None,
312
+ default_to_square: bool = True,
313
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
314
+ ) -> tuple:
315
+ """
316
+ Computes the output size given the desired size.
317
+
318
+ Args:
319
+ image (`np.ndarray`):
320
+ The input image.
321
+ size (`int` or `Tuple[int, int]` or `List[int]` or `Tuple[int]`):
322
+ The size of the output image.
323
+ max_size (`int`, *optional*):
324
+ The maximum size of the output image.
325
+ default_to_square (`bool`, *optional*, defaults to `True`):
326
+ Whether to default to square if no size is provided.
327
+ input_data_format (`ChannelDimension` or `str`, *optional*):
328
+ The channel dimension format of the input image. If unset, will use the inferred format from the input.
329
+
330
+ Returns:
331
+ `Tuple[int, int]`: The output size.
332
+ """
333
+ output_size = get_resize_output_image_size(
334
+ input_image=image,
335
+ size=size,
336
+ default_to_square=default_to_square,
337
+ max_size=max_size,
338
+ input_data_format=input_data_format,
339
+ )
340
+ return output_size
341
+
342
+
343
+ def prepare_metadata(class_info):
344
+ metadata = {}
345
+ class_names = []
346
+ thing_ids = []
347
+ for key, info in class_info.items():
348
+ metadata[key] = info["name"]
349
+ class_names.append(info["name"])
350
+ if info["isthing"]:
351
+ thing_ids.append(int(key))
352
+ metadata["thing_ids"] = thing_ids
353
+ metadata["class_names"] = class_names
354
+ return metadata
355
+
356
+
357
+ def load_metadata(repo_id, class_info_file):
358
+ fname = os.path.join("" if repo_id is None else repo_id, class_info_file)
359
+
360
+ if not os.path.exists(fname) or not os.path.isfile(fname):
361
+ if repo_id is None:
362
+ raise ValueError(f"Could not file {fname} locally. repo_id must be defined if loading from the hub")
363
+ # We try downloading from a dataset by default for backward compatibility
364
+ try:
365
+ fname = hf_hub_download(repo_id, class_info_file, repo_type="dataset")
366
+ except RepositoryNotFoundError:
367
+ fname = hf_hub_download(repo_id, class_info_file)
368
+
369
+ with open(fname, "r") as f:
370
+ class_info = json.load(f)
371
+
372
+ return class_info
373
+
374
+
375
+ class OneFormerImageProcessor(BaseImageProcessor):
376
+ r"""
377
+ Constructs a OneFormer image processor. The image processor can be used to prepare image(s), task input(s) and
378
+ optional text inputs and targets for the model.
379
+
380
+ This image processor inherits from [`BaseImageProcessor`] which contains most of the main methods. Users should
381
+ refer to this superclass for more information regarding those methods.
382
+
383
+ Args:
384
+ do_resize (`bool`, *optional*, defaults to `True`):
385
+ Whether to resize the input to a certain `size`.
386
+ size (`int`, *optional*, defaults to 800):
387
+ Resize the input to the given size. Only has an effect if `do_resize` is set to `True`. If size is a
388
+ sequence like `(width, height)`, output size will be matched to this. If size is an int, smaller edge of
389
+ the image will be matched to this number. i.e, if `height > width`, then image will be rescaled to `(size *
390
+ height / width, size)`.
391
+ resample (`int`, *optional*, defaults to `Resampling.BILINEAR`):
392
+ An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
393
+ `PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
394
+ `PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
395
+ to `True`.
396
+ do_rescale (`bool`, *optional*, defaults to `True`):
397
+ Whether to rescale the input to a certain `scale`.
398
+ rescale_factor (`float`, *optional*, defaults to `1/ 255`):
399
+ Rescale the input by the given factor. Only has an effect if `do_rescale` is set to `True`.
400
+ do_normalize (`bool`, *optional*, defaults to `True`):
401
+ Whether or not to normalize the input with mean and standard deviation.
402
+ image_mean (`int`, *optional*, defaults to `[0.485, 0.456, 0.406]`):
403
+ The sequence of means for each channel, to be used when normalizing images. Defaults to the ImageNet mean.
404
+ image_std (`int`, *optional*, defaults to `[0.229, 0.224, 0.225]`):
405
+ The sequence of standard deviations for each channel, to be used when normalizing images. Defaults to the
406
+ ImageNet std.
407
+ ignore_index (`int`, *optional*):
408
+ Label to be assigned to background pixels in segmentation maps. If provided, segmentation map pixels
409
+ denoted with 0 (background) will be replaced with `ignore_index`.
410
+ do_reduce_labels (`bool`, *optional*, defaults to `False`):
411
+ Whether or not to decrement all label values of segmentation maps by 1. Usually used for datasets where 0
412
+ is used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k).
413
+ The background label will be replaced by `ignore_index`.
414
+ repo_path (`str`, *optional*, defaults to `"shi-labs/oneformer_demo"`):
415
+ Path to hub repo or local directory containing the JSON file with class information for the dataset.
416
+ If unset, will look for `class_info_file` in the current working directory.
417
+ class_info_file (`str`, *optional*):
418
+ JSON file containing class information for the dataset. See `shi-labs/oneformer_demo/cityscapes_panoptic.json` for an example.
419
+ num_text (`int`, *optional*):
420
+ Number of text entries in the text input list.
421
+ num_labels (`int`, *optional*):
422
+ The number of labels in the segmentation map.
423
+ """
424
+
425
+ model_input_names = ["pixel_values", "pixel_mask", "task_inputs"]
426
+
427
+ @deprecate_kwarg("reduce_labels", new_name="do_reduce_labels", version="4.44.0")
428
+ @deprecate_kwarg("max_size", version="4.27.0", warn_if_greater_or_equal_version=True)
429
+ @filter_out_non_signature_kwargs(extra=["max_size", "metadata", *INIT_SERVICE_KWARGS])
430
+ def __init__(
431
+ self,
432
+ do_resize: bool = True,
433
+ size: Dict[str, int] = None,
434
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
435
+ do_rescale: bool = True,
436
+ rescale_factor: float = 1 / 255,
437
+ do_normalize: bool = True,
438
+ image_mean: Union[float, List[float]] = None,
439
+ image_std: Union[float, List[float]] = None,
440
+ ignore_index: Optional[int] = None,
441
+ do_reduce_labels: bool = False,
442
+ repo_path: Optional[str] = "shi-labs/oneformer_demo",
443
+ class_info_file: Optional[str] = None,
444
+ num_text: Optional[int] = None,
445
+ num_labels: Optional[int] = None,
446
+ **kwargs,
447
+ ):
448
+ super().__init__(**kwargs)
449
+
450
+ # Deprecated, backward compatibility
451
+ self._max_size = kwargs.pop("max_size", 1333)
452
+
453
+ size = size if size is not None else {"shortest_edge": 800, "longest_edge": self._max_size}
454
+ size = get_size_dict(size, max_size=self._max_size, default_to_square=False)
455
+
456
+ if class_info_file is None:
457
+ raise ValueError("You must provide a `class_info_file`")
458
+
459
+ self.do_resize = do_resize
460
+ self.size = size
461
+ self.resample = resample
462
+ self.do_rescale = do_rescale
463
+ self.rescale_factor = rescale_factor
464
+ self.do_normalize = do_normalize
465
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
466
+ self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
467
+ self.ignore_index = ignore_index
468
+ self.do_reduce_labels = do_reduce_labels
469
+ self.class_info_file = class_info_file
470
+ self.repo_path = repo_path
471
+ self.metadata = prepare_metadata(load_metadata(repo_path, class_info_file))
472
+ self.num_text = num_text
473
+ self.num_labels = num_labels
474
+
475
+ @classmethod
476
+ def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
477
+ """
478
+ Overrides the `from_dict` method from the base class to save support of deprecated `reduce_labels` in old configs
479
+ """
480
+ image_processor_dict = image_processor_dict.copy()
481
+ if "reduce_labels" in image_processor_dict:
482
+ image_processor_dict["do_reduce_labels"] = image_processor_dict.pop("reduce_labels")
483
+ return super().from_dict(image_processor_dict, **kwargs)
484
+
485
+ # Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.to_dict
486
+ def to_dict(self) -> Dict[str, Any]:
487
+ """
488
+ Serializes this instance to a Python dictionary. This method calls the superclass method and then removes the
489
+ `_max_size` attribute from the dictionary.
490
+ """
491
+ image_processor_dict = super().to_dict()
492
+ image_processor_dict.pop("_max_size", None)
493
+ return image_processor_dict
494
+
495
+ @deprecate_kwarg("max_size", version="4.27.0", warn_if_greater_or_equal_version=True)
496
+ @filter_out_non_signature_kwargs(extra=["max_size"])
497
+ def resize(
498
+ self,
499
+ image: np.ndarray,
500
+ size: Dict[str, int],
501
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
502
+ data_format=None,
503
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
504
+ **kwargs,
505
+ ) -> np.ndarray:
506
+ """
507
+ Resize the image to the given size. Size can be min_size (scalar) or `(height, width)` tuple. If size is an
508
+ int, smaller edge of the image will be matched to this number.
509
+ """
510
+
511
+ # Deprecated, backward compatibility
512
+ max_size = kwargs.pop("max_size", None)
513
+
514
+ size = get_size_dict(size, max_size=max_size, default_to_square=False)
515
+ if "shortest_edge" in size and "longest_edge" in size:
516
+ size, max_size = size["shortest_edge"], size["longest_edge"]
517
+ elif "height" in size and "width" in size:
518
+ size = (size["height"], size["width"])
519
+ max_size = None
520
+ else:
521
+ raise ValueError(
522
+ "Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got"
523
+ f" {size.keys()}."
524
+ )
525
+ size = get_oneformer_resize_output_image_size(
526
+ image=image, size=size, max_size=max_size, default_to_square=False, input_data_format=input_data_format
527
+ )
528
+ image = resize(
529
+ image, size=size, resample=resample, data_format=data_format, input_data_format=input_data_format
530
+ )
531
+ return image
532
+
533
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.rescale
534
+ def rescale(
535
+ self,
536
+ image: np.ndarray,
537
+ rescale_factor: float,
538
+ data_format: Optional[Union[str, ChannelDimension]] = None,
539
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
540
+ ) -> np.ndarray:
541
+ """
542
+ Rescale the image by the given factor. image = image * rescale_factor.
543
+
544
+ Args:
545
+ image (`np.ndarray`):
546
+ Image to rescale.
547
+ rescale_factor (`float`):
548
+ The value to use for rescaling.
549
+ data_format (`str` or `ChannelDimension`, *optional*):
550
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
551
+ image is used. Can be one of:
552
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
553
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
554
+ input_data_format (`str` or `ChannelDimension`, *optional*):
555
+ The channel dimension format for the input image. If unset, is inferred from the input image. Can be
556
+ one of:
557
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
558
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
559
+ """
560
+ return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format)
561
+
562
+ # Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.convert_segmentation_map_to_binary_masks
563
+ def convert_segmentation_map_to_binary_masks(
564
+ self,
565
+ segmentation_map: "np.ndarray",
566
+ instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
567
+ ignore_index: Optional[int] = None,
568
+ do_reduce_labels: bool = False,
569
+ ):
570
+ do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels
571
+ ignore_index = ignore_index if ignore_index is not None else self.ignore_index
572
+ return convert_segmentation_map_to_binary_masks(
573
+ segmentation_map=segmentation_map,
574
+ instance_id_to_semantic_id=instance_id_to_semantic_id,
575
+ ignore_index=ignore_index,
576
+ do_reduce_labels=do_reduce_labels,
577
+ )
578
+
579
+ def __call__(self, images, task_inputs=None, segmentation_maps=None, **kwargs) -> BatchFeature:
580
+ return self.preprocess(images, task_inputs=task_inputs, segmentation_maps=segmentation_maps, **kwargs)
581
+
582
+ def _preprocess(
583
+ self,
584
+ image: ImageInput,
585
+ do_resize: Optional[bool] = None,
586
+ size: Dict[str, int] = None,
587
+ resample: PILImageResampling = None,
588
+ do_rescale: Optional[bool] = None,
589
+ rescale_factor: Optional[float] = None,
590
+ do_normalize: Optional[bool] = None,
591
+ image_mean: Optional[Union[float, List[float]]] = None,
592
+ image_std: Optional[Union[float, List[float]]] = None,
593
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
594
+ ):
595
+ if do_resize:
596
+ image = self.resize(image, size=size, resample=resample, input_data_format=input_data_format)
597
+ if do_rescale:
598
+ image = self.rescale(image, rescale_factor=rescale_factor, input_data_format=input_data_format)
599
+ if do_normalize:
600
+ image = self.normalize(image, mean=image_mean, std=image_std, input_data_format=input_data_format)
601
+ return image
602
+
603
+ def _preprocess_image(
604
+ self,
605
+ image: ImageInput,
606
+ do_resize: Optional[bool] = None,
607
+ size: Dict[str, int] = None,
608
+ resample: PILImageResampling = None,
609
+ do_rescale: Optional[bool] = None,
610
+ rescale_factor: Optional[float] = None,
611
+ do_normalize: Optional[bool] = None,
612
+ image_mean: Optional[Union[float, List[float]]] = None,
613
+ image_std: Optional[Union[float, List[float]]] = None,
614
+ data_format: Optional[Union[str, ChannelDimension]] = None,
615
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
616
+ ) -> np.ndarray:
617
+ """Preprocesses a single image."""
618
+ # All transformations expect numpy arrays.
619
+ image = to_numpy_array(image)
620
+ if do_rescale and is_scaled_image(image):
621
+ logger.warning_once(
622
+ "It looks like you are trying to rescale already rescaled images. If the input"
623
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
624
+ )
625
+ if input_data_format is None:
626
+ input_data_format = infer_channel_dimension_format(image)
627
+ image = self._preprocess(
628
+ image=image,
629
+ do_resize=do_resize,
630
+ size=size,
631
+ resample=resample,
632
+ do_rescale=do_rescale,
633
+ rescale_factor=rescale_factor,
634
+ do_normalize=do_normalize,
635
+ image_mean=image_mean,
636
+ image_std=image_std,
637
+ input_data_format=input_data_format,
638
+ )
639
+ if data_format is not None:
640
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
641
+ return image
642
+
643
+ def _preprocess_mask(
644
+ self,
645
+ segmentation_map: ImageInput,
646
+ do_resize: Optional[bool] = None,
647
+ size: Dict[str, int] = None,
648
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
649
+ ) -> np.ndarray:
650
+ """Preprocesses a single mask."""
651
+ segmentation_map = to_numpy_array(segmentation_map)
652
+ # Add channel dimension if missing - needed for certain transformations
653
+ if segmentation_map.ndim == 2:
654
+ added_channel_dim = True
655
+ segmentation_map = segmentation_map[None, ...]
656
+ input_data_format = ChannelDimension.FIRST
657
+ else:
658
+ added_channel_dim = False
659
+ if input_data_format is None:
660
+ input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1)
661
+ # TODO: (Amy)
662
+ # Remork segmentation map processing to include reducing labels and resizing which doesn't
663
+ # drop segment IDs > 255.
664
+ segmentation_map = self._preprocess(
665
+ image=segmentation_map,
666
+ do_resize=do_resize,
667
+ resample=PILImageResampling.NEAREST,
668
+ size=size,
669
+ do_rescale=False,
670
+ do_normalize=False,
671
+ input_data_format=input_data_format,
672
+ )
673
+ # Remove extra channel dimension if added for processing
674
+ if added_channel_dim:
675
+ segmentation_map = segmentation_map.squeeze(0)
676
+ return segmentation_map
677
+
678
+ @filter_out_non_signature_kwargs()
679
+ def preprocess(
680
+ self,
681
+ images: ImageInput,
682
+ task_inputs: Optional[List[str]] = None,
683
+ segmentation_maps: Optional[ImageInput] = None,
684
+ instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
685
+ do_resize: Optional[bool] = None,
686
+ size: Optional[Dict[str, int]] = None,
687
+ resample: PILImageResampling = None,
688
+ do_rescale: Optional[bool] = None,
689
+ rescale_factor: Optional[float] = None,
690
+ do_normalize: Optional[bool] = None,
691
+ image_mean: Optional[Union[float, List[float]]] = None,
692
+ image_std: Optional[Union[float, List[float]]] = None,
693
+ ignore_index: Optional[int] = None,
694
+ do_reduce_labels: Optional[bool] = None,
695
+ return_tensors: Optional[Union[str, TensorType]] = None,
696
+ data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
697
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
698
+ ) -> BatchFeature:
699
+ if task_inputs is None:
700
+ # Default value
701
+ task_inputs = ["panoptic"]
702
+
703
+ do_resize = do_resize if do_resize is not None else self.do_resize
704
+ size = size if size is not None else self.size
705
+ size = get_size_dict(size, default_to_square=False, max_size=self._max_size)
706
+ resample = resample if resample is not None else self.resample
707
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
708
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
709
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
710
+ image_mean = image_mean if image_mean is not None else self.image_mean
711
+ image_std = image_std if image_std is not None else self.image_std
712
+ ignore_index = ignore_index if ignore_index is not None else self.ignore_index
713
+ do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels
714
+
715
+ if not valid_images(images):
716
+ raise ValueError(
717
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
718
+ "torch.Tensor, tf.Tensor or jax.ndarray."
719
+ )
720
+
721
+ validate_preprocess_arguments(
722
+ do_rescale=do_rescale,
723
+ rescale_factor=rescale_factor,
724
+ do_normalize=do_normalize,
725
+ image_mean=image_mean,
726
+ image_std=image_std,
727
+ do_resize=do_resize,
728
+ size=size,
729
+ resample=resample,
730
+ )
731
+
732
+ if segmentation_maps is not None and not valid_images(segmentation_maps):
733
+ raise ValueError(
734
+ "Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, "
735
+ "torch.Tensor, tf.Tensor or jax.ndarray."
736
+ )
737
+
738
+ images = make_list_of_images(images)
739
+ if segmentation_maps is not None:
740
+ segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2)
741
+
742
+ if segmentation_maps is not None and len(images) != len(segmentation_maps):
743
+ raise ValueError("Images and segmentation maps must have the same length.")
744
+
745
+ images = [
746
+ self._preprocess_image(
747
+ image,
748
+ do_resize=do_resize,
749
+ size=size,
750
+ resample=resample,
751
+ do_rescale=do_rescale,
752
+ rescale_factor=rescale_factor,
753
+ do_normalize=do_normalize,
754
+ image_mean=image_mean,
755
+ image_std=image_std,
756
+ data_format=data_format,
757
+ input_data_format=input_data_format,
758
+ )
759
+ for image in images
760
+ ]
761
+
762
+ if segmentation_maps is not None:
763
+ segmentation_maps = [
764
+ self._preprocess_mask(segmentation_map, do_resize, size, input_data_format=input_data_format)
765
+ for segmentation_map in segmentation_maps
766
+ ]
767
+ encoded_inputs = self.encode_inputs(
768
+ images,
769
+ task_inputs,
770
+ segmentation_maps,
771
+ instance_id_to_semantic_id,
772
+ ignore_index,
773
+ do_reduce_labels,
774
+ return_tensors,
775
+ input_data_format=data_format,
776
+ )
777
+ return encoded_inputs
778
+
779
+ # Copied from transformers.models.vilt.image_processing_vilt.ViltImageProcessor._pad_image
780
+ def _pad_image(
781
+ self,
782
+ image: np.ndarray,
783
+ output_size: Tuple[int, int],
784
+ constant_values: Union[float, Iterable[float]] = 0,
785
+ data_format: Optional[ChannelDimension] = None,
786
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
787
+ ) -> np.ndarray:
788
+ """
789
+ Pad an image with zeros to the given size.
790
+ """
791
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
792
+ output_height, output_width = output_size
793
+
794
+ pad_bottom = output_height - input_height
795
+ pad_right = output_width - input_width
796
+ padding = ((0, pad_bottom), (0, pad_right))
797
+ padded_image = pad(
798
+ image,
799
+ padding,
800
+ mode=PaddingMode.CONSTANT,
801
+ constant_values=constant_values,
802
+ data_format=data_format,
803
+ input_data_format=input_data_format,
804
+ )
805
+ return padded_image
806
+
807
+ # Copied from transformers.models.vilt.image_processing_vilt.ViltImageProcessor.pad
808
+ def pad(
809
+ self,
810
+ images: List[np.ndarray],
811
+ constant_values: Union[float, Iterable[float]] = 0,
812
+ return_pixel_mask: bool = True,
813
+ return_tensors: Optional[Union[str, TensorType]] = None,
814
+ data_format: Optional[ChannelDimension] = None,
815
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
816
+ ) -> BatchFeature:
817
+ """
818
+ Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
819
+ in the batch and optionally returns their corresponding pixel mask.
820
+
821
+ Args:
822
+ image (`np.ndarray`):
823
+ Image to pad.
824
+ constant_values (`float` or `Iterable[float]`, *optional*):
825
+ The value to use for the padding if `mode` is `"constant"`.
826
+ return_pixel_mask (`bool`, *optional*, defaults to `True`):
827
+ Whether to return a pixel mask.
828
+ return_tensors (`str` or `TensorType`, *optional*):
829
+ The type of tensors to return. Can be one of:
830
+ - Unset: Return a list of `np.ndarray`.
831
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
832
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
833
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
834
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
835
+ data_format (`str` or `ChannelDimension`, *optional*):
836
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
837
+ input_data_format (`ChannelDimension` or `str`, *optional*):
838
+ The channel dimension format of the input image. If not provided, it will be inferred.
839
+ """
840
+ pad_size = get_max_height_width(images, input_data_format=input_data_format)
841
+
842
+ padded_images = [
843
+ self._pad_image(
844
+ image,
845
+ pad_size,
846
+ constant_values=constant_values,
847
+ data_format=data_format,
848
+ input_data_format=input_data_format,
849
+ )
850
+ for image in images
851
+ ]
852
+ data = {"pixel_values": padded_images}
853
+
854
+ if return_pixel_mask:
855
+ masks = [
856
+ make_pixel_mask(image=image, output_size=pad_size, input_data_format=input_data_format)
857
+ for image in images
858
+ ]
859
+ data["pixel_mask"] = masks
860
+
861
+ return BatchFeature(data=data, tensor_type=return_tensors)
862
+
863
+ def get_semantic_annotations(self, label, num_class_obj):
864
+ annotation_classes = label["classes"]
865
+ annotation_masks = label["masks"]
866
+
867
+ texts = ["a semantic photo"] * self.num_text
868
+ classes = []
869
+ masks = []
870
+
871
+ for idx in range(len(annotation_classes)):
872
+ class_id = annotation_classes[idx]
873
+ mask = annotation_masks[idx]
874
+ if not np.all(mask is False):
875
+ if class_id not in classes:
876
+ cls_name = self.metadata[str(class_id)]
877
+ classes.append(class_id)
878
+ masks.append(mask)
879
+ num_class_obj[cls_name] += 1
880
+ else:
881
+ idx = classes.index(class_id)
882
+ masks[idx] += mask
883
+ masks[idx] = np.clip(masks[idx], 0, 1)
884
+
885
+ num = 0
886
+ for i, cls_name in enumerate(self.metadata["class_names"]):
887
+ if num_class_obj[cls_name] > 0:
888
+ for _ in range(num_class_obj[cls_name]):
889
+ if num >= len(texts):
890
+ break
891
+ texts[num] = f"a photo with a {cls_name}"
892
+ num += 1
893
+
894
+ classes = np.array(classes)
895
+ masks = np.array(masks)
896
+ return classes, masks, texts
897
+
898
+ def get_instance_annotations(self, label, num_class_obj):
899
+ annotation_classes = label["classes"]
900
+ annotation_masks = label["masks"]
901
+
902
+ texts = ["an instance photo"] * self.num_text
903
+ classes = []
904
+ masks = []
905
+
906
+ for idx in range(len(annotation_classes)):
907
+ class_id = annotation_classes[idx]
908
+ mask = annotation_masks[idx]
909
+
910
+ if class_id in self.metadata["thing_ids"]:
911
+ if not np.all(mask is False):
912
+ cls_name = self.metadata[str(class_id)]
913
+ classes.append(class_id)
914
+ masks.append(mask)
915
+ num_class_obj[cls_name] += 1
916
+
917
+ num = 0
918
+ for i, cls_name in enumerate(self.metadata["class_names"]):
919
+ if num_class_obj[cls_name] > 0:
920
+ for _ in range(num_class_obj[cls_name]):
921
+ if num >= len(texts):
922
+ break
923
+ texts[num] = f"a photo with a {cls_name}"
924
+ num += 1
925
+
926
+ classes = np.array(classes)
927
+ masks = np.array(masks)
928
+ return classes, masks, texts
929
+
930
+ def get_panoptic_annotations(self, label, num_class_obj):
931
+ annotation_classes = label["classes"]
932
+ annotation_masks = label["masks"]
933
+
934
+ texts = ["an panoptic photo"] * self.num_text
935
+ classes = []
936
+ masks = []
937
+
938
+ for idx in range(len(annotation_classes)):
939
+ class_id = annotation_classes[idx]
940
+ mask = annotation_masks[idx].data
941
+ if not np.all(mask is False):
942
+ cls_name = self.metadata[str(class_id)]
943
+ classes.append(class_id)
944
+ masks.append(mask)
945
+ num_class_obj[cls_name] += 1
946
+
947
+ num = 0
948
+ for i, cls_name in enumerate(self.metadata["class_names"]):
949
+ if num_class_obj[cls_name] > 0:
950
+ for _ in range(num_class_obj[cls_name]):
951
+ if num >= len(texts):
952
+ break
953
+ texts[num] = f"a photo with a {cls_name}"
954
+ num += 1
955
+
956
+ classes = np.array(classes)
957
+ masks = np.array(masks)
958
+ return classes, masks, texts
959
+
960
+ def encode_inputs(
961
+ self,
962
+ pixel_values_list: List[ImageInput],
963
+ task_inputs: List[str],
964
+ segmentation_maps: ImageInput = None,
965
+ instance_id_to_semantic_id: Optional[Union[List[Dict[int, int]], Dict[int, int]]] = None,
966
+ ignore_index: Optional[int] = None,
967
+ do_reduce_labels: bool = False,
968
+ return_tensors: Optional[Union[str, TensorType]] = None,
969
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
970
+ ):
971
+ """
972
+ Pad images up to the largest image in a batch and create a corresponding `pixel_mask`.
973
+
974
+ OneFormer addresses semantic segmentation with a mask classification paradigm, thus input segmentation maps
975
+ will be converted to lists of binary masks and their respective labels. Let's see an example, assuming
976
+ `segmentation_maps = [[2,6,7,9]]`, the output will contain `mask_labels =
977
+ [[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]]` (four binary masks) and `class_labels = [2,6,7,9]`, the labels for
978
+ each mask.
979
+
980
+ Args:
981
+ pixel_values_list (`List[ImageInput]`):
982
+ List of images (pixel values) to be padded. Each image should be a tensor of shape `(channels, height,
983
+ width)`.
984
+
985
+ task_inputs (`List[str]`):
986
+ List of task values.
987
+
988
+ segmentation_maps (`ImageInput`, *optional*):
989
+ The corresponding semantic segmentation maps with the pixel-wise annotations.
990
+
991
+ (`bool`, *optional*, defaults to `True`):
992
+ Whether or not to pad images up to the largest image in a batch and create a pixel mask.
993
+
994
+ If left to the default, will return a pixel mask that is:
995
+
996
+ - 1 for pixels that are real (i.e. **not masked**),
997
+ - 0 for pixels that are padding (i.e. **masked**).
998
+
999
+ instance_id_to_semantic_id (`List[Dict[int, int]]` or `Dict[int, int]`, *optional*):
1000
+ A mapping between object instance ids and class ids. If passed, `segmentation_maps` is treated as an
1001
+ instance segmentation map where each pixel represents an instance id. Can be provided as a single
1002
+ dictionary with a global/dataset-level mapping or as a list of dictionaries (one per image), to map
1003
+ instance ids in each image separately.
1004
+
1005
+ return_tensors (`str` or [`~file_utils.TensorType`], *optional*):
1006
+ If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor`
1007
+ objects.
1008
+
1009
+ input_data_format (`str` or `ChannelDimension`, *optional*):
1010
+ The channel dimension format of the input image. If not provided, it will be inferred from the input
1011
+ image.
1012
+
1013
+ Returns:
1014
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
1015
+
1016
+ - **pixel_values** -- Pixel values to be fed to a model.
1017
+ - **pixel_mask** -- Pixel mask to be fed to a model (when `=True` or if `pixel_mask` is in
1018
+ `self.model_input_names`).
1019
+ - **mask_labels** -- Optional list of mask labels of shape `(labels, height, width)` to be fed to a model
1020
+ (when `annotations` are provided).
1021
+ - **class_labels** -- Optional list of class labels of shape `(labels)` to be fed to a model (when
1022
+ `annotations` are provided). They identify the labels of `mask_labels`, e.g. the label of
1023
+ `mask_labels[i][j]` if `class_labels[i][j]`.
1024
+ - **text_inputs** -- Optional list of text string entries to be fed to a model (when `annotations` are
1025
+ provided). They identify the binary masks present in the image.
1026
+ """
1027
+ ignore_index = self.ignore_index if ignore_index is None else ignore_index
1028
+ do_reduce_labels = self.do_reduce_labels if do_reduce_labels is None else do_reduce_labels
1029
+ pixel_values_list = [to_numpy_array(pixel_values) for pixel_values in pixel_values_list]
1030
+
1031
+ if input_data_format is None:
1032
+ input_data_format = infer_channel_dimension_format(pixel_values_list[0])
1033
+
1034
+ pad_size = get_max_height_width(pixel_values_list, input_data_format=input_data_format)
1035
+ encoded_inputs = self.pad(
1036
+ pixel_values_list, return_tensors=return_tensors, input_data_format=input_data_format
1037
+ )
1038
+
1039
+ annotations = None
1040
+ if segmentation_maps is not None:
1041
+ segmentation_maps = map(np.array, segmentation_maps)
1042
+ annotations = []
1043
+ for idx, segmentation_map in enumerate(segmentation_maps):
1044
+ # Use instance2class_id mapping per image
1045
+ if isinstance(instance_id_to_semantic_id, list):
1046
+ instance_id = instance_id_to_semantic_id[idx]
1047
+ else:
1048
+ instance_id = instance_id_to_semantic_id
1049
+ # Use instance2class_id mapping per image
1050
+ masks, classes = self.convert_segmentation_map_to_binary_masks(
1051
+ segmentation_map, instance_id, ignore_index=ignore_index, do_reduce_labels=do_reduce_labels
1052
+ )
1053
+ annotations.append({"masks": masks, "classes": classes})
1054
+
1055
+ if annotations is not None:
1056
+ mask_labels = []
1057
+ class_labels = []
1058
+ text_inputs = []
1059
+
1060
+ num_class_obj = {}
1061
+ for cls_name in self.metadata["class_names"]:
1062
+ num_class_obj[cls_name] = 0
1063
+
1064
+ for i, label in enumerate(annotations):
1065
+ task = task_inputs[i]
1066
+ if task == "semantic":
1067
+ classes, masks, texts = self.get_semantic_annotations(label, num_class_obj)
1068
+ elif task == "instance":
1069
+ classes, masks, texts = self.get_instance_annotations(label, num_class_obj)
1070
+ elif task == "panoptic":
1071
+ classes, masks, texts = self.get_panoptic_annotations(label, num_class_obj)
1072
+ else:
1073
+ raise ValueError(f"{task} was not expected, expected `semantic`, `instance` or `panoptic`")
1074
+
1075
+ # we cannot batch them since they don't share a common class size
1076
+ masks = [mask[None, ...] for mask in masks]
1077
+ masks = [
1078
+ self._pad_image(image=mask, output_size=pad_size, constant_values=ignore_index) for mask in masks
1079
+ ]
1080
+ masks = np.concatenate(masks, axis=0)
1081
+ mask_labels.append(torch.from_numpy(masks))
1082
+ class_labels.append(torch.from_numpy(classes).long())
1083
+ text_inputs.append(texts)
1084
+
1085
+ encoded_inputs["mask_labels"] = mask_labels
1086
+ encoded_inputs["class_labels"] = class_labels
1087
+ encoded_inputs["text_inputs"] = text_inputs
1088
+
1089
+ # This needs to be tokenized before sending to the model.
1090
+ encoded_inputs["task_inputs"] = [f"the task is {task_input}" for task_input in task_inputs]
1091
+
1092
+ return encoded_inputs
1093
+
1094
+ # Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.post_process_semantic_segmentation
1095
+ def post_process_semantic_segmentation(
1096
+ self, outputs, target_sizes: Optional[List[Tuple[int, int]]] = None
1097
+ ) -> "torch.Tensor":
1098
+ """
1099
+ Converts the output of [`MaskFormerForInstanceSegmentation`] into semantic segmentation maps. Only supports
1100
+ PyTorch.
1101
+
1102
+ Args:
1103
+ outputs ([`MaskFormerForInstanceSegmentation`]):
1104
+ Raw outputs of the model.
1105
+ target_sizes (`List[Tuple[int, int]]`, *optional*):
1106
+ List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested
1107
+ final size (height, width) of each prediction. If left to None, predictions will not be resized.
1108
+ Returns:
1109
+ `List[torch.Tensor]`:
1110
+ A list of length `batch_size`, where each item is a semantic segmentation map of shape (height, width)
1111
+ corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each
1112
+ `torch.Tensor` correspond to a semantic class id.
1113
+ """
1114
+ class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1]
1115
+ masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width]
1116
+
1117
+ # Remove the null class `[..., :-1]`
1118
+ masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1]
1119
+ masks_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
1120
+
1121
+ # Semantic segmentation logits of shape (batch_size, num_classes, height, width)
1122
+ segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
1123
+ batch_size = class_queries_logits.shape[0]
1124
+
1125
+ # Resize logits and compute semantic segmentation maps
1126
+ if target_sizes is not None:
1127
+ if batch_size != len(target_sizes):
1128
+ raise ValueError(
1129
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits"
1130
+ )
1131
+
1132
+ semantic_segmentation = []
1133
+ for idx in range(batch_size):
1134
+ resized_logits = torch.nn.functional.interpolate(
1135
+ segmentation[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
1136
+ )
1137
+ semantic_map = resized_logits[0].argmax(dim=0)
1138
+ semantic_segmentation.append(semantic_map)
1139
+ else:
1140
+ semantic_segmentation = segmentation.argmax(dim=1)
1141
+ semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
1142
+
1143
+ return semantic_segmentation
1144
+
1145
+ def post_process_instance_segmentation(
1146
+ self,
1147
+ outputs,
1148
+ task_type: str = "instance",
1149
+ is_demo: bool = True,
1150
+ threshold: float = 0.5,
1151
+ mask_threshold: float = 0.5,
1152
+ overlap_mask_area_threshold: float = 0.8,
1153
+ target_sizes: Optional[List[Tuple[int, int]]] = None,
1154
+ return_coco_annotation: Optional[bool] = False,
1155
+ ):
1156
+ """
1157
+ Converts the output of [`OneFormerForUniversalSegmentationOutput`] into image instance segmentation
1158
+ predictions. Only supports PyTorch.
1159
+
1160
+ Args:
1161
+ outputs ([`OneFormerForUniversalSegmentationOutput`]):
1162
+ The outputs from [`OneFormerForUniversalSegmentationOutput`].
1163
+ task_type (`str`, *optional*, defaults to "instance"):
1164
+ The post processing depends on the task token input. If the `task_type` is "panoptic", we need to
1165
+ ignore the stuff predictions.
1166
+ is_demo (`bool`, *optional)*, defaults to `True`):
1167
+ Whether the model is in demo mode. If true, use threshold to predict final masks.
1168
+ threshold (`float`, *optional*, defaults to 0.5):
1169
+ The probability score threshold to keep predicted instance masks.
1170
+ mask_threshold (`float`, *optional*, defaults to 0.5):
1171
+ Threshold to use when turning the predicted masks into binary values.
1172
+ overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
1173
+ The overlap mask area threshold to merge or discard small disconnected parts within each binary
1174
+ instance mask.
1175
+ target_sizes (`List[Tuple]`, *optional*):
1176
+ List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested
1177
+ final size (height, width) of each prediction in batch. If left to None, predictions will not be
1178
+ resized.
1179
+ return_coco_annotation (`bool`, *optional)*, defaults to `False`):
1180
+ Whether to return predictions in COCO format.
1181
+
1182
+ Returns:
1183
+ `List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
1184
+ - **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id`, set
1185
+ to `None` if no mask if found above `threshold`. If `target_sizes` is specified, segmentation is resized
1186
+ to the corresponding `target_sizes` entry.
1187
+ - **segments_info** -- A dictionary that contains additional information on each segment.
1188
+ - **id** -- an integer representing the `segment_id`.
1189
+ - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
1190
+ - **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise.
1191
+ Multiple instances of the same class / label were fused and assigned a single `segment_id`.
1192
+ - **score** -- Prediction score of segment with `segment_id`.
1193
+ """
1194
+ class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1]
1195
+ masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width]
1196
+
1197
+ device = masks_queries_logits.device
1198
+ batch_size = class_queries_logits.shape[0]
1199
+ num_queries = class_queries_logits.shape[1]
1200
+ num_classes = class_queries_logits.shape[-1] - 1
1201
+
1202
+ # Loop over items in batch size
1203
+ results: List[Dict[str, torch.Tensor]] = []
1204
+
1205
+ for i in range(batch_size):
1206
+ # [Q, K]
1207
+ scores = torch.nn.functional.softmax(class_queries_logits[i], dim=-1)[:, :-1]
1208
+ labels = torch.arange(num_classes, device=device).unsqueeze(0).repeat(num_queries, 1).flatten(0, 1)
1209
+
1210
+ # scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False)
1211
+ scores_per_image, topk_indices = scores.flatten(0, 1).topk(num_queries, sorted=False)
1212
+ labels_per_image = labels[topk_indices]
1213
+
1214
+ topk_indices = torch.div(topk_indices, num_classes, rounding_mode="floor")
1215
+ # mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1)
1216
+ mask_pred = masks_queries_logits[i][topk_indices]
1217
+
1218
+ # Only consider scores with confidence over [threshold] for demo
1219
+ if is_demo:
1220
+ keep = scores_per_image > threshold
1221
+ scores_per_image = scores_per_image[keep]
1222
+ labels_per_image = labels_per_image[keep]
1223
+ mask_pred = mask_pred[keep]
1224
+
1225
+ # if this is panoptic segmentation, we only keep the "thing" classes
1226
+ if task_type == "panoptic":
1227
+ keep = torch.zeros_like(scores_per_image).bool()
1228
+ for j, lab in enumerate(labels_per_image):
1229
+ keep[j] = lab in self.metadata["thing_ids"]
1230
+
1231
+ scores_per_image = scores_per_image[keep]
1232
+ labels_per_image = labels_per_image[keep]
1233
+ mask_pred = mask_pred[keep]
1234
+
1235
+ if mask_pred.shape[0] <= 0:
1236
+ height, width = target_sizes[i] if target_sizes is not None else mask_pred.shape[1:]
1237
+ segmentation = torch.zeros((height, width)) - 1
1238
+ results.append({"segmentation": segmentation, "segments_info": []})
1239
+ continue
1240
+
1241
+ if "ade20k" in self.class_info_file and not is_demo and "instance" in task_type:
1242
+ for j in range(labels_per_image.shape[0]):
1243
+ labels_per_image[j] = self.metadata["thing_ids"].index(labels_per_image[j].item())
1244
+
1245
+ # Get segmentation map and segment information of batch item
1246
+ target_size = target_sizes[i] if target_sizes is not None else None
1247
+ segmentation, segments = compute_segments(
1248
+ mask_pred,
1249
+ scores_per_image,
1250
+ labels_per_image,
1251
+ mask_threshold,
1252
+ overlap_mask_area_threshold,
1253
+ set(),
1254
+ target_size,
1255
+ )
1256
+
1257
+ # Return segmentation map in run-length encoding (RLE) format
1258
+ if return_coco_annotation:
1259
+ segmentation = convert_segmentation_to_rle(segmentation)
1260
+
1261
+ results.append({"segmentation": segmentation, "segments_info": segments})
1262
+ return results
1263
+
1264
+ # Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.post_process_panoptic_segmentation
1265
+ def post_process_panoptic_segmentation(
1266
+ self,
1267
+ outputs,
1268
+ threshold: float = 0.5,
1269
+ mask_threshold: float = 0.5,
1270
+ overlap_mask_area_threshold: float = 0.8,
1271
+ label_ids_to_fuse: Optional[Set[int]] = None,
1272
+ target_sizes: Optional[List[Tuple[int, int]]] = None,
1273
+ ) -> List[Dict]:
1274
+ """
1275
+ Converts the output of [`MaskFormerForInstanceSegmentationOutput`] into image panoptic segmentation
1276
+ predictions. Only supports PyTorch.
1277
+
1278
+ Args:
1279
+ outputs ([`MaskFormerForInstanceSegmentationOutput`]):
1280
+ The outputs from [`MaskFormerForInstanceSegmentation`].
1281
+ threshold (`float`, *optional*, defaults to 0.5):
1282
+ The probability score threshold to keep predicted instance masks.
1283
+ mask_threshold (`float`, *optional*, defaults to 0.5):
1284
+ Threshold to use when turning the predicted masks into binary values.
1285
+ overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
1286
+ The overlap mask area threshold to merge or discard small disconnected parts within each binary
1287
+ instance mask.
1288
+ label_ids_to_fuse (`Set[int]`, *optional*):
1289
+ The labels in this state will have all their instances be fused together. For instance we could say
1290
+ there can only be one sky in an image, but several persons, so the label ID for sky would be in that
1291
+ set, but not the one for person.
1292
+ target_sizes (`List[Tuple]`, *optional*):
1293
+ List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested
1294
+ final size (height, width) of each prediction in batch. If left to None, predictions will not be
1295
+ resized.
1296
+
1297
+ Returns:
1298
+ `List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
1299
+ - **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id`, set
1300
+ to `None` if no mask if found above `threshold`. If `target_sizes` is specified, segmentation is resized
1301
+ to the corresponding `target_sizes` entry.
1302
+ - **segments_info** -- A dictionary that contains additional information on each segment.
1303
+ - **id** -- an integer representing the `segment_id`.
1304
+ - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
1305
+ - **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise.
1306
+ Multiple instances of the same class / label were fused and assigned a single `segment_id`.
1307
+ - **score** -- Prediction score of segment with `segment_id`.
1308
+ """
1309
+
1310
+ if label_ids_to_fuse is None:
1311
+ logger.warning("`label_ids_to_fuse` unset. No instance will be fused.")
1312
+ label_ids_to_fuse = set()
1313
+
1314
+ class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1]
1315
+ masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width]
1316
+
1317
+ batch_size = class_queries_logits.shape[0]
1318
+ num_labels = class_queries_logits.shape[-1] - 1
1319
+
1320
+ mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
1321
+
1322
+ # Predicted label and score of each query (batch_size, num_queries)
1323
+ pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1)
1324
+
1325
+ # Loop over items in batch size
1326
+ results: List[Dict[str, TensorType]] = []
1327
+
1328
+ for i in range(batch_size):
1329
+ mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects(
1330
+ mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels
1331
+ )
1332
+
1333
+ # No mask found
1334
+ if mask_probs_item.shape[0] <= 0:
1335
+ height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:]
1336
+ segmentation = torch.zeros((height, width)) - 1
1337
+ results.append({"segmentation": segmentation, "segments_info": []})
1338
+ continue
1339
+
1340
+ # Get segmentation map and segment information of batch item
1341
+ target_size = target_sizes[i] if target_sizes is not None else None
1342
+ segmentation, segments = compute_segments(
1343
+ mask_probs=mask_probs_item,
1344
+ pred_scores=pred_scores_item,
1345
+ pred_labels=pred_labels_item,
1346
+ mask_threshold=mask_threshold,
1347
+ overlap_mask_area_threshold=overlap_mask_area_threshold,
1348
+ label_ids_to_fuse=label_ids_to_fuse,
1349
+ target_size=target_size,
1350
+ )
1351
+
1352
+ results.append({"segmentation": segmentation, "segments_info": segments})
1353
+ return results
1354
+
1355
+
1356
+ __all__ = ["OneFormerImageProcessor"]
docs/transformers/src/transformers/models/oneformer/modeling_oneformer.py ADDED
The diff for this file is too large to render. See raw diff
 
docs/transformers/src/transformers/models/oneformer/processing_oneformer.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 SHI Labs and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Image/Text processor class for OneFormer
17
+ """
18
+
19
+ from typing import List
20
+
21
+ from ...processing_utils import ProcessorMixin
22
+ from ...utils import is_torch_available
23
+
24
+
25
+ if is_torch_available():
26
+ import torch
27
+
28
+
29
+ class OneFormerProcessor(ProcessorMixin):
30
+ r"""
31
+ Constructs an OneFormer processor which wraps [`OneFormerImageProcessor`] and
32
+ [`CLIPTokenizer`]/[`CLIPTokenizerFast`] into a single processor that inherits both the image processor and
33
+ tokenizer functionalities.
34
+
35
+ Args:
36
+ image_processor ([`OneFormerImageProcessor`]):
37
+ The image processor is a required input.
38
+ tokenizer ([`CLIPTokenizer`, `CLIPTokenizerFast`]):
39
+ The tokenizer is a required input.
40
+ max_seq_len (`int`, *optional*, defaults to 77)):
41
+ Sequence length for input text list.
42
+ task_seq_len (`int`, *optional*, defaults to 77):
43
+ Sequence length for input task token.
44
+ """
45
+
46
+ attributes = ["image_processor", "tokenizer"]
47
+ image_processor_class = "OneFormerImageProcessor"
48
+ tokenizer_class = ("CLIPTokenizer", "CLIPTokenizerFast")
49
+
50
+ def __init__(
51
+ self, image_processor=None, tokenizer=None, max_seq_length: int = 77, task_seq_length: int = 77, **kwargs
52
+ ):
53
+ if image_processor is None:
54
+ raise ValueError("You need to specify an `image_processor`.")
55
+ if tokenizer is None:
56
+ raise ValueError("You need to specify a `tokenizer`.")
57
+
58
+ self.max_seq_length = max_seq_length
59
+ self.task_seq_length = task_seq_length
60
+
61
+ super().__init__(image_processor, tokenizer)
62
+
63
+ def _preprocess_text(self, text_list=None, max_length=77):
64
+ if text_list is None:
65
+ raise ValueError("tokens cannot be None.")
66
+
67
+ tokens = self.tokenizer(text_list, padding="max_length", max_length=max_length, truncation=True)
68
+
69
+ attention_masks, input_ids = tokens["attention_mask"], tokens["input_ids"]
70
+
71
+ token_inputs = []
72
+ for attn_mask, input_id in zip(attention_masks, input_ids):
73
+ token = torch.tensor(attn_mask) * torch.tensor(input_id)
74
+ token_inputs.append(token.unsqueeze(0))
75
+
76
+ token_inputs = torch.cat(token_inputs, dim=0)
77
+ return token_inputs
78
+
79
+ def __call__(self, images=None, task_inputs=None, segmentation_maps=None, **kwargs):
80
+ """
81
+ Main method to prepare for the model one or several task input(s) and image(s). This method forwards the
82
+ `task_inputs` and `kwargs` arguments to CLIPTokenizer's [`~CLIPTokenizer.__call__`] if `task_inputs` is not
83
+ `None` to encode. To prepare the image(s), this method forwards the `images` and `kwargs` arguments to
84
+ OneFormerImageProcessor's [`~OneFormerImageProcessor.__call__`] if `images` is not `None`. Please refer to the
85
+ docstring of the above two methods for more information.
86
+
87
+ Args:
88
+ task_inputs (`str`, `List[str]`):
89
+ The sequence or batch of task_inputs sequences to be encoded. Each sequence can be a string or a list
90
+ of strings of the template "the task is {task}".
91
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`,
92
+ `List[torch.Tensor]`):
93
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
94
+ tensor. Both channels-first and channels-last formats are supported.
95
+ segmentation_maps (`ImageInput`, *optional*):
96
+ The corresponding semantic segmentation maps with the pixel-wise annotations.
97
+
98
+ (`bool`, *optional*, defaults to `True`):
99
+ Whether or not to pad images up to the largest image in a batch and create a pixel mask.
100
+
101
+ If left to the default, will return a pixel mask that is:
102
+
103
+ - 1 for pixels that are real (i.e. **not masked**),
104
+ - 0 for pixels that are padding (i.e. **masked**).
105
+ Returns:
106
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
107
+ - **task_inputs** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
108
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
109
+ """
110
+
111
+ if task_inputs is None:
112
+ raise ValueError("You have to specify the task_input. Found None.")
113
+ elif images is None:
114
+ raise ValueError("You have to specify the image. Found None.")
115
+
116
+ if not all(task in ["semantic", "instance", "panoptic"] for task in task_inputs):
117
+ raise ValueError("task_inputs must be semantic, instance, or panoptic.")
118
+
119
+ encoded_inputs = self.image_processor(images, task_inputs, segmentation_maps, **kwargs)
120
+
121
+ if isinstance(task_inputs, str):
122
+ task_inputs = [task_inputs]
123
+
124
+ if isinstance(task_inputs, List) and all(isinstance(task_input, str) for task_input in task_inputs):
125
+ task_token_inputs = []
126
+ for task in task_inputs:
127
+ task_input = f"the task is {task}"
128
+ task_token_inputs.append(task_input)
129
+ encoded_inputs["task_inputs"] = self._preprocess_text(task_token_inputs, max_length=self.task_seq_length)
130
+ else:
131
+ raise TypeError("Task Inputs should be a string or a list of strings.")
132
+
133
+ if hasattr(encoded_inputs, "text_inputs"):
134
+ texts_list = encoded_inputs.text_inputs
135
+
136
+ text_inputs = []
137
+ for texts in texts_list:
138
+ text_input_list = self._preprocess_text(texts, max_length=self.max_seq_length)
139
+ text_inputs.append(text_input_list.unsqueeze(0))
140
+
141
+ encoded_inputs["text_inputs"] = torch.cat(text_inputs, dim=0)
142
+
143
+ return encoded_inputs
144
+
145
+ def encode_inputs(self, images=None, task_inputs=None, segmentation_maps=None, **kwargs):
146
+ """
147
+ This method forwards all its arguments to [`OneFormerImageProcessor.encode_inputs`] and then tokenizes the
148
+ task_inputs. Please refer to the docstring of this method for more information.
149
+ """
150
+
151
+ if task_inputs is None:
152
+ raise ValueError("You have to specify the task_input. Found None.")
153
+ elif images is None:
154
+ raise ValueError("You have to specify the image. Found None.")
155
+
156
+ if not all(task in ["semantic", "instance", "panoptic"] for task in task_inputs):
157
+ raise ValueError("task_inputs must be semantic, instance, or panoptic.")
158
+
159
+ encoded_inputs = self.image_processor.encode_inputs(images, task_inputs, segmentation_maps, **kwargs)
160
+
161
+ if isinstance(task_inputs, str):
162
+ task_inputs = [task_inputs]
163
+
164
+ if isinstance(task_inputs, List) and all(isinstance(task_input, str) for task_input in task_inputs):
165
+ task_token_inputs = []
166
+ for task in task_inputs:
167
+ task_input = f"the task is {task}"
168
+ task_token_inputs.append(task_input)
169
+ encoded_inputs["task_inputs"] = self._preprocess_text(task_token_inputs, max_length=self.task_seq_length)
170
+ else:
171
+ raise TypeError("Task Inputs should be a string or a list of strings.")
172
+
173
+ if hasattr(encoded_inputs, "text_inputs"):
174
+ texts_list = encoded_inputs.text_inputs
175
+
176
+ text_inputs = []
177
+ for texts in texts_list:
178
+ text_input_list = self._preprocess_text(texts, max_length=self.max_seq_length)
179
+ text_inputs.append(text_input_list.unsqueeze(0))
180
+
181
+ encoded_inputs["text_inputs"] = torch.cat(text_inputs, dim=0)
182
+
183
+ return encoded_inputs
184
+
185
+ def post_process_semantic_segmentation(self, *args, **kwargs):
186
+ """
187
+ This method forwards all its arguments to [`OneFormerImageProcessor.post_process_semantic_segmentation`].
188
+ Please refer to the docstring of this method for more information.
189
+ """
190
+ return self.image_processor.post_process_semantic_segmentation(*args, **kwargs)
191
+
192
+ def post_process_instance_segmentation(self, *args, **kwargs):
193
+ """
194
+ This method forwards all its arguments to [`OneFormerImageProcessor.post_process_instance_segmentation`].
195
+ Please refer to the docstring of this method for more information.
196
+ """
197
+ return self.image_processor.post_process_instance_segmentation(*args, **kwargs)
198
+
199
+ def post_process_panoptic_segmentation(self, *args, **kwargs):
200
+ """
201
+ This method forwards all its arguments to [`OneFormerImageProcessor.post_process_panoptic_segmentation`].
202
+ Please refer to the docstring of this method for more information.
203
+ """
204
+ return self.image_processor.post_process_panoptic_segmentation(*args, **kwargs)
205
+
206
+
207
+ __all__ = ["OneFormerProcessor"]
docs/transformers/src/transformers/models/openai/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_openai import *
22
+ from .modeling_openai import *
23
+ from .modeling_tf_openai import *
24
+ from .tokenization_openai import *
25
+ from .tokenization_openai_fast import *
26
+ else:
27
+ import sys
28
+
29
+ _file = globals()["__file__"]
30
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/src/transformers/models/openai/configuration_openai.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """OpenAI GPT configuration"""
17
+
18
+ from ...configuration_utils import PretrainedConfig
19
+ from ...utils import logging
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class OpenAIGPTConfig(PretrainedConfig):
26
+ """
27
+ This is the configuration class to store the configuration of a [`OpenAIGPTModel`] or a [`TFOpenAIGPTModel`]. It is
28
+ used to instantiate a GPT model according to the specified arguments, defining the model architecture.
29
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the GPT
30
+ [openai-community/openai-gpt](https://huggingface.co/openai-community/openai-gpt) architecture from OpenAI.
31
+
32
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
33
+ documentation from [`PretrainedConfig`] for more information.
34
+
35
+ Args:
36
+ vocab_size (`int`, *optional*, defaults to 40478):
37
+ Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the
38
+ `inputs_ids` passed when calling [`OpenAIGPTModel`] or [`TFOpenAIGPTModel`].
39
+ n_positions (`int`, *optional*, defaults to 512):
40
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
41
+ just in case (e.g., 512 or 1024 or 2048).
42
+ n_embd (`int`, *optional*, defaults to 768):
43
+ Dimensionality of the embeddings and hidden states.
44
+ n_layer (`int`, *optional*, defaults to 12):
45
+ Number of hidden layers in the Transformer encoder.
46
+ n_head (`int`, *optional*, defaults to 12):
47
+ Number of attention heads for each attention layer in the Transformer encoder.
48
+ afn (`str` or `Callable`, *optional*, defaults to `"gelu"`):
49
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
50
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
51
+ resid_pdrop (`float`, *optional*, defaults to 0.1):
52
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
53
+ embd_pdrop (`int`, *optional*, defaults to 0.1):
54
+ The dropout ratio for the embeddings.
55
+ attn_pdrop (`float`, *optional*, defaults to 0.1):
56
+ The dropout ratio for the attention.
57
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
58
+ The epsilon to use in the layer normalization layers
59
+ initializer_range (`float`, *optional*, defaults to 0.02):
60
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
61
+ summary_type (`str`, *optional*, defaults to `"cls_index"`):
62
+ Argument used when doing sequence summary, used in the models [`OpenAIGPTDoubleHeadsModel`] and
63
+ [`OpenAIGPTDoubleHeadsModel`].
64
+
65
+ Has to be one of the following options:
66
+
67
+ - `"last"`: Take the last token hidden state (like XLNet).
68
+ - `"first"`: Take the first token hidden state (like BERT).
69
+ - `"mean"`: Take the mean of all tokens hidden states.
70
+ - `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2).
71
+ - `"attn"`: Not implemented now, use multi-head attention.
72
+ summary_use_proj (`bool`, *optional*, defaults to `True`):
73
+ Argument used when doing sequence summary, used in the models [`OpenAIGPTDoubleHeadsModel`] and
74
+ [`OpenAIGPTDoubleHeadsModel`].
75
+
76
+ Whether or not to add a projection after the vector extraction.
77
+ summary_activation (`str`, *optional*):
78
+ Argument used when doing sequence summary, used in the models [`OpenAIGPTDoubleHeadsModel`] and
79
+ [`OpenAIGPTDoubleHeadsModel`].
80
+
81
+ Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation.
82
+ summary_proj_to_labels (`bool`, *optional*, defaults to `True`):
83
+ Argument used when doing sequence summary, used in the models [`OpenAIGPTDoubleHeadsModel`] and
84
+ [`OpenAIGPTDoubleHeadsModel`].
85
+
86
+ Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes.
87
+ summary_first_dropout (`float`, *optional*, defaults to 0.1):
88
+ Argument used when doing sequence summary, used in the models [`OpenAIGPTDoubleHeadsModel`] and
89
+ [`OpenAIGPTDoubleHeadsModel`].
90
+
91
+ The dropout ratio to be used after the projection and activation.
92
+
93
+
94
+ Examples:
95
+
96
+ ```python
97
+ >>> from transformers import OpenAIGPTConfig, OpenAIGPTModel
98
+
99
+ >>> # Initializing a GPT configuration
100
+ >>> configuration = OpenAIGPTConfig()
101
+
102
+ >>> # Initializing a model (with random weights) from the configuration
103
+ >>> model = OpenAIGPTModel(configuration)
104
+
105
+ >>> # Accessing the model configuration
106
+ >>> configuration = model.config
107
+ ```"""
108
+
109
+ model_type = "openai-gpt"
110
+ attribute_map = {
111
+ "max_position_embeddings": "n_positions",
112
+ "hidden_size": "n_embd",
113
+ "num_attention_heads": "n_head",
114
+ "num_hidden_layers": "n_layer",
115
+ }
116
+
117
+ def __init__(
118
+ self,
119
+ vocab_size=40478,
120
+ n_positions=512,
121
+ n_embd=768,
122
+ n_layer=12,
123
+ n_head=12,
124
+ afn="gelu",
125
+ resid_pdrop=0.1,
126
+ embd_pdrop=0.1,
127
+ attn_pdrop=0.1,
128
+ layer_norm_epsilon=1e-5,
129
+ initializer_range=0.02,
130
+ summary_type="cls_index",
131
+ summary_use_proj=True,
132
+ summary_activation=None,
133
+ summary_proj_to_labels=True,
134
+ summary_first_dropout=0.1,
135
+ **kwargs,
136
+ ):
137
+ self.vocab_size = vocab_size
138
+ self.n_positions = n_positions
139
+ self.n_embd = n_embd
140
+ self.n_layer = n_layer
141
+ self.n_head = n_head
142
+ self.afn = afn
143
+ self.resid_pdrop = resid_pdrop
144
+ self.embd_pdrop = embd_pdrop
145
+ self.attn_pdrop = attn_pdrop
146
+ self.layer_norm_epsilon = layer_norm_epsilon
147
+ self.initializer_range = initializer_range
148
+ self.summary_type = summary_type
149
+ self.summary_use_proj = summary_use_proj
150
+ self.summary_activation = summary_activation
151
+ self.summary_first_dropout = summary_first_dropout
152
+ self.summary_proj_to_labels = summary_proj_to_labels
153
+ super().__init__(**kwargs)
154
+
155
+
156
+ __all__ = ["OpenAIGPTConfig"]
docs/transformers/src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert OpenAI GPT checkpoint."""
16
+
17
+ import argparse
18
+
19
+ import torch
20
+
21
+ from transformers import OpenAIGPTConfig, OpenAIGPTModel, load_tf_weights_in_openai_gpt
22
+ from transformers.utils import CONFIG_NAME, WEIGHTS_NAME, logging
23
+
24
+
25
+ logging.set_verbosity_info()
26
+
27
+
28
+ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path):
29
+ # Construct model
30
+ if openai_config_file == "":
31
+ config = OpenAIGPTConfig()
32
+ else:
33
+ config = OpenAIGPTConfig.from_json_file(openai_config_file)
34
+ model = OpenAIGPTModel(config)
35
+
36
+ # Load weights from numpy
37
+ load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path)
38
+
39
+ # Save pytorch-model
40
+ pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
41
+ pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
42
+ print(f"Save PyTorch model to {pytorch_weights_dump_path}")
43
+ torch.save(model.state_dict(), pytorch_weights_dump_path)
44
+ print(f"Save configuration file to {pytorch_config_dump_path}")
45
+ with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
46
+ f.write(config.to_json_string())
47
+
48
+
49
+ if __name__ == "__main__":
50
+ parser = argparse.ArgumentParser()
51
+ # Required parameters
52
+ parser.add_argument(
53
+ "--openai_checkpoint_folder_path",
54
+ default=None,
55
+ type=str,
56
+ required=True,
57
+ help="Path to the TensorFlow checkpoint path.",
58
+ )
59
+ parser.add_argument(
60
+ "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
61
+ )
62
+ parser.add_argument(
63
+ "--openai_config_file",
64
+ default="",
65
+ type=str,
66
+ help=(
67
+ "An optional config json file corresponding to the pre-trained OpenAI model. \n"
68
+ "This specifies the model architecture."
69
+ ),
70
+ )
71
+ args = parser.parse_args()
72
+ convert_openai_checkpoint_to_pytorch(
73
+ args.openai_checkpoint_folder_path, args.openai_config_file, args.pytorch_dump_folder_path
74
+ )
docs/transformers/src/transformers/models/openai/modeling_openai.py ADDED
@@ -0,0 +1,967 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch OpenAI GPT model."""
17
+
18
+ import json
19
+ import math
20
+ import os
21
+ from dataclasses import dataclass
22
+ from typing import Any, Callable, Dict, Optional, Tuple, Union
23
+
24
+ import torch
25
+ from torch import nn
26
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
+
28
+ from ...activations import gelu_new, get_activation, silu
29
+ from ...generation import GenerationMixin
30
+ from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
31
+ from ...modeling_utils import PreTrainedModel
32
+ from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
33
+ from ...utils import (
34
+ ModelOutput,
35
+ add_code_sample_docstrings,
36
+ add_start_docstrings,
37
+ add_start_docstrings_to_model_forward,
38
+ logging,
39
+ replace_return_docstrings,
40
+ )
41
+ from .configuration_openai import OpenAIGPTConfig
42
+
43
+
44
+ logger = logging.get_logger(__name__)
45
+
46
+ _CHECKPOINT_FOR_DOC = "openai-community/openai-gpt"
47
+ _CONFIG_FOR_DOC = "OpenAIGPTConfig"
48
+
49
+
50
+ def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
51
+ """Load tf pre-trained weights in a pytorch model (from NumPy arrays here)"""
52
+ import re
53
+
54
+ import numpy as np
55
+
56
+ if ".ckpt" in openai_checkpoint_folder_path:
57
+ openai_checkpoint_folder_path = os.path.dirname(openai_checkpoint_folder_path)
58
+
59
+ logger.info(f"Loading weights from {openai_checkpoint_folder_path}")
60
+
61
+ with open(openai_checkpoint_folder_path + "/parameters_names.json", "r", encoding="utf-8") as names_handle:
62
+ names = json.load(names_handle)
63
+ with open(openai_checkpoint_folder_path + "/params_shapes.json", "r", encoding="utf-8") as shapes_handle:
64
+ shapes = json.load(shapes_handle)
65
+ offsets = np.cumsum([np.prod(shape) for shape in shapes])
66
+ init_params = [np.load(openai_checkpoint_folder_path + f"/params_{n}.npy") for n in range(10)]
67
+ init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
68
+ init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
69
+
70
+ # This was used when we had a single embedding matrix for positions and tokens
71
+ # init_params[0] = np.concatenate([init_params[1], init_params[0]], 0)
72
+ # del init_params[1]
73
+ init_params = [arr.squeeze() for arr in init_params]
74
+
75
+ # Check that the token and position embeddings weight dimensions map those of the init parameters.
76
+ if model.tokens_embed.weight.shape != init_params[1].shape:
77
+ raise ValueError(
78
+ f"tokens_embed.weight.shape: {model.tokens_embed.weight.shape} does not match init_param[1].shape:"
79
+ f" {init_params[1].shape}"
80
+ )
81
+
82
+ if model.positions_embed.weight.shape != init_params[0].shape:
83
+ raise ValueError(
84
+ f"positions_embed.weight.shape: {model.positions_embed.weight.shape} does not match init_param[0].shape:"
85
+ f" {init_params[0].shape}"
86
+ )
87
+
88
+ model.tokens_embed.weight.data = torch.from_numpy(init_params[1])
89
+ model.positions_embed.weight.data = torch.from_numpy(init_params[0])
90
+ names.pop(0)
91
+ # Pop position and token embedding arrays
92
+ init_params.pop(0)
93
+ init_params.pop(0)
94
+
95
+ for name, array in zip(names, init_params): # names[1:n_transfer], init_params[1:n_transfer]):
96
+ name = name[6:] # skip "model/"
97
+ if name[-2:] != ":0":
98
+ raise ValueError(f"Layer {name} does not end with :0")
99
+ name = name[:-2]
100
+ name = name.split("/")
101
+ pointer = model
102
+ for m_name in name:
103
+ if re.fullmatch(r"[A-Za-z]+\d+", m_name):
104
+ scope_names = re.split(r"(\d+)", m_name)
105
+ else:
106
+ scope_names = [m_name]
107
+ if scope_names[0] == "g":
108
+ pointer = getattr(pointer, "weight")
109
+ elif scope_names[0] == "b":
110
+ pointer = getattr(pointer, "bias")
111
+ elif scope_names[0] == "w":
112
+ pointer = getattr(pointer, "weight")
113
+ else:
114
+ pointer = getattr(pointer, scope_names[0])
115
+ if len(scope_names) >= 2:
116
+ num = int(scope_names[1])
117
+ pointer = pointer[num]
118
+
119
+ # Ensure that the pointer and array have compatible shapes.
120
+ if pointer.shape != array.shape:
121
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
122
+
123
+ logger.info(f"Initialize PyTorch weight {name}")
124
+ pointer.data = torch.from_numpy(array)
125
+ return model
126
+
127
+
128
+ ACT_FNS = {"relu": nn.ReLU(), "silu": silu, "gelu": gelu_new, "swish": silu}
129
+
130
+
131
+ class Attention(nn.Module):
132
+ def __init__(self, nx, n_positions, config, scale=False):
133
+ super().__init__()
134
+ n_state = nx # in Attention: n_state=768 (nx=n_embd)
135
+ # [switch nx => n_state from Block to Attention to keep identical to TF implementation]
136
+ if n_state % config.n_head != 0:
137
+ raise ValueError(f"Attention n_state shape: {n_state} must be divisible by config.n_head {config.n_head}")
138
+ self.register_buffer(
139
+ "bias",
140
+ torch.tril(torch.ones(n_positions, n_positions)).view(1, 1, n_positions, n_positions),
141
+ persistent=False,
142
+ )
143
+ self.n_head = config.n_head
144
+ self.split_size = n_state
145
+ self.scale = scale
146
+
147
+ self.c_attn = Conv1D(n_state * 3, nx)
148
+ self.c_proj = Conv1D(n_state, nx)
149
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
150
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
151
+ self.pruned_heads = set()
152
+
153
+ def prune_heads(self, heads):
154
+ if len(heads) == 0:
155
+ return
156
+ heads, index = find_pruneable_heads_and_indices(
157
+ heads, self.n_head, self.split_size // self.n_head, self.pruned_heads
158
+ )
159
+ index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
160
+ # Prune conv1d layers
161
+ self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
162
+ self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
163
+ # Update hyper params
164
+ self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
165
+ self.n_head = self.n_head - len(heads)
166
+ self.pruned_heads = self.pruned_heads.union(heads)
167
+
168
+ def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False):
169
+ w = torch.matmul(q, k)
170
+ if self.scale:
171
+ w = w / math.sqrt(v.size(-1))
172
+ # w = w * self.bias + -1e9 * (1 - self.bias) # TF implementation method: mask_attn_weights
173
+ # XD: self.b may be larger than w, so we need to crop it
174
+ b = self.bias[:, :, : w.size(-2), : w.size(-1)]
175
+ w = w * b + -1e4 * (1 - b)
176
+
177
+ if attention_mask is not None:
178
+ # Apply the attention mask
179
+ w = w + attention_mask
180
+
181
+ w = nn.functional.softmax(w, dim=-1)
182
+ w = self.attn_dropout(w)
183
+
184
+ # Mask heads if we want to
185
+ if head_mask is not None:
186
+ w = w * head_mask
187
+
188
+ outputs = [torch.matmul(w, v)]
189
+ if output_attentions:
190
+ outputs.append(w)
191
+ return outputs
192
+
193
+ def merge_heads(self, x):
194
+ x = x.permute(0, 2, 1, 3).contiguous()
195
+ new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
196
+ return x.view(*new_x_shape) # in Tensorflow implementation: fct merge_states
197
+
198
+ def split_heads(self, x, k=False):
199
+ new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
200
+ x = x.view(*new_x_shape) # in Tensorflow implementation: fct split_states
201
+ if k:
202
+ return x.permute(0, 2, 3, 1)
203
+ else:
204
+ return x.permute(0, 2, 1, 3)
205
+
206
+ def forward(self, x, attention_mask=None, head_mask=None, output_attentions=False):
207
+ x = self.c_attn(x)
208
+ query, key, value = x.split(self.split_size, dim=2)
209
+ query = self.split_heads(query)
210
+ key = self.split_heads(key, k=True)
211
+ value = self.split_heads(value)
212
+
213
+ attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions)
214
+ a = attn_outputs[0]
215
+
216
+ a = self.merge_heads(a)
217
+ a = self.c_proj(a)
218
+ a = self.resid_dropout(a)
219
+
220
+ outputs = [a] + attn_outputs[1:]
221
+ return outputs # a, (attentions)
222
+
223
+
224
+ class MLP(nn.Module):
225
+ def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
226
+ super().__init__()
227
+ nx = config.n_embd
228
+ self.c_fc = Conv1D(n_state, nx)
229
+ self.c_proj = Conv1D(nx, n_state)
230
+ self.act = ACT_FNS[config.afn]
231
+ self.dropout = nn.Dropout(config.resid_pdrop)
232
+
233
+ def forward(self, x):
234
+ h = self.act(self.c_fc(x))
235
+ h2 = self.c_proj(h)
236
+ return self.dropout(h2)
237
+
238
+
239
+ class Block(nn.Module):
240
+ def __init__(self, n_positions, config, scale=False):
241
+ super().__init__()
242
+ nx = config.n_embd
243
+ self.attn = Attention(nx, n_positions, config, scale)
244
+ self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
245
+ self.mlp = MLP(4 * nx, config)
246
+ self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
247
+
248
+ def forward(self, x, attention_mask=None, head_mask=None, output_attentions=False):
249
+ attn_outputs = self.attn(
250
+ x,
251
+ attention_mask=attention_mask,
252
+ head_mask=head_mask,
253
+ output_attentions=output_attentions,
254
+ )
255
+ a = attn_outputs[0]
256
+
257
+ n = self.ln_1(x + a)
258
+ m = self.mlp(n)
259
+ h = self.ln_2(n + m)
260
+
261
+ outputs = [h] + attn_outputs[1:]
262
+ return outputs
263
+
264
+
265
+ # Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->OpenAIGPT
266
+ class OpenAIGPTSequenceSummary(nn.Module):
267
+ r"""
268
+ Compute a single vector summary of a sequence hidden states.
269
+
270
+ Args:
271
+ config ([`OpenAIGPTConfig`]):
272
+ The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
273
+ config class of your model for the default values it uses):
274
+
275
+ - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
276
+
277
+ - `"last"` -- Take the last token hidden state (like XLNet)
278
+ - `"first"` -- Take the first token hidden state (like Bert)
279
+ - `"mean"` -- Take the mean of all tokens hidden states
280
+ - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
281
+ - `"attn"` -- Not implemented now, use multi-head attention
282
+
283
+ - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
284
+ - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
285
+ (otherwise to `config.hidden_size`).
286
+ - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
287
+ another string or `None` will add no activation.
288
+ - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
289
+ - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
290
+ """
291
+
292
+ def __init__(self, config: OpenAIGPTConfig):
293
+ super().__init__()
294
+
295
+ self.summary_type = getattr(config, "summary_type", "last")
296
+ if self.summary_type == "attn":
297
+ # We should use a standard multi-head attention module with absolute positional embedding for that.
298
+ # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
299
+ # We can probably just use the multi-head attention module of PyTorch >=1.1.0
300
+ raise NotImplementedError
301
+
302
+ self.summary = nn.Identity()
303
+ if hasattr(config, "summary_use_proj") and config.summary_use_proj:
304
+ if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
305
+ num_classes = config.num_labels
306
+ else:
307
+ num_classes = config.hidden_size
308
+ self.summary = nn.Linear(config.hidden_size, num_classes)
309
+
310
+ activation_string = getattr(config, "summary_activation", None)
311
+ self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
312
+
313
+ self.first_dropout = nn.Identity()
314
+ if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
315
+ self.first_dropout = nn.Dropout(config.summary_first_dropout)
316
+
317
+ self.last_dropout = nn.Identity()
318
+ if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
319
+ self.last_dropout = nn.Dropout(config.summary_last_dropout)
320
+
321
+ def forward(
322
+ self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
323
+ ) -> torch.FloatTensor:
324
+ """
325
+ Compute a single vector summary of a sequence hidden states.
326
+
327
+ Args:
328
+ hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
329
+ The hidden states of the last layer.
330
+ cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
331
+ Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
332
+
333
+ Returns:
334
+ `torch.FloatTensor`: The summary of the sequence hidden states.
335
+ """
336
+ if self.summary_type == "last":
337
+ output = hidden_states[:, -1]
338
+ elif self.summary_type == "first":
339
+ output = hidden_states[:, 0]
340
+ elif self.summary_type == "mean":
341
+ output = hidden_states.mean(dim=1)
342
+ elif self.summary_type == "cls_index":
343
+ if cls_index is None:
344
+ cls_index = torch.full_like(
345
+ hidden_states[..., :1, :],
346
+ hidden_states.shape[-2] - 1,
347
+ dtype=torch.long,
348
+ )
349
+ else:
350
+ cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
351
+ cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
352
+ # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
353
+ output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
354
+ elif self.summary_type == "attn":
355
+ raise NotImplementedError
356
+
357
+ output = self.first_dropout(output)
358
+ output = self.summary(output)
359
+ output = self.activation(output)
360
+ output = self.last_dropout(output)
361
+
362
+ return output
363
+
364
+
365
+ class OpenAIGPTPreTrainedModel(PreTrainedModel):
366
+ """
367
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
368
+ models.
369
+ """
370
+
371
+ config_class = OpenAIGPTConfig
372
+ load_tf_weights = load_tf_weights_in_openai_gpt
373
+ base_model_prefix = "transformer"
374
+
375
+ def _init_weights(self, module):
376
+ """Initialize the weights."""
377
+ if isinstance(module, (nn.Linear, Conv1D)):
378
+ # Slightly different from the TF version which uses truncated_normal for initialization
379
+ # cf https://github.com/pytorch/pytorch/pull/5617
380
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
381
+ if module.bias is not None:
382
+ module.bias.data.zero_()
383
+ elif isinstance(module, nn.Embedding):
384
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
385
+ if module.padding_idx is not None:
386
+ module.weight.data[module.padding_idx].zero_()
387
+ elif isinstance(module, nn.LayerNorm):
388
+ module.bias.data.zero_()
389
+ module.weight.data.fill_(1.0)
390
+
391
+
392
+ @dataclass
393
+ class OpenAIGPTDoubleHeadsModelOutput(ModelOutput):
394
+ """
395
+ Base class for outputs of models predicting if two sentences are consecutive or not.
396
+
397
+ Args:
398
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
399
+ Language modeling loss.
400
+ mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided):
401
+ Multiple choice classification loss.
402
+ logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`):
403
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
404
+ mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
405
+ Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
406
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
407
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
408
+ shape `(batch_size, sequence_length, hidden_size)`.
409
+
410
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
411
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
412
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
413
+ sequence_length)`.
414
+
415
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
416
+ heads.
417
+ """
418
+
419
+ loss: Optional[torch.FloatTensor] = None
420
+ mc_loss: Optional[torch.FloatTensor] = None
421
+ logits: Optional[torch.FloatTensor] = None
422
+ mc_logits: Optional[torch.FloatTensor] = None
423
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
424
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
425
+
426
+
427
+ OPENAI_GPT_START_DOCSTRING = r"""
428
+
429
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
430
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
431
+ etc.)
432
+
433
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
434
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
435
+ and behavior.
436
+
437
+ Parameters:
438
+ config ([`OpenAIGPTConfig`]): Model configuration class with all the parameters of the model.
439
+ Initializing with a config file does not load the weights associated with the model, only the
440
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
441
+ """
442
+
443
+ OPENAI_GPT_INPUTS_DOCSTRING = r"""
444
+ Args:
445
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
446
+ Indices of input sequence tokens in the vocabulary.
447
+
448
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
449
+ [`PreTrainedTokenizer.__call__`] for details.
450
+
451
+ [What are input IDs?](../glossary#input-ids)
452
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
453
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
454
+
455
+ - 1 for tokens that are **not masked**,
456
+ - 0 for tokens that are **masked**.
457
+
458
+ [What are attention masks?](../glossary#attention-mask)
459
+ token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
460
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
461
+ 1]`:
462
+
463
+ - 0 corresponds to a *sentence A* token,
464
+ - 1 corresponds to a *sentence B* token.
465
+
466
+ [What are token type IDs?](../glossary#token-type-ids)
467
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
468
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
469
+ config.max_position_embeddings - 1]`.
470
+
471
+ [What are position IDs?](../glossary#position-ids)
472
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
473
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
474
+
475
+ - 1 indicates the head is **not masked**,
476
+ - 0 indicates the head is **masked**.
477
+
478
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
479
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
480
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
481
+ model's internal embedding lookup matrix.
482
+ output_attentions (`bool`, *optional*):
483
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
484
+ tensors for more detail.
485
+ output_hidden_states (`bool`, *optional*):
486
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
487
+ more detail.
488
+ return_dict (`bool`, *optional*):
489
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
490
+ """
491
+
492
+
493
+ @add_start_docstrings(
494
+ "The bare OpenAI GPT transformer model outputting raw hidden-states without any specific head on top.",
495
+ OPENAI_GPT_START_DOCSTRING,
496
+ )
497
+ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
498
+ def __init__(self, config):
499
+ super().__init__(config)
500
+
501
+ self.tokens_embed = nn.Embedding(config.vocab_size, config.n_embd)
502
+ self.positions_embed = nn.Embedding(config.n_positions, config.n_embd)
503
+ self.drop = nn.Dropout(config.embd_pdrop)
504
+ self.h = nn.ModuleList([Block(config.n_positions, config, scale=True) for _ in range(config.n_layer)])
505
+
506
+ self.register_buffer("position_ids", torch.arange(config.n_positions), persistent=False)
507
+ # Initialize weights and apply final processing
508
+ self.post_init()
509
+
510
+ def get_input_embeddings(self):
511
+ return self.tokens_embed
512
+
513
+ def set_input_embeddings(self, new_embeddings):
514
+ self.tokens_embed = new_embeddings
515
+
516
+ def _prune_heads(self, heads_to_prune):
517
+ """
518
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
519
+ """
520
+ for layer, heads in heads_to_prune.items():
521
+ self.h[layer].attn.prune_heads(heads)
522
+
523
+ @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
524
+ @add_code_sample_docstrings(
525
+ checkpoint=_CHECKPOINT_FOR_DOC,
526
+ output_type=BaseModelOutput,
527
+ config_class=_CONFIG_FOR_DOC,
528
+ )
529
+ def forward(
530
+ self,
531
+ input_ids: Optional[torch.LongTensor] = None,
532
+ attention_mask: Optional[torch.FloatTensor] = None,
533
+ token_type_ids: Optional[torch.LongTensor] = None,
534
+ position_ids: Optional[torch.LongTensor] = None,
535
+ head_mask: Optional[torch.FloatTensor] = None,
536
+ inputs_embeds: Optional[torch.FloatTensor] = None,
537
+ output_attentions: Optional[bool] = None,
538
+ output_hidden_states: Optional[bool] = None,
539
+ return_dict: Optional[bool] = None,
540
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
541
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
542
+ output_hidden_states = (
543
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
544
+ )
545
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
546
+
547
+ if input_ids is not None and inputs_embeds is not None:
548
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
549
+ elif input_ids is not None:
550
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
551
+ input_shape = input_ids.size()
552
+ input_ids = input_ids.view(-1, input_shape[-1])
553
+ elif inputs_embeds is not None:
554
+ input_shape = inputs_embeds.size()[:-1]
555
+ else:
556
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
557
+
558
+ if position_ids is None:
559
+ # Code is different from when we had a single embedding matrix from position and token embeddings
560
+ position_ids = self.position_ids[None, : input_shape[-1]]
561
+
562
+ # Attention mask.
563
+ if attention_mask is not None:
564
+ # We create a 3D attention mask from a 2D tensor mask.
565
+ # Sizes are [batch_size, 1, 1, to_seq_length]
566
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
567
+ # this attention mask is more simple than the triangular masking of causal attention
568
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
569
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
570
+
571
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
572
+ # masked positions, this operation will create a tensor which is 0.0 for
573
+ # positions we want to attend and the dtype's smallest value for masked positions.
574
+ # Since we are adding it to the raw scores before the softmax, this is
575
+ # effectively the same as removing these entirely.
576
+ attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
577
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
578
+
579
+ # Prepare head mask if needed
580
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
581
+
582
+ if inputs_embeds is None:
583
+ inputs_embeds = self.tokens_embed(input_ids)
584
+ position_embeds = self.positions_embed(position_ids)
585
+ if token_type_ids is not None:
586
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
587
+ token_type_embeds = self.tokens_embed(token_type_ids)
588
+ else:
589
+ token_type_embeds = 0
590
+ hidden_states = inputs_embeds + position_embeds + token_type_embeds
591
+ hidden_states = self.drop(hidden_states)
592
+
593
+ output_shape = input_shape + (hidden_states.size(-1),)
594
+
595
+ all_attentions = () if output_attentions else None
596
+ all_hidden_states = () if output_hidden_states else None
597
+ for i, block in enumerate(self.h):
598
+ if output_hidden_states:
599
+ all_hidden_states = all_hidden_states + (hidden_states,)
600
+
601
+ outputs = block(hidden_states, attention_mask, head_mask[i], output_attentions=output_attentions)
602
+ hidden_states = outputs[0]
603
+ if output_attentions:
604
+ all_attentions = all_attentions + (outputs[1],)
605
+
606
+ hidden_states = hidden_states.view(*output_shape)
607
+ # Add last layer
608
+ if output_hidden_states:
609
+ all_hidden_states = all_hidden_states + (hidden_states,)
610
+
611
+ if not return_dict:
612
+ return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
613
+
614
+ return BaseModelOutput(
615
+ last_hidden_state=hidden_states,
616
+ hidden_states=all_hidden_states,
617
+ attentions=all_attentions,
618
+ )
619
+
620
+
621
+ @add_start_docstrings(
622
+ """
623
+ OpenAI GPT Model transformer with a language modeling head on top (linear layer with weights tied to the input
624
+ embeddings).
625
+ """,
626
+ OPENAI_GPT_START_DOCSTRING,
627
+ )
628
+ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel, GenerationMixin):
629
+ _tied_weights_keys = ["lm_head.weight"]
630
+
631
+ def __init__(self, config):
632
+ super().__init__(config)
633
+ self.transformer = OpenAIGPTModel(config)
634
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
635
+
636
+ # Initialize weights and apply final processing
637
+ self.post_init()
638
+
639
+ def get_output_embeddings(self):
640
+ return self.lm_head
641
+
642
+ def set_output_embeddings(self, new_embeddings):
643
+ self.lm_head = new_embeddings
644
+
645
+ @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
646
+ @add_code_sample_docstrings(
647
+ checkpoint=_CHECKPOINT_FOR_DOC,
648
+ output_type=CausalLMOutput,
649
+ config_class=_CONFIG_FOR_DOC,
650
+ )
651
+ def forward(
652
+ self,
653
+ input_ids: Optional[torch.LongTensor] = None,
654
+ attention_mask: Optional[torch.FloatTensor] = None,
655
+ token_type_ids: Optional[torch.LongTensor] = None,
656
+ position_ids: Optional[torch.LongTensor] = None,
657
+ head_mask: Optional[torch.FloatTensor] = None,
658
+ inputs_embeds: Optional[torch.FloatTensor] = None,
659
+ labels: Optional[torch.LongTensor] = None,
660
+ output_attentions: Optional[bool] = None,
661
+ output_hidden_states: Optional[bool] = None,
662
+ return_dict: Optional[bool] = None,
663
+ **kwargs,
664
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutput]:
665
+ r"""
666
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
667
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
668
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
669
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
670
+ """
671
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
672
+
673
+ transformer_outputs = self.transformer(
674
+ input_ids,
675
+ attention_mask=attention_mask,
676
+ token_type_ids=token_type_ids,
677
+ position_ids=position_ids,
678
+ head_mask=head_mask,
679
+ inputs_embeds=inputs_embeds,
680
+ output_attentions=output_attentions,
681
+ output_hidden_states=output_hidden_states,
682
+ return_dict=return_dict,
683
+ )
684
+ hidden_states = transformer_outputs[0]
685
+ lm_logits = self.lm_head(hidden_states)
686
+
687
+ loss = None
688
+ if labels is not None:
689
+ # Flatten the tokens
690
+ loss = self.loss_function(
691
+ lm_logits,
692
+ labels,
693
+ vocab_size=self.config.vocab_size,
694
+ **kwargs,
695
+ )
696
+
697
+ if not return_dict:
698
+ output = (lm_logits,) + transformer_outputs[1:]
699
+ return ((loss,) + output) if loss is not None else output
700
+
701
+ return CausalLMOutput(
702
+ loss=loss,
703
+ logits=lm_logits,
704
+ hidden_states=transformer_outputs.hidden_states,
705
+ attentions=transformer_outputs.attentions,
706
+ )
707
+
708
+ def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]:
709
+ # Overwritten -- old model with reduced inputs
710
+ return {"input_ids": input_ids}
711
+
712
+
713
+ @add_start_docstrings(
714
+ """
715
+ OpenAI GPT Model transformer with a language modeling and a multiple-choice classification head on top e.g. for
716
+ RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the
717
+ input embeddings, the classification head takes as input the input of a specified classification token index in the
718
+ input sequence).
719
+ """,
720
+ OPENAI_GPT_START_DOCSTRING,
721
+ )
722
+ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
723
+ _tied_weights_keys = ["lm_head.weight"]
724
+
725
+ def __init__(self, config):
726
+ super().__init__(config)
727
+
728
+ config.num_labels = 1
729
+ self.transformer = OpenAIGPTModel(config)
730
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
731
+ self.multiple_choice_head = OpenAIGPTSequenceSummary(config)
732
+
733
+ # Initialize weights and apply final processing
734
+ self.post_init()
735
+
736
+ def get_output_embeddings(self):
737
+ return self.lm_head
738
+
739
+ def set_output_embeddings(self, new_embeddings):
740
+ self.lm_head = new_embeddings
741
+
742
+ @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
743
+ @replace_return_docstrings(output_type=OpenAIGPTDoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
744
+ def forward(
745
+ self,
746
+ input_ids: Optional[torch.LongTensor] = None,
747
+ attention_mask: Optional[torch.FloatTensor] = None,
748
+ token_type_ids: Optional[torch.LongTensor] = None,
749
+ position_ids: Optional[torch.LongTensor] = None,
750
+ head_mask: Optional[torch.FloatTensor] = None,
751
+ inputs_embeds: Optional[torch.FloatTensor] = None,
752
+ mc_token_ids: Optional[torch.LongTensor] = None,
753
+ labels: Optional[torch.LongTensor] = None,
754
+ mc_labels: Optional[torch.LongTensor] = None,
755
+ output_attentions: Optional[bool] = None,
756
+ output_hidden_states: Optional[bool] = None,
757
+ return_dict: Optional[bool] = None,
758
+ ) -> Union[Tuple[torch.Tensor], OpenAIGPTDoubleHeadsModelOutput]:
759
+ r"""
760
+ mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
761
+ Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
762
+ 1]`.
763
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
764
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
765
+ `labels = input_ids` Indices are selected in `[-1, 0, ..., config.vocab_size]` All labels set to `-100` are
766
+ ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
767
+ mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*):
768
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
769
+ where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above)
770
+
771
+ Return:
772
+
773
+ Examples:
774
+
775
+ ```python
776
+ >>> from transformers import AutoTokenizer, OpenAIGPTDoubleHeadsModel
777
+ >>> import torch
778
+
779
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/openai-gpt")
780
+ >>> model = OpenAIGPTDoubleHeadsModel.from_pretrained("openai-community/openai-gpt")
781
+ >>> tokenizer.add_special_tokens(
782
+ ... {"cls_token": "[CLS]"}
783
+ ... ) # Add a [CLS] to the vocabulary (we should train it also!)
784
+ >>> model.resize_token_embeddings(len(tokenizer))
785
+
786
+ >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
787
+ >>> input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
788
+ >>> mc_token_ids = torch.tensor([input_ids.size(-1) - 1, input_ids.size(-1) - 1]).unsqueeze(0) # Batch size 1
789
+
790
+ >>> outputs = model(input_ids, mc_token_ids=mc_token_ids)
791
+ >>> lm_logits = outputs.logits
792
+ >>> mc_logits = outputs.mc_logits
793
+ ```"""
794
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
795
+
796
+ transformer_outputs = self.transformer(
797
+ input_ids,
798
+ attention_mask=attention_mask,
799
+ token_type_ids=token_type_ids,
800
+ position_ids=position_ids,
801
+ head_mask=head_mask,
802
+ inputs_embeds=inputs_embeds,
803
+ output_attentions=output_attentions,
804
+ output_hidden_states=output_hidden_states,
805
+ return_dict=return_dict,
806
+ )
807
+ hidden_states = transformer_outputs[0]
808
+
809
+ lm_logits = self.lm_head(hidden_states)
810
+ mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
811
+
812
+ lm_loss, mc_loss = None, None
813
+ if mc_labels is not None:
814
+ loss_fct = CrossEntropyLoss()
815
+ mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
816
+ if labels is not None:
817
+ shift_logits = lm_logits[..., :-1, :].contiguous()
818
+ shift_labels = labels[..., 1:].contiguous()
819
+ loss_fct = CrossEntropyLoss()
820
+ lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
821
+
822
+ if not return_dict:
823
+ output = (lm_logits, mc_logits) + transformer_outputs[1:]
824
+ if mc_loss is not None:
825
+ output = (mc_loss,) + output
826
+ return ((lm_loss,) + output) if lm_loss is not None else output
827
+
828
+ return OpenAIGPTDoubleHeadsModelOutput(
829
+ loss=lm_loss,
830
+ mc_loss=mc_loss,
831
+ logits=lm_logits,
832
+ mc_logits=mc_logits,
833
+ hidden_states=transformer_outputs.hidden_states,
834
+ attentions=transformer_outputs.attentions,
835
+ )
836
+
837
+
838
+ @add_start_docstrings(
839
+ """
840
+ The Original OpenAI GPT Model transformer with a sequence classification head on top (linear layer).
841
+ [`OpenAIGPTForSequenceClassification`] uses the last token in order to do the classification, as other causal
842
+ models (e.g. GPT-2) do. Since it does classification on the last token, it requires to know the position of the
843
+ last token. If a `pad_token_id` is defined in the configuration, it finds the last token that is not a padding
844
+ token in each row. If no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since
845
+ it cannot guess the padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take
846
+ the last value in each row of the batch).
847
+ """,
848
+ OPENAI_GPT_START_DOCSTRING,
849
+ )
850
+ class OpenAIGPTForSequenceClassification(OpenAIGPTPreTrainedModel):
851
+ def __init__(self, config):
852
+ super().__init__(config)
853
+ self.num_labels = config.num_labels
854
+ self.transformer = OpenAIGPTModel(config)
855
+ self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
856
+
857
+ # Initialize weights and apply final processing
858
+ self.post_init()
859
+
860
+ @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
861
+ @add_code_sample_docstrings(
862
+ checkpoint=_CHECKPOINT_FOR_DOC,
863
+ output_type=SequenceClassifierOutput,
864
+ config_class=_CONFIG_FOR_DOC,
865
+ )
866
+ def forward(
867
+ self,
868
+ input_ids: Optional[torch.LongTensor] = None,
869
+ attention_mask: Optional[torch.FloatTensor] = None,
870
+ token_type_ids: Optional[torch.LongTensor] = None,
871
+ position_ids: Optional[torch.LongTensor] = None,
872
+ head_mask: Optional[torch.FloatTensor] = None,
873
+ inputs_embeds: Optional[torch.FloatTensor] = None,
874
+ labels: Optional[torch.LongTensor] = None,
875
+ output_attentions: Optional[bool] = None,
876
+ output_hidden_states: Optional[bool] = None,
877
+ return_dict: Optional[bool] = None,
878
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
879
+ r"""
880
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
881
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
882
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
883
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
884
+ """
885
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
886
+
887
+ transformer_outputs = self.transformer(
888
+ input_ids,
889
+ attention_mask=attention_mask,
890
+ token_type_ids=token_type_ids,
891
+ position_ids=position_ids,
892
+ head_mask=head_mask,
893
+ inputs_embeds=inputs_embeds,
894
+ output_attentions=output_attentions,
895
+ output_hidden_states=output_hidden_states,
896
+ return_dict=return_dict,
897
+ )
898
+
899
+ hidden_states = transformer_outputs[0]
900
+ logits = self.score(hidden_states)
901
+
902
+ if input_ids is not None:
903
+ batch_size, sequence_length = input_ids.shape[:2]
904
+ else:
905
+ batch_size, sequence_length = inputs_embeds.shape[:2]
906
+
907
+ # Ensure the batch size is > 1 if there is no padding.
908
+ if self.config.pad_token_id is None and batch_size != 1:
909
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
910
+ if self.config.pad_token_id is None:
911
+ last_non_pad_token = -1
912
+ elif input_ids is not None:
913
+ # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
914
+ non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
915
+ token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
916
+ last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
917
+ else:
918
+ last_non_pad_token = -1
919
+ logger.warning_once(
920
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
921
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
922
+ )
923
+
924
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
925
+
926
+ loss = None
927
+ if labels is not None:
928
+ if self.config.problem_type is None:
929
+ if self.num_labels == 1:
930
+ self.config.problem_type = "regression"
931
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
932
+ self.config.problem_type = "single_label_classification"
933
+ else:
934
+ self.config.problem_type = "multi_label_classification"
935
+
936
+ if self.config.problem_type == "regression":
937
+ loss_fct = MSELoss()
938
+ if self.num_labels == 1:
939
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
940
+ else:
941
+ loss = loss_fct(pooled_logits, labels)
942
+ elif self.config.problem_type == "single_label_classification":
943
+ loss_fct = CrossEntropyLoss()
944
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
945
+ elif self.config.problem_type == "multi_label_classification":
946
+ loss_fct = BCEWithLogitsLoss()
947
+ loss = loss_fct(pooled_logits, labels)
948
+ if not return_dict:
949
+ output = (pooled_logits,) + transformer_outputs[1:]
950
+ return ((loss,) + output) if loss is not None else output
951
+
952
+ return SequenceClassifierOutput(
953
+ loss=loss,
954
+ logits=pooled_logits,
955
+ hidden_states=transformer_outputs.hidden_states,
956
+ attentions=transformer_outputs.attentions,
957
+ )
958
+
959
+
960
+ __all__ = [
961
+ "OpenAIGPTDoubleHeadsModel",
962
+ "OpenAIGPTForSequenceClassification",
963
+ "OpenAIGPTLMHeadModel",
964
+ "OpenAIGPTModel",
965
+ "OpenAIGPTPreTrainedModel",
966
+ "load_tf_weights_in_openai_gpt",
967
+ ]
docs/transformers/src/transformers/models/openai/modeling_tf_openai.py ADDED
@@ -0,0 +1,937 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """TF 2.0 OpenAI GPT model."""
17
+
18
+ from __future__ import annotations
19
+
20
+ from dataclasses import dataclass
21
+ from typing import Optional, Tuple, Union
22
+
23
+ import numpy as np
24
+ import tensorflow as tf
25
+
26
+ from ...activations_tf import get_tf_activation
27
+ from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput, TFSequenceClassifierOutput
28
+ from ...modeling_tf_utils import (
29
+ TFCausalLanguageModelingLoss,
30
+ TFConv1D,
31
+ TFModelInputType,
32
+ TFPreTrainedModel,
33
+ TFSequenceClassificationLoss,
34
+ TFSequenceSummary,
35
+ TFSharedEmbeddings,
36
+ get_initializer,
37
+ keras,
38
+ keras_serializable,
39
+ unpack_inputs,
40
+ )
41
+ from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
42
+ from ...utils import (
43
+ ModelOutput,
44
+ add_code_sample_docstrings,
45
+ add_start_docstrings,
46
+ add_start_docstrings_to_model_forward,
47
+ logging,
48
+ replace_return_docstrings,
49
+ )
50
+ from .configuration_openai import OpenAIGPTConfig
51
+
52
+
53
+ logger = logging.get_logger(__name__)
54
+
55
+ _CHECKPOINT_FOR_DOC = "openai-community/openai-gpt"
56
+ _CONFIG_FOR_DOC = "OpenAIGPTConfig"
57
+
58
+
59
+ class TFAttention(keras.layers.Layer):
60
+ def __init__(self, nx, config, scale=False, **kwargs):
61
+ super().__init__(**kwargs)
62
+
63
+ n_state = nx # in Attention: n_state=768 (nx=n_embd)
64
+ # [switch nx => n_state from Block to Attention to keep identical to TF implementation]
65
+ assert n_state % config.n_head == 0, (
66
+ f"Hidden dimension {n_state} not dividable by number of heads {config.n_head}"
67
+ )
68
+ self.n_head = config.n_head
69
+ self.split_size = n_state
70
+ self.scale = scale
71
+ self.output_attentions = config.output_attentions
72
+
73
+ self.c_attn = TFConv1D(n_state * 3, nx, initializer_range=config.initializer_range, name="c_attn")
74
+ self.c_proj = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_proj")
75
+ self.attn_dropout = keras.layers.Dropout(config.attn_pdrop)
76
+ self.resid_dropout = keras.layers.Dropout(config.resid_pdrop)
77
+ self.n_state = n_state
78
+ self.pruned_heads = set()
79
+
80
+ def prune_heads(self, heads):
81
+ pass
82
+
83
+ @staticmethod
84
+ def causal_attention_mask(nd, ns):
85
+ """
86
+ 1's in the lower triangle, counting from the lower right corner. Same as tf.matrix_band_part(tf.ones([nd, ns]),
87
+ -1, ns-nd), but doesn't produce garbage on TPUs.
88
+ """
89
+ i = tf.range(nd)[:, None]
90
+ j = tf.range(ns)
91
+ m = i >= j - ns + nd
92
+ return m
93
+
94
+ def _attn(self, q, k, v, attention_mask, head_mask, output_attentions, training=False):
95
+ # q, k, v have shape [batch, heads, sequence, features]
96
+ w = tf.matmul(q, k, transpose_b=True)
97
+ if self.scale:
98
+ dk = tf.cast(shape_list(k)[-1], dtype=w.dtype) # scale attention_scores
99
+ w = w / tf.math.sqrt(dk)
100
+
101
+ # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
102
+ _, _, nd, ns = shape_list(w)
103
+ b = tf.cast(self.causal_attention_mask(nd, ns), dtype=w.dtype)
104
+ b = tf.reshape(b, [1, 1, nd, ns])
105
+ w = w * b - 1e4 * (1 - b)
106
+
107
+ if attention_mask is not None:
108
+ # Apply the attention mask
109
+ attention_mask = tf.cast(attention_mask, dtype=w.dtype)
110
+ w = w + attention_mask
111
+
112
+ w = stable_softmax(w, axis=-1)
113
+ w = self.attn_dropout(w, training=training)
114
+
115
+ # Mask heads if we want to
116
+ if head_mask is not None:
117
+ w = w * head_mask
118
+
119
+ outputs = [tf.matmul(w, v)]
120
+ if output_attentions:
121
+ outputs.append(w)
122
+ return outputs
123
+
124
+ def merge_heads(self, x):
125
+ x = tf.transpose(x, [0, 2, 1, 3])
126
+ x_shape = shape_list(x)
127
+ new_x_shape = x_shape[:-2] + [x_shape[-2] * x_shape[-1]]
128
+ return tf.reshape(x, new_x_shape)
129
+
130
+ def split_heads(self, x):
131
+ x_shape = shape_list(x)
132
+ new_x_shape = x_shape[:-1] + [self.n_head, x_shape[-1] // self.n_head]
133
+ x = tf.reshape(x, new_x_shape)
134
+ return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features)
135
+
136
+ def call(self, x, attention_mask, head_mask, output_attentions, training=False):
137
+ x = self.c_attn(x)
138
+ query, key, value = tf.split(x, 3, axis=2)
139
+ query = self.split_heads(query)
140
+ key = self.split_heads(key)
141
+ value = self.split_heads(value)
142
+
143
+ attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions, training=training)
144
+ a = attn_outputs[0]
145
+
146
+ a = self.merge_heads(a)
147
+ a = self.c_proj(a)
148
+ a = self.resid_dropout(a, training=training)
149
+
150
+ outputs = [a] + attn_outputs[1:]
151
+ return outputs # a, (attentions)
152
+
153
+ def build(self, input_shape=None):
154
+ if self.built:
155
+ return
156
+ self.built = True
157
+ if getattr(self, "c_attn", None) is not None:
158
+ with tf.name_scope(self.c_attn.name):
159
+ self.c_attn.build([None, None, self.n_state * 3])
160
+ if getattr(self, "c_proj", None) is not None:
161
+ with tf.name_scope(self.c_proj.name):
162
+ self.c_proj.build([None, None, self.n_state])
163
+
164
+
165
+ class TFMLP(keras.layers.Layer):
166
+ def __init__(self, n_state, config, **kwargs):
167
+ super().__init__(**kwargs)
168
+ nx = config.n_embd
169
+ self.c_fc = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_fc")
170
+ self.c_proj = TFConv1D(nx, n_state, initializer_range=config.initializer_range, name="c_proj")
171
+ self.act = get_tf_activation("gelu")
172
+ self.dropout = keras.layers.Dropout(config.resid_pdrop)
173
+ self.nx = nx
174
+ self.n_state = n_state
175
+
176
+ def call(self, x, training=False):
177
+ h = self.act(self.c_fc(x))
178
+ h2 = self.c_proj(h)
179
+ h2 = self.dropout(h2, training=training)
180
+ return h2
181
+
182
+ def build(self, input_shape=None):
183
+ if self.built:
184
+ return
185
+ self.built = True
186
+ if getattr(self, "c_fc", None) is not None:
187
+ with tf.name_scope(self.c_fc.name):
188
+ self.c_fc.build([None, None, self.n_state])
189
+ if getattr(self, "c_proj", None) is not None:
190
+ with tf.name_scope(self.c_proj.name):
191
+ self.c_proj.build([None, None, self.nx])
192
+
193
+
194
+ class TFBlock(keras.layers.Layer):
195
+ def __init__(self, config, scale=False, **kwargs):
196
+ super().__init__(**kwargs)
197
+ nx = config.n_embd
198
+ self.attn = TFAttention(nx, config, scale, name="attn")
199
+ self.ln_1 = keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_1")
200
+ self.mlp = TFMLP(4 * nx, config, name="mlp")
201
+ self.ln_2 = keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_2")
202
+ self.nx = nx
203
+
204
+ def call(self, x, attention_mask, head_mask, output_attentions, training=False):
205
+ output_attn = self.attn(x, attention_mask, head_mask, output_attentions, training=training)
206
+ a = output_attn[0] # output_attn: a, (attentions)
207
+
208
+ n = self.ln_1(x + a)
209
+ m = self.mlp(n, training=training)
210
+ h = self.ln_2(n + m)
211
+
212
+ outputs = [h] + output_attn[1:]
213
+ return outputs # x, (attentions)
214
+
215
+ def build(self, input_shape=None):
216
+ if self.built:
217
+ return
218
+ self.built = True
219
+ if getattr(self, "attn", None) is not None:
220
+ with tf.name_scope(self.attn.name):
221
+ self.attn.build(None)
222
+ if getattr(self, "ln_1", None) is not None:
223
+ with tf.name_scope(self.ln_1.name):
224
+ self.ln_1.build([None, None, self.nx])
225
+ if getattr(self, "mlp", None) is not None:
226
+ with tf.name_scope(self.mlp.name):
227
+ self.mlp.build(None)
228
+ if getattr(self, "ln_2", None) is not None:
229
+ with tf.name_scope(self.ln_2.name):
230
+ self.ln_2.build([None, None, self.nx])
231
+
232
+
233
+ @keras_serializable
234
+ class TFOpenAIGPTMainLayer(keras.layers.Layer):
235
+ config_class = OpenAIGPTConfig
236
+
237
+ def __init__(self, config, *inputs, **kwargs):
238
+ super().__init__(*inputs, **kwargs)
239
+
240
+ self.config = config
241
+ self.output_hidden_states = config.output_hidden_states
242
+ self.output_attentions = config.output_attentions
243
+ self.return_dict = config.use_return_dict
244
+ self.num_hidden_layers = config.n_layer
245
+ self.n_embd = config.n_embd
246
+ self.n_positions = config.n_positions
247
+ self.initializer_range = config.initializer_range
248
+
249
+ self.tokens_embed = TFSharedEmbeddings(
250
+ config.vocab_size, config.n_embd, initializer_range=config.initializer_range, name="tokens_embed"
251
+ )
252
+ self.drop = keras.layers.Dropout(config.embd_pdrop)
253
+ self.h = [TFBlock(config, scale=True, name=f"h_._{i}") for i in range(config.n_layer)]
254
+
255
+ def build(self, input_shape=None):
256
+ with tf.name_scope("positions_embed"):
257
+ self.positions_embed = self.add_weight(
258
+ name="embeddings",
259
+ shape=[self.n_positions, self.n_embd],
260
+ initializer=get_initializer(self.initializer_range),
261
+ )
262
+
263
+ if self.built:
264
+ return
265
+ self.built = True
266
+ if getattr(self, "tokens_embed", None) is not None:
267
+ with tf.name_scope(self.tokens_embed.name):
268
+ self.tokens_embed.build(None)
269
+ if getattr(self, "h", None) is not None:
270
+ for layer in self.h:
271
+ with tf.name_scope(layer.name):
272
+ layer.build(None)
273
+
274
+ def get_input_embeddings(self):
275
+ return self.tokens_embed
276
+
277
+ def set_input_embeddings(self, value):
278
+ self.tokens_embed.weight = value
279
+ self.tokens_embed.vocab_size = shape_list(value)[0]
280
+
281
+ def _prune_heads(self, heads_to_prune):
282
+ """
283
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
284
+ """
285
+ raise NotImplementedError
286
+
287
+ @unpack_inputs
288
+ def call(
289
+ self,
290
+ input_ids: TFModelInputType | None = None,
291
+ attention_mask: np.ndarray | tf.Tensor | None = None,
292
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
293
+ position_ids: np.ndarray | tf.Tensor | None = None,
294
+ head_mask: np.ndarray | tf.Tensor | None = None,
295
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
296
+ output_attentions: Optional[bool] = None,
297
+ output_hidden_states: Optional[bool] = None,
298
+ return_dict: Optional[bool] = None,
299
+ training: Optional[bool] = False,
300
+ ) -> Union[Tuple, TFBaseModelOutput]:
301
+ if input_ids is not None and inputs_embeds is not None:
302
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
303
+ elif input_ids is not None:
304
+ input_shape = shape_list(input_ids)
305
+ input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
306
+ elif inputs_embeds is not None:
307
+ input_shape = shape_list(inputs_embeds)[:-1]
308
+ else:
309
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
310
+
311
+ if position_ids is None:
312
+ position_ids = tf.expand_dims(tf.range(input_shape[-1]), axis=0)
313
+
314
+ if attention_mask is not None:
315
+ # We create a 3D attention mask from a 2D tensor mask.
316
+ # Sizes are [batch_size, 1, 1, to_seq_length]
317
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
318
+ # this attention mask is more simple than the triangular masking of causal attention
319
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
320
+ attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1]))
321
+
322
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
323
+ # masked positions, this operation will create a tensor which is 0.0 for
324
+ # positions we want to attend and -10000.0 for masked positions.
325
+ # Since we are adding it to the raw scores before the softmax, this is
326
+ # effectively the same as removing these entirely.
327
+
328
+ one_cst = tf.constant(1.0)
329
+ attention_mask = tf.cast(attention_mask, dtype=one_cst.dtype)
330
+ attention_mask = tf.multiply(tf.subtract(one_cst, attention_mask), tf.constant(-10000.0))
331
+ else:
332
+ attention_mask = None
333
+
334
+ # Prepare head mask if needed
335
+ # 1.0 in head_mask indicate we keep the head
336
+ # attention_probs has shape bsz x n_heads x N x N
337
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
338
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
339
+ if head_mask is not None:
340
+ raise NotImplementedError
341
+ else:
342
+ head_mask = [None] * self.num_hidden_layers
343
+ # head_mask = tf.constant([0] * self.num_hidden_layers)
344
+
345
+ position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
346
+
347
+ if inputs_embeds is None:
348
+ check_embeddings_within_bounds(input_ids, self.config.vocab_size)
349
+ inputs_embeds = self.tokens_embed(input_ids, mode="embedding")
350
+ position_embeds = tf.gather(self.positions_embed, position_ids)
351
+ if token_type_ids is not None:
352
+ token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
353
+ check_embeddings_within_bounds(token_type_ids, self.config.vocab_size, "token_type_ids")
354
+ token_type_embeds = self.tokens_embed(token_type_ids, mode="embedding")
355
+ else:
356
+ token_type_embeds = 0
357
+ hidden_states = inputs_embeds + position_embeds + token_type_embeds
358
+ hidden_states = self.drop(hidden_states, training=training)
359
+
360
+ output_shape = input_shape + [shape_list(hidden_states)[-1]]
361
+
362
+ all_attentions = () if output_attentions else None
363
+ all_hidden_states = () if output_hidden_states else None
364
+ for i, block in enumerate(self.h):
365
+ if output_hidden_states:
366
+ all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
367
+
368
+ outputs = block(
369
+ hidden_states,
370
+ attention_mask,
371
+ head_mask[i],
372
+ output_attentions,
373
+ training=training,
374
+ )
375
+ hidden_states = outputs[0]
376
+ if output_attentions:
377
+ all_attentions = all_attentions + (outputs[1],)
378
+
379
+ hidden_states = tf.reshape(hidden_states, output_shape)
380
+ # Add last hidden state
381
+ if output_hidden_states:
382
+ all_hidden_states = all_hidden_states + (hidden_states,)
383
+
384
+ if output_attentions:
385
+ # let the number of heads free (-1) so we can extract attention even after head pruning
386
+ attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
387
+ all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
388
+
389
+ if not return_dict:
390
+ return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
391
+
392
+ return TFBaseModelOutput(
393
+ last_hidden_state=hidden_states,
394
+ hidden_states=all_hidden_states,
395
+ attentions=all_attentions,
396
+ )
397
+
398
+
399
+ class TFOpenAIGPTPreTrainedModel(TFPreTrainedModel):
400
+ """
401
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
402
+ models.
403
+ """
404
+
405
+ config_class = OpenAIGPTConfig
406
+ base_model_prefix = "transformer"
407
+
408
+
409
+ @dataclass
410
+ class TFOpenAIGPTDoubleHeadsModelOutput(ModelOutput):
411
+ """
412
+ Base class for outputs of models predicting if two sentences are consecutive or not.
413
+
414
+ Args:
415
+ logits (`tf.Tensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`):
416
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
417
+ mc_logits (`tf.Tensor` of shape `(batch_size, num_choices)`):
418
+ Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
419
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
420
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
421
+ `(batch_size, sequence_length, hidden_size)`.
422
+
423
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
424
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
425
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
426
+ sequence_length)`.
427
+
428
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
429
+ heads.
430
+ """
431
+
432
+ logits: Optional[tf.Tensor] = None
433
+ mc_logits: Optional[tf.Tensor] = None
434
+ hidden_states: Tuple[tf.Tensor] | None = None
435
+ attentions: Tuple[tf.Tensor] | None = None
436
+
437
+
438
+ OPENAI_GPT_START_DOCSTRING = r"""
439
+
440
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
441
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
442
+ etc.)
443
+
444
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
445
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
446
+ behavior.
447
+
448
+ <Tip>
449
+
450
+ TensorFlow models and layers in `transformers` accept two formats as input:
451
+
452
+ - having all inputs as keyword arguments (like PyTorch models), or
453
+ - having all inputs as a list, tuple or dict in the first positional argument.
454
+
455
+ The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
456
+ and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
457
+ pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
458
+ format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
459
+ the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
460
+ positional argument:
461
+
462
+ - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
463
+ - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
464
+ `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
465
+ - a dictionary with one or several input Tensors associated to the input names given in the docstring:
466
+ `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
467
+
468
+ Note that when creating models and layers with
469
+ [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
470
+ about any of this, as you can just pass inputs like you would to any other Python function!
471
+
472
+ </Tip>
473
+
474
+ Parameters:
475
+ config ([`OpenAIGPTConfig`]): Model configuration class with all the parameters of the model.
476
+ Initializing with a config file does not load the weights associated with the model, only the
477
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
478
+ """
479
+
480
+ OPENAI_GPT_INPUTS_DOCSTRING = r"""
481
+ Args:
482
+ input_ids (`Numpy array` or `tf.Tensor` of shape `(batch_size, sequence_length)`):
483
+ Indices of input sequence tokens in the vocabulary.
484
+
485
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
486
+ [`PreTrainedTokenizer.encode`] for details.
487
+
488
+ [What are input IDs?](../glossary#input-ids)
489
+ attention_mask (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*):
490
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
491
+
492
+ - 1 for tokens that are **not masked**,
493
+ - 0 for tokens that are **masked**.
494
+
495
+ [What are attention masks?](../glossary#attention-mask)
496
+ token_type_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*):
497
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
498
+ 1]`:
499
+
500
+ - 0 corresponds to a *sentence A* token,
501
+ - 1 corresponds to a *sentence B* token.
502
+
503
+ [What are token type IDs?](../glossary#token-type-ids)
504
+ position_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*):
505
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
506
+ config.max_position_embeddings - 1]`.
507
+
508
+ [What are position IDs?](../glossary#position-ids)
509
+ head_mask (`tf.Tensor` or `Numpy array` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
510
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
511
+
512
+ - 1 indicates the head is **not masked**,
513
+ - 0 indicates the head is **masked**.
514
+
515
+ inputs_embeds (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
516
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
517
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
518
+ model's internal embedding lookup matrix.
519
+ output_attentions (`bool`, *optional*):
520
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
521
+ tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
522
+ config will be used instead.
523
+ output_hidden_states (`bool`, *optional*):
524
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
525
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
526
+ used instead.
527
+ return_dict (`bool`, *optional*):
528
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
529
+ eager mode, in graph mode the value will always be set to True.
530
+ training (`bool`, *optional*, defaults to `False`):
531
+ Whether or not to use the model in training mode (some modules like dropout modules have different
532
+ behaviors between training and evaluation).
533
+ """
534
+
535
+
536
+ @add_start_docstrings(
537
+ "The bare OpenAI GPT transformer model outputting raw hidden-states without any specific head on top.",
538
+ OPENAI_GPT_START_DOCSTRING,
539
+ )
540
+ class TFOpenAIGPTModel(TFOpenAIGPTPreTrainedModel):
541
+ def __init__(self, config, *inputs, **kwargs):
542
+ super().__init__(config, *inputs, **kwargs)
543
+ self.transformer = TFOpenAIGPTMainLayer(config, name="transformer")
544
+
545
+ @unpack_inputs
546
+ @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
547
+ @add_code_sample_docstrings(
548
+ checkpoint=_CHECKPOINT_FOR_DOC,
549
+ output_type=TFBaseModelOutput,
550
+ config_class=_CONFIG_FOR_DOC,
551
+ )
552
+ def call(
553
+ self,
554
+ input_ids: TFModelInputType | None = None,
555
+ attention_mask: np.ndarray | tf.Tensor | None = None,
556
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
557
+ position_ids: np.ndarray | tf.Tensor | None = None,
558
+ head_mask: np.ndarray | tf.Tensor | None = None,
559
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
560
+ output_attentions: Optional[bool] = None,
561
+ output_hidden_states: Optional[bool] = None,
562
+ return_dict: Optional[bool] = None,
563
+ training: Optional[bool] = False,
564
+ ) -> Union[Tuple, TFBaseModelOutput]:
565
+ outputs = self.transformer(
566
+ input_ids=input_ids,
567
+ attention_mask=attention_mask,
568
+ token_type_ids=token_type_ids,
569
+ position_ids=position_ids,
570
+ head_mask=head_mask,
571
+ inputs_embeds=inputs_embeds,
572
+ output_attentions=output_attentions,
573
+ output_hidden_states=output_hidden_states,
574
+ return_dict=return_dict,
575
+ training=training,
576
+ )
577
+ return outputs
578
+
579
+ def build(self, input_shape=None):
580
+ if self.built:
581
+ return
582
+ self.built = True
583
+ if getattr(self, "transformer", None) is not None:
584
+ with tf.name_scope(self.transformer.name):
585
+ self.transformer.build(None)
586
+
587
+
588
+ @add_start_docstrings(
589
+ """
590
+ OpenAI GPT Model transformer with a language modeling head on top (linear layer with weights tied to the input
591
+ embeddings).
592
+ """,
593
+ OPENAI_GPT_START_DOCSTRING,
594
+ )
595
+ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelingLoss):
596
+ def __init__(self, config, *inputs, **kwargs):
597
+ super().__init__(config, *inputs, **kwargs)
598
+ self.transformer = TFOpenAIGPTMainLayer(config, name="transformer")
599
+ # OpenAIGPT does not have past caching features
600
+ self.supports_xla_generation = False
601
+
602
+ def get_output_embeddings(self):
603
+ return self.get_input_embeddings()
604
+
605
+ def set_output_embeddings(self, value):
606
+ self.set_input_embeddings(value)
607
+
608
+ @unpack_inputs
609
+ @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
610
+ @add_code_sample_docstrings(
611
+ checkpoint=_CHECKPOINT_FOR_DOC,
612
+ output_type=TFCausalLMOutput,
613
+ config_class=_CONFIG_FOR_DOC,
614
+ )
615
+ def call(
616
+ self,
617
+ input_ids: TFModelInputType | None = None,
618
+ attention_mask: np.ndarray | tf.Tensor | None = None,
619
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
620
+ position_ids: np.ndarray | tf.Tensor | None = None,
621
+ head_mask: np.ndarray | tf.Tensor | None = None,
622
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
623
+ output_attentions: Optional[bool] = None,
624
+ output_hidden_states: Optional[bool] = None,
625
+ return_dict: Optional[bool] = None,
626
+ labels: np.ndarray | tf.Tensor | None = None,
627
+ training: Optional[bool] = False,
628
+ ) -> Union[Tuple, TFCausalLMOutput]:
629
+ r"""
630
+ labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
631
+ Labels for computing the cross entropy classification loss. Indices should be in `[0, ...,
632
+ config.vocab_size - 1]`.
633
+ """
634
+
635
+ transformer_outputs = self.transformer(
636
+ input_ids=input_ids,
637
+ attention_mask=attention_mask,
638
+ token_type_ids=token_type_ids,
639
+ position_ids=position_ids,
640
+ head_mask=head_mask,
641
+ inputs_embeds=inputs_embeds,
642
+ output_attentions=output_attentions,
643
+ output_hidden_states=output_hidden_states,
644
+ return_dict=return_dict,
645
+ training=training,
646
+ )
647
+ hidden_states = transformer_outputs[0]
648
+
649
+ logits = self.transformer.tokens_embed(hidden_states, mode="linear")
650
+
651
+ loss = None
652
+ if labels is not None:
653
+ # shift labels to the left and cut last logit token
654
+ shifted_logits = logits[:, :-1]
655
+ labels = labels[:, 1:]
656
+ loss = self.hf_compute_loss(labels, shifted_logits)
657
+
658
+ if not return_dict:
659
+ output = (logits,) + transformer_outputs[1:]
660
+ return ((loss,) + output) if loss is not None else output
661
+
662
+ return TFCausalLMOutput(
663
+ loss=loss,
664
+ logits=logits,
665
+ hidden_states=transformer_outputs.hidden_states,
666
+ attentions=transformer_outputs.attentions,
667
+ )
668
+
669
+ def prepare_inputs_for_generation(self, inputs, **kwargs):
670
+ return {"input_ids": inputs}
671
+
672
+ def build(self, input_shape=None):
673
+ if self.built:
674
+ return
675
+ self.built = True
676
+ if getattr(self, "transformer", None) is not None:
677
+ with tf.name_scope(self.transformer.name):
678
+ self.transformer.build(None)
679
+
680
+
681
+ @add_start_docstrings(
682
+ """
683
+ OpenAI GPT Model transformer with a language modeling and a multiple-choice classification head on top e.g. for
684
+ RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the
685
+ input embeddings, the classification head takes as input the input of a specified classification token index in the
686
+ input sequence).
687
+ """,
688
+ OPENAI_GPT_START_DOCSTRING,
689
+ )
690
+ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
691
+ def __init__(self, config, *inputs, **kwargs):
692
+ super().__init__(config, *inputs, **kwargs)
693
+ config.num_labels = 1
694
+ self.transformer = TFOpenAIGPTMainLayer(config, name="transformer")
695
+ self.multiple_choice_head = TFSequenceSummary(
696
+ config, initializer_range=config.initializer_range, name="multiple_choice_head"
697
+ )
698
+
699
+ @unpack_inputs
700
+ @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
701
+ @replace_return_docstrings(output_type=TFOpenAIGPTDoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
702
+ def call(
703
+ self,
704
+ input_ids: TFModelInputType | None = None,
705
+ attention_mask: np.ndarray | tf.Tensor | None = None,
706
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
707
+ position_ids: np.ndarray | tf.Tensor | None = None,
708
+ head_mask: np.ndarray | tf.Tensor | None = None,
709
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
710
+ mc_token_ids: np.ndarray | tf.Tensor | None = None,
711
+ output_attentions: Optional[bool] = None,
712
+ output_hidden_states: Optional[bool] = None,
713
+ return_dict: Optional[bool] = None,
714
+ training: Optional[bool] = False,
715
+ ) -> Union[Tuple, TFOpenAIGPTDoubleHeadsModelOutput]:
716
+ r"""
717
+ mc_token_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
718
+ Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
719
+ 1]`.
720
+
721
+ Return:
722
+
723
+ Examples:
724
+
725
+ ```python
726
+ >>> import tensorflow as tf
727
+ >>> from transformers import AutoTokenizer, TFOpenAIGPTDoubleHeadsModel
728
+
729
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/openai-gpt")
730
+ >>> model = TFOpenAIGPTDoubleHeadsModel.from_pretrained("openai-community/openai-gpt")
731
+
732
+ >>> # Add a [CLS] to the vocabulary (we should train it also!)
733
+ >>> tokenizer.add_special_tokens({"cls_token": "[CLS]"})
734
+ >>> model.resize_token_embeddings(len(tokenizer)) # Update the model embeddings with the new vocabulary size
735
+ >>> print(tokenizer.cls_token_id, len(tokenizer)) # The newly token the last token of the vocabulary
736
+
737
+ >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
738
+ >>> encoding = tokenizer(choices, return_tensors="tf")
739
+ >>> inputs = {k: tf.expand_dims(v, 0) for k, v in encoding.items()}
740
+ >>> inputs["mc_token_ids"] = tf.constant(
741
+ ... [inputs["input_ids"].shape[-1] - 1, inputs["input_ids"].shape[-1] - 1]
742
+ ... )[
743
+ ... None, :
744
+ ... ] # Batch size 1
745
+ >>> outputs = model(inputs)
746
+ >>> lm_prediction_scores, mc_prediction_scores = outputs[:2]
747
+ ```"""
748
+
749
+ if input_ids is not None:
750
+ input_shapes = shape_list(input_ids)
751
+ else:
752
+ input_shapes = shape_list(inputs_embeds)[:-1]
753
+
754
+ seq_length = input_shapes[-1]
755
+ flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
756
+ flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
757
+ flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
758
+ flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
759
+ transformer_outputs = self.transformer(
760
+ flat_input_ids,
761
+ flat_attention_mask,
762
+ flat_token_type_ids,
763
+ flat_position_ids,
764
+ head_mask,
765
+ inputs_embeds,
766
+ output_attentions,
767
+ output_hidden_states,
768
+ return_dict=return_dict,
769
+ training=training,
770
+ )
771
+ hidden_states = transformer_outputs[0]
772
+ hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:])
773
+ if return_dict and output_hidden_states:
774
+ # We do this to match the slightly odd PT behaviour - the final hidden state is reshaped to rank 4 when the
775
+ # input is rank 3, but all other hidden states remain at rank-3 (with the first 2 dims merged)
776
+ all_hidden_states = transformer_outputs.hidden_states[:-1] + (hidden_states,)
777
+ else:
778
+ all_hidden_states = None
779
+ lm_logits = self.transformer.tokens_embed(hidden_states, mode="linear")
780
+ mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training)
781
+ mc_logits = tf.squeeze(mc_logits, axis=-1)
782
+
783
+ if not return_dict:
784
+ return (lm_logits, mc_logits) + transformer_outputs[1:]
785
+
786
+ return TFOpenAIGPTDoubleHeadsModelOutput(
787
+ logits=lm_logits,
788
+ mc_logits=mc_logits,
789
+ hidden_states=all_hidden_states,
790
+ attentions=transformer_outputs.attentions,
791
+ )
792
+
793
+ @property
794
+ def input_signature(self):
795
+ return {
796
+ "input_ids": tf.TensorSpec((None, None, None), tf.int32, name="input_ids"),
797
+ "attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="attention_mask"),
798
+ "mc_token_ids": tf.TensorSpec((None, None), tf.int32, name="token_type_ids"),
799
+ }
800
+
801
+ def build(self, input_shape=None):
802
+ if self.built:
803
+ return
804
+ self.built = True
805
+ if getattr(self, "transformer", None) is not None:
806
+ with tf.name_scope(self.transformer.name):
807
+ self.transformer.build(None)
808
+ if getattr(self, "multiple_choice_head", None) is not None:
809
+ with tf.name_scope(self.multiple_choice_head.name):
810
+ self.multiple_choice_head.build(None)
811
+
812
+
813
+ @add_start_docstrings(
814
+ """
815
+ The OpenAI GPT Model transformer with a sequence classification head on top (linear layer).
816
+
817
+ [`TFOpenAIGPTForSequenceClassification`] uses the last token in order to do the classification, as other causal
818
+ models (e.g. GPT-2) do.
819
+
820
+ Since it does classification on the last token, it requires to know the position of the last token. If a
821
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
822
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
823
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
824
+ each row of the batch).
825
+ """,
826
+ OPENAI_GPT_START_DOCSTRING,
827
+ )
828
+ class TFOpenAIGPTForSequenceClassification(TFOpenAIGPTPreTrainedModel, TFSequenceClassificationLoss):
829
+ def __init__(self, config, *inputs, **kwargs):
830
+ super().__init__(config, *inputs, **kwargs)
831
+ self.num_labels = config.num_labels
832
+ self.score = keras.layers.Dense(
833
+ config.num_labels,
834
+ kernel_initializer=get_initializer(config.initializer_range),
835
+ name="score",
836
+ use_bias=False,
837
+ )
838
+ self.transformer = TFOpenAIGPTMainLayer(config, name="transformer")
839
+ self.config = config
840
+
841
+ @unpack_inputs
842
+ @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
843
+ @add_code_sample_docstrings(
844
+ checkpoint=_CHECKPOINT_FOR_DOC,
845
+ output_type=TFSequenceClassifierOutput,
846
+ config_class=_CONFIG_FOR_DOC,
847
+ )
848
+ def call(
849
+ self,
850
+ input_ids: TFModelInputType | None = None,
851
+ attention_mask: np.ndarray | tf.Tensor | None = None,
852
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
853
+ position_ids: np.ndarray | tf.Tensor | None = None,
854
+ head_mask: np.ndarray | tf.Tensor | None = None,
855
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
856
+ output_attentions: Optional[bool] = None,
857
+ output_hidden_states: Optional[bool] = None,
858
+ return_dict: Optional[bool] = None,
859
+ labels: np.ndarray | tf.Tensor | None = None,
860
+ training: Optional[bool] = False,
861
+ ) -> Union[Tuple, TFSequenceClassifierOutput]:
862
+ r"""
863
+ labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
864
+ Labels for computing the cross entropy classification loss. Indices should be in `[0, ...,
865
+ config.vocab_size - 1]`.
866
+ """
867
+ transformer_outputs = self.transformer(
868
+ input_ids=input_ids,
869
+ attention_mask=attention_mask,
870
+ token_type_ids=token_type_ids,
871
+ position_ids=position_ids,
872
+ head_mask=head_mask,
873
+ inputs_embeds=inputs_embeds,
874
+ output_attentions=output_attentions,
875
+ output_hidden_states=output_hidden_states,
876
+ return_dict=return_dict,
877
+ training=training,
878
+ )
879
+ hidden_states = transformer_outputs[0]
880
+ logits = self.score(hidden_states)
881
+ logits_shape = shape_list(logits)
882
+ batch_size = logits_shape[0]
883
+
884
+ if self.config.pad_token_id is None:
885
+ last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1)
886
+ else:
887
+ if input_ids is not None:
888
+ token_indices = tf.range(shape_list(input_ids)[-1])
889
+ non_pad_mask = tf.cast(input_ids != self.config.pad_token_id, token_indices.dtype)
890
+ last_non_pad_token = tf.reduce_max(token_indices * non_pad_mask, axis=-1)
891
+ else:
892
+ last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1)
893
+ logger.warning_once(
894
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
895
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
896
+ )
897
+ loss = None
898
+
899
+ pooled_logits = tf.gather(logits, last_non_pad_token, batch_dims=1, axis=1)
900
+
901
+ if labels is not None:
902
+ if self.config.pad_token_id is None and logits_shape[0] != 1:
903
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
904
+
905
+ loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(pooled_logits, [-1, self.num_labels]))
906
+
907
+ if not return_dict:
908
+ output = (pooled_logits,) + transformer_outputs[1:]
909
+ return ((loss,) + output) if loss is not None else output
910
+
911
+ return TFSequenceClassifierOutput(
912
+ loss=loss,
913
+ logits=pooled_logits,
914
+ hidden_states=transformer_outputs.hidden_states,
915
+ attentions=transformer_outputs.attentions,
916
+ )
917
+
918
+ def build(self, input_shape=None):
919
+ if self.built:
920
+ return
921
+ self.built = True
922
+ if getattr(self, "score", None) is not None:
923
+ with tf.name_scope(self.score.name):
924
+ self.score.build([None, None, self.config.n_embd])
925
+ if getattr(self, "transformer", None) is not None:
926
+ with tf.name_scope(self.transformer.name):
927
+ self.transformer.build(None)
928
+
929
+
930
+ __all__ = [
931
+ "TFOpenAIGPTDoubleHeadsModel",
932
+ "TFOpenAIGPTForSequenceClassification",
933
+ "TFOpenAIGPTLMHeadModel",
934
+ "TFOpenAIGPTMainLayer",
935
+ "TFOpenAIGPTModel",
936
+ "TFOpenAIGPTPreTrainedModel",
937
+ ]
docs/transformers/src/transformers/models/openai/tokenization_openai.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for OpenAI GPT."""
16
+
17
+ import json
18
+ import os
19
+ import re
20
+ import unicodedata
21
+ from typing import Optional, Tuple
22
+
23
+ from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
24
+ from ...utils import logging
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+ VOCAB_FILES_NAMES = {
30
+ "vocab_file": "vocab.json",
31
+ "merges_file": "merges.txt",
32
+ }
33
+
34
+
35
+ # Copied from transformers.models.bert.tokenization_bert.whitespace_tokenize
36
+ def whitespace_tokenize(text):
37
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
38
+ text = text.strip()
39
+ if not text:
40
+ return []
41
+ tokens = text.split()
42
+ return tokens
43
+
44
+
45
+ # Copied from transformers.models.bert.tokenization_bert.BasicTokenizer
46
+ class BasicTokenizer:
47
+ """
48
+ Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
49
+
50
+ Args:
51
+ do_lower_case (`bool`, *optional*, defaults to `True`):
52
+ Whether or not to lowercase the input when tokenizing.
53
+ never_split (`Iterable`, *optional*):
54
+ Collection of tokens which will never be split during tokenization. Only has an effect when
55
+ `do_basic_tokenize=True`
56
+ tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
57
+ Whether or not to tokenize Chinese characters.
58
+
59
+ This should likely be deactivated for Japanese (see this
60
+ [issue](https://github.com/huggingface/transformers/issues/328)).
61
+ strip_accents (`bool`, *optional*):
62
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
63
+ value for `lowercase` (as in the original BERT).
64
+ do_split_on_punc (`bool`, *optional*, defaults to `True`):
65
+ In some instances we want to skip the basic punctuation splitting so that later tokenization can capture
66
+ the full context of the words, such as contractions.
67
+ """
68
+
69
+ def __init__(
70
+ self,
71
+ do_lower_case=True,
72
+ never_split=None,
73
+ tokenize_chinese_chars=True,
74
+ strip_accents=None,
75
+ do_split_on_punc=True,
76
+ ):
77
+ if never_split is None:
78
+ never_split = []
79
+ self.do_lower_case = do_lower_case
80
+ self.never_split = set(never_split)
81
+ self.tokenize_chinese_chars = tokenize_chinese_chars
82
+ self.strip_accents = strip_accents
83
+ self.do_split_on_punc = do_split_on_punc
84
+
85
+ def tokenize(self, text, never_split=None):
86
+ """
87
+ Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer.
88
+
89
+ Args:
90
+ never_split (`List[str]`, *optional*)
91
+ Kept for backward compatibility purposes. Now implemented directly at the base class level (see
92
+ [`PreTrainedTokenizer.tokenize`]) List of token not to split.
93
+ """
94
+ # union() returns a new set by concatenating the two sets.
95
+ never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
96
+ text = self._clean_text(text)
97
+
98
+ # This was added on November 1st, 2018 for the multilingual and Chinese
99
+ # models. This is also applied to the English models now, but it doesn't
100
+ # matter since the English models were not trained on any Chinese data
101
+ # and generally don't have any Chinese data in them (there are Chinese
102
+ # characters in the vocabulary because Wikipedia does have some Chinese
103
+ # words in the English Wikipedia.).
104
+ if self.tokenize_chinese_chars:
105
+ text = self._tokenize_chinese_chars(text)
106
+ # prevents treating the same character with different unicode codepoints as different characters
107
+ unicode_normalized_text = unicodedata.normalize("NFC", text)
108
+ orig_tokens = whitespace_tokenize(unicode_normalized_text)
109
+ split_tokens = []
110
+ for token in orig_tokens:
111
+ if token not in never_split:
112
+ if self.do_lower_case:
113
+ token = token.lower()
114
+ if self.strip_accents is not False:
115
+ token = self._run_strip_accents(token)
116
+ elif self.strip_accents:
117
+ token = self._run_strip_accents(token)
118
+ split_tokens.extend(self._run_split_on_punc(token, never_split))
119
+
120
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
121
+ return output_tokens
122
+
123
+ def _run_strip_accents(self, text):
124
+ """Strips accents from a piece of text."""
125
+ text = unicodedata.normalize("NFD", text)
126
+ output = []
127
+ for char in text:
128
+ cat = unicodedata.category(char)
129
+ if cat == "Mn":
130
+ continue
131
+ output.append(char)
132
+ return "".join(output)
133
+
134
+ def _run_split_on_punc(self, text, never_split=None):
135
+ """Splits punctuation on a piece of text."""
136
+ if not self.do_split_on_punc or (never_split is not None and text in never_split):
137
+ return [text]
138
+ chars = list(text)
139
+ i = 0
140
+ start_new_word = True
141
+ output = []
142
+ while i < len(chars):
143
+ char = chars[i]
144
+ if _is_punctuation(char):
145
+ output.append([char])
146
+ start_new_word = True
147
+ else:
148
+ if start_new_word:
149
+ output.append([])
150
+ start_new_word = False
151
+ output[-1].append(char)
152
+ i += 1
153
+
154
+ return ["".join(x) for x in output]
155
+
156
+ def _tokenize_chinese_chars(self, text):
157
+ """Adds whitespace around any CJK character."""
158
+ output = []
159
+ for char in text:
160
+ cp = ord(char)
161
+ if self._is_chinese_char(cp):
162
+ output.append(" ")
163
+ output.append(char)
164
+ output.append(" ")
165
+ else:
166
+ output.append(char)
167
+ return "".join(output)
168
+
169
+ def _is_chinese_char(self, cp):
170
+ """Checks whether CP is the codepoint of a CJK character."""
171
+ # This defines a "chinese character" as anything in the CJK Unicode block:
172
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
173
+ #
174
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
175
+ # despite its name. The modern Korean Hangul alphabet is a different block,
176
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
177
+ # space-separated words, so they are not treated specially and handled
178
+ # like the all of the other languages.
179
+ if (
180
+ (cp >= 0x4E00 and cp <= 0x9FFF)
181
+ or (cp >= 0x3400 and cp <= 0x4DBF) #
182
+ or (cp >= 0x20000 and cp <= 0x2A6DF) #
183
+ or (cp >= 0x2A700 and cp <= 0x2B73F) #
184
+ or (cp >= 0x2B740 and cp <= 0x2B81F) #
185
+ or (cp >= 0x2B820 and cp <= 0x2CEAF) #
186
+ or (cp >= 0xF900 and cp <= 0xFAFF)
187
+ or (cp >= 0x2F800 and cp <= 0x2FA1F) #
188
+ ): #
189
+ return True
190
+
191
+ return False
192
+
193
+ def _clean_text(self, text):
194
+ """Performs invalid character removal and whitespace cleanup on text."""
195
+ output = []
196
+ for char in text:
197
+ cp = ord(char)
198
+ if cp == 0 or cp == 0xFFFD or _is_control(char):
199
+ continue
200
+ if _is_whitespace(char):
201
+ output.append(" ")
202
+ else:
203
+ output.append(char)
204
+ return "".join(output)
205
+
206
+
207
+ def get_pairs(word):
208
+ """
209
+ Return set of symbol pairs in a word. word is represented as tuple of symbols (symbols being variable-length
210
+ strings)
211
+ """
212
+ pairs = set()
213
+ prev_char = word[0]
214
+ for char in word[1:]:
215
+ pairs.add((prev_char, char))
216
+ prev_char = char
217
+ return pairs
218
+
219
+
220
+ def text_standardize(text):
221
+ """
222
+ fixes some issues the spacy tokenizer had on books corpus also does some whitespace standardization
223
+ """
224
+ text = text.replace("—", "-")
225
+ text = text.replace("–", "-")
226
+ text = text.replace("―", "-")
227
+ text = text.replace("…", "...")
228
+ text = text.replace("´", "'")
229
+ text = re.sub(r"""(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)""", r" \1 ", text)
230
+ text = re.sub(r"\s*\n\s*", " \n ", text)
231
+ text = re.sub(r"[^\S\n]+", " ", text)
232
+ return text.strip()
233
+
234
+
235
+ class OpenAIGPTTokenizer(PreTrainedTokenizer):
236
+ """
237
+ Construct a GPT Tokenizer. Based on Byte-Pair-Encoding with the following peculiarities:
238
+
239
+ - lowercases all inputs,
240
+ - uses `SpaCy` tokenizer and `ftfy` for pre-BPE tokenization if they are installed, fallback to BERT's
241
+ `BasicTokenizer` if not.
242
+
243
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
244
+ this superclass for more information regarding those methods.
245
+
246
+ Args:
247
+ vocab_file (`str`):
248
+ Path to the vocabulary file.
249
+ merges_file (`str`):
250
+ Path to the merges file.
251
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
252
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
253
+ token instead.
254
+ """
255
+
256
+ vocab_files_names = VOCAB_FILES_NAMES
257
+ model_input_names = ["input_ids", "attention_mask"]
258
+
259
+ def __init__(self, vocab_file, merges_file, unk_token="<unk>", **kwargs):
260
+ try:
261
+ import ftfy
262
+ from spacy.lang.en import English
263
+
264
+ _nlp = English()
265
+ self.nlp = _nlp.tokenizer
266
+ self.fix_text = ftfy.fix_text
267
+ except ImportError:
268
+ logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.")
269
+ self.nlp = BasicTokenizer(do_lower_case=True)
270
+ self.fix_text = None
271
+
272
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
273
+ self.encoder = json.load(vocab_handle)
274
+ self.decoder = {v: k for k, v in self.encoder.items()}
275
+ with open(merges_file, encoding="utf-8") as merges_handle:
276
+ merges = merges_handle.read().split("\n")[1:-1]
277
+ merges = [tuple(merge.split()) for merge in merges]
278
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
279
+ self.cache = {}
280
+
281
+ super().__init__(unk_token=unk_token, **kwargs)
282
+
283
+ @property
284
+ def do_lower_case(self):
285
+ return True
286
+
287
+ @property
288
+ def vocab_size(self):
289
+ return len(self.encoder)
290
+
291
+ def get_vocab(self):
292
+ return dict(self.encoder, **self.added_tokens_encoder)
293
+
294
+ def bpe(self, token):
295
+ word = tuple(token[:-1]) + (token[-1] + "</w>",)
296
+ if token in self.cache:
297
+ return self.cache[token]
298
+ pairs = get_pairs(word)
299
+
300
+ if not pairs:
301
+ return token + "</w>"
302
+
303
+ while True:
304
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
305
+ if bigram not in self.bpe_ranks:
306
+ break
307
+ first, second = bigram
308
+ new_word = []
309
+ i = 0
310
+ while i < len(word):
311
+ try:
312
+ j = word.index(first, i)
313
+ except ValueError:
314
+ new_word.extend(word[i:])
315
+ break
316
+ else:
317
+ new_word.extend(word[i:j])
318
+ i = j
319
+
320
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
321
+ new_word.append(first + second)
322
+ i += 2
323
+ else:
324
+ new_word.append(word[i])
325
+ i += 1
326
+ new_word = tuple(new_word)
327
+ word = new_word
328
+ if len(word) == 1:
329
+ break
330
+ else:
331
+ pairs = get_pairs(word)
332
+ word = " ".join(word)
333
+ if word == "\n </w>":
334
+ word = "\n</w>"
335
+ self.cache[token] = word
336
+ return word
337
+
338
+ def _tokenize(self, text):
339
+ """Tokenize a string."""
340
+ split_tokens = []
341
+ if self.fix_text is None:
342
+ # Using BERT's BasicTokenizer
343
+ text = self.nlp.tokenize(text)
344
+ for token in text:
345
+ split_tokens.extend(list(self.bpe(token).split(" ")))
346
+ else:
347
+ # Using SpaCy & ftfy (original tokenization process of OpenAI GPT)
348
+ text = self.nlp(text_standardize(self.fix_text(text)))
349
+ for token in text:
350
+ split_tokens.extend(list(self.bpe(token.text.lower()).split(" ")))
351
+ return split_tokens
352
+
353
+ def _convert_token_to_id(self, token):
354
+ """Converts a token (str) in an id using the vocab."""
355
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
356
+
357
+ def _convert_id_to_token(self, index):
358
+ """Converts an id in a token (BPE) using the vocab."""
359
+ return self.decoder.get(index, self.unk_token)
360
+
361
+ def convert_tokens_to_string(self, tokens):
362
+ """Converts a sequence of tokens (string) in a single string."""
363
+ out_string = "".join(tokens).replace("</w>", " ").strip()
364
+ return out_string
365
+
366
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
367
+ if not os.path.isdir(save_directory):
368
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
369
+ return
370
+ vocab_file = os.path.join(
371
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
372
+ )
373
+ merge_file = os.path.join(
374
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
375
+ )
376
+
377
+ with open(vocab_file, "w", encoding="utf-8") as f:
378
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
379
+
380
+ index = 0
381
+ with open(merge_file, "w", encoding="utf-8") as writer:
382
+ writer.write("#version: 0.2\n")
383
+ for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
384
+ if index != token_index:
385
+ logger.warning(
386
+ f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
387
+ " Please check that the tokenizer is not corrupted!"
388
+ )
389
+ index = token_index
390
+ writer.write(" ".join(bpe_tokens) + "\n")
391
+ index += 1
392
+
393
+ return vocab_file, merge_file
394
+
395
+
396
+ __all__ = ["OpenAIGPTTokenizer"]
docs/transformers/src/transformers/models/openai/tokenization_openai_fast.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Fast Tokenization classes for OpenAI GPT."""
16
+
17
+ from typing import Optional, Tuple
18
+
19
+ from ...tokenization_utils_fast import PreTrainedTokenizerFast
20
+ from ...utils import logging
21
+ from .tokenization_openai import OpenAIGPTTokenizer
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
27
+
28
+
29
+ class OpenAIGPTTokenizerFast(PreTrainedTokenizerFast):
30
+ """
31
+ Construct a "fast" GPT Tokenizer (backed by HuggingFace's *tokenizers* library). Based on Byte-Pair-Encoding with
32
+ the following peculiarities:
33
+
34
+ - lower case all inputs
35
+ - uses BERT's BasicTokenizer for pre-BPE tokenization
36
+
37
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
38
+ refer to this superclass for more information regarding those methods.
39
+
40
+ Args:
41
+ vocab_file (`str`):
42
+ Path to the vocabulary file.
43
+ merges_file (`str`):
44
+ Path to the merges file.
45
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
46
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
47
+ token instead.
48
+ """
49
+
50
+ vocab_files_names = VOCAB_FILES_NAMES
51
+ model_input_names = ["input_ids", "attention_mask"]
52
+ slow_tokenizer_class = OpenAIGPTTokenizer
53
+
54
+ def __init__(self, vocab_file=None, merges_file=None, tokenizer_file=None, unk_token="<unk>", **kwargs):
55
+ super().__init__(vocab_file, merges_file, tokenizer_file=tokenizer_file, unk_token=unk_token, **kwargs)
56
+
57
+ @property
58
+ def do_lower_case(self):
59
+ return True
60
+
61
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
62
+ files = self._tokenizer.model.save(save_directory, name=filename_prefix)
63
+ return tuple(files)
64
+
65
+
66
+ __all__ = ["OpenAIGPTTokenizerFast"]
docs/transformers/src/transformers/models/opt/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_opt import *
22
+ from .modeling_flax_opt import *
23
+ from .modeling_opt import *
24
+ from .modeling_tf_opt import *
25
+ else:
26
+ import sys
27
+
28
+ _file = globals()["__file__"]
29
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)