Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- docs/transformers/src/transformers/models/myt5/tokenization_myt5.py +380 -0
- docs/transformers/src/transformers/models/nemotron/configuration_nemotron.py +156 -0
- docs/transformers/src/transformers/models/nemotron/convert_nemotron_nemo_to_hf.py +346 -0
- docs/transformers/src/transformers/models/nllb/tokenization_nllb.py +394 -0
- docs/transformers/src/transformers/models/nllb_moe/__init__.py +27 -0
- docs/transformers/src/transformers/models/nllb_moe/configuration_nllb_moe.py +219 -0
- docs/transformers/src/transformers/models/nllb_moe/convert_nllb_moe_sharded_original_checkpoint_to_pytorch.py +161 -0
- docs/transformers/src/transformers/models/nllb_moe/modeling_nllb_moe.py +1784 -0
- docs/transformers/src/transformers/models/nougat/__init__.py +28 -0
- docs/transformers/src/transformers/models/nougat/convert_nougat_to_hf.py +282 -0
- docs/transformers/src/transformers/models/nougat/image_processing_nougat.py +525 -0
- docs/transformers/src/transformers/models/nougat/processing_nougat.py +163 -0
- docs/transformers/src/transformers/models/nougat/tokenization_nougat_fast.py +620 -0
- docs/transformers/src/transformers/models/nystromformer/__init__.py +27 -0
- docs/transformers/src/transformers/models/nystromformer/configuration_nystromformer.py +132 -0
- docs/transformers/src/transformers/models/nystromformer/convert_nystromformer_original_pytorch_checkpoint_to_pytorch.py +111 -0
- docs/transformers/src/transformers/models/nystromformer/modeling_nystromformer.py +1124 -0
- docs/transformers/src/transformers/models/olmo/__init__.py +27 -0
- docs/transformers/src/transformers/models/olmo/configuration_olmo.py +198 -0
- docs/transformers/src/transformers/models/olmo/convert_olmo_weights_to_hf.py +248 -0
- docs/transformers/src/transformers/models/olmo/modeling_olmo.py +814 -0
- docs/transformers/src/transformers/models/olmo/modular_olmo.py +148 -0
- docs/transformers/src/transformers/models/olmo2/__init__.py +27 -0
- docs/transformers/src/transformers/models/olmo2/configuration_olmo2.py +180 -0
- docs/transformers/src/transformers/models/olmo2/convert_olmo2_weights_to_hf.py +306 -0
- docs/transformers/src/transformers/models/olmo2/modeling_olmo2.py +820 -0
- docs/transformers/src/transformers/models/olmo2/modular_olmo2.py +320 -0
- docs/transformers/src/transformers/models/olmoe/__init__.py +27 -0
- docs/transformers/src/transformers/models/olmoe/configuration_olmoe.py +182 -0
- docs/transformers/src/transformers/models/olmoe/convert_olmoe_weights_to_hf.py +281 -0
- docs/transformers/src/transformers/models/olmoe/modeling_olmoe.py +1273 -0
- docs/transformers/src/transformers/models/omdet_turbo/__init__.py +28 -0
- docs/transformers/src/transformers/models/omdet_turbo/configuration_omdet_turbo.py +293 -0
- docs/transformers/src/transformers/models/omdet_turbo/convert_omdet_turbo_to_hf.py +349 -0
- docs/transformers/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py +1711 -0
- docs/transformers/src/transformers/models/omdet_turbo/processing_omdet_turbo.py +415 -0
- docs/transformers/src/transformers/models/oneformer/__init__.py +29 -0
- docs/transformers/src/transformers/models/oneformer/configuration_oneformer.py +277 -0
- docs/transformers/src/transformers/models/oneformer/convert_to_hf_oneformer.py +1191 -0
- docs/transformers/src/transformers/models/oneformer/image_processing_oneformer.py +1356 -0
- docs/transformers/src/transformers/models/oneformer/modeling_oneformer.py +0 -0
- docs/transformers/src/transformers/models/oneformer/processing_oneformer.py +207 -0
- docs/transformers/src/transformers/models/openai/__init__.py +30 -0
- docs/transformers/src/transformers/models/openai/configuration_openai.py +156 -0
- docs/transformers/src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py +74 -0
- docs/transformers/src/transformers/models/openai/modeling_openai.py +967 -0
- docs/transformers/src/transformers/models/openai/modeling_tf_openai.py +937 -0
- docs/transformers/src/transformers/models/openai/tokenization_openai.py +396 -0
- docs/transformers/src/transformers/models/openai/tokenization_openai_fast.py +66 -0
- docs/transformers/src/transformers/models/opt/__init__.py +29 -0
docs/transformers/src/transformers/models/myt5/tokenization_myt5.py
ADDED
|
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Tokenization class for model MyT5."""
|
| 16 |
+
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
import warnings
|
| 20 |
+
from collections import defaultdict
|
| 21 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 22 |
+
|
| 23 |
+
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
|
| 24 |
+
from ...utils import logging
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
logger = logging.get_logger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
VOCAB_FILES_NAMES = {"vocab_file": "byte_maps.json"}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class ByteRewriter:
|
| 34 |
+
"""
|
| 35 |
+
Byte rewriter class for MyT5 tokenizer.
|
| 36 |
+
This class is used to rewrite bytes using a hash tree. The hash tree is constructed from a set of rewriting rules.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
rewriting_rules (`str` or `Dict[str, str]`):
|
| 40 |
+
A path to a json file containing the rewriting rules or a dictionary containing the rewriting rules.
|
| 41 |
+
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
LEAF = "[LEAF]"
|
| 45 |
+
|
| 46 |
+
def __init__(self, rewriting_rules: Union[str, Dict[str, str]]):
|
| 47 |
+
if isinstance(rewriting_rules, str):
|
| 48 |
+
with open(rewriting_rules, "r") as f:
|
| 49 |
+
rewriting_rules = json.load(f)
|
| 50 |
+
elif not isinstance(rewriting_rules, dict):
|
| 51 |
+
raise ValueError(
|
| 52 |
+
f"rewriting_rules should be either a path to json file or a dict, got {type(rewriting_rules)}"
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
self.hash_tree = self.construct_hash_tree(rewriting_rules)
|
| 56 |
+
reverse_rewriting_rules = {v: k for k, v in rewriting_rules.items()}
|
| 57 |
+
self.reverse_hash_tree = self.construct_hash_tree(reverse_rewriting_rules)
|
| 58 |
+
|
| 59 |
+
def add_leaf(self, hash_tree: Dict[str, Union[dict, List[str]]], byte_in_sequence: str, byte_out_sequence: str):
|
| 60 |
+
"""
|
| 61 |
+
Add a leaf with the output byte sequence to the hash tree.
|
| 62 |
+
"""
|
| 63 |
+
byte_in_list = byte_in_sequence.split(" ")
|
| 64 |
+
byte_out_list = byte_out_sequence.split(" ")
|
| 65 |
+
|
| 66 |
+
tree_pointer = hash_tree
|
| 67 |
+
for b in byte_in_list:
|
| 68 |
+
if b not in tree_pointer:
|
| 69 |
+
tree_pointer[b] = {}
|
| 70 |
+
tree_pointer = tree_pointer[b]
|
| 71 |
+
|
| 72 |
+
tree_pointer[self.LEAF] = byte_out_list
|
| 73 |
+
|
| 74 |
+
def construct_hash_tree(self, rewriting_rules: Dict[str, str]) -> Dict[str, Union[dict, List[str]]]:
|
| 75 |
+
"""
|
| 76 |
+
Construct a hash tree for rewritten byte sequences.
|
| 77 |
+
"""
|
| 78 |
+
hash_tree = defaultdict(dict)
|
| 79 |
+
for b in (f"{x:02x}" for x in range(256)):
|
| 80 |
+
hash_tree[b][self.LEAF] = [b]
|
| 81 |
+
|
| 82 |
+
for in_sequence, out_sequence in rewriting_rules.items():
|
| 83 |
+
self.add_leaf(hash_tree, in_sequence, out_sequence)
|
| 84 |
+
|
| 85 |
+
return hash_tree
|
| 86 |
+
|
| 87 |
+
def search_hash_tree(self, byte_sequence: List[str]) -> Union[None, List[str]]:
|
| 88 |
+
"""
|
| 89 |
+
Search the hash tree and return the rewritten byte sequence if found.
|
| 90 |
+
"""
|
| 91 |
+
tree_pointer = self.hash_tree
|
| 92 |
+
for b in byte_sequence:
|
| 93 |
+
if b in tree_pointer:
|
| 94 |
+
tree_pointer = tree_pointer[b]
|
| 95 |
+
else:
|
| 96 |
+
return None
|
| 97 |
+
|
| 98 |
+
return tree_pointer[self.LEAF]
|
| 99 |
+
|
| 100 |
+
def rewrite_bytes(self, in_bytes: List[str], reverse=False) -> List[str]:
|
| 101 |
+
"""
|
| 102 |
+
Rewrite a sequence of bytes using the hash tree.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
in_bytes (`List[str]`): A list of bytes to be rewritten.
|
| 106 |
+
reverse (`bool`): If True, decoding is performed with the reverse hash tree.
|
| 107 |
+
Returns:
|
| 108 |
+
`List[str]`: The rewritten byte sequence.
|
| 109 |
+
"""
|
| 110 |
+
out_bytes = []
|
| 111 |
+
b_start = 0
|
| 112 |
+
b_end = 0
|
| 113 |
+
|
| 114 |
+
while b_start < len(in_bytes):
|
| 115 |
+
tree_pointer = self.hash_tree if not reverse else self.reverse_hash_tree
|
| 116 |
+
for j in range(b_start, len(in_bytes)):
|
| 117 |
+
b = in_bytes[j]
|
| 118 |
+
if b in tree_pointer:
|
| 119 |
+
tree_pointer = tree_pointer[b]
|
| 120 |
+
elif j == b_start:
|
| 121 |
+
cur_leaf = [b]
|
| 122 |
+
b_end = j
|
| 123 |
+
break
|
| 124 |
+
else:
|
| 125 |
+
break
|
| 126 |
+
if self.LEAF in tree_pointer:
|
| 127 |
+
cur_leaf = tree_pointer[self.LEAF]
|
| 128 |
+
b_end = j
|
| 129 |
+
out_bytes.extend(cur_leaf)
|
| 130 |
+
b_start = b_end + 1
|
| 131 |
+
|
| 132 |
+
return out_bytes
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class MyT5Tokenizer(PreTrainedTokenizer):
|
| 136 |
+
"""
|
| 137 |
+
Construct a MyT5 tokenizer.
|
| 138 |
+
|
| 139 |
+
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
| 140 |
+
this superclass for more information regarding those methods.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
vocab_file (`str`): The file containing the byte rewriting rules.
|
| 144 |
+
eos_token (`str`, *optional*, defaults to `"</s>"`):
|
| 145 |
+
The end of sequence token.
|
| 146 |
+
|
| 147 |
+
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
| 148 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 149 |
+
token instead.
|
| 150 |
+
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
| 151 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 152 |
+
extra_ids (`int`, *optional*, defaults to 125):
|
| 153 |
+
Add a number of extra ids added to the end of the vocabulary for use as sentinels. These tokens are
|
| 154 |
+
accessible as "<extra_id_{%d}>" where "{%d}" is a number between 0 and extra_ids-1. Extra tokens are
|
| 155 |
+
indexed from the end of the vocabulary up to beginning ("<extra_id_0>" is the last token in the vocabulary
|
| 156 |
+
like in ByT5 preprocessing see
|
| 157 |
+
[here](https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/t5/data/preprocessors.py#L2117)).
|
| 158 |
+
additional_special_tokens (`List[str]`, *optional*):
|
| 159 |
+
Additional special tokens used by the tokenizer.
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 163 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 164 |
+
|
| 165 |
+
def __init__(
|
| 166 |
+
self,
|
| 167 |
+
vocab_file,
|
| 168 |
+
eos_token="</s>",
|
| 169 |
+
unk_token="<unk>",
|
| 170 |
+
pad_token="<pad>",
|
| 171 |
+
extra_ids=125,
|
| 172 |
+
additional_special_tokens=None,
|
| 173 |
+
**kwargs,
|
| 174 |
+
) -> None:
|
| 175 |
+
# Add extra_ids to the special token list
|
| 176 |
+
if extra_ids > 0 and additional_special_tokens is None:
|
| 177 |
+
additional_special_tokens = [f"<extra_id_{i}>" for i in range(extra_ids)]
|
| 178 |
+
elif extra_ids > 0 and additional_special_tokens is not None and len(additional_special_tokens) > 0:
|
| 179 |
+
# Check that we have the right number of extra_id special tokens
|
| 180 |
+
extra_tokens = len(set(filter(lambda x: bool("extra_id" in str(x)), additional_special_tokens)))
|
| 181 |
+
if extra_tokens != extra_ids:
|
| 182 |
+
raise ValueError(
|
| 183 |
+
f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are"
|
| 184 |
+
" provided to MyT5Tokenizer. In this case the additional_special_tokens must include the"
|
| 185 |
+
" extra_ids tokens"
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
pad_token = AddedToken(pad_token, lstrip=True, rstrip=True) if isinstance(pad_token, str) else pad_token
|
| 189 |
+
eos_token = AddedToken(eos_token, lstrip=True, rstrip=True) if isinstance(eos_token, str) else eos_token
|
| 190 |
+
unk_token = AddedToken(unk_token, lstrip=True, rstrip=True) if isinstance(unk_token, str) else unk_token
|
| 191 |
+
# unk token needs to be in the vocab with correct index
|
| 192 |
+
self._added_tokens_decoder = {0: pad_token, 1: eos_token, 2: unk_token}
|
| 193 |
+
self.offset = len(self._added_tokens_decoder)
|
| 194 |
+
self._utf_vocab_size = 2**8 # utf is 8 bits
|
| 195 |
+
|
| 196 |
+
# Load byte maps
|
| 197 |
+
self.byte_maps = json.load(open(vocab_file, "r"))
|
| 198 |
+
|
| 199 |
+
self.decompose_rewriter = ByteRewriter(self.byte_maps["decompose_map"])
|
| 200 |
+
self.merge_rewriter = ByteRewriter(self.byte_maps["merge_map"])
|
| 201 |
+
|
| 202 |
+
super().__init__(
|
| 203 |
+
eos_token=eos_token,
|
| 204 |
+
unk_token=unk_token,
|
| 205 |
+
pad_token=pad_token,
|
| 206 |
+
extra_ids=0,
|
| 207 |
+
additional_special_tokens=additional_special_tokens,
|
| 208 |
+
**kwargs,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
@property
|
| 212 |
+
def vocab_size(self):
|
| 213 |
+
return self._utf_vocab_size
|
| 214 |
+
|
| 215 |
+
# Copied from transformers.models.byt5.tokenization_byt5.ByT5Tokenizer.get_vocab
|
| 216 |
+
def get_vocab(self):
|
| 217 |
+
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size + self.offset)}
|
| 218 |
+
vocab.update(self.added_tokens_encoder)
|
| 219 |
+
return vocab
|
| 220 |
+
|
| 221 |
+
# Copied from transformers.models.byt5.tokenization_byt5.ByT5Tokenizer.get_special_tokens_mask
|
| 222 |
+
def get_special_tokens_mask(
|
| 223 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 224 |
+
) -> List[int]:
|
| 225 |
+
"""
|
| 226 |
+
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 227 |
+
special tokens using the tokenizer `prepare_for_model` method.
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
token_ids_0 (`List[int]`):
|
| 231 |
+
List of IDs.
|
| 232 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 233 |
+
Optional second list of IDs for sequence pairs.
|
| 234 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
| 235 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
| 236 |
+
|
| 237 |
+
Returns:
|
| 238 |
+
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 239 |
+
"""
|
| 240 |
+
if already_has_special_tokens:
|
| 241 |
+
return super().get_special_tokens_mask(
|
| 242 |
+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
# normal case: some special tokens
|
| 246 |
+
if token_ids_1 is None:
|
| 247 |
+
return ([0] * len(token_ids_0)) + [1]
|
| 248 |
+
return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
| 249 |
+
|
| 250 |
+
def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]:
|
| 251 |
+
"""Do not add eos again if user already added it."""
|
| 252 |
+
if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id:
|
| 253 |
+
warnings.warn(
|
| 254 |
+
f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated"
|
| 255 |
+
" eos tokens being added."
|
| 256 |
+
)
|
| 257 |
+
return token_ids
|
| 258 |
+
else:
|
| 259 |
+
return token_ids + [self.eos_token_id]
|
| 260 |
+
|
| 261 |
+
def create_token_type_ids_from_sequences(
|
| 262 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 263 |
+
) -> List[int]:
|
| 264 |
+
"""
|
| 265 |
+
Create a mask from the two sequences passed to be used in a sequence-pair classification task. MyT5 does not
|
| 266 |
+
make use of token type ids, therefore a list of zeros is returned.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
token_ids_0 (`List[int]`):
|
| 270 |
+
List of IDs.
|
| 271 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 272 |
+
Optional second list of IDs for sequence pairs.
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
`List[int]`: List of zeros.
|
| 276 |
+
"""
|
| 277 |
+
eos = [self.eos_token_id]
|
| 278 |
+
|
| 279 |
+
if token_ids_1 is None:
|
| 280 |
+
return len(token_ids_0 + eos) * [0]
|
| 281 |
+
return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
|
| 282 |
+
|
| 283 |
+
# Copied from transformers.models.byt5.tokenization_byt5.ByT5Tokenizer.build_inputs_with_special_tokens
|
| 284 |
+
def build_inputs_with_special_tokens(
|
| 285 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 286 |
+
) -> List[int]:
|
| 287 |
+
"""
|
| 288 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
| 289 |
+
adding special tokens. A sequence has the following format:
|
| 290 |
+
|
| 291 |
+
- single sequence: `X </s>`
|
| 292 |
+
- pair of sequences: `A </s> B </s>`
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
token_ids_0 (`List[int]`):
|
| 296 |
+
List of IDs to which the special tokens will be added.
|
| 297 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 298 |
+
Optional second list of IDs for sequence pairs.
|
| 299 |
+
|
| 300 |
+
Returns:
|
| 301 |
+
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
| 302 |
+
"""
|
| 303 |
+
token_ids_0 = self._add_eos_if_not_present(token_ids_0)
|
| 304 |
+
if token_ids_1 is None:
|
| 305 |
+
return token_ids_0
|
| 306 |
+
else:
|
| 307 |
+
token_ids_1 = self._add_eos_if_not_present(token_ids_1)
|
| 308 |
+
return token_ids_0 + token_ids_1
|
| 309 |
+
|
| 310 |
+
def _tokenize(self, text: str, **kwargs) -> List[str]:
|
| 311 |
+
"""Take as input a string and return a list of strings (tokens) for words/sub-words.
|
| 312 |
+
Represents tokens in two character hex format"""
|
| 313 |
+
|
| 314 |
+
tokens = [f"{i:02x}" for i in text.encode("utf-8")]
|
| 315 |
+
tokens = self.morphological_encode(tokens)
|
| 316 |
+
return tokens
|
| 317 |
+
|
| 318 |
+
def _convert_token_to_id(self, token):
|
| 319 |
+
"""Converts a token (str) in an id using the vocab."""
|
| 320 |
+
|
| 321 |
+
if len(token) != 2:
|
| 322 |
+
token_id = None
|
| 323 |
+
else:
|
| 324 |
+
token_id = int(token, 16) + self.offset
|
| 325 |
+
|
| 326 |
+
return token_id
|
| 327 |
+
|
| 328 |
+
def _convert_id_to_token(self, index):
|
| 329 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 330 |
+
token = f"{index - self.offset:02x}"
|
| 331 |
+
return token
|
| 332 |
+
|
| 333 |
+
def morphological_encode(self, indices: List[str]) -> List[str]:
|
| 334 |
+
# Decompose and merge morphological sequences
|
| 335 |
+
indices = self.decompose_rewriter.rewrite_bytes(indices, reverse=False)
|
| 336 |
+
indices = self.merge_rewriter.rewrite_bytes(indices, reverse=False)
|
| 337 |
+
return indices
|
| 338 |
+
|
| 339 |
+
def morphological_decode(self, indices: List[str]) -> List[str]:
|
| 340 |
+
# Demerge and compose morphological sequences
|
| 341 |
+
indices = self.merge_rewriter.rewrite_bytes(indices, reverse=True)
|
| 342 |
+
indices = self.decompose_rewriter.rewrite_bytes(indices, reverse=True)
|
| 343 |
+
return indices
|
| 344 |
+
|
| 345 |
+
def convert_tokens_to_string(self, tokens):
|
| 346 |
+
"""Converts a sequence of tokens (string) in a single string."""
|
| 347 |
+
bstring = b""
|
| 348 |
+
|
| 349 |
+
out_tokens = []
|
| 350 |
+
for token in tokens:
|
| 351 |
+
if token in self.added_tokens_decoder:
|
| 352 |
+
out_tokens.append(self.added_tokens_decoder[token])
|
| 353 |
+
elif token in self.added_tokens_encoder:
|
| 354 |
+
out_tokens.append(token)
|
| 355 |
+
else:
|
| 356 |
+
out_tokens.append(token)
|
| 357 |
+
|
| 358 |
+
out_tokens = self.morphological_decode(out_tokens)
|
| 359 |
+
_added_tokens = set(self.added_tokens_decoder.values()) | set(self.added_tokens_encoder)
|
| 360 |
+
for token in out_tokens:
|
| 361 |
+
if token in _added_tokens:
|
| 362 |
+
bstring += bytes(token, "utf-8")
|
| 363 |
+
else:
|
| 364 |
+
bstring += bytes.fromhex(token)
|
| 365 |
+
string = bstring.decode("utf-8", errors="ignore")
|
| 366 |
+
return string
|
| 367 |
+
|
| 368 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 369 |
+
if os.path.isdir(save_directory):
|
| 370 |
+
vocab_file = os.path.join(
|
| 371 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 372 |
+
)
|
| 373 |
+
else:
|
| 374 |
+
vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
|
| 375 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
| 376 |
+
writer.write(json.dumps(self.byte_maps, indent=2, ensure_ascii=False))
|
| 377 |
+
return (vocab_file,)
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
__all__ = ["MyT5Tokenizer"]
|
docs/transformers/src/transformers/models/nemotron/configuration_nemotron.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""Nemotron model configuration"""
|
| 17 |
+
|
| 18 |
+
from ...configuration_utils import PretrainedConfig
|
| 19 |
+
from ...modeling_rope_utils import rope_config_validation
|
| 20 |
+
from ...utils import logging
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
logger = logging.get_logger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class NemotronConfig(PretrainedConfig):
|
| 27 |
+
r"""
|
| 28 |
+
This is the configuration class to store the configuration of a [`NemotronModel`]. It is used to instantiate an Nemotron
|
| 29 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 30 |
+
defaults will yield a similar configuration to that of the Nemotron-8B.
|
| 31 |
+
e.g. [nvidia/nemotron-3-8b-base-4k-hf](https://huggingface.co/nvidia/nemotron-3-8b-base-4k-hf).
|
| 32 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 33 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
vocab_size (`int`, *optional*, defaults to 256000):
|
| 38 |
+
Vocabulary size of the Nemotron model. Defines the number of different tokens that can be represented by the
|
| 39 |
+
`inputs_ids` passed when calling [`NemotronModel`]
|
| 40 |
+
hidden_size (`int`, *optional*, defaults to 6144):
|
| 41 |
+
Dimension of the hidden representations.
|
| 42 |
+
intermediate_size (`int`, *optional*, defaults to 24576):
|
| 43 |
+
Dimension of the MLP representations.
|
| 44 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
| 45 |
+
Number of hidden layers in the Transformer decoder.
|
| 46 |
+
num_attention_heads (`int`, *optional*, defaults to 48):
|
| 47 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
| 48 |
+
head_dim (`int`, *optional*):
|
| 49 |
+
Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if None
|
| 50 |
+
num_key_value_heads (`int`, *optional*):
|
| 51 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
| 52 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
| 53 |
+
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
| 54 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
| 55 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
| 56 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
| 57 |
+
`num_attention_heads`.
|
| 58 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"relu2"`):
|
| 59 |
+
The non-linear activation function (function or string) in the decoder.
|
| 60 |
+
max_position_embeddings (`int`, *optional*, defaults to 4096):
|
| 61 |
+
The maximum sequence length that this model might ever be used with.
|
| 62 |
+
initializer_range (`float`, *optional*, defaults to 0.0134):
|
| 63 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 64 |
+
norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 65 |
+
The epsilon used by the normalization layers.
|
| 66 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 67 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 68 |
+
relevant if `config.is_decoder=True`.
|
| 69 |
+
pad_token_id (`int`, *optional*):
|
| 70 |
+
Padding token id.
|
| 71 |
+
bos_token_id (`int`, *optional*, defaults to 2):
|
| 72 |
+
Beginning of stream token id.
|
| 73 |
+
eos_token_id (`int`, *optional*, defaults to 3):
|
| 74 |
+
End of stream token id.
|
| 75 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 76 |
+
Whether to tie weight embeddings
|
| 77 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
| 78 |
+
The base period of the RoPE embeddings.
|
| 79 |
+
partial_rotary_factor (`float`, *optional*, defaults to 0.5): Percentage of the query and keys which will have rotary embedding.
|
| 80 |
+
attention_bias (`bool`, *optional*, defaults to `False`):
|
| 81 |
+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
| 82 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 83 |
+
The dropout ratio for the attention probabilities.
|
| 84 |
+
mlp_bias (`bool`, *optional*, defaults to `False`):
|
| 85 |
+
Whether to use a bias in up_proj and down_proj layers in the MLP layers.
|
| 86 |
+
|
| 87 |
+
```python
|
| 88 |
+
>>> from transformers import NemotronModel, NemotronConfig
|
| 89 |
+
|
| 90 |
+
>>> # Initializing a Nemotron nemotron-15b style configuration
|
| 91 |
+
>>> configuration = NemotronConfig()
|
| 92 |
+
|
| 93 |
+
>>> # Initializing a model from the nemotron-15b style configuration
|
| 94 |
+
>>> model = NemotronModel(configuration)
|
| 95 |
+
|
| 96 |
+
>>> # Accessing the model configuration
|
| 97 |
+
>>> configuration = model.config
|
| 98 |
+
```"""
|
| 99 |
+
|
| 100 |
+
model_type = "nemotron"
|
| 101 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 102 |
+
|
| 103 |
+
def __init__(
|
| 104 |
+
self,
|
| 105 |
+
vocab_size=256000,
|
| 106 |
+
hidden_size=6144,
|
| 107 |
+
intermediate_size=24576,
|
| 108 |
+
num_hidden_layers=32,
|
| 109 |
+
num_attention_heads=48,
|
| 110 |
+
head_dim=None,
|
| 111 |
+
num_key_value_heads=None,
|
| 112 |
+
hidden_act="relu2",
|
| 113 |
+
max_position_embeddings=4096,
|
| 114 |
+
initializer_range=0.0134,
|
| 115 |
+
norm_eps=1e-5,
|
| 116 |
+
use_cache=True,
|
| 117 |
+
pad_token_id=None,
|
| 118 |
+
bos_token_id=2,
|
| 119 |
+
eos_token_id=3,
|
| 120 |
+
tie_word_embeddings=False,
|
| 121 |
+
rope_theta=10000.0,
|
| 122 |
+
partial_rotary_factor=0.5,
|
| 123 |
+
attention_bias=False,
|
| 124 |
+
attention_dropout=0.0,
|
| 125 |
+
mlp_bias=False,
|
| 126 |
+
**kwargs,
|
| 127 |
+
):
|
| 128 |
+
self.vocab_size = vocab_size
|
| 129 |
+
self.max_position_embeddings = max_position_embeddings
|
| 130 |
+
self.hidden_size = hidden_size
|
| 131 |
+
self.intermediate_size = intermediate_size
|
| 132 |
+
self.num_hidden_layers = num_hidden_layers
|
| 133 |
+
self.num_attention_heads = num_attention_heads
|
| 134 |
+
self.head_dim = head_dim if head_dim is not None else hidden_size // num_attention_heads
|
| 135 |
+
self.num_key_value_heads = num_key_value_heads
|
| 136 |
+
self.hidden_act = hidden_act
|
| 137 |
+
self.initializer_range = initializer_range
|
| 138 |
+
self.norm_eps = norm_eps
|
| 139 |
+
self.use_cache = use_cache
|
| 140 |
+
self.rope_theta = rope_theta
|
| 141 |
+
self.partial_rotary_factor = partial_rotary_factor
|
| 142 |
+
rope_config_validation(self)
|
| 143 |
+
self.attention_bias = attention_bias
|
| 144 |
+
self.attention_dropout = attention_dropout
|
| 145 |
+
self.mlp_bias = mlp_bias
|
| 146 |
+
|
| 147 |
+
super().__init__(
|
| 148 |
+
pad_token_id=pad_token_id,
|
| 149 |
+
bos_token_id=bos_token_id,
|
| 150 |
+
eos_token_id=eos_token_id,
|
| 151 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 152 |
+
**kwargs,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
__all__ = ["NemotronConfig"]
|
docs/transformers/src/transformers/models/nemotron/convert_nemotron_nemo_to_hf.py
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import os
|
| 17 |
+
import shutil
|
| 18 |
+
from argparse import ArgumentParser
|
| 19 |
+
from collections import OrderedDict
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
|
| 23 |
+
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
|
| 24 |
+
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy
|
| 25 |
+
from nemo.utils import logging
|
| 26 |
+
from pytorch_lightning import Trainer
|
| 27 |
+
|
| 28 |
+
from transformers import LlamaTokenizer, PreTrainedTokenizerFast
|
| 29 |
+
from transformers.convert_slow_tokenizer import LlamaConverter
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
"""
|
| 33 |
+
Script to convert a nemotron checkpoint in nemo (mcore path) into a HuggingFace checkpoint.
|
| 34 |
+
This script can be used to 1) generate only the HF weights, or 2) generate an entire HF model folder.
|
| 35 |
+
|
| 36 |
+
1) Generate only HF weights from a nemo file:
|
| 37 |
+
|
| 38 |
+
python convert_nemotron_nemo_to_hf.py \
|
| 39 |
+
--input_name_or_path /path/to/file.nemo or /path/to/extracted_folder \
|
| 40 |
+
--output_path /path/to/pytorch_model.bin
|
| 41 |
+
|
| 42 |
+
2) Generate the full HF model folder
|
| 43 |
+
|
| 44 |
+
python convert_nemotron_nemo_to_hf.py \
|
| 45 |
+
--input_name_or_path /path/to/file.nemo or /path/to/extracted_folder \
|
| 46 |
+
--hf_input_path /path/to/input_hf_folder \
|
| 47 |
+
--hf_output_path /path/to/output_hf_folder \
|
| 48 |
+
|
| 49 |
+
Use the --cpu-only flag if the model cannot fit in the GPU (e.g. Nemotron4 340b).
|
| 50 |
+
However this option makes the conversion script significantly slower.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def get_args():
|
| 55 |
+
parser = ArgumentParser()
|
| 56 |
+
parser.add_argument(
|
| 57 |
+
"--input_name_or_path",
|
| 58 |
+
type=str,
|
| 59 |
+
default=None,
|
| 60 |
+
required=True,
|
| 61 |
+
help="Path to .nemo file or extracted folder",
|
| 62 |
+
)
|
| 63 |
+
parser.add_argument("--output_path", type=str, default=None, required=False, help="Path to HF .bin file")
|
| 64 |
+
parser.add_argument(
|
| 65 |
+
"--hf_input_path",
|
| 66 |
+
type=str,
|
| 67 |
+
default=None,
|
| 68 |
+
help="A HF model path, e.g. a folder containing https://huggingface.co/nvidia/Minitron-8B-Base",
|
| 69 |
+
)
|
| 70 |
+
parser.add_argument(
|
| 71 |
+
"--hf_output_path",
|
| 72 |
+
type=str,
|
| 73 |
+
default=None,
|
| 74 |
+
help="Output HF model path, with the same format as above but user's own weights",
|
| 75 |
+
)
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--precision",
|
| 78 |
+
type=str,
|
| 79 |
+
default=None,
|
| 80 |
+
help="Precision of output weights."
|
| 81 |
+
"Defaults to precision of the input nemo weights (model.cfg.trainer.precision)",
|
| 82 |
+
)
|
| 83 |
+
parser.add_argument(
|
| 84 |
+
"--cpu-only",
|
| 85 |
+
action="store_true",
|
| 86 |
+
help="Load model in cpu only. Useful if the model cannot fit in GPU memory, "
|
| 87 |
+
"but this option makes the conversion script significantly slower.",
|
| 88 |
+
)
|
| 89 |
+
args = parser.parse_args()
|
| 90 |
+
return args
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def convert_hf_config(nemo_config, tokenizer, vocab_size, dtype, hf_output_path, hf_url="nvidia/Minitron-8B-Base"):
|
| 94 |
+
"""
|
| 95 |
+
Convert NeMo config to HF config
|
| 96 |
+
"""
|
| 97 |
+
NEMO_ACT2HF = {
|
| 98 |
+
"squared-relu": "relu2",
|
| 99 |
+
"fast-swiglu": "silu",
|
| 100 |
+
}
|
| 101 |
+
DTYPE2HF = {
|
| 102 |
+
torch.bfloat16: "bfloat16",
|
| 103 |
+
torch.float16: "float16",
|
| 104 |
+
torch.float32: "float32",
|
| 105 |
+
}
|
| 106 |
+
hf_config = {
|
| 107 |
+
"_name_or_path": hf_url,
|
| 108 |
+
"architectures": ["NemotronForCausalLM"],
|
| 109 |
+
"bos_token_id": tokenizer.bos_id,
|
| 110 |
+
"eos_token_id": tokenizer.eos_id,
|
| 111 |
+
"hidden_act": NEMO_ACT2HF[nemo_config.activation],
|
| 112 |
+
"hidden_size": nemo_config.hidden_size,
|
| 113 |
+
"initializer_range": nemo_config.init_method_std,
|
| 114 |
+
"intermediate_size": nemo_config.ffn_hidden_size,
|
| 115 |
+
"max_position_embeddings": nemo_config.max_position_embeddings,
|
| 116 |
+
"model_type": "nemotron",
|
| 117 |
+
"num_attention_heads": nemo_config.num_attention_heads,
|
| 118 |
+
"num_hidden_layers": nemo_config.num_layers,
|
| 119 |
+
"num_key_value_heads": nemo_config.get("num_query_groups", nemo_config.num_attention_heads),
|
| 120 |
+
"norm_eps": nemo_config.layernorm_epsilon,
|
| 121 |
+
"rope_theta": nemo_config.get("rotary_base", 10000),
|
| 122 |
+
"partial_rotary_factor": nemo_config.get("rotary_percentage", 1.0),
|
| 123 |
+
"tie_word_embeddings": False,
|
| 124 |
+
"torch_dtype": DTYPE2HF[dtype],
|
| 125 |
+
"transformers_version": "4.32.0.dev0", # TODO
|
| 126 |
+
"use_cache": True,
|
| 127 |
+
"vocab_size": vocab_size,
|
| 128 |
+
}
|
| 129 |
+
if nemo_config.kv_channels is not None:
|
| 130 |
+
hf_config["kv_channels"] = nemo_config.kv_channels
|
| 131 |
+
json.dump(hf_config, open(f"{hf_output_path}/config.json", "w"), indent=2)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) -> None:
|
| 135 |
+
"""
|
| 136 |
+
Convert NeMo weights to HF weights
|
| 137 |
+
"""
|
| 138 |
+
dummy_trainer = Trainer(devices=1, accelerator="cpu", strategy=NLPDDPStrategy())
|
| 139 |
+
model_config = MegatronGPTModel.restore_from(input_nemo_file, trainer=dummy_trainer, return_config=True)
|
| 140 |
+
model_config.tensor_model_parallel_size = 1
|
| 141 |
+
model_config.pipeline_model_parallel_size = 1
|
| 142 |
+
model_config.sequence_parallel = False
|
| 143 |
+
model_config.transformer_engine = True
|
| 144 |
+
if cpu_only:
|
| 145 |
+
map_location = torch.device("cpu")
|
| 146 |
+
model_config.use_cpu_initialization = True
|
| 147 |
+
model_config.dist_ckpt_load_on_device = False
|
| 148 |
+
else:
|
| 149 |
+
map_location = None
|
| 150 |
+
|
| 151 |
+
if cpu_only:
|
| 152 |
+
logging.info("******** Loading model on CPU. This will take a significant amount of time.")
|
| 153 |
+
|
| 154 |
+
model = MegatronGPTModel.restore_from(
|
| 155 |
+
input_nemo_file, trainer=dummy_trainer, override_config_path=model_config, map_location=map_location
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
vocab_size = model.padded_vocab_size
|
| 159 |
+
|
| 160 |
+
if precision is None:
|
| 161 |
+
precision = model.cfg.precision
|
| 162 |
+
if precision in [32, "32"]:
|
| 163 |
+
dtype = torch.float32
|
| 164 |
+
elif precision in [16, "16", "16-mixed"]:
|
| 165 |
+
dtype = torch.float16
|
| 166 |
+
elif precision in ["bf16", "bf16-mixed"]:
|
| 167 |
+
dtype = torch.bfloat16
|
| 168 |
+
else:
|
| 169 |
+
logging.warning(f"Precision string {precision} is not recognized, falling back to fp32")
|
| 170 |
+
dtype = torch.float32 # fallback
|
| 171 |
+
logging.info(f"Using precision {dtype}")
|
| 172 |
+
|
| 173 |
+
def param_to_weights(param):
|
| 174 |
+
return param.to(dtype)
|
| 175 |
+
|
| 176 |
+
checkpoint = OrderedDict()
|
| 177 |
+
|
| 178 |
+
hidden_size = model.cfg.hidden_size
|
| 179 |
+
head_num = model.cfg.num_attention_heads
|
| 180 |
+
num_layers = model.cfg.num_layers
|
| 181 |
+
ffn_hidden_size = model.cfg.ffn_hidden_size
|
| 182 |
+
num_query_groups = model.cfg.get("num_query_groups", head_num) # different num_query_groups for 70B
|
| 183 |
+
if num_query_groups is None:
|
| 184 |
+
num_query_groups = head_num
|
| 185 |
+
heads_per_group = head_num // num_query_groups
|
| 186 |
+
qkv_total_dim = head_num + 2 * num_query_groups
|
| 187 |
+
|
| 188 |
+
# Embedding
|
| 189 |
+
embed_weight = model.state_dict()["model.embedding.word_embeddings.weight"]
|
| 190 |
+
embed_weights_base_name = "model.embed_tokens.weight"
|
| 191 |
+
checkpoint[embed_weights_base_name] = param_to_weights(embed_weight)
|
| 192 |
+
|
| 193 |
+
for l in range(int(num_layers)):
|
| 194 |
+
print(f"converting layer {l}")
|
| 195 |
+
|
| 196 |
+
qkv_weights = model.state_dict()[f"model.decoder.layers.{l}.self_attention.linear_qkv.weight"]
|
| 197 |
+
qkv_weights = qkv_weights.reshape([qkv_total_dim, -1, hidden_size])
|
| 198 |
+
|
| 199 |
+
q_slice = torch.cat(
|
| 200 |
+
[
|
| 201 |
+
torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group)
|
| 202 |
+
for i in range(num_query_groups)
|
| 203 |
+
]
|
| 204 |
+
)
|
| 205 |
+
k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2))
|
| 206 |
+
v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2))
|
| 207 |
+
## Example of slices
|
| 208 |
+
## (without GQA): num_query_groups = head_num = 32,
|
| 209 |
+
## q_slice = [0, 3, 6, 9 , ... 90, 93]
|
| 210 |
+
## k_slice = [1, 4, 7, 10, ... 91, 94]
|
| 211 |
+
## v_slice = [2, 5, 8, 11, ... 92, 95]
|
| 212 |
+
## (with GQA): num_query_groups = 8, head_num = 64
|
| 213 |
+
## q_slice = [0, 1, .. 6, 7, 10, 11, .. 16, 17, 20, 21, .. 67, 70, ... 76, 77]
|
| 214 |
+
## k_slice = [8, 18, 28, ... 68, 78]
|
| 215 |
+
## v_slice = [9, 19, 29, ... 69, 79]
|
| 216 |
+
|
| 217 |
+
q_weights_base_name = f"model.layers.{l}.self_attn.q_proj.weight"
|
| 218 |
+
k_weights_base_name = f"model.layers.{l}.self_attn.k_proj.weight"
|
| 219 |
+
v_weights_base_name = f"model.layers.{l}.self_attn.v_proj.weight"
|
| 220 |
+
|
| 221 |
+
checkpoint[q_weights_base_name] = param_to_weights(qkv_weights[q_slice].reshape(-1, hidden_size))
|
| 222 |
+
checkpoint[k_weights_base_name] = param_to_weights(qkv_weights[k_slice].reshape(-1, hidden_size))
|
| 223 |
+
checkpoint[v_weights_base_name] = param_to_weights(qkv_weights[v_slice].reshape(-1, hidden_size))
|
| 224 |
+
|
| 225 |
+
# attention dense
|
| 226 |
+
o_weight = model.state_dict()[f"model.decoder.layers.{l}.self_attention.linear_proj.weight"]
|
| 227 |
+
o_weight_base_name = f"model.layers.{l}.self_attn.o_proj.weight"
|
| 228 |
+
checkpoint[o_weight_base_name] = param_to_weights(o_weight)
|
| 229 |
+
|
| 230 |
+
# mlp
|
| 231 |
+
mlp_weights = model.state_dict()[f"model.decoder.layers.{l}.mlp.linear_fc1.weight"]
|
| 232 |
+
mlp_up_proj_weight = model.state_dict()[f"model.decoder.layers.{l}.mlp.linear_fc2.weight"]
|
| 233 |
+
|
| 234 |
+
if mlp_weights.shape[0] != mlp_up_proj_weight.shape[1]:
|
| 235 |
+
# Has projection (used for swi-glu)
|
| 236 |
+
logging.warning(
|
| 237 |
+
"Gated projection layers detected in NeMo checkpoint. Currently Nemotron HF does not support gated MLP."
|
| 238 |
+
)
|
| 239 |
+
assert mlp_weights.shape[0] == 2 * mlp_up_proj_weight.shape[1]
|
| 240 |
+
|
| 241 |
+
mlp_down_proj_weight = mlp_weights[:ffn_hidden_size, :]
|
| 242 |
+
mlp_gate_proj_weight = mlp_weights[ffn_hidden_size:, :]
|
| 243 |
+
|
| 244 |
+
mlp_down_proj_base_name = f"model.layers.{l}.mlp.gate_proj.weight"
|
| 245 |
+
mlp_gate_proj_base_name = f"model.layers.{l}.mlp.up_proj.weight"
|
| 246 |
+
|
| 247 |
+
checkpoint[mlp_down_proj_base_name] = param_to_weights(mlp_down_proj_weight)
|
| 248 |
+
checkpoint[mlp_gate_proj_base_name] = param_to_weights(mlp_gate_proj_weight)
|
| 249 |
+
else:
|
| 250 |
+
mlp_down_proj_weight = mlp_weights
|
| 251 |
+
mlp_down_proj_base_name = f"model.layers.{l}.mlp.up_proj.weight"
|
| 252 |
+
checkpoint[mlp_down_proj_base_name] = param_to_weights(mlp_down_proj_weight)
|
| 253 |
+
|
| 254 |
+
mlp_up_proj_base_name = f"model.layers.{l}.mlp.down_proj.weight"
|
| 255 |
+
checkpoint[mlp_up_proj_base_name] = param_to_weights(mlp_up_proj_weight)
|
| 256 |
+
|
| 257 |
+
# layernorm
|
| 258 |
+
input_ln_weight = model.state_dict()[f"model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_weight"]
|
| 259 |
+
input_ln_base_name = f"model.layers.{l}.input_layernorm.weight"
|
| 260 |
+
checkpoint[input_ln_base_name] = param_to_weights(input_ln_weight)
|
| 261 |
+
if (
|
| 262 |
+
model.state_dict().get(f"model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_bias", None)
|
| 263 |
+
is not None
|
| 264 |
+
):
|
| 265 |
+
input_ln_bias = model.state_dict()[f"model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_bias"]
|
| 266 |
+
input_ln_bias_name = f"model.layers.{l}.input_layernorm.bias"
|
| 267 |
+
checkpoint[input_ln_bias_name] = param_to_weights(input_ln_bias)
|
| 268 |
+
|
| 269 |
+
post_attn_ln_weight = model.state_dict()[f"model.decoder.layers.{l}.mlp.linear_fc1.layer_norm_weight"]
|
| 270 |
+
post_attn_ln_base_name = f"model.layers.{l}.post_attention_layernorm.weight"
|
| 271 |
+
checkpoint[post_attn_ln_base_name] = param_to_weights(post_attn_ln_weight)
|
| 272 |
+
if model.state_dict().get(f"model.decoder.layers.{l}.mlp.linear_fc1.layer_norm_bias", None) is not None:
|
| 273 |
+
post_attn_ln_bias = model.state_dict()[f"model.decoder.layers.{l}.mlp.linear_fc1.layer_norm_bias"]
|
| 274 |
+
post_attn_ln_bias_name = f"model.layers.{l}.post_attention_layernorm.bias"
|
| 275 |
+
checkpoint[post_attn_ln_bias_name] = param_to_weights(post_attn_ln_bias)
|
| 276 |
+
|
| 277 |
+
print(f"done layer {l}")
|
| 278 |
+
|
| 279 |
+
final_ln_weight = model.state_dict()["model.decoder.final_layernorm.weight"]
|
| 280 |
+
final_ln_base_name = "model.norm.weight"
|
| 281 |
+
checkpoint[final_ln_base_name] = param_to_weights(final_ln_weight)
|
| 282 |
+
if model.state_dict().get("model.decoder.final_layernorm.bias", None) is not None:
|
| 283 |
+
final_ln_bias = model.state_dict()["model.decoder.final_layernorm.bias"]
|
| 284 |
+
final_ln_bias_name = "model.norm.bias"
|
| 285 |
+
checkpoint[final_ln_bias_name] = param_to_weights(final_ln_bias)
|
| 286 |
+
|
| 287 |
+
output_layer_weight = model.state_dict()["model.output_layer.weight"]
|
| 288 |
+
output_layer_base_name = "lm_head.weight"
|
| 289 |
+
checkpoint[output_layer_base_name] = param_to_weights(output_layer_weight)
|
| 290 |
+
|
| 291 |
+
os.makedirs(os.path.dirname(output_hf_file), exist_ok=True)
|
| 292 |
+
torch.save(checkpoint, output_hf_file)
|
| 293 |
+
logging.info(f"Weights saved to {output_hf_file}")
|
| 294 |
+
|
| 295 |
+
return model_config, model.tokenizer, dtype, vocab_size
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def extract_nemotron_tokenizer(nemo_file, model_config, output_hf_path, nemo_tokenizer):
|
| 299 |
+
tokenizer_cfg = model_config.tokenizer
|
| 300 |
+
if tokenizer_cfg.library == "sentencepiece":
|
| 301 |
+
# For sentencepiece tokenizer, we are wrapping with HF's LlamaTokenizer
|
| 302 |
+
# and convert it to a PreTrainedTokenizerFast
|
| 303 |
+
tokenizer_fn = tokenizer_cfg.model[5:]
|
| 304 |
+
output_tokenizer = f"{output_hf_path}/tokenizer.model"
|
| 305 |
+
if nemo_file.endswith(".nemo"):
|
| 306 |
+
import tarfile
|
| 307 |
+
|
| 308 |
+
archive = tarfile.open(nemo_file, "r")
|
| 309 |
+
tokenizer_filename = "./" + tokenizer_fn # exclude 'nemo:' prefix
|
| 310 |
+
archive.extract(tokenizer_filename, output_hf_path)
|
| 311 |
+
archive.close()
|
| 312 |
+
os.rename(f"{output_hf_path}/{tokenizer_fn}", output_tokenizer)
|
| 313 |
+
elif os.path.isdir(nemo_file):
|
| 314 |
+
shutil.copy(f"{nemo_file}/{tokenizer_fn}", output_tokenizer)
|
| 315 |
+
# We use LlamaTokenizer for sentencepiece based tokenizer
|
| 316 |
+
tokenizer = LlamaTokenizer.from_pretrained(output_hf_path, legacy=False)
|
| 317 |
+
# Convert the LlamaTokenizer to a PreTrainedTokenizerFast instance
|
| 318 |
+
tokenizer = PreTrainedTokenizerFast(
|
| 319 |
+
tokenizer_object=LlamaConverter(tokenizer).converted(), model_input_names=["input_ids", "token_type_ids"]
|
| 320 |
+
)
|
| 321 |
+
tokenizer.save_pretrained(output_hf_path)
|
| 322 |
+
logging.info(f"Setencepiece tokenizer has been saved to {output_tokenizer}")
|
| 323 |
+
elif isinstance(nemo_tokenizer, AutoTokenizer):
|
| 324 |
+
nemo_tokenizer.tokenizer.save_pretrained(output_hf_path)
|
| 325 |
+
logging.info(f"HF AutoTokenizer has been saved to {output_hf_path}")
|
| 326 |
+
else:
|
| 327 |
+
raise ValueError(f"Unsupported tokenizer type: library: {tokenizer_cfg.library}, type: {tokenizer_cfg.type}")
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
if __name__ == "__main__":
|
| 331 |
+
args = get_args()
|
| 332 |
+
if not args.hf_output_path:
|
| 333 |
+
assert args.output_path is not None, "Need to provide either output_path or hf_output_path"
|
| 334 |
+
else:
|
| 335 |
+
args.output_path = f"{args.hf_output_path}/pytorch_model.bin"
|
| 336 |
+
logging.info(f"weight will be saved to {args.output_path}")
|
| 337 |
+
|
| 338 |
+
nemo_config, nemo_tokenizer, dtype, vocab_size = convert(
|
| 339 |
+
args.input_name_or_path, args.output_path, precision=args.precision, cpu_only=args.cpu_only
|
| 340 |
+
)
|
| 341 |
+
if args.hf_input_path and args.hf_output_path:
|
| 342 |
+
convert_hf_config(nemo_config, nemo_tokenizer, vocab_size, dtype, args.hf_output_path, args.hf_input_path)
|
| 343 |
+
extract_nemotron_tokenizer(args.input_name_or_path, nemo_config, args.hf_output_path, nemo_tokenizer)
|
| 344 |
+
else:
|
| 345 |
+
logging.info("`hf_input_path` and/or `hf_output_path` not provided, not generating full HF model.")
|
| 346 |
+
logging.info(f".bin file is saved to {args.output_path}")
|
docs/transformers/src/transformers/models/nllb/tokenization_nllb.py
ADDED
|
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
from shutil import copyfile
|
| 18 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 19 |
+
|
| 20 |
+
import sentencepiece as spm
|
| 21 |
+
|
| 22 |
+
from ...tokenization_utils import AddedToken, BatchEncoding, PreTrainedTokenizer
|
| 23 |
+
from ...utils import logging
|
| 24 |
+
from ...utils.import_utils import requires
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
logger = logging.get_logger(__name__)
|
| 28 |
+
|
| 29 |
+
SPIECE_UNDERLINE = "▁"
|
| 30 |
+
|
| 31 |
+
VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"}
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
FAIRSEQ_LANGUAGE_CODES = ['ace_Arab', 'ace_Latn', 'acm_Arab', 'acq_Arab', 'aeb_Arab', 'afr_Latn', 'ajp_Arab', 'aka_Latn', 'amh_Ethi', 'apc_Arab', 'arb_Arab', 'ars_Arab', 'ary_Arab', 'arz_Arab', 'asm_Beng', 'ast_Latn', 'awa_Deva', 'ayr_Latn', 'azb_Arab', 'azj_Latn', 'bak_Cyrl', 'bam_Latn', 'ban_Latn', 'bel_Cyrl', 'bem_Latn', 'ben_Beng', 'bho_Deva', 'bjn_Arab', 'bjn_Latn', 'bod_Tibt', 'bos_Latn', 'bug_Latn', 'bul_Cyrl', 'cat_Latn', 'ceb_Latn', 'ces_Latn', 'cjk_Latn', 'ckb_Arab', 'crh_Latn', 'cym_Latn', 'dan_Latn', 'deu_Latn', 'dik_Latn', 'dyu_Latn', 'dzo_Tibt', 'ell_Grek', 'eng_Latn', 'epo_Latn', 'est_Latn', 'eus_Latn', 'ewe_Latn', 'fao_Latn', 'pes_Arab', 'fij_Latn', 'fin_Latn', 'fon_Latn', 'fra_Latn', 'fur_Latn', 'fuv_Latn', 'gla_Latn', 'gle_Latn', 'glg_Latn', 'grn_Latn', 'guj_Gujr', 'hat_Latn', 'hau_Latn', 'heb_Hebr', 'hin_Deva', 'hne_Deva', 'hrv_Latn', 'hun_Latn', 'hye_Armn', 'ibo_Latn', 'ilo_Latn', 'ind_Latn', 'isl_Latn', 'ita_Latn', 'jav_Latn', 'jpn_Jpan', 'kab_Latn', 'kac_Latn', 'kam_Latn', 'kan_Knda', 'kas_Arab', 'kas_Deva', 'kat_Geor', 'knc_Arab', 'knc_Latn', 'kaz_Cyrl', 'kbp_Latn', 'kea_Latn', 'khm_Khmr', 'kik_Latn', 'kin_Latn', 'kir_Cyrl', 'kmb_Latn', 'kon_Latn', 'kor_Hang', 'kmr_Latn', 'lao_Laoo', 'lvs_Latn', 'lij_Latn', 'lim_Latn', 'lin_Latn', 'lit_Latn', 'lmo_Latn', 'ltg_Latn', 'ltz_Latn', 'lua_Latn', 'lug_Latn', 'luo_Latn', 'lus_Latn', 'mag_Deva', 'mai_Deva', 'mal_Mlym', 'mar_Deva', 'min_Latn', 'mkd_Cyrl', 'plt_Latn', 'mlt_Latn', 'mni_Beng', 'khk_Cyrl', 'mos_Latn', 'mri_Latn', 'zsm_Latn', 'mya_Mymr', 'nld_Latn', 'nno_Latn', 'nob_Latn', 'npi_Deva', 'nso_Latn', 'nus_Latn', 'nya_Latn', 'oci_Latn', 'gaz_Latn', 'ory_Orya', 'pag_Latn', 'pan_Guru', 'pap_Latn', 'pol_Latn', 'por_Latn', 'prs_Arab', 'pbt_Arab', 'quy_Latn', 'ron_Latn', 'run_Latn', 'rus_Cyrl', 'sag_Latn', 'san_Deva', 'sat_Beng', 'scn_Latn', 'shn_Mymr', 'sin_Sinh', 'slk_Latn', 'slv_Latn', 'smo_Latn', 'sna_Latn', 'snd_Arab', 'som_Latn', 'sot_Latn', 'spa_Latn', 'als_Latn', 'srd_Latn', 'srp_Cyrl', 'ssw_Latn', 'sun_Latn', 'swe_Latn', 'swh_Latn', 'szl_Latn', 'tam_Taml', 'tat_Cyrl', 'tel_Telu', 'tgk_Cyrl', 'tgl_Latn', 'tha_Thai', 'tir_Ethi', 'taq_Latn', 'taq_Tfng', 'tpi_Latn', 'tsn_Latn', 'tso_Latn', 'tuk_Latn', 'tum_Latn', 'tur_Latn', 'twi_Latn', 'tzm_Tfng', 'uig_Arab', 'ukr_Cyrl', 'umb_Latn', 'urd_Arab', 'uzn_Latn', 'vec_Latn', 'vie_Latn', 'war_Latn', 'wol_Latn', 'xho_Latn', 'ydd_Hebr', 'yor_Latn', 'yue_Hant', 'zho_Hans', 'zho_Hant', 'zul_Latn'] # fmt: skip
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@requires(backends=("sentencepiece",))
|
| 38 |
+
class NllbTokenizer(PreTrainedTokenizer):
|
| 39 |
+
"""
|
| 40 |
+
Construct an NLLB tokenizer.
|
| 41 |
+
|
| 42 |
+
Adapted from [`RobertaTokenizer`] and [`XLNetTokenizer`]. Based on
|
| 43 |
+
[SentencePiece](https://github.com/google/sentencepiece).
|
| 44 |
+
|
| 45 |
+
The tokenization method is `<tokens> <eos> <language code>` for source language documents, and `<language code>
|
| 46 |
+
<tokens> <eos>` for target language documents.
|
| 47 |
+
|
| 48 |
+
Examples:
|
| 49 |
+
|
| 50 |
+
```python
|
| 51 |
+
>>> from transformers import NllbTokenizer
|
| 52 |
+
|
| 53 |
+
>>> tokenizer = NllbTokenizer.from_pretrained(
|
| 54 |
+
... "facebook/nllb-200-distilled-600M", src_lang="eng_Latn", tgt_lang="fra_Latn"
|
| 55 |
+
... )
|
| 56 |
+
>>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
|
| 57 |
+
>>> expected_translation_french = "Le chef de l'ONU affirme qu'il n'y a pas de solution militaire en Syrie."
|
| 58 |
+
>>> inputs = tokenizer(example_english_phrase, text_target=expected_translation_french, return_tensors="pt")
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
vocab_file (`str`):
|
| 63 |
+
Path to the vocabulary file.
|
| 64 |
+
bos_token (`str`, *optional*, defaults to `"<s>"`):
|
| 65 |
+
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
| 66 |
+
|
| 67 |
+
<Tip>
|
| 68 |
+
|
| 69 |
+
When building a sequence using special tokens, this is not the token that is used for the beginning of
|
| 70 |
+
sequence. The token used is the `cls_token`.
|
| 71 |
+
|
| 72 |
+
</Tip>
|
| 73 |
+
|
| 74 |
+
eos_token (`str`, *optional*, defaults to `"</s>"`):
|
| 75 |
+
The end of sequence token.
|
| 76 |
+
|
| 77 |
+
<Tip>
|
| 78 |
+
|
| 79 |
+
When building a sequence using special tokens, this is not the token that is used for the end of sequence.
|
| 80 |
+
The token used is the `sep_token`.
|
| 81 |
+
|
| 82 |
+
</Tip>
|
| 83 |
+
|
| 84 |
+
sep_token (`str`, *optional*, defaults to `"</s>"`):
|
| 85 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
|
| 86 |
+
sequence classification or for a text and a question for question answering. It is also used as the last
|
| 87 |
+
token of a sequence built with special tokens.
|
| 88 |
+
cls_token (`str`, *optional*, defaults to `"<s>"`):
|
| 89 |
+
The classifier token which is used when doing sequence classification (classification of the whole sequence
|
| 90 |
+
instead of per-token classification). It is the first token of the sequence when built with special tokens.
|
| 91 |
+
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
| 92 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 93 |
+
token instead.
|
| 94 |
+
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
| 95 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 96 |
+
mask_token (`str`, *optional*, defaults to `"<mask>"`):
|
| 97 |
+
The token used for masking values. This is the token used when training this model with masked language
|
| 98 |
+
modeling. This is the token which the model will try to predict.
|
| 99 |
+
tokenizer_file (`str`, *optional*):
|
| 100 |
+
The path to a tokenizer file to use instead of the vocab file.
|
| 101 |
+
src_lang (`str`, *optional*):
|
| 102 |
+
The language to use as source language for translation.
|
| 103 |
+
tgt_lang (`str`, *optional*):
|
| 104 |
+
The language to use as target language for translation.
|
| 105 |
+
sp_model_kwargs (`Dict[str, str]`):
|
| 106 |
+
Additional keyword arguments to pass to the model initialization.
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 110 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 111 |
+
|
| 112 |
+
prefix_tokens: List[int] = []
|
| 113 |
+
suffix_tokens: List[int] = []
|
| 114 |
+
|
| 115 |
+
def __init__(
|
| 116 |
+
self,
|
| 117 |
+
vocab_file,
|
| 118 |
+
bos_token="<s>",
|
| 119 |
+
eos_token="</s>",
|
| 120 |
+
sep_token="</s>",
|
| 121 |
+
cls_token="<s>",
|
| 122 |
+
unk_token="<unk>",
|
| 123 |
+
pad_token="<pad>",
|
| 124 |
+
mask_token="<mask>",
|
| 125 |
+
tokenizer_file=None,
|
| 126 |
+
src_lang=None,
|
| 127 |
+
tgt_lang=None,
|
| 128 |
+
sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
| 129 |
+
additional_special_tokens=None,
|
| 130 |
+
legacy_behaviour=False,
|
| 131 |
+
**kwargs,
|
| 132 |
+
):
|
| 133 |
+
if additional_special_tokens is None:
|
| 134 |
+
additional_special_tokens = FAIRSEQ_LANGUAGE_CODES
|
| 135 |
+
bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
|
| 136 |
+
pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token
|
| 137 |
+
eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
|
| 138 |
+
unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
|
| 139 |
+
# Mask token behave like a normal word, i.e. include the space before it
|
| 140 |
+
mask_token = (
|
| 141 |
+
AddedToken(mask_token, normalized=True, lstrip=True, special=True)
|
| 142 |
+
if isinstance(mask_token, str)
|
| 143 |
+
else mask_token
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
| 147 |
+
self.legacy_behaviour = legacy_behaviour
|
| 148 |
+
|
| 149 |
+
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
| 150 |
+
self.sp_model.Load(str(vocab_file))
|
| 151 |
+
self.vocab_file = vocab_file
|
| 152 |
+
# Original fairseq vocab and spm vocab must be "aligned":
|
| 153 |
+
# Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
|
| 154 |
+
# -------- | ------- | ------- | ------ | ------- | ---- | ---- | ---- | ---- | ---- | ----
|
| 155 |
+
# fairseq | '<s>' | '<pad>' | '</s>' | '<unk>' | 'an' | '▁n' | '▁m' | '▁t' | '▁k' | '▁a'
|
| 156 |
+
# spm | '<unk>' | '<s>' | '</s>' | 'an' | '▁n' | '▁m' | '▁t' | '▁k' | '▁a' | '▁s'
|
| 157 |
+
|
| 158 |
+
# unk token needs to be in the vocab with correct index
|
| 159 |
+
self._added_tokens_decoder = {0: bos_token, 1: pad_token, 2: eos_token, 3: unk_token}
|
| 160 |
+
# The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab
|
| 161 |
+
self.fairseq_offset = 1
|
| 162 |
+
self.sp_model_size = len(self.sp_model)
|
| 163 |
+
|
| 164 |
+
super().__init__(
|
| 165 |
+
bos_token=bos_token,
|
| 166 |
+
eos_token=eos_token,
|
| 167 |
+
unk_token=unk_token,
|
| 168 |
+
sep_token=sep_token,
|
| 169 |
+
cls_token=cls_token,
|
| 170 |
+
pad_token=pad_token,
|
| 171 |
+
mask_token=mask_token,
|
| 172 |
+
tokenizer_file=tokenizer_file,
|
| 173 |
+
src_lang=src_lang,
|
| 174 |
+
tgt_lang=tgt_lang,
|
| 175 |
+
additional_special_tokens=additional_special_tokens,
|
| 176 |
+
sp_model_kwargs=self.sp_model_kwargs,
|
| 177 |
+
legacy_behaviour=legacy_behaviour,
|
| 178 |
+
**kwargs,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
self._src_lang = src_lang if src_lang is not None else "eng_Latn"
|
| 182 |
+
self.cur_lang_code_id = self.convert_tokens_to_ids(self._src_lang)
|
| 183 |
+
self.tgt_lang = tgt_lang
|
| 184 |
+
self.set_src_lang_special_tokens(self._src_lang)
|
| 185 |
+
|
| 186 |
+
def __getstate__(self):
|
| 187 |
+
state = self.__dict__.copy()
|
| 188 |
+
state["sp_model"] = None
|
| 189 |
+
state["sp_model_proto"] = self.sp_model.serialized_model_proto()
|
| 190 |
+
return state
|
| 191 |
+
|
| 192 |
+
def __setstate__(self, d):
|
| 193 |
+
self.__dict__ = d
|
| 194 |
+
|
| 195 |
+
# for backward compatibility
|
| 196 |
+
if not hasattr(self, "sp_model_kwargs"):
|
| 197 |
+
self.sp_model_kwargs = {}
|
| 198 |
+
|
| 199 |
+
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
| 200 |
+
self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
|
| 201 |
+
|
| 202 |
+
@property
|
| 203 |
+
def vocab_size(self):
|
| 204 |
+
return len(self.sp_model) + self.fairseq_offset
|
| 205 |
+
|
| 206 |
+
@property
|
| 207 |
+
def src_lang(self) -> str:
|
| 208 |
+
return self._src_lang
|
| 209 |
+
|
| 210 |
+
@src_lang.setter
|
| 211 |
+
def src_lang(self, new_src_lang: str) -> None:
|
| 212 |
+
self._src_lang = new_src_lang
|
| 213 |
+
self.set_src_lang_special_tokens(self._src_lang)
|
| 214 |
+
|
| 215 |
+
def get_special_tokens_mask(
|
| 216 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 217 |
+
) -> List[int]:
|
| 218 |
+
"""
|
| 219 |
+
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 220 |
+
special tokens using the tokenizer `prepare_for_model` method.
|
| 221 |
+
|
| 222 |
+
Args:
|
| 223 |
+
token_ids_0 (`List[int]`):
|
| 224 |
+
List of IDs.
|
| 225 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 226 |
+
Optional second list of IDs for sequence pairs.
|
| 227 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
| 228 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 232 |
+
"""
|
| 233 |
+
|
| 234 |
+
if already_has_special_tokens:
|
| 235 |
+
return super().get_special_tokens_mask(
|
| 236 |
+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
prefix_ones = [1] * len(self.prefix_tokens)
|
| 240 |
+
suffix_ones = [1] * len(self.suffix_tokens)
|
| 241 |
+
if token_ids_1 is None:
|
| 242 |
+
return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
|
| 243 |
+
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
|
| 244 |
+
|
| 245 |
+
def build_inputs_with_special_tokens(
|
| 246 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 247 |
+
) -> List[int]:
|
| 248 |
+
"""
|
| 249 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
| 250 |
+
adding special tokens. An NLLB sequence has the following format, where `X` represents the sequence:
|
| 251 |
+
|
| 252 |
+
- `input_ids` (for encoder) `X [eos, src_lang_code]`
|
| 253 |
+
- `decoder_input_ids`: (for decoder) `X [eos, tgt_lang_code]`
|
| 254 |
+
|
| 255 |
+
BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
|
| 256 |
+
separator.
|
| 257 |
+
|
| 258 |
+
Args:
|
| 259 |
+
token_ids_0 (`List[int]`):
|
| 260 |
+
List of IDs to which the special tokens will be added.
|
| 261 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 262 |
+
Optional second list of IDs for sequence pairs.
|
| 263 |
+
|
| 264 |
+
Returns:
|
| 265 |
+
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
| 266 |
+
"""
|
| 267 |
+
if token_ids_1 is None:
|
| 268 |
+
return self.prefix_tokens + token_ids_0 + self.suffix_tokens
|
| 269 |
+
# We don't expect to process pairs, but leave the pair logic for API consistency
|
| 270 |
+
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
| 271 |
+
|
| 272 |
+
def create_token_type_ids_from_sequences(
|
| 273 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 274 |
+
) -> List[int]:
|
| 275 |
+
"""
|
| 276 |
+
Create a mask from the two sequences passed to be used in a sequence-pair classification task. nllb does not
|
| 277 |
+
make use of token type ids, therefore a list of zeros is returned.
|
| 278 |
+
|
| 279 |
+
Args:
|
| 280 |
+
token_ids_0 (`List[int]`):
|
| 281 |
+
List of IDs.
|
| 282 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 283 |
+
Optional second list of IDs for sequence pairs.
|
| 284 |
+
|
| 285 |
+
Returns:
|
| 286 |
+
`List[int]`: List of zeros.
|
| 287 |
+
|
| 288 |
+
"""
|
| 289 |
+
|
| 290 |
+
sep = [self.sep_token_id]
|
| 291 |
+
cls = [self.cls_token_id]
|
| 292 |
+
|
| 293 |
+
if token_ids_1 is None:
|
| 294 |
+
return len(cls + token_ids_0 + sep) * [0]
|
| 295 |
+
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
|
| 296 |
+
|
| 297 |
+
def _build_translation_inputs(
|
| 298 |
+
self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs
|
| 299 |
+
):
|
| 300 |
+
"""Used by translation pipeline, to prepare inputs for the generate function"""
|
| 301 |
+
if src_lang is None or tgt_lang is None:
|
| 302 |
+
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
|
| 303 |
+
self.src_lang = src_lang
|
| 304 |
+
inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs)
|
| 305 |
+
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
|
| 306 |
+
inputs["forced_bos_token_id"] = tgt_lang_id
|
| 307 |
+
return inputs
|
| 308 |
+
|
| 309 |
+
def get_vocab(self):
|
| 310 |
+
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
| 311 |
+
vocab.update(self.added_tokens_encoder)
|
| 312 |
+
return vocab
|
| 313 |
+
|
| 314 |
+
def _tokenize(self, text: str) -> List[str]:
|
| 315 |
+
return self.sp_model.encode(text, out_type=str)
|
| 316 |
+
|
| 317 |
+
def _convert_token_to_id(self, token):
|
| 318 |
+
"""Converts a token (str) in an id using the vocab."""
|
| 319 |
+
spm_id = self.sp_model.PieceToId(token)
|
| 320 |
+
# Need to return unknown token if the SP model returned 0
|
| 321 |
+
return spm_id + self.fairseq_offset if spm_id else self.unk_token_id
|
| 322 |
+
|
| 323 |
+
def _convert_id_to_token(self, index):
|
| 324 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 325 |
+
return self.sp_model.IdToPiece(index - self.fairseq_offset)
|
| 326 |
+
|
| 327 |
+
def convert_tokens_to_string(self, tokens):
|
| 328 |
+
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
|
| 329 |
+
out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
|
| 330 |
+
return out_string
|
| 331 |
+
|
| 332 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 333 |
+
if not os.path.isdir(save_directory):
|
| 334 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
| 335 |
+
return
|
| 336 |
+
out_vocab_file = os.path.join(
|
| 337 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
| 341 |
+
copyfile(self.vocab_file, out_vocab_file)
|
| 342 |
+
elif not os.path.isfile(self.vocab_file):
|
| 343 |
+
with open(out_vocab_file, "wb") as fi:
|
| 344 |
+
content_spiece_model = self.sp_model.serialized_model_proto()
|
| 345 |
+
fi.write(content_spiece_model)
|
| 346 |
+
|
| 347 |
+
return (out_vocab_file,)
|
| 348 |
+
|
| 349 |
+
def prepare_seq2seq_batch(
|
| 350 |
+
self,
|
| 351 |
+
src_texts: List[str],
|
| 352 |
+
src_lang: str = "eng_Latn",
|
| 353 |
+
tgt_texts: Optional[List[str]] = None,
|
| 354 |
+
tgt_lang: str = "fra_Latn",
|
| 355 |
+
**kwargs,
|
| 356 |
+
) -> BatchEncoding:
|
| 357 |
+
self.src_lang = src_lang
|
| 358 |
+
self.tgt_lang = tgt_lang
|
| 359 |
+
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
|
| 360 |
+
|
| 361 |
+
def _switch_to_input_mode(self):
|
| 362 |
+
return self.set_src_lang_special_tokens(self.src_lang)
|
| 363 |
+
|
| 364 |
+
def _switch_to_target_mode(self):
|
| 365 |
+
return self.set_tgt_lang_special_tokens(self.tgt_lang)
|
| 366 |
+
|
| 367 |
+
def set_src_lang_special_tokens(self, src_lang) -> None:
|
| 368 |
+
"""Reset the special tokens to the source lang setting.
|
| 369 |
+
- In legacy mode: No prefix and suffix=[eos, src_lang_code].
|
| 370 |
+
- In default mode: Prefix=[src_lang_code], suffix = [eos]
|
| 371 |
+
"""
|
| 372 |
+
self.cur_lang_code = self.convert_tokens_to_ids(src_lang)
|
| 373 |
+
if self.legacy_behaviour:
|
| 374 |
+
self.prefix_tokens = []
|
| 375 |
+
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
|
| 376 |
+
else:
|
| 377 |
+
self.prefix_tokens = [self.cur_lang_code]
|
| 378 |
+
self.suffix_tokens = [self.eos_token_id]
|
| 379 |
+
|
| 380 |
+
def set_tgt_lang_special_tokens(self, lang: str) -> None:
|
| 381 |
+
"""Reset the special tokens to the target lang setting.
|
| 382 |
+
- In legacy mode: No prefix and suffix=[eos, tgt_lang_code].
|
| 383 |
+
- In default mode: Prefix=[tgt_lang_code], suffix = [eos]
|
| 384 |
+
"""
|
| 385 |
+
self.cur_lang_code = self.convert_tokens_to_ids(lang)
|
| 386 |
+
if self.legacy_behaviour:
|
| 387 |
+
self.prefix_tokens = []
|
| 388 |
+
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
|
| 389 |
+
else:
|
| 390 |
+
self.prefix_tokens = [self.cur_lang_code]
|
| 391 |
+
self.suffix_tokens = [self.eos_token_id]
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
__all__ = ["NllbTokenizer"]
|
docs/transformers/src/transformers/models/nllb_moe/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_nllb_moe import *
|
| 22 |
+
from .modeling_nllb_moe import *
|
| 23 |
+
else:
|
| 24 |
+
import sys
|
| 25 |
+
|
| 26 |
+
_file = globals()["__file__"]
|
| 27 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/src/transformers/models/nllb_moe/configuration_nllb_moe.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023, HuggingFace Inc.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""NLLB-MoE model configuration"""
|
| 16 |
+
|
| 17 |
+
from ...configuration_utils import PretrainedConfig
|
| 18 |
+
from ...utils import logging
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
logger = logging.get_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class NllbMoeConfig(PretrainedConfig):
|
| 25 |
+
r"""
|
| 26 |
+
This is the configuration class to store the configuration of a [`NllbMoeModel`]. It is used to instantiate an
|
| 27 |
+
NLLB-MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 28 |
+
with the defaults will yield a similar configuration to that of the NLLB-MoE
|
| 29 |
+
[facebook/nllb-moe-54b](https://huggingface.co/facebook/nllb-moe-54b) architecture.
|
| 30 |
+
|
| 31 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 32 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
vocab_size (`int`, *optional*, defaults to 50265):
|
| 37 |
+
Vocabulary size of the NllbMoe model. Defines the number of different tokens that can be represented by the
|
| 38 |
+
`inputs_ids` passed when calling [`NllbMoeModel`] or
|
| 39 |
+
d_model (`int`, *optional*, defaults to 1024):
|
| 40 |
+
Dimensionality of the layers and the pooler layer.
|
| 41 |
+
encoder_layers (`int`, *optional*, defaults to 12):
|
| 42 |
+
Number of encoder layers.
|
| 43 |
+
decoder_layers (`int`, *optional*, defaults to 12):
|
| 44 |
+
Number of decoder layers.
|
| 45 |
+
encoder_attention_heads (`int`, *optional*, defaults to 16):
|
| 46 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 47 |
+
decoder_attention_heads (`int`, *optional*, defaults to 16):
|
| 48 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
| 49 |
+
decoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
| 50 |
+
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
| 51 |
+
encoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
| 52 |
+
Dimensionality of the "intermediate" (often named feed-forward) layer in encoder.
|
| 53 |
+
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
|
| 54 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 55 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
| 56 |
+
dropout (`float`, *optional*, defaults to 0.1):
|
| 57 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 58 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 59 |
+
The dropout ratio for the attention probabilities.
|
| 60 |
+
activation_dropout (`float`, *optional*, defaults to 0.0):
|
| 61 |
+
The dropout ratio for activations inside the fully connected layer.
|
| 62 |
+
classifier_dropout (`float`, *optional*, defaults to 0.0):
|
| 63 |
+
The dropout ratio for classifier.
|
| 64 |
+
max_position_embeddings (`int`, *optional*, defaults to 1024):
|
| 65 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
| 66 |
+
just in case (e.g., 512 or 1024 or 2048).
|
| 67 |
+
init_std (`float`, *optional*, defaults to 0.02):
|
| 68 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 69 |
+
encoder_layerdrop (`float`, *optional*, defaults to 0.0):
|
| 70 |
+
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
| 71 |
+
for more details.
|
| 72 |
+
decoder_layerdrop (`float`, *optional*, defaults to 0.0):
|
| 73 |
+
The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
| 74 |
+
for more details.
|
| 75 |
+
second_expert_policy ( `str`, *optional*, default to `"all"`):
|
| 76 |
+
The policy used for the sampling the probability of being sampled to a second expert for each token.
|
| 77 |
+
normalize_router_prob_before_dropping (`bool`, *optional*, defaults to `True`):
|
| 78 |
+
Whether or not to normalize the router probabilities before applying a mask based on the experts capacity
|
| 79 |
+
(capacity dropping).
|
| 80 |
+
batch_prioritized_routing (`bool`, *optional*, defaults to `True`):
|
| 81 |
+
Whether or not to orders the tokens by their router probabilities before capacity dropping. This means that
|
| 82 |
+
the tokens that have the highest probabilities will be routed before other tokens that might be further in
|
| 83 |
+
the sequence.
|
| 84 |
+
moe_eval_capacity_token_fraction (`float`, *optional*, defaults to 1.0):
|
| 85 |
+
Fraction of tokens as capacity during validation, if set to negative, uses the same as training. Should be
|
| 86 |
+
in range: (0.0, 1.0].
|
| 87 |
+
num_experts (`int`, *optional*, defaults to 128):
|
| 88 |
+
Number of experts for each NllbMoeSparseMlp layer.
|
| 89 |
+
expert_capacity (`int`, *optional*, defaults to 64):
|
| 90 |
+
Number of tokens that can be stored in each expert.
|
| 91 |
+
encoder_sparse_step (`int`, *optional*, defaults to 4):
|
| 92 |
+
Frequency of the sparse layers in the encoder. 4 means that one out of 4 layers will be sparse.
|
| 93 |
+
decoder_sparse_step (`int`, *optional*, defaults to 4):
|
| 94 |
+
Frequency of the sparse layers in the decoder. 4 means that one out of 4 layers will be sparse.
|
| 95 |
+
router_dtype (`str`, *optional*, default to `"float32"`):
|
| 96 |
+
The `dtype` used for the routers. It is preferable to keep the `dtype` to `"float32"` as specified in the
|
| 97 |
+
*selective precision* discussion in [the paper](https://arxiv.org/abs/2101.03961).
|
| 98 |
+
router_ignore_padding_tokens (`bool`, *optional*, defaults to `False`):
|
| 99 |
+
Whether to ignore padding tokens when routing. if `False`, the padding tokens are not routed to any
|
| 100 |
+
experts.
|
| 101 |
+
router_bias (`bool`, *optional*, defaults to `False`):
|
| 102 |
+
Whether or not the classifier of the router should have a bias.
|
| 103 |
+
moe_token_dropout (`float`, *optional*, default to 0.2):
|
| 104 |
+
Masking rate for MoE expert output masking (EOM), which is implemented via a Dropout2d on the expert
|
| 105 |
+
outputs.
|
| 106 |
+
output_router_logits (`bool`, *optional*, defaults to `False`):
|
| 107 |
+
Whether or not to return the router logits. Only set to `True` to get the auxiliary loss when training.
|
| 108 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 109 |
+
Whether or not the model should return the last key/values attentions (not used by all models).
|
| 110 |
+
|
| 111 |
+
Example:
|
| 112 |
+
|
| 113 |
+
```python
|
| 114 |
+
>>> from transformers import NllbMoeModel, NllbMoeConfig
|
| 115 |
+
|
| 116 |
+
>>> # Initializing a NllbMoe facebook/nllb-moe-54b style configuration
|
| 117 |
+
>>> configuration = NllbMoeConfig()
|
| 118 |
+
|
| 119 |
+
>>> # Initializing a model from the facebook/nllb-moe-54b style configuration
|
| 120 |
+
>>> model = NllbMoeModel(configuration)
|
| 121 |
+
|
| 122 |
+
>>> # Accessing the model configuration
|
| 123 |
+
>>> configuration = model.config
|
| 124 |
+
```"""
|
| 125 |
+
|
| 126 |
+
model_type = "nllb-moe"
|
| 127 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 128 |
+
attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
|
| 129 |
+
|
| 130 |
+
def __init__(
|
| 131 |
+
self,
|
| 132 |
+
vocab_size=128112,
|
| 133 |
+
max_position_embeddings=1024,
|
| 134 |
+
encoder_layers=12,
|
| 135 |
+
encoder_ffn_dim=4096,
|
| 136 |
+
encoder_attention_heads=16,
|
| 137 |
+
decoder_layers=12,
|
| 138 |
+
decoder_ffn_dim=4096,
|
| 139 |
+
decoder_attention_heads=16,
|
| 140 |
+
encoder_layerdrop=0.05,
|
| 141 |
+
decoder_layerdrop=0.05,
|
| 142 |
+
use_cache=True,
|
| 143 |
+
is_encoder_decoder=True,
|
| 144 |
+
activation_function="relu",
|
| 145 |
+
d_model=1024,
|
| 146 |
+
dropout=0.1,
|
| 147 |
+
attention_dropout=0.1,
|
| 148 |
+
activation_dropout=0.0,
|
| 149 |
+
init_std=0.02,
|
| 150 |
+
decoder_start_token_id=2,
|
| 151 |
+
scale_embedding=True,
|
| 152 |
+
router_bias=False,
|
| 153 |
+
router_dtype="float32",
|
| 154 |
+
router_ignore_padding_tokens=False,
|
| 155 |
+
num_experts=128,
|
| 156 |
+
expert_capacity=64,
|
| 157 |
+
encoder_sparse_step=4,
|
| 158 |
+
decoder_sparse_step=4,
|
| 159 |
+
router_z_loss_coef=0.001,
|
| 160 |
+
router_aux_loss_coef=0.001,
|
| 161 |
+
second_expert_policy="all",
|
| 162 |
+
normalize_router_prob_before_dropping=False,
|
| 163 |
+
batch_prioritized_routing=False,
|
| 164 |
+
moe_eval_capacity_token_fraction=1.0,
|
| 165 |
+
moe_token_dropout=0.2,
|
| 166 |
+
pad_token_id=1,
|
| 167 |
+
bos_token_id=0,
|
| 168 |
+
eos_token_id=2,
|
| 169 |
+
output_router_logits=False,
|
| 170 |
+
**kwargs,
|
| 171 |
+
):
|
| 172 |
+
self.vocab_size = vocab_size
|
| 173 |
+
self.max_position_embeddings = max_position_embeddings
|
| 174 |
+
self.d_model = d_model
|
| 175 |
+
self.encoder_ffn_dim = encoder_ffn_dim
|
| 176 |
+
self.encoder_layers = encoder_layers
|
| 177 |
+
self.encoder_attention_heads = encoder_attention_heads
|
| 178 |
+
self.decoder_ffn_dim = decoder_ffn_dim
|
| 179 |
+
self.decoder_layers = decoder_layers
|
| 180 |
+
self.decoder_attention_heads = decoder_attention_heads
|
| 181 |
+
self.dropout = dropout
|
| 182 |
+
self.attention_dropout = attention_dropout
|
| 183 |
+
self.activation_dropout = activation_dropout
|
| 184 |
+
self.activation_function = activation_function
|
| 185 |
+
self.init_std = init_std
|
| 186 |
+
self.encoder_layerdrop = encoder_layerdrop
|
| 187 |
+
self.decoder_layerdrop = decoder_layerdrop
|
| 188 |
+
self.use_cache = use_cache
|
| 189 |
+
self.num_hidden_layers = encoder_layers
|
| 190 |
+
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
| 191 |
+
self.router_z_loss_coef = router_z_loss_coef
|
| 192 |
+
self.router_aux_loss_coef = router_aux_loss_coef
|
| 193 |
+
self.decoder_sparse_step = decoder_sparse_step
|
| 194 |
+
self.encoder_sparse_step = encoder_sparse_step
|
| 195 |
+
self.num_experts = num_experts
|
| 196 |
+
self.expert_capacity = expert_capacity
|
| 197 |
+
self.router_bias = router_bias
|
| 198 |
+
if router_dtype not in ["float32", "float16", "bfloat16"]:
|
| 199 |
+
raise ValueError(f"`router_dtype` must be one of 'float32', 'float16' or 'bfloat16', got {router_dtype}")
|
| 200 |
+
self.router_dtype = router_dtype
|
| 201 |
+
|
| 202 |
+
self.router_ignore_padding_tokens = router_ignore_padding_tokens
|
| 203 |
+
self.batch_prioritized_routing = batch_prioritized_routing
|
| 204 |
+
self.second_expert_policy = second_expert_policy
|
| 205 |
+
self.normalize_router_prob_before_dropping = normalize_router_prob_before_dropping
|
| 206 |
+
self.moe_eval_capacity_token_fraction = moe_eval_capacity_token_fraction
|
| 207 |
+
self.moe_token_dropout = moe_token_dropout
|
| 208 |
+
self.output_router_logits = output_router_logits
|
| 209 |
+
super().__init__(
|
| 210 |
+
pad_token_id=pad_token_id,
|
| 211 |
+
bos_token_id=bos_token_id,
|
| 212 |
+
eos_token_id=eos_token_id,
|
| 213 |
+
is_encoder_decoder=is_encoder_decoder,
|
| 214 |
+
decoder_start_token_id=decoder_start_token_id,
|
| 215 |
+
**kwargs,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
__all__ = ["NllbMoeConfig"]
|
docs/transformers/src/transformers/models/nllb_moe/convert_nllb_moe_sharded_original_checkpoint_to_pytorch.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import argparse
|
| 15 |
+
import json
|
| 16 |
+
import os
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from torch import nn
|
| 20 |
+
|
| 21 |
+
from transformers import NllbMoeConfig, NllbMoeModel
|
| 22 |
+
from transformers.utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def remove_ignore_keys_(state_dict):
|
| 26 |
+
ignore_keys = [
|
| 27 |
+
"encoder.version",
|
| 28 |
+
"decoder.version",
|
| 29 |
+
"model.encoder.version",
|
| 30 |
+
"model.decoder.version",
|
| 31 |
+
"decoder.output_projection.weight",
|
| 32 |
+
"_float_tensor",
|
| 33 |
+
"encoder.embed_positions._float_tensor",
|
| 34 |
+
"decoder.embed_positions._float_tensor",
|
| 35 |
+
]
|
| 36 |
+
for k in ignore_keys:
|
| 37 |
+
state_dict.pop(k, None)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def make_linear_from_emb(emb):
|
| 41 |
+
vocab_size, emb_size = emb.weight.shape
|
| 42 |
+
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
|
| 43 |
+
lin_layer.weight.data = emb.weight.data
|
| 44 |
+
return lin_layer
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def rename_fairseq_keys(state_dict, expert_idx=None):
|
| 48 |
+
new_dict = {}
|
| 49 |
+
for old_key in state_dict.keys():
|
| 50 |
+
key = old_key
|
| 51 |
+
if "moe_layer.experts." in key:
|
| 52 |
+
if expert_idx is not None:
|
| 53 |
+
key = key.replace("moe_layer.experts.0", f"ffn.experts.expert_{expert_idx}")
|
| 54 |
+
else:
|
| 55 |
+
key = key.replace("moe_layer.experts.", "ffn.experts.expert_")
|
| 56 |
+
if "gate" in key:
|
| 57 |
+
key = key.replace(".moe_layer.gate.wg", ".ffn.router.classifier")
|
| 58 |
+
if "fc2" and "experts" not in key:
|
| 59 |
+
key = key.replace(".fc2.", ".ffn.fc2.")
|
| 60 |
+
if "fc1" and "experts" not in key:
|
| 61 |
+
key = key.replace(".fc1.", ".ffn.fc1.")
|
| 62 |
+
if ".encoder_attn." in key:
|
| 63 |
+
key = key.replace(".encoder_attn.", ".cross_attention.")
|
| 64 |
+
if "encoder_attn_layer_norm" in key:
|
| 65 |
+
key = key.replace("encoder_attn_layer_norm", "cross_attention_layer_norm")
|
| 66 |
+
if "final_layer_norm" in key:
|
| 67 |
+
key = key.replace("final_layer_norm", "ff_layer_norm")
|
| 68 |
+
new_dict[key] = state_dict[old_key]
|
| 69 |
+
return new_dict
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def shard_on_the_fly(switch_checkpoint_path, dump_path, num_experts, dtype, weights_name: str = WEIGHTS_NAME):
|
| 73 |
+
sharded_state_dicts = []
|
| 74 |
+
total_size = 0
|
| 75 |
+
os.makedirs(dump_path, exist_ok=True)
|
| 76 |
+
|
| 77 |
+
for expert in range(num_experts):
|
| 78 |
+
expert_path = switch_checkpoint_path + f"-rank-{expert}.pt"
|
| 79 |
+
if os.path.isfile(expert_path):
|
| 80 |
+
expert_state = torch.load(expert_path, weights_only=True)["model"]
|
| 81 |
+
remove_ignore_keys_(expert_state)
|
| 82 |
+
expert_state = rename_fairseq_keys(expert_state, expert)
|
| 83 |
+
save_path = os.path.join(
|
| 84 |
+
dump_path, weights_name.replace(".bin", f"-{len(sharded_state_dicts) + 1:05d}-of-???.bin")
|
| 85 |
+
)
|
| 86 |
+
torch.save(expert_state, save_path)
|
| 87 |
+
sharded_state_dicts.append(expert_state.keys())
|
| 88 |
+
total_size += sum([value.numel() for key, value in expert_state.items()]) * (
|
| 89 |
+
expert_state[list(expert_state)[0]].element_size()
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# Add the last block
|
| 93 |
+
save_path = os.path.join(
|
| 94 |
+
dump_path, weights_name.replace(".bin", f"-{len(sharded_state_dicts) + 1:05d}-of-???.bin")
|
| 95 |
+
)
|
| 96 |
+
shared_weights = torch.load(switch_checkpoint_path + "-shared.pt", weights_only=True)["model"]
|
| 97 |
+
remove_ignore_keys_(shared_weights)
|
| 98 |
+
shared_weights = rename_fairseq_keys(shared_weights, None)
|
| 99 |
+
shared_weights["shared.weight"] = shared_weights["decoder.embed_tokens.weight"]
|
| 100 |
+
sharded_state_dicts.append(shared_weights.keys())
|
| 101 |
+
|
| 102 |
+
# If we only have the shared weights (dummy model/experts saved on the same file)
|
| 103 |
+
if len(sharded_state_dicts) == 1:
|
| 104 |
+
save_path = os.path.join(dump_path, weights_name)
|
| 105 |
+
torch.save(shared_weights, save_path)
|
| 106 |
+
return {weights_name: sharded_state_dicts[0]}, None
|
| 107 |
+
else:
|
| 108 |
+
torch.save(shared_weights, save_path)
|
| 109 |
+
# Otherwise, let's build the index
|
| 110 |
+
weight_map = {}
|
| 111 |
+
for idx, shard in enumerate(sharded_state_dicts):
|
| 112 |
+
shard_file = weights_name.replace(".bin", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.bin")
|
| 113 |
+
temp_filename = os.path.join(dump_path, weights_name.replace(".bin", f"-{idx + 1:05d}-of-???.bin"))
|
| 114 |
+
os.rename(temp_filename, os.path.join(dump_path, shard_file))
|
| 115 |
+
for key in shard:
|
| 116 |
+
weight_map[key] = shard_file
|
| 117 |
+
|
| 118 |
+
# Add the metadata
|
| 119 |
+
metadata = {"total_size": total_size}
|
| 120 |
+
index = {"metadata": metadata, "weight_map": weight_map}
|
| 121 |
+
|
| 122 |
+
with open(os.path.join(dump_path, WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f:
|
| 123 |
+
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
| 124 |
+
f.write(content)
|
| 125 |
+
|
| 126 |
+
return metadata, index
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
if __name__ == "__main__":
|
| 130 |
+
parser = argparse.ArgumentParser()
|
| 131 |
+
# Required parameters
|
| 132 |
+
parser.add_argument(
|
| 133 |
+
"--nllb_moe_checkpoint_path",
|
| 134 |
+
default="/home/arthur_huggingface_co/fairseq/weights/checkpoints/model_moe_54b/checkpoint_2_300000",
|
| 135 |
+
type=str,
|
| 136 |
+
required=False,
|
| 137 |
+
help="Path to a directory containing a folder per layer. Follows the original Google format.",
|
| 138 |
+
)
|
| 139 |
+
parser.add_argument("--dtype", default="float32", type=str, required=False, help="dtype of the saved model")
|
| 140 |
+
parser.add_argument(
|
| 141 |
+
"--pytorch_dump_folder_path",
|
| 142 |
+
default="/home/arthur_huggingface_co/fairseq/weights/checkpoints/hf-converted-moe-54b",
|
| 143 |
+
type=str,
|
| 144 |
+
required=False,
|
| 145 |
+
help="Path to the output pytorch model.",
|
| 146 |
+
)
|
| 147 |
+
args = parser.parse_args()
|
| 148 |
+
metadata, index = shard_on_the_fly(
|
| 149 |
+
args.nllb_moe_checkpoint_path,
|
| 150 |
+
args.pytorch_dump_folder_path,
|
| 151 |
+
128,
|
| 152 |
+
args.dtype,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
config = NllbMoeConfig.from_pretrained(
|
| 156 |
+
"facebook/nllb-200-3.3B", encoder_sparse_step=4, decoder_sparse_step=4, num_experts=128
|
| 157 |
+
)
|
| 158 |
+
config.save_pretrained(args.pytorch_dump_folder_path)
|
| 159 |
+
model = NllbMoeModel.from_pretrained(args.pytorch_dump_folder_path)
|
| 160 |
+
print("Done")
|
| 161 |
+
model.save_pretrained(args.pytorch_dump_folder_path)
|
docs/transformers/src/transformers/models/nllb_moe/modeling_nllb_moe.py
ADDED
|
@@ -0,0 +1,1784 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 NllbMoe Authors and HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch NLLB-MoE model."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from typing import List, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
from torch.nn import CrossEntropyLoss
|
| 23 |
+
|
| 24 |
+
from ...activations import ACT2FN
|
| 25 |
+
from ...generation import GenerationMixin
|
| 26 |
+
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
| 27 |
+
from ...integrations.fsdp import is_fsdp_managed_module
|
| 28 |
+
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
|
| 29 |
+
from ...modeling_outputs import (
|
| 30 |
+
MoEModelOutput,
|
| 31 |
+
MoEModelOutputWithPastAndCrossAttentions,
|
| 32 |
+
Seq2SeqMoEModelOutput,
|
| 33 |
+
Seq2SeqMoEOutput,
|
| 34 |
+
)
|
| 35 |
+
from ...modeling_utils import PreTrainedModel
|
| 36 |
+
from ...utils import (
|
| 37 |
+
add_end_docstrings,
|
| 38 |
+
add_start_docstrings,
|
| 39 |
+
add_start_docstrings_to_model_forward,
|
| 40 |
+
logging,
|
| 41 |
+
replace_return_docstrings,
|
| 42 |
+
)
|
| 43 |
+
from .configuration_nllb_moe import NllbMoeConfig
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
logger = logging.get_logger(__name__)
|
| 47 |
+
|
| 48 |
+
_CONFIG_FOR_DOC = "NllbMoeConfig"
|
| 49 |
+
_CHECKPOINT_FOR_DOC = "hf-internal-testing/dummy-nllb-moe-2-experts"
|
| 50 |
+
_REAL_CHECKPOINT_FOR_DOC = "facebook/nllb-moe-54b"
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
####################################################
|
| 54 |
+
# This dict contains ids and associated url
|
| 55 |
+
# for the pretrained weights provided with the models
|
| 56 |
+
####################################################
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# Copied from transformers.models.bart.modeling_bart.shift_tokens_right
|
| 60 |
+
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
|
| 61 |
+
"""
|
| 62 |
+
Shift input ids one token to the right.
|
| 63 |
+
"""
|
| 64 |
+
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
| 65 |
+
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
|
| 66 |
+
shifted_input_ids[:, 0] = decoder_start_token_id
|
| 67 |
+
|
| 68 |
+
if pad_token_id is None:
|
| 69 |
+
raise ValueError("self.model.config.pad_token_id has to be defined.")
|
| 70 |
+
# replace possible -100 values in labels by `pad_token_id`
|
| 71 |
+
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
|
| 72 |
+
|
| 73 |
+
return shifted_input_ids
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids
|
| 77 |
+
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
|
| 78 |
+
"""
|
| 79 |
+
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
|
| 80 |
+
are ignored. This is modified from fairseq's `utils.make_positions`.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
x: torch.Tensor x:
|
| 84 |
+
|
| 85 |
+
Returns: torch.Tensor
|
| 86 |
+
"""
|
| 87 |
+
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
|
| 88 |
+
mask = input_ids.ne(padding_idx).int()
|
| 89 |
+
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
|
| 90 |
+
return incremental_indices.long() + padding_idx
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.Tensor) -> float:
|
| 94 |
+
r"""
|
| 95 |
+
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
|
| 96 |
+
|
| 97 |
+
See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
|
| 98 |
+
function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
|
| 99 |
+
experts is too unbalanced.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
router_probs (`torch.Tensor`):
|
| 103 |
+
Probability assigned to each expert per token. Shape: [batch_size, seqeunce_length, num_experts].
|
| 104 |
+
expert_indices (`torch.Tensor`):
|
| 105 |
+
Indices tensor of shape [batch_size, seqeunce_length] identifying the selected expert for a given token.
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
The auxiliary loss.
|
| 109 |
+
"""
|
| 110 |
+
if router_probs is None:
|
| 111 |
+
return 0
|
| 112 |
+
|
| 113 |
+
num_experts = router_probs.shape[-1]
|
| 114 |
+
|
| 115 |
+
# cast the expert indices to int64, otherwise one-hot encoding will fail
|
| 116 |
+
if expert_indices.dtype != torch.int64:
|
| 117 |
+
expert_indices = expert_indices.to(torch.int64)
|
| 118 |
+
|
| 119 |
+
if len(expert_indices.shape) == 2:
|
| 120 |
+
expert_indices = expert_indices.unsqueeze(2)
|
| 121 |
+
|
| 122 |
+
expert_mask = torch.nn.functional.one_hot(expert_indices, num_experts)
|
| 123 |
+
|
| 124 |
+
# For a given token, determine if it was routed to a given expert.
|
| 125 |
+
expert_mask = torch.max(expert_mask, axis=-2).values
|
| 126 |
+
|
| 127 |
+
# cast to float32 otherwise mean will fail
|
| 128 |
+
expert_mask = expert_mask.to(torch.float32)
|
| 129 |
+
tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2)
|
| 130 |
+
|
| 131 |
+
router_prob_per_group_and_expert = torch.mean(router_probs, axis=-2)
|
| 132 |
+
return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) * (num_experts**2)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ScaledWordEmbedding with M2M100->NllbMoe
|
| 136 |
+
class NllbMoeScaledWordEmbedding(nn.Embedding):
|
| 137 |
+
"""
|
| 138 |
+
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
|
| 139 |
+
"""
|
| 140 |
+
|
| 141 |
+
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
|
| 142 |
+
super().__init__(num_embeddings, embedding_dim, padding_idx)
|
| 143 |
+
self.embed_scale = embed_scale
|
| 144 |
+
|
| 145 |
+
def forward(self, input_ids: torch.Tensor):
|
| 146 |
+
return super().forward(input_ids) * self.embed_scale
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding
|
| 150 |
+
class NllbMoeSinusoidalPositionalEmbedding(nn.Module):
|
| 151 |
+
"""This module produces sinusoidal positional embeddings of any length."""
|
| 152 |
+
|
| 153 |
+
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
| 154 |
+
super().__init__()
|
| 155 |
+
self.offset = 2
|
| 156 |
+
self.embedding_dim = embedding_dim
|
| 157 |
+
self.padding_idx = padding_idx
|
| 158 |
+
self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
|
| 159 |
+
|
| 160 |
+
def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
| 161 |
+
emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
|
| 162 |
+
if hasattr(self, "weights"):
|
| 163 |
+
# in forward put the weights on the correct dtype and device of the param
|
| 164 |
+
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
|
| 165 |
+
|
| 166 |
+
self.register_buffer("weights", emb_weights, persistent=False)
|
| 167 |
+
|
| 168 |
+
@staticmethod
|
| 169 |
+
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
| 170 |
+
"""
|
| 171 |
+
Build sinusoidal embeddings.
|
| 172 |
+
|
| 173 |
+
This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
|
| 174 |
+
"Attention Is All You Need".
|
| 175 |
+
"""
|
| 176 |
+
half_dim = embedding_dim // 2
|
| 177 |
+
emb = math.log(10000) / (half_dim - 1)
|
| 178 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb)
|
| 179 |
+
emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0)
|
| 180 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
|
| 181 |
+
if embedding_dim % 2 == 1:
|
| 182 |
+
# zero pad
|
| 183 |
+
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
|
| 184 |
+
if padding_idx is not None:
|
| 185 |
+
emb[padding_idx, :] = 0
|
| 186 |
+
|
| 187 |
+
return emb.to(torch.get_default_dtype())
|
| 188 |
+
|
| 189 |
+
@torch.no_grad()
|
| 190 |
+
def forward(
|
| 191 |
+
self,
|
| 192 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 193 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 194 |
+
past_key_values_length: int = 0,
|
| 195 |
+
):
|
| 196 |
+
if input_ids is not None:
|
| 197 |
+
bsz, seq_len = input_ids.size()
|
| 198 |
+
# Create the position ids from the input token ids. Any padded tokens remain padded.
|
| 199 |
+
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length).to(
|
| 200 |
+
input_ids.device
|
| 201 |
+
)
|
| 202 |
+
else:
|
| 203 |
+
bsz, seq_len = inputs_embeds.size()[:-1]
|
| 204 |
+
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds, past_key_values_length)
|
| 205 |
+
|
| 206 |
+
# expand embeddings if needed
|
| 207 |
+
max_pos = self.padding_idx + 1 + seq_len + past_key_values_length
|
| 208 |
+
if max_pos > self.weights.size(0):
|
| 209 |
+
self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx)
|
| 210 |
+
|
| 211 |
+
return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, self.weights.shape[-1]).detach()
|
| 212 |
+
|
| 213 |
+
def create_position_ids_from_inputs_embeds(self, inputs_embeds, past_key_values_length):
|
| 214 |
+
"""
|
| 215 |
+
We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
inputs_embeds: torch.Tensor
|
| 219 |
+
|
| 220 |
+
Returns: torch.Tensor
|
| 221 |
+
"""
|
| 222 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 223 |
+
sequence_length = input_shape[1]
|
| 224 |
+
|
| 225 |
+
position_ids = torch.arange(
|
| 226 |
+
self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
|
| 227 |
+
)
|
| 228 |
+
return position_ids.unsqueeze(0).expand(input_shape).contiguous() + past_key_values_length
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
class NllbMoeTop2Router(nn.Module):
|
| 232 |
+
"""
|
| 233 |
+
Router using tokens choose top-2 experts assignment.
|
| 234 |
+
|
| 235 |
+
This router uses the same mechanism as in NLLB-MoE from the fairseq repository. Items are sorted by router_probs
|
| 236 |
+
and then routed to their choice of expert until the expert's expert_capacity is reached. **There is no guarantee
|
| 237 |
+
that each token is processed by an expert**, or that each expert receives at least one token.
|
| 238 |
+
|
| 239 |
+
The router combining weights are also returned to make sure that the states that are not updated will be masked.
|
| 240 |
+
|
| 241 |
+
"""
|
| 242 |
+
|
| 243 |
+
def __init__(self, config: NllbMoeConfig):
|
| 244 |
+
super().__init__()
|
| 245 |
+
self.num_experts = config.num_experts
|
| 246 |
+
self.expert_capacity = config.expert_capacity
|
| 247 |
+
self.classifier = nn.Linear(config.hidden_size, self.num_experts, bias=config.router_bias)
|
| 248 |
+
self.router_ignore_padding_tokens = config.router_ignore_padding_tokens
|
| 249 |
+
self.dtype = getattr(torch, config.router_dtype)
|
| 250 |
+
|
| 251 |
+
self.second_expert_policy = config.second_expert_policy
|
| 252 |
+
self.normalize_router_prob_before_dropping = config.normalize_router_prob_before_dropping
|
| 253 |
+
self.batch_prioritized_routing = config.batch_prioritized_routing
|
| 254 |
+
self.moe_eval_capacity_token_fraction = config.moe_eval_capacity_token_fraction
|
| 255 |
+
|
| 256 |
+
def _cast_classifier(self):
|
| 257 |
+
r"""
|
| 258 |
+
`bitsandbytes` `Linear8bitLt` layers does not support manual casting Therefore we need to check if they are an
|
| 259 |
+
instance of the `Linear8bitLt` class by checking special attributes.
|
| 260 |
+
"""
|
| 261 |
+
if not (hasattr(self.classifier, "SCB") or hasattr(self.classifier, "CB")):
|
| 262 |
+
self.classifier = self.classifier.to(self.dtype)
|
| 263 |
+
|
| 264 |
+
def normalize_router_probabilities(self, router_probs, top_1_mask, top_2_mask):
|
| 265 |
+
top_1_max_probs = (router_probs * top_1_mask).sum(dim=1)
|
| 266 |
+
top_2_max_probs = (router_probs * top_2_mask).sum(dim=1)
|
| 267 |
+
denom_s = torch.clamp(top_1_max_probs + top_2_max_probs, min=torch.finfo(router_probs.dtype).eps)
|
| 268 |
+
top_1_max_probs = top_1_max_probs / denom_s
|
| 269 |
+
top_2_max_probs = top_2_max_probs / denom_s
|
| 270 |
+
return top_1_max_probs, top_2_max_probs
|
| 271 |
+
|
| 272 |
+
def route_tokens(
|
| 273 |
+
self,
|
| 274 |
+
router_logits: torch.Tensor,
|
| 275 |
+
input_dtype: torch.dtype = torch.float32,
|
| 276 |
+
padding_mask: Optional[torch.LongTensor] = None,
|
| 277 |
+
) -> Tuple:
|
| 278 |
+
"""
|
| 279 |
+
Computes the `dispatch_mask` and the `dispatch_weights` for each experts. The masks are adapted to the expert
|
| 280 |
+
capacity.
|
| 281 |
+
"""
|
| 282 |
+
nb_tokens = router_logits.shape[0]
|
| 283 |
+
# Apply Softmax and cast back to the original `dtype`
|
| 284 |
+
router_probs = nn.functional.softmax(router_logits, dim=-1, dtype=self.dtype).to(input_dtype)
|
| 285 |
+
top_1_expert_index = torch.argmax(router_probs, dim=-1)
|
| 286 |
+
top_1_mask = torch.nn.functional.one_hot(top_1_expert_index, num_classes=self.num_experts)
|
| 287 |
+
|
| 288 |
+
if self.second_expert_policy == "sampling":
|
| 289 |
+
gumbel = torch.distributions.gumbel.Gumbel(0, 1).rsample
|
| 290 |
+
router_logits += gumbel(router_logits.shape).to(router_logits.device)
|
| 291 |
+
|
| 292 |
+
# replace top_1_expert_index with min values
|
| 293 |
+
logits_except_top_1 = router_logits.masked_fill(top_1_mask.bool(), float("-inf"))
|
| 294 |
+
top_2_expert_index = torch.argmax(logits_except_top_1, dim=-1)
|
| 295 |
+
top_2_mask = torch.nn.functional.one_hot(top_2_expert_index, num_classes=self.num_experts)
|
| 296 |
+
|
| 297 |
+
if self.normalize_router_prob_before_dropping:
|
| 298 |
+
top_1_max_probs, top_2_max_probs = self.normalize_router_probabilities(
|
| 299 |
+
router_probs, top_1_mask, top_2_mask
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
if self.second_expert_policy == "random":
|
| 303 |
+
top_2_max_probs = (router_probs * top_2_mask).sum(dim=1)
|
| 304 |
+
sampled = (2 * top_2_max_probs) > torch.rand_like(top_2_max_probs.float())
|
| 305 |
+
top_2_mask = top_2_mask * sampled.repeat(self.num_experts, 1).transpose(1, 0)
|
| 306 |
+
|
| 307 |
+
if padding_mask is not None and not self.router_ignore_padding_tokens:
|
| 308 |
+
if len(padding_mask.shape) == 4:
|
| 309 |
+
# only get the last causal mask
|
| 310 |
+
padding_mask = padding_mask[:, :, -1, :].reshape(-1)[-nb_tokens:]
|
| 311 |
+
non_padding = ~padding_mask.bool()
|
| 312 |
+
top_1_mask = top_1_mask * non_padding.unsqueeze(-1).to(top_1_mask.dtype)
|
| 313 |
+
top_2_mask = top_2_mask * non_padding.unsqueeze(-1).to(top_1_mask.dtype)
|
| 314 |
+
|
| 315 |
+
if self.batch_prioritized_routing:
|
| 316 |
+
# sort tokens based on their routing probability
|
| 317 |
+
# to make sure important tokens are routed, first
|
| 318 |
+
importance_scores = -1 * router_probs.max(dim=1)[0]
|
| 319 |
+
sorted_top_1_mask = top_1_mask[importance_scores.argsort(dim=0)]
|
| 320 |
+
sorted_cumsum1 = (torch.cumsum(sorted_top_1_mask, dim=0) - 1) * sorted_top_1_mask
|
| 321 |
+
locations1 = sorted_cumsum1[importance_scores.argsort(dim=0).argsort(dim=0)]
|
| 322 |
+
|
| 323 |
+
sorted_top_2_mask = top_2_mask[importance_scores.argsort(dim=0)]
|
| 324 |
+
sorted_cumsum2 = (torch.cumsum(sorted_top_2_mask, dim=0) - 1) * sorted_top_2_mask
|
| 325 |
+
locations2 = sorted_cumsum2[importance_scores.argsort(dim=0).argsort(dim=0)]
|
| 326 |
+
# Update 2nd's location by accounting for locations of 1st
|
| 327 |
+
locations2 += torch.sum(top_1_mask, dim=0, keepdim=True)
|
| 328 |
+
|
| 329 |
+
else:
|
| 330 |
+
locations1 = torch.cumsum(top_1_mask, dim=0) - 1
|
| 331 |
+
locations2 = torch.cumsum(top_2_mask, dim=0) - 1
|
| 332 |
+
# Update 2nd's location by accounting for locations of 1st
|
| 333 |
+
locations2 += torch.sum(top_1_mask, dim=0, keepdim=True)
|
| 334 |
+
|
| 335 |
+
if not self.training and self.moe_eval_capacity_token_fraction > 0:
|
| 336 |
+
self.expert_capacity = math.ceil(self.moe_eval_capacity_token_fraction * nb_tokens)
|
| 337 |
+
else:
|
| 338 |
+
capacity = 2 * math.ceil(nb_tokens / self.num_experts)
|
| 339 |
+
self.expert_capacity = capacity if self.expert_capacity is None else self.expert_capacity
|
| 340 |
+
|
| 341 |
+
# Remove locations outside capacity from ( cumsum < capacity = False will not be routed)
|
| 342 |
+
top_1_mask = top_1_mask * torch.lt(locations1, self.expert_capacity)
|
| 343 |
+
top_2_mask = top_2_mask * torch.lt(locations2, self.expert_capacity)
|
| 344 |
+
|
| 345 |
+
if not self.normalize_router_prob_before_dropping:
|
| 346 |
+
top_1_max_probs, top_2_max_probs = self.normalize_router_probabilities(
|
| 347 |
+
router_probs, top_1_mask, top_2_mask
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
# Calculate combine_weights and dispatch_mask
|
| 351 |
+
gates1 = top_1_max_probs[:, None] * top_1_mask
|
| 352 |
+
gates2 = top_2_max_probs[:, None] * top_2_mask
|
| 353 |
+
router_probs = gates1 + gates2
|
| 354 |
+
|
| 355 |
+
return top_1_mask, router_probs
|
| 356 |
+
|
| 357 |
+
def forward(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.LongTensor] = None) -> Tuple:
|
| 358 |
+
r"""
|
| 359 |
+
The hidden states are reshaped to simplify the computation of the router probabilities (combining weights for
|
| 360 |
+
each experts.)
|
| 361 |
+
|
| 362 |
+
Args:
|
| 363 |
+
hidden_states (`torch.Tensor`):
|
| 364 |
+
(batch_size, sequence_length, hidden_dim) from which router probabilities are computed.
|
| 365 |
+
Returns:
|
| 366 |
+
top_1_mask (`torch.Tensor` of shape (batch_size, sequence_length)):
|
| 367 |
+
Index tensor of shape [batch_size, sequence_length] corresponding to the expert selected for each token
|
| 368 |
+
using the top1 probabilities of the router.
|
| 369 |
+
router_probabilities (`torch.Tensor` of shape (batch_size, sequence_length, nump_experts)):
|
| 370 |
+
Tensor of shape (batch_size, sequence_length, num_experts) corresponding to the probabilities for each
|
| 371 |
+
token and expert. Used for routing tokens to experts.
|
| 372 |
+
router_logits (`torch.Tensor` of shape (batch_size, sequence_length))):
|
| 373 |
+
Logits tensor of shape (batch_size, sequence_length, num_experts) corresponding to raw router logits.
|
| 374 |
+
This is used later for computing router z-loss.
|
| 375 |
+
"""
|
| 376 |
+
self.input_dtype = hidden_states.dtype
|
| 377 |
+
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
| 378 |
+
hidden_states = hidden_states.reshape((batch_size * sequence_length), hidden_dim)
|
| 379 |
+
hidden_states = hidden_states.to(self.dtype)
|
| 380 |
+
self._cast_classifier()
|
| 381 |
+
router_logits = self.classifier(hidden_states)
|
| 382 |
+
top_1_mask, router_probs = self.route_tokens(router_logits, self.input_dtype, padding_mask)
|
| 383 |
+
return top_1_mask, router_probs
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
class NllbMoeDenseActDense(nn.Module):
|
| 387 |
+
def __init__(self, config: NllbMoeConfig, ffn_dim: int):
|
| 388 |
+
super().__init__()
|
| 389 |
+
self.fc1 = nn.Linear(config.d_model, ffn_dim)
|
| 390 |
+
self.fc2 = nn.Linear(ffn_dim, config.d_model)
|
| 391 |
+
self.dropout = nn.Dropout(config.activation_dropout)
|
| 392 |
+
self.act = ACT2FN[config.activation_function]
|
| 393 |
+
|
| 394 |
+
def forward(self, hidden_states):
|
| 395 |
+
hidden_states = self.fc1(hidden_states)
|
| 396 |
+
hidden_states = self.act(hidden_states)
|
| 397 |
+
hidden_states = self.dropout(hidden_states)
|
| 398 |
+
if (
|
| 399 |
+
isinstance(self.fc2.weight, torch.Tensor)
|
| 400 |
+
and hidden_states.dtype != self.fc2.weight.dtype
|
| 401 |
+
and (self.fc2.weight.dtype != torch.int8 and self.fc2.weight.dtype != torch.uint8)
|
| 402 |
+
):
|
| 403 |
+
hidden_states = hidden_states.to(self.fc2.weight.dtype)
|
| 404 |
+
hidden_states = self.fc2(hidden_states)
|
| 405 |
+
return hidden_states
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
class NllbMoeSparseMLP(nn.Module):
|
| 409 |
+
r"""
|
| 410 |
+
Implementation of the NLLB-MoE sparse MLP module.
|
| 411 |
+
"""
|
| 412 |
+
|
| 413 |
+
def __init__(self, config: NllbMoeConfig, ffn_dim: int, expert_class: nn.Module = NllbMoeDenseActDense):
|
| 414 |
+
super().__init__()
|
| 415 |
+
self.router = NllbMoeTop2Router(config)
|
| 416 |
+
self.moe_token_dropout = config.moe_token_dropout
|
| 417 |
+
self.token_dropout = nn.Dropout(self.moe_token_dropout)
|
| 418 |
+
self.num_experts = config.num_experts
|
| 419 |
+
|
| 420 |
+
self.experts = nn.ModuleDict()
|
| 421 |
+
for idx in range(self.num_experts):
|
| 422 |
+
self.experts[f"expert_{idx}"] = expert_class(config, ffn_dim)
|
| 423 |
+
|
| 424 |
+
def forward(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.Tensor] = False):
|
| 425 |
+
r"""
|
| 426 |
+
The goal of this forward pass is to have the same number of operation as the equivalent `NllbMoeDenseActDense`
|
| 427 |
+
(mlp) layer. This means that all of the hidden states should be processed at most twice ( since we are using a
|
| 428 |
+
top_2 gating mecanism). This means that we keep the complexity to O(batch_size x sequence_length x hidden_dim)
|
| 429 |
+
instead of O(num_experts x batch_size x sequence_length x hidden_dim).
|
| 430 |
+
|
| 431 |
+
1- Get the `router_probs` from the `router`. The shape of the `router_mask` is `(batch_size X sequence_length,
|
| 432 |
+
num_expert)` and corresponds to the boolean version of the `router_probs`. The inputs are masked using the
|
| 433 |
+
`router_mask`.
|
| 434 |
+
|
| 435 |
+
2- Dispatch the hidden_states to its associated experts. The router probabilities are used to weight the
|
| 436 |
+
contribution of each experts when updating the masked hidden states.
|
| 437 |
+
|
| 438 |
+
Args:
|
| 439 |
+
hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_dim)`):
|
| 440 |
+
The hidden states
|
| 441 |
+
padding_mask (`torch.Tensor`, *optional*, defaults to `False`):
|
| 442 |
+
Attention mask. Can be in the causal form or not.
|
| 443 |
+
|
| 444 |
+
Returns:
|
| 445 |
+
hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_dim)`):
|
| 446 |
+
Updated hidden states
|
| 447 |
+
router_logits (`torch.Tensor` of shape `(batch_size, sequence_length, num_experts)`):
|
| 448 |
+
Needed for computing the loss
|
| 449 |
+
|
| 450 |
+
"""
|
| 451 |
+
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
| 452 |
+
|
| 453 |
+
top_1_mask, router_probs = self.router(hidden_states, padding_mask)
|
| 454 |
+
router_mask = router_probs.bool()
|
| 455 |
+
hidden_states = hidden_states.reshape((batch_size * sequence_length), hidden_dim)
|
| 456 |
+
masked_hidden_states = torch.einsum("bm,be->ebm", hidden_states, router_mask)
|
| 457 |
+
for idx, expert in enumerate(self.experts.values()):
|
| 458 |
+
token_indices = router_mask[:, idx]
|
| 459 |
+
combining_weights = router_probs[token_indices, idx]
|
| 460 |
+
expert_output = expert(masked_hidden_states[idx, token_indices])
|
| 461 |
+
if self.moe_token_dropout > 0:
|
| 462 |
+
if self.training:
|
| 463 |
+
expert_output = self.token_dropout(expert_output)
|
| 464 |
+
else:
|
| 465 |
+
expert_output *= 1 - self.moe_token_dropout
|
| 466 |
+
masked_hidden_states[idx, token_indices] = torch.einsum("b,be->be", combining_weights, expert_output)
|
| 467 |
+
hidden_states = masked_hidden_states.sum(dim=0).reshape(batch_size, sequence_length, hidden_dim)
|
| 468 |
+
|
| 469 |
+
top_1_expert_index = torch.argmax(top_1_mask, dim=-1)
|
| 470 |
+
return hidden_states, (router_probs, top_1_expert_index)
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->NllbMoe,key_value_states->encoder_hidden_states
|
| 474 |
+
class NllbMoeAttention(nn.Module):
|
| 475 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 476 |
+
|
| 477 |
+
def __init__(
|
| 478 |
+
self,
|
| 479 |
+
embed_dim: int,
|
| 480 |
+
num_heads: int,
|
| 481 |
+
dropout: float = 0.0,
|
| 482 |
+
is_decoder: bool = False,
|
| 483 |
+
bias: bool = True,
|
| 484 |
+
is_causal: bool = False,
|
| 485 |
+
config: Optional[NllbMoeConfig] = None,
|
| 486 |
+
):
|
| 487 |
+
super().__init__()
|
| 488 |
+
self.embed_dim = embed_dim
|
| 489 |
+
self.num_heads = num_heads
|
| 490 |
+
self.dropout = dropout
|
| 491 |
+
self.head_dim = embed_dim // num_heads
|
| 492 |
+
self.config = config
|
| 493 |
+
|
| 494 |
+
if (self.head_dim * num_heads) != self.embed_dim:
|
| 495 |
+
raise ValueError(
|
| 496 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
| 497 |
+
f" and `num_heads`: {num_heads})."
|
| 498 |
+
)
|
| 499 |
+
self.scaling = self.head_dim**-0.5
|
| 500 |
+
self.is_decoder = is_decoder
|
| 501 |
+
self.is_causal = is_causal
|
| 502 |
+
|
| 503 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 504 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 505 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 506 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 507 |
+
|
| 508 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
| 509 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
| 510 |
+
|
| 511 |
+
def forward(
|
| 512 |
+
self,
|
| 513 |
+
hidden_states: torch.Tensor,
|
| 514 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 515 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 516 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 517 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
| 518 |
+
output_attentions: bool = False,
|
| 519 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 520 |
+
"""Input shape: Batch x Time x Channel"""
|
| 521 |
+
|
| 522 |
+
# if encoder_hidden_states are provided this layer is used as a cross-attention layer
|
| 523 |
+
# for the decoder
|
| 524 |
+
is_cross_attention = encoder_hidden_states is not None
|
| 525 |
+
|
| 526 |
+
bsz, tgt_len, _ = hidden_states.size()
|
| 527 |
+
|
| 528 |
+
# get query proj
|
| 529 |
+
query_states = self.q_proj(hidden_states) * self.scaling
|
| 530 |
+
# get key, value proj
|
| 531 |
+
# `past_key_value[0].shape[2] == encoder_hidden_states.shape[1]`
|
| 532 |
+
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
| 533 |
+
# the provided `encoder_hidden_states` to support prefix tuning
|
| 534 |
+
if (
|
| 535 |
+
is_cross_attention
|
| 536 |
+
and past_key_value is not None
|
| 537 |
+
and past_key_value[0].shape[2] == encoder_hidden_states.shape[1]
|
| 538 |
+
):
|
| 539 |
+
# reuse k,v, cross_attentions
|
| 540 |
+
key_states = past_key_value[0]
|
| 541 |
+
value_states = past_key_value[1]
|
| 542 |
+
elif is_cross_attention:
|
| 543 |
+
# cross_attentions
|
| 544 |
+
key_states = self._shape(self.k_proj(encoder_hidden_states), -1, bsz)
|
| 545 |
+
value_states = self._shape(self.v_proj(encoder_hidden_states), -1, bsz)
|
| 546 |
+
elif past_key_value is not None:
|
| 547 |
+
# reuse k, v, self_attention
|
| 548 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
| 549 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
| 550 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
| 551 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
| 552 |
+
else:
|
| 553 |
+
# self_attention
|
| 554 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
| 555 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
| 556 |
+
|
| 557 |
+
if self.is_decoder:
|
| 558 |
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
| 559 |
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
| 560 |
+
# key/value_states (first "if" case)
|
| 561 |
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
| 562 |
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
| 563 |
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
| 564 |
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
| 565 |
+
past_key_value = (key_states, value_states)
|
| 566 |
+
|
| 567 |
+
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
| 568 |
+
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
| 569 |
+
key_states = key_states.reshape(*proj_shape)
|
| 570 |
+
value_states = value_states.reshape(*proj_shape)
|
| 571 |
+
|
| 572 |
+
src_len = key_states.size(1)
|
| 573 |
+
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
| 574 |
+
|
| 575 |
+
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
| 576 |
+
raise ValueError(
|
| 577 |
+
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
| 578 |
+
f" {attn_weights.size()}"
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
if attention_mask is not None:
|
| 582 |
+
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
| 583 |
+
raise ValueError(
|
| 584 |
+
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
| 585 |
+
)
|
| 586 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
| 587 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
| 588 |
+
|
| 589 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
| 590 |
+
|
| 591 |
+
if layer_head_mask is not None:
|
| 592 |
+
if layer_head_mask.size() != (self.num_heads,):
|
| 593 |
+
raise ValueError(
|
| 594 |
+
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
| 595 |
+
f" {layer_head_mask.size()}"
|
| 596 |
+
)
|
| 597 |
+
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
| 598 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
| 599 |
+
|
| 600 |
+
if output_attentions:
|
| 601 |
+
# this operation is a bit awkward, but it's required to
|
| 602 |
+
# make sure that attn_weights keeps its gradient.
|
| 603 |
+
# In order to do so, attn_weights have to be reshaped
|
| 604 |
+
# twice and have to be reused in the following
|
| 605 |
+
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
| 606 |
+
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
| 607 |
+
else:
|
| 608 |
+
attn_weights_reshaped = None
|
| 609 |
+
|
| 610 |
+
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
| 611 |
+
|
| 612 |
+
attn_output = torch.bmm(attn_probs, value_states)
|
| 613 |
+
|
| 614 |
+
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
| 615 |
+
raise ValueError(
|
| 616 |
+
f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
|
| 617 |
+
f" {attn_output.size()}"
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
| 621 |
+
attn_output = attn_output.transpose(1, 2)
|
| 622 |
+
|
| 623 |
+
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
| 624 |
+
# partitioned across GPUs when using tensor-parallelism.
|
| 625 |
+
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
| 626 |
+
|
| 627 |
+
attn_output = self.out_proj(attn_output)
|
| 628 |
+
|
| 629 |
+
return attn_output, attn_weights_reshaped, past_key_value
|
| 630 |
+
|
| 631 |
+
|
| 632 |
+
class NllbMoeEncoderLayer(nn.Module):
|
| 633 |
+
def __init__(self, config: NllbMoeConfig, is_sparse: bool = False):
|
| 634 |
+
super().__init__()
|
| 635 |
+
self.embed_dim = config.d_model
|
| 636 |
+
self.is_sparse = is_sparse
|
| 637 |
+
self.self_attn = NllbMoeAttention(
|
| 638 |
+
embed_dim=self.embed_dim,
|
| 639 |
+
num_heads=config.encoder_attention_heads,
|
| 640 |
+
dropout=config.attention_dropout,
|
| 641 |
+
)
|
| 642 |
+
self.attn_dropout = nn.Dropout(config.dropout)
|
| 643 |
+
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
| 644 |
+
if not self.is_sparse:
|
| 645 |
+
self.ffn = NllbMoeDenseActDense(config, ffn_dim=config.encoder_ffn_dim)
|
| 646 |
+
else:
|
| 647 |
+
self.ffn = NllbMoeSparseMLP(config, ffn_dim=config.encoder_ffn_dim)
|
| 648 |
+
self.ff_layer_norm = nn.LayerNorm(config.d_model)
|
| 649 |
+
self.ff_dropout = nn.Dropout(config.activation_dropout)
|
| 650 |
+
|
| 651 |
+
def forward(
|
| 652 |
+
self,
|
| 653 |
+
hidden_states: torch.Tensor,
|
| 654 |
+
attention_mask: torch.Tensor,
|
| 655 |
+
layer_head_mask: torch.Tensor,
|
| 656 |
+
output_attentions: bool = False,
|
| 657 |
+
output_router_logits: bool = False,
|
| 658 |
+
) -> torch.Tensor:
|
| 659 |
+
"""
|
| 660 |
+
Args:
|
| 661 |
+
hidden_states (`torch.FloatTensor`):
|
| 662 |
+
input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 663 |
+
attention_mask (`torch.FloatTensor`):
|
| 664 |
+
attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very
|
| 665 |
+
large negative values.
|
| 666 |
+
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
|
| 667 |
+
`(encoder_attention_heads,)`.
|
| 668 |
+
output_attentions (`bool`, *optional*):
|
| 669 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 670 |
+
returned tensors for more detail.
|
| 671 |
+
"""
|
| 672 |
+
residual = hidden_states
|
| 673 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 674 |
+
hidden_states, attn_weights, _ = self.self_attn(
|
| 675 |
+
hidden_states=hidden_states,
|
| 676 |
+
attention_mask=attention_mask,
|
| 677 |
+
layer_head_mask=layer_head_mask,
|
| 678 |
+
output_attentions=output_attentions,
|
| 679 |
+
)
|
| 680 |
+
hidden_states = self.attn_dropout(hidden_states)
|
| 681 |
+
hidden_states = residual + hidden_states
|
| 682 |
+
|
| 683 |
+
residual = hidden_states
|
| 684 |
+
|
| 685 |
+
hidden_states = self.ff_layer_norm(hidden_states)
|
| 686 |
+
if self.is_sparse:
|
| 687 |
+
hidden_states, router_states = self.ffn(hidden_states, attention_mask)
|
| 688 |
+
else:
|
| 689 |
+
# router_states set to None to track which layers have None gradients.
|
| 690 |
+
hidden_states, router_states = self.ffn(hidden_states), None
|
| 691 |
+
|
| 692 |
+
hidden_states = self.ff_dropout(hidden_states)
|
| 693 |
+
|
| 694 |
+
hidden_states = residual + hidden_states
|
| 695 |
+
|
| 696 |
+
if hidden_states.dtype == torch.float16 and (
|
| 697 |
+
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
|
| 698 |
+
):
|
| 699 |
+
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
| 700 |
+
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
| 701 |
+
|
| 702 |
+
outputs = (hidden_states,)
|
| 703 |
+
|
| 704 |
+
if output_attentions:
|
| 705 |
+
outputs += (attn_weights,)
|
| 706 |
+
|
| 707 |
+
if output_router_logits:
|
| 708 |
+
outputs += (router_states,)
|
| 709 |
+
|
| 710 |
+
return outputs
|
| 711 |
+
|
| 712 |
+
|
| 713 |
+
class NllbMoeDecoderLayer(nn.Module):
|
| 714 |
+
def __init__(self, config: NllbMoeConfig, is_sparse: bool = False):
|
| 715 |
+
super().__init__()
|
| 716 |
+
self.embed_dim = config.d_model
|
| 717 |
+
self.is_sparse = is_sparse
|
| 718 |
+
self.self_attn = NllbMoeAttention(
|
| 719 |
+
embed_dim=self.embed_dim,
|
| 720 |
+
num_heads=config.decoder_attention_heads,
|
| 721 |
+
dropout=config.attention_dropout,
|
| 722 |
+
is_decoder=True,
|
| 723 |
+
)
|
| 724 |
+
self.dropout = config.dropout
|
| 725 |
+
self.activation_fn = ACT2FN[config.activation_function]
|
| 726 |
+
self.attn_dropout = nn.Dropout(config.dropout)
|
| 727 |
+
|
| 728 |
+
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
| 729 |
+
self.cross_attention = NllbMoeAttention(
|
| 730 |
+
self.embed_dim, config.decoder_attention_heads, config.attention_dropout, is_decoder=True
|
| 731 |
+
)
|
| 732 |
+
self.cross_attention_layer_norm = nn.LayerNorm(self.embed_dim)
|
| 733 |
+
if not self.is_sparse:
|
| 734 |
+
self.ffn = NllbMoeDenseActDense(config, ffn_dim=config.decoder_ffn_dim)
|
| 735 |
+
else:
|
| 736 |
+
self.ffn = NllbMoeSparseMLP(config, ffn_dim=config.decoder_ffn_dim)
|
| 737 |
+
self.ff_layer_norm = nn.LayerNorm(config.d_model)
|
| 738 |
+
self.ff_dropout = nn.Dropout(config.activation_dropout)
|
| 739 |
+
|
| 740 |
+
def forward(
|
| 741 |
+
self,
|
| 742 |
+
hidden_states: torch.Tensor,
|
| 743 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 744 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 745 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 746 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
| 747 |
+
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
|
| 748 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 749 |
+
output_attentions: Optional[bool] = False,
|
| 750 |
+
output_router_logits: Optional[bool] = False,
|
| 751 |
+
use_cache: Optional[bool] = True,
|
| 752 |
+
) -> torch.Tensor:
|
| 753 |
+
"""
|
| 754 |
+
Args:
|
| 755 |
+
hidden_states (`torch.FloatTensor`):
|
| 756 |
+
input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 757 |
+
attention_mask (`torch.FloatTensor`):
|
| 758 |
+
attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very
|
| 759 |
+
large negative values.
|
| 760 |
+
encoder_hidden_states (`torch.FloatTensor`):
|
| 761 |
+
cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 762 |
+
encoder_attention_mask (`torch.FloatTensor`):
|
| 763 |
+
encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by
|
| 764 |
+
very large negative values.
|
| 765 |
+
layer_head_mask (`torch.FloatTensor`):
|
| 766 |
+
mask for attention heads in a given layer of size `(encoder_attention_heads,)`.
|
| 767 |
+
cross_attn_layer_head_mask (`torch.FloatTensor`):
|
| 768 |
+
mask for cross-attention heads in a given layer of size `(decoder_attention_heads,)`.
|
| 769 |
+
past_key_value (`Tuple(torch.FloatTensor)`):
|
| 770 |
+
cached past key and value projection states
|
| 771 |
+
output_attentions (`bool`, *optional*):
|
| 772 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 773 |
+
returned tensors for more detail.
|
| 774 |
+
"""
|
| 775 |
+
residual = hidden_states
|
| 776 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 777 |
+
|
| 778 |
+
# Self Attention
|
| 779 |
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
| 780 |
+
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
| 781 |
+
# add present self-attn cache to positions 1,2 of present_key_value tuple
|
| 782 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
| 783 |
+
hidden_states=hidden_states,
|
| 784 |
+
past_key_value=self_attn_past_key_value,
|
| 785 |
+
attention_mask=attention_mask,
|
| 786 |
+
layer_head_mask=layer_head_mask,
|
| 787 |
+
output_attentions=output_attentions,
|
| 788 |
+
)
|
| 789 |
+
hidden_states = self.attn_dropout(hidden_states)
|
| 790 |
+
hidden_states = residual + hidden_states
|
| 791 |
+
|
| 792 |
+
# Cross-Attention Block
|
| 793 |
+
cross_attn_present_key_value = None
|
| 794 |
+
cross_attn_weights = None
|
| 795 |
+
if encoder_hidden_states is not None:
|
| 796 |
+
residual = hidden_states
|
| 797 |
+
hidden_states = self.cross_attention_layer_norm(hidden_states)
|
| 798 |
+
|
| 799 |
+
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
| 800 |
+
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
| 801 |
+
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.cross_attention(
|
| 802 |
+
hidden_states=hidden_states,
|
| 803 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 804 |
+
past_key_value=cross_attn_past_key_value,
|
| 805 |
+
attention_mask=encoder_attention_mask,
|
| 806 |
+
layer_head_mask=cross_attn_layer_head_mask,
|
| 807 |
+
output_attentions=output_attentions,
|
| 808 |
+
)
|
| 809 |
+
hidden_states = self.attn_dropout(hidden_states)
|
| 810 |
+
hidden_states = residual + hidden_states
|
| 811 |
+
|
| 812 |
+
# add cross-attn to positions 3,4 of present_key_value tuple
|
| 813 |
+
present_key_value += cross_attn_present_key_value
|
| 814 |
+
|
| 815 |
+
# Fully Connected
|
| 816 |
+
residual = hidden_states
|
| 817 |
+
|
| 818 |
+
hidden_states = self.ff_layer_norm(hidden_states)
|
| 819 |
+
if self.is_sparse:
|
| 820 |
+
hidden_states, router_states = self.ffn(hidden_states, attention_mask)
|
| 821 |
+
else:
|
| 822 |
+
hidden_states, router_states = self.ffn(hidden_states), None
|
| 823 |
+
|
| 824 |
+
hidden_states = self.ff_dropout(hidden_states)
|
| 825 |
+
|
| 826 |
+
hidden_states = residual + hidden_states
|
| 827 |
+
|
| 828 |
+
# clamp inf values to enable fp16 training
|
| 829 |
+
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
|
| 830 |
+
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
| 831 |
+
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
| 832 |
+
|
| 833 |
+
outputs = (hidden_states, present_key_value)
|
| 834 |
+
|
| 835 |
+
if output_attentions:
|
| 836 |
+
outputs += (self_attn_weights, cross_attn_weights)
|
| 837 |
+
|
| 838 |
+
if output_router_logits:
|
| 839 |
+
outputs += (router_states,)
|
| 840 |
+
|
| 841 |
+
return outputs
|
| 842 |
+
|
| 843 |
+
|
| 844 |
+
class NllbMoePreTrainedModel(PreTrainedModel):
|
| 845 |
+
config_class = NllbMoeConfig
|
| 846 |
+
base_model_prefix = "model"
|
| 847 |
+
supports_gradient_checkpointing = True
|
| 848 |
+
_no_split_modules = ["NllbMoeEncoderLayer", "NllbMoeDecoderLayer"]
|
| 849 |
+
|
| 850 |
+
def _init_weights(self, module):
|
| 851 |
+
"""Initialize the weights"""
|
| 852 |
+
std = self.config.init_std
|
| 853 |
+
if isinstance(module, nn.Linear):
|
| 854 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 855 |
+
if module.bias is not None:
|
| 856 |
+
module.bias.data.zero_()
|
| 857 |
+
elif isinstance(module, nn.Embedding):
|
| 858 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 859 |
+
if module.padding_idx is not None:
|
| 860 |
+
module.weight.data[module.padding_idx].zero_()
|
| 861 |
+
|
| 862 |
+
|
| 863 |
+
NLLB_MOE_START_DOCSTRING = r"""
|
| 864 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 865 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 866 |
+
etc.)
|
| 867 |
+
|
| 868 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 869 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 870 |
+
and behavior.
|
| 871 |
+
|
| 872 |
+
Parameters:
|
| 873 |
+
config ([`NllbMoeConfig`]):
|
| 874 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
| 875 |
+
load the weights associated with the model, only the configuration. Check out the
|
| 876 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 877 |
+
"""
|
| 878 |
+
|
| 879 |
+
NLLB_MOE_GENERATION_EXAMPLE = r"""
|
| 880 |
+
Translation example:
|
| 881 |
+
|
| 882 |
+
```python
|
| 883 |
+
>>> from transformers import AutoTokenizer, NllbMoeForConditionalGeneration
|
| 884 |
+
|
| 885 |
+
>>> model = NllbMoeForConditionalGeneration.from_pretrained("facebook/nllb-moe-54b")
|
| 886 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-moe-54b")
|
| 887 |
+
|
| 888 |
+
>>> text_to_translate = "Life is like a box of chocolates"
|
| 889 |
+
>>> model_inputs = tokenizer(text_to_translate, return_tensors="pt")
|
| 890 |
+
|
| 891 |
+
>>> # translate to French
|
| 892 |
+
>>> gen_tokens = model.generate(**model_inputs, forced_bos_token_id=tokenizer.get_lang_id("eng_Latn"))
|
| 893 |
+
>>> print(tokenizer.batch_decode(gen_tokens, skip_special_tokens=True))
|
| 894 |
+
```
|
| 895 |
+
"""
|
| 896 |
+
|
| 897 |
+
NLLB_MOE_INPUTS_DOCSTRING = r"""
|
| 898 |
+
Args:
|
| 899 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 900 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 901 |
+
it.
|
| 902 |
+
|
| 903 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 904 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 905 |
+
|
| 906 |
+
[What are input IDs?](../glossary#input-ids)
|
| 907 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 908 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 909 |
+
|
| 910 |
+
- 1 for tokens that are **not masked**,
|
| 911 |
+
- 0 for tokens that are **masked**.
|
| 912 |
+
|
| 913 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 914 |
+
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
| 915 |
+
Indices of decoder input sequence tokens in the vocabulary.
|
| 916 |
+
|
| 917 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 918 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 919 |
+
|
| 920 |
+
[What are decoder input IDs?](../glossary#decoder-input-ids)
|
| 921 |
+
|
| 922 |
+
NllbMoe uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If
|
| 923 |
+
`past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
| 924 |
+
`past_key_values`).
|
| 925 |
+
decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
| 926 |
+
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
| 927 |
+
be used by default.
|
| 928 |
+
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
|
| 929 |
+
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
|
| 930 |
+
|
| 931 |
+
- 1 indicates the head is **not masked**,
|
| 932 |
+
- 0 indicates the head is **masked**.
|
| 933 |
+
|
| 934 |
+
decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
|
| 935 |
+
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
|
| 936 |
+
|
| 937 |
+
- 1 indicates the head is **not masked**,
|
| 938 |
+
- 0 indicates the head is **masked**.
|
| 939 |
+
|
| 940 |
+
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
|
| 941 |
+
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
|
| 942 |
+
1]`:
|
| 943 |
+
|
| 944 |
+
- 1 indicates the head is **not masked**,
|
| 945 |
+
- 0 indicates the head is **masked**.
|
| 946 |
+
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
|
| 947 |
+
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
|
| 948 |
+
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
|
| 949 |
+
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
| 950 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
| 951 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
| 952 |
+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
| 953 |
+
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
| 954 |
+
|
| 955 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
| 956 |
+
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
| 957 |
+
|
| 958 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
| 959 |
+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
| 960 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
| 961 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 962 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
| 963 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
| 964 |
+
than the model's internal embedding lookup matrix.
|
| 965 |
+
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
|
| 966 |
+
Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
|
| 967 |
+
representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
|
| 968 |
+
input (see `past_key_values`). This is useful if you want more control over how to convert
|
| 969 |
+
`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
|
| 970 |
+
|
| 971 |
+
If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
|
| 972 |
+
of `inputs_embeds`.
|
| 973 |
+
use_cache (`bool`, *optional*):
|
| 974 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
| 975 |
+
`past_key_values`).
|
| 976 |
+
output_attentions (`bool`, *optional*):
|
| 977 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 978 |
+
tensors for more detail.
|
| 979 |
+
output_hidden_states (`bool`, *optional*):
|
| 980 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 981 |
+
more detail.
|
| 982 |
+
output_router_logits (`bool`, *optional*):
|
| 983 |
+
Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
|
| 984 |
+
should not be returned during inference.
|
| 985 |
+
return_dict (`bool`, *optional*):
|
| 986 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 987 |
+
"""
|
| 988 |
+
|
| 989 |
+
|
| 990 |
+
class NllbMoeEncoder(NllbMoePreTrainedModel):
|
| 991 |
+
"""
|
| 992 |
+
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
|
| 993 |
+
[`NllbMoeEncoderLayer`].
|
| 994 |
+
|
| 995 |
+
Args:
|
| 996 |
+
config:
|
| 997 |
+
NllbMoeConfig
|
| 998 |
+
embed_tokens (nn.Embedding):
|
| 999 |
+
output embedding
|
| 1000 |
+
"""
|
| 1001 |
+
|
| 1002 |
+
def __init__(self, config: NllbMoeConfig, embed_tokens: Optional[nn.Embedding] = None):
|
| 1003 |
+
super().__init__(config)
|
| 1004 |
+
|
| 1005 |
+
self.dropout = config.dropout
|
| 1006 |
+
self.layerdrop = config.encoder_layerdrop
|
| 1007 |
+
|
| 1008 |
+
embed_dim = config.d_model
|
| 1009 |
+
self.padding_idx = config.pad_token_id
|
| 1010 |
+
self.max_source_positions = config.max_position_embeddings
|
| 1011 |
+
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
| 1012 |
+
|
| 1013 |
+
self.embed_tokens = NllbMoeScaledWordEmbedding(
|
| 1014 |
+
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
|
| 1015 |
+
)
|
| 1016 |
+
|
| 1017 |
+
if embed_tokens is not None:
|
| 1018 |
+
self.embed_tokens.weight = embed_tokens.weight
|
| 1019 |
+
|
| 1020 |
+
self.embed_positions = NllbMoeSinusoidalPositionalEmbedding(
|
| 1021 |
+
config.max_position_embeddings,
|
| 1022 |
+
embed_dim,
|
| 1023 |
+
self.padding_idx,
|
| 1024 |
+
)
|
| 1025 |
+
sparse_step = config.encoder_sparse_step
|
| 1026 |
+
self.layers = nn.ModuleList()
|
| 1027 |
+
for i in range(config.encoder_layers):
|
| 1028 |
+
is_sparse = (i + 1) % sparse_step == 0 if sparse_step > 0 else False
|
| 1029 |
+
self.layers.append(NllbMoeEncoderLayer(config, is_sparse))
|
| 1030 |
+
|
| 1031 |
+
self.layer_norm = nn.LayerNorm(config.d_model)
|
| 1032 |
+
|
| 1033 |
+
self.gradient_checkpointing = False
|
| 1034 |
+
# Initialize weights and apply final processing
|
| 1035 |
+
self.post_init()
|
| 1036 |
+
|
| 1037 |
+
def forward(
|
| 1038 |
+
self,
|
| 1039 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1040 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1041 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1042 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1043 |
+
output_attentions: Optional[bool] = None,
|
| 1044 |
+
output_hidden_states: Optional[bool] = None,
|
| 1045 |
+
output_router_logits: Optional[bool] = None,
|
| 1046 |
+
return_dict: Optional[bool] = None,
|
| 1047 |
+
):
|
| 1048 |
+
r"""
|
| 1049 |
+
Args:
|
| 1050 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 1051 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
| 1052 |
+
provide it.
|
| 1053 |
+
|
| 1054 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 1055 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 1056 |
+
|
| 1057 |
+
[What are input IDs?](../glossary#input-ids)
|
| 1058 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1059 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 1060 |
+
|
| 1061 |
+
- 1 for tokens that are **not masked**,
|
| 1062 |
+
- 0 for tokens that are **masked**.
|
| 1063 |
+
|
| 1064 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 1065 |
+
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
|
| 1066 |
+
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
| 1067 |
+
|
| 1068 |
+
- 1 indicates the head is **not masked**,
|
| 1069 |
+
- 0 indicates the head is **masked**.
|
| 1070 |
+
|
| 1071 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 1072 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
| 1073 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
| 1074 |
+
than the model's internal embedding lookup matrix.
|
| 1075 |
+
output_attentions (`bool`, *optional*):
|
| 1076 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 1077 |
+
returned tensors for more detail.
|
| 1078 |
+
output_hidden_states (`bool`, *optional*):
|
| 1079 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
| 1080 |
+
for more detail.
|
| 1081 |
+
output_router_logits (`bool`, *optional*):
|
| 1082 |
+
Whether or not to return the logits of all the routers. They are useful for computing the router loss,
|
| 1083 |
+
and should not be returned during inference.
|
| 1084 |
+
return_dict (`bool`, *optional*):
|
| 1085 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 1086 |
+
"""
|
| 1087 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1088 |
+
output_hidden_states = (
|
| 1089 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1090 |
+
)
|
| 1091 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 1092 |
+
|
| 1093 |
+
# retrieve input_ids and inputs_embeds
|
| 1094 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 1095 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 1096 |
+
elif input_ids is not None:
|
| 1097 |
+
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
| 1098 |
+
input_shape = input_ids.size()
|
| 1099 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
| 1100 |
+
elif inputs_embeds is not None:
|
| 1101 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 1102 |
+
else:
|
| 1103 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 1104 |
+
|
| 1105 |
+
if inputs_embeds is None:
|
| 1106 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 1107 |
+
|
| 1108 |
+
embed_pos = self.embed_positions(input_ids, inputs_embeds)
|
| 1109 |
+
embed_pos = embed_pos.to(inputs_embeds.device)
|
| 1110 |
+
|
| 1111 |
+
hidden_states = inputs_embeds + embed_pos
|
| 1112 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 1113 |
+
|
| 1114 |
+
# expand attention_mask
|
| 1115 |
+
if attention_mask is not None:
|
| 1116 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 1117 |
+
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
|
| 1118 |
+
|
| 1119 |
+
encoder_states = () if output_hidden_states else None
|
| 1120 |
+
all_router_probs = () if output_router_logits else None
|
| 1121 |
+
all_attentions = () if output_attentions else None
|
| 1122 |
+
|
| 1123 |
+
# check if head_mask has a correct number of layers specified if desired
|
| 1124 |
+
if head_mask is not None:
|
| 1125 |
+
if head_mask.size()[0] != len(self.layers):
|
| 1126 |
+
raise ValueError(
|
| 1127 |
+
f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
|
| 1128 |
+
f" {head_mask.size()[0]}."
|
| 1129 |
+
)
|
| 1130 |
+
|
| 1131 |
+
for idx, encoder_layer in enumerate(self.layers):
|
| 1132 |
+
if output_hidden_states:
|
| 1133 |
+
encoder_states = encoder_states + (hidden_states,)
|
| 1134 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
| 1135 |
+
dropout_probability = torch.rand([])
|
| 1136 |
+
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
| 1137 |
+
layer_outputs = (None, None, None)
|
| 1138 |
+
else:
|
| 1139 |
+
if self.gradient_checkpointing and self.training:
|
| 1140 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 1141 |
+
encoder_layer.__call__,
|
| 1142 |
+
hidden_states,
|
| 1143 |
+
attention_mask,
|
| 1144 |
+
(head_mask[idx] if head_mask is not None else None),
|
| 1145 |
+
output_attentions,
|
| 1146 |
+
)
|
| 1147 |
+
else:
|
| 1148 |
+
layer_outputs = encoder_layer(
|
| 1149 |
+
hidden_states,
|
| 1150 |
+
attention_mask,
|
| 1151 |
+
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
| 1152 |
+
output_attentions=output_attentions,
|
| 1153 |
+
output_router_logits=output_router_logits,
|
| 1154 |
+
)
|
| 1155 |
+
|
| 1156 |
+
hidden_states = layer_outputs[0]
|
| 1157 |
+
|
| 1158 |
+
if output_attentions:
|
| 1159 |
+
all_attentions += (layer_outputs[1],)
|
| 1160 |
+
|
| 1161 |
+
if output_router_logits:
|
| 1162 |
+
all_router_probs += (layer_outputs[-1],)
|
| 1163 |
+
|
| 1164 |
+
last_hidden_state = self.layer_norm(hidden_states)
|
| 1165 |
+
|
| 1166 |
+
if output_hidden_states:
|
| 1167 |
+
encoder_states += (last_hidden_state,)
|
| 1168 |
+
|
| 1169 |
+
if not return_dict:
|
| 1170 |
+
return tuple(
|
| 1171 |
+
v for v in [last_hidden_state, encoder_states, all_attentions, all_router_probs] if v is not None
|
| 1172 |
+
)
|
| 1173 |
+
|
| 1174 |
+
return MoEModelOutput(
|
| 1175 |
+
last_hidden_state=last_hidden_state,
|
| 1176 |
+
hidden_states=encoder_states,
|
| 1177 |
+
attentions=all_attentions,
|
| 1178 |
+
router_probs=all_router_probs,
|
| 1179 |
+
)
|
| 1180 |
+
|
| 1181 |
+
|
| 1182 |
+
class NllbMoeDecoder(NllbMoePreTrainedModel):
|
| 1183 |
+
"""
|
| 1184 |
+
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`NllbMoeDecoderLayer`]
|
| 1185 |
+
|
| 1186 |
+
Args:
|
| 1187 |
+
config:
|
| 1188 |
+
NllbMoeConfig
|
| 1189 |
+
embed_tokens (nn.Embedding):
|
| 1190 |
+
output embedding
|
| 1191 |
+
"""
|
| 1192 |
+
|
| 1193 |
+
def __init__(self, config: NllbMoeConfig, embed_tokens: Optional[nn.Embedding] = None):
|
| 1194 |
+
super().__init__(config)
|
| 1195 |
+
self.dropout = config.dropout
|
| 1196 |
+
self.layerdrop = config.decoder_layerdrop
|
| 1197 |
+
self.padding_idx = config.pad_token_id
|
| 1198 |
+
self.max_target_positions = config.max_position_embeddings
|
| 1199 |
+
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
| 1200 |
+
|
| 1201 |
+
self.embed_tokens = NllbMoeScaledWordEmbedding(
|
| 1202 |
+
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
|
| 1203 |
+
)
|
| 1204 |
+
|
| 1205 |
+
if embed_tokens is not None:
|
| 1206 |
+
self.embed_tokens.weight = embed_tokens.weight
|
| 1207 |
+
|
| 1208 |
+
self.embed_positions = NllbMoeSinusoidalPositionalEmbedding(
|
| 1209 |
+
config.max_position_embeddings,
|
| 1210 |
+
config.d_model,
|
| 1211 |
+
self.padding_idx,
|
| 1212 |
+
)
|
| 1213 |
+
|
| 1214 |
+
sparse_step = config.decoder_sparse_step
|
| 1215 |
+
self.layers = nn.ModuleList()
|
| 1216 |
+
for i in range(config.decoder_layers):
|
| 1217 |
+
is_sparse = (i + 1) % sparse_step == 0 if sparse_step > 0 else False
|
| 1218 |
+
self.layers.append(NllbMoeDecoderLayer(config, is_sparse))
|
| 1219 |
+
|
| 1220 |
+
self.layer_norm = nn.LayerNorm(config.d_model)
|
| 1221 |
+
|
| 1222 |
+
self.gradient_checkpointing = False
|
| 1223 |
+
# Initialize weights and apply final processing
|
| 1224 |
+
self.post_init()
|
| 1225 |
+
|
| 1226 |
+
def forward(
|
| 1227 |
+
self,
|
| 1228 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1229 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1230 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 1231 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 1232 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1233 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
| 1234 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 1235 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1236 |
+
use_cache: Optional[bool] = None,
|
| 1237 |
+
output_attentions: Optional[bool] = None,
|
| 1238 |
+
output_hidden_states: Optional[bool] = None,
|
| 1239 |
+
output_router_logits: Optional[bool] = None,
|
| 1240 |
+
return_dict: Optional[bool] = None,
|
| 1241 |
+
):
|
| 1242 |
+
r"""
|
| 1243 |
+
Args:
|
| 1244 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 1245 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
| 1246 |
+
provide it.
|
| 1247 |
+
|
| 1248 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 1249 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 1250 |
+
|
| 1251 |
+
[What are input IDs?](../glossary#input-ids)
|
| 1252 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1253 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 1254 |
+
|
| 1255 |
+
- 1 for tokens that are **not masked**,
|
| 1256 |
+
- 0 for tokens that are **masked**.
|
| 1257 |
+
|
| 1258 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 1259 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
|
| 1260 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
| 1261 |
+
of the decoder.
|
| 1262 |
+
encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
|
| 1263 |
+
Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
|
| 1264 |
+
selected in `[0, 1]`:
|
| 1265 |
+
|
| 1266 |
+
- 1 for tokens that are **not masked**,
|
| 1267 |
+
- 0 for tokens that are **masked**.
|
| 1268 |
+
|
| 1269 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 1270 |
+
head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
|
| 1271 |
+
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
| 1272 |
+
|
| 1273 |
+
- 1 indicates the head is **not masked**,
|
| 1274 |
+
- 0 indicates the head is **masked**.
|
| 1275 |
+
|
| 1276 |
+
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
|
| 1277 |
+
Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
|
| 1278 |
+
cross-attention on hidden heads. Mask values selected in `[0, 1]`:
|
| 1279 |
+
|
| 1280 |
+
- 1 indicates the head is **not masked**,
|
| 1281 |
+
- 0 indicates the head is **masked**.
|
| 1282 |
+
|
| 1283 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
| 1284 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
| 1285 |
+
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
|
| 1286 |
+
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
| 1287 |
+
|
| 1288 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
|
| 1289 |
+
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
| 1290 |
+
|
| 1291 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
|
| 1292 |
+
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
| 1293 |
+
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
| 1294 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 1295 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
| 1296 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
| 1297 |
+
than the model's internal embedding lookup matrix.
|
| 1298 |
+
output_attentions (`bool`, *optional*):
|
| 1299 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 1300 |
+
returned tensors for more detail.
|
| 1301 |
+
output_hidden_states (`bool`, *optional*):
|
| 1302 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
| 1303 |
+
for more detail.
|
| 1304 |
+
output_router_logits (`bool`, *optional*):
|
| 1305 |
+
Whether or not to return the logits of all the routers. They are useful for computing the router loss,
|
| 1306 |
+
and should not be returned during inference.
|
| 1307 |
+
return_dict (`bool`, *optional*):
|
| 1308 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 1309 |
+
"""
|
| 1310 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1311 |
+
output_hidden_states = (
|
| 1312 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1313 |
+
)
|
| 1314 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 1315 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 1316 |
+
|
| 1317 |
+
# retrieve input_ids and inputs_embeds
|
| 1318 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 1319 |
+
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
| 1320 |
+
elif input_ids is not None:
|
| 1321 |
+
input_shape = input_ids.size()
|
| 1322 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
| 1323 |
+
elif inputs_embeds is not None:
|
| 1324 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 1325 |
+
else:
|
| 1326 |
+
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
| 1327 |
+
|
| 1328 |
+
# past_key_values_length
|
| 1329 |
+
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
| 1330 |
+
|
| 1331 |
+
if inputs_embeds is None:
|
| 1332 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 1333 |
+
|
| 1334 |
+
# create causal mask
|
| 1335 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 1336 |
+
combined_attention_mask = _prepare_4d_causal_attention_mask(
|
| 1337 |
+
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
| 1338 |
+
)
|
| 1339 |
+
|
| 1340 |
+
# expand encoder attention mask
|
| 1341 |
+
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
| 1342 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 1343 |
+
encoder_attention_mask = _prepare_4d_attention_mask(
|
| 1344 |
+
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
| 1345 |
+
)
|
| 1346 |
+
|
| 1347 |
+
# embed positions
|
| 1348 |
+
positions = self.embed_positions(input_ids, inputs_embeds, past_key_values_length)
|
| 1349 |
+
positions = positions.to(inputs_embeds.device)
|
| 1350 |
+
|
| 1351 |
+
hidden_states = inputs_embeds + positions
|
| 1352 |
+
|
| 1353 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 1354 |
+
|
| 1355 |
+
if self.gradient_checkpointing and self.training:
|
| 1356 |
+
if use_cache:
|
| 1357 |
+
logger.warning_once(
|
| 1358 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 1359 |
+
)
|
| 1360 |
+
use_cache = False
|
| 1361 |
+
|
| 1362 |
+
# decoder layers
|
| 1363 |
+
all_hidden_states = () if output_hidden_states else None
|
| 1364 |
+
all_self_attns = () if output_attentions else None
|
| 1365 |
+
all_router_probs = () if output_router_logits else None
|
| 1366 |
+
all_cross_attentions = () if output_attentions else None
|
| 1367 |
+
present_key_value_states = () if use_cache else None
|
| 1368 |
+
|
| 1369 |
+
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
| 1370 |
+
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
|
| 1371 |
+
if attn_mask is not None:
|
| 1372 |
+
if attn_mask.size()[0] != len(self.layers):
|
| 1373 |
+
raise ValueError(
|
| 1374 |
+
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
|
| 1375 |
+
f" {head_mask.size()[0]}."
|
| 1376 |
+
)
|
| 1377 |
+
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
|
| 1378 |
+
|
| 1379 |
+
for idx, decoder_layer in enumerate(self.layers):
|
| 1380 |
+
if output_hidden_states:
|
| 1381 |
+
all_hidden_states += (hidden_states,)
|
| 1382 |
+
|
| 1383 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
| 1384 |
+
dropout_probability = torch.rand([])
|
| 1385 |
+
|
| 1386 |
+
skip_the_layer = True if self.training and (dropout_probability < self.layerdrop) else False
|
| 1387 |
+
if not skip_the_layer or synced_gpus:
|
| 1388 |
+
layer_head_mask = head_mask[idx] if head_mask is not None else None
|
| 1389 |
+
cross_attn_layer_head_mask = cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
| 1390 |
+
|
| 1391 |
+
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
| 1392 |
+
|
| 1393 |
+
# under fsdp or deepspeed zero3 all gpus must run in sync
|
| 1394 |
+
if self.gradient_checkpointing and self.training:
|
| 1395 |
+
if use_cache:
|
| 1396 |
+
logger.warning_once(
|
| 1397 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 1398 |
+
)
|
| 1399 |
+
use_cache = False
|
| 1400 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 1401 |
+
decoder_layer.forward,
|
| 1402 |
+
hidden_states,
|
| 1403 |
+
combined_attention_mask,
|
| 1404 |
+
encoder_hidden_states,
|
| 1405 |
+
encoder_attention_mask,
|
| 1406 |
+
layer_head_mask,
|
| 1407 |
+
cross_attn_layer_head_mask,
|
| 1408 |
+
None, # past_key_value is always None with gradient checkpointing
|
| 1409 |
+
use_cache,
|
| 1410 |
+
output_attentions,
|
| 1411 |
+
)
|
| 1412 |
+
else:
|
| 1413 |
+
layer_outputs = decoder_layer(
|
| 1414 |
+
hidden_states,
|
| 1415 |
+
attention_mask=combined_attention_mask,
|
| 1416 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 1417 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 1418 |
+
layer_head_mask=layer_head_mask,
|
| 1419 |
+
cross_attn_layer_head_mask=cross_attn_layer_head_mask,
|
| 1420 |
+
past_key_value=past_key_value,
|
| 1421 |
+
use_cache=use_cache,
|
| 1422 |
+
output_attentions=output_attentions,
|
| 1423 |
+
output_router_logits=output_router_logits,
|
| 1424 |
+
)
|
| 1425 |
+
|
| 1426 |
+
hidden_states = layer_outputs[0]
|
| 1427 |
+
|
| 1428 |
+
if skip_the_layer:
|
| 1429 |
+
continue
|
| 1430 |
+
|
| 1431 |
+
if use_cache:
|
| 1432 |
+
present_key_value_states += (layer_outputs[1],)
|
| 1433 |
+
|
| 1434 |
+
if output_attentions:
|
| 1435 |
+
all_self_attns += (layer_outputs[2],)
|
| 1436 |
+
all_cross_attentions += (layer_outputs[3],)
|
| 1437 |
+
|
| 1438 |
+
if output_router_logits:
|
| 1439 |
+
all_router_probs += (layer_outputs[-1],)
|
| 1440 |
+
|
| 1441 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 1442 |
+
|
| 1443 |
+
# Add last layer
|
| 1444 |
+
if output_hidden_states:
|
| 1445 |
+
all_hidden_states += (hidden_states,)
|
| 1446 |
+
|
| 1447 |
+
if not return_dict:
|
| 1448 |
+
return tuple(
|
| 1449 |
+
v
|
| 1450 |
+
for v in [
|
| 1451 |
+
hidden_states,
|
| 1452 |
+
present_key_value_states,
|
| 1453 |
+
all_hidden_states,
|
| 1454 |
+
all_self_attns,
|
| 1455 |
+
all_cross_attentions,
|
| 1456 |
+
all_router_probs,
|
| 1457 |
+
]
|
| 1458 |
+
if v is not None
|
| 1459 |
+
)
|
| 1460 |
+
return MoEModelOutputWithPastAndCrossAttentions(
|
| 1461 |
+
last_hidden_state=hidden_states,
|
| 1462 |
+
past_key_values=present_key_value_states,
|
| 1463 |
+
hidden_states=all_hidden_states,
|
| 1464 |
+
attentions=all_self_attns,
|
| 1465 |
+
cross_attentions=all_cross_attentions,
|
| 1466 |
+
router_probs=all_router_probs,
|
| 1467 |
+
)
|
| 1468 |
+
|
| 1469 |
+
|
| 1470 |
+
@add_start_docstrings(
|
| 1471 |
+
"The bare NllbMoe Model outputting raw hidden-states without any specific head on top.",
|
| 1472 |
+
NLLB_MOE_START_DOCSTRING,
|
| 1473 |
+
)
|
| 1474 |
+
class NllbMoeModel(NllbMoePreTrainedModel):
|
| 1475 |
+
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
| 1476 |
+
|
| 1477 |
+
def __init__(self, config: NllbMoeConfig):
|
| 1478 |
+
super().__init__(config)
|
| 1479 |
+
|
| 1480 |
+
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
|
| 1481 |
+
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
| 1482 |
+
self.shared = NllbMoeScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)
|
| 1483 |
+
|
| 1484 |
+
self.encoder = NllbMoeEncoder(config, self.shared)
|
| 1485 |
+
self.decoder = NllbMoeDecoder(config, self.shared)
|
| 1486 |
+
|
| 1487 |
+
# Initialize weights and apply final processing
|
| 1488 |
+
self.post_init()
|
| 1489 |
+
|
| 1490 |
+
def get_input_embeddings(self):
|
| 1491 |
+
return self.shared
|
| 1492 |
+
|
| 1493 |
+
def set_input_embeddings(self, value):
|
| 1494 |
+
self.shared = value
|
| 1495 |
+
self.encoder.embed_tokens = self.shared
|
| 1496 |
+
self.decoder.embed_tokens = self.shared
|
| 1497 |
+
|
| 1498 |
+
def _tie_weights(self):
|
| 1499 |
+
if self.config.tie_word_embeddings:
|
| 1500 |
+
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
|
| 1501 |
+
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
|
| 1502 |
+
|
| 1503 |
+
def get_encoder(self):
|
| 1504 |
+
return self.encoder
|
| 1505 |
+
|
| 1506 |
+
def get_decoder(self):
|
| 1507 |
+
return self.decoder
|
| 1508 |
+
|
| 1509 |
+
@add_start_docstrings_to_model_forward(NLLB_MOE_INPUTS_DOCSTRING)
|
| 1510 |
+
@add_start_docstrings_to_model_forward(NLLB_MOE_INPUTS_DOCSTRING)
|
| 1511 |
+
@replace_return_docstrings(output_type=Seq2SeqMoEModelOutput, config_class=_CONFIG_FOR_DOC)
|
| 1512 |
+
def forward(
|
| 1513 |
+
self,
|
| 1514 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1515 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1516 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
| 1517 |
+
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
| 1518 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1519 |
+
decoder_head_mask: Optional[torch.Tensor] = None,
|
| 1520 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
| 1521 |
+
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 1522 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 1523 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1524 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1525 |
+
use_cache: Optional[bool] = None,
|
| 1526 |
+
output_attentions: Optional[bool] = None,
|
| 1527 |
+
output_hidden_states: Optional[bool] = None,
|
| 1528 |
+
output_router_logits: Optional[bool] = None,
|
| 1529 |
+
return_dict: Optional[bool] = None,
|
| 1530 |
+
) -> Union[Tuple[torch.Tensor], Seq2SeqMoEModelOutput]:
|
| 1531 |
+
r"""
|
| 1532 |
+
Returns:
|
| 1533 |
+
|
| 1534 |
+
Example:
|
| 1535 |
+
|
| 1536 |
+
```python
|
| 1537 |
+
>>> from transformers import AutoTokenizer, NllbMoeModel
|
| 1538 |
+
|
| 1539 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/random-nllb-moe-2-experts")
|
| 1540 |
+
>>> model = SwitchTransformersModel.from_pretrained("hf-internal-testing/random-nllb-moe-2-experts")
|
| 1541 |
+
|
| 1542 |
+
>>> input_ids = tokenizer(
|
| 1543 |
+
... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
|
| 1544 |
+
... ).input_ids # Batch size 1
|
| 1545 |
+
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
|
| 1546 |
+
|
| 1547 |
+
>>> # preprocess: Prepend decoder_input_ids with start token which is pad token for NllbMoeModel
|
| 1548 |
+
>>> decoder_input_ids = model._shift_right(decoder_input_ids)
|
| 1549 |
+
|
| 1550 |
+
>>> # forward pass
|
| 1551 |
+
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
| 1552 |
+
>>> last_hidden_states = outputs.last_hidden_state
|
| 1553 |
+
```"""
|
| 1554 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 1555 |
+
if encoder_outputs is None:
|
| 1556 |
+
encoder_outputs = self.encoder(
|
| 1557 |
+
input_ids=input_ids,
|
| 1558 |
+
attention_mask=attention_mask,
|
| 1559 |
+
head_mask=head_mask,
|
| 1560 |
+
inputs_embeds=inputs_embeds,
|
| 1561 |
+
output_attentions=output_attentions,
|
| 1562 |
+
output_hidden_states=output_hidden_states,
|
| 1563 |
+
output_router_logits=output_router_logits,
|
| 1564 |
+
return_dict=return_dict,
|
| 1565 |
+
)
|
| 1566 |
+
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
|
| 1567 |
+
elif return_dict and not isinstance(encoder_outputs, MoEModelOutput):
|
| 1568 |
+
encoder_outputs = MoEModelOutput(
|
| 1569 |
+
last_hidden_state=encoder_outputs[0],
|
| 1570 |
+
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
| 1571 |
+
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
| 1572 |
+
router_probs=encoder_outputs[3] if len(encoder_outputs) > 3 else None,
|
| 1573 |
+
)
|
| 1574 |
+
|
| 1575 |
+
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
|
| 1576 |
+
decoder_outputs = self.decoder(
|
| 1577 |
+
input_ids=decoder_input_ids,
|
| 1578 |
+
attention_mask=decoder_attention_mask,
|
| 1579 |
+
encoder_hidden_states=encoder_outputs[0],
|
| 1580 |
+
encoder_attention_mask=attention_mask,
|
| 1581 |
+
head_mask=decoder_head_mask,
|
| 1582 |
+
cross_attn_head_mask=cross_attn_head_mask,
|
| 1583 |
+
past_key_values=past_key_values,
|
| 1584 |
+
inputs_embeds=decoder_inputs_embeds,
|
| 1585 |
+
use_cache=use_cache,
|
| 1586 |
+
output_attentions=output_attentions,
|
| 1587 |
+
output_hidden_states=output_hidden_states,
|
| 1588 |
+
output_router_logits=output_router_logits,
|
| 1589 |
+
return_dict=return_dict,
|
| 1590 |
+
)
|
| 1591 |
+
|
| 1592 |
+
if not return_dict:
|
| 1593 |
+
return decoder_outputs + encoder_outputs
|
| 1594 |
+
|
| 1595 |
+
return Seq2SeqMoEModelOutput(
|
| 1596 |
+
past_key_values=decoder_outputs.past_key_values,
|
| 1597 |
+
cross_attentions=decoder_outputs.cross_attentions,
|
| 1598 |
+
last_hidden_state=decoder_outputs.last_hidden_state,
|
| 1599 |
+
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
| 1600 |
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
| 1601 |
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
| 1602 |
+
encoder_attentions=encoder_outputs.attentions,
|
| 1603 |
+
decoder_attentions=decoder_outputs.attentions,
|
| 1604 |
+
encoder_router_logits=encoder_outputs.router_probs,
|
| 1605 |
+
decoder_router_logits=decoder_outputs.router_probs,
|
| 1606 |
+
)
|
| 1607 |
+
|
| 1608 |
+
|
| 1609 |
+
@add_start_docstrings(
|
| 1610 |
+
"The NllbMoe Model with a language modeling head. Can be used for summarization.", NLLB_MOE_START_DOCSTRING
|
| 1611 |
+
)
|
| 1612 |
+
class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel, GenerationMixin):
|
| 1613 |
+
base_model_prefix = "model"
|
| 1614 |
+
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
|
| 1615 |
+
|
| 1616 |
+
def __init__(self, config: NllbMoeConfig):
|
| 1617 |
+
super().__init__(config)
|
| 1618 |
+
self.model = NllbMoeModel(config)
|
| 1619 |
+
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
| 1620 |
+
|
| 1621 |
+
self.router_z_loss_coef = config.router_z_loss_coef
|
| 1622 |
+
self.router_aux_loss_coef = config.router_aux_loss_coef
|
| 1623 |
+
# Initialize weights and apply final processing
|
| 1624 |
+
self.post_init()
|
| 1625 |
+
|
| 1626 |
+
def get_encoder(self):
|
| 1627 |
+
return self.model.get_encoder()
|
| 1628 |
+
|
| 1629 |
+
def get_decoder(self):
|
| 1630 |
+
return self.model.get_decoder()
|
| 1631 |
+
|
| 1632 |
+
def get_output_embeddings(self):
|
| 1633 |
+
return self.lm_head
|
| 1634 |
+
|
| 1635 |
+
def set_output_embeddings(self, new_embeddings):
|
| 1636 |
+
self.lm_head = new_embeddings
|
| 1637 |
+
|
| 1638 |
+
@add_start_docstrings_to_model_forward(NLLB_MOE_INPUTS_DOCSTRING)
|
| 1639 |
+
@replace_return_docstrings(output_type=Seq2SeqMoEOutput, config_class=_CONFIG_FOR_DOC)
|
| 1640 |
+
@add_end_docstrings(NLLB_MOE_GENERATION_EXAMPLE)
|
| 1641 |
+
def forward(
|
| 1642 |
+
self,
|
| 1643 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1644 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1645 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
| 1646 |
+
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
| 1647 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1648 |
+
decoder_head_mask: Optional[torch.Tensor] = None,
|
| 1649 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
| 1650 |
+
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 1651 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 1652 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1653 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1654 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1655 |
+
use_cache: Optional[bool] = None,
|
| 1656 |
+
output_attentions: Optional[bool] = None,
|
| 1657 |
+
output_hidden_states: Optional[bool] = None,
|
| 1658 |
+
output_router_logits: Optional[bool] = None,
|
| 1659 |
+
return_dict: Optional[bool] = None,
|
| 1660 |
+
) -> Union[Tuple[torch.Tensor], Seq2SeqMoEOutput]:
|
| 1661 |
+
r"""
|
| 1662 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1663 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 1664 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 1665 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 1666 |
+
|
| 1667 |
+
Returns:
|
| 1668 |
+
"""
|
| 1669 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 1670 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1671 |
+
output_router_logits = (
|
| 1672 |
+
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
| 1673 |
+
)
|
| 1674 |
+
if labels is not None:
|
| 1675 |
+
if decoder_input_ids is None:
|
| 1676 |
+
decoder_input_ids = shift_tokens_right(
|
| 1677 |
+
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
| 1678 |
+
)
|
| 1679 |
+
|
| 1680 |
+
outputs = self.model(
|
| 1681 |
+
input_ids,
|
| 1682 |
+
attention_mask=attention_mask,
|
| 1683 |
+
decoder_input_ids=decoder_input_ids,
|
| 1684 |
+
encoder_outputs=encoder_outputs,
|
| 1685 |
+
decoder_attention_mask=decoder_attention_mask,
|
| 1686 |
+
head_mask=head_mask,
|
| 1687 |
+
decoder_head_mask=decoder_head_mask,
|
| 1688 |
+
cross_attn_head_mask=cross_attn_head_mask,
|
| 1689 |
+
past_key_values=past_key_values,
|
| 1690 |
+
inputs_embeds=inputs_embeds,
|
| 1691 |
+
decoder_inputs_embeds=decoder_inputs_embeds,
|
| 1692 |
+
use_cache=use_cache,
|
| 1693 |
+
output_attentions=output_attentions,
|
| 1694 |
+
output_hidden_states=output_hidden_states,
|
| 1695 |
+
output_router_logits=output_router_logits,
|
| 1696 |
+
return_dict=return_dict,
|
| 1697 |
+
)
|
| 1698 |
+
lm_logits = self.lm_head(outputs[0])
|
| 1699 |
+
|
| 1700 |
+
loss = None
|
| 1701 |
+
encoder_aux_loss = None
|
| 1702 |
+
decoder_aux_loss = None
|
| 1703 |
+
|
| 1704 |
+
if labels is not None:
|
| 1705 |
+
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
| 1706 |
+
# todo check in the config if router loss enables
|
| 1707 |
+
|
| 1708 |
+
if output_router_logits:
|
| 1709 |
+
encoder_router_logits = outputs[-1]
|
| 1710 |
+
decoder_router_logits = outputs[3 if output_attentions else 4]
|
| 1711 |
+
|
| 1712 |
+
# Compute the router loss (z_loss + auxiliary loss) for each router in the encoder and decoder
|
| 1713 |
+
encoder_router_logits, encoder_expert_indexes = self._unpack_router_logits(encoder_router_logits)
|
| 1714 |
+
encoder_aux_loss = load_balancing_loss_func(encoder_router_logits, encoder_expert_indexes)
|
| 1715 |
+
|
| 1716 |
+
decoder_router_logits, decoder_expert_indexes = self._unpack_router_logits(decoder_router_logits)
|
| 1717 |
+
decoder_aux_loss = load_balancing_loss_func(decoder_router_logits, decoder_expert_indexes)
|
| 1718 |
+
|
| 1719 |
+
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
|
| 1720 |
+
|
| 1721 |
+
if output_router_logits and labels is not None:
|
| 1722 |
+
aux_loss = self.router_aux_loss_coef * (encoder_aux_loss + decoder_aux_loss)
|
| 1723 |
+
loss = loss + aux_loss
|
| 1724 |
+
|
| 1725 |
+
output = (loss,) if loss is not None else ()
|
| 1726 |
+
if not return_dict:
|
| 1727 |
+
output += (lm_logits,)
|
| 1728 |
+
if output_router_logits: # only return the loss if they are not None
|
| 1729 |
+
output += (
|
| 1730 |
+
encoder_aux_loss,
|
| 1731 |
+
decoder_aux_loss,
|
| 1732 |
+
*outputs[1:],
|
| 1733 |
+
)
|
| 1734 |
+
else:
|
| 1735 |
+
output += outputs[1:]
|
| 1736 |
+
|
| 1737 |
+
return output
|
| 1738 |
+
|
| 1739 |
+
return Seq2SeqMoEOutput(
|
| 1740 |
+
loss=loss,
|
| 1741 |
+
logits=lm_logits,
|
| 1742 |
+
past_key_values=outputs.past_key_values,
|
| 1743 |
+
cross_attentions=outputs.cross_attentions,
|
| 1744 |
+
encoder_aux_loss=encoder_aux_loss,
|
| 1745 |
+
decoder_aux_loss=decoder_aux_loss,
|
| 1746 |
+
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
| 1747 |
+
encoder_hidden_states=outputs.encoder_hidden_states,
|
| 1748 |
+
decoder_hidden_states=outputs.decoder_hidden_states,
|
| 1749 |
+
encoder_attentions=outputs.encoder_attentions,
|
| 1750 |
+
decoder_attentions=outputs.decoder_attentions,
|
| 1751 |
+
encoder_router_logits=outputs.encoder_router_logits,
|
| 1752 |
+
decoder_router_logits=outputs.decoder_router_logits,
|
| 1753 |
+
)
|
| 1754 |
+
|
| 1755 |
+
def _unpack_router_logits(self, router_outputs):
|
| 1756 |
+
total_router_logits = []
|
| 1757 |
+
total_expert_indexes = []
|
| 1758 |
+
for router_output in router_outputs:
|
| 1759 |
+
if router_output is not None:
|
| 1760 |
+
router_logits, expert_indexes = router_output
|
| 1761 |
+
total_router_logits.append(router_logits)
|
| 1762 |
+
total_expert_indexes.append(expert_indexes)
|
| 1763 |
+
|
| 1764 |
+
total_router_logits = torch.cat(total_router_logits, dim=1) if len(total_router_logits) > 0 else None
|
| 1765 |
+
total_expert_indexes = torch.stack(total_expert_indexes, dim=1) if len(total_expert_indexes) > 0 else None
|
| 1766 |
+
return total_router_logits, total_expert_indexes
|
| 1767 |
+
|
| 1768 |
+
@staticmethod
|
| 1769 |
+
def _reorder_cache(past_key_values, beam_idx):
|
| 1770 |
+
reordered_past = ()
|
| 1771 |
+
for layer_past in past_key_values:
|
| 1772 |
+
reordered_past += (
|
| 1773 |
+
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
| 1774 |
+
)
|
| 1775 |
+
return reordered_past
|
| 1776 |
+
|
| 1777 |
+
|
| 1778 |
+
__all__ = [
|
| 1779 |
+
"NllbMoeForConditionalGeneration",
|
| 1780 |
+
"NllbMoeModel",
|
| 1781 |
+
"NllbMoePreTrainedModel",
|
| 1782 |
+
"NllbMoeTop2Router",
|
| 1783 |
+
"NllbMoeSparseMLP",
|
| 1784 |
+
]
|
docs/transformers/src/transformers/models/nougat/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .image_processing_nougat import *
|
| 22 |
+
from .processing_nougat import *
|
| 23 |
+
from .tokenization_nougat_fast import *
|
| 24 |
+
else:
|
| 25 |
+
import sys
|
| 26 |
+
|
| 27 |
+
_file = globals()["__file__"]
|
| 28 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/src/transformers/models/nougat/convert_nougat_to_hf.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Convert Nougat checkpoints using the original `nougat` library. URL:
|
| 16 |
+
https://github.com/facebookresearch/nougat/tree/main"""
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
from huggingface_hub import hf_hub_download
|
| 22 |
+
from nougat import NougatModel
|
| 23 |
+
from nougat.dataset.rasterize import rasterize_paper
|
| 24 |
+
from nougat.utils.checkpoint import get_checkpoint
|
| 25 |
+
from PIL import Image
|
| 26 |
+
|
| 27 |
+
from transformers import (
|
| 28 |
+
DonutSwinConfig,
|
| 29 |
+
DonutSwinModel,
|
| 30 |
+
MBartConfig,
|
| 31 |
+
MBartForCausalLM,
|
| 32 |
+
NougatImageProcessor,
|
| 33 |
+
NougatProcessor,
|
| 34 |
+
NougatTokenizerFast,
|
| 35 |
+
VisionEncoderDecoderModel,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def get_configs(model):
|
| 40 |
+
original_config = model.config
|
| 41 |
+
|
| 42 |
+
encoder_config = DonutSwinConfig(
|
| 43 |
+
image_size=original_config.input_size,
|
| 44 |
+
patch_size=4,
|
| 45 |
+
depths=original_config.encoder_layer,
|
| 46 |
+
num_heads=[4, 8, 16, 32],
|
| 47 |
+
window_size=original_config.window_size,
|
| 48 |
+
embed_dim=128,
|
| 49 |
+
)
|
| 50 |
+
decoder_config = MBartConfig(
|
| 51 |
+
is_decoder=True,
|
| 52 |
+
is_encoder_decoder=False,
|
| 53 |
+
add_cross_attention=True,
|
| 54 |
+
decoder_layers=original_config.decoder_layer,
|
| 55 |
+
max_position_embeddings=original_config.max_position_embeddings,
|
| 56 |
+
vocab_size=len(
|
| 57 |
+
model.decoder.tokenizer
|
| 58 |
+
), # several special tokens are added to the vocab of XLMRobertaTokenizer, see repo on the hub (added_tokens.json)
|
| 59 |
+
scale_embedding=True,
|
| 60 |
+
add_final_layer_norm=True,
|
| 61 |
+
tie_word_embeddings=False,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
return encoder_config, decoder_config
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# Copied from transformers.models.donut.convert_donut_to_pytorch.rename_key
|
| 68 |
+
def rename_key(name):
|
| 69 |
+
if "encoder.model" in name:
|
| 70 |
+
name = name.replace("encoder.model", "encoder")
|
| 71 |
+
if "decoder.model" in name:
|
| 72 |
+
name = name.replace("decoder.model", "decoder")
|
| 73 |
+
if "patch_embed.proj" in name:
|
| 74 |
+
name = name.replace("patch_embed.proj", "embeddings.patch_embeddings.projection")
|
| 75 |
+
if "patch_embed.norm" in name:
|
| 76 |
+
name = name.replace("patch_embed.norm", "embeddings.norm")
|
| 77 |
+
if name.startswith("encoder"):
|
| 78 |
+
if "layers" in name:
|
| 79 |
+
name = "encoder." + name
|
| 80 |
+
if "attn.proj" in name:
|
| 81 |
+
name = name.replace("attn.proj", "attention.output.dense")
|
| 82 |
+
if "attn" in name and "mask" not in name:
|
| 83 |
+
name = name.replace("attn", "attention.self")
|
| 84 |
+
if "norm1" in name:
|
| 85 |
+
name = name.replace("norm1", "layernorm_before")
|
| 86 |
+
if "norm2" in name:
|
| 87 |
+
name = name.replace("norm2", "layernorm_after")
|
| 88 |
+
if "mlp.fc1" in name:
|
| 89 |
+
name = name.replace("mlp.fc1", "intermediate.dense")
|
| 90 |
+
if "mlp.fc2" in name:
|
| 91 |
+
name = name.replace("mlp.fc2", "output.dense")
|
| 92 |
+
|
| 93 |
+
if name == "encoder.norm.weight":
|
| 94 |
+
name = "encoder.layernorm.weight"
|
| 95 |
+
if name == "encoder.norm.bias":
|
| 96 |
+
name = "encoder.layernorm.bias"
|
| 97 |
+
|
| 98 |
+
return name
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# Copied from transformers.models.donut.convert_donut_to_pytorch.convert_state_dict
|
| 102 |
+
def convert_state_dict(orig_state_dict, model):
|
| 103 |
+
for key in orig_state_dict.copy().keys():
|
| 104 |
+
val = orig_state_dict.pop(key)
|
| 105 |
+
|
| 106 |
+
if "qkv" in key:
|
| 107 |
+
key_split = key.split(".")
|
| 108 |
+
layer_num = int(key_split[3])
|
| 109 |
+
block_num = int(key_split[5])
|
| 110 |
+
dim = model.encoder.encoder.layers[layer_num].blocks[block_num].attention.self.all_head_size
|
| 111 |
+
|
| 112 |
+
if "weight" in key:
|
| 113 |
+
orig_state_dict[
|
| 114 |
+
f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.query.weight"
|
| 115 |
+
] = val[:dim, :]
|
| 116 |
+
orig_state_dict[f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.key.weight"] = (
|
| 117 |
+
val[dim : dim * 2, :]
|
| 118 |
+
)
|
| 119 |
+
orig_state_dict[
|
| 120 |
+
f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.value.weight"
|
| 121 |
+
] = val[-dim:, :]
|
| 122 |
+
else:
|
| 123 |
+
orig_state_dict[f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.query.bias"] = (
|
| 124 |
+
val[:dim]
|
| 125 |
+
)
|
| 126 |
+
orig_state_dict[f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.key.bias"] = (
|
| 127 |
+
val[dim : dim * 2]
|
| 128 |
+
)
|
| 129 |
+
orig_state_dict[f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.value.bias"] = (
|
| 130 |
+
val[-dim:]
|
| 131 |
+
)
|
| 132 |
+
elif "attn_mask" in key or key in ["encoder.model.norm.weight", "encoder.model.norm.bias"]:
|
| 133 |
+
# HuggingFace implementation doesn't use attn_mask buffer
|
| 134 |
+
# and model doesn't use final LayerNorms for the encoder
|
| 135 |
+
pass
|
| 136 |
+
else:
|
| 137 |
+
orig_state_dict[rename_key(key)] = val
|
| 138 |
+
|
| 139 |
+
return orig_state_dict
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def convert_nougat_checkpoint(model_tag, pytorch_dump_folder_path=None, push_to_hub=False):
|
| 143 |
+
# load original model
|
| 144 |
+
checkpoint_path = get_checkpoint(None, model_tag)
|
| 145 |
+
original_model = NougatModel.from_pretrained(checkpoint_path)
|
| 146 |
+
original_model.eval()
|
| 147 |
+
|
| 148 |
+
# load HuggingFace model
|
| 149 |
+
encoder_config, decoder_config = get_configs(original_model)
|
| 150 |
+
encoder = DonutSwinModel(encoder_config)
|
| 151 |
+
decoder = MBartForCausalLM(decoder_config)
|
| 152 |
+
model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder)
|
| 153 |
+
model.eval()
|
| 154 |
+
|
| 155 |
+
state_dict = original_model.state_dict()
|
| 156 |
+
new_state_dict = convert_state_dict(state_dict, model)
|
| 157 |
+
model.load_state_dict(new_state_dict)
|
| 158 |
+
|
| 159 |
+
# verify results on PDF
|
| 160 |
+
filepath = hf_hub_download(repo_id="ysharma/nougat", filename="input/nougat.pdf", repo_type="space")
|
| 161 |
+
images = rasterize_paper(pdf=filepath, return_pil=True)
|
| 162 |
+
image = Image.open(images[0])
|
| 163 |
+
|
| 164 |
+
tokenizer_file = checkpoint_path / "tokenizer.json"
|
| 165 |
+
tokenizer = NougatTokenizerFast(tokenizer_file=str(tokenizer_file))
|
| 166 |
+
tokenizer.pad_token = "<pad>"
|
| 167 |
+
tokenizer.bos_token = "<s>"
|
| 168 |
+
tokenizer.eos_token = "</s>"
|
| 169 |
+
tokenizer.unk_token = "<unk>"
|
| 170 |
+
tokenizer.model_max_length = original_model.config.max_length
|
| 171 |
+
|
| 172 |
+
size = {"height": original_model.config.input_size[0], "width": original_model.config.input_size[1]}
|
| 173 |
+
image_processor = NougatImageProcessor(
|
| 174 |
+
do_align_long_axis=original_model.config.align_long_axis,
|
| 175 |
+
size=size,
|
| 176 |
+
)
|
| 177 |
+
processor = NougatProcessor(image_processor=image_processor, tokenizer=tokenizer)
|
| 178 |
+
|
| 179 |
+
# verify pixel_values
|
| 180 |
+
pixel_values = processor(image, return_tensors="pt").pixel_values
|
| 181 |
+
original_pixel_values = original_model.encoder.prepare_input(image).unsqueeze(0)
|
| 182 |
+
|
| 183 |
+
assert torch.allclose(original_pixel_values, pixel_values)
|
| 184 |
+
|
| 185 |
+
# verify patch embeddings
|
| 186 |
+
original_patch_embed = original_model.encoder.model.patch_embed(pixel_values)
|
| 187 |
+
patch_embeddings, _ = model.encoder.embeddings(pixel_values)
|
| 188 |
+
assert torch.allclose(original_patch_embed, patch_embeddings)
|
| 189 |
+
|
| 190 |
+
# verify encoder hidden states
|
| 191 |
+
original_last_hidden_state = original_model.encoder(pixel_values)
|
| 192 |
+
last_hidden_state = model.encoder(pixel_values).last_hidden_state
|
| 193 |
+
assert torch.allclose(original_last_hidden_state, last_hidden_state, atol=1e-2)
|
| 194 |
+
|
| 195 |
+
# NOTE original model does not use tied weights for embeddings of decoder
|
| 196 |
+
original_embeddings = original_model.decoder.model.model.decoder.embed_tokens
|
| 197 |
+
embeddings = model.decoder.model.decoder.embed_tokens
|
| 198 |
+
assert torch.allclose(original_embeddings.weight, embeddings.weight, atol=1e-3)
|
| 199 |
+
|
| 200 |
+
# verify decoder hidden states
|
| 201 |
+
prompt = "hello world"
|
| 202 |
+
decoder_input_ids = original_model.decoder.tokenizer(
|
| 203 |
+
prompt, add_special_tokens=False, return_tensors="pt"
|
| 204 |
+
).input_ids
|
| 205 |
+
decoder_attention_mask = torch.ones_like(decoder_input_ids)
|
| 206 |
+
original_logits = original_model(
|
| 207 |
+
image_tensors=pixel_values, decoder_input_ids=decoder_input_ids, attention_mask=decoder_attention_mask
|
| 208 |
+
).logits
|
| 209 |
+
logits = model(
|
| 210 |
+
pixel_values,
|
| 211 |
+
decoder_input_ids=decoder_input_ids[:, :-1],
|
| 212 |
+
decoder_attention_mask=decoder_attention_mask[:, :-1],
|
| 213 |
+
).logits
|
| 214 |
+
assert torch.allclose(original_logits, logits, atol=1e-3)
|
| 215 |
+
|
| 216 |
+
# verify generation
|
| 217 |
+
outputs = model.generate(
|
| 218 |
+
pixel_values,
|
| 219 |
+
min_length=1,
|
| 220 |
+
max_length=30,
|
| 221 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 222 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 223 |
+
use_cache=True,
|
| 224 |
+
bad_words_ids=[
|
| 225 |
+
[tokenizer.unk_token_id],
|
| 226 |
+
],
|
| 227 |
+
return_dict_in_generate=True,
|
| 228 |
+
do_sample=False,
|
| 229 |
+
)
|
| 230 |
+
generated = tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)[0]
|
| 231 |
+
|
| 232 |
+
if model_tag == "0.1.0-base":
|
| 233 |
+
expected_generation = "# Nougat: Neural Optical Understanding for Academic Documents\n\nLukas Blecher\n\nCorrespondence to: lblec"
|
| 234 |
+
elif model_tag == "0.1.0-small":
|
| 235 |
+
expected_generation = (
|
| 236 |
+
"# Nougat: Neural Optical Understanding for Academic Documents\n\nLukas Blecher\n\nCorrespondence to: lble"
|
| 237 |
+
)
|
| 238 |
+
else:
|
| 239 |
+
raise ValueError(f"Unexpected model tag: {model_tag}")
|
| 240 |
+
|
| 241 |
+
assert generated == expected_generation
|
| 242 |
+
print("Looks ok!")
|
| 243 |
+
|
| 244 |
+
if pytorch_dump_folder_path is not None:
|
| 245 |
+
print(f"Saving model and processor to {pytorch_dump_folder_path}")
|
| 246 |
+
model.save_pretrained(pytorch_dump_folder_path)
|
| 247 |
+
processor.save_pretrained(pytorch_dump_folder_path)
|
| 248 |
+
|
| 249 |
+
if push_to_hub:
|
| 250 |
+
tag_to_name = {"0.1.0-base": "nougat-base", "0.1.0-small": "nougat-small"}
|
| 251 |
+
model_name = tag_to_name[model_tag]
|
| 252 |
+
|
| 253 |
+
model.push_to_hub(f"facebook/{model_name}")
|
| 254 |
+
processor.push_to_hub(f"facebook/{model_name}")
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
if __name__ == "__main__":
|
| 258 |
+
parser = argparse.ArgumentParser()
|
| 259 |
+
# Required parameters
|
| 260 |
+
parser.add_argument(
|
| 261 |
+
"--model_tag",
|
| 262 |
+
default="0.1.0-base",
|
| 263 |
+
required=False,
|
| 264 |
+
type=str,
|
| 265 |
+
choices=["0.1.0-base", "0.1.0-small"],
|
| 266 |
+
help="Tag of the original model you'd like to convert.",
|
| 267 |
+
)
|
| 268 |
+
parser.add_argument(
|
| 269 |
+
"--pytorch_dump_folder_path",
|
| 270 |
+
default=None,
|
| 271 |
+
required=False,
|
| 272 |
+
type=str,
|
| 273 |
+
help="Path to the output PyTorch model directory.",
|
| 274 |
+
)
|
| 275 |
+
parser.add_argument(
|
| 276 |
+
"--push_to_hub",
|
| 277 |
+
action="store_true",
|
| 278 |
+
help="Whether or not to push the converted model and processor to the 🤗 hub.",
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
args = parser.parse_args()
|
| 282 |
+
convert_nougat_checkpoint(args.model_tag, args.pytorch_dump_folder_path, args.push_to_hub)
|
docs/transformers/src/transformers/models/nougat/image_processing_nougat.py
ADDED
|
@@ -0,0 +1,525 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Image processor class for Nougat."""
|
| 16 |
+
|
| 17 |
+
from typing import Dict, List, Optional, Union
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
| 22 |
+
from ...image_transforms import (
|
| 23 |
+
get_resize_output_image_size,
|
| 24 |
+
pad,
|
| 25 |
+
resize,
|
| 26 |
+
to_channel_dimension_format,
|
| 27 |
+
to_pil_image,
|
| 28 |
+
)
|
| 29 |
+
from ...image_utils import (
|
| 30 |
+
IMAGENET_DEFAULT_MEAN,
|
| 31 |
+
IMAGENET_DEFAULT_STD,
|
| 32 |
+
ChannelDimension,
|
| 33 |
+
ImageInput,
|
| 34 |
+
PILImageResampling,
|
| 35 |
+
get_image_size,
|
| 36 |
+
infer_channel_dimension_format,
|
| 37 |
+
is_scaled_image,
|
| 38 |
+
make_list_of_images,
|
| 39 |
+
to_numpy_array,
|
| 40 |
+
valid_images,
|
| 41 |
+
validate_preprocess_arguments,
|
| 42 |
+
)
|
| 43 |
+
from ...utils import TensorType, filter_out_non_signature_kwargs, logging
|
| 44 |
+
from ...utils.import_utils import is_cv2_available, is_vision_available
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
logger = logging.get_logger(__name__)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
if is_cv2_available():
|
| 51 |
+
pass
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
if is_vision_available():
|
| 55 |
+
import PIL
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class NougatImageProcessor(BaseImageProcessor):
|
| 59 |
+
r"""
|
| 60 |
+
Constructs a Nougat image processor.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
do_crop_margin (`bool`, *optional*, defaults to `True`):
|
| 64 |
+
Whether to crop the image margins.
|
| 65 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
| 66 |
+
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
|
| 67 |
+
`do_resize` in the `preprocess` method.
|
| 68 |
+
size (`Dict[str, int]` *optional*, defaults to `{"height": 896, "width": 672}`):
|
| 69 |
+
Size of the image after resizing. Can be overridden by `size` in the `preprocess` method.
|
| 70 |
+
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
|
| 71 |
+
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
|
| 72 |
+
do_thumbnail (`bool`, *optional*, defaults to `True`):
|
| 73 |
+
Whether to resize the image using thumbnail method.
|
| 74 |
+
do_align_long_axis (`bool`, *optional*, defaults to `False`):
|
| 75 |
+
Whether to align the long axis of the image with the long axis of `size` by rotating by 90 degrees.
|
| 76 |
+
do_pad (`bool`, *optional*, defaults to `True`):
|
| 77 |
+
Whether to pad the images to the largest image size in the batch.
|
| 78 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
| 79 |
+
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
|
| 80 |
+
parameter in the `preprocess` method.
|
| 81 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
| 82 |
+
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
|
| 83 |
+
`preprocess` method.
|
| 84 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
| 85 |
+
Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
|
| 86 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`):
|
| 87 |
+
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
| 88 |
+
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
| 89 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`):
|
| 90 |
+
Image standard deviation.
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
model_input_names = ["pixel_values"]
|
| 94 |
+
|
| 95 |
+
def __init__(
|
| 96 |
+
self,
|
| 97 |
+
do_crop_margin: bool = True,
|
| 98 |
+
do_resize: bool = True,
|
| 99 |
+
size: Dict[str, int] = None,
|
| 100 |
+
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
| 101 |
+
do_thumbnail: bool = True,
|
| 102 |
+
do_align_long_axis: bool = False,
|
| 103 |
+
do_pad: bool = True,
|
| 104 |
+
do_rescale: bool = True,
|
| 105 |
+
rescale_factor: Union[int, float] = 1 / 255,
|
| 106 |
+
do_normalize: bool = True,
|
| 107 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 108 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 109 |
+
**kwargs,
|
| 110 |
+
) -> None:
|
| 111 |
+
super().__init__(**kwargs)
|
| 112 |
+
|
| 113 |
+
size = size if size is not None else {"height": 896, "width": 672}
|
| 114 |
+
size = get_size_dict(size)
|
| 115 |
+
|
| 116 |
+
self.do_crop_margin = do_crop_margin
|
| 117 |
+
self.do_resize = do_resize
|
| 118 |
+
self.size = size
|
| 119 |
+
self.resample = resample
|
| 120 |
+
self.do_thumbnail = do_thumbnail
|
| 121 |
+
self.do_align_long_axis = do_align_long_axis
|
| 122 |
+
self.do_pad = do_pad
|
| 123 |
+
self.do_rescale = do_rescale
|
| 124 |
+
self.rescale_factor = rescale_factor
|
| 125 |
+
self.do_normalize = do_normalize
|
| 126 |
+
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
|
| 127 |
+
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
|
| 128 |
+
|
| 129 |
+
def python_find_non_zero(self, image: np.array):
|
| 130 |
+
"""This is a reimplementation of a findNonZero function equivalent to cv2."""
|
| 131 |
+
non_zero_indices = np.column_stack(np.nonzero(image))
|
| 132 |
+
idxvec = non_zero_indices[:, [1, 0]]
|
| 133 |
+
idxvec = idxvec.reshape(-1, 1, 2)
|
| 134 |
+
return idxvec
|
| 135 |
+
|
| 136 |
+
def python_bounding_rect(self, coordinates):
|
| 137 |
+
"""This is a reimplementation of a BoundingRect function equivalent to cv2."""
|
| 138 |
+
min_values = np.min(coordinates, axis=(0, 1)).astype(int)
|
| 139 |
+
max_values = np.max(coordinates, axis=(0, 1)).astype(int)
|
| 140 |
+
x_min, y_min = min_values[0], min_values[1]
|
| 141 |
+
width = max_values[0] - x_min + 1
|
| 142 |
+
height = max_values[1] - y_min + 1
|
| 143 |
+
return x_min, y_min, width, height
|
| 144 |
+
|
| 145 |
+
def crop_margin(
|
| 146 |
+
self,
|
| 147 |
+
image: np.array,
|
| 148 |
+
gray_threshold: int = 200,
|
| 149 |
+
data_format: Optional[ChannelDimension] = None,
|
| 150 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 151 |
+
) -> np.array:
|
| 152 |
+
"""
|
| 153 |
+
Crops the margin of the image. Gray pixels are considered margin (i.e., pixels with a value below the
|
| 154 |
+
threshold).
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
image (`np.array`):
|
| 158 |
+
The image to be cropped.
|
| 159 |
+
gray_threshold (`int`, *optional*, defaults to `200`)
|
| 160 |
+
Value below which pixels are considered to be gray.
|
| 161 |
+
data_format (`ChannelDimension`, *optional*):
|
| 162 |
+
The channel dimension format of the output image. If unset, will use the inferred format from the
|
| 163 |
+
input.
|
| 164 |
+
input_data_format (`ChannelDimension`, *optional*):
|
| 165 |
+
The channel dimension format of the input image. If unset, will use the inferred format from the input.
|
| 166 |
+
"""
|
| 167 |
+
if input_data_format is None:
|
| 168 |
+
input_data_format = infer_channel_dimension_format(image)
|
| 169 |
+
|
| 170 |
+
image = to_pil_image(image, input_data_format=input_data_format)
|
| 171 |
+
data = np.array(image.convert("L")).astype(np.uint8)
|
| 172 |
+
max_val = data.max()
|
| 173 |
+
min_val = data.min()
|
| 174 |
+
if max_val == min_val:
|
| 175 |
+
image = np.array(image)
|
| 176 |
+
image = (
|
| 177 |
+
to_channel_dimension_format(image, data_format, input_data_format)
|
| 178 |
+
if data_format is not None
|
| 179 |
+
else image
|
| 180 |
+
)
|
| 181 |
+
return image
|
| 182 |
+
data = (data - min_val) / (max_val - min_val) * 255
|
| 183 |
+
gray = data < gray_threshold
|
| 184 |
+
coords = self.python_find_non_zero(gray)
|
| 185 |
+
x_min, y_min, width, height = self.python_bounding_rect(coords)
|
| 186 |
+
image = image.crop((x_min, y_min, x_min + width, y_min + height))
|
| 187 |
+
image = np.array(image).astype(np.uint8)
|
| 188 |
+
image = to_channel_dimension_format(image, input_data_format, ChannelDimension.LAST)
|
| 189 |
+
|
| 190 |
+
image = (
|
| 191 |
+
to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
return image
|
| 195 |
+
|
| 196 |
+
# Copied from transformers.models.donut.image_processing_donut.DonutImageProcessor.align_long_axis
|
| 197 |
+
def align_long_axis(
|
| 198 |
+
self,
|
| 199 |
+
image: np.ndarray,
|
| 200 |
+
size: Dict[str, int],
|
| 201 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 202 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 203 |
+
) -> np.ndarray:
|
| 204 |
+
"""
|
| 205 |
+
Align the long axis of the image to the longest axis of the specified size.
|
| 206 |
+
|
| 207 |
+
Args:
|
| 208 |
+
image (`np.ndarray`):
|
| 209 |
+
The image to be aligned.
|
| 210 |
+
size (`Dict[str, int]`):
|
| 211 |
+
The size `{"height": h, "width": w}` to align the long axis to.
|
| 212 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
| 213 |
+
The data format of the output image. If unset, the same format as the input image is used.
|
| 214 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 215 |
+
The channel dimension format of the input image. If not provided, it will be inferred.
|
| 216 |
+
|
| 217 |
+
Returns:
|
| 218 |
+
`np.ndarray`: The aligned image.
|
| 219 |
+
"""
|
| 220 |
+
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
| 221 |
+
output_height, output_width = size["height"], size["width"]
|
| 222 |
+
|
| 223 |
+
if input_data_format is None:
|
| 224 |
+
# We assume that all images have the same channel dimension format.
|
| 225 |
+
input_data_format = infer_channel_dimension_format(image)
|
| 226 |
+
|
| 227 |
+
if input_data_format == ChannelDimension.LAST:
|
| 228 |
+
rot_axes = (0, 1)
|
| 229 |
+
elif input_data_format == ChannelDimension.FIRST:
|
| 230 |
+
rot_axes = (1, 2)
|
| 231 |
+
else:
|
| 232 |
+
raise ValueError(f"Unsupported data format: {input_data_format}")
|
| 233 |
+
|
| 234 |
+
if (output_width < output_height and input_width > input_height) or (
|
| 235 |
+
output_width > output_height and input_width < input_height
|
| 236 |
+
):
|
| 237 |
+
image = np.rot90(image, 3, axes=rot_axes)
|
| 238 |
+
|
| 239 |
+
if data_format is not None:
|
| 240 |
+
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
| 241 |
+
|
| 242 |
+
return image
|
| 243 |
+
|
| 244 |
+
def pad_image(
|
| 245 |
+
self,
|
| 246 |
+
image: np.ndarray,
|
| 247 |
+
size: Dict[str, int],
|
| 248 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 249 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 250 |
+
) -> np.ndarray:
|
| 251 |
+
"""
|
| 252 |
+
Pad the image to the specified size at the top, bottom, left and right.
|
| 253 |
+
|
| 254 |
+
Args:
|
| 255 |
+
image (`np.ndarray`):
|
| 256 |
+
The image to be padded.
|
| 257 |
+
size (`Dict[str, int]`):
|
| 258 |
+
The size `{"height": h, "width": w}` to pad the image to.
|
| 259 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
| 260 |
+
The data format of the output image. If unset, the same format as the input image is used.
|
| 261 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 262 |
+
The channel dimension format of the input image. If not provided, it will be inferred.
|
| 263 |
+
"""
|
| 264 |
+
output_height, output_width = size["height"], size["width"]
|
| 265 |
+
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
| 266 |
+
|
| 267 |
+
delta_width = output_width - input_width
|
| 268 |
+
delta_height = output_height - input_height
|
| 269 |
+
|
| 270 |
+
pad_top = delta_height // 2
|
| 271 |
+
pad_left = delta_width // 2
|
| 272 |
+
|
| 273 |
+
pad_bottom = delta_height - pad_top
|
| 274 |
+
pad_right = delta_width - pad_left
|
| 275 |
+
|
| 276 |
+
padding = ((pad_top, pad_bottom), (pad_left, pad_right))
|
| 277 |
+
return pad(image, padding, data_format=data_format, input_data_format=input_data_format)
|
| 278 |
+
|
| 279 |
+
# Copied from transformers.models.donut.image_processing_donut.DonutImageProcessor.thumbnail
|
| 280 |
+
def thumbnail(
|
| 281 |
+
self,
|
| 282 |
+
image: np.ndarray,
|
| 283 |
+
size: Dict[str, int],
|
| 284 |
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
| 285 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 286 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 287 |
+
**kwargs,
|
| 288 |
+
) -> np.ndarray:
|
| 289 |
+
"""
|
| 290 |
+
Resize the image to make a thumbnail. The image is resized so that no dimension is larger than any
|
| 291 |
+
corresponding dimension of the specified size.
|
| 292 |
+
|
| 293 |
+
Args:
|
| 294 |
+
image (`np.ndarray`):
|
| 295 |
+
The image to be resized.
|
| 296 |
+
size (`Dict[str, int]`):
|
| 297 |
+
The size `{"height": h, "width": w}` to resize the image to.
|
| 298 |
+
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
| 299 |
+
The resampling filter to use.
|
| 300 |
+
data_format (`Optional[Union[str, ChannelDimension]]`, *optional*):
|
| 301 |
+
The data format of the output image. If unset, the same format as the input image is used.
|
| 302 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 303 |
+
The channel dimension format of the input image. If not provided, it will be inferred.
|
| 304 |
+
"""
|
| 305 |
+
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
| 306 |
+
output_height, output_width = size["height"], size["width"]
|
| 307 |
+
|
| 308 |
+
# We always resize to the smallest of either the input or output size.
|
| 309 |
+
height = min(input_height, output_height)
|
| 310 |
+
width = min(input_width, output_width)
|
| 311 |
+
|
| 312 |
+
if height == input_height and width == input_width:
|
| 313 |
+
return image
|
| 314 |
+
|
| 315 |
+
if input_height > input_width:
|
| 316 |
+
width = int(input_width * height / input_height)
|
| 317 |
+
elif input_width > input_height:
|
| 318 |
+
height = int(input_height * width / input_width)
|
| 319 |
+
|
| 320 |
+
return resize(
|
| 321 |
+
image,
|
| 322 |
+
size=(height, width),
|
| 323 |
+
resample=resample,
|
| 324 |
+
reducing_gap=2.0,
|
| 325 |
+
data_format=data_format,
|
| 326 |
+
input_data_format=input_data_format,
|
| 327 |
+
**kwargs,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
# Copied from transformers.models.donut.image_processing_donut.DonutImageProcessor.resize
|
| 331 |
+
def resize(
|
| 332 |
+
self,
|
| 333 |
+
image: np.ndarray,
|
| 334 |
+
size: Dict[str, int],
|
| 335 |
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
| 336 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 337 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 338 |
+
**kwargs,
|
| 339 |
+
) -> np.ndarray:
|
| 340 |
+
"""
|
| 341 |
+
Resizes `image` to `(height, width)` specified by `size` using the PIL library.
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
image (`np.ndarray`):
|
| 345 |
+
Image to resize.
|
| 346 |
+
size (`Dict[str, int]`):
|
| 347 |
+
Size of the output image.
|
| 348 |
+
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
| 349 |
+
Resampling filter to use when resiizing the image.
|
| 350 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
| 351 |
+
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
| 352 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 353 |
+
The channel dimension format of the input image. If not provided, it will be inferred.
|
| 354 |
+
"""
|
| 355 |
+
size = get_size_dict(size)
|
| 356 |
+
shortest_edge = min(size["height"], size["width"])
|
| 357 |
+
output_size = get_resize_output_image_size(
|
| 358 |
+
image, size=shortest_edge, default_to_square=False, input_data_format=input_data_format
|
| 359 |
+
)
|
| 360 |
+
resized_image = resize(
|
| 361 |
+
image,
|
| 362 |
+
size=output_size,
|
| 363 |
+
resample=resample,
|
| 364 |
+
data_format=data_format,
|
| 365 |
+
input_data_format=input_data_format,
|
| 366 |
+
**kwargs,
|
| 367 |
+
)
|
| 368 |
+
return resized_image
|
| 369 |
+
|
| 370 |
+
@filter_out_non_signature_kwargs()
|
| 371 |
+
def preprocess(
|
| 372 |
+
self,
|
| 373 |
+
images: ImageInput,
|
| 374 |
+
do_crop_margin: Optional[bool] = None,
|
| 375 |
+
do_resize: Optional[bool] = None,
|
| 376 |
+
size: Dict[str, int] = None,
|
| 377 |
+
resample: PILImageResampling = None,
|
| 378 |
+
do_thumbnail: Optional[bool] = None,
|
| 379 |
+
do_align_long_axis: Optional[bool] = None,
|
| 380 |
+
do_pad: Optional[bool] = None,
|
| 381 |
+
do_rescale: Optional[bool] = None,
|
| 382 |
+
rescale_factor: Union[int, float] = None,
|
| 383 |
+
do_normalize: Optional[bool] = None,
|
| 384 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 385 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 386 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 387 |
+
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
| 388 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 389 |
+
) -> PIL.Image.Image:
|
| 390 |
+
"""
|
| 391 |
+
Preprocess an image or batch of images.
|
| 392 |
+
|
| 393 |
+
Args:
|
| 394 |
+
images (`ImageInput`):
|
| 395 |
+
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255.
|
| 396 |
+
do_crop_margin (`bool`, *optional*, defaults to `self.do_crop_margin`):
|
| 397 |
+
Whether to crop the image margins.
|
| 398 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
| 399 |
+
Whether to resize the image.
|
| 400 |
+
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
| 401 |
+
Size of the image after resizing. Shortest edge of the image is resized to min(size["height"],
|
| 402 |
+
size["width"]) with the longest edge resized to keep the input aspect ratio.
|
| 403 |
+
resample (`int`, *optional*, defaults to `self.resample`):
|
| 404 |
+
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
| 405 |
+
has an effect if `do_resize` is set to `True`.
|
| 406 |
+
do_thumbnail (`bool`, *optional*, defaults to `self.do_thumbnail`):
|
| 407 |
+
Whether to resize the image using thumbnail method.
|
| 408 |
+
do_align_long_axis (`bool`, *optional*, defaults to `self.do_align_long_axis`):
|
| 409 |
+
Whether to align the long axis of the image with the long axis of `size` by rotating by 90 degrees.
|
| 410 |
+
do_pad (`bool`, *optional*, defaults to `self.do_pad`):
|
| 411 |
+
Whether to pad the images to the largest image size in the batch.
|
| 412 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
| 413 |
+
Whether to rescale the image by the specified scale `rescale_factor`.
|
| 414 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `self.rescale_factor`):
|
| 415 |
+
Scale factor to use if rescaling the image.
|
| 416 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
| 417 |
+
Whether to normalize the image.
|
| 418 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
| 419 |
+
Image mean to use for normalization.
|
| 420 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
| 421 |
+
Image standard deviation to use for normalization.
|
| 422 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
| 423 |
+
The type of tensors to return. Can be one of:
|
| 424 |
+
- Unset: Return a list of `np.ndarray`.
|
| 425 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
| 426 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
| 427 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
| 428 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
| 429 |
+
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
| 430 |
+
The channel dimension format for the output image. Can be one of:
|
| 431 |
+
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 432 |
+
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 433 |
+
- Unset: defaults to the channel dimension format of the input image.
|
| 434 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 435 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 436 |
+
from the input image. Can be one of:
|
| 437 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 438 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 439 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 440 |
+
"""
|
| 441 |
+
do_crop_margin = do_crop_margin if do_crop_margin is not None else self.do_crop_margin
|
| 442 |
+
do_resize = do_resize if do_resize is not None else self.do_resize
|
| 443 |
+
size = size if size is not None else self.size
|
| 444 |
+
resample = resample if resample is not None else self.resample
|
| 445 |
+
do_thumbnail = do_thumbnail if do_thumbnail is not None else self.do_thumbnail
|
| 446 |
+
do_align_long_axis = do_align_long_axis if do_align_long_axis is not None else self.do_align_long_axis
|
| 447 |
+
do_pad = do_pad if do_pad is not None else self.do_pad
|
| 448 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
| 449 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
| 450 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
| 451 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
| 452 |
+
image_std = image_std if image_std is not None else self.image_std
|
| 453 |
+
|
| 454 |
+
images = make_list_of_images(images)
|
| 455 |
+
|
| 456 |
+
if not valid_images(images):
|
| 457 |
+
raise ValueError(
|
| 458 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 459 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 460 |
+
)
|
| 461 |
+
validate_preprocess_arguments(
|
| 462 |
+
do_rescale=do_rescale,
|
| 463 |
+
rescale_factor=rescale_factor,
|
| 464 |
+
do_normalize=do_normalize,
|
| 465 |
+
image_mean=image_mean,
|
| 466 |
+
image_std=image_std,
|
| 467 |
+
do_pad=do_pad,
|
| 468 |
+
size_divisibility=size, # There is no pad divisibility in this processor, but pad requires the size arg.
|
| 469 |
+
do_resize=do_resize,
|
| 470 |
+
size=size,
|
| 471 |
+
resample=resample,
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
# All transformations expect numpy arrays.
|
| 475 |
+
images = [to_numpy_array(image) for image in images]
|
| 476 |
+
|
| 477 |
+
if do_rescale and is_scaled_image(images[0]):
|
| 478 |
+
logger.warning_once(
|
| 479 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
| 480 |
+
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
if input_data_format is None:
|
| 484 |
+
# We assume that all images have the same channel dimension format.
|
| 485 |
+
input_data_format = infer_channel_dimension_format(images[0])
|
| 486 |
+
|
| 487 |
+
if do_crop_margin:
|
| 488 |
+
images = [self.crop_margin(image, input_data_format=input_data_format) for image in images]
|
| 489 |
+
|
| 490 |
+
if do_align_long_axis:
|
| 491 |
+
images = [self.align_long_axis(image, size=size, input_data_format=input_data_format) for image in images]
|
| 492 |
+
|
| 493 |
+
if do_resize:
|
| 494 |
+
images = [
|
| 495 |
+
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
| 496 |
+
for image in images
|
| 497 |
+
]
|
| 498 |
+
|
| 499 |
+
if do_thumbnail:
|
| 500 |
+
images = [self.thumbnail(image=image, size=size, input_data_format=input_data_format) for image in images]
|
| 501 |
+
|
| 502 |
+
if do_pad:
|
| 503 |
+
images = [self.pad_image(image=image, size=size, input_data_format=input_data_format) for image in images]
|
| 504 |
+
|
| 505 |
+
if do_rescale:
|
| 506 |
+
images = [
|
| 507 |
+
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
| 508 |
+
for image in images
|
| 509 |
+
]
|
| 510 |
+
|
| 511 |
+
if do_normalize:
|
| 512 |
+
images = [
|
| 513 |
+
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
| 514 |
+
for image in images
|
| 515 |
+
]
|
| 516 |
+
|
| 517 |
+
images = [
|
| 518 |
+
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
| 519 |
+
]
|
| 520 |
+
|
| 521 |
+
data = {"pixel_values": images}
|
| 522 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
__all__ = ["NougatImageProcessor"]
|
docs/transformers/src/transformers/models/nougat/processing_nougat.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
Processor class for Nougat.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from typing import Dict, List, Optional, Union
|
| 20 |
+
|
| 21 |
+
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput, TruncationStrategy
|
| 22 |
+
|
| 23 |
+
from ...processing_utils import ProcessorMixin
|
| 24 |
+
from ...utils import PaddingStrategy, TensorType
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class NougatProcessor(ProcessorMixin):
|
| 28 |
+
r"""
|
| 29 |
+
Constructs a Nougat processor which wraps a Nougat image processor and a Nougat tokenizer into a single processor.
|
| 30 |
+
|
| 31 |
+
[`NougatProcessor`] offers all the functionalities of [`NougatImageProcessor`] and [`NougatTokenizerFast`]. See the
|
| 32 |
+
[`~NougatProcessor.__call__`] and [`~NougatProcessor.decode`] for more information.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
image_processor ([`NougatImageProcessor`]):
|
| 36 |
+
An instance of [`NougatImageProcessor`]. The image processor is a required input.
|
| 37 |
+
tokenizer ([`NougatTokenizerFast`]):
|
| 38 |
+
An instance of [`NougatTokenizerFast`]. The tokenizer is a required input.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
attributes = ["image_processor", "tokenizer"]
|
| 42 |
+
image_processor_class = "AutoImageProcessor"
|
| 43 |
+
tokenizer_class = "AutoTokenizer"
|
| 44 |
+
|
| 45 |
+
def __init__(self, image_processor, tokenizer):
|
| 46 |
+
super().__init__(image_processor, tokenizer)
|
| 47 |
+
self.current_processor = self.image_processor
|
| 48 |
+
|
| 49 |
+
def __call__(
|
| 50 |
+
self,
|
| 51 |
+
images=None,
|
| 52 |
+
text=None,
|
| 53 |
+
do_crop_margin: Optional[bool] = None,
|
| 54 |
+
do_resize: Optional[bool] = None,
|
| 55 |
+
size: Dict[str, int] = None,
|
| 56 |
+
resample: "PILImageResampling" = None, # noqa: F821
|
| 57 |
+
do_thumbnail: Optional[bool] = None,
|
| 58 |
+
do_align_long_axis: Optional[bool] = None,
|
| 59 |
+
do_pad: Optional[bool] = None,
|
| 60 |
+
do_rescale: Optional[bool] = None,
|
| 61 |
+
rescale_factor: Union[int, float] = None,
|
| 62 |
+
do_normalize: Optional[bool] = None,
|
| 63 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 64 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 65 |
+
data_format: Optional["ChannelDimension"] = "channels_first", # noqa: F821
|
| 66 |
+
input_data_format: Optional[Union[str, "ChannelDimension"]] = None, # noqa: F821
|
| 67 |
+
text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
|
| 68 |
+
text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
| 69 |
+
text_pair_target: Optional[
|
| 70 |
+
Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]
|
| 71 |
+
] = None,
|
| 72 |
+
add_special_tokens: bool = True,
|
| 73 |
+
padding: Union[bool, str, PaddingStrategy] = False,
|
| 74 |
+
truncation: Union[bool, str, TruncationStrategy] = None,
|
| 75 |
+
max_length: Optional[int] = None,
|
| 76 |
+
stride: int = 0,
|
| 77 |
+
is_split_into_words: bool = False,
|
| 78 |
+
pad_to_multiple_of: Optional[int] = None,
|
| 79 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 80 |
+
return_token_type_ids: Optional[bool] = None,
|
| 81 |
+
return_attention_mask: Optional[bool] = None,
|
| 82 |
+
return_overflowing_tokens: bool = False,
|
| 83 |
+
return_special_tokens_mask: bool = False,
|
| 84 |
+
return_offsets_mapping: bool = False,
|
| 85 |
+
return_length: bool = False,
|
| 86 |
+
verbose: bool = True,
|
| 87 |
+
):
|
| 88 |
+
if images is None and text is None:
|
| 89 |
+
raise ValueError("You need to specify either an `images` or `text` input to process.")
|
| 90 |
+
|
| 91 |
+
if images is not None:
|
| 92 |
+
inputs = self.image_processor(
|
| 93 |
+
images,
|
| 94 |
+
do_crop_margin=do_crop_margin,
|
| 95 |
+
do_resize=do_resize,
|
| 96 |
+
size=size,
|
| 97 |
+
resample=resample,
|
| 98 |
+
do_thumbnail=do_thumbnail,
|
| 99 |
+
do_align_long_axis=do_align_long_axis,
|
| 100 |
+
do_pad=do_pad,
|
| 101 |
+
do_rescale=do_rescale,
|
| 102 |
+
rescale_factor=rescale_factor,
|
| 103 |
+
do_normalize=do_normalize,
|
| 104 |
+
image_mean=image_mean,
|
| 105 |
+
image_std=image_std,
|
| 106 |
+
return_tensors=return_tensors,
|
| 107 |
+
data_format=data_format,
|
| 108 |
+
input_data_format=input_data_format,
|
| 109 |
+
)
|
| 110 |
+
if text is not None:
|
| 111 |
+
encodings = self.tokenizer(
|
| 112 |
+
text,
|
| 113 |
+
text_pair=text_pair,
|
| 114 |
+
text_target=text_target,
|
| 115 |
+
text_pair_target=text_pair_target,
|
| 116 |
+
add_special_tokens=add_special_tokens,
|
| 117 |
+
padding=padding,
|
| 118 |
+
truncation=truncation,
|
| 119 |
+
max_length=max_length,
|
| 120 |
+
stride=stride,
|
| 121 |
+
is_split_into_words=is_split_into_words,
|
| 122 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
| 123 |
+
return_tensors=return_tensors,
|
| 124 |
+
return_token_type_ids=return_token_type_ids,
|
| 125 |
+
return_attention_mask=return_attention_mask,
|
| 126 |
+
return_overflowing_tokens=return_overflowing_tokens,
|
| 127 |
+
return_special_tokens_mask=return_special_tokens_mask,
|
| 128 |
+
return_offsets_mapping=return_offsets_mapping,
|
| 129 |
+
return_length=return_length,
|
| 130 |
+
verbose=verbose,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
if text is None:
|
| 134 |
+
return inputs
|
| 135 |
+
elif images is None:
|
| 136 |
+
return encodings
|
| 137 |
+
else:
|
| 138 |
+
inputs["labels"] = encodings["input_ids"]
|
| 139 |
+
return inputs
|
| 140 |
+
|
| 141 |
+
def batch_decode(self, *args, **kwargs):
|
| 142 |
+
"""
|
| 143 |
+
This method forwards all its arguments to NougatTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please refer
|
| 144 |
+
to the docstring of this method for more information.
|
| 145 |
+
"""
|
| 146 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
| 147 |
+
|
| 148 |
+
def decode(self, *args, **kwargs):
|
| 149 |
+
"""
|
| 150 |
+
This method forwards all its arguments to NougatTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
|
| 151 |
+
the docstring of this method for more information.
|
| 152 |
+
"""
|
| 153 |
+
return self.tokenizer.decode(*args, **kwargs)
|
| 154 |
+
|
| 155 |
+
def post_process_generation(self, *args, **kwargs):
|
| 156 |
+
"""
|
| 157 |
+
This method forwards all its arguments to NougatTokenizer's [`~PreTrainedTokenizer.post_process_generation`].
|
| 158 |
+
Please refer to the docstring of this method for more information.
|
| 159 |
+
"""
|
| 160 |
+
return self.tokenizer.post_process_generation(*args, **kwargs)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
__all__ = ["NougatProcessor"]
|
docs/transformers/src/transformers/models/nougat/tokenization_nougat_fast.py
ADDED
|
@@ -0,0 +1,620 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
Fast tokenizer class for Nougat.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import re
|
| 20 |
+
from functools import partial
|
| 21 |
+
from multiprocessing import Pool
|
| 22 |
+
from typing import List, Optional, Union
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
|
| 26 |
+
from transformers.tokenization_utils_base import INIT_TOKENIZER_DOCSTRING
|
| 27 |
+
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
| 28 |
+
from transformers.utils import add_end_docstrings
|
| 29 |
+
|
| 30 |
+
from ...utils import is_levenshtein_available, is_nltk_available, logging, requires_backends
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
if is_levenshtein_available():
|
| 34 |
+
from Levenshtein import ratio
|
| 35 |
+
|
| 36 |
+
if is_nltk_available():
|
| 37 |
+
import nltk
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
logger = logging.get_logger(__name__)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
INIT_TOKENIZER_DOCSTRING += """
|
| 44 |
+
tokenizer_object ([`tokenizers.Tokenizer`]):
|
| 45 |
+
A [`tokenizers.Tokenizer`] object from 🤗 tokenizers to instantiate from. See [Using tokenizers from 🤗
|
| 46 |
+
tokenizers](../fast_tokenizers) for more information.
|
| 47 |
+
tokenizer_file ([`str`]):
|
| 48 |
+
A path to a local JSON file representing a previously serialized [`tokenizers.Tokenizer`] object from 🤗
|
| 49 |
+
tokenizers.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
VOCAB_FILES_NAMES = {"tokenizer_file": "tokenizer.json"}
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def markdown_compatible(text: str) -> str:
|
| 57 |
+
"""
|
| 58 |
+
Make text compatible with Markdown formatting.
|
| 59 |
+
|
| 60 |
+
This function makes various text formatting adjustments to make it compatible with Markdown.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
text (`str`):
|
| 64 |
+
The input text to be made Markdown-compatible.
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
`str`: The Markdown-compatible text.
|
| 68 |
+
"""
|
| 69 |
+
# equation tag
|
| 70 |
+
# Replace lines that start with a pattern like (decimal) \[some text\] with \[[some text] \tag{decimal}\].
|
| 71 |
+
text = re.sub(r"^\(([\d.]+[a-zA-Z]?)\) \\\[(.+?)\\\]$", r"\[\2 \\tag{\1}\]", text, flags=re.M)
|
| 72 |
+
# Replace lines that start with a pattern like \[some text\] (decimal) with \[[some text] \tag{decimal}\].
|
| 73 |
+
text = re.sub(r"^\\\[(.+?)\\\] \(([\d.]+[a-zA-Z]?)\)$", r"\[\1 \\tag{\2}\]", text, flags=re.M)
|
| 74 |
+
# Replace lines that start with a pattern like \[some text\] (digits) \[another text\] with \[[some text] \tag{digits}\] [another text].
|
| 75 |
+
text = re.sub(
|
| 76 |
+
r"^\\\[(.+?)\\\] \(([\d.]+[a-zA-Z]?)\) (\\\[.+?\\\])$",
|
| 77 |
+
r"\[\1 \\tag{\2}\] \3",
|
| 78 |
+
text,
|
| 79 |
+
flags=re.M,
|
| 80 |
+
)
|
| 81 |
+
# multi line
|
| 82 |
+
text = text.replace(r"\. ", ". ")
|
| 83 |
+
# bold formatting
|
| 84 |
+
text = text.replace(r"\bm{", r"\mathbf{").replace(r"{\\bm ", r"\mathbf{")
|
| 85 |
+
text = re.sub(r"\\mbox{ ?\\boldmath\$(.*?)\$}", r"\\mathbf{\1}", text)
|
| 86 |
+
# Reformat urls (http, ftp and https only) to markdown [url](url) clickable format
|
| 87 |
+
text = re.sub(
|
| 88 |
+
r"((?:http|ftp|https):\/\/(?:[\w_-]+(?:(?:\.[\w_-]+)+))(?:[\w.,@?^=%&:\/~+#-]*[\w@?^=%&\/~+#-]))",
|
| 89 |
+
r"[\1](\1)",
|
| 90 |
+
text,
|
| 91 |
+
)
|
| 92 |
+
# algorithms
|
| 93 |
+
text = re.sub(r"```\s*(.+?)\s*```", r"```\n\1\n```", text, flags=re.S)
|
| 94 |
+
|
| 95 |
+
return text
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def normalize_list_like_lines(generation):
|
| 99 |
+
"""
|
| 100 |
+
Normalize lines in the given text that resemble list items. The function looks for lines that start optionally with
|
| 101 |
+
'-' or '*', possibly followed by Roman numerals or digits indicating nesting levels. The function reformats such
|
| 102 |
+
lines to make them more structured.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
generation (str): The input text containing lines that need to be normalized.
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
str: The input text with the list-like lines normalized.
|
| 109 |
+
|
| 110 |
+
Note:
|
| 111 |
+
The function uses regular expressions to identify and reformat the list-like lines. The patterns capture
|
| 112 |
+
optional bullet points, nesting levels indicated by numerals, and the actual list item content. The
|
| 113 |
+
normalization adjusts the bullet point style and nesting levels based on the captured patterns.
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
lines = generation.split("\n")
|
| 117 |
+
output_lines = []
|
| 118 |
+
for line_no, line in enumerate(lines):
|
| 119 |
+
match = re.search(r". ([-*]) ", line)
|
| 120 |
+
if not match or line[0] not in ("-", "*"):
|
| 121 |
+
output_lines.append(line)
|
| 122 |
+
continue # Doesn't fit the pattern we want, no changes
|
| 123 |
+
delim = match.group(1) + " "
|
| 124 |
+
splits = line.split(delim)[1:]
|
| 125 |
+
replacement = ""
|
| 126 |
+
delim1 = line[0] + " "
|
| 127 |
+
|
| 128 |
+
for i, item in enumerate(splits):
|
| 129 |
+
level = 0
|
| 130 |
+
potential_numeral, _, rest = item.strip().partition(" ")
|
| 131 |
+
if not rest:
|
| 132 |
+
continue
|
| 133 |
+
# Infer current nesting level based on detected numbering
|
| 134 |
+
if re.match(r"^[\dixv]+((?:\.[\dixv])?)+$", potential_numeral, flags=re.I | re.M):
|
| 135 |
+
level = potential_numeral.count(".")
|
| 136 |
+
|
| 137 |
+
replacement += (
|
| 138 |
+
("\n" if i > 0 else "") + ("\t" * level) + (delim if i > 0 or line_no == 0 else delim1) + item.strip()
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
if line_no == len(lines) - 1: # If this is the last line in the generation
|
| 142 |
+
replacement += "\n" # Add an empty line to the end of the generation
|
| 143 |
+
|
| 144 |
+
output_lines.append(replacement)
|
| 145 |
+
|
| 146 |
+
return "\n".join(output_lines)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def find_next_punctuation(text: str, start_idx=0):
|
| 150 |
+
"""
|
| 151 |
+
Find the index of the next punctuation mark.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
text (`str`):
|
| 155 |
+
String to examine
|
| 156 |
+
start_idx (`int`, *optional*)
|
| 157 |
+
Index where to start
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
for i in range(start_idx, len(text)):
|
| 161 |
+
if text[i] in [".", "?", "!", "\n"]:
|
| 162 |
+
return i
|
| 163 |
+
|
| 164 |
+
return None
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def truncate_repetitions(text: str, min_len: int = 30) -> str:
|
| 168 |
+
"""
|
| 169 |
+
Attempt to truncate repeating segments in the input string.
|
| 170 |
+
|
| 171 |
+
This function looks for the longest repeating substring at the end of the input string and truncates it to appear
|
| 172 |
+
only once. To be considered for removal, repetitions need to be continuous.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
text (`str`):
|
| 176 |
+
The input raw prediction to be truncated.
|
| 177 |
+
min_len (int):
|
| 178 |
+
The minimum length of the repeating segment.
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
`str`: The input string with repeated segments truncated.
|
| 182 |
+
"""
|
| 183 |
+
text_lower = text.lower()
|
| 184 |
+
text_length = len(text_lower)
|
| 185 |
+
|
| 186 |
+
if text_length < 2 * min_len:
|
| 187 |
+
return text
|
| 188 |
+
|
| 189 |
+
# try to find a length at which the tail is repeating
|
| 190 |
+
max_repetition_length = None
|
| 191 |
+
for repetition_length in range(min_len, int(text_length / 2)):
|
| 192 |
+
# check if there is a repetition at the end
|
| 193 |
+
same = True
|
| 194 |
+
for i in range(0, repetition_length):
|
| 195 |
+
if text_lower[text_length - repetition_length - i - 1] != text_lower[text_length - i - 1]:
|
| 196 |
+
same = False
|
| 197 |
+
break
|
| 198 |
+
|
| 199 |
+
if same:
|
| 200 |
+
max_repetition_length = repetition_length
|
| 201 |
+
|
| 202 |
+
if max_repetition_length is None:
|
| 203 |
+
return text
|
| 204 |
+
|
| 205 |
+
lcs = text_lower[-max_repetition_length:]
|
| 206 |
+
|
| 207 |
+
# remove all but the last repetition
|
| 208 |
+
substituted_text = text
|
| 209 |
+
substituted_text_lower = text_lower
|
| 210 |
+
while substituted_text_lower.endswith(lcs):
|
| 211 |
+
substituted_text = substituted_text[:-max_repetition_length]
|
| 212 |
+
substituted_text_lower = substituted_text_lower[:-max_repetition_length]
|
| 213 |
+
|
| 214 |
+
# this is the tail with the repetitions
|
| 215 |
+
repeating_tail = text_lower[len(substituted_text_lower) :]
|
| 216 |
+
|
| 217 |
+
# add until next punctuation and make sure last sentence is not repeating
|
| 218 |
+
substituted_text_lower_out = substituted_text_lower
|
| 219 |
+
while True:
|
| 220 |
+
sentence_end = find_next_punctuation(text_lower, len(substituted_text_lower_out))
|
| 221 |
+
sentence_start = find_next_punctuation(text_lower[::-1], len(substituted_text_lower_out))
|
| 222 |
+
if sentence_end and sentence_start:
|
| 223 |
+
sentence = text_lower[sentence_start:sentence_end]
|
| 224 |
+
substituted_text_lower_out = text_lower[: sentence_end + 1]
|
| 225 |
+
if sentence in repeating_tail:
|
| 226 |
+
break
|
| 227 |
+
else:
|
| 228 |
+
break
|
| 229 |
+
|
| 230 |
+
text_out = text[: len(substituted_text_lower_out)]
|
| 231 |
+
|
| 232 |
+
return text_out
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def remove_numbers(lines):
|
| 236 |
+
def _clean(s):
|
| 237 |
+
return re.sub(r"(?:[\d_]|\*\*)", "", s).strip()
|
| 238 |
+
|
| 239 |
+
if isinstance(lines, str):
|
| 240 |
+
return _clean(lines)
|
| 241 |
+
out = []
|
| 242 |
+
for l in lines:
|
| 243 |
+
out.append(_clean(l))
|
| 244 |
+
return out
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def get_slices(lines, clean_lines):
|
| 248 |
+
"""
|
| 249 |
+
Get slices of text based on specific criteria within the lines.
|
| 250 |
+
|
| 251 |
+
This function identifies and returns slices of text from the input lines based on certain conditions.
|
| 252 |
+
|
| 253 |
+
These conditions were chosen by the Nougat authors:
|
| 254 |
+
- The slice is less than 200 characters long.
|
| 255 |
+
- The slice is more than 3 characters long.
|
| 256 |
+
- The slice does not start with "[MISSING_PAGE".
|
| 257 |
+
- The slice is either the same as the next slice or the ratio of the two in terms of Levensthein distance is
|
| 258 |
+
greater than 0.9.
|
| 259 |
+
|
| 260 |
+
Args:
|
| 261 |
+
lines (`List[str]`):
|
| 262 |
+
The list of lines containing the text.
|
| 263 |
+
clean_lines (`List[str]`):
|
| 264 |
+
A cleaned version of the text (without numbers).
|
| 265 |
+
|
| 266 |
+
Returns:
|
| 267 |
+
`List[tuple]`: A list of tuples representing the start and end indices of text slices.
|
| 268 |
+
"""
|
| 269 |
+
indices = np.zeros(len(lines))
|
| 270 |
+
for i in range(len(lines) - 1):
|
| 271 |
+
j = i + 1
|
| 272 |
+
while not clean_lines[j] and j < len(lines) - 1:
|
| 273 |
+
j += 1
|
| 274 |
+
if (
|
| 275 |
+
len(clean_lines[i]) < 200
|
| 276 |
+
and len(clean_lines[i]) > 3
|
| 277 |
+
and len(clean_lines[j]) < 200
|
| 278 |
+
and len(clean_lines[j]) > 3
|
| 279 |
+
and not clean_lines[i].startswith("[MISSING_PAGE")
|
| 280 |
+
and (clean_lines[i] == clean_lines[j] or ratio(clean_lines[i], clean_lines[j]) > 0.9)
|
| 281 |
+
):
|
| 282 |
+
indices[i:j] = 1
|
| 283 |
+
ids = np.where(indices)[0]
|
| 284 |
+
slices = []
|
| 285 |
+
if len(ids) == 0:
|
| 286 |
+
return slices
|
| 287 |
+
j0 = 0
|
| 288 |
+
for j, x in enumerate(np.diff(ids) > 3):
|
| 289 |
+
if x:
|
| 290 |
+
slices.append((ids[j0], ids[j] + 2))
|
| 291 |
+
j0 = j + 1
|
| 292 |
+
slices.append((ids[j0], ids[-1] + 2))
|
| 293 |
+
return [sli for sli in slices if sli[1] - sli[0] > 15]
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def remove_slice_from_lines(lines, clean_text, slice) -> str:
|
| 297 |
+
"""
|
| 298 |
+
Remove a slice of text from the lines based on specific criteria.
|
| 299 |
+
|
| 300 |
+
This function identifies a slice of text within the lines and removes it based on certain conditions.
|
| 301 |
+
|
| 302 |
+
Args:
|
| 303 |
+
lines (list of str): The list of lines containing the text.
|
| 304 |
+
clean_text (list of str): A cleaned version of the text (without numbers).
|
| 305 |
+
slice (tuple): A tuple representing the start and end indices of the slice to be removed.
|
| 306 |
+
|
| 307 |
+
Returns:
|
| 308 |
+
str: The removed slice of text as a single string.
|
| 309 |
+
"""
|
| 310 |
+
base = clean_text[slice[0]]
|
| 311 |
+
section = list(slice)
|
| 312 |
+
check_start_flag = False
|
| 313 |
+
# backwards pass, at most 5 lines
|
| 314 |
+
for line_idx in range(max(0, slice[0] - 1), max(0, slice[0] - 5), -1):
|
| 315 |
+
if not lines[line_idx]:
|
| 316 |
+
continue
|
| 317 |
+
if lines[line_idx] == "## References":
|
| 318 |
+
section[0] = line_idx
|
| 319 |
+
break
|
| 320 |
+
elif ratio(base, remove_numbers(lines[line_idx])) < 0.9:
|
| 321 |
+
section[0] = line_idx + 1
|
| 322 |
+
potential_ref = remove_numbers(lines[max(0, line_idx - 1)].partition("* [")[-1])
|
| 323 |
+
if len(potential_ref) >= 0.75 * len(base) and ratio(base, potential_ref) < 0.9:
|
| 324 |
+
section[0] = line_idx
|
| 325 |
+
check_start_flag = True
|
| 326 |
+
break
|
| 327 |
+
# forward pass, at most 5 lines
|
| 328 |
+
for line_idx in range(min(len(lines), slice[1]), min(len(lines), slice[1] + 5)):
|
| 329 |
+
if ratio(base, remove_numbers(lines[line_idx])) < 0.9:
|
| 330 |
+
section[1] = line_idx
|
| 331 |
+
break
|
| 332 |
+
if len(lines) <= section[1]:
|
| 333 |
+
section[1] = len(lines) - 1
|
| 334 |
+
to_delete = "\n".join(lines[section[0] : section[1] + 1])
|
| 335 |
+
# cut off next page content
|
| 336 |
+
itera, iterb = enumerate(lines[section[1] - 1]), enumerate(lines[section[1]])
|
| 337 |
+
while True:
|
| 338 |
+
try:
|
| 339 |
+
(ia, a) = next(itera)
|
| 340 |
+
while a.isnumeric():
|
| 341 |
+
(ia, a) = next(itera)
|
| 342 |
+
(ib, b) = next(iterb)
|
| 343 |
+
while b.isnumeric():
|
| 344 |
+
(ib, b) = next(iterb)
|
| 345 |
+
if a != b:
|
| 346 |
+
break
|
| 347 |
+
except StopIteration:
|
| 348 |
+
break
|
| 349 |
+
if check_start_flag and "* [" in to_delete:
|
| 350 |
+
to_delete = "* [" + to_delete.partition("* [")[-1]
|
| 351 |
+
try:
|
| 352 |
+
delta = len(lines[section[1]]) - ib - 1
|
| 353 |
+
if delta > 0:
|
| 354 |
+
to_delete = to_delete[:-delta]
|
| 355 |
+
except UnboundLocalError:
|
| 356 |
+
pass
|
| 357 |
+
|
| 358 |
+
return to_delete.strip()
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
@add_end_docstrings(INIT_TOKENIZER_DOCSTRING)
|
| 362 |
+
class NougatTokenizerFast(PreTrainedTokenizerFast):
|
| 363 |
+
"""
|
| 364 |
+
Fast tokenizer for Nougat (backed by HuggingFace tokenizers library).
|
| 365 |
+
|
| 366 |
+
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
|
| 367 |
+
refer to this superclass for more information regarding those methods. This class mainly adds Nougat-specific
|
| 368 |
+
methods for postprocessing the generated text.
|
| 369 |
+
|
| 370 |
+
Args:
|
| 371 |
+
vocab_file (`str`, *optional*):
|
| 372 |
+
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that
|
| 373 |
+
contains the vocabulary necessary to instantiate a tokenizer.
|
| 374 |
+
tokenizer_file (`str`, *optional*):
|
| 375 |
+
[tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
|
| 376 |
+
contains everything needed to load the tokenizer.
|
| 377 |
+
|
| 378 |
+
clean_up_tokenization_spaces (`str`, *optional*, defaults to `False`):
|
| 379 |
+
Wether to cleanup spaces after decoding, cleanup consists in removing potential artifacts like extra
|
| 380 |
+
spaces.
|
| 381 |
+
|
| 382 |
+
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
| 383 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 384 |
+
token instead.
|
| 385 |
+
|
| 386 |
+
bos_token (`str`, *optional*, defaults to `"<s>"`):
|
| 387 |
+
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
| 388 |
+
|
| 389 |
+
eos_token (`str`, *optional*, defaults to `"</s>"`):
|
| 390 |
+
The end of sequence token.
|
| 391 |
+
|
| 392 |
+
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
| 393 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 394 |
+
"""
|
| 395 |
+
|
| 396 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 397 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 398 |
+
slow_tokenizer_class = None
|
| 399 |
+
|
| 400 |
+
def __init__(
|
| 401 |
+
self,
|
| 402 |
+
vocab_file=None,
|
| 403 |
+
tokenizer_file=None,
|
| 404 |
+
clean_up_tokenization_spaces=False,
|
| 405 |
+
unk_token="<unk>",
|
| 406 |
+
bos_token="<s>",
|
| 407 |
+
eos_token="</s>",
|
| 408 |
+
pad_token="<pad>",
|
| 409 |
+
**kwargs,
|
| 410 |
+
):
|
| 411 |
+
super().__init__(
|
| 412 |
+
vocab_file=vocab_file,
|
| 413 |
+
tokenizer_file=tokenizer_file,
|
| 414 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
| 415 |
+
unk_token=unk_token,
|
| 416 |
+
bos_token=bos_token,
|
| 417 |
+
eos_token=eos_token,
|
| 418 |
+
pad_token=pad_token,
|
| 419 |
+
**kwargs,
|
| 420 |
+
)
|
| 421 |
+
self.vocab_file = vocab_file
|
| 422 |
+
|
| 423 |
+
def remove_hallucinated_references(self, text: str) -> str:
|
| 424 |
+
"""
|
| 425 |
+
Remove hallucinated or missing references from the text.
|
| 426 |
+
|
| 427 |
+
This function identifies and removes references that are marked as missing or hallucinated from the input text.
|
| 428 |
+
|
| 429 |
+
Args:
|
| 430 |
+
text (`str`):
|
| 431 |
+
The input text containing references.
|
| 432 |
+
|
| 433 |
+
Returns:
|
| 434 |
+
`str`: The text with hallucinated references removed.
|
| 435 |
+
"""
|
| 436 |
+
lines = text.split("\n")
|
| 437 |
+
if len(lines) == 0:
|
| 438 |
+
return ""
|
| 439 |
+
clean_lines = remove_numbers(lines)
|
| 440 |
+
slices = get_slices(lines, clean_lines)
|
| 441 |
+
to_delete = []
|
| 442 |
+
for slice in slices:
|
| 443 |
+
to_delete.append(remove_slice_from_lines(lines, clean_lines, slice))
|
| 444 |
+
for to_delete in reversed(to_delete):
|
| 445 |
+
text = text.replace(to_delete, "\n\n[MISSING_PAGE_POST]\n\n")
|
| 446 |
+
text = re.sub(
|
| 447 |
+
r"## References\n+\[MISSING_PAGE_POST(:\d+)?\]",
|
| 448 |
+
"\n\n[MISSING_PAGE_POST\\1]",
|
| 449 |
+
text,
|
| 450 |
+
)
|
| 451 |
+
return text
|
| 452 |
+
|
| 453 |
+
def correct_tables(self, generation: str) -> str:
|
| 454 |
+
"""
|
| 455 |
+
Takes a generated string and fixes tables/tabulars to make them match the markdown format needed.
|
| 456 |
+
|
| 457 |
+
Args:
|
| 458 |
+
generation (str): The generated text to be postprocessed.
|
| 459 |
+
|
| 460 |
+
Returns:
|
| 461 |
+
str: The postprocessed text.
|
| 462 |
+
|
| 463 |
+
Example:
|
| 464 |
+
|
| 465 |
+
```python
|
| 466 |
+
correct_tables("\\begin{table} \\begin{tabular}{l l} & \\ \\end{tabular} \\end{table}")
|
| 467 |
+
"\\begin{table}\n\\begin{tabular}{l l} & \\ \\end{tabular}\n\\end{table}"
|
| 468 |
+
```
|
| 469 |
+
"""
|
| 470 |
+
# remove obvious wrong tables
|
| 471 |
+
for l in generation.split("\n"):
|
| 472 |
+
if l.count("\\begin{tabular}") > 15 or l.count("\\multicolumn") > 60 or l.count("&") > 400:
|
| 473 |
+
generation = generation.replace(l, "")
|
| 474 |
+
# whitespace corrections
|
| 475 |
+
|
| 476 |
+
generation = generation.replace("\\begin{table} \\begin{tabular}", "\\begin{table}\n\\begin{tabular}")
|
| 477 |
+
generation = generation.replace("\\end{tabular} \\end{table}", "\\end{tabular}\n\\end{table}")
|
| 478 |
+
generation = generation.replace("\\end{table} Tab", "\\end{table}\nTab")
|
| 479 |
+
|
| 480 |
+
generation = re.sub(r"(^.+)\\begin{tab", r"\1\n\\begin{tab", generation, flags=re.M)
|
| 481 |
+
|
| 482 |
+
# Remove left-aligned empty LaTeX tabular blocks.
|
| 483 |
+
generation = generation.replace(r"\begin{tabular}{l l} & \\ \end{tabular}", "")
|
| 484 |
+
# Remove tabulars with just 2 newline characters.
|
| 485 |
+
generation = generation.replace("\\begin{tabular}{}\n\n\\end{tabular}", "")
|
| 486 |
+
return generation
|
| 487 |
+
|
| 488 |
+
def post_process_single(self, generation: str, fix_markdown: bool = True) -> str:
|
| 489 |
+
"""
|
| 490 |
+
Postprocess a single generated text. Regular expressions used here are taken directly from the Nougat article
|
| 491 |
+
authors. These expressions are commented for clarity and tested end-to-end in most cases.
|
| 492 |
+
|
| 493 |
+
Args:
|
| 494 |
+
generation (str): The generated text to be postprocessed.
|
| 495 |
+
fix_markdown (bool, optional): Whether to perform Markdown formatting fixes. Default is True.
|
| 496 |
+
|
| 497 |
+
Returns:
|
| 498 |
+
str: The postprocessed text.
|
| 499 |
+
"""
|
| 500 |
+
generation = re.sub(
|
| 501 |
+
r"(?:\n|^)#+ \d*\W? ?(.{100,})", r"\n\1", generation
|
| 502 |
+
) # too long section titles probably are none
|
| 503 |
+
generation = generation.strip()
|
| 504 |
+
# Remove LaTeX left margin tag
|
| 505 |
+
generation = generation.replace("\n* [leftmargin=*]\n", "\n")
|
| 506 |
+
# Remove lines with markdown headings starting with #, with numerals,
|
| 507 |
+
# and possibly roman numerals with trailing spaces and newlines
|
| 508 |
+
generation = re.sub(r"^#+ (?:[\d+\.]+|[ixv\.]+)?\s*(?:$|\n\s*)", "", generation, flags=re.M)
|
| 509 |
+
# most likely hallucinated titles
|
| 510 |
+
lines = generation.split("\n")
|
| 511 |
+
if lines[-1].startswith("#") and lines[-1].lstrip("#").startswith(" ") and len(lines) > 1:
|
| 512 |
+
logger.info("Likely hallucinated title at the end of the page: " + lines[-1])
|
| 513 |
+
generation = "\n".join(lines[:-1])
|
| 514 |
+
# obvious repetition detection
|
| 515 |
+
generation = truncate_repetitions(generation)
|
| 516 |
+
# Reference corrections
|
| 517 |
+
generation = self.remove_hallucinated_references(generation)
|
| 518 |
+
# Remove lines starting with asterisks and numbers like "*[1]" and followed by capital letters and periods (ie too long references)
|
| 519 |
+
generation = re.sub(r"^\* \[\d+\](\s?[A-W]\.+\s?){10,}.*$", "", generation, flags=re.M)
|
| 520 |
+
# Remove empty brackets after a reference number in brackets. *[12][]ABC will become *[12]ABC
|
| 521 |
+
generation = re.sub(r"^(\* \[\d+\])\[\](.*)$", r"\1\2", generation, flags=re.M)
|
| 522 |
+
# Remove single characters before or after 2 new lines
|
| 523 |
+
generation = re.sub(r"(^\w\n\n|\n\n\w$)", "", generation)
|
| 524 |
+
# pmc math artifact correction
|
| 525 |
+
generation = re.sub(
|
| 526 |
+
r"([\s.,()])_([a-zA-Z0-9])__([a-zA-Z0-9]){1,3}_([\s.,:()])",
|
| 527 |
+
r"\1\(\2_{\3}\)\4",
|
| 528 |
+
generation,
|
| 529 |
+
)
|
| 530 |
+
generation = re.sub(r"([\s.,\d])_([a-zA-Z0-9])_([\s.,\d;])", r"\1\(\2\)\3", generation)
|
| 531 |
+
# footnote mistakes
|
| 532 |
+
generation = re.sub(
|
| 533 |
+
r"(\nFootnote .*?:) (?:footnotetext|thanks):\W*(.*(?:\n\n|$))",
|
| 534 |
+
r"\1 \2",
|
| 535 |
+
generation,
|
| 536 |
+
)
|
| 537 |
+
# TODO Come up with footnote formatting inside a table
|
| 538 |
+
generation = re.sub(r"\[FOOTNOTE:.+?\](.*?)\[ENDFOOTNOTE\]", "", generation)
|
| 539 |
+
# itemize post processing
|
| 540 |
+
generation = normalize_list_like_lines(generation)
|
| 541 |
+
|
| 542 |
+
if generation.endswith((".", "}")):
|
| 543 |
+
generation += "\n\n"
|
| 544 |
+
if re.match(r"[A-Z0-9,;:]$", generation):
|
| 545 |
+
# add space in case it there is a comma or word ending
|
| 546 |
+
generation += " "
|
| 547 |
+
elif generation.startswith(("#", "**", "\\begin")):
|
| 548 |
+
generation = "\n\n" + generation
|
| 549 |
+
elif generation.split("\n")[-1].startswith(("#", "Figure", "Table")):
|
| 550 |
+
generation = generation + "\n\n"
|
| 551 |
+
else:
|
| 552 |
+
try:
|
| 553 |
+
last_word = generation.split(" ")[-1]
|
| 554 |
+
if last_word in nltk.corpus.words.words():
|
| 555 |
+
generation += " "
|
| 556 |
+
except LookupError:
|
| 557 |
+
# add space just in case. Will split words but better than concatenating them
|
| 558 |
+
generation += " "
|
| 559 |
+
|
| 560 |
+
# table corrections
|
| 561 |
+
generation = self.correct_tables(generation)
|
| 562 |
+
# Remove optional, empty square brackets after begin{array}
|
| 563 |
+
generation = generation.replace("\\begin{array}[]{", "\\begin{array}{")
|
| 564 |
+
# Remove empty or malformed LaTeX tabular blocks with 2 or more columns specified, with spaces and ampersands.
|
| 565 |
+
generation = re.sub(
|
| 566 |
+
r"\\begin{tabular}{([clr ]){2,}}\s*[& ]*\s*(\\\\)? \\end{tabular}",
|
| 567 |
+
"",
|
| 568 |
+
generation,
|
| 569 |
+
)
|
| 570 |
+
# Remove lines containing "S.A.B." one or more times. Was included in Nougat's code.
|
| 571 |
+
generation = re.sub(r"(\*\*S\. A\. B\.\*\*\n+){2,}", "", generation)
|
| 572 |
+
# Remove markdown-style headers that are incomplete or empty on multiple lines.
|
| 573 |
+
generation = re.sub(r"^#+( [\[\d\w])?$", "", generation, flags=re.M)
|
| 574 |
+
# Remove lines with just one period.
|
| 575 |
+
generation = re.sub(r"^\.\s*$", "", generation, flags=re.M)
|
| 576 |
+
# Replace instances of three or more newlines with just two newlines.
|
| 577 |
+
generation = re.sub(r"\n{3,}", "\n\n", generation)
|
| 578 |
+
if fix_markdown:
|
| 579 |
+
return markdown_compatible(generation)
|
| 580 |
+
else:
|
| 581 |
+
return generation
|
| 582 |
+
|
| 583 |
+
def post_process_generation(
|
| 584 |
+
self,
|
| 585 |
+
generation: Union[str, List[str]],
|
| 586 |
+
fix_markdown: bool = True,
|
| 587 |
+
num_workers: Optional[int] = None,
|
| 588 |
+
) -> Union[str, List[str]]:
|
| 589 |
+
"""
|
| 590 |
+
Postprocess a generated text or a list of generated texts.
|
| 591 |
+
|
| 592 |
+
This function can be used to perform postprocessing on generated text, such as fixing Markdown formatting.
|
| 593 |
+
|
| 594 |
+
Postprocessing is quite slow so it is recommended to use multiprocessing to speed up the process.
|
| 595 |
+
|
| 596 |
+
Args:
|
| 597 |
+
generation (Union[str, List[str]]):
|
| 598 |
+
The generated text or a list of generated texts.
|
| 599 |
+
fix_markdown (`bool`, *optional*, defaults to `True`):
|
| 600 |
+
Whether to perform Markdown formatting fixes.
|
| 601 |
+
num_workers (`int`, *optional*):
|
| 602 |
+
Optional number of workers to pass to leverage multiprocessing (postprocessing several texts in
|
| 603 |
+
parallel).
|
| 604 |
+
|
| 605 |
+
Returns:
|
| 606 |
+
Union[str, List[str]]: The postprocessed text or list of postprocessed texts.
|
| 607 |
+
"""
|
| 608 |
+
requires_backends(self, ["nltk", "levenshtein"])
|
| 609 |
+
|
| 610 |
+
if isinstance(generation, list):
|
| 611 |
+
if num_workers is not None and isinstance(num_workers, int):
|
| 612 |
+
with Pool(num_workers) as p:
|
| 613 |
+
return p.map(partial(self.post_process_single, fix_markdown=fix_markdown), generation)
|
| 614 |
+
else:
|
| 615 |
+
return [self.post_process_single(s, fix_markdown=fix_markdown) for s in generation]
|
| 616 |
+
else:
|
| 617 |
+
return self.post_process_single(generation, fix_markdown=fix_markdown)
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
__all__ = ["NougatTokenizerFast"]
|
docs/transformers/src/transformers/models/nystromformer/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_nystromformer import *
|
| 22 |
+
from .modeling_nystromformer import *
|
| 23 |
+
else:
|
| 24 |
+
import sys
|
| 25 |
+
|
| 26 |
+
_file = globals()["__file__"]
|
| 27 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/src/transformers/models/nystromformer/configuration_nystromformer.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 UW-Madison and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Nystromformer model configuration"""
|
| 16 |
+
|
| 17 |
+
from ...configuration_utils import PretrainedConfig
|
| 18 |
+
from ...utils import logging
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
logger = logging.get_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class NystromformerConfig(PretrainedConfig):
|
| 25 |
+
r"""
|
| 26 |
+
This is the configuration class to store the configuration of a [`NystromformerModel`]. It is used to instantiate
|
| 27 |
+
an Nystromformer model according to the specified arguments, defining the model architecture. Instantiating a
|
| 28 |
+
configuration with the defaults will yield a similar configuration to that of the Nystromformer
|
| 29 |
+
[uw-madison/nystromformer-512](https://huggingface.co/uw-madison/nystromformer-512) architecture.
|
| 30 |
+
|
| 31 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 32 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
vocab_size (`int`, *optional*, defaults to 30000):
|
| 36 |
+
Vocabulary size of the Nystromformer model. Defines the number of different tokens that can be represented
|
| 37 |
+
by the `inputs_ids` passed when calling [`NystromformerModel`].
|
| 38 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 39 |
+
Dimension of the encoder layers and the pooler layer.
|
| 40 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 41 |
+
Number of hidden layers in the Transformer encoder.
|
| 42 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 43 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 44 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
| 45 |
+
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 46 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
| 47 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 48 |
+
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
| 49 |
+
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
|
| 50 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 51 |
+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
|
| 52 |
+
The dropout ratio for the attention probabilities.
|
| 53 |
+
max_position_embeddings (`int`, *optional*, defaults to 512):
|
| 54 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
| 55 |
+
just in case (e.g., 512 or 1024 or 2048).
|
| 56 |
+
type_vocab_size (`int`, *optional*, defaults to 2):
|
| 57 |
+
The vocabulary size of the `token_type_ids` passed when calling [`NystromformerModel`].
|
| 58 |
+
segment_means_seq_len (`int`, *optional*, defaults to 64):
|
| 59 |
+
Sequence length used in segment-means.
|
| 60 |
+
num_landmarks (`int`, *optional*, defaults to 64):
|
| 61 |
+
The number of landmark (or Nystrom) points to use in Nystrom approximation of the softmax self-attention
|
| 62 |
+
matrix.
|
| 63 |
+
conv_kernel_size (`int`, *optional*, defaults to 65):
|
| 64 |
+
The kernel size of depthwise convolution used in Nystrom approximation.
|
| 65 |
+
inv_coeff_init_option (`bool`, *optional*, defaults to `False`):
|
| 66 |
+
Whether or not to use exact coefficient computation for the initial values for the iterative method of
|
| 67 |
+
calculating the Moore-Penrose inverse of a matrix.
|
| 68 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 69 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 70 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
| 71 |
+
The epsilon used by the layer normalization layers.
|
| 72 |
+
|
| 73 |
+
Example:
|
| 74 |
+
|
| 75 |
+
```python
|
| 76 |
+
>>> from transformers import NystromformerModel, NystromformerConfig
|
| 77 |
+
|
| 78 |
+
>>> # Initializing a Nystromformer uw-madison/nystromformer-512 style configuration
|
| 79 |
+
>>> configuration = NystromformerConfig()
|
| 80 |
+
|
| 81 |
+
>>> # Initializing a model from the uw-madison/nystromformer-512 style configuration
|
| 82 |
+
>>> model = NystromformerModel(configuration)
|
| 83 |
+
|
| 84 |
+
>>> # Accessing the model configuration
|
| 85 |
+
>>> configuration = model.config
|
| 86 |
+
```"""
|
| 87 |
+
|
| 88 |
+
model_type = "nystromformer"
|
| 89 |
+
|
| 90 |
+
def __init__(
|
| 91 |
+
self,
|
| 92 |
+
vocab_size=30000,
|
| 93 |
+
hidden_size=768,
|
| 94 |
+
num_hidden_layers=12,
|
| 95 |
+
num_attention_heads=12,
|
| 96 |
+
intermediate_size=3072,
|
| 97 |
+
hidden_act="gelu_new",
|
| 98 |
+
hidden_dropout_prob=0.1,
|
| 99 |
+
attention_probs_dropout_prob=0.1,
|
| 100 |
+
max_position_embeddings=510,
|
| 101 |
+
type_vocab_size=2,
|
| 102 |
+
segment_means_seq_len=64,
|
| 103 |
+
num_landmarks=64,
|
| 104 |
+
conv_kernel_size=65,
|
| 105 |
+
inv_coeff_init_option=False,
|
| 106 |
+
initializer_range=0.02,
|
| 107 |
+
layer_norm_eps=1e-5,
|
| 108 |
+
pad_token_id=1,
|
| 109 |
+
bos_token_id=0,
|
| 110 |
+
eos_token_id=2,
|
| 111 |
+
**kwargs,
|
| 112 |
+
):
|
| 113 |
+
self.vocab_size = vocab_size
|
| 114 |
+
self.max_position_embeddings = max_position_embeddings
|
| 115 |
+
self.hidden_size = hidden_size
|
| 116 |
+
self.num_hidden_layers = num_hidden_layers
|
| 117 |
+
self.num_attention_heads = num_attention_heads
|
| 118 |
+
self.intermediate_size = intermediate_size
|
| 119 |
+
self.hidden_act = hidden_act
|
| 120 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 121 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 122 |
+
self.initializer_range = initializer_range
|
| 123 |
+
self.type_vocab_size = type_vocab_size
|
| 124 |
+
self.segment_means_seq_len = segment_means_seq_len
|
| 125 |
+
self.num_landmarks = num_landmarks
|
| 126 |
+
self.conv_kernel_size = conv_kernel_size
|
| 127 |
+
self.inv_coeff_init_option = inv_coeff_init_option
|
| 128 |
+
self.layer_norm_eps = layer_norm_eps
|
| 129 |
+
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
__all__ = ["NystromformerConfig"]
|
docs/transformers/src/transformers/models/nystromformer/convert_nystromformer_original_pytorch_checkpoint_to_pytorch.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Convert Nystromformer checkpoints from the original repository."""
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
from transformers import NystromformerConfig, NystromformerForMaskedLM
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def rename_key(orig_key):
|
| 26 |
+
if "model" in orig_key:
|
| 27 |
+
orig_key = orig_key.replace("model.", "")
|
| 28 |
+
if "norm1" in orig_key:
|
| 29 |
+
orig_key = orig_key.replace("norm1", "attention.output.LayerNorm")
|
| 30 |
+
if "norm2" in orig_key:
|
| 31 |
+
orig_key = orig_key.replace("norm2", "output.LayerNorm")
|
| 32 |
+
if "norm" in orig_key:
|
| 33 |
+
orig_key = orig_key.replace("norm", "LayerNorm")
|
| 34 |
+
if "transformer" in orig_key:
|
| 35 |
+
layer_num = orig_key.split(".")[0].split("_")[-1]
|
| 36 |
+
orig_key = orig_key.replace(f"transformer_{layer_num}", f"encoder.layer.{layer_num}")
|
| 37 |
+
if "mha.attn" in orig_key:
|
| 38 |
+
orig_key = orig_key.replace("mha.attn", "attention.self")
|
| 39 |
+
if "mha" in orig_key:
|
| 40 |
+
orig_key = orig_key.replace("mha", "attention")
|
| 41 |
+
if "W_q" in orig_key:
|
| 42 |
+
orig_key = orig_key.replace("W_q", "self.query")
|
| 43 |
+
if "W_k" in orig_key:
|
| 44 |
+
orig_key = orig_key.replace("W_k", "self.key")
|
| 45 |
+
if "W_v" in orig_key:
|
| 46 |
+
orig_key = orig_key.replace("W_v", "self.value")
|
| 47 |
+
if "ff1" in orig_key:
|
| 48 |
+
orig_key = orig_key.replace("ff1", "intermediate.dense")
|
| 49 |
+
if "ff2" in orig_key:
|
| 50 |
+
orig_key = orig_key.replace("ff2", "output.dense")
|
| 51 |
+
if "ff" in orig_key:
|
| 52 |
+
orig_key = orig_key.replace("ff", "output.dense")
|
| 53 |
+
if "mlm_class" in orig_key:
|
| 54 |
+
orig_key = orig_key.replace("mlm.mlm_class", "cls.predictions.decoder")
|
| 55 |
+
if "mlm" in orig_key:
|
| 56 |
+
orig_key = orig_key.replace("mlm", "cls.predictions.transform")
|
| 57 |
+
if "cls" not in orig_key:
|
| 58 |
+
orig_key = "nystromformer." + orig_key
|
| 59 |
+
|
| 60 |
+
return orig_key
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def convert_checkpoint_helper(config, orig_state_dict):
|
| 64 |
+
for key in orig_state_dict.copy().keys():
|
| 65 |
+
val = orig_state_dict.pop(key)
|
| 66 |
+
|
| 67 |
+
if ("pooler" in key) or ("sen_class" in key) or ("conv.bias" in key):
|
| 68 |
+
continue
|
| 69 |
+
else:
|
| 70 |
+
orig_state_dict[rename_key(key)] = val
|
| 71 |
+
|
| 72 |
+
orig_state_dict["cls.predictions.bias"] = orig_state_dict["cls.predictions.decoder.bias"]
|
| 73 |
+
orig_state_dict["nystromformer.embeddings.position_ids"] = (
|
| 74 |
+
torch.arange(config.max_position_embeddings).expand((1, -1)) + 2
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
return orig_state_dict
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def convert_nystromformer_checkpoint(checkpoint_path, nystromformer_config_file, pytorch_dump_path):
|
| 81 |
+
orig_state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)["model_state_dict"]
|
| 82 |
+
config = NystromformerConfig.from_json_file(nystromformer_config_file)
|
| 83 |
+
model = NystromformerForMaskedLM(config)
|
| 84 |
+
|
| 85 |
+
new_state_dict = convert_checkpoint_helper(config, orig_state_dict)
|
| 86 |
+
|
| 87 |
+
model.load_state_dict(new_state_dict)
|
| 88 |
+
model.eval()
|
| 89 |
+
model.save_pretrained(pytorch_dump_path)
|
| 90 |
+
|
| 91 |
+
print(f"Checkpoint successfuly converted. Model saved at {pytorch_dump_path}")
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
if __name__ == "__main__":
|
| 95 |
+
parser = argparse.ArgumentParser()
|
| 96 |
+
# Required parameters
|
| 97 |
+
parser.add_argument(
|
| 98 |
+
"--pytorch_model_path", default=None, type=str, required=True, help="Path to Nystromformer pytorch checkpoint."
|
| 99 |
+
)
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"--config_file",
|
| 102 |
+
default=None,
|
| 103 |
+
type=str,
|
| 104 |
+
required=True,
|
| 105 |
+
help="The json file for Nystromformer model config.",
|
| 106 |
+
)
|
| 107 |
+
parser.add_argument(
|
| 108 |
+
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
| 109 |
+
)
|
| 110 |
+
args = parser.parse_args()
|
| 111 |
+
convert_nystromformer_checkpoint(args.pytorch_model_path, args.config_file, args.pytorch_dump_path)
|
docs/transformers/src/transformers/models/nystromformer/modeling_nystromformer.py
ADDED
|
@@ -0,0 +1,1124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 UW-Madison The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch Nystromformer model."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from typing import Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.utils.checkpoint
|
| 22 |
+
from torch import nn
|
| 23 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 24 |
+
|
| 25 |
+
from ...activations import ACT2FN
|
| 26 |
+
from ...modeling_outputs import (
|
| 27 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
| 28 |
+
MaskedLMOutput,
|
| 29 |
+
MultipleChoiceModelOutput,
|
| 30 |
+
QuestionAnsweringModelOutput,
|
| 31 |
+
SequenceClassifierOutput,
|
| 32 |
+
TokenClassifierOutput,
|
| 33 |
+
)
|
| 34 |
+
from ...modeling_utils import PreTrainedModel
|
| 35 |
+
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
| 36 |
+
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
| 37 |
+
from .configuration_nystromformer import NystromformerConfig
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
logger = logging.get_logger(__name__)
|
| 41 |
+
|
| 42 |
+
_CHECKPOINT_FOR_DOC = "uw-madison/nystromformer-512"
|
| 43 |
+
_CONFIG_FOR_DOC = "NystromformerConfig"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class NystromformerEmbeddings(nn.Module):
|
| 47 |
+
"""Construct the embeddings from word, position and token_type embeddings."""
|
| 48 |
+
|
| 49 |
+
def __init__(self, config):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
| 52 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings + 2, config.hidden_size)
|
| 53 |
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
| 54 |
+
|
| 55 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
| 56 |
+
# any TensorFlow checkpoint file
|
| 57 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 58 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 59 |
+
|
| 60 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
| 61 |
+
self.register_buffer(
|
| 62 |
+
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + 2, persistent=False
|
| 63 |
+
)
|
| 64 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
| 65 |
+
self.register_buffer(
|
| 66 |
+
"token_type_ids",
|
| 67 |
+
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
|
| 68 |
+
persistent=False,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
|
| 72 |
+
if input_ids is not None:
|
| 73 |
+
input_shape = input_ids.size()
|
| 74 |
+
else:
|
| 75 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 76 |
+
|
| 77 |
+
seq_length = input_shape[1]
|
| 78 |
+
|
| 79 |
+
if position_ids is None:
|
| 80 |
+
position_ids = self.position_ids[:, :seq_length]
|
| 81 |
+
|
| 82 |
+
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
| 83 |
+
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
| 84 |
+
# issue #5664
|
| 85 |
+
if token_type_ids is None:
|
| 86 |
+
if hasattr(self, "token_type_ids"):
|
| 87 |
+
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
|
| 88 |
+
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
|
| 89 |
+
token_type_ids = buffered_token_type_ids_expanded
|
| 90 |
+
else:
|
| 91 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
| 92 |
+
|
| 93 |
+
if inputs_embeds is None:
|
| 94 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
| 95 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
| 96 |
+
|
| 97 |
+
embeddings = inputs_embeds + token_type_embeddings
|
| 98 |
+
if self.position_embedding_type == "absolute":
|
| 99 |
+
position_embeddings = self.position_embeddings(position_ids)
|
| 100 |
+
embeddings += position_embeddings
|
| 101 |
+
embeddings = self.LayerNorm(embeddings)
|
| 102 |
+
embeddings = self.dropout(embeddings)
|
| 103 |
+
return embeddings
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class NystromformerSelfAttention(nn.Module):
|
| 107 |
+
def __init__(self, config, position_embedding_type=None):
|
| 108 |
+
super().__init__()
|
| 109 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
| 110 |
+
raise ValueError(
|
| 111 |
+
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
| 112 |
+
f"heads ({config.num_attention_heads})"
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
self.num_attention_heads = config.num_attention_heads
|
| 116 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 117 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 118 |
+
|
| 119 |
+
self.num_landmarks = config.num_landmarks
|
| 120 |
+
self.seq_len = config.segment_means_seq_len
|
| 121 |
+
self.conv_kernel_size = config.conv_kernel_size
|
| 122 |
+
|
| 123 |
+
if config.inv_coeff_init_option:
|
| 124 |
+
self.init_option = config["inv_init_coeff_option"]
|
| 125 |
+
else:
|
| 126 |
+
self.init_option = "original"
|
| 127 |
+
|
| 128 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
| 129 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
| 130 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
| 131 |
+
|
| 132 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 133 |
+
self.position_embedding_type = position_embedding_type or getattr(
|
| 134 |
+
config, "position_embedding_type", "absolute"
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
if self.conv_kernel_size is not None:
|
| 138 |
+
self.conv = nn.Conv2d(
|
| 139 |
+
in_channels=self.num_attention_heads,
|
| 140 |
+
out_channels=self.num_attention_heads,
|
| 141 |
+
kernel_size=(self.conv_kernel_size, 1),
|
| 142 |
+
padding=(self.conv_kernel_size // 2, 0),
|
| 143 |
+
bias=False,
|
| 144 |
+
groups=self.num_attention_heads,
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# Function to approximate Moore-Penrose inverse via the iterative method
|
| 148 |
+
def iterative_inv(self, mat, n_iter=6):
|
| 149 |
+
identity = torch.eye(mat.size(-1), device=mat.device)
|
| 150 |
+
key = mat
|
| 151 |
+
|
| 152 |
+
# The entries of key are positive and ||key||_{\infty} = 1 due to softmax
|
| 153 |
+
if self.init_option == "original":
|
| 154 |
+
# This original implementation is more conservative to compute coefficient of Z_0.
|
| 155 |
+
value = 1 / torch.max(torch.sum(key, dim=-2)) * key.transpose(-1, -2)
|
| 156 |
+
else:
|
| 157 |
+
# This is the exact coefficient computation, 1 / ||key||_1, of initialization of Z_0, leading to faster convergence.
|
| 158 |
+
value = 1 / torch.max(torch.sum(key, dim=-2), dim=-1).values[:, :, None, None] * key.transpose(-1, -2)
|
| 159 |
+
|
| 160 |
+
for _ in range(n_iter):
|
| 161 |
+
key_value = torch.matmul(key, value)
|
| 162 |
+
value = torch.matmul(
|
| 163 |
+
0.25 * value,
|
| 164 |
+
13 * identity
|
| 165 |
+
- torch.matmul(key_value, 15 * identity - torch.matmul(key_value, 7 * identity - key_value)),
|
| 166 |
+
)
|
| 167 |
+
return value
|
| 168 |
+
|
| 169 |
+
def transpose_for_scores(self, layer):
|
| 170 |
+
new_layer_shape = layer.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 171 |
+
layer = layer.view(*new_layer_shape)
|
| 172 |
+
return layer.permute(0, 2, 1, 3)
|
| 173 |
+
|
| 174 |
+
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
|
| 175 |
+
mixed_query_layer = self.query(hidden_states)
|
| 176 |
+
|
| 177 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 178 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 179 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
| 180 |
+
|
| 181 |
+
query_layer = query_layer / math.sqrt(math.sqrt(self.attention_head_size))
|
| 182 |
+
key_layer = key_layer / math.sqrt(math.sqrt(self.attention_head_size))
|
| 183 |
+
|
| 184 |
+
if self.num_landmarks == self.seq_len:
|
| 185 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 186 |
+
|
| 187 |
+
if attention_mask is not None:
|
| 188 |
+
# Apply the attention mask is (precomputed for all layers in NystromformerModel forward() function)
|
| 189 |
+
attention_scores = attention_scores + attention_mask
|
| 190 |
+
|
| 191 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
| 192 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
| 193 |
+
|
| 194 |
+
else:
|
| 195 |
+
q_landmarks = query_layer.reshape(
|
| 196 |
+
-1,
|
| 197 |
+
self.num_attention_heads,
|
| 198 |
+
self.num_landmarks,
|
| 199 |
+
self.seq_len // self.num_landmarks,
|
| 200 |
+
self.attention_head_size,
|
| 201 |
+
).mean(dim=-2)
|
| 202 |
+
k_landmarks = key_layer.reshape(
|
| 203 |
+
-1,
|
| 204 |
+
self.num_attention_heads,
|
| 205 |
+
self.num_landmarks,
|
| 206 |
+
self.seq_len // self.num_landmarks,
|
| 207 |
+
self.attention_head_size,
|
| 208 |
+
).mean(dim=-2)
|
| 209 |
+
|
| 210 |
+
kernel_1 = torch.nn.functional.softmax(torch.matmul(query_layer, k_landmarks.transpose(-1, -2)), dim=-1)
|
| 211 |
+
kernel_2 = torch.nn.functional.softmax(torch.matmul(q_landmarks, k_landmarks.transpose(-1, -2)), dim=-1)
|
| 212 |
+
|
| 213 |
+
attention_scores = torch.matmul(q_landmarks, key_layer.transpose(-1, -2))
|
| 214 |
+
|
| 215 |
+
if attention_mask is not None:
|
| 216 |
+
# Apply the attention mask is (precomputed for all layers in NystromformerModel forward() function)
|
| 217 |
+
attention_scores = attention_scores + attention_mask
|
| 218 |
+
|
| 219 |
+
kernel_3 = nn.functional.softmax(attention_scores, dim=-1)
|
| 220 |
+
attention_probs = torch.matmul(kernel_1, self.iterative_inv(kernel_2))
|
| 221 |
+
new_value_layer = torch.matmul(kernel_3, value_layer)
|
| 222 |
+
context_layer = torch.matmul(attention_probs, new_value_layer)
|
| 223 |
+
|
| 224 |
+
if self.conv_kernel_size is not None:
|
| 225 |
+
context_layer += self.conv(value_layer)
|
| 226 |
+
|
| 227 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 228 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 229 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
| 230 |
+
|
| 231 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 232 |
+
|
| 233 |
+
return outputs
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
# Copied from transformers.models.bert.modeling_bert.BertSelfOutput
|
| 237 |
+
class NystromformerSelfOutput(nn.Module):
|
| 238 |
+
def __init__(self, config):
|
| 239 |
+
super().__init__()
|
| 240 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 241 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 242 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 243 |
+
|
| 244 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
| 245 |
+
hidden_states = self.dense(hidden_states)
|
| 246 |
+
hidden_states = self.dropout(hidden_states)
|
| 247 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 248 |
+
return hidden_states
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class NystromformerAttention(nn.Module):
|
| 252 |
+
def __init__(self, config, position_embedding_type=None):
|
| 253 |
+
super().__init__()
|
| 254 |
+
self.self = NystromformerSelfAttention(config, position_embedding_type=position_embedding_type)
|
| 255 |
+
self.output = NystromformerSelfOutput(config)
|
| 256 |
+
self.pruned_heads = set()
|
| 257 |
+
|
| 258 |
+
def prune_heads(self, heads):
|
| 259 |
+
if len(heads) == 0:
|
| 260 |
+
return
|
| 261 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 262 |
+
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
# Prune linear layers
|
| 266 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
| 267 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
| 268 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
| 269 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
| 270 |
+
|
| 271 |
+
# Update hyper params and store pruned heads
|
| 272 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
| 273 |
+
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
| 274 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 275 |
+
|
| 276 |
+
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
|
| 277 |
+
self_outputs = self.self(hidden_states, attention_mask, output_attentions)
|
| 278 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
| 279 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
| 280 |
+
return outputs
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->Nystromformer
|
| 284 |
+
class NystromformerIntermediate(nn.Module):
|
| 285 |
+
def __init__(self, config):
|
| 286 |
+
super().__init__()
|
| 287 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 288 |
+
if isinstance(config.hidden_act, str):
|
| 289 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 290 |
+
else:
|
| 291 |
+
self.intermediate_act_fn = config.hidden_act
|
| 292 |
+
|
| 293 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 294 |
+
hidden_states = self.dense(hidden_states)
|
| 295 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 296 |
+
return hidden_states
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->Nystromformer
|
| 300 |
+
class NystromformerOutput(nn.Module):
|
| 301 |
+
def __init__(self, config):
|
| 302 |
+
super().__init__()
|
| 303 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 304 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 305 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 306 |
+
|
| 307 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
| 308 |
+
hidden_states = self.dense(hidden_states)
|
| 309 |
+
hidden_states = self.dropout(hidden_states)
|
| 310 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 311 |
+
return hidden_states
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class NystromformerLayer(nn.Module):
|
| 315 |
+
def __init__(self, config):
|
| 316 |
+
super().__init__()
|
| 317 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
| 318 |
+
self.seq_len_dim = 1
|
| 319 |
+
self.attention = NystromformerAttention(config)
|
| 320 |
+
self.add_cross_attention = config.add_cross_attention
|
| 321 |
+
self.intermediate = NystromformerIntermediate(config)
|
| 322 |
+
self.output = NystromformerOutput(config)
|
| 323 |
+
|
| 324 |
+
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
|
| 325 |
+
self_attention_outputs = self.attention(hidden_states, attention_mask, output_attentions=output_attentions)
|
| 326 |
+
attention_output = self_attention_outputs[0]
|
| 327 |
+
|
| 328 |
+
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
| 329 |
+
|
| 330 |
+
layer_output = apply_chunking_to_forward(
|
| 331 |
+
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
| 332 |
+
)
|
| 333 |
+
outputs = (layer_output,) + outputs
|
| 334 |
+
|
| 335 |
+
return outputs
|
| 336 |
+
|
| 337 |
+
def feed_forward_chunk(self, attention_output):
|
| 338 |
+
intermediate_output = self.intermediate(attention_output)
|
| 339 |
+
layer_output = self.output(intermediate_output, attention_output)
|
| 340 |
+
return layer_output
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
class NystromformerEncoder(nn.Module):
|
| 344 |
+
def __init__(self, config):
|
| 345 |
+
super().__init__()
|
| 346 |
+
self.config = config
|
| 347 |
+
self.layer = nn.ModuleList([NystromformerLayer(config) for _ in range(config.num_hidden_layers)])
|
| 348 |
+
self.gradient_checkpointing = False
|
| 349 |
+
|
| 350 |
+
def forward(
|
| 351 |
+
self,
|
| 352 |
+
hidden_states: torch.Tensor,
|
| 353 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 354 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 355 |
+
output_attentions: bool = False,
|
| 356 |
+
output_hidden_states: bool = False,
|
| 357 |
+
return_dict: bool = True,
|
| 358 |
+
):
|
| 359 |
+
all_hidden_states = () if output_hidden_states else None
|
| 360 |
+
all_self_attentions = () if output_attentions else None
|
| 361 |
+
|
| 362 |
+
for i, layer_module in enumerate(self.layer):
|
| 363 |
+
if output_hidden_states:
|
| 364 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 365 |
+
|
| 366 |
+
if self.gradient_checkpointing and self.training:
|
| 367 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 368 |
+
layer_module.__call__,
|
| 369 |
+
hidden_states,
|
| 370 |
+
attention_mask,
|
| 371 |
+
output_attentions,
|
| 372 |
+
)
|
| 373 |
+
else:
|
| 374 |
+
layer_outputs = layer_module(hidden_states, attention_mask, output_attentions)
|
| 375 |
+
|
| 376 |
+
hidden_states = layer_outputs[0]
|
| 377 |
+
if output_attentions:
|
| 378 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 379 |
+
|
| 380 |
+
if output_hidden_states:
|
| 381 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 382 |
+
|
| 383 |
+
if not return_dict:
|
| 384 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
| 385 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
| 386 |
+
last_hidden_state=hidden_states,
|
| 387 |
+
hidden_states=all_hidden_states,
|
| 388 |
+
attentions=all_self_attentions,
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
# Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->Nystromformer
|
| 393 |
+
class NystromformerPredictionHeadTransform(nn.Module):
|
| 394 |
+
def __init__(self, config):
|
| 395 |
+
super().__init__()
|
| 396 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 397 |
+
if isinstance(config.hidden_act, str):
|
| 398 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
| 399 |
+
else:
|
| 400 |
+
self.transform_act_fn = config.hidden_act
|
| 401 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 402 |
+
|
| 403 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 404 |
+
hidden_states = self.dense(hidden_states)
|
| 405 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
| 406 |
+
hidden_states = self.LayerNorm(hidden_states)
|
| 407 |
+
return hidden_states
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->Nystromformer
|
| 411 |
+
class NystromformerLMPredictionHead(nn.Module):
|
| 412 |
+
def __init__(self, config):
|
| 413 |
+
super().__init__()
|
| 414 |
+
self.transform = NystromformerPredictionHeadTransform(config)
|
| 415 |
+
|
| 416 |
+
# The output weights are the same as the input embeddings, but there is
|
| 417 |
+
# an output-only bias for each token.
|
| 418 |
+
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 419 |
+
|
| 420 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
| 421 |
+
|
| 422 |
+
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
| 423 |
+
self.decoder.bias = self.bias
|
| 424 |
+
|
| 425 |
+
def _tie_weights(self):
|
| 426 |
+
self.decoder.bias = self.bias
|
| 427 |
+
|
| 428 |
+
def forward(self, hidden_states):
|
| 429 |
+
hidden_states = self.transform(hidden_states)
|
| 430 |
+
hidden_states = self.decoder(hidden_states)
|
| 431 |
+
return hidden_states
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
# Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->Nystromformer
|
| 435 |
+
class NystromformerOnlyMLMHead(nn.Module):
|
| 436 |
+
def __init__(self, config):
|
| 437 |
+
super().__init__()
|
| 438 |
+
self.predictions = NystromformerLMPredictionHead(config)
|
| 439 |
+
|
| 440 |
+
def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
|
| 441 |
+
prediction_scores = self.predictions(sequence_output)
|
| 442 |
+
return prediction_scores
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
class NystromformerPreTrainedModel(PreTrainedModel):
|
| 446 |
+
"""
|
| 447 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 448 |
+
models.
|
| 449 |
+
"""
|
| 450 |
+
|
| 451 |
+
config_class = NystromformerConfig
|
| 452 |
+
base_model_prefix = "nystromformer"
|
| 453 |
+
supports_gradient_checkpointing = True
|
| 454 |
+
|
| 455 |
+
def _init_weights(self, module):
|
| 456 |
+
"""Initialize the weights"""
|
| 457 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
| 458 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 459 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 460 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 461 |
+
if module.bias is not None:
|
| 462 |
+
module.bias.data.zero_()
|
| 463 |
+
elif isinstance(module, nn.Embedding):
|
| 464 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 465 |
+
if module.padding_idx is not None:
|
| 466 |
+
module.weight.data[module.padding_idx].zero_()
|
| 467 |
+
elif isinstance(module, nn.LayerNorm):
|
| 468 |
+
module.bias.data.zero_()
|
| 469 |
+
module.weight.data.fill_(1.0)
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
NYSTROMFORMER_START_DOCSTRING = r"""
|
| 473 |
+
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
|
| 474 |
+
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
| 475 |
+
behavior.
|
| 476 |
+
|
| 477 |
+
Parameters:
|
| 478 |
+
config ([`NystromformerConfig`]): Model configuration class with all the parameters of the model.
|
| 479 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 480 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 481 |
+
"""
|
| 482 |
+
|
| 483 |
+
NYSTROMFORMER_INPUTS_DOCSTRING = r"""
|
| 484 |
+
Args:
|
| 485 |
+
input_ids (`torch.LongTensor` of shape `({0})`):
|
| 486 |
+
Indices of input sequence tokens in the vocabulary.
|
| 487 |
+
|
| 488 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 489 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 490 |
+
|
| 491 |
+
[What are input IDs?](../glossary#input-ids)
|
| 492 |
+
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
|
| 493 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 494 |
+
|
| 495 |
+
- 1 for tokens that are **not masked**,
|
| 496 |
+
- 0 for tokens that are **masked**.
|
| 497 |
+
|
| 498 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 499 |
+
token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
| 500 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
| 501 |
+
1]`:
|
| 502 |
+
|
| 503 |
+
- 0 corresponds to a *sentence A* token,
|
| 504 |
+
- 1 corresponds to a *sentence B* token.
|
| 505 |
+
|
| 506 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
| 507 |
+
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
| 508 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 509 |
+
config.max_position_embeddings - 1]`.
|
| 510 |
+
|
| 511 |
+
[What are position IDs?](../glossary#position-ids)
|
| 512 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
| 513 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
| 514 |
+
|
| 515 |
+
- 1 indicates the head is **not masked**,
|
| 516 |
+
- 0 indicates the head is **masked**.
|
| 517 |
+
|
| 518 |
+
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
|
| 519 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 520 |
+
is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
|
| 521 |
+
model's internal embedding lookup matrix.
|
| 522 |
+
output_attentions (`bool`, *optional*):
|
| 523 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 524 |
+
tensors for more detail.
|
| 525 |
+
output_hidden_states (`bool`, *optional*):
|
| 526 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 527 |
+
more detail.
|
| 528 |
+
return_dict (`bool`, *optional*):
|
| 529 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 530 |
+
"""
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
@add_start_docstrings(
|
| 534 |
+
"The bare Nyströmformer Model transformer outputting raw hidden-states without any specific head on top.",
|
| 535 |
+
NYSTROMFORMER_START_DOCSTRING,
|
| 536 |
+
)
|
| 537 |
+
class NystromformerModel(NystromformerPreTrainedModel):
|
| 538 |
+
def __init__(self, config):
|
| 539 |
+
super().__init__(config)
|
| 540 |
+
self.config = config
|
| 541 |
+
|
| 542 |
+
self.embeddings = NystromformerEmbeddings(config)
|
| 543 |
+
self.encoder = NystromformerEncoder(config)
|
| 544 |
+
|
| 545 |
+
# Initialize weights and apply final processing
|
| 546 |
+
self.post_init()
|
| 547 |
+
|
| 548 |
+
def get_input_embeddings(self):
|
| 549 |
+
return self.embeddings.word_embeddings
|
| 550 |
+
|
| 551 |
+
def set_input_embeddings(self, value):
|
| 552 |
+
self.embeddings.word_embeddings = value
|
| 553 |
+
|
| 554 |
+
def _prune_heads(self, heads_to_prune):
|
| 555 |
+
"""
|
| 556 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 557 |
+
class PreTrainedModel
|
| 558 |
+
"""
|
| 559 |
+
for layer, heads in heads_to_prune.items():
|
| 560 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 561 |
+
|
| 562 |
+
@add_start_docstrings_to_model_forward(NYSTROMFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 563 |
+
@add_code_sample_docstrings(
|
| 564 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 565 |
+
output_type=BaseModelOutputWithPastAndCrossAttentions,
|
| 566 |
+
config_class=_CONFIG_FOR_DOC,
|
| 567 |
+
)
|
| 568 |
+
def forward(
|
| 569 |
+
self,
|
| 570 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 571 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 572 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 573 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 574 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 575 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 576 |
+
output_attentions: Optional[bool] = None,
|
| 577 |
+
output_hidden_states: Optional[bool] = None,
|
| 578 |
+
return_dict: Optional[bool] = None,
|
| 579 |
+
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
| 580 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 581 |
+
output_hidden_states = (
|
| 582 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 583 |
+
)
|
| 584 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 585 |
+
|
| 586 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 587 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 588 |
+
elif input_ids is not None:
|
| 589 |
+
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
| 590 |
+
input_shape = input_ids.size()
|
| 591 |
+
elif inputs_embeds is not None:
|
| 592 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 593 |
+
else:
|
| 594 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 595 |
+
|
| 596 |
+
batch_size, seq_length = input_shape
|
| 597 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 598 |
+
|
| 599 |
+
if attention_mask is None:
|
| 600 |
+
attention_mask = torch.ones(((batch_size, seq_length)), device=device)
|
| 601 |
+
|
| 602 |
+
if token_type_ids is None:
|
| 603 |
+
if hasattr(self.embeddings, "token_type_ids"):
|
| 604 |
+
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
|
| 605 |
+
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
|
| 606 |
+
token_type_ids = buffered_token_type_ids_expanded
|
| 607 |
+
else:
|
| 608 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
| 609 |
+
|
| 610 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
| 611 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
| 612 |
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
| 613 |
+
|
| 614 |
+
# Prepare head mask if needed
|
| 615 |
+
# 1.0 in head_mask indicate we keep the head
|
| 616 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 617 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 618 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 619 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 620 |
+
|
| 621 |
+
embedding_output = self.embeddings(
|
| 622 |
+
input_ids=input_ids,
|
| 623 |
+
position_ids=position_ids,
|
| 624 |
+
token_type_ids=token_type_ids,
|
| 625 |
+
inputs_embeds=inputs_embeds,
|
| 626 |
+
)
|
| 627 |
+
encoder_outputs = self.encoder(
|
| 628 |
+
embedding_output,
|
| 629 |
+
attention_mask=extended_attention_mask,
|
| 630 |
+
head_mask=head_mask,
|
| 631 |
+
output_attentions=output_attentions,
|
| 632 |
+
output_hidden_states=output_hidden_states,
|
| 633 |
+
return_dict=return_dict,
|
| 634 |
+
)
|
| 635 |
+
sequence_output = encoder_outputs[0]
|
| 636 |
+
|
| 637 |
+
if not return_dict:
|
| 638 |
+
return (sequence_output,) + encoder_outputs[1:]
|
| 639 |
+
|
| 640 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
| 641 |
+
last_hidden_state=sequence_output,
|
| 642 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 643 |
+
attentions=encoder_outputs.attentions,
|
| 644 |
+
cross_attentions=encoder_outputs.cross_attentions,
|
| 645 |
+
)
|
| 646 |
+
|
| 647 |
+
|
| 648 |
+
@add_start_docstrings("""Nyströmformer Model with a `language modeling` head on top.""", NYSTROMFORMER_START_DOCSTRING)
|
| 649 |
+
class NystromformerForMaskedLM(NystromformerPreTrainedModel):
|
| 650 |
+
_tied_weights_keys = ["cls.predictions.decoder"]
|
| 651 |
+
|
| 652 |
+
def __init__(self, config):
|
| 653 |
+
super().__init__(config)
|
| 654 |
+
|
| 655 |
+
self.nystromformer = NystromformerModel(config)
|
| 656 |
+
self.cls = NystromformerOnlyMLMHead(config)
|
| 657 |
+
|
| 658 |
+
# Initialize weights and apply final processing
|
| 659 |
+
self.post_init()
|
| 660 |
+
|
| 661 |
+
def get_output_embeddings(self):
|
| 662 |
+
return self.cls.predictions.decoder
|
| 663 |
+
|
| 664 |
+
def set_output_embeddings(self, new_embeddings):
|
| 665 |
+
self.cls.predictions.decoder = new_embeddings
|
| 666 |
+
self.cls.predictions.bias = new_embeddings.bias
|
| 667 |
+
|
| 668 |
+
@add_start_docstrings_to_model_forward(NYSTROMFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 669 |
+
@add_code_sample_docstrings(
|
| 670 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 671 |
+
output_type=MaskedLMOutput,
|
| 672 |
+
config_class=_CONFIG_FOR_DOC,
|
| 673 |
+
)
|
| 674 |
+
def forward(
|
| 675 |
+
self,
|
| 676 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 677 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 678 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 679 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 680 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 681 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 682 |
+
labels: Optional[torch.LongTensor] = None,
|
| 683 |
+
output_attentions: Optional[bool] = None,
|
| 684 |
+
output_hidden_states: Optional[bool] = None,
|
| 685 |
+
return_dict: Optional[bool] = None,
|
| 686 |
+
) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
|
| 687 |
+
r"""
|
| 688 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 689 |
+
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
| 690 |
+
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
| 691 |
+
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 692 |
+
"""
|
| 693 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 694 |
+
|
| 695 |
+
outputs = self.nystromformer(
|
| 696 |
+
input_ids,
|
| 697 |
+
attention_mask=attention_mask,
|
| 698 |
+
token_type_ids=token_type_ids,
|
| 699 |
+
position_ids=position_ids,
|
| 700 |
+
head_mask=head_mask,
|
| 701 |
+
inputs_embeds=inputs_embeds,
|
| 702 |
+
output_attentions=output_attentions,
|
| 703 |
+
output_hidden_states=output_hidden_states,
|
| 704 |
+
return_dict=return_dict,
|
| 705 |
+
)
|
| 706 |
+
|
| 707 |
+
sequence_output = outputs[0]
|
| 708 |
+
prediction_scores = self.cls(sequence_output)
|
| 709 |
+
|
| 710 |
+
masked_lm_loss = None
|
| 711 |
+
if labels is not None:
|
| 712 |
+
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
| 713 |
+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
| 714 |
+
|
| 715 |
+
if not return_dict:
|
| 716 |
+
output = (prediction_scores,) + outputs[1:]
|
| 717 |
+
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
| 718 |
+
|
| 719 |
+
return MaskedLMOutput(
|
| 720 |
+
loss=masked_lm_loss,
|
| 721 |
+
logits=prediction_scores,
|
| 722 |
+
hidden_states=outputs.hidden_states,
|
| 723 |
+
attentions=outputs.attentions,
|
| 724 |
+
)
|
| 725 |
+
|
| 726 |
+
|
| 727 |
+
class NystromformerClassificationHead(nn.Module):
|
| 728 |
+
"""Head for sentence-level classification tasks."""
|
| 729 |
+
|
| 730 |
+
def __init__(self, config):
|
| 731 |
+
super().__init__()
|
| 732 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 733 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 734 |
+
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
|
| 735 |
+
|
| 736 |
+
self.config = config
|
| 737 |
+
|
| 738 |
+
def forward(self, features, **kwargs):
|
| 739 |
+
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
|
| 740 |
+
x = self.dropout(x)
|
| 741 |
+
x = self.dense(x)
|
| 742 |
+
x = ACT2FN[self.config.hidden_act](x)
|
| 743 |
+
x = self.dropout(x)
|
| 744 |
+
x = self.out_proj(x)
|
| 745 |
+
return x
|
| 746 |
+
|
| 747 |
+
|
| 748 |
+
@add_start_docstrings(
|
| 749 |
+
"""
|
| 750 |
+
Nyströmformer Model transformer with a sequence classification/regression head on top (a linear layer on top of the
|
| 751 |
+
pooled output) e.g. for GLUE tasks.
|
| 752 |
+
""",
|
| 753 |
+
NYSTROMFORMER_START_DOCSTRING,
|
| 754 |
+
)
|
| 755 |
+
class NystromformerForSequenceClassification(NystromformerPreTrainedModel):
|
| 756 |
+
def __init__(self, config):
|
| 757 |
+
super().__init__(config)
|
| 758 |
+
self.num_labels = config.num_labels
|
| 759 |
+
self.nystromformer = NystromformerModel(config)
|
| 760 |
+
self.classifier = NystromformerClassificationHead(config)
|
| 761 |
+
|
| 762 |
+
# Initialize weights and apply final processing
|
| 763 |
+
self.post_init()
|
| 764 |
+
|
| 765 |
+
@add_start_docstrings_to_model_forward(NYSTROMFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 766 |
+
@add_code_sample_docstrings(
|
| 767 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 768 |
+
output_type=SequenceClassifierOutput,
|
| 769 |
+
config_class=_CONFIG_FOR_DOC,
|
| 770 |
+
)
|
| 771 |
+
def forward(
|
| 772 |
+
self,
|
| 773 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 774 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 775 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 776 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 777 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 778 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 779 |
+
labels: Optional[torch.LongTensor] = None,
|
| 780 |
+
output_attentions: Optional[bool] = None,
|
| 781 |
+
output_hidden_states: Optional[bool] = None,
|
| 782 |
+
return_dict: Optional[bool] = None,
|
| 783 |
+
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
| 784 |
+
r"""
|
| 785 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 786 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 787 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 788 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 789 |
+
"""
|
| 790 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 791 |
+
|
| 792 |
+
outputs = self.nystromformer(
|
| 793 |
+
input_ids,
|
| 794 |
+
attention_mask=attention_mask,
|
| 795 |
+
token_type_ids=token_type_ids,
|
| 796 |
+
position_ids=position_ids,
|
| 797 |
+
head_mask=head_mask,
|
| 798 |
+
inputs_embeds=inputs_embeds,
|
| 799 |
+
output_attentions=output_attentions,
|
| 800 |
+
output_hidden_states=output_hidden_states,
|
| 801 |
+
return_dict=return_dict,
|
| 802 |
+
)
|
| 803 |
+
|
| 804 |
+
sequence_output = outputs[0]
|
| 805 |
+
logits = self.classifier(sequence_output)
|
| 806 |
+
|
| 807 |
+
loss = None
|
| 808 |
+
if labels is not None:
|
| 809 |
+
if self.config.problem_type is None:
|
| 810 |
+
if self.num_labels == 1:
|
| 811 |
+
self.config.problem_type = "regression"
|
| 812 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 813 |
+
self.config.problem_type = "single_label_classification"
|
| 814 |
+
else:
|
| 815 |
+
self.config.problem_type = "multi_label_classification"
|
| 816 |
+
|
| 817 |
+
if self.config.problem_type == "regression":
|
| 818 |
+
loss_fct = MSELoss()
|
| 819 |
+
if self.num_labels == 1:
|
| 820 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
| 821 |
+
else:
|
| 822 |
+
loss = loss_fct(logits, labels)
|
| 823 |
+
elif self.config.problem_type == "single_label_classification":
|
| 824 |
+
loss_fct = CrossEntropyLoss()
|
| 825 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 826 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 827 |
+
loss_fct = BCEWithLogitsLoss()
|
| 828 |
+
loss = loss_fct(logits, labels)
|
| 829 |
+
if not return_dict:
|
| 830 |
+
output = (logits,) + outputs[1:]
|
| 831 |
+
return ((loss,) + output) if loss is not None else output
|
| 832 |
+
|
| 833 |
+
return SequenceClassifierOutput(
|
| 834 |
+
loss=loss,
|
| 835 |
+
logits=logits,
|
| 836 |
+
hidden_states=outputs.hidden_states,
|
| 837 |
+
attentions=outputs.attentions,
|
| 838 |
+
)
|
| 839 |
+
|
| 840 |
+
|
| 841 |
+
@add_start_docstrings(
|
| 842 |
+
"""
|
| 843 |
+
Nyströmformer Model with a multiple choice classification head on top (a linear layer on top of the pooled output
|
| 844 |
+
and a softmax) e.g. for RocStories/SWAG tasks.
|
| 845 |
+
""",
|
| 846 |
+
NYSTROMFORMER_START_DOCSTRING,
|
| 847 |
+
)
|
| 848 |
+
class NystromformerForMultipleChoice(NystromformerPreTrainedModel):
|
| 849 |
+
def __init__(self, config):
|
| 850 |
+
super().__init__(config)
|
| 851 |
+
|
| 852 |
+
self.nystromformer = NystromformerModel(config)
|
| 853 |
+
self.pre_classifier = nn.Linear(config.hidden_size, config.hidden_size)
|
| 854 |
+
self.classifier = nn.Linear(config.hidden_size, 1)
|
| 855 |
+
|
| 856 |
+
# Initialize weights and apply final processing
|
| 857 |
+
self.post_init()
|
| 858 |
+
|
| 859 |
+
@add_start_docstrings_to_model_forward(
|
| 860 |
+
NYSTROMFORMER_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
|
| 861 |
+
)
|
| 862 |
+
@add_code_sample_docstrings(
|
| 863 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 864 |
+
output_type=MultipleChoiceModelOutput,
|
| 865 |
+
config_class=_CONFIG_FOR_DOC,
|
| 866 |
+
)
|
| 867 |
+
def forward(
|
| 868 |
+
self,
|
| 869 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 870 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 871 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 872 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 873 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 874 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 875 |
+
labels: Optional[torch.LongTensor] = None,
|
| 876 |
+
output_attentions: Optional[bool] = None,
|
| 877 |
+
output_hidden_states: Optional[bool] = None,
|
| 878 |
+
return_dict: Optional[bool] = None,
|
| 879 |
+
) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
|
| 880 |
+
r"""
|
| 881 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 882 |
+
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
|
| 883 |
+
num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
|
| 884 |
+
`input_ids` above)
|
| 885 |
+
"""
|
| 886 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 887 |
+
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
| 888 |
+
|
| 889 |
+
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
|
| 890 |
+
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
| 891 |
+
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
| 892 |
+
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
| 893 |
+
inputs_embeds = (
|
| 894 |
+
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
| 895 |
+
if inputs_embeds is not None
|
| 896 |
+
else None
|
| 897 |
+
)
|
| 898 |
+
|
| 899 |
+
outputs = self.nystromformer(
|
| 900 |
+
input_ids,
|
| 901 |
+
attention_mask=attention_mask,
|
| 902 |
+
token_type_ids=token_type_ids,
|
| 903 |
+
position_ids=position_ids,
|
| 904 |
+
head_mask=head_mask,
|
| 905 |
+
inputs_embeds=inputs_embeds,
|
| 906 |
+
output_attentions=output_attentions,
|
| 907 |
+
output_hidden_states=output_hidden_states,
|
| 908 |
+
return_dict=return_dict,
|
| 909 |
+
)
|
| 910 |
+
|
| 911 |
+
hidden_state = outputs[0] # (bs * num_choices, seq_len, dim)
|
| 912 |
+
pooled_output = hidden_state[:, 0] # (bs * num_choices, dim)
|
| 913 |
+
pooled_output = self.pre_classifier(pooled_output) # (bs * num_choices, dim)
|
| 914 |
+
pooled_output = nn.ReLU()(pooled_output) # (bs * num_choices, dim)
|
| 915 |
+
logits = self.classifier(pooled_output)
|
| 916 |
+
|
| 917 |
+
reshaped_logits = logits.view(-1, num_choices)
|
| 918 |
+
|
| 919 |
+
loss = None
|
| 920 |
+
if labels is not None:
|
| 921 |
+
loss_fct = CrossEntropyLoss()
|
| 922 |
+
loss = loss_fct(reshaped_logits, labels)
|
| 923 |
+
|
| 924 |
+
if not return_dict:
|
| 925 |
+
output = (reshaped_logits,) + outputs[1:]
|
| 926 |
+
return ((loss,) + output) if loss is not None else output
|
| 927 |
+
|
| 928 |
+
return MultipleChoiceModelOutput(
|
| 929 |
+
loss=loss,
|
| 930 |
+
logits=reshaped_logits,
|
| 931 |
+
hidden_states=outputs.hidden_states,
|
| 932 |
+
attentions=outputs.attentions,
|
| 933 |
+
)
|
| 934 |
+
|
| 935 |
+
|
| 936 |
+
@add_start_docstrings(
|
| 937 |
+
"""
|
| 938 |
+
Nyströmformer Model with a token classification head on top (a linear layer on top of the hidden-states output)
|
| 939 |
+
e.g. for Named-Entity-Recognition (NER) tasks.
|
| 940 |
+
""",
|
| 941 |
+
NYSTROMFORMER_START_DOCSTRING,
|
| 942 |
+
)
|
| 943 |
+
class NystromformerForTokenClassification(NystromformerPreTrainedModel):
|
| 944 |
+
def __init__(self, config):
|
| 945 |
+
super().__init__(config)
|
| 946 |
+
self.num_labels = config.num_labels
|
| 947 |
+
|
| 948 |
+
self.nystromformer = NystromformerModel(config)
|
| 949 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 950 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 951 |
+
|
| 952 |
+
# Initialize weights and apply final processing
|
| 953 |
+
self.post_init()
|
| 954 |
+
|
| 955 |
+
@add_start_docstrings_to_model_forward(NYSTROMFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 956 |
+
@add_code_sample_docstrings(
|
| 957 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 958 |
+
output_type=TokenClassifierOutput,
|
| 959 |
+
config_class=_CONFIG_FOR_DOC,
|
| 960 |
+
)
|
| 961 |
+
def forward(
|
| 962 |
+
self,
|
| 963 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 964 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 965 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 966 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 967 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 968 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 969 |
+
labels: Optional[torch.LongTensor] = None,
|
| 970 |
+
output_attentions: Optional[bool] = None,
|
| 971 |
+
output_hidden_states: Optional[bool] = None,
|
| 972 |
+
return_dict: Optional[bool] = None,
|
| 973 |
+
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
|
| 974 |
+
r"""
|
| 975 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 976 |
+
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
| 977 |
+
"""
|
| 978 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 979 |
+
|
| 980 |
+
outputs = self.nystromformer(
|
| 981 |
+
input_ids,
|
| 982 |
+
attention_mask=attention_mask,
|
| 983 |
+
token_type_ids=token_type_ids,
|
| 984 |
+
position_ids=position_ids,
|
| 985 |
+
head_mask=head_mask,
|
| 986 |
+
inputs_embeds=inputs_embeds,
|
| 987 |
+
output_attentions=output_attentions,
|
| 988 |
+
output_hidden_states=output_hidden_states,
|
| 989 |
+
return_dict=return_dict,
|
| 990 |
+
)
|
| 991 |
+
|
| 992 |
+
sequence_output = outputs[0]
|
| 993 |
+
|
| 994 |
+
sequence_output = self.dropout(sequence_output)
|
| 995 |
+
logits = self.classifier(sequence_output)
|
| 996 |
+
|
| 997 |
+
loss = None
|
| 998 |
+
if labels is not None:
|
| 999 |
+
loss_fct = CrossEntropyLoss()
|
| 1000 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 1001 |
+
|
| 1002 |
+
if not return_dict:
|
| 1003 |
+
output = (logits,) + outputs[1:]
|
| 1004 |
+
return ((loss,) + output) if loss is not None else output
|
| 1005 |
+
|
| 1006 |
+
return TokenClassifierOutput(
|
| 1007 |
+
loss=loss,
|
| 1008 |
+
logits=logits,
|
| 1009 |
+
hidden_states=outputs.hidden_states,
|
| 1010 |
+
attentions=outputs.attentions,
|
| 1011 |
+
)
|
| 1012 |
+
|
| 1013 |
+
|
| 1014 |
+
@add_start_docstrings(
|
| 1015 |
+
"""
|
| 1016 |
+
Nyströmformer Model with a span classification head on top for extractive question-answering tasks like SQuAD (a
|
| 1017 |
+
linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
| 1018 |
+
""",
|
| 1019 |
+
NYSTROMFORMER_START_DOCSTRING,
|
| 1020 |
+
)
|
| 1021 |
+
class NystromformerForQuestionAnswering(NystromformerPreTrainedModel):
|
| 1022 |
+
def __init__(self, config):
|
| 1023 |
+
super().__init__(config)
|
| 1024 |
+
|
| 1025 |
+
config.num_labels = 2
|
| 1026 |
+
self.num_labels = config.num_labels
|
| 1027 |
+
|
| 1028 |
+
self.nystromformer = NystromformerModel(config)
|
| 1029 |
+
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
| 1030 |
+
|
| 1031 |
+
# Initialize weights and apply final processing
|
| 1032 |
+
self.post_init()
|
| 1033 |
+
|
| 1034 |
+
@add_start_docstrings_to_model_forward(NYSTROMFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1035 |
+
@add_code_sample_docstrings(
|
| 1036 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1037 |
+
output_type=QuestionAnsweringModelOutput,
|
| 1038 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1039 |
+
)
|
| 1040 |
+
def forward(
|
| 1041 |
+
self,
|
| 1042 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1043 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 1044 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 1045 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1046 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 1047 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1048 |
+
start_positions: Optional[torch.LongTensor] = None,
|
| 1049 |
+
end_positions: Optional[torch.LongTensor] = None,
|
| 1050 |
+
output_attentions: Optional[bool] = None,
|
| 1051 |
+
output_hidden_states: Optional[bool] = None,
|
| 1052 |
+
return_dict: Optional[bool] = None,
|
| 1053 |
+
) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
|
| 1054 |
+
r"""
|
| 1055 |
+
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1056 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
| 1057 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 1058 |
+
are not taken into account for computing the loss.
|
| 1059 |
+
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1060 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
| 1061 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 1062 |
+
are not taken into account for computing the loss.
|
| 1063 |
+
"""
|
| 1064 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1065 |
+
|
| 1066 |
+
outputs = self.nystromformer(
|
| 1067 |
+
input_ids,
|
| 1068 |
+
attention_mask=attention_mask,
|
| 1069 |
+
token_type_ids=token_type_ids,
|
| 1070 |
+
position_ids=position_ids,
|
| 1071 |
+
head_mask=head_mask,
|
| 1072 |
+
inputs_embeds=inputs_embeds,
|
| 1073 |
+
output_attentions=output_attentions,
|
| 1074 |
+
output_hidden_states=output_hidden_states,
|
| 1075 |
+
return_dict=return_dict,
|
| 1076 |
+
)
|
| 1077 |
+
|
| 1078 |
+
sequence_output = outputs[0]
|
| 1079 |
+
|
| 1080 |
+
logits = self.qa_outputs(sequence_output)
|
| 1081 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
| 1082 |
+
start_logits = start_logits.squeeze(-1)
|
| 1083 |
+
end_logits = end_logits.squeeze(-1)
|
| 1084 |
+
|
| 1085 |
+
total_loss = None
|
| 1086 |
+
if start_positions is not None and end_positions is not None:
|
| 1087 |
+
# If we are on multi-GPU, split add a dimension
|
| 1088 |
+
if len(start_positions.size()) > 1:
|
| 1089 |
+
start_positions = start_positions.squeeze(-1)
|
| 1090 |
+
if len(end_positions.size()) > 1:
|
| 1091 |
+
end_positions = end_positions.squeeze(-1)
|
| 1092 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
| 1093 |
+
ignored_index = start_logits.size(1)
|
| 1094 |
+
start_positions = start_positions.clamp(0, ignored_index)
|
| 1095 |
+
end_positions = end_positions.clamp(0, ignored_index)
|
| 1096 |
+
|
| 1097 |
+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
| 1098 |
+
start_loss = loss_fct(start_logits, start_positions)
|
| 1099 |
+
end_loss = loss_fct(end_logits, end_positions)
|
| 1100 |
+
total_loss = (start_loss + end_loss) / 2
|
| 1101 |
+
|
| 1102 |
+
if not return_dict:
|
| 1103 |
+
output = (start_logits, end_logits) + outputs[1:]
|
| 1104 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
| 1105 |
+
|
| 1106 |
+
return QuestionAnsweringModelOutput(
|
| 1107 |
+
loss=total_loss,
|
| 1108 |
+
start_logits=start_logits,
|
| 1109 |
+
end_logits=end_logits,
|
| 1110 |
+
hidden_states=outputs.hidden_states,
|
| 1111 |
+
attentions=outputs.attentions,
|
| 1112 |
+
)
|
| 1113 |
+
|
| 1114 |
+
|
| 1115 |
+
__all__ = [
|
| 1116 |
+
"NystromformerForMaskedLM",
|
| 1117 |
+
"NystromformerForMultipleChoice",
|
| 1118 |
+
"NystromformerForQuestionAnswering",
|
| 1119 |
+
"NystromformerForSequenceClassification",
|
| 1120 |
+
"NystromformerForTokenClassification",
|
| 1121 |
+
"NystromformerLayer",
|
| 1122 |
+
"NystromformerModel",
|
| 1123 |
+
"NystromformerPreTrainedModel",
|
| 1124 |
+
]
|
docs/transformers/src/transformers/models/olmo/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 EleutherAI and The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_olmo import *
|
| 22 |
+
from .modeling_olmo import *
|
| 23 |
+
else:
|
| 24 |
+
import sys
|
| 25 |
+
|
| 26 |
+
_file = globals()["__file__"]
|
| 27 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/src/transformers/models/olmo/configuration_olmo.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 5 |
+
# and OPT implementations in this library. It has been modified from its
|
| 6 |
+
# original forms to accommodate minor architectural differences compared
|
| 7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 8 |
+
#
|
| 9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
# you may not use this file except in compliance with the License.
|
| 11 |
+
# You may obtain a copy of the License at
|
| 12 |
+
#
|
| 13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
#
|
| 15 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
# See the License for the specific language governing permissions and
|
| 19 |
+
# limitations under the License.
|
| 20 |
+
"""OLMo model configuration"""
|
| 21 |
+
|
| 22 |
+
from ...configuration_utils import PretrainedConfig
|
| 23 |
+
from ...utils import logging
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
logger = logging.get_logger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class OlmoConfig(PretrainedConfig):
|
| 30 |
+
r"""
|
| 31 |
+
This is the configuration class to store the configuration of a [`OlmoModel`]. It is used to instantiate an OLMo
|
| 32 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 33 |
+
defaults will yield a similar configuration to that of the [allenai/OLMo-7B-hf](https://huggingface.co/allenai/OLMo-7B-hf).
|
| 34 |
+
|
| 35 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 36 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
vocab_size (`int`, *optional*, defaults to 50304):
|
| 41 |
+
Vocabulary size of the OLMo model. Defines the number of different tokens that can be represented by the
|
| 42 |
+
`inputs_ids` passed when calling [`OlmoModel`]
|
| 43 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
| 44 |
+
Dimension of the hidden representations.
|
| 45 |
+
intermediate_size (`int`, *optional*, defaults to 11008):
|
| 46 |
+
Dimension of the MLP representations.
|
| 47 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
| 48 |
+
Number of hidden layers in the Transformer decoder.
|
| 49 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
| 50 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
| 51 |
+
num_key_value_heads (`int`, *optional*):
|
| 52 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
| 53 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
| 54 |
+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
| 55 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
| 56 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
| 57 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
| 58 |
+
`num_attention_heads`.
|
| 59 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 60 |
+
The non-linear activation function (function or string) in the decoder.
|
| 61 |
+
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
| 62 |
+
The maximum sequence length that this model might ever be used with.
|
| 63 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 64 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 65 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 66 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 67 |
+
relevant if `config.is_decoder=True`.
|
| 68 |
+
pad_token_id (`int`, *optional*, defaults to 1):
|
| 69 |
+
Padding token id.
|
| 70 |
+
bos_token_id (`int`, *optional*):
|
| 71 |
+
Beginning of stream token id.
|
| 72 |
+
eos_token_id (`int`, *optional*, defaults to 50279):
|
| 73 |
+
End of stream token id.
|
| 74 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 75 |
+
Whether to tie weight embeddings
|
| 76 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
| 77 |
+
The base period of the RoPE embeddings.
|
| 78 |
+
rope_scaling (`Dict`, *optional*):
|
| 79 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
| 80 |
+
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
|
| 81 |
+
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
| 82 |
+
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
|
| 83 |
+
these scaling strategies behave:
|
| 84 |
+
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
|
| 85 |
+
experimental feature, subject to breaking API changes in future versions.
|
| 86 |
+
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
| 87 |
+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
| 88 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 89 |
+
The dropout ratio for the attention probabilities.
|
| 90 |
+
clip_qkv (`float`, *optional*):
|
| 91 |
+
If not `None`, elements of query, key and value attention states are clipped so that their
|
| 92 |
+
absolute value does not exceed this value.
|
| 93 |
+
|
| 94 |
+
```python
|
| 95 |
+
>>> from transformers import OlmoModel, OlmoConfig
|
| 96 |
+
|
| 97 |
+
>>> # Initializing a OLMo 7B style configuration
|
| 98 |
+
>>> configuration = OlmoConfig()
|
| 99 |
+
|
| 100 |
+
>>> # Initializing a model from the OLMo 7B style configuration
|
| 101 |
+
>>> model = OlmoModel(configuration)
|
| 102 |
+
|
| 103 |
+
>>> # Accessing the model configuration
|
| 104 |
+
>>> configuration = model.config
|
| 105 |
+
```"""
|
| 106 |
+
|
| 107 |
+
model_type = "olmo"
|
| 108 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 109 |
+
base_model_tp_plan = {
|
| 110 |
+
"layers.*.self_attn.q_proj": "colwise",
|
| 111 |
+
"layers.*.self_attn.k_proj": "colwise",
|
| 112 |
+
"layers.*.self_attn.v_proj": "colwise",
|
| 113 |
+
"layers.*.self_attn.o_proj": "rowwise",
|
| 114 |
+
"layers.*.mlp.gate_proj": "colwise",
|
| 115 |
+
"layers.*.mlp.up_proj": "colwise",
|
| 116 |
+
"layers.*.mlp.down_proj": "rowwise",
|
| 117 |
+
}
|
| 118 |
+
base_model_pp_plan = {
|
| 119 |
+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
| 120 |
+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
| 121 |
+
"norm": (["hidden_states"], ["hidden_states"]),
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
vocab_size=50304,
|
| 127 |
+
hidden_size=4096,
|
| 128 |
+
intermediate_size=11008,
|
| 129 |
+
num_hidden_layers=32,
|
| 130 |
+
num_attention_heads=32,
|
| 131 |
+
num_key_value_heads=None,
|
| 132 |
+
hidden_act="silu",
|
| 133 |
+
max_position_embeddings=2048,
|
| 134 |
+
initializer_range=0.02,
|
| 135 |
+
use_cache=True,
|
| 136 |
+
pad_token_id=1,
|
| 137 |
+
bos_token_id=None,
|
| 138 |
+
eos_token_id=50279,
|
| 139 |
+
tie_word_embeddings=False,
|
| 140 |
+
rope_theta=10000.0,
|
| 141 |
+
rope_scaling=None,
|
| 142 |
+
attention_bias=False,
|
| 143 |
+
attention_dropout=0.0,
|
| 144 |
+
clip_qkv=None,
|
| 145 |
+
**kwargs,
|
| 146 |
+
):
|
| 147 |
+
self.vocab_size = vocab_size
|
| 148 |
+
self.max_position_embeddings = max_position_embeddings
|
| 149 |
+
self.hidden_size = hidden_size
|
| 150 |
+
self.intermediate_size = intermediate_size
|
| 151 |
+
self.num_hidden_layers = num_hidden_layers
|
| 152 |
+
self.num_attention_heads = num_attention_heads
|
| 153 |
+
|
| 154 |
+
# for backward compatibility
|
| 155 |
+
if num_key_value_heads is None:
|
| 156 |
+
num_key_value_heads = num_attention_heads
|
| 157 |
+
|
| 158 |
+
self.num_key_value_heads = num_key_value_heads
|
| 159 |
+
self.hidden_act = hidden_act
|
| 160 |
+
self.initializer_range = initializer_range
|
| 161 |
+
self.use_cache = use_cache
|
| 162 |
+
self.rope_theta = rope_theta
|
| 163 |
+
self.rope_scaling = rope_scaling
|
| 164 |
+
self._rope_scaling_validation()
|
| 165 |
+
self.attention_bias = attention_bias
|
| 166 |
+
self.attention_dropout = attention_dropout
|
| 167 |
+
self.clip_qkv = clip_qkv
|
| 168 |
+
|
| 169 |
+
super().__init__(
|
| 170 |
+
pad_token_id=pad_token_id,
|
| 171 |
+
bos_token_id=bos_token_id,
|
| 172 |
+
eos_token_id=eos_token_id,
|
| 173 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 174 |
+
**kwargs,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
def _rope_scaling_validation(self):
|
| 178 |
+
"""
|
| 179 |
+
Validate the `rope_scaling` configuration.
|
| 180 |
+
"""
|
| 181 |
+
if self.rope_scaling is None:
|
| 182 |
+
return
|
| 183 |
+
|
| 184 |
+
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
| 185 |
+
raise ValueError(
|
| 186 |
+
f"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, got {self.rope_scaling}"
|
| 187 |
+
)
|
| 188 |
+
rope_scaling_type = self.rope_scaling.get("type", None)
|
| 189 |
+
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
| 190 |
+
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
| 191 |
+
raise ValueError(
|
| 192 |
+
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
| 193 |
+
)
|
| 194 |
+
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
| 195 |
+
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
__all__ = ["OlmoConfig"]
|
docs/transformers/src/transformers/models/olmo/convert_olmo_weights_to_hf.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 EleutherAI and The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import argparse
|
| 15 |
+
import gc
|
| 16 |
+
import json
|
| 17 |
+
import os
|
| 18 |
+
import shutil
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import yaml
|
| 23 |
+
from tokenizers import Tokenizer
|
| 24 |
+
|
| 25 |
+
from transformers import OlmoConfig, OlmoForCausalLM
|
| 26 |
+
from transformers.models.gpt_neox.tokenization_gpt_neox_fast import GPTNeoXTokenizerFast
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
"""
|
| 30 |
+
Sample usage:
|
| 31 |
+
|
| 32 |
+
```
|
| 33 |
+
python src/transformers/models/olmo/convert_olmo_weights_to_hf.py \
|
| 34 |
+
--input_dir /path/to/downloaded/olmo/weights --model_size 7B --output_dir /output/path
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
Thereafter, models can be loaded via:
|
| 38 |
+
|
| 39 |
+
```py
|
| 40 |
+
from transformers import OlmoForCausalLM, AutoTokenizer
|
| 41 |
+
|
| 42 |
+
model = OlmoForCausalLM.from_pretrained("/output/path")
|
| 43 |
+
tokenizer = AutoTokenizer.from_pretrained("/output/path")
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
|
| 47 |
+
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
|
| 52 |
+
return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def read_json(path):
|
| 56 |
+
with open(path, "r") as f:
|
| 57 |
+
return json.load(f)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def write_json(text, path):
|
| 61 |
+
with open(path, "w") as f:
|
| 62 |
+
json.dump(text, f)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def write_model(model_path, input_base_path, tokenizer_path=None, safe_serialization=True, fix_eos_token_id=True):
|
| 66 |
+
os.makedirs(model_path, exist_ok=True)
|
| 67 |
+
tmp_model_path = os.path.join(model_path, "tmp")
|
| 68 |
+
os.makedirs(tmp_model_path, exist_ok=True)
|
| 69 |
+
|
| 70 |
+
config_path = Path(input_base_path) / "config.yaml"
|
| 71 |
+
olmo_config = yaml.safe_load(config_path.read_text())["model"]
|
| 72 |
+
|
| 73 |
+
n_layers = olmo_config["n_layers"]
|
| 74 |
+
n_heads = olmo_config["n_heads"]
|
| 75 |
+
dim = olmo_config["d_model"]
|
| 76 |
+
dims_per_head = dim // n_heads
|
| 77 |
+
base = 10000.0
|
| 78 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
|
| 79 |
+
max_position_embeddings = olmo_config["max_sequence_length"]
|
| 80 |
+
|
| 81 |
+
vocab_size = olmo_config.get("embedding_size", olmo_config["vocab_size"])
|
| 82 |
+
|
| 83 |
+
if olmo_config.get("n_kv_heads", None) is not None:
|
| 84 |
+
num_key_value_heads = olmo_config["n_kv_heads"] # for GQA / MQA
|
| 85 |
+
elif olmo_config["multi_query_attention"]: # compatibility with other checkpoints
|
| 86 |
+
num_key_value_heads = 1
|
| 87 |
+
else:
|
| 88 |
+
num_key_value_heads = n_heads
|
| 89 |
+
|
| 90 |
+
print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
|
| 91 |
+
|
| 92 |
+
# Not sharded
|
| 93 |
+
# (The sharded implementation would also work, but this is simpler.)
|
| 94 |
+
loaded = torch.load(os.path.join(input_base_path, "model.pt"), map_location="cpu", weights_only=True)
|
| 95 |
+
|
| 96 |
+
param_count = 0
|
| 97 |
+
index_dict = {"weight_map": {}}
|
| 98 |
+
for layer_i in range(n_layers):
|
| 99 |
+
filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
|
| 100 |
+
# Unsharded
|
| 101 |
+
# TODO: Layernorm stuff
|
| 102 |
+
# TODO: multi query attention
|
| 103 |
+
fused_dims = [dim, dims_per_head * num_key_value_heads, dims_per_head * num_key_value_heads]
|
| 104 |
+
q_proj_weight, k_proj_weight, v_proj_weight = torch.split(
|
| 105 |
+
loaded[f"transformer.blocks.{layer_i}.att_proj.weight"], fused_dims, dim=0
|
| 106 |
+
)
|
| 107 |
+
up_proj_weight, gate_proj_weight = torch.chunk(
|
| 108 |
+
loaded[f"transformer.blocks.{layer_i}.ff_proj.weight"], 2, dim=0
|
| 109 |
+
)
|
| 110 |
+
state_dict = {
|
| 111 |
+
f"model.layers.{layer_i}.self_attn.q_proj.weight": q_proj_weight,
|
| 112 |
+
f"model.layers.{layer_i}.self_attn.k_proj.weight": k_proj_weight,
|
| 113 |
+
f"model.layers.{layer_i}.self_attn.v_proj.weight": v_proj_weight,
|
| 114 |
+
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"transformer.blocks.{layer_i}.attn_out.weight"],
|
| 115 |
+
f"model.layers.{layer_i}.mlp.gate_proj.weight": gate_proj_weight,
|
| 116 |
+
f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"transformer.blocks.{layer_i}.ff_out.weight"],
|
| 117 |
+
f"model.layers.{layer_i}.mlp.up_proj.weight": up_proj_weight,
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
|
| 121 |
+
|
| 122 |
+
for k, v in state_dict.items():
|
| 123 |
+
index_dict["weight_map"][k] = filename
|
| 124 |
+
param_count += v.numel()
|
| 125 |
+
torch.save(state_dict, os.path.join(tmp_model_path, filename))
|
| 126 |
+
|
| 127 |
+
filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
|
| 128 |
+
|
| 129 |
+
# Unsharded
|
| 130 |
+
# TODO: Deal with weight-tying
|
| 131 |
+
state_dict = {
|
| 132 |
+
"model.embed_tokens.weight": loaded["transformer.wte.weight"],
|
| 133 |
+
"lm_head.weight": loaded["transformer.ff_out.weight"]
|
| 134 |
+
if "transformer.ff_out.weight" in loaded
|
| 135 |
+
else loaded["transformer.wte.weight"],
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
for k, v in state_dict.items():
|
| 139 |
+
index_dict["weight_map"][k] = filename
|
| 140 |
+
param_count += v.numel()
|
| 141 |
+
torch.save(state_dict, os.path.join(tmp_model_path, filename))
|
| 142 |
+
|
| 143 |
+
# Write configs
|
| 144 |
+
index_dict["metadata"] = {"total_size": param_count * 2}
|
| 145 |
+
write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
|
| 146 |
+
|
| 147 |
+
if olmo_config.get("mlp_hidden_size", None) is not None:
|
| 148 |
+
intermediate_size = olmo_config["mlp_hidden_size"] // 2
|
| 149 |
+
else:
|
| 150 |
+
intermediate_size = (dim * olmo_config["mlp_ratio"]) // 2
|
| 151 |
+
|
| 152 |
+
config = OlmoConfig(
|
| 153 |
+
vocab_size=vocab_size,
|
| 154 |
+
hidden_size=dim,
|
| 155 |
+
intermediate_size=intermediate_size,
|
| 156 |
+
num_hidden_layers=n_layers,
|
| 157 |
+
num_attention_heads=n_heads,
|
| 158 |
+
num_key_value_heads=num_key_value_heads,
|
| 159 |
+
max_position_embeddings=max_position_embeddings,
|
| 160 |
+
pad_token_id=olmo_config["pad_token_id"],
|
| 161 |
+
bos_token_id=None,
|
| 162 |
+
eos_token_id=olmo_config["eos_token_id"],
|
| 163 |
+
tie_word_embeddings=olmo_config["weight_tying"],
|
| 164 |
+
rope_theta=base,
|
| 165 |
+
clip_qkv=olmo_config.get("clip_qkv"),
|
| 166 |
+
)
|
| 167 |
+
config.save_pretrained(tmp_model_path)
|
| 168 |
+
|
| 169 |
+
# Make space so we can load the model properly now.
|
| 170 |
+
del state_dict
|
| 171 |
+
del loaded
|
| 172 |
+
gc.collect()
|
| 173 |
+
|
| 174 |
+
if tokenizer_path is not None:
|
| 175 |
+
_write_tokenizer(model_path, config, tokenizer_path, fix_eos_token_id)
|
| 176 |
+
|
| 177 |
+
print("Loading the checkpoint in a OLMo model.")
|
| 178 |
+
model = OlmoForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float32, low_cpu_mem_usage=True)
|
| 179 |
+
# Avoid saving this as part of the config.
|
| 180 |
+
del model.config._name_or_path
|
| 181 |
+
print("Saving in the Transformers format.")
|
| 182 |
+
model.save_pretrained(model_path, safe_serialization=safe_serialization)
|
| 183 |
+
shutil.rmtree(tmp_model_path)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def _write_tokenizer(
|
| 187 |
+
output_path: Path, config: OlmoConfig, input_tokenizer_path: Path, fix_eos_token_id: bool = True
|
| 188 |
+
) -> None:
|
| 189 |
+
print(f"Saving a {GPTNeoXTokenizerFast.__name__} to {output_path}.")
|
| 190 |
+
|
| 191 |
+
base_tokenizer = Tokenizer.from_file(str(input_tokenizer_path))
|
| 192 |
+
|
| 193 |
+
eos_token_id = config.eos_token_id if config.eos_token_id is not None else base_tokenizer.get_vocab_size() - 1
|
| 194 |
+
pad_token_id = config.pad_token_id if config.pad_token_id is not None else eos_token_id
|
| 195 |
+
|
| 196 |
+
if fix_eos_token_id and eos_token_id == 0:
|
| 197 |
+
# Fixing a bug in OLMo where eos token id was incorrectly set
|
| 198 |
+
print("Changing eos_token_id from 0 to 50279.")
|
| 199 |
+
eos_token_id = 50279
|
| 200 |
+
|
| 201 |
+
tokenizer = GPTNeoXTokenizerFast(
|
| 202 |
+
tokenizer_object=base_tokenizer,
|
| 203 |
+
eos_token=base_tokenizer.decode([eos_token_id], skip_special_tokens=False),
|
| 204 |
+
pad_token=base_tokenizer.decode([pad_token_id], skip_special_tokens=False),
|
| 205 |
+
unk_token=None,
|
| 206 |
+
bos_token=None,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
tokenizer.save_pretrained(output_path)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def main():
|
| 213 |
+
parser = argparse.ArgumentParser()
|
| 214 |
+
parser.add_argument(
|
| 215 |
+
"--input_dir",
|
| 216 |
+
required=True,
|
| 217 |
+
help="Location of OLMo weights, which contains config.yaml and model.pt.",
|
| 218 |
+
)
|
| 219 |
+
parser.add_argument(
|
| 220 |
+
"--tokenizer_json_path",
|
| 221 |
+
default=None,
|
| 222 |
+
help="Location of OLMo tokenizer json file.",
|
| 223 |
+
)
|
| 224 |
+
parser.add_argument(
|
| 225 |
+
"--output_dir",
|
| 226 |
+
required=True,
|
| 227 |
+
help="Location to write HF model and tokenizer",
|
| 228 |
+
)
|
| 229 |
+
parser.add_argument(
|
| 230 |
+
"--no_fix_eos_token_id",
|
| 231 |
+
action="store_false",
|
| 232 |
+
dest="fix_eos_token_id",
|
| 233 |
+
help="If set, does not change eos token id from 0 to 50279 if it is 0. Changing 0 to 50279 is a bug fix, so use this option with care.",
|
| 234 |
+
)
|
| 235 |
+
parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.")
|
| 236 |
+
# Different OLMo versions used different default values for max_position_embeddings, hence the need to be able to specify which version is being used.
|
| 237 |
+
args = parser.parse_args()
|
| 238 |
+
write_model(
|
| 239 |
+
model_path=args.output_dir,
|
| 240 |
+
input_base_path=args.input_dir,
|
| 241 |
+
safe_serialization=args.safe_serialization,
|
| 242 |
+
tokenizer_path=args.tokenizer_json_path,
|
| 243 |
+
fix_eos_token_id=args.fix_eos_token_id,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
if __name__ == "__main__":
|
| 248 |
+
main()
|
docs/transformers/src/transformers/models/olmo/modeling_olmo.py
ADDED
|
@@ -0,0 +1,814 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 2 |
+
# This file was automatically generated from src/transformers/models/olmo/modular_olmo.py.
|
| 3 |
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
| 4 |
+
# the file from the modular. If any change should be done, please apply the change to the
|
| 5 |
+
# modular_olmo.py file directly. One of our CI enforces this.
|
| 6 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 7 |
+
from typing import Callable, Optional, Tuple, Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
from ...activations import ACT2FN
|
| 14 |
+
from ...cache_utils import Cache, DynamicCache, StaticCache
|
| 15 |
+
from ...generation import GenerationMixin
|
| 16 |
+
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
| 17 |
+
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
| 18 |
+
from ...modeling_layers import GradientCheckpointingLayer
|
| 19 |
+
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 20 |
+
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 21 |
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 22 |
+
from ...processing_utils import Unpack
|
| 23 |
+
from ...utils import (
|
| 24 |
+
LossKwargs,
|
| 25 |
+
add_start_docstrings,
|
| 26 |
+
add_start_docstrings_to_model_forward,
|
| 27 |
+
can_return_tuple,
|
| 28 |
+
is_torch_flex_attn_available,
|
| 29 |
+
logging,
|
| 30 |
+
replace_return_docstrings,
|
| 31 |
+
)
|
| 32 |
+
from .configuration_olmo import OlmoConfig
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
if is_torch_flex_attn_available():
|
| 36 |
+
from torch.nn.attention.flex_attention import BlockMask
|
| 37 |
+
|
| 38 |
+
from ...integrations.flex_attention import make_flex_block_causal_mask
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
logger = logging.get_logger(__name__)
|
| 42 |
+
_CONFIG_FOR_DOC = "OlmoConfig"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class OlmoLayerNorm(nn.Module):
|
| 46 |
+
"""LayerNorm but with no learnable weight or bias."""
|
| 47 |
+
|
| 48 |
+
def __init__(self, hidden_size: int) -> None:
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.normalized_shape = (hidden_size,)
|
| 51 |
+
|
| 52 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 53 |
+
orig_dtype = hidden_states.dtype
|
| 54 |
+
return F.layer_norm(hidden_states.to(dtype=torch.float32), self.normalized_shape, None, None, eps=1e-5).to(
|
| 55 |
+
orig_dtype
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class OlmoMLP(nn.Module):
|
| 60 |
+
def __init__(self, config):
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.config = config
|
| 63 |
+
self.hidden_size = config.hidden_size
|
| 64 |
+
self.intermediate_size = config.intermediate_size
|
| 65 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 66 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 67 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 68 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 69 |
+
|
| 70 |
+
def forward(self, x):
|
| 71 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 72 |
+
return down_proj
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def rotate_half(x):
|
| 76 |
+
"""Rotates half the hidden dims of the input."""
|
| 77 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 78 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 79 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 83 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
q (`torch.Tensor`): The query tensor.
|
| 87 |
+
k (`torch.Tensor`): The key tensor.
|
| 88 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 89 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 90 |
+
position_ids (`torch.Tensor`, *optional*):
|
| 91 |
+
Deprecated and unused.
|
| 92 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 93 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 94 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 95 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 96 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 97 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 98 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 99 |
+
Returns:
|
| 100 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 101 |
+
"""
|
| 102 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 103 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 104 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 105 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 106 |
+
return q_embed, k_embed
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 110 |
+
"""
|
| 111 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 112 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 113 |
+
"""
|
| 114 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 115 |
+
if n_rep == 1:
|
| 116 |
+
return hidden_states
|
| 117 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 118 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def eager_attention_forward(
|
| 122 |
+
module: nn.Module,
|
| 123 |
+
query: torch.Tensor,
|
| 124 |
+
key: torch.Tensor,
|
| 125 |
+
value: torch.Tensor,
|
| 126 |
+
attention_mask: Optional[torch.Tensor],
|
| 127 |
+
scaling: float,
|
| 128 |
+
dropout: float = 0.0,
|
| 129 |
+
**kwargs,
|
| 130 |
+
):
|
| 131 |
+
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 132 |
+
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 133 |
+
|
| 134 |
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 135 |
+
if attention_mask is not None:
|
| 136 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 137 |
+
attn_weights = attn_weights + causal_mask
|
| 138 |
+
|
| 139 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 140 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 141 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 142 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 143 |
+
|
| 144 |
+
return attn_output, attn_weights
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class OlmoAttention(nn.Module):
|
| 148 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 149 |
+
|
| 150 |
+
def __init__(self, config: OlmoConfig, layer_idx: int):
|
| 151 |
+
super().__init__()
|
| 152 |
+
self.config = config
|
| 153 |
+
self.layer_idx = layer_idx
|
| 154 |
+
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 155 |
+
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
| 156 |
+
self.scaling = self.head_dim**-0.5
|
| 157 |
+
self.attention_dropout = config.attention_dropout
|
| 158 |
+
self.is_causal = True
|
| 159 |
+
|
| 160 |
+
self.q_proj = nn.Linear(
|
| 161 |
+
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
| 162 |
+
)
|
| 163 |
+
self.k_proj = nn.Linear(
|
| 164 |
+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 165 |
+
)
|
| 166 |
+
self.v_proj = nn.Linear(
|
| 167 |
+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 168 |
+
)
|
| 169 |
+
self.o_proj = nn.Linear(
|
| 170 |
+
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
def forward(
|
| 174 |
+
self,
|
| 175 |
+
hidden_states: torch.Tensor,
|
| 176 |
+
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
| 177 |
+
attention_mask: Optional[torch.Tensor],
|
| 178 |
+
past_key_value: Optional[Cache] = None,
|
| 179 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 180 |
+
**kwargs,
|
| 181 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 182 |
+
input_shape = hidden_states.shape[:-1]
|
| 183 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 184 |
+
|
| 185 |
+
query_states = self.q_proj(hidden_states)
|
| 186 |
+
key_states = self.k_proj(hidden_states)
|
| 187 |
+
value_states = self.v_proj(hidden_states)
|
| 188 |
+
|
| 189 |
+
if self.config.clip_qkv is not None:
|
| 190 |
+
query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
| 191 |
+
key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
| 192 |
+
value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
| 193 |
+
|
| 194 |
+
query_states = query_states.view(hidden_shape).transpose(1, 2)
|
| 195 |
+
key_states = key_states.view(hidden_shape).transpose(1, 2)
|
| 196 |
+
value_states = value_states.view(hidden_shape).transpose(1, 2)
|
| 197 |
+
|
| 198 |
+
cos, sin = position_embeddings
|
| 199 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 200 |
+
|
| 201 |
+
if past_key_value is not None:
|
| 202 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 203 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 204 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 205 |
+
|
| 206 |
+
attention_interface: Callable = eager_attention_forward
|
| 207 |
+
if self.config._attn_implementation != "eager":
|
| 208 |
+
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
| 209 |
+
logger.warning_once(
|
| 210 |
+
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
| 211 |
+
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
| 212 |
+
)
|
| 213 |
+
else:
|
| 214 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 215 |
+
|
| 216 |
+
attn_output, attn_weights = attention_interface(
|
| 217 |
+
self,
|
| 218 |
+
query_states,
|
| 219 |
+
key_states,
|
| 220 |
+
value_states,
|
| 221 |
+
attention_mask,
|
| 222 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 223 |
+
scaling=self.scaling,
|
| 224 |
+
**kwargs,
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 228 |
+
attn_output = self.o_proj(attn_output)
|
| 229 |
+
return attn_output, attn_weights
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class OlmoDecoderLayer(GradientCheckpointingLayer):
|
| 233 |
+
def __init__(self, config: OlmoConfig, layer_idx: int):
|
| 234 |
+
super().__init__()
|
| 235 |
+
self.hidden_size = config.hidden_size
|
| 236 |
+
self.self_attn = OlmoAttention(config=config, layer_idx=layer_idx)
|
| 237 |
+
|
| 238 |
+
self.mlp = OlmoMLP(config)
|
| 239 |
+
self.input_layernorm = OlmoLayerNorm(config.hidden_size)
|
| 240 |
+
self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size)
|
| 241 |
+
|
| 242 |
+
def forward(
|
| 243 |
+
self,
|
| 244 |
+
hidden_states: torch.Tensor,
|
| 245 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 246 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 247 |
+
past_key_value: Optional[Cache] = None,
|
| 248 |
+
output_attentions: Optional[bool] = False,
|
| 249 |
+
use_cache: Optional[bool] = False,
|
| 250 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 251 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
| 252 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 253 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 254 |
+
residual = hidden_states
|
| 255 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 256 |
+
|
| 257 |
+
# Self Attention
|
| 258 |
+
hidden_states, self_attn_weights = self.self_attn(
|
| 259 |
+
hidden_states=hidden_states,
|
| 260 |
+
attention_mask=attention_mask,
|
| 261 |
+
position_ids=position_ids,
|
| 262 |
+
past_key_value=past_key_value,
|
| 263 |
+
output_attentions=output_attentions,
|
| 264 |
+
use_cache=use_cache,
|
| 265 |
+
cache_position=cache_position,
|
| 266 |
+
position_embeddings=position_embeddings,
|
| 267 |
+
**kwargs,
|
| 268 |
+
)
|
| 269 |
+
hidden_states = residual + hidden_states
|
| 270 |
+
|
| 271 |
+
# Fully Connected
|
| 272 |
+
residual = hidden_states
|
| 273 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 274 |
+
hidden_states = self.mlp(hidden_states)
|
| 275 |
+
hidden_states = residual + hidden_states
|
| 276 |
+
|
| 277 |
+
outputs = (hidden_states,)
|
| 278 |
+
if output_attentions:
|
| 279 |
+
outputs += (self_attn_weights,)
|
| 280 |
+
|
| 281 |
+
return outputs
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
class OlmoRotaryEmbedding(nn.Module):
|
| 285 |
+
def __init__(self, config: OlmoConfig, device=None):
|
| 286 |
+
super().__init__()
|
| 287 |
+
# BC: "rope_type" was originally "type"
|
| 288 |
+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
| 289 |
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 290 |
+
else:
|
| 291 |
+
self.rope_type = "default"
|
| 292 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 293 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 294 |
+
|
| 295 |
+
self.config = config
|
| 296 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 297 |
+
|
| 298 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 299 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 300 |
+
self.original_inv_freq = self.inv_freq
|
| 301 |
+
|
| 302 |
+
@torch.no_grad()
|
| 303 |
+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 304 |
+
def forward(self, x, position_ids):
|
| 305 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 306 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 307 |
+
|
| 308 |
+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 309 |
+
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
| 310 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 311 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 312 |
+
cos = emb.cos() * self.attention_scaling
|
| 313 |
+
sin = emb.sin() * self.attention_scaling
|
| 314 |
+
|
| 315 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
OLMO_START_DOCSTRING = r"""
|
| 319 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 320 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 321 |
+
etc.)
|
| 322 |
+
|
| 323 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 324 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 325 |
+
and behavior.
|
| 326 |
+
|
| 327 |
+
Parameters:
|
| 328 |
+
config ([`OlmoConfig`]):
|
| 329 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
| 330 |
+
load the weights associated with the model, only the configuration. Check out the
|
| 331 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 332 |
+
"""
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
@add_start_docstrings(
|
| 336 |
+
"The bare Olmo Model outputting raw hidden-states without any specific head on top.",
|
| 337 |
+
OLMO_START_DOCSTRING,
|
| 338 |
+
)
|
| 339 |
+
class OlmoPreTrainedModel(PreTrainedModel):
|
| 340 |
+
config_class = OlmoConfig
|
| 341 |
+
base_model_prefix = "model"
|
| 342 |
+
supports_gradient_checkpointing = True
|
| 343 |
+
_no_split_modules = ["OlmoDecoderLayer"]
|
| 344 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 345 |
+
_supports_flash_attn_2 = True
|
| 346 |
+
_supports_sdpa = True
|
| 347 |
+
_supports_flex_attn = True
|
| 348 |
+
_supports_cache_class = True
|
| 349 |
+
_supports_quantized_cache = True
|
| 350 |
+
_supports_static_cache = True
|
| 351 |
+
_supports_attention_backend = True
|
| 352 |
+
|
| 353 |
+
def _init_weights(self, module):
|
| 354 |
+
std = self.config.initializer_range
|
| 355 |
+
if isinstance(module, nn.Linear):
|
| 356 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 357 |
+
if module.bias is not None:
|
| 358 |
+
module.bias.data.zero_()
|
| 359 |
+
elif isinstance(module, nn.Embedding):
|
| 360 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 361 |
+
if module.padding_idx is not None:
|
| 362 |
+
module.weight.data[module.padding_idx].zero_()
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
OLMO_INPUTS_DOCSTRING = r"""
|
| 366 |
+
Args:
|
| 367 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 368 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 369 |
+
it.
|
| 370 |
+
|
| 371 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 372 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 373 |
+
|
| 374 |
+
[What are input IDs?](../glossary#input-ids)
|
| 375 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*):
|
| 376 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 377 |
+
|
| 378 |
+
- 1 for tokens that are **not masked**,
|
| 379 |
+
- 0 for tokens that are **masked**.
|
| 380 |
+
|
| 381 |
+
If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask,
|
| 382 |
+
but you can also pass a `BlockMask` object directly here.
|
| 383 |
+
|
| 384 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 385 |
+
|
| 386 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 387 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 388 |
+
|
| 389 |
+
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
|
| 390 |
+
`past_key_values`).
|
| 391 |
+
|
| 392 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
| 393 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
| 394 |
+
information on the default strategy.
|
| 395 |
+
|
| 396 |
+
- 1 indicates the head is **not masked**,
|
| 397 |
+
- 0 indicates the head is **masked**.
|
| 398 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 399 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 400 |
+
config.n_positions - 1]`.
|
| 401 |
+
|
| 402 |
+
[What are position IDs?](../glossary#position-ids)
|
| 403 |
+
past_key_values (`Cache`, *optional*):
|
| 404 |
+
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
| 405 |
+
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
| 406 |
+
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
| 407 |
+
|
| 408 |
+
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
| 409 |
+
|
| 410 |
+
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
| 411 |
+
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
| 412 |
+
of shape `(batch_size, sequence_length)`.
|
| 413 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 414 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 415 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 416 |
+
model's internal embedding lookup matrix.
|
| 417 |
+
use_cache (`bool`, *optional*):
|
| 418 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
| 419 |
+
`past_key_values`).
|
| 420 |
+
output_attentions (`bool`, *optional*):
|
| 421 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 422 |
+
tensors for more detail.
|
| 423 |
+
output_hidden_states (`bool`, *optional*):
|
| 424 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 425 |
+
more detail.
|
| 426 |
+
return_dict (`bool`, *optional*):
|
| 427 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 428 |
+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
| 429 |
+
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
| 430 |
+
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
| 431 |
+
the complete sequence length.
|
| 432 |
+
"""
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
@add_start_docstrings(
|
| 436 |
+
"The bare Olmo Model outputting raw hidden-states without any specific head on top.",
|
| 437 |
+
OLMO_START_DOCSTRING,
|
| 438 |
+
)
|
| 439 |
+
class OlmoModel(OlmoPreTrainedModel):
|
| 440 |
+
"""
|
| 441 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OlmoDecoderLayer`]
|
| 442 |
+
|
| 443 |
+
Args:
|
| 444 |
+
config: OlmoConfig
|
| 445 |
+
"""
|
| 446 |
+
|
| 447 |
+
def __init__(self, config: OlmoConfig):
|
| 448 |
+
super().__init__(config)
|
| 449 |
+
self.padding_idx = config.pad_token_id
|
| 450 |
+
self.vocab_size = config.vocab_size
|
| 451 |
+
|
| 452 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 453 |
+
self.layers = nn.ModuleList(
|
| 454 |
+
[OlmoDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 455 |
+
)
|
| 456 |
+
self.norm = OlmoLayerNorm(config.hidden_size)
|
| 457 |
+
self.rotary_emb = OlmoRotaryEmbedding(config=config)
|
| 458 |
+
self.gradient_checkpointing = False
|
| 459 |
+
|
| 460 |
+
# Initialize weights and apply final processing
|
| 461 |
+
self.post_init()
|
| 462 |
+
|
| 463 |
+
def get_input_embeddings(self):
|
| 464 |
+
return self.embed_tokens
|
| 465 |
+
|
| 466 |
+
def set_input_embeddings(self, value):
|
| 467 |
+
self.embed_tokens = value
|
| 468 |
+
|
| 469 |
+
@can_return_tuple
|
| 470 |
+
@add_start_docstrings_to_model_forward(OLMO_INPUTS_DOCSTRING)
|
| 471 |
+
def forward(
|
| 472 |
+
self,
|
| 473 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 474 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 475 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 476 |
+
past_key_values: Optional[Cache] = None,
|
| 477 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 478 |
+
use_cache: Optional[bool] = None,
|
| 479 |
+
output_attentions: Optional[bool] = None,
|
| 480 |
+
output_hidden_states: Optional[bool] = None,
|
| 481 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 482 |
+
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
| 483 |
+
) -> BaseModelOutputWithPast:
|
| 484 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 485 |
+
output_hidden_states = (
|
| 486 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 487 |
+
)
|
| 488 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 489 |
+
|
| 490 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 491 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 492 |
+
|
| 493 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
| 494 |
+
logger.warning_once(
|
| 495 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
| 496 |
+
)
|
| 497 |
+
use_cache = False
|
| 498 |
+
|
| 499 |
+
# TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
|
| 500 |
+
if not isinstance(past_key_values, (type(None), Cache)):
|
| 501 |
+
raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
|
| 502 |
+
|
| 503 |
+
if inputs_embeds is None:
|
| 504 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 505 |
+
|
| 506 |
+
if use_cache and past_key_values is None:
|
| 507 |
+
past_key_values = DynamicCache()
|
| 508 |
+
|
| 509 |
+
if cache_position is None:
|
| 510 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 511 |
+
cache_position = torch.arange(
|
| 512 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
if position_ids is None:
|
| 516 |
+
position_ids = cache_position.unsqueeze(0)
|
| 517 |
+
|
| 518 |
+
causal_mask = self._update_causal_mask(
|
| 519 |
+
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
hidden_states = inputs_embeds
|
| 523 |
+
|
| 524 |
+
# create position embeddings to be shared across the decoder layers
|
| 525 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 526 |
+
|
| 527 |
+
# decoder layers
|
| 528 |
+
all_hidden_states = () if output_hidden_states else None
|
| 529 |
+
all_self_attns = () if output_attentions else None
|
| 530 |
+
|
| 531 |
+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
| 532 |
+
if output_hidden_states:
|
| 533 |
+
all_hidden_states += (hidden_states,)
|
| 534 |
+
|
| 535 |
+
layer_outputs = decoder_layer(
|
| 536 |
+
hidden_states,
|
| 537 |
+
attention_mask=causal_mask,
|
| 538 |
+
position_ids=position_ids,
|
| 539 |
+
past_key_value=past_key_values,
|
| 540 |
+
output_attentions=output_attentions,
|
| 541 |
+
use_cache=use_cache,
|
| 542 |
+
cache_position=cache_position,
|
| 543 |
+
position_embeddings=position_embeddings,
|
| 544 |
+
**flash_attn_kwargs,
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
hidden_states = layer_outputs[0]
|
| 548 |
+
|
| 549 |
+
if output_attentions:
|
| 550 |
+
all_self_attns += (layer_outputs[1],)
|
| 551 |
+
|
| 552 |
+
hidden_states = self.norm(hidden_states)
|
| 553 |
+
|
| 554 |
+
# add hidden states from the last decoder layer
|
| 555 |
+
if output_hidden_states:
|
| 556 |
+
all_hidden_states += (hidden_states,)
|
| 557 |
+
|
| 558 |
+
return BaseModelOutputWithPast(
|
| 559 |
+
last_hidden_state=hidden_states,
|
| 560 |
+
past_key_values=past_key_values if use_cache else None,
|
| 561 |
+
hidden_states=all_hidden_states,
|
| 562 |
+
attentions=all_self_attns,
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
def _update_causal_mask(
|
| 566 |
+
self,
|
| 567 |
+
attention_mask: Union[torch.Tensor, "BlockMask"],
|
| 568 |
+
input_tensor: torch.Tensor,
|
| 569 |
+
cache_position: torch.Tensor,
|
| 570 |
+
past_key_values: Cache,
|
| 571 |
+
output_attentions: bool = False,
|
| 572 |
+
):
|
| 573 |
+
if self.config._attn_implementation == "flash_attention_2":
|
| 574 |
+
if attention_mask is not None and (attention_mask == 0.0).any():
|
| 575 |
+
return attention_mask
|
| 576 |
+
return None
|
| 577 |
+
if self.config._attn_implementation == "flex_attention":
|
| 578 |
+
if isinstance(attention_mask, torch.Tensor):
|
| 579 |
+
attention_mask = make_flex_block_causal_mask(attention_mask)
|
| 580 |
+
return attention_mask
|
| 581 |
+
|
| 582 |
+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
| 583 |
+
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
| 584 |
+
# to infer the attention mask.
|
| 585 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 586 |
+
using_static_cache = isinstance(past_key_values, StaticCache)
|
| 587 |
+
|
| 588 |
+
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
| 589 |
+
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
| 590 |
+
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
| 591 |
+
attention_mask,
|
| 592 |
+
inputs_embeds=input_tensor,
|
| 593 |
+
past_key_values_length=past_seen_tokens,
|
| 594 |
+
is_training=self.training,
|
| 595 |
+
):
|
| 596 |
+
return None
|
| 597 |
+
|
| 598 |
+
dtype, device = input_tensor.dtype, input_tensor.device
|
| 599 |
+
sequence_length = input_tensor.shape[1]
|
| 600 |
+
if using_static_cache:
|
| 601 |
+
target_length = past_key_values.get_max_cache_shape()
|
| 602 |
+
else:
|
| 603 |
+
target_length = (
|
| 604 |
+
attention_mask.shape[-1]
|
| 605 |
+
if isinstance(attention_mask, torch.Tensor)
|
| 606 |
+
else past_seen_tokens + sequence_length + 1
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
| 610 |
+
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
| 611 |
+
attention_mask,
|
| 612 |
+
sequence_length=sequence_length,
|
| 613 |
+
target_length=target_length,
|
| 614 |
+
dtype=dtype,
|
| 615 |
+
device=device,
|
| 616 |
+
cache_position=cache_position,
|
| 617 |
+
batch_size=input_tensor.shape[0],
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
if (
|
| 621 |
+
self.config._attn_implementation == "sdpa"
|
| 622 |
+
and attention_mask is not None
|
| 623 |
+
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
| 624 |
+
and not output_attentions
|
| 625 |
+
):
|
| 626 |
+
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
| 627 |
+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
| 628 |
+
# Details: https://github.com/pytorch/pytorch/issues/110213
|
| 629 |
+
min_dtype = torch.finfo(dtype).min
|
| 630 |
+
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
| 631 |
+
|
| 632 |
+
return causal_mask
|
| 633 |
+
|
| 634 |
+
@staticmethod
|
| 635 |
+
def _prepare_4d_causal_attention_mask_with_cache_position(
|
| 636 |
+
attention_mask: torch.Tensor,
|
| 637 |
+
sequence_length: int,
|
| 638 |
+
target_length: int,
|
| 639 |
+
dtype: torch.dtype,
|
| 640 |
+
device: torch.device,
|
| 641 |
+
cache_position: torch.Tensor,
|
| 642 |
+
batch_size: int,
|
| 643 |
+
**kwargs,
|
| 644 |
+
):
|
| 645 |
+
"""
|
| 646 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
| 647 |
+
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
| 648 |
+
|
| 649 |
+
Args:
|
| 650 |
+
attention_mask (`torch.Tensor`):
|
| 651 |
+
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
| 652 |
+
`(batch_size, 1, query_length, key_value_length)`.
|
| 653 |
+
sequence_length (`int`):
|
| 654 |
+
The sequence length being processed.
|
| 655 |
+
target_length (`int`):
|
| 656 |
+
The target length: when generating with static cache, the mask should be as long as the static cache,
|
| 657 |
+
to account for the 0 padding, the part of the cache that is not filled yet.
|
| 658 |
+
dtype (`torch.dtype`):
|
| 659 |
+
The dtype to use for the 4D attention mask.
|
| 660 |
+
device (`torch.device`):
|
| 661 |
+
The device to place the 4D attention mask on.
|
| 662 |
+
cache_position (`torch.Tensor`):
|
| 663 |
+
Indices depicting the position of the input sequence tokens in the sequence.
|
| 664 |
+
batch_size (`torch.Tensor`):
|
| 665 |
+
Batch size.
|
| 666 |
+
"""
|
| 667 |
+
if attention_mask is not None and attention_mask.dim() == 4:
|
| 668 |
+
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
| 669 |
+
causal_mask = attention_mask
|
| 670 |
+
else:
|
| 671 |
+
min_dtype = torch.finfo(dtype).min
|
| 672 |
+
causal_mask = torch.full(
|
| 673 |
+
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
| 674 |
+
)
|
| 675 |
+
if sequence_length != 1:
|
| 676 |
+
causal_mask = torch.triu(causal_mask, diagonal=1)
|
| 677 |
+
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
| 678 |
+
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
| 679 |
+
if attention_mask is not None:
|
| 680 |
+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
| 681 |
+
mask_length = attention_mask.shape[-1]
|
| 682 |
+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
| 683 |
+
causal_mask.device
|
| 684 |
+
)
|
| 685 |
+
padding_mask = padding_mask == 0
|
| 686 |
+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
| 687 |
+
padding_mask, min_dtype
|
| 688 |
+
)
|
| 689 |
+
|
| 690 |
+
return causal_mask
|
| 691 |
+
|
| 692 |
+
|
| 693 |
+
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin):
|
| 697 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 698 |
+
_tp_plan = {"lm_head": "colwise_rep"}
|
| 699 |
+
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 700 |
+
|
| 701 |
+
def __init__(self, config):
|
| 702 |
+
super().__init__(config)
|
| 703 |
+
self.model = OlmoModel(config)
|
| 704 |
+
self.vocab_size = config.vocab_size
|
| 705 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 706 |
+
|
| 707 |
+
# Initialize weights and apply final processing
|
| 708 |
+
self.post_init()
|
| 709 |
+
|
| 710 |
+
def get_input_embeddings(self):
|
| 711 |
+
return self.model.embed_tokens
|
| 712 |
+
|
| 713 |
+
def set_input_embeddings(self, value):
|
| 714 |
+
self.model.embed_tokens = value
|
| 715 |
+
|
| 716 |
+
def get_output_embeddings(self):
|
| 717 |
+
return self.lm_head
|
| 718 |
+
|
| 719 |
+
def set_output_embeddings(self, new_embeddings):
|
| 720 |
+
self.lm_head = new_embeddings
|
| 721 |
+
|
| 722 |
+
def set_decoder(self, decoder):
|
| 723 |
+
self.model = decoder
|
| 724 |
+
|
| 725 |
+
def get_decoder(self):
|
| 726 |
+
return self.model
|
| 727 |
+
|
| 728 |
+
@can_return_tuple
|
| 729 |
+
@add_start_docstrings_to_model_forward(OLMO_INPUTS_DOCSTRING)
|
| 730 |
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
| 731 |
+
def forward(
|
| 732 |
+
self,
|
| 733 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 734 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 735 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 736 |
+
past_key_values: Optional[Cache] = None,
|
| 737 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 738 |
+
labels: Optional[torch.LongTensor] = None,
|
| 739 |
+
use_cache: Optional[bool] = None,
|
| 740 |
+
output_attentions: Optional[bool] = None,
|
| 741 |
+
output_hidden_states: Optional[bool] = None,
|
| 742 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 743 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 744 |
+
**kwargs: Unpack[KwargsForCausalLM],
|
| 745 |
+
) -> CausalLMOutputWithPast:
|
| 746 |
+
r"""
|
| 747 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 748 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 749 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 750 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 751 |
+
|
| 752 |
+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
| 753 |
+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
| 754 |
+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
| 755 |
+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
| 756 |
+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
| 757 |
+
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
| 758 |
+
|
| 759 |
+
Returns:
|
| 760 |
+
|
| 761 |
+
Example:
|
| 762 |
+
|
| 763 |
+
```python
|
| 764 |
+
>>> from transformers import AutoTokenizer, OlmoForCausalLM
|
| 765 |
+
|
| 766 |
+
>>> model = OlmoForCausalLM.from_pretrained("meta-olmo/Olmo-2-7b-hf")
|
| 767 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("meta-olmo/Olmo-2-7b-hf")
|
| 768 |
+
|
| 769 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
| 770 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 771 |
+
|
| 772 |
+
>>> # Generate
|
| 773 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 774 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 775 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
| 776 |
+
```"""
|
| 777 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 778 |
+
output_hidden_states = (
|
| 779 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 780 |
+
)
|
| 781 |
+
|
| 782 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 783 |
+
outputs: BaseModelOutputWithPast = self.model(
|
| 784 |
+
input_ids=input_ids,
|
| 785 |
+
attention_mask=attention_mask,
|
| 786 |
+
position_ids=position_ids,
|
| 787 |
+
past_key_values=past_key_values,
|
| 788 |
+
inputs_embeds=inputs_embeds,
|
| 789 |
+
use_cache=use_cache,
|
| 790 |
+
output_attentions=output_attentions,
|
| 791 |
+
output_hidden_states=output_hidden_states,
|
| 792 |
+
cache_position=cache_position,
|
| 793 |
+
**kwargs,
|
| 794 |
+
)
|
| 795 |
+
|
| 796 |
+
hidden_states = outputs.last_hidden_state
|
| 797 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 798 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 799 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 800 |
+
|
| 801 |
+
loss = None
|
| 802 |
+
if labels is not None:
|
| 803 |
+
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
| 804 |
+
|
| 805 |
+
return CausalLMOutputWithPast(
|
| 806 |
+
loss=loss,
|
| 807 |
+
logits=logits,
|
| 808 |
+
past_key_values=outputs.past_key_values,
|
| 809 |
+
hidden_states=outputs.hidden_states,
|
| 810 |
+
attentions=outputs.attentions,
|
| 811 |
+
)
|
| 812 |
+
|
| 813 |
+
|
| 814 |
+
__all__ = ["OlmoForCausalLM", "OlmoModel", "OlmoPreTrainedModel"]
|
docs/transformers/src/transformers/models/olmo/modular_olmo.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable, Optional, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import torch.utils.checkpoint
|
| 7 |
+
|
| 8 |
+
from ...cache_utils import Cache
|
| 9 |
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
| 10 |
+
from ...utils import logging
|
| 11 |
+
from ..llama.modeling_llama import (
|
| 12 |
+
LlamaAttention,
|
| 13 |
+
LlamaDecoderLayer,
|
| 14 |
+
LlamaForCausalLM,
|
| 15 |
+
LlamaMLP,
|
| 16 |
+
LlamaModel,
|
| 17 |
+
LlamaPreTrainedModel,
|
| 18 |
+
LlamaRotaryEmbedding,
|
| 19 |
+
apply_rotary_pos_emb,
|
| 20 |
+
eager_attention_forward,
|
| 21 |
+
)
|
| 22 |
+
from .configuration_olmo import OlmoConfig
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logger = logging.get_logger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class OlmoLayerNorm(nn.Module):
|
| 29 |
+
"""LayerNorm but with no learnable weight or bias."""
|
| 30 |
+
|
| 31 |
+
def __init__(self, hidden_size: int) -> None:
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.normalized_shape = (hidden_size,)
|
| 34 |
+
|
| 35 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 36 |
+
orig_dtype = hidden_states.dtype
|
| 37 |
+
return F.layer_norm(hidden_states.to(dtype=torch.float32), self.normalized_shape, None, None, eps=1e-5).to(
|
| 38 |
+
orig_dtype
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class OlmoMLP(LlamaMLP):
|
| 43 |
+
def __init__(self, config):
|
| 44 |
+
super().__init__(config)
|
| 45 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 46 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 47 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class OlmoAttention(LlamaAttention):
|
| 51 |
+
def forward(
|
| 52 |
+
self,
|
| 53 |
+
hidden_states: torch.Tensor,
|
| 54 |
+
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
| 55 |
+
attention_mask: Optional[torch.Tensor],
|
| 56 |
+
past_key_value: Optional[Cache] = None,
|
| 57 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 58 |
+
**kwargs,
|
| 59 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 60 |
+
input_shape = hidden_states.shape[:-1]
|
| 61 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 62 |
+
|
| 63 |
+
query_states = self.q_proj(hidden_states)
|
| 64 |
+
key_states = self.k_proj(hidden_states)
|
| 65 |
+
value_states = self.v_proj(hidden_states)
|
| 66 |
+
|
| 67 |
+
if self.config.clip_qkv is not None:
|
| 68 |
+
query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
| 69 |
+
key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
| 70 |
+
value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
| 71 |
+
|
| 72 |
+
query_states = query_states.view(hidden_shape).transpose(1, 2)
|
| 73 |
+
key_states = key_states.view(hidden_shape).transpose(1, 2)
|
| 74 |
+
value_states = value_states.view(hidden_shape).transpose(1, 2)
|
| 75 |
+
|
| 76 |
+
cos, sin = position_embeddings
|
| 77 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 78 |
+
|
| 79 |
+
if past_key_value is not None:
|
| 80 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 81 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 82 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 83 |
+
|
| 84 |
+
attention_interface: Callable = eager_attention_forward
|
| 85 |
+
if self.config._attn_implementation != "eager":
|
| 86 |
+
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
| 87 |
+
logger.warning_once(
|
| 88 |
+
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
| 89 |
+
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
| 90 |
+
)
|
| 91 |
+
else:
|
| 92 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 93 |
+
|
| 94 |
+
attn_output, attn_weights = attention_interface(
|
| 95 |
+
self,
|
| 96 |
+
query_states,
|
| 97 |
+
key_states,
|
| 98 |
+
value_states,
|
| 99 |
+
attention_mask,
|
| 100 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 101 |
+
scaling=self.scaling,
|
| 102 |
+
**kwargs,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 106 |
+
attn_output = self.o_proj(attn_output)
|
| 107 |
+
return attn_output, attn_weights
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class OlmoDecoderLayer(LlamaDecoderLayer):
|
| 111 |
+
def __init__(self, config: OlmoConfig, layer_idx: int):
|
| 112 |
+
super().__init__(config, layer_idx)
|
| 113 |
+
self.input_layernorm = OlmoLayerNorm(config.hidden_size)
|
| 114 |
+
self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size)
|
| 115 |
+
self.self_attn = OlmoAttention(config=config, layer_idx=layer_idx)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class OlmoRotaryEmbedding(LlamaRotaryEmbedding):
|
| 119 |
+
pass
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class OlmoPreTrainedModel(LlamaPreTrainedModel):
|
| 123 |
+
def _init_weights(self, module):
|
| 124 |
+
std = self.config.initializer_range
|
| 125 |
+
if isinstance(module, nn.Linear):
|
| 126 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 127 |
+
if module.bias is not None:
|
| 128 |
+
module.bias.data.zero_()
|
| 129 |
+
elif isinstance(module, nn.Embedding):
|
| 130 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 131 |
+
if module.padding_idx is not None:
|
| 132 |
+
module.weight.data[module.padding_idx].zero_()
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class OlmoModel(LlamaModel):
|
| 136 |
+
def __init__(self, config: OlmoConfig):
|
| 137 |
+
super().__init__(config)
|
| 138 |
+
self.layers = nn.ModuleList(
|
| 139 |
+
[OlmoDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 140 |
+
)
|
| 141 |
+
self.norm = OlmoLayerNorm(config.hidden_size)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class OlmoForCausalLM(LlamaForCausalLM):
|
| 145 |
+
pass
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
__all__ = ["OlmoForCausalLM", "OlmoModel", "OlmoPreTrainedModel"]
|
docs/transformers/src/transformers/models/olmo2/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 EleutherAI and The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_olmo2 import *
|
| 22 |
+
from .modeling_olmo2 import *
|
| 23 |
+
else:
|
| 24 |
+
import sys
|
| 25 |
+
|
| 26 |
+
_file = globals()["__file__"]
|
| 27 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/src/transformers/models/olmo2/configuration_olmo2.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 2 |
+
# This file was automatically generated from src/transformers/models/olmo2/modular_olmo2.py.
|
| 3 |
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
| 4 |
+
# the file from the modular. If any change should be done, please apply the change to the
|
| 5 |
+
# modular_olmo2.py file directly. One of our CI enforces this.
|
| 6 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 7 |
+
|
| 8 |
+
from ...configuration_utils import PretrainedConfig
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Olmo2Config(PretrainedConfig):
|
| 12 |
+
r"""
|
| 13 |
+
This is the configuration class to store the configuration of a [`Olmo2Model`]. It is used to instantiate an OLMo2
|
| 14 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 15 |
+
defaults will yield a similar configuration to that of the [allenai/Olmo2-7B-1124-hf](https://huggingface.co/allenai/Olmo2-7B-1124-hf).
|
| 16 |
+
|
| 17 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 18 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
vocab_size (`int`, *optional*, defaults to 50304):
|
| 23 |
+
Vocabulary size of the Olmo2 model. Defines the number of different tokens that can be represented by the
|
| 24 |
+
`inputs_ids` passed when calling [`Olmo2Model`]
|
| 25 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
| 26 |
+
Dimension of the hidden representations.
|
| 27 |
+
intermediate_size (`int`, *optional*, defaults to 11008):
|
| 28 |
+
Dimension of the MLP representations.
|
| 29 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
| 30 |
+
Number of hidden layers in the Transformer decoder.
|
| 31 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
| 32 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
| 33 |
+
num_key_value_heads (`int`, *optional*):
|
| 34 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
| 35 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
| 36 |
+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
| 37 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
| 38 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
| 39 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
| 40 |
+
`num_attention_heads`.
|
| 41 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 42 |
+
The non-linear activation function (function or string) in the decoder.
|
| 43 |
+
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
| 44 |
+
The maximum sequence length that this model might ever be used with.
|
| 45 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 46 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 47 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 48 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 49 |
+
relevant if `config.is_decoder=True`.
|
| 50 |
+
pad_token_id (`int`, *optional*, defaults to 1):
|
| 51 |
+
Padding token id.
|
| 52 |
+
bos_token_id (`int`, *optional*):
|
| 53 |
+
Beginning of stream token id.
|
| 54 |
+
eos_token_id (`int`, *optional*, defaults to 50279):
|
| 55 |
+
End of stream token id.
|
| 56 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 57 |
+
Whether to tie weight embeddings
|
| 58 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
| 59 |
+
The base period of the RoPE embeddings.
|
| 60 |
+
rope_scaling (`Dict`, *optional*):
|
| 61 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
| 62 |
+
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
|
| 63 |
+
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
| 64 |
+
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
|
| 65 |
+
these scaling strategies behave:
|
| 66 |
+
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
|
| 67 |
+
experimental feature, subject to breaking API changes in future versions.
|
| 68 |
+
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
| 69 |
+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
| 70 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 71 |
+
The dropout ratio for the attention probabilities.
|
| 72 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 73 |
+
The epsilon used by the rms normalization layers.
|
| 74 |
+
|
| 75 |
+
```python
|
| 76 |
+
>>> from transformers import Olmo2Model, Olmo2Config
|
| 77 |
+
|
| 78 |
+
>>> # Initializing a Olmo2 7B style configuration
|
| 79 |
+
>>> configuration = Olmo2Config()
|
| 80 |
+
|
| 81 |
+
>>> # Initializing a model from the Olmo2 7B style configuration
|
| 82 |
+
>>> model = Olmo2Model(configuration)
|
| 83 |
+
|
| 84 |
+
>>> # Accessing the model configuration
|
| 85 |
+
>>> configuration = model.config
|
| 86 |
+
```
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
model_type = "olmo2"
|
| 90 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 91 |
+
base_model_tp_plan = {
|
| 92 |
+
"layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
|
| 93 |
+
"layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
|
| 94 |
+
"layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
|
| 95 |
+
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
|
| 96 |
+
"layers.*.mlp.gate_proj": "colwise",
|
| 97 |
+
"layers.*.mlp.up_proj": "colwise",
|
| 98 |
+
"layers.*.mlp.down_proj": "rowwise",
|
| 99 |
+
}
|
| 100 |
+
base_model_pp_plan = {
|
| 101 |
+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
| 102 |
+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
| 103 |
+
"norm": (["hidden_states"], ["hidden_states"]),
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
def __init__(
|
| 107 |
+
self,
|
| 108 |
+
vocab_size=50304,
|
| 109 |
+
hidden_size=4096,
|
| 110 |
+
intermediate_size=11008,
|
| 111 |
+
num_hidden_layers=32,
|
| 112 |
+
num_attention_heads=32,
|
| 113 |
+
num_key_value_heads=None,
|
| 114 |
+
hidden_act="silu",
|
| 115 |
+
max_position_embeddings=2048,
|
| 116 |
+
initializer_range=0.02,
|
| 117 |
+
use_cache=True,
|
| 118 |
+
pad_token_id=1,
|
| 119 |
+
bos_token_id=None,
|
| 120 |
+
eos_token_id=50279,
|
| 121 |
+
tie_word_embeddings=False,
|
| 122 |
+
rope_theta=10000.0,
|
| 123 |
+
rope_scaling=None,
|
| 124 |
+
attention_bias=False,
|
| 125 |
+
attention_dropout=0.0,
|
| 126 |
+
rms_norm_eps=1e-5,
|
| 127 |
+
**kwargs,
|
| 128 |
+
):
|
| 129 |
+
super().__init__(
|
| 130 |
+
pad_token_id=pad_token_id,
|
| 131 |
+
bos_token_id=bos_token_id,
|
| 132 |
+
eos_token_id=eos_token_id,
|
| 133 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 134 |
+
**kwargs,
|
| 135 |
+
)
|
| 136 |
+
self.vocab_size = vocab_size
|
| 137 |
+
self.max_position_embeddings = max_position_embeddings
|
| 138 |
+
self.hidden_size = hidden_size
|
| 139 |
+
self.intermediate_size = intermediate_size
|
| 140 |
+
self.num_hidden_layers = num_hidden_layers
|
| 141 |
+
self.num_attention_heads = num_attention_heads
|
| 142 |
+
|
| 143 |
+
# for backward compatibility
|
| 144 |
+
if num_key_value_heads is None:
|
| 145 |
+
num_key_value_heads = num_attention_heads
|
| 146 |
+
|
| 147 |
+
self.num_key_value_heads = num_key_value_heads
|
| 148 |
+
self.hidden_act = hidden_act
|
| 149 |
+
self.initializer_range = initializer_range
|
| 150 |
+
self.use_cache = use_cache
|
| 151 |
+
self.rope_theta = rope_theta
|
| 152 |
+
self.rope_scaling = rope_scaling
|
| 153 |
+
self._rope_scaling_validation()
|
| 154 |
+
self.attention_bias = attention_bias
|
| 155 |
+
self.attention_dropout = attention_dropout
|
| 156 |
+
|
| 157 |
+
self.rms_norm_eps = rms_norm_eps
|
| 158 |
+
|
| 159 |
+
def _rope_scaling_validation(self):
|
| 160 |
+
"""
|
| 161 |
+
Validate the `rope_scaling` configuration.
|
| 162 |
+
"""
|
| 163 |
+
if self.rope_scaling is None:
|
| 164 |
+
return
|
| 165 |
+
|
| 166 |
+
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
| 167 |
+
raise ValueError(
|
| 168 |
+
f"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, got {self.rope_scaling}"
|
| 169 |
+
)
|
| 170 |
+
rope_scaling_type = self.rope_scaling.get("type", None)
|
| 171 |
+
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
| 172 |
+
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
| 173 |
+
raise ValueError(
|
| 174 |
+
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
| 175 |
+
)
|
| 176 |
+
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
| 177 |
+
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
__all__ = ["Olmo2Config"]
|
docs/transformers/src/transformers/models/olmo2/convert_olmo2_weights_to_hf.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 EleutherAI and The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import gc
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
import shutil
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from typing import Any, Dict
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
import yaml
|
| 26 |
+
from tokenizers import Tokenizer
|
| 27 |
+
|
| 28 |
+
from transformers import Olmo2Config, Olmo2ForCausalLM
|
| 29 |
+
from transformers.models.gpt2.tokenization_gpt2_fast import GPT2TokenizerFast
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
"""
|
| 33 |
+
Sample usage:
|
| 34 |
+
|
| 35 |
+
```
|
| 36 |
+
python src/transformers/models/olmo2/convert_olmo2_weights_to_hf.py \
|
| 37 |
+
--input_dir /path/to/downloaded/olmo2/weights --model_size 7B --output_dir /output/path
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
Thereafter, models can be loaded via:
|
| 41 |
+
|
| 42 |
+
```py
|
| 43 |
+
from transformers import Olmo2ForCausalLM, AutoTokenizer
|
| 44 |
+
|
| 45 |
+
model = Olmo2ForCausalLM.from_pretrained("/output/path")
|
| 46 |
+
tokenizer = AutoTokenizer.from_pretrained("/output/path")
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
|
| 50 |
+
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
|
| 55 |
+
return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def read_json(path):
|
| 59 |
+
with open(path, "r") as f:
|
| 60 |
+
return json.load(f)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def write_json(text, path):
|
| 64 |
+
with open(path, "w") as f:
|
| 65 |
+
json.dump(text, f)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def write_model(
|
| 69 |
+
model_path,
|
| 70 |
+
input_base_path,
|
| 71 |
+
include_tokenizer=True,
|
| 72 |
+
tokenizer_path=None,
|
| 73 |
+
safe_serialization=True,
|
| 74 |
+
fix_eos_token_id=True,
|
| 75 |
+
tmp_cleanup=True,
|
| 76 |
+
):
|
| 77 |
+
os.makedirs(model_path, exist_ok=True)
|
| 78 |
+
tmp_model_path = os.path.join(model_path, "tmp")
|
| 79 |
+
os.makedirs(tmp_model_path, exist_ok=True)
|
| 80 |
+
|
| 81 |
+
config_path = Path(input_base_path) / "config.yaml"
|
| 82 |
+
olmo2_config = yaml.safe_load(config_path.read_text())["model"]
|
| 83 |
+
|
| 84 |
+
if not olmo2_config.get("attention_layer_norm", False):
|
| 85 |
+
raise RuntimeError("OLMo2 checkpoints must have attention layer norm")
|
| 86 |
+
if not olmo2_config.get("norm_after", False):
|
| 87 |
+
raise RuntimeError("OLMo2 checkpoints must set norm_after to True")
|
| 88 |
+
|
| 89 |
+
n_layers = olmo2_config["n_layers"]
|
| 90 |
+
n_heads = olmo2_config["n_heads"]
|
| 91 |
+
dim = olmo2_config["d_model"]
|
| 92 |
+
dims_per_head = dim // n_heads
|
| 93 |
+
base = olmo2_config["rope_theta"]
|
| 94 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
|
| 95 |
+
max_position_embeddings = olmo2_config["max_sequence_length"]
|
| 96 |
+
|
| 97 |
+
vocab_size = olmo2_config.get("embedding_size", olmo2_config["vocab_size"])
|
| 98 |
+
|
| 99 |
+
if olmo2_config.get("n_kv_heads", None) is not None:
|
| 100 |
+
num_key_value_heads = olmo2_config["n_kv_heads"] # for GQA / MQA
|
| 101 |
+
elif olmo2_config["multi_query_attention"]: # compatibility with other checkpoints
|
| 102 |
+
num_key_value_heads = 1
|
| 103 |
+
else:
|
| 104 |
+
num_key_value_heads = n_heads
|
| 105 |
+
|
| 106 |
+
print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
|
| 107 |
+
|
| 108 |
+
# Not sharded
|
| 109 |
+
# (The sharded implementation would also work, but this is simpler.)
|
| 110 |
+
loaded = torch.load(os.path.join(input_base_path, "model.pt"), map_location="cpu", weights_only=True)
|
| 111 |
+
|
| 112 |
+
param_count = 0
|
| 113 |
+
index_dict: Dict[str, Any] = {"weight_map": {}}
|
| 114 |
+
for layer_i in range(n_layers):
|
| 115 |
+
filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
|
| 116 |
+
# Unsharded
|
| 117 |
+
# TODO: Layernorm stuff
|
| 118 |
+
# TODO: multi query attention
|
| 119 |
+
fused_dims = [dim, dims_per_head * num_key_value_heads, dims_per_head * num_key_value_heads]
|
| 120 |
+
q_proj_weight, k_proj_weight, v_proj_weight = torch.split(
|
| 121 |
+
loaded[f"transformer.blocks.{layer_i}.att_proj.weight"], fused_dims, dim=0
|
| 122 |
+
)
|
| 123 |
+
up_proj_weight, gate_proj_weight = torch.chunk(
|
| 124 |
+
loaded[f"transformer.blocks.{layer_i}.ff_proj.weight"], 2, dim=0
|
| 125 |
+
)
|
| 126 |
+
state_dict = {
|
| 127 |
+
f"model.layers.{layer_i}.self_attn.q_proj.weight": q_proj_weight,
|
| 128 |
+
f"model.layers.{layer_i}.self_attn.k_proj.weight": k_proj_weight,
|
| 129 |
+
f"model.layers.{layer_i}.self_attn.v_proj.weight": v_proj_weight,
|
| 130 |
+
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"transformer.blocks.{layer_i}.attn_out.weight"],
|
| 131 |
+
f"model.layers.{layer_i}.self_attn.q_norm.weight": loaded[f"transformer.blocks.{layer_i}.q_norm.weight"],
|
| 132 |
+
f"model.layers.{layer_i}.self_attn.k_norm.weight": loaded[f"transformer.blocks.{layer_i}.k_norm.weight"],
|
| 133 |
+
f"model.layers.{layer_i}.mlp.gate_proj.weight": gate_proj_weight,
|
| 134 |
+
f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"transformer.blocks.{layer_i}.ff_out.weight"],
|
| 135 |
+
f"model.layers.{layer_i}.mlp.up_proj.weight": up_proj_weight,
|
| 136 |
+
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[
|
| 137 |
+
f"transformer.blocks.{layer_i}.attn_norm.weight"
|
| 138 |
+
],
|
| 139 |
+
f"model.layers.{layer_i}.post_feedforward_layernorm.weight": loaded[
|
| 140 |
+
f"transformer.blocks.{layer_i}.ff_norm.weight"
|
| 141 |
+
],
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
|
| 145 |
+
|
| 146 |
+
for k, v in state_dict.items():
|
| 147 |
+
index_dict["weight_map"][k] = filename
|
| 148 |
+
param_count += v.numel()
|
| 149 |
+
torch.save(state_dict, os.path.join(tmp_model_path, filename))
|
| 150 |
+
|
| 151 |
+
filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
|
| 152 |
+
|
| 153 |
+
# Unsharded
|
| 154 |
+
# TODO: Deal with weight-tying
|
| 155 |
+
state_dict = {
|
| 156 |
+
"model.embed_tokens.weight": loaded["transformer.wte.weight"],
|
| 157 |
+
"model.norm.weight": loaded["transformer.ln_f.weight"],
|
| 158 |
+
"lm_head.weight": loaded["transformer.ff_out.weight"]
|
| 159 |
+
if "transformer.ff_out.weight" in loaded
|
| 160 |
+
else loaded["transformer.wte.weight"],
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
for k, v in state_dict.items():
|
| 164 |
+
index_dict["weight_map"][k] = filename
|
| 165 |
+
param_count += v.numel()
|
| 166 |
+
torch.save(state_dict, os.path.join(tmp_model_path, filename))
|
| 167 |
+
|
| 168 |
+
# Write configs
|
| 169 |
+
index_dict["metadata"] = {"total_size": param_count * 2}
|
| 170 |
+
write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
|
| 171 |
+
|
| 172 |
+
if olmo2_config.get("mlp_hidden_size", None) is not None:
|
| 173 |
+
intermediate_size = olmo2_config["mlp_hidden_size"] // 2
|
| 174 |
+
else:
|
| 175 |
+
intermediate_size = (dim * olmo2_config["mlp_ratio"]) // 2
|
| 176 |
+
|
| 177 |
+
if fix_eos_token_id and olmo2_config["eos_token_id"] == 0:
|
| 178 |
+
# Fixing a bug in OLMo where eos token id was incorrectly set
|
| 179 |
+
print("Changing eos_token_id from 0 to 50279.")
|
| 180 |
+
olmo2_config["eos_token_id"] = 50279
|
| 181 |
+
|
| 182 |
+
config = Olmo2Config(
|
| 183 |
+
vocab_size=vocab_size,
|
| 184 |
+
hidden_size=dim,
|
| 185 |
+
intermediate_size=intermediate_size,
|
| 186 |
+
num_hidden_layers=n_layers,
|
| 187 |
+
num_attention_heads=n_heads,
|
| 188 |
+
num_key_value_heads=num_key_value_heads,
|
| 189 |
+
max_position_embeddings=max_position_embeddings,
|
| 190 |
+
pad_token_id=olmo2_config["pad_token_id"],
|
| 191 |
+
bos_token_id=None,
|
| 192 |
+
eos_token_id=olmo2_config["eos_token_id"],
|
| 193 |
+
tie_word_embeddings=olmo2_config["weight_tying"],
|
| 194 |
+
rms_norm_eps=olmo2_config["layer_norm_eps"],
|
| 195 |
+
rope_theta=base,
|
| 196 |
+
)
|
| 197 |
+
config.save_pretrained(tmp_model_path)
|
| 198 |
+
|
| 199 |
+
# Make space so we can load the model properly now.
|
| 200 |
+
del state_dict
|
| 201 |
+
del loaded
|
| 202 |
+
gc.collect()
|
| 203 |
+
|
| 204 |
+
if include_tokenizer:
|
| 205 |
+
_write_tokenizer(model_path, config, input_base_path, tokenizer_path)
|
| 206 |
+
|
| 207 |
+
print("Loading the checkpoint in a OLMo2 model.")
|
| 208 |
+
model = Olmo2ForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float32, low_cpu_mem_usage=True)
|
| 209 |
+
# Avoid saving this as part of the config.
|
| 210 |
+
del model.config._name_or_path
|
| 211 |
+
print("Saving in the Transformers format.")
|
| 212 |
+
model.save_pretrained(model_path, safe_serialization=safe_serialization)
|
| 213 |
+
if tmp_cleanup:
|
| 214 |
+
# Make cleanup optional; attempting to `rmtree` the `tmp_model_path` causes
|
| 215 |
+
# errors if using NFS.
|
| 216 |
+
shutil.rmtree(tmp_model_path)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def _write_tokenizer(
|
| 220 |
+
output_path: Path,
|
| 221 |
+
config: Olmo2Config,
|
| 222 |
+
checkpoint_dir: str,
|
| 223 |
+
input_tokenizer_path: Path | None,
|
| 224 |
+
) -> None:
|
| 225 |
+
print(f"Saving a {GPT2TokenizerFast.__name__} to {output_path}.")
|
| 226 |
+
|
| 227 |
+
if input_tokenizer_path is not None:
|
| 228 |
+
base_tokenizer = Tokenizer.from_file(str(input_tokenizer_path))
|
| 229 |
+
else:
|
| 230 |
+
config_path = Path(checkpoint_dir) / "config.yaml"
|
| 231 |
+
tokenizer_config = yaml.safe_load(config_path.read_text())["tokenizer"]
|
| 232 |
+
|
| 233 |
+
# Initialize tokenizer and validate vocab size.
|
| 234 |
+
if Path(tokenizer_config["identifier"]).is_file():
|
| 235 |
+
base_tokenizer = Tokenizer.from_file(tokenizer_config["identifier"])
|
| 236 |
+
else:
|
| 237 |
+
base_tokenizer = Tokenizer.from_pretrained(tokenizer_config["identifier"])
|
| 238 |
+
|
| 239 |
+
eos_token_id = config.eos_token_id if config.eos_token_id is not None else base_tokenizer.get_vocab_size() - 1
|
| 240 |
+
pad_token_id = config.pad_token_id if config.pad_token_id is not None else eos_token_id
|
| 241 |
+
|
| 242 |
+
tokenizer = GPT2TokenizerFast(
|
| 243 |
+
tokenizer_object=base_tokenizer,
|
| 244 |
+
eos_token=base_tokenizer.decode([eos_token_id], skip_special_tokens=False),
|
| 245 |
+
pad_token=base_tokenizer.decode([pad_token_id], skip_special_tokens=False),
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
tokenizer.save_pretrained(output_path)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def main():
|
| 252 |
+
parser = argparse.ArgumentParser()
|
| 253 |
+
parser.add_argument(
|
| 254 |
+
"--input_dir",
|
| 255 |
+
required=True,
|
| 256 |
+
help="Location of OLMo2 weights, which contains config.yaml and model.pt.",
|
| 257 |
+
)
|
| 258 |
+
parser.add_argument(
|
| 259 |
+
"--no_tokenizer",
|
| 260 |
+
action="store_false",
|
| 261 |
+
dest="include_tokenizer",
|
| 262 |
+
help="If set, do not convert OLMo tokenizer to HF tokenizer.",
|
| 263 |
+
)
|
| 264 |
+
parser.add_argument(
|
| 265 |
+
"--tokenizer_json_path",
|
| 266 |
+
type=Path,
|
| 267 |
+
default=None,
|
| 268 |
+
help="Location of OLMo2 tokenizer json file. Defaults to what is set in the config file.",
|
| 269 |
+
)
|
| 270 |
+
parser.add_argument(
|
| 271 |
+
"--output_dir",
|
| 272 |
+
required=True,
|
| 273 |
+
help="Location to write HF model and tokenizer",
|
| 274 |
+
)
|
| 275 |
+
parser.add_argument(
|
| 276 |
+
"--no_fix_eos_token_id",
|
| 277 |
+
action="store_false",
|
| 278 |
+
dest="fix_eos_token_id",
|
| 279 |
+
help="If set, does not change eos token id from 0 to 50279 if it is 0. Changing 0 to 50279 is a bug fix, so use this option with care.",
|
| 280 |
+
)
|
| 281 |
+
parser.add_argument(
|
| 282 |
+
"--no_tmp_cleanup",
|
| 283 |
+
action="store_false",
|
| 284 |
+
dest="tmp_cleanup",
|
| 285 |
+
help="If passed, don't remove temp dir at end of HF conversion.",
|
| 286 |
+
)
|
| 287 |
+
parser.add_argument(
|
| 288 |
+
"--no_safe_serialization",
|
| 289 |
+
action="store_false",
|
| 290 |
+
dest="safe_serialization",
|
| 291 |
+
help="Whether or not to save using `safetensors`.",
|
| 292 |
+
)
|
| 293 |
+
args = parser.parse_args()
|
| 294 |
+
write_model(
|
| 295 |
+
model_path=args.output_dir,
|
| 296 |
+
input_base_path=args.input_dir,
|
| 297 |
+
safe_serialization=args.safe_serialization,
|
| 298 |
+
include_tokenizer=args.include_tokenizer,
|
| 299 |
+
tokenizer_path=args.tokenizer_json_path,
|
| 300 |
+
fix_eos_token_id=args.fix_eos_token_id,
|
| 301 |
+
tmp_cleanup=args.tmp_cleanup,
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
if __name__ == "__main__":
|
| 306 |
+
main()
|
docs/transformers/src/transformers/models/olmo2/modeling_olmo2.py
ADDED
|
@@ -0,0 +1,820 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 2 |
+
# This file was automatically generated from src/transformers/models/olmo2/modular_olmo2.py.
|
| 3 |
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
| 4 |
+
# the file from the modular. If any change should be done, please apply the change to the
|
| 5 |
+
# modular_olmo2.py file directly. One of our CI enforces this.
|
| 6 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 7 |
+
from typing import Callable, Optional, Tuple, Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
|
| 12 |
+
from ...activations import ACT2FN
|
| 13 |
+
from ...cache_utils import Cache, DynamicCache, StaticCache
|
| 14 |
+
from ...generation import GenerationMixin
|
| 15 |
+
from ...integrations import use_kernel_forward_from_hub
|
| 16 |
+
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
| 17 |
+
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
| 18 |
+
from ...modeling_layers import GradientCheckpointingLayer
|
| 19 |
+
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 20 |
+
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 21 |
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 22 |
+
from ...processing_utils import Unpack
|
| 23 |
+
from ...utils import (
|
| 24 |
+
LossKwargs,
|
| 25 |
+
add_start_docstrings,
|
| 26 |
+
add_start_docstrings_to_model_forward,
|
| 27 |
+
can_return_tuple,
|
| 28 |
+
is_torch_flex_attn_available,
|
| 29 |
+
logging,
|
| 30 |
+
replace_return_docstrings,
|
| 31 |
+
)
|
| 32 |
+
from .configuration_olmo2 import Olmo2Config
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
if is_torch_flex_attn_available():
|
| 36 |
+
from torch.nn.attention.flex_attention import BlockMask
|
| 37 |
+
|
| 38 |
+
from ...integrations.flex_attention import make_flex_block_causal_mask
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
logger = logging.get_logger(__name__)
|
| 42 |
+
_CONFIG_FOR_DOC = "Olmo2Config"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@use_kernel_forward_from_hub("RMSNorm")
|
| 46 |
+
class Olmo2RMSNorm(nn.Module):
|
| 47 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 48 |
+
"""
|
| 49 |
+
Olmo2RMSNorm is equivalent to T5LayerNorm
|
| 50 |
+
"""
|
| 51 |
+
super().__init__()
|
| 52 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 53 |
+
self.variance_epsilon = eps
|
| 54 |
+
|
| 55 |
+
def forward(self, hidden_states):
|
| 56 |
+
input_dtype = hidden_states.dtype
|
| 57 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 58 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 59 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 60 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 61 |
+
|
| 62 |
+
def extra_repr(self):
|
| 63 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def rotate_half(x):
|
| 67 |
+
"""Rotates half the hidden dims of the input."""
|
| 68 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 69 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 70 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 74 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
q (`torch.Tensor`): The query tensor.
|
| 78 |
+
k (`torch.Tensor`): The key tensor.
|
| 79 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 80 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 81 |
+
position_ids (`torch.Tensor`, *optional*):
|
| 82 |
+
Deprecated and unused.
|
| 83 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 84 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 85 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 86 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 87 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 88 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 89 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 90 |
+
Returns:
|
| 91 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 92 |
+
"""
|
| 93 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 94 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 95 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 96 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 97 |
+
return q_embed, k_embed
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 101 |
+
"""
|
| 102 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 103 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 104 |
+
"""
|
| 105 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 106 |
+
if n_rep == 1:
|
| 107 |
+
return hidden_states
|
| 108 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 109 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def eager_attention_forward(
|
| 113 |
+
module: nn.Module,
|
| 114 |
+
query: torch.Tensor,
|
| 115 |
+
key: torch.Tensor,
|
| 116 |
+
value: torch.Tensor,
|
| 117 |
+
attention_mask: Optional[torch.Tensor],
|
| 118 |
+
scaling: float,
|
| 119 |
+
dropout: float = 0.0,
|
| 120 |
+
**kwargs,
|
| 121 |
+
):
|
| 122 |
+
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 123 |
+
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 124 |
+
|
| 125 |
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 126 |
+
if attention_mask is not None:
|
| 127 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 128 |
+
attn_weights = attn_weights + causal_mask
|
| 129 |
+
|
| 130 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 131 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 132 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 133 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 134 |
+
|
| 135 |
+
return attn_output, attn_weights
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class Olmo2Attention(nn.Module):
|
| 139 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 140 |
+
|
| 141 |
+
def __init__(self, config: Olmo2Config, layer_idx: Optional[int] = None):
|
| 142 |
+
super().__init__()
|
| 143 |
+
self.config = config
|
| 144 |
+
self.layer_idx = layer_idx
|
| 145 |
+
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 146 |
+
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
| 147 |
+
self.scaling = self.head_dim**-0.5
|
| 148 |
+
self.attention_dropout = config.attention_dropout
|
| 149 |
+
self.is_causal = True
|
| 150 |
+
|
| 151 |
+
self.q_proj = nn.Linear(
|
| 152 |
+
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
| 153 |
+
)
|
| 154 |
+
self.k_proj = nn.Linear(
|
| 155 |
+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 156 |
+
)
|
| 157 |
+
self.v_proj = nn.Linear(
|
| 158 |
+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 159 |
+
)
|
| 160 |
+
self.o_proj = nn.Linear(
|
| 161 |
+
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
| 162 |
+
)
|
| 163 |
+
self.q_norm = Olmo2RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps)
|
| 164 |
+
self.k_norm = Olmo2RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps)
|
| 165 |
+
|
| 166 |
+
def forward(
|
| 167 |
+
self,
|
| 168 |
+
hidden_states: torch.Tensor,
|
| 169 |
+
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
| 170 |
+
attention_mask: Optional[torch.Tensor],
|
| 171 |
+
past_key_value: Optional[Cache] = None,
|
| 172 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 173 |
+
**kwargs,
|
| 174 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 175 |
+
input_shape = hidden_states.shape[:-1]
|
| 176 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 177 |
+
|
| 178 |
+
query_states = self.q_norm(self.q_proj(hidden_states))
|
| 179 |
+
key_states = self.k_norm(self.k_proj(hidden_states))
|
| 180 |
+
value_states = self.v_proj(hidden_states)
|
| 181 |
+
|
| 182 |
+
query_states = query_states.view(hidden_shape).transpose(1, 2)
|
| 183 |
+
key_states = key_states.view(hidden_shape).transpose(1, 2)
|
| 184 |
+
value_states = value_states.view(hidden_shape).transpose(1, 2)
|
| 185 |
+
|
| 186 |
+
cos, sin = position_embeddings
|
| 187 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 188 |
+
|
| 189 |
+
if past_key_value is not None:
|
| 190 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 191 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 192 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 193 |
+
|
| 194 |
+
attention_interface: Callable = eager_attention_forward
|
| 195 |
+
if self.config._attn_implementation != "eager":
|
| 196 |
+
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
| 197 |
+
logger.warning_once(
|
| 198 |
+
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
| 199 |
+
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
| 200 |
+
)
|
| 201 |
+
else:
|
| 202 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 203 |
+
|
| 204 |
+
attn_output, attn_weights = attention_interface(
|
| 205 |
+
self,
|
| 206 |
+
query_states,
|
| 207 |
+
key_states,
|
| 208 |
+
value_states,
|
| 209 |
+
attention_mask,
|
| 210 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 211 |
+
scaling=self.scaling,
|
| 212 |
+
**kwargs,
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 216 |
+
attn_output = self.o_proj(attn_output)
|
| 217 |
+
return attn_output, attn_weights
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class Olmo2MLP(nn.Module):
|
| 221 |
+
def __init__(self, config):
|
| 222 |
+
super().__init__()
|
| 223 |
+
self.config = config
|
| 224 |
+
self.hidden_size = config.hidden_size
|
| 225 |
+
self.intermediate_size = config.intermediate_size
|
| 226 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 227 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 228 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 229 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 230 |
+
|
| 231 |
+
def forward(self, x):
|
| 232 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 233 |
+
return down_proj
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class Olmo2DecoderLayer(GradientCheckpointingLayer):
|
| 237 |
+
def __init__(self, config: Olmo2Config, layer_idx: int):
|
| 238 |
+
super().__init__()
|
| 239 |
+
self.hidden_size = config.hidden_size
|
| 240 |
+
self.self_attn = Olmo2Attention(config=config, layer_idx=layer_idx)
|
| 241 |
+
|
| 242 |
+
self.mlp = Olmo2MLP(config)
|
| 243 |
+
self.post_attention_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 244 |
+
self.post_feedforward_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 245 |
+
|
| 246 |
+
def forward(
|
| 247 |
+
self,
|
| 248 |
+
hidden_states: torch.Tensor,
|
| 249 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 250 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 251 |
+
past_key_value: Optional[Cache] = None,
|
| 252 |
+
output_attentions: Optional[bool] = False,
|
| 253 |
+
use_cache: Optional[bool] = False,
|
| 254 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 255 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
| 256 |
+
**kwargs,
|
| 257 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 258 |
+
residual = hidden_states
|
| 259 |
+
|
| 260 |
+
# Self Attention
|
| 261 |
+
hidden_states, self_attn_weights = self.self_attn(
|
| 262 |
+
hidden_states=hidden_states,
|
| 263 |
+
attention_mask=attention_mask,
|
| 264 |
+
position_ids=position_ids,
|
| 265 |
+
past_key_value=past_key_value,
|
| 266 |
+
output_attentions=output_attentions,
|
| 267 |
+
use_cache=use_cache,
|
| 268 |
+
cache_position=cache_position,
|
| 269 |
+
position_embeddings=position_embeddings,
|
| 270 |
+
**kwargs,
|
| 271 |
+
)
|
| 272 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 273 |
+
hidden_states = residual + hidden_states
|
| 274 |
+
|
| 275 |
+
# Fully Connected
|
| 276 |
+
residual = hidden_states
|
| 277 |
+
hidden_states = self.mlp(hidden_states)
|
| 278 |
+
hidden_states = self.post_feedforward_layernorm(hidden_states)
|
| 279 |
+
hidden_states = residual + hidden_states
|
| 280 |
+
|
| 281 |
+
outputs = (hidden_states,)
|
| 282 |
+
if output_attentions:
|
| 283 |
+
outputs += (self_attn_weights,)
|
| 284 |
+
|
| 285 |
+
return outputs
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
class Olmo2RotaryEmbedding(nn.Module):
|
| 289 |
+
def __init__(self, config: Olmo2Config, device=None):
|
| 290 |
+
super().__init__()
|
| 291 |
+
# BC: "rope_type" was originally "type"
|
| 292 |
+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
| 293 |
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 294 |
+
else:
|
| 295 |
+
self.rope_type = "default"
|
| 296 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 297 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 298 |
+
|
| 299 |
+
self.config = config
|
| 300 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 301 |
+
|
| 302 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 303 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 304 |
+
self.original_inv_freq = self.inv_freq
|
| 305 |
+
|
| 306 |
+
@torch.no_grad()
|
| 307 |
+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 308 |
+
def forward(self, x, position_ids):
|
| 309 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 310 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 311 |
+
|
| 312 |
+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 313 |
+
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
| 314 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 315 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 316 |
+
cos = emb.cos() * self.attention_scaling
|
| 317 |
+
sin = emb.sin() * self.attention_scaling
|
| 318 |
+
|
| 319 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
OLMO2_START_DOCSTRING = r"""
|
| 323 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 324 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 325 |
+
etc.)
|
| 326 |
+
|
| 327 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 328 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 329 |
+
and behavior.
|
| 330 |
+
|
| 331 |
+
Parameters:
|
| 332 |
+
config ([`Olmo2Config`]):
|
| 333 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
| 334 |
+
load the weights associated with the model, only the configuration. Check out the
|
| 335 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 336 |
+
"""
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
@add_start_docstrings(
|
| 340 |
+
"The bare Olmo2 Model outputting raw hidden-states without any specific head on top.",
|
| 341 |
+
OLMO2_START_DOCSTRING,
|
| 342 |
+
)
|
| 343 |
+
class Olmo2PreTrainedModel(PreTrainedModel):
|
| 344 |
+
config_class = Olmo2Config
|
| 345 |
+
base_model_prefix = "model"
|
| 346 |
+
supports_gradient_checkpointing = True
|
| 347 |
+
_no_split_modules = ["Olmo2DecoderLayer"]
|
| 348 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 349 |
+
_supports_flash_attn_2 = True
|
| 350 |
+
_supports_sdpa = True
|
| 351 |
+
_supports_flex_attn = True
|
| 352 |
+
_supports_cache_class = True
|
| 353 |
+
_supports_quantized_cache = True
|
| 354 |
+
_supports_static_cache = True
|
| 355 |
+
_supports_attention_backend = True
|
| 356 |
+
|
| 357 |
+
def _init_weights(self, module):
|
| 358 |
+
std = self.config.initializer_range
|
| 359 |
+
if isinstance(module, nn.Linear):
|
| 360 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 361 |
+
if module.bias is not None:
|
| 362 |
+
module.bias.data.zero_()
|
| 363 |
+
elif isinstance(module, nn.Embedding):
|
| 364 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 365 |
+
if module.padding_idx is not None:
|
| 366 |
+
module.weight.data[module.padding_idx].zero_()
|
| 367 |
+
elif isinstance(module, Olmo2RMSNorm):
|
| 368 |
+
module.weight.data.fill_(1.0)
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
OLMO2_INPUTS_DOCSTRING = r"""
|
| 372 |
+
Args:
|
| 373 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 374 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 375 |
+
it.
|
| 376 |
+
|
| 377 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 378 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 379 |
+
|
| 380 |
+
[What are input IDs?](../glossary#input-ids)
|
| 381 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*):
|
| 382 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 383 |
+
|
| 384 |
+
- 1 for tokens that are **not masked**,
|
| 385 |
+
- 0 for tokens that are **masked**.
|
| 386 |
+
|
| 387 |
+
If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask,
|
| 388 |
+
but you can also pass a `BlockMask` object directly here.
|
| 389 |
+
|
| 390 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 391 |
+
|
| 392 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 393 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 394 |
+
|
| 395 |
+
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
|
| 396 |
+
`past_key_values`).
|
| 397 |
+
|
| 398 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
| 399 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
| 400 |
+
information on the default strategy.
|
| 401 |
+
|
| 402 |
+
- 1 indicates the head is **not masked**,
|
| 403 |
+
- 0 indicates the head is **masked**.
|
| 404 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 405 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 406 |
+
config.n_positions - 1]`.
|
| 407 |
+
|
| 408 |
+
[What are position IDs?](../glossary#position-ids)
|
| 409 |
+
past_key_values (`Cache`, *optional*):
|
| 410 |
+
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
| 411 |
+
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
| 412 |
+
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
| 413 |
+
|
| 414 |
+
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
| 415 |
+
|
| 416 |
+
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
| 417 |
+
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
| 418 |
+
of shape `(batch_size, sequence_length)`.
|
| 419 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 420 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 421 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 422 |
+
model's internal embedding lookup matrix.
|
| 423 |
+
use_cache (`bool`, *optional*):
|
| 424 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
| 425 |
+
`past_key_values`).
|
| 426 |
+
output_attentions (`bool`, *optional*):
|
| 427 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 428 |
+
tensors for more detail.
|
| 429 |
+
output_hidden_states (`bool`, *optional*):
|
| 430 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 431 |
+
more detail.
|
| 432 |
+
return_dict (`bool`, *optional*):
|
| 433 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 434 |
+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
| 435 |
+
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
| 436 |
+
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
| 437 |
+
the complete sequence length.
|
| 438 |
+
"""
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
@add_start_docstrings(
|
| 442 |
+
"The bare Olmo2 Model outputting raw hidden-states without any specific head on top.",
|
| 443 |
+
OLMO2_START_DOCSTRING,
|
| 444 |
+
)
|
| 445 |
+
class Olmo2Model(Olmo2PreTrainedModel):
|
| 446 |
+
"""
|
| 447 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Olmo2DecoderLayer`]
|
| 448 |
+
|
| 449 |
+
Args:
|
| 450 |
+
config: Olmo2Config
|
| 451 |
+
"""
|
| 452 |
+
|
| 453 |
+
def __init__(self, config: Olmo2Config):
|
| 454 |
+
super().__init__(config)
|
| 455 |
+
self.padding_idx = config.pad_token_id
|
| 456 |
+
self.vocab_size = config.vocab_size
|
| 457 |
+
|
| 458 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 459 |
+
self.layers = nn.ModuleList(
|
| 460 |
+
[Olmo2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 461 |
+
)
|
| 462 |
+
self.norm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 463 |
+
self.rotary_emb = Olmo2RotaryEmbedding(config=config)
|
| 464 |
+
self.gradient_checkpointing = False
|
| 465 |
+
|
| 466 |
+
# Initialize weights and apply final processing
|
| 467 |
+
self.post_init()
|
| 468 |
+
|
| 469 |
+
def get_input_embeddings(self):
|
| 470 |
+
return self.embed_tokens
|
| 471 |
+
|
| 472 |
+
def set_input_embeddings(self, value):
|
| 473 |
+
self.embed_tokens = value
|
| 474 |
+
|
| 475 |
+
@can_return_tuple
|
| 476 |
+
@add_start_docstrings_to_model_forward(OLMO2_INPUTS_DOCSTRING)
|
| 477 |
+
def forward(
|
| 478 |
+
self,
|
| 479 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 480 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 481 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 482 |
+
past_key_values: Optional[Cache] = None,
|
| 483 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 484 |
+
use_cache: Optional[bool] = None,
|
| 485 |
+
output_attentions: Optional[bool] = None,
|
| 486 |
+
output_hidden_states: Optional[bool] = None,
|
| 487 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 488 |
+
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
| 489 |
+
) -> BaseModelOutputWithPast:
|
| 490 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 491 |
+
output_hidden_states = (
|
| 492 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 493 |
+
)
|
| 494 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 495 |
+
|
| 496 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 497 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 498 |
+
|
| 499 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
| 500 |
+
logger.warning_once(
|
| 501 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
| 502 |
+
)
|
| 503 |
+
use_cache = False
|
| 504 |
+
|
| 505 |
+
# TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
|
| 506 |
+
if not isinstance(past_key_values, (type(None), Cache)):
|
| 507 |
+
raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
|
| 508 |
+
|
| 509 |
+
if inputs_embeds is None:
|
| 510 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 511 |
+
|
| 512 |
+
if use_cache and past_key_values is None:
|
| 513 |
+
past_key_values = DynamicCache()
|
| 514 |
+
|
| 515 |
+
if cache_position is None:
|
| 516 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 517 |
+
cache_position = torch.arange(
|
| 518 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
if position_ids is None:
|
| 522 |
+
position_ids = cache_position.unsqueeze(0)
|
| 523 |
+
|
| 524 |
+
causal_mask = self._update_causal_mask(
|
| 525 |
+
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
hidden_states = inputs_embeds
|
| 529 |
+
|
| 530 |
+
# create position embeddings to be shared across the decoder layers
|
| 531 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 532 |
+
|
| 533 |
+
# decoder layers
|
| 534 |
+
all_hidden_states = () if output_hidden_states else None
|
| 535 |
+
all_self_attns = () if output_attentions else None
|
| 536 |
+
|
| 537 |
+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
| 538 |
+
if output_hidden_states:
|
| 539 |
+
all_hidden_states += (hidden_states,)
|
| 540 |
+
|
| 541 |
+
layer_outputs = decoder_layer(
|
| 542 |
+
hidden_states,
|
| 543 |
+
attention_mask=causal_mask,
|
| 544 |
+
position_ids=position_ids,
|
| 545 |
+
past_key_value=past_key_values,
|
| 546 |
+
output_attentions=output_attentions,
|
| 547 |
+
use_cache=use_cache,
|
| 548 |
+
cache_position=cache_position,
|
| 549 |
+
position_embeddings=position_embeddings,
|
| 550 |
+
**flash_attn_kwargs,
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
hidden_states = layer_outputs[0]
|
| 554 |
+
|
| 555 |
+
if output_attentions:
|
| 556 |
+
all_self_attns += (layer_outputs[1],)
|
| 557 |
+
|
| 558 |
+
hidden_states = self.norm(hidden_states)
|
| 559 |
+
|
| 560 |
+
# add hidden states from the last decoder layer
|
| 561 |
+
if output_hidden_states:
|
| 562 |
+
all_hidden_states += (hidden_states,)
|
| 563 |
+
|
| 564 |
+
return BaseModelOutputWithPast(
|
| 565 |
+
last_hidden_state=hidden_states,
|
| 566 |
+
past_key_values=past_key_values if use_cache else None,
|
| 567 |
+
hidden_states=all_hidden_states,
|
| 568 |
+
attentions=all_self_attns,
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
def _update_causal_mask(
|
| 572 |
+
self,
|
| 573 |
+
attention_mask: Union[torch.Tensor, "BlockMask"],
|
| 574 |
+
input_tensor: torch.Tensor,
|
| 575 |
+
cache_position: torch.Tensor,
|
| 576 |
+
past_key_values: Cache,
|
| 577 |
+
output_attentions: bool = False,
|
| 578 |
+
):
|
| 579 |
+
if self.config._attn_implementation == "flash_attention_2":
|
| 580 |
+
if attention_mask is not None and (attention_mask == 0.0).any():
|
| 581 |
+
return attention_mask
|
| 582 |
+
return None
|
| 583 |
+
if self.config._attn_implementation == "flex_attention":
|
| 584 |
+
if isinstance(attention_mask, torch.Tensor):
|
| 585 |
+
attention_mask = make_flex_block_causal_mask(attention_mask)
|
| 586 |
+
return attention_mask
|
| 587 |
+
|
| 588 |
+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
| 589 |
+
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
| 590 |
+
# to infer the attention mask.
|
| 591 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 592 |
+
using_static_cache = isinstance(past_key_values, StaticCache)
|
| 593 |
+
|
| 594 |
+
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
| 595 |
+
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
| 596 |
+
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
| 597 |
+
attention_mask,
|
| 598 |
+
inputs_embeds=input_tensor,
|
| 599 |
+
past_key_values_length=past_seen_tokens,
|
| 600 |
+
is_training=self.training,
|
| 601 |
+
):
|
| 602 |
+
return None
|
| 603 |
+
|
| 604 |
+
dtype, device = input_tensor.dtype, input_tensor.device
|
| 605 |
+
sequence_length = input_tensor.shape[1]
|
| 606 |
+
if using_static_cache:
|
| 607 |
+
target_length = past_key_values.get_max_cache_shape()
|
| 608 |
+
else:
|
| 609 |
+
target_length = (
|
| 610 |
+
attention_mask.shape[-1]
|
| 611 |
+
if isinstance(attention_mask, torch.Tensor)
|
| 612 |
+
else past_seen_tokens + sequence_length + 1
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
| 616 |
+
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
| 617 |
+
attention_mask,
|
| 618 |
+
sequence_length=sequence_length,
|
| 619 |
+
target_length=target_length,
|
| 620 |
+
dtype=dtype,
|
| 621 |
+
device=device,
|
| 622 |
+
cache_position=cache_position,
|
| 623 |
+
batch_size=input_tensor.shape[0],
|
| 624 |
+
)
|
| 625 |
+
|
| 626 |
+
if (
|
| 627 |
+
self.config._attn_implementation == "sdpa"
|
| 628 |
+
and attention_mask is not None
|
| 629 |
+
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
| 630 |
+
and not output_attentions
|
| 631 |
+
):
|
| 632 |
+
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
| 633 |
+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
| 634 |
+
# Details: https://github.com/pytorch/pytorch/issues/110213
|
| 635 |
+
min_dtype = torch.finfo(dtype).min
|
| 636 |
+
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
| 637 |
+
|
| 638 |
+
return causal_mask
|
| 639 |
+
|
| 640 |
+
@staticmethod
|
| 641 |
+
def _prepare_4d_causal_attention_mask_with_cache_position(
|
| 642 |
+
attention_mask: torch.Tensor,
|
| 643 |
+
sequence_length: int,
|
| 644 |
+
target_length: int,
|
| 645 |
+
dtype: torch.dtype,
|
| 646 |
+
device: torch.device,
|
| 647 |
+
cache_position: torch.Tensor,
|
| 648 |
+
batch_size: int,
|
| 649 |
+
**kwargs,
|
| 650 |
+
):
|
| 651 |
+
"""
|
| 652 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
| 653 |
+
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
| 654 |
+
|
| 655 |
+
Args:
|
| 656 |
+
attention_mask (`torch.Tensor`):
|
| 657 |
+
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
| 658 |
+
`(batch_size, 1, query_length, key_value_length)`.
|
| 659 |
+
sequence_length (`int`):
|
| 660 |
+
The sequence length being processed.
|
| 661 |
+
target_length (`int`):
|
| 662 |
+
The target length: when generating with static cache, the mask should be as long as the static cache,
|
| 663 |
+
to account for the 0 padding, the part of the cache that is not filled yet.
|
| 664 |
+
dtype (`torch.dtype`):
|
| 665 |
+
The dtype to use for the 4D attention mask.
|
| 666 |
+
device (`torch.device`):
|
| 667 |
+
The device to place the 4D attention mask on.
|
| 668 |
+
cache_position (`torch.Tensor`):
|
| 669 |
+
Indices depicting the position of the input sequence tokens in the sequence.
|
| 670 |
+
batch_size (`torch.Tensor`):
|
| 671 |
+
Batch size.
|
| 672 |
+
"""
|
| 673 |
+
if attention_mask is not None and attention_mask.dim() == 4:
|
| 674 |
+
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
| 675 |
+
causal_mask = attention_mask
|
| 676 |
+
else:
|
| 677 |
+
min_dtype = torch.finfo(dtype).min
|
| 678 |
+
causal_mask = torch.full(
|
| 679 |
+
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
| 680 |
+
)
|
| 681 |
+
if sequence_length != 1:
|
| 682 |
+
causal_mask = torch.triu(causal_mask, diagonal=1)
|
| 683 |
+
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
| 684 |
+
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
| 685 |
+
if attention_mask is not None:
|
| 686 |
+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
| 687 |
+
mask_length = attention_mask.shape[-1]
|
| 688 |
+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
| 689 |
+
causal_mask.device
|
| 690 |
+
)
|
| 691 |
+
padding_mask = padding_mask == 0
|
| 692 |
+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
| 693 |
+
padding_mask, min_dtype
|
| 694 |
+
)
|
| 695 |
+
|
| 696 |
+
return causal_mask
|
| 697 |
+
|
| 698 |
+
|
| 699 |
+
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
| 700 |
+
|
| 701 |
+
|
| 702 |
+
class Olmo2ForCausalLM(Olmo2PreTrainedModel, GenerationMixin):
|
| 703 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 704 |
+
_tp_plan = {"lm_head": "colwise_rep"}
|
| 705 |
+
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 706 |
+
|
| 707 |
+
def __init__(self, config):
|
| 708 |
+
super().__init__(config)
|
| 709 |
+
self.model = Olmo2Model(config)
|
| 710 |
+
self.vocab_size = config.vocab_size
|
| 711 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 712 |
+
|
| 713 |
+
# Initialize weights and apply final processing
|
| 714 |
+
self.post_init()
|
| 715 |
+
|
| 716 |
+
def get_input_embeddings(self):
|
| 717 |
+
return self.model.embed_tokens
|
| 718 |
+
|
| 719 |
+
def set_input_embeddings(self, value):
|
| 720 |
+
self.model.embed_tokens = value
|
| 721 |
+
|
| 722 |
+
def get_output_embeddings(self):
|
| 723 |
+
return self.lm_head
|
| 724 |
+
|
| 725 |
+
def set_output_embeddings(self, new_embeddings):
|
| 726 |
+
self.lm_head = new_embeddings
|
| 727 |
+
|
| 728 |
+
def set_decoder(self, decoder):
|
| 729 |
+
self.model = decoder
|
| 730 |
+
|
| 731 |
+
def get_decoder(self):
|
| 732 |
+
return self.model
|
| 733 |
+
|
| 734 |
+
@can_return_tuple
|
| 735 |
+
@add_start_docstrings_to_model_forward(OLMO2_INPUTS_DOCSTRING)
|
| 736 |
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
| 737 |
+
def forward(
|
| 738 |
+
self,
|
| 739 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 740 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 741 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 742 |
+
past_key_values: Optional[Cache] = None,
|
| 743 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 744 |
+
labels: Optional[torch.LongTensor] = None,
|
| 745 |
+
use_cache: Optional[bool] = None,
|
| 746 |
+
output_attentions: Optional[bool] = None,
|
| 747 |
+
output_hidden_states: Optional[bool] = None,
|
| 748 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 749 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 750 |
+
**kwargs: Unpack[KwargsForCausalLM],
|
| 751 |
+
) -> CausalLMOutputWithPast:
|
| 752 |
+
r"""
|
| 753 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 754 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 755 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 756 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 757 |
+
|
| 758 |
+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
| 759 |
+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
| 760 |
+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
| 761 |
+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
| 762 |
+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
| 763 |
+
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
| 764 |
+
|
| 765 |
+
Returns:
|
| 766 |
+
|
| 767 |
+
Example:
|
| 768 |
+
|
| 769 |
+
```python
|
| 770 |
+
>>> from transformers import AutoTokenizer, Olmo2ForCausalLM
|
| 771 |
+
|
| 772 |
+
>>> model = Olmo2ForCausalLM.from_pretrained("meta-olmo2/Olmo2-2-7b-hf")
|
| 773 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("meta-olmo2/Olmo2-2-7b-hf")
|
| 774 |
+
|
| 775 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
| 776 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 777 |
+
|
| 778 |
+
>>> # Generate
|
| 779 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 780 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 781 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
| 782 |
+
```"""
|
| 783 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 784 |
+
output_hidden_states = (
|
| 785 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 786 |
+
)
|
| 787 |
+
|
| 788 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 789 |
+
outputs: BaseModelOutputWithPast = self.model(
|
| 790 |
+
input_ids=input_ids,
|
| 791 |
+
attention_mask=attention_mask,
|
| 792 |
+
position_ids=position_ids,
|
| 793 |
+
past_key_values=past_key_values,
|
| 794 |
+
inputs_embeds=inputs_embeds,
|
| 795 |
+
use_cache=use_cache,
|
| 796 |
+
output_attentions=output_attentions,
|
| 797 |
+
output_hidden_states=output_hidden_states,
|
| 798 |
+
cache_position=cache_position,
|
| 799 |
+
**kwargs,
|
| 800 |
+
)
|
| 801 |
+
|
| 802 |
+
hidden_states = outputs.last_hidden_state
|
| 803 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 804 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 805 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 806 |
+
|
| 807 |
+
loss = None
|
| 808 |
+
if labels is not None:
|
| 809 |
+
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
| 810 |
+
|
| 811 |
+
return CausalLMOutputWithPast(
|
| 812 |
+
loss=loss,
|
| 813 |
+
logits=logits,
|
| 814 |
+
past_key_values=outputs.past_key_values,
|
| 815 |
+
hidden_states=outputs.hidden_states,
|
| 816 |
+
attentions=outputs.attentions,
|
| 817 |
+
)
|
| 818 |
+
|
| 819 |
+
|
| 820 |
+
__all__ = ["Olmo2ForCausalLM", "Olmo2Model", "Olmo2PreTrainedModel"]
|
docs/transformers/src/transformers/models/olmo2/modular_olmo2.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable, Optional, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
from ...cache_utils import Cache
|
| 7 |
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
| 8 |
+
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
| 9 |
+
from ...utils import logging
|
| 10 |
+
from ..llama.modeling_llama import LlamaPreTrainedModel, LlamaRMSNorm, eager_attention_forward
|
| 11 |
+
from ..olmo.configuration_olmo import OlmoConfig
|
| 12 |
+
from ..olmo.modeling_olmo import (
|
| 13 |
+
OlmoAttention,
|
| 14 |
+
OlmoDecoderLayer,
|
| 15 |
+
OlmoForCausalLM,
|
| 16 |
+
OlmoModel,
|
| 17 |
+
OlmoRotaryEmbedding,
|
| 18 |
+
apply_rotary_pos_emb,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
logger = logging.get_logger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class Olmo2Config(OlmoConfig):
|
| 26 |
+
r"""
|
| 27 |
+
This is the configuration class to store the configuration of a [`Olmo2Model`]. It is used to instantiate an OLMo2
|
| 28 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 29 |
+
defaults will yield a similar configuration to that of the [allenai/Olmo2-7B-1124-hf](https://huggingface.co/allenai/Olmo2-7B-1124-hf).
|
| 30 |
+
|
| 31 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 32 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
vocab_size (`int`, *optional*, defaults to 50304):
|
| 37 |
+
Vocabulary size of the Olmo2 model. Defines the number of different tokens that can be represented by the
|
| 38 |
+
`inputs_ids` passed when calling [`Olmo2Model`]
|
| 39 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
| 40 |
+
Dimension of the hidden representations.
|
| 41 |
+
intermediate_size (`int`, *optional*, defaults to 11008):
|
| 42 |
+
Dimension of the MLP representations.
|
| 43 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
| 44 |
+
Number of hidden layers in the Transformer decoder.
|
| 45 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
| 46 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
| 47 |
+
num_key_value_heads (`int`, *optional*):
|
| 48 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
| 49 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
| 50 |
+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
| 51 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
| 52 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
| 53 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
| 54 |
+
`num_attention_heads`.
|
| 55 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 56 |
+
The non-linear activation function (function or string) in the decoder.
|
| 57 |
+
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
| 58 |
+
The maximum sequence length that this model might ever be used with.
|
| 59 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 60 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 61 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 62 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 63 |
+
relevant if `config.is_decoder=True`.
|
| 64 |
+
pad_token_id (`int`, *optional*, defaults to 1):
|
| 65 |
+
Padding token id.
|
| 66 |
+
bos_token_id (`int`, *optional*):
|
| 67 |
+
Beginning of stream token id.
|
| 68 |
+
eos_token_id (`int`, *optional*, defaults to 50279):
|
| 69 |
+
End of stream token id.
|
| 70 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 71 |
+
Whether to tie weight embeddings
|
| 72 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
| 73 |
+
The base period of the RoPE embeddings.
|
| 74 |
+
rope_scaling (`Dict`, *optional*):
|
| 75 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
| 76 |
+
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
|
| 77 |
+
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
| 78 |
+
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
|
| 79 |
+
these scaling strategies behave:
|
| 80 |
+
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
|
| 81 |
+
experimental feature, subject to breaking API changes in future versions.
|
| 82 |
+
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
| 83 |
+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
| 84 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 85 |
+
The dropout ratio for the attention probabilities.
|
| 86 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 87 |
+
The epsilon used by the rms normalization layers.
|
| 88 |
+
|
| 89 |
+
```python
|
| 90 |
+
>>> from transformers import Olmo2Model, Olmo2Config
|
| 91 |
+
|
| 92 |
+
>>> # Initializing a Olmo2 7B style configuration
|
| 93 |
+
>>> configuration = Olmo2Config()
|
| 94 |
+
|
| 95 |
+
>>> # Initializing a model from the Olmo2 7B style configuration
|
| 96 |
+
>>> model = Olmo2Model(configuration)
|
| 97 |
+
|
| 98 |
+
>>> # Accessing the model configuration
|
| 99 |
+
>>> configuration = model.config
|
| 100 |
+
```
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
model_type = "olmo2"
|
| 104 |
+
base_model_tp_plan = {
|
| 105 |
+
"layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
|
| 106 |
+
"layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
|
| 107 |
+
"layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
|
| 108 |
+
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
|
| 109 |
+
"layers.*.mlp.gate_proj": "colwise",
|
| 110 |
+
"layers.*.mlp.up_proj": "colwise",
|
| 111 |
+
"layers.*.mlp.down_proj": "rowwise",
|
| 112 |
+
}
|
| 113 |
+
base_model_pp_plan = {
|
| 114 |
+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
| 115 |
+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
| 116 |
+
"norm": (["hidden_states"], ["hidden_states"]),
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
def __init__(
|
| 120 |
+
self,
|
| 121 |
+
vocab_size=50304,
|
| 122 |
+
hidden_size=4096,
|
| 123 |
+
intermediate_size=11008,
|
| 124 |
+
num_hidden_layers=32,
|
| 125 |
+
num_attention_heads=32,
|
| 126 |
+
num_key_value_heads=None,
|
| 127 |
+
hidden_act="silu",
|
| 128 |
+
max_position_embeddings=2048,
|
| 129 |
+
initializer_range=0.02,
|
| 130 |
+
use_cache=True,
|
| 131 |
+
pad_token_id=1,
|
| 132 |
+
bos_token_id=None,
|
| 133 |
+
eos_token_id=50279,
|
| 134 |
+
tie_word_embeddings=False,
|
| 135 |
+
rope_theta=10000.0,
|
| 136 |
+
rope_scaling=None,
|
| 137 |
+
attention_bias=False,
|
| 138 |
+
attention_dropout=0.0,
|
| 139 |
+
rms_norm_eps=1e-5,
|
| 140 |
+
**kwargs,
|
| 141 |
+
):
|
| 142 |
+
super().__init__(
|
| 143 |
+
vocab_size=vocab_size,
|
| 144 |
+
hidden_size=hidden_size,
|
| 145 |
+
intermediate_size=intermediate_size,
|
| 146 |
+
num_hidden_layers=num_hidden_layers,
|
| 147 |
+
num_attention_heads=num_attention_heads,
|
| 148 |
+
num_key_value_heads=num_key_value_heads,
|
| 149 |
+
hidden_act=hidden_act,
|
| 150 |
+
max_position_embeddings=max_position_embeddings,
|
| 151 |
+
initializer_range=initializer_range,
|
| 152 |
+
use_cache=use_cache,
|
| 153 |
+
pad_token_id=pad_token_id,
|
| 154 |
+
bos_token_id=bos_token_id,
|
| 155 |
+
eos_token_id=eos_token_id,
|
| 156 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 157 |
+
rope_theta=rope_theta,
|
| 158 |
+
rope_scaling=rope_scaling,
|
| 159 |
+
attention_bias=attention_bias,
|
| 160 |
+
attention_dropout=attention_dropout,
|
| 161 |
+
**kwargs,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
self.rms_norm_eps = rms_norm_eps
|
| 165 |
+
del self.clip_qkv
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class Olmo2RMSNorm(LlamaRMSNorm):
|
| 169 |
+
pass
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
ALL_LAYERNORM_LAYERS.append(Olmo2RMSNorm)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
# Olmo2 attention is identical to OLMo attention except:
|
| 176 |
+
# - Norm is applied to attention queries and keys.
|
| 177 |
+
# - No qkv clipping.
|
| 178 |
+
class Olmo2Attention(OlmoAttention):
|
| 179 |
+
def __init__(self, config: Olmo2Config, layer_idx: Optional[int] = None):
|
| 180 |
+
super().__init__(config, layer_idx=layer_idx)
|
| 181 |
+
self.q_norm = Olmo2RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps)
|
| 182 |
+
self.k_norm = Olmo2RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps)
|
| 183 |
+
|
| 184 |
+
def forward(
|
| 185 |
+
self,
|
| 186 |
+
hidden_states: torch.Tensor,
|
| 187 |
+
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
| 188 |
+
attention_mask: Optional[torch.Tensor],
|
| 189 |
+
past_key_value: Optional[Cache] = None,
|
| 190 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 191 |
+
**kwargs,
|
| 192 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 193 |
+
input_shape = hidden_states.shape[:-1]
|
| 194 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 195 |
+
|
| 196 |
+
query_states = self.q_norm(self.q_proj(hidden_states))
|
| 197 |
+
key_states = self.k_norm(self.k_proj(hidden_states))
|
| 198 |
+
value_states = self.v_proj(hidden_states)
|
| 199 |
+
|
| 200 |
+
query_states = query_states.view(hidden_shape).transpose(1, 2)
|
| 201 |
+
key_states = key_states.view(hidden_shape).transpose(1, 2)
|
| 202 |
+
value_states = value_states.view(hidden_shape).transpose(1, 2)
|
| 203 |
+
|
| 204 |
+
cos, sin = position_embeddings
|
| 205 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 206 |
+
|
| 207 |
+
if past_key_value is not None:
|
| 208 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 209 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 210 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 211 |
+
|
| 212 |
+
attention_interface: Callable = eager_attention_forward
|
| 213 |
+
if self.config._attn_implementation != "eager":
|
| 214 |
+
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
| 215 |
+
logger.warning_once(
|
| 216 |
+
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
| 217 |
+
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
| 218 |
+
)
|
| 219 |
+
else:
|
| 220 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 221 |
+
|
| 222 |
+
attn_output, attn_weights = attention_interface(
|
| 223 |
+
self,
|
| 224 |
+
query_states,
|
| 225 |
+
key_states,
|
| 226 |
+
value_states,
|
| 227 |
+
attention_mask,
|
| 228 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 229 |
+
scaling=self.scaling,
|
| 230 |
+
**kwargs,
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 234 |
+
attn_output = self.o_proj(attn_output)
|
| 235 |
+
return attn_output, attn_weights
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
# The OLMo2 layers are identical to those of the OLMo model except:
|
| 239 |
+
# - RMSNorm is used instead of standard layer norm.
|
| 240 |
+
# - Norm is applied after attention/feedforward rather than before.
|
| 241 |
+
class Olmo2DecoderLayer(OlmoDecoderLayer):
|
| 242 |
+
def __init__(self, config: Olmo2Config, layer_idx: int):
|
| 243 |
+
super().__init__(config, layer_idx=layer_idx)
|
| 244 |
+
self.post_attention_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 245 |
+
self.post_feedforward_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 246 |
+
self.self_attn = Olmo2Attention(config=config, layer_idx=layer_idx)
|
| 247 |
+
del self.input_layernorm
|
| 248 |
+
|
| 249 |
+
def forward(
|
| 250 |
+
self,
|
| 251 |
+
hidden_states: torch.Tensor,
|
| 252 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 253 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 254 |
+
past_key_value: Optional[Cache] = None,
|
| 255 |
+
output_attentions: Optional[bool] = False,
|
| 256 |
+
use_cache: Optional[bool] = False,
|
| 257 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 258 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
| 259 |
+
**kwargs,
|
| 260 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 261 |
+
residual = hidden_states
|
| 262 |
+
|
| 263 |
+
# Self Attention
|
| 264 |
+
hidden_states, self_attn_weights = self.self_attn(
|
| 265 |
+
hidden_states=hidden_states,
|
| 266 |
+
attention_mask=attention_mask,
|
| 267 |
+
position_ids=position_ids,
|
| 268 |
+
past_key_value=past_key_value,
|
| 269 |
+
output_attentions=output_attentions,
|
| 270 |
+
use_cache=use_cache,
|
| 271 |
+
cache_position=cache_position,
|
| 272 |
+
position_embeddings=position_embeddings,
|
| 273 |
+
**kwargs,
|
| 274 |
+
)
|
| 275 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 276 |
+
hidden_states = residual + hidden_states
|
| 277 |
+
|
| 278 |
+
# Fully Connected
|
| 279 |
+
residual = hidden_states
|
| 280 |
+
hidden_states = self.mlp(hidden_states)
|
| 281 |
+
hidden_states = self.post_feedforward_layernorm(hidden_states)
|
| 282 |
+
hidden_states = residual + hidden_states
|
| 283 |
+
|
| 284 |
+
outputs = (hidden_states,)
|
| 285 |
+
if output_attentions:
|
| 286 |
+
outputs += (self_attn_weights,)
|
| 287 |
+
|
| 288 |
+
return outputs
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
class Olmo2RotaryEmbedding(OlmoRotaryEmbedding):
|
| 292 |
+
pass
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
class Olmo2PreTrainedModel(LlamaPreTrainedModel):
|
| 296 |
+
pass
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
# The OLMo2 model is identical to the OLMo model, except RMSNorm is used instead of
|
| 300 |
+
# standard layer norm for the output norm.
|
| 301 |
+
class Olmo2Model(OlmoModel):
|
| 302 |
+
def __init__(self, config: Olmo2Config):
|
| 303 |
+
super().__init__(config)
|
| 304 |
+
self.norm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 305 |
+
self.layers = nn.ModuleList(
|
| 306 |
+
[Olmo2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
# The heads now only need to redefine the model inside to the correct `RobertaModel`
|
| 311 |
+
class Olmo2ForCausalLM(OlmoForCausalLM):
|
| 312 |
+
pass
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
__all__ = [
|
| 316 |
+
"Olmo2Config",
|
| 317 |
+
"Olmo2ForCausalLM",
|
| 318 |
+
"Olmo2Model",
|
| 319 |
+
"Olmo2PreTrainedModel", # noqa: F822
|
| 320 |
+
]
|
docs/transformers/src/transformers/models/olmoe/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_olmoe import *
|
| 22 |
+
from .modeling_olmoe import *
|
| 23 |
+
else:
|
| 24 |
+
import sys
|
| 25 |
+
|
| 26 |
+
_file = globals()["__file__"]
|
| 27 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/src/transformers/models/olmoe/configuration_olmoe.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 2 |
+
# you may not use this file except in compliance with the License.
|
| 3 |
+
# You may obtain a copy of the License at
|
| 4 |
+
#
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
#
|
| 7 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 8 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 9 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 10 |
+
# See the License for the specific language governing permissions and
|
| 11 |
+
# limitations under the License.
|
| 12 |
+
"""OLMoE model configuration"""
|
| 13 |
+
|
| 14 |
+
from ...configuration_utils import PretrainedConfig
|
| 15 |
+
from ...modeling_rope_utils import rope_config_validation
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class OlmoeConfig(PretrainedConfig):
|
| 19 |
+
r"""
|
| 20 |
+
This is the configuration class to store the configuration of a [`OlmoeModel`]. It is used to instantiate an OLMoE
|
| 21 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 22 |
+
defaults will yield a similar configuration to that of the [allenai/OLMoE-1B-7B-0924](https://huggingface.co/allenai/OLMoE-1B-7B-0924).
|
| 23 |
+
|
| 24 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 25 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
vocab_size (`int`, *optional*, defaults to 50304):
|
| 30 |
+
Vocabulary size of the OLMoE model. Defines the number of different tokens that can be represented by the
|
| 31 |
+
`inputs_ids` passed when calling [`OlmoeModel`]
|
| 32 |
+
hidden_size (`int`, *optional*, defaults to 2048):
|
| 33 |
+
Dimension of the hidden representations.
|
| 34 |
+
intermediate_size (`int`, *optional*, defaults to 2048):
|
| 35 |
+
Dimension of the MLP representations.
|
| 36 |
+
num_hidden_layers (`int`, *optional*, defaults to 16):
|
| 37 |
+
Number of hidden layers in the Transformer decoder.
|
| 38 |
+
num_attention_heads (`int`, *optional*, defaults to 16):
|
| 39 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
| 40 |
+
num_key_value_heads (`int`, *optional*):
|
| 41 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
| 42 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
| 43 |
+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
| 44 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
| 45 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
| 46 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
| 47 |
+
`num_attention_heads`.
|
| 48 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 49 |
+
The non-linear activation function (function or string) in the decoder.
|
| 50 |
+
max_position_embeddings (`int`, *optional*, defaults to 4096):
|
| 51 |
+
The maximum sequence length that this model might ever be used with.
|
| 52 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 53 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 54 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 55 |
+
The epsilon used by the rms normalization layers.
|
| 56 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 57 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 58 |
+
relevant if `config.is_decoder=True`.
|
| 59 |
+
pad_token_id (`int`, *optional*, defaults to 1):
|
| 60 |
+
Padding token id.
|
| 61 |
+
bos_token_id (`int`, *optional*):
|
| 62 |
+
Beginning of stream token id.
|
| 63 |
+
eos_token_id (`int`, *optional*, defaults to 50279):
|
| 64 |
+
End of stream token id.
|
| 65 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 66 |
+
Whether to tie weight embeddings
|
| 67 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
| 68 |
+
The base period of the RoPE embeddings.
|
| 69 |
+
rope_scaling (`Dict`, *optional*):
|
| 70 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
| 71 |
+
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
|
| 72 |
+
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
| 73 |
+
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
|
| 74 |
+
these scaling strategies behave:
|
| 75 |
+
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
|
| 76 |
+
experimental feature, subject to breaking API changes in future versions.
|
| 77 |
+
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
| 78 |
+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
| 79 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 80 |
+
The dropout ratio for the attention probabilities.
|
| 81 |
+
clip_qkv (`float`, *optional*):
|
| 82 |
+
If not `None`, elements of query, key and value attention states are clipped so that their
|
| 83 |
+
absolute value does not exceed this value.
|
| 84 |
+
num_experts_per_tok (`int`, *optional*, defaults to 8):
|
| 85 |
+
Number of selected experts.
|
| 86 |
+
num_experts (`int`, *optional*, defaults to 64):
|
| 87 |
+
Number of routed experts.
|
| 88 |
+
output_router_logits (`bool`, *optional*, defaults to `False`):
|
| 89 |
+
Whether or not the router logits should be returned by the model. Enabeling this will also
|
| 90 |
+
allow the model to output the auxiliary loss, including load balancing loss and router z-loss.
|
| 91 |
+
router_aux_loss_coef (`float`, *optional*, defaults to 0.01):
|
| 92 |
+
The aux loss factor for the total loss.
|
| 93 |
+
norm_topk_prob (`bool`, *optional*, defaults to `False`):
|
| 94 |
+
Whether to normalize the topk probabilities.
|
| 95 |
+
|
| 96 |
+
```python
|
| 97 |
+
>>> from transformers import OlmoeModel, OlmoeConfig
|
| 98 |
+
|
| 99 |
+
>>> # Initializing a OLMoE 7B A1B style configuration
|
| 100 |
+
>>> configuration = OlmoeConfig()
|
| 101 |
+
|
| 102 |
+
>>> # Initializing a model from the OLMoE 7B A1B style configuration
|
| 103 |
+
>>> model = OlmoeModel(configuration)
|
| 104 |
+
|
| 105 |
+
>>> # Accessing the model configuration
|
| 106 |
+
>>> configuration = model.config
|
| 107 |
+
```"""
|
| 108 |
+
|
| 109 |
+
model_type = "olmoe"
|
| 110 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 111 |
+
|
| 112 |
+
def __init__(
|
| 113 |
+
self,
|
| 114 |
+
vocab_size=50304,
|
| 115 |
+
hidden_size=2048,
|
| 116 |
+
intermediate_size=2048,
|
| 117 |
+
num_hidden_layers=16,
|
| 118 |
+
num_attention_heads=16,
|
| 119 |
+
num_key_value_heads=None,
|
| 120 |
+
hidden_act="silu",
|
| 121 |
+
max_position_embeddings=4096,
|
| 122 |
+
initializer_range=0.02,
|
| 123 |
+
rms_norm_eps=1e-05,
|
| 124 |
+
use_cache=True,
|
| 125 |
+
pad_token_id=1,
|
| 126 |
+
bos_token_id=None,
|
| 127 |
+
eos_token_id=50279,
|
| 128 |
+
tie_word_embeddings=False,
|
| 129 |
+
rope_theta=10000.0,
|
| 130 |
+
rope_scaling=None,
|
| 131 |
+
attention_bias=False,
|
| 132 |
+
attention_dropout=0.0,
|
| 133 |
+
clip_qkv=None,
|
| 134 |
+
num_experts_per_tok=8,
|
| 135 |
+
num_experts=64,
|
| 136 |
+
output_router_logits=False,
|
| 137 |
+
router_aux_loss_coef=0.01,
|
| 138 |
+
norm_topk_prob=False,
|
| 139 |
+
**kwargs,
|
| 140 |
+
):
|
| 141 |
+
self.vocab_size = vocab_size
|
| 142 |
+
self.max_position_embeddings = max_position_embeddings
|
| 143 |
+
self.hidden_size = hidden_size
|
| 144 |
+
self.intermediate_size = intermediate_size
|
| 145 |
+
self.num_hidden_layers = num_hidden_layers
|
| 146 |
+
self.num_attention_heads = num_attention_heads
|
| 147 |
+
|
| 148 |
+
# for backward compatibility
|
| 149 |
+
if num_key_value_heads is None:
|
| 150 |
+
num_key_value_heads = num_attention_heads
|
| 151 |
+
|
| 152 |
+
self.num_key_value_heads = num_key_value_heads
|
| 153 |
+
self.hidden_act = hidden_act
|
| 154 |
+
self.initializer_range = initializer_range
|
| 155 |
+
self.rms_norm_eps = rms_norm_eps
|
| 156 |
+
self.use_cache = use_cache
|
| 157 |
+
self.rope_theta = rope_theta
|
| 158 |
+
self.rope_scaling = rope_scaling
|
| 159 |
+
self.attention_bias = attention_bias
|
| 160 |
+
self.attention_dropout = attention_dropout
|
| 161 |
+
self.clip_qkv = clip_qkv
|
| 162 |
+
self.num_experts_per_tok = num_experts_per_tok
|
| 163 |
+
self.num_experts = num_experts
|
| 164 |
+
self.output_router_logits = output_router_logits
|
| 165 |
+
self.router_aux_loss_coef = router_aux_loss_coef
|
| 166 |
+
self.norm_topk_prob = norm_topk_prob
|
| 167 |
+
# Validate the correctness of rotary position embeddings parameters
|
| 168 |
+
# BC: if there is a 'type' field, move it to 'rope_type'.
|
| 169 |
+
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
| 170 |
+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
| 171 |
+
rope_config_validation(self)
|
| 172 |
+
|
| 173 |
+
super().__init__(
|
| 174 |
+
pad_token_id=pad_token_id,
|
| 175 |
+
bos_token_id=bos_token_id,
|
| 176 |
+
eos_token_id=eos_token_id,
|
| 177 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 178 |
+
**kwargs,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
__all__ = ["OlmoeConfig"]
|
docs/transformers/src/transformers/models/olmoe/convert_olmoe_weights_to_hf.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 2 |
+
# you may not use this file except in compliance with the License.
|
| 3 |
+
# You may obtain a copy of the License at
|
| 4 |
+
#
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
#
|
| 7 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 8 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 9 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 10 |
+
# See the License for the specific language governing permissions and
|
| 11 |
+
# limitations under the License.
|
| 12 |
+
"""
|
| 13 |
+
Example for running:
|
| 14 |
+
0. Cp ckpts to local
|
| 15 |
+
aws s3 cp --recursive s3://ai2-llm/checkpoints/OLMoE/olmoe-8x1b-newhp-newds-final-annealFrom1200000/step23842 /data/niklas/llm/checkpoints/olmoe-8x1b-newhp-newds-final-annealFrom1200000_step23842
|
| 16 |
+
1. Unshard your OLMoE checkpoint using https://github.com/allenai/OLMo/blob/7d63fe09d23cf23714da5aa633a44a90180195da/scripts/unshard.py
|
| 17 |
+
python OLMo/scripts/unshard.py /data/niklas/llm/checkpoints/23485/step954000 /data/niklas/llm/checkpoints/1b-954000-unsharded --model-only
|
| 18 |
+
python OLMo/scripts/unshard.py /data/niklas/llm/checkpoints/23485/step954000 /data/niklas/llm/checkpoints/1b-954000-unsharded --model-only
|
| 19 |
+
python OLMo/scripts/unshard.py /data/niklas/llm/checkpoints/olmoe-8x1b-newhp-newds-final-annealFrom1200000_step23842 /data/niklas/llm/checkpoints/olmoe-8x1b-newhp-newds-final-annealFrom1200000_step23842-unsharded --model-only
|
| 20 |
+
2. Convert to transformers
|
| 21 |
+
rm -rf olmoe; mkdir olmoe; python /data/niklas/transformers/src/transformers/models/olmoe/convert_olmoe_weights_to_hf.py --input_dir /data/niklas/llm/checkpoints/olmoe-8x1b-newhp-newds-final-annealFrom1200000_step23842-unsharded --tokenizer_json_path /data/niklas/llm/checkpoints/olmoe-step1200000-unsharded/tokenizer.json --output_dir olmoe
|
| 22 |
+
3. Load model via:
|
| 23 |
+
```
|
| 24 |
+
from transformers import OlmoeForCausalLM, AutoTokenizer
|
| 25 |
+
import torch
|
| 26 |
+
model = OlmoeForCausalLM.from_pretrained("../transformers/olmoe", torch_dtype=torch.bfloat16).cuda()
|
| 27 |
+
model = OlmoeForCausalLM.from_pretrained("../transformers/olmoe").cuda()
|
| 28 |
+
tokenizer = AutoTokenizer.from_pretrained("../transformers/olmoe")
|
| 29 |
+
inputs = tokenizer("Bitcoin is", return_tensors="pt")
|
| 30 |
+
inputs = {k: v.cuda() for k, v in inputs.items()}
|
| 31 |
+
out = model.generate(**inputs, max_length=64)
|
| 32 |
+
print(tokenizer.decode(out[0]))
|
| 33 |
+
# > # Bitcoin is a digital currency that is created and held electronically. No one controls it. Bitcoins aren’t printed, like dollars or euros – they’re produced by people and businesses running computers all around the world, using software that solves mathematical
|
| 34 |
+
# Or quick sanity check:
|
| 35 |
+
o = model(torch.tensor([[0, 1]]).cuda())
|
| 36 |
+
# If the checkpoint is not converted to BF16 but kept in FP32:
|
| 37 |
+
# > # Bitcoin is a digital currency that is not controlled by any central authority. It is a peer-to-peer payment system that allows users to send and receive payments from anywhere in the world. Bitcoin is also known as a cryptocurrency because it uses cryptography to secure transactions and prevent fraud.
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
Note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
|
| 41 |
+
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
|
| 42 |
+
|
| 43 |
+
Compare with OLMo codebase:
|
| 44 |
+
```
|
| 45 |
+
from olmo.model import OLMo
|
| 46 |
+
import torch
|
| 47 |
+
model = OLMo.from_checkpoint("/data/niklas/llm/checkpoints/olmoe-step1200000-unsharded-pt")
|
| 48 |
+
model = model.cuda()
|
| 49 |
+
model = model.to(torch.bfloat16)
|
| 50 |
+
from transformers import AutoTokenizer
|
| 51 |
+
tokenizer = AutoTokenizer.from_pretrained("../transformers/olmoe")
|
| 52 |
+
inputs = tokenizer("Bitcoin is", return_tensors="pt")
|
| 53 |
+
inputs = {k: v.cuda() for k, v in inputs.items()}
|
| 54 |
+
out = model.generate(**inputs)
|
| 55 |
+
print(tokenizer.decode(out[0][0][0]))
|
| 56 |
+
# Bitcoin is a digital currency that is created and held electronically. No one controls it. Bitcoins aren’t printed, like dollars or euros – they’re produced by people and businesses running computers all around the world, using software that solves mathematical problems. It’s the first example of a growing category of money
|
| 57 |
+
# Or quick sanity check:
|
| 58 |
+
o = model(torch.tensor([[0, 1]]).cuda())
|
| 59 |
+
```
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
import argparse
|
| 63 |
+
import gc
|
| 64 |
+
import json
|
| 65 |
+
import os
|
| 66 |
+
import shutil
|
| 67 |
+
from pathlib import Path
|
| 68 |
+
|
| 69 |
+
import torch
|
| 70 |
+
import yaml
|
| 71 |
+
from tokenizers import Tokenizer
|
| 72 |
+
|
| 73 |
+
from transformers import OlmoeConfig, OlmoeForCausalLM
|
| 74 |
+
from transformers.models.gpt_neox.tokenization_gpt_neox_fast import GPTNeoXTokenizerFast
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
|
| 78 |
+
return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def read_json(path):
|
| 82 |
+
with open(path, "r") as f:
|
| 83 |
+
return json.load(f)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def write_json(text, path):
|
| 87 |
+
with open(path, "w") as f:
|
| 88 |
+
json.dump(text, f)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def write_model(model_path, input_base_path, tokenizer_path=None, safe_serialization=True, fix_eos_token_id=True):
|
| 92 |
+
os.makedirs(model_path, exist_ok=True)
|
| 93 |
+
tmp_model_path = os.path.join(model_path, "tmp")
|
| 94 |
+
os.makedirs(tmp_model_path, exist_ok=True)
|
| 95 |
+
|
| 96 |
+
config_path = Path(input_base_path) / "config.yaml"
|
| 97 |
+
olmoe_config = yaml.safe_load(config_path.read_text())["model"]
|
| 98 |
+
|
| 99 |
+
if fix_eos_token_id:
|
| 100 |
+
olmoe_config["eos_token_id"] = 50279
|
| 101 |
+
|
| 102 |
+
n_layers = olmoe_config["n_layers"]
|
| 103 |
+
n_heads = olmoe_config["n_heads"]
|
| 104 |
+
dim = olmoe_config["d_model"]
|
| 105 |
+
dims_per_head = dim // n_heads
|
| 106 |
+
base = 10000.0
|
| 107 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
|
| 108 |
+
max_position_embeddings = olmoe_config["max_sequence_length"]
|
| 109 |
+
|
| 110 |
+
vocab_size = olmoe_config.get("embedding_size", olmoe_config["vocab_size"])
|
| 111 |
+
|
| 112 |
+
if olmoe_config.get("n_kv_heads", None) is not None:
|
| 113 |
+
num_key_value_heads = olmoe_config["n_kv_heads"] # for GQA / MQA
|
| 114 |
+
elif olmoe_config["multi_query_attention"]: # compatibility with other checkpoints
|
| 115 |
+
num_key_value_heads = 1
|
| 116 |
+
else:
|
| 117 |
+
num_key_value_heads = n_heads
|
| 118 |
+
|
| 119 |
+
print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
|
| 120 |
+
|
| 121 |
+
# Not sharded
|
| 122 |
+
loaded = torch.load(os.path.join(input_base_path, "model.pt"), map_location="cpu", weights_only=True)
|
| 123 |
+
|
| 124 |
+
param_count = 0
|
| 125 |
+
index_dict = {"weight_map": {}}
|
| 126 |
+
for layer_i in range(n_layers):
|
| 127 |
+
filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
|
| 128 |
+
fused_dims = [dim, dims_per_head * num_key_value_heads, dims_per_head * num_key_value_heads]
|
| 129 |
+
q_proj_weight, k_proj_weight, v_proj_weight = torch.split(
|
| 130 |
+
loaded[f"transformer.blocks.{layer_i}.att_proj.weight"], fused_dims, dim=0
|
| 131 |
+
)
|
| 132 |
+
state_dict = {
|
| 133 |
+
f"model.layers.{layer_i}.self_attn.q_proj.weight": q_proj_weight,
|
| 134 |
+
f"model.layers.{layer_i}.self_attn.k_proj.weight": k_proj_weight,
|
| 135 |
+
f"model.layers.{layer_i}.self_attn.v_proj.weight": v_proj_weight,
|
| 136 |
+
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"transformer.blocks.{layer_i}.attn_out.weight"],
|
| 137 |
+
f"model.layers.{layer_i}.self_attn.q_norm.weight": loaded[f"transformer.blocks.{layer_i}.q_norm.weight"],
|
| 138 |
+
f"model.layers.{layer_i}.self_attn.k_norm.weight": loaded[f"transformer.blocks.{layer_i}.k_norm.weight"],
|
| 139 |
+
f"model.layers.{layer_i}.mlp.gate.weight": loaded[f"transformer.blocks.{layer_i}.ffn.router.layer.weight"],
|
| 140 |
+
f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"transformer.blocks.{layer_i}.attn_norm.weight"],
|
| 141 |
+
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[
|
| 142 |
+
f"transformer.blocks.{layer_i}.ff_norm.weight"
|
| 143 |
+
],
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
num_experts = loaded[f"transformer.blocks.{layer_i}.ffn.router.layer.weight"].shape[0]
|
| 147 |
+
dim_per_expert = loaded[f"transformer.blocks.{layer_i}.ffn.experts.mlp.w1"].shape[0] // num_experts
|
| 148 |
+
for expert_i in range(num_experts):
|
| 149 |
+
state_dict[f"model.layers.{layer_i}.mlp.experts.{expert_i}.gate_proj.weight"] = loaded[
|
| 150 |
+
f"transformer.blocks.{layer_i}.ffn.experts.mlp.w1"
|
| 151 |
+
][dim_per_expert * expert_i : dim_per_expert * (expert_i + 1), :]
|
| 152 |
+
state_dict[f"model.layers.{layer_i}.mlp.experts.{expert_i}.up_proj.weight"] = loaded[
|
| 153 |
+
f"transformer.blocks.{layer_i}.ffn.experts.mlp.v1"
|
| 154 |
+
][dim_per_expert * expert_i : dim_per_expert * (expert_i + 1), :]
|
| 155 |
+
state_dict[f"model.layers.{layer_i}.mlp.experts.{expert_i}.down_proj.weight"] = loaded[
|
| 156 |
+
f"transformer.blocks.{layer_i}.ffn.experts.mlp.w2"
|
| 157 |
+
][dim_per_expert * expert_i : dim_per_expert * (expert_i + 1), :].T.contiguous()
|
| 158 |
+
|
| 159 |
+
state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
|
| 160 |
+
|
| 161 |
+
for k, v in state_dict.items():
|
| 162 |
+
index_dict["weight_map"][k] = filename
|
| 163 |
+
param_count += v.numel()
|
| 164 |
+
torch.save(state_dict, os.path.join(tmp_model_path, filename))
|
| 165 |
+
|
| 166 |
+
filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
|
| 167 |
+
|
| 168 |
+
# Unsharded
|
| 169 |
+
state_dict = {
|
| 170 |
+
"model.embed_tokens.weight": loaded["transformer.wte.weight"],
|
| 171 |
+
"lm_head.weight": loaded["transformer.ff_out.weight"],
|
| 172 |
+
"model.norm.weight": loaded["transformer.ln_f.weight"],
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
for k, v in state_dict.items():
|
| 176 |
+
index_dict["weight_map"][k] = filename
|
| 177 |
+
param_count += v.numel()
|
| 178 |
+
torch.save(state_dict, os.path.join(tmp_model_path, filename))
|
| 179 |
+
|
| 180 |
+
# Write configs
|
| 181 |
+
index_dict["metadata"] = {"total_size": param_count * 2}
|
| 182 |
+
write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
|
| 183 |
+
|
| 184 |
+
config = OlmoeConfig(
|
| 185 |
+
vocab_size=vocab_size,
|
| 186 |
+
hidden_size=dim,
|
| 187 |
+
intermediate_size=dim_per_expert,
|
| 188 |
+
num_hidden_layers=n_layers,
|
| 189 |
+
num_attention_heads=n_heads,
|
| 190 |
+
num_key_value_heads=num_key_value_heads,
|
| 191 |
+
max_position_embeddings=max_position_embeddings,
|
| 192 |
+
pad_token_id=olmoe_config["pad_token_id"],
|
| 193 |
+
bos_token_id=None,
|
| 194 |
+
eos_token_id=olmoe_config["eos_token_id"],
|
| 195 |
+
tie_word_embeddings=olmoe_config["weight_tying"],
|
| 196 |
+
rope_theta=base,
|
| 197 |
+
clip_qkv=olmoe_config.get("clip_qkv"),
|
| 198 |
+
)
|
| 199 |
+
config.save_pretrained(tmp_model_path)
|
| 200 |
+
|
| 201 |
+
# Make space so we can load the model properly now.
|
| 202 |
+
del state_dict
|
| 203 |
+
del loaded
|
| 204 |
+
gc.collect()
|
| 205 |
+
|
| 206 |
+
if tokenizer_path is not None:
|
| 207 |
+
_write_tokenizer(model_path, config, tokenizer_path, fix_eos_token_id)
|
| 208 |
+
|
| 209 |
+
print("Loading the checkpoint in a OLMoE model.")
|
| 210 |
+
model = OlmoeForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.bfloat16)
|
| 211 |
+
# Avoid saving this as part of the config.
|
| 212 |
+
del model.config._name_or_path
|
| 213 |
+
print("Saving in the Transformers format.")
|
| 214 |
+
model.save_pretrained(model_path, safe_serialization=safe_serialization)
|
| 215 |
+
shutil.rmtree(tmp_model_path)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def _write_tokenizer(
|
| 219 |
+
output_path: Path, config: OlmoeConfig, input_tokenizer_path: Path, fix_eos_token_id: bool = True
|
| 220 |
+
) -> None:
|
| 221 |
+
print(f"Saving a {GPTNeoXTokenizerFast.__name__} to {output_path}.")
|
| 222 |
+
|
| 223 |
+
base_tokenizer = Tokenizer.from_file(str(input_tokenizer_path))
|
| 224 |
+
|
| 225 |
+
eos_token_id = config.eos_token_id if config.eos_token_id is not None else base_tokenizer.get_vocab_size() - 1
|
| 226 |
+
pad_token_id = config.pad_token_id if config.pad_token_id is not None else eos_token_id
|
| 227 |
+
|
| 228 |
+
if fix_eos_token_id and eos_token_id == 0:
|
| 229 |
+
# Fixing a bug in OLMo where eos token id was incorrectly set
|
| 230 |
+
print("Changing eos_token_id from 0 to 50279.")
|
| 231 |
+
eos_token_id = 50279
|
| 232 |
+
|
| 233 |
+
tokenizer = GPTNeoXTokenizerFast(
|
| 234 |
+
tokenizer_object=base_tokenizer,
|
| 235 |
+
eos_token=base_tokenizer.decode([eos_token_id], skip_special_tokens=False),
|
| 236 |
+
pad_token=base_tokenizer.decode([pad_token_id], skip_special_tokens=False),
|
| 237 |
+
unk_token=None,
|
| 238 |
+
bos_token=None,
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
tokenizer.save_pretrained(output_path)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def main():
|
| 245 |
+
parser = argparse.ArgumentParser()
|
| 246 |
+
parser.add_argument(
|
| 247 |
+
"--input_dir",
|
| 248 |
+
required=True,
|
| 249 |
+
help="Location of OLMoE weights, which contains config.yaml and model.pt.",
|
| 250 |
+
)
|
| 251 |
+
parser.add_argument(
|
| 252 |
+
"--tokenizer_json_path",
|
| 253 |
+
default=None,
|
| 254 |
+
help="Location of OLMoE tokenizer json file.",
|
| 255 |
+
)
|
| 256 |
+
parser.add_argument(
|
| 257 |
+
"--output_dir",
|
| 258 |
+
required=True,
|
| 259 |
+
help="Location to write HF model and tokenizer",
|
| 260 |
+
)
|
| 261 |
+
parser.add_argument(
|
| 262 |
+
"--no_fix_eos_token_id",
|
| 263 |
+
action="store_false",
|
| 264 |
+
dest="fix_eos_token_id",
|
| 265 |
+
help="If set, does not change eos token id from 0 to 50279 if it is 0. Changing 0 to 50279 is a bug fix, so use this option with care.",
|
| 266 |
+
)
|
| 267 |
+
parser.add_argument(
|
| 268 |
+
"--safe_serialization", type=bool, default=True, help="Whether or not to save using `safetensors`."
|
| 269 |
+
)
|
| 270 |
+
args = parser.parse_args()
|
| 271 |
+
write_model(
|
| 272 |
+
model_path=args.output_dir,
|
| 273 |
+
input_base_path=args.input_dir,
|
| 274 |
+
safe_serialization=args.safe_serialization,
|
| 275 |
+
tokenizer_path=args.tokenizer_json_path,
|
| 276 |
+
fix_eos_token_id=args.fix_eos_token_id,
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
if __name__ == "__main__":
|
| 281 |
+
main()
|
docs/transformers/src/transformers/models/olmoe/modeling_olmoe.py
ADDED
|
@@ -0,0 +1,1273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 2 |
+
# you may not use this file except in compliance with the License.
|
| 3 |
+
# You may obtain a copy of the License at
|
| 4 |
+
#
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
#
|
| 7 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 8 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 9 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 10 |
+
# See the License for the specific language governing permissions and
|
| 11 |
+
# limitations under the License.
|
| 12 |
+
"""PyTorch OLMoE model."""
|
| 13 |
+
|
| 14 |
+
import math
|
| 15 |
+
from typing import List, Optional, Tuple, Union
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
import torch.utils.checkpoint
|
| 20 |
+
from torch import nn
|
| 21 |
+
|
| 22 |
+
from ...activations import ACT2FN
|
| 23 |
+
from ...cache_utils import Cache, DynamicCache, StaticCache
|
| 24 |
+
from ...generation import GenerationMixin
|
| 25 |
+
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
| 26 |
+
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
|
| 27 |
+
from ...modeling_outputs import (
|
| 28 |
+
MoeCausalLMOutputWithPast,
|
| 29 |
+
MoeModelOutputWithPast,
|
| 30 |
+
)
|
| 31 |
+
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 32 |
+
from ...modeling_utils import PreTrainedModel
|
| 33 |
+
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
| 34 |
+
from ...utils import (
|
| 35 |
+
add_start_docstrings,
|
| 36 |
+
add_start_docstrings_to_model_forward,
|
| 37 |
+
logging,
|
| 38 |
+
replace_return_docstrings,
|
| 39 |
+
)
|
| 40 |
+
from .configuration_olmoe import OlmoeConfig
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
if is_flash_attn_available():
|
| 44 |
+
from ...modeling_flash_attention_utils import _flash_attention_forward
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
logger = logging.get_logger(__name__)
|
| 48 |
+
|
| 49 |
+
_CONFIG_FOR_DOC = "OlmoeConfig"
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func
|
| 53 |
+
def load_balancing_loss_func(
|
| 54 |
+
gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None],
|
| 55 |
+
num_experts: Optional[int] = None,
|
| 56 |
+
top_k=2,
|
| 57 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 58 |
+
) -> Union[torch.Tensor, int]:
|
| 59 |
+
r"""
|
| 60 |
+
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
|
| 61 |
+
|
| 62 |
+
See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
|
| 63 |
+
function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
|
| 64 |
+
experts is too unbalanced.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
gate_logits:
|
| 68 |
+
Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
|
| 69 |
+
shape [batch_size X sequence_length, num_experts].
|
| 70 |
+
num_experts:
|
| 71 |
+
Number of experts
|
| 72 |
+
top_k:
|
| 73 |
+
The number of experts to route per-token, can be also interpreted as the `top-k` routing
|
| 74 |
+
parameter.
|
| 75 |
+
attention_mask (`torch.Tensor`, *optional*):
|
| 76 |
+
The attention_mask used in forward function
|
| 77 |
+
shape [batch_size X sequence_length] if not None.
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
The auxiliary loss.
|
| 81 |
+
"""
|
| 82 |
+
if gate_logits is None or not isinstance(gate_logits, tuple):
|
| 83 |
+
return 0
|
| 84 |
+
|
| 85 |
+
if isinstance(gate_logits, tuple):
|
| 86 |
+
compute_device = gate_logits[0].device
|
| 87 |
+
concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
|
| 88 |
+
|
| 89 |
+
routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
|
| 90 |
+
|
| 91 |
+
_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
| 92 |
+
|
| 93 |
+
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
|
| 94 |
+
|
| 95 |
+
if attention_mask is None:
|
| 96 |
+
# Compute the percentage of tokens routed to each experts
|
| 97 |
+
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
|
| 98 |
+
|
| 99 |
+
# Compute the average probability of routing to these experts
|
| 100 |
+
router_prob_per_expert = torch.mean(routing_weights, dim=0)
|
| 101 |
+
else:
|
| 102 |
+
batch_size, sequence_length = attention_mask.shape
|
| 103 |
+
num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
|
| 104 |
+
|
| 105 |
+
# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
|
| 106 |
+
expert_attention_mask = (
|
| 107 |
+
attention_mask[None, :, :, None, None]
|
| 108 |
+
.expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
|
| 109 |
+
.reshape(-1, top_k, num_experts)
|
| 110 |
+
.to(compute_device)
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
# Compute the percentage of tokens routed to each experts
|
| 114 |
+
tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
|
| 115 |
+
expert_attention_mask, dim=0
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
|
| 119 |
+
router_per_expert_attention_mask = (
|
| 120 |
+
attention_mask[None, :, :, None]
|
| 121 |
+
.expand((num_hidden_layers, batch_size, sequence_length, num_experts))
|
| 122 |
+
.reshape(-1, num_experts)
|
| 123 |
+
.to(compute_device)
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# Compute the average probability of routing to these experts
|
| 127 |
+
router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
|
| 128 |
+
router_per_expert_attention_mask, dim=0
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
|
| 132 |
+
return overall_loss * num_experts
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class OlmoeRMSNorm(nn.Module):
|
| 136 |
+
def __init__(self, hidden_size, eps=1e-5):
|
| 137 |
+
"""
|
| 138 |
+
OlmoeRMSNorm is equivalent to T5LayerNorm
|
| 139 |
+
"""
|
| 140 |
+
super().__init__()
|
| 141 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 142 |
+
self.variance_epsilon = eps
|
| 143 |
+
|
| 144 |
+
def forward(self, hidden_states):
|
| 145 |
+
input_dtype = hidden_states.dtype
|
| 146 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 147 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 148 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 149 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 150 |
+
|
| 151 |
+
def extra_repr(self):
|
| 152 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
ALL_LAYERNORM_LAYERS.append(OlmoeRMSNorm)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Olmoe
|
| 159 |
+
class OlmoeRotaryEmbedding(nn.Module):
|
| 160 |
+
def __init__(self, config: OlmoeConfig, device=None):
|
| 161 |
+
super().__init__()
|
| 162 |
+
# BC: "rope_type" was originally "type"
|
| 163 |
+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
| 164 |
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 165 |
+
else:
|
| 166 |
+
self.rope_type = "default"
|
| 167 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 168 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 169 |
+
|
| 170 |
+
self.config = config
|
| 171 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 172 |
+
|
| 173 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 174 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 175 |
+
self.original_inv_freq = self.inv_freq
|
| 176 |
+
|
| 177 |
+
@torch.no_grad()
|
| 178 |
+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 179 |
+
def forward(self, x, position_ids):
|
| 180 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 181 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 182 |
+
|
| 183 |
+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 184 |
+
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
| 185 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 186 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 187 |
+
cos = emb.cos() * self.attention_scaling
|
| 188 |
+
sin = emb.sin() * self.attention_scaling
|
| 189 |
+
|
| 190 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
| 194 |
+
def rotate_half(x):
|
| 195 |
+
"""Rotates half the hidden dims of the input."""
|
| 196 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 197 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 198 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
| 202 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 203 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
q (`torch.Tensor`): The query tensor.
|
| 207 |
+
k (`torch.Tensor`): The key tensor.
|
| 208 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 209 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 210 |
+
position_ids (`torch.Tensor`, *optional*):
|
| 211 |
+
Deprecated and unused.
|
| 212 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 213 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 214 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 215 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 216 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 217 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 218 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 219 |
+
Returns:
|
| 220 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 221 |
+
"""
|
| 222 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 223 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 224 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 225 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 226 |
+
return q_embed, k_embed
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
# Copied from transformers.models.olmo.modeling_olmo.OlmoMLP with Olmo->Olmoe
|
| 230 |
+
class OlmoeMLP(nn.Module):
|
| 231 |
+
def __init__(self, config):
|
| 232 |
+
super().__init__()
|
| 233 |
+
self.config = config
|
| 234 |
+
self.hidden_size = config.hidden_size
|
| 235 |
+
self.intermediate_size = config.intermediate_size
|
| 236 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 237 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 238 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 239 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 240 |
+
|
| 241 |
+
def forward(self, x):
|
| 242 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 243 |
+
return down_proj
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
| 247 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 248 |
+
"""
|
| 249 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 250 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 251 |
+
"""
|
| 252 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 253 |
+
if n_rep == 1:
|
| 254 |
+
return hidden_states
|
| 255 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 256 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
class OlmoeAttention(nn.Module):
|
| 260 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 261 |
+
|
| 262 |
+
def __init__(self, config: OlmoeConfig, layer_idx: Optional[int] = None):
|
| 263 |
+
super().__init__()
|
| 264 |
+
self.config = config
|
| 265 |
+
self.layer_idx = layer_idx
|
| 266 |
+
if layer_idx is None:
|
| 267 |
+
logger.warning_once(
|
| 268 |
+
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
| 269 |
+
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
| 270 |
+
"when creating this class."
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
self.attention_dropout = config.attention_dropout
|
| 274 |
+
self.hidden_size = config.hidden_size
|
| 275 |
+
self.num_heads = config.num_attention_heads
|
| 276 |
+
self.head_dim = self.hidden_size // self.num_heads
|
| 277 |
+
self.num_key_value_heads = config.num_key_value_heads
|
| 278 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
| 279 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 280 |
+
self.rope_theta = config.rope_theta
|
| 281 |
+
self.is_causal = True
|
| 282 |
+
|
| 283 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
| 284 |
+
raise ValueError(
|
| 285 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
| 286 |
+
f" and `num_heads`: {self.num_heads})."
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
|
| 290 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
| 291 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
| 292 |
+
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
|
| 293 |
+
self.q_norm = OlmoeRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
| 294 |
+
self.k_norm = OlmoeRMSNorm(
|
| 295 |
+
(self.hidden_size // self.num_heads) * self.num_key_value_heads, eps=config.rms_norm_eps
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
def forward(
|
| 299 |
+
self,
|
| 300 |
+
hidden_states: torch.Tensor,
|
| 301 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 302 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 303 |
+
past_key_value: Optional[Cache] = None,
|
| 304 |
+
output_attentions: bool = False,
|
| 305 |
+
use_cache: bool = False,
|
| 306 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 307 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 308 |
+
**kwargs,
|
| 309 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 310 |
+
bsz, q_len, _ = hidden_states.size()
|
| 311 |
+
|
| 312 |
+
query_states = self.q_norm(self.q_proj(hidden_states))
|
| 313 |
+
key_states = self.k_norm(self.k_proj(hidden_states))
|
| 314 |
+
value_states = self.v_proj(hidden_states)
|
| 315 |
+
|
| 316 |
+
if self.config.clip_qkv is not None:
|
| 317 |
+
query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
| 318 |
+
key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
| 319 |
+
value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
| 320 |
+
|
| 321 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 322 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 323 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 324 |
+
|
| 325 |
+
cos, sin = position_embeddings
|
| 326 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 327 |
+
|
| 328 |
+
if past_key_value is not None:
|
| 329 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 330 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 331 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 332 |
+
|
| 333 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 334 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 335 |
+
|
| 336 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
| 337 |
+
|
| 338 |
+
if attention_mask is not None: # no matter the length, we just slice it
|
| 339 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 340 |
+
attn_weights = attn_weights + causal_mask
|
| 341 |
+
|
| 342 |
+
# upcast attention to fp32
|
| 343 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
| 344 |
+
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
| 345 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 346 |
+
|
| 347 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
| 348 |
+
raise ValueError(
|
| 349 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
| 350 |
+
f" {attn_output.size()}"
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 354 |
+
|
| 355 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
| 356 |
+
|
| 357 |
+
attn_output = self.o_proj(attn_output)
|
| 358 |
+
|
| 359 |
+
if not output_attentions:
|
| 360 |
+
attn_weights = None
|
| 361 |
+
|
| 362 |
+
return attn_output, attn_weights, past_key_value
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
class OlmoeFlashAttention2(OlmoeAttention):
|
| 366 |
+
"""
|
| 367 |
+
OLMoE flash attention module. This module inherits from `OlmoeAttention` as the weights of the module stays
|
| 368 |
+
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
| 369 |
+
flash attention and deal with padding tokens in case the input contains any of them.
|
| 370 |
+
"""
|
| 371 |
+
|
| 372 |
+
def __init__(self, *args, **kwargs):
|
| 373 |
+
super().__init__(*args, **kwargs)
|
| 374 |
+
|
| 375 |
+
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
| 376 |
+
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
| 377 |
+
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
| 378 |
+
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
|
| 379 |
+
|
| 380 |
+
def forward(
|
| 381 |
+
self,
|
| 382 |
+
hidden_states: torch.Tensor,
|
| 383 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 384 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 385 |
+
past_key_value: Optional[Cache] = None,
|
| 386 |
+
output_attentions: bool = False,
|
| 387 |
+
use_cache: bool = False,
|
| 388 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 389 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 390 |
+
**kwargs,
|
| 391 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 392 |
+
output_attentions = False
|
| 393 |
+
|
| 394 |
+
bsz, q_len, _ = hidden_states.size()
|
| 395 |
+
|
| 396 |
+
query_states = self.q_norm(self.q_proj(hidden_states))
|
| 397 |
+
key_states = self.k_norm(self.k_proj(hidden_states))
|
| 398 |
+
value_states = self.v_proj(hidden_states)
|
| 399 |
+
if self.config.clip_qkv is not None:
|
| 400 |
+
query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
| 401 |
+
key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
| 402 |
+
value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
| 403 |
+
|
| 404 |
+
# Flash attention requires the input to have the shape
|
| 405 |
+
# batch_size x seq_length x head_dim x hidden_dim
|
| 406 |
+
# therefore we just need to keep the original shape
|
| 407 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 408 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 409 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 410 |
+
|
| 411 |
+
cos, sin = position_embeddings
|
| 412 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 413 |
+
|
| 414 |
+
if past_key_value is not None:
|
| 415 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 416 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 417 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 418 |
+
|
| 419 |
+
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
| 420 |
+
# to be able to avoid many of these transpose/reshape/view.
|
| 421 |
+
query_states = query_states.transpose(1, 2)
|
| 422 |
+
key_states = key_states.transpose(1, 2)
|
| 423 |
+
value_states = value_states.transpose(1, 2)
|
| 424 |
+
|
| 425 |
+
dropout_rate = self.attention_dropout if self.training else 0.0
|
| 426 |
+
|
| 427 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
| 428 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
| 429 |
+
# cast them back in the correct dtype just to be sure everything works as expected.
|
| 430 |
+
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
| 431 |
+
# in fp32. (OlmoeRMSNorm handles it correctly)
|
| 432 |
+
|
| 433 |
+
input_dtype = query_states.dtype
|
| 434 |
+
if input_dtype == torch.float32:
|
| 435 |
+
if torch.is_autocast_enabled():
|
| 436 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
| 437 |
+
# Handle the case where the model is quantized
|
| 438 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
| 439 |
+
target_dtype = self.config._pre_quantization_dtype
|
| 440 |
+
else:
|
| 441 |
+
target_dtype = self.q_proj.weight.dtype
|
| 442 |
+
|
| 443 |
+
logger.warning_once(
|
| 444 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
| 445 |
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
| 446 |
+
f" {target_dtype}."
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
query_states = query_states.to(target_dtype)
|
| 450 |
+
key_states = key_states.to(target_dtype)
|
| 451 |
+
value_states = value_states.to(target_dtype)
|
| 452 |
+
|
| 453 |
+
attn_output = _flash_attention_forward(
|
| 454 |
+
query_states,
|
| 455 |
+
key_states,
|
| 456 |
+
value_states,
|
| 457 |
+
attention_mask,
|
| 458 |
+
q_len,
|
| 459 |
+
dropout=dropout_rate,
|
| 460 |
+
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
| 461 |
+
is_causal=self.is_causal,
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
| 465 |
+
attn_output = self.o_proj(attn_output)
|
| 466 |
+
|
| 467 |
+
if not output_attentions:
|
| 468 |
+
attn_weights = None
|
| 469 |
+
|
| 470 |
+
return attn_output, attn_weights, past_key_value
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
class OlmoeSdpaAttention(OlmoeAttention):
|
| 474 |
+
"""
|
| 475 |
+
OLMoE attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
| 476 |
+
`OlmoeAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
| 477 |
+
SDPA API.
|
| 478 |
+
"""
|
| 479 |
+
|
| 480 |
+
# Adapted from OlmoeAttention.forward
|
| 481 |
+
def forward(
|
| 482 |
+
self,
|
| 483 |
+
hidden_states: torch.Tensor,
|
| 484 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 485 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 486 |
+
past_key_value: Optional[Cache] = None,
|
| 487 |
+
output_attentions: bool = False,
|
| 488 |
+
use_cache: bool = False,
|
| 489 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 490 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 491 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 492 |
+
if output_attentions:
|
| 493 |
+
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
| 494 |
+
logger.warning_once(
|
| 495 |
+
"OlmoeModel is using OlmoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
| 496 |
+
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
| 497 |
+
)
|
| 498 |
+
return super().forward(
|
| 499 |
+
hidden_states=hidden_states,
|
| 500 |
+
attention_mask=attention_mask,
|
| 501 |
+
position_ids=position_ids,
|
| 502 |
+
past_key_value=past_key_value,
|
| 503 |
+
output_attentions=output_attentions,
|
| 504 |
+
use_cache=use_cache,
|
| 505 |
+
cache_position=cache_position,
|
| 506 |
+
position_embeddings=position_embeddings,
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
bsz, q_len, _ = hidden_states.size()
|
| 510 |
+
|
| 511 |
+
query_states = self.q_norm(self.q_proj(hidden_states))
|
| 512 |
+
key_states = self.k_norm(self.k_proj(hidden_states))
|
| 513 |
+
value_states = self.v_proj(hidden_states)
|
| 514 |
+
|
| 515 |
+
if self.config.clip_qkv is not None:
|
| 516 |
+
query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
| 517 |
+
key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
| 518 |
+
value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
| 519 |
+
|
| 520 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 521 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 522 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 523 |
+
|
| 524 |
+
cos, sin = position_embeddings
|
| 525 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 526 |
+
|
| 527 |
+
if past_key_value is not None:
|
| 528 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 529 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 530 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 531 |
+
|
| 532 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 533 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 534 |
+
|
| 535 |
+
causal_mask = attention_mask
|
| 536 |
+
# if attention_mask is not None and cache_position is not None:
|
| 537 |
+
if attention_mask is not None:
|
| 538 |
+
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
| 539 |
+
|
| 540 |
+
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
| 541 |
+
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
| 542 |
+
if query_states.device.type == "cuda" and causal_mask is not None:
|
| 543 |
+
query_states = query_states.contiguous()
|
| 544 |
+
key_states = key_states.contiguous()
|
| 545 |
+
value_states = value_states.contiguous()
|
| 546 |
+
|
| 547 |
+
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
| 548 |
+
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
| 549 |
+
is_causal = True if causal_mask is None and q_len > 1 else False
|
| 550 |
+
|
| 551 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
| 552 |
+
query_states,
|
| 553 |
+
key_states,
|
| 554 |
+
value_states,
|
| 555 |
+
attn_mask=causal_mask,
|
| 556 |
+
dropout_p=self.attention_dropout if self.training else 0.0,
|
| 557 |
+
is_causal=is_causal,
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 561 |
+
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
| 562 |
+
|
| 563 |
+
attn_output = self.o_proj(attn_output)
|
| 564 |
+
|
| 565 |
+
return attn_output, None, past_key_value
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
OLMOE_ATTENTION_CLASSES = {
|
| 569 |
+
"eager": OlmoeAttention,
|
| 570 |
+
"flash_attention_2": OlmoeFlashAttention2,
|
| 571 |
+
"sdpa": OlmoeSdpaAttention,
|
| 572 |
+
}
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
class OlmoeSparseMoeBlock(nn.Module):
|
| 576 |
+
def __init__(self, config):
|
| 577 |
+
super().__init__()
|
| 578 |
+
self.num_experts = config.num_experts
|
| 579 |
+
self.top_k = config.num_experts_per_tok
|
| 580 |
+
self.norm_topk_prob = config.norm_topk_prob
|
| 581 |
+
self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
|
| 582 |
+
self.experts = nn.ModuleList([OlmoeMLP(config) for _ in range(self.num_experts)])
|
| 583 |
+
|
| 584 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 585 |
+
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
| 586 |
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
| 587 |
+
# router_logits: (batch * sequence_length, n_experts)
|
| 588 |
+
router_logits = self.gate(hidden_states)
|
| 589 |
+
|
| 590 |
+
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
| 591 |
+
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
| 592 |
+
if self.norm_topk_prob:
|
| 593 |
+
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
| 594 |
+
# we cast back to the input dtype
|
| 595 |
+
routing_weights = routing_weights.to(hidden_states.dtype)
|
| 596 |
+
|
| 597 |
+
final_hidden_states = torch.zeros(
|
| 598 |
+
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
# One hot encode the selected experts to create an expert mask
|
| 602 |
+
# this will be used to easily index which expert is going to be selected
|
| 603 |
+
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
|
| 604 |
+
|
| 605 |
+
# Loop over all available experts in the model and perform the computation on each expert
|
| 606 |
+
for expert_idx in range(self.num_experts):
|
| 607 |
+
expert_layer = self.experts[expert_idx]
|
| 608 |
+
idx, top_x = torch.where(expert_mask[expert_idx])
|
| 609 |
+
|
| 610 |
+
# Index the correct hidden states and compute the expert hidden state for
|
| 611 |
+
# the current expert. We need to make sure to multiply the output hidden
|
| 612 |
+
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
| 613 |
+
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
|
| 614 |
+
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
|
| 615 |
+
|
| 616 |
+
# However `index_add_` only support torch tensors for indexing so we'll use
|
| 617 |
+
# the `top_x` tensor here.
|
| 618 |
+
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
| 619 |
+
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
| 620 |
+
return final_hidden_states, router_logits
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
class OlmoeDecoderLayer(nn.Module):
|
| 624 |
+
def __init__(self, config: OlmoeConfig, layer_idx: int):
|
| 625 |
+
super().__init__()
|
| 626 |
+
self.hidden_size = config.hidden_size
|
| 627 |
+
|
| 628 |
+
self.self_attn = OLMOE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
| 629 |
+
|
| 630 |
+
self.mlp = OlmoeSparseMoeBlock(config)
|
| 631 |
+
self.input_layernorm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 632 |
+
self.post_attention_layernorm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 633 |
+
|
| 634 |
+
def forward(
|
| 635 |
+
self,
|
| 636 |
+
hidden_states: torch.Tensor,
|
| 637 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 638 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 639 |
+
past_key_value: Optional[Cache] = None,
|
| 640 |
+
output_attentions: Optional[bool] = False,
|
| 641 |
+
output_router_logits: Optional[bool] = False,
|
| 642 |
+
use_cache: Optional[bool] = False,
|
| 643 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 644 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 645 |
+
**kwargs,
|
| 646 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 647 |
+
"""
|
| 648 |
+
Args:
|
| 649 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 650 |
+
attention_mask (`torch.FloatTensor`, *optional*):
|
| 651 |
+
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
|
| 652 |
+
query_sequence_length, key_sequence_length)` if default attention is used.
|
| 653 |
+
output_attentions (`bool`, *optional*):
|
| 654 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 655 |
+
returned tensors for more detail.
|
| 656 |
+
output_router_logits (`bool`, *optional*):
|
| 657 |
+
Whether or not to return the logits of all the routers. They are useful for computing the router loss,
|
| 658 |
+
and should not be returned during inference.
|
| 659 |
+
use_cache (`bool`, *optional*):
|
| 660 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
| 661 |
+
(see `past_key_values`).
|
| 662 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
| 663 |
+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
| 664 |
+
Indices depicting the position of the input sequence tokens in the sequence
|
| 665 |
+
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
|
| 666 |
+
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
| 667 |
+
with `head_dim` being the embedding dimension of each attention head.
|
| 668 |
+
kwargs (`dict`, *optional*):
|
| 669 |
+
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
| 670 |
+
into the model
|
| 671 |
+
"""
|
| 672 |
+
residual = hidden_states
|
| 673 |
+
|
| 674 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 675 |
+
|
| 676 |
+
# Self Attention
|
| 677 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
| 678 |
+
hidden_states=hidden_states,
|
| 679 |
+
attention_mask=attention_mask,
|
| 680 |
+
position_ids=position_ids,
|
| 681 |
+
past_key_value=past_key_value,
|
| 682 |
+
output_attentions=output_attentions,
|
| 683 |
+
use_cache=use_cache,
|
| 684 |
+
cache_position=cache_position,
|
| 685 |
+
position_embeddings=position_embeddings,
|
| 686 |
+
**kwargs,
|
| 687 |
+
)
|
| 688 |
+
hidden_states = residual + hidden_states
|
| 689 |
+
|
| 690 |
+
# Fully Connected
|
| 691 |
+
residual = hidden_states
|
| 692 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 693 |
+
hidden_states, router_logits = self.mlp(hidden_states)
|
| 694 |
+
hidden_states = residual + hidden_states
|
| 695 |
+
|
| 696 |
+
outputs = (hidden_states,)
|
| 697 |
+
|
| 698 |
+
if output_attentions:
|
| 699 |
+
outputs += (self_attn_weights,)
|
| 700 |
+
|
| 701 |
+
if use_cache:
|
| 702 |
+
outputs += (present_key_value,)
|
| 703 |
+
|
| 704 |
+
if output_router_logits:
|
| 705 |
+
outputs += (router_logits,)
|
| 706 |
+
|
| 707 |
+
return outputs
|
| 708 |
+
|
| 709 |
+
|
| 710 |
+
OLMOE_START_DOCSTRING = r"""
|
| 711 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 712 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 713 |
+
etc.)
|
| 714 |
+
|
| 715 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 716 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 717 |
+
and behavior.
|
| 718 |
+
|
| 719 |
+
Parameters:
|
| 720 |
+
config ([`OlmoeConfig`]):
|
| 721 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
| 722 |
+
load the weights associated with the model, only the configuration. Check out the
|
| 723 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 724 |
+
"""
|
| 725 |
+
|
| 726 |
+
|
| 727 |
+
@add_start_docstrings(
|
| 728 |
+
"The bare Olmoe Model outputting raw hidden-states without any specific head on top.",
|
| 729 |
+
OLMOE_START_DOCSTRING,
|
| 730 |
+
)
|
| 731 |
+
class OlmoePreTrainedModel(PreTrainedModel):
|
| 732 |
+
config_class = OlmoeConfig
|
| 733 |
+
base_model_prefix = "model"
|
| 734 |
+
supports_gradient_checkpointing = True
|
| 735 |
+
_no_split_modules = ["OlmoeDecoderLayer"]
|
| 736 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 737 |
+
_supports_flash_attn_2 = True
|
| 738 |
+
_supports_sdpa = True
|
| 739 |
+
_supports_cache_class = True
|
| 740 |
+
_supports_quantized_cache = True
|
| 741 |
+
_supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
|
| 742 |
+
|
| 743 |
+
def _init_weights(self, module):
|
| 744 |
+
std = self.config.initializer_range
|
| 745 |
+
if isinstance(module, nn.Linear):
|
| 746 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 747 |
+
if module.bias is not None:
|
| 748 |
+
module.bias.data.zero_()
|
| 749 |
+
elif isinstance(module, OlmoeRMSNorm):
|
| 750 |
+
module.weight.data.fill_(1.0)
|
| 751 |
+
elif isinstance(module, nn.Embedding):
|
| 752 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 753 |
+
if module.padding_idx is not None:
|
| 754 |
+
module.weight.data[module.padding_idx].zero_()
|
| 755 |
+
|
| 756 |
+
|
| 757 |
+
OLMOE_INPUTS_DOCSTRING = r"""
|
| 758 |
+
Args:
|
| 759 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 760 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 761 |
+
it.
|
| 762 |
+
|
| 763 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 764 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 765 |
+
|
| 766 |
+
[What are input IDs?](../glossary#input-ids)
|
| 767 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 768 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 769 |
+
|
| 770 |
+
- 1 for tokens that are **not masked**,
|
| 771 |
+
- 0 for tokens that are **masked**.
|
| 772 |
+
|
| 773 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 774 |
+
|
| 775 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 776 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 777 |
+
|
| 778 |
+
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
|
| 779 |
+
`past_key_values`).
|
| 780 |
+
|
| 781 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
| 782 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
| 783 |
+
information on the default strategy.
|
| 784 |
+
|
| 785 |
+
- 1 indicates the head is **not masked**,
|
| 786 |
+
- 0 indicates the head is **masked**.
|
| 787 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 788 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 789 |
+
config.n_positions - 1]`.
|
| 790 |
+
|
| 791 |
+
[What are position IDs?](../glossary#position-ids)
|
| 792 |
+
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
| 793 |
+
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
| 794 |
+
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
| 795 |
+
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
| 796 |
+
|
| 797 |
+
Two formats are allowed:
|
| 798 |
+
- a [`~cache_utils.Cache`] instance;
|
| 799 |
+
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
| 800 |
+
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
| 801 |
+
cache format.
|
| 802 |
+
|
| 803 |
+
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
| 804 |
+
legacy cache format will be returned.
|
| 805 |
+
|
| 806 |
+
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
| 807 |
+
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
| 808 |
+
of shape `(batch_size, sequence_length)`.
|
| 809 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 810 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 811 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 812 |
+
model's internal embedding lookup matrix.
|
| 813 |
+
use_cache (`bool`, *optional*):
|
| 814 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
| 815 |
+
`past_key_values`).
|
| 816 |
+
output_attentions (`bool`, *optional*):
|
| 817 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 818 |
+
tensors for more detail.
|
| 819 |
+
output_hidden_states (`bool`, *optional*):
|
| 820 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 821 |
+
more detail.
|
| 822 |
+
output_router_logits (`bool`, *optional*):
|
| 823 |
+
Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
|
| 824 |
+
should not be returned during inference.
|
| 825 |
+
return_dict (`bool`, *optional*):
|
| 826 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 827 |
+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
| 828 |
+
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
| 829 |
+
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
| 830 |
+
the complete sequence length.
|
| 831 |
+
"""
|
| 832 |
+
|
| 833 |
+
|
| 834 |
+
@add_start_docstrings(
|
| 835 |
+
"The bare Olmoe Model outputting raw hidden-states without any specific head on top.",
|
| 836 |
+
OLMOE_START_DOCSTRING,
|
| 837 |
+
)
|
| 838 |
+
# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->Olmoe
|
| 839 |
+
class OlmoeModel(OlmoePreTrainedModel):
|
| 840 |
+
"""
|
| 841 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OlmoeDecoderLayer`]
|
| 842 |
+
|
| 843 |
+
Args:
|
| 844 |
+
config: OlmoeConfig
|
| 845 |
+
"""
|
| 846 |
+
|
| 847 |
+
def __init__(self, config: OlmoeConfig):
|
| 848 |
+
super().__init__(config)
|
| 849 |
+
self.padding_idx = config.pad_token_id
|
| 850 |
+
self.vocab_size = config.vocab_size
|
| 851 |
+
|
| 852 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 853 |
+
self.layers = nn.ModuleList(
|
| 854 |
+
[OlmoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 855 |
+
)
|
| 856 |
+
self.norm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 857 |
+
self.rotary_emb = OlmoeRotaryEmbedding(config=config)
|
| 858 |
+
self.gradient_checkpointing = False
|
| 859 |
+
|
| 860 |
+
# Initialize weights and apply final processing
|
| 861 |
+
self.post_init()
|
| 862 |
+
|
| 863 |
+
def get_input_embeddings(self):
|
| 864 |
+
return self.embed_tokens
|
| 865 |
+
|
| 866 |
+
def set_input_embeddings(self, value):
|
| 867 |
+
self.embed_tokens = value
|
| 868 |
+
|
| 869 |
+
@add_start_docstrings_to_model_forward(OLMOE_INPUTS_DOCSTRING)
|
| 870 |
+
# Ignore copy
|
| 871 |
+
def forward(
|
| 872 |
+
self,
|
| 873 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 874 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 875 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 876 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
| 877 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 878 |
+
use_cache: Optional[bool] = None,
|
| 879 |
+
output_attentions: Optional[bool] = None,
|
| 880 |
+
output_hidden_states: Optional[bool] = None,
|
| 881 |
+
output_router_logits: Optional[bool] = None,
|
| 882 |
+
return_dict: Optional[bool] = None,
|
| 883 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 884 |
+
) -> Union[Tuple, MoeModelOutputWithPast]:
|
| 885 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 886 |
+
output_router_logits = (
|
| 887 |
+
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
| 888 |
+
)
|
| 889 |
+
output_hidden_states = (
|
| 890 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 891 |
+
)
|
| 892 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 893 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 894 |
+
|
| 895 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 896 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 897 |
+
|
| 898 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
| 899 |
+
logger.warning_once(
|
| 900 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
| 901 |
+
)
|
| 902 |
+
use_cache = False
|
| 903 |
+
|
| 904 |
+
if inputs_embeds is None:
|
| 905 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 906 |
+
|
| 907 |
+
# kept for BC (non `Cache` `past_key_values` inputs)
|
| 908 |
+
return_legacy_cache = False
|
| 909 |
+
if use_cache and not isinstance(past_key_values, Cache):
|
| 910 |
+
return_legacy_cache = True
|
| 911 |
+
if past_key_values is None:
|
| 912 |
+
past_key_values = DynamicCache()
|
| 913 |
+
else:
|
| 914 |
+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
| 915 |
+
logger.warning_once(
|
| 916 |
+
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
| 917 |
+
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
| 918 |
+
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
| 919 |
+
)
|
| 920 |
+
|
| 921 |
+
if cache_position is None:
|
| 922 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 923 |
+
cache_position = torch.arange(
|
| 924 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 925 |
+
)
|
| 926 |
+
if position_ids is None:
|
| 927 |
+
position_ids = cache_position.unsqueeze(0)
|
| 928 |
+
|
| 929 |
+
causal_mask = self._update_causal_mask(
|
| 930 |
+
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
| 931 |
+
)
|
| 932 |
+
|
| 933 |
+
# embed positions
|
| 934 |
+
hidden_states = inputs_embeds
|
| 935 |
+
|
| 936 |
+
# create position embeddings to be shared across the decoder layers
|
| 937 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 938 |
+
|
| 939 |
+
# decoder layers
|
| 940 |
+
all_hidden_states = () if output_hidden_states else None
|
| 941 |
+
all_self_attns = () if output_attentions else None
|
| 942 |
+
all_router_logits = () if output_router_logits else None
|
| 943 |
+
next_decoder_cache = None
|
| 944 |
+
|
| 945 |
+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
| 946 |
+
if output_hidden_states:
|
| 947 |
+
all_hidden_states += (hidden_states,)
|
| 948 |
+
|
| 949 |
+
if self.gradient_checkpointing and self.training:
|
| 950 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 951 |
+
decoder_layer.__call__,
|
| 952 |
+
hidden_states,
|
| 953 |
+
causal_mask,
|
| 954 |
+
position_ids,
|
| 955 |
+
past_key_values,
|
| 956 |
+
output_attentions,
|
| 957 |
+
output_router_logits,
|
| 958 |
+
use_cache,
|
| 959 |
+
cache_position,
|
| 960 |
+
position_embeddings,
|
| 961 |
+
)
|
| 962 |
+
else:
|
| 963 |
+
layer_outputs = decoder_layer(
|
| 964 |
+
hidden_states,
|
| 965 |
+
attention_mask=causal_mask,
|
| 966 |
+
position_ids=position_ids,
|
| 967 |
+
past_key_value=past_key_values,
|
| 968 |
+
output_attentions=output_attentions,
|
| 969 |
+
output_router_logits=output_router_logits,
|
| 970 |
+
use_cache=use_cache,
|
| 971 |
+
cache_position=cache_position,
|
| 972 |
+
position_embeddings=position_embeddings,
|
| 973 |
+
)
|
| 974 |
+
|
| 975 |
+
hidden_states = layer_outputs[0]
|
| 976 |
+
|
| 977 |
+
if use_cache:
|
| 978 |
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
| 979 |
+
|
| 980 |
+
if output_attentions:
|
| 981 |
+
all_self_attns += (layer_outputs[1],)
|
| 982 |
+
|
| 983 |
+
if output_router_logits and layer_outputs[-1] is not None:
|
| 984 |
+
all_router_logits += (layer_outputs[-1],)
|
| 985 |
+
|
| 986 |
+
hidden_states = self.norm(hidden_states)
|
| 987 |
+
|
| 988 |
+
# add hidden states from the last decoder layer
|
| 989 |
+
if output_hidden_states:
|
| 990 |
+
all_hidden_states += (hidden_states,)
|
| 991 |
+
|
| 992 |
+
next_cache = next_decoder_cache if use_cache else None
|
| 993 |
+
if return_legacy_cache:
|
| 994 |
+
next_cache = next_cache.to_legacy_cache()
|
| 995 |
+
|
| 996 |
+
if not return_dict:
|
| 997 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
| 998 |
+
return MoeModelOutputWithPast(
|
| 999 |
+
last_hidden_state=hidden_states,
|
| 1000 |
+
past_key_values=next_cache,
|
| 1001 |
+
hidden_states=all_hidden_states,
|
| 1002 |
+
attentions=all_self_attns,
|
| 1003 |
+
router_logits=all_router_logits,
|
| 1004 |
+
)
|
| 1005 |
+
|
| 1006 |
+
def _update_causal_mask(
|
| 1007 |
+
self,
|
| 1008 |
+
attention_mask: torch.Tensor,
|
| 1009 |
+
input_tensor: torch.Tensor,
|
| 1010 |
+
cache_position: torch.Tensor,
|
| 1011 |
+
past_key_values: Cache,
|
| 1012 |
+
output_attentions: bool,
|
| 1013 |
+
):
|
| 1014 |
+
if self.config._attn_implementation == "flash_attention_2":
|
| 1015 |
+
if attention_mask is not None and 0.0 in attention_mask:
|
| 1016 |
+
return attention_mask
|
| 1017 |
+
return None
|
| 1018 |
+
|
| 1019 |
+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
| 1020 |
+
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
| 1021 |
+
# to infer the attention mask.
|
| 1022 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 1023 |
+
using_static_cache = isinstance(past_key_values, StaticCache)
|
| 1024 |
+
|
| 1025 |
+
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
| 1026 |
+
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
| 1027 |
+
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
| 1028 |
+
attention_mask,
|
| 1029 |
+
inputs_embeds=input_tensor,
|
| 1030 |
+
past_key_values_length=past_seen_tokens,
|
| 1031 |
+
is_training=self.training,
|
| 1032 |
+
):
|
| 1033 |
+
return None
|
| 1034 |
+
|
| 1035 |
+
dtype, device = input_tensor.dtype, input_tensor.device
|
| 1036 |
+
sequence_length = input_tensor.shape[1]
|
| 1037 |
+
if using_static_cache:
|
| 1038 |
+
target_length = past_key_values.get_max_cache_shape()
|
| 1039 |
+
else:
|
| 1040 |
+
target_length = (
|
| 1041 |
+
attention_mask.shape[-1]
|
| 1042 |
+
if isinstance(attention_mask, torch.Tensor)
|
| 1043 |
+
else past_seen_tokens + sequence_length + 1
|
| 1044 |
+
)
|
| 1045 |
+
|
| 1046 |
+
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
| 1047 |
+
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
| 1048 |
+
attention_mask,
|
| 1049 |
+
sequence_length=sequence_length,
|
| 1050 |
+
target_length=target_length,
|
| 1051 |
+
dtype=dtype,
|
| 1052 |
+
device=device,
|
| 1053 |
+
cache_position=cache_position,
|
| 1054 |
+
batch_size=input_tensor.shape[0],
|
| 1055 |
+
)
|
| 1056 |
+
|
| 1057 |
+
if (
|
| 1058 |
+
self.config._attn_implementation == "sdpa"
|
| 1059 |
+
and attention_mask is not None
|
| 1060 |
+
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
| 1061 |
+
and not output_attentions
|
| 1062 |
+
):
|
| 1063 |
+
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
| 1064 |
+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
| 1065 |
+
# Details: https://github.com/pytorch/pytorch/issues/110213
|
| 1066 |
+
min_dtype = torch.finfo(dtype).min
|
| 1067 |
+
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
| 1068 |
+
|
| 1069 |
+
return causal_mask
|
| 1070 |
+
|
| 1071 |
+
@staticmethod
|
| 1072 |
+
def _prepare_4d_causal_attention_mask_with_cache_position(
|
| 1073 |
+
attention_mask: torch.Tensor,
|
| 1074 |
+
sequence_length: int,
|
| 1075 |
+
target_length: int,
|
| 1076 |
+
dtype: torch.dtype,
|
| 1077 |
+
device: torch.device,
|
| 1078 |
+
cache_position: torch.Tensor,
|
| 1079 |
+
batch_size: int,
|
| 1080 |
+
**kwargs,
|
| 1081 |
+
):
|
| 1082 |
+
"""
|
| 1083 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
| 1084 |
+
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
| 1085 |
+
|
| 1086 |
+
Args:
|
| 1087 |
+
attention_mask (`torch.Tensor`):
|
| 1088 |
+
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
| 1089 |
+
`(batch_size, 1, query_length, key_value_length)`.
|
| 1090 |
+
sequence_length (`int`):
|
| 1091 |
+
The sequence length being processed.
|
| 1092 |
+
target_length (`int`):
|
| 1093 |
+
The target length: when generating with static cache, the mask should be as long as the static cache,
|
| 1094 |
+
to account for the 0 padding, the part of the cache that is not filled yet.
|
| 1095 |
+
dtype (`torch.dtype`):
|
| 1096 |
+
The dtype to use for the 4D attention mask.
|
| 1097 |
+
device (`torch.device`):
|
| 1098 |
+
The device to place the 4D attention mask on.
|
| 1099 |
+
cache_position (`torch.Tensor`):
|
| 1100 |
+
Indices depicting the position of the input sequence tokens in the sequence.
|
| 1101 |
+
batch_size (`torch.Tensor`):
|
| 1102 |
+
Batch size.
|
| 1103 |
+
"""
|
| 1104 |
+
if attention_mask is not None and attention_mask.dim() == 4:
|
| 1105 |
+
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
| 1106 |
+
causal_mask = attention_mask
|
| 1107 |
+
else:
|
| 1108 |
+
min_dtype = torch.finfo(dtype).min
|
| 1109 |
+
causal_mask = torch.full(
|
| 1110 |
+
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
| 1111 |
+
)
|
| 1112 |
+
if sequence_length != 1:
|
| 1113 |
+
causal_mask = torch.triu(causal_mask, diagonal=1)
|
| 1114 |
+
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
| 1115 |
+
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
| 1116 |
+
if attention_mask is not None:
|
| 1117 |
+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
| 1118 |
+
mask_length = attention_mask.shape[-1]
|
| 1119 |
+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
| 1120 |
+
padding_mask = padding_mask == 0
|
| 1121 |
+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
| 1122 |
+
padding_mask, min_dtype
|
| 1123 |
+
)
|
| 1124 |
+
|
| 1125 |
+
return causal_mask
|
| 1126 |
+
|
| 1127 |
+
|
| 1128 |
+
class OlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin):
|
| 1129 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 1130 |
+
|
| 1131 |
+
def __init__(self, config):
|
| 1132 |
+
super().__init__(config)
|
| 1133 |
+
self.model = OlmoeModel(config)
|
| 1134 |
+
self.vocab_size = config.vocab_size
|
| 1135 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 1136 |
+
|
| 1137 |
+
self.router_aux_loss_coef = config.router_aux_loss_coef
|
| 1138 |
+
self.num_experts = config.num_experts
|
| 1139 |
+
self.num_experts_per_tok = config.num_experts_per_tok
|
| 1140 |
+
# Initialize weights and apply final processing
|
| 1141 |
+
self.post_init()
|
| 1142 |
+
|
| 1143 |
+
def get_input_embeddings(self):
|
| 1144 |
+
return self.model.embed_tokens
|
| 1145 |
+
|
| 1146 |
+
def set_input_embeddings(self, value):
|
| 1147 |
+
self.model.embed_tokens = value
|
| 1148 |
+
|
| 1149 |
+
def get_output_embeddings(self):
|
| 1150 |
+
return self.lm_head
|
| 1151 |
+
|
| 1152 |
+
def set_output_embeddings(self, new_embeddings):
|
| 1153 |
+
self.lm_head = new_embeddings
|
| 1154 |
+
|
| 1155 |
+
def set_decoder(self, decoder):
|
| 1156 |
+
self.model = decoder
|
| 1157 |
+
|
| 1158 |
+
def get_decoder(self):
|
| 1159 |
+
return self.model
|
| 1160 |
+
|
| 1161 |
+
@add_start_docstrings_to_model_forward(OLMOE_INPUTS_DOCSTRING)
|
| 1162 |
+
@replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
| 1163 |
+
def forward(
|
| 1164 |
+
self,
|
| 1165 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1166 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1167 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1168 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 1169 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1170 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1171 |
+
use_cache: Optional[bool] = None,
|
| 1172 |
+
output_attentions: Optional[bool] = None,
|
| 1173 |
+
output_hidden_states: Optional[bool] = None,
|
| 1174 |
+
output_router_logits: Optional[bool] = None,
|
| 1175 |
+
return_dict: Optional[bool] = None,
|
| 1176 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 1177 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 1178 |
+
**loss_kwargs,
|
| 1179 |
+
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
|
| 1180 |
+
r"""
|
| 1181 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1182 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 1183 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 1184 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 1185 |
+
|
| 1186 |
+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
| 1187 |
+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
| 1188 |
+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
| 1189 |
+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
| 1190 |
+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
| 1191 |
+
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
| 1192 |
+
|
| 1193 |
+
Returns:
|
| 1194 |
+
|
| 1195 |
+
Example:
|
| 1196 |
+
|
| 1197 |
+
```python
|
| 1198 |
+
>>> from transformers import AutoTokenizer, OlmoeForCausalLM
|
| 1199 |
+
|
| 1200 |
+
>>> model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924")
|
| 1201 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924")
|
| 1202 |
+
|
| 1203 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
| 1204 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 1205 |
+
|
| 1206 |
+
>>> # Generate
|
| 1207 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 1208 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 1209 |
+
'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m'
|
| 1210 |
+
```
|
| 1211 |
+
"""
|
| 1212 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1213 |
+
output_router_logits = (
|
| 1214 |
+
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
| 1215 |
+
)
|
| 1216 |
+
output_hidden_states = (
|
| 1217 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1218 |
+
)
|
| 1219 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1220 |
+
|
| 1221 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 1222 |
+
outputs = self.model(
|
| 1223 |
+
input_ids=input_ids,
|
| 1224 |
+
attention_mask=attention_mask,
|
| 1225 |
+
position_ids=position_ids,
|
| 1226 |
+
past_key_values=past_key_values,
|
| 1227 |
+
inputs_embeds=inputs_embeds,
|
| 1228 |
+
use_cache=use_cache,
|
| 1229 |
+
output_attentions=output_attentions,
|
| 1230 |
+
output_hidden_states=output_hidden_states,
|
| 1231 |
+
output_router_logits=output_router_logits,
|
| 1232 |
+
return_dict=return_dict,
|
| 1233 |
+
cache_position=cache_position,
|
| 1234 |
+
)
|
| 1235 |
+
|
| 1236 |
+
hidden_states = outputs[0]
|
| 1237 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 1238 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 1239 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 1240 |
+
|
| 1241 |
+
loss = None
|
| 1242 |
+
if labels is not None:
|
| 1243 |
+
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
| 1244 |
+
|
| 1245 |
+
aux_loss = None
|
| 1246 |
+
if output_router_logits:
|
| 1247 |
+
aux_loss = load_balancing_loss_func(
|
| 1248 |
+
outputs.router_logits if return_dict else outputs[-1],
|
| 1249 |
+
self.num_experts,
|
| 1250 |
+
self.num_experts_per_tok,
|
| 1251 |
+
attention_mask,
|
| 1252 |
+
)
|
| 1253 |
+
if labels is not None:
|
| 1254 |
+
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
|
| 1255 |
+
|
| 1256 |
+
if not return_dict:
|
| 1257 |
+
output = (logits,) + outputs[1:]
|
| 1258 |
+
if output_router_logits:
|
| 1259 |
+
output = (aux_loss,) + output
|
| 1260 |
+
return (loss,) + output if loss is not None else output
|
| 1261 |
+
|
| 1262 |
+
return MoeCausalLMOutputWithPast(
|
| 1263 |
+
loss=loss,
|
| 1264 |
+
aux_loss=aux_loss,
|
| 1265 |
+
logits=logits,
|
| 1266 |
+
past_key_values=outputs.past_key_values,
|
| 1267 |
+
hidden_states=outputs.hidden_states,
|
| 1268 |
+
attentions=outputs.attentions,
|
| 1269 |
+
router_logits=outputs.router_logits,
|
| 1270 |
+
)
|
| 1271 |
+
|
| 1272 |
+
|
| 1273 |
+
__all__ = ["OlmoeForCausalLM", "OlmoeModel", "OlmoePreTrainedModel"]
|
docs/transformers/src/transformers/models/omdet_turbo/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_omdet_turbo import *
|
| 22 |
+
from .modeling_omdet_turbo import *
|
| 23 |
+
from .processing_omdet_turbo import *
|
| 24 |
+
else:
|
| 25 |
+
import sys
|
| 26 |
+
|
| 27 |
+
_file = globals()["__file__"]
|
| 28 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/src/transformers/models/omdet_turbo/configuration_omdet_turbo.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""OmDet-Turbo model configuration"""
|
| 16 |
+
|
| 17 |
+
from ...configuration_utils import PretrainedConfig
|
| 18 |
+
from ...utils import logging
|
| 19 |
+
from ...utils.backbone_utils import verify_backbone_config_arguments
|
| 20 |
+
from ..auto import CONFIG_MAPPING
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
logger = logging.get_logger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class OmDetTurboConfig(PretrainedConfig):
|
| 27 |
+
r"""
|
| 28 |
+
This is the configuration class to store the configuration of a [`OmDetTurboForObjectDetection`].
|
| 29 |
+
It is used to instantiate a OmDet-Turbo model according to the specified arguments, defining the model architecture
|
| 30 |
+
Instantiating a configuration with the defaults will yield a similar configuration to that of the OmDet-Turbo
|
| 31 |
+
[omlab/omdet-turbo-swin-tiny-hf](https://huggingface.co/omlab/omdet-turbo-swin-tiny-hf) architecture.
|
| 32 |
+
|
| 33 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 34 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
text_config (`PretrainedConfig`, *optional*):
|
| 38 |
+
The configuration of the text backbone.
|
| 39 |
+
backbone_config (`PretrainedConfig`, *optional*):
|
| 40 |
+
The configuration of the vision backbone.
|
| 41 |
+
use_timm_backbone (`bool`, *optional*, defaults to `True`):
|
| 42 |
+
Whether to use the timm for the vision backbone.
|
| 43 |
+
backbone (`str`, *optional*, defaults to `"swin_tiny_patch4_window7_224"`):
|
| 44 |
+
The name of the pretrained vision backbone to use. If `use_pretrained_backbone=False` a randomly initialized
|
| 45 |
+
backbone with the same architecture `backbone` is used.
|
| 46 |
+
backbone_kwargs (`dict`, *optional*):
|
| 47 |
+
Additional kwargs for the vision backbone.
|
| 48 |
+
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
|
| 49 |
+
Whether to use a pretrained vision backbone.
|
| 50 |
+
apply_layernorm_after_vision_backbone (`bool`, *optional*, defaults to `True`):
|
| 51 |
+
Whether to apply layer normalization on the feature maps of the vision backbone output.
|
| 52 |
+
image_size (`int`, *optional*, defaults to 640):
|
| 53 |
+
The size (resolution) of each image.
|
| 54 |
+
disable_custom_kernels (`bool`, *optional*, defaults to `False`):
|
| 55 |
+
Whether to disable custom kernels.
|
| 56 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 57 |
+
The epsilon value for layer normalization.
|
| 58 |
+
batch_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 59 |
+
The epsilon value for batch normalization.
|
| 60 |
+
init_std (`float`, *optional*, defaults to 0.02):
|
| 61 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 62 |
+
text_projection_in_dim (`int`, *optional*, defaults to 512):
|
| 63 |
+
The input dimension for the text projection.
|
| 64 |
+
text_projection_out_dim (`int`, *optional*, defaults to 512):
|
| 65 |
+
The output dimension for the text projection.
|
| 66 |
+
task_encoder_hidden_dim (`int`, *optional*, defaults to 1024):
|
| 67 |
+
The feedforward dimension for the task encoder.
|
| 68 |
+
class_embed_dim (`int`, *optional*, defaults to 512):
|
| 69 |
+
The dimension of the classes embeddings.
|
| 70 |
+
class_distance_type (`str`, *optional*, defaults to `"cosine"`):
|
| 71 |
+
The type of of distance to compare predicted classes to projected classes embeddings.
|
| 72 |
+
Can be `"cosine"` or `"dot"`.
|
| 73 |
+
num_queries (`int`, *optional*, defaults to 900):
|
| 74 |
+
The number of queries.
|
| 75 |
+
csp_activation (`str`, *optional*, defaults to `"silu"`):
|
| 76 |
+
The activation function of the Cross Stage Partial (CSP) networks of the encoder.
|
| 77 |
+
conv_norm_activation (`str`, *optional*, defaults to `"gelu"`):
|
| 78 |
+
The activation function of the ConvNormLayer layers of the encoder.
|
| 79 |
+
encoder_feedforward_activation (`str`, *optional*, defaults to `"relu"`):
|
| 80 |
+
The activation function for the feedforward network of the encoder.
|
| 81 |
+
encoder_feedforward_dropout (`float`, *optional*, defaults to 0.0):
|
| 82 |
+
The dropout rate following the activation of the encoder feedforward network.
|
| 83 |
+
encoder_dropout (`float`, *optional*, defaults to 0.0):
|
| 84 |
+
The dropout rate of the encoder multi-head attention module.
|
| 85 |
+
hidden_expansion (`int`, *optional*, defaults to 1):
|
| 86 |
+
The hidden expansion of the CSP networks in the encoder.
|
| 87 |
+
vision_features_channels (`tuple(int)`, *optional*, defaults to `[256, 256, 256]`):
|
| 88 |
+
The projected vision features channels used as inputs for the decoder.
|
| 89 |
+
encoder_hidden_dim (`int`, *optional*, defaults to 256):
|
| 90 |
+
The hidden dimension of the encoder.
|
| 91 |
+
encoder_in_channels (`List(int)`, *optional*, defaults to `[192, 384, 768]`):
|
| 92 |
+
The input channels for the encoder.
|
| 93 |
+
encoder_projection_indices (`List(int)`, *optional*, defaults to `[2]`):
|
| 94 |
+
The indices of the input features projected by each layers.
|
| 95 |
+
encoder_attention_heads (`int`, *optional*, defaults to 8):
|
| 96 |
+
The number of attention heads for the encoder.
|
| 97 |
+
encoder_dim_feedforward (`int`, *optional*, defaults to 2048):
|
| 98 |
+
The feedforward dimension for the encoder.
|
| 99 |
+
encoder_layers (`int`, *optional*, defaults to 1):
|
| 100 |
+
The number of layers in the encoder.
|
| 101 |
+
positional_encoding_temperature (`int`, *optional*, defaults to 10000):
|
| 102 |
+
The positional encoding temperature in the encoder.
|
| 103 |
+
num_feature_levels (`int`, *optional*, defaults to 3):
|
| 104 |
+
The number of feature levels for the multi-scale deformable attention module of the decoder.
|
| 105 |
+
decoder_hidden_dim (`int`, *optional*, defaults to 256):
|
| 106 |
+
The hidden dimension of the decoder.
|
| 107 |
+
decoder_num_heads (`int`, *optional*, defaults to 8):
|
| 108 |
+
The number of heads for the decoder.
|
| 109 |
+
decoder_num_layers (`int`, *optional*, defaults to 6):
|
| 110 |
+
The number of layers for the decoder.
|
| 111 |
+
decoder_activation (`str`, *optional*, defaults to `"relu"`):
|
| 112 |
+
The activation function for the decoder.
|
| 113 |
+
decoder_dim_feedforward (`int`, *optional*, defaults to 2048):
|
| 114 |
+
The feedforward dimension for the decoder.
|
| 115 |
+
decoder_num_points (`int`, *optional*, defaults to 4):
|
| 116 |
+
The number of points sampled in the decoder multi-scale deformable attention module.
|
| 117 |
+
decoder_dropout (`float`, *optional*, defaults to 0.0):
|
| 118 |
+
The dropout rate for the decoder.
|
| 119 |
+
eval_size (`Tuple[int, int]`, *optional*):
|
| 120 |
+
Height and width used to computes the effective height and width of the position embeddings after taking
|
| 121 |
+
into account the stride (see RTDetr).
|
| 122 |
+
learn_initial_query (`bool`, *optional*, defaults to `False`):
|
| 123 |
+
Whether to learn the initial query.
|
| 124 |
+
cache_size (`int`, *optional*, defaults to 100):
|
| 125 |
+
The cache size for the classes and prompts caches.
|
| 126 |
+
is_encoder_decoder (`bool`, *optional*, defaults to `True`):
|
| 127 |
+
Whether the model is used as an encoder-decoder model or not.
|
| 128 |
+
kwargs (`Dict[str, Any]`, *optional*):
|
| 129 |
+
Additional parameters from the architecture. The values in kwargs will be saved as part of the configuration
|
| 130 |
+
and can be used to control the model outputs.
|
| 131 |
+
|
| 132 |
+
Examples:
|
| 133 |
+
|
| 134 |
+
```python
|
| 135 |
+
>>> from transformers import OmDetTurboConfig, OmDetTurboForObjectDetection
|
| 136 |
+
|
| 137 |
+
>>> # Initializing a OmDet-Turbo omlab/omdet-turbo-swin-tiny-hf style configuration
|
| 138 |
+
>>> configuration = OmDetTurboConfig()
|
| 139 |
+
|
| 140 |
+
>>> # Initializing a model (with random weights) from the omlab/omdet-turbo-swin-tiny-hf style configuration
|
| 141 |
+
>>> model = OmDetTurboForObjectDetection(configuration)
|
| 142 |
+
|
| 143 |
+
>>> # Accessing the model configuration
|
| 144 |
+
>>> configuration = model.config
|
| 145 |
+
```"""
|
| 146 |
+
|
| 147 |
+
model_type = "omdet-turbo"
|
| 148 |
+
attribute_map = {
|
| 149 |
+
"encoder_hidden_dim": "d_model",
|
| 150 |
+
"num_attention_heads": "encoder_attention_heads",
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
def __init__(
|
| 154 |
+
self,
|
| 155 |
+
text_config=None,
|
| 156 |
+
backbone_config=None,
|
| 157 |
+
use_timm_backbone=True,
|
| 158 |
+
backbone="swin_tiny_patch4_window7_224",
|
| 159 |
+
backbone_kwargs=None,
|
| 160 |
+
use_pretrained_backbone=False,
|
| 161 |
+
apply_layernorm_after_vision_backbone=True,
|
| 162 |
+
image_size=640,
|
| 163 |
+
disable_custom_kernels=False,
|
| 164 |
+
layer_norm_eps=1e-5,
|
| 165 |
+
batch_norm_eps=1e-5,
|
| 166 |
+
init_std=0.02,
|
| 167 |
+
text_projection_in_dim=512,
|
| 168 |
+
text_projection_out_dim=512,
|
| 169 |
+
task_encoder_hidden_dim=1024,
|
| 170 |
+
class_embed_dim=512,
|
| 171 |
+
class_distance_type="cosine",
|
| 172 |
+
num_queries=900,
|
| 173 |
+
csp_activation="silu",
|
| 174 |
+
conv_norm_activation="gelu",
|
| 175 |
+
encoder_feedforward_activation="relu",
|
| 176 |
+
encoder_feedforward_dropout=0.0,
|
| 177 |
+
encoder_dropout=0.0,
|
| 178 |
+
hidden_expansion=1,
|
| 179 |
+
vision_features_channels=[256, 256, 256],
|
| 180 |
+
encoder_hidden_dim=256,
|
| 181 |
+
encoder_in_channels=[192, 384, 768],
|
| 182 |
+
encoder_projection_indices=[2],
|
| 183 |
+
encoder_attention_heads=8,
|
| 184 |
+
encoder_dim_feedforward=2048,
|
| 185 |
+
encoder_layers=1,
|
| 186 |
+
positional_encoding_temperature=10000,
|
| 187 |
+
num_feature_levels=3,
|
| 188 |
+
decoder_hidden_dim=256,
|
| 189 |
+
decoder_num_heads=8,
|
| 190 |
+
decoder_num_layers=6,
|
| 191 |
+
decoder_activation="relu",
|
| 192 |
+
decoder_dim_feedforward=2048,
|
| 193 |
+
decoder_num_points=4,
|
| 194 |
+
decoder_dropout=0.0,
|
| 195 |
+
eval_size=None,
|
| 196 |
+
learn_initial_query=False,
|
| 197 |
+
cache_size=100,
|
| 198 |
+
is_encoder_decoder=True,
|
| 199 |
+
**kwargs,
|
| 200 |
+
):
|
| 201 |
+
if use_timm_backbone:
|
| 202 |
+
if backbone_config is None:
|
| 203 |
+
backbone_kwargs = {
|
| 204 |
+
"out_indices": [1, 2, 3],
|
| 205 |
+
"img_size": image_size,
|
| 206 |
+
"always_partition": True,
|
| 207 |
+
}
|
| 208 |
+
elif backbone_config is None:
|
| 209 |
+
logger.info("`backbone_config` is `None`. Initializing the config with the default `swin` vision config.")
|
| 210 |
+
backbone_config = CONFIG_MAPPING["swin"](
|
| 211 |
+
window_size=7,
|
| 212 |
+
image_size=image_size,
|
| 213 |
+
embed_dim=96,
|
| 214 |
+
depths=[2, 2, 6, 2],
|
| 215 |
+
num_heads=[3, 6, 12, 24],
|
| 216 |
+
out_indices=[2, 3, 4],
|
| 217 |
+
)
|
| 218 |
+
elif isinstance(backbone_config, dict):
|
| 219 |
+
backbone_model_type = backbone_config.get("model_type")
|
| 220 |
+
config_class = CONFIG_MAPPING[backbone_model_type]
|
| 221 |
+
backbone_config = config_class.from_dict(backbone_config)
|
| 222 |
+
|
| 223 |
+
verify_backbone_config_arguments(
|
| 224 |
+
use_timm_backbone=use_timm_backbone,
|
| 225 |
+
use_pretrained_backbone=use_pretrained_backbone,
|
| 226 |
+
backbone=backbone,
|
| 227 |
+
backbone_config=backbone_config,
|
| 228 |
+
backbone_kwargs=backbone_kwargs,
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
if text_config is None:
|
| 232 |
+
logger.info(
|
| 233 |
+
"`text_config` is `None`. Initializing the config with the default `clip_text_model` text config."
|
| 234 |
+
)
|
| 235 |
+
text_config = CONFIG_MAPPING["clip_text_model"]()
|
| 236 |
+
elif isinstance(text_config, dict):
|
| 237 |
+
text_model_type = text_config.get("model_type")
|
| 238 |
+
text_config = CONFIG_MAPPING[text_model_type](**text_config)
|
| 239 |
+
|
| 240 |
+
if class_distance_type not in ["cosine", "dot"]:
|
| 241 |
+
raise ValueError(
|
| 242 |
+
f"Invalid `class_distance_type`. It should be either `cosine` or `dot`, but got {class_distance_type}."
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
self.text_config = text_config
|
| 246 |
+
self.backbone_config = backbone_config
|
| 247 |
+
self.use_timm_backbone = use_timm_backbone
|
| 248 |
+
self.backbone = backbone
|
| 249 |
+
self.backbone_kwargs = backbone_kwargs
|
| 250 |
+
self.use_pretrained_backbone = use_pretrained_backbone
|
| 251 |
+
self.apply_layernorm_after_vision_backbone = apply_layernorm_after_vision_backbone
|
| 252 |
+
self.image_size = image_size
|
| 253 |
+
self.disable_custom_kernels = disable_custom_kernels
|
| 254 |
+
self.layer_norm_eps = layer_norm_eps
|
| 255 |
+
self.batch_norm_eps = batch_norm_eps
|
| 256 |
+
self.init_std = init_std
|
| 257 |
+
self.text_projection_in_dim = text_projection_in_dim
|
| 258 |
+
self.text_projection_out_dim = text_projection_out_dim
|
| 259 |
+
self.task_encoder_hidden_dim = task_encoder_hidden_dim
|
| 260 |
+
self.class_embed_dim = class_embed_dim
|
| 261 |
+
self.class_distance_type = class_distance_type
|
| 262 |
+
self.num_queries = num_queries
|
| 263 |
+
self.csp_activation = csp_activation
|
| 264 |
+
self.conv_norm_activation = conv_norm_activation
|
| 265 |
+
self.encoder_feedforward_activation = encoder_feedforward_activation
|
| 266 |
+
self.encoder_feedforward_dropout = encoder_feedforward_dropout
|
| 267 |
+
self.encoder_dropout = encoder_dropout
|
| 268 |
+
self.hidden_expansion = hidden_expansion
|
| 269 |
+
self.vision_features_channels = vision_features_channels
|
| 270 |
+
self.encoder_hidden_dim = encoder_hidden_dim
|
| 271 |
+
self.encoder_in_channels = encoder_in_channels
|
| 272 |
+
self.encoder_projection_indices = encoder_projection_indices
|
| 273 |
+
self.encoder_attention_heads = encoder_attention_heads
|
| 274 |
+
self.encoder_dim_feedforward = encoder_dim_feedforward
|
| 275 |
+
self.encoder_layers = encoder_layers
|
| 276 |
+
self.positional_encoding_temperature = positional_encoding_temperature
|
| 277 |
+
self.num_feature_levels = num_feature_levels
|
| 278 |
+
self.decoder_hidden_dim = decoder_hidden_dim
|
| 279 |
+
self.decoder_num_heads = decoder_num_heads
|
| 280 |
+
self.decoder_num_layers = decoder_num_layers
|
| 281 |
+
self.decoder_activation = decoder_activation
|
| 282 |
+
self.decoder_dim_feedforward = decoder_dim_feedforward
|
| 283 |
+
self.decoder_num_points = decoder_num_points
|
| 284 |
+
self.decoder_dropout = decoder_dropout
|
| 285 |
+
self.eval_size = eval_size
|
| 286 |
+
self.learn_initial_query = learn_initial_query
|
| 287 |
+
self.cache_size = cache_size
|
| 288 |
+
self.is_encoder_decoder = is_encoder_decoder
|
| 289 |
+
|
| 290 |
+
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
__all__ = ["OmDetTurboConfig"]
|
docs/transformers/src/transformers/models/omdet_turbo/convert_omdet_turbo_to_hf.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Convert OmDet-Turbo checkpoints from the original repository.
|
| 16 |
+
|
| 17 |
+
URL: https://github.com/om-ai-lab/OmDet"""
|
| 18 |
+
|
| 19 |
+
import argparse
|
| 20 |
+
|
| 21 |
+
import requests
|
| 22 |
+
import torch
|
| 23 |
+
from PIL import Image
|
| 24 |
+
|
| 25 |
+
from transformers import (
|
| 26 |
+
CLIPTokenizer,
|
| 27 |
+
DetrImageProcessor,
|
| 28 |
+
OmDetTurboConfig,
|
| 29 |
+
OmDetTurboForObjectDetection,
|
| 30 |
+
OmDetTurboProcessor,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
IMAGE_MEAN = [123.675, 116.28, 103.53]
|
| 35 |
+
IMAGE_STD = [58.395, 57.12, 57.375]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_omdet_turbo_config(model_name, use_timm_backbone):
|
| 39 |
+
if "tiny" in model_name:
|
| 40 |
+
window_size = 7
|
| 41 |
+
embed_dim = 96
|
| 42 |
+
depths = (2, 2, 6, 2)
|
| 43 |
+
num_heads = (3, 6, 12, 24)
|
| 44 |
+
image_size = 640
|
| 45 |
+
else:
|
| 46 |
+
raise ValueError("Model not supported, only supports tiny variant.")
|
| 47 |
+
|
| 48 |
+
config = OmDetTurboConfig(
|
| 49 |
+
backbone_window_size=window_size,
|
| 50 |
+
backbone_image_size=image_size,
|
| 51 |
+
backbone_embed_dim=embed_dim,
|
| 52 |
+
backbone_depths=depths,
|
| 53 |
+
backbone_num_heads=num_heads,
|
| 54 |
+
backbone_out_indices=(1, 2, 3),
|
| 55 |
+
text_config={"model_type": "clip_text_model"},
|
| 56 |
+
use_timm_backbone=use_timm_backbone,
|
| 57 |
+
backbone="swin_tiny_patch4_window7_224" if use_timm_backbone else None,
|
| 58 |
+
apply_layernorm_after_vision_backbone=True if use_timm_backbone else False,
|
| 59 |
+
use_pretrained_backbone=False,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
return config
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def create_rename_keys_vision(state_dict, config):
|
| 66 |
+
rename_keys = []
|
| 67 |
+
# fmt: off
|
| 68 |
+
########################################## VISION BACKBONE - START
|
| 69 |
+
for layer_name in state_dict.keys():
|
| 70 |
+
if layer_name.startswith("backbone") and not layer_name.startswith("backbone.norm"):
|
| 71 |
+
if config.use_timm_backbone:
|
| 72 |
+
layer_name_replace = layer_name.replace("backbone", "vision_backbone.vision_backbone._backbone")
|
| 73 |
+
layer_name_replace = layer_name_replace.replace(".layers.", ".layers_")
|
| 74 |
+
if "downsample" in layer_name:
|
| 75 |
+
# get layer number
|
| 76 |
+
layer_num = int(layer_name.split(".")[2])
|
| 77 |
+
layer_name_replace = layer_name_replace.replace(f"{layer_num}.downsample", f"{layer_num+1}.downsample")
|
| 78 |
+
else:
|
| 79 |
+
layer_name_replace = layer_name.replace("backbone", "vision_backbone.vision_backbone")
|
| 80 |
+
layer_name_replace = layer_name_replace.replace("patch_embed.proj", "embeddings.patch_embeddings.projection")
|
| 81 |
+
layer_name_replace = layer_name_replace.replace("patch_embed.norm", "embeddings.norm")
|
| 82 |
+
if layer_name.startswith("backbone.layers"):
|
| 83 |
+
layer_name_replace = layer_name_replace.replace("norm1", "layernorm_before")
|
| 84 |
+
layer_name_replace = layer_name_replace.replace("norm2", "layernorm_after")
|
| 85 |
+
layer_name_replace = layer_name_replace.replace("attn.proj", "attention.output.dense")
|
| 86 |
+
layer_name_replace = layer_name_replace.replace("mlp.fc1", "intermediate.dense")
|
| 87 |
+
layer_name_replace = layer_name_replace.replace("mlp.fc2", "output.dense")
|
| 88 |
+
layer_name_replace = layer_name_replace.replace(".layers.", ".encoder.layers.")
|
| 89 |
+
layer_name_replace = layer_name_replace.replace(".attn.", ".attention.self.")
|
| 90 |
+
elif layer_name.startswith("backbone.norm"):
|
| 91 |
+
layer_num = int(layer_name.split("norm")[1].split(".")[0])
|
| 92 |
+
if config.use_timm_backbone:
|
| 93 |
+
layer_name_replace = layer_name.replace("backbone", "vision_backbone")
|
| 94 |
+
layer_name_replace = layer_name_replace.replace(f"norm{layer_num}", f"layer_norms.{layer_num-1}")
|
| 95 |
+
else:
|
| 96 |
+
layer_name_replace = layer_name.replace(f"backbone.norm{layer_num}", f"vision_backbone.vision_backbone.hidden_states_norms.stage{layer_num+1}")
|
| 97 |
+
else:
|
| 98 |
+
continue
|
| 99 |
+
rename_keys.append((layer_name, layer_name_replace))
|
| 100 |
+
########################################## VISION BACKBONE - END
|
| 101 |
+
|
| 102 |
+
########################################## ENCODER - START
|
| 103 |
+
for layer_name, params in state_dict.items():
|
| 104 |
+
if "neck" in layer_name:
|
| 105 |
+
layer_name_replace = layer_name.replace("neck", "encoder")
|
| 106 |
+
layer_name_replace = layer_name_replace.replace("input_proj", "channel_projection_layers")
|
| 107 |
+
if "fpn_blocks" in layer_name or "pan_blocks" in layer_name or "lateral_convs" in layer_name or "downsample_convs" in layer_name:
|
| 108 |
+
layer_name_replace = layer_name_replace.replace(".m.", ".bottlenecks.")
|
| 109 |
+
layer_name_replace = layer_name_replace.replace(".cv", ".conv")
|
| 110 |
+
layer_name_replace = layer_name_replace.replace(".bn", ".norm")
|
| 111 |
+
if "encoder_layer" in layer_name:
|
| 112 |
+
layer_name_replace = layer_name_replace.replace("encoder_layer", "encoder.0.layers.0")
|
| 113 |
+
layer_name_replace = layer_name_replace.replace(".linear", ".fc")
|
| 114 |
+
layer_name_replace = layer_name_replace.replace("norm1", "self_attn_layer_norm")
|
| 115 |
+
layer_name_replace = layer_name_replace.replace("norm2", "final_layer_norm")
|
| 116 |
+
rename_keys.append((layer_name, layer_name_replace))
|
| 117 |
+
########################################## ENCODER - END
|
| 118 |
+
|
| 119 |
+
########################################## DECODER - START
|
| 120 |
+
for layer_name, params in state_dict.items():
|
| 121 |
+
if layer_name.startswith("decoder"):
|
| 122 |
+
layer_name_replace = layer_name.replace("decoder.decoder.layers", "decoder.layers")
|
| 123 |
+
layer_name_replace = layer_name_replace.replace("input_proj", "channel_projection_layers")
|
| 124 |
+
layer_name_replace = layer_name_replace.replace("query_pos_head", "query_position_head")
|
| 125 |
+
layer_name_replace = layer_name_replace.replace("enc_bbox_head", "encoder_bbox_head")
|
| 126 |
+
layer_name_replace = layer_name_replace.replace("enc_output", "encoder_vision_features")
|
| 127 |
+
layer_name_replace = layer_name_replace.replace("dec_score_head", "decoder_class_head")
|
| 128 |
+
layer_name_replace = layer_name_replace.replace("dec_bbox_head", "decoder_bbox_head")
|
| 129 |
+
layer_name_replace = layer_name_replace.replace("enc_score_head", "encoder_class_head")
|
| 130 |
+
rename_keys.append((layer_name, layer_name_replace))
|
| 131 |
+
########################################## DECODER - END
|
| 132 |
+
# fmt: on
|
| 133 |
+
return rename_keys
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def create_rename_keys_language(state_dict):
|
| 137 |
+
rename_keys = []
|
| 138 |
+
# fmt: off
|
| 139 |
+
for layer_name in state_dict.keys():
|
| 140 |
+
if layer_name.startswith("language_backbone") and not layer_name.startswith("language_backbone.text_projection"):
|
| 141 |
+
layer_name_replace = layer_name.replace("language_backbone", "language_backbone.model.text_model")
|
| 142 |
+
layer_name_replace = layer_name_replace.replace("transformer.resblocks", "encoder.layers")
|
| 143 |
+
layer_name_replace = layer_name_replace.replace("token_embedding", "embeddings.token_embedding")
|
| 144 |
+
layer_name_replace = layer_name_replace.replace("positional_embedding", "embeddings.position_embedding.weight")
|
| 145 |
+
layer_name_replace = layer_name_replace.replace(".attn", ".self_attn")
|
| 146 |
+
layer_name_replace = layer_name_replace.replace(".mlp.c_fc", ".mlp.fc1")
|
| 147 |
+
layer_name_replace = layer_name_replace.replace(".mlp.c_proj", ".mlp.fc2")
|
| 148 |
+
layer_name_replace = layer_name_replace.replace("ln_final", "final_layer_norm")
|
| 149 |
+
layer_name_replace = layer_name_replace.replace(".ln_", ".layer_norm")
|
| 150 |
+
rename_keys.append((layer_name, layer_name_replace))
|
| 151 |
+
# fmt: on
|
| 152 |
+
return rename_keys
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def rename_key(dct, old, new):
|
| 156 |
+
val = dct.pop(old)
|
| 157 |
+
dct[new] = val
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
# we split up the matrix of each encoder layer into queries, keys and values
|
| 161 |
+
def read_in_q_k_v_vision(state_dict, config):
|
| 162 |
+
state_dict_keys = list(state_dict.keys())
|
| 163 |
+
for layer_name_vision in state_dict_keys:
|
| 164 |
+
if layer_name_vision.startswith("vision_backbone") and "qkv" in layer_name_vision:
|
| 165 |
+
layer_num = int(layer_name_vision.split(".")[4])
|
| 166 |
+
hidden_size = config.backbone_config.embed_dim * 2**layer_num
|
| 167 |
+
if "weight" in layer_name_vision:
|
| 168 |
+
in_proj_weight = state_dict.pop(layer_name_vision)
|
| 169 |
+
state_dict[layer_name_vision.replace("qkv.weight", "key.weight")] = in_proj_weight[:hidden_size, :]
|
| 170 |
+
state_dict[layer_name_vision.replace("qkv.weight", "query.weight")] = in_proj_weight[
|
| 171 |
+
hidden_size : hidden_size * 2, :
|
| 172 |
+
]
|
| 173 |
+
state_dict[layer_name_vision.replace("qkv.weight", "value.weight")] = in_proj_weight[-hidden_size:, :]
|
| 174 |
+
elif "bias" in layer_name_vision:
|
| 175 |
+
in_proj_bias = state_dict.pop(layer_name_vision)
|
| 176 |
+
state_dict[layer_name_vision.replace("qkv.bias", "key.bias")] = in_proj_bias[:hidden_size]
|
| 177 |
+
state_dict[layer_name_vision.replace("qkv.bias", "query.bias")] = in_proj_bias[
|
| 178 |
+
hidden_size : hidden_size * 2
|
| 179 |
+
]
|
| 180 |
+
state_dict[layer_name_vision.replace("qkv.bias", "value.bias")] = in_proj_bias[-hidden_size:]
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def read_in_q_k_v_text(state_dict, config):
|
| 184 |
+
state_dict_keys = list(state_dict.keys())
|
| 185 |
+
hidden_size = config.text_config.projection_dim
|
| 186 |
+
for layer_name_text in state_dict_keys:
|
| 187 |
+
if layer_name_text.startswith("language_backbone") and "in_proj" in layer_name_text:
|
| 188 |
+
if "weight" in layer_name_text:
|
| 189 |
+
in_proj_weight = state_dict.pop(layer_name_text)
|
| 190 |
+
state_dict[layer_name_text.replace("in_proj_weight", "q_proj.weight")] = in_proj_weight[
|
| 191 |
+
:hidden_size, :
|
| 192 |
+
]
|
| 193 |
+
state_dict[layer_name_text.replace("in_proj_weight", "k_proj.weight")] = in_proj_weight[
|
| 194 |
+
hidden_size : hidden_size * 2, :
|
| 195 |
+
]
|
| 196 |
+
state_dict[layer_name_text.replace("in_proj_weight", "v_proj.weight")] = in_proj_weight[
|
| 197 |
+
-hidden_size:, :
|
| 198 |
+
]
|
| 199 |
+
elif "bias" in layer_name_text:
|
| 200 |
+
in_proj_bias = state_dict.pop(layer_name_text)
|
| 201 |
+
state_dict[layer_name_text.replace("in_proj_bias", "q_proj.bias")] = in_proj_bias[:hidden_size]
|
| 202 |
+
state_dict[layer_name_text.replace("in_proj_bias", "k_proj.bias")] = in_proj_bias[
|
| 203 |
+
hidden_size : hidden_size * 2
|
| 204 |
+
]
|
| 205 |
+
state_dict[layer_name_text.replace("in_proj_bias", "v_proj.bias")] = in_proj_bias[-hidden_size:]
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def read_in_q_k_v_encoder(state_dict, config):
|
| 209 |
+
embed_dim = config.encoder_hidden_dim
|
| 210 |
+
# read in weights + bias of input projection layer (in original implementation, this is a single matrix + bias)
|
| 211 |
+
in_proj_weight = state_dict.pop("encoder.encoder.0.layers.0.self_attn.in_proj_weight")
|
| 212 |
+
in_proj_bias = state_dict.pop("encoder.encoder.0.layers.0.self_attn.in_proj_bias")
|
| 213 |
+
# next, add query, keys and values (in that order) to the state dict
|
| 214 |
+
state_dict["encoder.encoder.0.layers.0.self_attn.query.weight"] = in_proj_weight[:embed_dim, :]
|
| 215 |
+
state_dict["encoder.encoder.0.layers.0.self_attn.query.bias"] = in_proj_bias[:embed_dim]
|
| 216 |
+
state_dict["encoder.encoder.0.layers.0.self_attn.key.weight"] = in_proj_weight[embed_dim : embed_dim * 2, :]
|
| 217 |
+
state_dict["encoder.encoder.0.layers.0.self_attn.key.bias"] = in_proj_bias[embed_dim : embed_dim * 2]
|
| 218 |
+
state_dict["encoder.encoder.0.layers.0.self_attn.value.weight"] = in_proj_weight[-embed_dim:, :]
|
| 219 |
+
state_dict["encoder.encoder.0.layers.0.self_attn.value.bias"] = in_proj_bias[-embed_dim:]
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def read_in_q_k_v_decoder(state_dict, config):
|
| 223 |
+
for layer_num in range(config.decoder_num_layers):
|
| 224 |
+
embed_dim = config.decoder_hidden_dim
|
| 225 |
+
# read in weights + bias of input projection layer (in original implementation, this is a single matrix + bias)
|
| 226 |
+
in_proj_weight = state_dict.pop(f"decoder.layers.{layer_num}.self_attn.in_proj_weight")
|
| 227 |
+
in_proj_bias = state_dict.pop(f"decoder.layers.{layer_num}.self_attn.in_proj_bias")
|
| 228 |
+
# next, add query, keys and values (in that order) to the state dict
|
| 229 |
+
state_dict[f"decoder.layers.{layer_num}.self_attn.query.weight"] = in_proj_weight[:embed_dim, :]
|
| 230 |
+
state_dict[f"decoder.layers.{layer_num}.self_attn.query.bias"] = in_proj_bias[:embed_dim]
|
| 231 |
+
state_dict[f"decoder.layers.{layer_num}.self_attn.key.weight"] = in_proj_weight[embed_dim : embed_dim * 2, :]
|
| 232 |
+
state_dict[f"decoder.layers.{layer_num}.self_attn.key.bias"] = in_proj_bias[embed_dim : embed_dim * 2]
|
| 233 |
+
state_dict[f"decoder.layers.{layer_num}.self_attn.value.weight"] = in_proj_weight[-embed_dim:, :]
|
| 234 |
+
state_dict[f"decoder.layers.{layer_num}.self_attn.value.bias"] = in_proj_bias[-embed_dim:]
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def run_test(model, processor):
|
| 238 |
+
# We will verify our results on an image of cute cats
|
| 239 |
+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 240 |
+
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
|
| 241 |
+
|
| 242 |
+
classes = ["cat", "remote"]
|
| 243 |
+
task = "Detect {}.".format(", ".join(classes))
|
| 244 |
+
inputs = processor(image, text=classes, task=task, return_tensors="pt")
|
| 245 |
+
|
| 246 |
+
# Running forward
|
| 247 |
+
with torch.no_grad():
|
| 248 |
+
outputs = model(**inputs)
|
| 249 |
+
|
| 250 |
+
predicted_slice = outputs[1][0, :3, :3]
|
| 251 |
+
print(predicted_slice)
|
| 252 |
+
expected_slice = torch.tensor([[0.9427, -2.5958], [0.2105, -3.4569], [-2.6364, -4.1610]])
|
| 253 |
+
|
| 254 |
+
assert torch.allclose(predicted_slice, expected_slice, atol=1e-4)
|
| 255 |
+
print("Looks ok!")
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
@torch.no_grad()
|
| 259 |
+
def convert_omdet_turbo_checkpoint(args):
|
| 260 |
+
model_name = args.model_name
|
| 261 |
+
pytorch_dump_folder_path = args.pytorch_dump_folder_path
|
| 262 |
+
push_to_hub = args.push_to_hub
|
| 263 |
+
use_timm_backbone = args.use_timm_backbone
|
| 264 |
+
|
| 265 |
+
checkpoint_mapping = {
|
| 266 |
+
"omdet-turbo-tiny": [
|
| 267 |
+
"https://huggingface.co/omlab/OmDet-Turbo_tiny_SWIN_T/resolve/main/OmDet-Turbo_tiny_SWIN_T.pth",
|
| 268 |
+
"https://huggingface.co/omlab/OmDet-Turbo_tiny_SWIN_T/resolve/main/ViT-B-16.pt",
|
| 269 |
+
],
|
| 270 |
+
}
|
| 271 |
+
# Define default OmDetTurbo configuation
|
| 272 |
+
config = get_omdet_turbo_config(model_name, use_timm_backbone)
|
| 273 |
+
|
| 274 |
+
# Load original checkpoint
|
| 275 |
+
checkpoint_url = checkpoint_mapping[model_name]
|
| 276 |
+
original_state_dict_vision = torch.hub.load_state_dict_from_url(checkpoint_url[0], map_location="cpu")["model"]
|
| 277 |
+
original_state_dict_vision = {k.replace("module.", ""): v for k, v in original_state_dict_vision.items()}
|
| 278 |
+
|
| 279 |
+
# Rename keys
|
| 280 |
+
new_state_dict = original_state_dict_vision.copy()
|
| 281 |
+
rename_keys_vision = create_rename_keys_vision(new_state_dict, config)
|
| 282 |
+
|
| 283 |
+
rename_keys_language = create_rename_keys_language(new_state_dict)
|
| 284 |
+
|
| 285 |
+
for src, dest in rename_keys_vision:
|
| 286 |
+
rename_key(new_state_dict, src, dest)
|
| 287 |
+
|
| 288 |
+
for src, dest in rename_keys_language:
|
| 289 |
+
rename_key(new_state_dict, src, dest)
|
| 290 |
+
|
| 291 |
+
if not use_timm_backbone:
|
| 292 |
+
read_in_q_k_v_vision(new_state_dict, config)
|
| 293 |
+
read_in_q_k_v_text(new_state_dict, config)
|
| 294 |
+
read_in_q_k_v_encoder(new_state_dict, config)
|
| 295 |
+
read_in_q_k_v_decoder(new_state_dict, config)
|
| 296 |
+
# add "model" prefix to all keys
|
| 297 |
+
new_state_dict = {f"model.{k}": v for k, v in new_state_dict.items()}
|
| 298 |
+
|
| 299 |
+
# Load HF model
|
| 300 |
+
model = OmDetTurboForObjectDetection(config)
|
| 301 |
+
model.eval()
|
| 302 |
+
missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
|
| 303 |
+
print("Missing keys:", missing_keys)
|
| 304 |
+
print("Unexpected keys:", unexpected_keys)
|
| 305 |
+
|
| 306 |
+
image_processor = DetrImageProcessor(
|
| 307 |
+
size={"height": config.backbone_image_size, "width": config.backbone_image_size},
|
| 308 |
+
do_rescale=False,
|
| 309 |
+
image_mean=IMAGE_MEAN,
|
| 310 |
+
image_std=IMAGE_STD,
|
| 311 |
+
do_pad=False,
|
| 312 |
+
)
|
| 313 |
+
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
| 314 |
+
processor = OmDetTurboProcessor(image_processor=image_processor, tokenizer=tokenizer)
|
| 315 |
+
|
| 316 |
+
# end-to-end consistency test
|
| 317 |
+
run_test(model, processor)
|
| 318 |
+
|
| 319 |
+
if pytorch_dump_folder_path is not None:
|
| 320 |
+
model.save_pretrained(pytorch_dump_folder_path)
|
| 321 |
+
processor.save_pretrained(pytorch_dump_folder_path)
|
| 322 |
+
|
| 323 |
+
if push_to_hub:
|
| 324 |
+
model.push_to_hub(f"omlab/{model_name}")
|
| 325 |
+
processor.push_to_hub(f"omlab/{model_name}")
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
if __name__ == "__main__":
|
| 329 |
+
parser = argparse.ArgumentParser()
|
| 330 |
+
# Required parameters
|
| 331 |
+
parser.add_argument(
|
| 332 |
+
"--model_name",
|
| 333 |
+
default="omdet-turbo-tiny",
|
| 334 |
+
type=str,
|
| 335 |
+
choices=["omdet-turbo-tiny"],
|
| 336 |
+
help="Name of the OmDetTurbo model you'd like to convert.",
|
| 337 |
+
)
|
| 338 |
+
parser.add_argument(
|
| 339 |
+
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
|
| 340 |
+
)
|
| 341 |
+
parser.add_argument(
|
| 342 |
+
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
|
| 343 |
+
)
|
| 344 |
+
parser.add_argument(
|
| 345 |
+
"--use_timm_backbone", action="store_true", help="Whether or not to use timm backbone for vision backbone."
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
args = parser.parse_args()
|
| 349 |
+
convert_omdet_turbo_checkpoint(args)
|
docs/transformers/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py
ADDED
|
@@ -0,0 +1,1711 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 Om Research Lab and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch OmDet-Turbo model."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
import warnings
|
| 19 |
+
from collections import OrderedDict
|
| 20 |
+
from dataclasses import dataclass
|
| 21 |
+
from functools import lru_cache
|
| 22 |
+
from typing import List, Optional, Tuple, Union
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
import torch.nn.functional as F
|
| 26 |
+
from torch import Tensor, nn
|
| 27 |
+
|
| 28 |
+
from ...activations import ACT2CLS, ACT2FN
|
| 29 |
+
from ...file_utils import (
|
| 30 |
+
ModelOutput,
|
| 31 |
+
add_start_docstrings,
|
| 32 |
+
add_start_docstrings_to_model_forward,
|
| 33 |
+
replace_return_docstrings,
|
| 34 |
+
)
|
| 35 |
+
from ...integrations import use_kernel_forward_from_hub
|
| 36 |
+
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
| 37 |
+
from ...modeling_utils import PreTrainedModel
|
| 38 |
+
from ...utils import logging
|
| 39 |
+
from ...utils.backbone_utils import load_backbone
|
| 40 |
+
from ..auto import AutoModel
|
| 41 |
+
from .configuration_omdet_turbo import OmDetTurboConfig
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
logger = logging.get_logger(__name__)
|
| 45 |
+
_CONFIG_FOR_DOC = "OmDetTurboConfig"
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@dataclass
|
| 49 |
+
class OmDetTurboEncoderOutput(ModelOutput):
|
| 50 |
+
"""
|
| 51 |
+
Base class for outputs of the OmDetTurboHybridEncoder.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
last_hidden_state (`torch.FloatTensor`):
|
| 55 |
+
Last hidden states of the encoder.
|
| 56 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 57 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| 58 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
| 59 |
+
|
| 60 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
| 61 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 62 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 63 |
+
sequence_length)`.
|
| 64 |
+
|
| 65 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 66 |
+
heads.
|
| 67 |
+
extracted_states (`Tuple[torch.FloatTensor]`):
|
| 68 |
+
The extracted states from the Feature Pyramid Network (FPN) and Path Aggregation Network (PAN) of the encoder.
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 72 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 73 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 74 |
+
extracted_states: Tuple[torch.FloatTensor] = None
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@dataclass
|
| 78 |
+
class OmDetTurboDecoderOutput(ModelOutput):
|
| 79 |
+
"""
|
| 80 |
+
Base class for outputs of the OmDetTurboDecoder.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 84 |
+
Sequence of hidden-states at the output of the last layer of the decoder.
|
| 85 |
+
decoder_coords (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
|
| 86 |
+
The predicted coordinates of the objects.
|
| 87 |
+
decoder_classes (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes)`):
|
| 88 |
+
The predicted classes of the objects.
|
| 89 |
+
encoder_coord_logits (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
|
| 90 |
+
The predicted coordinates of the objects from the encoder.
|
| 91 |
+
encoder_class_logits (`Tuple[torch.FloatTensor]`) of shape `(batch_size, num_queries, num_classes)`:
|
| 92 |
+
The predicted class of the objects from the encoder.
|
| 93 |
+
init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
|
| 94 |
+
The initial reference points.
|
| 95 |
+
intermediate_reference_points (`Tuple[Tuple[torch.FloatTensor]]`):
|
| 96 |
+
The intermediate reference points.
|
| 97 |
+
hidden_states (`Optional[Tuple[torch.FloatTensor]]`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 98 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape
|
| 99 |
+
`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
|
| 100 |
+
plus the initial embedding outputs.
|
| 101 |
+
attentions (`Optional[Tuple[Tuple[torch.FloatTensor]]]`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 102 |
+
Tuple of tuples of `torch.FloatTensor` (one for attention for each layer) of shape `(batch_size, num_heads,
|
| 103 |
+
sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
|
| 104 |
+
weighted average in the self-attention, cross-attention and multi-scale deformable attention heads.
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 108 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 109 |
+
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
| 110 |
+
decoder_coords: Optional[torch.FloatTensor] = None
|
| 111 |
+
decoder_classes: Optional[torch.FloatTensor] = None
|
| 112 |
+
encoder_coord_logits: Optional[torch.FloatTensor] = None
|
| 113 |
+
encoder_class_logits: Tuple[torch.FloatTensor] = None
|
| 114 |
+
init_reference_points: Optional[torch.FloatTensor] = None
|
| 115 |
+
intermediate_reference_points: Tuple[Tuple[torch.FloatTensor]] = None
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
@dataclass
|
| 119 |
+
class OmDetTurboObjectDetectionOutput(ModelOutput):
|
| 120 |
+
"""
|
| 121 |
+
Output type of [`OmDetTurboObjectDetectionOutput`].
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
loss (`torch.FloatTensor`):
|
| 125 |
+
The loss value.
|
| 126 |
+
decoder_coord_logits (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
|
| 127 |
+
The predicted coordinates logits of the objects.
|
| 128 |
+
decoder_class_logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes)`):
|
| 129 |
+
The predicted class of the objects.
|
| 130 |
+
init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
|
| 131 |
+
The initial reference points.
|
| 132 |
+
intermediate_reference_points (`Tuple[Tuple[torch.FloatTensor]]`):
|
| 133 |
+
The intermediate reference points.
|
| 134 |
+
encoder_coord_logits (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
|
| 135 |
+
The predicted coordinates of the objects from the encoder.
|
| 136 |
+
encoder_class_logits (`Tuple[torch.FloatTensor]`):
|
| 137 |
+
The predicted class of the objects from the encoder.
|
| 138 |
+
encoder_extracted_states (`torch.FloatTensor`):
|
| 139 |
+
The extracted states from the Feature Pyramid Network (FPN) and Path Aggregation Network (PAN) of the encoder.
|
| 140 |
+
decoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
|
| 141 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape
|
| 142 |
+
`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
|
| 143 |
+
plus the initial embedding outputs.
|
| 144 |
+
decoder_attentions (`Tuple[Tuple[torch.FloatTensor]]`, *optional*):
|
| 145 |
+
Tuple of tuples of `torch.FloatTensor` (one for attention for each layer) of shape `(batch_size, num_heads,
|
| 146 |
+
sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
|
| 147 |
+
weighted average in the self-attention, cross-attention and multi-scale deformable attention heads.
|
| 148 |
+
encoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
|
| 149 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape
|
| 150 |
+
`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
|
| 151 |
+
plus the initial embedding outputs.
|
| 152 |
+
encoder_attentions (`Tuple[Tuple[torch.FloatTensor]]`, *optional*):
|
| 153 |
+
Tuple of tuples of `torch.FloatTensor` (one for attention for each layer) of shape `(batch_size, num_heads,
|
| 154 |
+
sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
|
| 155 |
+
weighted average in the self-attention, cross-attention and multi-scale deformable attention heads.
|
| 156 |
+
classes_structure (`torch.LongTensor`, *optional*):
|
| 157 |
+
The number of queried classes for each image.
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
loss: Optional[torch.FloatTensor] = None
|
| 161 |
+
decoder_coord_logits: Optional[torch.FloatTensor] = None
|
| 162 |
+
decoder_class_logits: Optional[torch.FloatTensor] = None
|
| 163 |
+
init_reference_points: Optional[torch.FloatTensor] = None
|
| 164 |
+
intermediate_reference_points: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
| 165 |
+
encoder_coord_logits: Optional[torch.FloatTensor] = None
|
| 166 |
+
encoder_class_logits: Tuple[torch.FloatTensor] = None
|
| 167 |
+
encoder_extracted_states: Optional[torch.FloatTensor] = None
|
| 168 |
+
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 169 |
+
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
| 170 |
+
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 171 |
+
encoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
| 172 |
+
classes_structure: Optional[torch.LongTensor] = None
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
@use_kernel_forward_from_hub("MultiScaleDeformableAttention")
|
| 176 |
+
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttention
|
| 177 |
+
class MultiScaleDeformableAttention(nn.Module):
|
| 178 |
+
def forward(
|
| 179 |
+
self,
|
| 180 |
+
value: Tensor,
|
| 181 |
+
value_spatial_shapes: Tensor,
|
| 182 |
+
value_spatial_shapes_list: List[Tuple],
|
| 183 |
+
level_start_index: Tensor,
|
| 184 |
+
sampling_locations: Tensor,
|
| 185 |
+
attention_weights: Tensor,
|
| 186 |
+
im2col_step: int,
|
| 187 |
+
):
|
| 188 |
+
batch_size, _, num_heads, hidden_dim = value.shape
|
| 189 |
+
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
| 190 |
+
value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1)
|
| 191 |
+
sampling_grids = 2 * sampling_locations - 1
|
| 192 |
+
sampling_value_list = []
|
| 193 |
+
for level_id, (height, width) in enumerate(value_spatial_shapes_list):
|
| 194 |
+
# batch_size, height*width, num_heads, hidden_dim
|
| 195 |
+
# -> batch_size, height*width, num_heads*hidden_dim
|
| 196 |
+
# -> batch_size, num_heads*hidden_dim, height*width
|
| 197 |
+
# -> batch_size*num_heads, hidden_dim, height, width
|
| 198 |
+
value_l_ = (
|
| 199 |
+
value_list[level_id]
|
| 200 |
+
.flatten(2)
|
| 201 |
+
.transpose(1, 2)
|
| 202 |
+
.reshape(batch_size * num_heads, hidden_dim, height, width)
|
| 203 |
+
)
|
| 204 |
+
# batch_size, num_queries, num_heads, num_points, 2
|
| 205 |
+
# -> batch_size, num_heads, num_queries, num_points, 2
|
| 206 |
+
# -> batch_size*num_heads, num_queries, num_points, 2
|
| 207 |
+
sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
|
| 208 |
+
# batch_size*num_heads, hidden_dim, num_queries, num_points
|
| 209 |
+
sampling_value_l_ = nn.functional.grid_sample(
|
| 210 |
+
value_l_,
|
| 211 |
+
sampling_grid_l_,
|
| 212 |
+
mode="bilinear",
|
| 213 |
+
padding_mode="zeros",
|
| 214 |
+
align_corners=False,
|
| 215 |
+
)
|
| 216 |
+
sampling_value_list.append(sampling_value_l_)
|
| 217 |
+
# (batch_size, num_queries, num_heads, num_levels, num_points)
|
| 218 |
+
# -> (batch_size, num_heads, num_queries, num_levels, num_points)
|
| 219 |
+
# -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
|
| 220 |
+
attention_weights = attention_weights.transpose(1, 2).reshape(
|
| 221 |
+
batch_size * num_heads, 1, num_queries, num_levels * num_points
|
| 222 |
+
)
|
| 223 |
+
output = (
|
| 224 |
+
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
|
| 225 |
+
.sum(-1)
|
| 226 |
+
.view(batch_size, num_heads * hidden_dim, num_queries)
|
| 227 |
+
)
|
| 228 |
+
return output.transpose(1, 2).contiguous()
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
class OmDetTurboLRUCache:
|
| 232 |
+
def __init__(self, capacity: int):
|
| 233 |
+
self.cache = OrderedDict()
|
| 234 |
+
self.capacity = capacity
|
| 235 |
+
self.current_load = 0
|
| 236 |
+
|
| 237 |
+
def has(self, key) -> bool:
|
| 238 |
+
return key in self.cache
|
| 239 |
+
|
| 240 |
+
def get(self, key):
|
| 241 |
+
"""
|
| 242 |
+
Get the value of the key if the key exists in the cache, otherwise return None.
|
| 243 |
+
Move the key to the end of the cache to show that it was recently used.
|
| 244 |
+
"""
|
| 245 |
+
if key not in self.cache:
|
| 246 |
+
return None
|
| 247 |
+
self.cache.move_to_end(key)
|
| 248 |
+
return self.cache[key]
|
| 249 |
+
|
| 250 |
+
def put(self, key, value) -> None:
|
| 251 |
+
"""
|
| 252 |
+
Add the key-value pair to the cache.
|
| 253 |
+
Move the key to the end of the cache to show that it was recently used.
|
| 254 |
+
If the cache is full, remove the first key (least recently used).
|
| 255 |
+
"""
|
| 256 |
+
if key not in self.cache:
|
| 257 |
+
self.current_load += 1
|
| 258 |
+
if self.current_load > self.capacity:
|
| 259 |
+
self.cache.popitem(last=False)
|
| 260 |
+
self.current_load -= 1
|
| 261 |
+
|
| 262 |
+
self.cache[key] = value
|
| 263 |
+
self.cache.move_to_end(key)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
class OmDetTurboLanguageBackbone(nn.Module):
|
| 267 |
+
def __init__(self, config: OmDetTurboConfig):
|
| 268 |
+
super().__init__()
|
| 269 |
+
self.model = AutoModel.from_config(config.text_config)
|
| 270 |
+
self.text_projection = nn.Parameter(torch.zeros(config.text_projection_in_dim, config.text_projection_out_dim))
|
| 271 |
+
|
| 272 |
+
def forward(self, hidden_states, mask=None, encode_type="task"):
|
| 273 |
+
text_outputs = self.model(hidden_states)
|
| 274 |
+
pooled_output = text_outputs[0]
|
| 275 |
+
if encode_type == "task":
|
| 276 |
+
if mask is None:
|
| 277 |
+
raise ValueError("mask is required for task encoding")
|
| 278 |
+
max_len = (mask != 0).sum(1).max().item()
|
| 279 |
+
truncated_mask = mask[:, :max_len]
|
| 280 |
+
truncated_output = pooled_output[:, :max_len, :]
|
| 281 |
+
return truncated_output.transpose(0, 1), truncated_mask
|
| 282 |
+
elif encode_type == "class":
|
| 283 |
+
max_pooled_output = pooled_output[torch.arange(pooled_output.shape[0]), hidden_states.argmax(dim=-1)]
|
| 284 |
+
projected_output = max_pooled_output @ self.text_projection
|
| 285 |
+
return projected_output
|
| 286 |
+
else:
|
| 287 |
+
raise ValueError(f"encode_type {encode_type} is not supported")
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
class OmDetTurboVisionBackbone(nn.Module):
|
| 291 |
+
def __init__(self, config: OmDetTurboConfig):
|
| 292 |
+
super().__init__()
|
| 293 |
+
self.apply_layernorm_after_vision_backbone = config.apply_layernorm_after_vision_backbone
|
| 294 |
+
self.vision_backbone = load_backbone(config)
|
| 295 |
+
self.layer_norms = nn.ModuleList(
|
| 296 |
+
[nn.LayerNorm(in_channel_dim, eps=config.layer_norm_eps) for in_channel_dim in config.encoder_in_channels]
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
def forward(self, pixel_values):
|
| 300 |
+
outputs = self.vision_backbone(pixel_values).feature_maps
|
| 301 |
+
if self.apply_layernorm_after_vision_backbone:
|
| 302 |
+
outputs = [
|
| 303 |
+
layer_norm(output).permute(0, 3, 1, 2).contiguous()
|
| 304 |
+
for layer_norm, output in zip(self.layer_norms, outputs)
|
| 305 |
+
]
|
| 306 |
+
|
| 307 |
+
return outputs
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention with DeformableDetr->OmDetTurbo, Deformable DETR->OmDet-Turbo
|
| 311 |
+
class OmDetTurboMultiscaleDeformableAttention(nn.Module):
|
| 312 |
+
"""
|
| 313 |
+
Multiscale deformable attention as proposed in Deformable DETR.
|
| 314 |
+
"""
|
| 315 |
+
|
| 316 |
+
def __init__(self, config: OmDetTurboConfig, num_heads: int, n_points: int):
|
| 317 |
+
super().__init__()
|
| 318 |
+
|
| 319 |
+
self.attn = MultiScaleDeformableAttention()
|
| 320 |
+
|
| 321 |
+
if config.d_model % num_heads != 0:
|
| 322 |
+
raise ValueError(
|
| 323 |
+
f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}"
|
| 324 |
+
)
|
| 325 |
+
dim_per_head = config.d_model // num_heads
|
| 326 |
+
# check if dim_per_head is power of 2
|
| 327 |
+
if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0):
|
| 328 |
+
warnings.warn(
|
| 329 |
+
"You'd better set embed_dim (d_model) in OmDetTurboMultiscaleDeformableAttention to make the"
|
| 330 |
+
" dimension of each attention head a power of 2 which is more efficient in the authors' CUDA"
|
| 331 |
+
" implementation."
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
self.im2col_step = 64
|
| 335 |
+
|
| 336 |
+
self.d_model = config.d_model
|
| 337 |
+
self.n_levels = config.num_feature_levels
|
| 338 |
+
self.n_heads = num_heads
|
| 339 |
+
self.n_points = n_points
|
| 340 |
+
|
| 341 |
+
self.sampling_offsets = nn.Linear(config.d_model, num_heads * self.n_levels * n_points * 2)
|
| 342 |
+
self.attention_weights = nn.Linear(config.d_model, num_heads * self.n_levels * n_points)
|
| 343 |
+
self.value_proj = nn.Linear(config.d_model, config.d_model)
|
| 344 |
+
self.output_proj = nn.Linear(config.d_model, config.d_model)
|
| 345 |
+
|
| 346 |
+
self.disable_custom_kernels = config.disable_custom_kernels
|
| 347 |
+
|
| 348 |
+
def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
|
| 349 |
+
return tensor if position_embeddings is None else tensor + position_embeddings
|
| 350 |
+
|
| 351 |
+
def forward(
|
| 352 |
+
self,
|
| 353 |
+
hidden_states: torch.Tensor,
|
| 354 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 355 |
+
encoder_hidden_states=None,
|
| 356 |
+
encoder_attention_mask=None,
|
| 357 |
+
position_embeddings: Optional[torch.Tensor] = None,
|
| 358 |
+
reference_points=None,
|
| 359 |
+
spatial_shapes=None,
|
| 360 |
+
spatial_shapes_list=None,
|
| 361 |
+
level_start_index=None,
|
| 362 |
+
output_attentions: bool = False,
|
| 363 |
+
):
|
| 364 |
+
# add position embeddings to the hidden states before projecting to queries and keys
|
| 365 |
+
if position_embeddings is not None:
|
| 366 |
+
hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
|
| 367 |
+
|
| 368 |
+
batch_size, num_queries, _ = hidden_states.shape
|
| 369 |
+
batch_size, sequence_length, _ = encoder_hidden_states.shape
|
| 370 |
+
# Ignore copy
|
| 371 |
+
total_elements = sum([shape[0] * shape[1] for shape in spatial_shapes_list])
|
| 372 |
+
if total_elements != sequence_length:
|
| 373 |
+
raise ValueError(
|
| 374 |
+
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
value = self.value_proj(encoder_hidden_states)
|
| 378 |
+
if attention_mask is not None:
|
| 379 |
+
# we invert the attention_mask
|
| 380 |
+
value = value.masked_fill(~attention_mask[..., None], float(0))
|
| 381 |
+
value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads)
|
| 382 |
+
sampling_offsets = self.sampling_offsets(hidden_states).view(
|
| 383 |
+
batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2
|
| 384 |
+
)
|
| 385 |
+
attention_weights = self.attention_weights(hidden_states).view(
|
| 386 |
+
batch_size, num_queries, self.n_heads, self.n_levels * self.n_points
|
| 387 |
+
)
|
| 388 |
+
attention_weights = F.softmax(attention_weights, -1).view(
|
| 389 |
+
batch_size, num_queries, self.n_heads, self.n_levels, self.n_points
|
| 390 |
+
)
|
| 391 |
+
# batch_size, num_queries, n_heads, n_levels, n_points, 2
|
| 392 |
+
num_coordinates = reference_points.shape[-1]
|
| 393 |
+
if num_coordinates == 2:
|
| 394 |
+
offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
|
| 395 |
+
sampling_locations = (
|
| 396 |
+
reference_points[:, :, None, :, None, :]
|
| 397 |
+
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
|
| 398 |
+
)
|
| 399 |
+
elif num_coordinates == 4:
|
| 400 |
+
sampling_locations = (
|
| 401 |
+
reference_points[:, :, None, :, None, :2]
|
| 402 |
+
+ sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
|
| 403 |
+
)
|
| 404 |
+
else:
|
| 405 |
+
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
|
| 406 |
+
|
| 407 |
+
output = self.attn(
|
| 408 |
+
value,
|
| 409 |
+
spatial_shapes,
|
| 410 |
+
spatial_shapes_list,
|
| 411 |
+
level_start_index,
|
| 412 |
+
sampling_locations,
|
| 413 |
+
attention_weights,
|
| 414 |
+
self.im2col_step,
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
output = self.output_proj(output)
|
| 418 |
+
|
| 419 |
+
return output, attention_weights
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
# Copied from transformers.models.rt_detr.modeling_rt_detr.RTDetrConvNormLayer with RTDetr->OmDetTurbo
|
| 423 |
+
class OmDetTurboConvNormLayer(nn.Module):
|
| 424 |
+
def __init__(self, config, in_channels, out_channels, kernel_size, stride, padding=None, activation=None):
|
| 425 |
+
super().__init__()
|
| 426 |
+
self.conv = nn.Conv2d(
|
| 427 |
+
in_channels,
|
| 428 |
+
out_channels,
|
| 429 |
+
kernel_size,
|
| 430 |
+
stride,
|
| 431 |
+
padding=(kernel_size - 1) // 2 if padding is None else padding,
|
| 432 |
+
bias=False,
|
| 433 |
+
)
|
| 434 |
+
self.norm = nn.BatchNorm2d(out_channels, config.batch_norm_eps)
|
| 435 |
+
self.activation = nn.Identity() if activation is None else ACT2CLS[activation]()
|
| 436 |
+
|
| 437 |
+
def forward(self, hidden_state):
|
| 438 |
+
hidden_state = self.conv(hidden_state)
|
| 439 |
+
hidden_state = self.norm(hidden_state)
|
| 440 |
+
hidden_state = self.activation(hidden_state)
|
| 441 |
+
return hidden_state
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
# Copied from transformers.models.rt_detr.modeling_rt_detr.RTDetrRepVggBlock with RTDetr->OmDetTurbo, activation_function->csp_activation
|
| 445 |
+
class OmDetTurboRepVggBlock(nn.Module):
|
| 446 |
+
"""
|
| 447 |
+
RepVGG architecture block introduced by the work "RepVGG: Making VGG-style ConvNets Great Again".
|
| 448 |
+
"""
|
| 449 |
+
|
| 450 |
+
def __init__(self, config: OmDetTurboConfig):
|
| 451 |
+
super().__init__()
|
| 452 |
+
|
| 453 |
+
activation = config.csp_activation
|
| 454 |
+
hidden_channels = int(config.encoder_hidden_dim * config.hidden_expansion)
|
| 455 |
+
self.conv1 = OmDetTurboConvNormLayer(config, hidden_channels, hidden_channels, 3, 1, padding=1)
|
| 456 |
+
self.conv2 = OmDetTurboConvNormLayer(config, hidden_channels, hidden_channels, 1, 1, padding=0)
|
| 457 |
+
self.activation = nn.Identity() if activation is None else ACT2CLS[activation]()
|
| 458 |
+
|
| 459 |
+
def forward(self, x):
|
| 460 |
+
y = self.conv1(x) + self.conv2(x)
|
| 461 |
+
return self.activation(y)
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
# Copied from transformers.models.rt_detr.modeling_rt_detr.RTDetrCSPRepLayer with RTDetr->OmDetTurbo, activation_function->csp_activation
|
| 465 |
+
class OmDetTurboCSPRepLayer(nn.Module):
|
| 466 |
+
"""
|
| 467 |
+
Cross Stage Partial (CSP) network layer with RepVGG blocks.
|
| 468 |
+
"""
|
| 469 |
+
|
| 470 |
+
def __init__(self, config: OmDetTurboConfig):
|
| 471 |
+
super().__init__()
|
| 472 |
+
|
| 473 |
+
in_channels = config.encoder_hidden_dim * 2
|
| 474 |
+
out_channels = config.encoder_hidden_dim
|
| 475 |
+
num_blocks = 3
|
| 476 |
+
activation = config.csp_activation
|
| 477 |
+
|
| 478 |
+
hidden_channels = int(out_channels * config.hidden_expansion)
|
| 479 |
+
self.conv1 = OmDetTurboConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation)
|
| 480 |
+
self.conv2 = OmDetTurboConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation)
|
| 481 |
+
self.bottlenecks = nn.Sequential(*[OmDetTurboRepVggBlock(config) for _ in range(num_blocks)])
|
| 482 |
+
if hidden_channels != out_channels:
|
| 483 |
+
self.conv3 = OmDetTurboConvNormLayer(config, hidden_channels, out_channels, 1, 1, activation=activation)
|
| 484 |
+
else:
|
| 485 |
+
self.conv3 = nn.Identity()
|
| 486 |
+
|
| 487 |
+
def forward(self, hidden_state):
|
| 488 |
+
hidden_state_1 = self.conv1(hidden_state)
|
| 489 |
+
hidden_state_1 = self.bottlenecks(hidden_state_1)
|
| 490 |
+
hidden_state_2 = self.conv2(hidden_state)
|
| 491 |
+
return self.conv3(hidden_state_1 + hidden_state_2)
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
class OmDetTurboMultiheadAttention(nn.Module):
|
| 495 |
+
"""Equivalent implementation of nn.MultiheadAttention with `batch_first=True`."""
|
| 496 |
+
|
| 497 |
+
def __init__(self, config, hidden_size, num_attention_heads, dropout):
|
| 498 |
+
super().__init__()
|
| 499 |
+
if hidden_size % num_attention_heads != 0:
|
| 500 |
+
raise ValueError(
|
| 501 |
+
f"The hidden size ({hidden_size}) is not a multiple of the number of attention "
|
| 502 |
+
f"heads ({num_attention_heads})"
|
| 503 |
+
)
|
| 504 |
+
self.num_attention_heads = num_attention_heads
|
| 505 |
+
self.attention_head_size = int(hidden_size / num_attention_heads)
|
| 506 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 507 |
+
self.query = nn.Linear(hidden_size, self.all_head_size)
|
| 508 |
+
self.key = nn.Linear(hidden_size, self.all_head_size)
|
| 509 |
+
self.value = nn.Linear(hidden_size, self.all_head_size)
|
| 510 |
+
self.out_proj = nn.Linear(hidden_size, hidden_size)
|
| 511 |
+
self.dropout = nn.Dropout(dropout)
|
| 512 |
+
|
| 513 |
+
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
| 514 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 515 |
+
x = x.view(new_x_shape)
|
| 516 |
+
return x.permute(0, 2, 1, 3)
|
| 517 |
+
|
| 518 |
+
def forward(
|
| 519 |
+
self,
|
| 520 |
+
queries: torch.Tensor,
|
| 521 |
+
keys: torch.Tensor,
|
| 522 |
+
values: torch.Tensor,
|
| 523 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 524 |
+
output_attentions: Optional[bool] = False,
|
| 525 |
+
) -> Tuple[torch.Tensor]:
|
| 526 |
+
query_layer = self.transpose_for_scores(self.query(queries))
|
| 527 |
+
key_layer = self.transpose_for_scores(self.key(keys))
|
| 528 |
+
value_layer = self.transpose_for_scores(self.value(values))
|
| 529 |
+
|
| 530 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 531 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 532 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 533 |
+
|
| 534 |
+
if attention_mask is not None:
|
| 535 |
+
attention_scores = attention_scores + attention_mask
|
| 536 |
+
|
| 537 |
+
# Normalize the attention scores to probabilities.
|
| 538 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
| 539 |
+
|
| 540 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 541 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 542 |
+
attention_probs = self.dropout(attention_probs)
|
| 543 |
+
|
| 544 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
| 545 |
+
|
| 546 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 547 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 548 |
+
context_layer = context_layer.view(new_context_layer_shape)
|
| 549 |
+
|
| 550 |
+
context_layer = self.out_proj(context_layer)
|
| 551 |
+
|
| 552 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 553 |
+
|
| 554 |
+
return outputs
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
class OmDetTurboEncoderLayer(nn.Module):
|
| 558 |
+
def __init__(self, config: OmDetTurboConfig):
|
| 559 |
+
super().__init__()
|
| 560 |
+
self.self_attn = OmDetTurboMultiheadAttention(
|
| 561 |
+
config,
|
| 562 |
+
hidden_size=config.encoder_hidden_dim,
|
| 563 |
+
num_attention_heads=config.num_attention_heads,
|
| 564 |
+
dropout=config.encoder_dropout,
|
| 565 |
+
)
|
| 566 |
+
self.self_attn_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps)
|
| 567 |
+
self.dropout = nn.Dropout(config.encoder_dropout)
|
| 568 |
+
self.activation_fn = ACT2FN[config.encoder_feedforward_activation]
|
| 569 |
+
self.encoder_feedforward_dropout = nn.Dropout(config.encoder_feedforward_dropout)
|
| 570 |
+
self.fc1 = nn.Linear(config.encoder_hidden_dim, config.encoder_dim_feedforward)
|
| 571 |
+
self.fc2 = nn.Linear(config.encoder_dim_feedforward, config.encoder_hidden_dim)
|
| 572 |
+
self.final_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps)
|
| 573 |
+
|
| 574 |
+
@staticmethod
|
| 575 |
+
def with_pos_embed(tensor, pos_embed):
|
| 576 |
+
return tensor if pos_embed is None else tensor + pos_embed
|
| 577 |
+
|
| 578 |
+
def forward(
|
| 579 |
+
self,
|
| 580 |
+
hidden_states: torch.Tensor,
|
| 581 |
+
attention_mask: torch.Tensor,
|
| 582 |
+
position_embeddings: Optional[torch.Tensor] = None,
|
| 583 |
+
output_attentions: bool = False,
|
| 584 |
+
):
|
| 585 |
+
"""
|
| 586 |
+
Args:
|
| 587 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 588 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
| 589 |
+
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
| 590 |
+
values.
|
| 591 |
+
position_embeddings (`torch.FloatTensor`, *optional*):
|
| 592 |
+
Object queries (also called content embeddings), to be added to the hidden states.
|
| 593 |
+
output_attentions (`bool`, *optional*, defaults to `False`):
|
| 594 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 595 |
+
returned tensors for more detail.
|
| 596 |
+
"""
|
| 597 |
+
residual = hidden_states
|
| 598 |
+
query = key = self.with_pos_embed(hidden_states, position_embeddings)
|
| 599 |
+
|
| 600 |
+
hidden_states = self.self_attn(
|
| 601 |
+
queries=query,
|
| 602 |
+
keys=key,
|
| 603 |
+
values=hidden_states,
|
| 604 |
+
attention_mask=attention_mask,
|
| 605 |
+
output_attentions=output_attentions,
|
| 606 |
+
)
|
| 607 |
+
hidden_states, attentions = hidden_states if output_attentions else (hidden_states[0], None)
|
| 608 |
+
hidden_states = self.dropout(hidden_states)
|
| 609 |
+
hidden_states = residual + hidden_states
|
| 610 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 611 |
+
residual = hidden_states
|
| 612 |
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
| 613 |
+
hidden_states = self.encoder_feedforward_dropout(hidden_states)
|
| 614 |
+
hidden_states = self.fc2(hidden_states)
|
| 615 |
+
hidden_states = self.dropout(hidden_states)
|
| 616 |
+
hidden_states = residual + hidden_states
|
| 617 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 618 |
+
if self.training:
|
| 619 |
+
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
|
| 620 |
+
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
| 621 |
+
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
| 622 |
+
|
| 623 |
+
if output_attentions:
|
| 624 |
+
return hidden_states, attentions
|
| 625 |
+
|
| 626 |
+
return (hidden_states,)
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
class OmDetTurboEncoder(nn.Module):
|
| 630 |
+
def __init__(self, config: OmDetTurboConfig):
|
| 631 |
+
super().__init__()
|
| 632 |
+
|
| 633 |
+
self.layers = nn.ModuleList([OmDetTurboEncoderLayer(config) for _ in range(config.encoder_layers)])
|
| 634 |
+
|
| 635 |
+
def forward(
|
| 636 |
+
self, src, src_mask=None, pos_embed=None, output_attentions: bool = False
|
| 637 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
| 638 |
+
hidden_states = src
|
| 639 |
+
attention = () if output_attentions else None
|
| 640 |
+
for layer in self.layers:
|
| 641 |
+
hidden_states = layer(
|
| 642 |
+
hidden_states,
|
| 643 |
+
attention_mask=src_mask,
|
| 644 |
+
position_embeddings=pos_embed,
|
| 645 |
+
output_attentions=output_attentions,
|
| 646 |
+
)
|
| 647 |
+
if output_attentions:
|
| 648 |
+
attention = attention + (hidden_states[1],)
|
| 649 |
+
hidden_states = hidden_states[0]
|
| 650 |
+
|
| 651 |
+
return hidden_states, attention
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
class OmDetTurboHybridEncoder(nn.Module):
|
| 655 |
+
"""
|
| 656 |
+
Encoder consisting of channel projection layers, a set of `OmDetTurboEncoder`, a top-down Feature Pyramid Network
|
| 657 |
+
(FPN) and a bottom-up Path Aggregation Network (PAN). More details on the paper: https://arxiv.org/abs/2304.08069
|
| 658 |
+
|
| 659 |
+
Args:
|
| 660 |
+
config: OmDetTurboConfig
|
| 661 |
+
"""
|
| 662 |
+
|
| 663 |
+
def __init__(self, config: OmDetTurboConfig):
|
| 664 |
+
super().__init__()
|
| 665 |
+
self.config = config
|
| 666 |
+
self.in_channels = config.encoder_in_channels
|
| 667 |
+
self.encoder_hidden_dim = config.encoder_hidden_dim
|
| 668 |
+
self.encoder_projection_indices = config.encoder_projection_indices
|
| 669 |
+
self.positional_encoding_temperature = config.positional_encoding_temperature
|
| 670 |
+
self.eval_size = config.eval_size
|
| 671 |
+
self.out_channels = [self.encoder_hidden_dim for _ in self.in_channels]
|
| 672 |
+
|
| 673 |
+
self.channel_projection_layers = nn.ModuleList()
|
| 674 |
+
for in_channel in self.in_channels:
|
| 675 |
+
self.channel_projection_layers.append(
|
| 676 |
+
nn.Sequential(
|
| 677 |
+
nn.Conv2d(in_channel, self.encoder_hidden_dim, kernel_size=(1, 1), bias=False),
|
| 678 |
+
nn.BatchNorm2d(self.encoder_hidden_dim),
|
| 679 |
+
)
|
| 680 |
+
)
|
| 681 |
+
|
| 682 |
+
# encoder transformer
|
| 683 |
+
self.encoder = nn.ModuleList([OmDetTurboEncoder(config) for _ in range(len(self.encoder_projection_indices))])
|
| 684 |
+
# top-down fpn
|
| 685 |
+
self.lateral_convs = nn.ModuleList()
|
| 686 |
+
self.fpn_blocks = nn.ModuleList()
|
| 687 |
+
for _ in range(len(self.in_channels) - 1, 0, -1):
|
| 688 |
+
self.lateral_convs.append(
|
| 689 |
+
OmDetTurboConvNormLayer(
|
| 690 |
+
config,
|
| 691 |
+
in_channels=self.encoder_hidden_dim,
|
| 692 |
+
out_channels=self.encoder_hidden_dim,
|
| 693 |
+
kernel_size=1,
|
| 694 |
+
stride=1,
|
| 695 |
+
activation=config.conv_norm_activation,
|
| 696 |
+
)
|
| 697 |
+
)
|
| 698 |
+
self.fpn_blocks.append(OmDetTurboCSPRepLayer(config))
|
| 699 |
+
|
| 700 |
+
# bottom-up pan
|
| 701 |
+
self.downsample_convs = nn.ModuleList()
|
| 702 |
+
self.pan_blocks = nn.ModuleList()
|
| 703 |
+
for _ in range(len(self.in_channels) - 1):
|
| 704 |
+
self.downsample_convs.append(
|
| 705 |
+
OmDetTurboConvNormLayer(
|
| 706 |
+
config,
|
| 707 |
+
in_channels=self.encoder_hidden_dim,
|
| 708 |
+
out_channels=self.encoder_hidden_dim,
|
| 709 |
+
kernel_size=3,
|
| 710 |
+
stride=2,
|
| 711 |
+
activation=config.conv_norm_activation,
|
| 712 |
+
)
|
| 713 |
+
)
|
| 714 |
+
self.pan_blocks.append(OmDetTurboCSPRepLayer(config))
|
| 715 |
+
|
| 716 |
+
@staticmethod
|
| 717 |
+
def build_2d_sincos_position_embedding(
|
| 718 |
+
width, height, embed_dim=256, temperature=10000.0, device="cpu", dtype=torch.float32
|
| 719 |
+
):
|
| 720 |
+
grid_w = torch.arange(int(width), dtype=dtype, device=device)
|
| 721 |
+
grid_h = torch.arange(int(height), dtype=dtype, device=device)
|
| 722 |
+
grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij")
|
| 723 |
+
if embed_dim % 4 != 0:
|
| 724 |
+
raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding")
|
| 725 |
+
pos_dim = embed_dim // 4
|
| 726 |
+
omega = torch.arange(pos_dim, dtype=dtype, device=device) / pos_dim
|
| 727 |
+
omega = 1.0 / (temperature**omega)
|
| 728 |
+
|
| 729 |
+
out_w = grid_w.flatten()[..., None] @ omega[None]
|
| 730 |
+
out_h = grid_h.flatten()[..., None] @ omega[None]
|
| 731 |
+
|
| 732 |
+
return torch.concat([out_w.sin(), out_w.cos(), out_h.sin(), out_h.cos()], dim=1)[None, :, :]
|
| 733 |
+
|
| 734 |
+
def forward(
|
| 735 |
+
self,
|
| 736 |
+
inputs_embeddings=None,
|
| 737 |
+
output_attentions=None,
|
| 738 |
+
output_hidden_states=None,
|
| 739 |
+
return_dict=None,
|
| 740 |
+
):
|
| 741 |
+
r"""
|
| 742 |
+
Args:
|
| 743 |
+
inputs_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 744 |
+
Flattened feature map (output of the backbone + projection layers) that is passed to the encoder.
|
| 745 |
+
output_attentions (`bool`, *optional*):
|
| 746 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 747 |
+
returned tensors for more detail.
|
| 748 |
+
output_hidden_states (`bool`, *optional*):
|
| 749 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
| 750 |
+
for more detail.
|
| 751 |
+
return_dict (`bool`, *optional*):
|
| 752 |
+
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
| 753 |
+
"""
|
| 754 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 755 |
+
output_hidden_states = (
|
| 756 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 757 |
+
)
|
| 758 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 759 |
+
|
| 760 |
+
hidden_states = inputs_embeddings
|
| 761 |
+
|
| 762 |
+
encoder_states = () if output_hidden_states else None
|
| 763 |
+
all_attentions = () if output_attentions else None
|
| 764 |
+
# get projection features
|
| 765 |
+
projected_features = [self.channel_projection_layers[i](feature) for i, feature in enumerate(hidden_states)]
|
| 766 |
+
# encoder
|
| 767 |
+
for encoder_layer_index, feature_to_project_index in enumerate(self.encoder_projection_indices):
|
| 768 |
+
if output_hidden_states:
|
| 769 |
+
encoder_states = encoder_states + (projected_features[feature_to_project_index],)
|
| 770 |
+
height, width = projected_features[feature_to_project_index].shape[2:]
|
| 771 |
+
# flatten [batch, channel, height, width] to [batch, height*width, channel]
|
| 772 |
+
src_flatten = projected_features[feature_to_project_index].flatten(2).permute(0, 2, 1)
|
| 773 |
+
if self.training or self.eval_size is None:
|
| 774 |
+
pos_embed = self.build_2d_sincos_position_embedding(
|
| 775 |
+
width,
|
| 776 |
+
height,
|
| 777 |
+
self.encoder_hidden_dim,
|
| 778 |
+
self.positional_encoding_temperature,
|
| 779 |
+
device=src_flatten.device,
|
| 780 |
+
dtype=src_flatten.dtype,
|
| 781 |
+
).to(src_flatten.device, src_flatten.dtype)
|
| 782 |
+
else:
|
| 783 |
+
pos_embed = None
|
| 784 |
+
layer_outputs = self.encoder[encoder_layer_index](
|
| 785 |
+
src_flatten,
|
| 786 |
+
pos_embed=pos_embed,
|
| 787 |
+
output_attentions=output_attentions,
|
| 788 |
+
)
|
| 789 |
+
projected_features[feature_to_project_index] = (
|
| 790 |
+
layer_outputs[0].permute(0, 2, 1).reshape(-1, self.encoder_hidden_dim, height, width).contiguous()
|
| 791 |
+
)
|
| 792 |
+
|
| 793 |
+
if output_attentions:
|
| 794 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
| 795 |
+
|
| 796 |
+
if output_hidden_states:
|
| 797 |
+
encoder_states = encoder_states + (projected_features[feature_to_project_index],)
|
| 798 |
+
|
| 799 |
+
# Feature Pyramid Network (FPN)
|
| 800 |
+
fpn_feature_maps = [projected_features[-1]]
|
| 801 |
+
for idx in range(len(self.in_channels) - 1, 0, -1):
|
| 802 |
+
feat_high = fpn_feature_maps[0]
|
| 803 |
+
feat_low = projected_features[idx - 1]
|
| 804 |
+
feat_high = self.lateral_convs[len(self.in_channels) - 1 - idx](feat_high)
|
| 805 |
+
fpn_feature_maps[0] = feat_high
|
| 806 |
+
upsample_feat = F.interpolate(feat_high, scale_factor=2.0, mode="nearest")
|
| 807 |
+
fps_map = self.fpn_blocks[len(self.in_channels) - 1 - idx](torch.concat([upsample_feat, feat_low], dim=1))
|
| 808 |
+
fpn_feature_maps.insert(0, fps_map)
|
| 809 |
+
|
| 810 |
+
# Path Aggregation Network (PAN)
|
| 811 |
+
fpn_states = [fpn_feature_maps[0]]
|
| 812 |
+
for idx in range(len(self.in_channels) - 1):
|
| 813 |
+
feat_low = fpn_states[-1]
|
| 814 |
+
feat_high = fpn_feature_maps[idx + 1]
|
| 815 |
+
downsample_feat = self.downsample_convs[idx](feat_low)
|
| 816 |
+
hidden_states = self.pan_blocks[idx](
|
| 817 |
+
torch.concat([downsample_feat, feat_high.to(downsample_feat.device)], dim=1)
|
| 818 |
+
)
|
| 819 |
+
fpn_states.append(hidden_states)
|
| 820 |
+
if not return_dict:
|
| 821 |
+
return (fpn_states[-1], encoder_states, all_attentions, fpn_states)
|
| 822 |
+
return OmDetTurboEncoderOutput(
|
| 823 |
+
last_hidden_state=fpn_states[-1],
|
| 824 |
+
hidden_states=encoder_states,
|
| 825 |
+
attentions=all_attentions,
|
| 826 |
+
extracted_states=fpn_states,
|
| 827 |
+
)
|
| 828 |
+
|
| 829 |
+
|
| 830 |
+
class OmDetTurboMLPWithDropout(nn.Module):
|
| 831 |
+
def __init__(self, config):
|
| 832 |
+
super().__init__()
|
| 833 |
+
self.linear1 = nn.Linear(config.class_embed_dim, config.task_encoder_hidden_dim)
|
| 834 |
+
self.activation = ACT2FN[config.decoder_activation]
|
| 835 |
+
self.dropout = nn.Dropout(config.decoder_dropout)
|
| 836 |
+
self.linear2 = nn.Linear(config.task_encoder_hidden_dim, config.class_embed_dim)
|
| 837 |
+
|
| 838 |
+
def forward(self, x):
|
| 839 |
+
return self.linear2(self.dropout(self.activation(self.linear1(x))))
|
| 840 |
+
|
| 841 |
+
|
| 842 |
+
class OmDetTurboMLP(nn.Module):
|
| 843 |
+
"""Very simple multi-layer perceptron (also called FFN)"""
|
| 844 |
+
|
| 845 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
| 846 |
+
super().__init__()
|
| 847 |
+
self.num_layers = num_layers
|
| 848 |
+
hidden_layers_dims = [hidden_dim] * (num_layers - 1)
|
| 849 |
+
layers_dims = [input_dim] + hidden_layers_dims + [output_dim]
|
| 850 |
+
self.layers = nn.ModuleList(
|
| 851 |
+
[nn.Linear(in_dim, out_dim) for in_dim, out_dim in zip(layers_dims[:-1], layers_dims[1:])]
|
| 852 |
+
)
|
| 853 |
+
|
| 854 |
+
def forward(self, x):
|
| 855 |
+
for i, layer in enumerate(self.layers):
|
| 856 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
| 857 |
+
return x
|
| 858 |
+
|
| 859 |
+
|
| 860 |
+
class OmDetTurboResidualLayer(nn.Module):
|
| 861 |
+
"""
|
| 862 |
+
A residual connection followed by a layer norm.
|
| 863 |
+
"""
|
| 864 |
+
|
| 865 |
+
def __init__(self, config):
|
| 866 |
+
super().__init__()
|
| 867 |
+
self.norm1 = nn.LayerNorm(config.class_embed_dim, eps=config.layer_norm_eps)
|
| 868 |
+
self.dropout = nn.Dropout(config.decoder_dropout)
|
| 869 |
+
|
| 870 |
+
def forward(self, x, y):
|
| 871 |
+
return self.norm1(x + self.dropout(y))
|
| 872 |
+
|
| 873 |
+
|
| 874 |
+
class OmDetTurboTaskEncoder(nn.Module):
|
| 875 |
+
def __init__(self, config):
|
| 876 |
+
super().__init__()
|
| 877 |
+
self.mlp = OmDetTurboMLPWithDropout(config)
|
| 878 |
+
self.res1 = OmDetTurboResidualLayer(config)
|
| 879 |
+
|
| 880 |
+
def forward(self, x):
|
| 881 |
+
mlp_out = self.mlp(x)
|
| 882 |
+
x = self.res1(x, mlp_out)
|
| 883 |
+
return x
|
| 884 |
+
|
| 885 |
+
|
| 886 |
+
class OmDetTurboDeformableTransformerDecoderLayer(nn.Module):
|
| 887 |
+
"""
|
| 888 |
+
A single layer of the Deformable Transformer Decoder.
|
| 889 |
+
"""
|
| 890 |
+
|
| 891 |
+
def __init__(self, config):
|
| 892 |
+
super().__init__()
|
| 893 |
+
# self attention
|
| 894 |
+
self.self_attn = OmDetTurboMultiheadAttention(
|
| 895 |
+
config,
|
| 896 |
+
hidden_size=config.decoder_hidden_dim,
|
| 897 |
+
num_attention_heads=config.decoder_num_heads,
|
| 898 |
+
dropout=config.decoder_dropout,
|
| 899 |
+
)
|
| 900 |
+
self.dropout1 = nn.Dropout(config.decoder_dropout)
|
| 901 |
+
self.norm1 = nn.LayerNorm(config.decoder_hidden_dim, eps=config.layer_norm_eps)
|
| 902 |
+
|
| 903 |
+
# cross attention
|
| 904 |
+
self.cross_attn = OmDetTurboMultiscaleDeformableAttention(
|
| 905 |
+
config, num_heads=config.decoder_num_heads, n_points=config.decoder_num_points
|
| 906 |
+
)
|
| 907 |
+
self.dropout2 = nn.Dropout(config.decoder_dropout)
|
| 908 |
+
self.norm2 = nn.LayerNorm(config.decoder_hidden_dim, eps=config.layer_norm_eps)
|
| 909 |
+
|
| 910 |
+
# feed forward network
|
| 911 |
+
self.linear1 = nn.Linear(config.decoder_hidden_dim, config.decoder_dim_feedforward)
|
| 912 |
+
self.act = ACT2FN[config.decoder_activation]
|
| 913 |
+
self.dropout3 = nn.Dropout(config.decoder_dropout)
|
| 914 |
+
self.linear2 = nn.Linear(config.decoder_dim_feedforward, config.decoder_hidden_dim)
|
| 915 |
+
self.dropout4 = nn.Dropout(config.decoder_dropout)
|
| 916 |
+
self.norm3 = nn.LayerNorm(config.decoder_hidden_dim, eps=config.layer_norm_eps)
|
| 917 |
+
|
| 918 |
+
self.output_attentions = config.output_attentions
|
| 919 |
+
self.output_hidden_states = config.output_hidden_states
|
| 920 |
+
|
| 921 |
+
@staticmethod
|
| 922 |
+
def with_pos_embed(tensor, pos):
|
| 923 |
+
return tensor if pos is None else tensor + pos
|
| 924 |
+
|
| 925 |
+
def forward(
|
| 926 |
+
self,
|
| 927 |
+
decoder_embeddings,
|
| 928 |
+
task_features,
|
| 929 |
+
reference_points,
|
| 930 |
+
vision_features,
|
| 931 |
+
vision_shapes,
|
| 932 |
+
vision_shapes_list,
|
| 933 |
+
level_start_index=None,
|
| 934 |
+
attention_mask=None,
|
| 935 |
+
padding_mask=None,
|
| 936 |
+
query_position=None,
|
| 937 |
+
output_attentions=None,
|
| 938 |
+
output_hidden_states=None,
|
| 939 |
+
):
|
| 940 |
+
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
|
| 941 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
|
| 942 |
+
|
| 943 |
+
origin_embedding_len = decoder_embeddings.shape[1]
|
| 944 |
+
|
| 945 |
+
# self attention
|
| 946 |
+
query = key = self.with_pos_embed(decoder_embeddings, query_position)
|
| 947 |
+
# combine task_features with query, key, value
|
| 948 |
+
task_features = task_features.transpose(0, 1)
|
| 949 |
+
query = torch.cat((query, task_features), dim=1)
|
| 950 |
+
key = torch.cat((key, task_features), dim=1)
|
| 951 |
+
decoder_embeddings = torch.cat((decoder_embeddings, task_features), dim=1)
|
| 952 |
+
|
| 953 |
+
outputs = self.self_attn(
|
| 954 |
+
query,
|
| 955 |
+
key,
|
| 956 |
+
decoder_embeddings,
|
| 957 |
+
attention_mask=attention_mask,
|
| 958 |
+
output_attentions=output_attentions,
|
| 959 |
+
)
|
| 960 |
+
context, self_attention = outputs if output_attentions else (outputs[0], None)
|
| 961 |
+
decoder_embeddings = decoder_embeddings + self.dropout1(context)
|
| 962 |
+
decoder_embeddings = self.norm1(decoder_embeddings)
|
| 963 |
+
|
| 964 |
+
task_features = decoder_embeddings[:, origin_embedding_len:, :].transpose(0, 1)
|
| 965 |
+
decoder_embeddings = decoder_embeddings[:, :origin_embedding_len, :]
|
| 966 |
+
|
| 967 |
+
# cross attention
|
| 968 |
+
hidden_states = self.with_pos_embed(decoder_embeddings, query_position)
|
| 969 |
+
reference_points = reference_points.unsqueeze(2)
|
| 970 |
+
outputs, cross_attention = self.cross_attn(
|
| 971 |
+
hidden_states=hidden_states,
|
| 972 |
+
attention_mask=padding_mask,
|
| 973 |
+
encoder_hidden_states=vision_features,
|
| 974 |
+
reference_points=reference_points,
|
| 975 |
+
spatial_shapes=vision_shapes,
|
| 976 |
+
spatial_shapes_list=vision_shapes_list,
|
| 977 |
+
level_start_index=level_start_index,
|
| 978 |
+
)
|
| 979 |
+
decoder_embeddings = decoder_embeddings + self.dropout2(outputs)
|
| 980 |
+
residual = self.norm2(decoder_embeddings)
|
| 981 |
+
|
| 982 |
+
# feed forward network
|
| 983 |
+
decoder_embeddings = self.linear2(self.dropout3(self.act(self.linear1(residual))))
|
| 984 |
+
decoder_embeddings = residual + self.dropout4(decoder_embeddings)
|
| 985 |
+
decoder_embeddings = self.norm3(decoder_embeddings)
|
| 986 |
+
|
| 987 |
+
return (
|
| 988 |
+
decoder_embeddings,
|
| 989 |
+
task_features,
|
| 990 |
+
self_attention if output_attentions else None,
|
| 991 |
+
cross_attention if output_attentions else None,
|
| 992 |
+
)
|
| 993 |
+
|
| 994 |
+
|
| 995 |
+
class OmDetTurboPreTrainedModel(PreTrainedModel):
|
| 996 |
+
config_class = OmDetTurboConfig
|
| 997 |
+
base_model_prefix = "model"
|
| 998 |
+
main_input_name = "pixel_values"
|
| 999 |
+
|
| 1000 |
+
def _init_weights(self, module):
|
| 1001 |
+
def linear_init_(module_to_init):
|
| 1002 |
+
bound = 1 / math.sqrt(module_to_init.weight.shape[0])
|
| 1003 |
+
nn.init.uniform_(module_to_init.weight, -bound, bound)
|
| 1004 |
+
if hasattr(module_to_init, "bias") and module_to_init.bias is not None:
|
| 1005 |
+
nn.init.uniform_(module_to_init.bias, -bound, bound)
|
| 1006 |
+
|
| 1007 |
+
if isinstance(module, OmDetTurboEncoderLayer):
|
| 1008 |
+
linear_init_(module.fc1)
|
| 1009 |
+
linear_init_(module.fc2)
|
| 1010 |
+
elif isinstance(module, OmDetTurboDecoder):
|
| 1011 |
+
nn.init.constant_(module.encoder_bbox_head.layers[-1].weight, 0.0)
|
| 1012 |
+
nn.init.constant_(module.encoder_bbox_head.layers[-1].bias, 0.0)
|
| 1013 |
+
for mlp in module.decoder_bbox_head:
|
| 1014 |
+
nn.init.constant_(mlp.layers[-1].weight, 0.0)
|
| 1015 |
+
nn.init.constant_(mlp.layers[-1].bias, 0.0)
|
| 1016 |
+
linear_init_(module.encoder_vision_features[0])
|
| 1017 |
+
nn.init.xavier_uniform_(module.encoder_vision_features[0].weight)
|
| 1018 |
+
if module.learn_initial_query:
|
| 1019 |
+
nn.init.xavier_uniform_(module.tgt_embed.weight)
|
| 1020 |
+
nn.init.xavier_uniform_(module.query_position_head.layers[0].weight)
|
| 1021 |
+
nn.init.xavier_uniform_(module.query_position_head.layers[1].weight)
|
| 1022 |
+
for layer in module.channel_projection_layers:
|
| 1023 |
+
nn.init.xavier_uniform_(layer[0].weight)
|
| 1024 |
+
elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
|
| 1025 |
+
module.weight.data.normal_(mean=0.0, std=self.config.init_std)
|
| 1026 |
+
if module.bias is not None:
|
| 1027 |
+
module.bias.data.zero_()
|
| 1028 |
+
|
| 1029 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 1030 |
+
if isinstance(module, OmDetTurboDecoder):
|
| 1031 |
+
module.gradient_checkpointing = value
|
| 1032 |
+
|
| 1033 |
+
@staticmethod
|
| 1034 |
+
def _get_cache_key_at_index(input_ids, attention_mask, index):
|
| 1035 |
+
input_ids = input_ids[index]
|
| 1036 |
+
input_mask = attention_mask[index]
|
| 1037 |
+
cache_key = tuple(input_ids[input_mask != 0].tolist())
|
| 1038 |
+
return cache_key
|
| 1039 |
+
|
| 1040 |
+
def get_cached_class_embeddings(self, classes_input_ids, classes_attention_mask):
|
| 1041 |
+
not_cached_index = []
|
| 1042 |
+
not_cached_classes = []
|
| 1043 |
+
total_embeddings = []
|
| 1044 |
+
for idx, _ in enumerate(classes_input_ids):
|
| 1045 |
+
cache_key = self._get_cache_key_at_index(classes_input_ids, classes_attention_mask, idx)
|
| 1046 |
+
if self.language_cache_class.has(cache_key):
|
| 1047 |
+
total_embeddings.append(self.language_cache_class.get(cache_key))
|
| 1048 |
+
else:
|
| 1049 |
+
total_embeddings.append(None)
|
| 1050 |
+
not_cached_index.append(idx)
|
| 1051 |
+
not_cached_classes.append(cache_key)
|
| 1052 |
+
|
| 1053 |
+
if not_cached_classes:
|
| 1054 |
+
not_cached_classes_ids = torch.stack([classes_input_ids[idx] for idx in not_cached_index])
|
| 1055 |
+
embeddings = self.language_backbone(not_cached_classes_ids, encode_type="class")
|
| 1056 |
+
for idx, emb in enumerate(embeddings):
|
| 1057 |
+
idx_to_put = not_cached_index[idx]
|
| 1058 |
+
total_embeddings[idx_to_put] = emb
|
| 1059 |
+
self.language_cache_class.put(not_cached_classes[idx], emb)
|
| 1060 |
+
|
| 1061 |
+
total_class_embs = torch.stack(total_embeddings).to(self.device)
|
| 1062 |
+
return total_class_embs
|
| 1063 |
+
|
| 1064 |
+
def get_cached_task_embeddings(self, tasks_input_ids, tasks_attention_mask):
|
| 1065 |
+
not_cached_index = []
|
| 1066 |
+
not_cached_tasks = []
|
| 1067 |
+
total_task_features = []
|
| 1068 |
+
total_task_masks = []
|
| 1069 |
+
for idx, _ in enumerate(tasks_input_ids):
|
| 1070 |
+
cache_key = self._get_cache_key_at_index(tasks_input_ids, tasks_attention_mask, idx)
|
| 1071 |
+
if self.language_cache_prompt.has(cache_key):
|
| 1072 |
+
task_feature, task_mask = self.language_cache_prompt.get(cache_key)
|
| 1073 |
+
total_task_features.append(task_feature)
|
| 1074 |
+
total_task_masks.append(task_mask)
|
| 1075 |
+
else:
|
| 1076 |
+
total_task_features.append(None)
|
| 1077 |
+
total_task_masks.append(None)
|
| 1078 |
+
not_cached_index.append(idx)
|
| 1079 |
+
not_cached_tasks.append(cache_key)
|
| 1080 |
+
|
| 1081 |
+
if not_cached_tasks:
|
| 1082 |
+
not_cached_index_ids = torch.stack([tasks_input_ids[idx] for idx in not_cached_index])
|
| 1083 |
+
not_cached_mask = torch.stack([tasks_attention_mask[idx] for idx in not_cached_index])
|
| 1084 |
+
embeddings, masks = self.language_backbone(not_cached_index_ids, mask=not_cached_mask, encode_type="task")
|
| 1085 |
+
|
| 1086 |
+
for idx in range(embeddings.shape[1]):
|
| 1087 |
+
emb = embeddings[:, [idx], :]
|
| 1088 |
+
idx_to_put = not_cached_index[idx]
|
| 1089 |
+
cur_mask = torch.unsqueeze(masks[idx], dim=0).to(self.device)
|
| 1090 |
+
total_task_features[idx_to_put] = emb
|
| 1091 |
+
total_task_masks[idx_to_put] = cur_mask
|
| 1092 |
+
self.language_cache_prompt.put(not_cached_tasks[idx], (emb, cur_mask))
|
| 1093 |
+
|
| 1094 |
+
# pad before concat if needed
|
| 1095 |
+
max_len = max([task.shape[0] for task in total_task_features])
|
| 1096 |
+
for idx, task in enumerate(total_task_features):
|
| 1097 |
+
if task.shape[0] < max_len:
|
| 1098 |
+
pad_size = max_len - task.shape[0]
|
| 1099 |
+
total_task_features[idx] = F.pad(task, (0, 0, 0, 0, 0, pad_size))
|
| 1100 |
+
total_task_masks[idx] = F.pad(total_task_masks[idx], (0, pad_size))
|
| 1101 |
+
|
| 1102 |
+
total_task_features = torch.cat(total_task_features, dim=1).to(self.device)
|
| 1103 |
+
total_task_masks = torch.cat(total_task_masks, dim=0).to(self.device)
|
| 1104 |
+
|
| 1105 |
+
return total_task_features, total_task_masks
|
| 1106 |
+
|
| 1107 |
+
def get_language_embedding(
|
| 1108 |
+
self,
|
| 1109 |
+
classes_input_ids,
|
| 1110 |
+
classes_attention_mask,
|
| 1111 |
+
tasks_input_ids,
|
| 1112 |
+
tasks_attention_mask,
|
| 1113 |
+
classes_structure,
|
| 1114 |
+
):
|
| 1115 |
+
batched_classes_embeddings = self.get_cached_class_embeddings(classes_input_ids, classes_attention_mask)
|
| 1116 |
+
# regroup class embeddings using saved structure
|
| 1117 |
+
max_class_size = torch.max(classes_structure)
|
| 1118 |
+
class_embeddings_regrouped = []
|
| 1119 |
+
start = 0
|
| 1120 |
+
for size in classes_structure:
|
| 1121 |
+
pad_size = max_class_size - size
|
| 1122 |
+
class_embeddings_regrouped.append(
|
| 1123 |
+
F.pad(batched_classes_embeddings[start : start + size], (0, 0, 0, pad_size)).unsqueeze(1)
|
| 1124 |
+
)
|
| 1125 |
+
start += size
|
| 1126 |
+
class_embeddings = torch.cat(class_embeddings_regrouped, dim=1)
|
| 1127 |
+
|
| 1128 |
+
task_embeddings, task_mask = self.get_cached_task_embeddings(tasks_input_ids, tasks_attention_mask)
|
| 1129 |
+
|
| 1130 |
+
return class_embeddings, task_embeddings, task_mask
|
| 1131 |
+
|
| 1132 |
+
|
| 1133 |
+
OMDET_TURBO_START_DOCSTRING = r"""
|
| 1134 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 1135 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 1136 |
+
etc.)
|
| 1137 |
+
|
| 1138 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 1139 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 1140 |
+
and behavior.
|
| 1141 |
+
|
| 1142 |
+
Parameters:
|
| 1143 |
+
config ([`OmDetTurboConfig`]):
|
| 1144 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
| 1145 |
+
load the weights associated with the model, only the configuration. Check out the
|
| 1146 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 1147 |
+
"""
|
| 1148 |
+
|
| 1149 |
+
OMDET_TURBO_INPUTS_DOCSTRING = r"""
|
| 1150 |
+
Args:
|
| 1151 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 1152 |
+
Pixel values. Padding will be ignored by default should you provide it.
|
| 1153 |
+
|
| 1154 |
+
Pixel values can be obtained using [`AutoImageProcessor`]. See [`DetrImageProcessor.__call__`] for
|
| 1155 |
+
details.
|
| 1156 |
+
|
| 1157 |
+
classes_input_ids (`torch.LongTensor` of shape `(total_classes (>= batch_size), sequence_length)`):
|
| 1158 |
+
Indices of input classes sequence tokens in the vocabulary of the language model.
|
| 1159 |
+
Several classes can be provided for each tasks, thus the tokenized classes are flattened
|
| 1160 |
+
and the structure of the classes is provided in the `classes_structure` argument.
|
| 1161 |
+
|
| 1162 |
+
Indices can be obtained using [`OmDetTurboProcessor`]. See [`OmDetTurboProcessor.__call__`] for
|
| 1163 |
+
details.
|
| 1164 |
+
|
| 1165 |
+
[What are input IDs?](../glossary#input-ids)
|
| 1166 |
+
|
| 1167 |
+
classes_attention_mask (`torch.BoolTensor` of shape `(total_classes (>= batch_size), num_classes, sequence_length)`):
|
| 1168 |
+
Attention mask for the classes. This is a binary mask that indicates which tokens should be attended to,
|
| 1169 |
+
and which should not.
|
| 1170 |
+
|
| 1171 |
+
tasks_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 1172 |
+
Indices of input tasks sequence tokens in the vocabulary of the language model.
|
| 1173 |
+
|
| 1174 |
+
Indices can be obtained using [`OmDetTurboProcessor`]. See [`OmDetTurboProcessor.__call__`] for
|
| 1175 |
+
details.
|
| 1176 |
+
|
| 1177 |
+
[What are input IDs?](../glossary#input-ids)
|
| 1178 |
+
|
| 1179 |
+
tasks_attention_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
|
| 1180 |
+
Attention mask for the tasks. This is a binary mask that indicates which tokens should be attended to,
|
| 1181 |
+
and which should not.
|
| 1182 |
+
|
| 1183 |
+
classes_structure (torch.LongTensor of shape `(batch_size)`):
|
| 1184 |
+
Structure of the classes. This tensor indicates the number of classes for each task.
|
| 1185 |
+
|
| 1186 |
+
output_attentions (`bool`, *optional*):
|
| 1187 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 1188 |
+
tensors for more detail.
|
| 1189 |
+
output_hidden_states (`bool`, *optional*):
|
| 1190 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 1191 |
+
more detail.
|
| 1192 |
+
return_dict (`bool`, *optional*):
|
| 1193 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 1194 |
+
|
| 1195 |
+
"""
|
| 1196 |
+
|
| 1197 |
+
|
| 1198 |
+
def _cosine_similarity_scaled(a, b, logit_scale):
|
| 1199 |
+
a = a / a.norm(dim=2, keepdim=True).clamp_min(1e-12)
|
| 1200 |
+
b = b / b.norm(dim=1, keepdim=True).clamp_min(1e-12)
|
| 1201 |
+
logit_scale = logit_scale.exp()
|
| 1202 |
+
logits_per_image = logit_scale * torch.bmm(a, b)
|
| 1203 |
+
return logits_per_image
|
| 1204 |
+
|
| 1205 |
+
|
| 1206 |
+
def get_class_similarity(class_distance_type, cls_feature, class_proj):
|
| 1207 |
+
logit_scale = torch.tensor(1 / 0.07).log()
|
| 1208 |
+
if class_distance_type == "cosine":
|
| 1209 |
+
class_logits = _cosine_similarity_scaled(cls_feature, class_proj, logit_scale)
|
| 1210 |
+
elif class_distance_type == "dot":
|
| 1211 |
+
class_logits = torch.bmm(cls_feature, class_proj)
|
| 1212 |
+
else:
|
| 1213 |
+
raise Exception("Unknown class_distance_type {}".format(class_distance_type))
|
| 1214 |
+
return class_logits
|
| 1215 |
+
|
| 1216 |
+
|
| 1217 |
+
def _inverse_sigmoid(x, eps=1e-5):
|
| 1218 |
+
x = x.clamp(min=0, max=1)
|
| 1219 |
+
x1 = x.clamp(min=eps)
|
| 1220 |
+
x2 = (1 - x).clamp(min=eps)
|
| 1221 |
+
return torch.log(x1 / x2)
|
| 1222 |
+
|
| 1223 |
+
|
| 1224 |
+
class OmDetTurboDecoder(OmDetTurboPreTrainedModel):
|
| 1225 |
+
def __init__(self, config: OmDetTurboConfig):
|
| 1226 |
+
self.config = config
|
| 1227 |
+
super().__init__(config)
|
| 1228 |
+
self.gradient_checkpointing = False
|
| 1229 |
+
|
| 1230 |
+
hidden_dim = config.decoder_hidden_dim
|
| 1231 |
+
self.num_queries = config.num_queries
|
| 1232 |
+
self.class_distance_type = config.class_distance_type
|
| 1233 |
+
self.learn_initial_query = config.learn_initial_query
|
| 1234 |
+
|
| 1235 |
+
# backbone feature projection
|
| 1236 |
+
self.channel_projection_layers = nn.ModuleList(
|
| 1237 |
+
nn.Sequential(nn.Conv2d(x, hidden_dim, 1, bias=False), nn.BatchNorm2d(hidden_dim))
|
| 1238 |
+
for x in config.vision_features_channels
|
| 1239 |
+
)
|
| 1240 |
+
self.task_encoder = OmDetTurboTaskEncoder(config)
|
| 1241 |
+
if config.class_embed_dim != hidden_dim:
|
| 1242 |
+
self.task_project = nn.Linear(config.class_embed_dim, hidden_dim)
|
| 1243 |
+
|
| 1244 |
+
# Transformer module
|
| 1245 |
+
self.layers = nn.ModuleList(
|
| 1246 |
+
[OmDetTurboDeformableTransformerDecoderLayer(config) for _ in range(config.decoder_num_layers)]
|
| 1247 |
+
)
|
| 1248 |
+
self.decoder_num_layers = config.decoder_num_layers
|
| 1249 |
+
# decoder embedding
|
| 1250 |
+
if self.learn_initial_query:
|
| 1251 |
+
self.tgt_embed = nn.Embedding(self.num_queries, hidden_dim)
|
| 1252 |
+
self.query_position_head = OmDetTurboMLP(
|
| 1253 |
+
input_dim=4, hidden_dim=2 * hidden_dim, output_dim=hidden_dim, num_layers=2
|
| 1254 |
+
)
|
| 1255 |
+
|
| 1256 |
+
# encoder head
|
| 1257 |
+
self.encoder_vision_features = nn.Sequential(
|
| 1258 |
+
nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim, eps=config.layer_norm_eps)
|
| 1259 |
+
)
|
| 1260 |
+
self.encoder_class_head = nn.Linear(config.class_embed_dim, hidden_dim)
|
| 1261 |
+
self.encoder_bbox_head = OmDetTurboMLP(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=4, num_layers=3)
|
| 1262 |
+
|
| 1263 |
+
# decoder head
|
| 1264 |
+
self.decoder_class_head = nn.ModuleList(
|
| 1265 |
+
[nn.Linear(config.class_embed_dim, hidden_dim) for _ in range(config.decoder_num_layers)]
|
| 1266 |
+
)
|
| 1267 |
+
self.decoder_bbox_head = nn.ModuleList(
|
| 1268 |
+
[OmDetTurboMLP(hidden_dim, hidden_dim, 4, num_layers=3) for _ in range(config.decoder_num_layers)]
|
| 1269 |
+
)
|
| 1270 |
+
|
| 1271 |
+
# Initialize weights and apply final processing
|
| 1272 |
+
self.post_init()
|
| 1273 |
+
|
| 1274 |
+
@lru_cache(maxsize=32)
|
| 1275 |
+
def generate_anchors(self, spatial_shapes=None, grid_size=0.05, device="cpu", dtype=torch.float32):
|
| 1276 |
+
# We always generate anchors in float32 to preserve equivalence between
|
| 1277 |
+
# dynamic and static anchor inference
|
| 1278 |
+
# Ignore copy
|
| 1279 |
+
if spatial_shapes is None:
|
| 1280 |
+
raise ValueError("spatial_shapes must be provided")
|
| 1281 |
+
|
| 1282 |
+
anchors = []
|
| 1283 |
+
for level, (height, width) in enumerate(spatial_shapes):
|
| 1284 |
+
grid_y, grid_x = torch.meshgrid(
|
| 1285 |
+
torch.arange(end=height, dtype=dtype, device=device),
|
| 1286 |
+
torch.arange(end=width, dtype=dtype, device=device),
|
| 1287 |
+
indexing="ij",
|
| 1288 |
+
)
|
| 1289 |
+
grid_xy = torch.stack([grid_x, grid_y], -1)
|
| 1290 |
+
valid_wh = torch.tensor([width, height], dtype=dtype, device=device)
|
| 1291 |
+
grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_wh
|
| 1292 |
+
wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0**level)
|
| 1293 |
+
anchors.append(torch.concat([grid_xy, wh], -1).reshape(-1, height * width, 4))
|
| 1294 |
+
# define the valid range for anchor coordinates
|
| 1295 |
+
eps = 1e-2
|
| 1296 |
+
anchors = torch.concat(anchors, 1)
|
| 1297 |
+
valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True)
|
| 1298 |
+
anchors = torch.log(anchors / (1 - anchors))
|
| 1299 |
+
anchors = torch.where(valid_mask, anchors, torch.inf)
|
| 1300 |
+
|
| 1301 |
+
return anchors, valid_mask
|
| 1302 |
+
|
| 1303 |
+
def _get_encoder_input(self, vision_features):
|
| 1304 |
+
# get projection features
|
| 1305 |
+
vision_features = [self.channel_projection_layers[i](feat) for i, feat in enumerate(vision_features)]
|
| 1306 |
+
# get encoder inputs
|
| 1307 |
+
new_vision_features = []
|
| 1308 |
+
new_vision_shapes_list = []
|
| 1309 |
+
for feat in vision_features:
|
| 1310 |
+
height, width = feat.shape[2:]
|
| 1311 |
+
# [batch_size, channels, height, width] -> [batch_size, height*width, channels]
|
| 1312 |
+
new_vision_features.append(feat.flatten(2).permute(0, 2, 1))
|
| 1313 |
+
# [num_feature_levels, 2]
|
| 1314 |
+
new_vision_shapes_list.append((height, width))
|
| 1315 |
+
|
| 1316 |
+
# [batch_size, height*width, channels]
|
| 1317 |
+
new_vision_features = torch.cat(new_vision_features, 1)
|
| 1318 |
+
new_vision_shapes = torch.tensor(new_vision_shapes_list, dtype=torch.int64, device=vision_features[0].device)
|
| 1319 |
+
level_start_index = torch.cat((new_vision_shapes.new_zeros((1,)), new_vision_shapes.prod(1).cumsum(0)[:-1]))
|
| 1320 |
+
|
| 1321 |
+
return new_vision_features, new_vision_shapes, new_vision_shapes_list, level_start_index
|
| 1322 |
+
|
| 1323 |
+
def _get_decoder_input(
|
| 1324 |
+
self, vision_features, vision_shapes, class_features, denoise_embeddings=None, denoise_bboxes=None
|
| 1325 |
+
):
|
| 1326 |
+
batch_size = len(vision_features)
|
| 1327 |
+
# prepare input for decoder
|
| 1328 |
+
anchors, valid_mask = self.generate_anchors(
|
| 1329 |
+
vision_shapes, device=vision_features.device, dtype=vision_features.dtype
|
| 1330 |
+
)
|
| 1331 |
+
predicted_class_features = self.encoder_vision_features(
|
| 1332 |
+
torch.where(
|
| 1333 |
+
valid_mask,
|
| 1334 |
+
vision_features,
|
| 1335 |
+
torch.tensor(0.0, dtype=vision_features.dtype, device=vision_features.device),
|
| 1336 |
+
)
|
| 1337 |
+
)
|
| 1338 |
+
|
| 1339 |
+
original_class_projected = self.encoder_class_head(class_features).permute(1, 2, 0)
|
| 1340 |
+
encoder_class_similarity = get_class_similarity(
|
| 1341 |
+
self.class_distance_type, predicted_class_features, original_class_projected
|
| 1342 |
+
)
|
| 1343 |
+
|
| 1344 |
+
# dynamic anchors + static content
|
| 1345 |
+
# (batch_size, height*width, 4)
|
| 1346 |
+
encoder_outputs_bboxes = self.encoder_bbox_head(predicted_class_features) + anchors
|
| 1347 |
+
|
| 1348 |
+
# query selection
|
| 1349 |
+
# (batch_size, num_queries)
|
| 1350 |
+
topk_ind = torch.topk(encoder_class_similarity.max(-1).values, self.num_queries, dim=1).indices.view(-1)
|
| 1351 |
+
# (batch_size, num_queries)
|
| 1352 |
+
batch_ind = (
|
| 1353 |
+
torch.arange(end=batch_size, dtype=topk_ind.dtype, device=topk_ind.device)
|
| 1354 |
+
.unsqueeze(-1)
|
| 1355 |
+
.repeat(1, self.num_queries)
|
| 1356 |
+
.view(-1)
|
| 1357 |
+
)
|
| 1358 |
+
|
| 1359 |
+
reference_points = encoder_outputs_bboxes[batch_ind, topk_ind].view(batch_size, self.num_queries, -1)
|
| 1360 |
+
encoder_bboxes = reference_points.sigmoid()
|
| 1361 |
+
if denoise_bboxes is not None:
|
| 1362 |
+
reference_points = torch.cat([denoise_bboxes, reference_points], 1)
|
| 1363 |
+
if self.training:
|
| 1364 |
+
reference_points = reference_points.detach()
|
| 1365 |
+
encoder_class_similarity = encoder_class_similarity[batch_ind, topk_ind].view(batch_size, self.num_queries, -1)
|
| 1366 |
+
|
| 1367 |
+
if self.learn_initial_query:
|
| 1368 |
+
embeddings = self.tgt_embed.weight.unsqueeze(0).repeat(batch_size, 1, 1)
|
| 1369 |
+
else:
|
| 1370 |
+
embeddings = predicted_class_features[batch_ind, topk_ind].view(batch_size, self.num_queries, -1)
|
| 1371 |
+
if self.training:
|
| 1372 |
+
embeddings = embeddings.detach()
|
| 1373 |
+
if denoise_embeddings is not None:
|
| 1374 |
+
embeddings = torch.cat([denoise_embeddings, embeddings], 1)
|
| 1375 |
+
|
| 1376 |
+
return embeddings, reference_points, encoder_bboxes, encoder_class_similarity, anchors
|
| 1377 |
+
|
| 1378 |
+
def forward(
|
| 1379 |
+
self,
|
| 1380 |
+
vision_features,
|
| 1381 |
+
class_features,
|
| 1382 |
+
task_features,
|
| 1383 |
+
task_mask,
|
| 1384 |
+
output_attentions=None,
|
| 1385 |
+
output_hidden_states=None,
|
| 1386 |
+
return_dict=None,
|
| 1387 |
+
):
|
| 1388 |
+
"""
|
| 1389 |
+
Args:
|
| 1390 |
+
vision_features (`torch.FloatTensor`): The sequence of vision features. shape depends on the vision
|
| 1391 |
+
backbone.
|
| 1392 |
+
class_features (`torch.FloatTensor`): The sequence of class features of shape
|
| 1393 |
+
`(class_sequence_length, batch_size, class_embed_dim)`.
|
| 1394 |
+
task_features (`torch.FloatTensor`): The sequence of task features of shape
|
| 1395 |
+
`(task_sequence_length, batch_size, decoder_hidden_dim)`.
|
| 1396 |
+
task_mask (`torch.LongTensor`): The mask for the task features of shape `(batch_size, task_sequence_length)`.
|
| 1397 |
+
output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention
|
| 1398 |
+
layers. See `attentions` under returned tensors for more detail.
|
| 1399 |
+
output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See
|
| 1400 |
+
`hidden_states` under returned tensors for more detail.
|
| 1401 |
+
return_dict (`bool`, *optional*): Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain
|
| 1402 |
+
tuple.
|
| 1403 |
+
"""
|
| 1404 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1405 |
+
output_hidden_states = (
|
| 1406 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1407 |
+
)
|
| 1408 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1409 |
+
|
| 1410 |
+
vision_features, vision_shapes, vision_shapes_list, level_start_index = self._get_encoder_input(
|
| 1411 |
+
vision_features
|
| 1412 |
+
)
|
| 1413 |
+
|
| 1414 |
+
# todo add denoising for training
|
| 1415 |
+
denoise_embeddings, denoise_bboxes, key_padding_mask = None, None, None
|
| 1416 |
+
batch_size = task_mask.shape[0]
|
| 1417 |
+
|
| 1418 |
+
# compose attn_mask for vision_emb and task_emb fusion
|
| 1419 |
+
task_features = self.task_encoder(task_features)
|
| 1420 |
+
if self.task_project is not None:
|
| 1421 |
+
task_features = self.task_project(task_features)
|
| 1422 |
+
src_key_mask = (task_mask == 0).detach()
|
| 1423 |
+
attn_mask_len = self.num_queries
|
| 1424 |
+
fusion_size = attn_mask_len + task_features.shape[0]
|
| 1425 |
+
key_padding_mask = torch.zeros([batch_size, fusion_size], dtype=torch.bool).to(task_features.device)
|
| 1426 |
+
key_padding_mask[:, attn_mask_len:] = src_key_mask
|
| 1427 |
+
attention_mask = _prepare_4d_attention_mask(~key_padding_mask, dtype=vision_features.dtype)
|
| 1428 |
+
decoder_embeddings, reference_points, encoder_bboxes, encoder_class_similarity, init_reference_points = (
|
| 1429 |
+
self._get_decoder_input(
|
| 1430 |
+
vision_features, tuple(vision_shapes_list), class_features, denoise_embeddings, denoise_bboxes
|
| 1431 |
+
)
|
| 1432 |
+
)
|
| 1433 |
+
|
| 1434 |
+
all_hidden_states = () if output_hidden_states else None
|
| 1435 |
+
all_attns = () if output_attentions else None
|
| 1436 |
+
all_self_attns = () if output_attentions else None
|
| 1437 |
+
all_cross_attns = () if output_attentions else None
|
| 1438 |
+
predicted_class_features = decoder_embeddings
|
| 1439 |
+
|
| 1440 |
+
if output_hidden_states:
|
| 1441 |
+
all_hidden_states = all_hidden_states + (predicted_class_features,)
|
| 1442 |
+
decoder_bboxes = []
|
| 1443 |
+
decoder_classes = []
|
| 1444 |
+
last_refined_bbox = None
|
| 1445 |
+
reference_points = reference_points.sigmoid()
|
| 1446 |
+
for i, layer in enumerate(self.layers):
|
| 1447 |
+
if self.gradient_checkpointing and self.training:
|
| 1448 |
+
predicted_class_features, task_features, self_attention, cross_attention = (
|
| 1449 |
+
self._gradient_checkpointing_func(
|
| 1450 |
+
layer.__call__,
|
| 1451 |
+
predicted_class_features,
|
| 1452 |
+
task_features,
|
| 1453 |
+
reference_points,
|
| 1454 |
+
vision_features,
|
| 1455 |
+
vision_shapes,
|
| 1456 |
+
vision_shapes_list,
|
| 1457 |
+
level_start_index=level_start_index,
|
| 1458 |
+
attention_mask=attention_mask,
|
| 1459 |
+
query_position=self.query_position_head(reference_points),
|
| 1460 |
+
output_attentions=output_attentions,
|
| 1461 |
+
output_hidden_states=output_hidden_states,
|
| 1462 |
+
)
|
| 1463 |
+
)
|
| 1464 |
+
else:
|
| 1465 |
+
predicted_class_features, task_features, self_attention, cross_attention = layer(
|
| 1466 |
+
predicted_class_features,
|
| 1467 |
+
task_features,
|
| 1468 |
+
reference_points,
|
| 1469 |
+
vision_features,
|
| 1470 |
+
vision_shapes,
|
| 1471 |
+
vision_shapes_list,
|
| 1472 |
+
level_start_index=level_start_index,
|
| 1473 |
+
attention_mask=attention_mask,
|
| 1474 |
+
query_position=self.query_position_head(reference_points),
|
| 1475 |
+
output_attentions=output_attentions,
|
| 1476 |
+
output_hidden_states=output_hidden_states,
|
| 1477 |
+
)
|
| 1478 |
+
if output_attentions:
|
| 1479 |
+
all_self_attns = all_self_attns + (self_attention,)
|
| 1480 |
+
all_cross_attns = all_cross_attns + (cross_attention,)
|
| 1481 |
+
if output_hidden_states:
|
| 1482 |
+
all_hidden_states = all_hidden_states + (predicted_class_features,)
|
| 1483 |
+
|
| 1484 |
+
refined_bbox = torch.sigmoid(
|
| 1485 |
+
self.decoder_bbox_head[i](predicted_class_features) + _inverse_sigmoid(reference_points)
|
| 1486 |
+
)
|
| 1487 |
+
original_class_projected = self.decoder_class_head[i](class_features).permute(1, 2, 0)
|
| 1488 |
+
if self.training:
|
| 1489 |
+
decoder_classes.append(
|
| 1490 |
+
get_class_similarity(
|
| 1491 |
+
class_distance_type=self.class_distance_type,
|
| 1492 |
+
cls_feature=predicted_class_features,
|
| 1493 |
+
class_proj=original_class_projected,
|
| 1494 |
+
)
|
| 1495 |
+
)
|
| 1496 |
+
if i == 0:
|
| 1497 |
+
decoder_bboxes.append(refined_bbox)
|
| 1498 |
+
else:
|
| 1499 |
+
decoder_bboxes.append(
|
| 1500 |
+
torch.sigmoid(
|
| 1501 |
+
self.decoder_bbox_head[i](predicted_class_features) + _inverse_sigmoid(last_refined_bbox)
|
| 1502 |
+
)
|
| 1503 |
+
)
|
| 1504 |
+
elif i == self.decoder_num_layers - 1:
|
| 1505 |
+
decoder_classes.append(
|
| 1506 |
+
get_class_similarity(self.class_distance_type, predicted_class_features, original_class_projected)
|
| 1507 |
+
)
|
| 1508 |
+
decoder_bboxes.append(refined_bbox)
|
| 1509 |
+
break
|
| 1510 |
+
last_refined_bbox = refined_bbox
|
| 1511 |
+
reference_points = refined_bbox.detach() if self.training else refined_bbox
|
| 1512 |
+
if output_attentions:
|
| 1513 |
+
all_attns += (all_self_attns, all_cross_attns)
|
| 1514 |
+
|
| 1515 |
+
last_hidden_state = predicted_class_features
|
| 1516 |
+
decoder_bboxes = torch.stack(decoder_bboxes)
|
| 1517 |
+
decoder_classes = torch.stack(decoder_classes)
|
| 1518 |
+
|
| 1519 |
+
if not return_dict:
|
| 1520 |
+
return (
|
| 1521 |
+
last_hidden_state,
|
| 1522 |
+
all_hidden_states,
|
| 1523 |
+
all_attns,
|
| 1524 |
+
decoder_bboxes,
|
| 1525 |
+
decoder_classes,
|
| 1526 |
+
encoder_bboxes,
|
| 1527 |
+
encoder_class_similarity,
|
| 1528 |
+
init_reference_points,
|
| 1529 |
+
reference_points,
|
| 1530 |
+
)
|
| 1531 |
+
|
| 1532 |
+
return OmDetTurboDecoderOutput(
|
| 1533 |
+
last_hidden_state=last_hidden_state,
|
| 1534 |
+
hidden_states=all_hidden_states,
|
| 1535 |
+
attentions=all_attns,
|
| 1536 |
+
decoder_coords=decoder_bboxes,
|
| 1537 |
+
decoder_classes=decoder_classes,
|
| 1538 |
+
encoder_coord_logits=encoder_bboxes,
|
| 1539 |
+
encoder_class_logits=encoder_class_similarity,
|
| 1540 |
+
init_reference_points=init_reference_points,
|
| 1541 |
+
intermediate_reference_points=reference_points,
|
| 1542 |
+
)
|
| 1543 |
+
|
| 1544 |
+
|
| 1545 |
+
@add_start_docstrings(
|
| 1546 |
+
"""
|
| 1547 |
+
OmDetTurbo Model (consisting of a vision and a text backbone, and encoder-decoder architecture) outputting
|
| 1548 |
+
bounding boxes and classes scores for tasks such as COCO detection.
|
| 1549 |
+
""",
|
| 1550 |
+
OMDET_TURBO_START_DOCSTRING,
|
| 1551 |
+
)
|
| 1552 |
+
class OmDetTurboForObjectDetection(OmDetTurboPreTrainedModel):
|
| 1553 |
+
def __init__(self, config: OmDetTurboConfig):
|
| 1554 |
+
super().__init__(config)
|
| 1555 |
+
self.vision_backbone = OmDetTurboVisionBackbone(config)
|
| 1556 |
+
self.language_backbone = OmDetTurboLanguageBackbone(config)
|
| 1557 |
+
self.encoder = OmDetTurboHybridEncoder(config)
|
| 1558 |
+
self.decoder = OmDetTurboDecoder(config)
|
| 1559 |
+
self.num_queries = config.num_queries
|
| 1560 |
+
|
| 1561 |
+
self.language_cache_class = OmDetTurboLRUCache(config.cache_size)
|
| 1562 |
+
self.language_cache_prompt = OmDetTurboLRUCache(config.cache_size)
|
| 1563 |
+
self.vocab_size = config.text_config.vocab_size
|
| 1564 |
+
self.post_init()
|
| 1565 |
+
|
| 1566 |
+
def get_input_embeddings(self):
|
| 1567 |
+
return self.language_backbone.model.get_input_embeddings()
|
| 1568 |
+
|
| 1569 |
+
def set_input_embeddings(self, value):
|
| 1570 |
+
self.language_backbone.model.set_input_embeddings(value)
|
| 1571 |
+
|
| 1572 |
+
def resize_token_embeddings(
|
| 1573 |
+
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None, mean_resizing: bool = True
|
| 1574 |
+
) -> nn.Embedding:
|
| 1575 |
+
model_embeds = self.language_backbone.model.resize_token_embeddings(
|
| 1576 |
+
new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of, mean_resizing=mean_resizing
|
| 1577 |
+
)
|
| 1578 |
+
self.config.text_config.vocab_size = model_embeds.num_embeddings
|
| 1579 |
+
self.vocab_size = model_embeds.num_embeddings
|
| 1580 |
+
return model_embeds
|
| 1581 |
+
|
| 1582 |
+
@add_start_docstrings_to_model_forward(OMDET_TURBO_INPUTS_DOCSTRING)
|
| 1583 |
+
@replace_return_docstrings(output_type=OmDetTurboObjectDetectionOutput, config_class=_CONFIG_FOR_DOC)
|
| 1584 |
+
def forward(
|
| 1585 |
+
self,
|
| 1586 |
+
pixel_values: torch.FloatTensor,
|
| 1587 |
+
classes_input_ids: torch.LongTensor,
|
| 1588 |
+
classes_attention_mask: torch.LongTensor,
|
| 1589 |
+
tasks_input_ids: torch.LongTensor,
|
| 1590 |
+
tasks_attention_mask: torch.LongTensor,
|
| 1591 |
+
classes_structure: torch.LongTensor,
|
| 1592 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1593 |
+
output_attentions: Optional[bool] = None,
|
| 1594 |
+
output_hidden_states: Optional[bool] = None,
|
| 1595 |
+
return_dict: Optional[bool] = None,
|
| 1596 |
+
) -> Union[Tuple[torch.FloatTensor], OmDetTurboObjectDetectionOutput]:
|
| 1597 |
+
r"""
|
| 1598 |
+
Returns:
|
| 1599 |
+
|
| 1600 |
+
Examples:
|
| 1601 |
+
|
| 1602 |
+
```python
|
| 1603 |
+
>>> import requests
|
| 1604 |
+
>>> from PIL import Image
|
| 1605 |
+
|
| 1606 |
+
>>> from transformers import AutoProcessor, OmDetTurboForObjectDetection
|
| 1607 |
+
|
| 1608 |
+
>>> processor = AutoProcessor.from_pretrained("omlab/omdet-turbo-swin-tiny-hf")
|
| 1609 |
+
>>> model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf")
|
| 1610 |
+
|
| 1611 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1612 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1613 |
+
>>> classes = ["cat", "remote"]
|
| 1614 |
+
>>> task = "Detect {}.".format(", ".join(classes))
|
| 1615 |
+
>>> inputs = processor(image, text=classes, task=task, return_tensors="pt")
|
| 1616 |
+
|
| 1617 |
+
>>> outputs = model(**inputs)
|
| 1618 |
+
|
| 1619 |
+
>>> # convert outputs (bounding boxes and class logits)
|
| 1620 |
+
>>> results = processor.post_process_grounded_object_detection(
|
| 1621 |
+
... outputs,
|
| 1622 |
+
... classes=classes,
|
| 1623 |
+
... target_sizes=[image.size[::-1]],
|
| 1624 |
+
... score_threshold=0.3,
|
| 1625 |
+
... nms_threshold=0.3,
|
| 1626 |
+
>>> )[0]
|
| 1627 |
+
>>> for score, class_name, box in zip(results["scores"], results["classes"], results["boxes"]):
|
| 1628 |
+
... box = [round(i, 1) for i in box.tolist()]
|
| 1629 |
+
... print(
|
| 1630 |
+
... f"Detected {class_name} with confidence "
|
| 1631 |
+
... f"{round(score.item(), 2)} at location {box}"
|
| 1632 |
+
... )
|
| 1633 |
+
Detected remote with confidence 0.76 at location [39.9, 71.3, 176.5, 117.9]
|
| 1634 |
+
Detected cat with confidence 0.72 at location [345.1, 22.5, 639.7, 371.9]
|
| 1635 |
+
Detected cat with confidence 0.65 at location [12.7, 53.8, 315.5, 475.3]
|
| 1636 |
+
Detected remote with confidence 0.57 at location [333.4, 75.6, 370.7, 187.0]
|
| 1637 |
+
```"""
|
| 1638 |
+
if labels is not None:
|
| 1639 |
+
raise NotImplementedError("Training is not implemented yet")
|
| 1640 |
+
|
| 1641 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1642 |
+
output_hidden_states = (
|
| 1643 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1644 |
+
)
|
| 1645 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1646 |
+
|
| 1647 |
+
loss = None
|
| 1648 |
+
image_features = self.vision_backbone(pixel_values)
|
| 1649 |
+
encoder_outputs = self.encoder(
|
| 1650 |
+
image_features,
|
| 1651 |
+
output_attentions=output_attentions,
|
| 1652 |
+
output_hidden_states=output_hidden_states,
|
| 1653 |
+
return_dict=return_dict,
|
| 1654 |
+
)
|
| 1655 |
+
class_features, task_features, task_mask = self.get_language_embedding(
|
| 1656 |
+
classes_input_ids,
|
| 1657 |
+
classes_attention_mask,
|
| 1658 |
+
tasks_input_ids,
|
| 1659 |
+
tasks_attention_mask,
|
| 1660 |
+
classes_structure,
|
| 1661 |
+
)
|
| 1662 |
+
encoder_extracted_states = encoder_outputs.extracted_states if return_dict else encoder_outputs[-1]
|
| 1663 |
+
decoder_outputs = self.decoder(
|
| 1664 |
+
encoder_extracted_states,
|
| 1665 |
+
class_features,
|
| 1666 |
+
task_features,
|
| 1667 |
+
task_mask,
|
| 1668 |
+
output_attentions=output_attentions,
|
| 1669 |
+
output_hidden_states=output_hidden_states,
|
| 1670 |
+
return_dict=return_dict,
|
| 1671 |
+
)
|
| 1672 |
+
|
| 1673 |
+
if not return_dict:
|
| 1674 |
+
return tuple(
|
| 1675 |
+
output
|
| 1676 |
+
for output in [
|
| 1677 |
+
loss,
|
| 1678 |
+
decoder_outputs[3][-1],
|
| 1679 |
+
decoder_outputs[4][-1],
|
| 1680 |
+
decoder_outputs[7],
|
| 1681 |
+
decoder_outputs[8],
|
| 1682 |
+
decoder_outputs[5],
|
| 1683 |
+
decoder_outputs[6],
|
| 1684 |
+
encoder_outputs[-1],
|
| 1685 |
+
decoder_outputs[1],
|
| 1686 |
+
decoder_outputs[2],
|
| 1687 |
+
encoder_outputs[1],
|
| 1688 |
+
encoder_outputs[2],
|
| 1689 |
+
classes_structure,
|
| 1690 |
+
]
|
| 1691 |
+
if output is not None
|
| 1692 |
+
)
|
| 1693 |
+
|
| 1694 |
+
return OmDetTurboObjectDetectionOutput(
|
| 1695 |
+
loss=loss,
|
| 1696 |
+
decoder_coord_logits=decoder_outputs.decoder_coords[-1],
|
| 1697 |
+
decoder_class_logits=decoder_outputs.decoder_classes[-1],
|
| 1698 |
+
init_reference_points=decoder_outputs.init_reference_points,
|
| 1699 |
+
intermediate_reference_points=decoder_outputs.intermediate_reference_points,
|
| 1700 |
+
encoder_coord_logits=decoder_outputs.encoder_coord_logits,
|
| 1701 |
+
encoder_class_logits=decoder_outputs.encoder_class_logits,
|
| 1702 |
+
encoder_extracted_states=encoder_outputs.extracted_states,
|
| 1703 |
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
| 1704 |
+
decoder_attentions=decoder_outputs.attentions,
|
| 1705 |
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
| 1706 |
+
encoder_attentions=encoder_outputs.attentions,
|
| 1707 |
+
classes_structure=classes_structure,
|
| 1708 |
+
)
|
| 1709 |
+
|
| 1710 |
+
|
| 1711 |
+
__all__ = ["OmDetTurboForObjectDetection", "OmDetTurboPreTrainedModel"]
|
docs/transformers/src/transformers/models/omdet_turbo/processing_omdet_turbo.py
ADDED
|
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
Processor class for OmDet-Turbo.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import warnings
|
| 20 |
+
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
| 21 |
+
|
| 22 |
+
from ...feature_extraction_utils import BatchFeature
|
| 23 |
+
from ...image_transforms import center_to_corners_format
|
| 24 |
+
from ...image_utils import ImageInput
|
| 25 |
+
from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack
|
| 26 |
+
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
| 27 |
+
from ...utils import (
|
| 28 |
+
TensorType,
|
| 29 |
+
is_torch_available,
|
| 30 |
+
is_torchvision_available,
|
| 31 |
+
)
|
| 32 |
+
from ...utils.deprecation import deprecate_kwarg
|
| 33 |
+
from ...utils.import_utils import requires
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
if TYPE_CHECKING:
|
| 37 |
+
from .modeling_omdet_turbo import OmDetTurboObjectDetectionOutput
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class OmDetTurboTextKwargs(TextKwargs, total=False):
|
| 41 |
+
task: Optional[Union[str, List[str], TextInput, PreTokenizedInput]]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
if is_torch_available():
|
| 45 |
+
import torch
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
if is_torchvision_available():
|
| 49 |
+
from torchvision.ops.boxes import batched_nms
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class OmDetTurboProcessorKwargs(ProcessingKwargs, total=False):
|
| 53 |
+
text_kwargs: OmDetTurboTextKwargs
|
| 54 |
+
_defaults = {
|
| 55 |
+
"text_kwargs": {
|
| 56 |
+
"add_special_tokens": True,
|
| 57 |
+
"padding": "max_length",
|
| 58 |
+
"truncation": True,
|
| 59 |
+
"max_length": 77,
|
| 60 |
+
"stride": 0,
|
| 61 |
+
"return_overflowing_tokens": False,
|
| 62 |
+
"return_special_tokens_mask": False,
|
| 63 |
+
"return_offsets_mapping": False,
|
| 64 |
+
"return_token_type_ids": False,
|
| 65 |
+
"return_length": False,
|
| 66 |
+
"verbose": True,
|
| 67 |
+
"task": None,
|
| 68 |
+
},
|
| 69 |
+
"images_kwargs": {},
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class DictWithDeprecationWarning(dict):
|
| 74 |
+
message = (
|
| 75 |
+
"The `classes` key is deprecated for `OmDetTurboProcessor.post_process_grounded_object_detection` "
|
| 76 |
+
"output dict and will be removed in a 4.51.0 version. Please use `text_labels` instead."
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
def __getitem__(self, key):
|
| 80 |
+
if key == "classes":
|
| 81 |
+
warnings.warn(self.message, FutureWarning)
|
| 82 |
+
return super().__getitem__("text_labels")
|
| 83 |
+
return super().__getitem__(key)
|
| 84 |
+
|
| 85 |
+
def get(self, key, *args, **kwargs):
|
| 86 |
+
if key == "classes":
|
| 87 |
+
warnings.warn(self.message, FutureWarning)
|
| 88 |
+
return super().get("text_labels", *args, **kwargs)
|
| 89 |
+
return super().get(key, *args, **kwargs)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def clip_boxes(box, box_size: Tuple[int, int]):
|
| 93 |
+
"""
|
| 94 |
+
Clip the boxes by limiting x coordinates to the range [0, width]
|
| 95 |
+
and y coordinates to the range [0, height].
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
box (Tensor): The box to be clipped.
|
| 99 |
+
box_size (height, width): The clipping box's size.
|
| 100 |
+
"""
|
| 101 |
+
assert torch.isfinite(box).all(), "Box tensor contains infinite or NaN!"
|
| 102 |
+
height, width = box_size
|
| 103 |
+
x1 = box[:, 0].clamp(min=0, max=width)
|
| 104 |
+
y1 = box[:, 1].clamp(min=0, max=height)
|
| 105 |
+
x2 = box[:, 2].clamp(min=0, max=width)
|
| 106 |
+
y2 = box[:, 3].clamp(min=0, max=height)
|
| 107 |
+
box = torch.stack((x1, y1, x2, y2), dim=-1)
|
| 108 |
+
|
| 109 |
+
return box
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def compute_score(boxes):
|
| 113 |
+
"""
|
| 114 |
+
Compute logit scores per class for each box (proposal) and an array of class indices
|
| 115 |
+
corresponding to each proposal, flattened across the proposal_num.
|
| 116 |
+
The indices in `classes` will later be used to filter and match the predicted classes
|
| 117 |
+
with the input class names.
|
| 118 |
+
"""
|
| 119 |
+
num_classes = boxes.shape[2]
|
| 120 |
+
proposal_num = boxes.shape[1]
|
| 121 |
+
scores = torch.sigmoid(boxes)
|
| 122 |
+
classes = torch.arange(num_classes, device=boxes.device).unsqueeze(0).repeat(proposal_num, 1).flatten(0, 1)
|
| 123 |
+
return scores, classes
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def _post_process_boxes_for_image(
|
| 127 |
+
boxes: "torch.Tensor",
|
| 128 |
+
scores: "torch.Tensor",
|
| 129 |
+
labels: "torch.Tensor",
|
| 130 |
+
image_num_classes: int,
|
| 131 |
+
image_size: Tuple[int, int],
|
| 132 |
+
threshold: float,
|
| 133 |
+
nms_threshold: float,
|
| 134 |
+
max_num_det: Optional[int] = None,
|
| 135 |
+
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
| 136 |
+
"""
|
| 137 |
+
Filter predicted results using given thresholds and NMS.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
boxes (`torch.Tensor`):
|
| 141 |
+
A Tensor of predicted class-specific or class-agnostic boxes for the image.
|
| 142 |
+
Shape (num_queries, max_num_classes_in_batch * 4) if doing class-specific regression,
|
| 143 |
+
or (num_queries, 4) if doing class-agnostic regression.
|
| 144 |
+
scores (`torch.Tensor` of shape (num_queries, max_num_classes_in_batch + 1)):
|
| 145 |
+
A Tensor of predicted class scores for the image.
|
| 146 |
+
labels (`torch.Tensor` of shape (num_queries * (max_num_classes_in_batch + 1),)):
|
| 147 |
+
A Tensor of predicted labels for the image.
|
| 148 |
+
image_num_classes (`int`):
|
| 149 |
+
The number of classes queried for detection on the image.
|
| 150 |
+
image_size (`Tuple[int, int]`):
|
| 151 |
+
A tuple of (height, width) for the image.
|
| 152 |
+
threshold (`float`):
|
| 153 |
+
Only return detections with a confidence score exceeding this threshold.
|
| 154 |
+
nms_threshold (`float`):
|
| 155 |
+
The threshold to use for box non-maximum suppression. Value in [0, 1].
|
| 156 |
+
max_num_det (`int`, *optional*):
|
| 157 |
+
The maximum number of detections to return. Default is None.
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
Tuple: A tuple with the following:
|
| 161 |
+
"boxes" (Tensor): A tensor of shape (num_filtered_objects, 4), containing the predicted boxes in (x1, y1, x2, y2) format.
|
| 162 |
+
"scores" (Tensor): A tensor of shape (num_filtered_objects,), containing the predicted confidence scores for each detection.
|
| 163 |
+
"labels" (Tensor): A tensor of ids, where each id is the predicted class id for the corresponding detection
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
# Filter by max number of detections
|
| 167 |
+
proposal_num = len(boxes) if max_num_det is None else max_num_det
|
| 168 |
+
scores_per_image, topk_indices = scores.flatten(0, 1).topk(proposal_num, sorted=False)
|
| 169 |
+
labels_per_image = labels[topk_indices]
|
| 170 |
+
boxes_per_image = boxes.view(-1, 1, 4).repeat(1, scores.shape[1], 1).view(-1, 4)
|
| 171 |
+
boxes_per_image = boxes_per_image[topk_indices]
|
| 172 |
+
|
| 173 |
+
# Convert and scale boxes to original image size
|
| 174 |
+
boxes_per_image = center_to_corners_format(boxes_per_image)
|
| 175 |
+
boxes_per_image = boxes_per_image * torch.tensor(image_size[::-1]).repeat(2).to(boxes_per_image.device)
|
| 176 |
+
|
| 177 |
+
# Filtering by confidence score
|
| 178 |
+
filter_mask = scores_per_image > threshold # R x K
|
| 179 |
+
score_keep = filter_mask.nonzero(as_tuple=False).view(-1)
|
| 180 |
+
boxes_per_image = boxes_per_image[score_keep]
|
| 181 |
+
scores_per_image = scores_per_image[score_keep]
|
| 182 |
+
labels_per_image = labels_per_image[score_keep]
|
| 183 |
+
|
| 184 |
+
# Ensure we did not overflow to non existing classes
|
| 185 |
+
filter_classes_mask = labels_per_image < image_num_classes
|
| 186 |
+
classes_keep = filter_classes_mask.nonzero(as_tuple=False).view(-1)
|
| 187 |
+
boxes_per_image = boxes_per_image[classes_keep]
|
| 188 |
+
scores_per_image = scores_per_image[classes_keep]
|
| 189 |
+
labels_per_image = labels_per_image[classes_keep]
|
| 190 |
+
|
| 191 |
+
# NMS
|
| 192 |
+
keep = batched_nms(boxes_per_image, scores_per_image, labels_per_image, nms_threshold)
|
| 193 |
+
boxes_per_image = boxes_per_image[keep]
|
| 194 |
+
scores_per_image = scores_per_image[keep]
|
| 195 |
+
labels_per_image = labels_per_image[keep]
|
| 196 |
+
|
| 197 |
+
# Clip to image size
|
| 198 |
+
boxes_per_image = clip_boxes(boxes_per_image, image_size)
|
| 199 |
+
|
| 200 |
+
return boxes_per_image, scores_per_image, labels_per_image
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
@requires(backends=("vision", "torchvision"))
|
| 204 |
+
class OmDetTurboProcessor(ProcessorMixin):
|
| 205 |
+
r"""
|
| 206 |
+
Constructs a OmDet-Turbo processor which wraps a Deformable DETR image processor and an AutoTokenizer into a
|
| 207 |
+
single processor.
|
| 208 |
+
|
| 209 |
+
[`OmDetTurboProcessor`] offers all the functionalities of [`DetrImageProcessor`] and
|
| 210 |
+
[`AutoTokenizer`]. See the docstring of [`~OmDetTurboProcessor.__call__`] and [`~OmDetTurboProcessor.decode`]
|
| 211 |
+
for more information.
|
| 212 |
+
|
| 213 |
+
Args:
|
| 214 |
+
image_processor (`DetrImageProcessor`):
|
| 215 |
+
An instance of [`DetrImageProcessor`]. The image processor is a required input.
|
| 216 |
+
tokenizer (`AutoTokenizer`):
|
| 217 |
+
An instance of ['PreTrainedTokenizer`]. The tokenizer is a required input.
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
attributes = ["image_processor", "tokenizer"]
|
| 221 |
+
image_processor_class = ("DetrImageProcessor", "DetrImageProcessorFast")
|
| 222 |
+
tokenizer_class = "AutoTokenizer"
|
| 223 |
+
|
| 224 |
+
def __init__(self, image_processor, tokenizer):
|
| 225 |
+
super().__init__(image_processor, tokenizer)
|
| 226 |
+
|
| 227 |
+
def __call__(
|
| 228 |
+
self,
|
| 229 |
+
images: ImageInput = None,
|
| 230 |
+
text: Union[List[str], List[List[str]]] = None,
|
| 231 |
+
audio=None,
|
| 232 |
+
videos=None,
|
| 233 |
+
**kwargs: Unpack[OmDetTurboProcessorKwargs],
|
| 234 |
+
) -> BatchFeature:
|
| 235 |
+
"""
|
| 236 |
+
This method uses [*DetrImageProcessor.__call__] method to prepare image(s) for the model, and
|
| 237 |
+
[CLIPTokenizerFast.__call__] to prepare text for the model.
|
| 238 |
+
|
| 239 |
+
Please refer to the docstring of the above two methods for more information.
|
| 240 |
+
|
| 241 |
+
Args:
|
| 242 |
+
images (`ImageInput`):
|
| 243 |
+
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255.
|
| 244 |
+
text (`Union[str, List[str], List[List[str]]]`):
|
| 245 |
+
The classes used to limit the scope of the open vocabulary detection. Expects a list of strings or a list
|
| 246 |
+
of list of strings. Batched classes can be of different lengths.
|
| 247 |
+
Examples: ["cat", "dog", "bird"], [["cat", "dog", "bird"], ["hat", "person"], ["car"]]
|
| 248 |
+
Kwargs:
|
| 249 |
+
task (`Union[str, List[str], TextInput, PreTokenizedInput]`):
|
| 250 |
+
The grounded text used to guide open vocabulary detection. Expects a single string or a list of strings.
|
| 251 |
+
Examples: "Detect a cat, a dog, and a bird.",[ "Detect everything.", "Detect trees and flowers."]
|
| 252 |
+
When not provided, the default task is "Detect [class1], [class2], [class3]" etc.
|
| 253 |
+
...
|
| 254 |
+
"""
|
| 255 |
+
if images is None or text is None:
|
| 256 |
+
raise ValueError("You have to specify both `images` and `text`")
|
| 257 |
+
|
| 258 |
+
output_kwargs = self._merge_kwargs(
|
| 259 |
+
OmDetTurboProcessorKwargs,
|
| 260 |
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
| 261 |
+
**kwargs,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
if isinstance(text, str):
|
| 265 |
+
text = text.strip(" ").split(",")
|
| 266 |
+
|
| 267 |
+
if not (len(text) and isinstance(text[0], (list, tuple))):
|
| 268 |
+
text = [text]
|
| 269 |
+
|
| 270 |
+
task = output_kwargs["text_kwargs"].pop("task", None)
|
| 271 |
+
if task is None:
|
| 272 |
+
task = ["Detect {}.".format(", ".join(text_single)) for text_single in text]
|
| 273 |
+
elif not isinstance(task, (list, tuple)):
|
| 274 |
+
task = [task]
|
| 275 |
+
|
| 276 |
+
encoding_image_processor = self.image_processor(images, **output_kwargs["images_kwargs"])
|
| 277 |
+
tasks_encoding = self.tokenizer(text=task, **output_kwargs["text_kwargs"])
|
| 278 |
+
|
| 279 |
+
classes = text
|
| 280 |
+
|
| 281 |
+
classes_structure = torch.tensor([len(class_single) for class_single in classes], dtype=torch.long)
|
| 282 |
+
classes_flattened = [class_single for class_batch in classes for class_single in class_batch]
|
| 283 |
+
classes_encoding = self.tokenizer(text=classes_flattened, **output_kwargs["text_kwargs"])
|
| 284 |
+
|
| 285 |
+
encoding = BatchFeature()
|
| 286 |
+
encoding.update({f"tasks_{key}": value for key, value in tasks_encoding.items()})
|
| 287 |
+
encoding.update({f"classes_{key}": value for key, value in classes_encoding.items()})
|
| 288 |
+
encoding.update({"classes_structure": classes_structure})
|
| 289 |
+
encoding.update(encoding_image_processor)
|
| 290 |
+
|
| 291 |
+
return encoding
|
| 292 |
+
|
| 293 |
+
# Copied from transformers.models.blip.processing_blip.BlipProcessor.batch_decode with BertTokenizerFast->PreTrainedTokenizer
|
| 294 |
+
def batch_decode(self, *args, **kwargs):
|
| 295 |
+
"""
|
| 296 |
+
This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
|
| 297 |
+
refer to the docstring of this method for more information.
|
| 298 |
+
"""
|
| 299 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
| 300 |
+
|
| 301 |
+
# Copied from transformers.models.blip.processing_blip.BlipProcessor.decode with BertTokenizerFast->PreTrainedTokenizer
|
| 302 |
+
def decode(self, *args, **kwargs):
|
| 303 |
+
"""
|
| 304 |
+
This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
|
| 305 |
+
the docstring of this method for more information.
|
| 306 |
+
"""
|
| 307 |
+
return self.tokenizer.decode(*args, **kwargs)
|
| 308 |
+
|
| 309 |
+
def _get_default_image_size(self) -> Tuple[int, int]:
|
| 310 |
+
height = (
|
| 311 |
+
self.image_processor.size["height"]
|
| 312 |
+
if "height" in self.image_processor.size
|
| 313 |
+
else self.image_processor.size["shortest_edge"]
|
| 314 |
+
)
|
| 315 |
+
width = (
|
| 316 |
+
self.image_processor.size["width"]
|
| 317 |
+
if "width" in self.image_processor.size
|
| 318 |
+
else self.image_processor.size["longest_edge"]
|
| 319 |
+
)
|
| 320 |
+
return height, width
|
| 321 |
+
|
| 322 |
+
@deprecate_kwarg("score_threshold", new_name="threshold", version="4.51.0")
|
| 323 |
+
@deprecate_kwarg("classes", new_name="text_labels", version="4.51.0")
|
| 324 |
+
def post_process_grounded_object_detection(
|
| 325 |
+
self,
|
| 326 |
+
outputs: "OmDetTurboObjectDetectionOutput",
|
| 327 |
+
text_labels: Optional[Union[List[str], List[List[str]]]] = None,
|
| 328 |
+
threshold: float = 0.3,
|
| 329 |
+
nms_threshold: float = 0.5,
|
| 330 |
+
target_sizes: Optional[Union[TensorType, List[Tuple]]] = None,
|
| 331 |
+
max_num_det: Optional[int] = None,
|
| 332 |
+
):
|
| 333 |
+
"""
|
| 334 |
+
Converts the raw output of [`OmDetTurboForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
|
| 335 |
+
bottom_right_x, bottom_right_y) format and get the associated text class.
|
| 336 |
+
|
| 337 |
+
Args:
|
| 338 |
+
outputs ([`OmDetTurboObjectDetectionOutput`]):
|
| 339 |
+
Raw outputs of the model.
|
| 340 |
+
text_labels (Union[List[str], List[List[str]]], *optional*):
|
| 341 |
+
The input classes names. If not provided, `text_labels` will be set to `None` in `outputs`.
|
| 342 |
+
threshold (float, defaults to 0.3):
|
| 343 |
+
Only return detections with a confidence score exceeding this threshold.
|
| 344 |
+
nms_threshold (float, defaults to 0.5):
|
| 345 |
+
The threshold to use for box non-maximum suppression. Value in [0, 1].
|
| 346 |
+
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
|
| 347 |
+
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
|
| 348 |
+
`(height, width)` of each image in the batch. If unset, predictions will not be resized.
|
| 349 |
+
max_num_det (`int`, *optional*):
|
| 350 |
+
The maximum number of detections to return.
|
| 351 |
+
Returns:
|
| 352 |
+
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, classes and boxes for an image
|
| 353 |
+
in the batch as predicted by the model.
|
| 354 |
+
"""
|
| 355 |
+
|
| 356 |
+
batch_size = len(outputs.decoder_coord_logits)
|
| 357 |
+
|
| 358 |
+
# Inputs consistency check for target sizes
|
| 359 |
+
if target_sizes is None:
|
| 360 |
+
height, width = self._get_default_image_size()
|
| 361 |
+
target_sizes = [(height, width)] * batch_size
|
| 362 |
+
|
| 363 |
+
if any(len(image_size) != 2 for image_size in target_sizes):
|
| 364 |
+
raise ValueError(
|
| 365 |
+
"Each element of target_sizes must contain the size (height, width) of each image of the batch"
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
if len(target_sizes) != batch_size:
|
| 369 |
+
raise ValueError("Make sure that you pass in as many target sizes as output sequences")
|
| 370 |
+
|
| 371 |
+
# Inputs consistency check for text labels
|
| 372 |
+
if text_labels is not None and isinstance(text_labels[0], str):
|
| 373 |
+
text_labels = [text_labels]
|
| 374 |
+
|
| 375 |
+
if text_labels is not None and len(text_labels) != batch_size:
|
| 376 |
+
raise ValueError("Make sure that you pass in as many classes group as output sequences")
|
| 377 |
+
|
| 378 |
+
# Convert target_sizes to list for easier handling
|
| 379 |
+
if isinstance(target_sizes, torch.Tensor):
|
| 380 |
+
target_sizes = target_sizes.tolist()
|
| 381 |
+
|
| 382 |
+
batch_boxes = outputs.decoder_coord_logits
|
| 383 |
+
batch_logits = outputs.decoder_class_logits
|
| 384 |
+
batch_num_classes = outputs.classes_structure
|
| 385 |
+
|
| 386 |
+
batch_scores, batch_labels = compute_score(batch_logits)
|
| 387 |
+
|
| 388 |
+
results = []
|
| 389 |
+
for boxes, scores, image_size, image_num_classes in zip(
|
| 390 |
+
batch_boxes, batch_scores, target_sizes, batch_num_classes
|
| 391 |
+
):
|
| 392 |
+
boxes, scores, labels = _post_process_boxes_for_image(
|
| 393 |
+
boxes=boxes,
|
| 394 |
+
scores=scores,
|
| 395 |
+
labels=batch_labels,
|
| 396 |
+
image_num_classes=image_num_classes,
|
| 397 |
+
image_size=image_size,
|
| 398 |
+
threshold=threshold,
|
| 399 |
+
nms_threshold=nms_threshold,
|
| 400 |
+
max_num_det=max_num_det,
|
| 401 |
+
)
|
| 402 |
+
result = DictWithDeprecationWarning(
|
| 403 |
+
{"boxes": boxes, "scores": scores, "labels": labels, "text_labels": None}
|
| 404 |
+
)
|
| 405 |
+
results.append(result)
|
| 406 |
+
|
| 407 |
+
# Add text labels
|
| 408 |
+
if text_labels is not None:
|
| 409 |
+
for result, image_text_labels in zip(results, text_labels):
|
| 410 |
+
result["text_labels"] = [image_text_labels[idx] for idx in result["labels"]]
|
| 411 |
+
|
| 412 |
+
return results
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
__all__ = ["OmDetTurboProcessor"]
|
docs/transformers/src/transformers/models/oneformer/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_oneformer import *
|
| 22 |
+
from .image_processing_oneformer import *
|
| 23 |
+
from .modeling_oneformer import *
|
| 24 |
+
from .processing_oneformer import *
|
| 25 |
+
else:
|
| 26 |
+
import sys
|
| 27 |
+
|
| 28 |
+
_file = globals()["__file__"]
|
| 29 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/src/transformers/models/oneformer/configuration_oneformer.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 SHI Labs and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""OneFormer model configuration"""
|
| 16 |
+
|
| 17 |
+
from typing import Dict, Optional
|
| 18 |
+
|
| 19 |
+
from ...configuration_utils import PretrainedConfig
|
| 20 |
+
from ...utils import logging
|
| 21 |
+
from ...utils.backbone_utils import verify_backbone_config_arguments
|
| 22 |
+
from ..auto import CONFIG_MAPPING
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logger = logging.get_logger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class OneFormerConfig(PretrainedConfig):
|
| 29 |
+
r"""
|
| 30 |
+
This is the configuration class to store the configuration of a [`OneFormerModel`]. It is used to instantiate a
|
| 31 |
+
OneFormer model according to the specified arguments, defining the model architecture. Instantiating a
|
| 32 |
+
configuration with the defaults will yield a similar configuration to that of the OneFormer
|
| 33 |
+
[shi-labs/oneformer_ade20k_swin_tiny](https://huggingface.co/shi-labs/oneformer_ade20k_swin_tiny) architecture
|
| 34 |
+
trained on [ADE20k-150](https://huggingface.co/datasets/scene_parse_150).
|
| 35 |
+
|
| 36 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 37 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
backbone_config (`PretrainedConfig`, *optional*, defaults to `SwinConfig`):
|
| 41 |
+
The configuration of the backbone model.
|
| 42 |
+
backbone (`str`, *optional*):
|
| 43 |
+
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
|
| 44 |
+
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
|
| 45 |
+
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
|
| 46 |
+
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
|
| 47 |
+
Whether to use pretrained weights for the backbone.
|
| 48 |
+
use_timm_backbone (`bool`, *optional*, defaults to `False`):
|
| 49 |
+
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
|
| 50 |
+
library.
|
| 51 |
+
backbone_kwargs (`dict`, *optional*):
|
| 52 |
+
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
|
| 53 |
+
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
|
| 54 |
+
ignore_value (`int`, *optional*, defaults to 255):
|
| 55 |
+
Values to be ignored in GT label while calculating loss.
|
| 56 |
+
num_queries (`int`, *optional*, defaults to 150):
|
| 57 |
+
Number of object queries.
|
| 58 |
+
no_object_weight (`float`, *optional*, defaults to 0.1):
|
| 59 |
+
Weight for no-object class predictions.
|
| 60 |
+
class_weight (`float`, *optional*, defaults to 2.0):
|
| 61 |
+
Weight for Classification CE loss.
|
| 62 |
+
mask_weight (`float`, *optional*, defaults to 5.0):
|
| 63 |
+
Weight for binary CE loss.
|
| 64 |
+
dice_weight (`float`, *optional*, defaults to 5.0):
|
| 65 |
+
Weight for dice loss.
|
| 66 |
+
contrastive_weight (`float`, *optional*, defaults to 0.5):
|
| 67 |
+
Weight for contrastive loss.
|
| 68 |
+
contrastive_temperature (`float`, *optional*, defaults to 0.07):
|
| 69 |
+
Initial value for scaling the contrastive logits.
|
| 70 |
+
train_num_points (`int`, *optional*, defaults to 12544):
|
| 71 |
+
Number of points to sample while calculating losses on mask predictions.
|
| 72 |
+
oversample_ratio (`float`, *optional*, defaults to 3.0):
|
| 73 |
+
Ratio to decide how many points to oversample.
|
| 74 |
+
importance_sample_ratio (`float`, *optional*, defaults to 0.75):
|
| 75 |
+
Ratio of points that are sampled via importance sampling.
|
| 76 |
+
init_std (`float`, *optional*, defaults to 0.02):
|
| 77 |
+
Standard deviation for normal intialization.
|
| 78 |
+
init_xavier_std (`float`, *optional*, defaults to 1.0):
|
| 79 |
+
Standard deviation for xavier uniform initialization.
|
| 80 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 81 |
+
Epsilon for layer normalization.
|
| 82 |
+
is_training (`bool`, *optional*, defaults to `False`):
|
| 83 |
+
Whether to run in training or inference mode.
|
| 84 |
+
use_auxiliary_loss (`bool`, *optional*, defaults to `True`):
|
| 85 |
+
Whether to calculate loss using intermediate predictions from transformer decoder.
|
| 86 |
+
output_auxiliary_logits (`bool`, *optional*, defaults to `True`):
|
| 87 |
+
Whether to return intermediate predictions from transformer decoder.
|
| 88 |
+
strides (`list`, *optional*, defaults to `[4, 8, 16, 32]`):
|
| 89 |
+
List containing the strides for feature maps in the encoder.
|
| 90 |
+
task_seq_len (`int`, *optional*, defaults to 77):
|
| 91 |
+
Sequence length for tokenizing text list input.
|
| 92 |
+
text_encoder_width (`int`, *optional*, defaults to 256):
|
| 93 |
+
Hidden size for text encoder.
|
| 94 |
+
text_encoder_context_length (`int`, *optional*, defaults to 77):
|
| 95 |
+
Input sequence length for text encoder.
|
| 96 |
+
text_encoder_num_layers (`int`, *optional*, defaults to 6):
|
| 97 |
+
Number of layers for transformer in text encoder.
|
| 98 |
+
text_encoder_vocab_size (`int`, *optional*, defaults to 49408):
|
| 99 |
+
Vocabulary size for tokenizer.
|
| 100 |
+
text_encoder_proj_layers (`int`, *optional*, defaults to 2):
|
| 101 |
+
Number of layers in MLP for project text queries.
|
| 102 |
+
text_encoder_n_ctx (`int`, *optional*, defaults to 16):
|
| 103 |
+
Number of learnable text context queries.
|
| 104 |
+
conv_dim (`int`, *optional*, defaults to 256):
|
| 105 |
+
Feature map dimension to map outputs from the backbone.
|
| 106 |
+
mask_dim (`int`, *optional*, defaults to 256):
|
| 107 |
+
Dimension for feature maps in pixel decoder.
|
| 108 |
+
hidden_dim (`int`, *optional*, defaults to 256):
|
| 109 |
+
Dimension for hidden states in transformer decoder.
|
| 110 |
+
encoder_feedforward_dim (`int`, *optional*, defaults to 1024):
|
| 111 |
+
Dimension for FFN layer in pixel decoder.
|
| 112 |
+
norm (`str`, *optional*, defaults to `"GN"`):
|
| 113 |
+
Type of normalization.
|
| 114 |
+
encoder_layers (`int`, *optional*, defaults to 6):
|
| 115 |
+
Number of layers in pixel decoder.
|
| 116 |
+
decoder_layers (`int`, *optional*, defaults to 10):
|
| 117 |
+
Number of layers in transformer decoder.
|
| 118 |
+
use_task_norm (`bool`, *optional*, defaults to `True`):
|
| 119 |
+
Whether to normalize the task token.
|
| 120 |
+
num_attention_heads (`int`, *optional*, defaults to 8):
|
| 121 |
+
Number of attention heads in transformer layers in the pixel and transformer decoders.
|
| 122 |
+
dropout (`float`, *optional*, defaults to 0.1):
|
| 123 |
+
Dropout probability for pixel and transformer decoders.
|
| 124 |
+
dim_feedforward (`int`, *optional*, defaults to 2048):
|
| 125 |
+
Dimension for FFN layer in transformer decoder.
|
| 126 |
+
pre_norm (`bool`, *optional*, defaults to `False`):
|
| 127 |
+
Whether to normalize hidden states before attention layers in transformer decoder.
|
| 128 |
+
enforce_input_proj (`bool`, *optional*, defaults to `False`):
|
| 129 |
+
Whether to project hidden states in transformer decoder.
|
| 130 |
+
query_dec_layers (`int`, *optional*, defaults to 2):
|
| 131 |
+
Number of layers in query transformer.
|
| 132 |
+
common_stride (`int`, *optional*, defaults to 4):
|
| 133 |
+
Common stride used for features in pixel decoder.
|
| 134 |
+
|
| 135 |
+
Examples:
|
| 136 |
+
```python
|
| 137 |
+
>>> from transformers import OneFormerConfig, OneFormerModel
|
| 138 |
+
|
| 139 |
+
>>> # Initializing a OneFormer shi-labs/oneformer_ade20k_swin_tiny configuration
|
| 140 |
+
>>> configuration = OneFormerConfig()
|
| 141 |
+
>>> # Initializing a model (with random weights) from the shi-labs/oneformer_ade20k_swin_tiny style configuration
|
| 142 |
+
>>> model = OneFormerModel(configuration)
|
| 143 |
+
>>> # Accessing the model configuration
|
| 144 |
+
>>> configuration = model.config
|
| 145 |
+
```
|
| 146 |
+
"""
|
| 147 |
+
|
| 148 |
+
model_type = "oneformer"
|
| 149 |
+
attribute_map = {"hidden_size": "hidden_dim"}
|
| 150 |
+
|
| 151 |
+
def __init__(
|
| 152 |
+
self,
|
| 153 |
+
backbone_config: Optional[Dict] = None,
|
| 154 |
+
backbone: Optional[str] = None,
|
| 155 |
+
use_pretrained_backbone: bool = False,
|
| 156 |
+
use_timm_backbone: bool = False,
|
| 157 |
+
backbone_kwargs: Optional[Dict] = None,
|
| 158 |
+
ignore_value: int = 255,
|
| 159 |
+
num_queries: int = 150,
|
| 160 |
+
no_object_weight: int = 0.1,
|
| 161 |
+
class_weight: float = 2.0,
|
| 162 |
+
mask_weight: float = 5.0,
|
| 163 |
+
dice_weight: float = 5.0,
|
| 164 |
+
contrastive_weight: float = 0.5,
|
| 165 |
+
contrastive_temperature: float = 0.07,
|
| 166 |
+
train_num_points: int = 12544,
|
| 167 |
+
oversample_ratio: float = 3.0,
|
| 168 |
+
importance_sample_ratio: float = 0.75,
|
| 169 |
+
init_std: float = 0.02,
|
| 170 |
+
init_xavier_std: float = 1.0,
|
| 171 |
+
layer_norm_eps: float = 1e-05,
|
| 172 |
+
is_training: bool = False,
|
| 173 |
+
use_auxiliary_loss: bool = True,
|
| 174 |
+
output_auxiliary_logits: bool = True,
|
| 175 |
+
strides: Optional[list] = [4, 8, 16, 32],
|
| 176 |
+
task_seq_len: int = 77,
|
| 177 |
+
text_encoder_width: int = 256,
|
| 178 |
+
text_encoder_context_length: int = 77,
|
| 179 |
+
text_encoder_num_layers: int = 6,
|
| 180 |
+
text_encoder_vocab_size: int = 49408,
|
| 181 |
+
text_encoder_proj_layers: int = 2,
|
| 182 |
+
text_encoder_n_ctx: int = 16,
|
| 183 |
+
conv_dim: int = 256,
|
| 184 |
+
mask_dim: int = 256,
|
| 185 |
+
hidden_dim: int = 256,
|
| 186 |
+
encoder_feedforward_dim: int = 1024,
|
| 187 |
+
norm: str = "GN",
|
| 188 |
+
encoder_layers: int = 6,
|
| 189 |
+
decoder_layers: int = 10,
|
| 190 |
+
use_task_norm: bool = True,
|
| 191 |
+
num_attention_heads: int = 8,
|
| 192 |
+
dropout: float = 0.1,
|
| 193 |
+
dim_feedforward: int = 2048,
|
| 194 |
+
pre_norm: bool = False,
|
| 195 |
+
enforce_input_proj: bool = False,
|
| 196 |
+
query_dec_layers: int = 2,
|
| 197 |
+
common_stride: int = 4,
|
| 198 |
+
**kwargs,
|
| 199 |
+
):
|
| 200 |
+
if backbone_config is None and backbone is None:
|
| 201 |
+
logger.info("`backbone_config` is unset. Initializing the config with the default `Swin` backbone.")
|
| 202 |
+
backbone_config = CONFIG_MAPPING["swin"](
|
| 203 |
+
image_size=224,
|
| 204 |
+
num_channels=3,
|
| 205 |
+
patch_size=4,
|
| 206 |
+
embed_dim=96,
|
| 207 |
+
depths=[2, 2, 6, 2],
|
| 208 |
+
num_heads=[3, 6, 12, 24],
|
| 209 |
+
window_size=7,
|
| 210 |
+
drop_path_rate=0.3,
|
| 211 |
+
use_absolute_embeddings=False,
|
| 212 |
+
out_features=["stage1", "stage2", "stage3", "stage4"],
|
| 213 |
+
)
|
| 214 |
+
elif isinstance(backbone_config, dict):
|
| 215 |
+
backbone_model_type = backbone_config.get("model_type")
|
| 216 |
+
config_class = CONFIG_MAPPING[backbone_model_type]
|
| 217 |
+
backbone_config = config_class.from_dict(backbone_config)
|
| 218 |
+
|
| 219 |
+
verify_backbone_config_arguments(
|
| 220 |
+
use_timm_backbone=use_timm_backbone,
|
| 221 |
+
use_pretrained_backbone=use_pretrained_backbone,
|
| 222 |
+
backbone=backbone,
|
| 223 |
+
backbone_config=backbone_config,
|
| 224 |
+
backbone_kwargs=backbone_kwargs,
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
self.backbone_config = backbone_config
|
| 228 |
+
self.backbone = backbone
|
| 229 |
+
self.use_pretrained_backbone = use_pretrained_backbone
|
| 230 |
+
self.use_timm_backbone = use_timm_backbone
|
| 231 |
+
self.backbone_kwargs = backbone_kwargs
|
| 232 |
+
self.ignore_value = ignore_value
|
| 233 |
+
self.num_queries = num_queries
|
| 234 |
+
self.no_object_weight = no_object_weight
|
| 235 |
+
self.class_weight = class_weight
|
| 236 |
+
self.mask_weight = mask_weight
|
| 237 |
+
self.dice_weight = dice_weight
|
| 238 |
+
self.contrastive_weight = contrastive_weight
|
| 239 |
+
self.contrastive_temperature = contrastive_temperature
|
| 240 |
+
self.train_num_points = train_num_points
|
| 241 |
+
self.oversample_ratio = oversample_ratio
|
| 242 |
+
self.importance_sample_ratio = importance_sample_ratio
|
| 243 |
+
self.init_std = init_std
|
| 244 |
+
self.init_xavier_std = init_xavier_std
|
| 245 |
+
self.layer_norm_eps = layer_norm_eps
|
| 246 |
+
self.is_training = is_training
|
| 247 |
+
self.use_auxiliary_loss = use_auxiliary_loss
|
| 248 |
+
self.output_auxiliary_logits = output_auxiliary_logits
|
| 249 |
+
self.strides = strides
|
| 250 |
+
self.task_seq_len = task_seq_len
|
| 251 |
+
self.text_encoder_width = text_encoder_width
|
| 252 |
+
self.text_encoder_context_length = text_encoder_context_length
|
| 253 |
+
self.text_encoder_num_layers = text_encoder_num_layers
|
| 254 |
+
self.text_encoder_vocab_size = text_encoder_vocab_size
|
| 255 |
+
self.text_encoder_proj_layers = text_encoder_proj_layers
|
| 256 |
+
self.text_encoder_n_ctx = text_encoder_n_ctx
|
| 257 |
+
self.conv_dim = conv_dim
|
| 258 |
+
self.mask_dim = mask_dim
|
| 259 |
+
self.hidden_dim = hidden_dim
|
| 260 |
+
self.encoder_feedforward_dim = encoder_feedforward_dim
|
| 261 |
+
self.norm = norm
|
| 262 |
+
self.encoder_layers = encoder_layers
|
| 263 |
+
self.decoder_layers = decoder_layers
|
| 264 |
+
self.use_task_norm = use_task_norm
|
| 265 |
+
self.num_attention_heads = num_attention_heads
|
| 266 |
+
self.dropout = dropout
|
| 267 |
+
self.dim_feedforward = dim_feedforward
|
| 268 |
+
self.pre_norm = pre_norm
|
| 269 |
+
self.enforce_input_proj = enforce_input_proj
|
| 270 |
+
self.query_dec_layers = query_dec_layers
|
| 271 |
+
self.common_stride = common_stride
|
| 272 |
+
self.num_hidden_layers = decoder_layers
|
| 273 |
+
|
| 274 |
+
super().__init__(**kwargs)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
__all__ = ["OneFormerConfig"]
|
docs/transformers/src/transformers/models/oneformer/convert_to_hf_oneformer.py
ADDED
|
@@ -0,0 +1,1191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 SHI Labs and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Convert OneFormer checkpoints from the original repository. URL: https://github.com/SHI-Labs/OneFormer"""
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
import sys
|
| 20 |
+
from argparse import ArgumentParser
|
| 21 |
+
from dataclasses import dataclass
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from pprint import pformat
|
| 24 |
+
from typing import Any, Dict, Iterator, List, Set, Tuple
|
| 25 |
+
|
| 26 |
+
import requests
|
| 27 |
+
import torch
|
| 28 |
+
import torchvision.transforms as T
|
| 29 |
+
from PIL import Image
|
| 30 |
+
from torch import Tensor, nn
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
try:
|
| 34 |
+
from detectron2.checkpoint import DetectionCheckpointer
|
| 35 |
+
from detectron2.config import get_cfg
|
| 36 |
+
from detectron2.data import MetadataCatalog
|
| 37 |
+
from detectron2.projects.deeplab import add_deeplab_config
|
| 38 |
+
except ImportError:
|
| 39 |
+
pass
|
| 40 |
+
from transformers import CLIPTokenizer, DinatConfig, SwinConfig
|
| 41 |
+
from transformers.models.oneformer.image_processing_oneformer import OneFormerImageProcessor
|
| 42 |
+
from transformers.models.oneformer.modeling_oneformer import (
|
| 43 |
+
OneFormerConfig,
|
| 44 |
+
OneFormerForUniversalSegmentation,
|
| 45 |
+
OneFormerForUniversalSegmentationOutput,
|
| 46 |
+
OneFormerModel,
|
| 47 |
+
OneFormerModelOutput,
|
| 48 |
+
)
|
| 49 |
+
from transformers.models.oneformer.processing_oneformer import OneFormerProcessor
|
| 50 |
+
from transformers.utils import logging
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
StateDict = Dict[str, Tensor]
|
| 54 |
+
|
| 55 |
+
logging.set_verbosity_info()
|
| 56 |
+
logger = logging.get_logger()
|
| 57 |
+
|
| 58 |
+
torch.manual_seed(0)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class TrackedStateDict:
|
| 62 |
+
def __init__(self, to_track: Dict):
|
| 63 |
+
"""This class "tracks" a python dictionary by keeping track of which item is accessed.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
to_track (Dict): The dictionary we wish to track
|
| 67 |
+
"""
|
| 68 |
+
self.to_track = to_track
|
| 69 |
+
self._seen: Set[str] = set()
|
| 70 |
+
|
| 71 |
+
def __getitem__(self, key: str) -> Any:
|
| 72 |
+
return self.to_track[key]
|
| 73 |
+
|
| 74 |
+
def __setitem__(self, key: str, item: Any):
|
| 75 |
+
self._seen.add(key)
|
| 76 |
+
self.to_track[key] = item
|
| 77 |
+
|
| 78 |
+
def diff(self) -> List[str]:
|
| 79 |
+
"""This method returns a set difference between the keys in the tracked state dict and the one we have access so far.
|
| 80 |
+
This is an effective method to check if we have update all the keys
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
List[str]: List of keys not yet updated
|
| 84 |
+
"""
|
| 85 |
+
return set(self.to_track.keys()) - self._seen
|
| 86 |
+
|
| 87 |
+
def copy(self) -> Dict:
|
| 88 |
+
# proxy the call to the internal dictionary
|
| 89 |
+
return self.to_track.copy()
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# Image to verify the result
|
| 93 |
+
def prepare_img():
|
| 94 |
+
url = "https://praeclarumjj3.github.io/files/coco.jpeg"
|
| 95 |
+
img_data = requests.get(url, stream=True).raw
|
| 96 |
+
im = Image.open(img_data)
|
| 97 |
+
return im
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
@dataclass
|
| 101 |
+
class Args:
|
| 102 |
+
"""Fake command line arguments needed by oneformer/detectron2 implementation"""
|
| 103 |
+
|
| 104 |
+
config_file: str
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def setup_cfg(args: Args):
|
| 108 |
+
# load config from file and command-line arguments
|
| 109 |
+
cfg = get_cfg()
|
| 110 |
+
add_deeplab_config(cfg)
|
| 111 |
+
add_common_config(cfg)
|
| 112 |
+
add_oneformer_config(cfg)
|
| 113 |
+
add_swin_config(cfg)
|
| 114 |
+
add_dinat_config(cfg)
|
| 115 |
+
cfg.merge_from_file(args.config_file)
|
| 116 |
+
cfg.freeze()
|
| 117 |
+
return cfg
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class OriginalOneFormerConfigToOursConverter:
|
| 121 |
+
def __call__(self, original_config: object, is_swin: bool) -> OneFormerConfig:
|
| 122 |
+
model = original_config.MODEL
|
| 123 |
+
|
| 124 |
+
dataset_catalog = MetadataCatalog.get(original_config.DATASETS.TEST_PANOPTIC[0])
|
| 125 |
+
id2label = dict(enumerate(dataset_catalog.stuff_classes))
|
| 126 |
+
label2id = {label: idx for idx, label in id2label.items()}
|
| 127 |
+
|
| 128 |
+
if is_swin:
|
| 129 |
+
if model.SWIN.EMBED_DIM == 96:
|
| 130 |
+
backbone_config = SwinConfig.from_pretrained(
|
| 131 |
+
"microsoft/swin-tiny-patch4-window7-224",
|
| 132 |
+
drop_path_rate=model.SWIN.DROP_PATH_RATE,
|
| 133 |
+
out_features=["stage1", "stage2", "stage3", "stage4"],
|
| 134 |
+
)
|
| 135 |
+
elif model.SWIN.EMBED_DIM == 192:
|
| 136 |
+
backbone_config = SwinConfig.from_pretrained(
|
| 137 |
+
"microsoft/swin-large-patch4-window12-384",
|
| 138 |
+
drop_path_rate=model.SWIN.DROP_PATH_RATE,
|
| 139 |
+
out_features=["stage1", "stage2", "stage3", "stage4"],
|
| 140 |
+
)
|
| 141 |
+
else:
|
| 142 |
+
raise ValueError(f"embed dim {model.SWIN.EMBED_DIM} not supported for Swin!")
|
| 143 |
+
else:
|
| 144 |
+
backbone_config = DinatConfig.from_pretrained(
|
| 145 |
+
"shi-labs/dinat-large-11x11-in22k-in1k-384",
|
| 146 |
+
dilations=model.DiNAT.DILATIONS,
|
| 147 |
+
kernel_size=model.DiNAT.KERNEL_SIZE,
|
| 148 |
+
out_features=["stage1", "stage2", "stage3", "stage4"],
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
config: OneFormerConfig = OneFormerConfig(
|
| 152 |
+
backbone_config=backbone_config,
|
| 153 |
+
output_attentions=True,
|
| 154 |
+
output_hidden_states=True,
|
| 155 |
+
return_dict=True,
|
| 156 |
+
ignore_value=model.SEM_SEG_HEAD.IGNORE_VALUE,
|
| 157 |
+
num_classes=model.SEM_SEG_HEAD.NUM_CLASSES,
|
| 158 |
+
num_queries=model.ONE_FORMER.NUM_OBJECT_QUERIES,
|
| 159 |
+
no_object_weight=model.ONE_FORMER.NO_OBJECT_WEIGHT,
|
| 160 |
+
class_weight=model.ONE_FORMER.CLASS_WEIGHT,
|
| 161 |
+
mask_weight=model.ONE_FORMER.MASK_WEIGHT,
|
| 162 |
+
dice_weight=model.ONE_FORMER.DICE_WEIGHT,
|
| 163 |
+
contrastive_weight=model.ONE_FORMER.CONTRASTIVE_WEIGHT,
|
| 164 |
+
contrastive_temperature=model.ONE_FORMER.CONTRASTIVE_TEMPERATURE,
|
| 165 |
+
train_num_points=model.ONE_FORMER.TRAIN_NUM_POINTS,
|
| 166 |
+
oversample_ratio=model.ONE_FORMER.OVERSAMPLE_RATIO,
|
| 167 |
+
importance_sample_ratio=model.ONE_FORMER.IMPORTANCE_SAMPLE_RATIO,
|
| 168 |
+
init_std=0.02,
|
| 169 |
+
init_xavier_std=1.0,
|
| 170 |
+
layer_norm_eps=1e-05,
|
| 171 |
+
is_training=False,
|
| 172 |
+
use_auxiliary_loss=model.ONE_FORMER.DEEP_SUPERVISION,
|
| 173 |
+
output_auxiliary_logits=True,
|
| 174 |
+
strides=[4, 8, 16, 32],
|
| 175 |
+
task_seq_len=original_config.INPUT.TASK_SEQ_LEN,
|
| 176 |
+
max_seq_len=original_config.INPUT.MAX_SEQ_LEN,
|
| 177 |
+
text_encoder_width=model.TEXT_ENCODER.WIDTH,
|
| 178 |
+
text_encoder_context_length=model.TEXT_ENCODER.CONTEXT_LENGTH,
|
| 179 |
+
text_encoder_num_layers=model.TEXT_ENCODER.NUM_LAYERS,
|
| 180 |
+
text_encoder_vocab_size=model.TEXT_ENCODER.VOCAB_SIZE,
|
| 181 |
+
text_encoder_proj_layers=model.TEXT_ENCODER.PROJ_NUM_LAYERS,
|
| 182 |
+
text_encoder_n_ctx=model.TEXT_ENCODER.N_CTX,
|
| 183 |
+
conv_dim=model.SEM_SEG_HEAD.CONVS_DIM,
|
| 184 |
+
mask_dim=model.SEM_SEG_HEAD.MASK_DIM,
|
| 185 |
+
hidden_dim=model.ONE_FORMER.HIDDEN_DIM,
|
| 186 |
+
norm=model.SEM_SEG_HEAD.NORM,
|
| 187 |
+
encoder_layers=model.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS,
|
| 188 |
+
encoder_feedforward_dim=1024,
|
| 189 |
+
decoder_layers=model.ONE_FORMER.DEC_LAYERS,
|
| 190 |
+
use_task_norm=model.ONE_FORMER.USE_TASK_NORM,
|
| 191 |
+
num_attention_heads=model.ONE_FORMER.NHEADS,
|
| 192 |
+
dropout=model.ONE_FORMER.DROPOUT,
|
| 193 |
+
dim_feedforward=model.ONE_FORMER.DIM_FEEDFORWARD,
|
| 194 |
+
pre_norm=model.ONE_FORMER.PRE_NORM,
|
| 195 |
+
enforce_input_proj=model.ONE_FORMER.ENFORCE_INPUT_PROJ,
|
| 196 |
+
query_dec_layers=model.ONE_FORMER.CLASS_DEC_LAYERS,
|
| 197 |
+
common_stride=model.SEM_SEG_HEAD.COMMON_STRIDE,
|
| 198 |
+
id2label=id2label,
|
| 199 |
+
label2id=label2id,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
return config
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class OriginalOneFormerConfigToProcessorConverter:
|
| 206 |
+
def __call__(self, original_config: object, model_repo: str) -> OneFormerProcessor:
|
| 207 |
+
model = original_config.MODEL
|
| 208 |
+
model_input = original_config.INPUT
|
| 209 |
+
dataset_catalog = MetadataCatalog.get(original_config.DATASETS.TEST_PANOPTIC[0])
|
| 210 |
+
|
| 211 |
+
if "ade20k" in model_repo:
|
| 212 |
+
class_info_file = "ade20k_panoptic.json"
|
| 213 |
+
elif "coco" in model_repo:
|
| 214 |
+
class_info_file = "coco_panoptic.json"
|
| 215 |
+
elif "cityscapes" in model_repo:
|
| 216 |
+
class_info_file = "cityscapes_panoptic.json"
|
| 217 |
+
else:
|
| 218 |
+
raise ValueError("Invalid Dataset!")
|
| 219 |
+
|
| 220 |
+
image_processor = OneFormerImageProcessor(
|
| 221 |
+
image_mean=(torch.tensor(model.PIXEL_MEAN) / 255).tolist(),
|
| 222 |
+
image_std=(torch.tensor(model.PIXEL_STD) / 255).tolist(),
|
| 223 |
+
size=model_input.MIN_SIZE_TEST,
|
| 224 |
+
max_size=model_input.MAX_SIZE_TEST,
|
| 225 |
+
num_labels=model.SEM_SEG_HEAD.NUM_CLASSES,
|
| 226 |
+
ignore_index=dataset_catalog.ignore_label,
|
| 227 |
+
class_info_file=class_info_file,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
tokenizer = CLIPTokenizer.from_pretrained(model_repo)
|
| 231 |
+
|
| 232 |
+
return OneFormerProcessor(
|
| 233 |
+
image_processor=image_processor,
|
| 234 |
+
tokenizer=tokenizer,
|
| 235 |
+
task_seq_length=original_config.INPUT.TASK_SEQ_LEN,
|
| 236 |
+
max_seq_length=original_config.INPUT.MAX_SEQ_LEN,
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
class OriginalOneFormerCheckpointToOursConverter:
|
| 241 |
+
def __init__(self, original_model: nn.Module, config: OneFormerConfig):
|
| 242 |
+
self.original_model = original_model
|
| 243 |
+
self.config = config
|
| 244 |
+
|
| 245 |
+
def pop_all(self, renamed_keys: List[Tuple[str, str]], dst_state_dict: StateDict, src_state_dict: StateDict):
|
| 246 |
+
for src_key, dst_key in renamed_keys:
|
| 247 |
+
dst_state_dict[dst_key] = src_state_dict.pop(src_key)
|
| 248 |
+
|
| 249 |
+
# Swin Backbone
|
| 250 |
+
def replace_swin_backbone(self, dst_state_dict: StateDict, src_state_dict: StateDict, config: OneFormerConfig):
|
| 251 |
+
dst_prefix: str = "pixel_level_module.encoder"
|
| 252 |
+
src_prefix: str = "backbone"
|
| 253 |
+
|
| 254 |
+
renamed_keys = [
|
| 255 |
+
(
|
| 256 |
+
f"{src_prefix}.patch_embed.proj.weight",
|
| 257 |
+
f"{dst_prefix}.embeddings.patch_embeddings.projection.weight",
|
| 258 |
+
),
|
| 259 |
+
(f"{src_prefix}.patch_embed.proj.bias", f"{dst_prefix}.embeddings.patch_embeddings.projection.bias"),
|
| 260 |
+
(f"{src_prefix}.patch_embed.norm.weight", f"{dst_prefix}.embeddings.norm.weight"),
|
| 261 |
+
(f"{src_prefix}.patch_embed.norm.bias", f"{dst_prefix}.embeddings.norm.bias"),
|
| 262 |
+
]
|
| 263 |
+
num_layers = len(config.backbone_config.depths)
|
| 264 |
+
for layer_idx in range(num_layers):
|
| 265 |
+
for block_idx in range(config.backbone_config.depths[layer_idx]):
|
| 266 |
+
renamed_keys.extend(
|
| 267 |
+
[ # src, dst
|
| 268 |
+
(
|
| 269 |
+
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm1.weight",
|
| 270 |
+
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_before.weight",
|
| 271 |
+
),
|
| 272 |
+
(
|
| 273 |
+
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm1.bias",
|
| 274 |
+
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_before.bias",
|
| 275 |
+
),
|
| 276 |
+
(
|
| 277 |
+
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.relative_position_bias_table",
|
| 278 |
+
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.relative_position_bias_table",
|
| 279 |
+
),
|
| 280 |
+
]
|
| 281 |
+
)
|
| 282 |
+
# now we need to handle the attentions
|
| 283 |
+
# read in weights + bias of input projection layer of cross-attention
|
| 284 |
+
|
| 285 |
+
src_att_weight = src_state_dict[f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.weight"]
|
| 286 |
+
src_att_bias = src_state_dict[f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.bias"]
|
| 287 |
+
|
| 288 |
+
size = src_att_weight.shape[0]
|
| 289 |
+
offset = size // 3
|
| 290 |
+
dst_state_dict[
|
| 291 |
+
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.query.weight"
|
| 292 |
+
] = src_att_weight[:offset, :]
|
| 293 |
+
dst_state_dict[
|
| 294 |
+
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.query.bias"
|
| 295 |
+
] = src_att_bias[:offset]
|
| 296 |
+
|
| 297 |
+
dst_state_dict[
|
| 298 |
+
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.key.weight"
|
| 299 |
+
] = src_att_weight[offset : offset * 2, :]
|
| 300 |
+
dst_state_dict[
|
| 301 |
+
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.key.bias"
|
| 302 |
+
] = src_att_bias[offset : offset * 2]
|
| 303 |
+
|
| 304 |
+
dst_state_dict[
|
| 305 |
+
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.value.weight"
|
| 306 |
+
] = src_att_weight[-offset:, :]
|
| 307 |
+
dst_state_dict[
|
| 308 |
+
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.value.bias"
|
| 309 |
+
] = src_att_bias[-offset:]
|
| 310 |
+
|
| 311 |
+
# let's pop them
|
| 312 |
+
src_state_dict.pop(f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.weight")
|
| 313 |
+
src_state_dict.pop(f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.bias")
|
| 314 |
+
# proj
|
| 315 |
+
renamed_keys.extend(
|
| 316 |
+
[
|
| 317 |
+
(
|
| 318 |
+
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.proj.weight",
|
| 319 |
+
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.output.dense.weight",
|
| 320 |
+
),
|
| 321 |
+
(
|
| 322 |
+
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.proj.bias",
|
| 323 |
+
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.output.dense.bias",
|
| 324 |
+
),
|
| 325 |
+
]
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
# second norm
|
| 329 |
+
renamed_keys.extend(
|
| 330 |
+
[
|
| 331 |
+
(
|
| 332 |
+
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm2.weight",
|
| 333 |
+
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_after.weight",
|
| 334 |
+
),
|
| 335 |
+
(
|
| 336 |
+
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm2.bias",
|
| 337 |
+
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_after.bias",
|
| 338 |
+
),
|
| 339 |
+
]
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
# mlp
|
| 343 |
+
renamed_keys.extend(
|
| 344 |
+
[
|
| 345 |
+
(
|
| 346 |
+
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc1.weight",
|
| 347 |
+
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.intermediate.dense.weight",
|
| 348 |
+
),
|
| 349 |
+
(
|
| 350 |
+
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc1.bias",
|
| 351 |
+
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.intermediate.dense.bias",
|
| 352 |
+
),
|
| 353 |
+
(
|
| 354 |
+
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc2.weight",
|
| 355 |
+
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.output.dense.weight",
|
| 356 |
+
),
|
| 357 |
+
(
|
| 358 |
+
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc2.bias",
|
| 359 |
+
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.output.dense.bias",
|
| 360 |
+
),
|
| 361 |
+
]
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
renamed_keys.extend(
|
| 365 |
+
[
|
| 366 |
+
(
|
| 367 |
+
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.relative_position_index",
|
| 368 |
+
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.relative_position_index",
|
| 369 |
+
)
|
| 370 |
+
]
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
if layer_idx < num_layers - 1:
|
| 374 |
+
# patch merging
|
| 375 |
+
renamed_keys.extend(
|
| 376 |
+
[
|
| 377 |
+
(
|
| 378 |
+
f"{src_prefix}.layers.{layer_idx}.downsample.reduction.weight",
|
| 379 |
+
f"{dst_prefix}.encoder.layers.{layer_idx}.downsample.reduction.weight",
|
| 380 |
+
),
|
| 381 |
+
(
|
| 382 |
+
f"{src_prefix}.layers.{layer_idx}.downsample.norm.weight",
|
| 383 |
+
f"{dst_prefix}.encoder.layers.{layer_idx}.downsample.norm.weight",
|
| 384 |
+
),
|
| 385 |
+
(
|
| 386 |
+
f"{src_prefix}.layers.{layer_idx}.downsample.norm.bias",
|
| 387 |
+
f"{dst_prefix}.encoder.layers.{layer_idx}.downsample.norm.bias",
|
| 388 |
+
),
|
| 389 |
+
]
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
# hidden states norms
|
| 393 |
+
renamed_keys.extend(
|
| 394 |
+
[
|
| 395 |
+
(
|
| 396 |
+
f"{src_prefix}.norm{layer_idx}.weight",
|
| 397 |
+
f"{dst_prefix}.hidden_states_norms.stage{layer_idx + 1}.weight",
|
| 398 |
+
),
|
| 399 |
+
(
|
| 400 |
+
f"{src_prefix}.norm{layer_idx}.bias",
|
| 401 |
+
f"{dst_prefix}.hidden_states_norms.stage{layer_idx + 1}.bias",
|
| 402 |
+
),
|
| 403 |
+
]
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
self.pop_all(renamed_keys, dst_state_dict, src_state_dict)
|
| 407 |
+
|
| 408 |
+
# Dinat Backbone
|
| 409 |
+
def replace_dinat_backbone(self, dst_state_dict: StateDict, src_state_dict: StateDict, config: OneFormerConfig):
|
| 410 |
+
dst_prefix: str = "pixel_level_module.encoder"
|
| 411 |
+
src_prefix: str = "backbone"
|
| 412 |
+
|
| 413 |
+
def rename_keys_for_weight_bias(src_prefix: str, dst_prefix: str):
|
| 414 |
+
return [
|
| 415 |
+
(f"{src_prefix}.weight", f"{dst_prefix}.weight"),
|
| 416 |
+
(f"{src_prefix}.bias", f"{dst_prefix}.bias"),
|
| 417 |
+
]
|
| 418 |
+
|
| 419 |
+
renamed_keys = rename_keys_for_weight_bias(f"{src_prefix}.patch_embed.norm", f"{dst_prefix}.embeddings.norm")
|
| 420 |
+
|
| 421 |
+
for i in range(2):
|
| 422 |
+
renamed_keys.extend(
|
| 423 |
+
rename_keys_for_weight_bias(
|
| 424 |
+
f"{src_prefix}.patch_embed.proj.{i}",
|
| 425 |
+
f"{dst_prefix}.embeddings.patch_embeddings.projection.{i}",
|
| 426 |
+
)
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
num_layers = len(config.backbone_config.depths)
|
| 430 |
+
for layer_idx in range(num_layers):
|
| 431 |
+
for block_idx in range(config.backbone_config.depths[layer_idx]):
|
| 432 |
+
renamed_keys.extend(
|
| 433 |
+
rename_keys_for_weight_bias(
|
| 434 |
+
f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.norm1",
|
| 435 |
+
f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.layernorm_before",
|
| 436 |
+
)
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
renamed_keys.extend(
|
| 440 |
+
rename_keys_for_weight_bias(
|
| 441 |
+
f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.norm2",
|
| 442 |
+
f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.layernorm_after",
|
| 443 |
+
)
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
renamed_keys.extend(
|
| 447 |
+
[ # src, dst
|
| 448 |
+
(
|
| 449 |
+
f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.attn.rpb",
|
| 450 |
+
f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.self.rpb",
|
| 451 |
+
),
|
| 452 |
+
]
|
| 453 |
+
)
|
| 454 |
+
# now we need to handle the attentions
|
| 455 |
+
# read in weights + bias of input projection layer of cross-attention
|
| 456 |
+
|
| 457 |
+
src_att_weight = src_state_dict[f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.attn.qkv.weight"]
|
| 458 |
+
src_att_bias = src_state_dict[f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.attn.qkv.bias"]
|
| 459 |
+
|
| 460 |
+
size = src_att_weight.shape[0]
|
| 461 |
+
offset = size // 3
|
| 462 |
+
dst_state_dict[
|
| 463 |
+
f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.self.query.weight"
|
| 464 |
+
] = src_att_weight[:offset, :]
|
| 465 |
+
dst_state_dict[
|
| 466 |
+
f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.self.query.bias"
|
| 467 |
+
] = src_att_bias[:offset]
|
| 468 |
+
|
| 469 |
+
dst_state_dict[
|
| 470 |
+
f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.self.key.weight"
|
| 471 |
+
] = src_att_weight[offset : offset * 2, :]
|
| 472 |
+
dst_state_dict[
|
| 473 |
+
f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.self.key.bias"
|
| 474 |
+
] = src_att_bias[offset : offset * 2]
|
| 475 |
+
|
| 476 |
+
dst_state_dict[
|
| 477 |
+
f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.self.value.weight"
|
| 478 |
+
] = src_att_weight[-offset:, :]
|
| 479 |
+
dst_state_dict[
|
| 480 |
+
f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.self.value.bias"
|
| 481 |
+
] = src_att_bias[-offset:]
|
| 482 |
+
|
| 483 |
+
# let's pop them
|
| 484 |
+
src_state_dict.pop(f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.attn.qkv.weight")
|
| 485 |
+
src_state_dict.pop(f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.attn.qkv.bias")
|
| 486 |
+
# proj
|
| 487 |
+
|
| 488 |
+
renamed_keys.extend(
|
| 489 |
+
rename_keys_for_weight_bias(
|
| 490 |
+
f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.attn.proj",
|
| 491 |
+
f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.output.dense",
|
| 492 |
+
)
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
# mlp
|
| 496 |
+
renamed_keys.extend(
|
| 497 |
+
rename_keys_for_weight_bias(
|
| 498 |
+
f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.mlp.fc1",
|
| 499 |
+
f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.intermediate.dense",
|
| 500 |
+
)
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
renamed_keys.extend(
|
| 504 |
+
rename_keys_for_weight_bias(
|
| 505 |
+
f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.mlp.fc2",
|
| 506 |
+
f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.output.dense",
|
| 507 |
+
)
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
if layer_idx < num_layers - 1:
|
| 511 |
+
# patch merging
|
| 512 |
+
renamed_keys.extend(
|
| 513 |
+
[
|
| 514 |
+
(
|
| 515 |
+
f"{src_prefix}.levels.{layer_idx}.downsample.reduction.weight",
|
| 516 |
+
f"{dst_prefix}.encoder.levels.{layer_idx}.downsample.reduction.weight",
|
| 517 |
+
),
|
| 518 |
+
(
|
| 519 |
+
f"{src_prefix}.levels.{layer_idx}.downsample.norm.weight",
|
| 520 |
+
f"{dst_prefix}.encoder.levels.{layer_idx}.downsample.norm.weight",
|
| 521 |
+
),
|
| 522 |
+
(
|
| 523 |
+
f"{src_prefix}.levels.{layer_idx}.downsample.norm.bias",
|
| 524 |
+
f"{dst_prefix}.encoder.levels.{layer_idx}.downsample.norm.bias",
|
| 525 |
+
),
|
| 526 |
+
]
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
# hidden states norms
|
| 530 |
+
renamed_keys.extend(
|
| 531 |
+
[
|
| 532 |
+
(
|
| 533 |
+
f"{src_prefix}.norm{layer_idx}.weight",
|
| 534 |
+
f"{dst_prefix}.hidden_states_norms.stage{layer_idx + 1}.weight",
|
| 535 |
+
),
|
| 536 |
+
(
|
| 537 |
+
f"{src_prefix}.norm{layer_idx}.bias",
|
| 538 |
+
f"{dst_prefix}.hidden_states_norms.stage{layer_idx + 1}.bias",
|
| 539 |
+
),
|
| 540 |
+
]
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
self.pop_all(renamed_keys, dst_state_dict, src_state_dict)
|
| 544 |
+
|
| 545 |
+
# Backbone + Pixel Decoder
|
| 546 |
+
def replace_pixel_module(self, dst_state_dict: StateDict, src_state_dict: StateDict, is_swin: bool):
|
| 547 |
+
dst_prefix: str = "pixel_level_module.decoder"
|
| 548 |
+
src_prefix: str = "sem_seg_head.pixel_decoder"
|
| 549 |
+
|
| 550 |
+
if is_swin:
|
| 551 |
+
self.replace_swin_backbone(dst_state_dict, src_state_dict, self.config)
|
| 552 |
+
else:
|
| 553 |
+
self.replace_dinat_backbone(dst_state_dict, src_state_dict, self.config)
|
| 554 |
+
|
| 555 |
+
def rename_keys_for_weight_bias(src_prefix: str, dst_prefix: str):
|
| 556 |
+
return [
|
| 557 |
+
(f"{src_prefix}.weight", f"{dst_prefix}.weight"),
|
| 558 |
+
(f"{src_prefix}.bias", f"{dst_prefix}.bias"),
|
| 559 |
+
]
|
| 560 |
+
|
| 561 |
+
def rename_keys_for_self_attn(src_prefix: str, dst_prefix: str):
|
| 562 |
+
self_attn_keys = []
|
| 563 |
+
self_attn_keys.extend(
|
| 564 |
+
rename_keys_for_weight_bias(f"{src_prefix}.attention_weights", f"{dst_prefix}.attention_weights")
|
| 565 |
+
)
|
| 566 |
+
self_attn_keys.extend(
|
| 567 |
+
rename_keys_for_weight_bias(f"{src_prefix}.output_proj", f"{dst_prefix}.output_proj")
|
| 568 |
+
)
|
| 569 |
+
self_attn_keys.extend(
|
| 570 |
+
rename_keys_for_weight_bias(f"{src_prefix}.sampling_offsets", f"{dst_prefix}.sampling_offsets")
|
| 571 |
+
)
|
| 572 |
+
self_attn_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.value_proj", f"{dst_prefix}.value_proj"))
|
| 573 |
+
|
| 574 |
+
return self_attn_keys
|
| 575 |
+
|
| 576 |
+
def rename_keys_for_encoder_layer(src_prefix: str, dst_prefix: str):
|
| 577 |
+
encoder_keys = []
|
| 578 |
+
encoder_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.linear1", f"{dst_prefix}.fc1"))
|
| 579 |
+
encoder_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.linear2", f"{dst_prefix}.fc2"))
|
| 580 |
+
encoder_keys.extend(
|
| 581 |
+
rename_keys_for_weight_bias(f"{src_prefix}.norm1", f"{dst_prefix}.self_attn_layer_norm")
|
| 582 |
+
)
|
| 583 |
+
encoder_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.norm2", f"{dst_prefix}.final_layer_norm"))
|
| 584 |
+
encoder_keys.extend(rename_keys_for_self_attn(f"{src_prefix}.self_attn", f"{dst_prefix}.self_attn"))
|
| 585 |
+
|
| 586 |
+
return encoder_keys
|
| 587 |
+
|
| 588 |
+
# convolution layer for final features
|
| 589 |
+
renamed_keys = [
|
| 590 |
+
(f"{src_prefix}.adapter_1.weight", f"{dst_prefix}.adapter_1.0.weight"),
|
| 591 |
+
(f"{src_prefix}.adapter_1.norm.weight", f"{dst_prefix}.adapter_1.1.weight"),
|
| 592 |
+
(f"{src_prefix}.adapter_1.norm.bias", f"{dst_prefix}.adapter_1.1.bias"),
|
| 593 |
+
]
|
| 594 |
+
|
| 595 |
+
renamed_keys.extend(
|
| 596 |
+
[
|
| 597 |
+
(f"{src_prefix}.layer_1.weight", f"{dst_prefix}.layer_1.0.weight"),
|
| 598 |
+
(f"{src_prefix}.layer_1.norm.weight", f"{dst_prefix}.layer_1.1.weight"),
|
| 599 |
+
(f"{src_prefix}.layer_1.norm.bias", f"{dst_prefix}.layer_1.1.bias"),
|
| 600 |
+
]
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
# proj layers
|
| 604 |
+
for i in range(3):
|
| 605 |
+
for j in range(2):
|
| 606 |
+
renamed_keys.extend(
|
| 607 |
+
[
|
| 608 |
+
(f"{src_prefix}.input_proj.{i}.{j}.weight", f"{dst_prefix}.input_projections.{i}.{j}.weight"),
|
| 609 |
+
(f"{src_prefix}.input_proj.{i}.{j}.bias", f"{dst_prefix}.input_projections.{i}.{j}.bias"),
|
| 610 |
+
]
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
renamed_keys.extend([(f"{src_prefix}.transformer.level_embed", f"{dst_prefix}.level_embed")])
|
| 614 |
+
|
| 615 |
+
# layers
|
| 616 |
+
for layer_idx in range(self.config.encoder_layers):
|
| 617 |
+
renamed_keys.extend(
|
| 618 |
+
rename_keys_for_encoder_layer(
|
| 619 |
+
f"{src_prefix}.transformer.encoder.layers.{layer_idx}", f"{dst_prefix}.encoder.layers.{layer_idx}"
|
| 620 |
+
)
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
# proj
|
| 624 |
+
renamed_keys.extend(
|
| 625 |
+
[
|
| 626 |
+
(f"{src_prefix}.mask_features.weight", f"{dst_prefix}.mask_projection.weight"),
|
| 627 |
+
(f"{src_prefix}.mask_features.bias", f"{dst_prefix}.mask_projection.bias"),
|
| 628 |
+
]
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
self.pop_all(renamed_keys, dst_state_dict, src_state_dict)
|
| 632 |
+
|
| 633 |
+
# Transformer Decoder
|
| 634 |
+
def replace_keys_qkv_transformer_decoder(self, dst_state_dict: StateDict, src_state_dict: StateDict):
|
| 635 |
+
dst_prefix: str = "transformer_module.decoder.layers"
|
| 636 |
+
src_prefix: str = "sem_seg_head.predictor"
|
| 637 |
+
for i in range(self.config.decoder_layers - 1):
|
| 638 |
+
# read in weights + bias of input projection layer of self-attention
|
| 639 |
+
in_proj_weight = src_state_dict.pop(
|
| 640 |
+
f"{src_prefix}.transformer_self_attention_layers.{i}.self_attn.in_proj_weight"
|
| 641 |
+
)
|
| 642 |
+
in_proj_bias = src_state_dict.pop(
|
| 643 |
+
f"{src_prefix}.transformer_self_attention_layers.{i}.self_attn.in_proj_bias"
|
| 644 |
+
)
|
| 645 |
+
# next, add query, keys and values (in that order) to the state dict
|
| 646 |
+
dst_state_dict[f"{dst_prefix}.{i}.self_attn.self_attn.q_proj.weight"] = in_proj_weight[:256, :]
|
| 647 |
+
dst_state_dict[f"{dst_prefix}.{i}.self_attn.self_attn.q_proj.bias"] = in_proj_bias[:256]
|
| 648 |
+
dst_state_dict[f"{dst_prefix}.{i}.self_attn.self_attn.k_proj.weight"] = in_proj_weight[256:512, :]
|
| 649 |
+
dst_state_dict[f"{dst_prefix}.{i}.self_attn.self_attn.k_proj.bias"] = in_proj_bias[256:512]
|
| 650 |
+
dst_state_dict[f"{dst_prefix}.{i}.self_attn.self_attn.v_proj.weight"] = in_proj_weight[-256:, :]
|
| 651 |
+
dst_state_dict[f"{dst_prefix}.{i}.self_attn.self_attn.v_proj.bias"] = in_proj_bias[-256:]
|
| 652 |
+
|
| 653 |
+
def replace_transformer_module(self, dst_state_dict: StateDict, src_state_dict: StateDict):
|
| 654 |
+
dst_prefix: str = "transformer_module"
|
| 655 |
+
src_prefix: str = "sem_seg_head.predictor"
|
| 656 |
+
|
| 657 |
+
def rename_keys_for_weight_bias(src_prefix: str, dst_prefix: str):
|
| 658 |
+
return [
|
| 659 |
+
(f"{src_prefix}.weight", f"{dst_prefix}.weight"),
|
| 660 |
+
(f"{src_prefix}.bias", f"{dst_prefix}.bias"),
|
| 661 |
+
]
|
| 662 |
+
|
| 663 |
+
def rename_keys_for_attn(src_prefix: str, dst_prefix: str):
|
| 664 |
+
attn_keys = [
|
| 665 |
+
(f"{src_prefix}.in_proj_bias", f"{dst_prefix}.in_proj_bias"),
|
| 666 |
+
(f"{src_prefix}.in_proj_weight", f"{dst_prefix}.in_proj_weight"),
|
| 667 |
+
]
|
| 668 |
+
attn_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.out_proj", f"{dst_prefix}.out_proj"))
|
| 669 |
+
|
| 670 |
+
return attn_keys
|
| 671 |
+
|
| 672 |
+
def rename_keys_for_self_attn(src_prefix: str, dst_prefix: str):
|
| 673 |
+
attn_keys = []
|
| 674 |
+
attn_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.out_proj", f"{dst_prefix}.out_proj"))
|
| 675 |
+
|
| 676 |
+
return attn_keys
|
| 677 |
+
|
| 678 |
+
def rename_keys_for_query_transformer_layer(src_prefix: str, dst_prefix: str):
|
| 679 |
+
query_transformer_layer_keys = []
|
| 680 |
+
|
| 681 |
+
query_transformer_layer_keys.extend(
|
| 682 |
+
rename_keys_for_weight_bias(f"{src_prefix}.linear1", f"{dst_prefix}.linear1")
|
| 683 |
+
)
|
| 684 |
+
query_transformer_layer_keys.extend(
|
| 685 |
+
rename_keys_for_weight_bias(f"{src_prefix}.linear2", f"{dst_prefix}.linear2")
|
| 686 |
+
)
|
| 687 |
+
query_transformer_layer_keys.extend(
|
| 688 |
+
rename_keys_for_weight_bias(f"{src_prefix}.norm1", f"{dst_prefix}.norm1")
|
| 689 |
+
)
|
| 690 |
+
query_transformer_layer_keys.extend(
|
| 691 |
+
rename_keys_for_weight_bias(f"{src_prefix}.norm2", f"{dst_prefix}.norm2")
|
| 692 |
+
)
|
| 693 |
+
query_transformer_layer_keys.extend(
|
| 694 |
+
rename_keys_for_weight_bias(f"{src_prefix}.norm3", f"{dst_prefix}.norm3")
|
| 695 |
+
)
|
| 696 |
+
|
| 697 |
+
query_transformer_layer_keys.extend(
|
| 698 |
+
rename_keys_for_attn(f"{src_prefix}.self_attn", f"{dst_prefix}.self_attn")
|
| 699 |
+
)
|
| 700 |
+
|
| 701 |
+
query_transformer_layer_keys.extend(
|
| 702 |
+
rename_keys_for_attn(f"{src_prefix}.multihead_attn", f"{dst_prefix}.multihead_attn")
|
| 703 |
+
)
|
| 704 |
+
|
| 705 |
+
return query_transformer_layer_keys
|
| 706 |
+
|
| 707 |
+
def rename_keys_for_cross_attn_layer(src_prefix: str, dst_prefix: str):
|
| 708 |
+
cross_attn_layer_keys = []
|
| 709 |
+
|
| 710 |
+
cross_attn_layer_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.norm", f"{dst_prefix}.norm"))
|
| 711 |
+
cross_attn_layer_keys.extend(
|
| 712 |
+
rename_keys_for_attn(f"{src_prefix}.multihead_attn", f"{dst_prefix}.multihead_attn")
|
| 713 |
+
)
|
| 714 |
+
|
| 715 |
+
return cross_attn_layer_keys
|
| 716 |
+
|
| 717 |
+
def rename_keys_for_self_attn_layer(src_prefix: str, dst_prefix: str):
|
| 718 |
+
self_attn_layer_keys = []
|
| 719 |
+
|
| 720 |
+
self_attn_layer_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.norm", f"{dst_prefix}.norm"))
|
| 721 |
+
self_attn_layer_keys.extend(
|
| 722 |
+
rename_keys_for_self_attn(f"{src_prefix}.self_attn", f"{dst_prefix}.self_attn")
|
| 723 |
+
)
|
| 724 |
+
|
| 725 |
+
return self_attn_layer_keys
|
| 726 |
+
|
| 727 |
+
def rename_keys_for_ffn_layer(src_prefix: str, dst_prefix: str):
|
| 728 |
+
ffn_layer_keys = []
|
| 729 |
+
|
| 730 |
+
ffn_layer_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.linear1", f"{dst_prefix}.linear1"))
|
| 731 |
+
ffn_layer_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.linear2", f"{dst_prefix}.linear2"))
|
| 732 |
+
ffn_layer_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.norm", f"{dst_prefix}.norm"))
|
| 733 |
+
|
| 734 |
+
return ffn_layer_keys
|
| 735 |
+
|
| 736 |
+
def rename_keys_for_transformer_decoder_layer(src_prefix: str, dst_prefix: str, idx: int):
|
| 737 |
+
transformer_decoder_layer_keys = []
|
| 738 |
+
|
| 739 |
+
transformer_decoder_layer_keys.extend(
|
| 740 |
+
rename_keys_for_cross_attn_layer(
|
| 741 |
+
f"{src_prefix}.transformer_cross_attention_layers.{idx}", f"{dst_prefix}.{idx}.cross_attn"
|
| 742 |
+
)
|
| 743 |
+
)
|
| 744 |
+
|
| 745 |
+
transformer_decoder_layer_keys.extend(
|
| 746 |
+
rename_keys_for_self_attn_layer(
|
| 747 |
+
f"{src_prefix}.transformer_self_attention_layers.{idx}", f"{dst_prefix}.{idx}.self_attn"
|
| 748 |
+
)
|
| 749 |
+
)
|
| 750 |
+
|
| 751 |
+
transformer_decoder_layer_keys.extend(
|
| 752 |
+
rename_keys_for_ffn_layer(f"{src_prefix}.transformer_ffn_layers.{idx}", f"{dst_prefix}.{idx}.ffn")
|
| 753 |
+
)
|
| 754 |
+
|
| 755 |
+
return transformer_decoder_layer_keys
|
| 756 |
+
|
| 757 |
+
# positional embedding for object queries
|
| 758 |
+
renamed_keys = [
|
| 759 |
+
(f"{src_prefix}.query_embed.weight", f"{dst_prefix}.queries_embedder.weight"),
|
| 760 |
+
(f"{src_prefix}.level_embed.weight", f"{dst_prefix}.level_embed.weight"),
|
| 761 |
+
]
|
| 762 |
+
|
| 763 |
+
# norm
|
| 764 |
+
renamed_keys.extend(
|
| 765 |
+
rename_keys_for_weight_bias(f"{src_prefix}.decoder_norm", f"{dst_prefix}.decoder.decoder_norm")
|
| 766 |
+
)
|
| 767 |
+
|
| 768 |
+
# proj
|
| 769 |
+
renamed_keys.extend(
|
| 770 |
+
rename_keys_for_weight_bias(
|
| 771 |
+
f"{src_prefix}.class_input_proj", f"{dst_prefix}.decoder.query_input_projection"
|
| 772 |
+
)
|
| 773 |
+
)
|
| 774 |
+
|
| 775 |
+
renamed_keys.extend(
|
| 776 |
+
rename_keys_for_weight_bias(f"{src_prefix}.class_embed", f"{dst_prefix}.decoder.class_embed")
|
| 777 |
+
)
|
| 778 |
+
|
| 779 |
+
for i in range(3):
|
| 780 |
+
renamed_keys.extend(
|
| 781 |
+
rename_keys_for_weight_bias(
|
| 782 |
+
f"{src_prefix}.mask_embed.layers.{i}", f"{dst_prefix}.decoder.mask_embed.layers.{i}.0"
|
| 783 |
+
)
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
# norm
|
| 787 |
+
renamed_keys.extend(
|
| 788 |
+
rename_keys_for_weight_bias(
|
| 789 |
+
f"{src_prefix}.class_transformer.decoder.norm", f"{dst_prefix}.decoder.query_transformer.decoder.norm"
|
| 790 |
+
)
|
| 791 |
+
)
|
| 792 |
+
|
| 793 |
+
# transformer to update queries with task tokens
|
| 794 |
+
for i in range(self.config.query_dec_layers):
|
| 795 |
+
renamed_keys.extend(
|
| 796 |
+
rename_keys_for_query_transformer_layer(
|
| 797 |
+
f"{src_prefix}.class_transformer.decoder.layers.{i}",
|
| 798 |
+
f"{dst_prefix}.decoder.query_transformer.decoder.layers.{i}",
|
| 799 |
+
)
|
| 800 |
+
)
|
| 801 |
+
|
| 802 |
+
# decoder layers
|
| 803 |
+
for i in range(self.config.decoder_layers - 1):
|
| 804 |
+
renamed_keys.extend(
|
| 805 |
+
rename_keys_for_transformer_decoder_layer(
|
| 806 |
+
f"{src_prefix}",
|
| 807 |
+
f"{dst_prefix}.decoder.layers",
|
| 808 |
+
i,
|
| 809 |
+
)
|
| 810 |
+
)
|
| 811 |
+
|
| 812 |
+
self.pop_all(renamed_keys, dst_state_dict, src_state_dict)
|
| 813 |
+
self.replace_keys_qkv_transformer_decoder(dst_state_dict, src_state_dict)
|
| 814 |
+
|
| 815 |
+
def replace_task_mlp(self, dst_state_dict: StateDict, src_state_dict: StateDict):
|
| 816 |
+
dst_prefix: str = "task_encoder"
|
| 817 |
+
src_prefix: str = "task_mlp"
|
| 818 |
+
|
| 819 |
+
def rename_keys_for_weight_bias(src_prefix: str, dst_prefix: str):
|
| 820 |
+
return [
|
| 821 |
+
(f"{src_prefix}.weight", f"{dst_prefix}.weight"),
|
| 822 |
+
(f"{src_prefix}.bias", f"{dst_prefix}.bias"),
|
| 823 |
+
]
|
| 824 |
+
|
| 825 |
+
renamed_keys = []
|
| 826 |
+
|
| 827 |
+
for i in range(2):
|
| 828 |
+
renamed_keys.extend(
|
| 829 |
+
rename_keys_for_weight_bias(f"{src_prefix}.layers.{i}", f"{dst_prefix}.task_mlp.layers.{i}.0")
|
| 830 |
+
)
|
| 831 |
+
|
| 832 |
+
self.pop_all(renamed_keys, dst_state_dict, src_state_dict)
|
| 833 |
+
|
| 834 |
+
def replace_text_projector(self, dst_state_dict: StateDict, src_state_dict: StateDict):
|
| 835 |
+
dst_prefix: str = "text_mapper.text_projector"
|
| 836 |
+
src_prefix: str = "text_projector"
|
| 837 |
+
|
| 838 |
+
def rename_keys_for_weight_bias(src_prefix: str, dst_prefix: str):
|
| 839 |
+
return [
|
| 840 |
+
(f"{src_prefix}.weight", f"{dst_prefix}.weight"),
|
| 841 |
+
(f"{src_prefix}.bias", f"{dst_prefix}.bias"),
|
| 842 |
+
]
|
| 843 |
+
|
| 844 |
+
renamed_keys = []
|
| 845 |
+
|
| 846 |
+
for i in range(self.config.text_encoder_config["text_encoder_proj_layers"]):
|
| 847 |
+
renamed_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.layers.{i}", f"{dst_prefix}.{i}.0"))
|
| 848 |
+
|
| 849 |
+
self.pop_all(renamed_keys, dst_state_dict, src_state_dict)
|
| 850 |
+
|
| 851 |
+
def replace_text_mapper(self, dst_state_dict: StateDict, src_state_dict: StateDict):
|
| 852 |
+
dst_prefix: str = "text_mapper.text_encoder"
|
| 853 |
+
src_prefix: str = "text_encoder"
|
| 854 |
+
|
| 855 |
+
self.replace_text_projector(dst_state_dict, src_state_dict)
|
| 856 |
+
|
| 857 |
+
def rename_keys_for_weight_bias(src_prefix: str, dst_prefix: str):
|
| 858 |
+
return [
|
| 859 |
+
(f"{src_prefix}.weight", f"{dst_prefix}.weight"),
|
| 860 |
+
(f"{src_prefix}.bias", f"{dst_prefix}.bias"),
|
| 861 |
+
]
|
| 862 |
+
|
| 863 |
+
def rename_keys_for_attn(src_prefix: str, dst_prefix: str):
|
| 864 |
+
attn_keys = [
|
| 865 |
+
(f"{src_prefix}.in_proj_bias", f"{dst_prefix}.in_proj_bias"),
|
| 866 |
+
(f"{src_prefix}.in_proj_weight", f"{dst_prefix}.in_proj_weight"),
|
| 867 |
+
]
|
| 868 |
+
attn_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.out_proj", f"{dst_prefix}.out_proj"))
|
| 869 |
+
|
| 870 |
+
return attn_keys
|
| 871 |
+
|
| 872 |
+
def rename_keys_for_layer(src_prefix: str, dst_prefix: str):
|
| 873 |
+
resblock_keys = []
|
| 874 |
+
|
| 875 |
+
resblock_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.mlp.c_fc", f"{dst_prefix}.mlp.fc1"))
|
| 876 |
+
resblock_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.mlp.c_proj", f"{dst_prefix}.mlp.fc2"))
|
| 877 |
+
resblock_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.ln_1", f"{dst_prefix}.layer_norm1"))
|
| 878 |
+
resblock_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.ln_2", f"{dst_prefix}.layer_norm2"))
|
| 879 |
+
resblock_keys.extend(rename_keys_for_attn(f"{src_prefix}.attn", f"{dst_prefix}.self_attn"))
|
| 880 |
+
|
| 881 |
+
return resblock_keys
|
| 882 |
+
|
| 883 |
+
renamed_keys = [
|
| 884 |
+
("prompt_ctx.weight", "text_mapper.prompt_ctx.weight"),
|
| 885 |
+
]
|
| 886 |
+
|
| 887 |
+
renamed_keys.extend(
|
| 888 |
+
[
|
| 889 |
+
(f"{src_prefix}.positional_embedding", f"{dst_prefix}.positional_embedding"),
|
| 890 |
+
(f"{src_prefix}.token_embedding.weight", f"{dst_prefix}.token_embedding.weight"),
|
| 891 |
+
]
|
| 892 |
+
)
|
| 893 |
+
|
| 894 |
+
renamed_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.ln_final", f"{dst_prefix}.ln_final"))
|
| 895 |
+
|
| 896 |
+
for i in range(self.config.text_encoder_config["text_encoder_num_layers"]):
|
| 897 |
+
renamed_keys.extend(
|
| 898 |
+
rename_keys_for_layer(
|
| 899 |
+
f"{src_prefix}.transformer.resblocks.{i}", f"{dst_prefix}.transformer.layers.{i}"
|
| 900 |
+
)
|
| 901 |
+
)
|
| 902 |
+
|
| 903 |
+
self.pop_all(renamed_keys, dst_state_dict, src_state_dict)
|
| 904 |
+
|
| 905 |
+
def convert(self, oneformer: OneFormerModel, is_swin: bool) -> OneFormerModel:
|
| 906 |
+
dst_state_dict = TrackedStateDict(oneformer.state_dict())
|
| 907 |
+
src_state_dict = self.original_model.state_dict()
|
| 908 |
+
|
| 909 |
+
self.replace_pixel_module(dst_state_dict, src_state_dict, is_swin)
|
| 910 |
+
self.replace_transformer_module(dst_state_dict, src_state_dict)
|
| 911 |
+
self.replace_task_mlp(dst_state_dict, src_state_dict)
|
| 912 |
+
if self.config.is_training:
|
| 913 |
+
self.replace_text_mapper(dst_state_dict, src_state_dict)
|
| 914 |
+
|
| 915 |
+
logger.info(f"Missed keys are {pformat(dst_state_dict.diff())}")
|
| 916 |
+
logger.info(f"Not copied keys are {pformat(src_state_dict.keys())}")
|
| 917 |
+
logger.info("🙌 Done")
|
| 918 |
+
|
| 919 |
+
oneformer.load_state_dict(dst_state_dict)
|
| 920 |
+
|
| 921 |
+
return oneformer
|
| 922 |
+
|
| 923 |
+
@staticmethod
|
| 924 |
+
def using_dirs(checkpoints_dir: Path, config_dir: Path) -> Iterator[Tuple[object, Path, Path]]:
|
| 925 |
+
checkpoints: List[Path] = checkpoints_dir.glob("**/*.pth")
|
| 926 |
+
|
| 927 |
+
for checkpoint in checkpoints:
|
| 928 |
+
logger.info(f"💪 Converting {checkpoint.stem}")
|
| 929 |
+
# find associated config file
|
| 930 |
+
config: Path = config_dir / f"{checkpoint.stem}.yaml"
|
| 931 |
+
|
| 932 |
+
yield config, checkpoint
|
| 933 |
+
|
| 934 |
+
|
| 935 |
+
def post_process_sem_seg_output(outputs: OneFormerForUniversalSegmentationOutput, target_size: Tuple[int, int]):
|
| 936 |
+
# class_queries_logits has shape [BATCH, QUERIES, CLASSES + 1]
|
| 937 |
+
class_queries_logits = outputs.class_queries_logits
|
| 938 |
+
# masks_queries_logits has shape [BATCH, QUERIES, HEIGHT, WIDTH]
|
| 939 |
+
masks_queries_logits = outputs.masks_queries_logits
|
| 940 |
+
if target_size is not None:
|
| 941 |
+
masks_queries_logits = torch.nn.functional.interpolate(
|
| 942 |
+
masks_queries_logits,
|
| 943 |
+
size=target_size,
|
| 944 |
+
mode="bilinear",
|
| 945 |
+
align_corners=False,
|
| 946 |
+
)
|
| 947 |
+
# remove the null class `[..., :-1]`
|
| 948 |
+
masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1]
|
| 949 |
+
# mask probs has shape [BATCH, QUERIES, HEIGHT, WIDTH]
|
| 950 |
+
masks_probs = masks_queries_logits.sigmoid()
|
| 951 |
+
# now we want to sum over the queries,
|
| 952 |
+
# $ out_{c,h,w} = \sum_q p_{q,c} * m_{q,h,w} $
|
| 953 |
+
# where $ softmax(p) \in R^{q, c} $ is the mask classes
|
| 954 |
+
# and $ sigmoid(m) \in R^{q, h, w}$ is the mask probabilities
|
| 955 |
+
# b(atch)q(uery)c(lasses), b(atch)q(uery)h(eight)w(idth)
|
| 956 |
+
segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
|
| 957 |
+
|
| 958 |
+
return segmentation
|
| 959 |
+
|
| 960 |
+
|
| 961 |
+
def test(
|
| 962 |
+
original_model,
|
| 963 |
+
our_model: OneFormerForUniversalSegmentation,
|
| 964 |
+
processor: OneFormerProcessor,
|
| 965 |
+
model_repo: str,
|
| 966 |
+
):
|
| 967 |
+
def _preprocess_text(text_list=None, max_length=77):
|
| 968 |
+
if text_list is None:
|
| 969 |
+
raise ValueError("tokens cannot be None.")
|
| 970 |
+
|
| 971 |
+
tokens = tokenizer(text_list, padding="max_length", max_length=max_length, truncation=True)
|
| 972 |
+
|
| 973 |
+
attention_masks, input_ids = tokens["attention_mask"], tokens["input_ids"]
|
| 974 |
+
|
| 975 |
+
token_inputs = []
|
| 976 |
+
for attn_mask, input_id in zip(attention_masks, input_ids):
|
| 977 |
+
token = torch.tensor(attn_mask) * torch.tensor(input_id)
|
| 978 |
+
token_inputs.append(token.unsqueeze(0))
|
| 979 |
+
|
| 980 |
+
token_inputs = torch.cat(token_inputs, dim=0)
|
| 981 |
+
return token_inputs
|
| 982 |
+
|
| 983 |
+
with torch.no_grad():
|
| 984 |
+
tokenizer = CLIPTokenizer.from_pretrained(model_repo)
|
| 985 |
+
original_model = original_model.eval()
|
| 986 |
+
our_model = our_model.eval()
|
| 987 |
+
|
| 988 |
+
im = prepare_img()
|
| 989 |
+
|
| 990 |
+
tr = T.Compose(
|
| 991 |
+
[
|
| 992 |
+
T.Resize((640, 640)),
|
| 993 |
+
T.ToTensor(),
|
| 994 |
+
T.Normalize(
|
| 995 |
+
mean=torch.tensor([123.675, 116.280, 103.530]) / 255.0,
|
| 996 |
+
std=torch.tensor([58.395, 57.120, 57.375]) / 255.0,
|
| 997 |
+
),
|
| 998 |
+
],
|
| 999 |
+
)
|
| 1000 |
+
|
| 1001 |
+
x = tr(im).unsqueeze(0)
|
| 1002 |
+
|
| 1003 |
+
task_input = ["the task is semantic"]
|
| 1004 |
+
task_token = _preprocess_text(task_input, max_length=processor.task_seq_length)
|
| 1005 |
+
|
| 1006 |
+
original_model_backbone_features = original_model.backbone(x.clone())
|
| 1007 |
+
|
| 1008 |
+
our_model_output: OneFormerModelOutput = our_model.model(x.clone(), task_token, output_hidden_states=True)
|
| 1009 |
+
|
| 1010 |
+
for original_model_feature, our_model_feature in zip(
|
| 1011 |
+
original_model_backbone_features.values(), our_model_output.encoder_hidden_states
|
| 1012 |
+
):
|
| 1013 |
+
assert torch.allclose(original_model_feature, our_model_feature, atol=3e-3), (
|
| 1014 |
+
"The backbone features are not the same."
|
| 1015 |
+
)
|
| 1016 |
+
mask_features, _, multi_scale_features, _, _ = original_model.sem_seg_head.pixel_decoder.forward_features(
|
| 1017 |
+
original_model_backbone_features
|
| 1018 |
+
)
|
| 1019 |
+
|
| 1020 |
+
original_pixel_decoder_features = []
|
| 1021 |
+
original_pixel_decoder_features.append(mask_features)
|
| 1022 |
+
for i in range(len(multi_scale_features)):
|
| 1023 |
+
original_pixel_decoder_features.append(multi_scale_features[i])
|
| 1024 |
+
|
| 1025 |
+
for original_model_feature, our_model_feature in zip(
|
| 1026 |
+
original_pixel_decoder_features, our_model_output.pixel_decoder_hidden_states
|
| 1027 |
+
):
|
| 1028 |
+
assert torch.allclose(original_model_feature, our_model_feature, atol=3e-4), (
|
| 1029 |
+
"The pixel decoder feature are not the same"
|
| 1030 |
+
)
|
| 1031 |
+
|
| 1032 |
+
tr_complete = T.Compose(
|
| 1033 |
+
[
|
| 1034 |
+
T.Resize((640, 640)),
|
| 1035 |
+
T.ToTensor(),
|
| 1036 |
+
],
|
| 1037 |
+
)
|
| 1038 |
+
|
| 1039 |
+
y = (tr_complete(im) * 255.0).to(torch.int).float()
|
| 1040 |
+
|
| 1041 |
+
# let's test the full model
|
| 1042 |
+
original_model_out = original_model([{"image": y.clone(), "task": "The task is semantic"}])
|
| 1043 |
+
|
| 1044 |
+
original_segmentation = original_model_out[0]["sem_seg"]
|
| 1045 |
+
|
| 1046 |
+
our_model_out: OneFormerForUniversalSegmentationOutput = our_model(
|
| 1047 |
+
x.clone(), task_token, output_hidden_states=True
|
| 1048 |
+
)
|
| 1049 |
+
|
| 1050 |
+
our_segmentation = post_process_sem_seg_output(our_model_out, target_size=(640, 640))[0]
|
| 1051 |
+
|
| 1052 |
+
assert torch.allclose(original_segmentation, our_segmentation, atol=1e-3), (
|
| 1053 |
+
"The segmentation image is not the same."
|
| 1054 |
+
)
|
| 1055 |
+
|
| 1056 |
+
logger.info("✅ Test passed!")
|
| 1057 |
+
|
| 1058 |
+
|
| 1059 |
+
def get_name(checkpoint_file: Path):
|
| 1060 |
+
model_name_raw: str = checkpoint_file.stem
|
| 1061 |
+
|
| 1062 |
+
backbone = "swin" if "swin" in model_name_raw else "dinat"
|
| 1063 |
+
dataset = ""
|
| 1064 |
+
if "coco" in model_name_raw:
|
| 1065 |
+
dataset = "coco"
|
| 1066 |
+
elif "ade20k" in model_name_raw:
|
| 1067 |
+
dataset = "ade20k"
|
| 1068 |
+
elif "cityscapes" in model_name_raw:
|
| 1069 |
+
dataset = "cityscapes"
|
| 1070 |
+
else:
|
| 1071 |
+
raise ValueError(
|
| 1072 |
+
f"{model_name_raw} must be wrong since we didn't find 'coco' or 'ade20k' or 'cityscapes' in it "
|
| 1073 |
+
)
|
| 1074 |
+
|
| 1075 |
+
backbone_types = ["tiny", "large"]
|
| 1076 |
+
|
| 1077 |
+
backbone_type = list(filter(lambda x: x in model_name_raw, backbone_types))[0]
|
| 1078 |
+
|
| 1079 |
+
model_name = f"oneformer_{dataset}_{backbone}_{backbone_type}"
|
| 1080 |
+
|
| 1081 |
+
return model_name
|
| 1082 |
+
|
| 1083 |
+
|
| 1084 |
+
if __name__ == "__main__":
|
| 1085 |
+
parser = ArgumentParser(
|
| 1086 |
+
description=(
|
| 1087 |
+
"Command line to convert the original oneformer models (with swin backbone) to transformers"
|
| 1088 |
+
" implementation."
|
| 1089 |
+
)
|
| 1090 |
+
)
|
| 1091 |
+
|
| 1092 |
+
parser.add_argument(
|
| 1093 |
+
"--checkpoints_dir",
|
| 1094 |
+
type=Path,
|
| 1095 |
+
help=(
|
| 1096 |
+
"A directory containing the model's checkpoints. The directory has to have the following structure:"
|
| 1097 |
+
" structure: <DIR_NAME>/<DATASET_NAME>/<CONFIG_NAME>.pth; where <CONFIG_NAME> name must follow the"
|
| 1098 |
+
" following nomenclature nomenclature: oneformer_<DATASET_NAME>_<BACKBONE>_<BACKBONE_TYPE>"
|
| 1099 |
+
),
|
| 1100 |
+
)
|
| 1101 |
+
parser.add_argument(
|
| 1102 |
+
"--configs_dir",
|
| 1103 |
+
type=Path,
|
| 1104 |
+
help=(
|
| 1105 |
+
"A directory containing the model's configs, see detectron2 doc. The directory has to have the following"
|
| 1106 |
+
" structure: <DIR_NAME>/<DATASET_NAME>/<CONFIG_NAME>.yaml; where <CONFIG_NAME> name must follow the"
|
| 1107 |
+
" following nomenclature nomenclature: oneformer_<DATASET_NAME>_<BACKBONE>_<BACKBONE_TYPE>"
|
| 1108 |
+
),
|
| 1109 |
+
)
|
| 1110 |
+
parser.add_argument(
|
| 1111 |
+
"--pytorch_dump_folder_path",
|
| 1112 |
+
required=True,
|
| 1113 |
+
type=Path,
|
| 1114 |
+
help="Path to the folder to output PyTorch models.",
|
| 1115 |
+
)
|
| 1116 |
+
parser.add_argument(
|
| 1117 |
+
"--oneformer_dir",
|
| 1118 |
+
required=True,
|
| 1119 |
+
type=Path,
|
| 1120 |
+
help=(
|
| 1121 |
+
"A path to OneFormer's original implementation directory. You can download from here: "
|
| 1122 |
+
"https://github.com/SHI-Labs/OneFormer"
|
| 1123 |
+
),
|
| 1124 |
+
)
|
| 1125 |
+
|
| 1126 |
+
args = parser.parse_args()
|
| 1127 |
+
|
| 1128 |
+
checkpoints_dir: Path = args.checkpoints_dir
|
| 1129 |
+
config_dir: Path = args.configs_dir
|
| 1130 |
+
save_directory: Path = args.pytorch_dump_folder_path
|
| 1131 |
+
oneformer_dir: Path = args.oneformer_dir
|
| 1132 |
+
# append the path to the parents to oneformer dir
|
| 1133 |
+
sys.path.append(str(oneformer_dir.parent))
|
| 1134 |
+
# and import what's needed
|
| 1135 |
+
from OneFormer.oneformer import add_common_config, add_dinat_config, add_oneformer_config, add_swin_config
|
| 1136 |
+
from OneFormer.oneformer.oneformer_model import OneFormer as OriginalOneFormer
|
| 1137 |
+
|
| 1138 |
+
if not save_directory.exists():
|
| 1139 |
+
save_directory.mkdir(parents=True)
|
| 1140 |
+
|
| 1141 |
+
for config_file, checkpoint_file in OriginalOneFormerCheckpointToOursConverter.using_dirs(
|
| 1142 |
+
checkpoints_dir, config_dir
|
| 1143 |
+
):
|
| 1144 |
+
processor = OriginalOneFormerConfigToProcessorConverter()(
|
| 1145 |
+
setup_cfg(Args(config_file=config_file)), os.path.join("shi-labs", config_file.stem)
|
| 1146 |
+
)
|
| 1147 |
+
|
| 1148 |
+
original_config = setup_cfg(Args(config_file=config_file))
|
| 1149 |
+
oneformer_kwargs = OriginalOneFormer.from_config(original_config)
|
| 1150 |
+
|
| 1151 |
+
original_model = OriginalOneFormer(**oneformer_kwargs).eval()
|
| 1152 |
+
|
| 1153 |
+
DetectionCheckpointer(original_model).load(str(checkpoint_file))
|
| 1154 |
+
|
| 1155 |
+
is_swin = "swin" in config_file.stem
|
| 1156 |
+
|
| 1157 |
+
config: OneFormerConfig = OriginalOneFormerConfigToOursConverter()(original_config, is_swin)
|
| 1158 |
+
|
| 1159 |
+
oneformer = OneFormerModel(config=config).eval()
|
| 1160 |
+
|
| 1161 |
+
converter = OriginalOneFormerCheckpointToOursConverter(original_model, config)
|
| 1162 |
+
|
| 1163 |
+
oneformer = converter.convert(oneformer, is_swin)
|
| 1164 |
+
|
| 1165 |
+
oneformer_for_universal_segmentation = OneFormerForUniversalSegmentation(config=config).eval()
|
| 1166 |
+
|
| 1167 |
+
oneformer_for_universal_segmentation.model = oneformer
|
| 1168 |
+
|
| 1169 |
+
test(
|
| 1170 |
+
original_model,
|
| 1171 |
+
oneformer_for_universal_segmentation,
|
| 1172 |
+
processor,
|
| 1173 |
+
os.path.join("shi-labs", config_file.stem),
|
| 1174 |
+
)
|
| 1175 |
+
|
| 1176 |
+
model_name = get_name(checkpoint_file)
|
| 1177 |
+
logger.info(f"🪄 Saving {model_name}")
|
| 1178 |
+
|
| 1179 |
+
processor.save_pretrained(save_directory / model_name)
|
| 1180 |
+
oneformer_for_universal_segmentation.save_pretrained(save_directory / model_name)
|
| 1181 |
+
|
| 1182 |
+
processor.push_to_hub(
|
| 1183 |
+
repo_id=os.path.join("shi-labs", config_file.stem),
|
| 1184 |
+
commit_message="Add configs",
|
| 1185 |
+
use_temp_dir=True,
|
| 1186 |
+
)
|
| 1187 |
+
oneformer_for_universal_segmentation.push_to_hub(
|
| 1188 |
+
repo_id=os.path.join("shi-labs", config_file.stem),
|
| 1189 |
+
commit_message="Add model",
|
| 1190 |
+
use_temp_dir=True,
|
| 1191 |
+
)
|
docs/transformers/src/transformers/models/oneformer/image_processing_oneformer.py
ADDED
|
@@ -0,0 +1,1356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 SHI Labs and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Image processor class for OneFormer."""
|
| 16 |
+
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
from huggingface_hub import hf_hub_download
|
| 23 |
+
from huggingface_hub.utils import RepositoryNotFoundError
|
| 24 |
+
|
| 25 |
+
from ...image_processing_utils import INIT_SERVICE_KWARGS, BaseImageProcessor, BatchFeature, get_size_dict
|
| 26 |
+
from ...image_transforms import (
|
| 27 |
+
PaddingMode,
|
| 28 |
+
get_resize_output_image_size,
|
| 29 |
+
pad,
|
| 30 |
+
rescale,
|
| 31 |
+
resize,
|
| 32 |
+
to_channel_dimension_format,
|
| 33 |
+
)
|
| 34 |
+
from ...image_utils import (
|
| 35 |
+
ChannelDimension,
|
| 36 |
+
ImageInput,
|
| 37 |
+
PILImageResampling,
|
| 38 |
+
get_image_size,
|
| 39 |
+
infer_channel_dimension_format,
|
| 40 |
+
is_scaled_image,
|
| 41 |
+
make_list_of_images,
|
| 42 |
+
to_numpy_array,
|
| 43 |
+
valid_images,
|
| 44 |
+
validate_preprocess_arguments,
|
| 45 |
+
)
|
| 46 |
+
from ...utils import (
|
| 47 |
+
IMAGENET_DEFAULT_MEAN,
|
| 48 |
+
IMAGENET_DEFAULT_STD,
|
| 49 |
+
TensorType,
|
| 50 |
+
filter_out_non_signature_kwargs,
|
| 51 |
+
is_torch_available,
|
| 52 |
+
is_torch_tensor,
|
| 53 |
+
logging,
|
| 54 |
+
)
|
| 55 |
+
from ...utils.deprecation import deprecate_kwarg
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
logger = logging.get_logger(__name__)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
if is_torch_available():
|
| 62 |
+
import torch
|
| 63 |
+
from torch import nn
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# Copied from transformers.models.detr.image_processing_detr.max_across_indices
|
| 67 |
+
def max_across_indices(values: Iterable[Any]) -> List[Any]:
|
| 68 |
+
"""
|
| 69 |
+
Return the maximum value across all indices of an iterable of values.
|
| 70 |
+
"""
|
| 71 |
+
return [max(values_i) for values_i in zip(*values)]
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# Copied from transformers.models.detr.image_processing_detr.get_max_height_width
|
| 75 |
+
def get_max_height_width(
|
| 76 |
+
images: List[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
|
| 77 |
+
) -> List[int]:
|
| 78 |
+
"""
|
| 79 |
+
Get the maximum height and width across all images in a batch.
|
| 80 |
+
"""
|
| 81 |
+
if input_data_format is None:
|
| 82 |
+
input_data_format = infer_channel_dimension_format(images[0])
|
| 83 |
+
|
| 84 |
+
if input_data_format == ChannelDimension.FIRST:
|
| 85 |
+
_, max_height, max_width = max_across_indices([img.shape for img in images])
|
| 86 |
+
elif input_data_format == ChannelDimension.LAST:
|
| 87 |
+
max_height, max_width, _ = max_across_indices([img.shape for img in images])
|
| 88 |
+
else:
|
| 89 |
+
raise ValueError(f"Invalid channel dimension format: {input_data_format}")
|
| 90 |
+
return (max_height, max_width)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# Copied from transformers.models.detr.image_processing_detr.make_pixel_mask
|
| 94 |
+
def make_pixel_mask(
|
| 95 |
+
image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
|
| 96 |
+
) -> np.ndarray:
|
| 97 |
+
"""
|
| 98 |
+
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
image (`np.ndarray`):
|
| 102 |
+
Image to make the pixel mask for.
|
| 103 |
+
output_size (`Tuple[int, int]`):
|
| 104 |
+
Output size of the mask.
|
| 105 |
+
"""
|
| 106 |
+
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
| 107 |
+
mask = np.zeros(output_size, dtype=np.int64)
|
| 108 |
+
mask[:input_height, :input_width] = 1
|
| 109 |
+
return mask
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# Copied from transformers.models.detr.image_processing_detr.binary_mask_to_rle
|
| 113 |
+
def binary_mask_to_rle(mask):
|
| 114 |
+
"""
|
| 115 |
+
Converts given binary mask of shape `(height, width)` to the run-length encoding (RLE) format.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
mask (`torch.Tensor` or `numpy.array`):
|
| 119 |
+
A binary mask tensor of shape `(height, width)` where 0 denotes background and 1 denotes the target
|
| 120 |
+
segment_id or class_id.
|
| 121 |
+
Returns:
|
| 122 |
+
`List`: Run-length encoded list of the binary mask. Refer to COCO API for more information about the RLE
|
| 123 |
+
format.
|
| 124 |
+
"""
|
| 125 |
+
if is_torch_tensor(mask):
|
| 126 |
+
mask = mask.numpy()
|
| 127 |
+
|
| 128 |
+
pixels = mask.flatten()
|
| 129 |
+
pixels = np.concatenate([[0], pixels, [0]])
|
| 130 |
+
runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
|
| 131 |
+
runs[1::2] -= runs[::2]
|
| 132 |
+
return list(runs)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# Copied from transformers.models.detr.image_processing_detr.convert_segmentation_to_rle
|
| 136 |
+
def convert_segmentation_to_rle(segmentation):
|
| 137 |
+
"""
|
| 138 |
+
Converts given segmentation map of shape `(height, width)` to the run-length encoding (RLE) format.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
segmentation (`torch.Tensor` or `numpy.array`):
|
| 142 |
+
A segmentation map of shape `(height, width)` where each value denotes a segment or class id.
|
| 143 |
+
Returns:
|
| 144 |
+
`List[List]`: A list of lists, where each list is the run-length encoding of a segment / class id.
|
| 145 |
+
"""
|
| 146 |
+
segment_ids = torch.unique(segmentation)
|
| 147 |
+
|
| 148 |
+
run_length_encodings = []
|
| 149 |
+
for idx in segment_ids:
|
| 150 |
+
mask = torch.where(segmentation == idx, 1, 0)
|
| 151 |
+
rle = binary_mask_to_rle(mask)
|
| 152 |
+
run_length_encodings.append(rle)
|
| 153 |
+
|
| 154 |
+
return run_length_encodings
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# Copied from transformers.models.detr.image_processing_detr.remove_low_and_no_objects
|
| 158 |
+
def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_labels):
|
| 159 |
+
"""
|
| 160 |
+
Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores` and
|
| 161 |
+
`labels`.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
masks (`torch.Tensor`):
|
| 165 |
+
A tensor of shape `(num_queries, height, width)`.
|
| 166 |
+
scores (`torch.Tensor`):
|
| 167 |
+
A tensor of shape `(num_queries)`.
|
| 168 |
+
labels (`torch.Tensor`):
|
| 169 |
+
A tensor of shape `(num_queries)`.
|
| 170 |
+
object_mask_threshold (`float`):
|
| 171 |
+
A number between 0 and 1 used to binarize the masks.
|
| 172 |
+
Raises:
|
| 173 |
+
`ValueError`: Raised when the first dimension doesn't match in all input tensors.
|
| 174 |
+
Returns:
|
| 175 |
+
`Tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`]`: The `masks`, `scores` and `labels` without the region
|
| 176 |
+
< `object_mask_threshold`.
|
| 177 |
+
"""
|
| 178 |
+
if not (masks.shape[0] == scores.shape[0] == labels.shape[0]):
|
| 179 |
+
raise ValueError("mask, scores and labels must have the same shape!")
|
| 180 |
+
|
| 181 |
+
to_keep = labels.ne(num_labels) & (scores > object_mask_threshold)
|
| 182 |
+
|
| 183 |
+
return masks[to_keep], scores[to_keep], labels[to_keep]
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
# Copied from transformers.models.detr.image_processing_detr.check_segment_validity
|
| 187 |
+
def check_segment_validity(mask_labels, mask_probs, k, mask_threshold=0.5, overlap_mask_area_threshold=0.8):
|
| 188 |
+
# Get the mask associated with the k class
|
| 189 |
+
mask_k = mask_labels == k
|
| 190 |
+
mask_k_area = mask_k.sum()
|
| 191 |
+
|
| 192 |
+
# Compute the area of all the stuff in query k
|
| 193 |
+
original_area = (mask_probs[k] >= mask_threshold).sum()
|
| 194 |
+
mask_exists = mask_k_area > 0 and original_area > 0
|
| 195 |
+
|
| 196 |
+
# Eliminate disconnected tiny segments
|
| 197 |
+
if mask_exists:
|
| 198 |
+
area_ratio = mask_k_area / original_area
|
| 199 |
+
if not area_ratio.item() > overlap_mask_area_threshold:
|
| 200 |
+
mask_exists = False
|
| 201 |
+
|
| 202 |
+
return mask_exists, mask_k
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
# Copied from transformers.models.detr.image_processing_detr.compute_segments
|
| 206 |
+
def compute_segments(
|
| 207 |
+
mask_probs,
|
| 208 |
+
pred_scores,
|
| 209 |
+
pred_labels,
|
| 210 |
+
mask_threshold: float = 0.5,
|
| 211 |
+
overlap_mask_area_threshold: float = 0.8,
|
| 212 |
+
label_ids_to_fuse: Optional[Set[int]] = None,
|
| 213 |
+
target_size: Tuple[int, int] = None,
|
| 214 |
+
):
|
| 215 |
+
height = mask_probs.shape[1] if target_size is None else target_size[0]
|
| 216 |
+
width = mask_probs.shape[2] if target_size is None else target_size[1]
|
| 217 |
+
|
| 218 |
+
segmentation = torch.zeros((height, width), dtype=torch.int32, device=mask_probs.device)
|
| 219 |
+
segments: List[Dict] = []
|
| 220 |
+
|
| 221 |
+
if target_size is not None:
|
| 222 |
+
mask_probs = nn.functional.interpolate(
|
| 223 |
+
mask_probs.unsqueeze(0), size=target_size, mode="bilinear", align_corners=False
|
| 224 |
+
)[0]
|
| 225 |
+
|
| 226 |
+
current_segment_id = 0
|
| 227 |
+
|
| 228 |
+
# Weigh each mask by its prediction score
|
| 229 |
+
mask_probs *= pred_scores.view(-1, 1, 1)
|
| 230 |
+
mask_labels = mask_probs.argmax(0) # [height, width]
|
| 231 |
+
|
| 232 |
+
# Keep track of instances of each class
|
| 233 |
+
stuff_memory_list: Dict[str, int] = {}
|
| 234 |
+
for k in range(pred_labels.shape[0]):
|
| 235 |
+
pred_class = pred_labels[k].item()
|
| 236 |
+
should_fuse = pred_class in label_ids_to_fuse
|
| 237 |
+
|
| 238 |
+
# Check if mask exists and large enough to be a segment
|
| 239 |
+
mask_exists, mask_k = check_segment_validity(
|
| 240 |
+
mask_labels, mask_probs, k, mask_threshold, overlap_mask_area_threshold
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
if mask_exists:
|
| 244 |
+
if pred_class in stuff_memory_list:
|
| 245 |
+
current_segment_id = stuff_memory_list[pred_class]
|
| 246 |
+
else:
|
| 247 |
+
current_segment_id += 1
|
| 248 |
+
|
| 249 |
+
# Add current object segment to final segmentation map
|
| 250 |
+
segmentation[mask_k] = current_segment_id
|
| 251 |
+
segment_score = round(pred_scores[k].item(), 6)
|
| 252 |
+
segments.append(
|
| 253 |
+
{
|
| 254 |
+
"id": current_segment_id,
|
| 255 |
+
"label_id": pred_class,
|
| 256 |
+
"was_fused": should_fuse,
|
| 257 |
+
"score": segment_score,
|
| 258 |
+
}
|
| 259 |
+
)
|
| 260 |
+
if should_fuse:
|
| 261 |
+
stuff_memory_list[pred_class] = current_segment_id
|
| 262 |
+
|
| 263 |
+
return segmentation, segments
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
# Copied from transformers.models.maskformer.image_processing_maskformer.convert_segmentation_map_to_binary_masks
|
| 267 |
+
def convert_segmentation_map_to_binary_masks(
|
| 268 |
+
segmentation_map: "np.ndarray",
|
| 269 |
+
instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
|
| 270 |
+
ignore_index: Optional[int] = None,
|
| 271 |
+
do_reduce_labels: bool = False,
|
| 272 |
+
):
|
| 273 |
+
if do_reduce_labels and ignore_index is None:
|
| 274 |
+
raise ValueError("If `do_reduce_labels` is True, `ignore_index` must be provided.")
|
| 275 |
+
|
| 276 |
+
if do_reduce_labels:
|
| 277 |
+
segmentation_map = np.where(segmentation_map == 0, ignore_index, segmentation_map - 1)
|
| 278 |
+
|
| 279 |
+
# Get unique ids (class or instance ids based on input)
|
| 280 |
+
all_labels = np.unique(segmentation_map)
|
| 281 |
+
|
| 282 |
+
# Drop background label if applicable
|
| 283 |
+
if ignore_index is not None:
|
| 284 |
+
all_labels = all_labels[all_labels != ignore_index]
|
| 285 |
+
|
| 286 |
+
# Generate a binary mask for each object instance
|
| 287 |
+
binary_masks = [(segmentation_map == i) for i in all_labels]
|
| 288 |
+
|
| 289 |
+
# Stack the binary masks
|
| 290 |
+
if binary_masks:
|
| 291 |
+
binary_masks = np.stack(binary_masks, axis=0)
|
| 292 |
+
else:
|
| 293 |
+
binary_masks = np.zeros((0, *segmentation_map.shape))
|
| 294 |
+
|
| 295 |
+
# Convert instance ids to class ids
|
| 296 |
+
if instance_id_to_semantic_id is not None:
|
| 297 |
+
labels = np.zeros(all_labels.shape[0])
|
| 298 |
+
|
| 299 |
+
for label in all_labels:
|
| 300 |
+
class_id = instance_id_to_semantic_id[label + 1 if do_reduce_labels else label]
|
| 301 |
+
labels[all_labels == label] = class_id - 1 if do_reduce_labels else class_id
|
| 302 |
+
else:
|
| 303 |
+
labels = all_labels
|
| 304 |
+
|
| 305 |
+
return binary_masks.astype(np.float32), labels.astype(np.int64)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def get_oneformer_resize_output_image_size(
|
| 309 |
+
image: np.ndarray,
|
| 310 |
+
size: Union[int, Tuple[int, int], List[int], Tuple[int]],
|
| 311 |
+
max_size: Optional[int] = None,
|
| 312 |
+
default_to_square: bool = True,
|
| 313 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 314 |
+
) -> tuple:
|
| 315 |
+
"""
|
| 316 |
+
Computes the output size given the desired size.
|
| 317 |
+
|
| 318 |
+
Args:
|
| 319 |
+
image (`np.ndarray`):
|
| 320 |
+
The input image.
|
| 321 |
+
size (`int` or `Tuple[int, int]` or `List[int]` or `Tuple[int]`):
|
| 322 |
+
The size of the output image.
|
| 323 |
+
max_size (`int`, *optional*):
|
| 324 |
+
The maximum size of the output image.
|
| 325 |
+
default_to_square (`bool`, *optional*, defaults to `True`):
|
| 326 |
+
Whether to default to square if no size is provided.
|
| 327 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 328 |
+
The channel dimension format of the input image. If unset, will use the inferred format from the input.
|
| 329 |
+
|
| 330 |
+
Returns:
|
| 331 |
+
`Tuple[int, int]`: The output size.
|
| 332 |
+
"""
|
| 333 |
+
output_size = get_resize_output_image_size(
|
| 334 |
+
input_image=image,
|
| 335 |
+
size=size,
|
| 336 |
+
default_to_square=default_to_square,
|
| 337 |
+
max_size=max_size,
|
| 338 |
+
input_data_format=input_data_format,
|
| 339 |
+
)
|
| 340 |
+
return output_size
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def prepare_metadata(class_info):
|
| 344 |
+
metadata = {}
|
| 345 |
+
class_names = []
|
| 346 |
+
thing_ids = []
|
| 347 |
+
for key, info in class_info.items():
|
| 348 |
+
metadata[key] = info["name"]
|
| 349 |
+
class_names.append(info["name"])
|
| 350 |
+
if info["isthing"]:
|
| 351 |
+
thing_ids.append(int(key))
|
| 352 |
+
metadata["thing_ids"] = thing_ids
|
| 353 |
+
metadata["class_names"] = class_names
|
| 354 |
+
return metadata
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def load_metadata(repo_id, class_info_file):
|
| 358 |
+
fname = os.path.join("" if repo_id is None else repo_id, class_info_file)
|
| 359 |
+
|
| 360 |
+
if not os.path.exists(fname) or not os.path.isfile(fname):
|
| 361 |
+
if repo_id is None:
|
| 362 |
+
raise ValueError(f"Could not file {fname} locally. repo_id must be defined if loading from the hub")
|
| 363 |
+
# We try downloading from a dataset by default for backward compatibility
|
| 364 |
+
try:
|
| 365 |
+
fname = hf_hub_download(repo_id, class_info_file, repo_type="dataset")
|
| 366 |
+
except RepositoryNotFoundError:
|
| 367 |
+
fname = hf_hub_download(repo_id, class_info_file)
|
| 368 |
+
|
| 369 |
+
with open(fname, "r") as f:
|
| 370 |
+
class_info = json.load(f)
|
| 371 |
+
|
| 372 |
+
return class_info
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
class OneFormerImageProcessor(BaseImageProcessor):
|
| 376 |
+
r"""
|
| 377 |
+
Constructs a OneFormer image processor. The image processor can be used to prepare image(s), task input(s) and
|
| 378 |
+
optional text inputs and targets for the model.
|
| 379 |
+
|
| 380 |
+
This image processor inherits from [`BaseImageProcessor`] which contains most of the main methods. Users should
|
| 381 |
+
refer to this superclass for more information regarding those methods.
|
| 382 |
+
|
| 383 |
+
Args:
|
| 384 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
| 385 |
+
Whether to resize the input to a certain `size`.
|
| 386 |
+
size (`int`, *optional*, defaults to 800):
|
| 387 |
+
Resize the input to the given size. Only has an effect if `do_resize` is set to `True`. If size is a
|
| 388 |
+
sequence like `(width, height)`, output size will be matched to this. If size is an int, smaller edge of
|
| 389 |
+
the image will be matched to this number. i.e, if `height > width`, then image will be rescaled to `(size *
|
| 390 |
+
height / width, size)`.
|
| 391 |
+
resample (`int`, *optional*, defaults to `Resampling.BILINEAR`):
|
| 392 |
+
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
|
| 393 |
+
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
|
| 394 |
+
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
|
| 395 |
+
to `True`.
|
| 396 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
| 397 |
+
Whether to rescale the input to a certain `scale`.
|
| 398 |
+
rescale_factor (`float`, *optional*, defaults to `1/ 255`):
|
| 399 |
+
Rescale the input by the given factor. Only has an effect if `do_rescale` is set to `True`.
|
| 400 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
| 401 |
+
Whether or not to normalize the input with mean and standard deviation.
|
| 402 |
+
image_mean (`int`, *optional*, defaults to `[0.485, 0.456, 0.406]`):
|
| 403 |
+
The sequence of means for each channel, to be used when normalizing images. Defaults to the ImageNet mean.
|
| 404 |
+
image_std (`int`, *optional*, defaults to `[0.229, 0.224, 0.225]`):
|
| 405 |
+
The sequence of standard deviations for each channel, to be used when normalizing images. Defaults to the
|
| 406 |
+
ImageNet std.
|
| 407 |
+
ignore_index (`int`, *optional*):
|
| 408 |
+
Label to be assigned to background pixels in segmentation maps. If provided, segmentation map pixels
|
| 409 |
+
denoted with 0 (background) will be replaced with `ignore_index`.
|
| 410 |
+
do_reduce_labels (`bool`, *optional*, defaults to `False`):
|
| 411 |
+
Whether or not to decrement all label values of segmentation maps by 1. Usually used for datasets where 0
|
| 412 |
+
is used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k).
|
| 413 |
+
The background label will be replaced by `ignore_index`.
|
| 414 |
+
repo_path (`str`, *optional*, defaults to `"shi-labs/oneformer_demo"`):
|
| 415 |
+
Path to hub repo or local directory containing the JSON file with class information for the dataset.
|
| 416 |
+
If unset, will look for `class_info_file` in the current working directory.
|
| 417 |
+
class_info_file (`str`, *optional*):
|
| 418 |
+
JSON file containing class information for the dataset. See `shi-labs/oneformer_demo/cityscapes_panoptic.json` for an example.
|
| 419 |
+
num_text (`int`, *optional*):
|
| 420 |
+
Number of text entries in the text input list.
|
| 421 |
+
num_labels (`int`, *optional*):
|
| 422 |
+
The number of labels in the segmentation map.
|
| 423 |
+
"""
|
| 424 |
+
|
| 425 |
+
model_input_names = ["pixel_values", "pixel_mask", "task_inputs"]
|
| 426 |
+
|
| 427 |
+
@deprecate_kwarg("reduce_labels", new_name="do_reduce_labels", version="4.44.0")
|
| 428 |
+
@deprecate_kwarg("max_size", version="4.27.0", warn_if_greater_or_equal_version=True)
|
| 429 |
+
@filter_out_non_signature_kwargs(extra=["max_size", "metadata", *INIT_SERVICE_KWARGS])
|
| 430 |
+
def __init__(
|
| 431 |
+
self,
|
| 432 |
+
do_resize: bool = True,
|
| 433 |
+
size: Dict[str, int] = None,
|
| 434 |
+
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
| 435 |
+
do_rescale: bool = True,
|
| 436 |
+
rescale_factor: float = 1 / 255,
|
| 437 |
+
do_normalize: bool = True,
|
| 438 |
+
image_mean: Union[float, List[float]] = None,
|
| 439 |
+
image_std: Union[float, List[float]] = None,
|
| 440 |
+
ignore_index: Optional[int] = None,
|
| 441 |
+
do_reduce_labels: bool = False,
|
| 442 |
+
repo_path: Optional[str] = "shi-labs/oneformer_demo",
|
| 443 |
+
class_info_file: Optional[str] = None,
|
| 444 |
+
num_text: Optional[int] = None,
|
| 445 |
+
num_labels: Optional[int] = None,
|
| 446 |
+
**kwargs,
|
| 447 |
+
):
|
| 448 |
+
super().__init__(**kwargs)
|
| 449 |
+
|
| 450 |
+
# Deprecated, backward compatibility
|
| 451 |
+
self._max_size = kwargs.pop("max_size", 1333)
|
| 452 |
+
|
| 453 |
+
size = size if size is not None else {"shortest_edge": 800, "longest_edge": self._max_size}
|
| 454 |
+
size = get_size_dict(size, max_size=self._max_size, default_to_square=False)
|
| 455 |
+
|
| 456 |
+
if class_info_file is None:
|
| 457 |
+
raise ValueError("You must provide a `class_info_file`")
|
| 458 |
+
|
| 459 |
+
self.do_resize = do_resize
|
| 460 |
+
self.size = size
|
| 461 |
+
self.resample = resample
|
| 462 |
+
self.do_rescale = do_rescale
|
| 463 |
+
self.rescale_factor = rescale_factor
|
| 464 |
+
self.do_normalize = do_normalize
|
| 465 |
+
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
|
| 466 |
+
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
|
| 467 |
+
self.ignore_index = ignore_index
|
| 468 |
+
self.do_reduce_labels = do_reduce_labels
|
| 469 |
+
self.class_info_file = class_info_file
|
| 470 |
+
self.repo_path = repo_path
|
| 471 |
+
self.metadata = prepare_metadata(load_metadata(repo_path, class_info_file))
|
| 472 |
+
self.num_text = num_text
|
| 473 |
+
self.num_labels = num_labels
|
| 474 |
+
|
| 475 |
+
@classmethod
|
| 476 |
+
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
|
| 477 |
+
"""
|
| 478 |
+
Overrides the `from_dict` method from the base class to save support of deprecated `reduce_labels` in old configs
|
| 479 |
+
"""
|
| 480 |
+
image_processor_dict = image_processor_dict.copy()
|
| 481 |
+
if "reduce_labels" in image_processor_dict:
|
| 482 |
+
image_processor_dict["do_reduce_labels"] = image_processor_dict.pop("reduce_labels")
|
| 483 |
+
return super().from_dict(image_processor_dict, **kwargs)
|
| 484 |
+
|
| 485 |
+
# Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.to_dict
|
| 486 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 487 |
+
"""
|
| 488 |
+
Serializes this instance to a Python dictionary. This method calls the superclass method and then removes the
|
| 489 |
+
`_max_size` attribute from the dictionary.
|
| 490 |
+
"""
|
| 491 |
+
image_processor_dict = super().to_dict()
|
| 492 |
+
image_processor_dict.pop("_max_size", None)
|
| 493 |
+
return image_processor_dict
|
| 494 |
+
|
| 495 |
+
@deprecate_kwarg("max_size", version="4.27.0", warn_if_greater_or_equal_version=True)
|
| 496 |
+
@filter_out_non_signature_kwargs(extra=["max_size"])
|
| 497 |
+
def resize(
|
| 498 |
+
self,
|
| 499 |
+
image: np.ndarray,
|
| 500 |
+
size: Dict[str, int],
|
| 501 |
+
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
| 502 |
+
data_format=None,
|
| 503 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 504 |
+
**kwargs,
|
| 505 |
+
) -> np.ndarray:
|
| 506 |
+
"""
|
| 507 |
+
Resize the image to the given size. Size can be min_size (scalar) or `(height, width)` tuple. If size is an
|
| 508 |
+
int, smaller edge of the image will be matched to this number.
|
| 509 |
+
"""
|
| 510 |
+
|
| 511 |
+
# Deprecated, backward compatibility
|
| 512 |
+
max_size = kwargs.pop("max_size", None)
|
| 513 |
+
|
| 514 |
+
size = get_size_dict(size, max_size=max_size, default_to_square=False)
|
| 515 |
+
if "shortest_edge" in size and "longest_edge" in size:
|
| 516 |
+
size, max_size = size["shortest_edge"], size["longest_edge"]
|
| 517 |
+
elif "height" in size and "width" in size:
|
| 518 |
+
size = (size["height"], size["width"])
|
| 519 |
+
max_size = None
|
| 520 |
+
else:
|
| 521 |
+
raise ValueError(
|
| 522 |
+
"Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got"
|
| 523 |
+
f" {size.keys()}."
|
| 524 |
+
)
|
| 525 |
+
size = get_oneformer_resize_output_image_size(
|
| 526 |
+
image=image, size=size, max_size=max_size, default_to_square=False, input_data_format=input_data_format
|
| 527 |
+
)
|
| 528 |
+
image = resize(
|
| 529 |
+
image, size=size, resample=resample, data_format=data_format, input_data_format=input_data_format
|
| 530 |
+
)
|
| 531 |
+
return image
|
| 532 |
+
|
| 533 |
+
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.rescale
|
| 534 |
+
def rescale(
|
| 535 |
+
self,
|
| 536 |
+
image: np.ndarray,
|
| 537 |
+
rescale_factor: float,
|
| 538 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 539 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 540 |
+
) -> np.ndarray:
|
| 541 |
+
"""
|
| 542 |
+
Rescale the image by the given factor. image = image * rescale_factor.
|
| 543 |
+
|
| 544 |
+
Args:
|
| 545 |
+
image (`np.ndarray`):
|
| 546 |
+
Image to rescale.
|
| 547 |
+
rescale_factor (`float`):
|
| 548 |
+
The value to use for rescaling.
|
| 549 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
| 550 |
+
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
| 551 |
+
image is used. Can be one of:
|
| 552 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 553 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 554 |
+
input_data_format (`str` or `ChannelDimension`, *optional*):
|
| 555 |
+
The channel dimension format for the input image. If unset, is inferred from the input image. Can be
|
| 556 |
+
one of:
|
| 557 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 558 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 559 |
+
"""
|
| 560 |
+
return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format)
|
| 561 |
+
|
| 562 |
+
# Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.convert_segmentation_map_to_binary_masks
|
| 563 |
+
def convert_segmentation_map_to_binary_masks(
|
| 564 |
+
self,
|
| 565 |
+
segmentation_map: "np.ndarray",
|
| 566 |
+
instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
|
| 567 |
+
ignore_index: Optional[int] = None,
|
| 568 |
+
do_reduce_labels: bool = False,
|
| 569 |
+
):
|
| 570 |
+
do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels
|
| 571 |
+
ignore_index = ignore_index if ignore_index is not None else self.ignore_index
|
| 572 |
+
return convert_segmentation_map_to_binary_masks(
|
| 573 |
+
segmentation_map=segmentation_map,
|
| 574 |
+
instance_id_to_semantic_id=instance_id_to_semantic_id,
|
| 575 |
+
ignore_index=ignore_index,
|
| 576 |
+
do_reduce_labels=do_reduce_labels,
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
def __call__(self, images, task_inputs=None, segmentation_maps=None, **kwargs) -> BatchFeature:
|
| 580 |
+
return self.preprocess(images, task_inputs=task_inputs, segmentation_maps=segmentation_maps, **kwargs)
|
| 581 |
+
|
| 582 |
+
def _preprocess(
|
| 583 |
+
self,
|
| 584 |
+
image: ImageInput,
|
| 585 |
+
do_resize: Optional[bool] = None,
|
| 586 |
+
size: Dict[str, int] = None,
|
| 587 |
+
resample: PILImageResampling = None,
|
| 588 |
+
do_rescale: Optional[bool] = None,
|
| 589 |
+
rescale_factor: Optional[float] = None,
|
| 590 |
+
do_normalize: Optional[bool] = None,
|
| 591 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 592 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 593 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 594 |
+
):
|
| 595 |
+
if do_resize:
|
| 596 |
+
image = self.resize(image, size=size, resample=resample, input_data_format=input_data_format)
|
| 597 |
+
if do_rescale:
|
| 598 |
+
image = self.rescale(image, rescale_factor=rescale_factor, input_data_format=input_data_format)
|
| 599 |
+
if do_normalize:
|
| 600 |
+
image = self.normalize(image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
| 601 |
+
return image
|
| 602 |
+
|
| 603 |
+
def _preprocess_image(
|
| 604 |
+
self,
|
| 605 |
+
image: ImageInput,
|
| 606 |
+
do_resize: Optional[bool] = None,
|
| 607 |
+
size: Dict[str, int] = None,
|
| 608 |
+
resample: PILImageResampling = None,
|
| 609 |
+
do_rescale: Optional[bool] = None,
|
| 610 |
+
rescale_factor: Optional[float] = None,
|
| 611 |
+
do_normalize: Optional[bool] = None,
|
| 612 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 613 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 614 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 615 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 616 |
+
) -> np.ndarray:
|
| 617 |
+
"""Preprocesses a single image."""
|
| 618 |
+
# All transformations expect numpy arrays.
|
| 619 |
+
image = to_numpy_array(image)
|
| 620 |
+
if do_rescale and is_scaled_image(image):
|
| 621 |
+
logger.warning_once(
|
| 622 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
| 623 |
+
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
| 624 |
+
)
|
| 625 |
+
if input_data_format is None:
|
| 626 |
+
input_data_format = infer_channel_dimension_format(image)
|
| 627 |
+
image = self._preprocess(
|
| 628 |
+
image=image,
|
| 629 |
+
do_resize=do_resize,
|
| 630 |
+
size=size,
|
| 631 |
+
resample=resample,
|
| 632 |
+
do_rescale=do_rescale,
|
| 633 |
+
rescale_factor=rescale_factor,
|
| 634 |
+
do_normalize=do_normalize,
|
| 635 |
+
image_mean=image_mean,
|
| 636 |
+
image_std=image_std,
|
| 637 |
+
input_data_format=input_data_format,
|
| 638 |
+
)
|
| 639 |
+
if data_format is not None:
|
| 640 |
+
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
| 641 |
+
return image
|
| 642 |
+
|
| 643 |
+
def _preprocess_mask(
|
| 644 |
+
self,
|
| 645 |
+
segmentation_map: ImageInput,
|
| 646 |
+
do_resize: Optional[bool] = None,
|
| 647 |
+
size: Dict[str, int] = None,
|
| 648 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 649 |
+
) -> np.ndarray:
|
| 650 |
+
"""Preprocesses a single mask."""
|
| 651 |
+
segmentation_map = to_numpy_array(segmentation_map)
|
| 652 |
+
# Add channel dimension if missing - needed for certain transformations
|
| 653 |
+
if segmentation_map.ndim == 2:
|
| 654 |
+
added_channel_dim = True
|
| 655 |
+
segmentation_map = segmentation_map[None, ...]
|
| 656 |
+
input_data_format = ChannelDimension.FIRST
|
| 657 |
+
else:
|
| 658 |
+
added_channel_dim = False
|
| 659 |
+
if input_data_format is None:
|
| 660 |
+
input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1)
|
| 661 |
+
# TODO: (Amy)
|
| 662 |
+
# Remork segmentation map processing to include reducing labels and resizing which doesn't
|
| 663 |
+
# drop segment IDs > 255.
|
| 664 |
+
segmentation_map = self._preprocess(
|
| 665 |
+
image=segmentation_map,
|
| 666 |
+
do_resize=do_resize,
|
| 667 |
+
resample=PILImageResampling.NEAREST,
|
| 668 |
+
size=size,
|
| 669 |
+
do_rescale=False,
|
| 670 |
+
do_normalize=False,
|
| 671 |
+
input_data_format=input_data_format,
|
| 672 |
+
)
|
| 673 |
+
# Remove extra channel dimension if added for processing
|
| 674 |
+
if added_channel_dim:
|
| 675 |
+
segmentation_map = segmentation_map.squeeze(0)
|
| 676 |
+
return segmentation_map
|
| 677 |
+
|
| 678 |
+
@filter_out_non_signature_kwargs()
|
| 679 |
+
def preprocess(
|
| 680 |
+
self,
|
| 681 |
+
images: ImageInput,
|
| 682 |
+
task_inputs: Optional[List[str]] = None,
|
| 683 |
+
segmentation_maps: Optional[ImageInput] = None,
|
| 684 |
+
instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
|
| 685 |
+
do_resize: Optional[bool] = None,
|
| 686 |
+
size: Optional[Dict[str, int]] = None,
|
| 687 |
+
resample: PILImageResampling = None,
|
| 688 |
+
do_rescale: Optional[bool] = None,
|
| 689 |
+
rescale_factor: Optional[float] = None,
|
| 690 |
+
do_normalize: Optional[bool] = None,
|
| 691 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 692 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 693 |
+
ignore_index: Optional[int] = None,
|
| 694 |
+
do_reduce_labels: Optional[bool] = None,
|
| 695 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 696 |
+
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
| 697 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 698 |
+
) -> BatchFeature:
|
| 699 |
+
if task_inputs is None:
|
| 700 |
+
# Default value
|
| 701 |
+
task_inputs = ["panoptic"]
|
| 702 |
+
|
| 703 |
+
do_resize = do_resize if do_resize is not None else self.do_resize
|
| 704 |
+
size = size if size is not None else self.size
|
| 705 |
+
size = get_size_dict(size, default_to_square=False, max_size=self._max_size)
|
| 706 |
+
resample = resample if resample is not None else self.resample
|
| 707 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
| 708 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
| 709 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
| 710 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
| 711 |
+
image_std = image_std if image_std is not None else self.image_std
|
| 712 |
+
ignore_index = ignore_index if ignore_index is not None else self.ignore_index
|
| 713 |
+
do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels
|
| 714 |
+
|
| 715 |
+
if not valid_images(images):
|
| 716 |
+
raise ValueError(
|
| 717 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 718 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
validate_preprocess_arguments(
|
| 722 |
+
do_rescale=do_rescale,
|
| 723 |
+
rescale_factor=rescale_factor,
|
| 724 |
+
do_normalize=do_normalize,
|
| 725 |
+
image_mean=image_mean,
|
| 726 |
+
image_std=image_std,
|
| 727 |
+
do_resize=do_resize,
|
| 728 |
+
size=size,
|
| 729 |
+
resample=resample,
|
| 730 |
+
)
|
| 731 |
+
|
| 732 |
+
if segmentation_maps is not None and not valid_images(segmentation_maps):
|
| 733 |
+
raise ValueError(
|
| 734 |
+
"Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 735 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 736 |
+
)
|
| 737 |
+
|
| 738 |
+
images = make_list_of_images(images)
|
| 739 |
+
if segmentation_maps is not None:
|
| 740 |
+
segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2)
|
| 741 |
+
|
| 742 |
+
if segmentation_maps is not None and len(images) != len(segmentation_maps):
|
| 743 |
+
raise ValueError("Images and segmentation maps must have the same length.")
|
| 744 |
+
|
| 745 |
+
images = [
|
| 746 |
+
self._preprocess_image(
|
| 747 |
+
image,
|
| 748 |
+
do_resize=do_resize,
|
| 749 |
+
size=size,
|
| 750 |
+
resample=resample,
|
| 751 |
+
do_rescale=do_rescale,
|
| 752 |
+
rescale_factor=rescale_factor,
|
| 753 |
+
do_normalize=do_normalize,
|
| 754 |
+
image_mean=image_mean,
|
| 755 |
+
image_std=image_std,
|
| 756 |
+
data_format=data_format,
|
| 757 |
+
input_data_format=input_data_format,
|
| 758 |
+
)
|
| 759 |
+
for image in images
|
| 760 |
+
]
|
| 761 |
+
|
| 762 |
+
if segmentation_maps is not None:
|
| 763 |
+
segmentation_maps = [
|
| 764 |
+
self._preprocess_mask(segmentation_map, do_resize, size, input_data_format=input_data_format)
|
| 765 |
+
for segmentation_map in segmentation_maps
|
| 766 |
+
]
|
| 767 |
+
encoded_inputs = self.encode_inputs(
|
| 768 |
+
images,
|
| 769 |
+
task_inputs,
|
| 770 |
+
segmentation_maps,
|
| 771 |
+
instance_id_to_semantic_id,
|
| 772 |
+
ignore_index,
|
| 773 |
+
do_reduce_labels,
|
| 774 |
+
return_tensors,
|
| 775 |
+
input_data_format=data_format,
|
| 776 |
+
)
|
| 777 |
+
return encoded_inputs
|
| 778 |
+
|
| 779 |
+
# Copied from transformers.models.vilt.image_processing_vilt.ViltImageProcessor._pad_image
|
| 780 |
+
def _pad_image(
|
| 781 |
+
self,
|
| 782 |
+
image: np.ndarray,
|
| 783 |
+
output_size: Tuple[int, int],
|
| 784 |
+
constant_values: Union[float, Iterable[float]] = 0,
|
| 785 |
+
data_format: Optional[ChannelDimension] = None,
|
| 786 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 787 |
+
) -> np.ndarray:
|
| 788 |
+
"""
|
| 789 |
+
Pad an image with zeros to the given size.
|
| 790 |
+
"""
|
| 791 |
+
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
| 792 |
+
output_height, output_width = output_size
|
| 793 |
+
|
| 794 |
+
pad_bottom = output_height - input_height
|
| 795 |
+
pad_right = output_width - input_width
|
| 796 |
+
padding = ((0, pad_bottom), (0, pad_right))
|
| 797 |
+
padded_image = pad(
|
| 798 |
+
image,
|
| 799 |
+
padding,
|
| 800 |
+
mode=PaddingMode.CONSTANT,
|
| 801 |
+
constant_values=constant_values,
|
| 802 |
+
data_format=data_format,
|
| 803 |
+
input_data_format=input_data_format,
|
| 804 |
+
)
|
| 805 |
+
return padded_image
|
| 806 |
+
|
| 807 |
+
# Copied from transformers.models.vilt.image_processing_vilt.ViltImageProcessor.pad
|
| 808 |
+
def pad(
|
| 809 |
+
self,
|
| 810 |
+
images: List[np.ndarray],
|
| 811 |
+
constant_values: Union[float, Iterable[float]] = 0,
|
| 812 |
+
return_pixel_mask: bool = True,
|
| 813 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 814 |
+
data_format: Optional[ChannelDimension] = None,
|
| 815 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 816 |
+
) -> BatchFeature:
|
| 817 |
+
"""
|
| 818 |
+
Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
|
| 819 |
+
in the batch and optionally returns their corresponding pixel mask.
|
| 820 |
+
|
| 821 |
+
Args:
|
| 822 |
+
image (`np.ndarray`):
|
| 823 |
+
Image to pad.
|
| 824 |
+
constant_values (`float` or `Iterable[float]`, *optional*):
|
| 825 |
+
The value to use for the padding if `mode` is `"constant"`.
|
| 826 |
+
return_pixel_mask (`bool`, *optional*, defaults to `True`):
|
| 827 |
+
Whether to return a pixel mask.
|
| 828 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
| 829 |
+
The type of tensors to return. Can be one of:
|
| 830 |
+
- Unset: Return a list of `np.ndarray`.
|
| 831 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
| 832 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
| 833 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
| 834 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
| 835 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
| 836 |
+
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
| 837 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 838 |
+
The channel dimension format of the input image. If not provided, it will be inferred.
|
| 839 |
+
"""
|
| 840 |
+
pad_size = get_max_height_width(images, input_data_format=input_data_format)
|
| 841 |
+
|
| 842 |
+
padded_images = [
|
| 843 |
+
self._pad_image(
|
| 844 |
+
image,
|
| 845 |
+
pad_size,
|
| 846 |
+
constant_values=constant_values,
|
| 847 |
+
data_format=data_format,
|
| 848 |
+
input_data_format=input_data_format,
|
| 849 |
+
)
|
| 850 |
+
for image in images
|
| 851 |
+
]
|
| 852 |
+
data = {"pixel_values": padded_images}
|
| 853 |
+
|
| 854 |
+
if return_pixel_mask:
|
| 855 |
+
masks = [
|
| 856 |
+
make_pixel_mask(image=image, output_size=pad_size, input_data_format=input_data_format)
|
| 857 |
+
for image in images
|
| 858 |
+
]
|
| 859 |
+
data["pixel_mask"] = masks
|
| 860 |
+
|
| 861 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
| 862 |
+
|
| 863 |
+
def get_semantic_annotations(self, label, num_class_obj):
|
| 864 |
+
annotation_classes = label["classes"]
|
| 865 |
+
annotation_masks = label["masks"]
|
| 866 |
+
|
| 867 |
+
texts = ["a semantic photo"] * self.num_text
|
| 868 |
+
classes = []
|
| 869 |
+
masks = []
|
| 870 |
+
|
| 871 |
+
for idx in range(len(annotation_classes)):
|
| 872 |
+
class_id = annotation_classes[idx]
|
| 873 |
+
mask = annotation_masks[idx]
|
| 874 |
+
if not np.all(mask is False):
|
| 875 |
+
if class_id not in classes:
|
| 876 |
+
cls_name = self.metadata[str(class_id)]
|
| 877 |
+
classes.append(class_id)
|
| 878 |
+
masks.append(mask)
|
| 879 |
+
num_class_obj[cls_name] += 1
|
| 880 |
+
else:
|
| 881 |
+
idx = classes.index(class_id)
|
| 882 |
+
masks[idx] += mask
|
| 883 |
+
masks[idx] = np.clip(masks[idx], 0, 1)
|
| 884 |
+
|
| 885 |
+
num = 0
|
| 886 |
+
for i, cls_name in enumerate(self.metadata["class_names"]):
|
| 887 |
+
if num_class_obj[cls_name] > 0:
|
| 888 |
+
for _ in range(num_class_obj[cls_name]):
|
| 889 |
+
if num >= len(texts):
|
| 890 |
+
break
|
| 891 |
+
texts[num] = f"a photo with a {cls_name}"
|
| 892 |
+
num += 1
|
| 893 |
+
|
| 894 |
+
classes = np.array(classes)
|
| 895 |
+
masks = np.array(masks)
|
| 896 |
+
return classes, masks, texts
|
| 897 |
+
|
| 898 |
+
def get_instance_annotations(self, label, num_class_obj):
|
| 899 |
+
annotation_classes = label["classes"]
|
| 900 |
+
annotation_masks = label["masks"]
|
| 901 |
+
|
| 902 |
+
texts = ["an instance photo"] * self.num_text
|
| 903 |
+
classes = []
|
| 904 |
+
masks = []
|
| 905 |
+
|
| 906 |
+
for idx in range(len(annotation_classes)):
|
| 907 |
+
class_id = annotation_classes[idx]
|
| 908 |
+
mask = annotation_masks[idx]
|
| 909 |
+
|
| 910 |
+
if class_id in self.metadata["thing_ids"]:
|
| 911 |
+
if not np.all(mask is False):
|
| 912 |
+
cls_name = self.metadata[str(class_id)]
|
| 913 |
+
classes.append(class_id)
|
| 914 |
+
masks.append(mask)
|
| 915 |
+
num_class_obj[cls_name] += 1
|
| 916 |
+
|
| 917 |
+
num = 0
|
| 918 |
+
for i, cls_name in enumerate(self.metadata["class_names"]):
|
| 919 |
+
if num_class_obj[cls_name] > 0:
|
| 920 |
+
for _ in range(num_class_obj[cls_name]):
|
| 921 |
+
if num >= len(texts):
|
| 922 |
+
break
|
| 923 |
+
texts[num] = f"a photo with a {cls_name}"
|
| 924 |
+
num += 1
|
| 925 |
+
|
| 926 |
+
classes = np.array(classes)
|
| 927 |
+
masks = np.array(masks)
|
| 928 |
+
return classes, masks, texts
|
| 929 |
+
|
| 930 |
+
def get_panoptic_annotations(self, label, num_class_obj):
|
| 931 |
+
annotation_classes = label["classes"]
|
| 932 |
+
annotation_masks = label["masks"]
|
| 933 |
+
|
| 934 |
+
texts = ["an panoptic photo"] * self.num_text
|
| 935 |
+
classes = []
|
| 936 |
+
masks = []
|
| 937 |
+
|
| 938 |
+
for idx in range(len(annotation_classes)):
|
| 939 |
+
class_id = annotation_classes[idx]
|
| 940 |
+
mask = annotation_masks[idx].data
|
| 941 |
+
if not np.all(mask is False):
|
| 942 |
+
cls_name = self.metadata[str(class_id)]
|
| 943 |
+
classes.append(class_id)
|
| 944 |
+
masks.append(mask)
|
| 945 |
+
num_class_obj[cls_name] += 1
|
| 946 |
+
|
| 947 |
+
num = 0
|
| 948 |
+
for i, cls_name in enumerate(self.metadata["class_names"]):
|
| 949 |
+
if num_class_obj[cls_name] > 0:
|
| 950 |
+
for _ in range(num_class_obj[cls_name]):
|
| 951 |
+
if num >= len(texts):
|
| 952 |
+
break
|
| 953 |
+
texts[num] = f"a photo with a {cls_name}"
|
| 954 |
+
num += 1
|
| 955 |
+
|
| 956 |
+
classes = np.array(classes)
|
| 957 |
+
masks = np.array(masks)
|
| 958 |
+
return classes, masks, texts
|
| 959 |
+
|
| 960 |
+
def encode_inputs(
|
| 961 |
+
self,
|
| 962 |
+
pixel_values_list: List[ImageInput],
|
| 963 |
+
task_inputs: List[str],
|
| 964 |
+
segmentation_maps: ImageInput = None,
|
| 965 |
+
instance_id_to_semantic_id: Optional[Union[List[Dict[int, int]], Dict[int, int]]] = None,
|
| 966 |
+
ignore_index: Optional[int] = None,
|
| 967 |
+
do_reduce_labels: bool = False,
|
| 968 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 969 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 970 |
+
):
|
| 971 |
+
"""
|
| 972 |
+
Pad images up to the largest image in a batch and create a corresponding `pixel_mask`.
|
| 973 |
+
|
| 974 |
+
OneFormer addresses semantic segmentation with a mask classification paradigm, thus input segmentation maps
|
| 975 |
+
will be converted to lists of binary masks and their respective labels. Let's see an example, assuming
|
| 976 |
+
`segmentation_maps = [[2,6,7,9]]`, the output will contain `mask_labels =
|
| 977 |
+
[[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]]` (four binary masks) and `class_labels = [2,6,7,9]`, the labels for
|
| 978 |
+
each mask.
|
| 979 |
+
|
| 980 |
+
Args:
|
| 981 |
+
pixel_values_list (`List[ImageInput]`):
|
| 982 |
+
List of images (pixel values) to be padded. Each image should be a tensor of shape `(channels, height,
|
| 983 |
+
width)`.
|
| 984 |
+
|
| 985 |
+
task_inputs (`List[str]`):
|
| 986 |
+
List of task values.
|
| 987 |
+
|
| 988 |
+
segmentation_maps (`ImageInput`, *optional*):
|
| 989 |
+
The corresponding semantic segmentation maps with the pixel-wise annotations.
|
| 990 |
+
|
| 991 |
+
(`bool`, *optional*, defaults to `True`):
|
| 992 |
+
Whether or not to pad images up to the largest image in a batch and create a pixel mask.
|
| 993 |
+
|
| 994 |
+
If left to the default, will return a pixel mask that is:
|
| 995 |
+
|
| 996 |
+
- 1 for pixels that are real (i.e. **not masked**),
|
| 997 |
+
- 0 for pixels that are padding (i.e. **masked**).
|
| 998 |
+
|
| 999 |
+
instance_id_to_semantic_id (`List[Dict[int, int]]` or `Dict[int, int]`, *optional*):
|
| 1000 |
+
A mapping between object instance ids and class ids. If passed, `segmentation_maps` is treated as an
|
| 1001 |
+
instance segmentation map where each pixel represents an instance id. Can be provided as a single
|
| 1002 |
+
dictionary with a global/dataset-level mapping or as a list of dictionaries (one per image), to map
|
| 1003 |
+
instance ids in each image separately.
|
| 1004 |
+
|
| 1005 |
+
return_tensors (`str` or [`~file_utils.TensorType`], *optional*):
|
| 1006 |
+
If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor`
|
| 1007 |
+
objects.
|
| 1008 |
+
|
| 1009 |
+
input_data_format (`str` or `ChannelDimension`, *optional*):
|
| 1010 |
+
The channel dimension format of the input image. If not provided, it will be inferred from the input
|
| 1011 |
+
image.
|
| 1012 |
+
|
| 1013 |
+
Returns:
|
| 1014 |
+
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
| 1015 |
+
|
| 1016 |
+
- **pixel_values** -- Pixel values to be fed to a model.
|
| 1017 |
+
- **pixel_mask** -- Pixel mask to be fed to a model (when `=True` or if `pixel_mask` is in
|
| 1018 |
+
`self.model_input_names`).
|
| 1019 |
+
- **mask_labels** -- Optional list of mask labels of shape `(labels, height, width)` to be fed to a model
|
| 1020 |
+
(when `annotations` are provided).
|
| 1021 |
+
- **class_labels** -- Optional list of class labels of shape `(labels)` to be fed to a model (when
|
| 1022 |
+
`annotations` are provided). They identify the labels of `mask_labels`, e.g. the label of
|
| 1023 |
+
`mask_labels[i][j]` if `class_labels[i][j]`.
|
| 1024 |
+
- **text_inputs** -- Optional list of text string entries to be fed to a model (when `annotations` are
|
| 1025 |
+
provided). They identify the binary masks present in the image.
|
| 1026 |
+
"""
|
| 1027 |
+
ignore_index = self.ignore_index if ignore_index is None else ignore_index
|
| 1028 |
+
do_reduce_labels = self.do_reduce_labels if do_reduce_labels is None else do_reduce_labels
|
| 1029 |
+
pixel_values_list = [to_numpy_array(pixel_values) for pixel_values in pixel_values_list]
|
| 1030 |
+
|
| 1031 |
+
if input_data_format is None:
|
| 1032 |
+
input_data_format = infer_channel_dimension_format(pixel_values_list[0])
|
| 1033 |
+
|
| 1034 |
+
pad_size = get_max_height_width(pixel_values_list, input_data_format=input_data_format)
|
| 1035 |
+
encoded_inputs = self.pad(
|
| 1036 |
+
pixel_values_list, return_tensors=return_tensors, input_data_format=input_data_format
|
| 1037 |
+
)
|
| 1038 |
+
|
| 1039 |
+
annotations = None
|
| 1040 |
+
if segmentation_maps is not None:
|
| 1041 |
+
segmentation_maps = map(np.array, segmentation_maps)
|
| 1042 |
+
annotations = []
|
| 1043 |
+
for idx, segmentation_map in enumerate(segmentation_maps):
|
| 1044 |
+
# Use instance2class_id mapping per image
|
| 1045 |
+
if isinstance(instance_id_to_semantic_id, list):
|
| 1046 |
+
instance_id = instance_id_to_semantic_id[idx]
|
| 1047 |
+
else:
|
| 1048 |
+
instance_id = instance_id_to_semantic_id
|
| 1049 |
+
# Use instance2class_id mapping per image
|
| 1050 |
+
masks, classes = self.convert_segmentation_map_to_binary_masks(
|
| 1051 |
+
segmentation_map, instance_id, ignore_index=ignore_index, do_reduce_labels=do_reduce_labels
|
| 1052 |
+
)
|
| 1053 |
+
annotations.append({"masks": masks, "classes": classes})
|
| 1054 |
+
|
| 1055 |
+
if annotations is not None:
|
| 1056 |
+
mask_labels = []
|
| 1057 |
+
class_labels = []
|
| 1058 |
+
text_inputs = []
|
| 1059 |
+
|
| 1060 |
+
num_class_obj = {}
|
| 1061 |
+
for cls_name in self.metadata["class_names"]:
|
| 1062 |
+
num_class_obj[cls_name] = 0
|
| 1063 |
+
|
| 1064 |
+
for i, label in enumerate(annotations):
|
| 1065 |
+
task = task_inputs[i]
|
| 1066 |
+
if task == "semantic":
|
| 1067 |
+
classes, masks, texts = self.get_semantic_annotations(label, num_class_obj)
|
| 1068 |
+
elif task == "instance":
|
| 1069 |
+
classes, masks, texts = self.get_instance_annotations(label, num_class_obj)
|
| 1070 |
+
elif task == "panoptic":
|
| 1071 |
+
classes, masks, texts = self.get_panoptic_annotations(label, num_class_obj)
|
| 1072 |
+
else:
|
| 1073 |
+
raise ValueError(f"{task} was not expected, expected `semantic`, `instance` or `panoptic`")
|
| 1074 |
+
|
| 1075 |
+
# we cannot batch them since they don't share a common class size
|
| 1076 |
+
masks = [mask[None, ...] for mask in masks]
|
| 1077 |
+
masks = [
|
| 1078 |
+
self._pad_image(image=mask, output_size=pad_size, constant_values=ignore_index) for mask in masks
|
| 1079 |
+
]
|
| 1080 |
+
masks = np.concatenate(masks, axis=0)
|
| 1081 |
+
mask_labels.append(torch.from_numpy(masks))
|
| 1082 |
+
class_labels.append(torch.from_numpy(classes).long())
|
| 1083 |
+
text_inputs.append(texts)
|
| 1084 |
+
|
| 1085 |
+
encoded_inputs["mask_labels"] = mask_labels
|
| 1086 |
+
encoded_inputs["class_labels"] = class_labels
|
| 1087 |
+
encoded_inputs["text_inputs"] = text_inputs
|
| 1088 |
+
|
| 1089 |
+
# This needs to be tokenized before sending to the model.
|
| 1090 |
+
encoded_inputs["task_inputs"] = [f"the task is {task_input}" for task_input in task_inputs]
|
| 1091 |
+
|
| 1092 |
+
return encoded_inputs
|
| 1093 |
+
|
| 1094 |
+
# Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.post_process_semantic_segmentation
|
| 1095 |
+
def post_process_semantic_segmentation(
|
| 1096 |
+
self, outputs, target_sizes: Optional[List[Tuple[int, int]]] = None
|
| 1097 |
+
) -> "torch.Tensor":
|
| 1098 |
+
"""
|
| 1099 |
+
Converts the output of [`MaskFormerForInstanceSegmentation`] into semantic segmentation maps. Only supports
|
| 1100 |
+
PyTorch.
|
| 1101 |
+
|
| 1102 |
+
Args:
|
| 1103 |
+
outputs ([`MaskFormerForInstanceSegmentation`]):
|
| 1104 |
+
Raw outputs of the model.
|
| 1105 |
+
target_sizes (`List[Tuple[int, int]]`, *optional*):
|
| 1106 |
+
List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested
|
| 1107 |
+
final size (height, width) of each prediction. If left to None, predictions will not be resized.
|
| 1108 |
+
Returns:
|
| 1109 |
+
`List[torch.Tensor]`:
|
| 1110 |
+
A list of length `batch_size`, where each item is a semantic segmentation map of shape (height, width)
|
| 1111 |
+
corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each
|
| 1112 |
+
`torch.Tensor` correspond to a semantic class id.
|
| 1113 |
+
"""
|
| 1114 |
+
class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1]
|
| 1115 |
+
masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width]
|
| 1116 |
+
|
| 1117 |
+
# Remove the null class `[..., :-1]`
|
| 1118 |
+
masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1]
|
| 1119 |
+
masks_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
|
| 1120 |
+
|
| 1121 |
+
# Semantic segmentation logits of shape (batch_size, num_classes, height, width)
|
| 1122 |
+
segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
|
| 1123 |
+
batch_size = class_queries_logits.shape[0]
|
| 1124 |
+
|
| 1125 |
+
# Resize logits and compute semantic segmentation maps
|
| 1126 |
+
if target_sizes is not None:
|
| 1127 |
+
if batch_size != len(target_sizes):
|
| 1128 |
+
raise ValueError(
|
| 1129 |
+
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
| 1130 |
+
)
|
| 1131 |
+
|
| 1132 |
+
semantic_segmentation = []
|
| 1133 |
+
for idx in range(batch_size):
|
| 1134 |
+
resized_logits = torch.nn.functional.interpolate(
|
| 1135 |
+
segmentation[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
|
| 1136 |
+
)
|
| 1137 |
+
semantic_map = resized_logits[0].argmax(dim=0)
|
| 1138 |
+
semantic_segmentation.append(semantic_map)
|
| 1139 |
+
else:
|
| 1140 |
+
semantic_segmentation = segmentation.argmax(dim=1)
|
| 1141 |
+
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
|
| 1142 |
+
|
| 1143 |
+
return semantic_segmentation
|
| 1144 |
+
|
| 1145 |
+
def post_process_instance_segmentation(
|
| 1146 |
+
self,
|
| 1147 |
+
outputs,
|
| 1148 |
+
task_type: str = "instance",
|
| 1149 |
+
is_demo: bool = True,
|
| 1150 |
+
threshold: float = 0.5,
|
| 1151 |
+
mask_threshold: float = 0.5,
|
| 1152 |
+
overlap_mask_area_threshold: float = 0.8,
|
| 1153 |
+
target_sizes: Optional[List[Tuple[int, int]]] = None,
|
| 1154 |
+
return_coco_annotation: Optional[bool] = False,
|
| 1155 |
+
):
|
| 1156 |
+
"""
|
| 1157 |
+
Converts the output of [`OneFormerForUniversalSegmentationOutput`] into image instance segmentation
|
| 1158 |
+
predictions. Only supports PyTorch.
|
| 1159 |
+
|
| 1160 |
+
Args:
|
| 1161 |
+
outputs ([`OneFormerForUniversalSegmentationOutput`]):
|
| 1162 |
+
The outputs from [`OneFormerForUniversalSegmentationOutput`].
|
| 1163 |
+
task_type (`str`, *optional*, defaults to "instance"):
|
| 1164 |
+
The post processing depends on the task token input. If the `task_type` is "panoptic", we need to
|
| 1165 |
+
ignore the stuff predictions.
|
| 1166 |
+
is_demo (`bool`, *optional)*, defaults to `True`):
|
| 1167 |
+
Whether the model is in demo mode. If true, use threshold to predict final masks.
|
| 1168 |
+
threshold (`float`, *optional*, defaults to 0.5):
|
| 1169 |
+
The probability score threshold to keep predicted instance masks.
|
| 1170 |
+
mask_threshold (`float`, *optional*, defaults to 0.5):
|
| 1171 |
+
Threshold to use when turning the predicted masks into binary values.
|
| 1172 |
+
overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
|
| 1173 |
+
The overlap mask area threshold to merge or discard small disconnected parts within each binary
|
| 1174 |
+
instance mask.
|
| 1175 |
+
target_sizes (`List[Tuple]`, *optional*):
|
| 1176 |
+
List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested
|
| 1177 |
+
final size (height, width) of each prediction in batch. If left to None, predictions will not be
|
| 1178 |
+
resized.
|
| 1179 |
+
return_coco_annotation (`bool`, *optional)*, defaults to `False`):
|
| 1180 |
+
Whether to return predictions in COCO format.
|
| 1181 |
+
|
| 1182 |
+
Returns:
|
| 1183 |
+
`List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
|
| 1184 |
+
- **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id`, set
|
| 1185 |
+
to `None` if no mask if found above `threshold`. If `target_sizes` is specified, segmentation is resized
|
| 1186 |
+
to the corresponding `target_sizes` entry.
|
| 1187 |
+
- **segments_info** -- A dictionary that contains additional information on each segment.
|
| 1188 |
+
- **id** -- an integer representing the `segment_id`.
|
| 1189 |
+
- **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
|
| 1190 |
+
- **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise.
|
| 1191 |
+
Multiple instances of the same class / label were fused and assigned a single `segment_id`.
|
| 1192 |
+
- **score** -- Prediction score of segment with `segment_id`.
|
| 1193 |
+
"""
|
| 1194 |
+
class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1]
|
| 1195 |
+
masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width]
|
| 1196 |
+
|
| 1197 |
+
device = masks_queries_logits.device
|
| 1198 |
+
batch_size = class_queries_logits.shape[0]
|
| 1199 |
+
num_queries = class_queries_logits.shape[1]
|
| 1200 |
+
num_classes = class_queries_logits.shape[-1] - 1
|
| 1201 |
+
|
| 1202 |
+
# Loop over items in batch size
|
| 1203 |
+
results: List[Dict[str, torch.Tensor]] = []
|
| 1204 |
+
|
| 1205 |
+
for i in range(batch_size):
|
| 1206 |
+
# [Q, K]
|
| 1207 |
+
scores = torch.nn.functional.softmax(class_queries_logits[i], dim=-1)[:, :-1]
|
| 1208 |
+
labels = torch.arange(num_classes, device=device).unsqueeze(0).repeat(num_queries, 1).flatten(0, 1)
|
| 1209 |
+
|
| 1210 |
+
# scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False)
|
| 1211 |
+
scores_per_image, topk_indices = scores.flatten(0, 1).topk(num_queries, sorted=False)
|
| 1212 |
+
labels_per_image = labels[topk_indices]
|
| 1213 |
+
|
| 1214 |
+
topk_indices = torch.div(topk_indices, num_classes, rounding_mode="floor")
|
| 1215 |
+
# mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1)
|
| 1216 |
+
mask_pred = masks_queries_logits[i][topk_indices]
|
| 1217 |
+
|
| 1218 |
+
# Only consider scores with confidence over [threshold] for demo
|
| 1219 |
+
if is_demo:
|
| 1220 |
+
keep = scores_per_image > threshold
|
| 1221 |
+
scores_per_image = scores_per_image[keep]
|
| 1222 |
+
labels_per_image = labels_per_image[keep]
|
| 1223 |
+
mask_pred = mask_pred[keep]
|
| 1224 |
+
|
| 1225 |
+
# if this is panoptic segmentation, we only keep the "thing" classes
|
| 1226 |
+
if task_type == "panoptic":
|
| 1227 |
+
keep = torch.zeros_like(scores_per_image).bool()
|
| 1228 |
+
for j, lab in enumerate(labels_per_image):
|
| 1229 |
+
keep[j] = lab in self.metadata["thing_ids"]
|
| 1230 |
+
|
| 1231 |
+
scores_per_image = scores_per_image[keep]
|
| 1232 |
+
labels_per_image = labels_per_image[keep]
|
| 1233 |
+
mask_pred = mask_pred[keep]
|
| 1234 |
+
|
| 1235 |
+
if mask_pred.shape[0] <= 0:
|
| 1236 |
+
height, width = target_sizes[i] if target_sizes is not None else mask_pred.shape[1:]
|
| 1237 |
+
segmentation = torch.zeros((height, width)) - 1
|
| 1238 |
+
results.append({"segmentation": segmentation, "segments_info": []})
|
| 1239 |
+
continue
|
| 1240 |
+
|
| 1241 |
+
if "ade20k" in self.class_info_file and not is_demo and "instance" in task_type:
|
| 1242 |
+
for j in range(labels_per_image.shape[0]):
|
| 1243 |
+
labels_per_image[j] = self.metadata["thing_ids"].index(labels_per_image[j].item())
|
| 1244 |
+
|
| 1245 |
+
# Get segmentation map and segment information of batch item
|
| 1246 |
+
target_size = target_sizes[i] if target_sizes is not None else None
|
| 1247 |
+
segmentation, segments = compute_segments(
|
| 1248 |
+
mask_pred,
|
| 1249 |
+
scores_per_image,
|
| 1250 |
+
labels_per_image,
|
| 1251 |
+
mask_threshold,
|
| 1252 |
+
overlap_mask_area_threshold,
|
| 1253 |
+
set(),
|
| 1254 |
+
target_size,
|
| 1255 |
+
)
|
| 1256 |
+
|
| 1257 |
+
# Return segmentation map in run-length encoding (RLE) format
|
| 1258 |
+
if return_coco_annotation:
|
| 1259 |
+
segmentation = convert_segmentation_to_rle(segmentation)
|
| 1260 |
+
|
| 1261 |
+
results.append({"segmentation": segmentation, "segments_info": segments})
|
| 1262 |
+
return results
|
| 1263 |
+
|
| 1264 |
+
# Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.post_process_panoptic_segmentation
|
| 1265 |
+
def post_process_panoptic_segmentation(
|
| 1266 |
+
self,
|
| 1267 |
+
outputs,
|
| 1268 |
+
threshold: float = 0.5,
|
| 1269 |
+
mask_threshold: float = 0.5,
|
| 1270 |
+
overlap_mask_area_threshold: float = 0.8,
|
| 1271 |
+
label_ids_to_fuse: Optional[Set[int]] = None,
|
| 1272 |
+
target_sizes: Optional[List[Tuple[int, int]]] = None,
|
| 1273 |
+
) -> List[Dict]:
|
| 1274 |
+
"""
|
| 1275 |
+
Converts the output of [`MaskFormerForInstanceSegmentationOutput`] into image panoptic segmentation
|
| 1276 |
+
predictions. Only supports PyTorch.
|
| 1277 |
+
|
| 1278 |
+
Args:
|
| 1279 |
+
outputs ([`MaskFormerForInstanceSegmentationOutput`]):
|
| 1280 |
+
The outputs from [`MaskFormerForInstanceSegmentation`].
|
| 1281 |
+
threshold (`float`, *optional*, defaults to 0.5):
|
| 1282 |
+
The probability score threshold to keep predicted instance masks.
|
| 1283 |
+
mask_threshold (`float`, *optional*, defaults to 0.5):
|
| 1284 |
+
Threshold to use when turning the predicted masks into binary values.
|
| 1285 |
+
overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
|
| 1286 |
+
The overlap mask area threshold to merge or discard small disconnected parts within each binary
|
| 1287 |
+
instance mask.
|
| 1288 |
+
label_ids_to_fuse (`Set[int]`, *optional*):
|
| 1289 |
+
The labels in this state will have all their instances be fused together. For instance we could say
|
| 1290 |
+
there can only be one sky in an image, but several persons, so the label ID for sky would be in that
|
| 1291 |
+
set, but not the one for person.
|
| 1292 |
+
target_sizes (`List[Tuple]`, *optional*):
|
| 1293 |
+
List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested
|
| 1294 |
+
final size (height, width) of each prediction in batch. If left to None, predictions will not be
|
| 1295 |
+
resized.
|
| 1296 |
+
|
| 1297 |
+
Returns:
|
| 1298 |
+
`List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
|
| 1299 |
+
- **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id`, set
|
| 1300 |
+
to `None` if no mask if found above `threshold`. If `target_sizes` is specified, segmentation is resized
|
| 1301 |
+
to the corresponding `target_sizes` entry.
|
| 1302 |
+
- **segments_info** -- A dictionary that contains additional information on each segment.
|
| 1303 |
+
- **id** -- an integer representing the `segment_id`.
|
| 1304 |
+
- **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
|
| 1305 |
+
- **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise.
|
| 1306 |
+
Multiple instances of the same class / label were fused and assigned a single `segment_id`.
|
| 1307 |
+
- **score** -- Prediction score of segment with `segment_id`.
|
| 1308 |
+
"""
|
| 1309 |
+
|
| 1310 |
+
if label_ids_to_fuse is None:
|
| 1311 |
+
logger.warning("`label_ids_to_fuse` unset. No instance will be fused.")
|
| 1312 |
+
label_ids_to_fuse = set()
|
| 1313 |
+
|
| 1314 |
+
class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1]
|
| 1315 |
+
masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width]
|
| 1316 |
+
|
| 1317 |
+
batch_size = class_queries_logits.shape[0]
|
| 1318 |
+
num_labels = class_queries_logits.shape[-1] - 1
|
| 1319 |
+
|
| 1320 |
+
mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
|
| 1321 |
+
|
| 1322 |
+
# Predicted label and score of each query (batch_size, num_queries)
|
| 1323 |
+
pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1)
|
| 1324 |
+
|
| 1325 |
+
# Loop over items in batch size
|
| 1326 |
+
results: List[Dict[str, TensorType]] = []
|
| 1327 |
+
|
| 1328 |
+
for i in range(batch_size):
|
| 1329 |
+
mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects(
|
| 1330 |
+
mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels
|
| 1331 |
+
)
|
| 1332 |
+
|
| 1333 |
+
# No mask found
|
| 1334 |
+
if mask_probs_item.shape[0] <= 0:
|
| 1335 |
+
height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:]
|
| 1336 |
+
segmentation = torch.zeros((height, width)) - 1
|
| 1337 |
+
results.append({"segmentation": segmentation, "segments_info": []})
|
| 1338 |
+
continue
|
| 1339 |
+
|
| 1340 |
+
# Get segmentation map and segment information of batch item
|
| 1341 |
+
target_size = target_sizes[i] if target_sizes is not None else None
|
| 1342 |
+
segmentation, segments = compute_segments(
|
| 1343 |
+
mask_probs=mask_probs_item,
|
| 1344 |
+
pred_scores=pred_scores_item,
|
| 1345 |
+
pred_labels=pred_labels_item,
|
| 1346 |
+
mask_threshold=mask_threshold,
|
| 1347 |
+
overlap_mask_area_threshold=overlap_mask_area_threshold,
|
| 1348 |
+
label_ids_to_fuse=label_ids_to_fuse,
|
| 1349 |
+
target_size=target_size,
|
| 1350 |
+
)
|
| 1351 |
+
|
| 1352 |
+
results.append({"segmentation": segmentation, "segments_info": segments})
|
| 1353 |
+
return results
|
| 1354 |
+
|
| 1355 |
+
|
| 1356 |
+
__all__ = ["OneFormerImageProcessor"]
|
docs/transformers/src/transformers/models/oneformer/modeling_oneformer.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
docs/transformers/src/transformers/models/oneformer/processing_oneformer.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 SHI Labs and The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
Image/Text processor class for OneFormer
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from typing import List
|
| 20 |
+
|
| 21 |
+
from ...processing_utils import ProcessorMixin
|
| 22 |
+
from ...utils import is_torch_available
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
if is_torch_available():
|
| 26 |
+
import torch
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class OneFormerProcessor(ProcessorMixin):
|
| 30 |
+
r"""
|
| 31 |
+
Constructs an OneFormer processor which wraps [`OneFormerImageProcessor`] and
|
| 32 |
+
[`CLIPTokenizer`]/[`CLIPTokenizerFast`] into a single processor that inherits both the image processor and
|
| 33 |
+
tokenizer functionalities.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
image_processor ([`OneFormerImageProcessor`]):
|
| 37 |
+
The image processor is a required input.
|
| 38 |
+
tokenizer ([`CLIPTokenizer`, `CLIPTokenizerFast`]):
|
| 39 |
+
The tokenizer is a required input.
|
| 40 |
+
max_seq_len (`int`, *optional*, defaults to 77)):
|
| 41 |
+
Sequence length for input text list.
|
| 42 |
+
task_seq_len (`int`, *optional*, defaults to 77):
|
| 43 |
+
Sequence length for input task token.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
attributes = ["image_processor", "tokenizer"]
|
| 47 |
+
image_processor_class = "OneFormerImageProcessor"
|
| 48 |
+
tokenizer_class = ("CLIPTokenizer", "CLIPTokenizerFast")
|
| 49 |
+
|
| 50 |
+
def __init__(
|
| 51 |
+
self, image_processor=None, tokenizer=None, max_seq_length: int = 77, task_seq_length: int = 77, **kwargs
|
| 52 |
+
):
|
| 53 |
+
if image_processor is None:
|
| 54 |
+
raise ValueError("You need to specify an `image_processor`.")
|
| 55 |
+
if tokenizer is None:
|
| 56 |
+
raise ValueError("You need to specify a `tokenizer`.")
|
| 57 |
+
|
| 58 |
+
self.max_seq_length = max_seq_length
|
| 59 |
+
self.task_seq_length = task_seq_length
|
| 60 |
+
|
| 61 |
+
super().__init__(image_processor, tokenizer)
|
| 62 |
+
|
| 63 |
+
def _preprocess_text(self, text_list=None, max_length=77):
|
| 64 |
+
if text_list is None:
|
| 65 |
+
raise ValueError("tokens cannot be None.")
|
| 66 |
+
|
| 67 |
+
tokens = self.tokenizer(text_list, padding="max_length", max_length=max_length, truncation=True)
|
| 68 |
+
|
| 69 |
+
attention_masks, input_ids = tokens["attention_mask"], tokens["input_ids"]
|
| 70 |
+
|
| 71 |
+
token_inputs = []
|
| 72 |
+
for attn_mask, input_id in zip(attention_masks, input_ids):
|
| 73 |
+
token = torch.tensor(attn_mask) * torch.tensor(input_id)
|
| 74 |
+
token_inputs.append(token.unsqueeze(0))
|
| 75 |
+
|
| 76 |
+
token_inputs = torch.cat(token_inputs, dim=0)
|
| 77 |
+
return token_inputs
|
| 78 |
+
|
| 79 |
+
def __call__(self, images=None, task_inputs=None, segmentation_maps=None, **kwargs):
|
| 80 |
+
"""
|
| 81 |
+
Main method to prepare for the model one or several task input(s) and image(s). This method forwards the
|
| 82 |
+
`task_inputs` and `kwargs` arguments to CLIPTokenizer's [`~CLIPTokenizer.__call__`] if `task_inputs` is not
|
| 83 |
+
`None` to encode. To prepare the image(s), this method forwards the `images` and `kwargs` arguments to
|
| 84 |
+
OneFormerImageProcessor's [`~OneFormerImageProcessor.__call__`] if `images` is not `None`. Please refer to the
|
| 85 |
+
docstring of the above two methods for more information.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
task_inputs (`str`, `List[str]`):
|
| 89 |
+
The sequence or batch of task_inputs sequences to be encoded. Each sequence can be a string or a list
|
| 90 |
+
of strings of the template "the task is {task}".
|
| 91 |
+
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`,
|
| 92 |
+
`List[torch.Tensor]`):
|
| 93 |
+
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
| 94 |
+
tensor. Both channels-first and channels-last formats are supported.
|
| 95 |
+
segmentation_maps (`ImageInput`, *optional*):
|
| 96 |
+
The corresponding semantic segmentation maps with the pixel-wise annotations.
|
| 97 |
+
|
| 98 |
+
(`bool`, *optional*, defaults to `True`):
|
| 99 |
+
Whether or not to pad images up to the largest image in a batch and create a pixel mask.
|
| 100 |
+
|
| 101 |
+
If left to the default, will return a pixel mask that is:
|
| 102 |
+
|
| 103 |
+
- 1 for pixels that are real (i.e. **not masked**),
|
| 104 |
+
- 0 for pixels that are padding (i.e. **masked**).
|
| 105 |
+
Returns:
|
| 106 |
+
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
| 107 |
+
- **task_inputs** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
| 108 |
+
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
if task_inputs is None:
|
| 112 |
+
raise ValueError("You have to specify the task_input. Found None.")
|
| 113 |
+
elif images is None:
|
| 114 |
+
raise ValueError("You have to specify the image. Found None.")
|
| 115 |
+
|
| 116 |
+
if not all(task in ["semantic", "instance", "panoptic"] for task in task_inputs):
|
| 117 |
+
raise ValueError("task_inputs must be semantic, instance, or panoptic.")
|
| 118 |
+
|
| 119 |
+
encoded_inputs = self.image_processor(images, task_inputs, segmentation_maps, **kwargs)
|
| 120 |
+
|
| 121 |
+
if isinstance(task_inputs, str):
|
| 122 |
+
task_inputs = [task_inputs]
|
| 123 |
+
|
| 124 |
+
if isinstance(task_inputs, List) and all(isinstance(task_input, str) for task_input in task_inputs):
|
| 125 |
+
task_token_inputs = []
|
| 126 |
+
for task in task_inputs:
|
| 127 |
+
task_input = f"the task is {task}"
|
| 128 |
+
task_token_inputs.append(task_input)
|
| 129 |
+
encoded_inputs["task_inputs"] = self._preprocess_text(task_token_inputs, max_length=self.task_seq_length)
|
| 130 |
+
else:
|
| 131 |
+
raise TypeError("Task Inputs should be a string or a list of strings.")
|
| 132 |
+
|
| 133 |
+
if hasattr(encoded_inputs, "text_inputs"):
|
| 134 |
+
texts_list = encoded_inputs.text_inputs
|
| 135 |
+
|
| 136 |
+
text_inputs = []
|
| 137 |
+
for texts in texts_list:
|
| 138 |
+
text_input_list = self._preprocess_text(texts, max_length=self.max_seq_length)
|
| 139 |
+
text_inputs.append(text_input_list.unsqueeze(0))
|
| 140 |
+
|
| 141 |
+
encoded_inputs["text_inputs"] = torch.cat(text_inputs, dim=0)
|
| 142 |
+
|
| 143 |
+
return encoded_inputs
|
| 144 |
+
|
| 145 |
+
def encode_inputs(self, images=None, task_inputs=None, segmentation_maps=None, **kwargs):
|
| 146 |
+
"""
|
| 147 |
+
This method forwards all its arguments to [`OneFormerImageProcessor.encode_inputs`] and then tokenizes the
|
| 148 |
+
task_inputs. Please refer to the docstring of this method for more information.
|
| 149 |
+
"""
|
| 150 |
+
|
| 151 |
+
if task_inputs is None:
|
| 152 |
+
raise ValueError("You have to specify the task_input. Found None.")
|
| 153 |
+
elif images is None:
|
| 154 |
+
raise ValueError("You have to specify the image. Found None.")
|
| 155 |
+
|
| 156 |
+
if not all(task in ["semantic", "instance", "panoptic"] for task in task_inputs):
|
| 157 |
+
raise ValueError("task_inputs must be semantic, instance, or panoptic.")
|
| 158 |
+
|
| 159 |
+
encoded_inputs = self.image_processor.encode_inputs(images, task_inputs, segmentation_maps, **kwargs)
|
| 160 |
+
|
| 161 |
+
if isinstance(task_inputs, str):
|
| 162 |
+
task_inputs = [task_inputs]
|
| 163 |
+
|
| 164 |
+
if isinstance(task_inputs, List) and all(isinstance(task_input, str) for task_input in task_inputs):
|
| 165 |
+
task_token_inputs = []
|
| 166 |
+
for task in task_inputs:
|
| 167 |
+
task_input = f"the task is {task}"
|
| 168 |
+
task_token_inputs.append(task_input)
|
| 169 |
+
encoded_inputs["task_inputs"] = self._preprocess_text(task_token_inputs, max_length=self.task_seq_length)
|
| 170 |
+
else:
|
| 171 |
+
raise TypeError("Task Inputs should be a string or a list of strings.")
|
| 172 |
+
|
| 173 |
+
if hasattr(encoded_inputs, "text_inputs"):
|
| 174 |
+
texts_list = encoded_inputs.text_inputs
|
| 175 |
+
|
| 176 |
+
text_inputs = []
|
| 177 |
+
for texts in texts_list:
|
| 178 |
+
text_input_list = self._preprocess_text(texts, max_length=self.max_seq_length)
|
| 179 |
+
text_inputs.append(text_input_list.unsqueeze(0))
|
| 180 |
+
|
| 181 |
+
encoded_inputs["text_inputs"] = torch.cat(text_inputs, dim=0)
|
| 182 |
+
|
| 183 |
+
return encoded_inputs
|
| 184 |
+
|
| 185 |
+
def post_process_semantic_segmentation(self, *args, **kwargs):
|
| 186 |
+
"""
|
| 187 |
+
This method forwards all its arguments to [`OneFormerImageProcessor.post_process_semantic_segmentation`].
|
| 188 |
+
Please refer to the docstring of this method for more information.
|
| 189 |
+
"""
|
| 190 |
+
return self.image_processor.post_process_semantic_segmentation(*args, **kwargs)
|
| 191 |
+
|
| 192 |
+
def post_process_instance_segmentation(self, *args, **kwargs):
|
| 193 |
+
"""
|
| 194 |
+
This method forwards all its arguments to [`OneFormerImageProcessor.post_process_instance_segmentation`].
|
| 195 |
+
Please refer to the docstring of this method for more information.
|
| 196 |
+
"""
|
| 197 |
+
return self.image_processor.post_process_instance_segmentation(*args, **kwargs)
|
| 198 |
+
|
| 199 |
+
def post_process_panoptic_segmentation(self, *args, **kwargs):
|
| 200 |
+
"""
|
| 201 |
+
This method forwards all its arguments to [`OneFormerImageProcessor.post_process_panoptic_segmentation`].
|
| 202 |
+
Please refer to the docstring of this method for more information.
|
| 203 |
+
"""
|
| 204 |
+
return self.image_processor.post_process_panoptic_segmentation(*args, **kwargs)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
__all__ = ["OneFormerProcessor"]
|
docs/transformers/src/transformers/models/openai/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_openai import *
|
| 22 |
+
from .modeling_openai import *
|
| 23 |
+
from .modeling_tf_openai import *
|
| 24 |
+
from .tokenization_openai import *
|
| 25 |
+
from .tokenization_openai_fast import *
|
| 26 |
+
else:
|
| 27 |
+
import sys
|
| 28 |
+
|
| 29 |
+
_file = globals()["__file__"]
|
| 30 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/src/transformers/models/openai/configuration_openai.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""OpenAI GPT configuration"""
|
| 17 |
+
|
| 18 |
+
from ...configuration_utils import PretrainedConfig
|
| 19 |
+
from ...utils import logging
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
logger = logging.get_logger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class OpenAIGPTConfig(PretrainedConfig):
|
| 26 |
+
"""
|
| 27 |
+
This is the configuration class to store the configuration of a [`OpenAIGPTModel`] or a [`TFOpenAIGPTModel`]. It is
|
| 28 |
+
used to instantiate a GPT model according to the specified arguments, defining the model architecture.
|
| 29 |
+
Instantiating a configuration with the defaults will yield a similar configuration to that of the GPT
|
| 30 |
+
[openai-community/openai-gpt](https://huggingface.co/openai-community/openai-gpt) architecture from OpenAI.
|
| 31 |
+
|
| 32 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 33 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
vocab_size (`int`, *optional*, defaults to 40478):
|
| 37 |
+
Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the
|
| 38 |
+
`inputs_ids` passed when calling [`OpenAIGPTModel`] or [`TFOpenAIGPTModel`].
|
| 39 |
+
n_positions (`int`, *optional*, defaults to 512):
|
| 40 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
| 41 |
+
just in case (e.g., 512 or 1024 or 2048).
|
| 42 |
+
n_embd (`int`, *optional*, defaults to 768):
|
| 43 |
+
Dimensionality of the embeddings and hidden states.
|
| 44 |
+
n_layer (`int`, *optional*, defaults to 12):
|
| 45 |
+
Number of hidden layers in the Transformer encoder.
|
| 46 |
+
n_head (`int`, *optional*, defaults to 12):
|
| 47 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 48 |
+
afn (`str` or `Callable`, *optional*, defaults to `"gelu"`):
|
| 49 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 50 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
| 51 |
+
resid_pdrop (`float`, *optional*, defaults to 0.1):
|
| 52 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 53 |
+
embd_pdrop (`int`, *optional*, defaults to 0.1):
|
| 54 |
+
The dropout ratio for the embeddings.
|
| 55 |
+
attn_pdrop (`float`, *optional*, defaults to 0.1):
|
| 56 |
+
The dropout ratio for the attention.
|
| 57 |
+
layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
|
| 58 |
+
The epsilon to use in the layer normalization layers
|
| 59 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 60 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 61 |
+
summary_type (`str`, *optional*, defaults to `"cls_index"`):
|
| 62 |
+
Argument used when doing sequence summary, used in the models [`OpenAIGPTDoubleHeadsModel`] and
|
| 63 |
+
[`OpenAIGPTDoubleHeadsModel`].
|
| 64 |
+
|
| 65 |
+
Has to be one of the following options:
|
| 66 |
+
|
| 67 |
+
- `"last"`: Take the last token hidden state (like XLNet).
|
| 68 |
+
- `"first"`: Take the first token hidden state (like BERT).
|
| 69 |
+
- `"mean"`: Take the mean of all tokens hidden states.
|
| 70 |
+
- `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2).
|
| 71 |
+
- `"attn"`: Not implemented now, use multi-head attention.
|
| 72 |
+
summary_use_proj (`bool`, *optional*, defaults to `True`):
|
| 73 |
+
Argument used when doing sequence summary, used in the models [`OpenAIGPTDoubleHeadsModel`] and
|
| 74 |
+
[`OpenAIGPTDoubleHeadsModel`].
|
| 75 |
+
|
| 76 |
+
Whether or not to add a projection after the vector extraction.
|
| 77 |
+
summary_activation (`str`, *optional*):
|
| 78 |
+
Argument used when doing sequence summary, used in the models [`OpenAIGPTDoubleHeadsModel`] and
|
| 79 |
+
[`OpenAIGPTDoubleHeadsModel`].
|
| 80 |
+
|
| 81 |
+
Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation.
|
| 82 |
+
summary_proj_to_labels (`bool`, *optional*, defaults to `True`):
|
| 83 |
+
Argument used when doing sequence summary, used in the models [`OpenAIGPTDoubleHeadsModel`] and
|
| 84 |
+
[`OpenAIGPTDoubleHeadsModel`].
|
| 85 |
+
|
| 86 |
+
Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes.
|
| 87 |
+
summary_first_dropout (`float`, *optional*, defaults to 0.1):
|
| 88 |
+
Argument used when doing sequence summary, used in the models [`OpenAIGPTDoubleHeadsModel`] and
|
| 89 |
+
[`OpenAIGPTDoubleHeadsModel`].
|
| 90 |
+
|
| 91 |
+
The dropout ratio to be used after the projection and activation.
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
Examples:
|
| 95 |
+
|
| 96 |
+
```python
|
| 97 |
+
>>> from transformers import OpenAIGPTConfig, OpenAIGPTModel
|
| 98 |
+
|
| 99 |
+
>>> # Initializing a GPT configuration
|
| 100 |
+
>>> configuration = OpenAIGPTConfig()
|
| 101 |
+
|
| 102 |
+
>>> # Initializing a model (with random weights) from the configuration
|
| 103 |
+
>>> model = OpenAIGPTModel(configuration)
|
| 104 |
+
|
| 105 |
+
>>> # Accessing the model configuration
|
| 106 |
+
>>> configuration = model.config
|
| 107 |
+
```"""
|
| 108 |
+
|
| 109 |
+
model_type = "openai-gpt"
|
| 110 |
+
attribute_map = {
|
| 111 |
+
"max_position_embeddings": "n_positions",
|
| 112 |
+
"hidden_size": "n_embd",
|
| 113 |
+
"num_attention_heads": "n_head",
|
| 114 |
+
"num_hidden_layers": "n_layer",
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
def __init__(
|
| 118 |
+
self,
|
| 119 |
+
vocab_size=40478,
|
| 120 |
+
n_positions=512,
|
| 121 |
+
n_embd=768,
|
| 122 |
+
n_layer=12,
|
| 123 |
+
n_head=12,
|
| 124 |
+
afn="gelu",
|
| 125 |
+
resid_pdrop=0.1,
|
| 126 |
+
embd_pdrop=0.1,
|
| 127 |
+
attn_pdrop=0.1,
|
| 128 |
+
layer_norm_epsilon=1e-5,
|
| 129 |
+
initializer_range=0.02,
|
| 130 |
+
summary_type="cls_index",
|
| 131 |
+
summary_use_proj=True,
|
| 132 |
+
summary_activation=None,
|
| 133 |
+
summary_proj_to_labels=True,
|
| 134 |
+
summary_first_dropout=0.1,
|
| 135 |
+
**kwargs,
|
| 136 |
+
):
|
| 137 |
+
self.vocab_size = vocab_size
|
| 138 |
+
self.n_positions = n_positions
|
| 139 |
+
self.n_embd = n_embd
|
| 140 |
+
self.n_layer = n_layer
|
| 141 |
+
self.n_head = n_head
|
| 142 |
+
self.afn = afn
|
| 143 |
+
self.resid_pdrop = resid_pdrop
|
| 144 |
+
self.embd_pdrop = embd_pdrop
|
| 145 |
+
self.attn_pdrop = attn_pdrop
|
| 146 |
+
self.layer_norm_epsilon = layer_norm_epsilon
|
| 147 |
+
self.initializer_range = initializer_range
|
| 148 |
+
self.summary_type = summary_type
|
| 149 |
+
self.summary_use_proj = summary_use_proj
|
| 150 |
+
self.summary_activation = summary_activation
|
| 151 |
+
self.summary_first_dropout = summary_first_dropout
|
| 152 |
+
self.summary_proj_to_labels = summary_proj_to_labels
|
| 153 |
+
super().__init__(**kwargs)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
__all__ = ["OpenAIGPTConfig"]
|
docs/transformers/src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Convert OpenAI GPT checkpoint."""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
from transformers import OpenAIGPTConfig, OpenAIGPTModel, load_tf_weights_in_openai_gpt
|
| 22 |
+
from transformers.utils import CONFIG_NAME, WEIGHTS_NAME, logging
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logging.set_verbosity_info()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path):
|
| 29 |
+
# Construct model
|
| 30 |
+
if openai_config_file == "":
|
| 31 |
+
config = OpenAIGPTConfig()
|
| 32 |
+
else:
|
| 33 |
+
config = OpenAIGPTConfig.from_json_file(openai_config_file)
|
| 34 |
+
model = OpenAIGPTModel(config)
|
| 35 |
+
|
| 36 |
+
# Load weights from numpy
|
| 37 |
+
load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path)
|
| 38 |
+
|
| 39 |
+
# Save pytorch-model
|
| 40 |
+
pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
|
| 41 |
+
pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
|
| 42 |
+
print(f"Save PyTorch model to {pytorch_weights_dump_path}")
|
| 43 |
+
torch.save(model.state_dict(), pytorch_weights_dump_path)
|
| 44 |
+
print(f"Save configuration file to {pytorch_config_dump_path}")
|
| 45 |
+
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
| 46 |
+
f.write(config.to_json_string())
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if __name__ == "__main__":
|
| 50 |
+
parser = argparse.ArgumentParser()
|
| 51 |
+
# Required parameters
|
| 52 |
+
parser.add_argument(
|
| 53 |
+
"--openai_checkpoint_folder_path",
|
| 54 |
+
default=None,
|
| 55 |
+
type=str,
|
| 56 |
+
required=True,
|
| 57 |
+
help="Path to the TensorFlow checkpoint path.",
|
| 58 |
+
)
|
| 59 |
+
parser.add_argument(
|
| 60 |
+
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
| 61 |
+
)
|
| 62 |
+
parser.add_argument(
|
| 63 |
+
"--openai_config_file",
|
| 64 |
+
default="",
|
| 65 |
+
type=str,
|
| 66 |
+
help=(
|
| 67 |
+
"An optional config json file corresponding to the pre-trained OpenAI model. \n"
|
| 68 |
+
"This specifies the model architecture."
|
| 69 |
+
),
|
| 70 |
+
)
|
| 71 |
+
args = parser.parse_args()
|
| 72 |
+
convert_openai_checkpoint_to_pytorch(
|
| 73 |
+
args.openai_checkpoint_folder_path, args.openai_config_file, args.pytorch_dump_folder_path
|
| 74 |
+
)
|
docs/transformers/src/transformers/models/openai/modeling_openai.py
ADDED
|
@@ -0,0 +1,967 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""PyTorch OpenAI GPT model."""
|
| 17 |
+
|
| 18 |
+
import json
|
| 19 |
+
import math
|
| 20 |
+
import os
|
| 21 |
+
from dataclasses import dataclass
|
| 22 |
+
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
from torch import nn
|
| 26 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 27 |
+
|
| 28 |
+
from ...activations import gelu_new, get_activation, silu
|
| 29 |
+
from ...generation import GenerationMixin
|
| 30 |
+
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
|
| 31 |
+
from ...modeling_utils import PreTrainedModel
|
| 32 |
+
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
|
| 33 |
+
from ...utils import (
|
| 34 |
+
ModelOutput,
|
| 35 |
+
add_code_sample_docstrings,
|
| 36 |
+
add_start_docstrings,
|
| 37 |
+
add_start_docstrings_to_model_forward,
|
| 38 |
+
logging,
|
| 39 |
+
replace_return_docstrings,
|
| 40 |
+
)
|
| 41 |
+
from .configuration_openai import OpenAIGPTConfig
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
logger = logging.get_logger(__name__)
|
| 45 |
+
|
| 46 |
+
_CHECKPOINT_FOR_DOC = "openai-community/openai-gpt"
|
| 47 |
+
_CONFIG_FOR_DOC = "OpenAIGPTConfig"
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
|
| 51 |
+
"""Load tf pre-trained weights in a pytorch model (from NumPy arrays here)"""
|
| 52 |
+
import re
|
| 53 |
+
|
| 54 |
+
import numpy as np
|
| 55 |
+
|
| 56 |
+
if ".ckpt" in openai_checkpoint_folder_path:
|
| 57 |
+
openai_checkpoint_folder_path = os.path.dirname(openai_checkpoint_folder_path)
|
| 58 |
+
|
| 59 |
+
logger.info(f"Loading weights from {openai_checkpoint_folder_path}")
|
| 60 |
+
|
| 61 |
+
with open(openai_checkpoint_folder_path + "/parameters_names.json", "r", encoding="utf-8") as names_handle:
|
| 62 |
+
names = json.load(names_handle)
|
| 63 |
+
with open(openai_checkpoint_folder_path + "/params_shapes.json", "r", encoding="utf-8") as shapes_handle:
|
| 64 |
+
shapes = json.load(shapes_handle)
|
| 65 |
+
offsets = np.cumsum([np.prod(shape) for shape in shapes])
|
| 66 |
+
init_params = [np.load(openai_checkpoint_folder_path + f"/params_{n}.npy") for n in range(10)]
|
| 67 |
+
init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
|
| 68 |
+
init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
|
| 69 |
+
|
| 70 |
+
# This was used when we had a single embedding matrix for positions and tokens
|
| 71 |
+
# init_params[0] = np.concatenate([init_params[1], init_params[0]], 0)
|
| 72 |
+
# del init_params[1]
|
| 73 |
+
init_params = [arr.squeeze() for arr in init_params]
|
| 74 |
+
|
| 75 |
+
# Check that the token and position embeddings weight dimensions map those of the init parameters.
|
| 76 |
+
if model.tokens_embed.weight.shape != init_params[1].shape:
|
| 77 |
+
raise ValueError(
|
| 78 |
+
f"tokens_embed.weight.shape: {model.tokens_embed.weight.shape} does not match init_param[1].shape:"
|
| 79 |
+
f" {init_params[1].shape}"
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
if model.positions_embed.weight.shape != init_params[0].shape:
|
| 83 |
+
raise ValueError(
|
| 84 |
+
f"positions_embed.weight.shape: {model.positions_embed.weight.shape} does not match init_param[0].shape:"
|
| 85 |
+
f" {init_params[0].shape}"
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
model.tokens_embed.weight.data = torch.from_numpy(init_params[1])
|
| 89 |
+
model.positions_embed.weight.data = torch.from_numpy(init_params[0])
|
| 90 |
+
names.pop(0)
|
| 91 |
+
# Pop position and token embedding arrays
|
| 92 |
+
init_params.pop(0)
|
| 93 |
+
init_params.pop(0)
|
| 94 |
+
|
| 95 |
+
for name, array in zip(names, init_params): # names[1:n_transfer], init_params[1:n_transfer]):
|
| 96 |
+
name = name[6:] # skip "model/"
|
| 97 |
+
if name[-2:] != ":0":
|
| 98 |
+
raise ValueError(f"Layer {name} does not end with :0")
|
| 99 |
+
name = name[:-2]
|
| 100 |
+
name = name.split("/")
|
| 101 |
+
pointer = model
|
| 102 |
+
for m_name in name:
|
| 103 |
+
if re.fullmatch(r"[A-Za-z]+\d+", m_name):
|
| 104 |
+
scope_names = re.split(r"(\d+)", m_name)
|
| 105 |
+
else:
|
| 106 |
+
scope_names = [m_name]
|
| 107 |
+
if scope_names[0] == "g":
|
| 108 |
+
pointer = getattr(pointer, "weight")
|
| 109 |
+
elif scope_names[0] == "b":
|
| 110 |
+
pointer = getattr(pointer, "bias")
|
| 111 |
+
elif scope_names[0] == "w":
|
| 112 |
+
pointer = getattr(pointer, "weight")
|
| 113 |
+
else:
|
| 114 |
+
pointer = getattr(pointer, scope_names[0])
|
| 115 |
+
if len(scope_names) >= 2:
|
| 116 |
+
num = int(scope_names[1])
|
| 117 |
+
pointer = pointer[num]
|
| 118 |
+
|
| 119 |
+
# Ensure that the pointer and array have compatible shapes.
|
| 120 |
+
if pointer.shape != array.shape:
|
| 121 |
+
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
|
| 122 |
+
|
| 123 |
+
logger.info(f"Initialize PyTorch weight {name}")
|
| 124 |
+
pointer.data = torch.from_numpy(array)
|
| 125 |
+
return model
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
ACT_FNS = {"relu": nn.ReLU(), "silu": silu, "gelu": gelu_new, "swish": silu}
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class Attention(nn.Module):
|
| 132 |
+
def __init__(self, nx, n_positions, config, scale=False):
|
| 133 |
+
super().__init__()
|
| 134 |
+
n_state = nx # in Attention: n_state=768 (nx=n_embd)
|
| 135 |
+
# [switch nx => n_state from Block to Attention to keep identical to TF implementation]
|
| 136 |
+
if n_state % config.n_head != 0:
|
| 137 |
+
raise ValueError(f"Attention n_state shape: {n_state} must be divisible by config.n_head {config.n_head}")
|
| 138 |
+
self.register_buffer(
|
| 139 |
+
"bias",
|
| 140 |
+
torch.tril(torch.ones(n_positions, n_positions)).view(1, 1, n_positions, n_positions),
|
| 141 |
+
persistent=False,
|
| 142 |
+
)
|
| 143 |
+
self.n_head = config.n_head
|
| 144 |
+
self.split_size = n_state
|
| 145 |
+
self.scale = scale
|
| 146 |
+
|
| 147 |
+
self.c_attn = Conv1D(n_state * 3, nx)
|
| 148 |
+
self.c_proj = Conv1D(n_state, nx)
|
| 149 |
+
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
| 150 |
+
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
| 151 |
+
self.pruned_heads = set()
|
| 152 |
+
|
| 153 |
+
def prune_heads(self, heads):
|
| 154 |
+
if len(heads) == 0:
|
| 155 |
+
return
|
| 156 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 157 |
+
heads, self.n_head, self.split_size // self.n_head, self.pruned_heads
|
| 158 |
+
)
|
| 159 |
+
index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
|
| 160 |
+
# Prune conv1d layers
|
| 161 |
+
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
|
| 162 |
+
self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
|
| 163 |
+
# Update hyper params
|
| 164 |
+
self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
|
| 165 |
+
self.n_head = self.n_head - len(heads)
|
| 166 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 167 |
+
|
| 168 |
+
def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False):
|
| 169 |
+
w = torch.matmul(q, k)
|
| 170 |
+
if self.scale:
|
| 171 |
+
w = w / math.sqrt(v.size(-1))
|
| 172 |
+
# w = w * self.bias + -1e9 * (1 - self.bias) # TF implementation method: mask_attn_weights
|
| 173 |
+
# XD: self.b may be larger than w, so we need to crop it
|
| 174 |
+
b = self.bias[:, :, : w.size(-2), : w.size(-1)]
|
| 175 |
+
w = w * b + -1e4 * (1 - b)
|
| 176 |
+
|
| 177 |
+
if attention_mask is not None:
|
| 178 |
+
# Apply the attention mask
|
| 179 |
+
w = w + attention_mask
|
| 180 |
+
|
| 181 |
+
w = nn.functional.softmax(w, dim=-1)
|
| 182 |
+
w = self.attn_dropout(w)
|
| 183 |
+
|
| 184 |
+
# Mask heads if we want to
|
| 185 |
+
if head_mask is not None:
|
| 186 |
+
w = w * head_mask
|
| 187 |
+
|
| 188 |
+
outputs = [torch.matmul(w, v)]
|
| 189 |
+
if output_attentions:
|
| 190 |
+
outputs.append(w)
|
| 191 |
+
return outputs
|
| 192 |
+
|
| 193 |
+
def merge_heads(self, x):
|
| 194 |
+
x = x.permute(0, 2, 1, 3).contiguous()
|
| 195 |
+
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
|
| 196 |
+
return x.view(*new_x_shape) # in Tensorflow implementation: fct merge_states
|
| 197 |
+
|
| 198 |
+
def split_heads(self, x, k=False):
|
| 199 |
+
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
|
| 200 |
+
x = x.view(*new_x_shape) # in Tensorflow implementation: fct split_states
|
| 201 |
+
if k:
|
| 202 |
+
return x.permute(0, 2, 3, 1)
|
| 203 |
+
else:
|
| 204 |
+
return x.permute(0, 2, 1, 3)
|
| 205 |
+
|
| 206 |
+
def forward(self, x, attention_mask=None, head_mask=None, output_attentions=False):
|
| 207 |
+
x = self.c_attn(x)
|
| 208 |
+
query, key, value = x.split(self.split_size, dim=2)
|
| 209 |
+
query = self.split_heads(query)
|
| 210 |
+
key = self.split_heads(key, k=True)
|
| 211 |
+
value = self.split_heads(value)
|
| 212 |
+
|
| 213 |
+
attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions)
|
| 214 |
+
a = attn_outputs[0]
|
| 215 |
+
|
| 216 |
+
a = self.merge_heads(a)
|
| 217 |
+
a = self.c_proj(a)
|
| 218 |
+
a = self.resid_dropout(a)
|
| 219 |
+
|
| 220 |
+
outputs = [a] + attn_outputs[1:]
|
| 221 |
+
return outputs # a, (attentions)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class MLP(nn.Module):
|
| 225 |
+
def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
|
| 226 |
+
super().__init__()
|
| 227 |
+
nx = config.n_embd
|
| 228 |
+
self.c_fc = Conv1D(n_state, nx)
|
| 229 |
+
self.c_proj = Conv1D(nx, n_state)
|
| 230 |
+
self.act = ACT_FNS[config.afn]
|
| 231 |
+
self.dropout = nn.Dropout(config.resid_pdrop)
|
| 232 |
+
|
| 233 |
+
def forward(self, x):
|
| 234 |
+
h = self.act(self.c_fc(x))
|
| 235 |
+
h2 = self.c_proj(h)
|
| 236 |
+
return self.dropout(h2)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class Block(nn.Module):
|
| 240 |
+
def __init__(self, n_positions, config, scale=False):
|
| 241 |
+
super().__init__()
|
| 242 |
+
nx = config.n_embd
|
| 243 |
+
self.attn = Attention(nx, n_positions, config, scale)
|
| 244 |
+
self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
|
| 245 |
+
self.mlp = MLP(4 * nx, config)
|
| 246 |
+
self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
|
| 247 |
+
|
| 248 |
+
def forward(self, x, attention_mask=None, head_mask=None, output_attentions=False):
|
| 249 |
+
attn_outputs = self.attn(
|
| 250 |
+
x,
|
| 251 |
+
attention_mask=attention_mask,
|
| 252 |
+
head_mask=head_mask,
|
| 253 |
+
output_attentions=output_attentions,
|
| 254 |
+
)
|
| 255 |
+
a = attn_outputs[0]
|
| 256 |
+
|
| 257 |
+
n = self.ln_1(x + a)
|
| 258 |
+
m = self.mlp(n)
|
| 259 |
+
h = self.ln_2(n + m)
|
| 260 |
+
|
| 261 |
+
outputs = [h] + attn_outputs[1:]
|
| 262 |
+
return outputs
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
# Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->OpenAIGPT
|
| 266 |
+
class OpenAIGPTSequenceSummary(nn.Module):
|
| 267 |
+
r"""
|
| 268 |
+
Compute a single vector summary of a sequence hidden states.
|
| 269 |
+
|
| 270 |
+
Args:
|
| 271 |
+
config ([`OpenAIGPTConfig`]):
|
| 272 |
+
The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
|
| 273 |
+
config class of your model for the default values it uses):
|
| 274 |
+
|
| 275 |
+
- **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
|
| 276 |
+
|
| 277 |
+
- `"last"` -- Take the last token hidden state (like XLNet)
|
| 278 |
+
- `"first"` -- Take the first token hidden state (like Bert)
|
| 279 |
+
- `"mean"` -- Take the mean of all tokens hidden states
|
| 280 |
+
- `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
|
| 281 |
+
- `"attn"` -- Not implemented now, use multi-head attention
|
| 282 |
+
|
| 283 |
+
- **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
|
| 284 |
+
- **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
|
| 285 |
+
(otherwise to `config.hidden_size`).
|
| 286 |
+
- **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
|
| 287 |
+
another string or `None` will add no activation.
|
| 288 |
+
- **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
|
| 289 |
+
- **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
|
| 290 |
+
"""
|
| 291 |
+
|
| 292 |
+
def __init__(self, config: OpenAIGPTConfig):
|
| 293 |
+
super().__init__()
|
| 294 |
+
|
| 295 |
+
self.summary_type = getattr(config, "summary_type", "last")
|
| 296 |
+
if self.summary_type == "attn":
|
| 297 |
+
# We should use a standard multi-head attention module with absolute positional embedding for that.
|
| 298 |
+
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
|
| 299 |
+
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
|
| 300 |
+
raise NotImplementedError
|
| 301 |
+
|
| 302 |
+
self.summary = nn.Identity()
|
| 303 |
+
if hasattr(config, "summary_use_proj") and config.summary_use_proj:
|
| 304 |
+
if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
|
| 305 |
+
num_classes = config.num_labels
|
| 306 |
+
else:
|
| 307 |
+
num_classes = config.hidden_size
|
| 308 |
+
self.summary = nn.Linear(config.hidden_size, num_classes)
|
| 309 |
+
|
| 310 |
+
activation_string = getattr(config, "summary_activation", None)
|
| 311 |
+
self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
|
| 312 |
+
|
| 313 |
+
self.first_dropout = nn.Identity()
|
| 314 |
+
if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
|
| 315 |
+
self.first_dropout = nn.Dropout(config.summary_first_dropout)
|
| 316 |
+
|
| 317 |
+
self.last_dropout = nn.Identity()
|
| 318 |
+
if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
|
| 319 |
+
self.last_dropout = nn.Dropout(config.summary_last_dropout)
|
| 320 |
+
|
| 321 |
+
def forward(
|
| 322 |
+
self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
|
| 323 |
+
) -> torch.FloatTensor:
|
| 324 |
+
"""
|
| 325 |
+
Compute a single vector summary of a sequence hidden states.
|
| 326 |
+
|
| 327 |
+
Args:
|
| 328 |
+
hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
|
| 329 |
+
The hidden states of the last layer.
|
| 330 |
+
cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
|
| 331 |
+
Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
|
| 332 |
+
|
| 333 |
+
Returns:
|
| 334 |
+
`torch.FloatTensor`: The summary of the sequence hidden states.
|
| 335 |
+
"""
|
| 336 |
+
if self.summary_type == "last":
|
| 337 |
+
output = hidden_states[:, -1]
|
| 338 |
+
elif self.summary_type == "first":
|
| 339 |
+
output = hidden_states[:, 0]
|
| 340 |
+
elif self.summary_type == "mean":
|
| 341 |
+
output = hidden_states.mean(dim=1)
|
| 342 |
+
elif self.summary_type == "cls_index":
|
| 343 |
+
if cls_index is None:
|
| 344 |
+
cls_index = torch.full_like(
|
| 345 |
+
hidden_states[..., :1, :],
|
| 346 |
+
hidden_states.shape[-2] - 1,
|
| 347 |
+
dtype=torch.long,
|
| 348 |
+
)
|
| 349 |
+
else:
|
| 350 |
+
cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
|
| 351 |
+
cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
|
| 352 |
+
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
|
| 353 |
+
output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
|
| 354 |
+
elif self.summary_type == "attn":
|
| 355 |
+
raise NotImplementedError
|
| 356 |
+
|
| 357 |
+
output = self.first_dropout(output)
|
| 358 |
+
output = self.summary(output)
|
| 359 |
+
output = self.activation(output)
|
| 360 |
+
output = self.last_dropout(output)
|
| 361 |
+
|
| 362 |
+
return output
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
class OpenAIGPTPreTrainedModel(PreTrainedModel):
|
| 366 |
+
"""
|
| 367 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 368 |
+
models.
|
| 369 |
+
"""
|
| 370 |
+
|
| 371 |
+
config_class = OpenAIGPTConfig
|
| 372 |
+
load_tf_weights = load_tf_weights_in_openai_gpt
|
| 373 |
+
base_model_prefix = "transformer"
|
| 374 |
+
|
| 375 |
+
def _init_weights(self, module):
|
| 376 |
+
"""Initialize the weights."""
|
| 377 |
+
if isinstance(module, (nn.Linear, Conv1D)):
|
| 378 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 379 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 380 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 381 |
+
if module.bias is not None:
|
| 382 |
+
module.bias.data.zero_()
|
| 383 |
+
elif isinstance(module, nn.Embedding):
|
| 384 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 385 |
+
if module.padding_idx is not None:
|
| 386 |
+
module.weight.data[module.padding_idx].zero_()
|
| 387 |
+
elif isinstance(module, nn.LayerNorm):
|
| 388 |
+
module.bias.data.zero_()
|
| 389 |
+
module.weight.data.fill_(1.0)
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
@dataclass
|
| 393 |
+
class OpenAIGPTDoubleHeadsModelOutput(ModelOutput):
|
| 394 |
+
"""
|
| 395 |
+
Base class for outputs of models predicting if two sentences are consecutive or not.
|
| 396 |
+
|
| 397 |
+
Args:
|
| 398 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
| 399 |
+
Language modeling loss.
|
| 400 |
+
mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided):
|
| 401 |
+
Multiple choice classification loss.
|
| 402 |
+
logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`):
|
| 403 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 404 |
+
mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
|
| 405 |
+
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
|
| 406 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 407 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
| 408 |
+
shape `(batch_size, sequence_length, hidden_size)`.
|
| 409 |
+
|
| 410 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 411 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 412 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 413 |
+
sequence_length)`.
|
| 414 |
+
|
| 415 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 416 |
+
heads.
|
| 417 |
+
"""
|
| 418 |
+
|
| 419 |
+
loss: Optional[torch.FloatTensor] = None
|
| 420 |
+
mc_loss: Optional[torch.FloatTensor] = None
|
| 421 |
+
logits: Optional[torch.FloatTensor] = None
|
| 422 |
+
mc_logits: Optional[torch.FloatTensor] = None
|
| 423 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 424 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
OPENAI_GPT_START_DOCSTRING = r"""
|
| 428 |
+
|
| 429 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 430 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 431 |
+
etc.)
|
| 432 |
+
|
| 433 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 434 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 435 |
+
and behavior.
|
| 436 |
+
|
| 437 |
+
Parameters:
|
| 438 |
+
config ([`OpenAIGPTConfig`]): Model configuration class with all the parameters of the model.
|
| 439 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 440 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 441 |
+
"""
|
| 442 |
+
|
| 443 |
+
OPENAI_GPT_INPUTS_DOCSTRING = r"""
|
| 444 |
+
Args:
|
| 445 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 446 |
+
Indices of input sequence tokens in the vocabulary.
|
| 447 |
+
|
| 448 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 449 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 450 |
+
|
| 451 |
+
[What are input IDs?](../glossary#input-ids)
|
| 452 |
+
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 453 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 454 |
+
|
| 455 |
+
- 1 for tokens that are **not masked**,
|
| 456 |
+
- 0 for tokens that are **masked**.
|
| 457 |
+
|
| 458 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 459 |
+
token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 460 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
| 461 |
+
1]`:
|
| 462 |
+
|
| 463 |
+
- 0 corresponds to a *sentence A* token,
|
| 464 |
+
- 1 corresponds to a *sentence B* token.
|
| 465 |
+
|
| 466 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
| 467 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 468 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 469 |
+
config.max_position_embeddings - 1]`.
|
| 470 |
+
|
| 471 |
+
[What are position IDs?](../glossary#position-ids)
|
| 472 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
| 473 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
| 474 |
+
|
| 475 |
+
- 1 indicates the head is **not masked**,
|
| 476 |
+
- 0 indicates the head is **masked**.
|
| 477 |
+
|
| 478 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 479 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 480 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 481 |
+
model's internal embedding lookup matrix.
|
| 482 |
+
output_attentions (`bool`, *optional*):
|
| 483 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 484 |
+
tensors for more detail.
|
| 485 |
+
output_hidden_states (`bool`, *optional*):
|
| 486 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 487 |
+
more detail.
|
| 488 |
+
return_dict (`bool`, *optional*):
|
| 489 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 490 |
+
"""
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
@add_start_docstrings(
|
| 494 |
+
"The bare OpenAI GPT transformer model outputting raw hidden-states without any specific head on top.",
|
| 495 |
+
OPENAI_GPT_START_DOCSTRING,
|
| 496 |
+
)
|
| 497 |
+
class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
| 498 |
+
def __init__(self, config):
|
| 499 |
+
super().__init__(config)
|
| 500 |
+
|
| 501 |
+
self.tokens_embed = nn.Embedding(config.vocab_size, config.n_embd)
|
| 502 |
+
self.positions_embed = nn.Embedding(config.n_positions, config.n_embd)
|
| 503 |
+
self.drop = nn.Dropout(config.embd_pdrop)
|
| 504 |
+
self.h = nn.ModuleList([Block(config.n_positions, config, scale=True) for _ in range(config.n_layer)])
|
| 505 |
+
|
| 506 |
+
self.register_buffer("position_ids", torch.arange(config.n_positions), persistent=False)
|
| 507 |
+
# Initialize weights and apply final processing
|
| 508 |
+
self.post_init()
|
| 509 |
+
|
| 510 |
+
def get_input_embeddings(self):
|
| 511 |
+
return self.tokens_embed
|
| 512 |
+
|
| 513 |
+
def set_input_embeddings(self, new_embeddings):
|
| 514 |
+
self.tokens_embed = new_embeddings
|
| 515 |
+
|
| 516 |
+
def _prune_heads(self, heads_to_prune):
|
| 517 |
+
"""
|
| 518 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
| 519 |
+
"""
|
| 520 |
+
for layer, heads in heads_to_prune.items():
|
| 521 |
+
self.h[layer].attn.prune_heads(heads)
|
| 522 |
+
|
| 523 |
+
@add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
|
| 524 |
+
@add_code_sample_docstrings(
|
| 525 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 526 |
+
output_type=BaseModelOutput,
|
| 527 |
+
config_class=_CONFIG_FOR_DOC,
|
| 528 |
+
)
|
| 529 |
+
def forward(
|
| 530 |
+
self,
|
| 531 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 532 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 533 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 534 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 535 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 536 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 537 |
+
output_attentions: Optional[bool] = None,
|
| 538 |
+
output_hidden_states: Optional[bool] = None,
|
| 539 |
+
return_dict: Optional[bool] = None,
|
| 540 |
+
) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
|
| 541 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 542 |
+
output_hidden_states = (
|
| 543 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 544 |
+
)
|
| 545 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 546 |
+
|
| 547 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 548 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 549 |
+
elif input_ids is not None:
|
| 550 |
+
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
| 551 |
+
input_shape = input_ids.size()
|
| 552 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
| 553 |
+
elif inputs_embeds is not None:
|
| 554 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 555 |
+
else:
|
| 556 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 557 |
+
|
| 558 |
+
if position_ids is None:
|
| 559 |
+
# Code is different from when we had a single embedding matrix from position and token embeddings
|
| 560 |
+
position_ids = self.position_ids[None, : input_shape[-1]]
|
| 561 |
+
|
| 562 |
+
# Attention mask.
|
| 563 |
+
if attention_mask is not None:
|
| 564 |
+
# We create a 3D attention mask from a 2D tensor mask.
|
| 565 |
+
# Sizes are [batch_size, 1, 1, to_seq_length]
|
| 566 |
+
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
| 567 |
+
# this attention mask is more simple than the triangular masking of causal attention
|
| 568 |
+
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
| 569 |
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
| 570 |
+
|
| 571 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
| 572 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
| 573 |
+
# positions we want to attend and the dtype's smallest value for masked positions.
|
| 574 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
| 575 |
+
# effectively the same as removing these entirely.
|
| 576 |
+
attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
| 577 |
+
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
| 578 |
+
|
| 579 |
+
# Prepare head mask if needed
|
| 580 |
+
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
| 581 |
+
|
| 582 |
+
if inputs_embeds is None:
|
| 583 |
+
inputs_embeds = self.tokens_embed(input_ids)
|
| 584 |
+
position_embeds = self.positions_embed(position_ids)
|
| 585 |
+
if token_type_ids is not None:
|
| 586 |
+
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
|
| 587 |
+
token_type_embeds = self.tokens_embed(token_type_ids)
|
| 588 |
+
else:
|
| 589 |
+
token_type_embeds = 0
|
| 590 |
+
hidden_states = inputs_embeds + position_embeds + token_type_embeds
|
| 591 |
+
hidden_states = self.drop(hidden_states)
|
| 592 |
+
|
| 593 |
+
output_shape = input_shape + (hidden_states.size(-1),)
|
| 594 |
+
|
| 595 |
+
all_attentions = () if output_attentions else None
|
| 596 |
+
all_hidden_states = () if output_hidden_states else None
|
| 597 |
+
for i, block in enumerate(self.h):
|
| 598 |
+
if output_hidden_states:
|
| 599 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 600 |
+
|
| 601 |
+
outputs = block(hidden_states, attention_mask, head_mask[i], output_attentions=output_attentions)
|
| 602 |
+
hidden_states = outputs[0]
|
| 603 |
+
if output_attentions:
|
| 604 |
+
all_attentions = all_attentions + (outputs[1],)
|
| 605 |
+
|
| 606 |
+
hidden_states = hidden_states.view(*output_shape)
|
| 607 |
+
# Add last layer
|
| 608 |
+
if output_hidden_states:
|
| 609 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 610 |
+
|
| 611 |
+
if not return_dict:
|
| 612 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
| 613 |
+
|
| 614 |
+
return BaseModelOutput(
|
| 615 |
+
last_hidden_state=hidden_states,
|
| 616 |
+
hidden_states=all_hidden_states,
|
| 617 |
+
attentions=all_attentions,
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
@add_start_docstrings(
|
| 622 |
+
"""
|
| 623 |
+
OpenAI GPT Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
| 624 |
+
embeddings).
|
| 625 |
+
""",
|
| 626 |
+
OPENAI_GPT_START_DOCSTRING,
|
| 627 |
+
)
|
| 628 |
+
class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel, GenerationMixin):
|
| 629 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 630 |
+
|
| 631 |
+
def __init__(self, config):
|
| 632 |
+
super().__init__(config)
|
| 633 |
+
self.transformer = OpenAIGPTModel(config)
|
| 634 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| 635 |
+
|
| 636 |
+
# Initialize weights and apply final processing
|
| 637 |
+
self.post_init()
|
| 638 |
+
|
| 639 |
+
def get_output_embeddings(self):
|
| 640 |
+
return self.lm_head
|
| 641 |
+
|
| 642 |
+
def set_output_embeddings(self, new_embeddings):
|
| 643 |
+
self.lm_head = new_embeddings
|
| 644 |
+
|
| 645 |
+
@add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
|
| 646 |
+
@add_code_sample_docstrings(
|
| 647 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 648 |
+
output_type=CausalLMOutput,
|
| 649 |
+
config_class=_CONFIG_FOR_DOC,
|
| 650 |
+
)
|
| 651 |
+
def forward(
|
| 652 |
+
self,
|
| 653 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 654 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 655 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 656 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 657 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 658 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 659 |
+
labels: Optional[torch.LongTensor] = None,
|
| 660 |
+
output_attentions: Optional[bool] = None,
|
| 661 |
+
output_hidden_states: Optional[bool] = None,
|
| 662 |
+
return_dict: Optional[bool] = None,
|
| 663 |
+
**kwargs,
|
| 664 |
+
) -> Union[Tuple[torch.Tensor], CausalLMOutput]:
|
| 665 |
+
r"""
|
| 666 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 667 |
+
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
| 668 |
+
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
| 669 |
+
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
| 670 |
+
"""
|
| 671 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 672 |
+
|
| 673 |
+
transformer_outputs = self.transformer(
|
| 674 |
+
input_ids,
|
| 675 |
+
attention_mask=attention_mask,
|
| 676 |
+
token_type_ids=token_type_ids,
|
| 677 |
+
position_ids=position_ids,
|
| 678 |
+
head_mask=head_mask,
|
| 679 |
+
inputs_embeds=inputs_embeds,
|
| 680 |
+
output_attentions=output_attentions,
|
| 681 |
+
output_hidden_states=output_hidden_states,
|
| 682 |
+
return_dict=return_dict,
|
| 683 |
+
)
|
| 684 |
+
hidden_states = transformer_outputs[0]
|
| 685 |
+
lm_logits = self.lm_head(hidden_states)
|
| 686 |
+
|
| 687 |
+
loss = None
|
| 688 |
+
if labels is not None:
|
| 689 |
+
# Flatten the tokens
|
| 690 |
+
loss = self.loss_function(
|
| 691 |
+
lm_logits,
|
| 692 |
+
labels,
|
| 693 |
+
vocab_size=self.config.vocab_size,
|
| 694 |
+
**kwargs,
|
| 695 |
+
)
|
| 696 |
+
|
| 697 |
+
if not return_dict:
|
| 698 |
+
output = (lm_logits,) + transformer_outputs[1:]
|
| 699 |
+
return ((loss,) + output) if loss is not None else output
|
| 700 |
+
|
| 701 |
+
return CausalLMOutput(
|
| 702 |
+
loss=loss,
|
| 703 |
+
logits=lm_logits,
|
| 704 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 705 |
+
attentions=transformer_outputs.attentions,
|
| 706 |
+
)
|
| 707 |
+
|
| 708 |
+
def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]:
|
| 709 |
+
# Overwritten -- old model with reduced inputs
|
| 710 |
+
return {"input_ids": input_ids}
|
| 711 |
+
|
| 712 |
+
|
| 713 |
+
@add_start_docstrings(
|
| 714 |
+
"""
|
| 715 |
+
OpenAI GPT Model transformer with a language modeling and a multiple-choice classification head on top e.g. for
|
| 716 |
+
RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the
|
| 717 |
+
input embeddings, the classification head takes as input the input of a specified classification token index in the
|
| 718 |
+
input sequence).
|
| 719 |
+
""",
|
| 720 |
+
OPENAI_GPT_START_DOCSTRING,
|
| 721 |
+
)
|
| 722 |
+
class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
| 723 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 724 |
+
|
| 725 |
+
def __init__(self, config):
|
| 726 |
+
super().__init__(config)
|
| 727 |
+
|
| 728 |
+
config.num_labels = 1
|
| 729 |
+
self.transformer = OpenAIGPTModel(config)
|
| 730 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| 731 |
+
self.multiple_choice_head = OpenAIGPTSequenceSummary(config)
|
| 732 |
+
|
| 733 |
+
# Initialize weights and apply final processing
|
| 734 |
+
self.post_init()
|
| 735 |
+
|
| 736 |
+
def get_output_embeddings(self):
|
| 737 |
+
return self.lm_head
|
| 738 |
+
|
| 739 |
+
def set_output_embeddings(self, new_embeddings):
|
| 740 |
+
self.lm_head = new_embeddings
|
| 741 |
+
|
| 742 |
+
@add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
|
| 743 |
+
@replace_return_docstrings(output_type=OpenAIGPTDoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
|
| 744 |
+
def forward(
|
| 745 |
+
self,
|
| 746 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 747 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 748 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 749 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 750 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 751 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 752 |
+
mc_token_ids: Optional[torch.LongTensor] = None,
|
| 753 |
+
labels: Optional[torch.LongTensor] = None,
|
| 754 |
+
mc_labels: Optional[torch.LongTensor] = None,
|
| 755 |
+
output_attentions: Optional[bool] = None,
|
| 756 |
+
output_hidden_states: Optional[bool] = None,
|
| 757 |
+
return_dict: Optional[bool] = None,
|
| 758 |
+
) -> Union[Tuple[torch.Tensor], OpenAIGPTDoubleHeadsModelOutput]:
|
| 759 |
+
r"""
|
| 760 |
+
mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
|
| 761 |
+
Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
|
| 762 |
+
1]`.
|
| 763 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 764 |
+
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
| 765 |
+
`labels = input_ids` Indices are selected in `[-1, 0, ..., config.vocab_size]` All labels set to `-100` are
|
| 766 |
+
ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
| 767 |
+
mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*):
|
| 768 |
+
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
|
| 769 |
+
where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above)
|
| 770 |
+
|
| 771 |
+
Return:
|
| 772 |
+
|
| 773 |
+
Examples:
|
| 774 |
+
|
| 775 |
+
```python
|
| 776 |
+
>>> from transformers import AutoTokenizer, OpenAIGPTDoubleHeadsModel
|
| 777 |
+
>>> import torch
|
| 778 |
+
|
| 779 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("openai-community/openai-gpt")
|
| 780 |
+
>>> model = OpenAIGPTDoubleHeadsModel.from_pretrained("openai-community/openai-gpt")
|
| 781 |
+
>>> tokenizer.add_special_tokens(
|
| 782 |
+
... {"cls_token": "[CLS]"}
|
| 783 |
+
... ) # Add a [CLS] to the vocabulary (we should train it also!)
|
| 784 |
+
>>> model.resize_token_embeddings(len(tokenizer))
|
| 785 |
+
|
| 786 |
+
>>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
|
| 787 |
+
>>> input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
|
| 788 |
+
>>> mc_token_ids = torch.tensor([input_ids.size(-1) - 1, input_ids.size(-1) - 1]).unsqueeze(0) # Batch size 1
|
| 789 |
+
|
| 790 |
+
>>> outputs = model(input_ids, mc_token_ids=mc_token_ids)
|
| 791 |
+
>>> lm_logits = outputs.logits
|
| 792 |
+
>>> mc_logits = outputs.mc_logits
|
| 793 |
+
```"""
|
| 794 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 795 |
+
|
| 796 |
+
transformer_outputs = self.transformer(
|
| 797 |
+
input_ids,
|
| 798 |
+
attention_mask=attention_mask,
|
| 799 |
+
token_type_ids=token_type_ids,
|
| 800 |
+
position_ids=position_ids,
|
| 801 |
+
head_mask=head_mask,
|
| 802 |
+
inputs_embeds=inputs_embeds,
|
| 803 |
+
output_attentions=output_attentions,
|
| 804 |
+
output_hidden_states=output_hidden_states,
|
| 805 |
+
return_dict=return_dict,
|
| 806 |
+
)
|
| 807 |
+
hidden_states = transformer_outputs[0]
|
| 808 |
+
|
| 809 |
+
lm_logits = self.lm_head(hidden_states)
|
| 810 |
+
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
|
| 811 |
+
|
| 812 |
+
lm_loss, mc_loss = None, None
|
| 813 |
+
if mc_labels is not None:
|
| 814 |
+
loss_fct = CrossEntropyLoss()
|
| 815 |
+
mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
|
| 816 |
+
if labels is not None:
|
| 817 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
| 818 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 819 |
+
loss_fct = CrossEntropyLoss()
|
| 820 |
+
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
| 821 |
+
|
| 822 |
+
if not return_dict:
|
| 823 |
+
output = (lm_logits, mc_logits) + transformer_outputs[1:]
|
| 824 |
+
if mc_loss is not None:
|
| 825 |
+
output = (mc_loss,) + output
|
| 826 |
+
return ((lm_loss,) + output) if lm_loss is not None else output
|
| 827 |
+
|
| 828 |
+
return OpenAIGPTDoubleHeadsModelOutput(
|
| 829 |
+
loss=lm_loss,
|
| 830 |
+
mc_loss=mc_loss,
|
| 831 |
+
logits=lm_logits,
|
| 832 |
+
mc_logits=mc_logits,
|
| 833 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 834 |
+
attentions=transformer_outputs.attentions,
|
| 835 |
+
)
|
| 836 |
+
|
| 837 |
+
|
| 838 |
+
@add_start_docstrings(
|
| 839 |
+
"""
|
| 840 |
+
The Original OpenAI GPT Model transformer with a sequence classification head on top (linear layer).
|
| 841 |
+
[`OpenAIGPTForSequenceClassification`] uses the last token in order to do the classification, as other causal
|
| 842 |
+
models (e.g. GPT-2) do. Since it does classification on the last token, it requires to know the position of the
|
| 843 |
+
last token. If a `pad_token_id` is defined in the configuration, it finds the last token that is not a padding
|
| 844 |
+
token in each row. If no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since
|
| 845 |
+
it cannot guess the padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take
|
| 846 |
+
the last value in each row of the batch).
|
| 847 |
+
""",
|
| 848 |
+
OPENAI_GPT_START_DOCSTRING,
|
| 849 |
+
)
|
| 850 |
+
class OpenAIGPTForSequenceClassification(OpenAIGPTPreTrainedModel):
|
| 851 |
+
def __init__(self, config):
|
| 852 |
+
super().__init__(config)
|
| 853 |
+
self.num_labels = config.num_labels
|
| 854 |
+
self.transformer = OpenAIGPTModel(config)
|
| 855 |
+
self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
|
| 856 |
+
|
| 857 |
+
# Initialize weights and apply final processing
|
| 858 |
+
self.post_init()
|
| 859 |
+
|
| 860 |
+
@add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
|
| 861 |
+
@add_code_sample_docstrings(
|
| 862 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 863 |
+
output_type=SequenceClassifierOutput,
|
| 864 |
+
config_class=_CONFIG_FOR_DOC,
|
| 865 |
+
)
|
| 866 |
+
def forward(
|
| 867 |
+
self,
|
| 868 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 869 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 870 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 871 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 872 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 873 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 874 |
+
labels: Optional[torch.LongTensor] = None,
|
| 875 |
+
output_attentions: Optional[bool] = None,
|
| 876 |
+
output_hidden_states: Optional[bool] = None,
|
| 877 |
+
return_dict: Optional[bool] = None,
|
| 878 |
+
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
| 879 |
+
r"""
|
| 880 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 881 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 882 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 883 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 884 |
+
"""
|
| 885 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 886 |
+
|
| 887 |
+
transformer_outputs = self.transformer(
|
| 888 |
+
input_ids,
|
| 889 |
+
attention_mask=attention_mask,
|
| 890 |
+
token_type_ids=token_type_ids,
|
| 891 |
+
position_ids=position_ids,
|
| 892 |
+
head_mask=head_mask,
|
| 893 |
+
inputs_embeds=inputs_embeds,
|
| 894 |
+
output_attentions=output_attentions,
|
| 895 |
+
output_hidden_states=output_hidden_states,
|
| 896 |
+
return_dict=return_dict,
|
| 897 |
+
)
|
| 898 |
+
|
| 899 |
+
hidden_states = transformer_outputs[0]
|
| 900 |
+
logits = self.score(hidden_states)
|
| 901 |
+
|
| 902 |
+
if input_ids is not None:
|
| 903 |
+
batch_size, sequence_length = input_ids.shape[:2]
|
| 904 |
+
else:
|
| 905 |
+
batch_size, sequence_length = inputs_embeds.shape[:2]
|
| 906 |
+
|
| 907 |
+
# Ensure the batch size is > 1 if there is no padding.
|
| 908 |
+
if self.config.pad_token_id is None and batch_size != 1:
|
| 909 |
+
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
| 910 |
+
if self.config.pad_token_id is None:
|
| 911 |
+
last_non_pad_token = -1
|
| 912 |
+
elif input_ids is not None:
|
| 913 |
+
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
| 914 |
+
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
| 915 |
+
token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
|
| 916 |
+
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
| 917 |
+
else:
|
| 918 |
+
last_non_pad_token = -1
|
| 919 |
+
logger.warning_once(
|
| 920 |
+
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
| 921 |
+
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
| 922 |
+
)
|
| 923 |
+
|
| 924 |
+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
| 925 |
+
|
| 926 |
+
loss = None
|
| 927 |
+
if labels is not None:
|
| 928 |
+
if self.config.problem_type is None:
|
| 929 |
+
if self.num_labels == 1:
|
| 930 |
+
self.config.problem_type = "regression"
|
| 931 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 932 |
+
self.config.problem_type = "single_label_classification"
|
| 933 |
+
else:
|
| 934 |
+
self.config.problem_type = "multi_label_classification"
|
| 935 |
+
|
| 936 |
+
if self.config.problem_type == "regression":
|
| 937 |
+
loss_fct = MSELoss()
|
| 938 |
+
if self.num_labels == 1:
|
| 939 |
+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
| 940 |
+
else:
|
| 941 |
+
loss = loss_fct(pooled_logits, labels)
|
| 942 |
+
elif self.config.problem_type == "single_label_classification":
|
| 943 |
+
loss_fct = CrossEntropyLoss()
|
| 944 |
+
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
| 945 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 946 |
+
loss_fct = BCEWithLogitsLoss()
|
| 947 |
+
loss = loss_fct(pooled_logits, labels)
|
| 948 |
+
if not return_dict:
|
| 949 |
+
output = (pooled_logits,) + transformer_outputs[1:]
|
| 950 |
+
return ((loss,) + output) if loss is not None else output
|
| 951 |
+
|
| 952 |
+
return SequenceClassifierOutput(
|
| 953 |
+
loss=loss,
|
| 954 |
+
logits=pooled_logits,
|
| 955 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 956 |
+
attentions=transformer_outputs.attentions,
|
| 957 |
+
)
|
| 958 |
+
|
| 959 |
+
|
| 960 |
+
__all__ = [
|
| 961 |
+
"OpenAIGPTDoubleHeadsModel",
|
| 962 |
+
"OpenAIGPTForSequenceClassification",
|
| 963 |
+
"OpenAIGPTLMHeadModel",
|
| 964 |
+
"OpenAIGPTModel",
|
| 965 |
+
"OpenAIGPTPreTrainedModel",
|
| 966 |
+
"load_tf_weights_in_openai_gpt",
|
| 967 |
+
]
|
docs/transformers/src/transformers/models/openai/modeling_tf_openai.py
ADDED
|
@@ -0,0 +1,937 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""TF 2.0 OpenAI GPT model."""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
from dataclasses import dataclass
|
| 21 |
+
from typing import Optional, Tuple, Union
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
import tensorflow as tf
|
| 25 |
+
|
| 26 |
+
from ...activations_tf import get_tf_activation
|
| 27 |
+
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput, TFSequenceClassifierOutput
|
| 28 |
+
from ...modeling_tf_utils import (
|
| 29 |
+
TFCausalLanguageModelingLoss,
|
| 30 |
+
TFConv1D,
|
| 31 |
+
TFModelInputType,
|
| 32 |
+
TFPreTrainedModel,
|
| 33 |
+
TFSequenceClassificationLoss,
|
| 34 |
+
TFSequenceSummary,
|
| 35 |
+
TFSharedEmbeddings,
|
| 36 |
+
get_initializer,
|
| 37 |
+
keras,
|
| 38 |
+
keras_serializable,
|
| 39 |
+
unpack_inputs,
|
| 40 |
+
)
|
| 41 |
+
from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
|
| 42 |
+
from ...utils import (
|
| 43 |
+
ModelOutput,
|
| 44 |
+
add_code_sample_docstrings,
|
| 45 |
+
add_start_docstrings,
|
| 46 |
+
add_start_docstrings_to_model_forward,
|
| 47 |
+
logging,
|
| 48 |
+
replace_return_docstrings,
|
| 49 |
+
)
|
| 50 |
+
from .configuration_openai import OpenAIGPTConfig
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
logger = logging.get_logger(__name__)
|
| 54 |
+
|
| 55 |
+
_CHECKPOINT_FOR_DOC = "openai-community/openai-gpt"
|
| 56 |
+
_CONFIG_FOR_DOC = "OpenAIGPTConfig"
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class TFAttention(keras.layers.Layer):
|
| 60 |
+
def __init__(self, nx, config, scale=False, **kwargs):
|
| 61 |
+
super().__init__(**kwargs)
|
| 62 |
+
|
| 63 |
+
n_state = nx # in Attention: n_state=768 (nx=n_embd)
|
| 64 |
+
# [switch nx => n_state from Block to Attention to keep identical to TF implementation]
|
| 65 |
+
assert n_state % config.n_head == 0, (
|
| 66 |
+
f"Hidden dimension {n_state} not dividable by number of heads {config.n_head}"
|
| 67 |
+
)
|
| 68 |
+
self.n_head = config.n_head
|
| 69 |
+
self.split_size = n_state
|
| 70 |
+
self.scale = scale
|
| 71 |
+
self.output_attentions = config.output_attentions
|
| 72 |
+
|
| 73 |
+
self.c_attn = TFConv1D(n_state * 3, nx, initializer_range=config.initializer_range, name="c_attn")
|
| 74 |
+
self.c_proj = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_proj")
|
| 75 |
+
self.attn_dropout = keras.layers.Dropout(config.attn_pdrop)
|
| 76 |
+
self.resid_dropout = keras.layers.Dropout(config.resid_pdrop)
|
| 77 |
+
self.n_state = n_state
|
| 78 |
+
self.pruned_heads = set()
|
| 79 |
+
|
| 80 |
+
def prune_heads(self, heads):
|
| 81 |
+
pass
|
| 82 |
+
|
| 83 |
+
@staticmethod
|
| 84 |
+
def causal_attention_mask(nd, ns):
|
| 85 |
+
"""
|
| 86 |
+
1's in the lower triangle, counting from the lower right corner. Same as tf.matrix_band_part(tf.ones([nd, ns]),
|
| 87 |
+
-1, ns-nd), but doesn't produce garbage on TPUs.
|
| 88 |
+
"""
|
| 89 |
+
i = tf.range(nd)[:, None]
|
| 90 |
+
j = tf.range(ns)
|
| 91 |
+
m = i >= j - ns + nd
|
| 92 |
+
return m
|
| 93 |
+
|
| 94 |
+
def _attn(self, q, k, v, attention_mask, head_mask, output_attentions, training=False):
|
| 95 |
+
# q, k, v have shape [batch, heads, sequence, features]
|
| 96 |
+
w = tf.matmul(q, k, transpose_b=True)
|
| 97 |
+
if self.scale:
|
| 98 |
+
dk = tf.cast(shape_list(k)[-1], dtype=w.dtype) # scale attention_scores
|
| 99 |
+
w = w / tf.math.sqrt(dk)
|
| 100 |
+
|
| 101 |
+
# w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
|
| 102 |
+
_, _, nd, ns = shape_list(w)
|
| 103 |
+
b = tf.cast(self.causal_attention_mask(nd, ns), dtype=w.dtype)
|
| 104 |
+
b = tf.reshape(b, [1, 1, nd, ns])
|
| 105 |
+
w = w * b - 1e4 * (1 - b)
|
| 106 |
+
|
| 107 |
+
if attention_mask is not None:
|
| 108 |
+
# Apply the attention mask
|
| 109 |
+
attention_mask = tf.cast(attention_mask, dtype=w.dtype)
|
| 110 |
+
w = w + attention_mask
|
| 111 |
+
|
| 112 |
+
w = stable_softmax(w, axis=-1)
|
| 113 |
+
w = self.attn_dropout(w, training=training)
|
| 114 |
+
|
| 115 |
+
# Mask heads if we want to
|
| 116 |
+
if head_mask is not None:
|
| 117 |
+
w = w * head_mask
|
| 118 |
+
|
| 119 |
+
outputs = [tf.matmul(w, v)]
|
| 120 |
+
if output_attentions:
|
| 121 |
+
outputs.append(w)
|
| 122 |
+
return outputs
|
| 123 |
+
|
| 124 |
+
def merge_heads(self, x):
|
| 125 |
+
x = tf.transpose(x, [0, 2, 1, 3])
|
| 126 |
+
x_shape = shape_list(x)
|
| 127 |
+
new_x_shape = x_shape[:-2] + [x_shape[-2] * x_shape[-1]]
|
| 128 |
+
return tf.reshape(x, new_x_shape)
|
| 129 |
+
|
| 130 |
+
def split_heads(self, x):
|
| 131 |
+
x_shape = shape_list(x)
|
| 132 |
+
new_x_shape = x_shape[:-1] + [self.n_head, x_shape[-1] // self.n_head]
|
| 133 |
+
x = tf.reshape(x, new_x_shape)
|
| 134 |
+
return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features)
|
| 135 |
+
|
| 136 |
+
def call(self, x, attention_mask, head_mask, output_attentions, training=False):
|
| 137 |
+
x = self.c_attn(x)
|
| 138 |
+
query, key, value = tf.split(x, 3, axis=2)
|
| 139 |
+
query = self.split_heads(query)
|
| 140 |
+
key = self.split_heads(key)
|
| 141 |
+
value = self.split_heads(value)
|
| 142 |
+
|
| 143 |
+
attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions, training=training)
|
| 144 |
+
a = attn_outputs[0]
|
| 145 |
+
|
| 146 |
+
a = self.merge_heads(a)
|
| 147 |
+
a = self.c_proj(a)
|
| 148 |
+
a = self.resid_dropout(a, training=training)
|
| 149 |
+
|
| 150 |
+
outputs = [a] + attn_outputs[1:]
|
| 151 |
+
return outputs # a, (attentions)
|
| 152 |
+
|
| 153 |
+
def build(self, input_shape=None):
|
| 154 |
+
if self.built:
|
| 155 |
+
return
|
| 156 |
+
self.built = True
|
| 157 |
+
if getattr(self, "c_attn", None) is not None:
|
| 158 |
+
with tf.name_scope(self.c_attn.name):
|
| 159 |
+
self.c_attn.build([None, None, self.n_state * 3])
|
| 160 |
+
if getattr(self, "c_proj", None) is not None:
|
| 161 |
+
with tf.name_scope(self.c_proj.name):
|
| 162 |
+
self.c_proj.build([None, None, self.n_state])
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class TFMLP(keras.layers.Layer):
|
| 166 |
+
def __init__(self, n_state, config, **kwargs):
|
| 167 |
+
super().__init__(**kwargs)
|
| 168 |
+
nx = config.n_embd
|
| 169 |
+
self.c_fc = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_fc")
|
| 170 |
+
self.c_proj = TFConv1D(nx, n_state, initializer_range=config.initializer_range, name="c_proj")
|
| 171 |
+
self.act = get_tf_activation("gelu")
|
| 172 |
+
self.dropout = keras.layers.Dropout(config.resid_pdrop)
|
| 173 |
+
self.nx = nx
|
| 174 |
+
self.n_state = n_state
|
| 175 |
+
|
| 176 |
+
def call(self, x, training=False):
|
| 177 |
+
h = self.act(self.c_fc(x))
|
| 178 |
+
h2 = self.c_proj(h)
|
| 179 |
+
h2 = self.dropout(h2, training=training)
|
| 180 |
+
return h2
|
| 181 |
+
|
| 182 |
+
def build(self, input_shape=None):
|
| 183 |
+
if self.built:
|
| 184 |
+
return
|
| 185 |
+
self.built = True
|
| 186 |
+
if getattr(self, "c_fc", None) is not None:
|
| 187 |
+
with tf.name_scope(self.c_fc.name):
|
| 188 |
+
self.c_fc.build([None, None, self.n_state])
|
| 189 |
+
if getattr(self, "c_proj", None) is not None:
|
| 190 |
+
with tf.name_scope(self.c_proj.name):
|
| 191 |
+
self.c_proj.build([None, None, self.nx])
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class TFBlock(keras.layers.Layer):
|
| 195 |
+
def __init__(self, config, scale=False, **kwargs):
|
| 196 |
+
super().__init__(**kwargs)
|
| 197 |
+
nx = config.n_embd
|
| 198 |
+
self.attn = TFAttention(nx, config, scale, name="attn")
|
| 199 |
+
self.ln_1 = keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_1")
|
| 200 |
+
self.mlp = TFMLP(4 * nx, config, name="mlp")
|
| 201 |
+
self.ln_2 = keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_2")
|
| 202 |
+
self.nx = nx
|
| 203 |
+
|
| 204 |
+
def call(self, x, attention_mask, head_mask, output_attentions, training=False):
|
| 205 |
+
output_attn = self.attn(x, attention_mask, head_mask, output_attentions, training=training)
|
| 206 |
+
a = output_attn[0] # output_attn: a, (attentions)
|
| 207 |
+
|
| 208 |
+
n = self.ln_1(x + a)
|
| 209 |
+
m = self.mlp(n, training=training)
|
| 210 |
+
h = self.ln_2(n + m)
|
| 211 |
+
|
| 212 |
+
outputs = [h] + output_attn[1:]
|
| 213 |
+
return outputs # x, (attentions)
|
| 214 |
+
|
| 215 |
+
def build(self, input_shape=None):
|
| 216 |
+
if self.built:
|
| 217 |
+
return
|
| 218 |
+
self.built = True
|
| 219 |
+
if getattr(self, "attn", None) is not None:
|
| 220 |
+
with tf.name_scope(self.attn.name):
|
| 221 |
+
self.attn.build(None)
|
| 222 |
+
if getattr(self, "ln_1", None) is not None:
|
| 223 |
+
with tf.name_scope(self.ln_1.name):
|
| 224 |
+
self.ln_1.build([None, None, self.nx])
|
| 225 |
+
if getattr(self, "mlp", None) is not None:
|
| 226 |
+
with tf.name_scope(self.mlp.name):
|
| 227 |
+
self.mlp.build(None)
|
| 228 |
+
if getattr(self, "ln_2", None) is not None:
|
| 229 |
+
with tf.name_scope(self.ln_2.name):
|
| 230 |
+
self.ln_2.build([None, None, self.nx])
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
@keras_serializable
|
| 234 |
+
class TFOpenAIGPTMainLayer(keras.layers.Layer):
|
| 235 |
+
config_class = OpenAIGPTConfig
|
| 236 |
+
|
| 237 |
+
def __init__(self, config, *inputs, **kwargs):
|
| 238 |
+
super().__init__(*inputs, **kwargs)
|
| 239 |
+
|
| 240 |
+
self.config = config
|
| 241 |
+
self.output_hidden_states = config.output_hidden_states
|
| 242 |
+
self.output_attentions = config.output_attentions
|
| 243 |
+
self.return_dict = config.use_return_dict
|
| 244 |
+
self.num_hidden_layers = config.n_layer
|
| 245 |
+
self.n_embd = config.n_embd
|
| 246 |
+
self.n_positions = config.n_positions
|
| 247 |
+
self.initializer_range = config.initializer_range
|
| 248 |
+
|
| 249 |
+
self.tokens_embed = TFSharedEmbeddings(
|
| 250 |
+
config.vocab_size, config.n_embd, initializer_range=config.initializer_range, name="tokens_embed"
|
| 251 |
+
)
|
| 252 |
+
self.drop = keras.layers.Dropout(config.embd_pdrop)
|
| 253 |
+
self.h = [TFBlock(config, scale=True, name=f"h_._{i}") for i in range(config.n_layer)]
|
| 254 |
+
|
| 255 |
+
def build(self, input_shape=None):
|
| 256 |
+
with tf.name_scope("positions_embed"):
|
| 257 |
+
self.positions_embed = self.add_weight(
|
| 258 |
+
name="embeddings",
|
| 259 |
+
shape=[self.n_positions, self.n_embd],
|
| 260 |
+
initializer=get_initializer(self.initializer_range),
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
if self.built:
|
| 264 |
+
return
|
| 265 |
+
self.built = True
|
| 266 |
+
if getattr(self, "tokens_embed", None) is not None:
|
| 267 |
+
with tf.name_scope(self.tokens_embed.name):
|
| 268 |
+
self.tokens_embed.build(None)
|
| 269 |
+
if getattr(self, "h", None) is not None:
|
| 270 |
+
for layer in self.h:
|
| 271 |
+
with tf.name_scope(layer.name):
|
| 272 |
+
layer.build(None)
|
| 273 |
+
|
| 274 |
+
def get_input_embeddings(self):
|
| 275 |
+
return self.tokens_embed
|
| 276 |
+
|
| 277 |
+
def set_input_embeddings(self, value):
|
| 278 |
+
self.tokens_embed.weight = value
|
| 279 |
+
self.tokens_embed.vocab_size = shape_list(value)[0]
|
| 280 |
+
|
| 281 |
+
def _prune_heads(self, heads_to_prune):
|
| 282 |
+
"""
|
| 283 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
| 284 |
+
"""
|
| 285 |
+
raise NotImplementedError
|
| 286 |
+
|
| 287 |
+
@unpack_inputs
|
| 288 |
+
def call(
|
| 289 |
+
self,
|
| 290 |
+
input_ids: TFModelInputType | None = None,
|
| 291 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 292 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 293 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 294 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 295 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 296 |
+
output_attentions: Optional[bool] = None,
|
| 297 |
+
output_hidden_states: Optional[bool] = None,
|
| 298 |
+
return_dict: Optional[bool] = None,
|
| 299 |
+
training: Optional[bool] = False,
|
| 300 |
+
) -> Union[Tuple, TFBaseModelOutput]:
|
| 301 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 302 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 303 |
+
elif input_ids is not None:
|
| 304 |
+
input_shape = shape_list(input_ids)
|
| 305 |
+
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
|
| 306 |
+
elif inputs_embeds is not None:
|
| 307 |
+
input_shape = shape_list(inputs_embeds)[:-1]
|
| 308 |
+
else:
|
| 309 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 310 |
+
|
| 311 |
+
if position_ids is None:
|
| 312 |
+
position_ids = tf.expand_dims(tf.range(input_shape[-1]), axis=0)
|
| 313 |
+
|
| 314 |
+
if attention_mask is not None:
|
| 315 |
+
# We create a 3D attention mask from a 2D tensor mask.
|
| 316 |
+
# Sizes are [batch_size, 1, 1, to_seq_length]
|
| 317 |
+
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
| 318 |
+
# this attention mask is more simple than the triangular masking of causal attention
|
| 319 |
+
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
| 320 |
+
attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1]))
|
| 321 |
+
|
| 322 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
| 323 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
| 324 |
+
# positions we want to attend and -10000.0 for masked positions.
|
| 325 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
| 326 |
+
# effectively the same as removing these entirely.
|
| 327 |
+
|
| 328 |
+
one_cst = tf.constant(1.0)
|
| 329 |
+
attention_mask = tf.cast(attention_mask, dtype=one_cst.dtype)
|
| 330 |
+
attention_mask = tf.multiply(tf.subtract(one_cst, attention_mask), tf.constant(-10000.0))
|
| 331 |
+
else:
|
| 332 |
+
attention_mask = None
|
| 333 |
+
|
| 334 |
+
# Prepare head mask if needed
|
| 335 |
+
# 1.0 in head_mask indicate we keep the head
|
| 336 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 337 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 338 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 339 |
+
if head_mask is not None:
|
| 340 |
+
raise NotImplementedError
|
| 341 |
+
else:
|
| 342 |
+
head_mask = [None] * self.num_hidden_layers
|
| 343 |
+
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
| 344 |
+
|
| 345 |
+
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
|
| 346 |
+
|
| 347 |
+
if inputs_embeds is None:
|
| 348 |
+
check_embeddings_within_bounds(input_ids, self.config.vocab_size)
|
| 349 |
+
inputs_embeds = self.tokens_embed(input_ids, mode="embedding")
|
| 350 |
+
position_embeds = tf.gather(self.positions_embed, position_ids)
|
| 351 |
+
if token_type_ids is not None:
|
| 352 |
+
token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
|
| 353 |
+
check_embeddings_within_bounds(token_type_ids, self.config.vocab_size, "token_type_ids")
|
| 354 |
+
token_type_embeds = self.tokens_embed(token_type_ids, mode="embedding")
|
| 355 |
+
else:
|
| 356 |
+
token_type_embeds = 0
|
| 357 |
+
hidden_states = inputs_embeds + position_embeds + token_type_embeds
|
| 358 |
+
hidden_states = self.drop(hidden_states, training=training)
|
| 359 |
+
|
| 360 |
+
output_shape = input_shape + [shape_list(hidden_states)[-1]]
|
| 361 |
+
|
| 362 |
+
all_attentions = () if output_attentions else None
|
| 363 |
+
all_hidden_states = () if output_hidden_states else None
|
| 364 |
+
for i, block in enumerate(self.h):
|
| 365 |
+
if output_hidden_states:
|
| 366 |
+
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
|
| 367 |
+
|
| 368 |
+
outputs = block(
|
| 369 |
+
hidden_states,
|
| 370 |
+
attention_mask,
|
| 371 |
+
head_mask[i],
|
| 372 |
+
output_attentions,
|
| 373 |
+
training=training,
|
| 374 |
+
)
|
| 375 |
+
hidden_states = outputs[0]
|
| 376 |
+
if output_attentions:
|
| 377 |
+
all_attentions = all_attentions + (outputs[1],)
|
| 378 |
+
|
| 379 |
+
hidden_states = tf.reshape(hidden_states, output_shape)
|
| 380 |
+
# Add last hidden state
|
| 381 |
+
if output_hidden_states:
|
| 382 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 383 |
+
|
| 384 |
+
if output_attentions:
|
| 385 |
+
# let the number of heads free (-1) so we can extract attention even after head pruning
|
| 386 |
+
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
|
| 387 |
+
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
|
| 388 |
+
|
| 389 |
+
if not return_dict:
|
| 390 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
| 391 |
+
|
| 392 |
+
return TFBaseModelOutput(
|
| 393 |
+
last_hidden_state=hidden_states,
|
| 394 |
+
hidden_states=all_hidden_states,
|
| 395 |
+
attentions=all_attentions,
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
class TFOpenAIGPTPreTrainedModel(TFPreTrainedModel):
|
| 400 |
+
"""
|
| 401 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 402 |
+
models.
|
| 403 |
+
"""
|
| 404 |
+
|
| 405 |
+
config_class = OpenAIGPTConfig
|
| 406 |
+
base_model_prefix = "transformer"
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
@dataclass
|
| 410 |
+
class TFOpenAIGPTDoubleHeadsModelOutput(ModelOutput):
|
| 411 |
+
"""
|
| 412 |
+
Base class for outputs of models predicting if two sentences are consecutive or not.
|
| 413 |
+
|
| 414 |
+
Args:
|
| 415 |
+
logits (`tf.Tensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`):
|
| 416 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 417 |
+
mc_logits (`tf.Tensor` of shape `(batch_size, num_choices)`):
|
| 418 |
+
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
|
| 419 |
+
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 420 |
+
Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
|
| 421 |
+
`(batch_size, sequence_length, hidden_size)`.
|
| 422 |
+
|
| 423 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 424 |
+
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 425 |
+
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 426 |
+
sequence_length)`.
|
| 427 |
+
|
| 428 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 429 |
+
heads.
|
| 430 |
+
"""
|
| 431 |
+
|
| 432 |
+
logits: Optional[tf.Tensor] = None
|
| 433 |
+
mc_logits: Optional[tf.Tensor] = None
|
| 434 |
+
hidden_states: Tuple[tf.Tensor] | None = None
|
| 435 |
+
attentions: Tuple[tf.Tensor] | None = None
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
OPENAI_GPT_START_DOCSTRING = r"""
|
| 439 |
+
|
| 440 |
+
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 441 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 442 |
+
etc.)
|
| 443 |
+
|
| 444 |
+
This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
|
| 445 |
+
as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
|
| 446 |
+
behavior.
|
| 447 |
+
|
| 448 |
+
<Tip>
|
| 449 |
+
|
| 450 |
+
TensorFlow models and layers in `transformers` accept two formats as input:
|
| 451 |
+
|
| 452 |
+
- having all inputs as keyword arguments (like PyTorch models), or
|
| 453 |
+
- having all inputs as a list, tuple or dict in the first positional argument.
|
| 454 |
+
|
| 455 |
+
The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
|
| 456 |
+
and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
|
| 457 |
+
pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
|
| 458 |
+
format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
|
| 459 |
+
the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
|
| 460 |
+
positional argument:
|
| 461 |
+
|
| 462 |
+
- a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
|
| 463 |
+
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
|
| 464 |
+
`model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
|
| 465 |
+
- a dictionary with one or several input Tensors associated to the input names given in the docstring:
|
| 466 |
+
`model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
|
| 467 |
+
|
| 468 |
+
Note that when creating models and layers with
|
| 469 |
+
[subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
|
| 470 |
+
about any of this, as you can just pass inputs like you would to any other Python function!
|
| 471 |
+
|
| 472 |
+
</Tip>
|
| 473 |
+
|
| 474 |
+
Parameters:
|
| 475 |
+
config ([`OpenAIGPTConfig`]): Model configuration class with all the parameters of the model.
|
| 476 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 477 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 478 |
+
"""
|
| 479 |
+
|
| 480 |
+
OPENAI_GPT_INPUTS_DOCSTRING = r"""
|
| 481 |
+
Args:
|
| 482 |
+
input_ids (`Numpy array` or `tf.Tensor` of shape `(batch_size, sequence_length)`):
|
| 483 |
+
Indices of input sequence tokens in the vocabulary.
|
| 484 |
+
|
| 485 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
|
| 486 |
+
[`PreTrainedTokenizer.encode`] for details.
|
| 487 |
+
|
| 488 |
+
[What are input IDs?](../glossary#input-ids)
|
| 489 |
+
attention_mask (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*):
|
| 490 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 491 |
+
|
| 492 |
+
- 1 for tokens that are **not masked**,
|
| 493 |
+
- 0 for tokens that are **masked**.
|
| 494 |
+
|
| 495 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 496 |
+
token_type_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*):
|
| 497 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
| 498 |
+
1]`:
|
| 499 |
+
|
| 500 |
+
- 0 corresponds to a *sentence A* token,
|
| 501 |
+
- 1 corresponds to a *sentence B* token.
|
| 502 |
+
|
| 503 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
| 504 |
+
position_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*):
|
| 505 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 506 |
+
config.max_position_embeddings - 1]`.
|
| 507 |
+
|
| 508 |
+
[What are position IDs?](../glossary#position-ids)
|
| 509 |
+
head_mask (`tf.Tensor` or `Numpy array` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
| 510 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
| 511 |
+
|
| 512 |
+
- 1 indicates the head is **not masked**,
|
| 513 |
+
- 0 indicates the head is **masked**.
|
| 514 |
+
|
| 515 |
+
inputs_embeds (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 516 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 517 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 518 |
+
model's internal embedding lookup matrix.
|
| 519 |
+
output_attentions (`bool`, *optional*):
|
| 520 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 521 |
+
tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
|
| 522 |
+
config will be used instead.
|
| 523 |
+
output_hidden_states (`bool`, *optional*):
|
| 524 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 525 |
+
more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
|
| 526 |
+
used instead.
|
| 527 |
+
return_dict (`bool`, *optional*):
|
| 528 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
|
| 529 |
+
eager mode, in graph mode the value will always be set to True.
|
| 530 |
+
training (`bool`, *optional*, defaults to `False`):
|
| 531 |
+
Whether or not to use the model in training mode (some modules like dropout modules have different
|
| 532 |
+
behaviors between training and evaluation).
|
| 533 |
+
"""
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
@add_start_docstrings(
|
| 537 |
+
"The bare OpenAI GPT transformer model outputting raw hidden-states without any specific head on top.",
|
| 538 |
+
OPENAI_GPT_START_DOCSTRING,
|
| 539 |
+
)
|
| 540 |
+
class TFOpenAIGPTModel(TFOpenAIGPTPreTrainedModel):
|
| 541 |
+
def __init__(self, config, *inputs, **kwargs):
|
| 542 |
+
super().__init__(config, *inputs, **kwargs)
|
| 543 |
+
self.transformer = TFOpenAIGPTMainLayer(config, name="transformer")
|
| 544 |
+
|
| 545 |
+
@unpack_inputs
|
| 546 |
+
@add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
|
| 547 |
+
@add_code_sample_docstrings(
|
| 548 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 549 |
+
output_type=TFBaseModelOutput,
|
| 550 |
+
config_class=_CONFIG_FOR_DOC,
|
| 551 |
+
)
|
| 552 |
+
def call(
|
| 553 |
+
self,
|
| 554 |
+
input_ids: TFModelInputType | None = None,
|
| 555 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 556 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 557 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 558 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 559 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 560 |
+
output_attentions: Optional[bool] = None,
|
| 561 |
+
output_hidden_states: Optional[bool] = None,
|
| 562 |
+
return_dict: Optional[bool] = None,
|
| 563 |
+
training: Optional[bool] = False,
|
| 564 |
+
) -> Union[Tuple, TFBaseModelOutput]:
|
| 565 |
+
outputs = self.transformer(
|
| 566 |
+
input_ids=input_ids,
|
| 567 |
+
attention_mask=attention_mask,
|
| 568 |
+
token_type_ids=token_type_ids,
|
| 569 |
+
position_ids=position_ids,
|
| 570 |
+
head_mask=head_mask,
|
| 571 |
+
inputs_embeds=inputs_embeds,
|
| 572 |
+
output_attentions=output_attentions,
|
| 573 |
+
output_hidden_states=output_hidden_states,
|
| 574 |
+
return_dict=return_dict,
|
| 575 |
+
training=training,
|
| 576 |
+
)
|
| 577 |
+
return outputs
|
| 578 |
+
|
| 579 |
+
def build(self, input_shape=None):
|
| 580 |
+
if self.built:
|
| 581 |
+
return
|
| 582 |
+
self.built = True
|
| 583 |
+
if getattr(self, "transformer", None) is not None:
|
| 584 |
+
with tf.name_scope(self.transformer.name):
|
| 585 |
+
self.transformer.build(None)
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
@add_start_docstrings(
|
| 589 |
+
"""
|
| 590 |
+
OpenAI GPT Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
| 591 |
+
embeddings).
|
| 592 |
+
""",
|
| 593 |
+
OPENAI_GPT_START_DOCSTRING,
|
| 594 |
+
)
|
| 595 |
+
class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelingLoss):
|
| 596 |
+
def __init__(self, config, *inputs, **kwargs):
|
| 597 |
+
super().__init__(config, *inputs, **kwargs)
|
| 598 |
+
self.transformer = TFOpenAIGPTMainLayer(config, name="transformer")
|
| 599 |
+
# OpenAIGPT does not have past caching features
|
| 600 |
+
self.supports_xla_generation = False
|
| 601 |
+
|
| 602 |
+
def get_output_embeddings(self):
|
| 603 |
+
return self.get_input_embeddings()
|
| 604 |
+
|
| 605 |
+
def set_output_embeddings(self, value):
|
| 606 |
+
self.set_input_embeddings(value)
|
| 607 |
+
|
| 608 |
+
@unpack_inputs
|
| 609 |
+
@add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
|
| 610 |
+
@add_code_sample_docstrings(
|
| 611 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 612 |
+
output_type=TFCausalLMOutput,
|
| 613 |
+
config_class=_CONFIG_FOR_DOC,
|
| 614 |
+
)
|
| 615 |
+
def call(
|
| 616 |
+
self,
|
| 617 |
+
input_ids: TFModelInputType | None = None,
|
| 618 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 619 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 620 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 621 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 622 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 623 |
+
output_attentions: Optional[bool] = None,
|
| 624 |
+
output_hidden_states: Optional[bool] = None,
|
| 625 |
+
return_dict: Optional[bool] = None,
|
| 626 |
+
labels: np.ndarray | tf.Tensor | None = None,
|
| 627 |
+
training: Optional[bool] = False,
|
| 628 |
+
) -> Union[Tuple, TFCausalLMOutput]:
|
| 629 |
+
r"""
|
| 630 |
+
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 631 |
+
Labels for computing the cross entropy classification loss. Indices should be in `[0, ...,
|
| 632 |
+
config.vocab_size - 1]`.
|
| 633 |
+
"""
|
| 634 |
+
|
| 635 |
+
transformer_outputs = self.transformer(
|
| 636 |
+
input_ids=input_ids,
|
| 637 |
+
attention_mask=attention_mask,
|
| 638 |
+
token_type_ids=token_type_ids,
|
| 639 |
+
position_ids=position_ids,
|
| 640 |
+
head_mask=head_mask,
|
| 641 |
+
inputs_embeds=inputs_embeds,
|
| 642 |
+
output_attentions=output_attentions,
|
| 643 |
+
output_hidden_states=output_hidden_states,
|
| 644 |
+
return_dict=return_dict,
|
| 645 |
+
training=training,
|
| 646 |
+
)
|
| 647 |
+
hidden_states = transformer_outputs[0]
|
| 648 |
+
|
| 649 |
+
logits = self.transformer.tokens_embed(hidden_states, mode="linear")
|
| 650 |
+
|
| 651 |
+
loss = None
|
| 652 |
+
if labels is not None:
|
| 653 |
+
# shift labels to the left and cut last logit token
|
| 654 |
+
shifted_logits = logits[:, :-1]
|
| 655 |
+
labels = labels[:, 1:]
|
| 656 |
+
loss = self.hf_compute_loss(labels, shifted_logits)
|
| 657 |
+
|
| 658 |
+
if not return_dict:
|
| 659 |
+
output = (logits,) + transformer_outputs[1:]
|
| 660 |
+
return ((loss,) + output) if loss is not None else output
|
| 661 |
+
|
| 662 |
+
return TFCausalLMOutput(
|
| 663 |
+
loss=loss,
|
| 664 |
+
logits=logits,
|
| 665 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 666 |
+
attentions=transformer_outputs.attentions,
|
| 667 |
+
)
|
| 668 |
+
|
| 669 |
+
def prepare_inputs_for_generation(self, inputs, **kwargs):
|
| 670 |
+
return {"input_ids": inputs}
|
| 671 |
+
|
| 672 |
+
def build(self, input_shape=None):
|
| 673 |
+
if self.built:
|
| 674 |
+
return
|
| 675 |
+
self.built = True
|
| 676 |
+
if getattr(self, "transformer", None) is not None:
|
| 677 |
+
with tf.name_scope(self.transformer.name):
|
| 678 |
+
self.transformer.build(None)
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
@add_start_docstrings(
|
| 682 |
+
"""
|
| 683 |
+
OpenAI GPT Model transformer with a language modeling and a multiple-choice classification head on top e.g. for
|
| 684 |
+
RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the
|
| 685 |
+
input embeddings, the classification head takes as input the input of a specified classification token index in the
|
| 686 |
+
input sequence).
|
| 687 |
+
""",
|
| 688 |
+
OPENAI_GPT_START_DOCSTRING,
|
| 689 |
+
)
|
| 690 |
+
class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
|
| 691 |
+
def __init__(self, config, *inputs, **kwargs):
|
| 692 |
+
super().__init__(config, *inputs, **kwargs)
|
| 693 |
+
config.num_labels = 1
|
| 694 |
+
self.transformer = TFOpenAIGPTMainLayer(config, name="transformer")
|
| 695 |
+
self.multiple_choice_head = TFSequenceSummary(
|
| 696 |
+
config, initializer_range=config.initializer_range, name="multiple_choice_head"
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
@unpack_inputs
|
| 700 |
+
@add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
|
| 701 |
+
@replace_return_docstrings(output_type=TFOpenAIGPTDoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
|
| 702 |
+
def call(
|
| 703 |
+
self,
|
| 704 |
+
input_ids: TFModelInputType | None = None,
|
| 705 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 706 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 707 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 708 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 709 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 710 |
+
mc_token_ids: np.ndarray | tf.Tensor | None = None,
|
| 711 |
+
output_attentions: Optional[bool] = None,
|
| 712 |
+
output_hidden_states: Optional[bool] = None,
|
| 713 |
+
return_dict: Optional[bool] = None,
|
| 714 |
+
training: Optional[bool] = False,
|
| 715 |
+
) -> Union[Tuple, TFOpenAIGPTDoubleHeadsModelOutput]:
|
| 716 |
+
r"""
|
| 717 |
+
mc_token_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
|
| 718 |
+
Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
|
| 719 |
+
1]`.
|
| 720 |
+
|
| 721 |
+
Return:
|
| 722 |
+
|
| 723 |
+
Examples:
|
| 724 |
+
|
| 725 |
+
```python
|
| 726 |
+
>>> import tensorflow as tf
|
| 727 |
+
>>> from transformers import AutoTokenizer, TFOpenAIGPTDoubleHeadsModel
|
| 728 |
+
|
| 729 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("openai-community/openai-gpt")
|
| 730 |
+
>>> model = TFOpenAIGPTDoubleHeadsModel.from_pretrained("openai-community/openai-gpt")
|
| 731 |
+
|
| 732 |
+
>>> # Add a [CLS] to the vocabulary (we should train it also!)
|
| 733 |
+
>>> tokenizer.add_special_tokens({"cls_token": "[CLS]"})
|
| 734 |
+
>>> model.resize_token_embeddings(len(tokenizer)) # Update the model embeddings with the new vocabulary size
|
| 735 |
+
>>> print(tokenizer.cls_token_id, len(tokenizer)) # The newly token the last token of the vocabulary
|
| 736 |
+
|
| 737 |
+
>>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
|
| 738 |
+
>>> encoding = tokenizer(choices, return_tensors="tf")
|
| 739 |
+
>>> inputs = {k: tf.expand_dims(v, 0) for k, v in encoding.items()}
|
| 740 |
+
>>> inputs["mc_token_ids"] = tf.constant(
|
| 741 |
+
... [inputs["input_ids"].shape[-1] - 1, inputs["input_ids"].shape[-1] - 1]
|
| 742 |
+
... )[
|
| 743 |
+
... None, :
|
| 744 |
+
... ] # Batch size 1
|
| 745 |
+
>>> outputs = model(inputs)
|
| 746 |
+
>>> lm_prediction_scores, mc_prediction_scores = outputs[:2]
|
| 747 |
+
```"""
|
| 748 |
+
|
| 749 |
+
if input_ids is not None:
|
| 750 |
+
input_shapes = shape_list(input_ids)
|
| 751 |
+
else:
|
| 752 |
+
input_shapes = shape_list(inputs_embeds)[:-1]
|
| 753 |
+
|
| 754 |
+
seq_length = input_shapes[-1]
|
| 755 |
+
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
| 756 |
+
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
| 757 |
+
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
| 758 |
+
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
| 759 |
+
transformer_outputs = self.transformer(
|
| 760 |
+
flat_input_ids,
|
| 761 |
+
flat_attention_mask,
|
| 762 |
+
flat_token_type_ids,
|
| 763 |
+
flat_position_ids,
|
| 764 |
+
head_mask,
|
| 765 |
+
inputs_embeds,
|
| 766 |
+
output_attentions,
|
| 767 |
+
output_hidden_states,
|
| 768 |
+
return_dict=return_dict,
|
| 769 |
+
training=training,
|
| 770 |
+
)
|
| 771 |
+
hidden_states = transformer_outputs[0]
|
| 772 |
+
hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:])
|
| 773 |
+
if return_dict and output_hidden_states:
|
| 774 |
+
# We do this to match the slightly odd PT behaviour - the final hidden state is reshaped to rank 4 when the
|
| 775 |
+
# input is rank 3, but all other hidden states remain at rank-3 (with the first 2 dims merged)
|
| 776 |
+
all_hidden_states = transformer_outputs.hidden_states[:-1] + (hidden_states,)
|
| 777 |
+
else:
|
| 778 |
+
all_hidden_states = None
|
| 779 |
+
lm_logits = self.transformer.tokens_embed(hidden_states, mode="linear")
|
| 780 |
+
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training)
|
| 781 |
+
mc_logits = tf.squeeze(mc_logits, axis=-1)
|
| 782 |
+
|
| 783 |
+
if not return_dict:
|
| 784 |
+
return (lm_logits, mc_logits) + transformer_outputs[1:]
|
| 785 |
+
|
| 786 |
+
return TFOpenAIGPTDoubleHeadsModelOutput(
|
| 787 |
+
logits=lm_logits,
|
| 788 |
+
mc_logits=mc_logits,
|
| 789 |
+
hidden_states=all_hidden_states,
|
| 790 |
+
attentions=transformer_outputs.attentions,
|
| 791 |
+
)
|
| 792 |
+
|
| 793 |
+
@property
|
| 794 |
+
def input_signature(self):
|
| 795 |
+
return {
|
| 796 |
+
"input_ids": tf.TensorSpec((None, None, None), tf.int32, name="input_ids"),
|
| 797 |
+
"attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="attention_mask"),
|
| 798 |
+
"mc_token_ids": tf.TensorSpec((None, None), tf.int32, name="token_type_ids"),
|
| 799 |
+
}
|
| 800 |
+
|
| 801 |
+
def build(self, input_shape=None):
|
| 802 |
+
if self.built:
|
| 803 |
+
return
|
| 804 |
+
self.built = True
|
| 805 |
+
if getattr(self, "transformer", None) is not None:
|
| 806 |
+
with tf.name_scope(self.transformer.name):
|
| 807 |
+
self.transformer.build(None)
|
| 808 |
+
if getattr(self, "multiple_choice_head", None) is not None:
|
| 809 |
+
with tf.name_scope(self.multiple_choice_head.name):
|
| 810 |
+
self.multiple_choice_head.build(None)
|
| 811 |
+
|
| 812 |
+
|
| 813 |
+
@add_start_docstrings(
|
| 814 |
+
"""
|
| 815 |
+
The OpenAI GPT Model transformer with a sequence classification head on top (linear layer).
|
| 816 |
+
|
| 817 |
+
[`TFOpenAIGPTForSequenceClassification`] uses the last token in order to do the classification, as other causal
|
| 818 |
+
models (e.g. GPT-2) do.
|
| 819 |
+
|
| 820 |
+
Since it does classification on the last token, it requires to know the position of the last token. If a
|
| 821 |
+
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
| 822 |
+
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
| 823 |
+
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
| 824 |
+
each row of the batch).
|
| 825 |
+
""",
|
| 826 |
+
OPENAI_GPT_START_DOCSTRING,
|
| 827 |
+
)
|
| 828 |
+
class TFOpenAIGPTForSequenceClassification(TFOpenAIGPTPreTrainedModel, TFSequenceClassificationLoss):
|
| 829 |
+
def __init__(self, config, *inputs, **kwargs):
|
| 830 |
+
super().__init__(config, *inputs, **kwargs)
|
| 831 |
+
self.num_labels = config.num_labels
|
| 832 |
+
self.score = keras.layers.Dense(
|
| 833 |
+
config.num_labels,
|
| 834 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 835 |
+
name="score",
|
| 836 |
+
use_bias=False,
|
| 837 |
+
)
|
| 838 |
+
self.transformer = TFOpenAIGPTMainLayer(config, name="transformer")
|
| 839 |
+
self.config = config
|
| 840 |
+
|
| 841 |
+
@unpack_inputs
|
| 842 |
+
@add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
|
| 843 |
+
@add_code_sample_docstrings(
|
| 844 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 845 |
+
output_type=TFSequenceClassifierOutput,
|
| 846 |
+
config_class=_CONFIG_FOR_DOC,
|
| 847 |
+
)
|
| 848 |
+
def call(
|
| 849 |
+
self,
|
| 850 |
+
input_ids: TFModelInputType | None = None,
|
| 851 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 852 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 853 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 854 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 855 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 856 |
+
output_attentions: Optional[bool] = None,
|
| 857 |
+
output_hidden_states: Optional[bool] = None,
|
| 858 |
+
return_dict: Optional[bool] = None,
|
| 859 |
+
labels: np.ndarray | tf.Tensor | None = None,
|
| 860 |
+
training: Optional[bool] = False,
|
| 861 |
+
) -> Union[Tuple, TFSequenceClassifierOutput]:
|
| 862 |
+
r"""
|
| 863 |
+
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 864 |
+
Labels for computing the cross entropy classification loss. Indices should be in `[0, ...,
|
| 865 |
+
config.vocab_size - 1]`.
|
| 866 |
+
"""
|
| 867 |
+
transformer_outputs = self.transformer(
|
| 868 |
+
input_ids=input_ids,
|
| 869 |
+
attention_mask=attention_mask,
|
| 870 |
+
token_type_ids=token_type_ids,
|
| 871 |
+
position_ids=position_ids,
|
| 872 |
+
head_mask=head_mask,
|
| 873 |
+
inputs_embeds=inputs_embeds,
|
| 874 |
+
output_attentions=output_attentions,
|
| 875 |
+
output_hidden_states=output_hidden_states,
|
| 876 |
+
return_dict=return_dict,
|
| 877 |
+
training=training,
|
| 878 |
+
)
|
| 879 |
+
hidden_states = transformer_outputs[0]
|
| 880 |
+
logits = self.score(hidden_states)
|
| 881 |
+
logits_shape = shape_list(logits)
|
| 882 |
+
batch_size = logits_shape[0]
|
| 883 |
+
|
| 884 |
+
if self.config.pad_token_id is None:
|
| 885 |
+
last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1)
|
| 886 |
+
else:
|
| 887 |
+
if input_ids is not None:
|
| 888 |
+
token_indices = tf.range(shape_list(input_ids)[-1])
|
| 889 |
+
non_pad_mask = tf.cast(input_ids != self.config.pad_token_id, token_indices.dtype)
|
| 890 |
+
last_non_pad_token = tf.reduce_max(token_indices * non_pad_mask, axis=-1)
|
| 891 |
+
else:
|
| 892 |
+
last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1)
|
| 893 |
+
logger.warning_once(
|
| 894 |
+
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
| 895 |
+
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
| 896 |
+
)
|
| 897 |
+
loss = None
|
| 898 |
+
|
| 899 |
+
pooled_logits = tf.gather(logits, last_non_pad_token, batch_dims=1, axis=1)
|
| 900 |
+
|
| 901 |
+
if labels is not None:
|
| 902 |
+
if self.config.pad_token_id is None and logits_shape[0] != 1:
|
| 903 |
+
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
| 904 |
+
|
| 905 |
+
loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(pooled_logits, [-1, self.num_labels]))
|
| 906 |
+
|
| 907 |
+
if not return_dict:
|
| 908 |
+
output = (pooled_logits,) + transformer_outputs[1:]
|
| 909 |
+
return ((loss,) + output) if loss is not None else output
|
| 910 |
+
|
| 911 |
+
return TFSequenceClassifierOutput(
|
| 912 |
+
loss=loss,
|
| 913 |
+
logits=pooled_logits,
|
| 914 |
+
hidden_states=transformer_outputs.hidden_states,
|
| 915 |
+
attentions=transformer_outputs.attentions,
|
| 916 |
+
)
|
| 917 |
+
|
| 918 |
+
def build(self, input_shape=None):
|
| 919 |
+
if self.built:
|
| 920 |
+
return
|
| 921 |
+
self.built = True
|
| 922 |
+
if getattr(self, "score", None) is not None:
|
| 923 |
+
with tf.name_scope(self.score.name):
|
| 924 |
+
self.score.build([None, None, self.config.n_embd])
|
| 925 |
+
if getattr(self, "transformer", None) is not None:
|
| 926 |
+
with tf.name_scope(self.transformer.name):
|
| 927 |
+
self.transformer.build(None)
|
| 928 |
+
|
| 929 |
+
|
| 930 |
+
__all__ = [
|
| 931 |
+
"TFOpenAIGPTDoubleHeadsModel",
|
| 932 |
+
"TFOpenAIGPTForSequenceClassification",
|
| 933 |
+
"TFOpenAIGPTLMHeadModel",
|
| 934 |
+
"TFOpenAIGPTMainLayer",
|
| 935 |
+
"TFOpenAIGPTModel",
|
| 936 |
+
"TFOpenAIGPTPreTrainedModel",
|
| 937 |
+
]
|
docs/transformers/src/transformers/models/openai/tokenization_openai.py
ADDED
|
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Tokenization classes for OpenAI GPT."""
|
| 16 |
+
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
import re
|
| 20 |
+
import unicodedata
|
| 21 |
+
from typing import Optional, Tuple
|
| 22 |
+
|
| 23 |
+
from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
|
| 24 |
+
from ...utils import logging
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
logger = logging.get_logger(__name__)
|
| 28 |
+
|
| 29 |
+
VOCAB_FILES_NAMES = {
|
| 30 |
+
"vocab_file": "vocab.json",
|
| 31 |
+
"merges_file": "merges.txt",
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# Copied from transformers.models.bert.tokenization_bert.whitespace_tokenize
|
| 36 |
+
def whitespace_tokenize(text):
|
| 37 |
+
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
| 38 |
+
text = text.strip()
|
| 39 |
+
if not text:
|
| 40 |
+
return []
|
| 41 |
+
tokens = text.split()
|
| 42 |
+
return tokens
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# Copied from transformers.models.bert.tokenization_bert.BasicTokenizer
|
| 46 |
+
class BasicTokenizer:
|
| 47 |
+
"""
|
| 48 |
+
Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
do_lower_case (`bool`, *optional*, defaults to `True`):
|
| 52 |
+
Whether or not to lowercase the input when tokenizing.
|
| 53 |
+
never_split (`Iterable`, *optional*):
|
| 54 |
+
Collection of tokens which will never be split during tokenization. Only has an effect when
|
| 55 |
+
`do_basic_tokenize=True`
|
| 56 |
+
tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
|
| 57 |
+
Whether or not to tokenize Chinese characters.
|
| 58 |
+
|
| 59 |
+
This should likely be deactivated for Japanese (see this
|
| 60 |
+
[issue](https://github.com/huggingface/transformers/issues/328)).
|
| 61 |
+
strip_accents (`bool`, *optional*):
|
| 62 |
+
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
|
| 63 |
+
value for `lowercase` (as in the original BERT).
|
| 64 |
+
do_split_on_punc (`bool`, *optional*, defaults to `True`):
|
| 65 |
+
In some instances we want to skip the basic punctuation splitting so that later tokenization can capture
|
| 66 |
+
the full context of the words, such as contractions.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
def __init__(
|
| 70 |
+
self,
|
| 71 |
+
do_lower_case=True,
|
| 72 |
+
never_split=None,
|
| 73 |
+
tokenize_chinese_chars=True,
|
| 74 |
+
strip_accents=None,
|
| 75 |
+
do_split_on_punc=True,
|
| 76 |
+
):
|
| 77 |
+
if never_split is None:
|
| 78 |
+
never_split = []
|
| 79 |
+
self.do_lower_case = do_lower_case
|
| 80 |
+
self.never_split = set(never_split)
|
| 81 |
+
self.tokenize_chinese_chars = tokenize_chinese_chars
|
| 82 |
+
self.strip_accents = strip_accents
|
| 83 |
+
self.do_split_on_punc = do_split_on_punc
|
| 84 |
+
|
| 85 |
+
def tokenize(self, text, never_split=None):
|
| 86 |
+
"""
|
| 87 |
+
Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
never_split (`List[str]`, *optional*)
|
| 91 |
+
Kept for backward compatibility purposes. Now implemented directly at the base class level (see
|
| 92 |
+
[`PreTrainedTokenizer.tokenize`]) List of token not to split.
|
| 93 |
+
"""
|
| 94 |
+
# union() returns a new set by concatenating the two sets.
|
| 95 |
+
never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
|
| 96 |
+
text = self._clean_text(text)
|
| 97 |
+
|
| 98 |
+
# This was added on November 1st, 2018 for the multilingual and Chinese
|
| 99 |
+
# models. This is also applied to the English models now, but it doesn't
|
| 100 |
+
# matter since the English models were not trained on any Chinese data
|
| 101 |
+
# and generally don't have any Chinese data in them (there are Chinese
|
| 102 |
+
# characters in the vocabulary because Wikipedia does have some Chinese
|
| 103 |
+
# words in the English Wikipedia.).
|
| 104 |
+
if self.tokenize_chinese_chars:
|
| 105 |
+
text = self._tokenize_chinese_chars(text)
|
| 106 |
+
# prevents treating the same character with different unicode codepoints as different characters
|
| 107 |
+
unicode_normalized_text = unicodedata.normalize("NFC", text)
|
| 108 |
+
orig_tokens = whitespace_tokenize(unicode_normalized_text)
|
| 109 |
+
split_tokens = []
|
| 110 |
+
for token in orig_tokens:
|
| 111 |
+
if token not in never_split:
|
| 112 |
+
if self.do_lower_case:
|
| 113 |
+
token = token.lower()
|
| 114 |
+
if self.strip_accents is not False:
|
| 115 |
+
token = self._run_strip_accents(token)
|
| 116 |
+
elif self.strip_accents:
|
| 117 |
+
token = self._run_strip_accents(token)
|
| 118 |
+
split_tokens.extend(self._run_split_on_punc(token, never_split))
|
| 119 |
+
|
| 120 |
+
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
| 121 |
+
return output_tokens
|
| 122 |
+
|
| 123 |
+
def _run_strip_accents(self, text):
|
| 124 |
+
"""Strips accents from a piece of text."""
|
| 125 |
+
text = unicodedata.normalize("NFD", text)
|
| 126 |
+
output = []
|
| 127 |
+
for char in text:
|
| 128 |
+
cat = unicodedata.category(char)
|
| 129 |
+
if cat == "Mn":
|
| 130 |
+
continue
|
| 131 |
+
output.append(char)
|
| 132 |
+
return "".join(output)
|
| 133 |
+
|
| 134 |
+
def _run_split_on_punc(self, text, never_split=None):
|
| 135 |
+
"""Splits punctuation on a piece of text."""
|
| 136 |
+
if not self.do_split_on_punc or (never_split is not None and text in never_split):
|
| 137 |
+
return [text]
|
| 138 |
+
chars = list(text)
|
| 139 |
+
i = 0
|
| 140 |
+
start_new_word = True
|
| 141 |
+
output = []
|
| 142 |
+
while i < len(chars):
|
| 143 |
+
char = chars[i]
|
| 144 |
+
if _is_punctuation(char):
|
| 145 |
+
output.append([char])
|
| 146 |
+
start_new_word = True
|
| 147 |
+
else:
|
| 148 |
+
if start_new_word:
|
| 149 |
+
output.append([])
|
| 150 |
+
start_new_word = False
|
| 151 |
+
output[-1].append(char)
|
| 152 |
+
i += 1
|
| 153 |
+
|
| 154 |
+
return ["".join(x) for x in output]
|
| 155 |
+
|
| 156 |
+
def _tokenize_chinese_chars(self, text):
|
| 157 |
+
"""Adds whitespace around any CJK character."""
|
| 158 |
+
output = []
|
| 159 |
+
for char in text:
|
| 160 |
+
cp = ord(char)
|
| 161 |
+
if self._is_chinese_char(cp):
|
| 162 |
+
output.append(" ")
|
| 163 |
+
output.append(char)
|
| 164 |
+
output.append(" ")
|
| 165 |
+
else:
|
| 166 |
+
output.append(char)
|
| 167 |
+
return "".join(output)
|
| 168 |
+
|
| 169 |
+
def _is_chinese_char(self, cp):
|
| 170 |
+
"""Checks whether CP is the codepoint of a CJK character."""
|
| 171 |
+
# This defines a "chinese character" as anything in the CJK Unicode block:
|
| 172 |
+
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
| 173 |
+
#
|
| 174 |
+
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
| 175 |
+
# despite its name. The modern Korean Hangul alphabet is a different block,
|
| 176 |
+
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
| 177 |
+
# space-separated words, so they are not treated specially and handled
|
| 178 |
+
# like the all of the other languages.
|
| 179 |
+
if (
|
| 180 |
+
(cp >= 0x4E00 and cp <= 0x9FFF)
|
| 181 |
+
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
| 182 |
+
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
| 183 |
+
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
| 184 |
+
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
| 185 |
+
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
| 186 |
+
or (cp >= 0xF900 and cp <= 0xFAFF)
|
| 187 |
+
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
| 188 |
+
): #
|
| 189 |
+
return True
|
| 190 |
+
|
| 191 |
+
return False
|
| 192 |
+
|
| 193 |
+
def _clean_text(self, text):
|
| 194 |
+
"""Performs invalid character removal and whitespace cleanup on text."""
|
| 195 |
+
output = []
|
| 196 |
+
for char in text:
|
| 197 |
+
cp = ord(char)
|
| 198 |
+
if cp == 0 or cp == 0xFFFD or _is_control(char):
|
| 199 |
+
continue
|
| 200 |
+
if _is_whitespace(char):
|
| 201 |
+
output.append(" ")
|
| 202 |
+
else:
|
| 203 |
+
output.append(char)
|
| 204 |
+
return "".join(output)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def get_pairs(word):
|
| 208 |
+
"""
|
| 209 |
+
Return set of symbol pairs in a word. word is represented as tuple of symbols (symbols being variable-length
|
| 210 |
+
strings)
|
| 211 |
+
"""
|
| 212 |
+
pairs = set()
|
| 213 |
+
prev_char = word[0]
|
| 214 |
+
for char in word[1:]:
|
| 215 |
+
pairs.add((prev_char, char))
|
| 216 |
+
prev_char = char
|
| 217 |
+
return pairs
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def text_standardize(text):
|
| 221 |
+
"""
|
| 222 |
+
fixes some issues the spacy tokenizer had on books corpus also does some whitespace standardization
|
| 223 |
+
"""
|
| 224 |
+
text = text.replace("—", "-")
|
| 225 |
+
text = text.replace("–", "-")
|
| 226 |
+
text = text.replace("―", "-")
|
| 227 |
+
text = text.replace("…", "...")
|
| 228 |
+
text = text.replace("´", "'")
|
| 229 |
+
text = re.sub(r"""(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)""", r" \1 ", text)
|
| 230 |
+
text = re.sub(r"\s*\n\s*", " \n ", text)
|
| 231 |
+
text = re.sub(r"[^\S\n]+", " ", text)
|
| 232 |
+
return text.strip()
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class OpenAIGPTTokenizer(PreTrainedTokenizer):
|
| 236 |
+
"""
|
| 237 |
+
Construct a GPT Tokenizer. Based on Byte-Pair-Encoding with the following peculiarities:
|
| 238 |
+
|
| 239 |
+
- lowercases all inputs,
|
| 240 |
+
- uses `SpaCy` tokenizer and `ftfy` for pre-BPE tokenization if they are installed, fallback to BERT's
|
| 241 |
+
`BasicTokenizer` if not.
|
| 242 |
+
|
| 243 |
+
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
| 244 |
+
this superclass for more information regarding those methods.
|
| 245 |
+
|
| 246 |
+
Args:
|
| 247 |
+
vocab_file (`str`):
|
| 248 |
+
Path to the vocabulary file.
|
| 249 |
+
merges_file (`str`):
|
| 250 |
+
Path to the merges file.
|
| 251 |
+
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
| 252 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 253 |
+
token instead.
|
| 254 |
+
"""
|
| 255 |
+
|
| 256 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 257 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 258 |
+
|
| 259 |
+
def __init__(self, vocab_file, merges_file, unk_token="<unk>", **kwargs):
|
| 260 |
+
try:
|
| 261 |
+
import ftfy
|
| 262 |
+
from spacy.lang.en import English
|
| 263 |
+
|
| 264 |
+
_nlp = English()
|
| 265 |
+
self.nlp = _nlp.tokenizer
|
| 266 |
+
self.fix_text = ftfy.fix_text
|
| 267 |
+
except ImportError:
|
| 268 |
+
logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.")
|
| 269 |
+
self.nlp = BasicTokenizer(do_lower_case=True)
|
| 270 |
+
self.fix_text = None
|
| 271 |
+
|
| 272 |
+
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
| 273 |
+
self.encoder = json.load(vocab_handle)
|
| 274 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
| 275 |
+
with open(merges_file, encoding="utf-8") as merges_handle:
|
| 276 |
+
merges = merges_handle.read().split("\n")[1:-1]
|
| 277 |
+
merges = [tuple(merge.split()) for merge in merges]
|
| 278 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
| 279 |
+
self.cache = {}
|
| 280 |
+
|
| 281 |
+
super().__init__(unk_token=unk_token, **kwargs)
|
| 282 |
+
|
| 283 |
+
@property
|
| 284 |
+
def do_lower_case(self):
|
| 285 |
+
return True
|
| 286 |
+
|
| 287 |
+
@property
|
| 288 |
+
def vocab_size(self):
|
| 289 |
+
return len(self.encoder)
|
| 290 |
+
|
| 291 |
+
def get_vocab(self):
|
| 292 |
+
return dict(self.encoder, **self.added_tokens_encoder)
|
| 293 |
+
|
| 294 |
+
def bpe(self, token):
|
| 295 |
+
word = tuple(token[:-1]) + (token[-1] + "</w>",)
|
| 296 |
+
if token in self.cache:
|
| 297 |
+
return self.cache[token]
|
| 298 |
+
pairs = get_pairs(word)
|
| 299 |
+
|
| 300 |
+
if not pairs:
|
| 301 |
+
return token + "</w>"
|
| 302 |
+
|
| 303 |
+
while True:
|
| 304 |
+
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
| 305 |
+
if bigram not in self.bpe_ranks:
|
| 306 |
+
break
|
| 307 |
+
first, second = bigram
|
| 308 |
+
new_word = []
|
| 309 |
+
i = 0
|
| 310 |
+
while i < len(word):
|
| 311 |
+
try:
|
| 312 |
+
j = word.index(first, i)
|
| 313 |
+
except ValueError:
|
| 314 |
+
new_word.extend(word[i:])
|
| 315 |
+
break
|
| 316 |
+
else:
|
| 317 |
+
new_word.extend(word[i:j])
|
| 318 |
+
i = j
|
| 319 |
+
|
| 320 |
+
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
| 321 |
+
new_word.append(first + second)
|
| 322 |
+
i += 2
|
| 323 |
+
else:
|
| 324 |
+
new_word.append(word[i])
|
| 325 |
+
i += 1
|
| 326 |
+
new_word = tuple(new_word)
|
| 327 |
+
word = new_word
|
| 328 |
+
if len(word) == 1:
|
| 329 |
+
break
|
| 330 |
+
else:
|
| 331 |
+
pairs = get_pairs(word)
|
| 332 |
+
word = " ".join(word)
|
| 333 |
+
if word == "\n </w>":
|
| 334 |
+
word = "\n</w>"
|
| 335 |
+
self.cache[token] = word
|
| 336 |
+
return word
|
| 337 |
+
|
| 338 |
+
def _tokenize(self, text):
|
| 339 |
+
"""Tokenize a string."""
|
| 340 |
+
split_tokens = []
|
| 341 |
+
if self.fix_text is None:
|
| 342 |
+
# Using BERT's BasicTokenizer
|
| 343 |
+
text = self.nlp.tokenize(text)
|
| 344 |
+
for token in text:
|
| 345 |
+
split_tokens.extend(list(self.bpe(token).split(" ")))
|
| 346 |
+
else:
|
| 347 |
+
# Using SpaCy & ftfy (original tokenization process of OpenAI GPT)
|
| 348 |
+
text = self.nlp(text_standardize(self.fix_text(text)))
|
| 349 |
+
for token in text:
|
| 350 |
+
split_tokens.extend(list(self.bpe(token.text.lower()).split(" ")))
|
| 351 |
+
return split_tokens
|
| 352 |
+
|
| 353 |
+
def _convert_token_to_id(self, token):
|
| 354 |
+
"""Converts a token (str) in an id using the vocab."""
|
| 355 |
+
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
| 356 |
+
|
| 357 |
+
def _convert_id_to_token(self, index):
|
| 358 |
+
"""Converts an id in a token (BPE) using the vocab."""
|
| 359 |
+
return self.decoder.get(index, self.unk_token)
|
| 360 |
+
|
| 361 |
+
def convert_tokens_to_string(self, tokens):
|
| 362 |
+
"""Converts a sequence of tokens (string) in a single string."""
|
| 363 |
+
out_string = "".join(tokens).replace("</w>", " ").strip()
|
| 364 |
+
return out_string
|
| 365 |
+
|
| 366 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 367 |
+
if not os.path.isdir(save_directory):
|
| 368 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
| 369 |
+
return
|
| 370 |
+
vocab_file = os.path.join(
|
| 371 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 372 |
+
)
|
| 373 |
+
merge_file = os.path.join(
|
| 374 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
with open(vocab_file, "w", encoding="utf-8") as f:
|
| 378 |
+
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
|
| 379 |
+
|
| 380 |
+
index = 0
|
| 381 |
+
with open(merge_file, "w", encoding="utf-8") as writer:
|
| 382 |
+
writer.write("#version: 0.2\n")
|
| 383 |
+
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
|
| 384 |
+
if index != token_index:
|
| 385 |
+
logger.warning(
|
| 386 |
+
f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
|
| 387 |
+
" Please check that the tokenizer is not corrupted!"
|
| 388 |
+
)
|
| 389 |
+
index = token_index
|
| 390 |
+
writer.write(" ".join(bpe_tokens) + "\n")
|
| 391 |
+
index += 1
|
| 392 |
+
|
| 393 |
+
return vocab_file, merge_file
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
__all__ = ["OpenAIGPTTokenizer"]
|
docs/transformers/src/transformers/models/openai/tokenization_openai_fast.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Fast Tokenization classes for OpenAI GPT."""
|
| 16 |
+
|
| 17 |
+
from typing import Optional, Tuple
|
| 18 |
+
|
| 19 |
+
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
| 20 |
+
from ...utils import logging
|
| 21 |
+
from .tokenization_openai import OpenAIGPTTokenizer
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.get_logger(__name__)
|
| 25 |
+
|
| 26 |
+
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class OpenAIGPTTokenizerFast(PreTrainedTokenizerFast):
|
| 30 |
+
"""
|
| 31 |
+
Construct a "fast" GPT Tokenizer (backed by HuggingFace's *tokenizers* library). Based on Byte-Pair-Encoding with
|
| 32 |
+
the following peculiarities:
|
| 33 |
+
|
| 34 |
+
- lower case all inputs
|
| 35 |
+
- uses BERT's BasicTokenizer for pre-BPE tokenization
|
| 36 |
+
|
| 37 |
+
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
|
| 38 |
+
refer to this superclass for more information regarding those methods.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
vocab_file (`str`):
|
| 42 |
+
Path to the vocabulary file.
|
| 43 |
+
merges_file (`str`):
|
| 44 |
+
Path to the merges file.
|
| 45 |
+
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
| 46 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 47 |
+
token instead.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 51 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 52 |
+
slow_tokenizer_class = OpenAIGPTTokenizer
|
| 53 |
+
|
| 54 |
+
def __init__(self, vocab_file=None, merges_file=None, tokenizer_file=None, unk_token="<unk>", **kwargs):
|
| 55 |
+
super().__init__(vocab_file, merges_file, tokenizer_file=tokenizer_file, unk_token=unk_token, **kwargs)
|
| 56 |
+
|
| 57 |
+
@property
|
| 58 |
+
def do_lower_case(self):
|
| 59 |
+
return True
|
| 60 |
+
|
| 61 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 62 |
+
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
|
| 63 |
+
return tuple(files)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
__all__ = ["OpenAIGPTTokenizerFast"]
|
docs/transformers/src/transformers/models/opt/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_opt import *
|
| 22 |
+
from .modeling_flax_opt import *
|
| 23 |
+
from .modeling_opt import *
|
| 24 |
+
from .modeling_tf_opt import *
|
| 25 |
+
else:
|
| 26 |
+
import sys
|
| 27 |
+
|
| 28 |
+
_file = globals()["__file__"]
|
| 29 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|