Student0809 commited on
Commit
d377feb
·
verified ·
1 Parent(s): da4d9dc

Add files using upload-large-folder tool

Browse files
Files changed (20) hide show
  1. docs/transformers/build/lib/transformers/models/regnet/convert_regnet_seer_10b_to_pytorch.py +308 -0
  2. docs/transformers/build/lib/transformers/models/regnet/convert_regnet_to_pytorch.py +458 -0
  3. docs/transformers/build/lib/transformers/models/rembert/configuration_rembert.py +162 -0
  4. docs/transformers/build/lib/transformers/models/rembert/convert_rembert_tf_checkpoint_to_pytorch.py +62 -0
  5. docs/transformers/build/lib/transformers/models/rembert/modeling_rembert.py +1525 -0
  6. docs/transformers/build/lib/transformers/models/roberta/tokenization_roberta.py +402 -0
  7. docs/transformers/build/lib/transformers/models/roberta_prelayernorm/__init__.py +29 -0
  8. docs/transformers/build/lib/transformers/models/roberta_prelayernorm/convert_roberta_prelayernorm_original_pytorch_checkpoint_to_pytorch.py +79 -0
  9. docs/transformers/build/lib/transformers/models/roberta_prelayernorm/modeling_flax_roberta_prelayernorm.py +1527 -0
  10. docs/transformers/build/lib/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +1558 -0
  11. docs/transformers/build/lib/transformers/models/roc_bert/__init__.py +28 -0
  12. docs/transformers/build/lib/transformers/models/roc_bert/configuration_roc_bert.py +163 -0
  13. docs/transformers/build/lib/transformers/models/roc_bert/modeling_roc_bert.py +2017 -0
  14. docs/transformers/build/lib/transformers/models/roformer/__init__.py +31 -0
  15. docs/transformers/build/lib/transformers/models/roformer/configuration_roformer.py +150 -0
  16. docs/transformers/build/lib/transformers/models/roformer/convert_roformer_original_tf_checkpoint_to_pytorch.py +62 -0
  17. docs/transformers/build/lib/transformers/models/roformer/modeling_roformer.py +1660 -0
  18. docs/transformers/build/lib/transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +379 -0
  19. docs/transformers/build/lib/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +0 -0
  20. docs/transformers/build/lib/transformers/models/sam/convert_sam_to_hf.py +251 -0
docs/transformers/build/lib/transformers/models/regnet/convert_regnet_seer_10b_to_pytorch.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert RegNet 10B checkpoints vissl."""
16
+ # You need to install a specific version of classy vision
17
+ # pip install git+https://github.com/FrancescoSaverioZuppichini/ClassyVision.git@convert_weights
18
+
19
+ import argparse
20
+ import json
21
+ import os
22
+ import re
23
+ from collections import OrderedDict
24
+ from dataclasses import dataclass, field
25
+ from functools import partial
26
+ from pathlib import Path
27
+ from pprint import pprint
28
+ from typing import Dict, List, Optional, Tuple
29
+
30
+ import torch
31
+ import torch.nn as nn
32
+ from classy_vision.models.regnet import RegNet, RegNetParams
33
+ from huggingface_hub import hf_hub_download
34
+ from torch import Tensor
35
+ from vissl.models.model_helpers import get_trunk_forward_outputs
36
+
37
+ from transformers import AutoImageProcessor, RegNetConfig, RegNetForImageClassification, RegNetModel
38
+ from transformers.modeling_utils import _load_state_dict_into_meta_model, load_state_dict
39
+ from transformers.utils import logging
40
+
41
+
42
+ logging.set_verbosity_info()
43
+ logger = logging.get_logger()
44
+
45
+
46
+ @dataclass
47
+ class Tracker:
48
+ module: nn.Module
49
+ traced: List[nn.Module] = field(default_factory=list)
50
+ handles: list = field(default_factory=list)
51
+ name2module: Dict[str, nn.Module] = field(default_factory=OrderedDict)
52
+
53
+ def _forward_hook(self, m, inputs: Tensor, outputs: Tensor, name: str):
54
+ has_not_submodules = len(list(m.modules())) == 1 or isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d)
55
+ if has_not_submodules:
56
+ self.traced.append(m)
57
+ self.name2module[name] = m
58
+
59
+ def __call__(self, x: Tensor):
60
+ for name, m in self.module.named_modules():
61
+ self.handles.append(m.register_forward_hook(partial(self._forward_hook, name=name)))
62
+ self.module(x)
63
+ [x.remove() for x in self.handles]
64
+ return self
65
+
66
+ @property
67
+ def parametrized(self):
68
+ # check the len of the state_dict keys to see if we have learnable params
69
+ return {k: v for k, v in self.name2module.items() if len(list(v.state_dict().keys())) > 0}
70
+
71
+
72
+ class FakeRegNetVisslWrapper(nn.Module):
73
+ """
74
+ Fake wrapper for RegNet that mimics what vissl does without the need to pass a config file.
75
+ """
76
+
77
+ def __init__(self, model: nn.Module):
78
+ super().__init__()
79
+
80
+ feature_blocks: List[Tuple[str, nn.Module]] = []
81
+ # - get the stem
82
+ feature_blocks.append(("conv1", model.stem))
83
+ # - get all the feature blocks
84
+ for k, v in model.trunk_output.named_children():
85
+ assert k.startswith("block"), f"Unexpected layer name {k}"
86
+ block_index = len(feature_blocks) + 1
87
+ feature_blocks.append((f"res{block_index}", v))
88
+
89
+ self._feature_blocks = nn.ModuleDict(feature_blocks)
90
+
91
+ def forward(self, x: Tensor):
92
+ return get_trunk_forward_outputs(
93
+ x,
94
+ out_feat_keys=None,
95
+ feature_blocks=self._feature_blocks,
96
+ )
97
+
98
+
99
+ class FakeRegNetParams(RegNetParams):
100
+ """
101
+ Used to instantiace a RegNet model from classy vision with the same depth as the 10B one but with super small
102
+ parameters, so we can trace it in memory.
103
+ """
104
+
105
+ def get_expanded_params(self):
106
+ return [(8, 2, 2, 8, 1.0), (8, 2, 7, 8, 1.0), (8, 2, 17, 8, 1.0), (8, 2, 1, 8, 1.0)]
107
+
108
+
109
+ def get_from_to_our_keys(model_name: str) -> Dict[str, str]:
110
+ """
111
+ Returns a dictionary that maps from original model's key -> our implementation's keys
112
+ """
113
+
114
+ # create our model (with small weights)
115
+ our_config = RegNetConfig(depths=[2, 7, 17, 1], hidden_sizes=[8, 8, 8, 8], groups_width=8)
116
+ if "in1k" in model_name:
117
+ our_model = RegNetForImageClassification(our_config)
118
+ else:
119
+ our_model = RegNetModel(our_config)
120
+ # create from model (with small weights)
121
+ from_model = FakeRegNetVisslWrapper(
122
+ RegNet(FakeRegNetParams(depth=27, group_width=1010, w_0=1744, w_a=620.83, w_m=2.52))
123
+ )
124
+
125
+ with torch.no_grad():
126
+ from_model = from_model.eval()
127
+ our_model = our_model.eval()
128
+
129
+ x = torch.randn((1, 3, 32, 32))
130
+ # trace both
131
+ dest_tracker = Tracker(our_model)
132
+ dest_traced = dest_tracker(x).parametrized
133
+
134
+ pprint(dest_tracker.name2module)
135
+ src_tracker = Tracker(from_model)
136
+ src_traced = src_tracker(x).parametrized
137
+
138
+ # convert the keys -> module dict to keys -> params
139
+ def to_params_dict(dict_with_modules):
140
+ params_dict = OrderedDict()
141
+ for name, module in dict_with_modules.items():
142
+ for param_name, param in module.state_dict().items():
143
+ params_dict[f"{name}.{param_name}"] = param
144
+ return params_dict
145
+
146
+ from_to_ours_keys = {}
147
+
148
+ src_state_dict = to_params_dict(src_traced)
149
+ dst_state_dict = to_params_dict(dest_traced)
150
+
151
+ for (src_key, src_param), (dest_key, dest_param) in zip(src_state_dict.items(), dst_state_dict.items()):
152
+ from_to_ours_keys[src_key] = dest_key
153
+ logger.info(f"{src_key} -> {dest_key}")
154
+ # if "in1k" was in the model_name it means it must have a classification head (was finetuned)
155
+ if "in1k" in model_name:
156
+ from_to_ours_keys["0.clf.0.weight"] = "classifier.1.weight"
157
+ from_to_ours_keys["0.clf.0.bias"] = "classifier.1.bias"
158
+
159
+ return from_to_ours_keys
160
+
161
+
162
+ def convert_weights_and_push(save_directory: Path, model_name: Optional[str] = None, push_to_hub: bool = True):
163
+ filename = "imagenet-1k-id2label.json"
164
+ num_labels = 1000
165
+
166
+ repo_id = "huggingface/label-files"
167
+ num_labels = num_labels
168
+ id2label = json.loads(Path(hf_hub_download(repo_id, filename, repo_type="dataset")).read_text())
169
+ id2label = {int(k): v for k, v in id2label.items()}
170
+
171
+ id2label = id2label
172
+ label2id = {v: k for k, v in id2label.items()}
173
+
174
+ ImageNetPreTrainedConfig = partial(RegNetConfig, num_labels=num_labels, id2label=id2label, label2id=label2id)
175
+
176
+ names_to_config = {
177
+ "regnet-y-10b-seer": ImageNetPreTrainedConfig(
178
+ depths=[2, 7, 17, 1], hidden_sizes=[2020, 4040, 11110, 28280], groups_width=1010
179
+ ),
180
+ # finetuned on imagenet
181
+ "regnet-y-10b-seer-in1k": ImageNetPreTrainedConfig(
182
+ depths=[2, 7, 17, 1], hidden_sizes=[2020, 4040, 11110, 28280], groups_width=1010
183
+ ),
184
+ }
185
+
186
+ # add seer weights logic
187
+ def load_using_classy_vision(checkpoint_url: str) -> Tuple[Dict, Dict]:
188
+ files = torch.hub.load_state_dict_from_url(checkpoint_url, model_dir=str(save_directory), map_location="cpu")
189
+ # check if we have a head, if yes add it
190
+ model_state_dict = files["classy_state_dict"]["base_model"]["model"]
191
+ return model_state_dict["trunk"], model_state_dict["heads"]
192
+
193
+ names_to_from_model = {
194
+ "regnet-y-10b-seer": partial(
195
+ load_using_classy_vision,
196
+ "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_regnet10B/model_iteration124500_conso.torch",
197
+ ),
198
+ "regnet-y-10b-seer-in1k": partial(
199
+ load_using_classy_vision,
200
+ "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_10b_finetuned_in1k_model_phase28_conso.torch",
201
+ ),
202
+ }
203
+
204
+ from_to_ours_keys = get_from_to_our_keys(model_name)
205
+
206
+ if not (save_directory / f"{model_name}.pth").exists():
207
+ logger.info("Loading original state_dict.")
208
+ from_state_dict_trunk, from_state_dict_head = names_to_from_model[model_name]()
209
+ from_state_dict = from_state_dict_trunk
210
+ if "in1k" in model_name:
211
+ # add the head
212
+ from_state_dict = {**from_state_dict_trunk, **from_state_dict_head}
213
+ logger.info("Done!")
214
+
215
+ converted_state_dict = {}
216
+
217
+ not_used_keys = list(from_state_dict.keys())
218
+ regex = r"\.block.-part."
219
+ # this is "interesting", so the original checkpoints have `block[0,1]-part` in each key name, we remove it
220
+ for key in from_state_dict.keys():
221
+ # remove the weird "block[0,1]-part" from the key
222
+ src_key = re.sub(regex, "", key)
223
+ # now src_key from the model checkpoints is the one we got from the original model after tracing, so use it to get the correct destination key
224
+ dest_key = from_to_ours_keys[src_key]
225
+ # store the parameter with our key
226
+ converted_state_dict[dest_key] = from_state_dict[key]
227
+ not_used_keys.remove(key)
228
+ # check that all keys have been updated
229
+ assert len(not_used_keys) == 0, f"Some keys where not used {','.join(not_used_keys)}"
230
+
231
+ logger.info(f"The following keys were not used: {','.join(not_used_keys)}")
232
+
233
+ # save our state dict to disk
234
+ torch.save(converted_state_dict, save_directory / f"{model_name}.pth")
235
+
236
+ del converted_state_dict
237
+ else:
238
+ logger.info("The state_dict was already stored on disk.")
239
+ if push_to_hub:
240
+ logger.info(f"Token is {os.environ['HF_TOKEN']}")
241
+ logger.info("Loading our model.")
242
+ # create our model
243
+ our_config = names_to_config[model_name]
244
+ our_model_func = RegNetModel
245
+ if "in1k" in model_name:
246
+ our_model_func = RegNetForImageClassification
247
+ with torch.device("meta"):
248
+ our_model = our_model_func(our_config)
249
+ logger.info("Loading state_dict in our model.")
250
+ # load state dict
251
+ state_dict_keys = our_model.state_dict().keys()
252
+ state_dict = load_state_dict(save_directory / f"{model_name}.pth", weights_only=True)
253
+ fixed_state_dict = state_dict = {our_model._fix_state_dict_key_on_load(k)[0]: v for k, v in state_dict.items()}
254
+ _load_state_dict_into_meta_model(
255
+ our_model,
256
+ fixed_state_dict,
257
+ start_prefix="",
258
+ expected_keys=state_dict_keys,
259
+ )
260
+ logger.info("Finally, pushing!")
261
+ # push it to hub
262
+ our_model.push_to_hub(
263
+ repo_path_or_name=save_directory / model_name,
264
+ commit_message="Add model",
265
+ output_dir=save_directory / model_name,
266
+ )
267
+ size = 384
268
+ # we can use the convnext one
269
+ image_processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224-22k-1k", size=size)
270
+ image_processor.push_to_hub(
271
+ repo_path_or_name=save_directory / model_name,
272
+ commit_message="Add image processor",
273
+ output_dir=save_directory / model_name,
274
+ )
275
+
276
+
277
+ if __name__ == "__main__":
278
+ parser = argparse.ArgumentParser()
279
+ # Required parameters
280
+ parser.add_argument(
281
+ "--model_name",
282
+ default=None,
283
+ type=str,
284
+ help=(
285
+ "The name of the model you wish to convert, it must be one of the supported regnet* architecture,"
286
+ " currently: regnetx-*, regnety-*. If `None`, all of them will the converted."
287
+ ),
288
+ )
289
+ parser.add_argument(
290
+ "--pytorch_dump_folder_path",
291
+ default=None,
292
+ type=Path,
293
+ required=True,
294
+ help="Path to the output PyTorch model directory.",
295
+ )
296
+ parser.add_argument(
297
+ "--push_to_hub",
298
+ default=True,
299
+ type=bool,
300
+ required=False,
301
+ help="If True, push model and image processor to the hub.",
302
+ )
303
+
304
+ args = parser.parse_args()
305
+
306
+ pytorch_dump_folder_path: Path = args.pytorch_dump_folder_path
307
+ pytorch_dump_folder_path.mkdir(exist_ok=True, parents=True)
308
+ convert_weights_and_push(pytorch_dump_folder_path, args.model_name, args.push_to_hub)
docs/transformers/build/lib/transformers/models/regnet/convert_regnet_to_pytorch.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert RegNet checkpoints from timm and vissl."""
16
+
17
+ import argparse
18
+ import json
19
+ from dataclasses import dataclass, field
20
+ from functools import partial
21
+ from pathlib import Path
22
+ from typing import Callable, Dict, List, Optional, Tuple
23
+
24
+ import timm
25
+ import torch
26
+ import torch.nn as nn
27
+ from classy_vision.models.regnet import RegNet, RegNetParams, RegNetY32gf, RegNetY64gf, RegNetY128gf
28
+ from huggingface_hub import hf_hub_download
29
+ from torch import Tensor
30
+ from vissl.models.model_helpers import get_trunk_forward_outputs
31
+
32
+ from transformers import AutoImageProcessor, RegNetConfig, RegNetForImageClassification, RegNetModel
33
+ from transformers.utils import logging
34
+
35
+
36
+ logging.set_verbosity_info()
37
+ logger = logging.get_logger()
38
+
39
+
40
+ @dataclass
41
+ class Tracker:
42
+ module: nn.Module
43
+ traced: List[nn.Module] = field(default_factory=list)
44
+ handles: list = field(default_factory=list)
45
+
46
+ def _forward_hook(self, m, inputs: Tensor, outputs: Tensor):
47
+ has_not_submodules = len(list(m.modules())) == 1 or isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d)
48
+ if has_not_submodules:
49
+ self.traced.append(m)
50
+
51
+ def __call__(self, x: Tensor):
52
+ for m in self.module.modules():
53
+ self.handles.append(m.register_forward_hook(self._forward_hook))
54
+ self.module(x)
55
+ [x.remove() for x in self.handles]
56
+ return self
57
+
58
+ @property
59
+ def parametrized(self):
60
+ # check the len of the state_dict keys to see if we have learnable params
61
+ return list(filter(lambda x: len(list(x.state_dict().keys())) > 0, self.traced))
62
+
63
+
64
+ @dataclass
65
+ class ModuleTransfer:
66
+ src: nn.Module
67
+ dest: nn.Module
68
+ verbose: int = 1
69
+ src_skip: List = field(default_factory=list)
70
+ dest_skip: List = field(default_factory=list)
71
+ raise_if_mismatch: bool = True
72
+
73
+ def __call__(self, x: Tensor):
74
+ """
75
+ Transfer the weights of `self.src` to `self.dest` by performing a forward pass using `x` as input. Under the
76
+ hood we tracked all the operations in both modules.
77
+ """
78
+ dest_traced = Tracker(self.dest)(x).parametrized
79
+ src_traced = Tracker(self.src)(x).parametrized
80
+
81
+ src_traced = list(filter(lambda x: type(x) not in self.src_skip, src_traced))
82
+ dest_traced = list(filter(lambda x: type(x) not in self.dest_skip, dest_traced))
83
+
84
+ if len(dest_traced) != len(src_traced) and self.raise_if_mismatch:
85
+ raise Exception(
86
+ f"Numbers of operations are different. Source module has {len(src_traced)} operations while"
87
+ f" destination module has {len(dest_traced)}."
88
+ )
89
+
90
+ for dest_m, src_m in zip(dest_traced, src_traced):
91
+ dest_m.load_state_dict(src_m.state_dict())
92
+ if self.verbose == 1:
93
+ print(f"Transfered from={src_m} to={dest_m}")
94
+
95
+
96
+ class FakeRegNetVisslWrapper(nn.Module):
97
+ """
98
+ Fake wrapper for RegNet that mimics what vissl does without the need to pass a config file.
99
+ """
100
+
101
+ def __init__(self, model: nn.Module):
102
+ super().__init__()
103
+
104
+ feature_blocks: List[Tuple[str, nn.Module]] = []
105
+ # - get the stem
106
+ feature_blocks.append(("conv1", model.stem))
107
+ # - get all the feature blocks
108
+ for k, v in model.trunk_output.named_children():
109
+ assert k.startswith("block"), f"Unexpected layer name {k}"
110
+ block_index = len(feature_blocks) + 1
111
+ feature_blocks.append((f"res{block_index}", v))
112
+
113
+ self._feature_blocks = nn.ModuleDict(feature_blocks)
114
+
115
+ def forward(self, x: Tensor):
116
+ return get_trunk_forward_outputs(
117
+ x,
118
+ out_feat_keys=None,
119
+ feature_blocks=self._feature_blocks,
120
+ )
121
+
122
+
123
+ class NameToFromModelFuncMap(dict):
124
+ """
125
+ A Dictionary with some additional logic to return a function that creates the correct original model.
126
+ """
127
+
128
+ def convert_name_to_timm(self, x: str) -> str:
129
+ x_split = x.split("-")
130
+ return x_split[0] + x_split[1] + "_" + "".join(x_split[2:])
131
+
132
+ def __getitem__(self, x: str) -> Callable[[], Tuple[nn.Module, Dict]]:
133
+ # default to timm!
134
+ if x not in self:
135
+ x = self.convert_name_to_timm(x)
136
+ val = partial(lambda: (timm.create_model(x, pretrained=True).eval(), None))
137
+
138
+ else:
139
+ val = super().__getitem__(x)
140
+
141
+ return val
142
+
143
+
144
+ class NameToOurModelFuncMap(dict):
145
+ """
146
+ A Dictionary with some additional logic to return the correct hugging face RegNet class reference.
147
+ """
148
+
149
+ def __getitem__(self, x: str) -> Callable[[], nn.Module]:
150
+ if "seer" in x and "in1k" not in x:
151
+ val = RegNetModel
152
+ else:
153
+ val = RegNetForImageClassification
154
+ return val
155
+
156
+
157
+ def manually_copy_vissl_head(from_state_dict, to_state_dict, keys: List[Tuple[str, str]]):
158
+ for from_key, to_key in keys:
159
+ to_state_dict[to_key] = from_state_dict[from_key].clone()
160
+ print(f"Copied key={from_key} to={to_key}")
161
+ return to_state_dict
162
+
163
+
164
+ def convert_weight_and_push(
165
+ name: str,
166
+ from_model_func: Callable[[], nn.Module],
167
+ our_model_func: Callable[[], nn.Module],
168
+ config: RegNetConfig,
169
+ save_directory: Path,
170
+ push_to_hub: bool = True,
171
+ ):
172
+ print(f"Converting {name}...")
173
+ with torch.no_grad():
174
+ from_model, from_state_dict = from_model_func()
175
+ our_model = our_model_func(config).eval()
176
+ module_transfer = ModuleTransfer(src=from_model, dest=our_model, raise_if_mismatch=False)
177
+ x = torch.randn((1, 3, 224, 224))
178
+ module_transfer(x)
179
+
180
+ if from_state_dict is not None:
181
+ keys = []
182
+ # for seer - in1k finetuned we have to manually copy the head
183
+ if "seer" in name and "in1k" in name:
184
+ keys = [("0.clf.0.weight", "classifier.1.weight"), ("0.clf.0.bias", "classifier.1.bias")]
185
+ to_state_dict = manually_copy_vissl_head(from_state_dict, our_model.state_dict(), keys)
186
+ our_model.load_state_dict(to_state_dict)
187
+
188
+ our_outputs = our_model(x, output_hidden_states=True)
189
+ our_output = (
190
+ our_outputs.logits if isinstance(our_model, RegNetForImageClassification) else our_outputs.last_hidden_state
191
+ )
192
+
193
+ from_output = from_model(x)
194
+ from_output = from_output[-1] if isinstance(from_output, list) else from_output
195
+
196
+ # now since I don't want to use any config files, vissl seer model doesn't actually have an head, so let's just check the last hidden state
197
+ if "seer" in name and "in1k" in name:
198
+ our_output = our_outputs.hidden_states[-1]
199
+
200
+ assert torch.allclose(from_output, our_output), "The model logits don't match the original one."
201
+
202
+ if push_to_hub:
203
+ our_model.push_to_hub(
204
+ repo_path_or_name=save_directory / name,
205
+ commit_message="Add model",
206
+ use_temp_dir=True,
207
+ )
208
+
209
+ size = 224 if "seer" not in name else 384
210
+ # we can use the convnext one
211
+ image_processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224-22k-1k", size=size)
212
+ image_processor.push_to_hub(
213
+ repo_path_or_name=save_directory / name,
214
+ commit_message="Add image processor",
215
+ use_temp_dir=True,
216
+ )
217
+
218
+ print(f"Pushed {name}")
219
+
220
+
221
+ def convert_weights_and_push(save_directory: Path, model_name: Optional[str] = None, push_to_hub: bool = True):
222
+ filename = "imagenet-1k-id2label.json"
223
+ num_labels = 1000
224
+ expected_shape = (1, num_labels)
225
+
226
+ repo_id = "huggingface/label-files"
227
+ num_labels = num_labels
228
+ id2label = json.loads(Path(hf_hub_download(repo_id, filename, repo_type="dataset")).read_text())
229
+ id2label = {int(k): v for k, v in id2label.items()}
230
+
231
+ id2label = id2label
232
+ label2id = {v: k for k, v in id2label.items()}
233
+
234
+ ImageNetPreTrainedConfig = partial(RegNetConfig, num_labels=num_labels, id2label=id2label, label2id=label2id)
235
+
236
+ names_to_config = {
237
+ "regnet-x-002": ImageNetPreTrainedConfig(
238
+ depths=[1, 1, 4, 7], hidden_sizes=[24, 56, 152, 368], groups_width=8, layer_type="x"
239
+ ),
240
+ "regnet-x-004": ImageNetPreTrainedConfig(
241
+ depths=[1, 2, 7, 12], hidden_sizes=[32, 64, 160, 384], groups_width=16, layer_type="x"
242
+ ),
243
+ "regnet-x-006": ImageNetPreTrainedConfig(
244
+ depths=[1, 3, 5, 7], hidden_sizes=[48, 96, 240, 528], groups_width=24, layer_type="x"
245
+ ),
246
+ "regnet-x-008": ImageNetPreTrainedConfig(
247
+ depths=[1, 3, 7, 5], hidden_sizes=[64, 128, 288, 672], groups_width=16, layer_type="x"
248
+ ),
249
+ "regnet-x-016": ImageNetPreTrainedConfig(
250
+ depths=[2, 4, 10, 2], hidden_sizes=[72, 168, 408, 912], groups_width=24, layer_type="x"
251
+ ),
252
+ "regnet-x-032": ImageNetPreTrainedConfig(
253
+ depths=[2, 6, 15, 2], hidden_sizes=[96, 192, 432, 1008], groups_width=48, layer_type="x"
254
+ ),
255
+ "regnet-x-040": ImageNetPreTrainedConfig(
256
+ depths=[2, 5, 14, 2], hidden_sizes=[80, 240, 560, 1360], groups_width=40, layer_type="x"
257
+ ),
258
+ "regnet-x-064": ImageNetPreTrainedConfig(
259
+ depths=[2, 4, 10, 1], hidden_sizes=[168, 392, 784, 1624], groups_width=56, layer_type="x"
260
+ ),
261
+ "regnet-x-080": ImageNetPreTrainedConfig(
262
+ depths=[2, 5, 15, 1], hidden_sizes=[80, 240, 720, 1920], groups_width=120, layer_type="x"
263
+ ),
264
+ "regnet-x-120": ImageNetPreTrainedConfig(
265
+ depths=[2, 5, 11, 1], hidden_sizes=[224, 448, 896, 2240], groups_width=112, layer_type="x"
266
+ ),
267
+ "regnet-x-160": ImageNetPreTrainedConfig(
268
+ depths=[2, 6, 13, 1], hidden_sizes=[256, 512, 896, 2048], groups_width=128, layer_type="x"
269
+ ),
270
+ "regnet-x-320": ImageNetPreTrainedConfig(
271
+ depths=[2, 7, 13, 1], hidden_sizes=[336, 672, 1344, 2520], groups_width=168, layer_type="x"
272
+ ),
273
+ # y variant
274
+ "regnet-y-002": ImageNetPreTrainedConfig(depths=[1, 1, 4, 7], hidden_sizes=[24, 56, 152, 368], groups_width=8),
275
+ "regnet-y-004": ImageNetPreTrainedConfig(
276
+ depths=[1, 3, 6, 6], hidden_sizes=[48, 104, 208, 440], groups_width=8
277
+ ),
278
+ "regnet-y-006": ImageNetPreTrainedConfig(
279
+ depths=[1, 3, 7, 4], hidden_sizes=[48, 112, 256, 608], groups_width=16
280
+ ),
281
+ "regnet-y-008": ImageNetPreTrainedConfig(
282
+ depths=[1, 3, 8, 2], hidden_sizes=[64, 128, 320, 768], groups_width=16
283
+ ),
284
+ "regnet-y-016": ImageNetPreTrainedConfig(
285
+ depths=[2, 6, 17, 2], hidden_sizes=[48, 120, 336, 888], groups_width=24
286
+ ),
287
+ "regnet-y-032": ImageNetPreTrainedConfig(
288
+ depths=[2, 5, 13, 1], hidden_sizes=[72, 216, 576, 1512], groups_width=24
289
+ ),
290
+ "regnet-y-040": ImageNetPreTrainedConfig(
291
+ depths=[2, 6, 12, 2], hidden_sizes=[128, 192, 512, 1088], groups_width=64
292
+ ),
293
+ "regnet-y-064": ImageNetPreTrainedConfig(
294
+ depths=[2, 7, 14, 2], hidden_sizes=[144, 288, 576, 1296], groups_width=72
295
+ ),
296
+ "regnet-y-080": ImageNetPreTrainedConfig(
297
+ depths=[2, 4, 10, 1], hidden_sizes=[168, 448, 896, 2016], groups_width=56
298
+ ),
299
+ "regnet-y-120": ImageNetPreTrainedConfig(
300
+ depths=[2, 5, 11, 1], hidden_sizes=[224, 448, 896, 2240], groups_width=112
301
+ ),
302
+ "regnet-y-160": ImageNetPreTrainedConfig(
303
+ depths=[2, 4, 11, 1], hidden_sizes=[224, 448, 1232, 3024], groups_width=112
304
+ ),
305
+ "regnet-y-320": ImageNetPreTrainedConfig(
306
+ depths=[2, 5, 12, 1], hidden_sizes=[232, 696, 1392, 3712], groups_width=232
307
+ ),
308
+ # models created by SEER -> https://arxiv.org/abs/2202.08360
309
+ "regnet-y-320-seer": RegNetConfig(depths=[2, 5, 12, 1], hidden_sizes=[232, 696, 1392, 3712], groups_width=232),
310
+ "regnet-y-640-seer": RegNetConfig(depths=[2, 5, 12, 1], hidden_sizes=[328, 984, 1968, 4920], groups_width=328),
311
+ "regnet-y-1280-seer": RegNetConfig(
312
+ depths=[2, 7, 17, 1], hidden_sizes=[528, 1056, 2904, 7392], groups_width=264
313
+ ),
314
+ "regnet-y-2560-seer": RegNetConfig(
315
+ depths=[3, 7, 16, 1], hidden_sizes=[640, 1696, 2544, 5088], groups_width=640
316
+ ),
317
+ "regnet-y-10b-seer": ImageNetPreTrainedConfig(
318
+ depths=[2, 7, 17, 1], hidden_sizes=[2020, 4040, 11110, 28280], groups_width=1010
319
+ ),
320
+ # finetuned on imagenet
321
+ "regnet-y-320-seer-in1k": ImageNetPreTrainedConfig(
322
+ depths=[2, 5, 12, 1], hidden_sizes=[232, 696, 1392, 3712], groups_width=232
323
+ ),
324
+ "regnet-y-640-seer-in1k": ImageNetPreTrainedConfig(
325
+ depths=[2, 5, 12, 1], hidden_sizes=[328, 984, 1968, 4920], groups_width=328
326
+ ),
327
+ "regnet-y-1280-seer-in1k": ImageNetPreTrainedConfig(
328
+ depths=[2, 7, 17, 1], hidden_sizes=[528, 1056, 2904, 7392], groups_width=264
329
+ ),
330
+ "regnet-y-2560-seer-in1k": ImageNetPreTrainedConfig(
331
+ depths=[3, 7, 16, 1], hidden_sizes=[640, 1696, 2544, 5088], groups_width=640
332
+ ),
333
+ "regnet-y-10b-seer-in1k": ImageNetPreTrainedConfig(
334
+ depths=[2, 7, 17, 1], hidden_sizes=[2020, 4040, 11110, 28280], groups_width=1010
335
+ ),
336
+ }
337
+
338
+ names_to_ours_model_map = NameToOurModelFuncMap()
339
+ names_to_from_model_map = NameToFromModelFuncMap()
340
+ # add seer weights logic
341
+
342
+ def load_using_classy_vision(checkpoint_url: str, model_func: Callable[[], nn.Module]) -> Tuple[nn.Module, Dict]:
343
+ files = torch.hub.load_state_dict_from_url(checkpoint_url, model_dir=str(save_directory), map_location="cpu")
344
+ model = model_func()
345
+ # check if we have a head, if yes add it
346
+ model_state_dict = files["classy_state_dict"]["base_model"]["model"]
347
+ state_dict = model_state_dict["trunk"]
348
+ model.load_state_dict(state_dict)
349
+ return model.eval(), model_state_dict["heads"]
350
+
351
+ # pretrained
352
+ names_to_from_model_map["regnet-y-320-seer"] = partial(
353
+ load_using_classy_vision,
354
+ "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_regnet32d/seer_regnet32gf_model_iteration244000.torch",
355
+ lambda: FakeRegNetVisslWrapper(RegNetY32gf()),
356
+ )
357
+
358
+ names_to_from_model_map["regnet-y-640-seer"] = partial(
359
+ load_using_classy_vision,
360
+ "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_regnet64/seer_regnet64gf_model_final_checkpoint_phase0.torch",
361
+ lambda: FakeRegNetVisslWrapper(RegNetY64gf()),
362
+ )
363
+
364
+ names_to_from_model_map["regnet-y-1280-seer"] = partial(
365
+ load_using_classy_vision,
366
+ "https://dl.fbaipublicfiles.com/vissl/model_zoo/swav_ig1b_regnet128Gf_cnstant_bs32_node16_sinkhorn10_proto16k_syncBN64_warmup8k/model_final_checkpoint_phase0.torch",
367
+ lambda: FakeRegNetVisslWrapper(RegNetY128gf()),
368
+ )
369
+
370
+ names_to_from_model_map["regnet-y-10b-seer"] = partial(
371
+ load_using_classy_vision,
372
+ "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_regnet10B/model_iteration124500_conso.torch",
373
+ lambda: FakeRegNetVisslWrapper(
374
+ RegNet(RegNetParams(depth=27, group_width=1010, w_0=1744, w_a=620.83, w_m=2.52))
375
+ ),
376
+ )
377
+
378
+ # IN1K finetuned
379
+ names_to_from_model_map["regnet-y-320-seer-in1k"] = partial(
380
+ load_using_classy_vision,
381
+ "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet32_finetuned_in1k_model_final_checkpoint_phase78.torch",
382
+ lambda: FakeRegNetVisslWrapper(RegNetY32gf()),
383
+ )
384
+
385
+ names_to_from_model_map["regnet-y-640-seer-in1k"] = partial(
386
+ load_using_classy_vision,
387
+ "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet64_finetuned_in1k_model_final_checkpoint_phase78.torch",
388
+ lambda: FakeRegNetVisslWrapper(RegNetY64gf()),
389
+ )
390
+
391
+ names_to_from_model_map["regnet-y-1280-seer-in1k"] = partial(
392
+ load_using_classy_vision,
393
+ "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet128_finetuned_in1k_model_final_checkpoint_phase78.torch",
394
+ lambda: FakeRegNetVisslWrapper(RegNetY128gf()),
395
+ )
396
+
397
+ names_to_from_model_map["regnet-y-10b-seer-in1k"] = partial(
398
+ load_using_classy_vision,
399
+ "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_10b_finetuned_in1k_model_phase28_conso.torch",
400
+ lambda: FakeRegNetVisslWrapper(
401
+ RegNet(RegNetParams(depth=27, group_width=1010, w_0=1744, w_a=620.83, w_m=2.52))
402
+ ),
403
+ )
404
+
405
+ if model_name:
406
+ convert_weight_and_push(
407
+ model_name,
408
+ names_to_from_model_map[model_name],
409
+ names_to_ours_model_map[model_name],
410
+ names_to_config[model_name],
411
+ save_directory,
412
+ push_to_hub,
413
+ )
414
+ else:
415
+ for model_name, config in names_to_config.items():
416
+ convert_weight_and_push(
417
+ model_name,
418
+ names_to_from_model_map[model_name],
419
+ names_to_ours_model_map[model_name],
420
+ config,
421
+ save_directory,
422
+ push_to_hub,
423
+ )
424
+ return config, expected_shape
425
+
426
+
427
+ if __name__ == "__main__":
428
+ parser = argparse.ArgumentParser()
429
+ # Required parameters
430
+ parser.add_argument(
431
+ "--model_name",
432
+ default=None,
433
+ type=str,
434
+ help=(
435
+ "The name of the model you wish to convert, it must be one of the supported regnet* architecture,"
436
+ " currently: regnetx-*, regnety-*. If `None`, all of them will the converted."
437
+ ),
438
+ )
439
+ parser.add_argument(
440
+ "--pytorch_dump_folder_path",
441
+ default=None,
442
+ type=Path,
443
+ required=True,
444
+ help="Path to the output PyTorch model directory.",
445
+ )
446
+ parser.add_argument(
447
+ "--push_to_hub",
448
+ default=True,
449
+ type=bool,
450
+ required=False,
451
+ help="If True, push model and image processor to the hub.",
452
+ )
453
+
454
+ args = parser.parse_args()
455
+
456
+ pytorch_dump_folder_path: Path = args.pytorch_dump_folder_path
457
+ pytorch_dump_folder_path.mkdir(exist_ok=True, parents=True)
458
+ convert_weights_and_push(pytorch_dump_folder_path, args.model_name, args.push_to_hub)
docs/transformers/build/lib/transformers/models/rembert/configuration_rembert.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """RemBERT model configuration"""
16
+
17
+ from collections import OrderedDict
18
+ from typing import Mapping
19
+
20
+ from ...configuration_utils import PretrainedConfig
21
+ from ...onnx import OnnxConfig
22
+ from ...utils import logging
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ class RemBertConfig(PretrainedConfig):
29
+ r"""
30
+ This is the configuration class to store the configuration of a [`RemBertModel`]. It is used to instantiate an
31
+ RemBERT model according to the specified arguments, defining the model architecture. Instantiating a configuration
32
+ with the defaults will yield a similar configuration to that of the RemBERT
33
+ [google/rembert](https://huggingface.co/google/rembert) architecture.
34
+
35
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
+ documentation from [`PretrainedConfig`] for more information.
37
+
38
+
39
+ Args:
40
+ vocab_size (`int`, *optional*, defaults to 250300):
41
+ Vocabulary size of the RemBERT model. Defines the number of different tokens that can be represented by the
42
+ `inputs_ids` passed when calling [`RemBertModel`] or [`TFRemBertModel`]. Vocabulary size of the model.
43
+ Defines the different tokens that can be represented by the *inputs_ids* passed to the forward method of
44
+ [`RemBertModel`].
45
+ hidden_size (`int`, *optional*, defaults to 1152):
46
+ Dimensionality of the encoder layers and the pooler layer.
47
+ num_hidden_layers (`int`, *optional*, defaults to 32):
48
+ Number of hidden layers in the Transformer encoder.
49
+ num_attention_heads (`int`, *optional*, defaults to 18):
50
+ Number of attention heads for each attention layer in the Transformer encoder.
51
+ input_embedding_size (`int`, *optional*, defaults to 256):
52
+ Dimensionality of the input embeddings.
53
+ output_embedding_size (`int`, *optional*, defaults to 1664):
54
+ Dimensionality of the output embeddings.
55
+ intermediate_size (`int`, *optional*, defaults to 4608):
56
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
57
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
58
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
59
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
60
+ hidden_dropout_prob (`float`, *optional*, defaults to 0):
61
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
62
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0):
63
+ The dropout ratio for the attention probabilities.
64
+ classifier_dropout_prob (`float`, *optional*, defaults to 0.1):
65
+ The dropout ratio for the classifier layer when fine-tuning.
66
+ max_position_embeddings (`int`, *optional*, defaults to 512):
67
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
68
+ just in case (e.g., 512 or 1024 or 2048).
69
+ type_vocab_size (`int`, *optional*, defaults to 2):
70
+ The vocabulary size of the `token_type_ids` passed when calling [`RemBertModel`] or [`TFRemBertModel`].
71
+ initializer_range (`float`, *optional*, defaults to 0.02):
72
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
73
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
74
+ The epsilon used by the layer normalization layers.
75
+ is_decoder (`bool`, *optional*, defaults to `False`):
76
+ Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
77
+ use_cache (`bool`, *optional*, defaults to `True`):
78
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
79
+ relevant if `config.is_decoder=True`.
80
+
81
+ Example:
82
+
83
+ ```python
84
+ >>> from transformers import RemBertModel, RemBertConfig
85
+
86
+ >>> # Initializing a RemBERT rembert style configuration
87
+ >>> configuration = RemBertConfig()
88
+
89
+ >>> # Initializing a model from the rembert style configuration
90
+ >>> model = RemBertModel(configuration)
91
+
92
+ >>> # Accessing the model configuration
93
+ >>> configuration = model.config
94
+ ```"""
95
+
96
+ model_type = "rembert"
97
+
98
+ def __init__(
99
+ self,
100
+ vocab_size=250300,
101
+ hidden_size=1152,
102
+ num_hidden_layers=32,
103
+ num_attention_heads=18,
104
+ input_embedding_size=256,
105
+ output_embedding_size=1664,
106
+ intermediate_size=4608,
107
+ hidden_act="gelu",
108
+ hidden_dropout_prob=0.0,
109
+ attention_probs_dropout_prob=0.0,
110
+ classifier_dropout_prob=0.1,
111
+ max_position_embeddings=512,
112
+ type_vocab_size=2,
113
+ initializer_range=0.02,
114
+ layer_norm_eps=1e-12,
115
+ use_cache=True,
116
+ pad_token_id=0,
117
+ bos_token_id=312,
118
+ eos_token_id=313,
119
+ **kwargs,
120
+ ):
121
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
122
+
123
+ self.vocab_size = vocab_size
124
+ self.input_embedding_size = input_embedding_size
125
+ self.output_embedding_size = output_embedding_size
126
+ self.max_position_embeddings = max_position_embeddings
127
+ self.hidden_size = hidden_size
128
+ self.num_hidden_layers = num_hidden_layers
129
+ self.num_attention_heads = num_attention_heads
130
+ self.intermediate_size = intermediate_size
131
+ self.hidden_act = hidden_act
132
+ self.hidden_dropout_prob = hidden_dropout_prob
133
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
134
+ self.classifier_dropout_prob = classifier_dropout_prob
135
+ self.initializer_range = initializer_range
136
+ self.type_vocab_size = type_vocab_size
137
+ self.layer_norm_eps = layer_norm_eps
138
+ self.use_cache = use_cache
139
+ self.tie_word_embeddings = False
140
+
141
+
142
+ class RemBertOnnxConfig(OnnxConfig):
143
+ @property
144
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
145
+ if self.task == "multiple-choice":
146
+ dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
147
+ else:
148
+ dynamic_axis = {0: "batch", 1: "sequence"}
149
+ return OrderedDict(
150
+ [
151
+ ("input_ids", dynamic_axis),
152
+ ("attention_mask", dynamic_axis),
153
+ ("token_type_ids", dynamic_axis),
154
+ ]
155
+ )
156
+
157
+ @property
158
+ def atol_for_validation(self) -> float:
159
+ return 1e-4
160
+
161
+
162
+ __all__ = ["RemBertConfig", "RemBertOnnxConfig"]
docs/transformers/build/lib/transformers/models/rembert/convert_rembert_tf_checkpoint_to_pytorch.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert RemBERT checkpoint."""
16
+
17
+ import argparse
18
+
19
+ import torch
20
+
21
+ from transformers import RemBertConfig, RemBertModel, load_tf_weights_in_rembert
22
+ from transformers.utils import logging
23
+
24
+
25
+ logging.set_verbosity_info()
26
+
27
+
28
+ def convert_rembert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
29
+ # Initialise PyTorch model
30
+ config = RemBertConfig.from_json_file(bert_config_file)
31
+ print("Building PyTorch model from configuration: {}".format(str(config)))
32
+ model = RemBertModel(config)
33
+
34
+ # Load weights from tf checkpoint
35
+ load_tf_weights_in_rembert(model, config, tf_checkpoint_path)
36
+
37
+ # Save pytorch-model
38
+ print("Save PyTorch model to {}".format(pytorch_dump_path))
39
+ torch.save(model.state_dict(), pytorch_dump_path)
40
+
41
+
42
+ if __name__ == "__main__":
43
+ parser = argparse.ArgumentParser()
44
+ # Required parameters
45
+ parser.add_argument(
46
+ "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
47
+ )
48
+ parser.add_argument(
49
+ "--rembert_config_file",
50
+ default=None,
51
+ type=str,
52
+ required=True,
53
+ help=(
54
+ "The config json file corresponding to the pre-trained RemBERT model. \n"
55
+ "This specifies the model architecture."
56
+ ),
57
+ )
58
+ parser.add_argument(
59
+ "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
60
+ )
61
+ args = parser.parse_args()
62
+ convert_rembert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.rembert_config_file, args.pytorch_dump_path)
docs/transformers/build/lib/transformers/models/rembert/modeling_rembert.py ADDED
@@ -0,0 +1,1525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Team The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch RemBERT model."""
16
+
17
+ import math
18
+ import os
19
+ from typing import Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
25
+
26
+ from ...activations import ACT2FN
27
+ from ...generation import GenerationMixin
28
+ from ...modeling_outputs import (
29
+ BaseModelOutputWithPastAndCrossAttentions,
30
+ BaseModelOutputWithPoolingAndCrossAttentions,
31
+ CausalLMOutputWithCrossAttentions,
32
+ MaskedLMOutput,
33
+ MultipleChoiceModelOutput,
34
+ QuestionAnsweringModelOutput,
35
+ SequenceClassifierOutput,
36
+ TokenClassifierOutput,
37
+ )
38
+ from ...modeling_utils import PreTrainedModel
39
+ from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
40
+ from ...utils import (
41
+ add_code_sample_docstrings,
42
+ add_start_docstrings,
43
+ add_start_docstrings_to_model_forward,
44
+ logging,
45
+ replace_return_docstrings,
46
+ )
47
+ from .configuration_rembert import RemBertConfig
48
+
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+ _CONFIG_FOR_DOC = "RemBertConfig"
53
+ _CHECKPOINT_FOR_DOC = "google/rembert"
54
+
55
+
56
+ def load_tf_weights_in_rembert(model, config, tf_checkpoint_path):
57
+ """Load tf checkpoints in a pytorch model."""
58
+ try:
59
+ import re
60
+
61
+ import numpy as np
62
+ import tensorflow as tf
63
+ except ImportError:
64
+ logger.error(
65
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
66
+ "https://www.tensorflow.org/install/ for installation instructions."
67
+ )
68
+ raise
69
+ tf_path = os.path.abspath(tf_checkpoint_path)
70
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
71
+ # Load weights from TF model
72
+ init_vars = tf.train.list_variables(tf_path)
73
+ names = []
74
+ arrays = []
75
+ for name, shape in init_vars:
76
+ # Checkpoint is 12Gb, save memory by not loading useless variables
77
+ # Output embedding and cls are reset at classification time
78
+ if any(deny in name for deny in ("adam_v", "adam_m", "output_embedding", "cls")):
79
+ # logger.info("Skipping loading of %s", name)
80
+ continue
81
+ logger.info(f"Loading TF weight {name} with shape {shape}")
82
+ array = tf.train.load_variable(tf_path, name)
83
+ names.append(name)
84
+ arrays.append(array)
85
+
86
+ for name, array in zip(names, arrays):
87
+ # Replace prefix with right one
88
+ name = name.replace("bert/", "rembert/")
89
+ # The pooler is a linear layer
90
+ # name = name.replace("pooler/dense", "pooler")
91
+
92
+ name = name.split("/")
93
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
94
+ # which are not required for using pretrained model
95
+ if any(
96
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
97
+ for n in name
98
+ ):
99
+ logger.info(f"Skipping {'/'.join(name)}")
100
+ continue
101
+ pointer = model
102
+ for m_name in name:
103
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
104
+ scope_names = re.split(r"_(\d+)", m_name)
105
+ else:
106
+ scope_names = [m_name]
107
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
108
+ pointer = getattr(pointer, "weight")
109
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
110
+ pointer = getattr(pointer, "bias")
111
+ elif scope_names[0] == "output_weights":
112
+ pointer = getattr(pointer, "weight")
113
+ elif scope_names[0] == "squad":
114
+ pointer = getattr(pointer, "classifier")
115
+ else:
116
+ try:
117
+ pointer = getattr(pointer, scope_names[0])
118
+ except AttributeError:
119
+ logger.info("Skipping {}".format("/".join(name)))
120
+ continue
121
+ if len(scope_names) >= 2:
122
+ num = int(scope_names[1])
123
+ pointer = pointer[num]
124
+ if m_name[-11:] == "_embeddings":
125
+ pointer = getattr(pointer, "weight")
126
+ elif m_name == "kernel":
127
+ array = np.transpose(array)
128
+ try:
129
+ if pointer.shape != array.shape:
130
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
131
+ except AssertionError as e:
132
+ e.args += (pointer.shape, array.shape)
133
+ raise
134
+ logger.info(f"Initialize PyTorch weight {name}")
135
+ pointer.data = torch.from_numpy(array)
136
+ return model
137
+
138
+
139
+ class RemBertEmbeddings(nn.Module):
140
+ """Construct the embeddings from word, position and token_type embeddings."""
141
+
142
+ def __init__(self, config):
143
+ super().__init__()
144
+ self.word_embeddings = nn.Embedding(
145
+ config.vocab_size, config.input_embedding_size, padding_idx=config.pad_token_id
146
+ )
147
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.input_embedding_size)
148
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.input_embedding_size)
149
+
150
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
151
+ # any TensorFlow checkpoint file
152
+ self.LayerNorm = nn.LayerNorm(config.input_embedding_size, eps=config.layer_norm_eps)
153
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
154
+
155
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
156
+ self.register_buffer(
157
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
158
+ )
159
+
160
+ def forward(
161
+ self,
162
+ input_ids: Optional[torch.LongTensor] = None,
163
+ token_type_ids: Optional[torch.LongTensor] = None,
164
+ position_ids: Optional[torch.LongTensor] = None,
165
+ inputs_embeds: Optional[torch.FloatTensor] = None,
166
+ past_key_values_length: int = 0,
167
+ ) -> torch.Tensor:
168
+ if input_ids is not None:
169
+ input_shape = input_ids.size()
170
+ else:
171
+ input_shape = inputs_embeds.size()[:-1]
172
+
173
+ seq_length = input_shape[1]
174
+
175
+ if position_ids is None:
176
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
177
+
178
+ if token_type_ids is None:
179
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
180
+
181
+ if inputs_embeds is None:
182
+ inputs_embeds = self.word_embeddings(input_ids)
183
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
184
+
185
+ embeddings = inputs_embeds + token_type_embeddings
186
+ position_embeddings = self.position_embeddings(position_ids)
187
+ embeddings += position_embeddings
188
+ embeddings = self.LayerNorm(embeddings)
189
+ embeddings = self.dropout(embeddings)
190
+ return embeddings
191
+
192
+
193
+ # Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->RemBert
194
+ class RemBertPooler(nn.Module):
195
+ def __init__(self, config):
196
+ super().__init__()
197
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
198
+ self.activation = nn.Tanh()
199
+
200
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
201
+ # We "pool" the model by simply taking the hidden state corresponding
202
+ # to the first token.
203
+ first_token_tensor = hidden_states[:, 0]
204
+ pooled_output = self.dense(first_token_tensor)
205
+ pooled_output = self.activation(pooled_output)
206
+ return pooled_output
207
+
208
+
209
+ class RemBertSelfAttention(nn.Module):
210
+ def __init__(self, config):
211
+ super().__init__()
212
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
213
+ raise ValueError(
214
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
215
+ f"heads ({config.num_attention_heads})"
216
+ )
217
+
218
+ self.num_attention_heads = config.num_attention_heads
219
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
220
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
221
+
222
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
223
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
224
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
225
+
226
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
227
+
228
+ self.is_decoder = config.is_decoder
229
+
230
+ def transpose_for_scores(self, x):
231
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
232
+ x = x.view(*new_x_shape)
233
+ return x.permute(0, 2, 1, 3)
234
+
235
+ def forward(
236
+ self,
237
+ hidden_states: torch.Tensor,
238
+ attention_mask: Optional[torch.FloatTensor] = None,
239
+ head_mask: Optional[torch.FloatTensor] = None,
240
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
241
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
242
+ past_key_value: Tuple[Tuple[torch.FloatTensor]] = None,
243
+ output_attentions: bool = False,
244
+ ) -> Tuple:
245
+ mixed_query_layer = self.query(hidden_states)
246
+
247
+ # If this is instantiated as a cross-attention module, the keys
248
+ # and values come from an encoder; the attention mask needs to be
249
+ # such that the encoder's padding tokens are not attended to.
250
+ is_cross_attention = encoder_hidden_states is not None
251
+
252
+ if is_cross_attention and past_key_value is not None:
253
+ # reuse k,v, cross_attentions
254
+ key_layer = past_key_value[0]
255
+ value_layer = past_key_value[1]
256
+ attention_mask = encoder_attention_mask
257
+ elif is_cross_attention:
258
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
259
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
260
+ attention_mask = encoder_attention_mask
261
+ elif past_key_value is not None:
262
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
263
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
264
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
265
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
266
+ else:
267
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
268
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
269
+
270
+ query_layer = self.transpose_for_scores(mixed_query_layer)
271
+
272
+ if self.is_decoder:
273
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
274
+ # Further calls to cross_attention layer can then reuse all cross-attention
275
+ # key/value_states (first "if" case)
276
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
277
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
278
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
279
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
280
+ past_key_value = (key_layer, value_layer)
281
+
282
+ # Take the dot product between "query" and "key" to get the raw attention scores.
283
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
284
+
285
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
286
+ if attention_mask is not None:
287
+ # Apply the attention mask is (precomputed for all layers in RemBertModel forward() function)
288
+ attention_scores = attention_scores + attention_mask
289
+
290
+ # Normalize the attention scores to probabilities.
291
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
292
+
293
+ # This is actually dropping out entire tokens to attend to, which might
294
+ # seem a bit unusual, but is taken from the original Transformer paper.
295
+ attention_probs = self.dropout(attention_probs)
296
+
297
+ # Mask heads if we want to
298
+ if head_mask is not None:
299
+ attention_probs = attention_probs * head_mask
300
+
301
+ context_layer = torch.matmul(attention_probs, value_layer)
302
+
303
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
304
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
305
+ context_layer = context_layer.view(*new_context_layer_shape)
306
+
307
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
308
+
309
+ if self.is_decoder:
310
+ outputs = outputs + (past_key_value,)
311
+ return outputs
312
+
313
+
314
+ # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->RemBert
315
+ class RemBertSelfOutput(nn.Module):
316
+ def __init__(self, config):
317
+ super().__init__()
318
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
319
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
320
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
321
+
322
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
323
+ hidden_states = self.dense(hidden_states)
324
+ hidden_states = self.dropout(hidden_states)
325
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
326
+ return hidden_states
327
+
328
+
329
+ class RemBertAttention(nn.Module):
330
+ def __init__(self, config):
331
+ super().__init__()
332
+ self.self = RemBertSelfAttention(config)
333
+ self.output = RemBertSelfOutput(config)
334
+ self.pruned_heads = set()
335
+
336
+ # Copied from transformers.models.bert.modeling_bert.BertAttention.prune_heads
337
+ def prune_heads(self, heads):
338
+ if len(heads) == 0:
339
+ return
340
+ heads, index = find_pruneable_heads_and_indices(
341
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
342
+ )
343
+
344
+ # Prune linear layers
345
+ self.self.query = prune_linear_layer(self.self.query, index)
346
+ self.self.key = prune_linear_layer(self.self.key, index)
347
+ self.self.value = prune_linear_layer(self.self.value, index)
348
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
349
+
350
+ # Update hyper params and store pruned heads
351
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
352
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
353
+ self.pruned_heads = self.pruned_heads.union(heads)
354
+
355
+ # Copied from transformers.models.bert.modeling_bert.BertAttention.forward
356
+ def forward(
357
+ self,
358
+ hidden_states: torch.Tensor,
359
+ attention_mask: Optional[torch.FloatTensor] = None,
360
+ head_mask: Optional[torch.FloatTensor] = None,
361
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
362
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
363
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
364
+ output_attentions: Optional[bool] = False,
365
+ ) -> Tuple[torch.Tensor]:
366
+ self_outputs = self.self(
367
+ hidden_states,
368
+ attention_mask,
369
+ head_mask,
370
+ encoder_hidden_states,
371
+ encoder_attention_mask,
372
+ past_key_value,
373
+ output_attentions,
374
+ )
375
+ attention_output = self.output(self_outputs[0], hidden_states)
376
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
377
+ return outputs
378
+
379
+
380
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->RemBert
381
+ class RemBertIntermediate(nn.Module):
382
+ def __init__(self, config):
383
+ super().__init__()
384
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
385
+ if isinstance(config.hidden_act, str):
386
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
387
+ else:
388
+ self.intermediate_act_fn = config.hidden_act
389
+
390
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
391
+ hidden_states = self.dense(hidden_states)
392
+ hidden_states = self.intermediate_act_fn(hidden_states)
393
+ return hidden_states
394
+
395
+
396
+ # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->RemBert
397
+ class RemBertOutput(nn.Module):
398
+ def __init__(self, config):
399
+ super().__init__()
400
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
401
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
402
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
403
+
404
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
405
+ hidden_states = self.dense(hidden_states)
406
+ hidden_states = self.dropout(hidden_states)
407
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
408
+ return hidden_states
409
+
410
+
411
+ class RemBertLayer(nn.Module):
412
+ def __init__(self, config):
413
+ super().__init__()
414
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
415
+ self.seq_len_dim = 1
416
+ self.attention = RemBertAttention(config)
417
+ self.is_decoder = config.is_decoder
418
+ self.add_cross_attention = config.add_cross_attention
419
+ if self.add_cross_attention:
420
+ if not self.is_decoder:
421
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
422
+ self.crossattention = RemBertAttention(config)
423
+ self.intermediate = RemBertIntermediate(config)
424
+ self.output = RemBertOutput(config)
425
+
426
+ # Copied from transformers.models.bert.modeling_bert.BertLayer.forward
427
+ def forward(
428
+ self,
429
+ hidden_states: torch.Tensor,
430
+ attention_mask: Optional[torch.FloatTensor] = None,
431
+ head_mask: Optional[torch.FloatTensor] = None,
432
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
433
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
434
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
435
+ output_attentions: Optional[bool] = False,
436
+ ) -> Tuple[torch.Tensor]:
437
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
438
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
439
+ self_attention_outputs = self.attention(
440
+ hidden_states,
441
+ attention_mask,
442
+ head_mask,
443
+ output_attentions=output_attentions,
444
+ past_key_value=self_attn_past_key_value,
445
+ )
446
+ attention_output = self_attention_outputs[0]
447
+
448
+ # if decoder, the last output is tuple of self-attn cache
449
+ if self.is_decoder:
450
+ outputs = self_attention_outputs[1:-1]
451
+ present_key_value = self_attention_outputs[-1]
452
+ else:
453
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
454
+
455
+ cross_attn_present_key_value = None
456
+ if self.is_decoder and encoder_hidden_states is not None:
457
+ if not hasattr(self, "crossattention"):
458
+ raise ValueError(
459
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
460
+ " by setting `config.add_cross_attention=True`"
461
+ )
462
+
463
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
464
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
465
+ cross_attention_outputs = self.crossattention(
466
+ attention_output,
467
+ attention_mask,
468
+ head_mask,
469
+ encoder_hidden_states,
470
+ encoder_attention_mask,
471
+ cross_attn_past_key_value,
472
+ output_attentions,
473
+ )
474
+ attention_output = cross_attention_outputs[0]
475
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
476
+
477
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
478
+ cross_attn_present_key_value = cross_attention_outputs[-1]
479
+ present_key_value = present_key_value + cross_attn_present_key_value
480
+
481
+ layer_output = apply_chunking_to_forward(
482
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
483
+ )
484
+ outputs = (layer_output,) + outputs
485
+
486
+ # if decoder, return the attn key/values as the last output
487
+ if self.is_decoder:
488
+ outputs = outputs + (present_key_value,)
489
+
490
+ return outputs
491
+
492
+ # Copied from transformers.models.bert.modeling_bert.BertLayer.feed_forward_chunk
493
+ def feed_forward_chunk(self, attention_output):
494
+ intermediate_output = self.intermediate(attention_output)
495
+ layer_output = self.output(intermediate_output, attention_output)
496
+ return layer_output
497
+
498
+
499
+ class RemBertEncoder(nn.Module):
500
+ def __init__(self, config):
501
+ super().__init__()
502
+ self.config = config
503
+
504
+ self.embedding_hidden_mapping_in = nn.Linear(config.input_embedding_size, config.hidden_size)
505
+ self.layer = nn.ModuleList([RemBertLayer(config) for _ in range(config.num_hidden_layers)])
506
+ self.gradient_checkpointing = False
507
+
508
+ def forward(
509
+ self,
510
+ hidden_states: torch.Tensor,
511
+ attention_mask: Optional[torch.FloatTensor] = None,
512
+ head_mask: Optional[torch.FloatTensor] = None,
513
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
514
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
515
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
516
+ use_cache: Optional[bool] = None,
517
+ output_attentions: bool = False,
518
+ output_hidden_states: bool = False,
519
+ return_dict: bool = True,
520
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
521
+ if self.gradient_checkpointing and self.training:
522
+ if use_cache:
523
+ logger.warning_once(
524
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
525
+ )
526
+ use_cache = False
527
+ hidden_states = self.embedding_hidden_mapping_in(hidden_states)
528
+ all_hidden_states = () if output_hidden_states else None
529
+ all_self_attentions = () if output_attentions else None
530
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
531
+
532
+ next_decoder_cache = () if use_cache else None
533
+ for i, layer_module in enumerate(self.layer):
534
+ if output_hidden_states:
535
+ all_hidden_states = all_hidden_states + (hidden_states,)
536
+
537
+ layer_head_mask = head_mask[i] if head_mask is not None else None
538
+ past_key_value = past_key_values[i] if past_key_values is not None else None
539
+
540
+ if self.gradient_checkpointing and self.training:
541
+ layer_outputs = self._gradient_checkpointing_func(
542
+ layer_module.__call__,
543
+ hidden_states,
544
+ attention_mask,
545
+ layer_head_mask,
546
+ encoder_hidden_states,
547
+ encoder_attention_mask,
548
+ past_key_value,
549
+ output_attentions,
550
+ )
551
+ else:
552
+ layer_outputs = layer_module(
553
+ hidden_states,
554
+ attention_mask,
555
+ layer_head_mask,
556
+ encoder_hidden_states,
557
+ encoder_attention_mask,
558
+ past_key_value,
559
+ output_attentions,
560
+ )
561
+
562
+ hidden_states = layer_outputs[0]
563
+ if use_cache:
564
+ next_decoder_cache += (layer_outputs[-1],)
565
+ if output_attentions:
566
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
567
+ if self.config.add_cross_attention:
568
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
569
+
570
+ if output_hidden_states:
571
+ all_hidden_states = all_hidden_states + (hidden_states,)
572
+
573
+ if not return_dict:
574
+ return tuple(
575
+ v
576
+ for v in [
577
+ hidden_states,
578
+ next_decoder_cache,
579
+ all_hidden_states,
580
+ all_self_attentions,
581
+ all_cross_attentions,
582
+ ]
583
+ if v is not None
584
+ )
585
+ return BaseModelOutputWithPastAndCrossAttentions(
586
+ last_hidden_state=hidden_states,
587
+ past_key_values=next_decoder_cache,
588
+ hidden_states=all_hidden_states,
589
+ attentions=all_self_attentions,
590
+ cross_attentions=all_cross_attentions,
591
+ )
592
+
593
+
594
+ # Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->RemBert
595
+ class RemBertPredictionHeadTransform(nn.Module):
596
+ def __init__(self, config):
597
+ super().__init__()
598
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
599
+ if isinstance(config.hidden_act, str):
600
+ self.transform_act_fn = ACT2FN[config.hidden_act]
601
+ else:
602
+ self.transform_act_fn = config.hidden_act
603
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
604
+
605
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
606
+ hidden_states = self.dense(hidden_states)
607
+ hidden_states = self.transform_act_fn(hidden_states)
608
+ hidden_states = self.LayerNorm(hidden_states)
609
+ return hidden_states
610
+
611
+
612
+ class RemBertLMPredictionHead(nn.Module):
613
+ def __init__(self, config):
614
+ super().__init__()
615
+ self.dense = nn.Linear(config.hidden_size, config.output_embedding_size)
616
+ self.decoder = nn.Linear(config.output_embedding_size, config.vocab_size)
617
+ self.activation = ACT2FN[config.hidden_act]
618
+ self.LayerNorm = nn.LayerNorm(config.output_embedding_size, eps=config.layer_norm_eps)
619
+
620
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
621
+ hidden_states = self.dense(hidden_states)
622
+ hidden_states = self.activation(hidden_states)
623
+ hidden_states = self.LayerNorm(hidden_states)
624
+ hidden_states = self.decoder(hidden_states)
625
+ return hidden_states
626
+
627
+
628
+ # Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->RemBert
629
+ class RemBertOnlyMLMHead(nn.Module):
630
+ def __init__(self, config):
631
+ super().__init__()
632
+ self.predictions = RemBertLMPredictionHead(config)
633
+
634
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
635
+ prediction_scores = self.predictions(sequence_output)
636
+ return prediction_scores
637
+
638
+
639
+ class RemBertPreTrainedModel(PreTrainedModel):
640
+ """
641
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
642
+ models.
643
+ """
644
+
645
+ config_class = RemBertConfig
646
+ load_tf_weights = load_tf_weights_in_rembert
647
+ base_model_prefix = "rembert"
648
+ supports_gradient_checkpointing = True
649
+
650
+ def _init_weights(self, module):
651
+ """Initialize the weights"""
652
+ if isinstance(module, nn.Linear):
653
+ # Slightly different from the TF version which uses truncated_normal for initialization
654
+ # cf https://github.com/pytorch/pytorch/pull/5617
655
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
656
+ if module.bias is not None:
657
+ module.bias.data.zero_()
658
+ elif isinstance(module, nn.Embedding):
659
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
660
+ if module.padding_idx is not None:
661
+ module.weight.data[module.padding_idx].zero_()
662
+ elif isinstance(module, nn.LayerNorm):
663
+ module.bias.data.zero_()
664
+ module.weight.data.fill_(1.0)
665
+
666
+
667
+ REMBERT_START_DOCSTRING = r"""
668
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
669
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
670
+ behavior.
671
+
672
+ Parameters:
673
+ config ([`RemBertConfig`]): Model configuration class with all the parameters of the model.
674
+ Initializing with a config file does not load the weights associated with the model, only the
675
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
676
+ """
677
+
678
+ REMBERT_INPUTS_DOCSTRING = r"""
679
+ Args:
680
+ input_ids (`torch.LongTensor` of shape `({0})`):
681
+ Indices of input sequence tokens in the vocabulary.
682
+
683
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
684
+ [`PreTrainedTokenizer.__call__`] for details.
685
+
686
+ [What are input IDs?](../glossary#input-ids)
687
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
688
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
689
+
690
+ - 1 for tokens that are **not masked**,
691
+ - 0 for tokens that are **masked**.
692
+
693
+ [What are attention masks?](../glossary#attention-mask)
694
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
695
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
696
+ 1]`:
697
+
698
+ - 0 corresponds to a *sentence A* token,
699
+ - 1 corresponds to a *sentence B* token.
700
+
701
+ [What are token type IDs?](../glossary#token-type-ids)
702
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
703
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
704
+ config.max_position_embeddings - 1]`.
705
+
706
+ [What are position IDs?](../glossary#position-ids)
707
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
708
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
709
+
710
+ - 1 indicates the head is **not masked**,
711
+ - 0 indicates the head is **masked**.
712
+
713
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
714
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
715
+ is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
716
+ model's internal embedding lookup matrix.
717
+ output_attentions (`bool`, *optional*):
718
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
719
+ tensors for more detail.
720
+ output_hidden_states (`bool`, *optional*):
721
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
722
+ more detail.
723
+ return_dict (`bool`, *optional*):
724
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
725
+ """
726
+
727
+
728
+ @add_start_docstrings(
729
+ "The bare RemBERT Model transformer outputting raw hidden-states without any specific head on top.",
730
+ REMBERT_START_DOCSTRING,
731
+ )
732
+ class RemBertModel(RemBertPreTrainedModel):
733
+ """
734
+
735
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
736
+ cross-attention is added between the self-attention layers, following the architecture described in [Attention is
737
+ all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
738
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
739
+
740
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
741
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
742
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
743
+ """
744
+
745
+ def __init__(self, config, add_pooling_layer=True):
746
+ super().__init__(config)
747
+ self.config = config
748
+
749
+ self.embeddings = RemBertEmbeddings(config)
750
+ self.encoder = RemBertEncoder(config)
751
+
752
+ self.pooler = RemBertPooler(config) if add_pooling_layer else None
753
+
754
+ # Initialize weights and apply final processing
755
+ self.post_init()
756
+
757
+ def get_input_embeddings(self):
758
+ return self.embeddings.word_embeddings
759
+
760
+ def set_input_embeddings(self, value):
761
+ self.embeddings.word_embeddings = value
762
+
763
+ def _prune_heads(self, heads_to_prune):
764
+ """
765
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
766
+ class PreTrainedModel
767
+ """
768
+ for layer, heads in heads_to_prune.items():
769
+ self.encoder.layer[layer].attention.prune_heads(heads)
770
+
771
+ @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
772
+ @add_code_sample_docstrings(
773
+ checkpoint="google/rembert",
774
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
775
+ config_class=_CONFIG_FOR_DOC,
776
+ )
777
+ def forward(
778
+ self,
779
+ input_ids: Optional[torch.LongTensor] = None,
780
+ attention_mask: Optional[torch.LongTensor] = None,
781
+ token_type_ids: Optional[torch.LongTensor] = None,
782
+ position_ids: Optional[torch.LongTensor] = None,
783
+ head_mask: Optional[torch.FloatTensor] = None,
784
+ inputs_embeds: Optional[torch.FloatTensor] = None,
785
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
786
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
787
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
788
+ use_cache: Optional[bool] = None,
789
+ output_attentions: Optional[bool] = None,
790
+ output_hidden_states: Optional[bool] = None,
791
+ return_dict: Optional[bool] = None,
792
+ ) -> Union[Tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
793
+ r"""
794
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
795
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
796
+ the model is configured as a decoder.
797
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
798
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
799
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
800
+
801
+ - 1 for tokens that are **not masked**,
802
+ - 0 for tokens that are **masked**.
803
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
804
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
805
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
806
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
807
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
808
+ use_cache (`bool`, *optional*):
809
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
810
+ `past_key_values`).
811
+ """
812
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
813
+ output_hidden_states = (
814
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
815
+ )
816
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
817
+
818
+ if self.config.is_decoder:
819
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
820
+ else:
821
+ use_cache = False
822
+
823
+ if input_ids is not None and inputs_embeds is not None:
824
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
825
+ elif input_ids is not None:
826
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
827
+ input_shape = input_ids.size()
828
+ elif inputs_embeds is not None:
829
+ input_shape = inputs_embeds.size()[:-1]
830
+ else:
831
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
832
+
833
+ batch_size, seq_length = input_shape
834
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
835
+
836
+ # past_key_values_length
837
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
838
+
839
+ if attention_mask is None:
840
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
841
+ if token_type_ids is None:
842
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
843
+
844
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
845
+ # ourselves in which case we just need to make it broadcastable to all heads.
846
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
847
+
848
+ # If a 2D or 3D attention mask is provided for the cross-attention
849
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
850
+ if self.config.is_decoder and encoder_hidden_states is not None:
851
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
852
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
853
+ if encoder_attention_mask is None:
854
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
855
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
856
+ else:
857
+ encoder_extended_attention_mask = None
858
+
859
+ # Prepare head mask if needed
860
+ # 1.0 in head_mask indicate we keep the head
861
+ # attention_probs has shape bsz x n_heads x N x N
862
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
863
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
864
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
865
+
866
+ embedding_output = self.embeddings(
867
+ input_ids=input_ids,
868
+ position_ids=position_ids,
869
+ token_type_ids=token_type_ids,
870
+ inputs_embeds=inputs_embeds,
871
+ past_key_values_length=past_key_values_length,
872
+ )
873
+ encoder_outputs = self.encoder(
874
+ embedding_output,
875
+ attention_mask=extended_attention_mask,
876
+ head_mask=head_mask,
877
+ encoder_hidden_states=encoder_hidden_states,
878
+ encoder_attention_mask=encoder_extended_attention_mask,
879
+ past_key_values=past_key_values,
880
+ use_cache=use_cache,
881
+ output_attentions=output_attentions,
882
+ output_hidden_states=output_hidden_states,
883
+ return_dict=return_dict,
884
+ )
885
+ sequence_output = encoder_outputs[0]
886
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
887
+
888
+ if not return_dict:
889
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
890
+
891
+ return BaseModelOutputWithPoolingAndCrossAttentions(
892
+ last_hidden_state=sequence_output,
893
+ pooler_output=pooled_output,
894
+ past_key_values=encoder_outputs.past_key_values,
895
+ hidden_states=encoder_outputs.hidden_states,
896
+ attentions=encoder_outputs.attentions,
897
+ cross_attentions=encoder_outputs.cross_attentions,
898
+ )
899
+
900
+
901
+ @add_start_docstrings("""RemBERT Model with a `language modeling` head on top.""", REMBERT_START_DOCSTRING)
902
+ class RemBertForMaskedLM(RemBertPreTrainedModel):
903
+ _tied_weights_keys = ["cls.predictions.decoder.weight"]
904
+
905
+ def __init__(self, config):
906
+ super().__init__(config)
907
+
908
+ if config.is_decoder:
909
+ logger.warning(
910
+ "If you want to use `RemBertForMaskedLM` make sure `config.is_decoder=False` for "
911
+ "bi-directional self-attention."
912
+ )
913
+
914
+ self.rembert = RemBertModel(config, add_pooling_layer=False)
915
+ self.cls = RemBertOnlyMLMHead(config)
916
+
917
+ # Initialize weights and apply final processing
918
+ self.post_init()
919
+
920
+ def get_output_embeddings(self):
921
+ return self.cls.predictions.decoder
922
+
923
+ def set_output_embeddings(self, new_embeddings):
924
+ self.cls.predictions.decoder = new_embeddings
925
+
926
+ @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
927
+ @add_code_sample_docstrings(
928
+ checkpoint="google/rembert",
929
+ output_type=MaskedLMOutput,
930
+ config_class=_CONFIG_FOR_DOC,
931
+ )
932
+ def forward(
933
+ self,
934
+ input_ids: Optional[torch.LongTensor] = None,
935
+ attention_mask: Optional[torch.LongTensor] = None,
936
+ token_type_ids: Optional[torch.LongTensor] = None,
937
+ position_ids: Optional[torch.LongTensor] = None,
938
+ head_mask: Optional[torch.FloatTensor] = None,
939
+ inputs_embeds: Optional[torch.FloatTensor] = None,
940
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
941
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
942
+ labels: Optional[torch.LongTensor] = None,
943
+ output_attentions: Optional[bool] = None,
944
+ output_hidden_states: Optional[bool] = None,
945
+ return_dict: Optional[bool] = None,
946
+ ) -> Union[Tuple, MaskedLMOutput]:
947
+ r"""
948
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
949
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
950
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
951
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
952
+ """
953
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
954
+
955
+ outputs = self.rembert(
956
+ input_ids,
957
+ attention_mask=attention_mask,
958
+ token_type_ids=token_type_ids,
959
+ position_ids=position_ids,
960
+ head_mask=head_mask,
961
+ inputs_embeds=inputs_embeds,
962
+ encoder_hidden_states=encoder_hidden_states,
963
+ encoder_attention_mask=encoder_attention_mask,
964
+ output_attentions=output_attentions,
965
+ output_hidden_states=output_hidden_states,
966
+ return_dict=return_dict,
967
+ )
968
+
969
+ sequence_output = outputs[0]
970
+ prediction_scores = self.cls(sequence_output)
971
+
972
+ masked_lm_loss = None
973
+ if labels is not None:
974
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
975
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
976
+
977
+ if not return_dict:
978
+ output = (prediction_scores,) + outputs[2:]
979
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
980
+
981
+ return MaskedLMOutput(
982
+ loss=masked_lm_loss,
983
+ logits=prediction_scores,
984
+ hidden_states=outputs.hidden_states,
985
+ attentions=outputs.attentions,
986
+ )
987
+
988
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
989
+ input_shape = input_ids.shape
990
+ effective_batch_size = input_shape[0]
991
+
992
+ # add a dummy token
993
+ assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
994
+ attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
995
+ dummy_token = torch.full(
996
+ (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
997
+ )
998
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
999
+
1000
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
1001
+
1002
+ @classmethod
1003
+ def can_generate(cls) -> bool:
1004
+ """
1005
+ Legacy correction: RemBertForMaskedLM can't call `generate()` from `GenerationMixin`, even though it has a
1006
+ `prepare_inputs_for_generation` method.
1007
+ """
1008
+ return False
1009
+
1010
+
1011
+ @add_start_docstrings(
1012
+ """RemBERT Model with a `language modeling` head on top for CLM fine-tuning.""", REMBERT_START_DOCSTRING
1013
+ )
1014
+ class RemBertForCausalLM(RemBertPreTrainedModel, GenerationMixin):
1015
+ _tied_weights_keys = ["cls.predictions.decoder.weight"]
1016
+
1017
+ def __init__(self, config):
1018
+ super().__init__(config)
1019
+
1020
+ if not config.is_decoder:
1021
+ logger.warning("If you want to use `RemBertForCausalLM` as a standalone, add `is_decoder=True.`")
1022
+
1023
+ self.rembert = RemBertModel(config, add_pooling_layer=False)
1024
+ self.cls = RemBertOnlyMLMHead(config)
1025
+
1026
+ # Initialize weights and apply final processing
1027
+ self.post_init()
1028
+
1029
+ def get_output_embeddings(self):
1030
+ return self.cls.predictions.decoder
1031
+
1032
+ def set_output_embeddings(self, new_embeddings):
1033
+ self.cls.predictions.decoder = new_embeddings
1034
+
1035
+ @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1036
+ @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
1037
+ def forward(
1038
+ self,
1039
+ input_ids: Optional[torch.LongTensor] = None,
1040
+ attention_mask: Optional[torch.LongTensor] = None,
1041
+ token_type_ids: Optional[torch.LongTensor] = None,
1042
+ position_ids: Optional[torch.LongTensor] = None,
1043
+ head_mask: Optional[torch.FloatTensor] = None,
1044
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1045
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1046
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1047
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1048
+ labels: Optional[torch.LongTensor] = None,
1049
+ use_cache: Optional[bool] = None,
1050
+ output_attentions: Optional[bool] = None,
1051
+ output_hidden_states: Optional[bool] = None,
1052
+ return_dict: Optional[bool] = None,
1053
+ **kwargs,
1054
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
1055
+ r"""
1056
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1057
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1058
+ the model is configured as a decoder.
1059
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1060
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1061
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
1062
+
1063
+ - 1 for tokens that are **not masked**,
1064
+ - 0 for tokens that are **masked**.
1065
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1066
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1067
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1068
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1069
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1070
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1071
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1072
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
1073
+ ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`.
1074
+ use_cache (`bool`, *optional*):
1075
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1076
+ `past_key_values`).
1077
+
1078
+ Returns:
1079
+
1080
+ Example:
1081
+
1082
+ ```python
1083
+ >>> from transformers import AutoTokenizer, RemBertForCausalLM, RemBertConfig
1084
+ >>> import torch
1085
+
1086
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/rembert")
1087
+ >>> config = RemBertConfig.from_pretrained("google/rembert")
1088
+ >>> config.is_decoder = True
1089
+ >>> model = RemBertForCausalLM.from_pretrained("google/rembert", config=config)
1090
+
1091
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1092
+ >>> outputs = model(**inputs)
1093
+
1094
+ >>> prediction_logits = outputs.logits
1095
+ ```"""
1096
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1097
+
1098
+ outputs = self.rembert(
1099
+ input_ids,
1100
+ attention_mask=attention_mask,
1101
+ token_type_ids=token_type_ids,
1102
+ position_ids=position_ids,
1103
+ head_mask=head_mask,
1104
+ inputs_embeds=inputs_embeds,
1105
+ encoder_hidden_states=encoder_hidden_states,
1106
+ encoder_attention_mask=encoder_attention_mask,
1107
+ past_key_values=past_key_values,
1108
+ use_cache=use_cache,
1109
+ output_attentions=output_attentions,
1110
+ output_hidden_states=output_hidden_states,
1111
+ return_dict=return_dict,
1112
+ )
1113
+
1114
+ sequence_output = outputs[0]
1115
+ prediction_scores = self.cls(sequence_output)
1116
+
1117
+ lm_loss = None
1118
+ if labels is not None:
1119
+ lm_loss = self.loss_function(
1120
+ prediction_scores,
1121
+ labels,
1122
+ vocab_size=self.config.vocab_size,
1123
+ **kwargs,
1124
+ )
1125
+
1126
+ if not return_dict:
1127
+ output = (prediction_scores,) + outputs[2:]
1128
+ return ((lm_loss,) + output) if lm_loss is not None else output
1129
+
1130
+ return CausalLMOutputWithCrossAttentions(
1131
+ loss=lm_loss,
1132
+ logits=prediction_scores,
1133
+ past_key_values=outputs.past_key_values,
1134
+ hidden_states=outputs.hidden_states,
1135
+ attentions=outputs.attentions,
1136
+ cross_attentions=outputs.cross_attentions,
1137
+ )
1138
+
1139
+ def _reorder_cache(self, past_key_values, beam_idx):
1140
+ reordered_past = ()
1141
+ for layer_past in past_key_values:
1142
+ reordered_past += (
1143
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
1144
+ + layer_past[2:],
1145
+ )
1146
+ return reordered_past
1147
+
1148
+
1149
+ @add_start_docstrings(
1150
+ """
1151
+ RemBERT Model transformer with a sequence classification/regression head on top (a linear layer on top of the
1152
+ pooled output) e.g. for GLUE tasks.
1153
+ """,
1154
+ REMBERT_START_DOCSTRING,
1155
+ )
1156
+ class RemBertForSequenceClassification(RemBertPreTrainedModel):
1157
+ def __init__(self, config):
1158
+ super().__init__(config)
1159
+ self.num_labels = config.num_labels
1160
+ self.rembert = RemBertModel(config)
1161
+ self.dropout = nn.Dropout(config.classifier_dropout_prob)
1162
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1163
+
1164
+ # Initialize weights and apply final processing
1165
+ self.post_init()
1166
+
1167
+ @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1168
+ @add_code_sample_docstrings(
1169
+ checkpoint="google/rembert",
1170
+ output_type=SequenceClassifierOutput,
1171
+ config_class=_CONFIG_FOR_DOC,
1172
+ )
1173
+ def forward(
1174
+ self,
1175
+ input_ids: Optional[torch.FloatTensor] = None,
1176
+ attention_mask: Optional[torch.FloatTensor] = None,
1177
+ token_type_ids: Optional[torch.LongTensor] = None,
1178
+ position_ids: Optional[torch.FloatTensor] = None,
1179
+ head_mask: Optional[torch.FloatTensor] = None,
1180
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1181
+ labels: Optional[torch.LongTensor] = None,
1182
+ output_attentions: Optional[bool] = None,
1183
+ output_hidden_states: Optional[bool] = None,
1184
+ return_dict: Optional[bool] = None,
1185
+ ) -> Union[Tuple, SequenceClassifierOutput]:
1186
+ r"""
1187
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1188
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1189
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1190
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1191
+ """
1192
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1193
+
1194
+ outputs = self.rembert(
1195
+ input_ids,
1196
+ attention_mask=attention_mask,
1197
+ token_type_ids=token_type_ids,
1198
+ position_ids=position_ids,
1199
+ head_mask=head_mask,
1200
+ inputs_embeds=inputs_embeds,
1201
+ output_attentions=output_attentions,
1202
+ output_hidden_states=output_hidden_states,
1203
+ return_dict=return_dict,
1204
+ )
1205
+
1206
+ pooled_output = outputs[1]
1207
+
1208
+ pooled_output = self.dropout(pooled_output)
1209
+ logits = self.classifier(pooled_output)
1210
+
1211
+ loss = None
1212
+ if labels is not None:
1213
+ if self.config.problem_type is None:
1214
+ if self.num_labels == 1:
1215
+ self.config.problem_type = "regression"
1216
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1217
+ self.config.problem_type = "single_label_classification"
1218
+ else:
1219
+ self.config.problem_type = "multi_label_classification"
1220
+
1221
+ if self.config.problem_type == "regression":
1222
+ loss_fct = MSELoss()
1223
+ if self.num_labels == 1:
1224
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1225
+ else:
1226
+ loss = loss_fct(logits, labels)
1227
+ elif self.config.problem_type == "single_label_classification":
1228
+ loss_fct = CrossEntropyLoss()
1229
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1230
+ elif self.config.problem_type == "multi_label_classification":
1231
+ loss_fct = BCEWithLogitsLoss()
1232
+ loss = loss_fct(logits, labels)
1233
+ if not return_dict:
1234
+ output = (logits,) + outputs[2:]
1235
+ return ((loss,) + output) if loss is not None else output
1236
+
1237
+ return SequenceClassifierOutput(
1238
+ loss=loss,
1239
+ logits=logits,
1240
+ hidden_states=outputs.hidden_states,
1241
+ attentions=outputs.attentions,
1242
+ )
1243
+
1244
+
1245
+ @add_start_docstrings(
1246
+ """
1247
+ RemBERT Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
1248
+ softmax) e.g. for RocStories/SWAG tasks.
1249
+ """,
1250
+ REMBERT_START_DOCSTRING,
1251
+ )
1252
+ class RemBertForMultipleChoice(RemBertPreTrainedModel):
1253
+ def __init__(self, config):
1254
+ super().__init__(config)
1255
+
1256
+ self.rembert = RemBertModel(config)
1257
+ self.dropout = nn.Dropout(config.classifier_dropout_prob)
1258
+ self.classifier = nn.Linear(config.hidden_size, 1)
1259
+
1260
+ # Initialize weights and apply final processing
1261
+ self.post_init()
1262
+
1263
+ @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
1264
+ @add_code_sample_docstrings(
1265
+ checkpoint="google/rembert",
1266
+ output_type=MultipleChoiceModelOutput,
1267
+ config_class=_CONFIG_FOR_DOC,
1268
+ )
1269
+ def forward(
1270
+ self,
1271
+ input_ids: Optional[torch.FloatTensor] = None,
1272
+ attention_mask: Optional[torch.FloatTensor] = None,
1273
+ token_type_ids: Optional[torch.LongTensor] = None,
1274
+ position_ids: Optional[torch.FloatTensor] = None,
1275
+ head_mask: Optional[torch.FloatTensor] = None,
1276
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1277
+ labels: Optional[torch.LongTensor] = None,
1278
+ output_attentions: Optional[bool] = None,
1279
+ output_hidden_states: Optional[bool] = None,
1280
+ return_dict: Optional[bool] = None,
1281
+ ) -> Union[Tuple, MultipleChoiceModelOutput]:
1282
+ r"""
1283
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1284
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
1285
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
1286
+ `input_ids` above)
1287
+ """
1288
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1289
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1290
+
1291
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1292
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1293
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1294
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1295
+ inputs_embeds = (
1296
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1297
+ if inputs_embeds is not None
1298
+ else None
1299
+ )
1300
+
1301
+ outputs = self.rembert(
1302
+ input_ids,
1303
+ attention_mask=attention_mask,
1304
+ token_type_ids=token_type_ids,
1305
+ position_ids=position_ids,
1306
+ head_mask=head_mask,
1307
+ inputs_embeds=inputs_embeds,
1308
+ output_attentions=output_attentions,
1309
+ output_hidden_states=output_hidden_states,
1310
+ return_dict=return_dict,
1311
+ )
1312
+
1313
+ pooled_output = outputs[1]
1314
+
1315
+ pooled_output = self.dropout(pooled_output)
1316
+ logits = self.classifier(pooled_output)
1317
+ reshaped_logits = logits.view(-1, num_choices)
1318
+
1319
+ loss = None
1320
+ if labels is not None:
1321
+ loss_fct = CrossEntropyLoss()
1322
+ loss = loss_fct(reshaped_logits, labels)
1323
+
1324
+ if not return_dict:
1325
+ output = (reshaped_logits,) + outputs[2:]
1326
+ return ((loss,) + output) if loss is not None else output
1327
+
1328
+ return MultipleChoiceModelOutput(
1329
+ loss=loss,
1330
+ logits=reshaped_logits,
1331
+ hidden_states=outputs.hidden_states,
1332
+ attentions=outputs.attentions,
1333
+ )
1334
+
1335
+
1336
+ @add_start_docstrings(
1337
+ """
1338
+ RemBERT Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1339
+ Named-Entity-Recognition (NER) tasks.
1340
+ """,
1341
+ REMBERT_START_DOCSTRING,
1342
+ )
1343
+ class RemBertForTokenClassification(RemBertPreTrainedModel):
1344
+ def __init__(self, config):
1345
+ super().__init__(config)
1346
+ self.num_labels = config.num_labels
1347
+
1348
+ self.rembert = RemBertModel(config, add_pooling_layer=False)
1349
+ self.dropout = nn.Dropout(config.classifier_dropout_prob)
1350
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1351
+
1352
+ # Initialize weights and apply final processing
1353
+ self.post_init()
1354
+
1355
+ @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1356
+ @add_code_sample_docstrings(
1357
+ checkpoint="google/rembert",
1358
+ output_type=TokenClassifierOutput,
1359
+ config_class=_CONFIG_FOR_DOC,
1360
+ )
1361
+ def forward(
1362
+ self,
1363
+ input_ids: Optional[torch.FloatTensor] = None,
1364
+ attention_mask: Optional[torch.FloatTensor] = None,
1365
+ token_type_ids: Optional[torch.LongTensor] = None,
1366
+ position_ids: Optional[torch.FloatTensor] = None,
1367
+ head_mask: Optional[torch.FloatTensor] = None,
1368
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1369
+ labels: Optional[torch.LongTensor] = None,
1370
+ output_attentions: Optional[bool] = None,
1371
+ output_hidden_states: Optional[bool] = None,
1372
+ return_dict: Optional[bool] = None,
1373
+ ) -> Union[Tuple, TokenClassifierOutput]:
1374
+ r"""
1375
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1376
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1377
+ """
1378
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1379
+
1380
+ outputs = self.rembert(
1381
+ input_ids,
1382
+ attention_mask=attention_mask,
1383
+ token_type_ids=token_type_ids,
1384
+ position_ids=position_ids,
1385
+ head_mask=head_mask,
1386
+ inputs_embeds=inputs_embeds,
1387
+ output_attentions=output_attentions,
1388
+ output_hidden_states=output_hidden_states,
1389
+ return_dict=return_dict,
1390
+ )
1391
+
1392
+ sequence_output = outputs[0]
1393
+
1394
+ sequence_output = self.dropout(sequence_output)
1395
+ logits = self.classifier(sequence_output)
1396
+
1397
+ loss = None
1398
+ if labels is not None:
1399
+ loss_fct = CrossEntropyLoss()
1400
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1401
+
1402
+ if not return_dict:
1403
+ output = (logits,) + outputs[2:]
1404
+ return ((loss,) + output) if loss is not None else output
1405
+
1406
+ return TokenClassifierOutput(
1407
+ loss=loss,
1408
+ logits=logits,
1409
+ hidden_states=outputs.hidden_states,
1410
+ attentions=outputs.attentions,
1411
+ )
1412
+
1413
+
1414
+ @add_start_docstrings(
1415
+ """
1416
+ RemBERT Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1417
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1418
+ """,
1419
+ REMBERT_START_DOCSTRING,
1420
+ )
1421
+ class RemBertForQuestionAnswering(RemBertPreTrainedModel):
1422
+ def __init__(self, config):
1423
+ super().__init__(config)
1424
+
1425
+ self.num_labels = config.num_labels
1426
+
1427
+ self.rembert = RemBertModel(config, add_pooling_layer=False)
1428
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1429
+
1430
+ # Initialize weights and apply final processing
1431
+ self.post_init()
1432
+
1433
+ @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1434
+ @add_code_sample_docstrings(
1435
+ checkpoint="google/rembert",
1436
+ output_type=QuestionAnsweringModelOutput,
1437
+ config_class=_CONFIG_FOR_DOC,
1438
+ )
1439
+ def forward(
1440
+ self,
1441
+ input_ids: Optional[torch.FloatTensor] = None,
1442
+ attention_mask: Optional[torch.FloatTensor] = None,
1443
+ token_type_ids: Optional[torch.LongTensor] = None,
1444
+ position_ids: Optional[torch.FloatTensor] = None,
1445
+ head_mask: Optional[torch.FloatTensor] = None,
1446
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1447
+ start_positions: Optional[torch.LongTensor] = None,
1448
+ end_positions: Optional[torch.LongTensor] = None,
1449
+ output_attentions: Optional[bool] = None,
1450
+ output_hidden_states: Optional[bool] = None,
1451
+ return_dict: Optional[bool] = None,
1452
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1453
+ r"""
1454
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1455
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1456
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1457
+ are not taken into account for computing the loss.
1458
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1459
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1460
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1461
+ are not taken into account for computing the loss.
1462
+ """
1463
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1464
+
1465
+ outputs = self.rembert(
1466
+ input_ids,
1467
+ attention_mask=attention_mask,
1468
+ token_type_ids=token_type_ids,
1469
+ position_ids=position_ids,
1470
+ head_mask=head_mask,
1471
+ inputs_embeds=inputs_embeds,
1472
+ output_attentions=output_attentions,
1473
+ output_hidden_states=output_hidden_states,
1474
+ return_dict=return_dict,
1475
+ )
1476
+
1477
+ sequence_output = outputs[0]
1478
+
1479
+ logits = self.qa_outputs(sequence_output)
1480
+ start_logits, end_logits = logits.split(1, dim=-1)
1481
+ start_logits = start_logits.squeeze(-1)
1482
+ end_logits = end_logits.squeeze(-1)
1483
+
1484
+ total_loss = None
1485
+ if start_positions is not None and end_positions is not None:
1486
+ # If we are on multi-GPU, split add a dimension
1487
+ if len(start_positions.size()) > 1:
1488
+ start_positions = start_positions.squeeze(-1)
1489
+ if len(end_positions.size()) > 1:
1490
+ end_positions = end_positions.squeeze(-1)
1491
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1492
+ ignored_index = start_logits.size(1)
1493
+ start_positions.clamp_(0, ignored_index)
1494
+ end_positions.clamp_(0, ignored_index)
1495
+
1496
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1497
+ start_loss = loss_fct(start_logits, start_positions)
1498
+ end_loss = loss_fct(end_logits, end_positions)
1499
+ total_loss = (start_loss + end_loss) / 2
1500
+
1501
+ if not return_dict:
1502
+ output = (start_logits, end_logits) + outputs[2:]
1503
+ return ((total_loss,) + output) if total_loss is not None else output
1504
+
1505
+ return QuestionAnsweringModelOutput(
1506
+ loss=total_loss,
1507
+ start_logits=start_logits,
1508
+ end_logits=end_logits,
1509
+ hidden_states=outputs.hidden_states,
1510
+ attentions=outputs.attentions,
1511
+ )
1512
+
1513
+
1514
+ __all__ = [
1515
+ "RemBertForCausalLM",
1516
+ "RemBertForMaskedLM",
1517
+ "RemBertForMultipleChoice",
1518
+ "RemBertForQuestionAnswering",
1519
+ "RemBertForSequenceClassification",
1520
+ "RemBertForTokenClassification",
1521
+ "RemBertLayer",
1522
+ "RemBertModel",
1523
+ "RemBertPreTrainedModel",
1524
+ "load_tf_weights_in_rembert",
1525
+ ]
docs/transformers/build/lib/transformers/models/roberta/tokenization_roberta.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for RoBERTa."""
16
+
17
+ import json
18
+ import os
19
+ from functools import lru_cache
20
+ from typing import List, Optional, Tuple
21
+
22
+ import regex as re
23
+
24
+ from ...tokenization_utils import AddedToken, PreTrainedTokenizer
25
+ from ...utils import logging
26
+
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+ VOCAB_FILES_NAMES = {
31
+ "vocab_file": "vocab.json",
32
+ "merges_file": "merges.txt",
33
+ }
34
+
35
+
36
+ @lru_cache()
37
+ def bytes_to_unicode():
38
+ """
39
+ Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
40
+ characters the bpe code barfs on.
41
+
42
+ The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
43
+ if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
44
+ decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
45
+ tables between utf-8 bytes and unicode strings.
46
+ """
47
+ bs = (
48
+ list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
49
+ )
50
+ cs = bs[:]
51
+ n = 0
52
+ for b in range(2**8):
53
+ if b not in bs:
54
+ bs.append(b)
55
+ cs.append(2**8 + n)
56
+ n += 1
57
+ cs = [chr(n) for n in cs]
58
+ return dict(zip(bs, cs))
59
+
60
+
61
+ def get_pairs(word):
62
+ """
63
+ Return set of symbol pairs in a word.
64
+
65
+ Word is represented as tuple of symbols (symbols being variable-length strings).
66
+ """
67
+ pairs = set()
68
+ prev_char = word[0]
69
+ for char in word[1:]:
70
+ pairs.add((prev_char, char))
71
+ prev_char = char
72
+ return pairs
73
+
74
+
75
+ class RobertaTokenizer(PreTrainedTokenizer):
76
+ """
77
+ Constructs a RoBERTa tokenizer, derived from the GPT-2 tokenizer, using byte-level Byte-Pair-Encoding.
78
+
79
+ This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
80
+ be encoded differently whether it is at the beginning of the sentence (without space) or not:
81
+
82
+ ```python
83
+ >>> from transformers import RobertaTokenizer
84
+
85
+ >>> tokenizer = RobertaTokenizer.from_pretrained("FacebookAI/roberta-base")
86
+ >>> tokenizer("Hello world")["input_ids"]
87
+ [0, 31414, 232, 2]
88
+
89
+ >>> tokenizer(" Hello world")["input_ids"]
90
+ [0, 20920, 232, 2]
91
+ ```
92
+
93
+ You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you
94
+ call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.
95
+
96
+ <Tip>
97
+
98
+ When used with `is_split_into_words=True`, this tokenizer will add a space before each word (even the first one).
99
+
100
+ </Tip>
101
+
102
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
103
+ this superclass for more information regarding those methods.
104
+
105
+ Args:
106
+ vocab_file (`str`):
107
+ Path to the vocabulary file.
108
+ merges_file (`str`):
109
+ Path to the merges file.
110
+ errors (`str`, *optional*, defaults to `"replace"`):
111
+ Paradigm to follow when decoding bytes to UTF-8. See
112
+ [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
113
+ bos_token (`str`, *optional*, defaults to `"<s>"`):
114
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
115
+
116
+ <Tip>
117
+
118
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
119
+ sequence. The token used is the `cls_token`.
120
+
121
+ </Tip>
122
+
123
+ eos_token (`str`, *optional*, defaults to `"</s>"`):
124
+ The end of sequence token.
125
+
126
+ <Tip>
127
+
128
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
129
+ The token used is the `sep_token`.
130
+
131
+ </Tip>
132
+
133
+ sep_token (`str`, *optional*, defaults to `"</s>"`):
134
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
135
+ sequence classification or for a text and a question for question answering. It is also used as the last
136
+ token of a sequence built with special tokens.
137
+ cls_token (`str`, *optional*, defaults to `"<s>"`):
138
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
139
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
140
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
141
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
142
+ token instead.
143
+ pad_token (`str`, *optional*, defaults to `"<pad>"`):
144
+ The token used for padding, for example when batching sequences of different lengths.
145
+ mask_token (`str`, *optional*, defaults to `"<mask>"`):
146
+ The token used for masking values. This is the token used when training this model with masked language
147
+ modeling. This is the token which the model will try to predict.
148
+ add_prefix_space (`bool`, *optional*, defaults to `False`):
149
+ Whether or not to add an initial space to the input. This allows to treat the leading word just as any
150
+ other word. (RoBERTa tokenizer detect beginning of words by the preceding space).
151
+ """
152
+
153
+ vocab_files_names = VOCAB_FILES_NAMES
154
+ model_input_names = ["input_ids", "attention_mask"]
155
+
156
+ def __init__(
157
+ self,
158
+ vocab_file,
159
+ merges_file,
160
+ errors="replace",
161
+ bos_token="<s>",
162
+ eos_token="</s>",
163
+ sep_token="</s>",
164
+ cls_token="<s>",
165
+ unk_token="<unk>",
166
+ pad_token="<pad>",
167
+ mask_token="<mask>",
168
+ add_prefix_space=False,
169
+ **kwargs,
170
+ ):
171
+ bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
172
+ pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
173
+ eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
174
+ unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
175
+ sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token
176
+ cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token
177
+
178
+ # Mask token behave like a normal word, i.e. include the space before it
179
+ mask_token = (
180
+ AddedToken(mask_token, lstrip=True, rstrip=False, normalized=False)
181
+ if isinstance(mask_token, str)
182
+ else mask_token
183
+ )
184
+
185
+ # these special tokens are not part of the vocab.json, let's add them in the correct order
186
+
187
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
188
+ self.encoder = json.load(vocab_handle)
189
+ self.decoder = {v: k for k, v in self.encoder.items()}
190
+ self.errors = errors # how to handle errors in decoding
191
+ self.byte_encoder = bytes_to_unicode()
192
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
193
+ with open(merges_file, encoding="utf-8") as merges_handle:
194
+ bpe_merges = merges_handle.read().split("\n")[1:-1]
195
+ bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
196
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
197
+ self.cache = {}
198
+ self.add_prefix_space = add_prefix_space
199
+
200
+ # Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
201
+ self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
202
+
203
+ super().__init__(
204
+ errors=errors,
205
+ bos_token=bos_token,
206
+ eos_token=eos_token,
207
+ unk_token=unk_token,
208
+ sep_token=sep_token,
209
+ cls_token=cls_token,
210
+ pad_token=pad_token,
211
+ mask_token=mask_token,
212
+ add_prefix_space=add_prefix_space,
213
+ **kwargs,
214
+ )
215
+
216
+ @property
217
+ def vocab_size(self):
218
+ return len(self.encoder)
219
+
220
+ def get_vocab(self):
221
+ vocab = dict(self.encoder).copy()
222
+ vocab.update(self.added_tokens_encoder)
223
+ return vocab
224
+
225
+ def bpe(self, token):
226
+ if token in self.cache:
227
+ return self.cache[token]
228
+ word = tuple(token)
229
+ pairs = get_pairs(word)
230
+
231
+ if not pairs:
232
+ return token
233
+
234
+ while True:
235
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
236
+ if bigram not in self.bpe_ranks:
237
+ break
238
+ first, second = bigram
239
+ new_word = []
240
+ i = 0
241
+ while i < len(word):
242
+ try:
243
+ j = word.index(first, i)
244
+ except ValueError:
245
+ new_word.extend(word[i:])
246
+ break
247
+ else:
248
+ new_word.extend(word[i:j])
249
+ i = j
250
+
251
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
252
+ new_word.append(first + second)
253
+ i += 2
254
+ else:
255
+ new_word.append(word[i])
256
+ i += 1
257
+ new_word = tuple(new_word)
258
+ word = new_word
259
+ if len(word) == 1:
260
+ break
261
+ else:
262
+ pairs = get_pairs(word)
263
+ word = " ".join(word)
264
+ self.cache[token] = word
265
+ return word
266
+
267
+ def _tokenize(self, text):
268
+ """Tokenize a string."""
269
+ bpe_tokens = []
270
+ for token in re.findall(self.pat, text):
271
+ token = "".join(
272
+ self.byte_encoder[b] for b in token.encode("utf-8")
273
+ ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
274
+ bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
275
+ return bpe_tokens
276
+
277
+ def _convert_token_to_id(self, token):
278
+ """Converts a token (str) in an id using the vocab."""
279
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
280
+
281
+ def _convert_id_to_token(self, index):
282
+ """Converts an index (integer) in a token (str) using the vocab."""
283
+ return self.decoder.get(index)
284
+
285
+ def convert_tokens_to_string(self, tokens):
286
+ """Converts a sequence of tokens (string) in a single string."""
287
+ text = "".join(tokens)
288
+ text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
289
+ return text
290
+
291
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
292
+ if not os.path.isdir(save_directory):
293
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
294
+ return
295
+ vocab_file = os.path.join(
296
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
297
+ )
298
+ merge_file = os.path.join(
299
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
300
+ )
301
+
302
+ with open(vocab_file, "w", encoding="utf-8") as f:
303
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
304
+
305
+ index = 0
306
+ with open(merge_file, "w", encoding="utf-8") as writer:
307
+ writer.write("#version: 0.2\n")
308
+ for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
309
+ if index != token_index:
310
+ logger.warning(
311
+ f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
312
+ " Please check that the tokenizer is not corrupted!"
313
+ )
314
+ index = token_index
315
+ writer.write(" ".join(bpe_tokens) + "\n")
316
+ index += 1
317
+
318
+ return vocab_file, merge_file
319
+
320
+ def build_inputs_with_special_tokens(
321
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
322
+ ) -> List[int]:
323
+ """
324
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
325
+ adding special tokens. A RoBERTa sequence has the following format:
326
+
327
+ - single sequence: `<s> X </s>`
328
+ - pair of sequences: `<s> A </s></s> B </s>`
329
+
330
+ Args:
331
+ token_ids_0 (`List[int]`):
332
+ List of IDs to which the special tokens will be added.
333
+ token_ids_1 (`List[int]`, *optional*):
334
+ Optional second list of IDs for sequence pairs.
335
+
336
+ Returns:
337
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
338
+ """
339
+ if token_ids_1 is None:
340
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
341
+ cls = [self.cls_token_id]
342
+ sep = [self.sep_token_id]
343
+ return cls + token_ids_0 + sep + sep + token_ids_1 + sep
344
+
345
+ def get_special_tokens_mask(
346
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
347
+ ) -> List[int]:
348
+ """
349
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
350
+ special tokens using the tokenizer `prepare_for_model` method.
351
+
352
+ Args:
353
+ token_ids_0 (`List[int]`):
354
+ List of IDs.
355
+ token_ids_1 (`List[int]`, *optional*):
356
+ Optional second list of IDs for sequence pairs.
357
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
358
+ Whether or not the token list is already formatted with special tokens for the model.
359
+
360
+ Returns:
361
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
362
+ """
363
+ if already_has_special_tokens:
364
+ return super().get_special_tokens_mask(
365
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
366
+ )
367
+
368
+ if token_ids_1 is None:
369
+ return [1] + ([0] * len(token_ids_0)) + [1]
370
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
371
+
372
+ def create_token_type_ids_from_sequences(
373
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
374
+ ) -> List[int]:
375
+ """
376
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. RoBERTa does not
377
+ make use of token type ids, therefore a list of zeros is returned.
378
+
379
+ Args:
380
+ token_ids_0 (`List[int]`):
381
+ List of IDs.
382
+ token_ids_1 (`List[int]`, *optional*):
383
+ Optional second list of IDs for sequence pairs.
384
+
385
+ Returns:
386
+ `List[int]`: List of zeros.
387
+ """
388
+ sep = [self.sep_token_id]
389
+ cls = [self.cls_token_id]
390
+
391
+ if token_ids_1 is None:
392
+ return len(cls + token_ids_0 + sep) * [0]
393
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
394
+
395
+ def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
396
+ add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
397
+ if (is_split_into_words or add_prefix_space) and (len(text) > 0 and not text[0].isspace()):
398
+ text = " " + text
399
+ return (text, kwargs)
400
+
401
+
402
+ __all__ = ["RobertaTokenizer"]
docs/transformers/build/lib/transformers/models/roberta_prelayernorm/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_roberta_prelayernorm import *
22
+ from .modeling_flax_roberta_prelayernorm import *
23
+ from .modeling_roberta_prelayernorm import *
24
+ from .modeling_tf_roberta_prelayernorm import *
25
+ else:
26
+ import sys
27
+
28
+ _file = globals()["__file__"]
29
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/roberta_prelayernorm/convert_roberta_prelayernorm_original_pytorch_checkpoint_to_pytorch.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert RoBERTa-PreLayerNorm checkpoint."""
16
+
17
+ import argparse
18
+
19
+ import torch
20
+ from huggingface_hub import hf_hub_download
21
+
22
+ from transformers import AutoTokenizer, RobertaPreLayerNormConfig, RobertaPreLayerNormForMaskedLM
23
+ from transformers.utils import logging
24
+
25
+
26
+ logging.set_verbosity_info()
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ def convert_roberta_prelayernorm_checkpoint_to_pytorch(checkpoint_repo: str, pytorch_dump_folder_path: str):
31
+ """
32
+ Copy/paste/tweak roberta_prelayernorm's weights to our BERT structure.
33
+ """
34
+ # convert configuration
35
+ config = RobertaPreLayerNormConfig.from_pretrained(
36
+ checkpoint_repo, architectures=["RobertaPreLayerNormForMaskedLM"]
37
+ )
38
+
39
+ # convert state_dict
40
+ original_state_dict = torch.load(
41
+ hf_hub_download(repo_id=checkpoint_repo, filename="pytorch_model.bin"), weights_only=True
42
+ )
43
+ state_dict = {}
44
+ for tensor_key, tensor_value in original_state_dict.items():
45
+ # The transformer implementation gives the model a unique name, rather than overwiriting 'roberta'
46
+ if tensor_key.startswith("roberta."):
47
+ tensor_key = "roberta_prelayernorm." + tensor_key[len("roberta.") :]
48
+
49
+ # The original implementation contains weights which are not used, remove them from the state_dict
50
+ if tensor_key.endswith(".self.LayerNorm.weight") or tensor_key.endswith(".self.LayerNorm.bias"):
51
+ continue
52
+
53
+ state_dict[tensor_key] = tensor_value
54
+
55
+ model = RobertaPreLayerNormForMaskedLM.from_pretrained(
56
+ pretrained_model_name_or_path=None, config=config, state_dict=state_dict
57
+ )
58
+ model.save_pretrained(pytorch_dump_folder_path)
59
+
60
+ # convert tokenizer
61
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint_repo)
62
+ tokenizer.save_pretrained(pytorch_dump_folder_path)
63
+
64
+
65
+ if __name__ == "__main__":
66
+ parser = argparse.ArgumentParser()
67
+ # Required parameters
68
+ parser.add_argument(
69
+ "--checkpoint-repo",
70
+ default=None,
71
+ type=str,
72
+ required=True,
73
+ help="Path the official PyTorch dump, e.g. 'andreasmadsen/efficient_mlm_m0.40'.",
74
+ )
75
+ parser.add_argument(
76
+ "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
77
+ )
78
+ args = parser.parse_args()
79
+ convert_roberta_prelayernorm_checkpoint_to_pytorch(args.checkpoint_repo, args.pytorch_dump_folder_path)
docs/transformers/build/lib/transformers/models/roberta_prelayernorm/modeling_flax_roberta_prelayernorm.py ADDED
@@ -0,0 +1,1527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The Google Flax Team Authors and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Flax RoBERTa-PreLayerNorm model."""
16
+
17
+ from typing import Callable, Optional, Tuple
18
+
19
+ import flax.linen as nn
20
+ import jax
21
+ import jax.numpy as jnp
22
+ import numpy as np
23
+ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
24
+ from flax.linen import combine_masks, make_causal_mask
25
+ from flax.linen import partitioning as nn_partitioning
26
+ from flax.linen.attention import dot_product_attention_weights
27
+ from flax.traverse_util import flatten_dict, unflatten_dict
28
+ from jax import lax
29
+
30
+ from ...modeling_flax_outputs import (
31
+ FlaxBaseModelOutputWithPastAndCrossAttentions,
32
+ FlaxBaseModelOutputWithPooling,
33
+ FlaxBaseModelOutputWithPoolingAndCrossAttentions,
34
+ FlaxCausalLMOutputWithCrossAttentions,
35
+ FlaxMaskedLMOutput,
36
+ FlaxMultipleChoiceModelOutput,
37
+ FlaxQuestionAnsweringModelOutput,
38
+ FlaxSequenceClassifierOutput,
39
+ FlaxTokenClassifierOutput,
40
+ )
41
+ from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring, overwrite_call_docstring
42
+ from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
43
+ from .configuration_roberta_prelayernorm import RobertaPreLayerNormConfig
44
+
45
+
46
+ logger = logging.get_logger(__name__)
47
+
48
+ _CHECKPOINT_FOR_DOC = "andreasmadsen/efficient_mlm_m0.40"
49
+ _CONFIG_FOR_DOC = "RobertaPreLayerNormConfig"
50
+
51
+ remat = nn_partitioning.remat
52
+
53
+
54
+ # Copied from transformers.models.roberta.modeling_flax_roberta.create_position_ids_from_input_ids
55
+ def create_position_ids_from_input_ids(input_ids, padding_idx):
56
+ """
57
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
58
+ are ignored. This is modified from fairseq's `utils.make_positions`.
59
+
60
+ Args:
61
+ input_ids: jnp.ndarray
62
+ padding_idx: int
63
+
64
+ Returns: jnp.ndarray
65
+ """
66
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
67
+ mask = (input_ids != padding_idx).astype("i4")
68
+
69
+ if mask.ndim > 2:
70
+ mask = mask.reshape((-1, mask.shape[-1]))
71
+ incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask
72
+ incremental_indices = incremental_indices.reshape(input_ids.shape)
73
+ else:
74
+ incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask
75
+
76
+ return incremental_indices.astype("i4") + padding_idx
77
+
78
+
79
+ ROBERTA_PRELAYERNORM_START_DOCSTRING = r"""
80
+
81
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
82
+ library implements for all its model (such as downloading, saving and converting weights from PyTorch models)
83
+
84
+ This model is also a
85
+ [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as
86
+ a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and
87
+ behavior.
88
+
89
+ Finally, this model supports inherent JAX features such as:
90
+
91
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
92
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
93
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
94
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
95
+
96
+ Parameters:
97
+ config ([`RobertaPreLayerNormConfig`]): Model configuration class with all the parameters of the
98
+ model. Initializing with a config file does not load the weights associated with the model, only the
99
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
100
+ """
101
+
102
+ ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING = r"""
103
+ Args:
104
+ input_ids (`numpy.ndarray` of shape `({0})`):
105
+ Indices of input sequence tokens in the vocabulary.
106
+
107
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
108
+ [`PreTrainedTokenizer.__call__`] for details.
109
+
110
+ [What are input IDs?](../glossary#input-ids)
111
+ attention_mask (`numpy.ndarray` of shape `({0})`, *optional*):
112
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
113
+
114
+ - 1 for tokens that are **not masked**,
115
+ - 0 for tokens that are **masked**.
116
+
117
+ [What are attention masks?](../glossary#attention-mask)
118
+ token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*):
119
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
120
+ 1]`:
121
+
122
+ - 0 corresponds to a *sentence A* token,
123
+ - 1 corresponds to a *sentence B* token.
124
+
125
+ [What are token type IDs?](../glossary#token-type-ids)
126
+ position_ids (`numpy.ndarray` of shape `({0})`, *optional*):
127
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
128
+ config.max_position_embeddings - 1]`.
129
+ head_mask (`numpy.ndarray` of shape `({0})`, `optional):
130
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
131
+
132
+ - 1 indicates the head is **not masked**,
133
+ - 0 indicates the head is **masked**.
134
+
135
+ return_dict (`bool`, *optional*):
136
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
137
+ """
138
+
139
+
140
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings with Bert->RobertaPreLayerNorm
141
+ class FlaxRobertaPreLayerNormEmbeddings(nn.Module):
142
+ """Construct the embeddings from word, position and token_type embeddings."""
143
+
144
+ config: RobertaPreLayerNormConfig
145
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
146
+
147
+ def setup(self):
148
+ self.word_embeddings = nn.Embed(
149
+ self.config.vocab_size,
150
+ self.config.hidden_size,
151
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
152
+ dtype=self.dtype,
153
+ )
154
+ self.position_embeddings = nn.Embed(
155
+ self.config.max_position_embeddings,
156
+ self.config.hidden_size,
157
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
158
+ dtype=self.dtype,
159
+ )
160
+ self.token_type_embeddings = nn.Embed(
161
+ self.config.type_vocab_size,
162
+ self.config.hidden_size,
163
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
164
+ dtype=self.dtype,
165
+ )
166
+ self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
167
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
168
+
169
+ def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
170
+ # Embed
171
+ inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
172
+ position_embeds = self.position_embeddings(position_ids.astype("i4"))
173
+ token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))
174
+
175
+ # Sum all embeddings
176
+ hidden_states = inputs_embeds + token_type_embeddings + position_embeds
177
+
178
+ # Layer Norm
179
+ hidden_states = self.LayerNorm(hidden_states)
180
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
181
+ return hidden_states
182
+
183
+
184
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->RobertaPreLayerNorm
185
+ class FlaxRobertaPreLayerNormSelfAttention(nn.Module):
186
+ config: RobertaPreLayerNormConfig
187
+ causal: bool = False
188
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
189
+
190
+ def setup(self):
191
+ self.head_dim = self.config.hidden_size // self.config.num_attention_heads
192
+ if self.config.hidden_size % self.config.num_attention_heads != 0:
193
+ raise ValueError(
194
+ "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` "
195
+ " : {self.config.num_attention_heads}"
196
+ )
197
+
198
+ self.query = nn.Dense(
199
+ self.config.hidden_size,
200
+ dtype=self.dtype,
201
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
202
+ )
203
+ self.key = nn.Dense(
204
+ self.config.hidden_size,
205
+ dtype=self.dtype,
206
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
207
+ )
208
+ self.value = nn.Dense(
209
+ self.config.hidden_size,
210
+ dtype=self.dtype,
211
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
212
+ )
213
+
214
+ if self.causal:
215
+ self.causal_mask = make_causal_mask(
216
+ jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
217
+ )
218
+
219
+ def _split_heads(self, hidden_states):
220
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim))
221
+
222
+ def _merge_heads(self, hidden_states):
223
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,))
224
+
225
+ @nn.compact
226
+ # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache
227
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
228
+ """
229
+ This function takes projected key, value states from a single input token and concatenates the states to cached
230
+ states from previous steps. This function is slightly adapted from the official Flax repository:
231
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
232
+ """
233
+ # detect if we're initializing by absence of existing cache data.
234
+ is_initialized = self.has_variable("cache", "cached_key")
235
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
236
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
237
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
238
+
239
+ if is_initialized:
240
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
241
+ # update key, value caches with our new 1d spatial slices
242
+ cur_index = cache_index.value
243
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
244
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
245
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
246
+ cached_key.value = key
247
+ cached_value.value = value
248
+ num_updated_cache_vectors = query.shape[1]
249
+ cache_index.value = cache_index.value + num_updated_cache_vectors
250
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
251
+ pad_mask = jnp.broadcast_to(
252
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
253
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
254
+ )
255
+ attention_mask = combine_masks(pad_mask, attention_mask)
256
+ return key, value, attention_mask
257
+
258
+ def __call__(
259
+ self,
260
+ hidden_states,
261
+ attention_mask,
262
+ layer_head_mask,
263
+ key_value_states: Optional[jnp.ndarray] = None,
264
+ init_cache: bool = False,
265
+ deterministic=True,
266
+ output_attentions: bool = False,
267
+ ):
268
+ # if key_value_states are provided this layer is used as a cross-attention layer
269
+ # for the decoder
270
+ is_cross_attention = key_value_states is not None
271
+ batch_size = hidden_states.shape[0]
272
+
273
+ # get query proj
274
+ query_states = self.query(hidden_states)
275
+ # get key, value proj
276
+ if is_cross_attention:
277
+ # cross_attentions
278
+ key_states = self.key(key_value_states)
279
+ value_states = self.value(key_value_states)
280
+ else:
281
+ # self_attention
282
+ key_states = self.key(hidden_states)
283
+ value_states = self.value(hidden_states)
284
+
285
+ query_states = self._split_heads(query_states)
286
+ key_states = self._split_heads(key_states)
287
+ value_states = self._split_heads(value_states)
288
+
289
+ # handle cache prepare causal attention mask
290
+ if self.causal:
291
+ query_length, key_length = query_states.shape[1], key_states.shape[1]
292
+ if self.has_variable("cache", "cached_key"):
293
+ mask_shift = self.variables["cache"]["cache_index"]
294
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
295
+ causal_mask = lax.dynamic_slice(
296
+ self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
297
+ )
298
+ else:
299
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
300
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
301
+
302
+ # combine masks if needed
303
+ if attention_mask is not None and self.causal:
304
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
305
+ attention_mask = combine_masks(attention_mask, causal_mask)
306
+ elif self.causal:
307
+ attention_mask = causal_mask
308
+ elif attention_mask is not None:
309
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
310
+
311
+ # During fast autoregressive decoding, we feed one position at a time,
312
+ # and cache the keys and values step by step.
313
+ if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
314
+ key_states, value_states, attention_mask = self._concatenate_to_cache(
315
+ key_states, value_states, query_states, attention_mask
316
+ )
317
+
318
+ # Convert the boolean attention mask to an attention bias.
319
+ if attention_mask is not None:
320
+ # attention mask in the form of attention bias
321
+ attention_bias = lax.select(
322
+ attention_mask > 0,
323
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
324
+ jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
325
+ )
326
+ else:
327
+ attention_bias = None
328
+
329
+ dropout_rng = None
330
+ if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
331
+ dropout_rng = self.make_rng("dropout")
332
+
333
+ attn_weights = dot_product_attention_weights(
334
+ query_states,
335
+ key_states,
336
+ bias=attention_bias,
337
+ dropout_rng=dropout_rng,
338
+ dropout_rate=self.config.attention_probs_dropout_prob,
339
+ broadcast_dropout=True,
340
+ deterministic=deterministic,
341
+ dtype=self.dtype,
342
+ precision=None,
343
+ )
344
+
345
+ # Mask heads if we want to
346
+ if layer_head_mask is not None:
347
+ attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask)
348
+
349
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
350
+ attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
351
+
352
+ outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
353
+ return outputs
354
+
355
+
356
+ class FlaxRobertaPreLayerNormSelfOutput(nn.Module):
357
+ config: RobertaPreLayerNormConfig
358
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
359
+
360
+ def setup(self):
361
+ self.dense = nn.Dense(
362
+ self.config.hidden_size,
363
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
364
+ dtype=self.dtype,
365
+ )
366
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
367
+
368
+ def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
369
+ hidden_states = self.dense(hidden_states)
370
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
371
+ hidden_states = hidden_states + input_tensor
372
+ return hidden_states
373
+
374
+
375
+ class FlaxRobertaPreLayerNormAttention(nn.Module):
376
+ config: RobertaPreLayerNormConfig
377
+ causal: bool = False
378
+ dtype: jnp.dtype = jnp.float32
379
+
380
+ def setup(self):
381
+ self.self = FlaxRobertaPreLayerNormSelfAttention(self.config, causal=self.causal, dtype=self.dtype)
382
+ self.output = FlaxRobertaPreLayerNormSelfOutput(self.config, dtype=self.dtype)
383
+ self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
384
+
385
+ def __call__(
386
+ self,
387
+ hidden_states,
388
+ attention_mask,
389
+ layer_head_mask,
390
+ key_value_states=None,
391
+ init_cache=False,
392
+ deterministic=True,
393
+ output_attentions: bool = False,
394
+ ):
395
+ hidden_states_pre_layer_norm = self.LayerNorm(hidden_states)
396
+ # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
397
+ # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
398
+ # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
399
+ attn_outputs = self.self(
400
+ hidden_states_pre_layer_norm,
401
+ attention_mask,
402
+ layer_head_mask=layer_head_mask,
403
+ key_value_states=key_value_states,
404
+ init_cache=init_cache,
405
+ deterministic=deterministic,
406
+ output_attentions=output_attentions,
407
+ )
408
+ attn_output = attn_outputs[0]
409
+ hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
410
+
411
+ outputs = (hidden_states,)
412
+
413
+ if output_attentions:
414
+ outputs += (attn_outputs[1],)
415
+
416
+ return outputs
417
+
418
+
419
+ class FlaxRobertaPreLayerNormIntermediate(nn.Module):
420
+ config: RobertaPreLayerNormConfig
421
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
422
+
423
+ def setup(self):
424
+ self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
425
+ self.dense = nn.Dense(
426
+ self.config.intermediate_size,
427
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
428
+ dtype=self.dtype,
429
+ )
430
+ self.activation = ACT2FN[self.config.hidden_act]
431
+
432
+ def __call__(self, hidden_states):
433
+ hidden_states = self.LayerNorm(hidden_states)
434
+ hidden_states = self.dense(hidden_states)
435
+ hidden_states = self.activation(hidden_states)
436
+ return hidden_states
437
+
438
+
439
+ class FlaxRobertaPreLayerNormOutput(nn.Module):
440
+ config: RobertaPreLayerNormConfig
441
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
442
+
443
+ def setup(self):
444
+ self.dense = nn.Dense(
445
+ self.config.hidden_size,
446
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
447
+ dtype=self.dtype,
448
+ )
449
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
450
+
451
+ def __call__(self, hidden_states, attention_output, deterministic: bool = True):
452
+ hidden_states = self.dense(hidden_states)
453
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
454
+ hidden_states = hidden_states + attention_output
455
+ return hidden_states
456
+
457
+
458
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayer with Bert->RobertaPreLayerNorm
459
+ class FlaxRobertaPreLayerNormLayer(nn.Module):
460
+ config: RobertaPreLayerNormConfig
461
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
462
+
463
+ def setup(self):
464
+ self.attention = FlaxRobertaPreLayerNormAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype)
465
+ self.intermediate = FlaxRobertaPreLayerNormIntermediate(self.config, dtype=self.dtype)
466
+ self.output = FlaxRobertaPreLayerNormOutput(self.config, dtype=self.dtype)
467
+ if self.config.add_cross_attention:
468
+ self.crossattention = FlaxRobertaPreLayerNormAttention(self.config, causal=False, dtype=self.dtype)
469
+
470
+ def __call__(
471
+ self,
472
+ hidden_states,
473
+ attention_mask,
474
+ layer_head_mask,
475
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
476
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
477
+ init_cache: bool = False,
478
+ deterministic: bool = True,
479
+ output_attentions: bool = False,
480
+ ):
481
+ # Self Attention
482
+ attention_outputs = self.attention(
483
+ hidden_states,
484
+ attention_mask,
485
+ layer_head_mask=layer_head_mask,
486
+ init_cache=init_cache,
487
+ deterministic=deterministic,
488
+ output_attentions=output_attentions,
489
+ )
490
+ attention_output = attention_outputs[0]
491
+
492
+ # Cross-Attention Block
493
+ if encoder_hidden_states is not None:
494
+ cross_attention_outputs = self.crossattention(
495
+ attention_output,
496
+ attention_mask=encoder_attention_mask,
497
+ layer_head_mask=layer_head_mask,
498
+ key_value_states=encoder_hidden_states,
499
+ deterministic=deterministic,
500
+ output_attentions=output_attentions,
501
+ )
502
+ attention_output = cross_attention_outputs[0]
503
+
504
+ hidden_states = self.intermediate(attention_output)
505
+ hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
506
+
507
+ outputs = (hidden_states,)
508
+
509
+ if output_attentions:
510
+ outputs += (attention_outputs[1],)
511
+ if encoder_hidden_states is not None:
512
+ outputs += (cross_attention_outputs[1],)
513
+ return outputs
514
+
515
+
516
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->RobertaPreLayerNorm
517
+ class FlaxRobertaPreLayerNormLayerCollection(nn.Module):
518
+ config: RobertaPreLayerNormConfig
519
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
520
+ gradient_checkpointing: bool = False
521
+
522
+ def setup(self):
523
+ if self.gradient_checkpointing:
524
+ FlaxRobertaPreLayerNormCheckpointLayer = remat(FlaxRobertaPreLayerNormLayer, static_argnums=(5, 6, 7))
525
+ self.layers = [
526
+ FlaxRobertaPreLayerNormCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
527
+ for i in range(self.config.num_hidden_layers)
528
+ ]
529
+ else:
530
+ self.layers = [
531
+ FlaxRobertaPreLayerNormLayer(self.config, name=str(i), dtype=self.dtype)
532
+ for i in range(self.config.num_hidden_layers)
533
+ ]
534
+
535
+ def __call__(
536
+ self,
537
+ hidden_states,
538
+ attention_mask,
539
+ head_mask,
540
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
541
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
542
+ init_cache: bool = False,
543
+ deterministic: bool = True,
544
+ output_attentions: bool = False,
545
+ output_hidden_states: bool = False,
546
+ return_dict: bool = True,
547
+ ):
548
+ all_attentions = () if output_attentions else None
549
+ all_hidden_states = () if output_hidden_states else None
550
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
551
+
552
+ # Check if head_mask has a correct number of layers specified if desired
553
+ if head_mask is not None:
554
+ if head_mask.shape[0] != (len(self.layers)):
555
+ raise ValueError(
556
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for "
557
+ f" {head_mask.shape[0]}."
558
+ )
559
+
560
+ for i, layer in enumerate(self.layers):
561
+ if output_hidden_states:
562
+ all_hidden_states += (hidden_states,)
563
+
564
+ layer_outputs = layer(
565
+ hidden_states,
566
+ attention_mask,
567
+ head_mask[i] if head_mask is not None else None,
568
+ encoder_hidden_states,
569
+ encoder_attention_mask,
570
+ init_cache,
571
+ deterministic,
572
+ output_attentions,
573
+ )
574
+
575
+ hidden_states = layer_outputs[0]
576
+
577
+ if output_attentions:
578
+ all_attentions += (layer_outputs[1],)
579
+
580
+ if encoder_hidden_states is not None:
581
+ all_cross_attentions += (layer_outputs[2],)
582
+
583
+ if output_hidden_states:
584
+ all_hidden_states += (hidden_states,)
585
+
586
+ outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions)
587
+
588
+ if not return_dict:
589
+ return tuple(v for v in outputs if v is not None)
590
+
591
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
592
+ last_hidden_state=hidden_states,
593
+ hidden_states=all_hidden_states,
594
+ attentions=all_attentions,
595
+ cross_attentions=all_cross_attentions,
596
+ )
597
+
598
+
599
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->RobertaPreLayerNorm
600
+ class FlaxRobertaPreLayerNormEncoder(nn.Module):
601
+ config: RobertaPreLayerNormConfig
602
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
603
+ gradient_checkpointing: bool = False
604
+
605
+ def setup(self):
606
+ self.layer = FlaxRobertaPreLayerNormLayerCollection(
607
+ self.config,
608
+ dtype=self.dtype,
609
+ gradient_checkpointing=self.gradient_checkpointing,
610
+ )
611
+
612
+ def __call__(
613
+ self,
614
+ hidden_states,
615
+ attention_mask,
616
+ head_mask,
617
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
618
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
619
+ init_cache: bool = False,
620
+ deterministic: bool = True,
621
+ output_attentions: bool = False,
622
+ output_hidden_states: bool = False,
623
+ return_dict: bool = True,
624
+ ):
625
+ return self.layer(
626
+ hidden_states,
627
+ attention_mask,
628
+ head_mask=head_mask,
629
+ encoder_hidden_states=encoder_hidden_states,
630
+ encoder_attention_mask=encoder_attention_mask,
631
+ init_cache=init_cache,
632
+ deterministic=deterministic,
633
+ output_attentions=output_attentions,
634
+ output_hidden_states=output_hidden_states,
635
+ return_dict=return_dict,
636
+ )
637
+
638
+
639
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->RobertaPreLayerNorm
640
+ class FlaxRobertaPreLayerNormPooler(nn.Module):
641
+ config: RobertaPreLayerNormConfig
642
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
643
+
644
+ def setup(self):
645
+ self.dense = nn.Dense(
646
+ self.config.hidden_size,
647
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
648
+ dtype=self.dtype,
649
+ )
650
+
651
+ def __call__(self, hidden_states):
652
+ cls_hidden_state = hidden_states[:, 0]
653
+ cls_hidden_state = self.dense(cls_hidden_state)
654
+ return nn.tanh(cls_hidden_state)
655
+
656
+
657
+ # Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaLMHead with Roberta->RobertaPreLayerNorm
658
+ class FlaxRobertaPreLayerNormLMHead(nn.Module):
659
+ config: RobertaPreLayerNormConfig
660
+ dtype: jnp.dtype = jnp.float32
661
+ bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
662
+
663
+ def setup(self):
664
+ self.dense = nn.Dense(
665
+ self.config.hidden_size,
666
+ dtype=self.dtype,
667
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
668
+ )
669
+ self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
670
+ self.decoder = nn.Dense(
671
+ self.config.vocab_size,
672
+ dtype=self.dtype,
673
+ use_bias=False,
674
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
675
+ )
676
+ self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))
677
+
678
+ def __call__(self, hidden_states, shared_embedding=None):
679
+ hidden_states = self.dense(hidden_states)
680
+ hidden_states = ACT2FN["gelu"](hidden_states)
681
+ hidden_states = self.layer_norm(hidden_states)
682
+
683
+ if shared_embedding is not None:
684
+ hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
685
+ else:
686
+ hidden_states = self.decoder(hidden_states)
687
+
688
+ bias = jnp.asarray(self.bias, self.dtype)
689
+ hidden_states += bias
690
+ return hidden_states
691
+
692
+
693
+ # Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaClassificationHead with Roberta->RobertaPreLayerNorm
694
+ class FlaxRobertaPreLayerNormClassificationHead(nn.Module):
695
+ config: RobertaPreLayerNormConfig
696
+ dtype: jnp.dtype = jnp.float32
697
+
698
+ def setup(self):
699
+ self.dense = nn.Dense(
700
+ self.config.hidden_size,
701
+ dtype=self.dtype,
702
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
703
+ )
704
+ classifier_dropout = (
705
+ self.config.classifier_dropout
706
+ if self.config.classifier_dropout is not None
707
+ else self.config.hidden_dropout_prob
708
+ )
709
+ self.dropout = nn.Dropout(rate=classifier_dropout)
710
+ self.out_proj = nn.Dense(
711
+ self.config.num_labels,
712
+ dtype=self.dtype,
713
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
714
+ )
715
+
716
+ def __call__(self, hidden_states, deterministic=True):
717
+ hidden_states = hidden_states[:, 0, :] # take <s> token (equiv. to [CLS])
718
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
719
+ hidden_states = self.dense(hidden_states)
720
+ hidden_states = nn.tanh(hidden_states)
721
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
722
+ hidden_states = self.out_proj(hidden_states)
723
+ return hidden_states
724
+
725
+
726
+ # Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaPreTrainedModel with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm
727
+ class FlaxRobertaPreLayerNormPreTrainedModel(FlaxPreTrainedModel):
728
+ """
729
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
730
+ models.
731
+ """
732
+
733
+ config_class = RobertaPreLayerNormConfig
734
+ base_model_prefix = "roberta_prelayernorm"
735
+
736
+ module_class: nn.Module = None
737
+
738
+ def __init__(
739
+ self,
740
+ config: RobertaPreLayerNormConfig,
741
+ input_shape: Tuple = (1, 1),
742
+ seed: int = 0,
743
+ dtype: jnp.dtype = jnp.float32,
744
+ _do_init: bool = True,
745
+ gradient_checkpointing: bool = False,
746
+ **kwargs,
747
+ ):
748
+ module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
749
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
750
+
751
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing
752
+ def enable_gradient_checkpointing(self):
753
+ self._module = self.module_class(
754
+ config=self.config,
755
+ dtype=self.dtype,
756
+ gradient_checkpointing=True,
757
+ )
758
+
759
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
760
+ # init input tensors
761
+ input_ids = jnp.zeros(input_shape, dtype="i4")
762
+ token_type_ids = jnp.ones_like(input_ids)
763
+ position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
764
+ attention_mask = jnp.ones_like(input_ids)
765
+ head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
766
+
767
+ params_rng, dropout_rng = jax.random.split(rng)
768
+ rngs = {"params": params_rng, "dropout": dropout_rng}
769
+
770
+ if self.config.add_cross_attention:
771
+ encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
772
+ encoder_attention_mask = attention_mask
773
+ module_init_outputs = self.module.init(
774
+ rngs,
775
+ input_ids,
776
+ attention_mask,
777
+ token_type_ids,
778
+ position_ids,
779
+ head_mask,
780
+ encoder_hidden_states,
781
+ encoder_attention_mask,
782
+ return_dict=False,
783
+ )
784
+ else:
785
+ module_init_outputs = self.module.init(
786
+ rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
787
+ )
788
+
789
+ random_params = module_init_outputs["params"]
790
+
791
+ if params is not None:
792
+ random_params = flatten_dict(unfreeze(random_params))
793
+ params = flatten_dict(unfreeze(params))
794
+ for missing_key in self._missing_keys:
795
+ params[missing_key] = random_params[missing_key]
796
+ self._missing_keys = set()
797
+ return freeze(unflatten_dict(params))
798
+ else:
799
+ return random_params
800
+
801
+ # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache
802
+ def init_cache(self, batch_size, max_length):
803
+ r"""
804
+ Args:
805
+ batch_size (`int`):
806
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
807
+ max_length (`int`):
808
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
809
+ cache.
810
+ """
811
+ # init input variables to retrieve cache
812
+ input_ids = jnp.ones((batch_size, max_length), dtype="i4")
813
+ attention_mask = jnp.ones_like(input_ids, dtype="i4")
814
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
815
+
816
+ init_variables = self.module.init(
817
+ jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
818
+ )
819
+ return unfreeze(init_variables["cache"])
820
+
821
+ @add_start_docstrings_to_model_forward(ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
822
+ def __call__(
823
+ self,
824
+ input_ids,
825
+ attention_mask=None,
826
+ token_type_ids=None,
827
+ position_ids=None,
828
+ head_mask=None,
829
+ encoder_hidden_states=None,
830
+ encoder_attention_mask=None,
831
+ params: dict = None,
832
+ dropout_rng: jax.random.PRNGKey = None,
833
+ train: bool = False,
834
+ output_attentions: Optional[bool] = None,
835
+ output_hidden_states: Optional[bool] = None,
836
+ return_dict: Optional[bool] = None,
837
+ past_key_values: dict = None,
838
+ ):
839
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
840
+ output_hidden_states = (
841
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
842
+ )
843
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
844
+
845
+ # init input tensors if not passed
846
+ if token_type_ids is None:
847
+ token_type_ids = jnp.zeros_like(input_ids)
848
+
849
+ if position_ids is None:
850
+ position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
851
+
852
+ if attention_mask is None:
853
+ attention_mask = jnp.ones_like(input_ids)
854
+
855
+ if head_mask is None:
856
+ head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
857
+
858
+ # Handle any PRNG if needed
859
+ rngs = {}
860
+ if dropout_rng is not None:
861
+ rngs["dropout"] = dropout_rng
862
+
863
+ inputs = {"params": params or self.params}
864
+
865
+ if self.config.add_cross_attention:
866
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
867
+ # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
868
+ # changed by FlaxRobertaPreLayerNormAttention module
869
+ if past_key_values:
870
+ inputs["cache"] = past_key_values
871
+ mutable = ["cache"]
872
+ else:
873
+ mutable = False
874
+
875
+ outputs = self.module.apply(
876
+ inputs,
877
+ jnp.array(input_ids, dtype="i4"),
878
+ jnp.array(attention_mask, dtype="i4"),
879
+ token_type_ids=jnp.array(token_type_ids, dtype="i4"),
880
+ position_ids=jnp.array(position_ids, dtype="i4"),
881
+ head_mask=jnp.array(head_mask, dtype="i4"),
882
+ encoder_hidden_states=encoder_hidden_states,
883
+ encoder_attention_mask=encoder_attention_mask,
884
+ deterministic=not train,
885
+ output_attentions=output_attentions,
886
+ output_hidden_states=output_hidden_states,
887
+ return_dict=return_dict,
888
+ rngs=rngs,
889
+ mutable=mutable,
890
+ )
891
+
892
+ # add updated cache to model output
893
+ if past_key_values is not None and return_dict:
894
+ outputs, past_key_values = outputs
895
+ outputs["past_key_values"] = unfreeze(past_key_values["cache"])
896
+ return outputs
897
+ elif past_key_values is not None and not return_dict:
898
+ outputs, past_key_values = outputs
899
+ outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
900
+
901
+ else:
902
+ outputs = self.module.apply(
903
+ inputs,
904
+ jnp.array(input_ids, dtype="i4"),
905
+ jnp.array(attention_mask, dtype="i4"),
906
+ token_type_ids=jnp.array(token_type_ids, dtype="i4"),
907
+ position_ids=jnp.array(position_ids, dtype="i4"),
908
+ head_mask=jnp.array(head_mask, dtype="i4"),
909
+ deterministic=not train,
910
+ output_attentions=output_attentions,
911
+ output_hidden_states=output_hidden_states,
912
+ return_dict=return_dict,
913
+ rngs=rngs,
914
+ )
915
+
916
+ return outputs
917
+
918
+
919
+ class FlaxRobertaPreLayerNormModule(nn.Module):
920
+ config: RobertaPreLayerNormConfig
921
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
922
+ add_pooling_layer: bool = True
923
+ gradient_checkpointing: bool = False
924
+
925
+ def setup(self):
926
+ self.embeddings = FlaxRobertaPreLayerNormEmbeddings(self.config, dtype=self.dtype)
927
+ self.encoder = FlaxRobertaPreLayerNormEncoder(
928
+ self.config,
929
+ dtype=self.dtype,
930
+ gradient_checkpointing=self.gradient_checkpointing,
931
+ )
932
+ self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
933
+ self.pooler = FlaxRobertaPreLayerNormPooler(self.config, dtype=self.dtype)
934
+
935
+ def __call__(
936
+ self,
937
+ input_ids,
938
+ attention_mask,
939
+ token_type_ids: Optional[jnp.ndarray] = None,
940
+ position_ids: Optional[jnp.ndarray] = None,
941
+ head_mask: Optional[jnp.ndarray] = None,
942
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
943
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
944
+ init_cache: bool = False,
945
+ deterministic: bool = True,
946
+ output_attentions: bool = False,
947
+ output_hidden_states: bool = False,
948
+ return_dict: bool = True,
949
+ ):
950
+ # make sure `token_type_ids` is correctly initialized when not passed
951
+ if token_type_ids is None:
952
+ token_type_ids = jnp.zeros_like(input_ids)
953
+
954
+ # make sure `position_ids` is correctly initialized when not passed
955
+ if position_ids is None:
956
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
957
+
958
+ hidden_states = self.embeddings(
959
+ input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
960
+ )
961
+ outputs = self.encoder(
962
+ hidden_states,
963
+ attention_mask,
964
+ head_mask=head_mask,
965
+ deterministic=deterministic,
966
+ encoder_hidden_states=encoder_hidden_states,
967
+ encoder_attention_mask=encoder_attention_mask,
968
+ init_cache=init_cache,
969
+ output_attentions=output_attentions,
970
+ output_hidden_states=output_hidden_states,
971
+ return_dict=return_dict,
972
+ )
973
+ hidden_states = outputs[0]
974
+ hidden_states = self.LayerNorm(hidden_states)
975
+ pooled = self.pooler(hidden_states) if self.add_pooling_layer else None
976
+
977
+ if not return_dict:
978
+ # if pooled is None, don't return it
979
+ if pooled is None:
980
+ return (hidden_states,) + outputs[1:]
981
+ return (hidden_states, pooled) + outputs[1:]
982
+
983
+ return FlaxBaseModelOutputWithPoolingAndCrossAttentions(
984
+ last_hidden_state=hidden_states,
985
+ pooler_output=pooled,
986
+ hidden_states=outputs.hidden_states,
987
+ attentions=outputs.attentions,
988
+ cross_attentions=outputs.cross_attentions,
989
+ )
990
+
991
+
992
+ @add_start_docstrings(
993
+ "The bare RoBERTa-PreLayerNorm Model transformer outputting raw hidden-states without any specific head on top.",
994
+ ROBERTA_PRELAYERNORM_START_DOCSTRING,
995
+ )
996
+ # Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaModel with Roberta->RobertaPreLayerNorm
997
+ class FlaxRobertaPreLayerNormModel(FlaxRobertaPreLayerNormPreTrainedModel):
998
+ module_class = FlaxRobertaPreLayerNormModule
999
+
1000
+
1001
+ append_call_sample_docstring(
1002
+ FlaxRobertaPreLayerNormModel,
1003
+ _CHECKPOINT_FOR_DOC,
1004
+ FlaxBaseModelOutputWithPooling,
1005
+ _CONFIG_FOR_DOC,
1006
+ )
1007
+
1008
+
1009
+ # Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForMaskedLMModule with Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm
1010
+ class FlaxRobertaPreLayerNormForMaskedLMModule(nn.Module):
1011
+ config: RobertaPreLayerNormConfig
1012
+ dtype: jnp.dtype = jnp.float32
1013
+ gradient_checkpointing: bool = False
1014
+
1015
+ def setup(self):
1016
+ self.roberta_prelayernorm = FlaxRobertaPreLayerNormModule(
1017
+ config=self.config,
1018
+ add_pooling_layer=False,
1019
+ dtype=self.dtype,
1020
+ gradient_checkpointing=self.gradient_checkpointing,
1021
+ )
1022
+ self.lm_head = FlaxRobertaPreLayerNormLMHead(config=self.config, dtype=self.dtype)
1023
+
1024
+ def __call__(
1025
+ self,
1026
+ input_ids,
1027
+ attention_mask,
1028
+ token_type_ids,
1029
+ position_ids,
1030
+ head_mask,
1031
+ deterministic: bool = True,
1032
+ output_attentions: bool = False,
1033
+ output_hidden_states: bool = False,
1034
+ return_dict: bool = True,
1035
+ ):
1036
+ # Model
1037
+ outputs = self.roberta_prelayernorm(
1038
+ input_ids,
1039
+ attention_mask,
1040
+ token_type_ids,
1041
+ position_ids,
1042
+ head_mask,
1043
+ deterministic=deterministic,
1044
+ output_attentions=output_attentions,
1045
+ output_hidden_states=output_hidden_states,
1046
+ return_dict=return_dict,
1047
+ )
1048
+
1049
+ hidden_states = outputs[0]
1050
+ if self.config.tie_word_embeddings:
1051
+ shared_embedding = self.roberta_prelayernorm.variables["params"]["embeddings"]["word_embeddings"][
1052
+ "embedding"
1053
+ ]
1054
+ else:
1055
+ shared_embedding = None
1056
+
1057
+ # Compute the prediction scores
1058
+ logits = self.lm_head(hidden_states, shared_embedding=shared_embedding)
1059
+
1060
+ if not return_dict:
1061
+ return (logits,) + outputs[1:]
1062
+
1063
+ return FlaxMaskedLMOutput(
1064
+ logits=logits,
1065
+ hidden_states=outputs.hidden_states,
1066
+ attentions=outputs.attentions,
1067
+ )
1068
+
1069
+
1070
+ @add_start_docstrings(
1071
+ """RoBERTa-PreLayerNorm Model with a `language modeling` head on top.""", ROBERTA_PRELAYERNORM_START_DOCSTRING
1072
+ )
1073
+ # Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForMaskedLM with Roberta->RobertaPreLayerNorm
1074
+ class FlaxRobertaPreLayerNormForMaskedLM(FlaxRobertaPreLayerNormPreTrainedModel):
1075
+ module_class = FlaxRobertaPreLayerNormForMaskedLMModule
1076
+
1077
+
1078
+ append_call_sample_docstring(
1079
+ FlaxRobertaPreLayerNormForMaskedLM,
1080
+ _CHECKPOINT_FOR_DOC,
1081
+ FlaxBaseModelOutputWithPooling,
1082
+ _CONFIG_FOR_DOC,
1083
+ mask="<mask>",
1084
+ )
1085
+
1086
+
1087
+ # Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForSequenceClassificationModule with Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm
1088
+ class FlaxRobertaPreLayerNormForSequenceClassificationModule(nn.Module):
1089
+ config: RobertaPreLayerNormConfig
1090
+ dtype: jnp.dtype = jnp.float32
1091
+ gradient_checkpointing: bool = False
1092
+
1093
+ def setup(self):
1094
+ self.roberta_prelayernorm = FlaxRobertaPreLayerNormModule(
1095
+ config=self.config,
1096
+ dtype=self.dtype,
1097
+ add_pooling_layer=False,
1098
+ gradient_checkpointing=self.gradient_checkpointing,
1099
+ )
1100
+ self.classifier = FlaxRobertaPreLayerNormClassificationHead(config=self.config, dtype=self.dtype)
1101
+
1102
+ def __call__(
1103
+ self,
1104
+ input_ids,
1105
+ attention_mask,
1106
+ token_type_ids,
1107
+ position_ids,
1108
+ head_mask,
1109
+ deterministic: bool = True,
1110
+ output_attentions: bool = False,
1111
+ output_hidden_states: bool = False,
1112
+ return_dict: bool = True,
1113
+ ):
1114
+ # Model
1115
+ outputs = self.roberta_prelayernorm(
1116
+ input_ids,
1117
+ attention_mask,
1118
+ token_type_ids,
1119
+ position_ids,
1120
+ head_mask,
1121
+ deterministic=deterministic,
1122
+ output_attentions=output_attentions,
1123
+ output_hidden_states=output_hidden_states,
1124
+ return_dict=return_dict,
1125
+ )
1126
+
1127
+ sequence_output = outputs[0]
1128
+ logits = self.classifier(sequence_output, deterministic=deterministic)
1129
+
1130
+ if not return_dict:
1131
+ return (logits,) + outputs[1:]
1132
+
1133
+ return FlaxSequenceClassifierOutput(
1134
+ logits=logits,
1135
+ hidden_states=outputs.hidden_states,
1136
+ attentions=outputs.attentions,
1137
+ )
1138
+
1139
+
1140
+ @add_start_docstrings(
1141
+ """
1142
+ RobertaPreLayerNorm Model transformer with a sequence classification/regression head on top (a linear layer on top
1143
+ of the pooled output) e.g. for GLUE tasks.
1144
+ """,
1145
+ ROBERTA_PRELAYERNORM_START_DOCSTRING,
1146
+ )
1147
+ # Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForSequenceClassification with Roberta->RobertaPreLayerNorm
1148
+ class FlaxRobertaPreLayerNormForSequenceClassification(FlaxRobertaPreLayerNormPreTrainedModel):
1149
+ module_class = FlaxRobertaPreLayerNormForSequenceClassificationModule
1150
+
1151
+
1152
+ append_call_sample_docstring(
1153
+ FlaxRobertaPreLayerNormForSequenceClassification,
1154
+ _CHECKPOINT_FOR_DOC,
1155
+ FlaxSequenceClassifierOutput,
1156
+ _CONFIG_FOR_DOC,
1157
+ )
1158
+
1159
+
1160
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForMultipleChoiceModule with Bert->RobertaPreLayerNorm, with self.bert->self.roberta_prelayernorm
1161
+ class FlaxRobertaPreLayerNormForMultipleChoiceModule(nn.Module):
1162
+ config: RobertaPreLayerNormConfig
1163
+ dtype: jnp.dtype = jnp.float32
1164
+ gradient_checkpointing: bool = False
1165
+
1166
+ def setup(self):
1167
+ self.roberta_prelayernorm = FlaxRobertaPreLayerNormModule(
1168
+ config=self.config,
1169
+ dtype=self.dtype,
1170
+ gradient_checkpointing=self.gradient_checkpointing,
1171
+ )
1172
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
1173
+ self.classifier = nn.Dense(1, dtype=self.dtype)
1174
+
1175
+ def __call__(
1176
+ self,
1177
+ input_ids,
1178
+ attention_mask,
1179
+ token_type_ids,
1180
+ position_ids,
1181
+ head_mask,
1182
+ deterministic: bool = True,
1183
+ output_attentions: bool = False,
1184
+ output_hidden_states: bool = False,
1185
+ return_dict: bool = True,
1186
+ ):
1187
+ num_choices = input_ids.shape[1]
1188
+ input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None
1189
+ attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None
1190
+ token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None
1191
+ position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None
1192
+
1193
+ # Model
1194
+ outputs = self.roberta_prelayernorm(
1195
+ input_ids,
1196
+ attention_mask,
1197
+ token_type_ids,
1198
+ position_ids,
1199
+ head_mask,
1200
+ deterministic=deterministic,
1201
+ output_attentions=output_attentions,
1202
+ output_hidden_states=output_hidden_states,
1203
+ return_dict=return_dict,
1204
+ )
1205
+
1206
+ pooled_output = outputs[1]
1207
+ pooled_output = self.dropout(pooled_output, deterministic=deterministic)
1208
+ logits = self.classifier(pooled_output)
1209
+
1210
+ reshaped_logits = logits.reshape(-1, num_choices)
1211
+
1212
+ if not return_dict:
1213
+ return (reshaped_logits,) + outputs[2:]
1214
+
1215
+ return FlaxMultipleChoiceModelOutput(
1216
+ logits=reshaped_logits,
1217
+ hidden_states=outputs.hidden_states,
1218
+ attentions=outputs.attentions,
1219
+ )
1220
+
1221
+
1222
+ @add_start_docstrings(
1223
+ """
1224
+ RobertaPreLayerNorm Model with a multiple choice classification head on top (a linear layer on top of the pooled
1225
+ output and a softmax) e.g. for RocStories/SWAG tasks.
1226
+ """,
1227
+ ROBERTA_PRELAYERNORM_START_DOCSTRING,
1228
+ )
1229
+ # Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForMultipleChoice with Roberta->RobertaPreLayerNorm
1230
+ class FlaxRobertaPreLayerNormForMultipleChoice(FlaxRobertaPreLayerNormPreTrainedModel):
1231
+ module_class = FlaxRobertaPreLayerNormForMultipleChoiceModule
1232
+
1233
+
1234
+ overwrite_call_docstring(
1235
+ FlaxRobertaPreLayerNormForMultipleChoice,
1236
+ ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"),
1237
+ )
1238
+ append_call_sample_docstring(
1239
+ FlaxRobertaPreLayerNormForMultipleChoice,
1240
+ _CHECKPOINT_FOR_DOC,
1241
+ FlaxMultipleChoiceModelOutput,
1242
+ _CONFIG_FOR_DOC,
1243
+ )
1244
+
1245
+
1246
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForTokenClassificationModule with Bert->RobertaPreLayerNorm, with self.bert->self.roberta_prelayernorm
1247
+ class FlaxRobertaPreLayerNormForTokenClassificationModule(nn.Module):
1248
+ config: RobertaPreLayerNormConfig
1249
+ dtype: jnp.dtype = jnp.float32
1250
+ gradient_checkpointing: bool = False
1251
+
1252
+ def setup(self):
1253
+ self.roberta_prelayernorm = FlaxRobertaPreLayerNormModule(
1254
+ config=self.config,
1255
+ dtype=self.dtype,
1256
+ add_pooling_layer=False,
1257
+ gradient_checkpointing=self.gradient_checkpointing,
1258
+ )
1259
+ classifier_dropout = (
1260
+ self.config.classifier_dropout
1261
+ if self.config.classifier_dropout is not None
1262
+ else self.config.hidden_dropout_prob
1263
+ )
1264
+ self.dropout = nn.Dropout(rate=classifier_dropout)
1265
+ self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
1266
+
1267
+ def __call__(
1268
+ self,
1269
+ input_ids,
1270
+ attention_mask,
1271
+ token_type_ids,
1272
+ position_ids,
1273
+ head_mask,
1274
+ deterministic: bool = True,
1275
+ output_attentions: bool = False,
1276
+ output_hidden_states: bool = False,
1277
+ return_dict: bool = True,
1278
+ ):
1279
+ # Model
1280
+ outputs = self.roberta_prelayernorm(
1281
+ input_ids,
1282
+ attention_mask,
1283
+ token_type_ids,
1284
+ position_ids,
1285
+ head_mask,
1286
+ deterministic=deterministic,
1287
+ output_attentions=output_attentions,
1288
+ output_hidden_states=output_hidden_states,
1289
+ return_dict=return_dict,
1290
+ )
1291
+
1292
+ hidden_states = outputs[0]
1293
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
1294
+ logits = self.classifier(hidden_states)
1295
+
1296
+ if not return_dict:
1297
+ return (logits,) + outputs[1:]
1298
+
1299
+ return FlaxTokenClassifierOutput(
1300
+ logits=logits,
1301
+ hidden_states=outputs.hidden_states,
1302
+ attentions=outputs.attentions,
1303
+ )
1304
+
1305
+
1306
+ @add_start_docstrings(
1307
+ """
1308
+ RobertaPreLayerNorm Model with a token classification head on top (a linear layer on top of the hidden-states
1309
+ output) e.g. for Named-Entity-Recognition (NER) tasks.
1310
+ """,
1311
+ ROBERTA_PRELAYERNORM_START_DOCSTRING,
1312
+ )
1313
+ # Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForTokenClassification with Roberta->RobertaPreLayerNorm
1314
+ class FlaxRobertaPreLayerNormForTokenClassification(FlaxRobertaPreLayerNormPreTrainedModel):
1315
+ module_class = FlaxRobertaPreLayerNormForTokenClassificationModule
1316
+
1317
+
1318
+ append_call_sample_docstring(
1319
+ FlaxRobertaPreLayerNormForTokenClassification,
1320
+ _CHECKPOINT_FOR_DOC,
1321
+ FlaxTokenClassifierOutput,
1322
+ _CONFIG_FOR_DOC,
1323
+ )
1324
+
1325
+
1326
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForQuestionAnsweringModule with Bert->RobertaPreLayerNorm, with self.bert->self.roberta_prelayernorm
1327
+ class FlaxRobertaPreLayerNormForQuestionAnsweringModule(nn.Module):
1328
+ config: RobertaPreLayerNormConfig
1329
+ dtype: jnp.dtype = jnp.float32
1330
+ gradient_checkpointing: bool = False
1331
+
1332
+ def setup(self):
1333
+ self.roberta_prelayernorm = FlaxRobertaPreLayerNormModule(
1334
+ config=self.config,
1335
+ dtype=self.dtype,
1336
+ add_pooling_layer=False,
1337
+ gradient_checkpointing=self.gradient_checkpointing,
1338
+ )
1339
+ self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
1340
+
1341
+ def __call__(
1342
+ self,
1343
+ input_ids,
1344
+ attention_mask,
1345
+ token_type_ids,
1346
+ position_ids,
1347
+ head_mask,
1348
+ deterministic: bool = True,
1349
+ output_attentions: bool = False,
1350
+ output_hidden_states: bool = False,
1351
+ return_dict: bool = True,
1352
+ ):
1353
+ # Model
1354
+ outputs = self.roberta_prelayernorm(
1355
+ input_ids,
1356
+ attention_mask,
1357
+ token_type_ids,
1358
+ position_ids,
1359
+ head_mask,
1360
+ deterministic=deterministic,
1361
+ output_attentions=output_attentions,
1362
+ output_hidden_states=output_hidden_states,
1363
+ return_dict=return_dict,
1364
+ )
1365
+
1366
+ hidden_states = outputs[0]
1367
+
1368
+ logits = self.qa_outputs(hidden_states)
1369
+ start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
1370
+ start_logits = start_logits.squeeze(-1)
1371
+ end_logits = end_logits.squeeze(-1)
1372
+
1373
+ if not return_dict:
1374
+ return (start_logits, end_logits) + outputs[1:]
1375
+
1376
+ return FlaxQuestionAnsweringModelOutput(
1377
+ start_logits=start_logits,
1378
+ end_logits=end_logits,
1379
+ hidden_states=outputs.hidden_states,
1380
+ attentions=outputs.attentions,
1381
+ )
1382
+
1383
+
1384
+ @add_start_docstrings(
1385
+ """
1386
+ RobertaPreLayerNorm Model with a span classification head on top for extractive question-answering tasks like SQuAD
1387
+ (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1388
+ """,
1389
+ ROBERTA_PRELAYERNORM_START_DOCSTRING,
1390
+ )
1391
+ # Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForQuestionAnswering with Roberta->RobertaPreLayerNorm
1392
+ class FlaxRobertaPreLayerNormForQuestionAnswering(FlaxRobertaPreLayerNormPreTrainedModel):
1393
+ module_class = FlaxRobertaPreLayerNormForQuestionAnsweringModule
1394
+
1395
+
1396
+ append_call_sample_docstring(
1397
+ FlaxRobertaPreLayerNormForQuestionAnswering,
1398
+ _CHECKPOINT_FOR_DOC,
1399
+ FlaxQuestionAnsweringModelOutput,
1400
+ _CONFIG_FOR_DOC,
1401
+ )
1402
+
1403
+
1404
+ # Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForCausalLMModule with Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm
1405
+ class FlaxRobertaPreLayerNormForCausalLMModule(nn.Module):
1406
+ config: RobertaPreLayerNormConfig
1407
+ dtype: jnp.dtype = jnp.float32
1408
+ gradient_checkpointing: bool = False
1409
+
1410
+ def setup(self):
1411
+ self.roberta_prelayernorm = FlaxRobertaPreLayerNormModule(
1412
+ config=self.config,
1413
+ add_pooling_layer=False,
1414
+ dtype=self.dtype,
1415
+ gradient_checkpointing=self.gradient_checkpointing,
1416
+ )
1417
+ self.lm_head = FlaxRobertaPreLayerNormLMHead(config=self.config, dtype=self.dtype)
1418
+
1419
+ def __call__(
1420
+ self,
1421
+ input_ids,
1422
+ attention_mask,
1423
+ position_ids,
1424
+ token_type_ids: Optional[jnp.ndarray] = None,
1425
+ head_mask: Optional[jnp.ndarray] = None,
1426
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
1427
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
1428
+ init_cache: bool = False,
1429
+ deterministic: bool = True,
1430
+ output_attentions: bool = False,
1431
+ output_hidden_states: bool = False,
1432
+ return_dict: bool = True,
1433
+ ):
1434
+ # Model
1435
+ outputs = self.roberta_prelayernorm(
1436
+ input_ids,
1437
+ attention_mask,
1438
+ token_type_ids,
1439
+ position_ids,
1440
+ head_mask,
1441
+ encoder_hidden_states=encoder_hidden_states,
1442
+ encoder_attention_mask=encoder_attention_mask,
1443
+ init_cache=init_cache,
1444
+ deterministic=deterministic,
1445
+ output_attentions=output_attentions,
1446
+ output_hidden_states=output_hidden_states,
1447
+ return_dict=return_dict,
1448
+ )
1449
+
1450
+ hidden_states = outputs[0]
1451
+ if self.config.tie_word_embeddings:
1452
+ shared_embedding = self.roberta_prelayernorm.variables["params"]["embeddings"]["word_embeddings"][
1453
+ "embedding"
1454
+ ]
1455
+ else:
1456
+ shared_embedding = None
1457
+
1458
+ # Compute the prediction scores
1459
+ logits = self.lm_head(hidden_states, shared_embedding=shared_embedding)
1460
+
1461
+ if not return_dict:
1462
+ return (logits,) + outputs[1:]
1463
+
1464
+ return FlaxCausalLMOutputWithCrossAttentions(
1465
+ logits=logits,
1466
+ hidden_states=outputs.hidden_states,
1467
+ attentions=outputs.attentions,
1468
+ cross_attentions=outputs.cross_attentions,
1469
+ )
1470
+
1471
+
1472
+ @add_start_docstrings(
1473
+ """
1474
+ RobertaPreLayerNorm Model with a language modeling head on top (a linear layer on top of the hidden-states output)
1475
+ e.g for autoregressive tasks.
1476
+ """,
1477
+ ROBERTA_PRELAYERNORM_START_DOCSTRING,
1478
+ )
1479
+ # Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForCausalLM with Roberta->RobertaPreLayerNorm
1480
+ class FlaxRobertaPreLayerNormForCausalLM(FlaxRobertaPreLayerNormPreTrainedModel):
1481
+ module_class = FlaxRobertaPreLayerNormForCausalLMModule
1482
+
1483
+ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
1484
+ # initializing the cache
1485
+ batch_size, seq_length = input_ids.shape
1486
+
1487
+ past_key_values = self.init_cache(batch_size, max_length)
1488
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
1489
+ # But since the decoder uses a causal mask, those positions are masked anyway.
1490
+ # Thus, we can create a single static attention_mask here, which is more efficient for compilation
1491
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
1492
+ if attention_mask is not None:
1493
+ position_ids = attention_mask.cumsum(axis=-1) - 1
1494
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
1495
+ else:
1496
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
1497
+
1498
+ return {
1499
+ "past_key_values": past_key_values,
1500
+ "attention_mask": extended_attention_mask,
1501
+ "position_ids": position_ids,
1502
+ }
1503
+
1504
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
1505
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
1506
+ model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
1507
+ return model_kwargs
1508
+
1509
+
1510
+ append_call_sample_docstring(
1511
+ FlaxRobertaPreLayerNormForCausalLM,
1512
+ _CHECKPOINT_FOR_DOC,
1513
+ FlaxCausalLMOutputWithCrossAttentions,
1514
+ _CONFIG_FOR_DOC,
1515
+ )
1516
+
1517
+
1518
+ __all__ = [
1519
+ "FlaxRobertaPreLayerNormForCausalLM",
1520
+ "FlaxRobertaPreLayerNormForMaskedLM",
1521
+ "FlaxRobertaPreLayerNormForMultipleChoice",
1522
+ "FlaxRobertaPreLayerNormForQuestionAnswering",
1523
+ "FlaxRobertaPreLayerNormForSequenceClassification",
1524
+ "FlaxRobertaPreLayerNormForTokenClassification",
1525
+ "FlaxRobertaPreLayerNormModel",
1526
+ "FlaxRobertaPreLayerNormPreTrainedModel",
1527
+ ]
docs/transformers/build/lib/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py ADDED
@@ -0,0 +1,1558 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch RoBERTa-PreLayerNorm model."""
17
+
18
+ import math
19
+ from typing import List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
25
+
26
+ from ...activations import ACT2FN, gelu
27
+ from ...generation import GenerationMixin
28
+ from ...modeling_outputs import (
29
+ BaseModelOutputWithPastAndCrossAttentions,
30
+ BaseModelOutputWithPoolingAndCrossAttentions,
31
+ CausalLMOutputWithCrossAttentions,
32
+ MaskedLMOutput,
33
+ MultipleChoiceModelOutput,
34
+ QuestionAnsweringModelOutput,
35
+ SequenceClassifierOutput,
36
+ TokenClassifierOutput,
37
+ )
38
+ from ...modeling_utils import PreTrainedModel
39
+ from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
40
+ from ...utils import (
41
+ add_code_sample_docstrings,
42
+ add_start_docstrings,
43
+ add_start_docstrings_to_model_forward,
44
+ logging,
45
+ replace_return_docstrings,
46
+ )
47
+ from .configuration_roberta_prelayernorm import RobertaPreLayerNormConfig
48
+
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+ _CHECKPOINT_FOR_DOC = "andreasmadsen/efficient_mlm_m0.40"
53
+ _CONFIG_FOR_DOC = "RobertaPreLayerNormConfig"
54
+
55
+
56
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings with Roberta->RobertaPreLayerNorm
57
+ class RobertaPreLayerNormEmbeddings(nn.Module):
58
+ """
59
+ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
60
+ """
61
+
62
+ # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__
63
+ def __init__(self, config):
64
+ super().__init__()
65
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
66
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
67
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
68
+
69
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
70
+ # any TensorFlow checkpoint file
71
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
72
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
73
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
74
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
75
+ self.register_buffer(
76
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
77
+ )
78
+ self.register_buffer(
79
+ "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
80
+ )
81
+
82
+ # End copy
83
+ self.padding_idx = config.pad_token_id
84
+ self.position_embeddings = nn.Embedding(
85
+ config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
86
+ )
87
+
88
+ def forward(
89
+ self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
90
+ ):
91
+ if position_ids is None:
92
+ if input_ids is not None:
93
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
94
+ position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
95
+ else:
96
+ position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
97
+
98
+ if input_ids is not None:
99
+ input_shape = input_ids.size()
100
+ else:
101
+ input_shape = inputs_embeds.size()[:-1]
102
+
103
+ seq_length = input_shape[1]
104
+
105
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
106
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
107
+ # issue #5664
108
+ if token_type_ids is None:
109
+ if hasattr(self, "token_type_ids"):
110
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
111
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
112
+ token_type_ids = buffered_token_type_ids_expanded
113
+ else:
114
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
115
+
116
+ if inputs_embeds is None:
117
+ inputs_embeds = self.word_embeddings(input_ids)
118
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
119
+
120
+ embeddings = inputs_embeds + token_type_embeddings
121
+ if self.position_embedding_type == "absolute":
122
+ position_embeddings = self.position_embeddings(position_ids)
123
+ embeddings += position_embeddings
124
+ embeddings = self.LayerNorm(embeddings)
125
+ embeddings = self.dropout(embeddings)
126
+ return embeddings
127
+
128
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
129
+ """
130
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
131
+
132
+ Args:
133
+ inputs_embeds: torch.Tensor
134
+
135
+ Returns: torch.Tensor
136
+ """
137
+ input_shape = inputs_embeds.size()[:-1]
138
+ sequence_length = input_shape[1]
139
+
140
+ position_ids = torch.arange(
141
+ self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
142
+ )
143
+ return position_ids.unsqueeze(0).expand(input_shape)
144
+
145
+
146
+ # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->RobertaPreLayerNorm
147
+ class RobertaPreLayerNormSelfAttention(nn.Module):
148
+ def __init__(self, config, position_embedding_type=None):
149
+ super().__init__()
150
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
151
+ raise ValueError(
152
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
153
+ f"heads ({config.num_attention_heads})"
154
+ )
155
+
156
+ self.num_attention_heads = config.num_attention_heads
157
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
158
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
159
+
160
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
161
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
162
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
163
+
164
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
165
+ self.position_embedding_type = position_embedding_type or getattr(
166
+ config, "position_embedding_type", "absolute"
167
+ )
168
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
169
+ self.max_position_embeddings = config.max_position_embeddings
170
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
171
+
172
+ self.is_decoder = config.is_decoder
173
+
174
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
175
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
176
+ x = x.view(new_x_shape)
177
+ return x.permute(0, 2, 1, 3)
178
+
179
+ def forward(
180
+ self,
181
+ hidden_states: torch.Tensor,
182
+ attention_mask: Optional[torch.FloatTensor] = None,
183
+ head_mask: Optional[torch.FloatTensor] = None,
184
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
185
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
186
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
187
+ output_attentions: Optional[bool] = False,
188
+ ) -> Tuple[torch.Tensor]:
189
+ mixed_query_layer = self.query(hidden_states)
190
+
191
+ # If this is instantiated as a cross-attention module, the keys
192
+ # and values come from an encoder; the attention mask needs to be
193
+ # such that the encoder's padding tokens are not attended to.
194
+ is_cross_attention = encoder_hidden_states is not None
195
+
196
+ if is_cross_attention and past_key_value is not None:
197
+ # reuse k,v, cross_attentions
198
+ key_layer = past_key_value[0]
199
+ value_layer = past_key_value[1]
200
+ attention_mask = encoder_attention_mask
201
+ elif is_cross_attention:
202
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
203
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
204
+ attention_mask = encoder_attention_mask
205
+ elif past_key_value is not None:
206
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
207
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
208
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
209
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
210
+ else:
211
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
212
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
213
+
214
+ query_layer = self.transpose_for_scores(mixed_query_layer)
215
+
216
+ use_cache = past_key_value is not None
217
+ if self.is_decoder:
218
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
219
+ # Further calls to cross_attention layer can then reuse all cross-attention
220
+ # key/value_states (first "if" case)
221
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
222
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
223
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
224
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
225
+ past_key_value = (key_layer, value_layer)
226
+
227
+ # Take the dot product between "query" and "key" to get the raw attention scores.
228
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
229
+
230
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
231
+ query_length, key_length = query_layer.shape[2], key_layer.shape[2]
232
+ if use_cache:
233
+ position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
234
+ -1, 1
235
+ )
236
+ else:
237
+ position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
238
+ position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
239
+ distance = position_ids_l - position_ids_r
240
+
241
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
242
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
243
+
244
+ if self.position_embedding_type == "relative_key":
245
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
246
+ attention_scores = attention_scores + relative_position_scores
247
+ elif self.position_embedding_type == "relative_key_query":
248
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
249
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
250
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
251
+
252
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
253
+ if attention_mask is not None:
254
+ # Apply the attention mask is (precomputed for all layers in RobertaPreLayerNormModel forward() function)
255
+ attention_scores = attention_scores + attention_mask
256
+
257
+ # Normalize the attention scores to probabilities.
258
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
259
+
260
+ # This is actually dropping out entire tokens to attend to, which might
261
+ # seem a bit unusual, but is taken from the original Transformer paper.
262
+ attention_probs = self.dropout(attention_probs)
263
+
264
+ # Mask heads if we want to
265
+ if head_mask is not None:
266
+ attention_probs = attention_probs * head_mask
267
+
268
+ context_layer = torch.matmul(attention_probs, value_layer)
269
+
270
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
271
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
272
+ context_layer = context_layer.view(new_context_layer_shape)
273
+
274
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
275
+
276
+ if self.is_decoder:
277
+ outputs = outputs + (past_key_value,)
278
+ return outputs
279
+
280
+
281
+ class RobertaPreLayerNormSelfOutput(nn.Module):
282
+ def __init__(self, config):
283
+ super().__init__()
284
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
285
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
286
+
287
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
288
+ hidden_states = self.dense(hidden_states)
289
+ hidden_states = self.dropout(hidden_states)
290
+ hidden_states = hidden_states + input_tensor
291
+ return hidden_states
292
+
293
+
294
+ class RobertaPreLayerNormAttention(nn.Module):
295
+ def __init__(self, config, position_embedding_type=None):
296
+ super().__init__()
297
+ self.self = RobertaPreLayerNormSelfAttention(config, position_embedding_type=position_embedding_type)
298
+ self.output = RobertaPreLayerNormSelfOutput(config)
299
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
300
+ self.pruned_heads = set()
301
+
302
+ # Copied from transformers.models.bert.modeling_bert.BertAttention.prune_heads
303
+ def prune_heads(self, heads):
304
+ if len(heads) == 0:
305
+ return
306
+ heads, index = find_pruneable_heads_and_indices(
307
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
308
+ )
309
+
310
+ # Prune linear layers
311
+ self.self.query = prune_linear_layer(self.self.query, index)
312
+ self.self.key = prune_linear_layer(self.self.key, index)
313
+ self.self.value = prune_linear_layer(self.self.value, index)
314
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
315
+
316
+ # Update hyper params and store pruned heads
317
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
318
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
319
+ self.pruned_heads = self.pruned_heads.union(heads)
320
+
321
+ def forward(
322
+ self,
323
+ hidden_states: torch.Tensor,
324
+ attention_mask: Optional[torch.FloatTensor] = None,
325
+ head_mask: Optional[torch.FloatTensor] = None,
326
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
327
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
328
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
329
+ output_attentions: Optional[bool] = False,
330
+ ) -> Tuple[torch.Tensor]:
331
+ hidden_states_pre_layer_norm = self.LayerNorm(hidden_states)
332
+ self_outputs = self.self(
333
+ hidden_states_pre_layer_norm,
334
+ attention_mask,
335
+ head_mask,
336
+ encoder_hidden_states,
337
+ encoder_attention_mask,
338
+ past_key_value,
339
+ output_attentions,
340
+ )
341
+ attention_output = self.output(self_outputs[0], hidden_states)
342
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
343
+ return outputs
344
+
345
+
346
+ class RobertaPreLayerNormIntermediate(nn.Module):
347
+ def __init__(self, config):
348
+ super().__init__()
349
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
350
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
351
+ if isinstance(config.hidden_act, str):
352
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
353
+ else:
354
+ self.intermediate_act_fn = config.hidden_act
355
+
356
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
357
+ hidden_states = self.LayerNorm(hidden_states)
358
+ hidden_states = self.dense(hidden_states)
359
+ hidden_states = self.intermediate_act_fn(hidden_states)
360
+ return hidden_states
361
+
362
+
363
+ class RobertaPreLayerNormOutput(nn.Module):
364
+ def __init__(self, config):
365
+ super().__init__()
366
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
367
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
368
+
369
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
370
+ hidden_states = self.dense(hidden_states)
371
+ hidden_states = self.dropout(hidden_states)
372
+ hidden_states = hidden_states + input_tensor
373
+ return hidden_states
374
+
375
+
376
+ # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->RobertaPreLayerNorm
377
+ class RobertaPreLayerNormLayer(nn.Module):
378
+ def __init__(self, config):
379
+ super().__init__()
380
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
381
+ self.seq_len_dim = 1
382
+ self.attention = RobertaPreLayerNormAttention(config)
383
+ self.is_decoder = config.is_decoder
384
+ self.add_cross_attention = config.add_cross_attention
385
+ if self.add_cross_attention:
386
+ if not self.is_decoder:
387
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
388
+ self.crossattention = RobertaPreLayerNormAttention(config, position_embedding_type="absolute")
389
+ self.intermediate = RobertaPreLayerNormIntermediate(config)
390
+ self.output = RobertaPreLayerNormOutput(config)
391
+
392
+ def forward(
393
+ self,
394
+ hidden_states: torch.Tensor,
395
+ attention_mask: Optional[torch.FloatTensor] = None,
396
+ head_mask: Optional[torch.FloatTensor] = None,
397
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
398
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
399
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
400
+ output_attentions: Optional[bool] = False,
401
+ ) -> Tuple[torch.Tensor]:
402
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
403
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
404
+ self_attention_outputs = self.attention(
405
+ hidden_states,
406
+ attention_mask,
407
+ head_mask,
408
+ output_attentions=output_attentions,
409
+ past_key_value=self_attn_past_key_value,
410
+ )
411
+ attention_output = self_attention_outputs[0]
412
+
413
+ # if decoder, the last output is tuple of self-attn cache
414
+ if self.is_decoder:
415
+ outputs = self_attention_outputs[1:-1]
416
+ present_key_value = self_attention_outputs[-1]
417
+ else:
418
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
419
+
420
+ cross_attn_present_key_value = None
421
+ if self.is_decoder and encoder_hidden_states is not None:
422
+ if not hasattr(self, "crossattention"):
423
+ raise ValueError(
424
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
425
+ " by setting `config.add_cross_attention=True`"
426
+ )
427
+
428
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
429
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
430
+ cross_attention_outputs = self.crossattention(
431
+ attention_output,
432
+ attention_mask,
433
+ head_mask,
434
+ encoder_hidden_states,
435
+ encoder_attention_mask,
436
+ cross_attn_past_key_value,
437
+ output_attentions,
438
+ )
439
+ attention_output = cross_attention_outputs[0]
440
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
441
+
442
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
443
+ cross_attn_present_key_value = cross_attention_outputs[-1]
444
+ present_key_value = present_key_value + cross_attn_present_key_value
445
+
446
+ layer_output = apply_chunking_to_forward(
447
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
448
+ )
449
+ outputs = (layer_output,) + outputs
450
+
451
+ # if decoder, return the attn key/values as the last output
452
+ if self.is_decoder:
453
+ outputs = outputs + (present_key_value,)
454
+
455
+ return outputs
456
+
457
+ def feed_forward_chunk(self, attention_output):
458
+ intermediate_output = self.intermediate(attention_output)
459
+ layer_output = self.output(intermediate_output, attention_output)
460
+ return layer_output
461
+
462
+
463
+ # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->RobertaPreLayerNorm
464
+ class RobertaPreLayerNormEncoder(nn.Module):
465
+ def __init__(self, config):
466
+ super().__init__()
467
+ self.config = config
468
+ self.layer = nn.ModuleList([RobertaPreLayerNormLayer(config) for _ in range(config.num_hidden_layers)])
469
+ self.gradient_checkpointing = False
470
+
471
+ def forward(
472
+ self,
473
+ hidden_states: torch.Tensor,
474
+ attention_mask: Optional[torch.FloatTensor] = None,
475
+ head_mask: Optional[torch.FloatTensor] = None,
476
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
477
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
478
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
479
+ use_cache: Optional[bool] = None,
480
+ output_attentions: Optional[bool] = False,
481
+ output_hidden_states: Optional[bool] = False,
482
+ return_dict: Optional[bool] = True,
483
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
484
+ all_hidden_states = () if output_hidden_states else None
485
+ all_self_attentions = () if output_attentions else None
486
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
487
+
488
+ if self.gradient_checkpointing and self.training:
489
+ if use_cache:
490
+ logger.warning_once(
491
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
492
+ )
493
+ use_cache = False
494
+
495
+ next_decoder_cache = () if use_cache else None
496
+ for i, layer_module in enumerate(self.layer):
497
+ if output_hidden_states:
498
+ all_hidden_states = all_hidden_states + (hidden_states,)
499
+
500
+ layer_head_mask = head_mask[i] if head_mask is not None else None
501
+ past_key_value = past_key_values[i] if past_key_values is not None else None
502
+
503
+ if self.gradient_checkpointing and self.training:
504
+ layer_outputs = self._gradient_checkpointing_func(
505
+ layer_module.__call__,
506
+ hidden_states,
507
+ attention_mask,
508
+ layer_head_mask,
509
+ encoder_hidden_states,
510
+ encoder_attention_mask,
511
+ past_key_value,
512
+ output_attentions,
513
+ )
514
+ else:
515
+ layer_outputs = layer_module(
516
+ hidden_states,
517
+ attention_mask,
518
+ layer_head_mask,
519
+ encoder_hidden_states,
520
+ encoder_attention_mask,
521
+ past_key_value,
522
+ output_attentions,
523
+ )
524
+
525
+ hidden_states = layer_outputs[0]
526
+ if use_cache:
527
+ next_decoder_cache += (layer_outputs[-1],)
528
+ if output_attentions:
529
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
530
+ if self.config.add_cross_attention:
531
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
532
+
533
+ if output_hidden_states:
534
+ all_hidden_states = all_hidden_states + (hidden_states,)
535
+
536
+ if not return_dict:
537
+ return tuple(
538
+ v
539
+ for v in [
540
+ hidden_states,
541
+ next_decoder_cache,
542
+ all_hidden_states,
543
+ all_self_attentions,
544
+ all_cross_attentions,
545
+ ]
546
+ if v is not None
547
+ )
548
+ return BaseModelOutputWithPastAndCrossAttentions(
549
+ last_hidden_state=hidden_states,
550
+ past_key_values=next_decoder_cache,
551
+ hidden_states=all_hidden_states,
552
+ attentions=all_self_attentions,
553
+ cross_attentions=all_cross_attentions,
554
+ )
555
+
556
+
557
+ # Copied from transformers.models.bert.modeling_bert.BertPooler
558
+ class RobertaPreLayerNormPooler(nn.Module):
559
+ def __init__(self, config):
560
+ super().__init__()
561
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
562
+ self.activation = nn.Tanh()
563
+
564
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
565
+ # We "pool" the model by simply taking the hidden state corresponding
566
+ # to the first token.
567
+ first_token_tensor = hidden_states[:, 0]
568
+ pooled_output = self.dense(first_token_tensor)
569
+ pooled_output = self.activation(pooled_output)
570
+ return pooled_output
571
+
572
+
573
+ class RobertaPreLayerNormPreTrainedModel(PreTrainedModel):
574
+ """
575
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
576
+ models.
577
+ """
578
+
579
+ config_class = RobertaPreLayerNormConfig
580
+ base_model_prefix = "roberta_prelayernorm"
581
+ supports_gradient_checkpointing = True
582
+ _no_split_modules = ["RobertaPreLayerNormEmbeddings", "RobertaPreLayerNormSelfAttention"]
583
+
584
+ # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with BertLMPredictionHead->RobertaPreLayerNormLMHead
585
+ def _init_weights(self, module):
586
+ """Initialize the weights"""
587
+ if isinstance(module, nn.Linear):
588
+ # Slightly different from the TF version which uses truncated_normal for initialization
589
+ # cf https://github.com/pytorch/pytorch/pull/5617
590
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
591
+ if module.bias is not None:
592
+ module.bias.data.zero_()
593
+ elif isinstance(module, nn.Embedding):
594
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
595
+ if module.padding_idx is not None:
596
+ module.weight.data[module.padding_idx].zero_()
597
+ elif isinstance(module, nn.LayerNorm):
598
+ module.bias.data.zero_()
599
+ module.weight.data.fill_(1.0)
600
+ elif isinstance(module, RobertaPreLayerNormLMHead):
601
+ module.bias.data.zero_()
602
+
603
+
604
+ ROBERTA_PRELAYERNORM_START_DOCSTRING = r"""
605
+
606
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
607
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
608
+ etc.)
609
+
610
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
611
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
612
+ and behavior.
613
+
614
+ Parameters:
615
+ config ([`RobertaPreLayerNormConfig`]): Model configuration class with all the parameters of the
616
+ model. Initializing with a config file does not load the weights associated with the model, only the
617
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
618
+ """
619
+
620
+ ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING = r"""
621
+ Args:
622
+ input_ids (`torch.LongTensor` of shape `({0})`):
623
+ Indices of input sequence tokens in the vocabulary.
624
+
625
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
626
+ [`PreTrainedTokenizer.__call__`] for details.
627
+
628
+ [What are input IDs?](../glossary#input-ids)
629
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
630
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
631
+
632
+ - 1 for tokens that are **not masked**,
633
+ - 0 for tokens that are **masked**.
634
+
635
+ [What are attention masks?](../glossary#attention-mask)
636
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
637
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
638
+
639
+ - 0 corresponds to a *sentence A* token,
640
+ - 1 corresponds to a *sentence B* token.
641
+ This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
642
+ >= 2. All the value in this tensor should be always < type_vocab_size.
643
+
644
+ [What are token type IDs?](../glossary#token-type-ids)
645
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
646
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
647
+ config.max_position_embeddings - 1]`.
648
+
649
+ [What are position IDs?](../glossary#position-ids)
650
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
651
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
652
+
653
+ - 1 indicates the head is **not masked**,
654
+ - 0 indicates the head is **masked**.
655
+
656
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
657
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
658
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
659
+ model's internal embedding lookup matrix.
660
+ output_attentions (`bool`, *optional*):
661
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
662
+ tensors for more detail.
663
+ output_hidden_states (`bool`, *optional*):
664
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
665
+ more detail.
666
+ return_dict (`bool`, *optional*):
667
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
668
+ """
669
+
670
+
671
+ @add_start_docstrings(
672
+ "The bare RoBERTa-PreLayerNorm Model transformer outputting raw hidden-states without any specific head on top.",
673
+ ROBERTA_PRELAYERNORM_START_DOCSTRING,
674
+ )
675
+ class RobertaPreLayerNormModel(RobertaPreLayerNormPreTrainedModel):
676
+ """
677
+
678
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
679
+ cross-attention is added between the self-attention layers, following the architecture described in *Attention is
680
+ all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
681
+ Kaiser and Illia Polosukhin.
682
+
683
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
684
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
685
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
686
+
687
+ .. _*Attention is all you need*: https://arxiv.org/abs/1706.03762
688
+
689
+ """
690
+
691
+ def __init__(self, config, add_pooling_layer=True):
692
+ super().__init__(config)
693
+ self.config = config
694
+
695
+ self.embeddings = RobertaPreLayerNormEmbeddings(config)
696
+ self.encoder = RobertaPreLayerNormEncoder(config)
697
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
698
+
699
+ self.pooler = RobertaPreLayerNormPooler(config) if add_pooling_layer else None
700
+
701
+ # Initialize weights and apply final processing
702
+ self.post_init()
703
+
704
+ def get_input_embeddings(self):
705
+ return self.embeddings.word_embeddings
706
+
707
+ def set_input_embeddings(self, value):
708
+ self.embeddings.word_embeddings = value
709
+
710
+ def _prune_heads(self, heads_to_prune):
711
+ """
712
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
713
+ class PreTrainedModel
714
+ """
715
+ for layer, heads in heads_to_prune.items():
716
+ self.encoder.layer[layer].attention.prune_heads(heads)
717
+
718
+ @add_start_docstrings_to_model_forward(ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
719
+ @add_code_sample_docstrings(
720
+ checkpoint=_CHECKPOINT_FOR_DOC,
721
+ output_type=BaseModelOutputWithPoolingAndCrossAttentions,
722
+ config_class=_CONFIG_FOR_DOC,
723
+ )
724
+ def forward(
725
+ self,
726
+ input_ids: Optional[torch.Tensor] = None,
727
+ attention_mask: Optional[torch.Tensor] = None,
728
+ token_type_ids: Optional[torch.Tensor] = None,
729
+ position_ids: Optional[torch.Tensor] = None,
730
+ head_mask: Optional[torch.Tensor] = None,
731
+ inputs_embeds: Optional[torch.Tensor] = None,
732
+ encoder_hidden_states: Optional[torch.Tensor] = None,
733
+ encoder_attention_mask: Optional[torch.Tensor] = None,
734
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
735
+ use_cache: Optional[bool] = None,
736
+ output_attentions: Optional[bool] = None,
737
+ output_hidden_states: Optional[bool] = None,
738
+ return_dict: Optional[bool] = None,
739
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
740
+ r"""
741
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
742
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
743
+ the model is configured as a decoder.
744
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
745
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
746
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
747
+
748
+ - 1 for tokens that are **not masked**,
749
+ - 0 for tokens that are **masked**.
750
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
751
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
752
+
753
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
754
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
755
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
756
+ use_cache (`bool`, *optional*):
757
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
758
+ `past_key_values`).
759
+ """
760
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
761
+ output_hidden_states = (
762
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
763
+ )
764
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
765
+
766
+ if self.config.is_decoder:
767
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
768
+ else:
769
+ use_cache = False
770
+
771
+ if input_ids is not None and inputs_embeds is not None:
772
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
773
+ elif input_ids is not None:
774
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
775
+ input_shape = input_ids.size()
776
+ elif inputs_embeds is not None:
777
+ input_shape = inputs_embeds.size()[:-1]
778
+ else:
779
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
780
+
781
+ batch_size, seq_length = input_shape
782
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
783
+
784
+ # past_key_values_length
785
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
786
+
787
+ if attention_mask is None:
788
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
789
+
790
+ if token_type_ids is None:
791
+ if hasattr(self.embeddings, "token_type_ids"):
792
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
793
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
794
+ token_type_ids = buffered_token_type_ids_expanded
795
+ else:
796
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
797
+
798
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
799
+ # ourselves in which case we just need to make it broadcastable to all heads.
800
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
801
+
802
+ # If a 2D or 3D attention mask is provided for the cross-attention
803
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
804
+ if self.config.is_decoder and encoder_hidden_states is not None:
805
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
806
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
807
+ if encoder_attention_mask is None:
808
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
809
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
810
+ else:
811
+ encoder_extended_attention_mask = None
812
+
813
+ # Prepare head mask if needed
814
+ # 1.0 in head_mask indicate we keep the head
815
+ # attention_probs has shape bsz x n_heads x N x N
816
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
817
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
818
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
819
+
820
+ embedding_output = self.embeddings(
821
+ input_ids=input_ids,
822
+ position_ids=position_ids,
823
+ token_type_ids=token_type_ids,
824
+ inputs_embeds=inputs_embeds,
825
+ past_key_values_length=past_key_values_length,
826
+ )
827
+ encoder_outputs = self.encoder(
828
+ embedding_output,
829
+ attention_mask=extended_attention_mask,
830
+ head_mask=head_mask,
831
+ encoder_hidden_states=encoder_hidden_states,
832
+ encoder_attention_mask=encoder_extended_attention_mask,
833
+ past_key_values=past_key_values,
834
+ use_cache=use_cache,
835
+ output_attentions=output_attentions,
836
+ output_hidden_states=output_hidden_states,
837
+ return_dict=return_dict,
838
+ )
839
+ sequence_output = encoder_outputs[0]
840
+ sequence_output = self.LayerNorm(sequence_output)
841
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
842
+
843
+ if not return_dict:
844
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
845
+
846
+ return BaseModelOutputWithPoolingAndCrossAttentions(
847
+ last_hidden_state=sequence_output,
848
+ pooler_output=pooled_output,
849
+ past_key_values=encoder_outputs.past_key_values,
850
+ hidden_states=encoder_outputs.hidden_states,
851
+ attentions=encoder_outputs.attentions,
852
+ cross_attentions=encoder_outputs.cross_attentions,
853
+ )
854
+
855
+
856
+ @add_start_docstrings(
857
+ """RoBERTa-PreLayerNorm Model with a `language modeling` head on top for CLM fine-tuning.""",
858
+ ROBERTA_PRELAYERNORM_START_DOCSTRING,
859
+ )
860
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM with FacebookAI/roberta-base->andreasmadsen/efficient_mlm_m0.40,ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm, RobertaPreLayerNormTokenizer->RobertaTokenizer
861
+ class RobertaPreLayerNormForCausalLM(RobertaPreLayerNormPreTrainedModel, GenerationMixin):
862
+ _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
863
+
864
+ def __init__(self, config):
865
+ super().__init__(config)
866
+
867
+ if not config.is_decoder:
868
+ logger.warning(
869
+ "If you want to use `RobertaPreLayerNormLMHeadModel` as a standalone, add `is_decoder=True.`"
870
+ )
871
+
872
+ self.roberta_prelayernorm = RobertaPreLayerNormModel(config, add_pooling_layer=False)
873
+ self.lm_head = RobertaPreLayerNormLMHead(config)
874
+
875
+ # Initialize weights and apply final processing
876
+ self.post_init()
877
+
878
+ def get_output_embeddings(self):
879
+ return self.lm_head.decoder
880
+
881
+ def set_output_embeddings(self, new_embeddings):
882
+ self.lm_head.decoder = new_embeddings
883
+
884
+ @add_start_docstrings_to_model_forward(ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
885
+ @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
886
+ def forward(
887
+ self,
888
+ input_ids: Optional[torch.LongTensor] = None,
889
+ attention_mask: Optional[torch.FloatTensor] = None,
890
+ token_type_ids: Optional[torch.LongTensor] = None,
891
+ position_ids: Optional[torch.LongTensor] = None,
892
+ head_mask: Optional[torch.FloatTensor] = None,
893
+ inputs_embeds: Optional[torch.FloatTensor] = None,
894
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
895
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
896
+ labels: Optional[torch.LongTensor] = None,
897
+ past_key_values: Tuple[Tuple[torch.FloatTensor]] = None,
898
+ use_cache: Optional[bool] = None,
899
+ output_attentions: Optional[bool] = None,
900
+ output_hidden_states: Optional[bool] = None,
901
+ return_dict: Optional[bool] = None,
902
+ **kwargs,
903
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
904
+ r"""
905
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
906
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
907
+ the model is configured as a decoder.
908
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
909
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
910
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
911
+
912
+ - 1 for tokens that are **not masked**,
913
+ - 0 for tokens that are **masked**.
914
+
915
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
916
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
917
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
918
+ ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
919
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
920
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
921
+
922
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
923
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
924
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
925
+ use_cache (`bool`, *optional*):
926
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
927
+ `past_key_values`).
928
+
929
+ Returns:
930
+
931
+ Example:
932
+
933
+ ```python
934
+ >>> from transformers import AutoTokenizer, RobertaPreLayerNormForCausalLM, AutoConfig
935
+ >>> import torch
936
+
937
+ >>> tokenizer = AutoTokenizer.from_pretrained("andreasmadsen/efficient_mlm_m0.40")
938
+ >>> config = AutoConfig.from_pretrained("andreasmadsen/efficient_mlm_m0.40")
939
+ >>> config.is_decoder = True
940
+ >>> model = RobertaPreLayerNormForCausalLM.from_pretrained("andreasmadsen/efficient_mlm_m0.40", config=config)
941
+
942
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
943
+ >>> outputs = model(**inputs)
944
+
945
+ >>> prediction_logits = outputs.logits
946
+ ```"""
947
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
948
+ if labels is not None:
949
+ use_cache = False
950
+
951
+ outputs = self.roberta_prelayernorm(
952
+ input_ids,
953
+ attention_mask=attention_mask,
954
+ token_type_ids=token_type_ids,
955
+ position_ids=position_ids,
956
+ head_mask=head_mask,
957
+ inputs_embeds=inputs_embeds,
958
+ encoder_hidden_states=encoder_hidden_states,
959
+ encoder_attention_mask=encoder_attention_mask,
960
+ past_key_values=past_key_values,
961
+ use_cache=use_cache,
962
+ output_attentions=output_attentions,
963
+ output_hidden_states=output_hidden_states,
964
+ return_dict=return_dict,
965
+ )
966
+
967
+ sequence_output = outputs[0]
968
+ prediction_scores = self.lm_head(sequence_output)
969
+
970
+ lm_loss = None
971
+ if labels is not None:
972
+ # move labels to correct device to enable model parallelism
973
+ labels = labels.to(prediction_scores.device)
974
+ lm_loss = self.loss_function(
975
+ prediction_scores,
976
+ labels,
977
+ vocab_size=self.config.vocab_size,
978
+ **kwargs,
979
+ )
980
+
981
+ if not return_dict:
982
+ output = (prediction_scores,) + outputs[2:]
983
+ return ((lm_loss,) + output) if lm_loss is not None else output
984
+
985
+ return CausalLMOutputWithCrossAttentions(
986
+ loss=lm_loss,
987
+ logits=prediction_scores,
988
+ past_key_values=outputs.past_key_values,
989
+ hidden_states=outputs.hidden_states,
990
+ attentions=outputs.attentions,
991
+ cross_attentions=outputs.cross_attentions,
992
+ )
993
+
994
+ def _reorder_cache(self, past_key_values, beam_idx):
995
+ reordered_past = ()
996
+ for layer_past in past_key_values:
997
+ reordered_past += (
998
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
999
+ )
1000
+ return reordered_past
1001
+
1002
+
1003
+ @add_start_docstrings(
1004
+ """RoBERTa-PreLayerNorm Model with a `language modeling` head on top.""", ROBERTA_PRELAYERNORM_START_DOCSTRING
1005
+ )
1006
+ class RobertaPreLayerNormForMaskedLM(RobertaPreLayerNormPreTrainedModel):
1007
+ _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
1008
+
1009
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaForMaskedLM.__init__ with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm
1010
+ def __init__(self, config):
1011
+ super().__init__(config)
1012
+
1013
+ if config.is_decoder:
1014
+ logger.warning(
1015
+ "If you want to use `RobertaPreLayerNormForMaskedLM` make sure `config.is_decoder=False` for "
1016
+ "bi-directional self-attention."
1017
+ )
1018
+
1019
+ self.roberta_prelayernorm = RobertaPreLayerNormModel(config, add_pooling_layer=False)
1020
+ self.lm_head = RobertaPreLayerNormLMHead(config)
1021
+
1022
+ # Initialize weights and apply final processing
1023
+ self.post_init()
1024
+
1025
+ def get_output_embeddings(self):
1026
+ return self.lm_head.decoder
1027
+
1028
+ def set_output_embeddings(self, new_embeddings):
1029
+ self.lm_head.decoder = new_embeddings
1030
+
1031
+ @add_start_docstrings_to_model_forward(ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1032
+ @add_code_sample_docstrings(
1033
+ checkpoint=_CHECKPOINT_FOR_DOC,
1034
+ output_type=MaskedLMOutput,
1035
+ config_class=_CONFIG_FOR_DOC,
1036
+ mask="<mask>",
1037
+ expected_output="' Paris'",
1038
+ expected_loss=0.69,
1039
+ )
1040
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaForMaskedLM.forward with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm
1041
+ def forward(
1042
+ self,
1043
+ input_ids: Optional[torch.LongTensor] = None,
1044
+ attention_mask: Optional[torch.FloatTensor] = None,
1045
+ token_type_ids: Optional[torch.LongTensor] = None,
1046
+ position_ids: Optional[torch.LongTensor] = None,
1047
+ head_mask: Optional[torch.FloatTensor] = None,
1048
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1049
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1050
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1051
+ labels: Optional[torch.LongTensor] = None,
1052
+ output_attentions: Optional[bool] = None,
1053
+ output_hidden_states: Optional[bool] = None,
1054
+ return_dict: Optional[bool] = None,
1055
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
1056
+ r"""
1057
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1058
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1059
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1060
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1061
+ kwargs (`Dict[str, any]`, *optional*, defaults to `{}`):
1062
+ Used to hide legacy arguments that have been deprecated.
1063
+ """
1064
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1065
+
1066
+ outputs = self.roberta_prelayernorm(
1067
+ input_ids,
1068
+ attention_mask=attention_mask,
1069
+ token_type_ids=token_type_ids,
1070
+ position_ids=position_ids,
1071
+ head_mask=head_mask,
1072
+ inputs_embeds=inputs_embeds,
1073
+ encoder_hidden_states=encoder_hidden_states,
1074
+ encoder_attention_mask=encoder_attention_mask,
1075
+ output_attentions=output_attentions,
1076
+ output_hidden_states=output_hidden_states,
1077
+ return_dict=return_dict,
1078
+ )
1079
+ sequence_output = outputs[0]
1080
+ prediction_scores = self.lm_head(sequence_output)
1081
+
1082
+ masked_lm_loss = None
1083
+ if labels is not None:
1084
+ # move labels to correct device to enable model parallelism
1085
+ labels = labels.to(prediction_scores.device)
1086
+ loss_fct = CrossEntropyLoss()
1087
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1088
+
1089
+ if not return_dict:
1090
+ output = (prediction_scores,) + outputs[2:]
1091
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1092
+
1093
+ return MaskedLMOutput(
1094
+ loss=masked_lm_loss,
1095
+ logits=prediction_scores,
1096
+ hidden_states=outputs.hidden_states,
1097
+ attentions=outputs.attentions,
1098
+ )
1099
+
1100
+
1101
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaLMHead with Roberta->RobertaPreLayerNorm
1102
+ class RobertaPreLayerNormLMHead(nn.Module):
1103
+ """RobertaPreLayerNorm Head for masked language modeling."""
1104
+
1105
+ def __init__(self, config):
1106
+ super().__init__()
1107
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1108
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1109
+
1110
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
1111
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
1112
+ self.decoder.bias = self.bias
1113
+
1114
+ def forward(self, features, **kwargs):
1115
+ x = self.dense(features)
1116
+ x = gelu(x)
1117
+ x = self.layer_norm(x)
1118
+
1119
+ # project back to size of vocabulary with bias
1120
+ x = self.decoder(x)
1121
+
1122
+ return x
1123
+
1124
+ def _tie_weights(self):
1125
+ # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
1126
+ # For accelerate compatibility and to not break backward compatibility
1127
+ if self.decoder.bias.device.type == "meta":
1128
+ self.decoder.bias = self.bias
1129
+ else:
1130
+ self.bias = self.decoder.bias
1131
+
1132
+
1133
+ @add_start_docstrings(
1134
+ """
1135
+ RoBERTa-PreLayerNorm Model transformer with a sequence classification/regression head on top (a linear layer on top
1136
+ of the pooled output) e.g. for GLUE tasks.
1137
+ """,
1138
+ ROBERTA_PRELAYERNORM_START_DOCSTRING,
1139
+ )
1140
+ class RobertaPreLayerNormForSequenceClassification(RobertaPreLayerNormPreTrainedModel):
1141
+ def __init__(self, config):
1142
+ super().__init__(config)
1143
+ self.num_labels = config.num_labels
1144
+ self.config = config
1145
+
1146
+ self.roberta_prelayernorm = RobertaPreLayerNormModel(config, add_pooling_layer=False)
1147
+ self.classifier = RobertaPreLayerNormClassificationHead(config)
1148
+
1149
+ # Initialize weights and apply final processing
1150
+ self.post_init()
1151
+
1152
+ @add_start_docstrings_to_model_forward(ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1153
+ @add_code_sample_docstrings(
1154
+ checkpoint=_CHECKPOINT_FOR_DOC,
1155
+ output_type=SequenceClassifierOutput,
1156
+ config_class=_CONFIG_FOR_DOC,
1157
+ )
1158
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaForSequenceClassification.forward with roberta->roberta_prelayernorm
1159
+ def forward(
1160
+ self,
1161
+ input_ids: Optional[torch.LongTensor] = None,
1162
+ attention_mask: Optional[torch.FloatTensor] = None,
1163
+ token_type_ids: Optional[torch.LongTensor] = None,
1164
+ position_ids: Optional[torch.LongTensor] = None,
1165
+ head_mask: Optional[torch.FloatTensor] = None,
1166
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1167
+ labels: Optional[torch.LongTensor] = None,
1168
+ output_attentions: Optional[bool] = None,
1169
+ output_hidden_states: Optional[bool] = None,
1170
+ return_dict: Optional[bool] = None,
1171
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1172
+ r"""
1173
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1174
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1175
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1176
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1177
+ """
1178
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1179
+
1180
+ outputs = self.roberta_prelayernorm(
1181
+ input_ids,
1182
+ attention_mask=attention_mask,
1183
+ token_type_ids=token_type_ids,
1184
+ position_ids=position_ids,
1185
+ head_mask=head_mask,
1186
+ inputs_embeds=inputs_embeds,
1187
+ output_attentions=output_attentions,
1188
+ output_hidden_states=output_hidden_states,
1189
+ return_dict=return_dict,
1190
+ )
1191
+ sequence_output = outputs[0]
1192
+ logits = self.classifier(sequence_output)
1193
+
1194
+ loss = None
1195
+ if labels is not None:
1196
+ # move labels to correct device to enable model parallelism
1197
+ labels = labels.to(logits.device)
1198
+ if self.config.problem_type is None:
1199
+ if self.num_labels == 1:
1200
+ self.config.problem_type = "regression"
1201
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1202
+ self.config.problem_type = "single_label_classification"
1203
+ else:
1204
+ self.config.problem_type = "multi_label_classification"
1205
+
1206
+ if self.config.problem_type == "regression":
1207
+ loss_fct = MSELoss()
1208
+ if self.num_labels == 1:
1209
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1210
+ else:
1211
+ loss = loss_fct(logits, labels)
1212
+ elif self.config.problem_type == "single_label_classification":
1213
+ loss_fct = CrossEntropyLoss()
1214
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1215
+ elif self.config.problem_type == "multi_label_classification":
1216
+ loss_fct = BCEWithLogitsLoss()
1217
+ loss = loss_fct(logits, labels)
1218
+
1219
+ if not return_dict:
1220
+ output = (logits,) + outputs[2:]
1221
+ return ((loss,) + output) if loss is not None else output
1222
+
1223
+ return SequenceClassifierOutput(
1224
+ loss=loss,
1225
+ logits=logits,
1226
+ hidden_states=outputs.hidden_states,
1227
+ attentions=outputs.attentions,
1228
+ )
1229
+
1230
+
1231
+ @add_start_docstrings(
1232
+ """
1233
+ RobertaPreLayerNorm Model with a multiple choice classification head on top (a linear layer on top of the pooled
1234
+ output and a softmax) e.g. for RocStories/SWAG tasks.
1235
+ """,
1236
+ ROBERTA_PRELAYERNORM_START_DOCSTRING,
1237
+ )
1238
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaForMultipleChoice with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm
1239
+ class RobertaPreLayerNormForMultipleChoice(RobertaPreLayerNormPreTrainedModel):
1240
+ def __init__(self, config):
1241
+ super().__init__(config)
1242
+
1243
+ self.roberta_prelayernorm = RobertaPreLayerNormModel(config)
1244
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1245
+ self.classifier = nn.Linear(config.hidden_size, 1)
1246
+
1247
+ # Initialize weights and apply final processing
1248
+ self.post_init()
1249
+
1250
+ @add_start_docstrings_to_model_forward(
1251
+ ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
1252
+ )
1253
+ @add_code_sample_docstrings(
1254
+ checkpoint=_CHECKPOINT_FOR_DOC,
1255
+ output_type=MultipleChoiceModelOutput,
1256
+ config_class=_CONFIG_FOR_DOC,
1257
+ )
1258
+ def forward(
1259
+ self,
1260
+ input_ids: Optional[torch.LongTensor] = None,
1261
+ token_type_ids: Optional[torch.LongTensor] = None,
1262
+ attention_mask: Optional[torch.FloatTensor] = None,
1263
+ labels: Optional[torch.LongTensor] = None,
1264
+ position_ids: Optional[torch.LongTensor] = None,
1265
+ head_mask: Optional[torch.FloatTensor] = None,
1266
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1267
+ output_attentions: Optional[bool] = None,
1268
+ output_hidden_states: Optional[bool] = None,
1269
+ return_dict: Optional[bool] = None,
1270
+ ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
1271
+ r"""
1272
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1273
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
1274
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
1275
+ `input_ids` above)
1276
+ """
1277
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1278
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1279
+
1280
+ flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1281
+ flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1282
+ flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1283
+ flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1284
+ flat_inputs_embeds = (
1285
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1286
+ if inputs_embeds is not None
1287
+ else None
1288
+ )
1289
+
1290
+ outputs = self.roberta_prelayernorm(
1291
+ flat_input_ids,
1292
+ position_ids=flat_position_ids,
1293
+ token_type_ids=flat_token_type_ids,
1294
+ attention_mask=flat_attention_mask,
1295
+ head_mask=head_mask,
1296
+ inputs_embeds=flat_inputs_embeds,
1297
+ output_attentions=output_attentions,
1298
+ output_hidden_states=output_hidden_states,
1299
+ return_dict=return_dict,
1300
+ )
1301
+ pooled_output = outputs[1]
1302
+
1303
+ pooled_output = self.dropout(pooled_output)
1304
+ logits = self.classifier(pooled_output)
1305
+ reshaped_logits = logits.view(-1, num_choices)
1306
+
1307
+ loss = None
1308
+ if labels is not None:
1309
+ # move labels to correct device to enable model parallelism
1310
+ labels = labels.to(reshaped_logits.device)
1311
+ loss_fct = CrossEntropyLoss()
1312
+ loss = loss_fct(reshaped_logits, labels)
1313
+
1314
+ if not return_dict:
1315
+ output = (reshaped_logits,) + outputs[2:]
1316
+ return ((loss,) + output) if loss is not None else output
1317
+
1318
+ return MultipleChoiceModelOutput(
1319
+ loss=loss,
1320
+ logits=reshaped_logits,
1321
+ hidden_states=outputs.hidden_states,
1322
+ attentions=outputs.attentions,
1323
+ )
1324
+
1325
+
1326
+ @add_start_docstrings(
1327
+ """
1328
+ RobertaPreLayerNorm Model with a token classification head on top (a linear layer on top of the hidden-states
1329
+ output) e.g. for Named-Entity-Recognition (NER) tasks.
1330
+ """,
1331
+ ROBERTA_PRELAYERNORM_START_DOCSTRING,
1332
+ )
1333
+ class RobertaPreLayerNormForTokenClassification(RobertaPreLayerNormPreTrainedModel):
1334
+ def __init__(self, config):
1335
+ super().__init__(config)
1336
+ self.num_labels = config.num_labels
1337
+
1338
+ self.roberta_prelayernorm = RobertaPreLayerNormModel(config, add_pooling_layer=False)
1339
+ classifier_dropout = (
1340
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1341
+ )
1342
+ self.dropout = nn.Dropout(classifier_dropout)
1343
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1344
+
1345
+ # Initialize weights and apply final processing
1346
+ self.post_init()
1347
+
1348
+ @add_start_docstrings_to_model_forward(ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1349
+ @add_code_sample_docstrings(
1350
+ checkpoint=_CHECKPOINT_FOR_DOC,
1351
+ output_type=TokenClassifierOutput,
1352
+ config_class=_CONFIG_FOR_DOC,
1353
+ )
1354
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaForTokenClassification.forward with roberta->roberta_prelayernorm
1355
+ def forward(
1356
+ self,
1357
+ input_ids: Optional[torch.LongTensor] = None,
1358
+ attention_mask: Optional[torch.FloatTensor] = None,
1359
+ token_type_ids: Optional[torch.LongTensor] = None,
1360
+ position_ids: Optional[torch.LongTensor] = None,
1361
+ head_mask: Optional[torch.FloatTensor] = None,
1362
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1363
+ labels: Optional[torch.LongTensor] = None,
1364
+ output_attentions: Optional[bool] = None,
1365
+ output_hidden_states: Optional[bool] = None,
1366
+ return_dict: Optional[bool] = None,
1367
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1368
+ r"""
1369
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1370
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1371
+ """
1372
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1373
+
1374
+ outputs = self.roberta_prelayernorm(
1375
+ input_ids,
1376
+ attention_mask=attention_mask,
1377
+ token_type_ids=token_type_ids,
1378
+ position_ids=position_ids,
1379
+ head_mask=head_mask,
1380
+ inputs_embeds=inputs_embeds,
1381
+ output_attentions=output_attentions,
1382
+ output_hidden_states=output_hidden_states,
1383
+ return_dict=return_dict,
1384
+ )
1385
+
1386
+ sequence_output = outputs[0]
1387
+
1388
+ sequence_output = self.dropout(sequence_output)
1389
+ logits = self.classifier(sequence_output)
1390
+
1391
+ loss = None
1392
+ if labels is not None:
1393
+ # move labels to correct device to enable model parallelism
1394
+ labels = labels.to(logits.device)
1395
+ loss_fct = CrossEntropyLoss()
1396
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1397
+
1398
+ if not return_dict:
1399
+ output = (logits,) + outputs[2:]
1400
+ return ((loss,) + output) if loss is not None else output
1401
+
1402
+ return TokenClassifierOutput(
1403
+ loss=loss,
1404
+ logits=logits,
1405
+ hidden_states=outputs.hidden_states,
1406
+ attentions=outputs.attentions,
1407
+ )
1408
+
1409
+
1410
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaClassificationHead with Roberta->RobertaPreLayerNorm
1411
+ class RobertaPreLayerNormClassificationHead(nn.Module):
1412
+ """Head for sentence-level classification tasks."""
1413
+
1414
+ def __init__(self, config):
1415
+ super().__init__()
1416
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1417
+ classifier_dropout = (
1418
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1419
+ )
1420
+ self.dropout = nn.Dropout(classifier_dropout)
1421
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
1422
+
1423
+ def forward(self, features, **kwargs):
1424
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
1425
+ x = self.dropout(x)
1426
+ x = self.dense(x)
1427
+ x = torch.tanh(x)
1428
+ x = self.dropout(x)
1429
+ x = self.out_proj(x)
1430
+ return x
1431
+
1432
+
1433
+ @add_start_docstrings(
1434
+ """
1435
+ RobertaPreLayerNorm Model with a span classification head on top for extractive question-answering tasks like SQuAD
1436
+ (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1437
+ """,
1438
+ ROBERTA_PRELAYERNORM_START_DOCSTRING,
1439
+ )
1440
+ class RobertaPreLayerNormForQuestionAnswering(RobertaPreLayerNormPreTrainedModel):
1441
+ def __init__(self, config):
1442
+ super().__init__(config)
1443
+ self.num_labels = config.num_labels
1444
+
1445
+ self.roberta_prelayernorm = RobertaPreLayerNormModel(config, add_pooling_layer=False)
1446
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1447
+
1448
+ # Initialize weights and apply final processing
1449
+ self.post_init()
1450
+
1451
+ @add_start_docstrings_to_model_forward(ROBERTA_PRELAYERNORM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1452
+ @add_code_sample_docstrings(
1453
+ checkpoint=_CHECKPOINT_FOR_DOC,
1454
+ output_type=QuestionAnsweringModelOutput,
1455
+ config_class=_CONFIG_FOR_DOC,
1456
+ )
1457
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaForQuestionAnswering.forward with roberta->roberta_prelayernorm
1458
+ def forward(
1459
+ self,
1460
+ input_ids: Optional[torch.LongTensor] = None,
1461
+ attention_mask: Optional[torch.FloatTensor] = None,
1462
+ token_type_ids: Optional[torch.LongTensor] = None,
1463
+ position_ids: Optional[torch.LongTensor] = None,
1464
+ head_mask: Optional[torch.FloatTensor] = None,
1465
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1466
+ start_positions: Optional[torch.LongTensor] = None,
1467
+ end_positions: Optional[torch.LongTensor] = None,
1468
+ output_attentions: Optional[bool] = None,
1469
+ output_hidden_states: Optional[bool] = None,
1470
+ return_dict: Optional[bool] = None,
1471
+ ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
1472
+ r"""
1473
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1474
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1475
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1476
+ are not taken into account for computing the loss.
1477
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1478
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1479
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1480
+ are not taken into account for computing the loss.
1481
+ """
1482
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1483
+
1484
+ outputs = self.roberta_prelayernorm(
1485
+ input_ids,
1486
+ attention_mask=attention_mask,
1487
+ token_type_ids=token_type_ids,
1488
+ position_ids=position_ids,
1489
+ head_mask=head_mask,
1490
+ inputs_embeds=inputs_embeds,
1491
+ output_attentions=output_attentions,
1492
+ output_hidden_states=output_hidden_states,
1493
+ return_dict=return_dict,
1494
+ )
1495
+
1496
+ sequence_output = outputs[0]
1497
+
1498
+ logits = self.qa_outputs(sequence_output)
1499
+ start_logits, end_logits = logits.split(1, dim=-1)
1500
+ start_logits = start_logits.squeeze(-1).contiguous()
1501
+ end_logits = end_logits.squeeze(-1).contiguous()
1502
+
1503
+ total_loss = None
1504
+ if start_positions is not None and end_positions is not None:
1505
+ # If we are on multi-GPU, split add a dimension
1506
+ if len(start_positions.size()) > 1:
1507
+ start_positions = start_positions.squeeze(-1)
1508
+ if len(end_positions.size()) > 1:
1509
+ end_positions = end_positions.squeeze(-1)
1510
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1511
+ ignored_index = start_logits.size(1)
1512
+ start_positions = start_positions.clamp(0, ignored_index)
1513
+ end_positions = end_positions.clamp(0, ignored_index)
1514
+
1515
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1516
+ start_loss = loss_fct(start_logits, start_positions)
1517
+ end_loss = loss_fct(end_logits, end_positions)
1518
+ total_loss = (start_loss + end_loss) / 2
1519
+
1520
+ if not return_dict:
1521
+ output = (start_logits, end_logits) + outputs[2:]
1522
+ return ((total_loss,) + output) if total_loss is not None else output
1523
+
1524
+ return QuestionAnsweringModelOutput(
1525
+ loss=total_loss,
1526
+ start_logits=start_logits,
1527
+ end_logits=end_logits,
1528
+ hidden_states=outputs.hidden_states,
1529
+ attentions=outputs.attentions,
1530
+ )
1531
+
1532
+
1533
+ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
1534
+ """
1535
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
1536
+ are ignored. This is modified from fairseq's `utils.make_positions`.
1537
+
1538
+ Args:
1539
+ x: torch.Tensor x:
1540
+
1541
+ Returns: torch.Tensor
1542
+ """
1543
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
1544
+ mask = input_ids.ne(padding_idx).int()
1545
+ incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
1546
+ return incremental_indices.long() + padding_idx
1547
+
1548
+
1549
+ __all__ = [
1550
+ "RobertaPreLayerNormForCausalLM",
1551
+ "RobertaPreLayerNormForMaskedLM",
1552
+ "RobertaPreLayerNormForMultipleChoice",
1553
+ "RobertaPreLayerNormForQuestionAnswering",
1554
+ "RobertaPreLayerNormForSequenceClassification",
1555
+ "RobertaPreLayerNormForTokenClassification",
1556
+ "RobertaPreLayerNormModel",
1557
+ "RobertaPreLayerNormPreTrainedModel",
1558
+ ]
docs/transformers/build/lib/transformers/models/roc_bert/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_roc_bert import *
22
+ from .modeling_roc_bert import *
23
+ from .tokenization_roc_bert import *
24
+ else:
25
+ import sys
26
+
27
+ _file = globals()["__file__"]
28
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/roc_bert/configuration_roc_bert.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 WeChatAI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """RoCBert model configuration"""
16
+
17
+ from ...configuration_utils import PretrainedConfig
18
+ from ...utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class RoCBertConfig(PretrainedConfig):
25
+ r"""
26
+ This is the configuration class to store the configuration of a [`RoCBertModel`]. It is used to instantiate a
27
+ RoCBert model according to the specified arguments, defining the model architecture. Instantiating a configuration
28
+ with the defaults will yield a similar configuration to that of the RoCBert
29
+ [weiweishi/roc-bert-base-zh](https://huggingface.co/weiweishi/roc-bert-base-zh) architecture.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+
35
+ Args:
36
+ vocab_size (`int`, *optional*, defaults to 30522):
37
+ Vocabulary size of the RoCBert model. Defines the number of different tokens that can be represented by the
38
+ `inputs_ids` passed when calling [`RoCBertModel`].
39
+ hidden_size (`int`, *optional*, defaults to 768):
40
+ Dimension of the encoder layers and the pooler layer.
41
+ num_hidden_layers (`int`, *optional*, defaults to 12):
42
+ Number of hidden layers in the Transformer encoder.
43
+ num_attention_heads (`int`, *optional*, defaults to 12):
44
+ Number of attention heads for each attention layer in the Transformer encoder.
45
+ intermediate_size (`int`, *optional*, defaults to 3072):
46
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
47
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
48
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
49
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
50
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
51
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
52
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
53
+ The dropout ratio for the attention probabilities.
54
+ max_position_embeddings (`int`, *optional*, defaults to 512):
55
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
56
+ just in case (e.g., 512 or 1024 or 2048).
57
+ type_vocab_size (`int`, *optional*, defaults to 2):
58
+ The vocabulary size of the `token_type_ids` passed when calling [`RoCBertModel`].
59
+ initializer_range (`float`, *optional*, defaults to 0.02):
60
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
61
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
62
+ The epsilon used by the layer normalization layers.
63
+ is_decoder (`bool`, *optional*, defaults to `False`):
64
+ Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
65
+ use_cache (`bool`, *optional*, defaults to `True`):
66
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
67
+ relevant if `config.is_decoder=True`.
68
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
69
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
70
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
71
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
72
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
73
+ with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
74
+ classifier_dropout (`float`, *optional*):
75
+ The dropout ratio for the classification head.
76
+ enable_pronunciation (`bool`, *optional*, defaults to `True`):
77
+ Whether or not the model use pronunciation embed when training.
78
+ enable_shape (`bool`, *optional*, defaults to `True`):
79
+ Whether or not the model use shape embed when training.
80
+ pronunciation_embed_dim (`int`, *optional*, defaults to 768):
81
+ Dimension of the pronunciation_embed.
82
+ pronunciation_vocab_size (`int`, *optional*, defaults to 910):
83
+ Pronunciation Vocabulary size of the RoCBert model. Defines the number of different tokens that can be
84
+ represented by the `input_pronunciation_ids` passed when calling [`RoCBertModel`].
85
+ shape_embed_dim (`int`, *optional*, defaults to 512):
86
+ Dimension of the shape_embed.
87
+ shape_vocab_size (`int`, *optional*, defaults to 24858):
88
+ Shape Vocabulary size of the RoCBert model. Defines the number of different tokens that can be represented
89
+ by the `input_shape_ids` passed when calling [`RoCBertModel`].
90
+ concat_input (`bool`, *optional*, defaults to `True`):
91
+ Defines the way of merging the shape_embed, pronunciation_embed and word_embed, if the value is true,
92
+ output_embed = torch.cat((word_embed, shape_embed, pronunciation_embed), -1), else output_embed =
93
+ (word_embed + shape_embed + pronunciation_embed) / 3
94
+ Example:
95
+
96
+ ```python
97
+ >>> from transformers import RoCBertModel, RoCBertConfig
98
+
99
+ >>> # Initializing a RoCBert weiweishi/roc-bert-base-zh style configuration
100
+ >>> configuration = RoCBertConfig()
101
+
102
+ >>> # Initializing a model from the weiweishi/roc-bert-base-zh style configuration
103
+ >>> model = RoCBertModel(configuration)
104
+
105
+ >>> # Accessing the model configuration
106
+ >>> configuration = model.config
107
+ ```"""
108
+
109
+ model_type = "roc_bert"
110
+
111
+ def __init__(
112
+ self,
113
+ vocab_size=30522,
114
+ hidden_size=768,
115
+ num_hidden_layers=12,
116
+ num_attention_heads=12,
117
+ intermediate_size=3072,
118
+ hidden_act="gelu",
119
+ hidden_dropout_prob=0.1,
120
+ attention_probs_dropout_prob=0.1,
121
+ max_position_embeddings=512,
122
+ type_vocab_size=2,
123
+ initializer_range=0.02,
124
+ layer_norm_eps=1e-12,
125
+ use_cache=True,
126
+ pad_token_id=0,
127
+ position_embedding_type="absolute",
128
+ classifier_dropout=None,
129
+ enable_pronunciation=True,
130
+ enable_shape=True,
131
+ pronunciation_embed_dim=768,
132
+ pronunciation_vocab_size=910,
133
+ shape_embed_dim=512,
134
+ shape_vocab_size=24858,
135
+ concat_input=True,
136
+ **kwargs,
137
+ ):
138
+ self.vocab_size = vocab_size
139
+ self.max_position_embeddings = max_position_embeddings
140
+ self.hidden_size = hidden_size
141
+ self.num_hidden_layers = num_hidden_layers
142
+ self.num_attention_heads = num_attention_heads
143
+ self.intermediate_size = intermediate_size
144
+ self.hidden_act = hidden_act
145
+ self.hidden_dropout_prob = hidden_dropout_prob
146
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
147
+ self.initializer_range = initializer_range
148
+ self.type_vocab_size = type_vocab_size
149
+ self.layer_norm_eps = layer_norm_eps
150
+ self.use_cache = use_cache
151
+ self.enable_pronunciation = enable_pronunciation
152
+ self.enable_shape = enable_shape
153
+ self.pronunciation_embed_dim = pronunciation_embed_dim
154
+ self.pronunciation_vocab_size = pronunciation_vocab_size
155
+ self.shape_embed_dim = shape_embed_dim
156
+ self.shape_vocab_size = shape_vocab_size
157
+ self.concat_input = concat_input
158
+ self.position_embedding_type = position_embedding_type
159
+ self.classifier_dropout = classifier_dropout
160
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
161
+
162
+
163
+ __all__ = ["RoCBertConfig"]
docs/transformers/build/lib/transformers/models/roc_bert/modeling_roc_bert.py ADDED
@@ -0,0 +1,2017 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 WeChatAI The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch RoCBert model."""
16
+
17
+ import math
18
+ import os
19
+ from typing import List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
25
+
26
+ from ...activations import ACT2FN
27
+ from ...generation import GenerationMixin
28
+ from ...modeling_outputs import (
29
+ BaseModelOutputWithPastAndCrossAttentions,
30
+ BaseModelOutputWithPoolingAndCrossAttentions,
31
+ CausalLMOutputWithCrossAttentions,
32
+ MaskedLMOutput,
33
+ MultipleChoiceModelOutput,
34
+ QuestionAnsweringModelOutput,
35
+ SequenceClassifierOutput,
36
+ TokenClassifierOutput,
37
+ )
38
+ from ...modeling_utils import PreTrainedModel
39
+ from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
40
+ from ...utils import (
41
+ add_code_sample_docstrings,
42
+ add_start_docstrings,
43
+ add_start_docstrings_to_model_forward,
44
+ logging,
45
+ replace_return_docstrings,
46
+ )
47
+ from .configuration_roc_bert import RoCBertConfig
48
+
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+ _CHECKPOINT_FOR_DOC = "weiweishi/roc-bert-base-zh"
53
+ _CONFIG_FOR_DOC = "RoCBertConfig"
54
+
55
+ # Base model docstring
56
+ _EXPECTED_OUTPUT_SHAPE = [1, 8, 768]
57
+
58
+ # Token Classification output
59
+ _CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "ArthurZ/dummy-rocbert-ner"
60
+ _TOKEN_CLASS_EXPECTED_OUTPUT = ["S-EVENT", "S-FAC", "I-ORDINAL", "I-ORDINAL", "E-ORG", "E-LANGUAGE", "E-ORG", "E-ORG", "E-ORG", "E-ORG", "I-EVENT", "S-TIME", "S-TIME", "E-LANGUAGE", "S-TIME", "E-DATE", "I-ORDINAL", "E-QUANTITY", "E-LANGUAGE", "S-TIME", "B-ORDINAL", "S-PRODUCT", "E-LANGUAGE", "E-LANGUAGE", "E-ORG", "E-LOC", "S-TIME", "I-ORDINAL", "S-FAC", "O", "S-GPE", "I-EVENT", "S-GPE", "E-LANGUAGE", "E-ORG", "S-EVENT", "S-FAC", "S-FAC", "S-FAC", "E-ORG", "S-FAC", "E-ORG", "S-GPE"] # fmt: skip
61
+ _TOKEN_CLASS_EXPECTED_LOSS = 3.62
62
+
63
+ # SequenceClassification docstring
64
+ _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ArthurZ/dummy-rocbert-seq"
65
+ _SEQ_CLASS_EXPECTED_OUTPUT = "'financial news'"
66
+ _SEQ_CLASS_EXPECTED_LOSS = 2.31
67
+
68
+ # QuestionAsnwering docstring
69
+ _CHECKPOINT_FOR_QA = "ArthurZ/dummy-rocbert-qa"
70
+ _QA_EXPECTED_OUTPUT = "''"
71
+ _QA_EXPECTED_LOSS = 3.75
72
+ _QA_TARGET_START_INDEX = 14
73
+ _QA_TARGET_END_INDEX = 15
74
+
75
+ # Maske language modeling
76
+
77
+
78
+ # Copied from transformers.models.bert.modeling_bert.load_tf_weights_in_bert with bert->roc_bert
79
+ def load_tf_weights_in_roc_bert(model, config, tf_checkpoint_path):
80
+ """Load tf checkpoints in a pytorch model."""
81
+ try:
82
+ import re
83
+
84
+ import numpy as np
85
+ import tensorflow as tf
86
+ except ImportError:
87
+ logger.error(
88
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
89
+ "https://www.tensorflow.org/install/ for installation instructions."
90
+ )
91
+ raise
92
+ tf_path = os.path.abspath(tf_checkpoint_path)
93
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
94
+ # Load weights from TF model
95
+ init_vars = tf.train.list_variables(tf_path)
96
+ names = []
97
+ arrays = []
98
+ for name, shape in init_vars:
99
+ logger.info(f"Loading TF weight {name} with shape {shape}")
100
+ array = tf.train.load_variable(tf_path, name)
101
+ names.append(name)
102
+ arrays.append(array)
103
+
104
+ for name, array in zip(names, arrays):
105
+ name = name.split("/")
106
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
107
+ # which are not required for using pretrained model
108
+ if any(
109
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
110
+ for n in name
111
+ ):
112
+ logger.info(f"Skipping {'/'.join(name)}")
113
+ continue
114
+ pointer = model
115
+ for m_name in name:
116
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
117
+ scope_names = re.split(r"_(\d+)", m_name)
118
+ else:
119
+ scope_names = [m_name]
120
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
121
+ pointer = getattr(pointer, "weight")
122
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
123
+ pointer = getattr(pointer, "bias")
124
+ elif scope_names[0] == "output_weights":
125
+ pointer = getattr(pointer, "weight")
126
+ elif scope_names[0] == "squad":
127
+ pointer = getattr(pointer, "classifier")
128
+ else:
129
+ try:
130
+ pointer = getattr(pointer, scope_names[0])
131
+ except AttributeError:
132
+ logger.info(f"Skipping {'/'.join(name)}")
133
+ continue
134
+ if len(scope_names) >= 2:
135
+ num = int(scope_names[1])
136
+ pointer = pointer[num]
137
+ if m_name[-11:] == "_embeddings":
138
+ pointer = getattr(pointer, "weight")
139
+ elif m_name == "kernel":
140
+ array = np.transpose(array)
141
+ try:
142
+ if pointer.shape != array.shape:
143
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
144
+ except ValueError as e:
145
+ e.args += (pointer.shape, array.shape)
146
+ raise
147
+ logger.info(f"Initialize PyTorch weight {name}")
148
+ pointer.data = torch.from_numpy(array)
149
+ return model
150
+
151
+
152
+ class RoCBertEmbeddings(nn.Module):
153
+ """Construct the embeddings from word, position, shape, pronunciation and token_type embeddings."""
154
+
155
+ def __init__(self, config):
156
+ super().__init__()
157
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
158
+ self.pronunciation_embed = nn.Embedding(
159
+ config.pronunciation_vocab_size, config.pronunciation_embed_dim, padding_idx=config.pad_token_id
160
+ )
161
+ self.shape_embed = nn.Embedding(
162
+ config.shape_vocab_size, config.shape_embed_dim, padding_idx=config.pad_token_id
163
+ )
164
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
165
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
166
+
167
+ self.enable_pronunciation = config.enable_pronunciation
168
+ self.enable_shape = config.enable_shape
169
+
170
+ if config.concat_input:
171
+ input_dim = config.hidden_size
172
+ if self.enable_pronunciation:
173
+ pronunciation_dim = config.pronunciation_embed_dim
174
+ input_dim += pronunciation_dim
175
+ if self.enable_shape:
176
+ shape_dim = config.shape_embed_dim
177
+ input_dim += shape_dim
178
+ self.map_inputs_layer = torch.nn.Linear(input_dim, config.hidden_size)
179
+ else:
180
+ self.map_inputs_layer = None
181
+
182
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
183
+ # any TensorFlow checkpoint file
184
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
185
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
186
+
187
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
188
+ self.register_buffer(
189
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
190
+ )
191
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
192
+ self.register_buffer(
193
+ "token_type_ids",
194
+ torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
195
+ persistent=False,
196
+ )
197
+
198
+ def forward(
199
+ self,
200
+ input_ids=None,
201
+ input_shape_ids=None,
202
+ input_pronunciation_ids=None,
203
+ token_type_ids=None,
204
+ position_ids=None,
205
+ inputs_embeds=None,
206
+ past_key_values_length=0,
207
+ ):
208
+ if input_ids is not None:
209
+ input_shape = input_ids.size()
210
+ else:
211
+ input_shape = inputs_embeds.size()[:-1]
212
+
213
+ seq_length = input_shape[1]
214
+
215
+ if position_ids is None:
216
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
217
+
218
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
219
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
220
+ # issue #5664
221
+ if token_type_ids is None:
222
+ if hasattr(self, "token_type_ids"):
223
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
224
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
225
+ token_type_ids = buffered_token_type_ids_expanded
226
+ else:
227
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
228
+
229
+ if self.map_inputs_layer is None:
230
+ if inputs_embeds is None:
231
+ inputs_embeds = self.word_embeddings(input_ids)
232
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
233
+ embeddings = inputs_embeds + token_type_embeddings
234
+ if self.position_embedding_type == "absolute":
235
+ position_embeddings = self.position_embeddings(position_ids)
236
+ embeddings += position_embeddings
237
+ embeddings = self.LayerNorm(embeddings)
238
+ embeddings = self.dropout(embeddings)
239
+
240
+ denominator = 1
241
+ embedding_in = torch.clone(embeddings)
242
+ if self.enable_shape and input_shape_ids is not None:
243
+ embedding_shape = self.shape_embed(input_shape_ids)
244
+ embedding_in += embedding_shape
245
+ denominator += 1
246
+ if self.enable_pronunciation and input_pronunciation_ids is not None:
247
+ embedding_pronunciation = self.pronunciation_embed(input_pronunciation_ids)
248
+ embedding_in += embedding_pronunciation
249
+ denominator += 1
250
+
251
+ embedding_in /= denominator
252
+ return embedding_in
253
+ else:
254
+ if inputs_embeds is None:
255
+ inputs_embeds = self.word_embeddings(input_ids) # embedding_word
256
+ device = inputs_embeds.device
257
+
258
+ embedding_in = torch.clone(inputs_embeds)
259
+ if self.enable_shape:
260
+ if input_shape_ids is None:
261
+ input_shape_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
262
+ embedding_shape = self.shape_embed(input_shape_ids)
263
+ embedding_in = torch.cat((embedding_in, embedding_shape), -1)
264
+ if self.enable_pronunciation:
265
+ if input_pronunciation_ids is None:
266
+ input_pronunciation_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
267
+ embedding_pronunciation = self.pronunciation_embed(input_pronunciation_ids)
268
+ embedding_in = torch.cat((embedding_in, embedding_pronunciation), -1)
269
+
270
+ embedding_in = self.map_inputs_layer(embedding_in) # batch_size * seq_len * hidden_dim
271
+
272
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
273
+ embedding_in += token_type_embeddings
274
+ if self.position_embedding_type == "absolute":
275
+ position_embeddings = self.position_embeddings(position_ids)
276
+ embedding_in += position_embeddings
277
+
278
+ embedding_in = self.LayerNorm(embedding_in)
279
+ embedding_in = self.dropout(embedding_in)
280
+ return embedding_in
281
+
282
+
283
+ # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->RoCBert
284
+ class RoCBertSelfAttention(nn.Module):
285
+ def __init__(self, config, position_embedding_type=None):
286
+ super().__init__()
287
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
288
+ raise ValueError(
289
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
290
+ f"heads ({config.num_attention_heads})"
291
+ )
292
+
293
+ self.num_attention_heads = config.num_attention_heads
294
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
295
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
296
+
297
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
298
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
299
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
300
+
301
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
302
+ self.position_embedding_type = position_embedding_type or getattr(
303
+ config, "position_embedding_type", "absolute"
304
+ )
305
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
306
+ self.max_position_embeddings = config.max_position_embeddings
307
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
308
+
309
+ self.is_decoder = config.is_decoder
310
+
311
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
312
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
313
+ x = x.view(new_x_shape)
314
+ return x.permute(0, 2, 1, 3)
315
+
316
+ def forward(
317
+ self,
318
+ hidden_states: torch.Tensor,
319
+ attention_mask: Optional[torch.FloatTensor] = None,
320
+ head_mask: Optional[torch.FloatTensor] = None,
321
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
322
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
323
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
324
+ output_attentions: Optional[bool] = False,
325
+ ) -> Tuple[torch.Tensor]:
326
+ mixed_query_layer = self.query(hidden_states)
327
+
328
+ # If this is instantiated as a cross-attention module, the keys
329
+ # and values come from an encoder; the attention mask needs to be
330
+ # such that the encoder's padding tokens are not attended to.
331
+ is_cross_attention = encoder_hidden_states is not None
332
+
333
+ if is_cross_attention and past_key_value is not None:
334
+ # reuse k,v, cross_attentions
335
+ key_layer = past_key_value[0]
336
+ value_layer = past_key_value[1]
337
+ attention_mask = encoder_attention_mask
338
+ elif is_cross_attention:
339
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
340
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
341
+ attention_mask = encoder_attention_mask
342
+ elif past_key_value is not None:
343
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
344
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
345
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
346
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
347
+ else:
348
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
349
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
350
+
351
+ query_layer = self.transpose_for_scores(mixed_query_layer)
352
+
353
+ use_cache = past_key_value is not None
354
+ if self.is_decoder:
355
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
356
+ # Further calls to cross_attention layer can then reuse all cross-attention
357
+ # key/value_states (first "if" case)
358
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
359
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
360
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
361
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
362
+ past_key_value = (key_layer, value_layer)
363
+
364
+ # Take the dot product between "query" and "key" to get the raw attention scores.
365
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
366
+
367
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
368
+ query_length, key_length = query_layer.shape[2], key_layer.shape[2]
369
+ if use_cache:
370
+ position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
371
+ -1, 1
372
+ )
373
+ else:
374
+ position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
375
+ position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
376
+ distance = position_ids_l - position_ids_r
377
+
378
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
379
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
380
+
381
+ if self.position_embedding_type == "relative_key":
382
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
383
+ attention_scores = attention_scores + relative_position_scores
384
+ elif self.position_embedding_type == "relative_key_query":
385
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
386
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
387
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
388
+
389
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
390
+ if attention_mask is not None:
391
+ # Apply the attention mask is (precomputed for all layers in RoCBertModel forward() function)
392
+ attention_scores = attention_scores + attention_mask
393
+
394
+ # Normalize the attention scores to probabilities.
395
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
396
+
397
+ # This is actually dropping out entire tokens to attend to, which might
398
+ # seem a bit unusual, but is taken from the original Transformer paper.
399
+ attention_probs = self.dropout(attention_probs)
400
+
401
+ # Mask heads if we want to
402
+ if head_mask is not None:
403
+ attention_probs = attention_probs * head_mask
404
+
405
+ context_layer = torch.matmul(attention_probs, value_layer)
406
+
407
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
408
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
409
+ context_layer = context_layer.view(new_context_layer_shape)
410
+
411
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
412
+
413
+ if self.is_decoder:
414
+ outputs = outputs + (past_key_value,)
415
+ return outputs
416
+
417
+
418
+ # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->RoCBert
419
+ class RoCBertSelfOutput(nn.Module):
420
+ def __init__(self, config):
421
+ super().__init__()
422
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
423
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
424
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
425
+
426
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
427
+ hidden_states = self.dense(hidden_states)
428
+ hidden_states = self.dropout(hidden_states)
429
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
430
+ return hidden_states
431
+
432
+
433
+ ROC_BERT_SELF_ATTENTION_CLASSES = {
434
+ "eager": RoCBertSelfAttention,
435
+ }
436
+
437
+
438
+ # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->RoCBert,BERT->ROC_BERT
439
+ class RoCBertAttention(nn.Module):
440
+ def __init__(self, config, position_embedding_type=None):
441
+ super().__init__()
442
+ self.self = ROC_BERT_SELF_ATTENTION_CLASSES[config._attn_implementation](
443
+ config, position_embedding_type=position_embedding_type
444
+ )
445
+ self.output = RoCBertSelfOutput(config)
446
+ self.pruned_heads = set()
447
+
448
+ def prune_heads(self, heads):
449
+ if len(heads) == 0:
450
+ return
451
+ heads, index = find_pruneable_heads_and_indices(
452
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
453
+ )
454
+
455
+ # Prune linear layers
456
+ self.self.query = prune_linear_layer(self.self.query, index)
457
+ self.self.key = prune_linear_layer(self.self.key, index)
458
+ self.self.value = prune_linear_layer(self.self.value, index)
459
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
460
+
461
+ # Update hyper params and store pruned heads
462
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
463
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
464
+ self.pruned_heads = self.pruned_heads.union(heads)
465
+
466
+ def forward(
467
+ self,
468
+ hidden_states: torch.Tensor,
469
+ attention_mask: Optional[torch.FloatTensor] = None,
470
+ head_mask: Optional[torch.FloatTensor] = None,
471
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
472
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
473
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
474
+ output_attentions: Optional[bool] = False,
475
+ ) -> Tuple[torch.Tensor]:
476
+ self_outputs = self.self(
477
+ hidden_states,
478
+ attention_mask,
479
+ head_mask,
480
+ encoder_hidden_states,
481
+ encoder_attention_mask,
482
+ past_key_value,
483
+ output_attentions,
484
+ )
485
+ attention_output = self.output(self_outputs[0], hidden_states)
486
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
487
+ return outputs
488
+
489
+
490
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->RoCBert
491
+ class RoCBertIntermediate(nn.Module):
492
+ def __init__(self, config):
493
+ super().__init__()
494
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
495
+ if isinstance(config.hidden_act, str):
496
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
497
+ else:
498
+ self.intermediate_act_fn = config.hidden_act
499
+
500
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
501
+ hidden_states = self.dense(hidden_states)
502
+ hidden_states = self.intermediate_act_fn(hidden_states)
503
+ return hidden_states
504
+
505
+
506
+ # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->RoCBert
507
+ class RoCBertOutput(nn.Module):
508
+ def __init__(self, config):
509
+ super().__init__()
510
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
511
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
512
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
513
+
514
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
515
+ hidden_states = self.dense(hidden_states)
516
+ hidden_states = self.dropout(hidden_states)
517
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
518
+ return hidden_states
519
+
520
+
521
+ # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->RoCBert
522
+ class RoCBertLayer(nn.Module):
523
+ def __init__(self, config):
524
+ super().__init__()
525
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
526
+ self.seq_len_dim = 1
527
+ self.attention = RoCBertAttention(config)
528
+ self.is_decoder = config.is_decoder
529
+ self.add_cross_attention = config.add_cross_attention
530
+ if self.add_cross_attention:
531
+ if not self.is_decoder:
532
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
533
+ self.crossattention = RoCBertAttention(config, position_embedding_type="absolute")
534
+ self.intermediate = RoCBertIntermediate(config)
535
+ self.output = RoCBertOutput(config)
536
+
537
+ def forward(
538
+ self,
539
+ hidden_states: torch.Tensor,
540
+ attention_mask: Optional[torch.FloatTensor] = None,
541
+ head_mask: Optional[torch.FloatTensor] = None,
542
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
543
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
544
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
545
+ output_attentions: Optional[bool] = False,
546
+ ) -> Tuple[torch.Tensor]:
547
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
548
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
549
+ self_attention_outputs = self.attention(
550
+ hidden_states,
551
+ attention_mask,
552
+ head_mask,
553
+ output_attentions=output_attentions,
554
+ past_key_value=self_attn_past_key_value,
555
+ )
556
+ attention_output = self_attention_outputs[0]
557
+
558
+ # if decoder, the last output is tuple of self-attn cache
559
+ if self.is_decoder:
560
+ outputs = self_attention_outputs[1:-1]
561
+ present_key_value = self_attention_outputs[-1]
562
+ else:
563
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
564
+
565
+ cross_attn_present_key_value = None
566
+ if self.is_decoder and encoder_hidden_states is not None:
567
+ if not hasattr(self, "crossattention"):
568
+ raise ValueError(
569
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
570
+ " by setting `config.add_cross_attention=True`"
571
+ )
572
+
573
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
574
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
575
+ cross_attention_outputs = self.crossattention(
576
+ attention_output,
577
+ attention_mask,
578
+ head_mask,
579
+ encoder_hidden_states,
580
+ encoder_attention_mask,
581
+ cross_attn_past_key_value,
582
+ output_attentions,
583
+ )
584
+ attention_output = cross_attention_outputs[0]
585
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
586
+
587
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
588
+ cross_attn_present_key_value = cross_attention_outputs[-1]
589
+ present_key_value = present_key_value + cross_attn_present_key_value
590
+
591
+ layer_output = apply_chunking_to_forward(
592
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
593
+ )
594
+ outputs = (layer_output,) + outputs
595
+
596
+ # if decoder, return the attn key/values as the last output
597
+ if self.is_decoder:
598
+ outputs = outputs + (present_key_value,)
599
+
600
+ return outputs
601
+
602
+ def feed_forward_chunk(self, attention_output):
603
+ intermediate_output = self.intermediate(attention_output)
604
+ layer_output = self.output(intermediate_output, attention_output)
605
+ return layer_output
606
+
607
+
608
+ # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->RoCBert
609
+ class RoCBertEncoder(nn.Module):
610
+ def __init__(self, config):
611
+ super().__init__()
612
+ self.config = config
613
+ self.layer = nn.ModuleList([RoCBertLayer(config) for _ in range(config.num_hidden_layers)])
614
+ self.gradient_checkpointing = False
615
+
616
+ def forward(
617
+ self,
618
+ hidden_states: torch.Tensor,
619
+ attention_mask: Optional[torch.FloatTensor] = None,
620
+ head_mask: Optional[torch.FloatTensor] = None,
621
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
622
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
623
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
624
+ use_cache: Optional[bool] = None,
625
+ output_attentions: Optional[bool] = False,
626
+ output_hidden_states: Optional[bool] = False,
627
+ return_dict: Optional[bool] = True,
628
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
629
+ all_hidden_states = () if output_hidden_states else None
630
+ all_self_attentions = () if output_attentions else None
631
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
632
+
633
+ if self.gradient_checkpointing and self.training:
634
+ if use_cache:
635
+ logger.warning_once(
636
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
637
+ )
638
+ use_cache = False
639
+
640
+ next_decoder_cache = () if use_cache else None
641
+ for i, layer_module in enumerate(self.layer):
642
+ if output_hidden_states:
643
+ all_hidden_states = all_hidden_states + (hidden_states,)
644
+
645
+ layer_head_mask = head_mask[i] if head_mask is not None else None
646
+ past_key_value = past_key_values[i] if past_key_values is not None else None
647
+
648
+ if self.gradient_checkpointing and self.training:
649
+ layer_outputs = self._gradient_checkpointing_func(
650
+ layer_module.__call__,
651
+ hidden_states,
652
+ attention_mask,
653
+ layer_head_mask,
654
+ encoder_hidden_states,
655
+ encoder_attention_mask,
656
+ past_key_value,
657
+ output_attentions,
658
+ )
659
+ else:
660
+ layer_outputs = layer_module(
661
+ hidden_states,
662
+ attention_mask,
663
+ layer_head_mask,
664
+ encoder_hidden_states,
665
+ encoder_attention_mask,
666
+ past_key_value,
667
+ output_attentions,
668
+ )
669
+
670
+ hidden_states = layer_outputs[0]
671
+ if use_cache:
672
+ next_decoder_cache += (layer_outputs[-1],)
673
+ if output_attentions:
674
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
675
+ if self.config.add_cross_attention:
676
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
677
+
678
+ if output_hidden_states:
679
+ all_hidden_states = all_hidden_states + (hidden_states,)
680
+
681
+ if not return_dict:
682
+ return tuple(
683
+ v
684
+ for v in [
685
+ hidden_states,
686
+ next_decoder_cache,
687
+ all_hidden_states,
688
+ all_self_attentions,
689
+ all_cross_attentions,
690
+ ]
691
+ if v is not None
692
+ )
693
+ return BaseModelOutputWithPastAndCrossAttentions(
694
+ last_hidden_state=hidden_states,
695
+ past_key_values=next_decoder_cache,
696
+ hidden_states=all_hidden_states,
697
+ attentions=all_self_attentions,
698
+ cross_attentions=all_cross_attentions,
699
+ )
700
+
701
+
702
+ # Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->RoCBert
703
+ class RoCBertPooler(nn.Module):
704
+ def __init__(self, config):
705
+ super().__init__()
706
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
707
+ self.activation = nn.Tanh()
708
+
709
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
710
+ # We "pool" the model by simply taking the hidden state corresponding
711
+ # to the first token.
712
+ first_token_tensor = hidden_states[:, 0]
713
+ pooled_output = self.dense(first_token_tensor)
714
+ pooled_output = self.activation(pooled_output)
715
+ return pooled_output
716
+
717
+
718
+ # Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->RoCBert
719
+ class RoCBertPredictionHeadTransform(nn.Module):
720
+ def __init__(self, config):
721
+ super().__init__()
722
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
723
+ if isinstance(config.hidden_act, str):
724
+ self.transform_act_fn = ACT2FN[config.hidden_act]
725
+ else:
726
+ self.transform_act_fn = config.hidden_act
727
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
728
+
729
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
730
+ hidden_states = self.dense(hidden_states)
731
+ hidden_states = self.transform_act_fn(hidden_states)
732
+ hidden_states = self.LayerNorm(hidden_states)
733
+ return hidden_states
734
+
735
+
736
+ # Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->RoCBert
737
+ class RoCBertLMPredictionHead(nn.Module):
738
+ def __init__(self, config):
739
+ super().__init__()
740
+ self.transform = RoCBertPredictionHeadTransform(config)
741
+
742
+ # The output weights are the same as the input embeddings, but there is
743
+ # an output-only bias for each token.
744
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
745
+
746
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
747
+
748
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
749
+ self.decoder.bias = self.bias
750
+
751
+ def _tie_weights(self):
752
+ self.decoder.bias = self.bias
753
+
754
+ def forward(self, hidden_states):
755
+ hidden_states = self.transform(hidden_states)
756
+ hidden_states = self.decoder(hidden_states)
757
+ return hidden_states
758
+
759
+
760
+ # Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->RoCBert
761
+ class RoCBertOnlyMLMHead(nn.Module):
762
+ def __init__(self, config):
763
+ super().__init__()
764
+ self.predictions = RoCBertLMPredictionHead(config)
765
+
766
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
767
+ prediction_scores = self.predictions(sequence_output)
768
+ return prediction_scores
769
+
770
+
771
+ class RoCBertPreTrainedModel(PreTrainedModel):
772
+ """
773
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
774
+ models.
775
+ """
776
+
777
+ config_class = RoCBertConfig
778
+ load_tf_weights = load_tf_weights_in_roc_bert
779
+ base_model_prefix = "roc_bert"
780
+ supports_gradient_checkpointing = True
781
+
782
+ def _init_weights(self, module):
783
+ """Initialize the weights"""
784
+ if isinstance(module, nn.Linear):
785
+ # Slightly different from the TF version which uses truncated_normal for initialization
786
+ # cf https://github.com/pytorch/pytorch/pull/5617
787
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
788
+ if module.bias is not None:
789
+ module.bias.data.zero_()
790
+ elif isinstance(module, nn.Embedding):
791
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
792
+ if module.padding_idx is not None:
793
+ module.weight.data[module.padding_idx].zero_()
794
+ elif isinstance(module, nn.LayerNorm):
795
+ module.bias.data.zero_()
796
+ module.weight.data.fill_(1.0)
797
+ elif isinstance(module, RoCBertLMPredictionHead):
798
+ module.bias.data.zero_()
799
+
800
+
801
+ ROC_BERT_START_DOCSTRING = r"""
802
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
803
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
804
+ behavior.
805
+
806
+ Parameters:
807
+ config ([`RoCBertConfig`]): Model configuration class with all the parameters of the model.
808
+ Initializing with a config file does not load the weights associated with the model, only the
809
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
810
+ """
811
+
812
+ ROC_BERT_INPUTS_DOCSTRING = r"""
813
+ Args:
814
+ input_ids (`torch.LongTensor` of shape `({0})`):
815
+ Indices of input sequence tokens in the vocabulary.
816
+
817
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
818
+ [`PreTrainedTokenizer.__call__`] for details.
819
+
820
+ [What are input IDs?](../glossary#input-ids)
821
+ input_shape_ids (`torch.LongTensor` of shape `({0})`):
822
+ Indices of input sequence tokens in the shape vocabulary.
823
+
824
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
825
+ [`PreTrainedTokenizer.__call__`] for details.
826
+
827
+ [What are input IDs?](../glossary#input_shape_ids)
828
+ input_pronunciation_ids (`torch.LongTensor` of shape `({0})`):
829
+ Indices of input sequence tokens in the pronunciation vocabulary.
830
+
831
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
832
+ [`PreTrainedTokenizer.__call__`] for details.
833
+
834
+ [What are input IDs?](../glossary#input_pronunciation_ids)
835
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
836
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
837
+
838
+ - 1 for tokens that are **not masked**,
839
+ - 0 for tokens that are **masked**.
840
+
841
+ [What are attention masks?](../glossary#attention-mask)
842
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
843
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
844
+ 1]`:
845
+
846
+ - 0 corresponds to a *sentence A* token,
847
+ - 1 corresponds to a *sentence B* token.
848
+
849
+ [What are token type IDs?](../glossary#token-type-ids)
850
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
851
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
852
+ config.max_position_embeddings - 1]`.
853
+
854
+ [What are position IDs?](../glossary#position-ids)
855
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
856
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
857
+
858
+ - 1 indicates the head is **not masked**,
859
+ - 0 indicates the head is **masked**.
860
+
861
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
862
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
863
+ is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
864
+ model's internal embedding lookup matrix.
865
+ output_attentions (`bool`, *optional*):
866
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
867
+ tensors for more detail.
868
+ output_hidden_states (`bool`, *optional*):
869
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
870
+ more detail.
871
+ return_dict (`bool`, *optional*):
872
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
873
+ """
874
+
875
+
876
+ @add_start_docstrings(
877
+ "The bare RoCBert Model transformer outputting raw hidden-states without any specific head on top.",
878
+ ROC_BERT_START_DOCSTRING,
879
+ )
880
+ class RoCBertModel(RoCBertPreTrainedModel):
881
+ """
882
+
883
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
884
+ cross-attention is added between the self-attention layers, following the architecture described in [Attention is
885
+ all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
886
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
887
+
888
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
889
+ to `True`. To be used in a Seq2Seq model, the model needs to be initialized with both `is_decoder` argument and
890
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
891
+ """
892
+
893
+ # Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->RoCBert
894
+ def __init__(self, config, add_pooling_layer=True):
895
+ super().__init__(config)
896
+ self.config = config
897
+
898
+ self.embeddings = RoCBertEmbeddings(config)
899
+ self.encoder = RoCBertEncoder(config)
900
+
901
+ self.pooler = RoCBertPooler(config) if add_pooling_layer else None
902
+
903
+ # Initialize weights and apply final processing
904
+ self.post_init()
905
+
906
+ # Copied from transformers.models.bert.modeling_bert.BertModel.get_input_embeddings
907
+ def get_input_embeddings(self):
908
+ return self.embeddings.word_embeddings
909
+
910
+ # Copied from transformers.models.bert.modeling_bert.BertModel.set_input_embeddings
911
+ def set_input_embeddings(self, value):
912
+ self.embeddings.word_embeddings = value
913
+
914
+ def get_pronunciation_embeddings(self):
915
+ return self.embeddings.pronunciation_embed
916
+
917
+ def set_pronunciation_embeddings(self, value):
918
+ self.embeddings.pronunciation_embed = value
919
+
920
+ def get_shape_embeddings(self):
921
+ return self.embeddings.shape_embed
922
+
923
+ def set_shape_embeddings(self, value):
924
+ self.embeddings.shape_embed = value
925
+
926
+ # Copied from transformers.models.bert.modeling_bert.BertModel._prune_heads
927
+ def _prune_heads(self, heads_to_prune):
928
+ """
929
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
930
+ class PreTrainedModel
931
+ """
932
+ for layer, heads in heads_to_prune.items():
933
+ self.encoder.layer[layer].attention.prune_heads(heads)
934
+
935
+ @add_start_docstrings_to_model_forward(ROC_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
936
+ @add_code_sample_docstrings(
937
+ checkpoint=_CHECKPOINT_FOR_DOC,
938
+ output_type=BaseModelOutputWithPoolingAndCrossAttentions,
939
+ config_class=_CONFIG_FOR_DOC,
940
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
941
+ )
942
+ def forward(
943
+ self,
944
+ input_ids: Optional[torch.Tensor] = None,
945
+ input_shape_ids: Optional[torch.Tensor] = None,
946
+ input_pronunciation_ids: Optional[torch.Tensor] = None,
947
+ attention_mask: Optional[torch.Tensor] = None,
948
+ token_type_ids: Optional[torch.Tensor] = None,
949
+ position_ids: Optional[torch.Tensor] = None,
950
+ head_mask: Optional[torch.Tensor] = None,
951
+ inputs_embeds: Optional[torch.Tensor] = None,
952
+ encoder_hidden_states: Optional[torch.Tensor] = None,
953
+ encoder_attention_mask: Optional[torch.Tensor] = None,
954
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
955
+ use_cache: Optional[bool] = None,
956
+ output_attentions: Optional[bool] = None,
957
+ output_hidden_states: Optional[bool] = None,
958
+ return_dict: Optional[bool] = None,
959
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
960
+ r"""
961
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
962
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
963
+ the model is configured as a decoder.
964
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
965
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
966
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
967
+
968
+ - 1 for tokens that are **not masked**,
969
+ - 0 for tokens that are **masked**.
970
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
971
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
972
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
973
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
974
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
975
+ use_cache (`bool`, *optional*):
976
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
977
+ `past_key_values`).
978
+ """
979
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
980
+ output_hidden_states = (
981
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
982
+ )
983
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
984
+
985
+ if self.config.is_decoder:
986
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
987
+ else:
988
+ use_cache = False
989
+
990
+ if input_ids is not None and inputs_embeds is not None:
991
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
992
+ elif input_ids is not None:
993
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
994
+ input_shape = input_ids.size()
995
+ elif inputs_embeds is not None:
996
+ input_shape = inputs_embeds.size()[:-1]
997
+ else:
998
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
999
+
1000
+ batch_size, seq_length = input_shape
1001
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1002
+
1003
+ # past_key_values_length
1004
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
1005
+
1006
+ if attention_mask is None:
1007
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
1008
+
1009
+ if token_type_ids is None:
1010
+ if hasattr(self.embeddings, "token_type_ids"):
1011
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
1012
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
1013
+ token_type_ids = buffered_token_type_ids_expanded
1014
+ else:
1015
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
1016
+
1017
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
1018
+ # ourselves in which case we just need to make it broadcastable to all heads.
1019
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
1020
+
1021
+ # If a 2D or 3D attention mask is provided for the cross-attention
1022
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
1023
+ if self.config.is_decoder and encoder_hidden_states is not None:
1024
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
1025
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
1026
+ if encoder_attention_mask is None:
1027
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
1028
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
1029
+ else:
1030
+ encoder_extended_attention_mask = None
1031
+
1032
+ # Prepare head mask if needed
1033
+ # 1.0 in head_mask indicate we keep the head
1034
+ # attention_probs has shape bsz x n_heads x N x N
1035
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1036
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1037
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1038
+
1039
+ embedding_output = self.embeddings(
1040
+ input_ids=input_ids,
1041
+ input_shape_ids=input_shape_ids,
1042
+ input_pronunciation_ids=input_pronunciation_ids,
1043
+ position_ids=position_ids,
1044
+ token_type_ids=token_type_ids,
1045
+ inputs_embeds=inputs_embeds,
1046
+ past_key_values_length=past_key_values_length,
1047
+ )
1048
+ encoder_outputs = self.encoder(
1049
+ embedding_output,
1050
+ attention_mask=extended_attention_mask,
1051
+ head_mask=head_mask,
1052
+ encoder_hidden_states=encoder_hidden_states,
1053
+ encoder_attention_mask=encoder_extended_attention_mask,
1054
+ past_key_values=past_key_values,
1055
+ use_cache=use_cache,
1056
+ output_attentions=output_attentions,
1057
+ output_hidden_states=output_hidden_states,
1058
+ return_dict=return_dict,
1059
+ )
1060
+ sequence_output = encoder_outputs[0]
1061
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1062
+
1063
+ if not return_dict:
1064
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
1065
+
1066
+ return BaseModelOutputWithPoolingAndCrossAttentions(
1067
+ last_hidden_state=sequence_output,
1068
+ pooler_output=pooled_output,
1069
+ past_key_values=encoder_outputs.past_key_values,
1070
+ hidden_states=encoder_outputs.hidden_states,
1071
+ attentions=encoder_outputs.attentions,
1072
+ cross_attentions=encoder_outputs.cross_attentions,
1073
+ )
1074
+
1075
+
1076
+ @add_start_docstrings(
1077
+ """
1078
+ RoCBert Model with contrastive loss and masked_lm_loss during the pretraining.
1079
+ """,
1080
+ ROC_BERT_START_DOCSTRING,
1081
+ )
1082
+ class RoCBertForPreTraining(RoCBertPreTrainedModel):
1083
+ _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
1084
+
1085
+ def __init__(self, config):
1086
+ super().__init__(config)
1087
+
1088
+ self.roc_bert = RoCBertModel(config)
1089
+ self.cls = RoCBertOnlyMLMHead(config)
1090
+
1091
+ # Initialize weights and apply final processing
1092
+ self.post_init()
1093
+
1094
+ # Copied from transformers.models.bert.modeling_bert.BertForPreTraining.get_output_embeddings
1095
+ def get_output_embeddings(self):
1096
+ return self.cls.predictions.decoder
1097
+
1098
+ # Copied from transformers.models.bert.modeling_bert.BertForPreTraining.set_output_embeddings
1099
+ def set_output_embeddings(self, new_embeddings):
1100
+ self.cls.predictions.decoder = new_embeddings
1101
+ self.cls.predictions.bias = new_embeddings.bias
1102
+
1103
+ @add_start_docstrings_to_model_forward(ROC_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1104
+ @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
1105
+ def forward(
1106
+ self,
1107
+ input_ids: Optional[torch.Tensor] = None,
1108
+ input_shape_ids: Optional[torch.Tensor] = None,
1109
+ input_pronunciation_ids: Optional[torch.Tensor] = None,
1110
+ attention_mask: Optional[torch.Tensor] = None,
1111
+ token_type_ids: Optional[torch.Tensor] = None,
1112
+ attack_input_ids: Optional[torch.Tensor] = None,
1113
+ attack_input_shape_ids: Optional[torch.Tensor] = None,
1114
+ attack_input_pronunciation_ids: Optional[torch.Tensor] = None,
1115
+ attack_attention_mask: Optional[torch.Tensor] = None,
1116
+ attack_token_type_ids: Optional[torch.Tensor] = None,
1117
+ position_ids: Optional[torch.Tensor] = None,
1118
+ head_mask: Optional[torch.Tensor] = None,
1119
+ inputs_embeds: Optional[torch.Tensor] = None,
1120
+ labels_input_ids: Optional[torch.Tensor] = None,
1121
+ labels_input_shape_ids: Optional[torch.Tensor] = None,
1122
+ labels_input_pronunciation_ids: Optional[torch.Tensor] = None,
1123
+ labels_attention_mask: Optional[torch.Tensor] = None,
1124
+ labels_token_type_ids: Optional[torch.Tensor] = None,
1125
+ output_attentions: Optional[bool] = None,
1126
+ output_hidden_states: Optional[bool] = None,
1127
+ return_dict: Optional[bool] = None,
1128
+ **kwargs,
1129
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
1130
+ r"""
1131
+ attack_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1132
+ attack sample ids for computing the contrastive loss. Indices should be in `[-100, 0, ...,
1133
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked),
1134
+ the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1135
+ attack_input_shape_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1136
+ attack sample shape ids for computing the contrastive loss. Indices should be in `[-100, 0, ...,
1137
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked),
1138
+ the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1139
+ attack_input_pronunciation_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1140
+ attack sample pronunciation ids for computing the contrastive loss. Indices should be in `[-100, 0,
1141
+ ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored
1142
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1143
+ labels_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1144
+ target ids for computing the contrastive loss and masked_lm_loss . Indices should be in `[-100, 0, ...,
1145
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked),
1146
+ the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1147
+ labels_input_shape_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1148
+ target shape ids for computing the contrastive loss and masked_lm_loss . Indices should be in `[-100,
1149
+ 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored
1150
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1151
+ labels_input_pronunciation_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1152
+ target pronunciation ids for computing the contrastive loss and masked_lm_loss . Indices should be in
1153
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
1154
+ ignored (masked), the loss is only computed for the tokens with labels in `[0, ...,
1155
+ config.vocab_size]`
1156
+
1157
+ kwargs (`Dict[str, any]`, *optional*, defaults to *{}*):
1158
+ Used to hide legacy arguments that have been deprecated.
1159
+
1160
+ Returns:
1161
+
1162
+ Example:
1163
+
1164
+ ```python
1165
+ >>> from transformers import AutoTokenizer, RoCBertForPreTraining
1166
+ >>> import torch
1167
+
1168
+ >>> tokenizer = AutoTokenizer.from_pretrained("weiweishi/roc-bert-base-zh")
1169
+ >>> model = RoCBertForPreTraining.from_pretrained("weiweishi/roc-bert-base-zh")
1170
+
1171
+ >>> inputs = tokenizer("你好,很高兴认识你", return_tensors="pt")
1172
+ >>> attack_inputs = {}
1173
+ >>> for key in list(inputs.keys()):
1174
+ ... attack_inputs[f"attack_{key}"] = inputs[key]
1175
+ >>> label_inputs = {}
1176
+ >>> for key in list(inputs.keys()):
1177
+ ... label_inputs[f"labels_{key}"] = inputs[key]
1178
+
1179
+ >>> inputs.update(label_inputs)
1180
+ >>> inputs.update(attack_inputs)
1181
+ >>> outputs = model(**inputs)
1182
+
1183
+ >>> logits = outputs.logits
1184
+ >>> logits.shape
1185
+ torch.Size([1, 11, 21128])
1186
+ ```
1187
+ """
1188
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1189
+
1190
+ outputs = self.roc_bert(
1191
+ input_ids,
1192
+ input_shape_ids=input_shape_ids,
1193
+ input_pronunciation_ids=input_pronunciation_ids,
1194
+ attention_mask=attention_mask,
1195
+ token_type_ids=token_type_ids,
1196
+ position_ids=position_ids,
1197
+ head_mask=head_mask,
1198
+ inputs_embeds=inputs_embeds,
1199
+ output_attentions=output_attentions,
1200
+ output_hidden_states=output_hidden_states,
1201
+ return_dict=return_dict,
1202
+ )
1203
+
1204
+ sequence_output, pooled_output = outputs[:2]
1205
+ prediction_scores = self.cls(sequence_output)
1206
+
1207
+ loss = None
1208
+ if labels_input_ids is not None:
1209
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1210
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels_input_ids.view(-1))
1211
+
1212
+ if attack_input_ids is not None:
1213
+ batch_size, _ = labels_input_ids.shape
1214
+ device = labels_input_ids.device
1215
+
1216
+ target_inputs = torch.clone(labels_input_ids)
1217
+ target_inputs[target_inputs == -100] = self.config.pad_token_id
1218
+
1219
+ labels_output = self.roc_bert(
1220
+ target_inputs,
1221
+ input_shape_ids=labels_input_shape_ids,
1222
+ input_pronunciation_ids=labels_input_pronunciation_ids,
1223
+ attention_mask=labels_attention_mask,
1224
+ token_type_ids=labels_token_type_ids,
1225
+ return_dict=return_dict,
1226
+ )
1227
+ attack_output = self.roc_bert(
1228
+ attack_input_ids,
1229
+ input_shape_ids=attack_input_shape_ids,
1230
+ input_pronunciation_ids=attack_input_pronunciation_ids,
1231
+ attention_mask=attack_attention_mask,
1232
+ token_type_ids=attack_token_type_ids,
1233
+ return_dict=return_dict,
1234
+ )
1235
+
1236
+ labels_pooled_output = labels_output[1]
1237
+ attack_pooled_output = attack_output[1]
1238
+
1239
+ pooled_output_norm = torch.nn.functional.normalize(pooled_output, dim=-1)
1240
+ labels_pooled_output_norm = torch.nn.functional.normalize(labels_pooled_output, dim=-1)
1241
+ attack_pooled_output_norm = torch.nn.functional.normalize(attack_pooled_output, dim=-1)
1242
+
1243
+ sim_matrix = torch.matmul(pooled_output_norm, attack_pooled_output_norm.T) # batch_size * hidden_dim
1244
+ sim_matrix_target = torch.matmul(labels_pooled_output_norm, attack_pooled_output_norm.T)
1245
+ batch_labels = torch.tensor(list(range(batch_size)), device=device)
1246
+ contrastive_loss = (
1247
+ loss_fct(100 * sim_matrix.view(batch_size, -1), batch_labels.view(-1))
1248
+ + loss_fct(100 * sim_matrix_target.view(batch_size, -1), batch_labels.view(-1))
1249
+ ) / 2
1250
+
1251
+ loss = contrastive_loss + masked_lm_loss
1252
+ else:
1253
+ loss = masked_lm_loss
1254
+
1255
+ if not return_dict:
1256
+ output = (prediction_scores,) + outputs[2:]
1257
+ return ((loss,) + output) if loss is not None else output
1258
+
1259
+ return MaskedLMOutput(
1260
+ loss=loss,
1261
+ logits=prediction_scores,
1262
+ hidden_states=outputs.hidden_states,
1263
+ attentions=outputs.attentions,
1264
+ )
1265
+
1266
+
1267
+ @add_start_docstrings("""RoCBert Model with a `language modeling` head on top.""", ROC_BERT_START_DOCSTRING)
1268
+ class RoCBertForMaskedLM(RoCBertPreTrainedModel):
1269
+ _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
1270
+
1271
+ # Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.__init__ with Bert->RoCBert,bert->roc_bert
1272
+ def __init__(self, config):
1273
+ super().__init__(config)
1274
+
1275
+ if config.is_decoder:
1276
+ logger.warning(
1277
+ "If you want to use `RoCBertForMaskedLM` make sure `config.is_decoder=False` for "
1278
+ "bi-directional self-attention."
1279
+ )
1280
+
1281
+ self.roc_bert = RoCBertModel(config, add_pooling_layer=False)
1282
+ self.cls = RoCBertOnlyMLMHead(config)
1283
+
1284
+ # Initialize weights and apply final processing
1285
+ self.post_init()
1286
+
1287
+ # Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.get_output_embeddings
1288
+ def get_output_embeddings(self):
1289
+ return self.cls.predictions.decoder
1290
+
1291
+ # Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.set_output_embeddings
1292
+ def set_output_embeddings(self, new_embeddings):
1293
+ self.cls.predictions.decoder = new_embeddings
1294
+ self.cls.predictions.bias = new_embeddings.bias
1295
+
1296
+ @add_start_docstrings_to_model_forward(ROC_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1297
+ def forward(
1298
+ self,
1299
+ input_ids: Optional[torch.Tensor] = None,
1300
+ input_shape_ids: Optional[torch.Tensor] = None,
1301
+ input_pronunciation_ids: Optional[torch.Tensor] = None,
1302
+ attention_mask: Optional[torch.Tensor] = None,
1303
+ token_type_ids: Optional[torch.Tensor] = None,
1304
+ position_ids: Optional[torch.Tensor] = None,
1305
+ head_mask: Optional[torch.Tensor] = None,
1306
+ inputs_embeds: Optional[torch.Tensor] = None,
1307
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1308
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1309
+ labels: Optional[torch.Tensor] = None,
1310
+ output_attentions: Optional[bool] = None,
1311
+ output_hidden_states: Optional[bool] = None,
1312
+ return_dict: Optional[bool] = None,
1313
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
1314
+ r"""
1315
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1316
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1317
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1318
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1319
+
1320
+ Example:
1321
+ ```python
1322
+ >>> from transformers import AutoTokenizer, RoCBertForMaskedLM
1323
+ >>> import torch
1324
+
1325
+ >>> tokenizer = AutoTokenizer.from_pretrained("weiweishi/roc-bert-base-zh")
1326
+ >>> model = RoCBertForMaskedLM.from_pretrained("weiweishi/roc-bert-base-zh")
1327
+
1328
+ >>> inputs = tokenizer("法国是首都[MASK].", return_tensors="pt")
1329
+
1330
+ >>> with torch.no_grad():
1331
+ ... logits = model(**inputs).logits
1332
+
1333
+ >>> # retrieve index of {mask}
1334
+ >>> mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
1335
+
1336
+ >>> predicted_token_id = logits[0, mask_token_index].argmax(axis=-1)
1337
+ >>> tokenizer.decode(predicted_token_id)
1338
+ '.'
1339
+ ```
1340
+ """
1341
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1342
+
1343
+ outputs = self.roc_bert(
1344
+ input_ids,
1345
+ input_shape_ids=input_shape_ids,
1346
+ input_pronunciation_ids=input_pronunciation_ids,
1347
+ attention_mask=attention_mask,
1348
+ token_type_ids=token_type_ids,
1349
+ position_ids=position_ids,
1350
+ head_mask=head_mask,
1351
+ inputs_embeds=inputs_embeds,
1352
+ encoder_hidden_states=encoder_hidden_states,
1353
+ encoder_attention_mask=encoder_attention_mask,
1354
+ output_attentions=output_attentions,
1355
+ output_hidden_states=output_hidden_states,
1356
+ return_dict=return_dict,
1357
+ )
1358
+
1359
+ sequence_output = outputs[0]
1360
+ prediction_scores = self.cls(sequence_output)
1361
+
1362
+ masked_lm_loss = None
1363
+ if labels is not None:
1364
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1365
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1366
+
1367
+ if not return_dict:
1368
+ output = (prediction_scores,) + outputs[2:]
1369
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1370
+
1371
+ return MaskedLMOutput(
1372
+ loss=masked_lm_loss,
1373
+ logits=prediction_scores,
1374
+ hidden_states=outputs.hidden_states,
1375
+ attentions=outputs.attentions,
1376
+ )
1377
+
1378
+ def prepare_inputs_for_generation(
1379
+ self, input_ids, input_shape_ids=None, input_pronunciation_ids=None, attention_mask=None, **model_kwargs
1380
+ ):
1381
+ input_shape = input_ids.shape
1382
+ effective_batch_size = input_shape[0]
1383
+
1384
+ # add a dummy token
1385
+ if self.config.pad_token_id is None:
1386
+ raise ValueError("The PAD token should be defined for generation")
1387
+
1388
+ attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
1389
+ dummy_token = torch.full(
1390
+ (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
1391
+ )
1392
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
1393
+ if input_shape_ids is not None:
1394
+ input_shape_ids = torch.cat([input_shape_ids, dummy_token], dim=1)
1395
+ if input_pronunciation_ids is not None:
1396
+ input_pronunciation_ids = torch.cat([input_pronunciation_ids, dummy_token], dim=1)
1397
+
1398
+ return {
1399
+ "input_ids": input_ids,
1400
+ "input_shape_ids": input_shape_ids,
1401
+ "input_pronunciation_ids": input_pronunciation_ids,
1402
+ "attention_mask": attention_mask,
1403
+ }
1404
+
1405
+
1406
+ @add_start_docstrings(
1407
+ """RoCBert Model with a `language modeling` head on top for CLM fine-tuning.""", ROC_BERT_START_DOCSTRING
1408
+ )
1409
+ class RoCBertForCausalLM(RoCBertPreTrainedModel, GenerationMixin):
1410
+ _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
1411
+
1412
+ # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.__init__ with BertLMHeadModel->RoCBertForCausalLM,Bert->RoCBert,bert->roc_bert
1413
+ def __init__(self, config):
1414
+ super().__init__(config)
1415
+
1416
+ if not config.is_decoder:
1417
+ logger.warning("If you want to use `RoCRoCBertForCausalLM` as a standalone, add `is_decoder=True.`")
1418
+
1419
+ self.roc_bert = RoCBertModel(config, add_pooling_layer=False)
1420
+ self.cls = RoCBertOnlyMLMHead(config)
1421
+
1422
+ # Initialize weights and apply final processing
1423
+ self.post_init()
1424
+
1425
+ # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.get_output_embeddings
1426
+ def get_output_embeddings(self):
1427
+ return self.cls.predictions.decoder
1428
+
1429
+ # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.set_output_embeddings
1430
+ def set_output_embeddings(self, new_embeddings):
1431
+ self.cls.predictions.decoder = new_embeddings
1432
+ self.cls.predictions.bias = new_embeddings.bias
1433
+
1434
+ @add_start_docstrings_to_model_forward(ROC_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1435
+ @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
1436
+ def forward(
1437
+ self,
1438
+ input_ids: Optional[torch.Tensor] = None,
1439
+ input_shape_ids: Optional[torch.Tensor] = None,
1440
+ input_pronunciation_ids: Optional[torch.Tensor] = None,
1441
+ attention_mask: Optional[torch.Tensor] = None,
1442
+ token_type_ids: Optional[torch.Tensor] = None,
1443
+ position_ids: Optional[torch.Tensor] = None,
1444
+ inputs_embeds: Optional[torch.Tensor] = None,
1445
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1446
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1447
+ head_mask: Optional[torch.Tensor] = None,
1448
+ past_key_values: Optional[List[torch.Tensor]] = None,
1449
+ labels: Optional[torch.Tensor] = None,
1450
+ use_cache: Optional[bool] = None,
1451
+ output_attentions: Optional[bool] = None,
1452
+ output_hidden_states: Optional[bool] = None,
1453
+ return_dict: Optional[bool] = None,
1454
+ **kwargs,
1455
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
1456
+ r"""
1457
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1458
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1459
+ the model is configured as a decoder.
1460
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1461
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1462
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
1463
+
1464
+ - 1 for tokens that are **not masked**,
1465
+ - 0 for tokens that are **masked**.
1466
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1467
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
1468
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
1469
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional tensors are
1470
+ only required when the model is used as a decoder in a Sequence to Sequence model.
1471
+
1472
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1473
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
1474
+
1475
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1476
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1477
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1478
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1479
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1480
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
1481
+ ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`.
1482
+ use_cache (`bool`, *optional*):
1483
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1484
+ `past_key_values`).
1485
+
1486
+ Returns:
1487
+
1488
+ Example:
1489
+
1490
+ ```python
1491
+ >>> from transformers import AutoTokenizer, RoCBertForCausalLM, RoCBertConfig
1492
+ >>> import torch
1493
+
1494
+ >>> tokenizer = AutoTokenizer.from_pretrained("weiweishi/roc-bert-base-zh")
1495
+ >>> config = RoCBertConfig.from_pretrained("weiweishi/roc-bert-base-zh")
1496
+ >>> config.is_decoder = True
1497
+ >>> model = RoCBertForCausalLM.from_pretrained("weiweishi/roc-bert-base-zh", config=config)
1498
+
1499
+ >>> inputs = tokenizer("你好,很高兴认识你", return_tensors="pt")
1500
+ >>> outputs = model(**inputs)
1501
+
1502
+ >>> prediction_logits = outputs.logits
1503
+ ```
1504
+ """
1505
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1506
+
1507
+ outputs = self.roc_bert(
1508
+ input_ids,
1509
+ input_shape_ids=input_shape_ids,
1510
+ input_pronunciation_ids=input_pronunciation_ids,
1511
+ attention_mask=attention_mask,
1512
+ token_type_ids=token_type_ids,
1513
+ position_ids=position_ids,
1514
+ head_mask=head_mask,
1515
+ inputs_embeds=inputs_embeds,
1516
+ encoder_hidden_states=encoder_hidden_states,
1517
+ encoder_attention_mask=encoder_attention_mask,
1518
+ past_key_values=past_key_values,
1519
+ use_cache=use_cache,
1520
+ output_attentions=output_attentions,
1521
+ output_hidden_states=output_hidden_states,
1522
+ return_dict=return_dict,
1523
+ )
1524
+
1525
+ sequence_output = outputs[0]
1526
+ prediction_scores = self.cls(sequence_output)
1527
+
1528
+ lm_loss = None
1529
+ if labels is not None:
1530
+ lm_loss = self.loss_function(
1531
+ prediction_scores,
1532
+ labels,
1533
+ vocab_size=self.config.vocab_size,
1534
+ **kwargs,
1535
+ )
1536
+
1537
+ if not return_dict:
1538
+ output = (prediction_scores,) + outputs[2:]
1539
+ return ((lm_loss,) + output) if lm_loss is not None else output
1540
+
1541
+ return CausalLMOutputWithCrossAttentions(
1542
+ loss=lm_loss,
1543
+ logits=prediction_scores,
1544
+ past_key_values=outputs.past_key_values,
1545
+ hidden_states=outputs.hidden_states,
1546
+ attentions=outputs.attentions,
1547
+ cross_attentions=outputs.cross_attentions,
1548
+ )
1549
+
1550
+ def prepare_inputs_for_generation(
1551
+ self,
1552
+ input_ids,
1553
+ input_shape_ids=None,
1554
+ input_pronunciation_ids=None,
1555
+ past_key_values=None,
1556
+ attention_mask=None,
1557
+ **model_kwargs,
1558
+ ):
1559
+ # Overwritten -- `input_pronunciation_ids`
1560
+
1561
+ input_shape = input_ids.shape
1562
+
1563
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1564
+ if attention_mask is None:
1565
+ attention_mask = input_ids.new_ones(input_shape)
1566
+
1567
+ # cut decoder_input_ids if past_key_values is used
1568
+ if past_key_values is not None:
1569
+ past_length = past_key_values[0][0].shape[2]
1570
+
1571
+ # Some generation methods already pass only the last input ID
1572
+ if input_ids.shape[1] > past_length:
1573
+ remove_prefix_length = past_length
1574
+ else:
1575
+ # Default to old behavior: keep only final ID
1576
+ remove_prefix_length = input_ids.shape[1] - 1
1577
+
1578
+ input_ids = input_ids[:, remove_prefix_length:]
1579
+ if input_shape_ids is not None:
1580
+ input_shape_ids = input_shape_ids[:, -1:]
1581
+ if input_pronunciation_ids is not None:
1582
+ input_pronunciation_ids = input_pronunciation_ids[:, -1:]
1583
+
1584
+ return {
1585
+ "input_ids": input_ids,
1586
+ "input_shape_ids": input_shape_ids,
1587
+ "input_pronunciation_ids": input_pronunciation_ids,
1588
+ "attention_mask": attention_mask,
1589
+ "past_key_values": past_key_values,
1590
+ }
1591
+
1592
+ # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel._reorder_cache
1593
+ def _reorder_cache(self, past_key_values, beam_idx):
1594
+ reordered_past = ()
1595
+ for layer_past in past_key_values:
1596
+ reordered_past += (
1597
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1598
+ )
1599
+ return reordered_past
1600
+
1601
+
1602
+ @add_start_docstrings(
1603
+ """RoCBert Model transformer with a sequence classification/regression head on top (a linear layer on top of
1604
+ the pooled output) e.g. for GLUE tasks.""",
1605
+ ROC_BERT_START_DOCSTRING,
1606
+ )
1607
+ class RoCBertForSequenceClassification(RoCBertPreTrainedModel):
1608
+ # Copied from transformers.models.bert.modeling_bert.BertForSequenceClassification.__init__ with Bert->RoCBert,bert->roc_bert
1609
+ def __init__(self, config):
1610
+ super().__init__(config)
1611
+ self.num_labels = config.num_labels
1612
+ self.config = config
1613
+
1614
+ self.roc_bert = RoCBertModel(config)
1615
+ classifier_dropout = (
1616
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1617
+ )
1618
+ self.dropout = nn.Dropout(classifier_dropout)
1619
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1620
+
1621
+ # Initialize weights and apply final processing
1622
+ self.post_init()
1623
+
1624
+ @add_start_docstrings_to_model_forward(ROC_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1625
+ @add_code_sample_docstrings(
1626
+ checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
1627
+ output_type=SequenceClassifierOutput,
1628
+ config_class=_CONFIG_FOR_DOC,
1629
+ expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
1630
+ expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
1631
+ )
1632
+ def forward(
1633
+ self,
1634
+ input_ids: Optional[torch.Tensor] = None,
1635
+ input_shape_ids: Optional[torch.Tensor] = None,
1636
+ input_pronunciation_ids: Optional[torch.Tensor] = None,
1637
+ attention_mask: Optional[torch.Tensor] = None,
1638
+ token_type_ids: Optional[torch.Tensor] = None,
1639
+ position_ids: Optional[torch.Tensor] = None,
1640
+ head_mask: Optional[torch.Tensor] = None,
1641
+ inputs_embeds: Optional[torch.Tensor] = None,
1642
+ labels: Optional[torch.Tensor] = None,
1643
+ output_attentions: Optional[bool] = None,
1644
+ output_hidden_states: Optional[bool] = None,
1645
+ return_dict: Optional[bool] = None,
1646
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1647
+ r"""
1648
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1649
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1650
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1651
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1652
+ """
1653
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1654
+
1655
+ outputs = self.roc_bert(
1656
+ input_ids,
1657
+ input_shape_ids=input_shape_ids,
1658
+ input_pronunciation_ids=input_pronunciation_ids,
1659
+ attention_mask=attention_mask,
1660
+ token_type_ids=token_type_ids,
1661
+ position_ids=position_ids,
1662
+ head_mask=head_mask,
1663
+ inputs_embeds=inputs_embeds,
1664
+ output_attentions=output_attentions,
1665
+ output_hidden_states=output_hidden_states,
1666
+ return_dict=return_dict,
1667
+ )
1668
+
1669
+ pooled_output = outputs[1]
1670
+
1671
+ pooled_output = self.dropout(pooled_output)
1672
+ logits = self.classifier(pooled_output)
1673
+
1674
+ loss = None
1675
+ if labels is not None:
1676
+ if self.config.problem_type is None:
1677
+ if self.num_labels == 1:
1678
+ self.config.problem_type = "regression"
1679
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1680
+ self.config.problem_type = "single_label_classification"
1681
+ else:
1682
+ self.config.problem_type = "multi_label_classification"
1683
+
1684
+ if self.config.problem_type == "regression":
1685
+ loss_fct = MSELoss()
1686
+ if self.num_labels == 1:
1687
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1688
+ else:
1689
+ loss = loss_fct(logits, labels)
1690
+ elif self.config.problem_type == "single_label_classification":
1691
+ loss_fct = CrossEntropyLoss()
1692
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1693
+ elif self.config.problem_type == "multi_label_classification":
1694
+ loss_fct = BCEWithLogitsLoss()
1695
+ loss = loss_fct(logits, labels)
1696
+ if not return_dict:
1697
+ output = (logits,) + outputs[2:]
1698
+ return ((loss,) + output) if loss is not None else output
1699
+
1700
+ return SequenceClassifierOutput(
1701
+ loss=loss,
1702
+ logits=logits,
1703
+ hidden_states=outputs.hidden_states,
1704
+ attentions=outputs.attentions,
1705
+ )
1706
+
1707
+
1708
+ @add_start_docstrings(
1709
+ """RoCBert Model with a multiple choice classification head on top (a linear layer on top of
1710
+ the pooled output and a softmax) e.g. for RocStories/SWAG tasks.""",
1711
+ ROC_BERT_START_DOCSTRING,
1712
+ )
1713
+ class RoCBertForMultipleChoice(RoCBertPreTrainedModel):
1714
+ # Copied from transformers.models.bert.modeling_bert.BertForMultipleChoice.__init__ with Bert->RoCBert,bert->roc_bert
1715
+ def __init__(self, config):
1716
+ super().__init__(config)
1717
+
1718
+ self.roc_bert = RoCBertModel(config)
1719
+ classifier_dropout = (
1720
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1721
+ )
1722
+ self.dropout = nn.Dropout(classifier_dropout)
1723
+ self.classifier = nn.Linear(config.hidden_size, 1)
1724
+
1725
+ # Initialize weights and apply final processing
1726
+ self.post_init()
1727
+
1728
+ @add_start_docstrings_to_model_forward(
1729
+ ROC_BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
1730
+ )
1731
+ @add_code_sample_docstrings(
1732
+ checkpoint=_CHECKPOINT_FOR_DOC,
1733
+ output_type=MultipleChoiceModelOutput,
1734
+ config_class=_CONFIG_FOR_DOC,
1735
+ )
1736
+ def forward(
1737
+ self,
1738
+ input_ids: Optional[torch.Tensor] = None,
1739
+ input_shape_ids: Optional[torch.Tensor] = None,
1740
+ input_pronunciation_ids: Optional[torch.Tensor] = None,
1741
+ attention_mask: Optional[torch.Tensor] = None,
1742
+ token_type_ids: Optional[torch.Tensor] = None,
1743
+ position_ids: Optional[torch.Tensor] = None,
1744
+ head_mask: Optional[torch.Tensor] = None,
1745
+ inputs_embeds: Optional[torch.Tensor] = None,
1746
+ labels: Optional[torch.Tensor] = None,
1747
+ output_attentions: Optional[bool] = None,
1748
+ output_hidden_states: Optional[bool] = None,
1749
+ return_dict: Optional[bool] = None,
1750
+ ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
1751
+ r"""
1752
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1753
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
1754
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
1755
+ `input_ids` above)
1756
+ """
1757
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1758
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1759
+
1760
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1761
+ input_shape_ids = input_shape_ids.view(-1, input_shape_ids.size(-1)) if input_shape_ids is not None else None
1762
+ input_pronunciation_ids = (
1763
+ input_pronunciation_ids.view(-1, input_pronunciation_ids.size(-1))
1764
+ if input_pronunciation_ids is not None
1765
+ else None
1766
+ )
1767
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1768
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1769
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1770
+ inputs_embeds = (
1771
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1772
+ if inputs_embeds is not None
1773
+ else None
1774
+ )
1775
+
1776
+ outputs = self.roc_bert(
1777
+ input_ids,
1778
+ input_shape_ids=input_shape_ids,
1779
+ input_pronunciation_ids=input_pronunciation_ids,
1780
+ attention_mask=attention_mask,
1781
+ token_type_ids=token_type_ids,
1782
+ position_ids=position_ids,
1783
+ head_mask=head_mask,
1784
+ inputs_embeds=inputs_embeds,
1785
+ output_attentions=output_attentions,
1786
+ output_hidden_states=output_hidden_states,
1787
+ return_dict=return_dict,
1788
+ )
1789
+
1790
+ pooled_output = outputs[1]
1791
+
1792
+ pooled_output = self.dropout(pooled_output)
1793
+ logits = self.classifier(pooled_output)
1794
+ reshaped_logits = logits.view(-1, num_choices)
1795
+
1796
+ loss = None
1797
+ if labels is not None:
1798
+ loss_fct = CrossEntropyLoss()
1799
+ loss = loss_fct(reshaped_logits, labels)
1800
+
1801
+ if not return_dict:
1802
+ output = (reshaped_logits,) + outputs[2:]
1803
+ return ((loss,) + output) if loss is not None else output
1804
+
1805
+ return MultipleChoiceModelOutput(
1806
+ loss=loss,
1807
+ logits=reshaped_logits,
1808
+ hidden_states=outputs.hidden_states,
1809
+ attentions=outputs.attentions,
1810
+ )
1811
+
1812
+
1813
+ @add_start_docstrings(
1814
+ """RoCBert Model with a token classification head on top (a linear layer on top of
1815
+ the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks.""",
1816
+ ROC_BERT_START_DOCSTRING,
1817
+ )
1818
+ class RoCBertForTokenClassification(RoCBertPreTrainedModel):
1819
+ # Copied from transformers.models.bert.modeling_bert.BertForTokenClassification.__init__ with Bert->RoCBert,bert->roc_bert
1820
+ def __init__(self, config):
1821
+ super().__init__(config)
1822
+ self.num_labels = config.num_labels
1823
+
1824
+ self.roc_bert = RoCBertModel(config, add_pooling_layer=False)
1825
+ classifier_dropout = (
1826
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1827
+ )
1828
+ self.dropout = nn.Dropout(classifier_dropout)
1829
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1830
+
1831
+ # Initialize weights and apply final processing
1832
+ self.post_init()
1833
+
1834
+ @add_start_docstrings_to_model_forward(ROC_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1835
+ @add_code_sample_docstrings(
1836
+ checkpoint=_CHECKPOINT_FOR_TOKEN_CLASSIFICATION,
1837
+ output_type=TokenClassifierOutput,
1838
+ config_class=_CONFIG_FOR_DOC,
1839
+ expected_output=_TOKEN_CLASS_EXPECTED_OUTPUT,
1840
+ expected_loss=_TOKEN_CLASS_EXPECTED_LOSS,
1841
+ )
1842
+ def forward(
1843
+ self,
1844
+ input_ids: Optional[torch.Tensor] = None,
1845
+ input_shape_ids: Optional[torch.Tensor] = None,
1846
+ input_pronunciation_ids: Optional[torch.Tensor] = None,
1847
+ attention_mask: Optional[torch.Tensor] = None,
1848
+ token_type_ids: Optional[torch.Tensor] = None,
1849
+ position_ids: Optional[torch.Tensor] = None,
1850
+ head_mask: Optional[torch.Tensor] = None,
1851
+ inputs_embeds: Optional[torch.Tensor] = None,
1852
+ labels: Optional[torch.Tensor] = None,
1853
+ output_attentions: Optional[bool] = None,
1854
+ output_hidden_states: Optional[bool] = None,
1855
+ return_dict: Optional[bool] = None,
1856
+ ) -> Union[Tuple, TokenClassifierOutput]:
1857
+ r"""
1858
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1859
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1860
+ """
1861
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1862
+
1863
+ outputs = self.roc_bert(
1864
+ input_ids,
1865
+ input_shape_ids=input_shape_ids,
1866
+ input_pronunciation_ids=input_pronunciation_ids,
1867
+ attention_mask=attention_mask,
1868
+ token_type_ids=token_type_ids,
1869
+ position_ids=position_ids,
1870
+ head_mask=head_mask,
1871
+ inputs_embeds=inputs_embeds,
1872
+ output_attentions=output_attentions,
1873
+ output_hidden_states=output_hidden_states,
1874
+ return_dict=return_dict,
1875
+ )
1876
+
1877
+ sequence_output = outputs[0]
1878
+
1879
+ sequence_output = self.dropout(sequence_output)
1880
+ logits = self.classifier(sequence_output)
1881
+
1882
+ loss = None
1883
+ if labels is not None:
1884
+ loss_fct = CrossEntropyLoss()
1885
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1886
+
1887
+ if not return_dict:
1888
+ output = (logits,) + outputs[2:]
1889
+ return ((loss,) + output) if loss is not None else output
1890
+
1891
+ return TokenClassifierOutput(
1892
+ loss=loss,
1893
+ logits=logits,
1894
+ hidden_states=outputs.hidden_states,
1895
+ attentions=outputs.attentions,
1896
+ )
1897
+
1898
+
1899
+ @add_start_docstrings(
1900
+ """RoCBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1901
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).""",
1902
+ ROC_BERT_START_DOCSTRING,
1903
+ )
1904
+ class RoCBertForQuestionAnswering(RoCBertPreTrainedModel):
1905
+ # Copied from transformers.models.bert.modeling_bert.BertForQuestionAnswering.__init__ with Bert->RoCBert,bert->roc_bert
1906
+ def __init__(self, config):
1907
+ super().__init__(config)
1908
+ self.num_labels = config.num_labels
1909
+
1910
+ self.roc_bert = RoCBertModel(config, add_pooling_layer=False)
1911
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1912
+
1913
+ # Initialize weights and apply final processing
1914
+ self.post_init()
1915
+
1916
+ @add_start_docstrings_to_model_forward(ROC_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1917
+ @add_code_sample_docstrings(
1918
+ checkpoint=_CHECKPOINT_FOR_QA,
1919
+ output_type=QuestionAnsweringModelOutput,
1920
+ config_class=_CONFIG_FOR_DOC,
1921
+ qa_target_start_index=_QA_TARGET_START_INDEX,
1922
+ qa_target_end_index=_QA_TARGET_END_INDEX,
1923
+ expected_output=_QA_EXPECTED_OUTPUT,
1924
+ expected_loss=_QA_EXPECTED_LOSS,
1925
+ )
1926
+ def forward(
1927
+ self,
1928
+ input_ids: Optional[torch.Tensor] = None,
1929
+ input_shape_ids: Optional[torch.Tensor] = None,
1930
+ input_pronunciation_ids: Optional[torch.Tensor] = None,
1931
+ attention_mask: Optional[torch.Tensor] = None,
1932
+ token_type_ids: Optional[torch.Tensor] = None,
1933
+ position_ids: Optional[torch.Tensor] = None,
1934
+ head_mask: Optional[torch.Tensor] = None,
1935
+ inputs_embeds: Optional[torch.Tensor] = None,
1936
+ start_positions: Optional[torch.Tensor] = None,
1937
+ end_positions: Optional[torch.Tensor] = None,
1938
+ output_attentions: Optional[bool] = None,
1939
+ output_hidden_states: Optional[bool] = None,
1940
+ return_dict: Optional[bool] = None,
1941
+ ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
1942
+ r"""
1943
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1944
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1945
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1946
+ are not taken into account for computing the loss.
1947
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1948
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1949
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1950
+ are not taken into account for computing the loss.
1951
+ """
1952
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1953
+
1954
+ outputs = self.roc_bert(
1955
+ input_ids,
1956
+ input_shape_ids=input_shape_ids,
1957
+ input_pronunciation_ids=input_pronunciation_ids,
1958
+ attention_mask=attention_mask,
1959
+ token_type_ids=token_type_ids,
1960
+ position_ids=position_ids,
1961
+ head_mask=head_mask,
1962
+ inputs_embeds=inputs_embeds,
1963
+ output_attentions=output_attentions,
1964
+ output_hidden_states=output_hidden_states,
1965
+ return_dict=return_dict,
1966
+ )
1967
+
1968
+ sequence_output = outputs[0]
1969
+
1970
+ logits = self.qa_outputs(sequence_output)
1971
+ start_logits, end_logits = logits.split(1, dim=-1)
1972
+ start_logits = start_logits.squeeze(-1)
1973
+ end_logits = end_logits.squeeze(-1)
1974
+
1975
+ total_loss = None
1976
+ if start_positions is not None and end_positions is not None:
1977
+ # If we are on multi-GPU, split add a dimension
1978
+ if len(start_positions.size()) > 1:
1979
+ start_positions = start_positions.squeeze(-1)
1980
+ if len(end_positions.size()) > 1:
1981
+ end_positions = end_positions.squeeze(-1)
1982
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1983
+ ignored_index = start_logits.size(1)
1984
+ start_positions = start_positions.clamp(0, ignored_index)
1985
+ end_positions = end_positions.clamp(0, ignored_index)
1986
+
1987
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1988
+ start_loss = loss_fct(start_logits, start_positions)
1989
+ end_loss = loss_fct(end_logits, end_positions)
1990
+ total_loss = (start_loss + end_loss) / 2
1991
+
1992
+ if not return_dict:
1993
+ output = (start_logits, end_logits) + outputs[2:]
1994
+ return ((total_loss,) + output) if total_loss is not None else output
1995
+
1996
+ return QuestionAnsweringModelOutput(
1997
+ loss=total_loss,
1998
+ start_logits=start_logits,
1999
+ end_logits=end_logits,
2000
+ hidden_states=outputs.hidden_states,
2001
+ attentions=outputs.attentions,
2002
+ )
2003
+
2004
+
2005
+ __all__ = [
2006
+ "RoCBertForCausalLM",
2007
+ "RoCBertForMaskedLM",
2008
+ "RoCBertForMultipleChoice",
2009
+ "RoCBertForPreTraining",
2010
+ "RoCBertForQuestionAnswering",
2011
+ "RoCBertForSequenceClassification",
2012
+ "RoCBertForTokenClassification",
2013
+ "RoCBertLayer",
2014
+ "RoCBertModel",
2015
+ "RoCBertPreTrainedModel",
2016
+ "load_tf_weights_in_roc_bert",
2017
+ ]
docs/transformers/build/lib/transformers/models/roformer/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_roformer import *
22
+ from .modeling_flax_roformer import *
23
+ from .modeling_roformer import *
24
+ from .modeling_tf_roformer import *
25
+ from .tokenization_roformer import *
26
+ from .tokenization_roformer_fast import *
27
+ else:
28
+ import sys
29
+
30
+ _file = globals()["__file__"]
31
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/roformer/configuration_roformer.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """RoFormer model configuration"""
16
+
17
+ from collections import OrderedDict
18
+ from typing import Mapping
19
+
20
+ from ...configuration_utils import PretrainedConfig
21
+ from ...onnx import OnnxConfig
22
+ from ...utils import logging
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ class RoFormerConfig(PretrainedConfig):
29
+ r"""
30
+ This is the configuration class to store the configuration of a [`RoFormerModel`]. It is used to instantiate an
31
+ RoFormer model according to the specified arguments, defining the model architecture. Instantiating a configuration
32
+ with the defaults will yield a similar configuration to that of the RoFormer
33
+ [junnyu/roformer_chinese_base](https://huggingface.co/junnyu/roformer_chinese_base) architecture.
34
+
35
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
+ documentation from [`PretrainedConfig`] for more information.
37
+
38
+
39
+ Args:
40
+ vocab_size (`int`, *optional*, defaults to 50000):
41
+ Vocabulary size of the RoFormer model. Defines the number of different tokens that can be represented by
42
+ the `inputs_ids` passed when calling [`RoFormerModel`] or [`TFRoFormerModel`].
43
+ embedding_size (`int`, *optional*, defaults to None):
44
+ Dimensionality of the encoder layers and the pooler layer. Defaults to the `hidden_size` if not provided.
45
+ hidden_size (`int`, *optional*, defaults to 768):
46
+ Dimension of the encoder layers and the pooler layer.
47
+ num_hidden_layers (`int`, *optional*, defaults to 12):
48
+ Number of hidden layers in the Transformer encoder.
49
+ num_attention_heads (`int`, *optional*, defaults to 12):
50
+ Number of attention heads for each attention layer in the Transformer encoder.
51
+ intermediate_size (`int`, *optional*, defaults to 3072):
52
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
53
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
54
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
55
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
56
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
57
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
58
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
59
+ The dropout ratio for the attention probabilities.
60
+ max_position_embeddings (`int`, *optional*, defaults to 1536):
61
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
62
+ just in case (e.g., 512 or 1024 or 1536).
63
+ type_vocab_size (`int`, *optional*, defaults to 2):
64
+ The vocabulary size of the `token_type_ids` passed when calling [`RoFormerModel`] or [`TFRoFormerModel`].
65
+ initializer_range (`float`, *optional*, defaults to 0.02):
66
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
67
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
68
+ The epsilon used by the layer normalization layers.
69
+ is_decoder (`bool`, *optional*, defaults to `False`):
70
+ Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
71
+ use_cache (`bool`, *optional*, defaults to `True`):
72
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
73
+ relevant if `config.is_decoder=True`.
74
+ rotary_value (`bool`, *optional*, defaults to `False`):
75
+ Whether or not apply rotary position embeddings on value layer.
76
+
77
+ Example:
78
+
79
+ ```python
80
+ >>> from transformers import RoFormerModel, RoFormerConfig
81
+
82
+ >>> # Initializing a RoFormer junnyu/roformer_chinese_base style configuration
83
+ >>> configuration = RoFormerConfig()
84
+
85
+ >>> # Initializing a model (with random weights) from the junnyu/roformer_chinese_base style configuration
86
+ >>> model = RoFormerModel(configuration)
87
+
88
+ >>> # Accessing the model configuration
89
+ >>> configuration = model.config
90
+ ```"""
91
+
92
+ model_type = "roformer"
93
+
94
+ def __init__(
95
+ self,
96
+ vocab_size=50000,
97
+ embedding_size=None,
98
+ hidden_size=768,
99
+ num_hidden_layers=12,
100
+ num_attention_heads=12,
101
+ intermediate_size=3072,
102
+ hidden_act="gelu",
103
+ hidden_dropout_prob=0.1,
104
+ attention_probs_dropout_prob=0.1,
105
+ max_position_embeddings=1536,
106
+ type_vocab_size=2,
107
+ initializer_range=0.02,
108
+ layer_norm_eps=1e-12,
109
+ pad_token_id=0,
110
+ rotary_value=False,
111
+ use_cache=True,
112
+ **kwargs,
113
+ ):
114
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
115
+
116
+ self.vocab_size = vocab_size
117
+ self.embedding_size = hidden_size if embedding_size is None else embedding_size
118
+ self.hidden_size = hidden_size
119
+ self.num_hidden_layers = num_hidden_layers
120
+ self.num_attention_heads = num_attention_heads
121
+ self.hidden_act = hidden_act
122
+ self.intermediate_size = intermediate_size
123
+ self.hidden_dropout_prob = hidden_dropout_prob
124
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
125
+ self.max_position_embeddings = max_position_embeddings
126
+ self.type_vocab_size = type_vocab_size
127
+ self.initializer_range = initializer_range
128
+ self.layer_norm_eps = layer_norm_eps
129
+ self.rotary_value = rotary_value
130
+ self.use_cache = use_cache
131
+
132
+
133
+ class RoFormerOnnxConfig(OnnxConfig):
134
+ @property
135
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
136
+ if self.task == "multiple-choice":
137
+ dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
138
+ else:
139
+ dynamic_axis = {0: "batch", 1: "sequence"}
140
+ dynamic_axis = {0: "batch", 1: "sequence"}
141
+ return OrderedDict(
142
+ [
143
+ ("input_ids", dynamic_axis),
144
+ ("attention_mask", dynamic_axis),
145
+ ("token_type_ids", dynamic_axis),
146
+ ]
147
+ )
148
+
149
+
150
+ __all__ = ["RoFormerConfig", "RoFormerOnnxConfig"]
docs/transformers/build/lib/transformers/models/roformer/convert_roformer_original_tf_checkpoint_to_pytorch.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert RoFormer checkpoint."""
16
+
17
+ import argparse
18
+
19
+ import torch
20
+
21
+ from transformers import RoFormerConfig, RoFormerForMaskedLM, load_tf_weights_in_roformer
22
+ from transformers.utils import logging
23
+
24
+
25
+ logging.set_verbosity_info()
26
+
27
+
28
+ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
29
+ # Initialise PyTorch model
30
+ config = RoFormerConfig.from_json_file(bert_config_file)
31
+ print(f"Building PyTorch model from configuration: {config}")
32
+ model = RoFormerForMaskedLM(config)
33
+
34
+ # Load weights from tf checkpoint
35
+ load_tf_weights_in_roformer(model, config, tf_checkpoint_path)
36
+
37
+ # Save pytorch-model
38
+ print(f"Save PyTorch model to {pytorch_dump_path}")
39
+ torch.save(model.state_dict(), pytorch_dump_path, _use_new_zipfile_serialization=False)
40
+
41
+
42
+ if __name__ == "__main__":
43
+ parser = argparse.ArgumentParser()
44
+ # Required parameters
45
+ parser.add_argument(
46
+ "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
47
+ )
48
+ parser.add_argument(
49
+ "--bert_config_file",
50
+ default=None,
51
+ type=str,
52
+ required=True,
53
+ help=(
54
+ "The config json file corresponding to the pre-trained BERT model. \n"
55
+ "This specifies the model architecture."
56
+ ),
57
+ )
58
+ parser.add_argument(
59
+ "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
60
+ )
61
+ args = parser.parse_args()
62
+ convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)
docs/transformers/build/lib/transformers/models/roformer/modeling_roformer.py ADDED
@@ -0,0 +1,1660 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch RoFormer model."""
16
+
17
+ import math
18
+ import os
19
+ from typing import Callable, Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
26
+
27
+ from ...activations import ACT2FN, get_activation
28
+ from ...generation import GenerationMixin
29
+ from ...modeling_outputs import (
30
+ BaseModelOutputWithPastAndCrossAttentions,
31
+ CausalLMOutputWithCrossAttentions,
32
+ MaskedLMOutput,
33
+ MultipleChoiceModelOutput,
34
+ QuestionAnsweringModelOutput,
35
+ SequenceClassifierOutput,
36
+ TokenClassifierOutput,
37
+ )
38
+ from ...modeling_utils import PreTrainedModel
39
+ from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
40
+ from ...utils import (
41
+ add_code_sample_docstrings,
42
+ add_start_docstrings,
43
+ add_start_docstrings_to_model_forward,
44
+ logging,
45
+ replace_return_docstrings,
46
+ )
47
+ from .configuration_roformer import RoFormerConfig
48
+
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+ _CHECKPOINT_FOR_DOC = "junnyu/roformer_chinese_base"
53
+ _CONFIG_FOR_DOC = "RoFormerConfig"
54
+
55
+
56
+ # Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->RoFormer
57
+ class RoFormerSinusoidalPositionalEmbedding(nn.Embedding):
58
+ """This module produces sinusoidal positional embeddings of any length."""
59
+
60
+ def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
61
+ super().__init__(num_positions, embedding_dim)
62
+
63
+ def _init_weight(self):
64
+ """
65
+ Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
66
+ the 2nd half of the vector. [dim // 2:]
67
+ """
68
+ n_pos, dim = self.weight.shape
69
+ position_enc = np.array(
70
+ [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
71
+ )
72
+ out = torch.empty(n_pos, dim, dtype=self.weight.dtype, requires_grad=False)
73
+ sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1
74
+ out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
75
+ out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
76
+ self.weight = nn.Parameter(out, requires_grad=False)
77
+
78
+ @torch.no_grad()
79
+ def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor:
80
+ """`input_ids_shape` is expected to be [bsz x seqlen]."""
81
+ bsz, seq_len = input_ids_shape[:2]
82
+ positions = torch.arange(
83
+ past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
84
+ )
85
+ return super().forward(positions)
86
+
87
+
88
+ def load_tf_weights_in_roformer(model, config, tf_checkpoint_path):
89
+ """Load tf checkpoints in a pytorch model."""
90
+ try:
91
+ import re
92
+
93
+ import numpy as np
94
+ import tensorflow as tf
95
+ except ImportError:
96
+ logger.error(
97
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
98
+ "https://www.tensorflow.org/install/ for installation instructions."
99
+ )
100
+ raise
101
+ tf_path = os.path.abspath(tf_checkpoint_path)
102
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
103
+ # Load weights from TF model
104
+ init_vars = tf.train.list_variables(tf_path)
105
+ names = []
106
+ arrays = []
107
+ for name, shape in init_vars:
108
+ logger.info(f"Loading TF weight {name} with shape {shape}")
109
+ array = tf.train.load_variable(tf_path, name)
110
+ names.append(name.replace("bert", "roformer"))
111
+ arrays.append(array)
112
+
113
+ for name, array in zip(names, arrays):
114
+ name = name.split("/")
115
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
116
+ # which are not required for using pretrained model
117
+ if any(
118
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
119
+ for n in name
120
+ ):
121
+ logger.info(f"Skipping {'/'.join(name)}")
122
+ continue
123
+ pointer = model
124
+ for m_name in name:
125
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
126
+ scope_names = re.split(r"_(\d+)", m_name)
127
+ else:
128
+ scope_names = [m_name]
129
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
130
+ pointer = getattr(pointer, "weight")
131
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
132
+ pointer = getattr(pointer, "bias")
133
+ elif scope_names[0] == "output_weights":
134
+ pointer = getattr(pointer, "weight")
135
+ elif scope_names[0] == "squad":
136
+ pointer = getattr(pointer, "classifier")
137
+ else:
138
+ try:
139
+ pointer = getattr(pointer, scope_names[0])
140
+ except AttributeError:
141
+ logger.info(f"Skipping {'/'.join(name)}")
142
+ continue
143
+ if len(scope_names) >= 2:
144
+ num = int(scope_names[1])
145
+ pointer = pointer[num]
146
+ if m_name[-11:] == "_embeddings":
147
+ pointer = getattr(pointer, "weight")
148
+ elif m_name == "kernel":
149
+ array = np.transpose(array)
150
+ try:
151
+ if not pointer.shape == array.shape:
152
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
153
+ except AssertionError as e:
154
+ e.args += (pointer.shape, array.shape)
155
+ raise
156
+ logger.info(f"Initialize PyTorch weight {name}")
157
+ pointer.data = torch.from_numpy(array)
158
+ return model
159
+
160
+
161
+ class RoFormerEmbeddings(nn.Module):
162
+ """Construct the embeddings from word and token_type embeddings."""
163
+
164
+ def __init__(self, config):
165
+ super().__init__()
166
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
167
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)
168
+
169
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
170
+ # any TensorFlow checkpoint file
171
+ self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
172
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
173
+
174
+ def forward(self, input_ids=None, token_type_ids=None, inputs_embeds=None):
175
+ if input_ids is not None:
176
+ input_shape = input_ids.size()
177
+ else:
178
+ input_shape = inputs_embeds.size()[:-1]
179
+
180
+ if inputs_embeds is None:
181
+ inputs_embeds = self.word_embeddings(input_ids)
182
+
183
+ if token_type_ids is None:
184
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=inputs_embeds.device)
185
+
186
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
187
+
188
+ embeddings = inputs_embeds + token_type_embeddings
189
+
190
+ embeddings = self.LayerNorm(embeddings)
191
+ embeddings = self.dropout(embeddings)
192
+ return embeddings
193
+
194
+
195
+ class RoFormerSelfAttention(nn.Module):
196
+ def __init__(self, config):
197
+ super().__init__()
198
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
199
+ raise ValueError(
200
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
201
+ f"heads ({config.num_attention_heads})"
202
+ )
203
+
204
+ self.num_attention_heads = config.num_attention_heads
205
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
206
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
207
+
208
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
209
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
210
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
211
+
212
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
213
+
214
+ self.is_decoder = config.is_decoder
215
+ self.rotary_value = config.rotary_value
216
+
217
+ def transpose_for_scores(self, x):
218
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
219
+ x = x.view(*new_x_shape)
220
+ return x.permute(0, 2, 1, 3)
221
+
222
+ def forward(
223
+ self,
224
+ hidden_states,
225
+ attention_mask=None,
226
+ sinusoidal_pos=None,
227
+ head_mask=None,
228
+ encoder_hidden_states=None,
229
+ encoder_attention_mask=None,
230
+ past_key_value=None,
231
+ output_attentions=False,
232
+ ):
233
+ mixed_query_layer = self.query(hidden_states)
234
+ query_layer = self.transpose_for_scores(mixed_query_layer)
235
+ # If this is instantiated as a cross-attention module, the keys
236
+ # and values come from an encoder; the attention mask needs to be
237
+ # such that the encoder's padding tokens are not attended to.
238
+ is_cross_attention = encoder_hidden_states is not None
239
+
240
+ if is_cross_attention and past_key_value is not None:
241
+ # reuse k,v, cross_attentions
242
+ key_layer = past_key_value[0]
243
+ value_layer = past_key_value[1]
244
+ attention_mask = encoder_attention_mask
245
+ elif is_cross_attention:
246
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
247
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
248
+ attention_mask = encoder_attention_mask
249
+ else:
250
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
251
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
252
+ if sinusoidal_pos is not None:
253
+ if self.rotary_value:
254
+ query_layer, key_layer, value_layer = self.apply_rotary_position_embeddings(
255
+ sinusoidal_pos, query_layer, key_layer, value_layer
256
+ )
257
+ else:
258
+ query_layer, key_layer = self.apply_rotary_position_embeddings(
259
+ sinusoidal_pos, query_layer, key_layer
260
+ )
261
+ if past_key_value is not None:
262
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
263
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
264
+ if self.is_decoder:
265
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
266
+ # Further calls to cross_attention layer can then reuse all cross-attention
267
+ # key/value_states (first "if" case)
268
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
269
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
270
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
271
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
272
+ past_key_value = (key_layer, value_layer)
273
+
274
+ # Take the dot product between "query" and "key" to get the raw attention scores.
275
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
276
+
277
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
278
+ if attention_mask is not None:
279
+ # Apply the attention mask is (precomputed for all layers in RoFormerModel forward() function)
280
+ attention_scores = attention_scores + attention_mask
281
+
282
+ # Normalize the attention scores to probabilities.
283
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
284
+
285
+ # This is actually dropping out entire tokens to attend to, which might
286
+ # seem a bit unusual, but is taken from the original Transformer paper.
287
+ attention_probs = self.dropout(attention_probs)
288
+
289
+ # Mask heads if we want to
290
+ if head_mask is not None:
291
+ attention_probs = attention_probs * head_mask
292
+
293
+ context_layer = torch.matmul(attention_probs, value_layer)
294
+
295
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
296
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
297
+ context_layer = context_layer.view(*new_context_layer_shape)
298
+
299
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
300
+
301
+ if self.is_decoder:
302
+ outputs = outputs + (past_key_value,)
303
+ return outputs
304
+
305
+ @staticmethod
306
+ def apply_rotary_position_embeddings(sinusoidal_pos, query_layer, key_layer, value_layer=None):
307
+ # https://kexue.fm/archives/8265
308
+ # sin [batch_size, num_heads, sequence_length, embed_size_per_head//2]
309
+ # cos [batch_size, num_heads, sequence_length, embed_size_per_head//2]
310
+ sin, cos = sinusoidal_pos.chunk(2, dim=-1)
311
+ # sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
312
+ sin_pos = torch.stack([sin, sin], dim=-1).reshape_as(sinusoidal_pos)
313
+ # cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
314
+ cos_pos = torch.stack([cos, cos], dim=-1).reshape_as(sinusoidal_pos)
315
+ # rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2]
316
+ rotate_half_query_layer = torch.stack([-query_layer[..., 1::2], query_layer[..., ::2]], dim=-1).reshape_as(
317
+ query_layer
318
+ )
319
+ query_layer = query_layer * cos_pos + rotate_half_query_layer * sin_pos
320
+ # rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2]
321
+ rotate_half_key_layer = torch.stack([-key_layer[..., 1::2], key_layer[..., ::2]], dim=-1).reshape_as(key_layer)
322
+ key_layer = key_layer * cos_pos + rotate_half_key_layer * sin_pos
323
+ if value_layer is not None:
324
+ # rotate_half_value_layer [-v1,v0,-v3,v2......,-vd-1,vd-2]
325
+ rotate_half_value_layer = torch.stack([-value_layer[..., 1::2], value_layer[..., ::2]], dim=-1).reshape_as(
326
+ value_layer
327
+ )
328
+ value_layer = value_layer * cos_pos + rotate_half_value_layer * sin_pos
329
+ return query_layer, key_layer, value_layer
330
+ return query_layer, key_layer
331
+
332
+
333
+ # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->RoFormer
334
+ class RoFormerSelfOutput(nn.Module):
335
+ def __init__(self, config):
336
+ super().__init__()
337
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
338
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
339
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
340
+
341
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
342
+ hidden_states = self.dense(hidden_states)
343
+ hidden_states = self.dropout(hidden_states)
344
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
345
+ return hidden_states
346
+
347
+
348
+ class RoFormerAttention(nn.Module):
349
+ def __init__(self, config):
350
+ super().__init__()
351
+ self.self = RoFormerSelfAttention(config)
352
+ self.output = RoFormerSelfOutput(config)
353
+ self.pruned_heads = set()
354
+
355
+ # Copied from transformers.models.bert.modeling_bert.BertAttention.prune_heads
356
+ def prune_heads(self, heads):
357
+ if len(heads) == 0:
358
+ return
359
+ heads, index = find_pruneable_heads_and_indices(
360
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
361
+ )
362
+
363
+ # Prune linear layers
364
+ self.self.query = prune_linear_layer(self.self.query, index)
365
+ self.self.key = prune_linear_layer(self.self.key, index)
366
+ self.self.value = prune_linear_layer(self.self.value, index)
367
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
368
+
369
+ # Update hyper params and store pruned heads
370
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
371
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
372
+ self.pruned_heads = self.pruned_heads.union(heads)
373
+
374
+ # End Copy
375
+ def forward(
376
+ self,
377
+ hidden_states,
378
+ attention_mask=None,
379
+ sinusoidal_pos=None,
380
+ head_mask=None,
381
+ encoder_hidden_states=None,
382
+ encoder_attention_mask=None,
383
+ past_key_value=None,
384
+ output_attentions=False,
385
+ ):
386
+ self_outputs = self.self(
387
+ hidden_states,
388
+ attention_mask,
389
+ sinusoidal_pos,
390
+ head_mask,
391
+ encoder_hidden_states,
392
+ encoder_attention_mask,
393
+ past_key_value,
394
+ output_attentions,
395
+ )
396
+ attention_output = self.output(self_outputs[0], hidden_states)
397
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
398
+ return outputs
399
+
400
+
401
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->RoFormer
402
+ class RoFormerIntermediate(nn.Module):
403
+ def __init__(self, config):
404
+ super().__init__()
405
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
406
+ if isinstance(config.hidden_act, str):
407
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
408
+ else:
409
+ self.intermediate_act_fn = config.hidden_act
410
+
411
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
412
+ hidden_states = self.dense(hidden_states)
413
+ hidden_states = self.intermediate_act_fn(hidden_states)
414
+ return hidden_states
415
+
416
+
417
+ # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->RoFormer
418
+ class RoFormerOutput(nn.Module):
419
+ def __init__(self, config):
420
+ super().__init__()
421
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
422
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
423
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
424
+
425
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
426
+ hidden_states = self.dense(hidden_states)
427
+ hidden_states = self.dropout(hidden_states)
428
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
429
+ return hidden_states
430
+
431
+
432
+ class RoFormerLayer(nn.Module):
433
+ def __init__(self, config):
434
+ super().__init__()
435
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
436
+ self.seq_len_dim = 1
437
+ self.attention = RoFormerAttention(config)
438
+ self.is_decoder = config.is_decoder
439
+ self.add_cross_attention = config.add_cross_attention
440
+ if self.add_cross_attention:
441
+ if not self.is_decoder:
442
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
443
+ self.crossattention = RoFormerAttention(config)
444
+ self.intermediate = RoFormerIntermediate(config)
445
+ self.output = RoFormerOutput(config)
446
+
447
+ def forward(
448
+ self,
449
+ hidden_states,
450
+ attention_mask=None,
451
+ sinusoidal_pos=None,
452
+ head_mask=None,
453
+ encoder_hidden_states=None,
454
+ encoder_attention_mask=None,
455
+ past_key_value=None,
456
+ output_attentions=False,
457
+ ):
458
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
459
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
460
+ self_attention_outputs = self.attention(
461
+ hidden_states,
462
+ attention_mask,
463
+ sinusoidal_pos,
464
+ head_mask,
465
+ output_attentions=output_attentions,
466
+ past_key_value=self_attn_past_key_value,
467
+ )
468
+ attention_output = self_attention_outputs[0]
469
+
470
+ # if decoder, the last output is tuple of self-attn cache
471
+ if self.is_decoder:
472
+ outputs = self_attention_outputs[1:-1]
473
+ present_key_value = self_attention_outputs[-1]
474
+ else:
475
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
476
+
477
+ cross_attn_present_key_value = None
478
+ if self.is_decoder and encoder_hidden_states is not None:
479
+ if not hasattr(self, "crossattention"):
480
+ raise ValueError(
481
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention "
482
+ "layers by setting `config.add_cross_attention=True`"
483
+ )
484
+
485
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
486
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
487
+ cross_attention_outputs = self.crossattention(
488
+ attention_output,
489
+ attention_mask,
490
+ sinusoidal_pos,
491
+ head_mask,
492
+ encoder_hidden_states,
493
+ encoder_attention_mask,
494
+ cross_attn_past_key_value,
495
+ output_attentions,
496
+ )
497
+ attention_output = cross_attention_outputs[0]
498
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
499
+
500
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
501
+ cross_attn_present_key_value = cross_attention_outputs[-1]
502
+ present_key_value = present_key_value + cross_attn_present_key_value
503
+
504
+ layer_output = apply_chunking_to_forward(
505
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
506
+ )
507
+ outputs = (layer_output,) + outputs
508
+
509
+ # if decoder, return the attn key/values as the last output
510
+ if self.is_decoder:
511
+ outputs = outputs + (present_key_value,)
512
+
513
+ return outputs
514
+
515
+ def feed_forward_chunk(self, attention_output):
516
+ intermediate_output = self.intermediate(attention_output)
517
+ layer_output = self.output(intermediate_output, attention_output)
518
+ return layer_output
519
+
520
+
521
+ class RoFormerEncoder(nn.Module):
522
+ def __init__(self, config):
523
+ super().__init__()
524
+ self.config = config
525
+ self.embed_positions = RoFormerSinusoidalPositionalEmbedding(
526
+ config.max_position_embeddings, config.hidden_size // config.num_attention_heads
527
+ )
528
+ self.layer = nn.ModuleList([RoFormerLayer(config) for _ in range(config.num_hidden_layers)])
529
+ self.gradient_checkpointing = False
530
+
531
+ def forward(
532
+ self,
533
+ hidden_states,
534
+ attention_mask=None,
535
+ head_mask=None,
536
+ encoder_hidden_states=None,
537
+ encoder_attention_mask=None,
538
+ past_key_values=None,
539
+ use_cache=None,
540
+ output_attentions=False,
541
+ output_hidden_states=False,
542
+ return_dict=True,
543
+ ):
544
+ if self.gradient_checkpointing and self.training:
545
+ if use_cache:
546
+ logger.warning_once(
547
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
548
+ )
549
+ use_cache = False
550
+ all_hidden_states = () if output_hidden_states else None
551
+ all_self_attentions = () if output_attentions else None
552
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
553
+
554
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
555
+
556
+ # [sequence_length, embed_size_per_head] -> [batch_size, num_heads, sequence_length, embed_size_per_head]
557
+ sinusoidal_pos = self.embed_positions(hidden_states.shape[:-1], past_key_values_length)[None, None, :, :]
558
+
559
+ next_decoder_cache = () if use_cache else None
560
+ for i, layer_module in enumerate(self.layer):
561
+ if output_hidden_states:
562
+ all_hidden_states = all_hidden_states + (hidden_states,)
563
+
564
+ layer_head_mask = head_mask[i] if head_mask is not None else None
565
+ past_key_value = past_key_values[i] if past_key_values is not None else None
566
+
567
+ if self.gradient_checkpointing and self.training:
568
+ layer_outputs = self._gradient_checkpointing_func(
569
+ layer_module.__call__,
570
+ hidden_states,
571
+ attention_mask,
572
+ sinusoidal_pos,
573
+ layer_head_mask,
574
+ encoder_hidden_states,
575
+ encoder_attention_mask,
576
+ past_key_value,
577
+ output_attentions,
578
+ )
579
+ else:
580
+ layer_outputs = layer_module(
581
+ hidden_states,
582
+ attention_mask,
583
+ sinusoidal_pos,
584
+ layer_head_mask,
585
+ encoder_hidden_states,
586
+ encoder_attention_mask,
587
+ past_key_value,
588
+ output_attentions,
589
+ )
590
+
591
+ hidden_states = layer_outputs[0]
592
+ if use_cache:
593
+ next_decoder_cache += (layer_outputs[-1],)
594
+ if output_attentions:
595
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
596
+ if self.config.add_cross_attention:
597
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
598
+
599
+ if output_hidden_states:
600
+ all_hidden_states = all_hidden_states + (hidden_states,)
601
+
602
+ if not return_dict:
603
+ return tuple(
604
+ v
605
+ for v in [
606
+ hidden_states,
607
+ next_decoder_cache,
608
+ all_hidden_states,
609
+ all_self_attentions,
610
+ all_cross_attentions,
611
+ ]
612
+ if v is not None
613
+ )
614
+ return BaseModelOutputWithPastAndCrossAttentions(
615
+ last_hidden_state=hidden_states,
616
+ past_key_values=next_decoder_cache,
617
+ hidden_states=all_hidden_states,
618
+ attentions=all_self_attentions,
619
+ cross_attentions=all_cross_attentions,
620
+ )
621
+
622
+
623
+ # Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->RoFormer
624
+ class RoFormerSequenceSummary(nn.Module):
625
+ r"""
626
+ Compute a single vector summary of a sequence hidden states.
627
+
628
+ Args:
629
+ config ([`RoFormerConfig`]):
630
+ The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
631
+ config class of your model for the default values it uses):
632
+
633
+ - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
634
+
635
+ - `"last"` -- Take the last token hidden state (like XLNet)
636
+ - `"first"` -- Take the first token hidden state (like Bert)
637
+ - `"mean"` -- Take the mean of all tokens hidden states
638
+ - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
639
+ - `"attn"` -- Not implemented now, use multi-head attention
640
+
641
+ - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
642
+ - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
643
+ (otherwise to `config.hidden_size`).
644
+ - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
645
+ another string or `None` will add no activation.
646
+ - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
647
+ - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
648
+ """
649
+
650
+ def __init__(self, config: RoFormerConfig):
651
+ super().__init__()
652
+
653
+ self.summary_type = getattr(config, "summary_type", "last")
654
+ if self.summary_type == "attn":
655
+ # We should use a standard multi-head attention module with absolute positional embedding for that.
656
+ # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
657
+ # We can probably just use the multi-head attention module of PyTorch >=1.1.0
658
+ raise NotImplementedError
659
+
660
+ self.summary = nn.Identity()
661
+ if hasattr(config, "summary_use_proj") and config.summary_use_proj:
662
+ if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
663
+ num_classes = config.num_labels
664
+ else:
665
+ num_classes = config.hidden_size
666
+ self.summary = nn.Linear(config.hidden_size, num_classes)
667
+
668
+ activation_string = getattr(config, "summary_activation", None)
669
+ self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
670
+
671
+ self.first_dropout = nn.Identity()
672
+ if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
673
+ self.first_dropout = nn.Dropout(config.summary_first_dropout)
674
+
675
+ self.last_dropout = nn.Identity()
676
+ if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
677
+ self.last_dropout = nn.Dropout(config.summary_last_dropout)
678
+
679
+ def forward(
680
+ self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
681
+ ) -> torch.FloatTensor:
682
+ """
683
+ Compute a single vector summary of a sequence hidden states.
684
+
685
+ Args:
686
+ hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
687
+ The hidden states of the last layer.
688
+ cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
689
+ Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
690
+
691
+ Returns:
692
+ `torch.FloatTensor`: The summary of the sequence hidden states.
693
+ """
694
+ if self.summary_type == "last":
695
+ output = hidden_states[:, -1]
696
+ elif self.summary_type == "first":
697
+ output = hidden_states[:, 0]
698
+ elif self.summary_type == "mean":
699
+ output = hidden_states.mean(dim=1)
700
+ elif self.summary_type == "cls_index":
701
+ if cls_index is None:
702
+ cls_index = torch.full_like(
703
+ hidden_states[..., :1, :],
704
+ hidden_states.shape[-2] - 1,
705
+ dtype=torch.long,
706
+ )
707
+ else:
708
+ cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
709
+ cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
710
+ # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
711
+ output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
712
+ elif self.summary_type == "attn":
713
+ raise NotImplementedError
714
+
715
+ output = self.first_dropout(output)
716
+ output = self.summary(output)
717
+ output = self.activation(output)
718
+ output = self.last_dropout(output)
719
+
720
+ return output
721
+
722
+
723
+ class RoFormerPredictionHeadTransform(nn.Module):
724
+ def __init__(self, config):
725
+ super().__init__()
726
+ self.dense = nn.Linear(config.hidden_size, config.embedding_size)
727
+ if isinstance(config.hidden_act, str):
728
+ self.transform_act_fn = ACT2FN[config.hidden_act]
729
+ else:
730
+ self.transform_act_fn = config.hidden_act
731
+ self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
732
+
733
+ def forward(self, hidden_states):
734
+ hidden_states = self.dense(hidden_states)
735
+ hidden_states = self.transform_act_fn(hidden_states)
736
+ hidden_states = self.LayerNorm(hidden_states)
737
+ return hidden_states
738
+
739
+
740
+ class RoFormerLMPredictionHead(nn.Module):
741
+ def __init__(self, config):
742
+ super().__init__()
743
+ self.transform = RoFormerPredictionHeadTransform(config)
744
+
745
+ # The output weights are the same as the input embeddings, but there is
746
+ # an output-only bias for each token.
747
+ self.decoder = nn.Linear(config.embedding_size, config.vocab_size, bias=False)
748
+
749
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
750
+
751
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
752
+ self.decoder.bias = self.bias
753
+
754
+ def _tie_weights(self) -> None:
755
+ self.decoder.bias = self.bias
756
+
757
+ def forward(self, hidden_states):
758
+ hidden_states = self.transform(hidden_states)
759
+ hidden_states = self.decoder(hidden_states)
760
+ return hidden_states
761
+
762
+
763
+ # Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->RoFormer
764
+ class RoFormerOnlyMLMHead(nn.Module):
765
+ def __init__(self, config):
766
+ super().__init__()
767
+ self.predictions = RoFormerLMPredictionHead(config)
768
+
769
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
770
+ prediction_scores = self.predictions(sequence_output)
771
+ return prediction_scores
772
+
773
+
774
+ class RoFormerPreTrainedModel(PreTrainedModel):
775
+ """
776
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
777
+ models.
778
+ """
779
+
780
+ config_class = RoFormerConfig
781
+ load_tf_weights = load_tf_weights_in_roformer
782
+ base_model_prefix = "roformer"
783
+ supports_gradient_checkpointing = True
784
+
785
+ def _init_weights(self, module):
786
+ """Initialize the weights"""
787
+ if isinstance(module, nn.Linear):
788
+ # Slightly different from the TF version which uses truncated_normal for initialization
789
+ # cf https://github.com/pytorch/pytorch/pull/5617
790
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
791
+ if module.bias is not None:
792
+ module.bias.data.zero_()
793
+ elif isinstance(module, RoFormerSinusoidalPositionalEmbedding):
794
+ module._init_weight()
795
+ elif isinstance(module, nn.Embedding):
796
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
797
+ if module.padding_idx is not None:
798
+ module.weight.data[module.padding_idx].zero_()
799
+ elif isinstance(module, nn.LayerNorm):
800
+ module.bias.data.zero_()
801
+ module.weight.data.fill_(1.0)
802
+ elif isinstance(module, RoFormerLMPredictionHead):
803
+ module.bias.data.zero_()
804
+
805
+
806
+ ROFORMER_START_DOCSTRING = r"""
807
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
808
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
809
+ behavior.
810
+
811
+ Parameters:
812
+ config ([`RoFormerConfig`]): Model configuration class with all the parameters of the model.
813
+ Initializing with a config file does not load the weights associated with the model, only the
814
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
815
+ """
816
+
817
+ ROFORMER_INPUTS_DOCSTRING = r"""
818
+ Args:
819
+ input_ids (`torch.LongTensor` of shape `({0})`):
820
+ Indices of input sequence tokens in the vocabulary.
821
+
822
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
823
+ [`PreTrainedTokenizer.__call__`] for details.
824
+
825
+ [What are input IDs?](../glossary#input-ids)
826
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
827
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
828
+
829
+ - 1 for tokens that are **not masked**,
830
+ - 0 for tokens that are **masked**.
831
+
832
+ [What are attention masks?](../glossary#attention-mask)
833
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
834
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
835
+ 1]`:
836
+
837
+ - 0 corresponds to a *sentence A* token,
838
+ - 1 corresponds to a *sentence B* token.
839
+
840
+ [What are token type IDs?](../glossary#token-type-ids)
841
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
842
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
843
+
844
+ - 1 indicates the head is **not masked**,
845
+ - 0 indicates the head is **masked**.
846
+
847
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
848
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
849
+ is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
850
+ model's internal embedding lookup matrix.
851
+ output_attentions (`bool`, *optional*):
852
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
853
+ tensors for more detail.
854
+ output_hidden_states (`bool`, *optional*):
855
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
856
+ more detail.
857
+ return_dict (`bool`, *optional*):
858
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
859
+ """
860
+
861
+
862
+ @add_start_docstrings(
863
+ "The bare RoFormer Model transformer outputting raw hidden-states without any specific head on top.",
864
+ ROFORMER_START_DOCSTRING,
865
+ )
866
+ class RoFormerModel(RoFormerPreTrainedModel):
867
+ """
868
+
869
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
870
+ cross-attention is added between the self-attention layers, following the architecture described in [Attention is
871
+ all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
872
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
873
+
874
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
875
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
876
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
877
+ """
878
+
879
+ def __init__(self, config):
880
+ super().__init__(config)
881
+ self.config = config
882
+ self.embeddings = RoFormerEmbeddings(config)
883
+
884
+ if config.embedding_size != config.hidden_size:
885
+ self.embeddings_project = nn.Linear(config.embedding_size, config.hidden_size)
886
+
887
+ self.encoder = RoFormerEncoder(config)
888
+
889
+ # Initialize weights and apply final processing
890
+ self.post_init()
891
+
892
+ def get_input_embeddings(self):
893
+ return self.embeddings.word_embeddings
894
+
895
+ def set_input_embeddings(self, value):
896
+ self.embeddings.word_embeddings = value
897
+
898
+ def _prune_heads(self, heads_to_prune):
899
+ """
900
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
901
+ class PreTrainedModel
902
+ """
903
+ for layer, heads in heads_to_prune.items():
904
+ self.encoder.layer[layer].attention.prune_heads(heads)
905
+
906
+ @add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
907
+ @add_code_sample_docstrings(
908
+ checkpoint=_CHECKPOINT_FOR_DOC,
909
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
910
+ config_class=_CONFIG_FOR_DOC,
911
+ )
912
+ def forward(
913
+ self,
914
+ input_ids: Optional[torch.LongTensor] = None,
915
+ attention_mask: Optional[torch.FloatTensor] = None,
916
+ token_type_ids: Optional[torch.LongTensor] = None,
917
+ head_mask: Optional[torch.FloatTensor] = None,
918
+ inputs_embeds: Optional[torch.FloatTensor] = None,
919
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
920
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
921
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
922
+ use_cache: Optional[bool] = None,
923
+ output_attentions: Optional[bool] = None,
924
+ output_hidden_states: Optional[bool] = None,
925
+ return_dict: Optional[bool] = None,
926
+ ) -> Union[BaseModelOutputWithPastAndCrossAttentions, Tuple[torch.Tensor]]:
927
+ r"""
928
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
929
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
930
+ the model is configured as a decoder.
931
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
932
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
933
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
934
+
935
+ - 1 for tokens that are **not masked**,
936
+ - 0 for tokens that are **masked**.
937
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
938
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
939
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
940
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
941
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
942
+ use_cache (`bool`, *optional*):
943
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
944
+ `past_key_values`).
945
+ """
946
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
947
+ output_hidden_states = (
948
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
949
+ )
950
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
951
+
952
+ if self.config.is_decoder:
953
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
954
+ else:
955
+ use_cache = False
956
+
957
+ if input_ids is not None and inputs_embeds is not None:
958
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
959
+ elif input_ids is not None:
960
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
961
+ input_shape = input_ids.size()
962
+ elif inputs_embeds is not None:
963
+ input_shape = inputs_embeds.size()[:-1]
964
+ else:
965
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
966
+
967
+ batch_size, seq_length = input_shape
968
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
969
+
970
+ # past_key_values_length
971
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
972
+
973
+ if attention_mask is None:
974
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
975
+ if token_type_ids is None:
976
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
977
+
978
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
979
+ # ourselves in which case we just need to make it broadcastable to all heads.
980
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
981
+
982
+ # If a 2D or 3D attention mask is provided for the cross-attention
983
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
984
+ if self.config.is_decoder and encoder_hidden_states is not None:
985
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
986
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
987
+ if encoder_attention_mask is None:
988
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
989
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
990
+ else:
991
+ encoder_extended_attention_mask = None
992
+
993
+ # Prepare head mask if needed
994
+ # 1.0 in head_mask indicate we keep the head
995
+ # attention_probs has shape bsz x n_heads x N x N
996
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
997
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
998
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
999
+
1000
+ embedding_output = self.embeddings(
1001
+ input_ids=input_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
1002
+ )
1003
+ if hasattr(self, "embeddings_project"):
1004
+ embedding_output = self.embeddings_project(embedding_output)
1005
+
1006
+ encoder_outputs = self.encoder(
1007
+ embedding_output,
1008
+ attention_mask=extended_attention_mask,
1009
+ head_mask=head_mask,
1010
+ encoder_hidden_states=encoder_hidden_states,
1011
+ encoder_attention_mask=encoder_extended_attention_mask,
1012
+ past_key_values=past_key_values,
1013
+ use_cache=use_cache,
1014
+ output_attentions=output_attentions,
1015
+ output_hidden_states=output_hidden_states,
1016
+ return_dict=return_dict,
1017
+ )
1018
+ sequence_output = encoder_outputs[0]
1019
+
1020
+ if not return_dict:
1021
+ return (sequence_output,) + encoder_outputs[1:]
1022
+
1023
+ return BaseModelOutputWithPastAndCrossAttentions(
1024
+ last_hidden_state=sequence_output,
1025
+ past_key_values=encoder_outputs.past_key_values,
1026
+ hidden_states=encoder_outputs.hidden_states,
1027
+ attentions=encoder_outputs.attentions,
1028
+ cross_attentions=encoder_outputs.cross_attentions,
1029
+ )
1030
+
1031
+
1032
+ @add_start_docstrings("""RoFormer Model with a `language modeling` head on top.""", ROFORMER_START_DOCSTRING)
1033
+ class RoFormerForMaskedLM(RoFormerPreTrainedModel):
1034
+ _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
1035
+
1036
+ def __init__(self, config):
1037
+ super().__init__(config)
1038
+
1039
+ if config.is_decoder:
1040
+ logger.warning(
1041
+ "If you want to use `RoFormerForMaskedLM` make sure `config.is_decoder=False` for "
1042
+ "bi-directional self-attention."
1043
+ )
1044
+
1045
+ self.roformer = RoFormerModel(config)
1046
+ self.cls = RoFormerOnlyMLMHead(config)
1047
+
1048
+ # Initialize weights and apply final processing
1049
+ self.post_init()
1050
+
1051
+ def get_output_embeddings(self):
1052
+ return self.cls.predictions.decoder
1053
+
1054
+ def set_output_embeddings(self, new_embeddings):
1055
+ self.cls.predictions.decoder = new_embeddings
1056
+ self.cls.predictions.bias = new_embeddings.bias
1057
+
1058
+ @add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1059
+ @add_code_sample_docstrings(
1060
+ checkpoint=_CHECKPOINT_FOR_DOC,
1061
+ output_type=MaskedLMOutput,
1062
+ config_class=_CONFIG_FOR_DOC,
1063
+ )
1064
+ def forward(
1065
+ self,
1066
+ input_ids: Optional[torch.LongTensor] = None,
1067
+ attention_mask: Optional[torch.FloatTensor] = None,
1068
+ token_type_ids: Optional[torch.LongTensor] = None,
1069
+ head_mask: Optional[torch.FloatTensor] = None,
1070
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1071
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1072
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1073
+ labels: Optional[torch.LongTensor] = None,
1074
+ output_attentions: Optional[bool] = None,
1075
+ output_hidden_states: Optional[bool] = None,
1076
+ return_dict: Optional[bool] = None,
1077
+ ) -> Union[MaskedLMOutput, Tuple[torch.Tensor]]:
1078
+ r"""
1079
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1080
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1081
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1082
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1083
+ """
1084
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1085
+
1086
+ outputs = self.roformer(
1087
+ input_ids,
1088
+ attention_mask=attention_mask,
1089
+ token_type_ids=token_type_ids,
1090
+ head_mask=head_mask,
1091
+ inputs_embeds=inputs_embeds,
1092
+ encoder_hidden_states=encoder_hidden_states,
1093
+ encoder_attention_mask=encoder_attention_mask,
1094
+ output_attentions=output_attentions,
1095
+ output_hidden_states=output_hidden_states,
1096
+ return_dict=return_dict,
1097
+ )
1098
+
1099
+ sequence_output = outputs[0]
1100
+ prediction_scores = self.cls(sequence_output)
1101
+
1102
+ masked_lm_loss = None
1103
+ if labels is not None:
1104
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1105
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1106
+
1107
+ if not return_dict:
1108
+ output = (prediction_scores,) + outputs[1:]
1109
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1110
+
1111
+ return MaskedLMOutput(
1112
+ loss=masked_lm_loss,
1113
+ logits=prediction_scores,
1114
+ hidden_states=outputs.hidden_states,
1115
+ attentions=outputs.attentions,
1116
+ )
1117
+
1118
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
1119
+ input_shape = input_ids.shape
1120
+ effective_batch_size = input_shape[0]
1121
+
1122
+ # add a dummy token
1123
+ assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
1124
+ attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
1125
+ dummy_token = torch.full(
1126
+ (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
1127
+ )
1128
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
1129
+
1130
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
1131
+
1132
+
1133
+ @add_start_docstrings(
1134
+ """RoFormer Model with a `language modeling` head on top for CLM fine-tuning.""", ROFORMER_START_DOCSTRING
1135
+ )
1136
+ class RoFormerForCausalLM(RoFormerPreTrainedModel, GenerationMixin):
1137
+ _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
1138
+
1139
+ def __init__(self, config):
1140
+ super().__init__(config)
1141
+
1142
+ if not config.is_decoder:
1143
+ logger.warning("If you want to use `RoFormerForCausalLM` as a standalone, add `is_decoder=True.`")
1144
+
1145
+ self.roformer = RoFormerModel(config)
1146
+ self.cls = RoFormerOnlyMLMHead(config)
1147
+
1148
+ # Initialize weights and apply final processing
1149
+ self.post_init()
1150
+
1151
+ def get_output_embeddings(self):
1152
+ return self.cls.predictions.decoder
1153
+
1154
+ def set_output_embeddings(self, new_embeddings):
1155
+ self.cls.predictions.decoder = new_embeddings
1156
+ self.cls.predictions.bias = new_embeddings.bias
1157
+
1158
+ @add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1159
+ @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
1160
+ def forward(
1161
+ self,
1162
+ input_ids: Optional[torch.LongTensor] = None,
1163
+ attention_mask: Optional[torch.FloatTensor] = None,
1164
+ token_type_ids: Optional[torch.LongTensor] = None,
1165
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1166
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1167
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1168
+ head_mask: Optional[torch.FloatTensor] = None,
1169
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1170
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1171
+ labels: Optional[torch.LongTensor] = None,
1172
+ use_cache: Optional[bool] = None,
1173
+ output_attentions: Optional[bool] = None,
1174
+ output_hidden_states: Optional[bool] = None,
1175
+ return_dict: Optional[bool] = None,
1176
+ **kwargs,
1177
+ ) -> Union[CausalLMOutputWithCrossAttentions, Tuple[torch.Tensor]]:
1178
+ r"""
1179
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1180
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1181
+ the model is configured as a decoder.
1182
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1183
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1184
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
1185
+
1186
+ - 1 for tokens that are **not masked**,
1187
+ - 0 for tokens that are **masked**.
1188
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1189
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1190
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1191
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1192
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1193
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1194
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1195
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
1196
+ ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`.
1197
+ use_cache (`bool`, *optional*):
1198
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1199
+ `past_key_values`).
1200
+
1201
+ Returns:
1202
+
1203
+ Example:
1204
+
1205
+ ```python
1206
+ >>> from transformers import AutoTokenizer, RoFormerForCausalLM, RoFormerConfig
1207
+ >>> import torch
1208
+
1209
+ >>> tokenizer = AutoTokenizer.from_pretrained("junnyu/roformer_chinese_base")
1210
+ >>> config = RoFormerConfig.from_pretrained("junnyu/roformer_chinese_base")
1211
+ >>> config.is_decoder = True
1212
+ >>> model = RoFormerForCausalLM.from_pretrained("junnyu/roformer_chinese_base", config=config)
1213
+
1214
+ >>> inputs = tokenizer("今天天气非常好。", return_tensors="pt")
1215
+ >>> outputs = model(**inputs)
1216
+
1217
+ >>> prediction_logits = outputs.logits
1218
+ ```"""
1219
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1220
+
1221
+ outputs = self.roformer(
1222
+ input_ids,
1223
+ attention_mask=attention_mask,
1224
+ token_type_ids=token_type_ids,
1225
+ head_mask=head_mask,
1226
+ inputs_embeds=inputs_embeds,
1227
+ encoder_hidden_states=encoder_hidden_states,
1228
+ encoder_attention_mask=encoder_attention_mask,
1229
+ past_key_values=past_key_values,
1230
+ use_cache=use_cache,
1231
+ output_attentions=output_attentions,
1232
+ output_hidden_states=output_hidden_states,
1233
+ return_dict=return_dict,
1234
+ )
1235
+
1236
+ sequence_output = outputs[0]
1237
+ prediction_scores = self.cls(sequence_output)
1238
+
1239
+ lm_loss = None
1240
+ if labels is not None:
1241
+ lm_loss = self.loss_function(
1242
+ prediction_scores,
1243
+ labels,
1244
+ vocab_size=self.config.vocab_size,
1245
+ **kwargs,
1246
+ )
1247
+
1248
+ if not return_dict:
1249
+ output = (prediction_scores,) + outputs[1:]
1250
+ return ((lm_loss,) + output) if lm_loss is not None else output
1251
+
1252
+ return CausalLMOutputWithCrossAttentions(
1253
+ loss=lm_loss,
1254
+ logits=prediction_scores,
1255
+ past_key_values=outputs.past_key_values,
1256
+ hidden_states=outputs.hidden_states,
1257
+ attentions=outputs.attentions,
1258
+ cross_attentions=outputs.cross_attentions,
1259
+ )
1260
+
1261
+ def _reorder_cache(self, past_key_values, beam_idx):
1262
+ reordered_past = ()
1263
+ for layer_past in past_key_values:
1264
+ reordered_past += (
1265
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
1266
+ + layer_past[2:],
1267
+ )
1268
+ return reordered_past
1269
+
1270
+
1271
+ class RoFormerClassificationHead(nn.Module):
1272
+ """Head for sentence-level classification tasks."""
1273
+
1274
+ def __init__(self, config):
1275
+ super().__init__()
1276
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1277
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1278
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
1279
+
1280
+ self.config = config
1281
+
1282
+ def forward(self, features, **kwargs):
1283
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
1284
+ x = self.dropout(x)
1285
+ x = self.dense(x)
1286
+ x = ACT2FN[self.config.hidden_act](x)
1287
+ x = self.dropout(x)
1288
+ x = self.out_proj(x)
1289
+ return x
1290
+
1291
+
1292
+ @add_start_docstrings(
1293
+ """
1294
+ RoFormer Model transformer with a sequence classification/regression head on top (a linear layer on top of the
1295
+ pooled output) e.g. for GLUE tasks.
1296
+ """,
1297
+ ROFORMER_START_DOCSTRING,
1298
+ )
1299
+ class RoFormerForSequenceClassification(RoFormerPreTrainedModel):
1300
+ def __init__(self, config):
1301
+ super().__init__(config)
1302
+ self.num_labels = config.num_labels
1303
+ self.roformer = RoFormerModel(config)
1304
+ self.classifier = RoFormerClassificationHead(config)
1305
+
1306
+ # Initialize weights and apply final processing
1307
+ self.post_init()
1308
+
1309
+ @add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1310
+ @add_code_sample_docstrings(
1311
+ checkpoint=_CHECKPOINT_FOR_DOC,
1312
+ output_type=SequenceClassifierOutput,
1313
+ config_class=_CONFIG_FOR_DOC,
1314
+ )
1315
+ def forward(
1316
+ self,
1317
+ input_ids: Optional[torch.LongTensor] = None,
1318
+ attention_mask: Optional[torch.FloatTensor] = None,
1319
+ token_type_ids: Optional[torch.LongTensor] = None,
1320
+ head_mask: Optional[torch.FloatTensor] = None,
1321
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1322
+ labels: Optional[torch.LongTensor] = None,
1323
+ output_attentions: Optional[bool] = None,
1324
+ output_hidden_states: Optional[bool] = None,
1325
+ return_dict: Optional[bool] = None,
1326
+ ) -> Union[SequenceClassifierOutput, Tuple[torch.Tensor]]:
1327
+ r"""
1328
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1329
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1330
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1331
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1332
+ """
1333
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1334
+
1335
+ outputs = self.roformer(
1336
+ input_ids,
1337
+ attention_mask=attention_mask,
1338
+ token_type_ids=token_type_ids,
1339
+ head_mask=head_mask,
1340
+ inputs_embeds=inputs_embeds,
1341
+ output_attentions=output_attentions,
1342
+ output_hidden_states=output_hidden_states,
1343
+ return_dict=return_dict,
1344
+ )
1345
+
1346
+ sequence_output = outputs[0]
1347
+ logits = self.classifier(sequence_output)
1348
+
1349
+ loss = None
1350
+ if labels is not None:
1351
+ if self.config.problem_type is None:
1352
+ if self.num_labels == 1:
1353
+ self.config.problem_type = "regression"
1354
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1355
+ self.config.problem_type = "single_label_classification"
1356
+ else:
1357
+ self.config.problem_type = "multi_label_classification"
1358
+
1359
+ if self.config.problem_type == "regression":
1360
+ loss_fct = MSELoss()
1361
+ if self.num_labels == 1:
1362
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1363
+ else:
1364
+ loss = loss_fct(logits, labels)
1365
+ elif self.config.problem_type == "single_label_classification":
1366
+ loss_fct = CrossEntropyLoss()
1367
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1368
+ elif self.config.problem_type == "multi_label_classification":
1369
+ loss_fct = BCEWithLogitsLoss()
1370
+ loss = loss_fct(logits, labels)
1371
+ if not return_dict:
1372
+ output = (logits,) + outputs[1:]
1373
+ return ((loss,) + output) if loss is not None else output
1374
+
1375
+ return SequenceClassifierOutput(
1376
+ loss=loss,
1377
+ logits=logits,
1378
+ hidden_states=outputs.hidden_states,
1379
+ attentions=outputs.attentions,
1380
+ )
1381
+
1382
+
1383
+ @add_start_docstrings(
1384
+ """
1385
+ RoFormer Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
1386
+ softmax) e.g. for RocStories/SWAG tasks.
1387
+ """,
1388
+ ROFORMER_START_DOCSTRING,
1389
+ )
1390
+ class RoFormerForMultipleChoice(RoFormerPreTrainedModel):
1391
+ def __init__(self, config):
1392
+ super().__init__(config)
1393
+
1394
+ self.roformer = RoFormerModel(config)
1395
+ self.sequence_summary = RoFormerSequenceSummary(config)
1396
+ self.classifier = nn.Linear(config.hidden_size, 1)
1397
+
1398
+ # Initialize weights and apply final processing
1399
+ self.post_init()
1400
+
1401
+ @add_start_docstrings_to_model_forward(
1402
+ ROFORMER_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
1403
+ )
1404
+ @add_code_sample_docstrings(
1405
+ checkpoint=_CHECKPOINT_FOR_DOC,
1406
+ output_type=MultipleChoiceModelOutput,
1407
+ config_class=_CONFIG_FOR_DOC,
1408
+ )
1409
+ def forward(
1410
+ self,
1411
+ input_ids: Optional[torch.LongTensor] = None,
1412
+ attention_mask: Optional[torch.FloatTensor] = None,
1413
+ token_type_ids: Optional[torch.LongTensor] = None,
1414
+ head_mask: Optional[torch.FloatTensor] = None,
1415
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1416
+ labels: Optional[torch.LongTensor] = None,
1417
+ output_attentions: Optional[bool] = None,
1418
+ output_hidden_states: Optional[bool] = None,
1419
+ return_dict: Optional[bool] = None,
1420
+ ) -> Union[MultipleChoiceModelOutput, Tuple[torch.Tensor]]:
1421
+ r"""
1422
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1423
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
1424
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
1425
+ `input_ids` above)
1426
+ """
1427
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1428
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1429
+
1430
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1431
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1432
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1433
+
1434
+ inputs_embeds = (
1435
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1436
+ if inputs_embeds is not None
1437
+ else None
1438
+ )
1439
+
1440
+ outputs = self.roformer(
1441
+ input_ids,
1442
+ attention_mask=attention_mask,
1443
+ token_type_ids=token_type_ids,
1444
+ head_mask=head_mask,
1445
+ inputs_embeds=inputs_embeds,
1446
+ output_attentions=output_attentions,
1447
+ output_hidden_states=output_hidden_states,
1448
+ return_dict=return_dict,
1449
+ )
1450
+
1451
+ sequence_output = outputs[0]
1452
+
1453
+ pooled_output = self.sequence_summary(sequence_output)
1454
+ logits = self.classifier(pooled_output)
1455
+ reshaped_logits = logits.view(-1, num_choices)
1456
+
1457
+ loss = None
1458
+ if labels is not None:
1459
+ loss_fct = CrossEntropyLoss()
1460
+ loss = loss_fct(reshaped_logits, labels)
1461
+
1462
+ if not return_dict:
1463
+ output = (reshaped_logits,) + outputs[1:]
1464
+ return ((loss,) + output) if loss is not None else output
1465
+
1466
+ return MultipleChoiceModelOutput(
1467
+ loss=loss,
1468
+ logits=reshaped_logits,
1469
+ hidden_states=outputs.hidden_states,
1470
+ attentions=outputs.attentions,
1471
+ )
1472
+
1473
+
1474
+ @add_start_docstrings(
1475
+ """
1476
+ RoFormer Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1477
+ Named-Entity-Recognition (NER) tasks.
1478
+ """,
1479
+ ROFORMER_START_DOCSTRING,
1480
+ )
1481
+ class RoFormerForTokenClassification(RoFormerPreTrainedModel):
1482
+ def __init__(self, config):
1483
+ super().__init__(config)
1484
+ self.num_labels = config.num_labels
1485
+
1486
+ self.roformer = RoFormerModel(config)
1487
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1488
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1489
+
1490
+ # Initialize weights and apply final processing
1491
+ self.post_init()
1492
+
1493
+ @add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1494
+ @add_code_sample_docstrings(
1495
+ checkpoint=_CHECKPOINT_FOR_DOC,
1496
+ output_type=TokenClassifierOutput,
1497
+ config_class=_CONFIG_FOR_DOC,
1498
+ )
1499
+ def forward(
1500
+ self,
1501
+ input_ids: Optional[torch.LongTensor] = None,
1502
+ attention_mask: Optional[torch.FloatTensor] = None,
1503
+ token_type_ids: Optional[torch.LongTensor] = None,
1504
+ head_mask: Optional[torch.FloatTensor] = None,
1505
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1506
+ labels: Optional[torch.LongTensor] = None,
1507
+ output_attentions: Optional[bool] = None,
1508
+ output_hidden_states: Optional[bool] = None,
1509
+ return_dict: Optional[bool] = None,
1510
+ ) -> Union[TokenClassifierOutput, Tuple[torch.Tensor]]:
1511
+ r"""
1512
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1513
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1514
+ """
1515
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1516
+
1517
+ outputs = self.roformer(
1518
+ input_ids,
1519
+ attention_mask=attention_mask,
1520
+ token_type_ids=token_type_ids,
1521
+ head_mask=head_mask,
1522
+ inputs_embeds=inputs_embeds,
1523
+ output_attentions=output_attentions,
1524
+ output_hidden_states=output_hidden_states,
1525
+ return_dict=return_dict,
1526
+ )
1527
+
1528
+ sequence_output = outputs[0]
1529
+
1530
+ sequence_output = self.dropout(sequence_output)
1531
+ logits = self.classifier(sequence_output)
1532
+
1533
+ loss = None
1534
+ if labels is not None:
1535
+ loss_fct = CrossEntropyLoss()
1536
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1537
+
1538
+ if not return_dict:
1539
+ output = (logits,) + outputs[1:]
1540
+ return ((loss,) + output) if loss is not None else output
1541
+
1542
+ return TokenClassifierOutput(
1543
+ loss=loss,
1544
+ logits=logits,
1545
+ hidden_states=outputs.hidden_states,
1546
+ attentions=outputs.attentions,
1547
+ )
1548
+
1549
+
1550
+ @add_start_docstrings(
1551
+ """
1552
+ RoFormer Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1553
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1554
+ """,
1555
+ ROFORMER_START_DOCSTRING,
1556
+ )
1557
+ class RoFormerForQuestionAnswering(RoFormerPreTrainedModel):
1558
+ def __init__(self, config):
1559
+ super().__init__(config)
1560
+
1561
+ config.num_labels = 2
1562
+ self.num_labels = config.num_labels
1563
+
1564
+ self.roformer = RoFormerModel(config)
1565
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1566
+
1567
+ # Initialize weights and apply final processing
1568
+ self.post_init()
1569
+
1570
+ @add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1571
+ @add_code_sample_docstrings(
1572
+ checkpoint=_CHECKPOINT_FOR_DOC,
1573
+ output_type=QuestionAnsweringModelOutput,
1574
+ config_class=_CONFIG_FOR_DOC,
1575
+ )
1576
+ def forward(
1577
+ self,
1578
+ input_ids: Optional[torch.LongTensor] = None,
1579
+ attention_mask: Optional[torch.FloatTensor] = None,
1580
+ token_type_ids: Optional[torch.LongTensor] = None,
1581
+ head_mask: Optional[torch.FloatTensor] = None,
1582
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1583
+ start_positions: Optional[torch.LongTensor] = None,
1584
+ end_positions: Optional[torch.LongTensor] = None,
1585
+ output_attentions: Optional[bool] = None,
1586
+ output_hidden_states: Optional[bool] = None,
1587
+ return_dict: Optional[bool] = None,
1588
+ ) -> Union[QuestionAnsweringModelOutput, Tuple[torch.Tensor]]:
1589
+ r"""
1590
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1591
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1592
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1593
+ are not taken into account for computing the loss.
1594
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1595
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1596
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1597
+ are not taken into account for computing the loss.
1598
+ """
1599
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1600
+
1601
+ outputs = self.roformer(
1602
+ input_ids,
1603
+ attention_mask=attention_mask,
1604
+ token_type_ids=token_type_ids,
1605
+ head_mask=head_mask,
1606
+ inputs_embeds=inputs_embeds,
1607
+ output_attentions=output_attentions,
1608
+ output_hidden_states=output_hidden_states,
1609
+ return_dict=return_dict,
1610
+ )
1611
+
1612
+ sequence_output = outputs[0]
1613
+
1614
+ logits = self.qa_outputs(sequence_output)
1615
+ start_logits, end_logits = logits.split(1, dim=-1)
1616
+ start_logits = start_logits.squeeze(-1)
1617
+ end_logits = end_logits.squeeze(-1)
1618
+
1619
+ total_loss = None
1620
+ if start_positions is not None and end_positions is not None:
1621
+ # If we are on multi-GPU, split add a dimension
1622
+ if len(start_positions.size()) > 1:
1623
+ start_positions = start_positions.squeeze(-1)
1624
+ if len(end_positions.size()) > 1:
1625
+ end_positions = end_positions.squeeze(-1)
1626
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1627
+ ignored_index = start_logits.size(1)
1628
+ start_positions = start_positions.clamp(0, ignored_index)
1629
+ end_positions = end_positions.clamp(0, ignored_index)
1630
+
1631
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1632
+ start_loss = loss_fct(start_logits, start_positions)
1633
+ end_loss = loss_fct(end_logits, end_positions)
1634
+ total_loss = (start_loss + end_loss) / 2
1635
+
1636
+ if not return_dict:
1637
+ output = (start_logits, end_logits) + outputs[1:]
1638
+ return ((total_loss,) + output) if total_loss is not None else output
1639
+
1640
+ return QuestionAnsweringModelOutput(
1641
+ loss=total_loss,
1642
+ start_logits=start_logits,
1643
+ end_logits=end_logits,
1644
+ hidden_states=outputs.hidden_states,
1645
+ attentions=outputs.attentions,
1646
+ )
1647
+
1648
+
1649
+ __all__ = [
1650
+ "RoFormerForCausalLM",
1651
+ "RoFormerForMaskedLM",
1652
+ "RoFormerForMultipleChoice",
1653
+ "RoFormerForQuestionAnswering",
1654
+ "RoFormerForSequenceClassification",
1655
+ "RoFormerForTokenClassification",
1656
+ "RoFormerLayer",
1657
+ "RoFormerModel",
1658
+ "RoFormerPreTrainedModel",
1659
+ "load_tf_weights_in_roformer",
1660
+ ]
docs/transformers/build/lib/transformers/models/rt_detr_v2/configuration_rt_detr_v2.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/rt_detr_v2/modular_rt_detr_v2.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_rt_detr_v2.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2025 Baidu Inc and The HuggingFace Inc. team.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ from ...configuration_utils import PretrainedConfig
23
+ from ...utils import logging
24
+ from ...utils.backbone_utils import verify_backbone_config_arguments
25
+ from ..auto import CONFIG_MAPPING
26
+
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ class RTDetrV2Config(PretrainedConfig):
32
+ r"""
33
+ This is the configuration class to store the configuration of a [`RTDetrV2Model`]. It is used to instantiate a
34
+ RT-DETR model according to the specified arguments, defining the model architecture. Instantiating a configuration
35
+ with the defaults will yield a similar configuration to that of the RT-DETR architecture.
36
+
37
+ e.g. [PekingU/rtdetr_r18vd](https://huggingface.co/PekingU/rtdetr_r18vd)
38
+
39
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
40
+ documentation from [`PretrainedConfig`] for more information.
41
+
42
+ Args:
43
+ initializer_range (`float`, *optional*, defaults to 0.01):
44
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
45
+ initializer_bias_prior_prob (`float`, *optional*):
46
+ The prior probability used by the bias initializer to initialize biases for `enc_score_head` and `class_embed`.
47
+ If `None`, `prior_prob` computed as `prior_prob = 1 / (num_labels + 1)` while initializing model weights.
48
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
49
+ The epsilon used by the layer normalization layers.
50
+ batch_norm_eps (`float`, *optional*, defaults to 1e-05):
51
+ The epsilon used by the batch normalization layers.
52
+ backbone_config (`Dict`, *optional*, defaults to `RTDetrV2ResNetConfig()`):
53
+ The configuration of the backbone model.
54
+ backbone (`str`, *optional*):
55
+ Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
56
+ will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
57
+ is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
58
+ use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
59
+ Whether to use pretrained weights for the backbone.
60
+ use_timm_backbone (`bool`, *optional*, defaults to `False`):
61
+ Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
62
+ library.
63
+ freeze_backbone_batch_norms (`bool`, *optional*, defaults to `True`):
64
+ Whether to freeze the batch normalization layers in the backbone.
65
+ backbone_kwargs (`dict`, *optional*):
66
+ Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
67
+ e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
68
+ encoder_hidden_dim (`int`, *optional*, defaults to 256):
69
+ Dimension of the layers in hybrid encoder.
70
+ encoder_in_channels (`list`, *optional*, defaults to `[512, 1024, 2048]`):
71
+ Multi level features input for encoder.
72
+ feat_strides (`List[int]`, *optional*, defaults to `[8, 16, 32]`):
73
+ Strides used in each feature map.
74
+ encoder_layers (`int`, *optional*, defaults to 1):
75
+ Total of layers to be used by the encoder.
76
+ encoder_ffn_dim (`int`, *optional*, defaults to 1024):
77
+ Dimension of the "intermediate" (often named feed-forward) layer in decoder.
78
+ encoder_attention_heads (`int`, *optional*, defaults to 8):
79
+ Number of attention heads for each attention layer in the Transformer encoder.
80
+ dropout (`float`, *optional*, defaults to 0.0):
81
+ The ratio for all dropout layers.
82
+ activation_dropout (`float`, *optional*, defaults to 0.0):
83
+ The dropout ratio for activations inside the fully connected layer.
84
+ encode_proj_layers (`List[int]`, *optional*, defaults to `[2]`):
85
+ Indexes of the projected layers to be used in the encoder.
86
+ positional_encoding_temperature (`int`, *optional*, defaults to 10000):
87
+ The temperature parameter used to create the positional encodings.
88
+ encoder_activation_function (`str`, *optional*, defaults to `"gelu"`):
89
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
90
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
91
+ activation_function (`str`, *optional*, defaults to `"silu"`):
92
+ The non-linear activation function (function or string) in the general layer. If string, `"gelu"`,
93
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
94
+ eval_size (`Tuple[int, int]`, *optional*):
95
+ Height and width used to compute the effective height and width of the position embeddings after taking
96
+ into account the stride.
97
+ normalize_before (`bool`, *optional*, defaults to `False`):
98
+ Determine whether to apply layer normalization in the transformer encoder layer before self-attention and
99
+ feed-forward modules.
100
+ hidden_expansion (`float`, *optional*, defaults to 1.0):
101
+ Expansion ratio to enlarge the dimension size of RepVGGBlock and CSPRepLayer.
102
+ d_model (`int`, *optional*, defaults to 256):
103
+ Dimension of the layers exclude hybrid encoder.
104
+ num_queries (`int`, *optional*, defaults to 300):
105
+ Number of object queries.
106
+ decoder_in_channels (`list`, *optional*, defaults to `[256, 256, 256]`):
107
+ Multi level features dimension for decoder
108
+ decoder_ffn_dim (`int`, *optional*, defaults to 1024):
109
+ Dimension of the "intermediate" (often named feed-forward) layer in decoder.
110
+ num_feature_levels (`int`, *optional*, defaults to 3):
111
+ The number of input feature levels.
112
+ decoder_n_points (`int`, *optional*, defaults to 4):
113
+ The number of sampled keys in each feature level for each attention head in the decoder.
114
+ decoder_layers (`int`, *optional*, defaults to 6):
115
+ Number of decoder layers.
116
+ decoder_attention_heads (`int`, *optional*, defaults to 8):
117
+ Number of attention heads for each attention layer in the Transformer decoder.
118
+ decoder_activation_function (`str`, *optional*, defaults to `"relu"`):
119
+ The non-linear activation function (function or string) in the decoder. If string, `"gelu"`,
120
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
121
+ attention_dropout (`float`, *optional*, defaults to 0.0):
122
+ The dropout ratio for the attention probabilities.
123
+ num_denoising (`int`, *optional*, defaults to 100):
124
+ The total number of denoising tasks or queries to be used for contrastive denoising.
125
+ label_noise_ratio (`float`, *optional*, defaults to 0.5):
126
+ The fraction of denoising labels to which random noise should be added.
127
+ box_noise_scale (`float`, *optional*, defaults to 1.0):
128
+ Scale or magnitude of noise to be added to the bounding boxes.
129
+ learn_initial_query (`bool`, *optional*, defaults to `False`):
130
+ Indicates whether the initial query embeddings for the decoder should be learned during training
131
+ anchor_image_size (`Tuple[int, int]`, *optional*):
132
+ Height and width of the input image used during evaluation to generate the bounding box anchors. If None, automatic generate anchor is applied.
133
+ with_box_refine (`bool`, *optional*, defaults to `True`):
134
+ Whether to apply iterative bounding box refinement, where each decoder layer refines the bounding boxes
135
+ based on the predictions from the previous layer.
136
+ is_encoder_decoder (`bool`, *optional*, defaults to `True`):
137
+ Whether the architecture has an encoder decoder structure.
138
+ matcher_alpha (`float`, *optional*, defaults to 0.25):
139
+ Parameter alpha used by the Hungarian Matcher.
140
+ matcher_gamma (`float`, *optional*, defaults to 2.0):
141
+ Parameter gamma used by the Hungarian Matcher.
142
+ matcher_class_cost (`float`, *optional*, defaults to 2.0):
143
+ The relative weight of the class loss used by the Hungarian Matcher.
144
+ matcher_bbox_cost (`float`, *optional*, defaults to 5.0):
145
+ The relative weight of the bounding box loss used by the Hungarian Matcher.
146
+ matcher_giou_cost (`float`, *optional*, defaults to 2.0):
147
+ The relative weight of the giou loss of used by the Hungarian Matcher.
148
+ use_focal_loss (`bool`, *optional*, defaults to `True`):
149
+ Parameter informing if focal loss should be used.
150
+ auxiliary_loss (`bool`, *optional*, defaults to `True`):
151
+ Whether auxiliary decoding losses (loss at each decoder layer) are to be used.
152
+ focal_loss_alpha (`float`, *optional*, defaults to 0.75):
153
+ Parameter alpha used to compute the focal loss.
154
+ focal_loss_gamma (`float`, *optional*, defaults to 2.0):
155
+ Parameter gamma used to compute the focal loss.
156
+ weight_loss_vfl (`float`, *optional*, defaults to 1.0):
157
+ Relative weight of the varifocal loss in the object detection loss.
158
+ weight_loss_bbox (`float`, *optional*, defaults to 5.0):
159
+ Relative weight of the L1 bounding box loss in the object detection loss.
160
+ weight_loss_giou (`float`, *optional*, defaults to 2.0):
161
+ Relative weight of the generalized IoU loss in the object detection loss.
162
+ eos_coefficient (`float`, *optional*, defaults to 0.0001):
163
+ Relative classification weight of the 'no-object' class in the object detection loss.
164
+ decoder_n_levels (`int`, *optional*, defaults to 3):
165
+ The number of feature levels used by the decoder.
166
+ decoder_offset_scale (`float`, *optional*, defaults to 0.5):
167
+ Scaling factor applied to the attention offsets in the decoder.
168
+ decoder_method (`str`, *optional*, defaults to `"default"`):
169
+ The method to use for the decoder: `"default"` or `"discrete"`.
170
+
171
+ Examples:
172
+
173
+ ```python
174
+ >>> from transformers import RTDetrV2Config, RTDetrV2Model
175
+
176
+ >>> # Initializing a RT-DETR configuration
177
+ >>> configuration = RTDetrV2Config()
178
+
179
+ >>> # Initializing a model (with random weights) from the configuration
180
+ >>> model = RTDetrV2Model(configuration)
181
+
182
+ >>> # Accessing the model configuration
183
+ >>> configuration = model.config
184
+ ```
185
+ """
186
+
187
+ model_type = "rt_detr_v2"
188
+ layer_types = ["basic", "bottleneck"]
189
+ attribute_map = {
190
+ "hidden_size": "d_model",
191
+ "num_attention_heads": "encoder_attention_heads",
192
+ }
193
+
194
+ def __init__(
195
+ self,
196
+ initializer_range=0.01,
197
+ initializer_bias_prior_prob=None,
198
+ layer_norm_eps=1e-5,
199
+ batch_norm_eps=1e-5,
200
+ # backbone
201
+ backbone_config=None,
202
+ backbone=None,
203
+ use_pretrained_backbone=False,
204
+ use_timm_backbone=False,
205
+ freeze_backbone_batch_norms=True,
206
+ backbone_kwargs=None,
207
+ # encoder HybridEncoder
208
+ encoder_hidden_dim=256,
209
+ encoder_in_channels=[512, 1024, 2048],
210
+ feat_strides=[8, 16, 32],
211
+ encoder_layers=1,
212
+ encoder_ffn_dim=1024,
213
+ encoder_attention_heads=8,
214
+ dropout=0.0,
215
+ activation_dropout=0.0,
216
+ encode_proj_layers=[2],
217
+ positional_encoding_temperature=10000,
218
+ encoder_activation_function="gelu",
219
+ activation_function="silu",
220
+ eval_size=None,
221
+ normalize_before=False,
222
+ hidden_expansion=1.0,
223
+ # decoder RTDetrV2Transformer
224
+ d_model=256,
225
+ num_queries=300,
226
+ decoder_in_channels=[256, 256, 256],
227
+ decoder_ffn_dim=1024,
228
+ num_feature_levels=3,
229
+ decoder_n_points=4,
230
+ decoder_layers=6,
231
+ decoder_attention_heads=8,
232
+ decoder_activation_function="relu",
233
+ attention_dropout=0.0,
234
+ num_denoising=100,
235
+ label_noise_ratio=0.5,
236
+ box_noise_scale=1.0,
237
+ learn_initial_query=False,
238
+ anchor_image_size=None,
239
+ with_box_refine=True,
240
+ is_encoder_decoder=True,
241
+ # Loss
242
+ matcher_alpha=0.25,
243
+ matcher_gamma=2.0,
244
+ matcher_class_cost=2.0,
245
+ matcher_bbox_cost=5.0,
246
+ matcher_giou_cost=2.0,
247
+ use_focal_loss=True,
248
+ auxiliary_loss=True,
249
+ focal_loss_alpha=0.75,
250
+ focal_loss_gamma=2.0,
251
+ weight_loss_vfl=1.0,
252
+ weight_loss_bbox=5.0,
253
+ weight_loss_giou=2.0,
254
+ eos_coefficient=1e-4,
255
+ decoder_n_levels=3, # default value
256
+ decoder_offset_scale=0.5, # default value
257
+ decoder_method="default",
258
+ **kwargs,
259
+ ):
260
+ super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
261
+ self.initializer_range = initializer_range
262
+ self.initializer_bias_prior_prob = initializer_bias_prior_prob
263
+ self.layer_norm_eps = layer_norm_eps
264
+ self.batch_norm_eps = batch_norm_eps
265
+ # backbone
266
+ if backbone_config is None and backbone is None:
267
+ logger.info(
268
+ "`backbone_config` and `backbone` are `None`. Initializing the config with the default `RTDetrV2-ResNet` backbone."
269
+ )
270
+ backbone_model_type = "rt_detr_resnet"
271
+ config_class = CONFIG_MAPPING[backbone_model_type]
272
+ # this will map it to RTDetrResNetConfig
273
+ # note: we can instead create RTDetrV2ResNetConfig but it will be exactly the same as V1
274
+ # and we would need to create RTDetrV2ResNetModel
275
+ backbone_config = config_class(
276
+ num_channels=3,
277
+ embedding_size=64,
278
+ hidden_sizes=[256, 512, 1024, 2048],
279
+ depths=[3, 4, 6, 3],
280
+ layer_type="bottleneck",
281
+ hidden_act="relu",
282
+ downsample_in_first_stage=False,
283
+ downsample_in_bottleneck=False,
284
+ out_features=None,
285
+ out_indices=[2, 3, 4],
286
+ )
287
+ elif isinstance(backbone_config, dict):
288
+ backbone_model_type = backbone_config.pop("model_type")
289
+ config_class = CONFIG_MAPPING[backbone_model_type]
290
+ backbone_config = config_class.from_dict(backbone_config)
291
+
292
+ verify_backbone_config_arguments(
293
+ use_timm_backbone=use_timm_backbone,
294
+ use_pretrained_backbone=use_pretrained_backbone,
295
+ backbone=backbone,
296
+ backbone_config=backbone_config,
297
+ backbone_kwargs=backbone_kwargs,
298
+ )
299
+
300
+ self.backbone_config = backbone_config
301
+ self.backbone = backbone
302
+ self.use_pretrained_backbone = use_pretrained_backbone
303
+ self.use_timm_backbone = use_timm_backbone
304
+ self.freeze_backbone_batch_norms = freeze_backbone_batch_norms
305
+ self.backbone_kwargs = backbone_kwargs
306
+ # encoder
307
+ self.encoder_hidden_dim = encoder_hidden_dim
308
+ self.encoder_in_channels = encoder_in_channels
309
+ self.feat_strides = feat_strides
310
+ self.encoder_ffn_dim = encoder_ffn_dim
311
+ self.dropout = dropout
312
+ self.activation_dropout = activation_dropout
313
+ self.encode_proj_layers = encode_proj_layers
314
+ self.encoder_layers = encoder_layers
315
+ self.positional_encoding_temperature = positional_encoding_temperature
316
+ self.eval_size = eval_size
317
+ self.normalize_before = normalize_before
318
+ self.encoder_activation_function = encoder_activation_function
319
+ self.activation_function = activation_function
320
+ self.hidden_expansion = hidden_expansion
321
+ self.num_queries = num_queries
322
+ self.decoder_ffn_dim = decoder_ffn_dim
323
+ self.decoder_in_channels = decoder_in_channels
324
+ self.num_feature_levels = num_feature_levels
325
+ self.decoder_n_points = decoder_n_points
326
+ self.decoder_layers = decoder_layers
327
+ self.decoder_attention_heads = decoder_attention_heads
328
+ self.decoder_activation_function = decoder_activation_function
329
+ self.attention_dropout = attention_dropout
330
+ self.num_denoising = num_denoising
331
+ self.label_noise_ratio = label_noise_ratio
332
+ self.box_noise_scale = box_noise_scale
333
+ self.learn_initial_query = learn_initial_query
334
+ self.anchor_image_size = anchor_image_size
335
+ self.auxiliary_loss = auxiliary_loss
336
+ self.with_box_refine = with_box_refine
337
+ # Loss
338
+ self.matcher_alpha = matcher_alpha
339
+ self.matcher_gamma = matcher_gamma
340
+ self.matcher_class_cost = matcher_class_cost
341
+ self.matcher_bbox_cost = matcher_bbox_cost
342
+ self.matcher_giou_cost = matcher_giou_cost
343
+ self.use_focal_loss = use_focal_loss
344
+ self.focal_loss_alpha = focal_loss_alpha
345
+ self.focal_loss_gamma = focal_loss_gamma
346
+ self.weight_loss_vfl = weight_loss_vfl
347
+ self.weight_loss_bbox = weight_loss_bbox
348
+ self.weight_loss_giou = weight_loss_giou
349
+ self.eos_coefficient = eos_coefficient
350
+
351
+ if not hasattr(self, "d_model"):
352
+ self.d_model = d_model
353
+
354
+ if not hasattr(self, "encoder_attention_heads"):
355
+ self.encoder_attention_heads = encoder_attention_heads
356
+ # add the new attributes with the given values or defaults
357
+ self.decoder_n_levels = decoder_n_levels
358
+ self.decoder_offset_scale = decoder_offset_scale
359
+ self.decoder_method = decoder_method
360
+
361
+ @classmethod
362
+ def from_backbone_configs(cls, backbone_config: PretrainedConfig, **kwargs):
363
+ """Instantiate a [`RTDetrV2Config`] (or a derived class) from a pre-trained backbone model configuration and DETR model
364
+ configuration.
365
+
366
+ Args:
367
+ backbone_config ([`PretrainedConfig`]):
368
+ The backbone configuration.
369
+
370
+ Returns:
371
+ [`RTDetrV2Config`]: An instance of a configuration object
372
+ """
373
+ return cls(
374
+ backbone_config=backbone_config,
375
+ **kwargs,
376
+ )
377
+
378
+
379
+ __all__ = ["RTDetrV2Config"]
docs/transformers/build/lib/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py ADDED
The diff for this file is too large to render. See raw diff
 
docs/transformers/build/lib/transformers/models/sam/convert_sam_to_hf.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Convert SAM checkpoints from the original repository.
17
+
18
+ URL: https://github.com/facebookresearch/segment-anything.
19
+
20
+ Also supports converting the SlimSAM checkpoints from https://github.com/czg1225/SlimSAM/tree/master.
21
+ """
22
+
23
+ import argparse
24
+ import re
25
+
26
+ import numpy as np
27
+ import requests
28
+ import torch
29
+ from huggingface_hub import hf_hub_download
30
+ from PIL import Image
31
+
32
+ from transformers import (
33
+ SamConfig,
34
+ SamImageProcessor,
35
+ SamModel,
36
+ SamProcessor,
37
+ SamVisionConfig,
38
+ )
39
+
40
+
41
+ def get_config(model_name):
42
+ if "slimsam-50" in model_name:
43
+ vision_config = SamVisionConfig(
44
+ hidden_size=384,
45
+ mlp_dim=1536,
46
+ num_hidden_layers=12,
47
+ num_attention_heads=12,
48
+ global_attn_indexes=[2, 5, 8, 11],
49
+ )
50
+ elif "slimsam-77" in model_name:
51
+ vision_config = SamVisionConfig(
52
+ hidden_size=168,
53
+ mlp_dim=696,
54
+ num_hidden_layers=12,
55
+ num_attention_heads=12,
56
+ global_attn_indexes=[2, 5, 8, 11],
57
+ )
58
+ elif "sam_vit_b" in model_name:
59
+ vision_config = SamVisionConfig()
60
+ elif "sam_vit_l" in model_name:
61
+ vision_config = SamVisionConfig(
62
+ hidden_size=1024,
63
+ num_hidden_layers=24,
64
+ num_attention_heads=16,
65
+ global_attn_indexes=[5, 11, 17, 23],
66
+ )
67
+ elif "sam_vit_h" in model_name:
68
+ vision_config = SamVisionConfig(
69
+ hidden_size=1280,
70
+ num_hidden_layers=32,
71
+ num_attention_heads=16,
72
+ global_attn_indexes=[7, 15, 23, 31],
73
+ )
74
+
75
+ config = SamConfig(
76
+ vision_config=vision_config,
77
+ )
78
+
79
+ return config
80
+
81
+
82
+ KEYS_TO_MODIFY_MAPPING = {
83
+ "iou_prediction_head.layers.0": "iou_prediction_head.proj_in",
84
+ "iou_prediction_head.layers.1": "iou_prediction_head.layers.0",
85
+ "iou_prediction_head.layers.2": "iou_prediction_head.proj_out",
86
+ "mask_decoder.output_upscaling.0": "mask_decoder.upscale_conv1",
87
+ "mask_decoder.output_upscaling.1": "mask_decoder.upscale_layer_norm",
88
+ "mask_decoder.output_upscaling.3": "mask_decoder.upscale_conv2",
89
+ "mask_downscaling.0": "mask_embed.conv1",
90
+ "mask_downscaling.1": "mask_embed.layer_norm1",
91
+ "mask_downscaling.3": "mask_embed.conv2",
92
+ "mask_downscaling.4": "mask_embed.layer_norm2",
93
+ "mask_downscaling.6": "mask_embed.conv3",
94
+ "point_embeddings": "point_embed",
95
+ "pe_layer.positional_encoding_gaussian_matrix": "shared_embedding.positional_embedding",
96
+ "image_encoder": "vision_encoder",
97
+ "neck.0": "neck.conv1",
98
+ "neck.1": "neck.layer_norm1",
99
+ "neck.2": "neck.conv2",
100
+ "neck.3": "neck.layer_norm2",
101
+ "patch_embed.proj": "patch_embed.projection",
102
+ ".norm": ".layer_norm",
103
+ "blocks": "layers",
104
+ }
105
+
106
+
107
+ def replace_keys(state_dict):
108
+ model_state_dict = {}
109
+ state_dict.pop("pixel_mean", None)
110
+ state_dict.pop("pixel_std", None)
111
+
112
+ output_hypernetworks_mlps_pattern = r".*.output_hypernetworks_mlps.(\d+).layers.(\d+).*"
113
+
114
+ for key, value in state_dict.items():
115
+ for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
116
+ if key_to_modify in key:
117
+ key = key.replace(key_to_modify, new_key)
118
+
119
+ if re.match(output_hypernetworks_mlps_pattern, key):
120
+ layer_nb = int(re.match(output_hypernetworks_mlps_pattern, key).group(2))
121
+ if layer_nb == 0:
122
+ key = key.replace("layers.0", "proj_in")
123
+ elif layer_nb == 1:
124
+ key = key.replace("layers.1", "layers.0")
125
+ elif layer_nb == 2:
126
+ key = key.replace("layers.2", "proj_out")
127
+
128
+ model_state_dict[key] = value
129
+
130
+ model_state_dict["shared_image_embedding.positional_embedding"] = model_state_dict[
131
+ "prompt_encoder.shared_embedding.positional_embedding"
132
+ ]
133
+
134
+ return model_state_dict
135
+
136
+
137
+ def convert_sam_checkpoint(model_name, checkpoint_path, pytorch_dump_folder, push_to_hub):
138
+ config = get_config(model_name)
139
+
140
+ state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
141
+ state_dict = replace_keys(state_dict)
142
+
143
+ image_processor = SamImageProcessor()
144
+ processor = SamProcessor(image_processor=image_processor)
145
+ hf_model = SamModel(config)
146
+ hf_model.eval()
147
+
148
+ device = "cuda" if torch.cuda.is_available() else "cpu"
149
+
150
+ hf_model.load_state_dict(state_dict)
151
+ hf_model = hf_model.to(device)
152
+
153
+ img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
154
+ raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
155
+
156
+ input_points = [[[500, 375]]]
157
+ input_labels = [[1]]
158
+
159
+ inputs = processor(images=np.array(raw_image), return_tensors="pt").to(device)
160
+
161
+ with torch.no_grad():
162
+ output = hf_model(**inputs)
163
+ scores = output.iou_scores.squeeze()
164
+
165
+ if model_name == "sam_vit_b_01ec64":
166
+ inputs = processor(
167
+ images=np.array(raw_image), input_points=input_points, input_labels=input_labels, return_tensors="pt"
168
+ ).to(device)
169
+
170
+ with torch.no_grad():
171
+ output = hf_model(**inputs)
172
+ scores = output.iou_scores.squeeze()
173
+
174
+ elif model_name == "sam_vit_h_4b8939":
175
+ inputs = processor(
176
+ images=np.array(raw_image), input_points=input_points, input_labels=input_labels, return_tensors="pt"
177
+ ).to(device)
178
+
179
+ with torch.no_grad():
180
+ output = hf_model(**inputs)
181
+ scores = output.iou_scores.squeeze()
182
+
183
+ assert scores[-1].item() == 0.9712603092193604
184
+
185
+ input_boxes = ((75, 275, 1725, 850),)
186
+
187
+ inputs = processor(images=np.array(raw_image), input_boxes=input_boxes, return_tensors="pt").to(device)
188
+
189
+ with torch.no_grad():
190
+ output = hf_model(**inputs)
191
+ scores = output.iou_scores.squeeze()
192
+
193
+ assert scores[-1].item() == 0.8686015605926514
194
+
195
+ # Test with 2 points and 1 image.
196
+ input_points = [[[400, 650], [800, 650]]]
197
+ input_labels = [[1, 1]]
198
+
199
+ inputs = processor(
200
+ images=np.array(raw_image), input_points=input_points, input_labels=input_labels, return_tensors="pt"
201
+ ).to(device)
202
+
203
+ with torch.no_grad():
204
+ output = hf_model(**inputs)
205
+ scores = output.iou_scores.squeeze()
206
+
207
+ assert scores[-1].item() == 0.9936047792434692
208
+
209
+ if pytorch_dump_folder is not None:
210
+ processor.save_pretrained(pytorch_dump_folder)
211
+ hf_model.save_pretrained(pytorch_dump_folder)
212
+
213
+ if push_to_hub:
214
+ repo_id = f"nielsr/{model_name}" if "slimsam" in model_name else f"meta/{model_name}"
215
+ processor.push_to_hub(repo_id)
216
+ hf_model.push_to_hub(repo_id)
217
+
218
+
219
+ if __name__ == "__main__":
220
+ parser = argparse.ArgumentParser()
221
+ choices = ["sam_vit_b_01ec64", "sam_vit_h_4b8939", "sam_vit_l_0b3195", "slimsam-50-uniform", "slimsam-77-uniform"]
222
+ parser.add_argument(
223
+ "--model_name",
224
+ default="sam_vit_h_4b8939",
225
+ choices=choices,
226
+ type=str,
227
+ help="Name of the original model to convert",
228
+ )
229
+ parser.add_argument(
230
+ "--checkpoint_path",
231
+ type=str,
232
+ required=False,
233
+ help="Path to the original checkpoint",
234
+ )
235
+ parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
236
+ parser.add_argument(
237
+ "--push_to_hub",
238
+ action="store_true",
239
+ help="Whether to push the model and processor to the hub after converting",
240
+ )
241
+
242
+ args = parser.parse_args()
243
+
244
+ if "slimsam" in args.model_name:
245
+ checkpoint_path = args.checkpoint_path
246
+ if checkpoint_path is None:
247
+ raise ValueError("You need to provide a checkpoint path for SlimSAM models.")
248
+ else:
249
+ checkpoint_path = hf_hub_download("ybelkada/segment-anything", f"checkpoints/{args.model_name}.pth")
250
+
251
+ convert_sam_checkpoint(args.model_name, checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub)