Student0809 commited on
Commit
c076144
·
verified ·
1 Parent(s): 0e9a03e

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. docs/transformers/src/transformers/models/blip_2/__init__.py +28 -0
  2. docs/transformers/src/transformers/models/blip_2/convert_blip_2_original_to_pytorch.py +390 -0
  3. docs/transformers/src/transformers/models/blip_2/modeling_blip_2.py +0 -0
  4. docs/transformers/src/transformers/models/blip_2/processing_blip_2.py +193 -0
  5. docs/transformers/src/transformers/models/bloom/__init__.py +29 -0
  6. docs/transformers/src/transformers/models/bloom/convert_bloom_original_checkpoint_to_pytorch.py +254 -0
  7. docs/transformers/src/transformers/models/bloom/modeling_bloom.py +1397 -0
  8. docs/transformers/src/transformers/models/bloom/modeling_flax_bloom.py +737 -0
  9. docs/transformers/src/transformers/models/bloom/tokenization_bloom_fast.py +152 -0
  10. docs/transformers/src/transformers/models/bridgetower/__init__.py +30 -0
  11. docs/transformers/src/transformers/models/bridgetower/configuration_bridgetower.py +319 -0
  12. docs/transformers/src/transformers/models/bridgetower/image_processing_bridgetower.py +541 -0
  13. docs/transformers/src/transformers/models/bridgetower/image_processing_bridgetower_fast.py +345 -0
  14. docs/transformers/src/transformers/models/bridgetower/modeling_bridgetower.py +1984 -0
  15. docs/transformers/src/transformers/models/bridgetower/processing_bridgetower.py +114 -0
  16. docs/transformers/src/transformers/models/bros/__init__.py +28 -0
  17. docs/transformers/src/transformers/models/bros/configuration_bros.py +138 -0
  18. docs/transformers/src/transformers/models/bros/convert_bros_to_pytorch.py +145 -0
  19. docs/transformers/src/transformers/models/bros/modeling_bros.py +1323 -0
  20. docs/transformers/src/transformers/models/bros/processing_bros.py +112 -0
  21. docs/transformers/src/transformers/models/byt5/__init__.py +26 -0
  22. docs/transformers/src/transformers/models/byt5/convert_byt5_original_tf_checkpoint_to_pytorch.py +59 -0
  23. docs/transformers/src/transformers/models/byt5/tokenization_byt5.py +236 -0
  24. docs/transformers/src/transformers/models/camembert/__init__.py +30 -0
  25. docs/transformers/src/transformers/models/camembert/configuration_camembert.py +155 -0
  26. docs/transformers/src/transformers/models/camembert/modeling_camembert.py +1716 -0
  27. docs/transformers/src/transformers/models/camembert/modeling_tf_camembert.py +1801 -0
  28. docs/transformers/src/transformers/models/camembert/tokenization_camembert.py +323 -0
  29. docs/transformers/src/transformers/models/camembert/tokenization_camembert_fast.py +201 -0
  30. docs/transformers/src/transformers/models/canine/__init__.py +28 -0
  31. docs/transformers/src/transformers/models/canine/configuration_canine.py +141 -0
  32. docs/transformers/src/transformers/models/canine/convert_canine_original_tf_checkpoint_to_pytorch.py +65 -0
  33. docs/transformers/src/transformers/models/canine/modeling_canine.py +1653 -0
  34. docs/transformers/src/transformers/models/canine/tokenization_canine.py +244 -0
  35. docs/transformers/src/transformers/models/chameleon/__init__.py +29 -0
  36. docs/transformers/src/transformers/models/chameleon/configuration_chameleon.py +281 -0
  37. docs/transformers/src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py +478 -0
  38. docs/transformers/src/transformers/models/chameleon/image_processing_chameleon.py +344 -0
  39. docs/transformers/src/transformers/models/chameleon/modeling_chameleon.py +1673 -0
  40. docs/transformers/src/transformers/models/chameleon/processing_chameleon.py +177 -0
  41. docs/transformers/src/transformers/models/chinese_clip/__init__.py +31 -0
  42. docs/transformers/src/transformers/models/chinese_clip/configuration_chinese_clip.py +434 -0
  43. docs/transformers/src/transformers/models/chinese_clip/convert_chinese_clip_original_pytorch_to_hf.py +134 -0
  44. docs/transformers/src/transformers/models/chinese_clip/feature_extraction_chinese_clip.py +38 -0
  45. docs/transformers/src/transformers/models/chinese_clip/image_processing_chinese_clip.py +314 -0
  46. docs/transformers/src/transformers/models/chinese_clip/image_processing_chinese_clip_fast.py +40 -0
  47. docs/transformers/src/transformers/models/chinese_clip/modeling_chinese_clip.py +1630 -0
  48. docs/transformers/src/transformers/models/chinese_clip/processing_chinese_clip.py +163 -0
  49. docs/transformers/src/transformers/models/clap/__init__.py +29 -0
  50. docs/transformers/src/transformers/models/clap/configuration_clap.py +394 -0
docs/transformers/src/transformers/models/blip_2/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_blip_2 import *
22
+ from .modeling_blip_2 import *
23
+ from .processing_blip_2 import *
24
+ else:
25
+ import sys
26
+
27
+ _file = globals()["__file__"]
28
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/src/transformers/models/blip_2/convert_blip_2_original_to_pytorch.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Convert BLIP-2 checkpoints from the original repository.
17
+
18
+ URL: https://github.com/salesforce/LAVIS/tree/main/projects/blip2
19
+ """
20
+
21
+ import argparse
22
+
23
+ import requests
24
+ import torch
25
+
26
+ # pip3 install salesforce-lavis
27
+ # I'm actually installing a slightly modified version: pip3 install -U git+https://github.com/nielsrogge/LAVIS.git@blip2_float32
28
+ # to make sure we can compare both original and HF implementation in float32
29
+ from lavis.models import load_model_and_preprocess
30
+ from PIL import Image
31
+
32
+ from transformers import (
33
+ AutoTokenizer,
34
+ BertTokenizer,
35
+ Blip2Config,
36
+ Blip2ForConditionalGeneration,
37
+ Blip2ForImageTextRetrieval,
38
+ Blip2Processor,
39
+ Blip2QFormerConfig,
40
+ Blip2VisionConfig,
41
+ BlipImageProcessor,
42
+ OPTConfig,
43
+ T5Config,
44
+ set_seed,
45
+ )
46
+ from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
47
+
48
+
49
+ def load_demo_image():
50
+ url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/merlion.png"
51
+ image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
52
+
53
+ return image
54
+
55
+
56
+ # here we list all keys to be renamed (original name on the left, our name on the right)
57
+ def create_rename_keys(config, model_name):
58
+ rename_keys = []
59
+ # fmt: off
60
+
61
+ # vision encoder
62
+ rename_keys.append(("visual_encoder.cls_token", "vision_model.embeddings.class_embedding"))
63
+ rename_keys.append(("visual_encoder.pos_embed", "vision_model.embeddings.position_embedding"))
64
+ rename_keys.append(("visual_encoder.patch_embed.proj.weight", "vision_model.embeddings.patch_embedding.weight"))
65
+ rename_keys.append(("visual_encoder.patch_embed.proj.bias", "vision_model.embeddings.patch_embedding.bias"))
66
+ rename_keys.append(("ln_vision.weight", "vision_model.post_layernorm.weight"))
67
+ rename_keys.append(("ln_vision.bias", "vision_model.post_layernorm.bias"))
68
+
69
+ for i in range(config.vision_config.num_hidden_layers):
70
+ rename_keys.append((f"visual_encoder.blocks.{i}.norm1.weight", f"vision_model.encoder.layers.{i}.layer_norm1.weight"))
71
+ rename_keys.append((f"visual_encoder.blocks.{i}.norm1.bias", f"vision_model.encoder.layers.{i}.layer_norm1.bias"))
72
+ rename_keys.append((f"visual_encoder.blocks.{i}.norm2.weight", f"vision_model.encoder.layers.{i}.layer_norm2.weight"))
73
+ rename_keys.append((f"visual_encoder.blocks.{i}.norm2.bias", f"vision_model.encoder.layers.{i}.layer_norm2.bias"))
74
+ rename_keys.append((f"visual_encoder.blocks.{i}.attn.qkv.weight", f"vision_model.encoder.layers.{i}.self_attn.qkv.weight"))
75
+ rename_keys.append((f"visual_encoder.blocks.{i}.attn.proj.weight", f"vision_model.encoder.layers.{i}.self_attn.projection.weight",))
76
+ rename_keys.append((f"visual_encoder.blocks.{i}.attn.proj.bias", f"vision_model.encoder.layers.{i}.self_attn.projection.bias"))
77
+ rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc1.weight", f"vision_model.encoder.layers.{i}.mlp.fc1.weight"))
78
+ rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc1.bias", f"vision_model.encoder.layers.{i}.mlp.fc1.bias"))
79
+ rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc2.weight", f"vision_model.encoder.layers.{i}.mlp.fc2.weight"))
80
+ rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc2.bias", f"vision_model.encoder.layers.{i}.mlp.fc2.bias"))
81
+
82
+ # QFormer
83
+ rename_keys.append(("Qformer.bert.embeddings.LayerNorm.weight", "qformer.layernorm.weight"))
84
+ rename_keys.append(("Qformer.bert.embeddings.LayerNorm.bias", "qformer.layernorm.bias"))
85
+ if "itm" in model_name:
86
+ rename_keys.append(("Qformer.bert.embeddings.word_embeddings.weight", "embeddings.word_embeddings.weight"))
87
+ rename_keys.append(("Qformer.bert.embeddings.position_embeddings.weight", "embeddings.position_embeddings.weight"))
88
+ rename_keys.append(("vision_proj.weight", "vision_projection.weight"))
89
+ rename_keys.append(("vision_proj.bias", "vision_projection.bias"))
90
+ rename_keys.append(("text_proj.weight", "text_projection.weight"))
91
+ rename_keys.append(("text_proj.bias", "text_projection.bias"))
92
+
93
+ # fmt: on
94
+ return rename_keys
95
+
96
+
97
+ def rename_key(dct, old, new):
98
+ val = dct.pop(old)
99
+ dct[new] = val
100
+
101
+
102
+ def read_in_q_v_bias(state_dict, config):
103
+ for i in range(config.vision_config.num_hidden_layers):
104
+ # read in original q and v biases
105
+ q_bias = state_dict.pop(f"visual_encoder.blocks.{i}.attn.q_bias")
106
+ v_bias = state_dict.pop(f"visual_encoder.blocks.{i}.attn.v_bias")
107
+
108
+ # next, set bias in the state dict
109
+ qkv_bias = torch.cat((q_bias, torch.zeros_like(v_bias, requires_grad=False), v_bias))
110
+ state_dict[f"vision_model.encoder.layers.{i}.self_attn.qkv.bias"] = qkv_bias
111
+
112
+
113
+ def get_blip2_config(model_name, eos_token_id):
114
+ image_size = 364 if "coco" in model_name else 224
115
+ vision_config = Blip2VisionConfig(image_size=image_size).to_dict()
116
+
117
+ # make sure the models have proper bos_token_id and eos_token_id set (important for generation)
118
+ # seems like flan-T5 models don't have bos_token_id properly set?
119
+ if "opt-2.7b" in model_name:
120
+ text_config = OPTConfig.from_pretrained("facebook/opt-2.7b", eos_token_id=eos_token_id).to_dict()
121
+ elif "opt-6.7b" in model_name:
122
+ text_config = OPTConfig.from_pretrained("facebook/opt-6.7b", eos_token_id=eos_token_id).to_dict()
123
+ elif "t5-xl" in model_name:
124
+ text_config = T5Config.from_pretrained("google/flan-t5-xl", dense_act_fn="gelu", bos_token_id=1).to_dict()
125
+ elif "t5-xxl" in model_name:
126
+ text_config = T5Config.from_pretrained("google/flan-t5-xxl", dense_act_fn="gelu", bos_token_id=1).to_dict()
127
+ elif "itm" in model_name:
128
+ text_config = {}
129
+ else:
130
+ raise ValueError("Model name not supported")
131
+
132
+ if "itm" in model_name:
133
+ config = Blip2Config(
134
+ vision_config=vision_config,
135
+ qformer_config=Blip2QFormerConfig(vocab_size=30523, use_qformer_text_input=True).to_dict(),
136
+ )
137
+ else:
138
+ config = Blip2Config(vision_config=vision_config, text_config=text_config)
139
+
140
+ return config, image_size
141
+
142
+
143
+ @torch.no_grad()
144
+ def convert_blip2_checkpoint(
145
+ model_name, pytorch_dump_folder_path=None, push_to_hub=False, lavis_device="cpu", hf_model_device="cpu"
146
+ ):
147
+ """
148
+ Copy/paste/tweak model's weights to Transformers design.
149
+ """
150
+ if "opt" in model_name:
151
+ tokenizer = AutoTokenizer.from_pretrained("facebook/opt-2.7b")
152
+ elif "itm" in model_name:
153
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="right")
154
+ tokenizer.add_special_tokens({"bos_token": "[DEC]"})
155
+ else:
156
+ tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl")
157
+
158
+ if "itm" in model_name:
159
+ eos_token_id = None
160
+ else:
161
+ eos_token_id = tokenizer("\n", add_special_tokens=False).input_ids[0]
162
+ config, image_size = get_blip2_config(model_name, eos_token_id=eos_token_id)
163
+
164
+ if "itm" in model_name:
165
+ hf_model = Blip2ForImageTextRetrieval(config).eval()
166
+ else:
167
+ hf_model = Blip2ForConditionalGeneration(config).eval()
168
+
169
+ model_name_to_original = {
170
+ "blip2-opt-2.7b": ("blip2_opt", "pretrain_opt2.7b"),
171
+ "blip2-opt-6.7b": ("blip2_opt", "pretrain_opt6.7b"),
172
+ "blip2-opt-2.7b-coco": ("blip2_opt", "caption_coco_opt2.7b"),
173
+ "blip2-opt-6.7b-coco": ("blip2_opt", "caption_coco_opt6.7b"),
174
+ "blip2-flan-t5-xl": ("blip2_t5", "pretrain_flant5xl"),
175
+ "blip2-flan-t5-xl-coco": ("blip2_t5", "caption_coco_flant5xl"),
176
+ "blip2-flan-t5-xxl": ("blip2_t5", "pretrain_flant5xxl"),
177
+ "blip2-itm-vit-g": ("blip2_image_text_matching", "pretrain"),
178
+ "blip2-itm-vit-g-coco": ("blip2_image_text_matching", "coco"),
179
+ }
180
+
181
+ name, type = model_name_to_original[model_name]
182
+
183
+ # load original model
184
+ print("Loading original model...")
185
+ original_model, vis_processors, _ = load_model_and_preprocess(
186
+ name=name, model_type=type, is_eval=True, device=lavis_device
187
+ )
188
+ original_model.eval()
189
+ print("Done!")
190
+
191
+ # update state dict keys
192
+ state_dict = original_model.state_dict()
193
+ rename_keys = create_rename_keys(config, model_name)
194
+ for src, dest in rename_keys:
195
+ rename_key(state_dict, src, dest)
196
+
197
+ # some keys can be renamed efficiently
198
+ for key, val in state_dict.copy().items():
199
+ val = state_dict.pop(key)
200
+ if key.startswith("Qformer.bert"):
201
+ key = key.replace("Qformer.bert", "qformer")
202
+ if "attention.self" in key:
203
+ key = key.replace("self", "attention")
204
+ if "opt_proj" in key:
205
+ key = key.replace("opt_proj", "language_projection")
206
+ if "t5_proj" in key:
207
+ key = key.replace("t5_proj", "language_projection")
208
+ if key.startswith("opt"):
209
+ key = key.replace("opt", "language")
210
+ if key.startswith("t5"):
211
+ key = key.replace("t5", "language")
212
+ state_dict[key] = val
213
+
214
+ # read in qv biases
215
+ read_in_q_v_bias(state_dict, config)
216
+
217
+ missing_keys, unexpected_keys = hf_model.load_state_dict(state_dict, strict=False)
218
+ assert len(missing_keys) == 0
219
+
220
+ if "itm" in model_name:
221
+ unexpected_keys = list(filter(lambda x: not x.startswith("Qformer.cls"), unexpected_keys))
222
+ assert unexpected_keys == ["temp", "qformer.embeddings.position_ids"]
223
+ else:
224
+ assert unexpected_keys == ["qformer.embeddings.position_ids"]
225
+
226
+ image = load_demo_image()
227
+ original_pixel_values = vis_processors["eval"](image).unsqueeze(0).to(lavis_device)
228
+
229
+ # create processor
230
+ image_processor = BlipImageProcessor(
231
+ size={"height": image_size, "width": image_size}, image_mean=OPENAI_CLIP_MEAN, image_std=OPENAI_CLIP_STD
232
+ )
233
+ processor = Blip2Processor(image_processor=image_processor, tokenizer=tokenizer)
234
+ pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(hf_model_device)
235
+
236
+ # make sure processor creates exact same pixel values
237
+ assert torch.allclose(pixel_values, original_pixel_values.to(pixel_values.device))
238
+
239
+ original_model.to(lavis_device)
240
+ hf_model.to(hf_model_device)
241
+
242
+ if "itm" in model_name:
243
+ caption = "a large fountain spewing water into the air"
244
+ input_ids = tokenizer([caption], return_tensors="pt").input_ids.to(hf_model_device)
245
+ attention_mask = processor(text=caption, return_tensors="pt").attention_mask.to(hf_model_device)
246
+
247
+ with torch.no_grad():
248
+ original_logits = original_model(
249
+ {"image": original_pixel_values, "text_input": [caption]}, match_head="itm"
250
+ )
251
+ logits = hf_model(
252
+ pixel_values=pixel_values,
253
+ input_ids=input_ids,
254
+ attention_mask=attention_mask,
255
+ use_image_text_matching_head=True,
256
+ )
257
+
258
+ assert original_logits.shape == logits.logits_per_image.shape
259
+ print("First values of original logits:", original_logits[0, :3])
260
+ print("First values of HF logits:", logits.logits_per_image[0, :3])
261
+
262
+ # assert values
263
+ # cast to same type
264
+ target_dtype = logits.logits_per_image.dtype
265
+ assert torch.allclose(original_logits.to(target_dtype), logits.logits_per_image, atol=1e-4)
266
+
267
+ original_itm_scores = torch.nn.functional.softmax(original_logits, dim=1)
268
+ itm_scores = torch.nn.functional.softmax(logits.logits_per_image, dim=1)
269
+ assert torch.allclose(original_itm_scores.to(target_dtype), itm_scores, atol=1e-4)
270
+ print("Looks ok!")
271
+
272
+ with torch.no_grad():
273
+ original_logits = original_model(
274
+ {"image": original_pixel_values, "text_input": [caption]}, match_head="itc"
275
+ )
276
+ logits = hf_model(
277
+ pixel_values=pixel_values,
278
+ input_ids=input_ids,
279
+ attention_mask=attention_mask,
280
+ use_image_text_matching_head=False,
281
+ )
282
+
283
+ assert original_logits.shape == logits.logits_per_image.shape
284
+ print("First values of original logits:", original_logits[0, :3])
285
+ print("First values of HF logits:", logits.logits_per_image[0, :3])
286
+
287
+ # assert values
288
+ # cast to same type
289
+ target_dtype = logits.logits_per_image.dtype
290
+ assert torch.allclose(original_logits.to(target_dtype), logits.logits_per_image, atol=1e-4)
291
+ print("Looks ok!")
292
+
293
+ else:
294
+ input_ids = tokenizer(["\n"], return_tensors="pt").input_ids.to(hf_model_device)
295
+
296
+ with torch.no_grad():
297
+ if "opt" in model_name:
298
+ original_logits = original_model({"image": original_pixel_values, "text_input": [""]}).logits
299
+ logits = hf_model(pixel_values, input_ids).logits
300
+ else:
301
+ original_logits = original_model(
302
+ {"image": original_pixel_values, "text_input": ["\n"], "text_output": ["\n"]}
303
+ ).logits
304
+ labels = input_ids.masked_fill(input_ids == tokenizer.pad_token_id, -100)
305
+ logits = hf_model(pixel_values, input_ids, labels=labels).logits
306
+
307
+ assert original_logits.shape == logits.shape
308
+ print("First values of original logits:", original_logits[0, :3, :3])
309
+ print("First values of HF logits:", logits[0, :3, :3])
310
+
311
+ # assert values
312
+ assert torch.allclose(original_logits.to(logits.device), logits, atol=1e-4)
313
+ print("Looks ok!")
314
+
315
+ print("Generating a caption...")
316
+ prompt = "Question: what object is in this image? Answer:"
317
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(hf_model_device)
318
+
319
+ set_seed(42)
320
+
321
+ original_outputs = original_model.generate(
322
+ {"image": original_pixel_values, "prompt": prompt}, use_nucleus_sampling=True, max_length=50
323
+ )
324
+ outputs = hf_model.generate(
325
+ pixel_values,
326
+ input_ids,
327
+ do_sample=True,
328
+ num_beams=5,
329
+ max_length=30,
330
+ min_length=1,
331
+ top_p=0.9,
332
+ repetition_penalty=1.0,
333
+ length_penalty=1.0,
334
+ temperature=1,
335
+ )
336
+ output_text = processor.batch_decode(outputs, skip_special_tokens=True)
337
+ output_text = [text.strip() for text in output_text]
338
+ print("Original generation:", original_outputs)
339
+ print("HF generation:", output_text)
340
+
341
+ if pytorch_dump_folder_path is not None:
342
+ processor.save_pretrained(pytorch_dump_folder_path)
343
+ hf_model.save_pretrained(pytorch_dump_folder_path)
344
+
345
+ if push_to_hub:
346
+ processor.push_to_hub(f"nielsr/{model_name}")
347
+ hf_model.push_to_hub(f"nielsr/{model_name}")
348
+
349
+
350
+ if __name__ == "__main__":
351
+ parser = argparse.ArgumentParser()
352
+ choices = [
353
+ "blip2-opt-2.7b",
354
+ "blip2-opt-6.7b",
355
+ "blip2-opt-2.7b-coco",
356
+ "blip2-opt-6.7b-coco",
357
+ "blip2-flan-t5-xl",
358
+ "blip2-flan-t5-xl-coco",
359
+ "blip2-flan-t5-xxl",
360
+ "blip2-itm-vit-g",
361
+ "blip2-itm-vit-g-coco",
362
+ ]
363
+ parser.add_argument(
364
+ "--model_name",
365
+ default="blip2-opt-2.7b",
366
+ choices=choices,
367
+ type=str,
368
+ help="Path to hf config.json of model to convert",
369
+ )
370
+ parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
371
+ parser.add_argument(
372
+ "--push_to_hub",
373
+ action="store_true",
374
+ help="Whether to push the model and processor to the hub after converting",
375
+ )
376
+ # note: this script is tested on 2 GPUs, as models are compared in float32,
377
+ # which requires quite some memory. Hence loading both on a
378
+ # separate device is the easiest to compare
379
+ parser.add_argument(
380
+ "--lavis_device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda."
381
+ )
382
+ parser.add_argument(
383
+ "--hf_model_device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda."
384
+ )
385
+
386
+ args = parser.parse_args()
387
+
388
+ convert_blip2_checkpoint(
389
+ args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.lavis_device, args.hf_model_device
390
+ )
docs/transformers/src/transformers/models/blip_2/modeling_blip_2.py ADDED
The diff for this file is too large to render. See raw diff
 
docs/transformers/src/transformers/models/blip_2/processing_blip_2.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Processor class for BLIP-2.
17
+ """
18
+
19
+ from typing import List, Optional, Union
20
+
21
+ from ...image_processing_utils import BatchFeature
22
+ from ...image_utils import ImageInput
23
+ from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
24
+ from ...tokenization_utils_base import (
25
+ AddedToken,
26
+ BatchEncoding,
27
+ PreTokenizedInput,
28
+ TextInput,
29
+ )
30
+ from ...utils import logging
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ class Blip2ProcessorKwargs(ProcessingKwargs, total=False):
37
+ _defaults = {
38
+ "text_kwargs": {
39
+ "add_special_tokens": True,
40
+ "padding": False,
41
+ "stride": 0,
42
+ "return_overflowing_tokens": False,
43
+ "return_special_tokens_mask": False,
44
+ "return_offsets_mapping": False,
45
+ "return_token_type_ids": False,
46
+ "return_length": False,
47
+ "verbose": True,
48
+ },
49
+ "images_kwargs": {},
50
+ }
51
+
52
+
53
+ class Blip2Processor(ProcessorMixin):
54
+ r"""
55
+ Constructs a BLIP-2 processor which wraps a BLIP image processor and an OPT/T5 tokenizer into a single processor.
56
+
57
+ [`BlipProcessor`] offers all the functionalities of [`BlipImageProcessor`] and [`AutoTokenizer`]. See the docstring
58
+ of [`~BlipProcessor.__call__`] and [`~BlipProcessor.decode`] for more information.
59
+
60
+ Args:
61
+ image_processor (`BlipImageProcessor`):
62
+ An instance of [`BlipImageProcessor`]. The image processor is a required input.
63
+ tokenizer (`AutoTokenizer`):
64
+ An instance of ['PreTrainedTokenizer`]. The tokenizer is a required input.
65
+ num_query_tokens (`int`, *optional*):
66
+ Number of tokens used by the Qformer as queries, should be same as in model's config.
67
+ """
68
+
69
+ attributes = ["image_processor", "tokenizer"]
70
+ valid_kwargs = ["num_query_tokens"]
71
+ image_processor_class = ("BlipImageProcessor", "BlipImageProcessorFast")
72
+ tokenizer_class = "AutoTokenizer"
73
+
74
+ def __init__(self, image_processor, tokenizer, num_query_tokens=None, **kwargs):
75
+ tokenizer.return_token_type_ids = False
76
+ self.current_processor = image_processor
77
+ if not hasattr(tokenizer, "image_token"):
78
+ self.image_token = AddedToken("<image>", normalized=False, special=True)
79
+ tokenizer.add_tokens([self.image_token], special_tokens=True)
80
+ else:
81
+ self.image_token = tokenizer.image_token
82
+ self.num_query_tokens = num_query_tokens
83
+
84
+ super().__init__(image_processor, tokenizer)
85
+
86
+ def __call__(
87
+ self,
88
+ images: ImageInput = None,
89
+ text: Optional[Union[str, List[str], TextInput, PreTokenizedInput]] = None,
90
+ audio=None,
91
+ videos=None,
92
+ **kwargs: Unpack[Blip2ProcessorKwargs],
93
+ ) -> BatchEncoding:
94
+ """
95
+ This method uses [`BlipImageProcessor.__call__`] method to prepare image(s) for the model, and
96
+ [`BertTokenizerFast.__call__`] to prepare text for the model.
97
+
98
+ Please refer to the docstring of the above two methods for more information.
99
+ Args:
100
+ images (`ImageInput`):
101
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
102
+ tensor. Both channels-first and channels-last formats are supported.
103
+ text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`):
104
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
105
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
106
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
107
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
108
+ If set, will return tensors of a particular framework. Acceptable values are:
109
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
110
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
111
+ - `'np'`: Return NumPy `np.ndarray` objects.
112
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
113
+ """
114
+ if images is None and text is None:
115
+ raise ValueError("You have to specify either images or text.")
116
+ output_kwargs = self._merge_kwargs(
117
+ Blip2ProcessorKwargs,
118
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
119
+ **kwargs,
120
+ )
121
+ # BC for explicit return_tensors
122
+ if "return_tensors" in output_kwargs["common_kwargs"]:
123
+ return_tensors = output_kwargs["common_kwargs"].pop("return_tensors", None)
124
+ else:
125
+ return_tensors = None
126
+ encoding = BatchFeature(tensor_type=return_tensors)
127
+ if text is not None:
128
+ if isinstance(text, str):
129
+ text = [text]
130
+ elif not isinstance(text, list) and not isinstance(text[0], str):
131
+ raise ValueError("Invalid input text. Please provide a string, or a list of strings")
132
+
133
+ text_encoding = {}
134
+
135
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
136
+ _text_encoding = self.tokenizer(text, **output_kwargs["text_kwargs"], return_tensors=None)
137
+ output_kwargs["text_kwargs"]["return_tensors"] = return_tensors
138
+
139
+ # if we know how many query tokens, expand text inside processor. We need this hacky manipulation
140
+ # because BLIP expects image tokens to be at the beginning even before BOS token
141
+ if self.num_query_tokens is not None:
142
+ image_tokens = self.image_token.content * self.num_query_tokens
143
+ image_token_encoding = self.tokenizer(
144
+ [image_tokens] * len(text), add_special_tokens=False, return_tensors=None
145
+ )
146
+ for k in _text_encoding:
147
+ text_encoding[k] = [
148
+ img_encoding + txt_encoding
149
+ for img_encoding, txt_encoding in zip(image_token_encoding[k], _text_encoding[k])
150
+ ]
151
+ else:
152
+ text_encoding = _text_encoding
153
+ logger.warning_once(
154
+ "Expanding inputs for image tokens in BLIP-2 should be done in processing. "
155
+ "Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your BLIP-2 model. "
156
+ "Using processors without these attributes in the config is deprecated and will throw an error in v4.50."
157
+ )
158
+
159
+ # cast to desired return tensors type
160
+ encoding.update(BatchEncoding(text_encoding, tensor_type=return_tensors))
161
+ # add pixel_values encoding. If we also have text_encoding, update image encoding and return it.
162
+ # else, return the text encoding.
163
+
164
+ if images is not None:
165
+ image_encoding = self.image_processor(images, **output_kwargs["images_kwargs"])
166
+ encoding.update(image_encoding)
167
+ return encoding
168
+
169
+ # Copied from transformers.models.blip.processing_blip.BlipProcessor.batch_decode with BertTokenizerFast->PreTrainedTokenizer
170
+ def batch_decode(self, *args, **kwargs):
171
+ """
172
+ This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
173
+ refer to the docstring of this method for more information.
174
+ """
175
+ return self.tokenizer.batch_decode(*args, **kwargs)
176
+
177
+ # Copied from transformers.models.blip.processing_blip.BlipProcessor.decode with BertTokenizerFast->PreTrainedTokenizer
178
+ def decode(self, *args, **kwargs):
179
+ """
180
+ This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
181
+ the docstring of this method for more information.
182
+ """
183
+ return self.tokenizer.decode(*args, **kwargs)
184
+
185
+ @property
186
+ # Copied from transformers.models.blip.processing_blip.BlipProcessor.model_input_names
187
+ def model_input_names(self):
188
+ tokenizer_input_names = self.tokenizer.model_input_names
189
+ image_processor_input_names = self.image_processor.model_input_names
190
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
191
+
192
+
193
+ __all__ = ["Blip2Processor"]
docs/transformers/src/transformers/models/bloom/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_bloom import *
22
+ from .modeling_bloom import *
23
+ from .modeling_flax_bloom import *
24
+ from .tokenization_bloom_fast import *
25
+ else:
26
+ import sys
27
+
28
+ _file = globals()["__file__"]
29
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/src/transformers/models/bloom/convert_bloom_original_checkpoint_to_pytorch.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert BigScience BLOOM checkpoint."""
16
+
17
+ import argparse
18
+ import json
19
+ import os
20
+ import re
21
+
22
+ import torch
23
+
24
+ from transformers import BloomConfig, BloomModel
25
+ from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME
26
+ from transformers.utils import logging
27
+
28
+
29
+ logging.set_verbosity_info()
30
+
31
+ WEIGHTS_TO_AVERAGE_ENDSWITH = [
32
+ "word_embeddings_layernorm.weight",
33
+ "word_embeddings_layernorm.bias",
34
+ "input_layernorm.weight",
35
+ "input_layernorm.bias",
36
+ "post_attention_layernorm.weight",
37
+ "post_attention_layernorm.bias",
38
+ "self_attention.dense.bias",
39
+ "mlp.dense_4h_to_h.bias",
40
+ "ln_f.weight",
41
+ "ln_f.bias",
42
+ ]
43
+
44
+ WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN = [
45
+ "mlp.dense_4h_to_h.weight",
46
+ "self_attention.dense.weight",
47
+ ]
48
+
49
+
50
+ def layer_name_mapping(key, file):
51
+ """Convert Megatron-DeepSpeed TP/PP weights mapping in transformers PP only"""
52
+ # Handle first and last layers
53
+ layer_rename_map = {
54
+ "word_embeddings.weight": "word_embeddings.weight",
55
+ "word_embeddings.norm.weight": "word_embeddings_layernorm.weight",
56
+ "word_embeddings.norm.bias": "word_embeddings_layernorm.bias",
57
+ "weight": "ln_f.weight",
58
+ "bias": "ln_f.bias",
59
+ }
60
+
61
+ if key in layer_rename_map:
62
+ return layer_rename_map[key]
63
+
64
+ # Handle transformer blocks
65
+ layer_number = int(re.match(r".*layer_(\d*).*", file)[1])
66
+ layer_number -= 3
67
+ return f"h.{layer_number}." + key
68
+
69
+
70
+ def get_dtype_size(dtype):
71
+ if dtype == torch.bool:
72
+ return 1 / 8
73
+ bit_search = re.search(r"[^\d](\d+)$", str(dtype))
74
+ if bit_search is None:
75
+ raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
76
+ bit_size = int(bit_search.groups()[0])
77
+ return bit_size // 8
78
+
79
+
80
+ def convert_bloom_checkpoint_to_pytorch(
81
+ bloom_checkpoint_path, bloom_config_file, pytorch_dump_folder_path, shard_model, pretraining_tp
82
+ ):
83
+ # Construct model
84
+ if bloom_config_file == "":
85
+ config = BloomConfig()
86
+ else:
87
+ config = BloomConfig.from_json_file(bloom_config_file)
88
+
89
+ if shard_model:
90
+ file_names = os.listdir(bloom_checkpoint_path)
91
+ file_names = sorted(filter(lambda s: s.startswith("layer") and "model_00" in s, file_names))
92
+
93
+ index_dict = {"weight_map": {}, "metadata": {}}
94
+ total_size = 0
95
+
96
+ missing_keys = None
97
+
98
+ config = BloomConfig()
99
+
100
+ for j, file in enumerate(file_names):
101
+ print("Processing file: {}".format(file))
102
+ tensors = None
103
+
104
+ for i in range(pretraining_tp):
105
+ # load all TP files
106
+ f_name = file.replace("model_00", f"model_0{i}")
107
+ temp = torch.load(os.path.join(bloom_checkpoint_path, f_name), map_location="cpu", weights_only=True)
108
+
109
+ # Rename keys in the transformers names
110
+ keys = list(temp.keys())
111
+ for key in keys:
112
+ temp[layer_name_mapping(key, file)] = temp.pop(key)
113
+
114
+ if tensors is None:
115
+ tensors = temp
116
+ else:
117
+ for key in tensors.keys():
118
+ if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
119
+ # We average (sum and then divide) some weights accross TP ranks (see https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/olruwase/sync_layer_norms/megatron/training.py#L425)
120
+ tensors[key] += temp[key]
121
+ else:
122
+ # Some weights are RowParallelLinear in Megatron-Deepspeed, others are ColumnParallel
123
+ cat_dim = 1 if any(text in key for text in WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0
124
+ # We concatenate these weights accross TP ranks
125
+ tensors[key] = torch.cat([tensors[key], temp[key]], dim=cat_dim)
126
+
127
+ # Divide by the number of TP the weights we want to average
128
+ for key in tensors.keys():
129
+ if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
130
+ tensors[key] = tensors[key] / pretraining_tp
131
+ torch.save(
132
+ tensors,
133
+ os.path.join(
134
+ pytorch_dump_folder_path,
135
+ "pytorch_model_{}-of-{}.bin".format(str(j + 1).zfill(5), str(len(file_names)).zfill(5)),
136
+ ),
137
+ )
138
+
139
+ for key in tensors.keys():
140
+ value = tensors[key]
141
+ total_size += value.numel() * get_dtype_size(value.dtype)
142
+ if key not in index_dict["weight_map"]:
143
+ index_dict["weight_map"][key] = "pytorch_model_{}-of-{}.bin".format(
144
+ str(j + 1).zfill(5), str(len(file_names)).zfill(5)
145
+ )
146
+
147
+ config = BloomConfig()
148
+ pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
149
+ index_dict["metadata"]["total_size"] = total_size
150
+ with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
151
+ f.write(config.to_json_string())
152
+ with open(os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME + ".index.json"), "w", encoding="utf-8") as f:
153
+ json_config = json.dumps(index_dict, indent=2, sort_keys=True) + "\n"
154
+ f.write(json_config)
155
+ else:
156
+ model = BloomModel(config)
157
+
158
+ file_names = os.listdir(bloom_checkpoint_path)
159
+ file_names = sorted(filter(lambda s: s.startswith("layer") and "model_00" in s, file_names))
160
+
161
+ missing_keys = None
162
+ for i, file in enumerate(file_names):
163
+ tensors = None
164
+ for i in range(pretraining_tp):
165
+ # load all TP files
166
+ f_name = file.replace("model_00", f"model_0{i}")
167
+ temp = torch.load(os.path.join(bloom_checkpoint_path, f_name), map_location="cpu", weights_only=True)
168
+
169
+ # Rename keys in the transformers names
170
+ keys = list(temp.keys())
171
+ for key in keys:
172
+ temp[layer_name_mapping(key, file)] = temp.pop(key)
173
+
174
+ if tensors is None:
175
+ tensors = temp
176
+ else:
177
+ for key in tensors.keys():
178
+ # We average (sum and then divide) some weights accross TP ranks (see https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/olruwase/sync_layer_norms/megatron/training.py#L425)
179
+ if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
180
+ tensors[key] += temp[key]
181
+ else:
182
+ # Some weights are RowParallelLinear in Megatron-Deepspeed, others are ColumnParallel
183
+ cat_dim = 1 if any(text in key for text in WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0
184
+ # We concatenate these weights accross TP ranks
185
+ tensors[key] = torch.cat([tensors[key], temp[key]], dim=cat_dim)
186
+
187
+ # Divide by the number of TP the weights we want to average
188
+ for key in tensors.keys():
189
+ if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
190
+ tensors[key] = tensors[key] / pretraining_tp
191
+
192
+ other_keys = model.load_state_dict(tensors, strict=False)
193
+ assert not other_keys.unexpected_keys, f"The keys {other_keys.unexpected_keys} are unexpected"
194
+ if missing_keys is None:
195
+ missing_keys = set(other_keys.missing_keys)
196
+ else:
197
+ missing_keys = missing_keys.intersection(set(other_keys.missing_keys))
198
+
199
+ assert not missing_keys, f"The keys {missing_keys} are missing"
200
+
201
+ # Save pytorch-model
202
+ os.makedirs(pytorch_dump_folder_path, exist_ok=True)
203
+ pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
204
+ pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
205
+ print(f"Save PyTorch model to {pytorch_weights_dump_path} with dtype {config.torch_dtype}")
206
+ if config.torch_dtype is not None:
207
+ model = model.to(config.torch_dtype)
208
+ torch.save(model.state_dict(), pytorch_weights_dump_path)
209
+ print(f"Save configuration file to {pytorch_config_dump_path}")
210
+ with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
211
+ f.write(config.to_json_string())
212
+
213
+
214
+ if __name__ == "__main__":
215
+ parser = argparse.ArgumentParser()
216
+ # Required parameters
217
+ parser.add_argument(
218
+ "--bloom_checkpoint_path",
219
+ default=None,
220
+ type=str,
221
+ required=True,
222
+ help="Path to the Megatron-LM checkpoint path.",
223
+ )
224
+ parser.add_argument(
225
+ "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
226
+ )
227
+ parser.add_argument(
228
+ "--bloom_config_file",
229
+ default="",
230
+ type=str,
231
+ help=(
232
+ "An optional config json file corresponding to the pre-trained model. \n"
233
+ "This specifies the model architecture."
234
+ ),
235
+ )
236
+ parser.add_argument(
237
+ "--shard_model",
238
+ action="store_true",
239
+ help="An optional setting to shard the output model \nThis enables sharding the converted checkpoint",
240
+ )
241
+ parser.add_argument(
242
+ "--pretraining_tp",
243
+ default=4,
244
+ type=int,
245
+ help="Pretraining TP rank that has been used when training the model in Megatron-LM \n",
246
+ )
247
+ args = parser.parse_args()
248
+ convert_bloom_checkpoint_to_pytorch(
249
+ args.bloom_checkpoint_path,
250
+ args.bloom_config_file,
251
+ args.pytorch_dump_folder_path,
252
+ args.shard_model,
253
+ args.pretraining_tp,
254
+ )
docs/transformers/src/transformers/models/bloom/modeling_bloom.py ADDED
@@ -0,0 +1,1397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 HuggingFace Inc. team and BigScience workshop.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch BLOOM model."""
16
+
17
+ import math
18
+ import warnings
19
+ from typing import Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
25
+ from torch.nn import functional as F
26
+
27
+ from ...cache_utils import Cache, DynamicCache, StaticCache
28
+ from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
29
+ from ...generation import GenerationMixin
30
+ from ...modeling_attn_mask_utils import AttentionMaskConverter
31
+ from ...modeling_outputs import (
32
+ BaseModelOutputWithPastAndCrossAttentions,
33
+ CausalLMOutputWithCrossAttentions,
34
+ QuestionAnsweringModelOutput,
35
+ SequenceClassifierOutputWithPast,
36
+ TokenClassifierOutput,
37
+ )
38
+ from ...modeling_utils import PreTrainedModel
39
+ from ...utils import (
40
+ is_torch_flex_attn_available,
41
+ logging,
42
+ )
43
+ from .configuration_bloom import BloomConfig
44
+
45
+
46
+ if is_torch_flex_attn_available():
47
+ from torch.nn.attention.flex_attention import BlockMask
48
+
49
+ from ...integrations.flex_attention import make_flex_block_causal_mask
50
+
51
+
52
+ logger = logging.get_logger(__name__)
53
+
54
+ _CHECKPOINT_FOR_DOC = "bigscience/bloom-560m"
55
+ _CONFIG_FOR_DOC = "BloomConfig"
56
+
57
+
58
+ def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
59
+ """
60
+ Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
61
+ relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
62
+ `softmax(l+a) = softmax(l)`. Based on
63
+ https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
64
+ TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
65
+
66
+ Args:
67
+ Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
68
+ attention_mask (`torch.Tensor`):
69
+ Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
70
+ num_heads (`int`):
71
+ number of heads
72
+ dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
73
+ dtype of the output tensor
74
+ """
75
+ batch_size, seq_length = attention_mask.shape
76
+ closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
77
+ base = torch.tensor(
78
+ 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
79
+ )
80
+ powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
81
+ slopes = torch.pow(base, powers)
82
+
83
+ if closest_power_of_2 != num_heads:
84
+ extra_base = torch.tensor(
85
+ 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
86
+ )
87
+ num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
88
+ extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
89
+ slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
90
+
91
+ # Note: alibi will added to the attention bias that will be applied to the query, key product of attention
92
+ # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
93
+ # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
94
+ # => the query_length dimension will then be broadcasted correctly
95
+ # This is more or less identical to T5's relative position bias:
96
+ # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
97
+ arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
98
+ alibi = slopes[..., None] * arange_tensor
99
+ return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
100
+
101
+
102
+ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
103
+ """
104
+ Dropout add function
105
+
106
+ Args:
107
+ x (`torch.tensor`):
108
+ input tensor
109
+ residual (`torch.tensor`):
110
+ residual tensor
111
+ prob (`float`):
112
+ dropout probability
113
+ training (`bool`):
114
+ training mode
115
+ """
116
+ out = F.dropout(x, p=prob, training=training)
117
+ out = residual + out
118
+ return out
119
+
120
+
121
+ def bloom_gelu_forward(x: torch.Tensor) -> torch.Tensor:
122
+ """
123
+ Custom bias GELU function. Adapted from Megatron-DeepSpeed code. Here we use a simple implementation (inference) to
124
+ make the model jitable.
125
+
126
+ Args:
127
+ x (`torch.tensor`):
128
+ input hidden states
129
+ """
130
+ return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
131
+
132
+
133
+ def bloom_gelu_back(g: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
134
+ """
135
+ gradient of tanh approximation of gelu gradient of actual gelu is: 0.5 * (1. + torch.erf(x * 0.70710678)) +
136
+ 0.3989423 * x * torch.exp(-0.5 * x * x)
137
+
138
+ Args:
139
+ g (`torch.tensor`):
140
+ gradient output tensor
141
+ x (`torch.tensor`):
142
+ input tensor
143
+ """
144
+ x = x[0] # x is a tuple of 1 element, needs to unpack it first
145
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
146
+ # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
147
+ ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
148
+ return ff * g
149
+
150
+
151
+ class GeLUFunction(torch.autograd.Function):
152
+ @staticmethod
153
+ def forward(ctx, input: torch.Tensor) -> torch.Tensor:
154
+ ctx.save_for_backward(input)
155
+ return bloom_gelu_forward(input)
156
+
157
+ @staticmethod
158
+ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
159
+ input = ctx.saved_tensors
160
+ tmp = bloom_gelu_back(grad_output, input)
161
+ return tmp
162
+
163
+
164
+ class BloomGelu(nn.Module):
165
+ """
166
+ BloomBiasGelu wrapper function that make use of the simple function on inference mode to make the model
167
+ torchscriptable and use the autograd function in training mode to get the accurate results of the gradients Partly
168
+ copied from Megatron-DeepSpeed code and adapted for our needs
169
+
170
+ See here why autograd functions are not torchscriptable: https://github.com/pytorch/pytorch/issues/22329
171
+ """
172
+
173
+ def __init__(self):
174
+ super().__init__()
175
+
176
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
177
+ if self.training:
178
+ return GeLUFunction.apply(x)
179
+ else:
180
+ return bloom_gelu_forward(x)
181
+
182
+
183
+ class BloomAttention(nn.Module):
184
+ def __init__(self, config: BloomConfig, layer_idx: Optional[int] = None):
185
+ super().__init__()
186
+
187
+ self.pretraining_tp = config.pretraining_tp
188
+ self.slow_but_exact = config.slow_but_exact
189
+
190
+ self.hidden_size = config.hidden_size
191
+ self.num_heads = config.n_head
192
+ self.head_dim = self.hidden_size // self.num_heads
193
+ self.split_size = self.hidden_size
194
+ self.hidden_dropout = config.hidden_dropout
195
+
196
+ if self.head_dim * self.num_heads != self.hidden_size:
197
+ raise ValueError(
198
+ f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
199
+ f" {self.num_heads})."
200
+ )
201
+
202
+ # Layer-wise attention scaling
203
+ self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
204
+ self.beta = 1.0
205
+ self.layer_idx = layer_idx
206
+ if layer_idx is None:
207
+ logger.warning_once(
208
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
209
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
210
+ "when creating this class."
211
+ )
212
+
213
+ self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
214
+ self.dense = nn.Linear(self.hidden_size, self.hidden_size)
215
+ self.attention_dropout = nn.Dropout(config.attention_dropout)
216
+
217
+ def _reshape(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
218
+ """
219
+ Split the last dimension into (num_heads, head_dim) and reshapes to (bs, heads, len, dim) shape
220
+ without making any copies, results share same memory storage as `fused_qkv`
221
+
222
+ Args:
223
+ fused_qkv (`torch.tensor`): [batch_size, seq_length, num_heads * 3 * head_dim]
224
+
225
+ Returns:
226
+ query: [batch_size, num_heads, seq_length, head_dim]
227
+ key: [batch_size, num_heads, seq_length, head_dim]
228
+ value: [batch_size, num_heads, seq_length, head_dim]
229
+ """
230
+ batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
231
+ fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
232
+ query_layer = fused_qkv[..., 0, :].transpose(1, 2)
233
+ key_layer = fused_qkv[..., 1, :].transpose(1, 2)
234
+ value_layer = fused_qkv[..., 2, :].transpose(1, 2)
235
+ return query_layer, key_layer, value_layer
236
+
237
+ def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
238
+ """
239
+ Merge heads together over the last dimension
240
+
241
+ Args:
242
+ x (`torch.tensor`): [batch_size * num_heads, seq_length, head_dim]
243
+
244
+ Returns:
245
+ torch.tensor: [batch_size, seq_length, num_heads * head_dim]
246
+ """
247
+ # What we want to achieve is:
248
+ # batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
249
+ batch_size_and_num_heads, seq_length, _ = x.shape
250
+ batch_size = batch_size_and_num_heads // self.num_heads
251
+
252
+ # First view to decompose the batch size
253
+ # batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
254
+ x = x.view(batch_size, self.num_heads, seq_length, self.head_dim)
255
+
256
+ # batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
257
+ x = x.permute(0, 2, 1, 3)
258
+
259
+ # batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
260
+ return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim)
261
+
262
+ def forward(
263
+ self,
264
+ hidden_states: torch.Tensor,
265
+ residual: torch.Tensor,
266
+ alibi: torch.Tensor,
267
+ attention_mask: torch.Tensor,
268
+ layer_past: Optional[Cache] = None,
269
+ head_mask: Optional[torch.Tensor] = None,
270
+ use_cache: bool = False,
271
+ output_attentions: bool = False,
272
+ cache_position: Optional[torch.LongTensor] = None,
273
+ ):
274
+ batch_size, q_length, _ = hidden_states.shape
275
+ fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
276
+ # 3 x [batch_size, num_heads, seq_length, head_dim]
277
+ query_layer, key_layer, value_layer = self._reshape(fused_qkv)
278
+
279
+ if layer_past is not None:
280
+ cache_kwargs = {"cache_position": cache_position}
281
+ key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs)
282
+
283
+ # reshape qkv for further computations
284
+ query_layer = query_layer.reshape(batch_size * self.num_heads, -1, self.head_dim)
285
+ key_layer = key_layer.reshape(batch_size * self.num_heads, -1, self.head_dim).transpose(-1, -2)
286
+ value_layer = value_layer.reshape(batch_size * self.num_heads, -1, self.head_dim)
287
+
288
+ # [batch_size * num_heads, q_length, kv_length]
289
+ attention_scores = alibi.baddbmm(
290
+ batch1=query_layer,
291
+ batch2=key_layer,
292
+ beta=self.beta,
293
+ alpha=self.inv_norm_factor,
294
+ )
295
+
296
+ # change view to [batch_size, num_heads, q_length, kv_length]
297
+ attn_weights = attention_scores.view(batch_size, self.num_heads, q_length, -1)
298
+ if attention_mask is not None: # no matter the length, we just slice it
299
+ causal_mask = attention_mask[:, :, :, : key_layer.shape[-1]]
300
+ attn_weights = attn_weights + causal_mask
301
+
302
+ # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype
303
+ attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_layer.dtype)
304
+
305
+ # [batch_size, num_heads, q_length, kv_length]
306
+ attention_probs = self.attention_dropout(attention_probs)
307
+
308
+ if head_mask is not None:
309
+ attention_probs = attention_probs * head_mask
310
+
311
+ # change view [batch_size x num_heads, q_length, kv_length]
312
+ attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, -1)
313
+
314
+ # matmul: [batch_size * num_heads, q_length, head_dim]
315
+ context_layer = torch.bmm(attention_probs_reshaped, value_layer)
316
+
317
+ # change view [batch_size, q_length, num_heads * head_dim]
318
+ context_layer = self._merge_heads(context_layer)
319
+
320
+ # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
321
+ if self.pretraining_tp > 1 and self.slow_but_exact:
322
+ slices = self.hidden_size / self.pretraining_tp
323
+ output_tensor = torch.zeros_like(context_layer)
324
+ for i in range(self.pretraining_tp):
325
+ output_tensor = output_tensor + F.linear(
326
+ context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
327
+ self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
328
+ )
329
+ else:
330
+ output_tensor = self.dense(context_layer)
331
+
332
+ output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
333
+
334
+ outputs = (output_tensor, layer_past)
335
+ if output_attentions:
336
+ outputs += (attention_probs,)
337
+
338
+ return outputs
339
+
340
+
341
+ class BloomMLP(nn.Module):
342
+ def __init__(self, config: BloomConfig):
343
+ super().__init__()
344
+ hidden_size = config.hidden_size
345
+
346
+ self.pretraining_tp = config.pretraining_tp
347
+ self.slow_but_exact = config.slow_but_exact
348
+ self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size)
349
+ self.gelu_impl = BloomGelu()
350
+ self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size)
351
+ self.hidden_dropout = config.hidden_dropout
352
+
353
+ def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
354
+ hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
355
+
356
+ if self.pretraining_tp > 1 and self.slow_but_exact:
357
+ intermediate_output = torch.zeros_like(residual)
358
+ slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp
359
+ for i in range(self.pretraining_tp):
360
+ intermediate_output = intermediate_output + F.linear(
361
+ hidden_states[:, :, int(i * slices) : int((i + 1) * slices)],
362
+ self.dense_4h_to_h.weight[:, int(i * slices) : int((i + 1) * slices)],
363
+ )
364
+ else:
365
+ intermediate_output = self.dense_4h_to_h(hidden_states)
366
+
367
+ output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
368
+
369
+ return output
370
+
371
+
372
+ class BloomBlock(nn.Module):
373
+ def __init__(self, config: BloomConfig, layer_idx: Optional[int] = None):
374
+ super().__init__()
375
+ hidden_size = config.hidden_size
376
+
377
+ self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
378
+ self.num_heads = config.n_head
379
+ self.self_attention = BloomAttention(config, layer_idx)
380
+ self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
381
+
382
+ self.mlp = BloomMLP(config)
383
+
384
+ self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
385
+ self.hidden_dropout = config.hidden_dropout
386
+
387
+ def forward(
388
+ self,
389
+ hidden_states: torch.Tensor,
390
+ alibi: torch.Tensor,
391
+ attention_mask: torch.Tensor,
392
+ layer_past: Optional[Cache] = None,
393
+ head_mask: Optional[torch.Tensor] = None,
394
+ use_cache: bool = False,
395
+ output_attentions: bool = False,
396
+ cache_position: Optional[torch.LongTensor] = None,
397
+ ):
398
+ # hidden_states: [batch_size, seq_length, hidden_size]
399
+
400
+ # Layer norm at the beginning of the transformer layer.
401
+ layernorm_output = self.input_layernorm(hidden_states)
402
+
403
+ # Layer norm post the self attention.
404
+ if self.apply_residual_connection_post_layernorm:
405
+ residual = layernorm_output
406
+ else:
407
+ residual = hidden_states
408
+
409
+ # Self attention.
410
+ attn_outputs = self.self_attention(
411
+ layernorm_output,
412
+ residual,
413
+ layer_past=layer_past,
414
+ attention_mask=attention_mask,
415
+ alibi=alibi,
416
+ head_mask=head_mask,
417
+ use_cache=use_cache,
418
+ output_attentions=output_attentions,
419
+ cache_position=cache_position,
420
+ )
421
+
422
+ attention_output = attn_outputs[0]
423
+
424
+ outputs = attn_outputs[1:]
425
+
426
+ layernorm_output = self.post_attention_layernorm(attention_output)
427
+
428
+ # Get residual
429
+ if self.apply_residual_connection_post_layernorm:
430
+ residual = layernorm_output
431
+ else:
432
+ residual = attention_output
433
+
434
+ # MLP.
435
+ output = self.mlp(layernorm_output, residual)
436
+
437
+ if use_cache:
438
+ outputs = (output,) + outputs
439
+ else:
440
+ outputs = (output,) + outputs[1:]
441
+
442
+ return outputs # hidden_states, past_kv, attentions
443
+
444
+
445
+ class BloomPreTrainedModel(PreTrainedModel):
446
+ config_class = BloomConfig
447
+ base_model_prefix = "transformer"
448
+ supports_gradient_checkpointing = True
449
+ _no_split_modules = ["BloomBlock"]
450
+ _skip_keys_device_placement = "past_key_values"
451
+ _supports_cache_class = True
452
+ _supports_static_cache = True
453
+ _supports_quantized_cache = True
454
+
455
+ def __init__(self, *inputs, **kwargs):
456
+ super().__init__(*inputs, **kwargs)
457
+
458
+ def _init_weights(self, module: nn.Module):
459
+ """Initialize the weights."""
460
+ if isinstance(module, nn.Linear):
461
+ # Slightly different from the TF version which uses truncated_normal for initialization
462
+ # cf https://github.com/pytorch/pytorch/pull/5617
463
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
464
+ if module.bias is not None:
465
+ module.bias.data.zero_()
466
+ elif isinstance(module, nn.Embedding):
467
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
468
+ if module.padding_idx is not None:
469
+ module.weight.data[module.padding_idx].zero_()
470
+ elif isinstance(module, LayerNorm):
471
+ module.bias.data.zero_()
472
+ module.weight.data.fill_(1.0)
473
+
474
+
475
+ BLOOM_START_DOCSTRING = r"""
476
+
477
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
478
+ library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)
479
+
480
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
481
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
482
+ and behavior.
483
+
484
+ Parameters:
485
+ config ([`BloomConfig`]): Model configuration class with all the parameters of the model.
486
+ Initializing with a config file does not load the weights associated with the model, only the
487
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
488
+ """
489
+
490
+ BLOOM_INPUTS_DOCSTRING = r"""
491
+ Args:
492
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
493
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]`
494
+ (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
495
+
496
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
497
+ `input_ids`.
498
+
499
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
500
+ [`PreTrainedTokenizer.__call__`] for details.
501
+
502
+ [What are input IDs?](../glossary#input-ids)
503
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
504
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
505
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
506
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
507
+
508
+ Two formats are allowed:
509
+ - a [`~cache_utils.Cache`] instance, see our
510
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
511
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
512
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
513
+ cache format.
514
+
515
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
516
+ legacy cache format will be returned.
517
+
518
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
519
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
520
+ of shape `(batch_size, sequence_length)`.
521
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
522
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
523
+
524
+ - 1 for tokens that are **not masked**,
525
+ - 0 for tokens that are **masked**.
526
+
527
+ [What are attention masks?](../glossary#attention-mask)
528
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
529
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
530
+
531
+ - 1 indicates the head is **not masked**,
532
+ - 0 indicates the head is **masked**.
533
+
534
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
535
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
536
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
537
+ model's internal embedding lookup matrix.
538
+
539
+ If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
540
+ `past_key_values`).
541
+ use_cache (`bool`, *optional*):
542
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
543
+ `past_key_values`).
544
+ output_attentions (`bool`, *optional*):
545
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
546
+ tensors for more detail.
547
+ output_hidden_states (`bool`, *optional*):
548
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
549
+ more detail.
550
+ return_dict (`bool`, *optional*):
551
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
552
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
553
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
554
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
555
+ the complete sequence length.
556
+ """
557
+
558
+
559
+ @add_start_docstrings(
560
+ "The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.",
561
+ BLOOM_START_DOCSTRING,
562
+ )
563
+ class BloomModel(BloomPreTrainedModel):
564
+ def __init__(self, config: BloomConfig):
565
+ super().__init__(config)
566
+
567
+ self.embed_dim = config.hidden_size
568
+ self.num_heads = config.n_head
569
+
570
+ # Embedding + LN Embedding
571
+ self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
572
+ self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
573
+
574
+ # Transformer blocks
575
+ self.h = nn.ModuleList([BloomBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])
576
+
577
+ # Final Layer Norm
578
+ self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
579
+
580
+ self.gradient_checkpointing = False
581
+
582
+ # Initialize weights and apply final processing
583
+ self.post_init()
584
+
585
+ def build_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
586
+ return build_alibi_tensor(attention_mask, num_heads, dtype)
587
+
588
+ def get_input_embeddings(self):
589
+ return self.word_embeddings
590
+
591
+ def set_input_embeddings(self, new_embeddings: torch.Tensor):
592
+ self.word_embeddings = new_embeddings
593
+
594
+ @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
595
+ @add_code_sample_docstrings(
596
+ checkpoint=_CHECKPOINT_FOR_DOC,
597
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
598
+ config_class=_CONFIG_FOR_DOC,
599
+ )
600
+ def forward(
601
+ self,
602
+ input_ids: Optional[torch.LongTensor] = None,
603
+ past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None,
604
+ attention_mask: Optional[torch.Tensor] = None,
605
+ head_mask: Optional[torch.LongTensor] = None,
606
+ inputs_embeds: Optional[torch.LongTensor] = None,
607
+ use_cache: Optional[bool] = None,
608
+ output_attentions: Optional[bool] = None,
609
+ output_hidden_states: Optional[bool] = None,
610
+ return_dict: Optional[bool] = None,
611
+ cache_position: Optional[torch.LongTensor] = None,
612
+ **deprecated_arguments,
613
+ ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
614
+ if deprecated_arguments.pop("position_ids", False) is not False:
615
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
616
+ warnings.warn(
617
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
618
+ " passing `position_ids`.",
619
+ FutureWarning,
620
+ )
621
+ if len(deprecated_arguments) > 0:
622
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
623
+
624
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
625
+ output_hidden_states = (
626
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
627
+ )
628
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
629
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
630
+
631
+ if (input_ids is None) ^ (inputs_embeds is not None):
632
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
633
+
634
+ if self.gradient_checkpointing and self.training and use_cache:
635
+ logger.warning_once(
636
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
637
+ )
638
+ use_cache = False
639
+
640
+ if inputs_embeds is None:
641
+ inputs_embeds = self.word_embeddings(input_ids)
642
+
643
+ # kept for BC (non `Cache` `past_key_values` inputs)
644
+ return_legacy_cache = False
645
+ if use_cache and not isinstance(past_key_values, Cache):
646
+ return_legacy_cache = True
647
+ if past_key_values is None:
648
+ past_key_values = DynamicCache()
649
+ else:
650
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
651
+ logger.warning_once(
652
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
653
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
654
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
655
+ )
656
+
657
+ batch_size, seq_length, _ = inputs_embeds.shape
658
+ past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
659
+ seq_length_with_past = seq_length + past_length
660
+ if cache_position is None:
661
+ cache_position = torch.arange(past_length, past_length + seq_length, device=inputs_embeds.device)
662
+
663
+ # Prepare head mask if needed
664
+ # 1.0 in head_mask indicate we keep the head
665
+ # attention_probs has shape batch_size x num_heads x N x N
666
+ # head_mask has shape n_layer x batch x num_heads x N x N
667
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
668
+ hidden_states = self.word_embeddings_layernorm(inputs_embeds)
669
+
670
+ next_decoder_cache = None
671
+ all_self_attentions = () if output_attentions else None
672
+ all_hidden_states = () if output_hidden_states else None
673
+
674
+ # Compute alibi tensor: check build_alibi_tensor documentation
675
+ if attention_mask is None:
676
+ attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
677
+ else:
678
+ attention_mask = attention_mask.to(hidden_states.device)
679
+
680
+ alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
681
+ causal_mask = self._update_causal_mask(
682
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
683
+ )
684
+
685
+ for i, block in enumerate(self.h):
686
+ if output_hidden_states:
687
+ all_hidden_states = all_hidden_states + (hidden_states,)
688
+
689
+ if self.gradient_checkpointing and self.training:
690
+ outputs = self._gradient_checkpointing_func(
691
+ block.__call__,
692
+ hidden_states,
693
+ alibi,
694
+ causal_mask,
695
+ past_key_values,
696
+ head_mask[i],
697
+ use_cache,
698
+ output_attentions,
699
+ cache_position,
700
+ )
701
+ else:
702
+ outputs = block(
703
+ hidden_states,
704
+ layer_past=past_key_values,
705
+ attention_mask=causal_mask,
706
+ head_mask=head_mask[i],
707
+ use_cache=use_cache,
708
+ output_attentions=output_attentions,
709
+ alibi=alibi,
710
+ cache_position=cache_position,
711
+ )
712
+
713
+ hidden_states = outputs[0]
714
+ if use_cache:
715
+ next_decoder_cache = outputs[1]
716
+
717
+ if output_attentions:
718
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
719
+
720
+ # Add last hidden state
721
+ hidden_states = self.ln_f(hidden_states)
722
+
723
+ if output_hidden_states:
724
+ all_hidden_states = all_hidden_states + (hidden_states,)
725
+
726
+ next_cache = next_decoder_cache if use_cache else None
727
+ if return_legacy_cache:
728
+ next_cache = next_cache.to_legacy_cache()
729
+
730
+ if not return_dict:
731
+ return tuple(
732
+ v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None
733
+ )
734
+
735
+ return BaseModelOutputWithPastAndCrossAttentions(
736
+ last_hidden_state=hidden_states,
737
+ past_key_values=next_cache,
738
+ hidden_states=all_hidden_states,
739
+ attentions=all_self_attentions,
740
+ )
741
+
742
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
743
+ def _update_causal_mask(
744
+ self,
745
+ attention_mask: Union[torch.Tensor, "BlockMask"],
746
+ input_tensor: torch.Tensor,
747
+ cache_position: torch.Tensor,
748
+ past_key_values: Cache,
749
+ output_attentions: bool = False,
750
+ ):
751
+ if self.config._attn_implementation == "flash_attention_2":
752
+ if attention_mask is not None and (attention_mask == 0.0).any():
753
+ return attention_mask
754
+ return None
755
+ if self.config._attn_implementation == "flex_attention":
756
+ if isinstance(attention_mask, torch.Tensor):
757
+ attention_mask = make_flex_block_causal_mask(attention_mask)
758
+ return attention_mask
759
+
760
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
761
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
762
+ # to infer the attention mask.
763
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
764
+ using_static_cache = isinstance(past_key_values, StaticCache)
765
+
766
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
767
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
768
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
769
+ attention_mask,
770
+ inputs_embeds=input_tensor,
771
+ past_key_values_length=past_seen_tokens,
772
+ is_training=self.training,
773
+ ):
774
+ return None
775
+
776
+ dtype, device = input_tensor.dtype, input_tensor.device
777
+ sequence_length = input_tensor.shape[1]
778
+ if using_static_cache:
779
+ target_length = past_key_values.get_max_cache_shape()
780
+ else:
781
+ target_length = (
782
+ attention_mask.shape[-1]
783
+ if isinstance(attention_mask, torch.Tensor)
784
+ else past_seen_tokens + sequence_length + 1
785
+ )
786
+
787
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
788
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
789
+ attention_mask,
790
+ sequence_length=sequence_length,
791
+ target_length=target_length,
792
+ dtype=dtype,
793
+ device=device,
794
+ cache_position=cache_position,
795
+ batch_size=input_tensor.shape[0],
796
+ )
797
+
798
+ if (
799
+ self.config._attn_implementation == "sdpa"
800
+ and attention_mask is not None
801
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
802
+ and not output_attentions
803
+ ):
804
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
805
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
806
+ # Details: https://github.com/pytorch/pytorch/issues/110213
807
+ min_dtype = torch.finfo(dtype).min
808
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
809
+
810
+ return causal_mask
811
+
812
+ @staticmethod
813
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
814
+ def _prepare_4d_causal_attention_mask_with_cache_position(
815
+ attention_mask: torch.Tensor,
816
+ sequence_length: int,
817
+ target_length: int,
818
+ dtype: torch.dtype,
819
+ device: torch.device,
820
+ cache_position: torch.Tensor,
821
+ batch_size: int,
822
+ **kwargs,
823
+ ):
824
+ """
825
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
826
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
827
+
828
+ Args:
829
+ attention_mask (`torch.Tensor`):
830
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
831
+ `(batch_size, 1, query_length, key_value_length)`.
832
+ sequence_length (`int`):
833
+ The sequence length being processed.
834
+ target_length (`int`):
835
+ The target length: when generating with static cache, the mask should be as long as the static cache,
836
+ to account for the 0 padding, the part of the cache that is not filled yet.
837
+ dtype (`torch.dtype`):
838
+ The dtype to use for the 4D attention mask.
839
+ device (`torch.device`):
840
+ The device to place the 4D attention mask on.
841
+ cache_position (`torch.Tensor`):
842
+ Indices depicting the position of the input sequence tokens in the sequence.
843
+ batch_size (`torch.Tensor`):
844
+ Batch size.
845
+ """
846
+ if attention_mask is not None and attention_mask.dim() == 4:
847
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
848
+ causal_mask = attention_mask
849
+ else:
850
+ min_dtype = torch.finfo(dtype).min
851
+ causal_mask = torch.full(
852
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
853
+ )
854
+ if sequence_length != 1:
855
+ causal_mask = torch.triu(causal_mask, diagonal=1)
856
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
857
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
858
+ if attention_mask is not None:
859
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
860
+ mask_length = attention_mask.shape[-1]
861
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
862
+ causal_mask.device
863
+ )
864
+ padding_mask = padding_mask == 0
865
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
866
+ padding_mask, min_dtype
867
+ )
868
+
869
+ return causal_mask
870
+
871
+
872
+ @add_start_docstrings(
873
+ """
874
+ The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input
875
+ embeddings).
876
+ """,
877
+ BLOOM_START_DOCSTRING,
878
+ )
879
+ class BloomForCausalLM(BloomPreTrainedModel, GenerationMixin):
880
+ _tied_weights_keys = ["lm_head.weight"]
881
+
882
+ def __init__(self, config: BloomConfig):
883
+ super().__init__(config)
884
+ self.transformer = BloomModel(config)
885
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
886
+
887
+ # Initialize weights and apply final processing
888
+ self.post_init()
889
+
890
+ def get_output_embeddings(self):
891
+ return self.lm_head
892
+
893
+ def set_output_embeddings(self, new_embeddings: torch.Tensor):
894
+ self.lm_head = new_embeddings
895
+
896
+ def prepare_inputs_for_generation(
897
+ self,
898
+ input_ids,
899
+ past_key_values=None,
900
+ attention_mask=None,
901
+ inputs_embeds=None,
902
+ cache_position=None,
903
+ use_cache=True,
904
+ **kwargs,
905
+ ):
906
+ # Overwriten because of the fixed-shape attention mask creation
907
+
908
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
909
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
910
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
911
+ # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
912
+ # (we can't check exception 3 while compiling)
913
+ # Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
914
+ # generate the first token for each sequence. Later use the generated Input ids for continuation.
915
+ if past_key_values is not None:
916
+ if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
917
+ inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
918
+ elif (
919
+ inputs_embeds is not None # Exception 1
920
+ or cache_position[-1] >= input_ids.shape[1] # Exception 3
921
+ ):
922
+ input_ids = input_ids[:, -cache_position.shape[0] :]
923
+ elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
924
+ input_ids = input_ids[:, cache_position]
925
+
926
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
927
+ if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
928
+ model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
929
+ else:
930
+ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the
931
+ # input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in
932
+ # the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
933
+ model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
934
+
935
+ # This part differs from other models because BLOOM needs a 2D mask to construct alibi tensor
936
+ # The only difference is the usage of 2D instead of 4D mask, but the shape will be static
937
+ if isinstance(past_key_values, StaticCache) and attention_mask is not None:
938
+ target_length = past_key_values.get_max_cache_shape()
939
+ batch_size, seq_length = attention_mask.shape
940
+ diff = target_length - seq_length
941
+
942
+ new_attn_mask = torch.zeros(batch_size, diff, device=attention_mask.device, dtype=attention_mask.dtype)
943
+ attention_mask = torch.cat(
944
+ [attention_mask, new_attn_mask],
945
+ dim=-1,
946
+ )
947
+
948
+ model_inputs.update(
949
+ {
950
+ "cache_position": cache_position,
951
+ "past_key_values": past_key_values,
952
+ "use_cache": use_cache,
953
+ "attention_mask": attention_mask,
954
+ }
955
+ )
956
+ return model_inputs
957
+
958
+ @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
959
+ @add_code_sample_docstrings(
960
+ checkpoint=_CHECKPOINT_FOR_DOC,
961
+ output_type=CausalLMOutputWithCrossAttentions,
962
+ config_class=_CONFIG_FOR_DOC,
963
+ )
964
+ def forward(
965
+ self,
966
+ input_ids: Optional[torch.LongTensor] = None,
967
+ past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None,
968
+ attention_mask: Optional[torch.Tensor] = None,
969
+ head_mask: Optional[torch.Tensor] = None,
970
+ inputs_embeds: Optional[torch.Tensor] = None,
971
+ labels: Optional[torch.Tensor] = None,
972
+ use_cache: Optional[bool] = None,
973
+ output_attentions: Optional[bool] = None,
974
+ output_hidden_states: Optional[bool] = None,
975
+ return_dict: Optional[bool] = None,
976
+ cache_position: Optional[torch.LongTensor] = None,
977
+ **deprecated_arguments,
978
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
979
+ r"""
980
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
981
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
982
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
983
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
984
+ """
985
+ # Bloom has deprecated kwargs, so we need to pop num_items_in_batch explicitly
986
+ num_items_in_batch = deprecated_arguments.pop("num_items_in_batch", None)
987
+ if deprecated_arguments.pop("position_ids", False) is not False:
988
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
989
+ warnings.warn(
990
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
991
+ " passing `position_ids`.",
992
+ FutureWarning,
993
+ )
994
+ if len(deprecated_arguments) > 0:
995
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
996
+
997
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
998
+
999
+ transformer_outputs = self.transformer(
1000
+ input_ids,
1001
+ past_key_values=past_key_values,
1002
+ attention_mask=attention_mask,
1003
+ head_mask=head_mask,
1004
+ inputs_embeds=inputs_embeds,
1005
+ use_cache=use_cache,
1006
+ output_attentions=output_attentions,
1007
+ output_hidden_states=output_hidden_states,
1008
+ return_dict=return_dict,
1009
+ cache_position=cache_position,
1010
+ )
1011
+ hidden_states = transformer_outputs[0]
1012
+
1013
+ lm_logits = self.lm_head(hidden_states)
1014
+
1015
+ loss = None
1016
+ if labels is not None:
1017
+ # move labels to correct device to enable model parallelism
1018
+ labels = labels.to(lm_logits.device)
1019
+ # Flatten the tokens
1020
+ loss = self.loss_function(
1021
+ lm_logits,
1022
+ labels,
1023
+ vocab_size=self.config.vocab_size,
1024
+ num_items_in_batch=num_items_in_batch,
1025
+ )
1026
+
1027
+ if not return_dict:
1028
+ output = (lm_logits,) + transformer_outputs[1:]
1029
+ return ((loss,) + output) if loss is not None else output
1030
+
1031
+ return CausalLMOutputWithCrossAttentions(
1032
+ loss=loss,
1033
+ logits=lm_logits,
1034
+ past_key_values=transformer_outputs.past_key_values,
1035
+ hidden_states=transformer_outputs.hidden_states,
1036
+ attentions=transformer_outputs.attentions,
1037
+ )
1038
+
1039
+ def _reorder_cache(
1040
+ self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
1041
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
1042
+ """
1043
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1044
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1045
+ beam_idx at every generation step.
1046
+
1047
+ Output shares the same memory storage as `past`.
1048
+ """
1049
+ # Get a copy of `beam_idx` on all the devices where we need those indices.
1050
+ device_to_beam_idx = {
1051
+ past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
1052
+ }
1053
+ reordered_past = tuple(
1054
+ (
1055
+ layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
1056
+ layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
1057
+ )
1058
+ for layer_past in past
1059
+ )
1060
+ return reordered_past
1061
+
1062
+
1063
+ @add_start_docstrings(
1064
+ """
1065
+ The Bloom Model transformer with a sequence classification head on top (linear layer).
1066
+
1067
+ [`BloomForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1068
+ (e.g. GPT-1) do.
1069
+
1070
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1071
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1072
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1073
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1074
+ each row of the batch).
1075
+ """,
1076
+ BLOOM_START_DOCSTRING,
1077
+ )
1078
+ class BloomForSequenceClassification(BloomPreTrainedModel):
1079
+ def __init__(self, config: BloomConfig):
1080
+ super().__init__(config)
1081
+ self.num_labels = config.num_labels
1082
+ self.transformer = BloomModel(config)
1083
+ self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
1084
+
1085
+ # Initialize weights and apply final processing
1086
+ self.post_init()
1087
+
1088
+ @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
1089
+ @add_code_sample_docstrings(
1090
+ checkpoint=_CHECKPOINT_FOR_DOC,
1091
+ output_type=SequenceClassifierOutputWithPast,
1092
+ config_class=_CONFIG_FOR_DOC,
1093
+ )
1094
+ def forward(
1095
+ self,
1096
+ input_ids: Optional[torch.LongTensor] = None,
1097
+ past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None,
1098
+ attention_mask: Optional[torch.Tensor] = None,
1099
+ head_mask: Optional[torch.Tensor] = None,
1100
+ inputs_embeds: Optional[torch.Tensor] = None,
1101
+ labels: Optional[torch.Tensor] = None,
1102
+ use_cache: Optional[bool] = None,
1103
+ output_attentions: Optional[bool] = None,
1104
+ output_hidden_states: Optional[bool] = None,
1105
+ return_dict: Optional[bool] = None,
1106
+ **deprecated_arguments,
1107
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
1108
+ r"""
1109
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1110
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1111
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1112
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1113
+ """
1114
+ if deprecated_arguments.pop("position_ids", False) is not False:
1115
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
1116
+ warnings.warn(
1117
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
1118
+ " passing `position_ids`.",
1119
+ FutureWarning,
1120
+ )
1121
+ if len(deprecated_arguments) > 0:
1122
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
1123
+
1124
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1125
+
1126
+ transformer_outputs = self.transformer(
1127
+ input_ids,
1128
+ past_key_values=past_key_values,
1129
+ attention_mask=attention_mask,
1130
+ head_mask=head_mask,
1131
+ inputs_embeds=inputs_embeds,
1132
+ use_cache=use_cache,
1133
+ output_attentions=output_attentions,
1134
+ output_hidden_states=output_hidden_states,
1135
+ return_dict=return_dict,
1136
+ )
1137
+
1138
+ hidden_states = transformer_outputs[0]
1139
+ logits = self.score(hidden_states)
1140
+
1141
+ if input_ids is not None:
1142
+ batch_size = input_ids.shape[0]
1143
+ else:
1144
+ batch_size = inputs_embeds.shape[0]
1145
+
1146
+ if self.config.pad_token_id is None and batch_size != 1:
1147
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1148
+ if self.config.pad_token_id is None:
1149
+ last_non_pad_token = -1
1150
+ elif input_ids is not None:
1151
+ # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
1152
+ non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
1153
+ token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
1154
+ last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
1155
+ else:
1156
+ last_non_pad_token = -1
1157
+ logger.warning_once(
1158
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1159
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1160
+ )
1161
+
1162
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
1163
+
1164
+ loss = None
1165
+ if labels is not None:
1166
+ if self.config.problem_type is None:
1167
+ if self.num_labels == 1:
1168
+ self.config.problem_type = "regression"
1169
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1170
+ self.config.problem_type = "single_label_classification"
1171
+ else:
1172
+ self.config.problem_type = "multi_label_classification"
1173
+
1174
+ if self.config.problem_type == "regression":
1175
+ loss_fct = MSELoss()
1176
+ if self.num_labels == 1:
1177
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1178
+ else:
1179
+ loss = loss_fct(pooled_logits, labels)
1180
+ elif self.config.problem_type == "single_label_classification":
1181
+ loss_fct = CrossEntropyLoss()
1182
+ loss = loss_fct(pooled_logits, labels)
1183
+ elif self.config.problem_type == "multi_label_classification":
1184
+ loss_fct = BCEWithLogitsLoss()
1185
+ loss = loss_fct(pooled_logits, labels)
1186
+ if not return_dict:
1187
+ output = (pooled_logits,) + transformer_outputs[1:]
1188
+ return ((loss,) + output) if loss is not None else output
1189
+
1190
+ return SequenceClassifierOutputWithPast(
1191
+ loss=loss,
1192
+ logits=pooled_logits,
1193
+ past_key_values=transformer_outputs.past_key_values,
1194
+ hidden_states=transformer_outputs.hidden_states,
1195
+ attentions=transformer_outputs.attentions,
1196
+ )
1197
+
1198
+
1199
+ @add_start_docstrings(
1200
+ """
1201
+ Bloom Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1202
+ Named-Entity-Recognition (NER) tasks.
1203
+ """,
1204
+ BLOOM_START_DOCSTRING,
1205
+ )
1206
+ class BloomForTokenClassification(BloomPreTrainedModel):
1207
+ def __init__(self, config: BloomConfig):
1208
+ super().__init__(config)
1209
+ self.num_labels = config.num_labels
1210
+
1211
+ self.transformer = BloomModel(config)
1212
+ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
1213
+ classifier_dropout = config.classifier_dropout
1214
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
1215
+ classifier_dropout = config.hidden_dropout
1216
+ else:
1217
+ classifier_dropout = 0.1
1218
+ self.dropout = nn.Dropout(classifier_dropout)
1219
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1220
+
1221
+ # Initialize weights and apply final processing
1222
+ self.post_init()
1223
+
1224
+ @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
1225
+ @add_code_sample_docstrings(
1226
+ checkpoint=_CHECKPOINT_FOR_DOC,
1227
+ output_type=TokenClassifierOutput,
1228
+ config_class=_CONFIG_FOR_DOC,
1229
+ )
1230
+ def forward(
1231
+ self,
1232
+ input_ids: Optional[torch.LongTensor] = None,
1233
+ past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None,
1234
+ attention_mask: Optional[torch.Tensor] = None,
1235
+ head_mask: Optional[torch.Tensor] = None,
1236
+ inputs_embeds: Optional[torch.Tensor] = None,
1237
+ labels: Optional[torch.Tensor] = None,
1238
+ use_cache: Optional[bool] = None,
1239
+ output_attentions: Optional[bool] = None,
1240
+ output_hidden_states: Optional[bool] = None,
1241
+ return_dict: Optional[bool] = None,
1242
+ **deprecated_arguments,
1243
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1244
+ r"""
1245
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1246
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1247
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1248
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1249
+ """
1250
+ if deprecated_arguments.pop("position_ids", False) is not False:
1251
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
1252
+ warnings.warn(
1253
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
1254
+ " passing `position_ids`.",
1255
+ FutureWarning,
1256
+ )
1257
+ if len(deprecated_arguments) > 0:
1258
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
1259
+
1260
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1261
+
1262
+ transformer_outputs = self.transformer(
1263
+ input_ids,
1264
+ past_key_values=past_key_values,
1265
+ attention_mask=attention_mask,
1266
+ head_mask=head_mask,
1267
+ inputs_embeds=inputs_embeds,
1268
+ use_cache=use_cache,
1269
+ output_attentions=output_attentions,
1270
+ output_hidden_states=output_hidden_states,
1271
+ return_dict=return_dict,
1272
+ )
1273
+
1274
+ hidden_states = transformer_outputs[0]
1275
+ hidden_states = self.dropout(hidden_states)
1276
+ logits = self.classifier(hidden_states)
1277
+
1278
+ loss = None
1279
+ if labels is not None:
1280
+ # move labels to correct device to enable model parallelism
1281
+ labels = labels.to(logits.device)
1282
+ batch_size, seq_length = labels.shape
1283
+ loss_fct = CrossEntropyLoss()
1284
+ loss = loss_fct(
1285
+ logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
1286
+ )
1287
+
1288
+ if not return_dict:
1289
+ output = (logits,) + transformer_outputs[2:]
1290
+ return ((loss,) + output) if loss is not None else output
1291
+
1292
+ return TokenClassifierOutput(
1293
+ loss=loss,
1294
+ logits=logits,
1295
+ hidden_states=transformer_outputs.hidden_states,
1296
+ attentions=transformer_outputs.attentions,
1297
+ )
1298
+
1299
+
1300
+ @add_start_docstrings(
1301
+ """
1302
+ The BLOOM Model transformer with a span classification head on top for extractive question-answering tasks like
1303
+ SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1304
+ """,
1305
+ BLOOM_START_DOCSTRING,
1306
+ )
1307
+ class BloomForQuestionAnswering(BloomPreTrainedModel):
1308
+ def __init__(self, config):
1309
+ super().__init__(config)
1310
+ self.transformer = BloomModel(config)
1311
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
1312
+
1313
+ # Initialize weights and apply final processing
1314
+ self.post_init()
1315
+
1316
+ @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1317
+ def forward(
1318
+ self,
1319
+ input_ids: Optional[torch.LongTensor] = None,
1320
+ attention_mask: Optional[torch.FloatTensor] = None,
1321
+ position_ids: Optional[torch.LongTensor] = None,
1322
+ head_mask: Optional[torch.FloatTensor] = None,
1323
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1324
+ start_positions: Optional[torch.LongTensor] = None,
1325
+ end_positions: Optional[torch.LongTensor] = None,
1326
+ output_attentions: Optional[bool] = None,
1327
+ output_hidden_states: Optional[bool] = None,
1328
+ return_dict: Optional[bool] = None,
1329
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1330
+ r"""
1331
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1332
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1333
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1334
+ are not taken into account for computing the loss.
1335
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1336
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1337
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1338
+ are not taken into account for computing the loss.
1339
+ """
1340
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1341
+
1342
+ outputs = self.transformer(
1343
+ input_ids,
1344
+ attention_mask=attention_mask,
1345
+ position_ids=position_ids,
1346
+ head_mask=head_mask,
1347
+ inputs_embeds=inputs_embeds,
1348
+ output_attentions=output_attentions,
1349
+ output_hidden_states=output_hidden_states,
1350
+ return_dict=return_dict,
1351
+ )
1352
+
1353
+ sequence_output = outputs[0]
1354
+
1355
+ logits = self.qa_outputs(sequence_output)
1356
+ start_logits, end_logits = logits.split(1, dim=-1)
1357
+ start_logits = start_logits.squeeze(-1).contiguous()
1358
+ end_logits = end_logits.squeeze(-1).contiguous()
1359
+
1360
+ total_loss = None
1361
+ if start_positions is not None and end_positions is not None:
1362
+ # If we are on multi-GPU, split add a dimension
1363
+ if len(start_positions.size()) > 1:
1364
+ start_positions = start_positions.squeeze(-1)
1365
+ if len(end_positions.size()) > 1:
1366
+ end_positions = end_positions.squeeze(-1)
1367
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1368
+ ignored_index = start_logits.size(1)
1369
+ start_positions = start_positions.clamp(0, ignored_index)
1370
+ end_positions = end_positions.clamp(0, ignored_index)
1371
+
1372
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1373
+ start_loss = loss_fct(start_logits, start_positions)
1374
+ end_loss = loss_fct(end_logits, end_positions)
1375
+ total_loss = (start_loss + end_loss) / 2
1376
+
1377
+ if not return_dict:
1378
+ output = (start_logits, end_logits) + outputs[2:]
1379
+ return ((total_loss,) + output) if total_loss is not None else output
1380
+
1381
+ return QuestionAnsweringModelOutput(
1382
+ loss=total_loss,
1383
+ start_logits=start_logits,
1384
+ end_logits=end_logits,
1385
+ hidden_states=outputs.hidden_states,
1386
+ attentions=outputs.attentions,
1387
+ )
1388
+
1389
+
1390
+ __all__ = [
1391
+ "BloomForCausalLM",
1392
+ "BloomModel",
1393
+ "BloomPreTrainedModel",
1394
+ "BloomForSequenceClassification",
1395
+ "BloomForTokenClassification",
1396
+ "BloomForQuestionAnswering",
1397
+ ]
docs/transformers/src/transformers/models/bloom/modeling_flax_bloom.py ADDED
@@ -0,0 +1,737 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 HuggingFace Inc. Team and Bigscience Workshop. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Flax BLOOM model."""
16
+
17
+ import math
18
+ from functools import partial
19
+ from typing import Optional, Tuple
20
+
21
+ import flax.linen as nn
22
+ import jax
23
+ import jax.numpy as jnp
24
+ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
25
+ from flax.linen import combine_masks, dot_product_attention_weights, make_causal_mask
26
+ from flax.linen.activation import tanh
27
+ from flax.traverse_util import flatten_dict, unflatten_dict
28
+ from jax import lax
29
+
30
+ from ...modeling_flax_outputs import (
31
+ FlaxBaseModelOutput,
32
+ FlaxBaseModelOutputWithPastAndCrossAttentions,
33
+ FlaxCausalLMOutput,
34
+ )
35
+ from ...modeling_flax_utils import FlaxPreTrainedModel, append_call_sample_docstring
36
+ from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
37
+ from .configuration_bloom import BloomConfig
38
+
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+ _CHECKPOINT_FOR_DOC = "bigscience/bloom"
43
+ _CONFIG_FOR_DOC = "BloomConfig"
44
+
45
+
46
+ BLOOM_START_DOCSTRING = r"""
47
+
48
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
49
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
50
+ etc.)
51
+
52
+ This model is also a Flax Linen
53
+ [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
54
+ regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
55
+
56
+ Finally, this model supports inherent JAX features such as:
57
+
58
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
59
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
60
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
61
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
62
+
63
+ Parameters:
64
+ config ([`BloomConfig`]): Model configuration class with all the parameters of the model.
65
+ Initializing with a config file does not load the weights associated with the model, only the
66
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
67
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
68
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
69
+ `jax.numpy.bfloat16` (on TPUs).
70
+
71
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
72
+ specified all the computation will be performed with the given `dtype`.
73
+
74
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
75
+ parameters.**
76
+
77
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
78
+ [`~FlaxPreTrainedModel.to_bf16`].
79
+ """
80
+
81
+ BLOOM_INPUTS_DOCSTRING = r"""
82
+ Args:
83
+ input_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`):
84
+ `input_ids_length` = `sequence_length`. Indices of input sequence tokens in the vocabulary.
85
+
86
+ Indices can be obtained using [`BloomTokenizer`]. See [`PreTrainedTokenizer.encode`] and
87
+ [`PreTrainedTokenizer.__call__`] for details.
88
+
89
+ [What are input IDs?](../glossary#input-ids)
90
+ attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
91
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
92
+
93
+ - 1 for tokens that are **not masked**,
94
+ - 0 for tokens that are **masked**.
95
+
96
+ [What are attention masks?](../glossary#attention-mask)
97
+ past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
98
+ Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
99
+ auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
100
+ output_attentions (`bool`, *optional*):
101
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
102
+ tensors for more detail.
103
+ output_hidden_states (`bool`, *optional*):
104
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
105
+ more detail.
106
+ return_dict (`bool`, *optional*):
107
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
108
+ """
109
+
110
+
111
+ def build_alibi_tensor(attention_mask: jnp.ndarray, num_heads: int, dtype: Optional[jnp.dtype] = jnp.float32):
112
+ """
113
+ Flax implementation of the BLOOM Alibi tensor. BLOOM Alibi tensor is not causal as the original paper mentions, it
114
+ relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
115
+ `softmax(l+a) = softmax(l)`. Based on
116
+ https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
117
+ Link to paper: https://arxiv.org/abs/2108.12409
118
+
119
+ Args:
120
+ attention_mask (`jnp.ndarray`):
121
+ Token-wise attention mask, this should be of shape `(batch_size, max_seq_len)`.
122
+ num_heads (`int`):
123
+ Number of attention heads.
124
+ dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
125
+ The data type (dtype) of the output tensor.
126
+
127
+ Returns: Alibi tensor of shape `(batch_size * num_heads, 1, max_seq_len)`.
128
+ """
129
+ batch_size, seq_length = attention_mask.shape
130
+ closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
131
+ base = jnp.array(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=jnp.float32)
132
+ powers = jnp.arange(1, 1 + closest_power_of_2, dtype=jnp.float32)
133
+ slopes = jax.lax.pow(base, powers)
134
+
135
+ if closest_power_of_2 != num_heads:
136
+ extra_base = jnp.array(2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=jnp.float32)
137
+ num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
138
+ extra_powers = jnp.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=jnp.float32)
139
+ slopes = jnp.cat([slopes, jax.lax.pow(extra_base, extra_powers)], axis=0)
140
+
141
+ # Note: the Alibi tensor will added to the attention bias that will be applied to the query, key product of attention
142
+ # therefore, Alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
143
+ # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
144
+ # so that the query_length dimension will then be broadcast correctly.
145
+ # This is more or less identical to T5's relative position bias:
146
+ # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
147
+ arange_tensor = ((attention_mask.cumsum(axis=-1) - 1) * attention_mask)[:, None, :]
148
+ alibi = slopes[..., None] * arange_tensor
149
+ alibi = jnp.expand_dims(alibi, axis=2)
150
+ return jnp.asarray(alibi, dtype)
151
+
152
+
153
+ class FlaxBloomAttention(nn.Module):
154
+ config: BloomConfig
155
+ dtype: jnp.dtype = jnp.float32
156
+
157
+ def setup(self):
158
+ self.hidden_size = self.config.hidden_size
159
+ self.num_heads = self.config.n_head
160
+ self.head_dim = self.hidden_size // self.num_heads
161
+ self.attention_softmax_in_fp32 = self.dtype is not jnp.float32
162
+
163
+ if self.head_dim * self.num_heads != self.hidden_size:
164
+ raise ValueError(
165
+ f"`hidden_size` must be divisible by `num_heads` (got `hidden_size`: {self.hidden_size} and "
166
+ f"`num_heads`: {self.num_heads})."
167
+ )
168
+
169
+ dense = partial(
170
+ nn.Dense,
171
+ dtype=self.dtype,
172
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
173
+ )
174
+
175
+ self.query_key_value = dense(self.hidden_size * 3)
176
+ self.dense = dense(self.hidden_size)
177
+ self.resid_dropout = nn.Dropout(rate=self.config.hidden_dropout)
178
+
179
+ def _split_heads(self, hidden_states):
180
+ return hidden_states.reshape(hidden_states.shape[:-1] + (self.num_heads, self.head_dim * 3))
181
+
182
+ def _merge_heads(self, hidden_states):
183
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,))
184
+
185
+ @nn.compact
186
+ # Copied from transformers.models.gptj.modeling_flax_gptj.FlaxGPTJAttention._concatenate_to_cache
187
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
188
+ """
189
+ This function takes projected key, value states from a single input token and concatenates the states to cached
190
+ states from previous steps. This function is slightly adapted from the official Flax repository:
191
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
192
+ """
193
+ # detect if we're initializing by absence of existing cache data.
194
+ is_initialized = self.has_variable("cache", "cached_key")
195
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
196
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
197
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
198
+
199
+ if is_initialized:
200
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
201
+ # update key, value caches with our new 1d spatial slices
202
+ cur_index = cache_index.value
203
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
204
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
205
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
206
+ cached_key.value = key
207
+ cached_value.value = value
208
+ num_updated_cache_vectors = query.shape[1]
209
+ cache_index.value = cache_index.value + num_updated_cache_vectors
210
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key
211
+ # positions that have already been generated and cached, not the remaining zero elements.
212
+ pad_mask = jnp.broadcast_to(
213
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
214
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
215
+ )
216
+ attention_mask = combine_masks(pad_mask, attention_mask)
217
+ return key, value, attention_mask
218
+
219
+ def __call__(
220
+ self,
221
+ hidden_states,
222
+ residual,
223
+ alibi,
224
+ attention_mask=None,
225
+ deterministic: bool = True,
226
+ init_cache: bool = False,
227
+ output_attentions: bool = False,
228
+ ):
229
+ batch_size, seq_length = hidden_states.shape[:2]
230
+
231
+ # proj q, k, v
232
+ fused_qkv = self.query_key_value(hidden_states)
233
+ fused_qkv = self._split_heads(fused_qkv)
234
+ query, key, value = jnp.split(fused_qkv, 3, axis=-1)
235
+
236
+ causal_attention_mask = make_causal_mask(attention_mask, dtype="bool")
237
+
238
+ # for fast decoding causal attention mask should be shifted
239
+ causal_attention_mask_shift = (
240
+ self.variables["cache"]["cache_index"] if self.has_variable("cache", "cached_key") else 0
241
+ )
242
+
243
+ # fast decoding for generate requires special attention_mask
244
+ if self.has_variable("cache", "cached_key"):
245
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
246
+ causal_attention_mask = jax.lax.dynamic_slice(
247
+ causal_attention_mask,
248
+ (0, 0, causal_attention_mask_shift, 0),
249
+ (1, 1, seq_length, max_decoder_length),
250
+ )
251
+
252
+ # broadcast causal attention mask & attention mask to fit for merge
253
+ causal_attention_mask = jnp.broadcast_to(
254
+ causal_attention_mask, (batch_size,) + causal_attention_mask.shape[1:]
255
+ )
256
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_attention_mask.shape)
257
+ attention_mask = combine_masks(attention_mask, causal_attention_mask)
258
+
259
+ dropout_rng = None
260
+ if not deterministic and self.config.attention_dropout > 0.0:
261
+ dropout_rng = self.make_rng("dropout")
262
+
263
+ # During fast autoregressive decoding, we feed one position at a time,
264
+ # and cache the keys and values step by step.
265
+ if self.has_variable("cache", "cached_key") or init_cache:
266
+ key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask)
267
+
268
+ # transform boolean mask into float mask
269
+ mask_value = jnp.finfo(self.dtype).min
270
+ attention_bias = lax.select(
271
+ attention_mask > 0,
272
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
273
+ jnp.full(attention_mask.shape, mask_value).astype(self.dtype),
274
+ )
275
+
276
+ attention_bias = attention_bias + alibi
277
+
278
+ # Cast in fp32 if the original dtype is different from fp32
279
+ attention_dtype = jnp.float32 if self.attention_softmax_in_fp32 else self.dtype
280
+
281
+ attn_weights = dot_product_attention_weights(
282
+ query,
283
+ key,
284
+ bias=attention_bias,
285
+ dropout_rng=dropout_rng,
286
+ dropout_rate=self.config.attention_dropout,
287
+ deterministic=deterministic,
288
+ dtype=attention_dtype,
289
+ )
290
+
291
+ # Cast back in the original dtype if the native dtype is not fp32
292
+ if self.attention_softmax_in_fp32:
293
+ attn_weights = attn_weights.astype(self.dtype)
294
+
295
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
296
+ attn_output = self._merge_heads(attn_output)
297
+ attn_output = self.dense(attn_output)
298
+ attn_output = self.resid_dropout(attn_output, deterministic=deterministic)
299
+
300
+ attn_output = attn_output + residual
301
+
302
+ outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
303
+ return outputs
304
+
305
+
306
+ class BloomGELU(nn.Module):
307
+ def setup(self):
308
+ self.dtype = jnp.float32
309
+
310
+ def __call__(self, x):
311
+ return x * 0.5 * (1.0 + tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
312
+
313
+
314
+ class FlaxBloomMLP(nn.Module):
315
+ config: BloomConfig
316
+ dtype: jnp.dtype = jnp.float32
317
+
318
+ def setup(self):
319
+ hidden_size = self.config.hidden_size
320
+
321
+ kernel_init = jax.nn.initializers.normal(self.config.initializer_range)
322
+
323
+ self.dense_h_to_4h = nn.Dense(4 * hidden_size, dtype=self.dtype, kernel_init=kernel_init)
324
+ self.dense_4h_to_h = nn.Dense(hidden_size, dtype=self.dtype, kernel_init=kernel_init)
325
+ self.hidden_dropout = nn.Dropout(self.config.hidden_dropout)
326
+ self.act = BloomGELU()
327
+
328
+ def __call__(self, hidden_states, residual, deterministic: bool = True):
329
+ hidden_states = self.dense_h_to_4h(hidden_states)
330
+ hidden_states = self.act(hidden_states)
331
+
332
+ intermediate_output = self.dense_4h_to_h(hidden_states)
333
+
334
+ intermediate_output = intermediate_output + residual
335
+ hidden_states = self.hidden_dropout(intermediate_output, deterministic=deterministic)
336
+
337
+ return hidden_states
338
+
339
+
340
+ class FlaxBloomBlock(nn.Module):
341
+ config: BloomConfig
342
+ dtype: jnp.dtype = jnp.float32
343
+
344
+ def setup(self):
345
+ self.input_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
346
+
347
+ self.self_attention = FlaxBloomAttention(self.config, dtype=self.dtype)
348
+ self.post_attention_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
349
+
350
+ self.mlp = FlaxBloomMLP(self.config, dtype=self.dtype)
351
+
352
+ self.apply_residual_connection_post_layernorm = self.config.apply_residual_connection_post_layernorm
353
+ self.hidden_dropout = self.config.hidden_dropout
354
+
355
+ def __call__(
356
+ self,
357
+ hidden_states,
358
+ alibi,
359
+ attention_mask=None,
360
+ deterministic: bool = True,
361
+ init_cache: bool = False,
362
+ output_attentions: bool = False,
363
+ ):
364
+ layernorm_output = self.input_layernorm(hidden_states)
365
+
366
+ # layer norm before saving residual if config calls for it
367
+ if self.apply_residual_connection_post_layernorm:
368
+ residual = layernorm_output
369
+ else:
370
+ residual = hidden_states
371
+
372
+ # self-attention
373
+ attn_outputs = self.self_attention(
374
+ layernorm_output,
375
+ residual=residual,
376
+ alibi=alibi,
377
+ attention_mask=attention_mask,
378
+ deterministic=deterministic,
379
+ init_cache=init_cache,
380
+ output_attentions=output_attentions,
381
+ )
382
+
383
+ attention_output = attn_outputs[0]
384
+
385
+ outputs = attn_outputs[1:]
386
+
387
+ post_layernorm = self.post_attention_layernorm(attention_output)
388
+
389
+ # set residual based on config
390
+ if self.apply_residual_connection_post_layernorm:
391
+ residual = post_layernorm
392
+ else:
393
+ residual = attention_output
394
+
395
+ output = self.mlp(post_layernorm, residual, deterministic=deterministic)
396
+
397
+ outputs = (output,) + outputs
398
+
399
+ return outputs
400
+
401
+
402
+ class FlaxBloomPreTrainedModel(FlaxPreTrainedModel):
403
+ """
404
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
405
+ models.
406
+ """
407
+
408
+ config_class = BloomConfig
409
+ base_model_prefix = "transformer"
410
+ module_class: nn.Module = None
411
+
412
+ def __init__(
413
+ self,
414
+ config: BloomConfig,
415
+ input_shape: Tuple = (1, 1),
416
+ seed: int = 0,
417
+ dtype: jnp.dtype = jnp.float32,
418
+ _do_init: bool = True,
419
+ **kwargs,
420
+ ):
421
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
422
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
423
+
424
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
425
+ # init input tensors
426
+ input_ids = jnp.zeros(input_shape, dtype="i4")
427
+ attention_mask = jnp.ones_like(input_ids)
428
+ params_rng, dropout_rng = jax.random.split(rng)
429
+ rngs = {"params": params_rng, "dropout": dropout_rng}
430
+
431
+ random_params = self.module.init(rngs, input_ids, attention_mask, return_dict=False)["params"]
432
+
433
+ if params is not None:
434
+ random_params = flatten_dict(unfreeze(random_params))
435
+ params = flatten_dict(unfreeze(params))
436
+ for missing_key in self._missing_keys:
437
+ params[missing_key] = random_params[missing_key]
438
+ self._missing_keys = set()
439
+ return freeze(unflatten_dict(params))
440
+ else:
441
+ return random_params
442
+
443
+ def init_cache(self, batch_size, max_length):
444
+ r"""
445
+ Args:
446
+ batch_size (`int`):
447
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
448
+ max_length (`int`):
449
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
450
+ cache.
451
+ """
452
+ # init input variables to retrieve cache
453
+ input_ids = jnp.ones((batch_size, max_length), dtype="i4")
454
+ attention_mask = jnp.ones_like(input_ids)
455
+
456
+ init_variables = self.module.init(
457
+ jax.random.PRNGKey(0), input_ids, attention_mask, return_dict=False, init_cache=True
458
+ )
459
+ return unfreeze(init_variables["cache"])
460
+
461
+ @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
462
+ def __call__(
463
+ self,
464
+ input_ids,
465
+ attention_mask=None,
466
+ past_key_values: dict = None,
467
+ params: dict = None,
468
+ dropout_rng: jax.random.PRNGKey = None,
469
+ train: bool = False,
470
+ output_attentions: Optional[bool] = None,
471
+ output_hidden_states: Optional[bool] = None,
472
+ return_dict: Optional[bool] = None,
473
+ ):
474
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
475
+ output_hidden_states = (
476
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
477
+ )
478
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
479
+
480
+ batch_size, sequence_length = input_ids.shape
481
+
482
+ if attention_mask is None:
483
+ attention_mask = jnp.ones((batch_size, sequence_length))
484
+
485
+ # Handle any PRNG if needed
486
+ rngs = {}
487
+ if dropout_rng is not None:
488
+ rngs["dropout"] = dropout_rng
489
+
490
+ inputs = {"params": params or self.params}
491
+
492
+ # If past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
493
+ # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
494
+ # changed by FlaxBloomAttention module
495
+ if past_key_values:
496
+ inputs["cache"] = past_key_values
497
+ mutable = ["cache"]
498
+ else:
499
+ mutable = False
500
+
501
+ outputs = self.module.apply(
502
+ inputs,
503
+ jnp.array(input_ids, dtype="i4"),
504
+ jnp.array(attention_mask, dtype="i4"),
505
+ not train,
506
+ False,
507
+ output_attentions,
508
+ output_hidden_states,
509
+ return_dict,
510
+ rngs=rngs,
511
+ mutable=mutable,
512
+ )
513
+
514
+ # add updated cache to model output
515
+ if past_key_values is not None and return_dict:
516
+ outputs, past_key_values = outputs
517
+ outputs["past_key_values"] = unfreeze(past_key_values["cache"])
518
+ return outputs
519
+ elif past_key_values is not None and not return_dict:
520
+ outputs, past_key_values = outputs
521
+ outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
522
+
523
+ return outputs
524
+
525
+
526
+ class FlaxBloomBlockCollection(nn.Module):
527
+ config: BloomConfig
528
+ dtype: jnp.dtype = jnp.float32
529
+
530
+ def setup(self):
531
+ self.layers = [
532
+ FlaxBloomBlock(self.config, name=str(layer_number), dtype=self.dtype)
533
+ for layer_number in range(self.config.num_hidden_layers)
534
+ ]
535
+
536
+ def __call__(
537
+ self,
538
+ hidden_states,
539
+ alibi,
540
+ attention_mask=None,
541
+ deterministic: bool = True,
542
+ init_cache: bool = False,
543
+ output_attentions: bool = False,
544
+ output_hidden_states: bool = False,
545
+ ):
546
+ all_attentions = () if output_attentions else None
547
+ all_hidden_states = () if output_hidden_states else None
548
+
549
+ for layer_number in range(self.config.num_hidden_layers):
550
+ if output_hidden_states:
551
+ all_hidden_states += (hidden_states,)
552
+
553
+ layer_outputs = self.layers[layer_number](
554
+ hidden_states,
555
+ alibi=alibi,
556
+ attention_mask=attention_mask,
557
+ deterministic=deterministic,
558
+ init_cache=init_cache,
559
+ output_attentions=output_attentions,
560
+ )
561
+ hidden_states = layer_outputs[0]
562
+
563
+ if output_attentions:
564
+ all_attentions += (layer_outputs[1],)
565
+
566
+ # this contains possible `None` values - `FlaxBloomModule` will filter them out
567
+ outputs = (hidden_states, all_hidden_states, all_attentions)
568
+
569
+ return outputs
570
+
571
+
572
+ class FlaxBloomModule(nn.Module):
573
+ config: BloomConfig
574
+ dtype: jnp.dtype = jnp.float32
575
+
576
+ def setup(self):
577
+ self.embed_dim = self.config.hidden_size
578
+
579
+ # word embeddings (no positional embedding layer)
580
+ self.word_embeddings = nn.Embed(
581
+ self.config.vocab_size,
582
+ self.embed_dim,
583
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
584
+ dtype=self.dtype,
585
+ )
586
+
587
+ # post-embedding layernorm
588
+ self.word_embeddings_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
589
+
590
+ # transformer layers
591
+ self.h = FlaxBloomBlockCollection(self.config, dtype=self.dtype)
592
+
593
+ # final layernorm
594
+ self.ln_f = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
595
+
596
+ def __call__(
597
+ self,
598
+ input_ids=None,
599
+ attention_mask=None,
600
+ deterministic=True,
601
+ init_cache: bool = False,
602
+ output_attentions: bool = False,
603
+ output_hidden_states: bool = False,
604
+ return_dict: bool = True,
605
+ ):
606
+ inputs_embeds = self.word_embeddings(input_ids)
607
+ # do post-embedding layernorm
608
+ hidden_states = self.word_embeddings_layernorm(inputs_embeds)
609
+
610
+ # build alibi depending on `attention_mask`
611
+ alibi = build_alibi_tensor(attention_mask, self.config.n_head, dtype=hidden_states.dtype)
612
+
613
+ outputs = self.h(
614
+ hidden_states,
615
+ alibi=alibi,
616
+ attention_mask=attention_mask,
617
+ deterministic=deterministic,
618
+ init_cache=init_cache,
619
+ output_hidden_states=output_hidden_states,
620
+ output_attentions=output_attentions,
621
+ )
622
+
623
+ hidden_states = outputs[0]
624
+ hidden_states = self.ln_f(hidden_states)
625
+
626
+ if output_hidden_states:
627
+ all_hidden_states = outputs[1] + (hidden_states,)
628
+ outputs = (hidden_states, all_hidden_states) + outputs[2:]
629
+ else:
630
+ outputs = (hidden_states,) + outputs[1:]
631
+
632
+ if not return_dict:
633
+ return tuple(v for v in [outputs[0], outputs[-1]] if v is not None)
634
+
635
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
636
+ last_hidden_state=hidden_states,
637
+ hidden_states=outputs[1],
638
+ attentions=outputs[-1],
639
+ )
640
+
641
+
642
+ @add_start_docstrings(
643
+ "The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.",
644
+ BLOOM_START_DOCSTRING,
645
+ )
646
+ # Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoModel with GPTNeo->Bloom
647
+ class FlaxBloomModel(FlaxBloomPreTrainedModel):
648
+ module_class = FlaxBloomModule
649
+
650
+
651
+ append_call_sample_docstring(FlaxBloomModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC)
652
+
653
+
654
+ class FlaxBloomForCausalLMModule(nn.Module):
655
+ config: BloomConfig
656
+ dtype: jnp.dtype = jnp.float32
657
+
658
+ def setup(self):
659
+ self.transformer = FlaxBloomModule(self.config, dtype=self.dtype)
660
+ self.lm_head = nn.Dense(
661
+ self.config.vocab_size,
662
+ use_bias=False,
663
+ dtype=self.dtype,
664
+ kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
665
+ )
666
+
667
+ def __call__(
668
+ self,
669
+ input_ids,
670
+ attention_mask,
671
+ deterministic: bool = True,
672
+ init_cache: bool = False,
673
+ output_attentions: bool = False,
674
+ output_hidden_states: bool = False,
675
+ return_dict: bool = True,
676
+ ):
677
+ outputs = self.transformer(
678
+ input_ids,
679
+ attention_mask=attention_mask,
680
+ deterministic=deterministic,
681
+ init_cache=init_cache,
682
+ output_attentions=output_attentions,
683
+ output_hidden_states=output_hidden_states,
684
+ return_dict=return_dict,
685
+ )
686
+
687
+ hidden_states = outputs[0]
688
+
689
+ if self.config.tie_word_embeddings:
690
+ shared_kernel = self.transformer.variables["params"]["word_embeddings"]["embedding"].T
691
+ lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
692
+ else:
693
+ lm_logits = self.lm_head(hidden_states)
694
+
695
+ if not return_dict:
696
+ return (lm_logits,) + outputs[1:]
697
+
698
+ return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
699
+
700
+
701
+ @add_start_docstrings(
702
+ """
703
+ The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input
704
+ embeddings).
705
+ """,
706
+ BLOOM_START_DOCSTRING,
707
+ )
708
+ class FlaxBloomForCausalLM(FlaxBloomPreTrainedModel):
709
+ module_class = FlaxBloomForCausalLMModule
710
+
711
+ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
712
+ # initializing the cache
713
+ batch_size, seq_length = input_ids.shape
714
+
715
+ past_key_values = self.init_cache(batch_size, max_length)
716
+ # Note that usually one would have to put 0's in the attention_mask for
717
+ # x > input_ids.shape[-1] and x < cache_length. But since Bloom uses a causal mask,
718
+ # those positions are masked anyway. Thus, we can create a single static attention_mask here,
719
+ # which is more efficient for compilation
720
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
721
+ if attention_mask is not None:
722
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
723
+
724
+ return {
725
+ "past_key_values": past_key_values,
726
+ "attention_mask": extended_attention_mask,
727
+ }
728
+
729
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
730
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
731
+ return model_kwargs
732
+
733
+
734
+ append_call_sample_docstring(FlaxBloomForCausalLM, _CHECKPOINT_FOR_DOC, FlaxCausalLMOutput, _CONFIG_FOR_DOC)
735
+
736
+
737
+ __all__ = ["FlaxBloomForCausalLM", "FlaxBloomModel", "FlaxBloomPreTrainedModel"]
docs/transformers/src/transformers/models/bloom/tokenization_bloom_fast.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for Bloom."""
16
+
17
+ import pickle
18
+ from typing import Optional, Tuple
19
+
20
+ from ...tokenization_utils_base import BatchEncoding
21
+ from ...tokenization_utils_fast import PreTrainedTokenizerFast
22
+ from ...utils import logging
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+ VOCAB_FILES_NAMES = {"tokenizer_file": "tokenizer.json"}
28
+
29
+
30
+ class BloomTokenizerFast(PreTrainedTokenizerFast):
31
+ """
32
+ Construct a "fast" Bloom tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level
33
+ Byte-Pair-Encoding.
34
+
35
+ This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
36
+ be encoded differently whether it is at the beginning of the sentence (without space) or not:
37
+
38
+ ```python
39
+ >>> from transformers import BloomTokenizerFast
40
+
41
+ >>> tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom")
42
+ >>> tokenizer("Hello world")["input_ids"]
43
+ [59414, 8876]
44
+
45
+ >>> tokenizer(" Hello world")["input_ids"]
46
+ [86153, 8876]
47
+ ```
48
+
49
+ You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer, but since
50
+ the model was not pretrained this way, it might yield a decrease in performance.
51
+
52
+ <Tip>
53
+
54
+ When used with `is_split_into_words=True`, this tokenizer needs to be instantiated with `add_prefix_space=True`.
55
+
56
+ </Tip>
57
+
58
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
59
+ refer to this superclass for more information regarding those methods.
60
+
61
+ Args:
62
+ vocab_file (`str`):
63
+ Path to the vocabulary file.
64
+ merges_file (`str`):
65
+ Path to the merges file.
66
+ errors (`str`, *optional*, defaults to `"replace"`):
67
+ Paradigm to follow when decoding bytes to UTF-8. See
68
+ [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
69
+ unk_token (`str`, *optional*, defaults to `<|endoftext|>`):
70
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
71
+ token instead.
72
+ bos_token (`str`, *optional*, defaults to `<|endoftext|>`):
73
+ The beginning of sequence token.
74
+ eos_token (`str`, *optional*, defaults to `<|endoftext|>`):
75
+ The end of sequence token.
76
+ add_prefix_space (`bool`, *optional*, defaults to `False`):
77
+ Whether or not to add an initial space to the input. This allows to treat the leading word just as any
78
+ other word. (Bloom tokenizer detect beginning of words by the preceding space).
79
+ trim_offsets (`bool`, *optional*, defaults to `True`):
80
+ Whether or not the post-processing step should trim offsets to avoid including whitespaces.
81
+ """
82
+
83
+ vocab_files_names = VOCAB_FILES_NAMES
84
+ model_input_names = ["input_ids", "attention_mask"]
85
+ slow_tokenizer_class = None
86
+ # No `max_model_input_sizes` as BLOOM uses ALiBi positional embeddings
87
+
88
+ def __init__(
89
+ self,
90
+ vocab_file=None,
91
+ merges_file=None,
92
+ tokenizer_file=None,
93
+ unk_token="<unk>",
94
+ bos_token="<s>",
95
+ eos_token="</s>",
96
+ pad_token="<pad>",
97
+ add_prefix_space=False,
98
+ clean_up_tokenization_spaces=False,
99
+ **kwargs,
100
+ ):
101
+ super().__init__(
102
+ vocab_file=vocab_file,
103
+ merges_file=merges_file,
104
+ tokenizer_file=tokenizer_file,
105
+ unk_token=unk_token,
106
+ bos_token=bos_token,
107
+ eos_token=eos_token,
108
+ pad_token=pad_token,
109
+ add_prefix_space=add_prefix_space,
110
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
111
+ **kwargs,
112
+ )
113
+ # TODO @ArthurZucker this can only work one way for now, to update later-on. Tests should also properly
114
+ # check this as they were green before.
115
+ pre_tok_state = pickle.dumps(self.backend_tokenizer.pre_tokenizer)
116
+ decoder_state = pickle.dumps(self.backend_tokenizer.decoder)
117
+
118
+ if add_prefix_space:
119
+ pre_tok_state = pre_tok_state.replace(b'"add_prefix_space":false', b'"add_prefix_space": true')
120
+ decoder_state = decoder_state.replace(b'"add_prefix_space":false', b'"add_prefix_space": true')
121
+ self.backend_tokenizer.pre_tokenizer = pickle.loads(pre_tok_state)
122
+ self.backend_tokenizer.decoder = pickle.loads(decoder_state)
123
+
124
+ self.add_prefix_space = add_prefix_space
125
+
126
+ def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
127
+ is_split_into_words = kwargs.get("is_split_into_words", False)
128
+ if not (self.add_prefix_space or not is_split_into_words):
129
+ raise Exception(
130
+ f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True to use it with"
131
+ " pretokenized inputs."
132
+ )
133
+
134
+ return super()._batch_encode_plus(*args, **kwargs)
135
+
136
+ def _encode_plus(self, *args, **kwargs) -> BatchEncoding:
137
+ is_split_into_words = kwargs.get("is_split_into_words", False)
138
+
139
+ if not (self.add_prefix_space or not is_split_into_words):
140
+ raise Exception(
141
+ f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True to use it with"
142
+ " pretokenized inputs."
143
+ )
144
+
145
+ return super()._encode_plus(*args, **kwargs)
146
+
147
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
148
+ files = self._tokenizer.model.save(save_directory, name=filename_prefix)
149
+ return tuple(files)
150
+
151
+
152
+ __all__ = ["BloomTokenizerFast"]
docs/transformers/src/transformers/models/bridgetower/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_bridgetower import *
22
+ from .image_processing_bridgetower import *
23
+ from .image_processing_bridgetower_fast import *
24
+ from .modeling_bridgetower import *
25
+ from .processing_bridgetower import *
26
+ else:
27
+ import sys
28
+
29
+ _file = globals()["__file__"]
30
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/src/transformers/models/bridgetower/configuration_bridgetower.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Intel Labs Team Authors, The Microsoft Research Team Authors and HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License=, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing=, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS=,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND=, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """BridgeTower model configuration"""
16
+
17
+ from ...configuration_utils import PretrainedConfig
18
+ from ...utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class BridgeTowerVisionConfig(PretrainedConfig):
25
+ r"""
26
+ This is the configuration class to store the vision configuration of a [`BridgeTowerModel`]. Instantiating a
27
+ configuration with the defaults will yield a similar configuration to that of the bridgetower-base
28
+ [BridgeTower/bridgetower-base](https://huggingface.co/BridgeTower/bridgetower-base/) architecture.
29
+
30
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
31
+ documentation from [`PretrainedConfig`] for more information.
32
+
33
+ Args:
34
+ hidden_size (`int`, *optional*, defaults to 768):
35
+ Dimensionality of the encoder layers and the pooler layer.
36
+ num_hidden_layers (`int`, *optional*, defaults to 12):
37
+ Number of hidden layers in visual encoder model.
38
+ patch_size (`int`, *optional*, defaults to 16):
39
+ The size (resolution) of each patch.
40
+ image_size (`int`, *optional*, defaults to 288):
41
+ The size (resolution) of each image.
42
+ initializer_factor (`float`, *optional*, defaults to 1):
43
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
44
+ testing).
45
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
46
+ The epsilon used by the layer normalization layers.
47
+ stop_gradient (`bool`, *optional*, defaults to `False`):
48
+ Whether to stop gradient for training.
49
+ share_layernorm (`bool`, *optional*, defaults to `True`):
50
+ Whether LayerNorm layers are shared.
51
+ remove_last_layer (`bool`, *optional*, defaults to `False`):
52
+ Whether to remove the last layer from the vision encoder.
53
+
54
+
55
+ Example:
56
+
57
+ ```python
58
+ >>> from transformers import BridgeTowerVisionConfig
59
+
60
+ >>> # Initializing a BridgeTower BridgeTower/bridgetower-base style configuration for the vision model
61
+ >>> configuration = BridgeTowerVisionConfig()
62
+
63
+ >>> # Accessing the configuration
64
+ >>> configuration
65
+ ```"""
66
+
67
+ model_type = "bridgetower_vision_model"
68
+ base_config_key = "vision_config"
69
+
70
+ def __init__(
71
+ self,
72
+ hidden_size=768,
73
+ num_hidden_layers=12,
74
+ num_channels=3,
75
+ patch_size=16,
76
+ image_size=288,
77
+ initializer_factor=1,
78
+ layer_norm_eps=1e-05,
79
+ stop_gradient=False,
80
+ share_layernorm=True,
81
+ remove_last_layer=False,
82
+ **kwargs,
83
+ ):
84
+ super().__init__(**kwargs)
85
+ self.hidden_size = hidden_size
86
+ self.num_hidden_layers = num_hidden_layers
87
+ self.num_channels = num_channels
88
+ self.patch_size = patch_size
89
+ self.image_size = image_size
90
+ self.initializer_factor = initializer_factor
91
+ self.layer_norm_eps = layer_norm_eps
92
+ self.stop_gradient = stop_gradient
93
+ self.share_layernorm = share_layernorm
94
+ self.remove_last_layer = remove_last_layer
95
+
96
+
97
+ class BridgeTowerTextConfig(PretrainedConfig):
98
+ r"""
99
+ This is the configuration class to store the text configuration of a [`BridgeTowerModel`]. The default values here
100
+ are copied from RoBERTa. Instantiating a configuration with the defaults will yield a similar configuration to that
101
+ of the bridgetower-base [BridegTower/bridgetower-base](https://huggingface.co/BridgeTower/bridgetower-base/)
102
+ architecture.
103
+
104
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
105
+ documentation from [`PretrainedConfig`] for more information.
106
+
107
+ Args:
108
+ vocab_size (`int`, *optional*, defaults to 50265):
109
+ Vocabulary size of the text part of the model. Defines the number of different tokens that can be
110
+ represented by the `inputs_ids` passed when calling [`BridgeTowerModel`].
111
+ hidden_size (`int`, *optional*, defaults to 768):
112
+ Dimensionality of the encoder layers and the pooler layer.
113
+ num_hidden_layers (`int`, *optional*, defaults to 12):
114
+ Number of hidden layers in the Transformer encoder.
115
+ num_attention_heads (`int`, *optional*, defaults to 12):
116
+ Number of attention heads for each attention layer in the Transformer encoder.
117
+ intermediate_size (`int`, *optional*, defaults to 3072):
118
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
119
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
120
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
121
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
122
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
123
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
124
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
125
+ The dropout ratio for the attention probabilities.
126
+ max_position_embeddings (`int`, *optional*, defaults to 514):
127
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
128
+ just in case (e.g., 512 or 1024 or 2048).
129
+ type_vocab_size (`int`, *optional*, defaults to 2):
130
+ The vocabulary size of the `token_type_ids`.
131
+ initializer_factor (`float`, *optional*, defaults to 1):
132
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
133
+ testing).
134
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
135
+ The epsilon used by the layer normalization layers.
136
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
137
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
138
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
139
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
140
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
141
+ with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
142
+ is_decoder (`bool`, *optional*, defaults to `False`):
143
+ Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
144
+ use_cache (`bool`, *optional*, defaults to `True`):
145
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
146
+ relevant if `config.is_decoder=True`.
147
+
148
+ Example:
149
+
150
+ ```python
151
+ >>> from transformers import BridgeTowerTextConfig
152
+
153
+ >>> # Initializing a BridgeTower BridgeTower/bridgetower-base style configuration for the text model
154
+ >>> configuration = BridgeTowerTextConfig()
155
+
156
+ >>> # Accessing the configuration
157
+ >>> configuration
158
+ ```"""
159
+
160
+ model_type = "bridgetower_text_model"
161
+ base_config_key = "text_config"
162
+
163
+ def __init__(
164
+ self,
165
+ vocab_size=50265,
166
+ hidden_size=768,
167
+ num_hidden_layers=12,
168
+ num_attention_heads=12,
169
+ initializer_factor=1,
170
+ intermediate_size=3072,
171
+ hidden_act="gelu",
172
+ hidden_dropout_prob=0.1,
173
+ attention_probs_dropout_prob=0.1,
174
+ max_position_embeddings=514,
175
+ type_vocab_size=1,
176
+ layer_norm_eps=1e-05,
177
+ pad_token_id=1,
178
+ bos_token_id=0,
179
+ eos_token_id=2,
180
+ position_embedding_type="absolute",
181
+ use_cache=True,
182
+ **kwargs,
183
+ ):
184
+ super().__init__(**kwargs)
185
+
186
+ self.vocab_size = vocab_size
187
+ self.hidden_size = hidden_size
188
+ self.num_hidden_layers = num_hidden_layers
189
+ self.num_attention_heads = num_attention_heads
190
+ self.hidden_act = hidden_act
191
+ self.initializer_factor = initializer_factor
192
+ self.intermediate_size = intermediate_size
193
+ self.hidden_dropout_prob = hidden_dropout_prob
194
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
195
+ self.max_position_embeddings = max_position_embeddings
196
+ self.type_vocab_size = type_vocab_size
197
+ self.layer_norm_eps = layer_norm_eps
198
+ self.position_embedding_type = position_embedding_type
199
+ self.use_cache = use_cache
200
+ self.pad_token_id = pad_token_id
201
+ self.bos_token_id = bos_token_id
202
+ self.eos_token_id = eos_token_id
203
+
204
+
205
+ class BridgeTowerConfig(PretrainedConfig):
206
+ r"""
207
+ This is the configuration class to store the configuration of a [`BridgeTowerModel`]. It is used to instantiate a
208
+ BridgeTower model according to the specified arguments, defining the model architecture. Instantiating a
209
+ configuration with the defaults will yield a similar configuration to that of the bridgetower-base
210
+ [BridgeTower/bridgetower-base](https://huggingface.co/BridgeTower/bridgetower-base/) architecture.
211
+
212
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
213
+ documentation from [`PretrainedConfig`] for more information.
214
+
215
+ Args:
216
+ share_cross_modal_transformer_layers (`bool`, *optional*, defaults to `True`):
217
+ Whether cross modal transformer layers are shared.
218
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
219
+ The non-linear activation function (function or string) in the encoder and pooler.
220
+ hidden_size (`int`, *optional*, defaults to 768):
221
+ Dimensionality of the encoder layers and the pooler layer.
222
+ initializer_factor (`float`, *optional*, defaults to 1):
223
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
224
+ testing).
225
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
226
+ The epsilon used by the layer normalization layers.
227
+ share_link_tower_layers (`bool`, *optional*, defaults to `False`):
228
+ Whether the bride/link tower layers are shared.
229
+ link_tower_type (`str`, *optional*, defaults to `"add"`):
230
+ Type of the bridge/link layer.
231
+ num_attention_heads (`int`, *optional*, defaults to 12):
232
+ Number of attention heads for each attention layer in the Transformer encoder.
233
+ num_hidden_layers (`int`, *optional*, defaults to 6):
234
+ Number of hidden layers in the Transformer encoder.
235
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
236
+ Whether to tie input and output embeddings.
237
+ init_layernorm_from_vision_encoder (`bool`, *optional*, defaults to `False`):
238
+ Whether to init LayerNorm from the vision encoder.
239
+ text_config (`dict`, *optional*):
240
+ Dictionary of configuration options used to initialize [`BridgeTowerTextConfig`].
241
+ vision_config (`dict`, *optional*):
242
+ Dictionary of configuration options used to initialize [`BridgeTowerVisionConfig`].
243
+
244
+ Example:
245
+
246
+ ```python
247
+ >>> from transformers import BridgeTowerModel, BridgeTowerConfig
248
+
249
+ >>> # Initializing a BridgeTower BridgeTower/bridgetower-base style configuration
250
+ >>> configuration = BridgeTowerConfig()
251
+
252
+ >>> # Initializing a model from the BridgeTower/bridgetower-base style configuration
253
+ >>> model = BridgeTowerModel(configuration)
254
+
255
+ >>> # Accessing the model configuration
256
+ >>> configuration = model.config
257
+ ```"""
258
+
259
+ model_type = "bridgetower"
260
+ sub_configs = {"text_config": BridgeTowerTextConfig, "vision_config": BridgeTowerVisionConfig}
261
+
262
+ def __init__(
263
+ self,
264
+ share_cross_modal_transformer_layers=True,
265
+ hidden_act="gelu",
266
+ hidden_size=768,
267
+ initializer_factor=1,
268
+ layer_norm_eps=1e-05,
269
+ share_link_tower_layers=False,
270
+ link_tower_type="add",
271
+ num_attention_heads=12,
272
+ num_hidden_layers=6,
273
+ tie_word_embeddings=False,
274
+ init_layernorm_from_vision_encoder=False,
275
+ text_config=None,
276
+ vision_config=None,
277
+ **kwargs,
278
+ ):
279
+ # TODO: remove this once the Hub files are updated.
280
+ _ = kwargs.pop("text_config_dict", None)
281
+ _ = kwargs.pop("vision_config_dict", None)
282
+
283
+ super().__init__(**kwargs)
284
+ self.share_cross_modal_transformer_layers = share_cross_modal_transformer_layers
285
+ self.hidden_act = hidden_act
286
+ self.hidden_size = hidden_size
287
+ self.initializer_factor = initializer_factor
288
+ self.layer_norm_eps = layer_norm_eps
289
+ self.share_link_tower_layers = share_link_tower_layers
290
+ self.link_tower_type = link_tower_type
291
+ self.num_attention_heads = num_attention_heads
292
+ self.num_hidden_layers = num_hidden_layers
293
+ self.tie_word_embeddings = tie_word_embeddings
294
+ self.init_layernorm_from_vision_encoder = init_layernorm_from_vision_encoder
295
+
296
+ if text_config is None:
297
+ text_config = {}
298
+ logger.info("`text_config` is `None`. Initializing the `BridgeTowerTextConfig` with default values.")
299
+
300
+ if vision_config is None:
301
+ vision_config = {}
302
+ logger.info("`vision_config` is `None`. Initializing the `BridgeTowerVisionConfig` with default values.")
303
+
304
+ self.text_config = BridgeTowerTextConfig(**text_config)
305
+ self.vision_config = BridgeTowerVisionConfig(**vision_config)
306
+
307
+ @classmethod
308
+ def from_text_vision_configs(
309
+ cls, text_config: BridgeTowerTextConfig, vision_config: BridgeTowerVisionConfig, **kwargs
310
+ ):
311
+ r"""
312
+ Instantiate a [`BridgeTowerConfig`] (or a derived class) from BridgeTower text model configuration. Returns:
313
+ [`BridgeTowerConfig`]: An instance of a configuration object
314
+ """
315
+
316
+ return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
317
+
318
+
319
+ __all__ = ["BridgeTowerConfig", "BridgeTowerTextConfig", "BridgeTowerVisionConfig"]
docs/transformers/src/transformers/models/bridgetower/image_processing_bridgetower.py ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Intel Labs Team Authors, The Microsoft Research Team Authors and HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Image processor class for BridgeTower."""
16
+
17
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+
21
+ from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
22
+ from ...image_transforms import PaddingMode, center_crop, pad, resize, to_channel_dimension_format
23
+ from ...image_utils import (
24
+ OPENAI_CLIP_MEAN,
25
+ OPENAI_CLIP_STD,
26
+ ChannelDimension,
27
+ ImageInput,
28
+ PILImageResampling,
29
+ get_image_size,
30
+ infer_channel_dimension_format,
31
+ is_scaled_image,
32
+ make_flat_list_of_images,
33
+ to_numpy_array,
34
+ valid_images,
35
+ validate_preprocess_arguments,
36
+ )
37
+ from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
38
+
39
+
40
+ if is_vision_available():
41
+ import PIL
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+
46
+ # Copied from transformers.models.vilt.image_processing_vilt.max_across_indices
47
+ def max_across_indices(values: Iterable[Any]) -> List[Any]:
48
+ """
49
+ Return the maximum value across all indices of an iterable of values.
50
+ """
51
+ return [max(values_i) for values_i in zip(*values)]
52
+
53
+
54
+ # Copied from transformers.models.vilt.image_processing_vilt.make_pixel_mask
55
+ def make_pixel_mask(
56
+ image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
57
+ ) -> np.ndarray:
58
+ """
59
+ Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
60
+
61
+ Args:
62
+ image (`np.ndarray`):
63
+ Image to make the pixel mask for.
64
+ output_size (`Tuple[int, int]`):
65
+ Output size of the mask.
66
+ """
67
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
68
+ mask = np.zeros(output_size, dtype=np.int64)
69
+ mask[:input_height, :input_width] = 1
70
+ return mask
71
+
72
+
73
+ # Copied from transformers.models.vilt.image_processing_vilt.get_max_height_width
74
+ def get_max_height_width(
75
+ images: List[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
76
+ ) -> List[int]:
77
+ """
78
+ Get the maximum height and width across all images in a batch.
79
+ """
80
+ if input_data_format is None:
81
+ input_data_format = infer_channel_dimension_format(images[0])
82
+
83
+ if input_data_format == ChannelDimension.FIRST:
84
+ _, max_height, max_width = max_across_indices([img.shape for img in images])
85
+ elif input_data_format == ChannelDimension.LAST:
86
+ max_height, max_width, _ = max_across_indices([img.shape for img in images])
87
+ else:
88
+ raise ValueError(f"Invalid channel dimension format: {input_data_format}")
89
+ return (max_height, max_width)
90
+
91
+
92
+ # Copied from transformers.models.vilt.image_processing_vilt.get_resize_output_image_size
93
+ def get_resize_output_image_size(
94
+ input_image: np.ndarray,
95
+ shorter: int = 800,
96
+ longer: int = 1333,
97
+ size_divisor: int = 32,
98
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
99
+ ) -> Tuple[int, int]:
100
+ input_height, input_width = get_image_size(input_image, input_data_format)
101
+ min_size, max_size = shorter, longer
102
+
103
+ scale = min_size / min(input_height, input_width)
104
+
105
+ if input_height < input_width:
106
+ new_height = min_size
107
+ new_width = scale * input_width
108
+ else:
109
+ new_height = scale * input_height
110
+ new_width = min_size
111
+
112
+ if max(new_height, new_width) > max_size:
113
+ scale = max_size / max(new_height, new_width)
114
+ new_height = scale * new_height
115
+ new_width = scale * new_width
116
+
117
+ new_height, new_width = int(new_height + 0.5), int(new_width + 0.5)
118
+ new_height = new_height // size_divisor * size_divisor
119
+ new_width = new_width // size_divisor * size_divisor
120
+
121
+ return new_height, new_width
122
+
123
+
124
+ class BridgeTowerImageProcessor(BaseImageProcessor):
125
+ r"""
126
+ Constructs a BridgeTower image processor.
127
+
128
+ Args:
129
+ do_resize (`bool`, *optional*, defaults to `True`):
130
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
131
+ `do_resize` parameter in the `preprocess` method.
132
+ size (`Dict[str, int]` *optional*, defaults to `{'shortest_edge': 288}`):
133
+ Resize the shorter side of the input to `size["shortest_edge"]`. The longer side will be limited to under
134
+ `int((1333 / 800) * size["shortest_edge"])` while preserving the aspect ratio. Only has an effect if
135
+ `do_resize` is set to `True`. Can be overridden by the `size` parameter in the `preprocess` method.
136
+ size_divisor (`int`, *optional*, defaults to 32):
137
+ The size by which to make sure both the height and width can be divided. Only has an effect if `do_resize`
138
+ is set to `True`. Can be overridden by the `size_divisor` parameter in the `preprocess` method.
139
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
140
+ Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
141
+ overridden by the `resample` parameter in the `preprocess` method.
142
+ do_rescale (`bool`, *optional*, defaults to `True`):
143
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
144
+ parameter in the `preprocess` method.
145
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
146
+ Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
147
+ overridden by the `rescale_factor` parameter in the `preprocess` method.
148
+ do_normalize (`bool`, *optional*, defaults to `True`):
149
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
150
+ method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
151
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
152
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
153
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
154
+ overridden by the `image_mean` parameter in the `preprocess` method.
155
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
156
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
157
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
158
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
159
+ do_center_crop (`bool`, *optional*, defaults to `True`):
160
+ Whether to center crop the image. Can be overridden by the `do_center_crop` parameter in the `preprocess`
161
+ method.
162
+ crop_size (`Dict[str, int]`, *optional*):
163
+ Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`.
164
+ Can be overridden by the `crop_size` parameter in the `preprocess` method. If unset defaults to `size`,
165
+ do_pad (`bool`, *optional*, defaults to `True`):
166
+ Whether to pad the image to the `(max_height, max_width)` of the images in the batch. Can be overridden by
167
+ the `do_pad` parameter in the `preprocess` method.
168
+ """
169
+
170
+ model_input_names = ["pixel_values"]
171
+
172
+ def __init__(
173
+ self,
174
+ do_resize: bool = True,
175
+ size: Dict[str, int] = None,
176
+ size_divisor: int = 32,
177
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
178
+ do_rescale: bool = True,
179
+ rescale_factor: Union[int, float] = 1 / 255,
180
+ do_normalize: bool = True,
181
+ image_mean: Optional[Union[float, List[float]]] = None,
182
+ image_std: Optional[Union[float, List[float]]] = None,
183
+ do_center_crop: bool = True,
184
+ crop_size: Dict[str, int] = None,
185
+ do_pad: bool = True,
186
+ **kwargs,
187
+ ) -> None:
188
+ if "pad_and_return_pixel_mask" in kwargs:
189
+ do_pad = kwargs.pop("pad_and_return_pixel_mask")
190
+
191
+ super().__init__(**kwargs)
192
+ size = size if size is not None else {"shortest_edge": 288}
193
+ size = get_size_dict(size, default_to_square=False)
194
+
195
+ self.do_resize = do_resize
196
+ self.size = size
197
+ self.size_divisor = size_divisor
198
+ self.resample = resample
199
+ self.do_rescale = do_rescale
200
+ self.rescale_factor = rescale_factor
201
+ self.do_normalize = do_normalize
202
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
203
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
204
+ self.do_pad = do_pad
205
+ self.do_center_crop = do_center_crop
206
+ self.crop_size = crop_size
207
+
208
+ # Copied from transformers.models.vilt.image_processing_vilt.ViltImageProcessor.resize
209
+ def resize(
210
+ self,
211
+ image: np.ndarray,
212
+ size: Dict[str, int],
213
+ size_divisor: int = 32,
214
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
215
+ data_format: Optional[Union[str, ChannelDimension]] = None,
216
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
217
+ **kwargs,
218
+ ) -> np.ndarray:
219
+ """
220
+ Resize an image.
221
+
222
+ Resizes the shorter side of the image to `size["shortest_edge"]` while preserving the aspect ratio. If the
223
+ longer side is larger than the max size `(int(`size["shortest_edge"]` * 1333 / 800))`, the longer side is then
224
+ resized to the max size while preserving the aspect ratio.
225
+
226
+ Args:
227
+ image (`np.ndarray`):
228
+ Image to resize.
229
+ size (`Dict[str, int]`):
230
+ Controls the size of the output image. Should be of the form `{"shortest_edge": int}`.
231
+ size_divisor (`int`, *optional*, defaults to 32):
232
+ The image is resized to a size that is a multiple of this value.
233
+ resample (`PILImageResampling` filter, *optional*, defaults to `PILImageResampling.BICUBIC`):
234
+ Resampling filter to use when resiizing the image.
235
+ data_format (`str` or `ChannelDimension`, *optional*):
236
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
237
+ input_data_format (`str` or `ChannelDimension`, *optional*):
238
+ The channel dimension format of the input image. If not provided, it will be inferred.
239
+ """
240
+ size = get_size_dict(size, default_to_square=False)
241
+ if "shortest_edge" not in size:
242
+ raise ValueError(f"The `size` dictionary must contain the key `shortest_edge`. Got {size.keys()}")
243
+ shorter = size["shortest_edge"]
244
+ longer = int(1333 / 800 * shorter)
245
+ output_size = get_resize_output_image_size(
246
+ image, shorter=shorter, longer=longer, size_divisor=size_divisor, input_data_format=input_data_format
247
+ )
248
+ return resize(
249
+ image,
250
+ size=output_size,
251
+ resample=resample,
252
+ data_format=data_format,
253
+ input_data_format=input_data_format,
254
+ **kwargs,
255
+ )
256
+
257
+ def center_crop(
258
+ self,
259
+ image: np.ndarray,
260
+ size: Dict[str, int],
261
+ data_format: Optional[Union[str, ChannelDimension]] = None,
262
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
263
+ **kwargs,
264
+ ) -> np.ndarray:
265
+ """
266
+ Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
267
+ any edge, the image is padded with 0's and then center cropped.
268
+
269
+ Args:
270
+ image (`np.ndarray`):
271
+ Image to center crop.
272
+ size (`Dict[str, int]`):
273
+ Size of the output image in the form `{"height": h, "width": w}`.
274
+ data_format (`str` or `ChannelDimension`, *optional*):
275
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
276
+ input_data_format (`ChannelDimension` or `str`, *optional*):
277
+ The channel dimension format of the input image. If not provided, it will be inferred from the input
278
+ image.
279
+ """
280
+ output_size = size["shortest_edge"]
281
+ return center_crop(
282
+ image,
283
+ size=(output_size, output_size),
284
+ data_format=data_format,
285
+ input_data_format=input_data_format,
286
+ **kwargs,
287
+ )
288
+
289
+ # Copied from transformers.models.vilt.image_processing_vilt.ViltImageProcessor._pad_image
290
+ def _pad_image(
291
+ self,
292
+ image: np.ndarray,
293
+ output_size: Tuple[int, int],
294
+ constant_values: Union[float, Iterable[float]] = 0,
295
+ data_format: Optional[ChannelDimension] = None,
296
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
297
+ ) -> np.ndarray:
298
+ """
299
+ Pad an image with zeros to the given size.
300
+ """
301
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
302
+ output_height, output_width = output_size
303
+
304
+ pad_bottom = output_height - input_height
305
+ pad_right = output_width - input_width
306
+ padding = ((0, pad_bottom), (0, pad_right))
307
+ padded_image = pad(
308
+ image,
309
+ padding,
310
+ mode=PaddingMode.CONSTANT,
311
+ constant_values=constant_values,
312
+ data_format=data_format,
313
+ input_data_format=input_data_format,
314
+ )
315
+ return padded_image
316
+
317
+ # Copied from transformers.models.vilt.image_processing_vilt.ViltImageProcessor.pad
318
+ def pad(
319
+ self,
320
+ images: List[np.ndarray],
321
+ constant_values: Union[float, Iterable[float]] = 0,
322
+ return_pixel_mask: bool = True,
323
+ return_tensors: Optional[Union[str, TensorType]] = None,
324
+ data_format: Optional[ChannelDimension] = None,
325
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
326
+ ) -> BatchFeature:
327
+ """
328
+ Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
329
+ in the batch and optionally returns their corresponding pixel mask.
330
+
331
+ Args:
332
+ image (`np.ndarray`):
333
+ Image to pad.
334
+ constant_values (`float` or `Iterable[float]`, *optional*):
335
+ The value to use for the padding if `mode` is `"constant"`.
336
+ return_pixel_mask (`bool`, *optional*, defaults to `True`):
337
+ Whether to return a pixel mask.
338
+ return_tensors (`str` or `TensorType`, *optional*):
339
+ The type of tensors to return. Can be one of:
340
+ - Unset: Return a list of `np.ndarray`.
341
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
342
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
343
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
344
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
345
+ data_format (`str` or `ChannelDimension`, *optional*):
346
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
347
+ input_data_format (`ChannelDimension` or `str`, *optional*):
348
+ The channel dimension format of the input image. If not provided, it will be inferred.
349
+ """
350
+ pad_size = get_max_height_width(images, input_data_format=input_data_format)
351
+
352
+ padded_images = [
353
+ self._pad_image(
354
+ image,
355
+ pad_size,
356
+ constant_values=constant_values,
357
+ data_format=data_format,
358
+ input_data_format=input_data_format,
359
+ )
360
+ for image in images
361
+ ]
362
+ data = {"pixel_values": padded_images}
363
+
364
+ if return_pixel_mask:
365
+ masks = [
366
+ make_pixel_mask(image=image, output_size=pad_size, input_data_format=input_data_format)
367
+ for image in images
368
+ ]
369
+ data["pixel_mask"] = masks
370
+
371
+ return BatchFeature(data=data, tensor_type=return_tensors)
372
+
373
+ @filter_out_non_signature_kwargs()
374
+ def preprocess(
375
+ self,
376
+ images: ImageInput,
377
+ do_resize: Optional[bool] = None,
378
+ size: Optional[Dict[str, int]] = None,
379
+ size_divisor: Optional[int] = None,
380
+ resample: PILImageResampling = None,
381
+ do_rescale: Optional[bool] = None,
382
+ rescale_factor: Optional[float] = None,
383
+ do_normalize: Optional[bool] = None,
384
+ image_mean: Optional[Union[float, List[float]]] = None,
385
+ image_std: Optional[Union[float, List[float]]] = None,
386
+ do_pad: Optional[bool] = None,
387
+ do_center_crop: Optional[bool] = None,
388
+ crop_size: Dict[str, int] = None,
389
+ return_tensors: Optional[Union[str, TensorType]] = None,
390
+ data_format: ChannelDimension = ChannelDimension.FIRST,
391
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
392
+ ) -> PIL.Image.Image:
393
+ """
394
+ Preprocess an image or batch of images.
395
+
396
+ Args:
397
+ images (`ImageInput`):
398
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
399
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
400
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
401
+ Whether to resize the image.
402
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
403
+ Controls the size of the image after `resize`. The shortest edge of the image is resized to
404
+ `size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image
405
+ is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest
406
+ edge equal to `int(size["shortest_edge"] * (1333 / 800))`.
407
+ size_divisor (`int`, *optional*, defaults to `self.size_divisor`):
408
+ The image is resized to a size that is a multiple of this value.
409
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
410
+ Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`.
411
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
412
+ Whether to rescale the image values between [0 - 1].
413
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
414
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
415
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
416
+ Whether to normalize the image.
417
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
418
+ Image mean to normalize the image by if `do_normalize` is set to `True`.
419
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
420
+ Image standard deviation to normalize the image by if `do_normalize` is set to `True`.
421
+ do_pad (`bool`, *optional*, defaults to `self.do_pad`):
422
+ Whether to pad the image to the (max_height, max_width) in the batch. If `True`, a pixel mask is also
423
+ created and returned.
424
+ do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
425
+ Whether to center crop the image. If the input size is smaller than `crop_size` along any edge, the
426
+ image is padded with 0's and then center cropped.
427
+ crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
428
+ Size of the image after center crop. If one edge the image is smaller than `crop_size`, it will be
429
+ padded with zeros and then cropped
430
+ return_tensors (`str` or `TensorType`, *optional*):
431
+ The type of tensors to return. Can be one of:
432
+ - Unset: Return a list of `np.ndarray`.
433
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
434
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
435
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
436
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
437
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
438
+ The channel dimension format for the output image. Can be one of:
439
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
440
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
441
+ - Unset: Use the channel dimension format of the input image.
442
+ input_data_format (`ChannelDimension` or `str`, *optional*):
443
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
444
+ from the input image. Can be one of:
445
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
446
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
447
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
448
+ """
449
+ do_resize = do_resize if do_resize is not None else self.do_resize
450
+ size_divisor = size_divisor if size_divisor is not None else self.size_divisor
451
+ resample = resample if resample is not None else self.resample
452
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
453
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
454
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
455
+ image_mean = image_mean if image_mean is not None else self.image_mean
456
+ image_std = image_std if image_std is not None else self.image_std
457
+ do_pad = do_pad if do_pad is not None else self.do_pad
458
+ do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
459
+ # For backwards compatibility. Initial version of this processor was cropping to the "size" argument, which
460
+ # it should default to if crop_size is undefined.
461
+ crop_size = (
462
+ crop_size if crop_size is not None else (self.crop_size if self.crop_size is not None else self.size)
463
+ )
464
+
465
+ size = size if size is not None else self.size
466
+ size = get_size_dict(size, default_to_square=False)
467
+ images = make_flat_list_of_images(images)
468
+
469
+ if not valid_images(images):
470
+ raise ValueError(
471
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
472
+ "torch.Tensor, tf.Tensor or jax.ndarray."
473
+ )
474
+ # Here, crop_size is used only if it is set, else size will be used.
475
+ validate_preprocess_arguments(
476
+ do_rescale=do_rescale,
477
+ rescale_factor=rescale_factor,
478
+ do_normalize=do_normalize,
479
+ image_mean=image_mean,
480
+ image_std=image_std,
481
+ do_pad=do_pad,
482
+ size_divisibility=size_divisor,
483
+ do_center_crop=do_center_crop,
484
+ crop_size=crop_size,
485
+ do_resize=do_resize,
486
+ size=size,
487
+ resample=resample,
488
+ )
489
+ # All transformations expect numpy arrays.
490
+ images = [to_numpy_array(image) for image in images]
491
+
492
+ if do_rescale and is_scaled_image(images[0]):
493
+ logger.warning_once(
494
+ "It looks like you are trying to rescale already rescaled images. If the input"
495
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
496
+ )
497
+
498
+ if do_resize:
499
+ images = [
500
+ self.resize(
501
+ image=image,
502
+ size=size,
503
+ size_divisor=size_divisor,
504
+ resample=resample,
505
+ input_data_format=input_data_format,
506
+ )
507
+ for image in images
508
+ ]
509
+
510
+ if do_center_crop:
511
+ images = [
512
+ self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
513
+ ]
514
+
515
+ if do_rescale:
516
+ images = [
517
+ self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
518
+ for image in images
519
+ ]
520
+
521
+ if do_normalize:
522
+ images = [
523
+ self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
524
+ for image in images
525
+ ]
526
+
527
+ images = [
528
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
529
+ ]
530
+
531
+ if do_pad:
532
+ encoded_outputs = self.pad(
533
+ images, return_pixel_mask=True, return_tensors=return_tensors, input_data_format=data_format
534
+ )
535
+ else:
536
+ encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
537
+
538
+ return encoded_outputs
539
+
540
+
541
+ __all__ = ["BridgeTowerImageProcessor"]
docs/transformers/src/transformers/models/bridgetower/image_processing_bridgetower_fast.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The Intel Labs Team Authors, The Microsoft Research Team Authors and HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Fast Image processor class for BridgeTower."""
16
+
17
+ from typing import Dict, Iterable, Optional, Tuple, Union
18
+
19
+ from ...image_processing_utils_fast import (
20
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
21
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
22
+ BaseImageProcessorFast,
23
+ BatchFeature,
24
+ DefaultFastImageProcessorKwargs,
25
+ ImageInput,
26
+ SizeDict,
27
+ TensorType,
28
+ Unpack,
29
+ get_max_height_width,
30
+ group_images_by_shape,
31
+ reorder_images,
32
+ )
33
+ from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling
34
+ from ...utils import add_start_docstrings, is_torch_available, is_torchvision_available, is_torchvision_v2_available
35
+
36
+
37
+ if is_torch_available():
38
+ import torch
39
+
40
+ if is_torchvision_available():
41
+ if is_torchvision_v2_available():
42
+ from torchvision.transforms.v2 import functional as F
43
+ else:
44
+ from torchvision.transforms import functional as F
45
+
46
+
47
+ def make_pixel_mask(
48
+ image: "torch.Tensor",
49
+ output_size: Tuple[int, int],
50
+ ) -> "torch.Tensor":
51
+ """
52
+ Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
53
+
54
+ Args:
55
+ image (`np.ndarray`):
56
+ Image to make the pixel mask for.
57
+ output_size (`Tuple[int, int]`):
58
+ Output size of the mask.
59
+ """
60
+ input_height, input_width = image.shape[-2:]
61
+ batch_size = image.size(0)
62
+ mask = torch.zeros((batch_size, *output_size), dtype=torch.long)
63
+ mask[:input_height, :input_width] = 1
64
+ return mask
65
+
66
+
67
+ def get_resize_output_image_size(
68
+ input_image: "torch.Tensor",
69
+ shorter: int = 800,
70
+ longer: int = 1333,
71
+ size_divisor: int = 32,
72
+ ) -> Tuple[int, int]:
73
+ input_height, input_width = input_image.shape[-2:]
74
+ min_size, max_size = shorter, longer
75
+
76
+ scale = min_size / min(input_height, input_width)
77
+
78
+ if input_height < input_width:
79
+ new_height = min_size
80
+ new_width = scale * input_width
81
+ else:
82
+ new_height = scale * input_height
83
+ new_width = min_size
84
+
85
+ if max(new_height, new_width) > max_size:
86
+ scale = max_size / max(new_height, new_width)
87
+ new_height = scale * new_height
88
+ new_width = scale * new_width
89
+
90
+ new_height, new_width = int(new_height + 0.5), int(new_width + 0.5)
91
+ new_height = new_height // size_divisor * size_divisor
92
+ new_width = new_width // size_divisor * size_divisor
93
+
94
+ return new_height, new_width
95
+
96
+
97
+ class BridgeTowerFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
98
+ size_divisor: Optional[int]
99
+ do_pad: Optional[bool]
100
+
101
+
102
+ @add_start_docstrings(
103
+ "Constructs a fast BridgeTower image processor.",
104
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
105
+ """
106
+ size_divisor (`int`, *optional*, defaults to 32):
107
+ The size by which to make sure both the height and width can be divided. Only has an effect if `do_resize`
108
+ is set to `True`. Can be overridden by the `size_divisor` parameter in the `preprocess` method.
109
+ do_pad (`bool`, *optional*, defaults to `True`):
110
+ Whether to pad the image to the `(max_height, max_width)` of the images in the batch. Can be overridden by
111
+ the `do_pad` parameter in the `preprocess` method.
112
+ """,
113
+ )
114
+ class BridgeTowerImageProcessorFast(BaseImageProcessorFast):
115
+ resample = PILImageResampling.BICUBIC
116
+ image_mean = OPENAI_CLIP_MEAN
117
+ image_std = OPENAI_CLIP_STD
118
+ size = {"shortest_edge": 288}
119
+ default_to_square = False
120
+ crop_size = {"shortest_edge": 288}
121
+ do_resize = True
122
+ do_center_crop = True
123
+ do_rescale = True
124
+ do_normalize = True
125
+ do_pad = True
126
+ size_divisor = 32
127
+ valid_kwargs = BridgeTowerFastImageProcessorKwargs
128
+
129
+ def __init__(self, **kwargs: Unpack[BridgeTowerFastImageProcessorKwargs]):
130
+ super().__init__(**kwargs)
131
+
132
+ @add_start_docstrings(
133
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
134
+ """
135
+ size_divisor (`int`, *optional*, defaults to 32):
136
+ The size by which to make sure both the height and width can be divided. Only has an effect if `do_resize`
137
+ is set to `True`. Can be overridden by the `size_divisor` parameter in the `preprocess` method.
138
+ do_pad (`bool`, *optional*, defaults to `True`):
139
+ Whether to pad the image to the `(max_height, max_width)` of the images in the batch. Can be overridden by
140
+ the `do_pad` parameter in the `preprocess` method.
141
+ """,
142
+ )
143
+ def preprocess(self, images: ImageInput, **kwargs: Unpack[BridgeTowerFastImageProcessorKwargs]) -> BatchFeature:
144
+ return super().preprocess(images, **kwargs)
145
+
146
+ def resize(
147
+ self,
148
+ image: "torch.Tensor",
149
+ size: SizeDict,
150
+ size_divisor: int = 32,
151
+ interpolation: "F.InterpolationMode" = None,
152
+ antialias: bool = True,
153
+ **kwargs,
154
+ ) -> "torch.Tensor":
155
+ """
156
+ Resize an image.
157
+
158
+ Resizes the shorter side of the image to `size["shortest_edge"]` while preserving the aspect ratio. If the
159
+ longer side is larger than the max size `(int(`size["shortest_edge"]` * 1333 / 800))`, the longer side is then
160
+ resized to the max size while preserving the aspect ratio.
161
+
162
+ Args:
163
+ image (`torch.Tensor`):
164
+ Image to resize.
165
+ size (`SizeDict`):
166
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
167
+ size_divisor (`int`, *optional*, defaults to 32):
168
+ The image is resized to a size that is a multiple of this value.
169
+ resample (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
170
+ `InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
171
+
172
+ Returns:
173
+ `torch.Tensor`: The resized image.
174
+ """
175
+ interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
176
+ if not size.shortest_edge:
177
+ raise ValueError(f"The `size` dictionary must contain the key `shortest_edge`. Got {size.keys()}")
178
+ shorter = size.shortest_edge
179
+ longer = int(1333 / 800 * shorter)
180
+ output_size = get_resize_output_image_size(
181
+ image,
182
+ shorter=shorter,
183
+ longer=longer,
184
+ size_divisor=size_divisor,
185
+ )
186
+ return F.resize(image, output_size, interpolation=interpolation, antialias=antialias)
187
+
188
+ def center_crop(
189
+ self,
190
+ image: "torch.Tensor",
191
+ size: Dict[str, int],
192
+ **kwargs,
193
+ ) -> "torch.Tensor":
194
+ """
195
+ Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
196
+ any edge, the image is padded with 0's and then center cropped.
197
+
198
+ Args:
199
+ image (`torch.Tensor`):
200
+ Image to center crop.
201
+ size (`Dict[str, int]`):
202
+ Size of the output image in the form `{"height": h, "width": w}`.
203
+ """
204
+ output_size = size.shortest_edge
205
+ return F.center_crop(
206
+ image,
207
+ output_size=(output_size, output_size),
208
+ **kwargs,
209
+ )
210
+
211
+ def _pad_image(
212
+ self,
213
+ image: "torch.Tensor",
214
+ output_size: Tuple[int, int],
215
+ constant_values: Union[float, Iterable[float]] = 0,
216
+ ) -> "torch.Tensor":
217
+ """
218
+ Pad an image with zeros to the given size.
219
+ """
220
+ input_height, input_width = image.shape[-2:]
221
+ output_height, output_width = output_size
222
+
223
+ pad_bottom = output_height - input_height
224
+ pad_right = output_width - input_width
225
+ padding = (0, 0, pad_right, pad_bottom)
226
+ padded_image = F.pad(
227
+ image,
228
+ padding,
229
+ fill=constant_values,
230
+ )
231
+ return padded_image
232
+
233
+ def pad(
234
+ self,
235
+ images: list["torch.Tensor"],
236
+ constant_values: Union[float, Iterable[float]] = 0,
237
+ return_pixel_mask: bool = True,
238
+ ) -> tuple:
239
+ """
240
+ Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
241
+ in the batch and optionally returns their corresponding pixel mask.
242
+
243
+ Args:
244
+ image (`torch.Tensor`):
245
+ Image to pad.
246
+ constant_values (`float` or `Iterable[float]`, *optional*):
247
+ The value to use for the padding if `mode` is `"constant"`.
248
+ return_pixel_mask (`bool`, *optional*, defaults to `True`):
249
+ Whether to return a pixel mask.
250
+ return_tensors (`str` or `TensorType`, *optional*):
251
+ The type of tensors to return. Can be one of:
252
+ - Unset: Return a list of `np.ndarray`.
253
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
254
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
255
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
256
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
257
+ """
258
+ pad_size = get_max_height_width(images)
259
+
260
+ grouped_images, grouped_images_index = group_images_by_shape(images)
261
+ processed_images_grouped = {}
262
+ processed_masks_grouped = {}
263
+ for shape, stacked_images in grouped_images.items():
264
+ stacked_images = self._pad_image(
265
+ stacked_images,
266
+ pad_size,
267
+ constant_values=constant_values,
268
+ )
269
+ processed_images_grouped[shape] = stacked_images
270
+
271
+ if return_pixel_mask:
272
+ stacked_masks = make_pixel_mask(image=stacked_images, output_size=pad_size)
273
+ processed_masks_grouped[shape] = stacked_masks
274
+
275
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index)
276
+
277
+ processed_masks = None
278
+ if return_pixel_mask:
279
+ processed_masks = reorder_images(processed_masks_grouped, grouped_images_index)
280
+
281
+ return processed_images, processed_masks
282
+
283
+ def _preprocess(
284
+ self,
285
+ images: list["torch.Tensor"],
286
+ do_resize: bool,
287
+ size: SizeDict,
288
+ size_divisor: Optional[int],
289
+ interpolation: Optional["F.InterpolationMode"],
290
+ do_pad: bool,
291
+ do_center_crop: bool,
292
+ crop_size: SizeDict,
293
+ do_rescale: bool,
294
+ rescale_factor: float,
295
+ do_normalize: bool,
296
+ image_mean: Optional[Union[float, list[float]]],
297
+ image_std: Optional[Union[float, list[float]]],
298
+ return_tensors: Optional[Union[str, TensorType]],
299
+ **kwargs,
300
+ ) -> BatchFeature:
301
+ # Group images by size for batched resizing
302
+ grouped_images, grouped_images_index = group_images_by_shape(images)
303
+ resized_images_grouped = {}
304
+ for shape, stacked_images in grouped_images.items():
305
+ if do_resize:
306
+ stacked_images = self.resize(
307
+ image=stacked_images, size=size, size_divisor=size_divisor, interpolation=interpolation
308
+ )
309
+ resized_images_grouped[shape] = stacked_images
310
+ resized_images = reorder_images(resized_images_grouped, grouped_images_index)
311
+
312
+ # Group images by size for further processing
313
+ # Needed in case do_resize is False, or resize returns images with different sizes
314
+ grouped_images, grouped_images_index = group_images_by_shape(resized_images)
315
+ processed_images_grouped = {}
316
+ for shape, stacked_images in grouped_images.items():
317
+ if do_center_crop:
318
+ stacked_images = self.center_crop(stacked_images, crop_size)
319
+ # Fused rescale and normalize
320
+ stacked_images = self.rescale_and_normalize(
321
+ stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
322
+ )
323
+ processed_images_grouped[shape] = stacked_images
324
+
325
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index)
326
+
327
+ data = {}
328
+ if do_pad:
329
+ processed_images, processed_masks = self.pad(processed_images, return_pixel_mask=True)
330
+ processed_masks = torch.stack(processed_masks, dim=0) if return_tensors else processed_masks
331
+ data["pixel_mask"] = processed_masks
332
+
333
+ processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
334
+ data["pixel_values"] = processed_images
335
+
336
+ return BatchFeature(data=data, tensor_type=return_tensors)
337
+
338
+ def to_dict(self):
339
+ encoder_dict = super().to_dict()
340
+ encoder_dict.pop("_valid_processor_keys", None)
341
+ encoder_dict.pop("crop_size", None)
342
+ return encoder_dict
343
+
344
+
345
+ __all__ = ["BridgeTowerImageProcessorFast"]
docs/transformers/src/transformers/models/bridgetower/modeling_bridgetower.py ADDED
@@ -0,0 +1,1984 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Intel Labs Team Authors, The Microsoft Research Team Authors and HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch BridgeTower Model"""
16
+
17
+ import math
18
+ from collections import OrderedDict
19
+ from dataclasses import dataclass
20
+ from typing import List, Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+ from torch.nn import CrossEntropyLoss
26
+
27
+ from ...activations import ACT2FN, QuickGELUActivation
28
+ from ...modeling_outputs import (
29
+ BaseModelOutputWithPastAndCrossAttentions,
30
+ BaseModelOutputWithPoolingAndCrossAttentions,
31
+ MaskedLMOutput,
32
+ ModelOutput,
33
+ SequenceClassifierOutput,
34
+ )
35
+ from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward
36
+ from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
37
+ from ...utils import (
38
+ add_start_docstrings,
39
+ add_start_docstrings_to_model_forward,
40
+ logging,
41
+ replace_return_docstrings,
42
+ torch_int,
43
+ )
44
+ from .configuration_bridgetower import BridgeTowerConfig, BridgeTowerTextConfig, BridgeTowerVisionConfig
45
+
46
+
47
+ logger = logging.get_logger(__name__)
48
+
49
+ _CONFIG_FOR_DOC = "BridgeTowerConfig"
50
+ _CHECKPOINT_FOR_DOC = "BridgeTower/bridgetower-base"
51
+ _TOKENIZER_FOR_DOC = "RobertaTokenizer"
52
+
53
+
54
+ BRIDGETOWER_START_DOCSTRING = r"""
55
+ This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ subclass. Use
56
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
57
+ behavior.
58
+
59
+ Parameters:
60
+ config ([`BridgeTowerConfig`]): Model configuration class with all the parameters of the model.
61
+ Initializing with a config file does not load the weights associated with the model, only the
62
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
63
+ """
64
+
65
+ BRIDGETOWER_INPUTS_DOCSTRING = r"""
66
+ Args:
67
+ input_ids (`torch.LongTensor` of shape `({0})`):
68
+ Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
69
+ [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
70
+ IDs?](../glossary#input-ids)
71
+
72
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
73
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
74
+ - 1 for tokens that are **not masked**,
75
+ - 0 for tokens that are **masked**.
76
+ [What are attention masks?](../glossary#attention-mask)
77
+
78
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
79
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
80
+ 1]`:
81
+ - 0 corresponds to a *sentence A* token,
82
+ - 1 corresponds to a *sentence B* token.
83
+ [What are token type IDs?](../glossary#token-type-ids)
84
+
85
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
86
+ Pixel values. Pixel values can be obtained using [`BridgeTowerImageProcessor`]. See
87
+ [`BridgeTowerImageProcessor.__call__`] for details.
88
+
89
+ pixel_mask (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
90
+ Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`:
91
+
92
+ - 1 for pixels that are real (i.e. **not masked**),
93
+ - 0 for pixels that are padding (i.e. **masked**).
94
+ `What are attention masks? <../glossary.html#attention-mask>`__
95
+
96
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
97
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
98
+ - 1 indicates the head is **not masked**,
99
+ - 0 indicates the head is **masked**.
100
+
101
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
102
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
103
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
104
+ model's internal embedding lookup matrix.
105
+
106
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
107
+ Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
108
+ This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
109
+
110
+ image_token_type_idx (`int`, *optional*):
111
+ - The token type ids for images.
112
+
113
+ output_attentions (`bool`, *optional*):
114
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
115
+ tensors for more detail.
116
+
117
+ output_hidden_states (`bool`, *optional*):
118
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
119
+ more detail.
120
+ interpolate_pos_encoding (`bool`, defaults to `False`):
121
+ Whether to interpolate the pre-trained position encodings.
122
+ return_dict (`bool`, *optional*):
123
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
124
+ """
125
+
126
+
127
+ @dataclass
128
+ class BridgeTowerModelOutput(ModelOutput):
129
+ """
130
+ Output type of [`BridgeTowerModel`].
131
+
132
+ Args:
133
+ text_features (`torch.FloatTensor` of shape `(batch_size, text_sequence_length, hidden_size)`):
134
+ Sequence of hidden-states at the text output of the last layer of the model.
135
+ image_features (`torch.FloatTensor` of shape `(batch_size, image_sequence_length, hidden_size)`):
136
+ Sequence of hidden-states at the image output of the last layer of the model.
137
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size x 2)`):
138
+ Concatenation of last layer hidden-state of the first token of the text and image sequence (classification
139
+ token), respectively, after further processing through layers used for auxiliary pretraining tasks.
140
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
141
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
142
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
143
+ the model at the output of each layer plus the optional initial embedding outputs.
144
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
145
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
146
+ sequence_length)`.
147
+
148
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
149
+ heads.
150
+ """
151
+
152
+ text_features: Optional[torch.FloatTensor] = None
153
+ image_features: Optional[torch.FloatTensor] = None
154
+ pooler_output: Optional[torch.FloatTensor] = None
155
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
156
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
157
+
158
+
159
+ @dataclass
160
+ class BridgeTowerContrastiveOutput(ModelOutput):
161
+ """
162
+ Output type of ['BridgeTowerForContrastiveLearning']
163
+
164
+ Args:
165
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`:
166
+ Image-text contrastive loss.
167
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
168
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
169
+ text_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`):
170
+ The text embeddings obtained by applying the projection layer to the pooler_output.
171
+ image_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`):
172
+ The image embeddings obtained by applying the projection layer to the pooler_output.
173
+ cross_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`):
174
+ The text-image cross-modal embeddings obtained by applying the projection layer to the pooler_output.
175
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
176
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
177
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
178
+ the model at the output of each layer plus the optional initial embedding outputs.
179
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
180
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
181
+ sequence_length)`.
182
+ """
183
+
184
+ loss: Optional[torch.FloatTensor] = None
185
+ logits: Optional[torch.FloatTensor] = None
186
+ text_embeds: Optional[Tuple[torch.FloatTensor]] = None
187
+ image_embeds: Optional[Tuple[torch.FloatTensor]] = None
188
+ cross_embeds: Optional[Tuple[torch.FloatTensor]] = None
189
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
190
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
191
+
192
+
193
+ class BridgeTowerResidualAttention(nn.Module):
194
+ def __init__(self, config):
195
+ super().__init__()
196
+
197
+ self.attn = nn.MultiheadAttention(config.hidden_size, config.hidden_size // 64)
198
+ self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
199
+ self.mlp = nn.ModuleDict(
200
+ OrderedDict(
201
+ [
202
+ ("c_fc", nn.Linear(config.hidden_size, config.hidden_size * 4)),
203
+ ("gelu", QuickGELUActivation()),
204
+ ("c_proj", nn.Linear(config.hidden_size * 4, config.hidden_size)),
205
+ ]
206
+ )
207
+ )
208
+ self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
209
+ self.attn_mask = None
210
+
211
+ def attention(self, hidden_state: torch.Tensor, attention_mask: torch.Tensor):
212
+ if attention_mask is not None:
213
+ attention_mask = attention_mask.to(dtype=torch.bool, device=hidden_state.device)
214
+ self.attn_mask = (
215
+ self.attn_mask.to(dtype=hidden_state.dtype, device=hidden_state.device)
216
+ if self.attn_mask is not None
217
+ else None
218
+ )
219
+ return self.attn(
220
+ hidden_state,
221
+ hidden_state,
222
+ hidden_state,
223
+ need_weights=False,
224
+ attn_mask=self.attn_mask,
225
+ key_padding_mask=attention_mask,
226
+ )[0]
227
+
228
+ def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
229
+ residual_state = hidden_state + self.attention(self.ln_1(hidden_state), attention_mask)
230
+ hidden_state = self.ln_2(residual_state)
231
+ for _, layer in self.mlp.items():
232
+ hidden_state = layer(hidden_state)
233
+ hidden_state = residual_state + hidden_state
234
+ return hidden_state
235
+
236
+
237
+ class BridgeTowerTransformer(nn.Module):
238
+ def __init__(self, config):
239
+ super().__init__()
240
+ self.hidden_size = config.hidden_size
241
+ self.num_hidden_layers = config.num_hidden_layers
242
+ if config.remove_last_layer:
243
+ self.resblocks = nn.ModuleList(
244
+ [BridgeTowerResidualAttention(config) for _ in range(self.num_hidden_layers - 1)]
245
+ )
246
+ else:
247
+ self.resblocks = nn.ModuleList(
248
+ [BridgeTowerResidualAttention(config) for _ in range(self.num_hidden_layers)]
249
+ )
250
+ self.stop_gradient = config.stop_gradient
251
+
252
+ def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
253
+ hidden_states = []
254
+ for block in self.resblocks:
255
+ hidden_state = block(hidden_state, attention_mask)
256
+ if self.stop_gradient:
257
+ hidden_states.append(hidden_state.detach())
258
+ else:
259
+ hidden_states.append(hidden_state)
260
+ return hidden_states
261
+
262
+
263
+ # Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->BridgeTower
264
+ class BridgeTowerVisionEmbeddings(nn.Module):
265
+ def __init__(self, config: BridgeTowerVisionConfig):
266
+ super().__init__()
267
+ self.config = config
268
+ self.embed_dim = config.hidden_size
269
+ self.image_size = config.image_size
270
+ self.patch_size = config.patch_size
271
+
272
+ self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
273
+
274
+ self.patch_embedding = nn.Conv2d(
275
+ in_channels=config.num_channels,
276
+ out_channels=self.embed_dim,
277
+ kernel_size=self.patch_size,
278
+ stride=self.patch_size,
279
+ bias=False,
280
+ )
281
+
282
+ self.num_patches = (self.image_size // self.patch_size) ** 2
283
+ self.num_positions = self.num_patches + 1
284
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
285
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
286
+
287
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
288
+ """
289
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
290
+ images. This method is also adapted to support torch.jit tracing.
291
+
292
+ Adapted from:
293
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
294
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
295
+ """
296
+
297
+ num_patches = embeddings.shape[1] - 1
298
+ position_embedding = self.position_embedding.weight.unsqueeze(0)
299
+ num_positions = position_embedding.shape[1] - 1
300
+
301
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
302
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
303
+ return self.position_embedding(self.position_ids)
304
+
305
+ class_pos_embed = position_embedding[:, :1]
306
+ patch_pos_embed = position_embedding[:, 1:]
307
+
308
+ dim = embeddings.shape[-1]
309
+
310
+ new_height = height // self.patch_size
311
+ new_width = width // self.patch_size
312
+
313
+ sqrt_num_positions = torch_int(num_positions**0.5)
314
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
315
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
316
+
317
+ patch_pos_embed = nn.functional.interpolate(
318
+ patch_pos_embed,
319
+ size=(new_height, new_width),
320
+ mode="bicubic",
321
+ align_corners=False,
322
+ )
323
+
324
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
325
+
326
+ return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
327
+
328
+ def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
329
+ batch_size, _, height, width = pixel_values.shape
330
+ if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
331
+ raise ValueError(
332
+ f"Input image size ({height}*{width}) doesn't match model ({self.image_size}*{self.image_size})."
333
+ )
334
+ target_dtype = self.patch_embedding.weight.dtype
335
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
336
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
337
+
338
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
339
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
340
+ if interpolate_pos_encoding:
341
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
342
+ else:
343
+ embeddings = embeddings + self.position_embedding(self.position_ids)
344
+ return embeddings
345
+
346
+
347
+ class BridgeTowerVisionTransformer(nn.Module):
348
+ def __init__(self, config):
349
+ super().__init__()
350
+
351
+ self.embeddings = BridgeTowerVisionEmbeddings(config)
352
+ self.ln_pre = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
353
+ self.transformer = BridgeTowerTransformer(config)
354
+ self.ln_post = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
355
+ self.share_layernorm = config.share_layernorm
356
+ if not config.share_layernorm:
357
+ self.ln_separate = nn.ModuleList(
358
+ [nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) for _ in range(config.num_hidden_layers)]
359
+ )
360
+
361
+ def forward(
362
+ self,
363
+ pixel_values: torch.Tensor,
364
+ attention_mask,
365
+ interpolate_pos_encoding: bool = False,
366
+ ):
367
+ hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding)
368
+ hidden_states = self.ln_pre(hidden_states)
369
+ # NLD -> LND
370
+ hidden_states = hidden_states.permute(1, 0, 2)
371
+
372
+ hidden_states = self.transformer(hidden_states, attention_mask)
373
+ # shape = [num_hidden_layers, hidden_size, *, grid ** 2]
374
+ hidden_states = torch.stack(hidden_states, dim=0)
375
+ # shape = [num_hidden_layers, *, hidden_size, grid ** 2]
376
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
377
+ if self.share_layernorm:
378
+ hidden_states = self.ln_post(hidden_states)
379
+ else:
380
+ hidden_states_stack = []
381
+ for hidden_states, ln in zip(hidden_states, self.ln_separate):
382
+ hidden_states = ln(hidden_states)
383
+ hidden_states_stack.append(hidden_states)
384
+ # shape = [num_hidden_layers, *, hidden_size, grid ** 2]
385
+ hidden_states = torch.stack(hidden_states_stack, dim=0)
386
+ return hidden_states
387
+
388
+ def forward_pre(
389
+ self,
390
+ pixel_values: torch.Tensor,
391
+ interpolate_pos_encoding: bool = False,
392
+ ):
393
+ hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
394
+ hidden_states = self.ln_pre(hidden_states)
395
+ # NLD -> LND
396
+ hidden_states = hidden_states.permute(1, 0, 2)
397
+ return hidden_states
398
+
399
+ def forward_post(self, hidden_state: torch.Tensor):
400
+ visual_output_post = hidden_state.permute(1, 0, 2)
401
+ visual_output_post = self.ln_post(visual_output_post)
402
+ return visual_output_post
403
+
404
+
405
+ class BridgeTowerLinkTower(nn.Module):
406
+ def __init__(self, config):
407
+ super().__init__()
408
+ self.link_tower_type = config.link_tower_type
409
+ self.hidden_size = config.hidden_size
410
+ if config.link_tower_type in ["add", "scaled_add", "interpolate"]:
411
+ if config.link_tower_type == "scaled_add":
412
+ self.scaled_factor = nn.Parameter(torch.tensor(1.0))
413
+ elif config.link_tower_type == "interpolate":
414
+ self.beta = nn.Parameter(torch.tensor(0.5))
415
+ self.LayerNorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
416
+ else:
417
+ raise NotImplementedError(f"link_tower_type {config.link_tower_type} is not implemented")
418
+
419
+ def forward(self, hidden_states, cross_modal_hidden_states, attention_mask):
420
+ if self.link_tower_type == "add":
421
+ return self.LayerNorm(hidden_states + cross_modal_hidden_states)
422
+ elif self.link_tower_type == "scaled_add":
423
+ return self.LayerNorm(hidden_states * self.scaled_factor + cross_modal_hidden_states)
424
+ elif self.link_tower_type == "interpolate":
425
+ return self.LayerNorm(hidden_states * (1 - self.beta) + cross_modal_hidden_states * self.beta)
426
+ else:
427
+ raise NotImplementedError(f"link_tower_type {self.link_tower_type} is not implemented")
428
+
429
+
430
+ # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->BridgeTower
431
+ class BridgeTowerSelfOutput(nn.Module):
432
+ def __init__(self, config):
433
+ super().__init__()
434
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
435
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
436
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
437
+
438
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
439
+ hidden_states = self.dense(hidden_states)
440
+ hidden_states = self.dropout(hidden_states)
441
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
442
+ return hidden_states
443
+
444
+
445
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->BridgeTower
446
+ class BridgeTowerIntermediate(nn.Module):
447
+ def __init__(self, config):
448
+ super().__init__()
449
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
450
+ if isinstance(config.hidden_act, str):
451
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
452
+ else:
453
+ self.intermediate_act_fn = config.hidden_act
454
+
455
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
456
+ hidden_states = self.dense(hidden_states)
457
+ hidden_states = self.intermediate_act_fn(hidden_states)
458
+ return hidden_states
459
+
460
+
461
+ # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->BridgeTower
462
+ class BridgeTowerOutput(nn.Module):
463
+ def __init__(self, config):
464
+ super().__init__()
465
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
466
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
467
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
468
+
469
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
470
+ hidden_states = self.dense(hidden_states)
471
+ hidden_states = self.dropout(hidden_states)
472
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
473
+ return hidden_states
474
+
475
+
476
+ # Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->BridgeTower
477
+ class BridgeTowerPooler(nn.Module):
478
+ def __init__(self, config):
479
+ super().__init__()
480
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
481
+ self.activation = nn.Tanh()
482
+
483
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
484
+ # We "pool" the model by simply taking the hidden state corresponding
485
+ # to the first token.
486
+ first_token_tensor = hidden_states[:, 0]
487
+ pooled_output = self.dense(first_token_tensor)
488
+ pooled_output = self.activation(pooled_output)
489
+ return pooled_output
490
+
491
+
492
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->BridgeTower
493
+ class BridgeTowerSelfAttention(nn.Module):
494
+ def __init__(self, config, position_embedding_type=None):
495
+ super().__init__()
496
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
497
+ raise ValueError(
498
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
499
+ f"heads ({config.num_attention_heads})"
500
+ )
501
+
502
+ self.num_attention_heads = config.num_attention_heads
503
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
504
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
505
+
506
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
507
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
508
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
509
+
510
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
511
+ self.position_embedding_type = position_embedding_type or getattr(
512
+ config, "position_embedding_type", "absolute"
513
+ )
514
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
515
+ self.max_position_embeddings = config.max_position_embeddings
516
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
517
+
518
+ self.is_decoder = config.is_decoder
519
+
520
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
521
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
522
+ x = x.view(new_x_shape)
523
+ return x.permute(0, 2, 1, 3)
524
+
525
+ def forward(
526
+ self,
527
+ hidden_states: torch.Tensor,
528
+ attention_mask: Optional[torch.FloatTensor] = None,
529
+ head_mask: Optional[torch.FloatTensor] = None,
530
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
531
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
532
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
533
+ output_attentions: Optional[bool] = False,
534
+ ) -> Tuple[torch.Tensor]:
535
+ mixed_query_layer = self.query(hidden_states)
536
+
537
+ # If this is instantiated as a cross-attention module, the keys
538
+ # and values come from an encoder; the attention mask needs to be
539
+ # such that the encoder's padding tokens are not attended to.
540
+ is_cross_attention = encoder_hidden_states is not None
541
+
542
+ if is_cross_attention and past_key_value is not None:
543
+ # reuse k,v, cross_attentions
544
+ key_layer = past_key_value[0]
545
+ value_layer = past_key_value[1]
546
+ attention_mask = encoder_attention_mask
547
+ elif is_cross_attention:
548
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
549
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
550
+ attention_mask = encoder_attention_mask
551
+ elif past_key_value is not None:
552
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
553
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
554
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
555
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
556
+ else:
557
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
558
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
559
+
560
+ query_layer = self.transpose_for_scores(mixed_query_layer)
561
+
562
+ use_cache = past_key_value is not None
563
+ if self.is_decoder:
564
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
565
+ # Further calls to cross_attention layer can then reuse all cross-attention
566
+ # key/value_states (first "if" case)
567
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
568
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
569
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
570
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
571
+ past_key_value = (key_layer, value_layer)
572
+
573
+ # Take the dot product between "query" and "key" to get the raw attention scores.
574
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
575
+
576
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
577
+ query_length, key_length = query_layer.shape[2], key_layer.shape[2]
578
+ if use_cache:
579
+ position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
580
+ -1, 1
581
+ )
582
+ else:
583
+ position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
584
+ position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
585
+ distance = position_ids_l - position_ids_r
586
+
587
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
588
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
589
+
590
+ if self.position_embedding_type == "relative_key":
591
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
592
+ attention_scores = attention_scores + relative_position_scores
593
+ elif self.position_embedding_type == "relative_key_query":
594
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
595
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
596
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
597
+
598
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
599
+ if attention_mask is not None:
600
+ # Apply the attention mask is (precomputed for all layers in BridgeTowerModel forward() function)
601
+ attention_scores = attention_scores + attention_mask
602
+
603
+ # Normalize the attention scores to probabilities.
604
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
605
+
606
+ # This is actually dropping out entire tokens to attend to, which might
607
+ # seem a bit unusual, but is taken from the original Transformer paper.
608
+ attention_probs = self.dropout(attention_probs)
609
+
610
+ # Mask heads if we want to
611
+ if head_mask is not None:
612
+ attention_probs = attention_probs * head_mask
613
+
614
+ context_layer = torch.matmul(attention_probs, value_layer)
615
+
616
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
617
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
618
+ context_layer = context_layer.view(new_context_layer_shape)
619
+
620
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
621
+
622
+ if self.is_decoder:
623
+ outputs = outputs + (past_key_value,)
624
+ return outputs
625
+
626
+
627
+ BRIDGE_TOWER_SELF_ATTENTION_CLASSES = {
628
+ "eager": BridgeTowerSelfAttention,
629
+ }
630
+
631
+
632
+ # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->BridgeTower,BERT->BRIDGE_TOWER
633
+ class BridgeTowerAttention(nn.Module):
634
+ def __init__(self, config, position_embedding_type=None):
635
+ super().__init__()
636
+ self.self = BRIDGE_TOWER_SELF_ATTENTION_CLASSES[config._attn_implementation](
637
+ config, position_embedding_type=position_embedding_type
638
+ )
639
+ self.output = BridgeTowerSelfOutput(config)
640
+ self.pruned_heads = set()
641
+
642
+ def prune_heads(self, heads):
643
+ if len(heads) == 0:
644
+ return
645
+ heads, index = find_pruneable_heads_and_indices(
646
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
647
+ )
648
+
649
+ # Prune linear layers
650
+ self.self.query = prune_linear_layer(self.self.query, index)
651
+ self.self.key = prune_linear_layer(self.self.key, index)
652
+ self.self.value = prune_linear_layer(self.self.value, index)
653
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
654
+
655
+ # Update hyper params and store pruned heads
656
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
657
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
658
+ self.pruned_heads = self.pruned_heads.union(heads)
659
+
660
+ def forward(
661
+ self,
662
+ hidden_states: torch.Tensor,
663
+ attention_mask: Optional[torch.FloatTensor] = None,
664
+ head_mask: Optional[torch.FloatTensor] = None,
665
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
666
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
667
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
668
+ output_attentions: Optional[bool] = False,
669
+ ) -> Tuple[torch.Tensor]:
670
+ self_outputs = self.self(
671
+ hidden_states,
672
+ attention_mask,
673
+ head_mask,
674
+ encoder_hidden_states,
675
+ encoder_attention_mask,
676
+ past_key_value,
677
+ output_attentions,
678
+ )
679
+ attention_output = self.output(self_outputs[0], hidden_states)
680
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
681
+ return outputs
682
+
683
+
684
+ class BridgeTowerBertCrossLayer(nn.Module):
685
+ def __init__(self, config):
686
+ super().__init__()
687
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
688
+ self.seq_len_dim = 1
689
+ self.attention = BridgeTowerAttention(config)
690
+ self.is_decoder = config.is_decoder
691
+ self.add_cross_attention = config.add_cross_attention
692
+ self.crossattention = BridgeTowerAttention(config)
693
+ self.intermediate = BridgeTowerIntermediate(config)
694
+ self.output = BridgeTowerOutput(config)
695
+
696
+ def forward(
697
+ self,
698
+ hidden_states,
699
+ encoder_hidden_states,
700
+ attention_mask=None,
701
+ head_mask=None,
702
+ encoder_attention_mask=None,
703
+ past_key_value=None,
704
+ output_attentions=False,
705
+ ):
706
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
707
+ self_attention_outputs = self.attention(
708
+ hidden_states,
709
+ attention_mask=attention_mask,
710
+ head_mask=None,
711
+ output_attentions=output_attentions,
712
+ past_key_value=None,
713
+ )
714
+ attention_output = self_attention_outputs[0]
715
+
716
+ # if decoder, the last output is tuple of self-attn cache
717
+ # add self attentions if we output attention weights
718
+ outputs = self_attention_outputs[1:]
719
+
720
+ cross_attention_outputs = self.crossattention(
721
+ attention_output,
722
+ attention_mask=attention_mask,
723
+ head_mask=head_mask,
724
+ encoder_hidden_states=encoder_hidden_states,
725
+ encoder_attention_mask=encoder_attention_mask,
726
+ past_key_value=past_key_value,
727
+ output_attentions=output_attentions,
728
+ )
729
+ attention_output = cross_attention_outputs[0]
730
+ # add cross attentions if we output attention weights
731
+ outputs = outputs + cross_attention_outputs[1:-1]
732
+
733
+ layer_output = apply_chunking_to_forward(
734
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
735
+ )
736
+ outputs = (layer_output,) + outputs
737
+
738
+ return outputs
739
+
740
+ def feed_forward_chunk(self, attention_output):
741
+ intermediate_output = self.intermediate(attention_output)
742
+ layer_output = self.output(intermediate_output, attention_output)
743
+ return layer_output
744
+
745
+
746
+ class BridgeTowerTextLayer(nn.Module):
747
+ def __init__(self, config):
748
+ super().__init__()
749
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
750
+ self.seq_len_dim = 1
751
+ self.attention = BridgeTowerAttention(config)
752
+ self.is_decoder = config.is_decoder
753
+ self.add_cross_attention = config.add_cross_attention
754
+ if self.add_cross_attention:
755
+ if not self.is_decoder:
756
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
757
+ self.crossattention = BridgeTowerAttention(config, position_embedding_type="absolute")
758
+ self.intermediate = BridgeTowerIntermediate(config)
759
+ self.output = BridgeTowerOutput(config)
760
+
761
+ def forward(
762
+ self,
763
+ hidden_states: torch.Tensor,
764
+ attention_mask: Optional[torch.FloatTensor] = None,
765
+ head_mask: Optional[torch.FloatTensor] = None,
766
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
767
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
768
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
769
+ output_attentions: Optional[bool] = False,
770
+ ) -> Tuple[torch.Tensor]:
771
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
772
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
773
+ self_attention_outputs = self.attention(
774
+ hidden_states,
775
+ attention_mask,
776
+ head_mask,
777
+ output_attentions=output_attentions,
778
+ past_key_value=self_attn_past_key_value,
779
+ )
780
+ attention_output = self_attention_outputs[0]
781
+
782
+ # if decoder, the last output is tuple of self-attn cache
783
+ if self.is_decoder:
784
+ outputs = self_attention_outputs[1:-1]
785
+ present_key_value = self_attention_outputs[-1]
786
+ else:
787
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
788
+
789
+ cross_attn_present_key_value = None
790
+ if self.is_decoder and encoder_hidden_states is not None:
791
+ if not hasattr(self, "crossattention"):
792
+ raise ValueError(
793
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
794
+ " by setting `config.add_cross_attention=True`"
795
+ )
796
+
797
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
798
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
799
+ cross_attention_outputs = self.crossattention(
800
+ attention_output,
801
+ attention_mask,
802
+ head_mask,
803
+ encoder_hidden_states,
804
+ encoder_attention_mask,
805
+ cross_attn_past_key_value,
806
+ output_attentions,
807
+ )
808
+ attention_output = cross_attention_outputs[0]
809
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
810
+
811
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
812
+ cross_attn_present_key_value = cross_attention_outputs[-1]
813
+ present_key_value = present_key_value + cross_attn_present_key_value
814
+
815
+ layer_output = apply_chunking_to_forward(
816
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
817
+ )
818
+ outputs = (layer_output,) + outputs
819
+
820
+ # if decoder, return the attn key/values as the last output
821
+ if self.is_decoder:
822
+ outputs = outputs + (present_key_value,)
823
+
824
+ return outputs
825
+
826
+ def feed_forward_chunk(self, attention_output):
827
+ intermediate_output = self.intermediate(attention_output)
828
+ layer_output = self.output(intermediate_output, attention_output)
829
+ return layer_output
830
+
831
+
832
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaEncoder with Roberta->BridgeTowerText
833
+ class BridgeTowerTextEncoder(nn.Module):
834
+ def __init__(self, config):
835
+ super().__init__()
836
+ self.config = config
837
+ self.layer = nn.ModuleList([BridgeTowerTextLayer(config) for _ in range(config.num_hidden_layers)])
838
+ self.gradient_checkpointing = False
839
+
840
+ def forward(
841
+ self,
842
+ hidden_states: torch.Tensor,
843
+ attention_mask: Optional[torch.FloatTensor] = None,
844
+ head_mask: Optional[torch.FloatTensor] = None,
845
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
846
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
847
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
848
+ use_cache: Optional[bool] = None,
849
+ output_attentions: Optional[bool] = False,
850
+ output_hidden_states: Optional[bool] = False,
851
+ return_dict: Optional[bool] = True,
852
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
853
+ all_hidden_states = () if output_hidden_states else None
854
+ all_self_attentions = () if output_attentions else None
855
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
856
+
857
+ if self.gradient_checkpointing and self.training:
858
+ if use_cache:
859
+ logger.warning_once(
860
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
861
+ )
862
+ use_cache = False
863
+
864
+ next_decoder_cache = () if use_cache else None
865
+ for i, layer_module in enumerate(self.layer):
866
+ if output_hidden_states:
867
+ all_hidden_states = all_hidden_states + (hidden_states,)
868
+
869
+ layer_head_mask = head_mask[i] if head_mask is not None else None
870
+ past_key_value = past_key_values[i] if past_key_values is not None else None
871
+
872
+ if self.gradient_checkpointing and self.training:
873
+ layer_outputs = self._gradient_checkpointing_func(
874
+ layer_module.__call__,
875
+ hidden_states,
876
+ attention_mask,
877
+ layer_head_mask,
878
+ encoder_hidden_states,
879
+ encoder_attention_mask,
880
+ past_key_value,
881
+ output_attentions,
882
+ )
883
+ else:
884
+ layer_outputs = layer_module(
885
+ hidden_states,
886
+ attention_mask,
887
+ layer_head_mask,
888
+ encoder_hidden_states,
889
+ encoder_attention_mask,
890
+ past_key_value,
891
+ output_attentions,
892
+ )
893
+
894
+ hidden_states = layer_outputs[0]
895
+ if use_cache:
896
+ next_decoder_cache += (layer_outputs[-1],)
897
+ if output_attentions:
898
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
899
+ if self.config.add_cross_attention:
900
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
901
+
902
+ if output_hidden_states:
903
+ all_hidden_states = all_hidden_states + (hidden_states,)
904
+
905
+ if not return_dict:
906
+ return tuple(
907
+ v
908
+ for v in [
909
+ hidden_states,
910
+ next_decoder_cache,
911
+ all_hidden_states,
912
+ all_self_attentions,
913
+ all_cross_attentions,
914
+ ]
915
+ if v is not None
916
+ )
917
+ return BaseModelOutputWithPastAndCrossAttentions(
918
+ last_hidden_state=hidden_states,
919
+ past_key_values=next_decoder_cache,
920
+ hidden_states=all_hidden_states,
921
+ attentions=all_self_attentions,
922
+ cross_attentions=all_cross_attentions,
923
+ )
924
+
925
+
926
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings with Roberta->BridgeTowerText
927
+ class BridgeTowerTextEmbeddings(nn.Module):
928
+ """
929
+ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
930
+ """
931
+
932
+ # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__
933
+ def __init__(self, config):
934
+ super().__init__()
935
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
936
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
937
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
938
+
939
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
940
+ # any TensorFlow checkpoint file
941
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
942
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
943
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
944
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
945
+ self.register_buffer(
946
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
947
+ )
948
+ self.register_buffer(
949
+ "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
950
+ )
951
+
952
+ # End copy
953
+ self.padding_idx = config.pad_token_id
954
+ self.position_embeddings = nn.Embedding(
955
+ config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
956
+ )
957
+
958
+ def forward(
959
+ self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
960
+ ):
961
+ if position_ids is None:
962
+ if input_ids is not None:
963
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
964
+ position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
965
+ else:
966
+ position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
967
+
968
+ if input_ids is not None:
969
+ input_shape = input_ids.size()
970
+ else:
971
+ input_shape = inputs_embeds.size()[:-1]
972
+
973
+ seq_length = input_shape[1]
974
+
975
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
976
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
977
+ # issue #5664
978
+ if token_type_ids is None:
979
+ if hasattr(self, "token_type_ids"):
980
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
981
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
982
+ token_type_ids = buffered_token_type_ids_expanded
983
+ else:
984
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
985
+
986
+ if inputs_embeds is None:
987
+ inputs_embeds = self.word_embeddings(input_ids)
988
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
989
+
990
+ embeddings = inputs_embeds + token_type_embeddings
991
+ if self.position_embedding_type == "absolute":
992
+ position_embeddings = self.position_embeddings(position_ids)
993
+ embeddings += position_embeddings
994
+ embeddings = self.LayerNorm(embeddings)
995
+ embeddings = self.dropout(embeddings)
996
+ return embeddings
997
+
998
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
999
+ """
1000
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
1001
+
1002
+ Args:
1003
+ inputs_embeds: torch.Tensor
1004
+
1005
+ Returns: torch.Tensor
1006
+ """
1007
+ input_shape = inputs_embeds.size()[:-1]
1008
+ sequence_length = input_shape[1]
1009
+
1010
+ position_ids = torch.arange(
1011
+ self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
1012
+ )
1013
+ return position_ids.unsqueeze(0).expand(input_shape)
1014
+
1015
+
1016
+ # Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids
1017
+ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
1018
+ """
1019
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
1020
+ are ignored. This is modified from fairseq's `utils.make_positions`.
1021
+
1022
+ Args:
1023
+ x: torch.Tensor x:
1024
+
1025
+ Returns: torch.Tensor
1026
+ """
1027
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
1028
+ mask = input_ids.ne(padding_idx).int()
1029
+ incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
1030
+ return incremental_indices.long() + padding_idx
1031
+
1032
+
1033
+ class BridgeTowerPreTrainedModel(PreTrainedModel):
1034
+ """
1035
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
1036
+ models.
1037
+ """
1038
+
1039
+ config_class = BridgeTowerConfig
1040
+ base_model_prefix = "bridgetower"
1041
+ supports_gradient_checkpointing = False
1042
+ _no_split_modules = ["BridgeTowerSelfAttention", "BridgeTowerResidualAttention"]
1043
+ _skip_keys_device_placement = "past_key_values"
1044
+
1045
+ def _init_weights(self, module):
1046
+ if isinstance(module, BridgeTowerVisionModel):
1047
+ proj_std = (module.visual.transformer.hidden_size**-0.5) * (
1048
+ (2 * module.visual.transformer.num_hidden_layers) ** -0.5
1049
+ )
1050
+ attn_std = module.visual.transformer.hidden_size**-0.5
1051
+ fc_std = (2 * module.visual.transformer.hidden_size) ** -0.5
1052
+ for block in module.visual.transformer.resblocks:
1053
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std * self.config.initializer_factor)
1054
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std * self.config.initializer_factor)
1055
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std * self.config.initializer_factor)
1056
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std * self.config.initializer_factor)
1057
+
1058
+ nn.init.normal_(module.visual.embeddings.class_embedding, std=attn_std * self.config.initializer_factor)
1059
+ nn.init.normal_(
1060
+ module.visual.embeddings.position_embedding.weight, std=attn_std * self.config.initializer_factor
1061
+ )
1062
+ elif isinstance(module, (nn.Linear, nn.Conv2d, nn.Embedding)):
1063
+ module.weight.data.normal_(mean=0.0, std=0.05 * self.config.initializer_factor)
1064
+ elif isinstance(module, nn.LayerNorm):
1065
+ module.bias.data.zero_()
1066
+ module.weight.data.fill_(1.0)
1067
+
1068
+ if isinstance(module, nn.Linear) and module.bias is not None:
1069
+ module.bias.data.zero_()
1070
+
1071
+
1072
+ class BridgeTowerVisionModel(BridgeTowerPreTrainedModel):
1073
+ config_class = BridgeTowerVisionConfig
1074
+
1075
+ def __init__(self, config):
1076
+ super().__init__(config)
1077
+ self.visual = BridgeTowerVisionTransformer(config)
1078
+
1079
+ @property
1080
+ def dtype(self):
1081
+ return self.visual.embeddings.patch_embedding.weight.dtype
1082
+
1083
+ def forward(self, image, image_mask=None, interpolate_pos_encoding=False):
1084
+ return self.visual(image.type(self.dtype), image_mask, interpolate_pos_encoding)
1085
+
1086
+
1087
+ class BridgeTowerTextModel(BridgeTowerPreTrainedModel):
1088
+ """
1089
+
1090
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
1091
+ cross-attention is added between the self-attention layers, following the architecture described in *Attention is
1092
+ all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
1093
+ Kaiser and Illia Polosukhin.
1094
+
1095
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
1096
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
1097
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
1098
+
1099
+ .. _*Attention is all you need*: https://arxiv.org/abs/1706.03762
1100
+
1101
+ """
1102
+
1103
+ config_class = BridgeTowerTextConfig
1104
+
1105
+ def __init__(self, config, add_pooling_layer=True):
1106
+ super().__init__(config)
1107
+ self.config = config
1108
+
1109
+ self.embeddings = BridgeTowerTextEmbeddings(config)
1110
+ self.encoder = BridgeTowerTextEncoder(config)
1111
+
1112
+ self.pooler = BridgeTowerPooler(config) if add_pooling_layer else None
1113
+
1114
+ # Initialize weights and apply final processing
1115
+ self.post_init()
1116
+
1117
+ def get_input_embeddings(self):
1118
+ return self.embeddings.word_embeddings
1119
+
1120
+ def set_input_embeddings(self, value):
1121
+ self.embeddings.word_embeddings = value
1122
+
1123
+ def _prune_heads(self, heads_to_prune):
1124
+ """
1125
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1126
+ class PreTrainedModel
1127
+ """
1128
+ for layer, heads in heads_to_prune.items():
1129
+ self.encoder.layer[layer].attention.prune_heads(heads)
1130
+
1131
+ # Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
1132
+ def forward(
1133
+ self,
1134
+ input_ids: Optional[torch.Tensor] = None,
1135
+ attention_mask: Optional[torch.Tensor] = None,
1136
+ token_type_ids: Optional[torch.Tensor] = None,
1137
+ position_ids: Optional[torch.Tensor] = None,
1138
+ head_mask: Optional[torch.Tensor] = None,
1139
+ inputs_embeds: Optional[torch.Tensor] = None,
1140
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1141
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1142
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1143
+ use_cache: Optional[bool] = None,
1144
+ output_attentions: Optional[bool] = None,
1145
+ output_hidden_states: Optional[bool] = None,
1146
+ return_dict: Optional[bool] = None,
1147
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
1148
+ r"""
1149
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1150
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1151
+ the model is configured as a decoder.
1152
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1153
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1154
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
1155
+
1156
+ - 1 for tokens that are **not masked**,
1157
+ - 0 for tokens that are **masked**.
1158
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1159
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1160
+
1161
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1162
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1163
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1164
+ use_cache (`bool`, *optional*):
1165
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1166
+ `past_key_values`).
1167
+ """
1168
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1169
+ output_hidden_states = (
1170
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1171
+ )
1172
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1173
+
1174
+ if self.config.is_decoder:
1175
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1176
+ else:
1177
+ use_cache = False
1178
+
1179
+ if input_ids is not None and inputs_embeds is not None:
1180
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1181
+ elif input_ids is not None:
1182
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
1183
+ input_shape = input_ids.size()
1184
+ elif inputs_embeds is not None:
1185
+ input_shape = inputs_embeds.size()[:-1]
1186
+ else:
1187
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1188
+
1189
+ batch_size, seq_length = input_shape
1190
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1191
+
1192
+ # past_key_values_length
1193
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
1194
+
1195
+ if attention_mask is None:
1196
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
1197
+
1198
+ if token_type_ids is None:
1199
+ if hasattr(self.embeddings, "token_type_ids"):
1200
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
1201
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
1202
+ token_type_ids = buffered_token_type_ids_expanded
1203
+ else:
1204
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
1205
+
1206
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
1207
+ # ourselves in which case we just need to make it broadcastable to all heads.
1208
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
1209
+
1210
+ # If a 2D or 3D attention mask is provided for the cross-attention
1211
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
1212
+ if self.config.is_decoder and encoder_hidden_states is not None:
1213
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
1214
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
1215
+ if encoder_attention_mask is None:
1216
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
1217
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
1218
+ else:
1219
+ encoder_extended_attention_mask = None
1220
+
1221
+ # Prepare head mask if needed
1222
+ # 1.0 in head_mask indicate we keep the head
1223
+ # attention_probs has shape bsz x n_heads x N x N
1224
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1225
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1226
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1227
+
1228
+ embedding_output = self.embeddings(
1229
+ input_ids=input_ids,
1230
+ position_ids=position_ids,
1231
+ token_type_ids=token_type_ids,
1232
+ inputs_embeds=inputs_embeds,
1233
+ past_key_values_length=past_key_values_length,
1234
+ )
1235
+ encoder_outputs = self.encoder(
1236
+ embedding_output,
1237
+ attention_mask=extended_attention_mask,
1238
+ head_mask=head_mask,
1239
+ encoder_hidden_states=encoder_hidden_states,
1240
+ encoder_attention_mask=encoder_extended_attention_mask,
1241
+ past_key_values=past_key_values,
1242
+ use_cache=use_cache,
1243
+ output_attentions=output_attentions,
1244
+ output_hidden_states=output_hidden_states,
1245
+ return_dict=return_dict,
1246
+ )
1247
+ sequence_output = encoder_outputs[0]
1248
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1249
+
1250
+ if not return_dict:
1251
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
1252
+
1253
+ return BaseModelOutputWithPoolingAndCrossAttentions(
1254
+ last_hidden_state=sequence_output,
1255
+ pooler_output=pooled_output,
1256
+ past_key_values=encoder_outputs.past_key_values,
1257
+ hidden_states=encoder_outputs.hidden_states,
1258
+ attentions=encoder_outputs.attentions,
1259
+ cross_attentions=encoder_outputs.cross_attentions,
1260
+ )
1261
+
1262
+
1263
+ @add_start_docstrings(
1264
+ "The bare BridgeTower Model transformer outputting BridgeTowerModelOutput object without any specific head on"
1265
+ " top.",
1266
+ BRIDGETOWER_START_DOCSTRING,
1267
+ )
1268
+ class BridgeTowerModel(BridgeTowerPreTrainedModel):
1269
+ def __init__(self, config):
1270
+ super().__init__(config)
1271
+ self.config = config
1272
+ vision_config = config.vision_config
1273
+ text_config = config.text_config
1274
+
1275
+ if config.share_cross_modal_transformer_layers:
1276
+ self.cross_modal_text_transform = nn.Linear(text_config.hidden_size, config.hidden_size)
1277
+ self.cross_modal_image_transform = nn.Linear(vision_config.hidden_size, config.hidden_size)
1278
+ else:
1279
+ self.cross_modal_text_transform = nn.ModuleList(
1280
+ [nn.Linear(text_config.hidden_size, config.hidden_size) for _ in range(config.num_hidden_layers)]
1281
+ )
1282
+ self.cross_modal_image_transform = nn.ModuleList(
1283
+ [nn.Linear(vision_config.hidden_size, config.hidden_size) for _ in range(config.num_hidden_layers)]
1284
+ )
1285
+
1286
+ self.token_type_embeddings = nn.Embedding(2, config.hidden_size)
1287
+
1288
+ self.vision_model = BridgeTowerVisionModel(vision_config)
1289
+
1290
+ self.text_model = BridgeTowerTextModel(text_config)
1291
+
1292
+ if not vision_config.share_layernorm and config.init_layernorm_from_vision_encoder:
1293
+ for ln in self.vision_model.visual.cross_modal_ln_separate:
1294
+ ln.weight.data = self.vision_model.visual.ln_post.weight.data
1295
+ ln.bias.data = self.vision_model.visual.ln_post.bias.data
1296
+
1297
+ self.cross_modal_image_layers = nn.ModuleList(
1298
+ [BridgeTowerBertCrossLayer(text_config) for _ in range(config.num_hidden_layers)]
1299
+ )
1300
+ self.cross_modal_text_layers = nn.ModuleList(
1301
+ [BridgeTowerBertCrossLayer(text_config) for _ in range(config.num_hidden_layers)]
1302
+ )
1303
+
1304
+ # Class token => Linear => Tanh
1305
+ self.cross_modal_image_pooler = BridgeTowerPooler(config)
1306
+ self.cross_modal_text_pooler = BridgeTowerPooler(config)
1307
+
1308
+ # Initialize BridgeTower Components
1309
+ self.cross_modal_text_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1310
+ self.cross_modal_image_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1311
+
1312
+ if config.share_link_tower_layers:
1313
+ self.cross_modal_text_link_tower = BridgeTowerLinkTower(config)
1314
+ self.cross_modal_image_link_tower = BridgeTowerLinkTower(config)
1315
+ else:
1316
+ self.cross_modal_text_link_tower = nn.ModuleList(
1317
+ [BridgeTowerLinkTower(config) for _ in range(config.num_hidden_layers - 1)]
1318
+ )
1319
+ self.cross_modal_image_link_tower = nn.ModuleList(
1320
+ [BridgeTowerLinkTower(config) for _ in range(config.num_hidden_layers - 1)]
1321
+ )
1322
+
1323
+ self.post_init()
1324
+
1325
+ def get_input_embeddings(self):
1326
+ return self.text_model.get_input_embeddings()
1327
+
1328
+ def set_input_embeddings(self, value):
1329
+ self.text_model.set_input_embeddings(value)
1330
+
1331
+ @add_start_docstrings_to_model_forward(BRIDGETOWER_INPUTS_DOCSTRING)
1332
+ @replace_return_docstrings(output_type=BridgeTowerModelOutput, config_class=_CONFIG_FOR_DOC)
1333
+ def forward(
1334
+ self,
1335
+ input_ids: Optional[torch.LongTensor] = None,
1336
+ attention_mask: Optional[torch.FloatTensor] = None,
1337
+ token_type_ids: Optional[torch.LongTensor] = None,
1338
+ pixel_values: Optional[torch.FloatTensor] = None,
1339
+ pixel_mask: Optional[torch.LongTensor] = None,
1340
+ head_mask: Optional[torch.FloatTensor] = None,
1341
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1342
+ image_embeds: Optional[torch.FloatTensor] = None,
1343
+ image_token_type_idx: Optional[int] = None,
1344
+ output_attentions: Optional[bool] = None,
1345
+ output_hidden_states: Optional[bool] = None,
1346
+ return_dict: Optional[bool] = None,
1347
+ labels: Optional[torch.LongTensor] = None,
1348
+ interpolate_pos_encoding: bool = False,
1349
+ ) -> Union[Tuple[torch.Tensor], BridgeTowerModelOutput]:
1350
+ r"""
1351
+ output_hidden_states (`bool`, *optional*):
1352
+ If set to `True`, hidden states are returned as a list containing the hidden states of text, image, and
1353
+ cross-modal components respectively. i.e. `(hidden_states_text, hidden_states_image,
1354
+ hidden_states_cross_modal)` where each element is a list of the hidden states of the corresponding
1355
+ modality. `hidden_states_txt/img` are a list of tensors corresponding to unimodal hidden states and
1356
+ `hidden_states_cross_modal` is a list of tuples containing `cross_modal_text_hidden_states` and
1357
+ `cross_modal_image_hidden_states` of each brdige layer.
1358
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1359
+ Labels are currently not supported.
1360
+ Returns:
1361
+
1362
+ Examples:
1363
+
1364
+ ```python
1365
+ >>> from transformers import BridgeTowerProcessor, BridgeTowerModel
1366
+ >>> from PIL import Image
1367
+ >>> import requests
1368
+
1369
+ >>> # prepare image and text
1370
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1371
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1372
+ >>> text = "hello world"
1373
+ >>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-base")
1374
+ >>> model = BridgeTowerModel.from_pretrained("BridgeTower/bridgetower-base")
1375
+
1376
+ >>> inputs = processor(image, text, return_tensors="pt")
1377
+ >>> outputs = model(**inputs)
1378
+ >>> outputs.keys()
1379
+ odict_keys(['text_features', 'image_features', 'pooler_output'])
1380
+ ```"""
1381
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1382
+ output_hidden_states = (
1383
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1384
+ )
1385
+ all_hidden_states_text = () if output_hidden_states else None
1386
+ all_hidden_states_image = () if output_hidden_states else None
1387
+ all_hidden_states_cross = () if output_hidden_states else None
1388
+ all_hidden_states = () if output_hidden_states else None
1389
+ all_self_attentions = () if output_attentions else None
1390
+
1391
+ if inputs_embeds is not None and input_ids is None:
1392
+ raise NotImplementedError(
1393
+ "BridgeTowerModel does not use `inputs_embeds`. Make sure to pass in `input_ids` instead."
1394
+ )
1395
+
1396
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1397
+ image_token_type_idx = image_token_type_idx if image_token_type_idx else 1
1398
+ input_shape = input_ids.size()
1399
+ text_embeds = self.text_model.embeddings(input_ids=input_ids)
1400
+
1401
+ if output_hidden_states:
1402
+ all_hidden_states_text += (text_embeds,)
1403
+
1404
+ if attention_mask is None:
1405
+ attention_mask = torch.ones(input_shape, dtype=torch.long, device=input_ids.device)
1406
+ extend_text_masks = self.text_model.get_extended_attention_mask(attention_mask, input_shape).to(
1407
+ input_ids.device
1408
+ )
1409
+
1410
+ # The split_index determines how many layers of the uni-modal encoder are applied before the cross-modal encoder
1411
+ split_index = len(self.text_model.encoder.layer) - self.config.num_hidden_layers + 1
1412
+
1413
+ # Run the first 'split_index' layers of the textual encoder
1414
+ for layer in self.text_model.encoder.layer[:split_index]:
1415
+ text_embeds = layer(text_embeds, extend_text_masks)[0]
1416
+
1417
+ if output_hidden_states:
1418
+ all_hidden_states_text += (text_embeds,)
1419
+
1420
+ if image_embeds is None:
1421
+ image_embeds = self.vision_model.visual.forward_pre(
1422
+ pixel_values.type(self.vision_model.dtype), interpolate_pos_encoding=interpolate_pos_encoding
1423
+ )
1424
+ else:
1425
+ # Permute as BridgeTowerResidualAttention has batch_first=True
1426
+ image_embeds = image_embeds.permute(1, 0, 2)
1427
+
1428
+ if output_hidden_states:
1429
+ all_hidden_states_image += (image_embeds,)
1430
+
1431
+ # Run the first 'split_index' layers of the visual encoder
1432
+ for block in self.vision_model.visual.transformer.resblocks[:split_index]:
1433
+ image_embeds = block(image_embeds)
1434
+ if output_hidden_states:
1435
+ all_hidden_states_image += (image_embeds,)
1436
+
1437
+ image_embeds_with_ln = self.vision_model.visual.forward_post(image_embeds.type(self.vision_model.dtype))
1438
+
1439
+ # first layer is a special case because we don't have the output from the cross-encoder yet
1440
+ cross_modal_text = self.cross_modal_text_transform(text_embeds)
1441
+
1442
+ text_token_type_embeddings = self.token_type_embeddings(
1443
+ torch.zeros(1, dtype=torch.long, device=input_ids.device)
1444
+ ).expand_as(cross_modal_text)
1445
+
1446
+ cross_modal_text = self.cross_modal_text_layernorm(cross_modal_text + text_token_type_embeddings)
1447
+
1448
+ image_embeds_with_ln = self.cross_modal_image_transform(image_embeds_with_ln)
1449
+ image_token_type_embeddings = self.token_type_embeddings(
1450
+ torch.full((1,), image_token_type_idx, dtype=torch.long, device=input_ids.device)
1451
+ ).expand_as(image_embeds_with_ln)
1452
+
1453
+ image_embeds_with_ln = image_embeds_with_ln + image_token_type_embeddings
1454
+ cross_modal_image = self.cross_modal_image_layernorm(image_embeds_with_ln)
1455
+
1456
+ pixel_mask = torch.ones(
1457
+ (cross_modal_image.size(0), cross_modal_image.size(1)),
1458
+ dtype=torch.long,
1459
+ device=input_ids.device,
1460
+ )
1461
+ extend_image_masks = self.text_model.get_extended_attention_mask(pixel_mask, pixel_mask.size()).to(
1462
+ input_ids.device
1463
+ )
1464
+
1465
+ layer_outputs_text = self.cross_modal_text_layers[0](
1466
+ cross_modal_text,
1467
+ cross_modal_image,
1468
+ attention_mask=extend_text_masks,
1469
+ encoder_attention_mask=extend_image_masks,
1470
+ output_attentions=output_attentions,
1471
+ )
1472
+ cross_text_features = layer_outputs_text[0]
1473
+
1474
+ layer_outputs_image = self.cross_modal_image_layers[0](
1475
+ cross_modal_image,
1476
+ cross_modal_text,
1477
+ attention_mask=extend_image_masks,
1478
+ encoder_attention_mask=extend_text_masks,
1479
+ output_attentions=output_attentions,
1480
+ )
1481
+ cross_image_features = layer_outputs_image[0]
1482
+
1483
+ if output_hidden_states:
1484
+ all_hidden_states_cross += ((cross_text_features, cross_image_features),)
1485
+
1486
+ if output_attentions:
1487
+ all_self_attentions += ((layer_outputs_text[1], layer_outputs_image[1]),)
1488
+
1489
+ link_layer_index = 0
1490
+
1491
+ # Each of the top 6 layers of the visual and textual encoders ([split_index:]) is connected to each layer of
1492
+ # the cross-modal encoder via bridge layers, which brings bottom-up alignment and fusion to the cross-modal encoder.
1493
+ for i in range(split_index, len(self.text_model.encoder.layer)):
1494
+ text_embeds = self.text_model.encoder.layer[i](text_embeds, extend_text_masks)[0]
1495
+ image_embeds = self.vision_model.visual.transformer.resblocks[i](image_embeds).type(
1496
+ self.vision_model.dtype
1497
+ )
1498
+ image_embeds_with_ln = (
1499
+ self.cross_modal_image_transform(self.vision_model.visual.forward_post(image_embeds))
1500
+ + image_token_type_embeddings
1501
+ )
1502
+
1503
+ text_link_tower = self.cross_modal_text_link_tower[link_layer_index]
1504
+ image_link_tower = self.cross_modal_image_link_tower[link_layer_index]
1505
+
1506
+ # Bridge layers for textual and visual encoders
1507
+ cross_text_features_ = text_link_tower(
1508
+ self.cross_modal_text_transform(text_embeds) + text_token_type_embeddings,
1509
+ cross_text_features,
1510
+ extend_text_masks,
1511
+ )
1512
+ cross_image_features_ = image_link_tower(image_embeds_with_ln, cross_image_features, extend_image_masks)
1513
+
1514
+ # Cross-modal encoder via bridge layers of textual and visual encoders
1515
+ layer_outputs_text = self.cross_modal_text_layers[link_layer_index + 1](
1516
+ cross_text_features_,
1517
+ cross_image_features_,
1518
+ attention_mask=extend_text_masks,
1519
+ encoder_attention_mask=extend_image_masks,
1520
+ output_attentions=output_attentions,
1521
+ )
1522
+ cross_text_features = layer_outputs_text[0]
1523
+
1524
+ layer_outputs_image = self.cross_modal_image_layers[link_layer_index + 1](
1525
+ cross_image_features_,
1526
+ cross_text_features_,
1527
+ attention_mask=extend_image_masks,
1528
+ encoder_attention_mask=extend_text_masks,
1529
+ output_attentions=output_attentions,
1530
+ )
1531
+ cross_image_features = layer_outputs_image[0]
1532
+
1533
+ link_layer_index += 1
1534
+
1535
+ if output_hidden_states:
1536
+ all_hidden_states_text += (text_embeds,)
1537
+ all_hidden_states_image += (image_embeds,)
1538
+ all_hidden_states_cross += ((cross_text_features, cross_image_features),)
1539
+
1540
+ if output_attentions:
1541
+ all_self_attentions += ((layer_outputs_text[1], layer_outputs_image[1]),)
1542
+
1543
+ # Concatenate the cls token of the text and image features to get the final represtation
1544
+ text_features, image_features = cross_text_features, cross_image_features
1545
+ cls_features = self.get_cls_features(text_features, image_features)
1546
+
1547
+ if output_hidden_states:
1548
+ all_hidden_states = (all_hidden_states_text, all_hidden_states_image, all_hidden_states_cross)
1549
+
1550
+ if not return_dict:
1551
+ return tuple(
1552
+ v
1553
+ for v in [text_features, image_features, cls_features, all_hidden_states, all_self_attentions]
1554
+ if v is not None
1555
+ )
1556
+
1557
+ return BridgeTowerModelOutput(
1558
+ text_features=text_features,
1559
+ image_features=image_features,
1560
+ pooler_output=cls_features,
1561
+ hidden_states=all_hidden_states,
1562
+ attentions=all_self_attentions,
1563
+ )
1564
+
1565
+ def get_cls_features(self, text_features, image_features):
1566
+ cls_features_text = self.cross_modal_text_pooler(text_features)
1567
+ cls_features_image = self.cross_modal_image_pooler(image_features)
1568
+ return torch.cat([cls_features_text, cls_features_image], dim=-1)
1569
+
1570
+
1571
+ # Copied from transformers.models.vilt.modeling_vilt.ViltPredictionHeadTransform with Vilt->BridgeTower
1572
+ class BridgeTowerPredictionHeadTransform(nn.Module):
1573
+ def __init__(self, config):
1574
+ super().__init__()
1575
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1576
+ if isinstance(config.hidden_act, str):
1577
+ self.transform_act_fn = ACT2FN[config.hidden_act]
1578
+ else:
1579
+ self.transform_act_fn = config.hidden_act
1580
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1581
+
1582
+ def forward(self, hidden_states):
1583
+ hidden_states = self.dense(hidden_states)
1584
+ hidden_states = self.transform_act_fn(hidden_states)
1585
+ hidden_states = self.LayerNorm(hidden_states)
1586
+ return hidden_states
1587
+
1588
+
1589
+ class BridgeTowerMLMHead(nn.Module):
1590
+ def __init__(self, config, weight=None):
1591
+ super().__init__()
1592
+ self.config = config
1593
+ self.transform = BridgeTowerPredictionHeadTransform(config)
1594
+ self.decoder = nn.Linear(config.hidden_size, config.text_config.vocab_size, bias=False)
1595
+ self.bias = nn.Parameter(torch.zeros(config.text_config.vocab_size))
1596
+ if weight is not None:
1597
+ self.decoder.weight = weight
1598
+
1599
+ def forward(self, x):
1600
+ mlm_score = self.transform(x)
1601
+ mlm_score = self.decoder(mlm_score) + self.bias
1602
+ return mlm_score
1603
+
1604
+
1605
+ class BridgeTowerITMHead(nn.Module):
1606
+ def __init__(self, hidden_size):
1607
+ super().__init__()
1608
+ self.fc = nn.Linear(hidden_size, 2)
1609
+
1610
+ def forward(self, x):
1611
+ itm_score = self.fc(x)
1612
+ return itm_score
1613
+
1614
+
1615
+ @add_start_docstrings(
1616
+ """
1617
+ BridgeTower Model with a language modeling head on top as done during pretraining.
1618
+ """,
1619
+ BRIDGETOWER_START_DOCSTRING,
1620
+ )
1621
+ class BridgeTowerForMaskedLM(BridgeTowerPreTrainedModel):
1622
+ _tied_weights_keys = ["mlm_score.decoder.weight"]
1623
+
1624
+ def __init__(self, config):
1625
+ super().__init__(config)
1626
+
1627
+ self.bridgetower = BridgeTowerModel(config)
1628
+ self.mlm_score = BridgeTowerMLMHead(config)
1629
+
1630
+ # Initialize weights and apply final processing
1631
+ self.post_init()
1632
+
1633
+ def get_output_embeddings(self):
1634
+ return self.mlm_score.decoder
1635
+
1636
+ def set_output_embeddings(self, new_embeddings):
1637
+ self.mlm_score.decoder = new_embeddings
1638
+
1639
+ @add_start_docstrings_to_model_forward(BRIDGETOWER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1640
+ @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
1641
+ def forward(
1642
+ self,
1643
+ input_ids: Optional[torch.LongTensor] = None,
1644
+ attention_mask: Optional[torch.FloatTensor] = None,
1645
+ token_type_ids: Optional[torch.LongTensor] = None,
1646
+ pixel_values: Optional[torch.FloatTensor] = None,
1647
+ pixel_mask: Optional[torch.LongTensor] = None,
1648
+ head_mask: Optional[torch.FloatTensor] = None,
1649
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1650
+ image_embeds: Optional[torch.FloatTensor] = None,
1651
+ output_attentions: Optional[bool] = None,
1652
+ output_hidden_states: Optional[bool] = None,
1653
+ return_dict: Optional[bool] = None,
1654
+ labels: Optional[torch.LongTensor] = None,
1655
+ ) -> Union[MaskedLMOutput, Tuple[torch.FloatTensor]]:
1656
+ r"""
1657
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1658
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1659
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1660
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1661
+ Returns:
1662
+
1663
+ Examples:
1664
+
1665
+ ```python
1666
+ >>> from transformers import BridgeTowerProcessor, BridgeTowerForMaskedLM
1667
+ >>> from PIL import Image
1668
+ >>> import requests
1669
+
1670
+ >>> url = "http://images.cocodataset.org/val2017/000000360943.jpg"
1671
+ >>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
1672
+ >>> text = "a <mask> looking out of the window"
1673
+
1674
+ >>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-base-itm-mlm")
1675
+ >>> model = BridgeTowerForMaskedLM.from_pretrained("BridgeTower/bridgetower-base-itm-mlm")
1676
+
1677
+ >>> # prepare inputs
1678
+ >>> encoding = processor(image, text, return_tensors="pt")
1679
+
1680
+ >>> # forward pass
1681
+ >>> outputs = model(**encoding)
1682
+
1683
+ >>> results = processor.decode(outputs.logits.argmax(dim=-1).squeeze(0).tolist())
1684
+
1685
+ >>> print(results)
1686
+ .a cat looking out of the window.
1687
+ ```"""
1688
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1689
+ outputs = self.bridgetower(
1690
+ input_ids,
1691
+ attention_mask=attention_mask,
1692
+ token_type_ids=token_type_ids,
1693
+ pixel_values=pixel_values,
1694
+ pixel_mask=pixel_mask,
1695
+ head_mask=head_mask,
1696
+ inputs_embeds=inputs_embeds,
1697
+ image_embeds=image_embeds,
1698
+ output_attentions=output_attentions,
1699
+ output_hidden_states=output_hidden_states,
1700
+ return_dict=return_dict,
1701
+ )
1702
+
1703
+ mlm_logits = self.mlm_score(outputs.text_features if return_dict else outputs[0])
1704
+ masked_lm_loss = None
1705
+ if labels is not None:
1706
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1707
+
1708
+ labels = labels.to(mlm_logits.device)
1709
+ masked_lm_loss = loss_fct(mlm_logits.view(-1, self.config.text_config.vocab_size), labels.view(-1))
1710
+
1711
+ if not return_dict:
1712
+ output = tuple(mlm_logits)
1713
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1714
+
1715
+ return MaskedLMOutput(
1716
+ loss=masked_lm_loss,
1717
+ logits=mlm_logits,
1718
+ hidden_states=outputs.hidden_states,
1719
+ attentions=outputs.attentions,
1720
+ )
1721
+
1722
+
1723
+ @add_start_docstrings(
1724
+ """
1725
+ BridgeTower Model transformer with a classifier head on top (a linear layer on top of the final hidden state of the
1726
+ [CLS] token) for image-to-text matching.
1727
+ """,
1728
+ BRIDGETOWER_START_DOCSTRING,
1729
+ )
1730
+ class BridgeTowerForImageAndTextRetrieval(BridgeTowerPreTrainedModel):
1731
+ def __init__(self, config):
1732
+ super().__init__(config)
1733
+
1734
+ self.bridgetower = BridgeTowerModel(config)
1735
+
1736
+ self.itm_score = BridgeTowerITMHead(config.hidden_size * 2)
1737
+
1738
+ # Initialize weights and apply final processing
1739
+ self.post_init()
1740
+
1741
+ @add_start_docstrings_to_model_forward(BRIDGETOWER_INPUTS_DOCSTRING)
1742
+ @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
1743
+ def forward(
1744
+ self,
1745
+ input_ids: Optional[torch.LongTensor] = None,
1746
+ attention_mask: Optional[torch.FloatTensor] = None,
1747
+ token_type_ids: Optional[torch.LongTensor] = None,
1748
+ pixel_values: Optional[torch.FloatTensor] = None,
1749
+ pixel_mask: Optional[torch.LongTensor] = None,
1750
+ head_mask: Optional[torch.FloatTensor] = None,
1751
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1752
+ image_embeds: Optional[torch.FloatTensor] = None,
1753
+ output_attentions: Optional[bool] = None,
1754
+ output_hidden_states: Optional[bool] = None,
1755
+ return_dict: Optional[bool] = None,
1756
+ labels: Optional[torch.LongTensor] = None,
1757
+ ) -> Union[SequenceClassifierOutput, Tuple[torch.FloatTensor]]:
1758
+ r"""
1759
+ labels (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*):
1760
+ Labels for computing the image-text matching loss. 0 means the pairs don't match and 1 means they match.
1761
+ The pairs with 0 will be skipped for calculation.
1762
+ Returns:
1763
+
1764
+ Examples:
1765
+
1766
+ ```python
1767
+ >>> from transformers import BridgeTowerProcessor, BridgeTowerForImageAndTextRetrieval
1768
+ >>> import requests
1769
+ >>> from PIL import Image
1770
+
1771
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1772
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1773
+ >>> texts = ["An image of two cats chilling on a couch", "A football player scoring a goal"]
1774
+
1775
+ >>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-base-itm-mlm")
1776
+ >>> model = BridgeTowerForImageAndTextRetrieval.from_pretrained("BridgeTower/bridgetower-base-itm-mlm")
1777
+
1778
+ >>> # forward pass
1779
+ >>> scores = dict()
1780
+ >>> for text in texts:
1781
+ ... # prepare inputs
1782
+ ... encoding = processor(image, text, return_tensors="pt")
1783
+ ... outputs = model(**encoding)
1784
+ ... scores[text] = outputs.logits[0, 1].item()
1785
+ ```"""
1786
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1787
+
1788
+ outputs = self.bridgetower(
1789
+ input_ids,
1790
+ attention_mask=attention_mask,
1791
+ token_type_ids=token_type_ids,
1792
+ pixel_values=pixel_values,
1793
+ pixel_mask=pixel_mask,
1794
+ head_mask=head_mask,
1795
+ inputs_embeds=inputs_embeds,
1796
+ image_embeds=image_embeds,
1797
+ output_attentions=output_attentions,
1798
+ output_hidden_states=output_hidden_states,
1799
+ return_dict=return_dict,
1800
+ )
1801
+
1802
+ pooler_output = outputs.pooler_output if return_dict else outputs[2]
1803
+
1804
+ logits = self.itm_score(pooler_output)
1805
+
1806
+ itm_loss = None
1807
+ if labels is not None:
1808
+ loss_fct = CrossEntropyLoss()
1809
+
1810
+ labels = labels.to(logits.device)
1811
+ itm_loss = loss_fct(logits, labels)
1812
+
1813
+ if not return_dict:
1814
+ output = tuple(logits)
1815
+ return ((itm_loss,) + output) if itm_loss is not None else output
1816
+
1817
+ return SequenceClassifierOutput(
1818
+ loss=itm_loss,
1819
+ logits=logits,
1820
+ hidden_states=outputs.hidden_states,
1821
+ attentions=outputs.attentions,
1822
+ )
1823
+
1824
+
1825
+ class BridgeTowerContrastiveHead(nn.Module):
1826
+ def __init__(self, hidden_size, embed_size):
1827
+ super().__init__()
1828
+ self.fc = nn.Linear(hidden_size, embed_size)
1829
+
1830
+ def forward(self, x):
1831
+ x = self.fc(x)
1832
+ return x
1833
+
1834
+
1835
+ @add_start_docstrings(
1836
+ """
1837
+ BridgeTower Model with a image-text contrastive head on top computing image-text contrastive loss.
1838
+ """,
1839
+ BRIDGETOWER_START_DOCSTRING,
1840
+ )
1841
+ class BridgeTowerForContrastiveLearning(BridgeTowerPreTrainedModel):
1842
+ def __init__(self, config):
1843
+ super().__init__(config)
1844
+
1845
+ self.bridgetower = BridgeTowerModel(config)
1846
+
1847
+ self.itc_text_head = BridgeTowerContrastiveHead(config.hidden_size, config.contrastive_hidden_size)
1848
+ self.itc_image_head = BridgeTowerContrastiveHead(config.hidden_size, config.contrastive_hidden_size)
1849
+ self.itc_cross_modal_head = BridgeTowerContrastiveHead(config.hidden_size * 2, config.contrastive_hidden_size)
1850
+
1851
+ self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
1852
+ # Initialize weights and apply final processing
1853
+ self.post_init()
1854
+
1855
+ @add_start_docstrings_to_model_forward(BRIDGETOWER_INPUTS_DOCSTRING)
1856
+ @replace_return_docstrings(output_type=BridgeTowerContrastiveOutput, config_class=_CONFIG_FOR_DOC)
1857
+ def forward(
1858
+ self,
1859
+ input_ids: Optional[torch.LongTensor] = None,
1860
+ attention_mask: Optional[torch.FloatTensor] = None,
1861
+ token_type_ids: Optional[torch.LongTensor] = None,
1862
+ pixel_values: Optional[torch.FloatTensor] = None,
1863
+ pixel_mask: Optional[torch.LongTensor] = None,
1864
+ head_mask: Optional[torch.FloatTensor] = None,
1865
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1866
+ image_embeds: Optional[torch.FloatTensor] = None,
1867
+ output_attentions: Optional[bool] = None,
1868
+ output_hidden_states: Optional[bool] = True,
1869
+ return_dict: Optional[bool] = None,
1870
+ return_loss: Optional[bool] = None,
1871
+ ) -> Union[BridgeTowerContrastiveOutput, Tuple[torch.FloatTensor]]:
1872
+ r"""
1873
+ return_loss (`bool`, *optional*):
1874
+ Whether or not to return the contrastive loss.
1875
+ Returns:
1876
+
1877
+ Examples:
1878
+
1879
+ ```python
1880
+ >>> from transformers import BridgeTowerProcessor, BridgeTowerForContrastiveLearning
1881
+ >>> import requests
1882
+ >>> from PIL import Image
1883
+ >>> import torch
1884
+
1885
+ >>> image_urls = [
1886
+ ... "https://farm4.staticflickr.com/3395/3428278415_81c3e27f15_z.jpg",
1887
+ ... "http://images.cocodataset.org/val2017/000000039769.jpg",
1888
+ ... ]
1889
+ >>> texts = ["two dogs in a car", "two cats sleeping on a couch"]
1890
+ >>> images = [Image.open(requests.get(url, stream=True).raw) for url in image_urls]
1891
+
1892
+ >>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
1893
+ >>> model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
1894
+
1895
+ >>> inputs = processor(images, texts, padding=True, return_tensors="pt")
1896
+ >>> loss = model(**inputs, return_loss=True).loss
1897
+
1898
+ >>> inputs = processor(images, texts[::-1], padding=True, return_tensors="pt")
1899
+ >>> loss_swapped = model(**inputs, return_loss=True).loss
1900
+
1901
+ >>> print("Loss", round(loss.item(), 4))
1902
+ Loss 0.0019
1903
+
1904
+ >>> print("Loss with swapped images", round(loss_swapped.item(), 4))
1905
+ Loss with swapped images 2.126
1906
+ ```"""
1907
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1908
+
1909
+ outputs = self.bridgetower(
1910
+ input_ids,
1911
+ attention_mask=attention_mask,
1912
+ token_type_ids=token_type_ids,
1913
+ pixel_values=pixel_values,
1914
+ pixel_mask=pixel_mask,
1915
+ head_mask=head_mask,
1916
+ inputs_embeds=inputs_embeds,
1917
+ image_embeds=image_embeds,
1918
+ output_attentions=output_attentions,
1919
+ output_hidden_states=True,
1920
+ return_dict=return_dict,
1921
+ )
1922
+
1923
+ pooler_output = outputs.pooler_output if return_dict else outputs[2]
1924
+ hidden_states_txt, hidden_states_img, hidden_states_cross_modal = (
1925
+ outputs.hidden_states if return_dict else outputs[3]
1926
+ )
1927
+
1928
+ text_embeds = hidden_states_txt[-1]
1929
+ image_embeds = hidden_states_img[-1]
1930
+
1931
+ image_embeds_with_ln = self.bridgetower.vision_model.visual.forward_post(image_embeds)
1932
+ image_token_type_embeddings = self.bridgetower.token_type_embeddings(
1933
+ torch.full((1,), 1, dtype=torch.long, device=self.bridgetower.token_type_embeddings.weight.device)
1934
+ ).expand_as(image_embeds_with_ln)
1935
+
1936
+ image_embeds = self.bridgetower.cross_modal_image_transform(image_embeds_with_ln) + image_token_type_embeddings
1937
+
1938
+ # normalized features
1939
+ text_embeds = nn.functional.normalize(self.itc_text_head(text_embeds[:, 0, :]), dim=-1, p=2)
1940
+ image_embeds = nn.functional.normalize(self.itc_image_head(image_embeds[:, 0, :]), dim=-1, p=2).to(
1941
+ device=text_embeds.device
1942
+ )
1943
+ cross_embeds = nn.functional.normalize(self.itc_cross_modal_head(pooler_output), dim=-1, p=2).to(
1944
+ device=text_embeds.device
1945
+ )
1946
+
1947
+ logits = torch.stack([text_embeds, image_embeds, cross_embeds], dim=-2)
1948
+
1949
+ logit_scale = self.logit_scale.exp().to(device=text_embeds.device)
1950
+ logits_text_to_image = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
1951
+ logits_text_to_cross = torch.matmul(text_embeds, cross_embeds.t()) * logit_scale
1952
+ logits_image_to_cross = torch.matmul(image_embeds, cross_embeds.t()) * logit_scale
1953
+
1954
+ itc_loss = None
1955
+
1956
+ if return_loss:
1957
+ labels = torch.arange(len(logits), device=logits.device)
1958
+ text_to_image_loss = nn.functional.cross_entropy(logits_text_to_image, labels)
1959
+ text_to_cross_loss = nn.functional.cross_entropy(logits_text_to_cross, labels)
1960
+ image_to_cross_loss = nn.functional.cross_entropy(logits_image_to_cross, labels)
1961
+ itc_loss = (text_to_image_loss + text_to_cross_loss + image_to_cross_loss) / 3.0
1962
+
1963
+ if not return_dict:
1964
+ output = (logits, text_embeds, image_embeds, cross_embeds) + outputs[3:]
1965
+ return ((itc_loss,) + output) if itc_loss is not None else output
1966
+
1967
+ return BridgeTowerContrastiveOutput(
1968
+ loss=itc_loss,
1969
+ logits=logits,
1970
+ text_embeds=text_embeds,
1971
+ image_embeds=image_embeds,
1972
+ cross_embeds=cross_embeds,
1973
+ hidden_states=outputs.hidden_states,
1974
+ attentions=outputs.attentions,
1975
+ )
1976
+
1977
+
1978
+ __all__ = [
1979
+ "BridgeTowerForContrastiveLearning",
1980
+ "BridgeTowerForImageAndTextRetrieval",
1981
+ "BridgeTowerForMaskedLM",
1982
+ "BridgeTowerModel",
1983
+ "BridgeTowerPreTrainedModel",
1984
+ ]
docs/transformers/src/transformers/models/bridgetower/processing_bridgetower.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Intel Labs Team Authors, The Microsoft Research Team Authors and HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Processor class for BridgeTower.
17
+ """
18
+
19
+ from typing import List, Union
20
+
21
+ from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
22
+ from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
23
+
24
+
25
+ class BridgeTowerProcessorKwargs(ProcessingKwargs, total=False):
26
+ _defaults = {
27
+ "text_kwargs": {
28
+ "add_special_tokens": True,
29
+ "padding": False,
30
+ "stride": 0,
31
+ "return_overflowing_tokens": False,
32
+ "return_special_tokens_mask": False,
33
+ "return_offsets_mapping": False,
34
+ "return_length": False,
35
+ "verbose": True,
36
+ },
37
+ "images_kwargs": {
38
+ "do_normalize": True,
39
+ "do_center_crop": True,
40
+ },
41
+ }
42
+
43
+
44
+ class BridgeTowerProcessor(ProcessorMixin):
45
+ r"""
46
+ Constructs a BridgeTower processor which wraps a Roberta tokenizer and BridgeTower image processor into a single
47
+ processor.
48
+
49
+ [`BridgeTowerProcessor`] offers all the functionalities of [`BridgeTowerImageProcessor`] and
50
+ [`RobertaTokenizerFast`]. See the docstring of [`~BridgeTowerProcessor.__call__`] and
51
+ [`~BridgeTowerProcessor.decode`] for more information.
52
+
53
+ Args:
54
+ image_processor (`BridgeTowerImageProcessor`):
55
+ An instance of [`BridgeTowerImageProcessor`]. The image processor is a required input.
56
+ tokenizer (`RobertaTokenizerFast`):
57
+ An instance of ['RobertaTokenizerFast`]. The tokenizer is a required input.
58
+ """
59
+
60
+ attributes = ["image_processor", "tokenizer"]
61
+ image_processor_class = "BridgeTowerImageProcessor"
62
+ tokenizer_class = ("RobertaTokenizer", "RobertaTokenizerFast")
63
+
64
+ def __init__(self, image_processor, tokenizer):
65
+ super().__init__(image_processor, tokenizer)
66
+
67
+ def __call__(
68
+ self,
69
+ images,
70
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
71
+ audio=None,
72
+ videos=None,
73
+ **kwargs: Unpack[BridgeTowerProcessorKwargs],
74
+ ) -> BatchEncoding:
75
+ """
76
+ This method uses [`BridgeTowerImageProcessor.__call__`] method to prepare image(s) for the model, and
77
+ [`RobertaTokenizerFast.__call__`] to prepare text for the model.
78
+
79
+ Please refer to the docstring of the above two methods for more information.
80
+ """
81
+ output_kwargs = self._merge_kwargs(
82
+ BridgeTowerProcessorKwargs,
83
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
84
+ **kwargs,
85
+ )
86
+ encoding = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
87
+ # add pixel_values + pixel_mask
88
+ encoding_image_processor = self.image_processor(images, **output_kwargs["images_kwargs"])
89
+ encoding.update(encoding_image_processor)
90
+
91
+ return encoding
92
+
93
+ def batch_decode(self, *args, **kwargs):
94
+ """
95
+ This method forwards all its arguments to RobertaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
96
+ refer to the docstring of this method for more information.
97
+ """
98
+ return self.tokenizer.batch_decode(*args, **kwargs)
99
+
100
+ def decode(self, *args, **kwargs):
101
+ """
102
+ This method forwards all its arguments to RobertaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer
103
+ to the docstring of this method for more information.
104
+ """
105
+ return self.tokenizer.decode(*args, **kwargs)
106
+
107
+ @property
108
+ def model_input_names(self):
109
+ tokenizer_input_names = self.tokenizer.model_input_names
110
+ image_processor_input_names = self.image_processor.model_input_names
111
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
112
+
113
+
114
+ __all__ = ["BridgeTowerProcessor"]
docs/transformers/src/transformers/models/bros/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_bros import *
22
+ from .modeling_bros import *
23
+ from .processing_bros import *
24
+ else:
25
+ import sys
26
+
27
+ _file = globals()["__file__"]
28
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/src/transformers/models/bros/configuration_bros.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023-present NAVER Corp, The Microsoft Research Asia LayoutLM Team Authors and the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Bros model configuration"""
16
+
17
+ from ...configuration_utils import PretrainedConfig
18
+ from ...utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class BrosConfig(PretrainedConfig):
25
+ r"""
26
+ This is the configuration class to store the configuration of a [`BrosModel`] or a [`TFBrosModel`]. It is used to
27
+ instantiate a Bros model according to the specified arguments, defining the model architecture. Instantiating a
28
+ configuration with the defaults will yield a similar configuration to that of the Bros
29
+ [jinho8345/bros-base-uncased](https://huggingface.co/jinho8345/bros-base-uncased) architecture.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+ Args:
35
+ vocab_size (`int`, *optional*, defaults to 30522):
36
+ Vocabulary size of the Bros model. Defines the number of different tokens that can be represented by the
37
+ `inputs_ids` passed when calling [`BrosModel`] or [`TFBrosModel`].
38
+ hidden_size (`int`, *optional*, defaults to 768):
39
+ Dimensionality of the encoder layers and the pooler layer.
40
+ num_hidden_layers (`int`, *optional*, defaults to 12):
41
+ Number of hidden layers in the Transformer encoder.
42
+ num_attention_heads (`int`, *optional*, defaults to 12):
43
+ Number of attention heads for each attention layer in the Transformer encoder.
44
+ intermediate_size (`int`, *optional*, defaults to 3072):
45
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
46
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
47
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
48
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
49
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
50
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
51
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
52
+ The dropout ratio for the attention probabilities.
53
+ max_position_embeddings (`int`, *optional*, defaults to 512):
54
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
55
+ just in case (e.g., 512 or 1024 or 2048).
56
+ type_vocab_size (`int`, *optional*, defaults to 2):
57
+ The vocabulary size of the `token_type_ids` passed when calling [`BrosModel`] or [`TFBrosModel`].
58
+ initializer_range (`float`, *optional*, defaults to 0.02):
59
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
60
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
61
+ The epsilon used by the layer normalization layers.
62
+ pad_token_id (`int`, *optional*, defaults to 0):
63
+ The index of the padding token in the token vocabulary.
64
+ dim_bbox (`int`, *optional*, defaults to 8):
65
+ The dimension of the bounding box coordinates. (x0, y1, x1, y0, x1, y1, x0, y1)
66
+ bbox_scale (`float`, *optional*, defaults to 100.0):
67
+ The scale factor of the bounding box coordinates.
68
+ n_relations (`int`, *optional*, defaults to 1):
69
+ The number of relations for SpadeEE(entity extraction), SpadeEL(entity linking) head.
70
+ classifier_dropout_prob (`float`, *optional*, defaults to 0.1):
71
+ The dropout ratio for the classifier head.
72
+
73
+
74
+ Examples:
75
+
76
+ ```python
77
+ >>> from transformers import BrosConfig, BrosModel
78
+
79
+ >>> # Initializing a BROS jinho8345/bros-base-uncased style configuration
80
+ >>> configuration = BrosConfig()
81
+
82
+ >>> # Initializing a model from the jinho8345/bros-base-uncased style configuration
83
+ >>> model = BrosModel(configuration)
84
+
85
+ >>> # Accessing the model configuration
86
+ >>> configuration = model.config
87
+ ```"""
88
+
89
+ model_type = "bros"
90
+
91
+ def __init__(
92
+ self,
93
+ vocab_size=30522,
94
+ hidden_size=768,
95
+ num_hidden_layers=12,
96
+ num_attention_heads=12,
97
+ intermediate_size=3072,
98
+ hidden_act="gelu",
99
+ hidden_dropout_prob=0.1,
100
+ attention_probs_dropout_prob=0.1,
101
+ max_position_embeddings=512,
102
+ type_vocab_size=2,
103
+ initializer_range=0.02,
104
+ layer_norm_eps=1e-12,
105
+ pad_token_id=0,
106
+ dim_bbox=8,
107
+ bbox_scale=100.0,
108
+ n_relations=1,
109
+ classifier_dropout_prob=0.1,
110
+ **kwargs,
111
+ ):
112
+ super().__init__(
113
+ vocab_size=vocab_size,
114
+ hidden_size=hidden_size,
115
+ num_hidden_layers=num_hidden_layers,
116
+ num_attention_heads=num_attention_heads,
117
+ intermediate_size=intermediate_size,
118
+ hidden_act=hidden_act,
119
+ hidden_dropout_prob=hidden_dropout_prob,
120
+ attention_probs_dropout_prob=attention_probs_dropout_prob,
121
+ max_position_embeddings=max_position_embeddings,
122
+ type_vocab_size=type_vocab_size,
123
+ initializer_range=initializer_range,
124
+ layer_norm_eps=layer_norm_eps,
125
+ pad_token_id=pad_token_id,
126
+ **kwargs,
127
+ )
128
+
129
+ self.dim_bbox = dim_bbox
130
+ self.bbox_scale = bbox_scale
131
+ self.n_relations = n_relations
132
+ self.dim_bbox_sinusoid_emb_2d = self.hidden_size // 4
133
+ self.dim_bbox_sinusoid_emb_1d = self.dim_bbox_sinusoid_emb_2d // self.dim_bbox
134
+ self.dim_bbox_projection = self.hidden_size // self.num_attention_heads
135
+ self.classifier_dropout_prob = classifier_dropout_prob
136
+
137
+
138
+ __all__ = ["BrosConfig"]
docs/transformers/src/transformers/models/bros/convert_bros_to_pytorch.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert Bros checkpoints."""
16
+
17
+ import argparse
18
+
19
+ import bros # original repo
20
+ import torch
21
+
22
+ from transformers import BrosConfig, BrosModel, BrosProcessor
23
+ from transformers.utils import logging
24
+
25
+
26
+ logging.set_verbosity_info()
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ def get_configs(model_name):
31
+ bros_config = BrosConfig.from_pretrained(model_name)
32
+ return bros_config
33
+
34
+
35
+ def remove_ignore_keys_(state_dict):
36
+ ignore_keys = [
37
+ "embeddings.bbox_sinusoid_emb.inv_freq",
38
+ ]
39
+ for k in ignore_keys:
40
+ state_dict.pop(k, None)
41
+
42
+
43
+ def rename_key(name):
44
+ if name == "embeddings.bbox_projection.weight":
45
+ name = "bbox_embeddings.bbox_projection.weight"
46
+
47
+ if name == "embeddings.bbox_sinusoid_emb.x_pos_emb.inv_freq":
48
+ name = "bbox_embeddings.bbox_sinusoid_emb.x_pos_emb.inv_freq"
49
+
50
+ if name == "embeddings.bbox_sinusoid_emb.y_pos_emb.inv_freq":
51
+ name = "bbox_embeddings.bbox_sinusoid_emb.y_pos_emb.inv_freq"
52
+
53
+ return name
54
+
55
+
56
+ def convert_state_dict(orig_state_dict, model):
57
+ # rename keys
58
+ for key in orig_state_dict.copy().keys():
59
+ val = orig_state_dict.pop(key)
60
+ orig_state_dict[rename_key(key)] = val
61
+
62
+ # remove ignore keys
63
+ remove_ignore_keys_(orig_state_dict)
64
+
65
+ return orig_state_dict
66
+
67
+
68
+ def convert_bros_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_hub=False):
69
+ # load original model
70
+ original_model = bros.BrosModel.from_pretrained(model_name).eval()
71
+
72
+ # load HuggingFace Model
73
+ bros_config = get_configs(model_name)
74
+ model = BrosModel.from_pretrained(model_name, config=bros_config)
75
+ model.eval()
76
+
77
+ state_dict = original_model.state_dict()
78
+ new_state_dict = convert_state_dict(state_dict, model)
79
+ model.load_state_dict(new_state_dict)
80
+
81
+ # verify results
82
+
83
+ # original BROS model require 4 points (8 float values) for each bbox, prepare bbox with [batch_size, seq_len, 8] shape
84
+ bbox = torch.tensor(
85
+ [
86
+ [
87
+ [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
88
+ [0.4396, 0.6720, 0.4659, 0.6720, 0.4659, 0.6850, 0.4396, 0.6850],
89
+ [0.4698, 0.6720, 0.4843, 0.6720, 0.4843, 0.6850, 0.4698, 0.6850],
90
+ [0.4698, 0.6720, 0.4843, 0.6720, 0.4843, 0.6850, 0.4698, 0.6850],
91
+ [0.2047, 0.6870, 0.2730, 0.6870, 0.2730, 0.7000, 0.2047, 0.7000],
92
+ [0.2047, 0.6870, 0.2730, 0.6870, 0.2730, 0.7000, 0.2047, 0.7000],
93
+ [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
94
+ ]
95
+ ]
96
+ )
97
+
98
+ processor = BrosProcessor.from_pretrained(model_name)
99
+
100
+ encoding = processor("His name is Rocco.", return_tensors="pt")
101
+ encoding["bbox"] = bbox
102
+
103
+ original_hidden_states = original_model(**encoding).last_hidden_state
104
+ # pixel_values = processor(image, return_tensors="pt").pixel_values
105
+
106
+ last_hidden_states = model(**encoding).last_hidden_state
107
+
108
+ assert torch.allclose(original_hidden_states, last_hidden_states, atol=1e-4)
109
+
110
+ if pytorch_dump_folder_path is not None:
111
+ print(f"Saving model and processor to {pytorch_dump_folder_path}")
112
+ model.save_pretrained(pytorch_dump_folder_path)
113
+ processor.save_pretrained(pytorch_dump_folder_path)
114
+
115
+ if push_to_hub:
116
+ model.push_to_hub("jinho8345/" + model_name.split("/")[-1], commit_message="Update model")
117
+ processor.push_to_hub("jinho8345/" + model_name.split("/")[-1], commit_message="Update model")
118
+
119
+
120
+ if __name__ == "__main__":
121
+ parser = argparse.ArgumentParser()
122
+
123
+ # Required parameters
124
+ parser.add_argument(
125
+ "--model_name",
126
+ default="jinho8345/bros-base-uncased",
127
+ required=False,
128
+ type=str,
129
+ help="Name of the original model you'd like to convert.",
130
+ )
131
+ parser.add_argument(
132
+ "--pytorch_dump_folder_path",
133
+ default=None,
134
+ required=False,
135
+ type=str,
136
+ help="Path to the output PyTorch model directory.",
137
+ )
138
+ parser.add_argument(
139
+ "--push_to_hub",
140
+ action="store_true",
141
+ help="Whether or not to push the converted model and processor to the 🤗 hub.",
142
+ )
143
+
144
+ args = parser.parse_args()
145
+ convert_bros_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
docs/transformers/src/transformers/models/bros/modeling_bros.py ADDED
@@ -0,0 +1,1323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023-present NAVER Corp, The Microsoft Research Asia LayoutLM Team Authors and the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch Bros model."""
16
+
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+ from torch.nn import CrossEntropyLoss
25
+
26
+ from ...activations import ACT2FN
27
+ from ...modeling_outputs import (
28
+ BaseModelOutputWithPastAndCrossAttentions,
29
+ BaseModelOutputWithPoolingAndCrossAttentions,
30
+ TokenClassifierOutput,
31
+ )
32
+ from ...modeling_utils import PreTrainedModel
33
+ from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
34
+ from ...utils import (
35
+ ModelOutput,
36
+ add_start_docstrings,
37
+ add_start_docstrings_to_model_forward,
38
+ logging,
39
+ replace_return_docstrings,
40
+ )
41
+ from .configuration_bros import BrosConfig
42
+
43
+
44
+ logger = logging.get_logger(__name__)
45
+
46
+ _CHECKPOINT_FOR_DOC = "jinho8345/bros-base-uncased"
47
+ _CONFIG_FOR_DOC = "BrosConfig"
48
+
49
+
50
+ BROS_START_DOCSTRING = r"""
51
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
52
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
53
+ and behavior.
54
+
55
+ Parameters:
56
+ config ([`BrosConfig`]): Model configuration class with all the parameters of the model.
57
+ Initializing with a config file does not load the weights associated with the model, only the
58
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
59
+ """
60
+
61
+ BROS_INPUTS_DOCSTRING = r"""
62
+ Args:
63
+ input_ids (`torch.LongTensor` of shape `({0})`):
64
+ Indices of input sequence tokens in the vocabulary.
65
+
66
+ Indices can be obtained using [`BrosProcessor`]. See [`PreTrainedTokenizer.encode`] and
67
+ [`PreTrainedTokenizer.__call__`] for details.
68
+
69
+ [What are input IDs?](../glossary#input-ids)
70
+
71
+ bbox ('torch.FloatTensor' of shape '(batch_size, num_boxes, 4)'):
72
+ Bounding box coordinates for each token in the input sequence. Each bounding box is a list of four values
73
+ (x1, y1, x2, y2), where (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner of the
74
+ bounding box.
75
+
76
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
77
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
78
+
79
+ - 1 for tokens that are **not masked**,
80
+ - 0 for tokens that are **masked**.
81
+
82
+ [What are attention masks?](../glossary#attention-mask)
83
+
84
+ bbox_first_token_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
85
+ Mask to indicate the first token of each bounding box. Mask values selected in `[0, 1]`:
86
+
87
+ - 1 for tokens that are **not masked**,
88
+ - 0 for tokens that are **masked**.
89
+
90
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
91
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
92
+ 1]`:
93
+
94
+ - 0 corresponds to a *sentence A* token,
95
+ - 1 corresponds to a *sentence B* token.
96
+
97
+ [What are token type IDs?](../glossary#token-type-ids)
98
+
99
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
100
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
101
+ config.max_position_embeddings - 1]`.
102
+
103
+ [What are position IDs?](../glossary#position-ids)
104
+
105
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
106
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
107
+
108
+ - 1 indicates the head is **not masked**,
109
+ - 0 indicates the head is **masked**.
110
+
111
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
112
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
113
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
114
+ model's internal embedding lookup matrix.
115
+
116
+ output_attentions (`bool`, *optional*):
117
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
118
+ tensors for more detail.
119
+
120
+ output_hidden_states (`bool`, *optional*):
121
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
122
+ more detail.
123
+
124
+ return_dict (`bool`, *optional*):
125
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
126
+ """
127
+
128
+
129
+ @dataclass
130
+ class BrosSpadeOutput(ModelOutput):
131
+ """
132
+ Base class for outputs of token classification models.
133
+
134
+ Args:
135
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided) :
136
+ Classification loss.
137
+ initial_token_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`):
138
+ Classification scores for entity initial tokens (before SoftMax).
139
+ subsequent_token_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, sequence_length+1)`):
140
+ Classification scores for entity sequence tokens (before SoftMax).
141
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
142
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
143
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
144
+
145
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
146
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
147
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
148
+ sequence_length)`.
149
+
150
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
151
+ heads.
152
+ """
153
+
154
+ loss: Optional[torch.FloatTensor] = None
155
+ initial_token_logits: Optional[torch.FloatTensor] = None
156
+ subsequent_token_logits: Optional[torch.FloatTensor] = None
157
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
158
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
159
+
160
+
161
+ class BrosPositionalEmbedding1D(nn.Module):
162
+ # Reference: https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py#L15
163
+
164
+ def __init__(self, config):
165
+ super(BrosPositionalEmbedding1D, self).__init__()
166
+
167
+ self.dim_bbox_sinusoid_emb_1d = config.dim_bbox_sinusoid_emb_1d
168
+
169
+ inv_freq = 1 / (
170
+ 10000 ** (torch.arange(0.0, self.dim_bbox_sinusoid_emb_1d, 2.0) / self.dim_bbox_sinusoid_emb_1d)
171
+ )
172
+ self.register_buffer("inv_freq", inv_freq)
173
+
174
+ def forward(self, pos_seq: torch.Tensor) -> torch.Tensor:
175
+ seq_size = pos_seq.size()
176
+ b1, b2, b3 = seq_size
177
+ sinusoid_inp = pos_seq.view(b1, b2, b3, 1) * self.inv_freq.view(1, 1, 1, self.dim_bbox_sinusoid_emb_1d // 2)
178
+ pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
179
+ return pos_emb
180
+
181
+
182
+ class BrosPositionalEmbedding2D(nn.Module):
183
+ def __init__(self, config):
184
+ super(BrosPositionalEmbedding2D, self).__init__()
185
+
186
+ self.dim_bbox = config.dim_bbox
187
+ self.x_pos_emb = BrosPositionalEmbedding1D(config)
188
+ self.y_pos_emb = BrosPositionalEmbedding1D(config)
189
+
190
+ def forward(self, bbox: torch.Tensor) -> torch.Tensor:
191
+ stack = []
192
+ for i in range(self.dim_bbox):
193
+ if i % 2 == 0:
194
+ stack.append(self.x_pos_emb(bbox[..., i]))
195
+ else:
196
+ stack.append(self.y_pos_emb(bbox[..., i]))
197
+ bbox_pos_emb = torch.cat(stack, dim=-1)
198
+ return bbox_pos_emb
199
+
200
+
201
+ class BrosBboxEmbeddings(nn.Module):
202
+ def __init__(self, config):
203
+ super(BrosBboxEmbeddings, self).__init__()
204
+ self.bbox_sinusoid_emb = BrosPositionalEmbedding2D(config)
205
+ self.bbox_projection = nn.Linear(config.dim_bbox_sinusoid_emb_2d, config.dim_bbox_projection, bias=False)
206
+
207
+ def forward(self, bbox: torch.Tensor):
208
+ bbox_t = bbox.transpose(0, 1)
209
+ bbox_pos = bbox_t[None, :, :, :] - bbox_t[:, None, :, :]
210
+ bbox_pos_emb = self.bbox_sinusoid_emb(bbox_pos)
211
+ bbox_pos_emb = self.bbox_projection(bbox_pos_emb)
212
+
213
+ return bbox_pos_emb
214
+
215
+
216
+ class BrosTextEmbeddings(nn.Module):
217
+ """Construct the embeddings from word, position and token_type embeddings."""
218
+
219
+ def __init__(self, config):
220
+ super().__init__()
221
+
222
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
223
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
224
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
225
+
226
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
227
+ # any TensorFlow checkpoint file
228
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
229
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
230
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
231
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
232
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
233
+ self.register_buffer(
234
+ "token_type_ids",
235
+ torch.zeros(
236
+ self.position_ids.size(),
237
+ dtype=torch.long,
238
+ device=self.position_ids.device,
239
+ ),
240
+ persistent=False,
241
+ )
242
+
243
+ def forward(
244
+ self,
245
+ input_ids: Optional[torch.Tensor] = None,
246
+ token_type_ids: Optional[torch.Tensor] = None,
247
+ position_ids: Optional[torch.Tensor] = None,
248
+ inputs_embeds: Optional[torch.Tensor] = None,
249
+ past_key_values_length: int = 0,
250
+ ) -> torch.Tensor:
251
+ if input_ids is not None:
252
+ input_shape = input_ids.size()
253
+ else:
254
+ input_shape = inputs_embeds.size()[:-1]
255
+
256
+ seq_length = input_shape[1]
257
+
258
+ if position_ids is None:
259
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
260
+
261
+ if token_type_ids is None:
262
+ if hasattr(self, "token_type_ids"):
263
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
264
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
265
+ token_type_ids = buffered_token_type_ids_expanded
266
+ else:
267
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
268
+
269
+ if inputs_embeds is None:
270
+ inputs_embeds = self.word_embeddings(input_ids)
271
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
272
+
273
+ embeddings = inputs_embeds + token_type_embeddings
274
+ if self.position_embedding_type == "absolute":
275
+ position_embeddings = self.position_embeddings(position_ids)
276
+ embeddings += position_embeddings
277
+ embeddings = self.LayerNorm(embeddings)
278
+ embeddings = self.dropout(embeddings)
279
+ return embeddings
280
+
281
+
282
+ class BrosSelfAttention(nn.Module):
283
+ def __init__(self, config):
284
+ super().__init__()
285
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
286
+ raise ValueError(
287
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
288
+ f"heads ({config.num_attention_heads})"
289
+ )
290
+
291
+ self.num_attention_heads = config.num_attention_heads
292
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
293
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
294
+
295
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
296
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
297
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
298
+
299
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
300
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
301
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
302
+ self.max_position_embeddings = config.max_position_embeddings
303
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
304
+
305
+ self.is_decoder = config.is_decoder
306
+
307
+ def transpose_for_scores(self, x: torch.Tensor):
308
+ new_x_shape = x.size()[:-1] + (
309
+ self.num_attention_heads,
310
+ self.attention_head_size,
311
+ )
312
+ x = x.view(*new_x_shape)
313
+ return x.permute(0, 2, 1, 3)
314
+
315
+ def forward(
316
+ self,
317
+ hidden_states: torch.Tensor,
318
+ bbox_pos_emb: torch.Tensor,
319
+ attention_mask: Optional[torch.Tensor] = None,
320
+ head_mask: Optional[torch.Tensor] = None,
321
+ encoder_hidden_states: Optional[torch.Tensor] = None,
322
+ encoder_attention_mask: Optional[torch.Tensor] = None,
323
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
324
+ output_attentions: Optional[torch.Tensor] = False,
325
+ ) -> Tuple[torch.Tensor]:
326
+ mixed_query_layer = self.query(hidden_states)
327
+
328
+ # If this is instantiated as a cross-attention module, the keys
329
+ # and values come from an encoder; the attention mask needs to be
330
+ # such that the encoder's padding tokens are not attended to.
331
+ is_cross_attention = encoder_hidden_states is not None
332
+
333
+ if is_cross_attention and past_key_value is not None:
334
+ # reuse k,v, cross_attentions
335
+ key_layer = past_key_value[0]
336
+ value_layer = past_key_value[1]
337
+ attention_mask = encoder_attention_mask
338
+ elif is_cross_attention:
339
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
340
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
341
+ attention_mask = encoder_attention_mask
342
+ elif past_key_value is not None:
343
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
344
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
345
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
346
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
347
+ else:
348
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
349
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
350
+
351
+ query_layer = self.transpose_for_scores(mixed_query_layer)
352
+
353
+ if self.is_decoder:
354
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
355
+ # Further calls to cross_attention layer can then reuse all cross-attention
356
+ # key/value_states (first "if" case)
357
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
358
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
359
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
360
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
361
+ past_key_value = (key_layer, value_layer)
362
+
363
+ # Take the dot product between "query" and "key" to get the raw attention scores.
364
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
365
+
366
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
367
+ seq_length = hidden_states.size()[1]
368
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
369
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
370
+ distance = position_ids_l - position_ids_r
371
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
372
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
373
+
374
+ if self.position_embedding_type == "relative_key":
375
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
376
+ attention_scores = attention_scores + relative_position_scores
377
+ elif self.position_embedding_type == "relative_key_query":
378
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
379
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
380
+
381
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
382
+
383
+ # bbox positional encoding
384
+ batch_size, n_head, seq_length, d_head = query_layer.shape
385
+ bbox_pos_emb = bbox_pos_emb.view(seq_length, seq_length, batch_size, d_head)
386
+ bbox_pos_emb = bbox_pos_emb.permute([2, 0, 1, 3])
387
+ bbox_pos_scores = torch.einsum("bnid,bijd->bnij", (query_layer, bbox_pos_emb))
388
+
389
+ attention_scores = attention_scores + bbox_pos_scores
390
+
391
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
392
+ if attention_mask is not None:
393
+ # Apply the attention mask is (precomputed for all layers in BrosModel forward() function)
394
+ attention_scores = attention_scores + attention_mask
395
+
396
+ # Normalize the attention scores to probabilities.
397
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
398
+
399
+ # This is actually dropping out entire tokens to attend to, which might
400
+ # seem a bit unusual, but is taken from the original Transformer paper.
401
+ attention_probs = self.dropout(attention_probs)
402
+
403
+ # Mask heads if we want to
404
+ if head_mask is not None:
405
+ attention_probs = attention_probs * head_mask
406
+
407
+ context_layer = torch.matmul(attention_probs, value_layer)
408
+
409
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
410
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
411
+ context_layer = context_layer.view(*new_context_layer_shape)
412
+
413
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
414
+
415
+ if self.is_decoder:
416
+ outputs = outputs + (past_key_value,)
417
+ return outputs
418
+
419
+
420
+ # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->Bros
421
+ class BrosSelfOutput(nn.Module):
422
+ def __init__(self, config):
423
+ super().__init__()
424
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
425
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
426
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
427
+
428
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
429
+ hidden_states = self.dense(hidden_states)
430
+ hidden_states = self.dropout(hidden_states)
431
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
432
+ return hidden_states
433
+
434
+
435
+ class BrosAttention(nn.Module):
436
+ def __init__(self, config):
437
+ super().__init__()
438
+ self.self = BrosSelfAttention(config)
439
+ self.output = BrosSelfOutput(config)
440
+ self.pruned_heads = set()
441
+
442
+ def prune_heads(self, heads):
443
+ if len(heads) == 0:
444
+ return
445
+ heads, index = find_pruneable_heads_and_indices(
446
+ heads,
447
+ self.self.num_attention_heads,
448
+ self.self.attention_head_size,
449
+ self.pruned_heads,
450
+ )
451
+
452
+ # Prune linear layers
453
+ self.self.query = prune_linear_layer(self.self.query, index)
454
+ self.self.key = prune_linear_layer(self.self.key, index)
455
+ self.self.value = prune_linear_layer(self.self.value, index)
456
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
457
+
458
+ # Update hyper params and store pruned heads
459
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
460
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
461
+ self.pruned_heads = self.pruned_heads.union(heads)
462
+
463
+ def forward(
464
+ self,
465
+ hidden_states: torch.Tensor,
466
+ bbox_pos_emb: torch.Tensor,
467
+ attention_mask: Optional[torch.Tensor] = None,
468
+ head_mask: Optional[torch.Tensor] = None,
469
+ encoder_hidden_states: Optional[torch.Tensor] = None,
470
+ encoder_attention_mask: Optional[torch.Tensor] = None,
471
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
472
+ output_attentions: Optional[bool] = False,
473
+ ) -> Tuple[torch.Tensor]:
474
+ self_outputs = self.self(
475
+ hidden_states=hidden_states,
476
+ bbox_pos_emb=bbox_pos_emb,
477
+ attention_mask=attention_mask,
478
+ head_mask=head_mask,
479
+ encoder_hidden_states=encoder_hidden_states,
480
+ encoder_attention_mask=encoder_attention_mask,
481
+ past_key_value=past_key_value,
482
+ output_attentions=output_attentions,
483
+ )
484
+ attention_output = self.output(self_outputs[0], hidden_states)
485
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
486
+ return outputs
487
+
488
+
489
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->Bros
490
+ class BrosIntermediate(nn.Module):
491
+ def __init__(self, config):
492
+ super().__init__()
493
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
494
+ if isinstance(config.hidden_act, str):
495
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
496
+ else:
497
+ self.intermediate_act_fn = config.hidden_act
498
+
499
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
500
+ hidden_states = self.dense(hidden_states)
501
+ hidden_states = self.intermediate_act_fn(hidden_states)
502
+ return hidden_states
503
+
504
+
505
+ class BrosOutput(nn.Module):
506
+ def __init__(self, config):
507
+ super().__init__()
508
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
509
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
510
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
511
+
512
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
513
+ hidden_states = self.dense(hidden_states)
514
+ hidden_states = self.dropout(hidden_states)
515
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
516
+ return hidden_states
517
+
518
+
519
+ class BrosLayer(nn.Module):
520
+ def __init__(self, config):
521
+ super().__init__()
522
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
523
+ self.seq_len_dim = 1
524
+ self.attention = BrosAttention(config)
525
+ self.is_decoder = config.is_decoder
526
+ self.add_cross_attention = config.add_cross_attention
527
+ if self.add_cross_attention:
528
+ if not self.is_decoder:
529
+ raise Exception(f"{self} should be used as a decoder model if cross attention is added")
530
+ self.crossattention = BrosAttention(config)
531
+ self.intermediate = BrosIntermediate(config)
532
+ self.output = BrosOutput(config)
533
+
534
+ def forward(
535
+ self,
536
+ hidden_states: torch.Tensor,
537
+ bbox_pos_emb: torch.Tensor,
538
+ attention_mask: Optional[torch.FloatTensor] = None,
539
+ head_mask: Optional[torch.FloatTensor] = None,
540
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
541
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
542
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
543
+ output_attentions: Optional[bool] = False,
544
+ ) -> Tuple[torch.Tensor]:
545
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
546
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
547
+ self_attention_outputs = self.attention(
548
+ hidden_states,
549
+ bbox_pos_emb=bbox_pos_emb,
550
+ attention_mask=attention_mask,
551
+ head_mask=head_mask,
552
+ output_attentions=output_attentions,
553
+ past_key_value=self_attn_past_key_value,
554
+ )
555
+ attention_output = self_attention_outputs[0]
556
+
557
+ # if decoder, the last output is tuple of self-attn cache
558
+ if self.is_decoder:
559
+ outputs = self_attention_outputs[1:-1]
560
+ present_key_value = self_attention_outputs[-1]
561
+ else:
562
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
563
+
564
+ cross_attn_present_key_value = None
565
+ if self.is_decoder and encoder_hidden_states is not None:
566
+ if hasattr(self, "crossattention"):
567
+ raise Exception(
568
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
569
+ )
570
+
571
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
572
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
573
+ cross_attention_outputs = self.crossattention(
574
+ attention_output,
575
+ attention_mask,
576
+ head_mask,
577
+ encoder_hidden_states,
578
+ encoder_attention_mask,
579
+ cross_attn_past_key_value,
580
+ output_attentions,
581
+ )
582
+ attention_output = cross_attention_outputs[0]
583
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
584
+
585
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
586
+ cross_attn_present_key_value = cross_attention_outputs[-1]
587
+ present_key_value = present_key_value + cross_attn_present_key_value
588
+
589
+ layer_output = apply_chunking_to_forward(
590
+ self.feed_forward_chunk,
591
+ self.chunk_size_feed_forward,
592
+ self.seq_len_dim,
593
+ attention_output,
594
+ )
595
+ outputs = (layer_output,) + outputs
596
+
597
+ # if decoder, return the attn key/values as the last output
598
+ if self.is_decoder:
599
+ outputs = outputs + (present_key_value,)
600
+
601
+ return outputs
602
+
603
+ def feed_forward_chunk(self, attention_output):
604
+ intermediate_output = self.intermediate(attention_output)
605
+ layer_output = self.output(intermediate_output, attention_output)
606
+ return layer_output
607
+
608
+
609
+ class BrosEncoder(nn.Module):
610
+ def __init__(self, config):
611
+ super().__init__()
612
+ self.config = config
613
+ self.layer = nn.ModuleList([BrosLayer(config) for _ in range(config.num_hidden_layers)])
614
+
615
+ def forward(
616
+ self,
617
+ hidden_states: torch.Tensor,
618
+ bbox_pos_emb: torch.Tensor,
619
+ attention_mask: Optional[torch.FloatTensor] = None,
620
+ head_mask: Optional[torch.FloatTensor] = None,
621
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
622
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
623
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
624
+ use_cache: Optional[bool] = None,
625
+ output_attentions: Optional[bool] = False,
626
+ output_hidden_states: Optional[bool] = False,
627
+ return_dict: Optional[bool] = True,
628
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
629
+ all_hidden_states = () if output_hidden_states else None
630
+ all_self_attentions = () if output_attentions else None
631
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
632
+
633
+ next_decoder_cache = () if use_cache else None
634
+ for i, layer_module in enumerate(self.layer):
635
+ if output_hidden_states:
636
+ all_hidden_states = all_hidden_states + (hidden_states,)
637
+
638
+ layer_head_mask = head_mask[i] if head_mask is not None else None
639
+ past_key_value = past_key_values[i] if past_key_values is not None else None
640
+
641
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
642
+ if use_cache:
643
+ logger.warning(
644
+ "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
645
+ "`use_cache=False`..."
646
+ )
647
+ use_cache = False
648
+ layer_outputs = self._gradient_checkpointing_func(
649
+ layer_module.__call__,
650
+ hidden_states,
651
+ bbox_pos_emb,
652
+ attention_mask,
653
+ layer_head_mask,
654
+ encoder_hidden_states,
655
+ encoder_attention_mask,
656
+ output_attentions,
657
+ )
658
+ else:
659
+ layer_outputs = layer_module(
660
+ hidden_states=hidden_states,
661
+ bbox_pos_emb=bbox_pos_emb,
662
+ attention_mask=attention_mask,
663
+ head_mask=layer_head_mask,
664
+ encoder_hidden_states=encoder_hidden_states,
665
+ encoder_attention_mask=encoder_attention_mask,
666
+ past_key_value=past_key_value,
667
+ output_attentions=output_attentions,
668
+ )
669
+
670
+ hidden_states = layer_outputs[0]
671
+ if use_cache:
672
+ next_decoder_cache += (layer_outputs[-1],)
673
+ if output_attentions:
674
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
675
+ if self.config.add_cross_attention:
676
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
677
+
678
+ if output_hidden_states:
679
+ all_hidden_states = all_hidden_states + (hidden_states,)
680
+
681
+ if not return_dict:
682
+ return tuple(
683
+ v
684
+ for v in [
685
+ hidden_states,
686
+ next_decoder_cache,
687
+ all_hidden_states,
688
+ all_self_attentions,
689
+ all_cross_attentions,
690
+ ]
691
+ if v is not None
692
+ )
693
+ return BaseModelOutputWithPastAndCrossAttentions(
694
+ last_hidden_state=hidden_states,
695
+ past_key_values=next_decoder_cache,
696
+ hidden_states=all_hidden_states,
697
+ attentions=all_self_attentions,
698
+ cross_attentions=all_cross_attentions,
699
+ )
700
+
701
+
702
+ # Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->Bros
703
+ class BrosPooler(nn.Module):
704
+ def __init__(self, config):
705
+ super().__init__()
706
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
707
+ self.activation = nn.Tanh()
708
+
709
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
710
+ # We "pool" the model by simply taking the hidden state corresponding
711
+ # to the first token.
712
+ first_token_tensor = hidden_states[:, 0]
713
+ pooled_output = self.dense(first_token_tensor)
714
+ pooled_output = self.activation(pooled_output)
715
+ return pooled_output
716
+
717
+
718
+ class BrosRelationExtractor(nn.Module):
719
+ def __init__(self, config):
720
+ super().__init__()
721
+ self.n_relations = config.n_relations
722
+ self.backbone_hidden_size = config.hidden_size
723
+ self.head_hidden_size = config.hidden_size
724
+ self.classifier_dropout_prob = config.classifier_dropout_prob
725
+
726
+ self.drop = nn.Dropout(self.classifier_dropout_prob)
727
+ self.query = nn.Linear(self.backbone_hidden_size, self.n_relations * self.head_hidden_size)
728
+
729
+ self.key = nn.Linear(self.backbone_hidden_size, self.n_relations * self.head_hidden_size)
730
+
731
+ self.dummy_node = nn.Parameter(torch.zeros(1, self.backbone_hidden_size))
732
+
733
+ def forward(self, query_layer: torch.Tensor, key_layer: torch.Tensor):
734
+ query_layer = self.query(self.drop(query_layer))
735
+
736
+ dummy_vec = self.dummy_node.unsqueeze(0).repeat(1, key_layer.size(1), 1)
737
+ key_layer = torch.cat([key_layer, dummy_vec], axis=0)
738
+ key_layer = self.key(self.drop(key_layer))
739
+
740
+ query_layer = query_layer.view(
741
+ query_layer.size(0), query_layer.size(1), self.n_relations, self.head_hidden_size
742
+ )
743
+ key_layer = key_layer.view(key_layer.size(0), key_layer.size(1), self.n_relations, self.head_hidden_size)
744
+
745
+ relation_score = torch.matmul(
746
+ query_layer.permute(2, 1, 0, 3), key_layer.permute(2, 1, 3, 0)
747
+ ) # equivalent to torch.einsum("ibnd,jbnd->nbij", (query_layer, key_layer))
748
+
749
+ return relation_score
750
+
751
+
752
+ class BrosPreTrainedModel(PreTrainedModel):
753
+ """
754
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
755
+ models.
756
+ """
757
+
758
+ config_class = BrosConfig
759
+ base_model_prefix = "bros"
760
+
761
+ def _init_weights(self, module):
762
+ """Initialize the weights"""
763
+ if isinstance(module, nn.Linear):
764
+ # Slightly different from the TF version which uses truncated_normal for initialization
765
+ # cf https://github.com/pytorch/pytorch/pull/5617
766
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
767
+ if module.bias is not None:
768
+ module.bias.data.zero_()
769
+ elif isinstance(module, nn.Embedding):
770
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
771
+ if module.padding_idx is not None:
772
+ module.weight.data[module.padding_idx].zero_()
773
+ elif isinstance(module, nn.LayerNorm):
774
+ module.bias.data.zero_()
775
+ module.weight.data.fill_(1.0)
776
+
777
+
778
+ @add_start_docstrings(
779
+ "The bare Bros Model transformer outputting raw hidden-states without any specific head on top.",
780
+ BROS_START_DOCSTRING,
781
+ )
782
+ class BrosModel(BrosPreTrainedModel):
783
+ def __init__(self, config, add_pooling_layer=True):
784
+ super().__init__(config)
785
+ self.config = config
786
+
787
+ self.embeddings = BrosTextEmbeddings(config)
788
+ self.bbox_embeddings = BrosBboxEmbeddings(config)
789
+ self.encoder = BrosEncoder(config)
790
+
791
+ self.pooler = BrosPooler(config) if add_pooling_layer else None
792
+
793
+ self.init_weights()
794
+
795
+ def get_input_embeddings(self):
796
+ return self.embeddings.word_embeddings
797
+
798
+ def set_input_embeddings(self, value):
799
+ self.embeddings.word_embeddings = value
800
+
801
+ def _prune_heads(self, heads_to_prune):
802
+ """
803
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
804
+ class PreTrainedModel
805
+ """
806
+ for layer, heads in heads_to_prune.items():
807
+ self.encoder.layer[layer].attention.prune_heads(heads)
808
+
809
+ @add_start_docstrings_to_model_forward(BROS_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
810
+ @replace_return_docstrings(output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC)
811
+ def forward(
812
+ self,
813
+ input_ids: Optional[torch.Tensor] = None,
814
+ bbox: Optional[torch.Tensor] = None,
815
+ attention_mask: Optional[torch.Tensor] = None,
816
+ token_type_ids: Optional[torch.Tensor] = None,
817
+ position_ids: Optional[torch.Tensor] = None,
818
+ head_mask: Optional[torch.Tensor] = None,
819
+ inputs_embeds: Optional[torch.Tensor] = None,
820
+ encoder_hidden_states: Optional[torch.Tensor] = None,
821
+ encoder_attention_mask: Optional[torch.Tensor] = None,
822
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
823
+ use_cache: Optional[bool] = None,
824
+ output_attentions: Optional[bool] = None,
825
+ output_hidden_states: Optional[bool] = None,
826
+ return_dict: Optional[bool] = None,
827
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
828
+ r"""
829
+ Returns:
830
+
831
+ Examples:
832
+
833
+ ```python
834
+ >>> import torch
835
+ >>> from transformers import BrosProcessor, BrosModel
836
+
837
+ >>> processor = BrosProcessor.from_pretrained("jinho8345/bros-base-uncased")
838
+
839
+ >>> model = BrosModel.from_pretrained("jinho8345/bros-base-uncased")
840
+
841
+ >>> encoding = processor("Hello, my dog is cute", add_special_tokens=False, return_tensors="pt")
842
+ >>> bbox = torch.tensor([[[0, 0, 1, 1]]]).repeat(1, encoding["input_ids"].shape[-1], 1)
843
+ >>> encoding["bbox"] = bbox
844
+
845
+ >>> outputs = model(**encoding)
846
+ >>> last_hidden_states = outputs.last_hidden_state
847
+ ```"""
848
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
849
+ output_hidden_states = (
850
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
851
+ )
852
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
853
+
854
+ if self.config.is_decoder:
855
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
856
+ else:
857
+ use_cache = False
858
+
859
+ if input_ids is not None and inputs_embeds is not None:
860
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
861
+ elif input_ids is not None:
862
+ input_shape = input_ids.size()
863
+ elif inputs_embeds is not None:
864
+ input_shape = inputs_embeds.size()[:-1]
865
+ else:
866
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
867
+
868
+ if bbox is None:
869
+ raise ValueError("You have to specify bbox")
870
+
871
+ batch_size, seq_length = input_shape
872
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
873
+
874
+ # past_key_values_length
875
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
876
+
877
+ if attention_mask is None:
878
+ attention_mask = torch.ones(input_shape, device=device)
879
+
880
+ if token_type_ids is None:
881
+ if hasattr(self.embeddings, "token_type_ids"):
882
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
883
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
884
+ token_type_ids = buffered_token_type_ids_expanded
885
+ else:
886
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
887
+
888
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
889
+ # ourselves in which case we just need to make it broadcastable to all heads.
890
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
891
+
892
+ # If a 2D or 3D attention mask is provided for the cross-attention
893
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
894
+ if self.config.is_decoder and encoder_hidden_states is not None:
895
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
896
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
897
+ if encoder_attention_mask is None:
898
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
899
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
900
+ else:
901
+ encoder_extended_attention_mask = None
902
+
903
+ # Prepare head mask if needed
904
+ # 1.0 in head_mask indicate we keep the head
905
+ # attention_probs has shape bsz x n_heads x N x N
906
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
907
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
908
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
909
+
910
+ embedding_output = self.embeddings(
911
+ input_ids=input_ids,
912
+ position_ids=position_ids,
913
+ token_type_ids=token_type_ids,
914
+ inputs_embeds=inputs_embeds,
915
+ past_key_values_length=past_key_values_length,
916
+ )
917
+
918
+ # if bbox has 2 points (4 float tensors) per token, convert it to 4 points (8 float tensors) per token
919
+ if bbox.shape[-1] == 4:
920
+ bbox = bbox[:, :, [0, 1, 2, 1, 2, 3, 0, 3]]
921
+ scaled_bbox = bbox * self.config.bbox_scale
922
+ bbox_position_embeddings = self.bbox_embeddings(scaled_bbox)
923
+
924
+ encoder_outputs = self.encoder(
925
+ embedding_output,
926
+ bbox_pos_emb=bbox_position_embeddings,
927
+ attention_mask=extended_attention_mask,
928
+ head_mask=head_mask,
929
+ encoder_hidden_states=encoder_hidden_states,
930
+ encoder_attention_mask=encoder_extended_attention_mask,
931
+ past_key_values=past_key_values,
932
+ use_cache=use_cache,
933
+ output_attentions=output_attentions,
934
+ output_hidden_states=output_hidden_states,
935
+ return_dict=return_dict,
936
+ )
937
+ sequence_output = encoder_outputs[0]
938
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
939
+
940
+ if not return_dict:
941
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
942
+
943
+ return BaseModelOutputWithPoolingAndCrossAttentions(
944
+ last_hidden_state=sequence_output,
945
+ pooler_output=pooled_output,
946
+ past_key_values=encoder_outputs.past_key_values,
947
+ hidden_states=encoder_outputs.hidden_states,
948
+ attentions=encoder_outputs.attentions,
949
+ cross_attentions=encoder_outputs.cross_attentions,
950
+ )
951
+
952
+
953
+ @add_start_docstrings(
954
+ """
955
+ Bros Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
956
+ Named-Entity-Recognition (NER) tasks.
957
+ """,
958
+ BROS_START_DOCSTRING,
959
+ )
960
+ class BrosForTokenClassification(BrosPreTrainedModel):
961
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
962
+
963
+ def __init__(self, config):
964
+ super().__init__(config)
965
+ self.num_labels = config.num_labels
966
+
967
+ self.bros = BrosModel(config)
968
+ classifier_dropout = (
969
+ config.classifier_dropout if hasattr(config, "classifier_dropout") else config.hidden_dropout_prob
970
+ )
971
+ self.dropout = nn.Dropout(classifier_dropout)
972
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
973
+
974
+ self.init_weights()
975
+
976
+ @add_start_docstrings_to_model_forward(BROS_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
977
+ @replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC)
978
+ def forward(
979
+ self,
980
+ input_ids: Optional[torch.Tensor] = None,
981
+ bbox: Optional[torch.Tensor] = None,
982
+ attention_mask: Optional[torch.Tensor] = None,
983
+ bbox_first_token_mask: Optional[torch.Tensor] = None,
984
+ token_type_ids: Optional[torch.Tensor] = None,
985
+ position_ids: Optional[torch.Tensor] = None,
986
+ head_mask: Optional[torch.Tensor] = None,
987
+ inputs_embeds: Optional[torch.Tensor] = None,
988
+ labels: Optional[torch.Tensor] = None,
989
+ output_attentions: Optional[bool] = None,
990
+ output_hidden_states: Optional[bool] = None,
991
+ return_dict: Optional[bool] = None,
992
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
993
+ r"""
994
+
995
+ Returns:
996
+
997
+ Examples:
998
+
999
+ ```python
1000
+ >>> import torch
1001
+ >>> from transformers import BrosProcessor, BrosForTokenClassification
1002
+
1003
+ >>> processor = BrosProcessor.from_pretrained("jinho8345/bros-base-uncased")
1004
+
1005
+ >>> model = BrosForTokenClassification.from_pretrained("jinho8345/bros-base-uncased")
1006
+
1007
+ >>> encoding = processor("Hello, my dog is cute", add_special_tokens=False, return_tensors="pt")
1008
+ >>> bbox = torch.tensor([[[0, 0, 1, 1]]]).repeat(1, encoding["input_ids"].shape[-1], 1)
1009
+ >>> encoding["bbox"] = bbox
1010
+
1011
+ >>> outputs = model(**encoding)
1012
+ ```"""
1013
+
1014
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1015
+
1016
+ outputs = self.bros(
1017
+ input_ids,
1018
+ bbox=bbox,
1019
+ attention_mask=attention_mask,
1020
+ token_type_ids=token_type_ids,
1021
+ position_ids=position_ids,
1022
+ head_mask=head_mask,
1023
+ inputs_embeds=inputs_embeds,
1024
+ output_attentions=output_attentions,
1025
+ output_hidden_states=output_hidden_states,
1026
+ return_dict=return_dict,
1027
+ )
1028
+
1029
+ sequence_output = outputs[0]
1030
+
1031
+ sequence_output = self.dropout(sequence_output)
1032
+ logits = self.classifier(sequence_output)
1033
+
1034
+ loss = None
1035
+ if labels is not None:
1036
+ loss_fct = CrossEntropyLoss()
1037
+ if bbox_first_token_mask is not None:
1038
+ bbox_first_token_mask = bbox_first_token_mask.view(-1)
1039
+ loss = loss_fct(
1040
+ logits.view(-1, self.num_labels)[bbox_first_token_mask], labels.view(-1)[bbox_first_token_mask]
1041
+ )
1042
+ else:
1043
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1044
+
1045
+ if not return_dict:
1046
+ output = (logits,) + outputs[2:]
1047
+ return ((loss,) + output) if loss is not None else output
1048
+
1049
+ return TokenClassifierOutput(
1050
+ loss=loss,
1051
+ logits=logits,
1052
+ hidden_states=outputs.hidden_states,
1053
+ attentions=outputs.attentions,
1054
+ )
1055
+
1056
+
1057
+ @add_start_docstrings(
1058
+ """
1059
+ Bros Model with a token classification head on top (initial_token_layers and subsequent_token_layer on top of the
1060
+ hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. The initial_token_classifier is used to
1061
+ predict the first token of each entity, and the subsequent_token_classifier is used to predict the subsequent
1062
+ tokens within an entity. Compared to BrosForTokenClassification, this model is more robust to serialization errors
1063
+ since it predicts next token from one token.
1064
+ """,
1065
+ BROS_START_DOCSTRING,
1066
+ )
1067
+ class BrosSpadeEEForTokenClassification(BrosPreTrainedModel):
1068
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1069
+
1070
+ def __init__(self, config):
1071
+ super().__init__(config)
1072
+ self.config = config
1073
+ self.num_labels = config.num_labels
1074
+ self.n_relations = config.n_relations
1075
+ self.backbone_hidden_size = config.hidden_size
1076
+
1077
+ self.bros = BrosModel(config)
1078
+ classifier_dropout = (
1079
+ config.classifier_dropout if hasattr(config, "classifier_dropout") else config.hidden_dropout_prob
1080
+ )
1081
+
1082
+ # Initial token classification for Entity Extraction (NER)
1083
+ self.initial_token_classifier = nn.Sequential(
1084
+ nn.Dropout(classifier_dropout),
1085
+ nn.Linear(config.hidden_size, config.hidden_size),
1086
+ nn.Dropout(classifier_dropout),
1087
+ nn.Linear(config.hidden_size, config.num_labels),
1088
+ )
1089
+
1090
+ # Subsequent token classification for Entity Extraction (NER)
1091
+ self.subsequent_token_classifier = BrosRelationExtractor(config)
1092
+
1093
+ self.init_weights()
1094
+
1095
+ @add_start_docstrings_to_model_forward(BROS_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1096
+ @replace_return_docstrings(output_type=BrosSpadeOutput, config_class=_CONFIG_FOR_DOC)
1097
+ def forward(
1098
+ self,
1099
+ input_ids: Optional[torch.Tensor] = None,
1100
+ bbox: Optional[torch.Tensor] = None,
1101
+ attention_mask: Optional[torch.Tensor] = None,
1102
+ bbox_first_token_mask: Optional[torch.Tensor] = None,
1103
+ token_type_ids: Optional[torch.Tensor] = None,
1104
+ position_ids: Optional[torch.Tensor] = None,
1105
+ head_mask: Optional[torch.Tensor] = None,
1106
+ inputs_embeds: Optional[torch.Tensor] = None,
1107
+ initial_token_labels: Optional[torch.Tensor] = None,
1108
+ subsequent_token_labels: Optional[torch.Tensor] = None,
1109
+ output_attentions: Optional[bool] = None,
1110
+ output_hidden_states: Optional[bool] = None,
1111
+ return_dict: Optional[bool] = None,
1112
+ ) -> Union[Tuple[torch.Tensor], BrosSpadeOutput]:
1113
+ r"""
1114
+ Returns:
1115
+
1116
+ Examples:
1117
+
1118
+ ```python
1119
+ >>> import torch
1120
+ >>> from transformers import BrosProcessor, BrosSpadeEEForTokenClassification
1121
+
1122
+ >>> processor = BrosProcessor.from_pretrained("jinho8345/bros-base-uncased")
1123
+
1124
+ >>> model = BrosSpadeEEForTokenClassification.from_pretrained("jinho8345/bros-base-uncased")
1125
+
1126
+ >>> encoding = processor("Hello, my dog is cute", add_special_tokens=False, return_tensors="pt")
1127
+ >>> bbox = torch.tensor([[[0, 0, 1, 1]]]).repeat(1, encoding["input_ids"].shape[-1], 1)
1128
+ >>> encoding["bbox"] = bbox
1129
+
1130
+ >>> outputs = model(**encoding)
1131
+ ```"""
1132
+
1133
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1134
+
1135
+ outputs = self.bros(
1136
+ input_ids=input_ids,
1137
+ bbox=bbox,
1138
+ attention_mask=attention_mask,
1139
+ token_type_ids=token_type_ids,
1140
+ position_ids=position_ids,
1141
+ head_mask=head_mask,
1142
+ inputs_embeds=inputs_embeds,
1143
+ output_attentions=output_attentions,
1144
+ output_hidden_states=output_hidden_states,
1145
+ return_dict=return_dict,
1146
+ )
1147
+
1148
+ last_hidden_states = outputs[0]
1149
+ last_hidden_states = last_hidden_states.transpose(0, 1).contiguous()
1150
+ initial_token_logits = self.initial_token_classifier(last_hidden_states).transpose(0, 1).contiguous()
1151
+ subsequent_token_logits = self.subsequent_token_classifier(last_hidden_states, last_hidden_states).squeeze(0)
1152
+
1153
+ # make subsequent token (sequence token classification) mask
1154
+ inv_attention_mask = 1 - attention_mask
1155
+ batch_size, max_seq_length = inv_attention_mask.shape
1156
+ device = inv_attention_mask.device
1157
+ invalid_token_mask = torch.cat([inv_attention_mask, torch.zeros([batch_size, 1]).to(device)], axis=1).bool()
1158
+ subsequent_token_logits = subsequent_token_logits.masked_fill(
1159
+ invalid_token_mask[:, None, :], torch.finfo(subsequent_token_logits.dtype).min
1160
+ )
1161
+ self_token_mask = torch.eye(max_seq_length, max_seq_length + 1).to(device=device, dtype=torch.bool)
1162
+ subsequent_token_logits = subsequent_token_logits.masked_fill(
1163
+ self_token_mask[None, :, :], torch.finfo(subsequent_token_logits.dtype).min
1164
+ )
1165
+ subsequent_token_mask = attention_mask.view(-1).bool()
1166
+
1167
+ loss = None
1168
+ if initial_token_labels is not None and subsequent_token_labels is not None:
1169
+ loss_fct = CrossEntropyLoss()
1170
+
1171
+ # get initial token loss
1172
+ initial_token_labels = initial_token_labels.view(-1)
1173
+ if bbox_first_token_mask is not None:
1174
+ bbox_first_token_mask = bbox_first_token_mask.view(-1)
1175
+ initial_token_loss = loss_fct(
1176
+ initial_token_logits.view(-1, self.num_labels)[bbox_first_token_mask],
1177
+ initial_token_labels[bbox_first_token_mask],
1178
+ )
1179
+ else:
1180
+ initial_token_loss = loss_fct(initial_token_logits.view(-1, self.num_labels), initial_token_labels)
1181
+
1182
+ subsequent_token_labels = subsequent_token_labels.view(-1)
1183
+ subsequent_token_loss = loss_fct(
1184
+ subsequent_token_logits.view(-1, max_seq_length + 1)[subsequent_token_mask],
1185
+ subsequent_token_labels[subsequent_token_mask],
1186
+ )
1187
+
1188
+ loss = initial_token_loss + subsequent_token_loss
1189
+
1190
+ if not return_dict:
1191
+ output = (initial_token_logits, subsequent_token_logits) + outputs[2:]
1192
+ return ((loss,) + output) if loss is not None else output
1193
+
1194
+ return BrosSpadeOutput(
1195
+ loss=loss,
1196
+ initial_token_logits=initial_token_logits,
1197
+ subsequent_token_logits=subsequent_token_logits,
1198
+ hidden_states=outputs.hidden_states,
1199
+ attentions=outputs.attentions,
1200
+ )
1201
+
1202
+
1203
+ @add_start_docstrings(
1204
+ """
1205
+ Bros Model with a token classification head on top (a entity_linker layer on top of the hidden-states output) e.g.
1206
+ for Entity-Linking. The entity_linker is used to predict intra-entity links (one entity to another entity).
1207
+ """,
1208
+ BROS_START_DOCSTRING,
1209
+ )
1210
+ class BrosSpadeELForTokenClassification(BrosPreTrainedModel):
1211
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1212
+
1213
+ def __init__(self, config):
1214
+ super().__init__(config)
1215
+ self.config = config
1216
+ self.num_labels = config.num_labels
1217
+ self.n_relations = config.n_relations
1218
+ self.backbone_hidden_size = config.hidden_size
1219
+
1220
+ self.bros = BrosModel(config)
1221
+ (config.classifier_dropout if hasattr(config, "classifier_dropout") else config.hidden_dropout_prob)
1222
+
1223
+ self.entity_linker = BrosRelationExtractor(config)
1224
+
1225
+ self.init_weights()
1226
+
1227
+ @add_start_docstrings_to_model_forward(BROS_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1228
+ @replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC)
1229
+ def forward(
1230
+ self,
1231
+ input_ids: Optional[torch.Tensor] = None,
1232
+ bbox: Optional[torch.Tensor] = None,
1233
+ attention_mask: Optional[torch.Tensor] = None,
1234
+ bbox_first_token_mask: Optional[torch.Tensor] = None,
1235
+ token_type_ids: Optional[torch.Tensor] = None,
1236
+ position_ids: Optional[torch.Tensor] = None,
1237
+ head_mask: Optional[torch.Tensor] = None,
1238
+ inputs_embeds: Optional[torch.Tensor] = None,
1239
+ labels: Optional[torch.Tensor] = None,
1240
+ output_attentions: Optional[bool] = None,
1241
+ output_hidden_states: Optional[bool] = None,
1242
+ return_dict: Optional[bool] = None,
1243
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1244
+ r"""
1245
+ Returns:
1246
+
1247
+ Examples:
1248
+
1249
+ ```python
1250
+ >>> import torch
1251
+ >>> from transformers import BrosProcessor, BrosSpadeELForTokenClassification
1252
+
1253
+ >>> processor = BrosProcessor.from_pretrained("jinho8345/bros-base-uncased")
1254
+
1255
+ >>> model = BrosSpadeELForTokenClassification.from_pretrained("jinho8345/bros-base-uncased")
1256
+
1257
+ >>> encoding = processor("Hello, my dog is cute", add_special_tokens=False, return_tensors="pt")
1258
+ >>> bbox = torch.tensor([[[0, 0, 1, 1]]]).repeat(1, encoding["input_ids"].shape[-1], 1)
1259
+ >>> encoding["bbox"] = bbox
1260
+
1261
+ >>> outputs = model(**encoding)
1262
+ ```"""
1263
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1264
+
1265
+ outputs = self.bros(
1266
+ input_ids=input_ids,
1267
+ bbox=bbox,
1268
+ attention_mask=attention_mask,
1269
+ token_type_ids=token_type_ids,
1270
+ position_ids=position_ids,
1271
+ head_mask=head_mask,
1272
+ inputs_embeds=inputs_embeds,
1273
+ output_attentions=output_attentions,
1274
+ output_hidden_states=output_hidden_states,
1275
+ return_dict=return_dict,
1276
+ )
1277
+
1278
+ last_hidden_states = outputs[0]
1279
+ last_hidden_states = last_hidden_states.transpose(0, 1).contiguous()
1280
+
1281
+ logits = self.entity_linker(last_hidden_states, last_hidden_states).squeeze(0)
1282
+
1283
+ loss = None
1284
+ if labels is not None:
1285
+ loss_fct = CrossEntropyLoss()
1286
+
1287
+ batch_size, max_seq_length = attention_mask.shape
1288
+ device = attention_mask.device
1289
+
1290
+ self_token_mask = torch.eye(max_seq_length, max_seq_length + 1).to(device=device, dtype=torch.bool)
1291
+
1292
+ mask = bbox_first_token_mask.view(-1)
1293
+ bbox_first_token_mask = torch.cat(
1294
+ [
1295
+ ~bbox_first_token_mask,
1296
+ torch.zeros([batch_size, 1], dtype=torch.bool, device=device),
1297
+ ],
1298
+ axis=1,
1299
+ )
1300
+ logits = logits.masked_fill(bbox_first_token_mask[:, None, :], torch.finfo(logits.dtype).min)
1301
+ logits = logits.masked_fill(self_token_mask[None, :, :], torch.finfo(logits.dtype).min)
1302
+
1303
+ loss = loss_fct(logits.view(-1, max_seq_length + 1)[mask], labels.view(-1)[mask])
1304
+
1305
+ if not return_dict:
1306
+ output = (logits,) + outputs[2:]
1307
+ return ((loss,) + output) if loss is not None else output
1308
+
1309
+ return TokenClassifierOutput(
1310
+ loss=loss,
1311
+ logits=logits,
1312
+ hidden_states=outputs.hidden_states,
1313
+ attentions=outputs.attentions,
1314
+ )
1315
+
1316
+
1317
+ __all__ = [
1318
+ "BrosPreTrainedModel",
1319
+ "BrosModel",
1320
+ "BrosForTokenClassification",
1321
+ "BrosSpadeEEForTokenClassification",
1322
+ "BrosSpadeELForTokenClassification",
1323
+ ]
docs/transformers/src/transformers/models/bros/processing_bros.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Processor class for Bros.
17
+ """
18
+
19
+ from typing import List, Optional, Union
20
+
21
+ from ...processing_utils import ProcessorMixin
22
+ from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
23
+ from ...utils import TensorType
24
+
25
+
26
+ class BrosProcessor(ProcessorMixin):
27
+ r"""
28
+ Constructs a Bros processor which wraps a BERT tokenizer.
29
+
30
+ [`BrosProcessor`] offers all the functionalities of [`BertTokenizerFast`]. See the docstring of
31
+ [`~BrosProcessor.__call__`] and [`~BrosProcessor.decode`] for more information.
32
+
33
+ Args:
34
+ tokenizer (`BertTokenizerFast`, *optional*):
35
+ An instance of ['BertTokenizerFast`]. The tokenizer is a required input.
36
+ """
37
+
38
+ attributes = ["tokenizer"]
39
+ tokenizer_class = ("BertTokenizer", "BertTokenizerFast")
40
+
41
+ def __init__(self, tokenizer=None, **kwargs):
42
+ if tokenizer is None:
43
+ raise ValueError("You need to specify a `tokenizer`.")
44
+
45
+ super().__init__(tokenizer)
46
+
47
+ def __call__(
48
+ self,
49
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
50
+ add_special_tokens: bool = True,
51
+ padding: Union[bool, str, PaddingStrategy] = False,
52
+ truncation: Union[bool, str, TruncationStrategy] = None,
53
+ max_length: Optional[int] = None,
54
+ stride: int = 0,
55
+ pad_to_multiple_of: Optional[int] = None,
56
+ return_token_type_ids: Optional[bool] = None,
57
+ return_attention_mask: Optional[bool] = None,
58
+ return_overflowing_tokens: bool = False,
59
+ return_special_tokens_mask: bool = False,
60
+ return_offsets_mapping: bool = False,
61
+ return_length: bool = False,
62
+ verbose: bool = True,
63
+ return_tensors: Optional[Union[str, TensorType]] = None,
64
+ **kwargs,
65
+ ) -> BatchEncoding:
66
+ """
67
+ This method uses [`BertTokenizerFast.__call__`] to prepare text for the model.
68
+
69
+ Please refer to the docstring of the above two methods for more information.
70
+ """
71
+ encoding = self.tokenizer(
72
+ text=text,
73
+ add_special_tokens=add_special_tokens,
74
+ padding=padding,
75
+ truncation=truncation,
76
+ max_length=max_length,
77
+ stride=stride,
78
+ pad_to_multiple_of=pad_to_multiple_of,
79
+ return_token_type_ids=return_token_type_ids,
80
+ return_attention_mask=return_attention_mask,
81
+ return_overflowing_tokens=return_overflowing_tokens,
82
+ return_special_tokens_mask=return_special_tokens_mask,
83
+ return_offsets_mapping=return_offsets_mapping,
84
+ return_length=return_length,
85
+ verbose=verbose,
86
+ return_tensors=return_tensors,
87
+ **kwargs,
88
+ )
89
+
90
+ return encoding
91
+
92
+ def batch_decode(self, *args, **kwargs):
93
+ """
94
+ This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
95
+ refer to the docstring of this method for more information.
96
+ """
97
+ return self.tokenizer.batch_decode(*args, **kwargs)
98
+
99
+ def decode(self, *args, **kwargs):
100
+ """
101
+ This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
102
+ the docstring of this method for more information.
103
+ """
104
+ return self.tokenizer.decode(*args, **kwargs)
105
+
106
+ @property
107
+ def model_input_names(self):
108
+ tokenizer_input_names = self.tokenizer.model_input_names
109
+ return list(dict.fromkeys(tokenizer_input_names))
110
+
111
+
112
+ __all__ = ["BrosProcessor"]
docs/transformers/src/transformers/models/byt5/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .tokenization_byt5 import *
22
+ else:
23
+ import sys
24
+
25
+ _file = globals()["__file__"]
26
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/src/transformers/models/byt5/convert_byt5_original_tf_checkpoint_to_pytorch.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The T5 authors and HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert T5 checkpoint."""
16
+
17
+ import argparse
18
+
19
+ from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5
20
+ from transformers.utils import logging
21
+
22
+
23
+ logging.set_verbosity_info()
24
+
25
+
26
+ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
27
+ # Initialise PyTorch model
28
+ config = T5Config.from_json_file(config_file)
29
+ print(f"Building PyTorch model from configuration: {config}")
30
+ model = T5ForConditionalGeneration(config)
31
+
32
+ # Load weights from tf checkpoint
33
+ load_tf_weights_in_t5(model, config, tf_checkpoint_path)
34
+
35
+ # Save pytorch-model
36
+ print(f"Save PyTorch model to {pytorch_dump_path}")
37
+ model.save_pretrained(pytorch_dump_path)
38
+
39
+
40
+ if __name__ == "__main__":
41
+ parser = argparse.ArgumentParser()
42
+ # Required parameters
43
+ parser.add_argument(
44
+ "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
45
+ )
46
+ parser.add_argument(
47
+ "--config_file",
48
+ default=None,
49
+ type=str,
50
+ required=True,
51
+ help=(
52
+ "The config json file corresponding to the pre-trained T5 model. \nThis specifies the model architecture."
53
+ ),
54
+ )
55
+ parser.add_argument(
56
+ "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
57
+ )
58
+ args = parser.parse_args()
59
+ convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path)
docs/transformers/src/transformers/models/byt5/tokenization_byt5.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 T5 Authors and HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization class for model ByT5."""
16
+
17
+ import warnings
18
+ from typing import List, Optional, Tuple
19
+
20
+ from ...tokenization_utils import AddedToken, PreTrainedTokenizer
21
+ from ...utils import logging
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ class ByT5Tokenizer(PreTrainedTokenizer):
28
+ """
29
+ Construct a ByT5 tokenizer. ByT5 simply uses raw bytes utf-8 encoding.
30
+
31
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
32
+ this superclass for more information regarding those methods.
33
+
34
+ Args:
35
+ eos_token (`str`, *optional*, defaults to `"</s>"`):
36
+ The end of sequence token.
37
+
38
+ <Tip>
39
+
40
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
41
+ The token used is the `sep_token`.
42
+
43
+ </Tip>
44
+
45
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
46
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
47
+ token instead.
48
+ pad_token (`str`, *optional*, defaults to `"<pad>"`):
49
+ The token used for padding, for example when batching sequences of different lengths.
50
+ extra_ids (`int`, *optional*, defaults to 125):
51
+ Add a number of extra ids added to the end of the vocabulary for use as sentinels. These tokens are
52
+ accessible as "<extra_id_{%d}>" where "{%d}" is a number between 0 and extra_ids-1. Extra tokens are
53
+ indexed from the end of the vocabulary up to beginning ("<extra_id_0>" is the last token in the vocabulary
54
+ like in ByT5 preprocessing see
55
+ [here](https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/t5/data/preprocessors.py#L2117)).
56
+ additional_special_tokens (`List[str]`, *optional*):
57
+ Additional special tokens used by the tokenizer.
58
+ """
59
+
60
+ model_input_names = ["input_ids", "attention_mask"]
61
+
62
+ def __init__(
63
+ self,
64
+ eos_token="</s>",
65
+ unk_token="<unk>",
66
+ pad_token="<pad>",
67
+ extra_ids=125,
68
+ additional_special_tokens=None,
69
+ **kwargs,
70
+ ) -> None:
71
+ # Add extra_ids to the special token list
72
+ if extra_ids > 0 and additional_special_tokens is None:
73
+ additional_special_tokens = [f"<extra_id_{i}>" for i in range(extra_ids)]
74
+ elif extra_ids > 0 and additional_special_tokens is not None and len(additional_special_tokens) > 0:
75
+ # Check that we have the right number of extra_id special tokens
76
+ extra_tokens = len(set(filter(lambda x: bool("extra_id" in str(x)), additional_special_tokens)))
77
+ if extra_tokens != extra_ids:
78
+ raise ValueError(
79
+ f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are"
80
+ " provided to ByT5Tokenizer. In this case the additional_special_tokens must include the"
81
+ " extra_ids tokens"
82
+ )
83
+
84
+ pad_token = AddedToken(pad_token, lstrip=True, rstrip=True) if isinstance(pad_token, str) else pad_token
85
+ # we force left and right stripping for backward compatibility. The byt5tests depend on this.
86
+ eos_token = AddedToken(eos_token, lstrip=True, rstrip=True) if isinstance(eos_token, str) else eos_token
87
+ unk_token = AddedToken(unk_token, lstrip=True, rstrip=True) if isinstance(unk_token, str) else unk_token
88
+ # unk token needs to be in the vocab with correct index
89
+ self._added_tokens_decoder = {0: pad_token, 1: eos_token, 2: unk_token}
90
+ self.offset = len(self._added_tokens_decoder)
91
+ self._utf_vocab_size = 2**8 # utf is 8 bits
92
+ super().__init__(
93
+ eos_token=eos_token,
94
+ unk_token=unk_token,
95
+ pad_token=pad_token,
96
+ extra_ids=0,
97
+ additional_special_tokens=additional_special_tokens, # TODO extra ids are not used :sweatywmile:
98
+ **kwargs,
99
+ )
100
+
101
+ @property
102
+ def vocab_size(self):
103
+ return self._utf_vocab_size
104
+
105
+ def get_vocab(self):
106
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size + self.offset)}
107
+ vocab.update(self.added_tokens_encoder)
108
+ return vocab
109
+
110
+ def get_special_tokens_mask(
111
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
112
+ ) -> List[int]:
113
+ """
114
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
115
+ special tokens using the tokenizer `prepare_for_model` method.
116
+
117
+ Args:
118
+ token_ids_0 (`List[int]`):
119
+ List of IDs.
120
+ token_ids_1 (`List[int]`, *optional*):
121
+ Optional second list of IDs for sequence pairs.
122
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
123
+ Whether or not the token list is already formatted with special tokens for the model.
124
+
125
+ Returns:
126
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
127
+ """
128
+ if already_has_special_tokens:
129
+ return super().get_special_tokens_mask(
130
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
131
+ )
132
+
133
+ # normal case: some special tokens
134
+ if token_ids_1 is None:
135
+ return ([0] * len(token_ids_0)) + [1]
136
+ return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
137
+
138
+ def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]:
139
+ """Do not add eos again if user already added it."""
140
+ if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id:
141
+ warnings.warn(
142
+ f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated"
143
+ " eos tokens being added."
144
+ )
145
+ return token_ids
146
+ else:
147
+ return token_ids + [self.eos_token_id]
148
+
149
+ def create_token_type_ids_from_sequences(
150
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
151
+ ) -> List[int]:
152
+ """
153
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. ByT5 does not
154
+ make use of token type ids, therefore a list of zeros is returned.
155
+
156
+ Args:
157
+ token_ids_0 (`List[int]`):
158
+ List of IDs.
159
+ token_ids_1 (`List[int]`, *optional*):
160
+ Optional second list of IDs for sequence pairs.
161
+
162
+ Returns:
163
+ `List[int]`: List of zeros.
164
+ """
165
+ eos = [self.eos_token_id]
166
+
167
+ if token_ids_1 is None:
168
+ return len(token_ids_0 + eos) * [0]
169
+ return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
170
+
171
+ def build_inputs_with_special_tokens(
172
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
173
+ ) -> List[int]:
174
+ """
175
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
176
+ adding special tokens. A sequence has the following format:
177
+
178
+ - single sequence: `X </s>`
179
+ - pair of sequences: `A </s> B </s>`
180
+
181
+ Args:
182
+ token_ids_0 (`List[int]`):
183
+ List of IDs to which the special tokens will be added.
184
+ token_ids_1 (`List[int]`, *optional*):
185
+ Optional second list of IDs for sequence pairs.
186
+
187
+ Returns:
188
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
189
+ """
190
+ token_ids_0 = self._add_eos_if_not_present(token_ids_0)
191
+ if token_ids_1 is None:
192
+ return token_ids_0
193
+ else:
194
+ token_ids_1 = self._add_eos_if_not_present(token_ids_1)
195
+ return token_ids_0 + token_ids_1
196
+
197
+ def _tokenize(self, text: str) -> List[str]:
198
+ """Take as input a string and return a list of strings (tokens) for words/sub-words"""
199
+ tokens = [chr(i) for i in text.encode("utf-8")]
200
+ return tokens
201
+
202
+ def _convert_token_to_id(self, token):
203
+ """Converts a token (str) in an id using the vocab."""
204
+
205
+ if len(token) != 1:
206
+ token_id = None
207
+ else:
208
+ token_id = ord(token) + self.offset
209
+
210
+ return token_id
211
+
212
+ def _convert_id_to_token(self, index):
213
+ """Converts an index (integer) in a token (str) using the vocab."""
214
+ token = chr(index - self.offset)
215
+ return token
216
+
217
+ def convert_tokens_to_string(self, tokens):
218
+ """Converts a sequence of tokens (string) in a single string."""
219
+ bstring = b""
220
+ for token in tokens:
221
+ if token in self.added_tokens_decoder:
222
+ tok_string = self.added_tokens_decoder[token].encode("utf-8")
223
+ elif token in self.added_tokens_encoder:
224
+ tok_string = token.encode("utf-8")
225
+ else:
226
+ tok_string = bytes([ord(token)])
227
+ bstring += tok_string
228
+ string = bstring.decode("utf-8", errors="ignore")
229
+ return string
230
+
231
+ # ByT5Tokenizer has no vocab file
232
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
233
+ return ()
234
+
235
+
236
+ __all__ = ["ByT5Tokenizer"]
docs/transformers/src/transformers/models/camembert/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_camembert import *
22
+ from .modeling_camembert import *
23
+ from .modeling_tf_camembert import *
24
+ from .tokenization_camembert import *
25
+ from .tokenization_camembert_fast import *
26
+ else:
27
+ import sys
28
+
29
+ _file = globals()["__file__"]
30
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/src/transformers/models/camembert/configuration_camembert.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """CamemBERT configuration"""
17
+
18
+ from collections import OrderedDict
19
+ from typing import Mapping
20
+
21
+ from ...configuration_utils import PretrainedConfig
22
+ from ...onnx import OnnxConfig
23
+ from ...utils import logging
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ class CamembertConfig(PretrainedConfig):
30
+ """
31
+ This is the configuration class to store the configuration of a [`CamembertModel`] or a [`TFCamembertModel`]. It is
32
+ used to instantiate a Camembert model according to the specified arguments, defining the model architecture.
33
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the Camembert
34
+ [almanach/camembert-base](https://huggingface.co/almanach/camembert-base) architecture.
35
+
36
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
37
+ documentation from [`PretrainedConfig`] for more information.
38
+
39
+
40
+ Args:
41
+ vocab_size (`int`, *optional*, defaults to 30522):
42
+ Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
43
+ `inputs_ids` passed when calling [`CamembertModel`] or [`TFCamembertModel`].
44
+ hidden_size (`int`, *optional*, defaults to 768):
45
+ Dimensionality of the encoder layers and the pooler layer.
46
+ num_hidden_layers (`int`, *optional*, defaults to 12):
47
+ Number of hidden layers in the Transformer encoder.
48
+ num_attention_heads (`int`, *optional*, defaults to 12):
49
+ Number of attention heads for each attention layer in the Transformer encoder.
50
+ intermediate_size (`int`, *optional*, defaults to 3072):
51
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
52
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
53
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
54
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
55
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
56
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
57
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
58
+ The dropout ratio for the attention probabilities.
59
+ max_position_embeddings (`int`, *optional*, defaults to 512):
60
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
61
+ just in case (e.g., 512 or 1024 or 2048).
62
+ type_vocab_size (`int`, *optional*, defaults to 2):
63
+ The vocabulary size of the `token_type_ids` passed when calling [`CamembertModel`] or [`TFCamembertModel`].
64
+ initializer_range (`float`, *optional*, defaults to 0.02):
65
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
66
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
67
+ The epsilon used by the layer normalization layers.
68
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
69
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
70
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
71
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
72
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
73
+ with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
74
+ is_decoder (`bool`, *optional*, defaults to `False`):
75
+ Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
76
+ use_cache (`bool`, *optional*, defaults to `True`):
77
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
78
+ relevant if `config.is_decoder=True`.
79
+ classifier_dropout (`float`, *optional*):
80
+ The dropout ratio for the classification head.
81
+
82
+ Example:
83
+
84
+ ```python
85
+ >>> from transformers import CamembertConfig, CamembertModel
86
+
87
+ >>> # Initializing a Camembert almanach/camembert-base style configuration
88
+ >>> configuration = CamembertConfig()
89
+
90
+ >>> # Initializing a model (with random weights) from the almanach/camembert-base style configuration
91
+ >>> model = CamembertModel(configuration)
92
+
93
+ >>> # Accessing the model configuration
94
+ >>> configuration = model.config
95
+ ```"""
96
+
97
+ model_type = "camembert"
98
+
99
+ def __init__(
100
+ self,
101
+ vocab_size=30522,
102
+ hidden_size=768,
103
+ num_hidden_layers=12,
104
+ num_attention_heads=12,
105
+ intermediate_size=3072,
106
+ hidden_act="gelu",
107
+ hidden_dropout_prob=0.1,
108
+ attention_probs_dropout_prob=0.1,
109
+ max_position_embeddings=512,
110
+ type_vocab_size=2,
111
+ initializer_range=0.02,
112
+ layer_norm_eps=1e-12,
113
+ pad_token_id=1,
114
+ bos_token_id=0,
115
+ eos_token_id=2,
116
+ position_embedding_type="absolute",
117
+ use_cache=True,
118
+ classifier_dropout=None,
119
+ **kwargs,
120
+ ):
121
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
122
+
123
+ self.vocab_size = vocab_size
124
+ self.hidden_size = hidden_size
125
+ self.num_hidden_layers = num_hidden_layers
126
+ self.num_attention_heads = num_attention_heads
127
+ self.hidden_act = hidden_act
128
+ self.intermediate_size = intermediate_size
129
+ self.hidden_dropout_prob = hidden_dropout_prob
130
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
131
+ self.max_position_embeddings = max_position_embeddings
132
+ self.type_vocab_size = type_vocab_size
133
+ self.initializer_range = initializer_range
134
+ self.layer_norm_eps = layer_norm_eps
135
+ self.position_embedding_type = position_embedding_type
136
+ self.use_cache = use_cache
137
+ self.classifier_dropout = classifier_dropout
138
+
139
+
140
+ class CamembertOnnxConfig(OnnxConfig):
141
+ @property
142
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
143
+ if self.task == "multiple-choice":
144
+ dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
145
+ else:
146
+ dynamic_axis = {0: "batch", 1: "sequence"}
147
+ return OrderedDict(
148
+ [
149
+ ("input_ids", dynamic_axis),
150
+ ("attention_mask", dynamic_axis),
151
+ ]
152
+ )
153
+
154
+
155
+ __all__ = ["CamembertConfig", "CamembertOnnxConfig"]
docs/transformers/src/transformers/models/camembert/modeling_camembert.py ADDED
@@ -0,0 +1,1716 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2019 Inria, Facebook AI Research and the HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch CamemBERT model."""
17
+
18
+ import math
19
+ from typing import List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from packaging import version
24
+ from torch import nn
25
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
26
+
27
+ from ...activations import ACT2FN, gelu
28
+ from ...generation import GenerationMixin
29
+ from ...modeling_attn_mask_utils import (
30
+ _prepare_4d_attention_mask_for_sdpa,
31
+ _prepare_4d_causal_attention_mask_for_sdpa,
32
+ )
33
+ from ...modeling_outputs import (
34
+ BaseModelOutputWithPastAndCrossAttentions,
35
+ BaseModelOutputWithPoolingAndCrossAttentions,
36
+ CausalLMOutputWithCrossAttentions,
37
+ MaskedLMOutput,
38
+ MultipleChoiceModelOutput,
39
+ QuestionAnsweringModelOutput,
40
+ SequenceClassifierOutput,
41
+ TokenClassifierOutput,
42
+ )
43
+ from ...modeling_utils import PreTrainedModel
44
+ from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
45
+ from ...utils import (
46
+ add_code_sample_docstrings,
47
+ add_start_docstrings,
48
+ add_start_docstrings_to_model_forward,
49
+ get_torch_version,
50
+ logging,
51
+ replace_return_docstrings,
52
+ )
53
+ from .configuration_camembert import CamembertConfig
54
+
55
+
56
+ logger = logging.get_logger(__name__)
57
+
58
+ _CHECKPOINT_FOR_DOC = "almanach/camembert-base"
59
+ _CONFIG_FOR_DOC = "CamembertConfig"
60
+
61
+
62
+ CAMEMBERT_START_DOCSTRING = r"""
63
+
64
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
65
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
66
+ etc.)
67
+
68
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
69
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
70
+ and behavior.
71
+
72
+ Parameters:
73
+ config ([`CamembertConfig`]): Model configuration class with all the parameters of the
74
+ model. Initializing with a config file does not load the weights associated with the model, only the
75
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
76
+ """
77
+
78
+
79
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings with Roberta->Camembert
80
+ class CamembertEmbeddings(nn.Module):
81
+ """
82
+ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
83
+ """
84
+
85
+ # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__
86
+ def __init__(self, config):
87
+ super().__init__()
88
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
89
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
90
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
91
+
92
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
93
+ # any TensorFlow checkpoint file
94
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
95
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
96
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
97
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
98
+ self.register_buffer(
99
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
100
+ )
101
+ self.register_buffer(
102
+ "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
103
+ )
104
+
105
+ # End copy
106
+ self.padding_idx = config.pad_token_id
107
+ self.position_embeddings = nn.Embedding(
108
+ config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
109
+ )
110
+
111
+ def forward(
112
+ self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
113
+ ):
114
+ if position_ids is None:
115
+ if input_ids is not None:
116
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
117
+ position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
118
+ else:
119
+ position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
120
+
121
+ if input_ids is not None:
122
+ input_shape = input_ids.size()
123
+ else:
124
+ input_shape = inputs_embeds.size()[:-1]
125
+
126
+ seq_length = input_shape[1]
127
+
128
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
129
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
130
+ # issue #5664
131
+ if token_type_ids is None:
132
+ if hasattr(self, "token_type_ids"):
133
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
134
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
135
+ token_type_ids = buffered_token_type_ids_expanded
136
+ else:
137
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
138
+
139
+ if inputs_embeds is None:
140
+ inputs_embeds = self.word_embeddings(input_ids)
141
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
142
+
143
+ embeddings = inputs_embeds + token_type_embeddings
144
+ if self.position_embedding_type == "absolute":
145
+ position_embeddings = self.position_embeddings(position_ids)
146
+ embeddings += position_embeddings
147
+ embeddings = self.LayerNorm(embeddings)
148
+ embeddings = self.dropout(embeddings)
149
+ return embeddings
150
+
151
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
152
+ """
153
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
154
+
155
+ Args:
156
+ inputs_embeds: torch.Tensor
157
+
158
+ Returns: torch.Tensor
159
+ """
160
+ input_shape = inputs_embeds.size()[:-1]
161
+ sequence_length = input_shape[1]
162
+
163
+ position_ids = torch.arange(
164
+ self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
165
+ )
166
+ return position_ids.unsqueeze(0).expand(input_shape)
167
+
168
+
169
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->Camembert
170
+ class CamembertSelfAttention(nn.Module):
171
+ def __init__(self, config, position_embedding_type=None):
172
+ super().__init__()
173
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
174
+ raise ValueError(
175
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
176
+ f"heads ({config.num_attention_heads})"
177
+ )
178
+
179
+ self.num_attention_heads = config.num_attention_heads
180
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
181
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
182
+
183
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
184
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
185
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
186
+
187
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
188
+ self.position_embedding_type = position_embedding_type or getattr(
189
+ config, "position_embedding_type", "absolute"
190
+ )
191
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
192
+ self.max_position_embeddings = config.max_position_embeddings
193
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
194
+
195
+ self.is_decoder = config.is_decoder
196
+
197
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
198
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
199
+ x = x.view(new_x_shape)
200
+ return x.permute(0, 2, 1, 3)
201
+
202
+ def forward(
203
+ self,
204
+ hidden_states: torch.Tensor,
205
+ attention_mask: Optional[torch.FloatTensor] = None,
206
+ head_mask: Optional[torch.FloatTensor] = None,
207
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
208
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
209
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
210
+ output_attentions: Optional[bool] = False,
211
+ ) -> Tuple[torch.Tensor]:
212
+ mixed_query_layer = self.query(hidden_states)
213
+
214
+ # If this is instantiated as a cross-attention module, the keys
215
+ # and values come from an encoder; the attention mask needs to be
216
+ # such that the encoder's padding tokens are not attended to.
217
+ is_cross_attention = encoder_hidden_states is not None
218
+
219
+ if is_cross_attention and past_key_value is not None:
220
+ # reuse k,v, cross_attentions
221
+ key_layer = past_key_value[0]
222
+ value_layer = past_key_value[1]
223
+ attention_mask = encoder_attention_mask
224
+ elif is_cross_attention:
225
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
226
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
227
+ attention_mask = encoder_attention_mask
228
+ elif past_key_value is not None:
229
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
230
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
231
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
232
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
233
+ else:
234
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
235
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
236
+
237
+ query_layer = self.transpose_for_scores(mixed_query_layer)
238
+
239
+ use_cache = past_key_value is not None
240
+ if self.is_decoder:
241
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
242
+ # Further calls to cross_attention layer can then reuse all cross-attention
243
+ # key/value_states (first "if" case)
244
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
245
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
246
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
247
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
248
+ past_key_value = (key_layer, value_layer)
249
+
250
+ # Take the dot product between "query" and "key" to get the raw attention scores.
251
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
252
+
253
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
254
+ query_length, key_length = query_layer.shape[2], key_layer.shape[2]
255
+ if use_cache:
256
+ position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
257
+ -1, 1
258
+ )
259
+ else:
260
+ position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
261
+ position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
262
+ distance = position_ids_l - position_ids_r
263
+
264
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
265
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
266
+
267
+ if self.position_embedding_type == "relative_key":
268
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
269
+ attention_scores = attention_scores + relative_position_scores
270
+ elif self.position_embedding_type == "relative_key_query":
271
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
272
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
273
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
274
+
275
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
276
+ if attention_mask is not None:
277
+ # Apply the attention mask is (precomputed for all layers in CamembertModel forward() function)
278
+ attention_scores = attention_scores + attention_mask
279
+
280
+ # Normalize the attention scores to probabilities.
281
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
282
+
283
+ # This is actually dropping out entire tokens to attend to, which might
284
+ # seem a bit unusual, but is taken from the original Transformer paper.
285
+ attention_probs = self.dropout(attention_probs)
286
+
287
+ # Mask heads if we want to
288
+ if head_mask is not None:
289
+ attention_probs = attention_probs * head_mask
290
+
291
+ context_layer = torch.matmul(attention_probs, value_layer)
292
+
293
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
294
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
295
+ context_layer = context_layer.view(new_context_layer_shape)
296
+
297
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
298
+
299
+ if self.is_decoder:
300
+ outputs = outputs + (past_key_value,)
301
+ return outputs
302
+
303
+
304
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaSdpaSelfAttention with Roberta->Camembert
305
+ class CamembertSdpaSelfAttention(CamembertSelfAttention):
306
+ def __init__(self, config, position_embedding_type=None):
307
+ super().__init__(config, position_embedding_type=position_embedding_type)
308
+ self.dropout_prob = config.attention_probs_dropout_prob
309
+ self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")
310
+
311
+ # Adapted from CamembertSelfAttention
312
+ def forward(
313
+ self,
314
+ hidden_states: torch.Tensor,
315
+ attention_mask: Optional[torch.Tensor] = None,
316
+ head_mask: Optional[torch.FloatTensor] = None,
317
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
318
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
319
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
320
+ output_attentions: Optional[bool] = False,
321
+ ) -> Tuple[torch.Tensor]:
322
+ if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None:
323
+ # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented.
324
+ logger.warning_once(
325
+ "CamembertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
326
+ "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to "
327
+ "the manual attention implementation, but specifying the manual implementation will be required from "
328
+ "Transformers version v5.0.0 onwards. This warning can be removed using the argument "
329
+ '`attn_implementation="eager"` when loading the model.'
330
+ )
331
+ return super().forward(
332
+ hidden_states,
333
+ attention_mask,
334
+ head_mask,
335
+ encoder_hidden_states,
336
+ encoder_attention_mask,
337
+ past_key_value,
338
+ output_attentions,
339
+ )
340
+
341
+ bsz, tgt_len, _ = hidden_states.size()
342
+
343
+ query_layer = self.transpose_for_scores(self.query(hidden_states))
344
+
345
+ # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention
346
+ # mask needs to be such that the encoder's padding tokens are not attended to.
347
+ is_cross_attention = encoder_hidden_states is not None
348
+
349
+ current_states = encoder_hidden_states if is_cross_attention else hidden_states
350
+ attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
351
+
352
+ # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning
353
+ if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]:
354
+ key_layer, value_layer = past_key_value
355
+ else:
356
+ key_layer = self.transpose_for_scores(self.key(current_states))
357
+ value_layer = self.transpose_for_scores(self.value(current_states))
358
+ if past_key_value is not None and not is_cross_attention:
359
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
360
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
361
+
362
+ if self.is_decoder:
363
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
364
+ # Further calls to cross_attention layer can then reuse all cross-attention
365
+ # key/value_states (first "if" case)
366
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
367
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
368
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
369
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
370
+ past_key_value = (key_layer, value_layer)
371
+
372
+ # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
373
+ # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
374
+ # Reference: https://github.com/pytorch/pytorch/issues/112577
375
+ if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
376
+ query_layer = query_layer.contiguous()
377
+ key_layer = key_layer.contiguous()
378
+ value_layer = value_layer.contiguous()
379
+
380
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
381
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
382
+ # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
383
+ # a causal mask in case tgt_len == 1.
384
+ is_causal = (
385
+ True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False
386
+ )
387
+
388
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
389
+ query_layer,
390
+ key_layer,
391
+ value_layer,
392
+ attn_mask=attention_mask,
393
+ dropout_p=self.dropout_prob if self.training else 0.0,
394
+ is_causal=is_causal,
395
+ )
396
+
397
+ attn_output = attn_output.transpose(1, 2)
398
+ attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size)
399
+
400
+ outputs = (attn_output,)
401
+ if self.is_decoder:
402
+ outputs = outputs + (past_key_value,)
403
+ return outputs
404
+
405
+
406
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfOutput with Roberta->Camembert
407
+ class CamembertSelfOutput(nn.Module):
408
+ def __init__(self, config):
409
+ super().__init__()
410
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
411
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
412
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
413
+
414
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
415
+ hidden_states = self.dense(hidden_states)
416
+ hidden_states = self.dropout(hidden_states)
417
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
418
+ return hidden_states
419
+
420
+
421
+ CAMEMBERT_SELF_ATTENTION_CLASSES = {
422
+ "eager": CamembertSelfAttention,
423
+ "sdpa": CamembertSdpaSelfAttention,
424
+ }
425
+
426
+
427
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->Camembert,ROBERTA->CAMEMBERT
428
+ class CamembertAttention(nn.Module):
429
+ def __init__(self, config, position_embedding_type=None):
430
+ super().__init__()
431
+ self.self = CAMEMBERT_SELF_ATTENTION_CLASSES[config._attn_implementation](
432
+ config, position_embedding_type=position_embedding_type
433
+ )
434
+ self.output = CamembertSelfOutput(config)
435
+ self.pruned_heads = set()
436
+
437
+ def prune_heads(self, heads):
438
+ if len(heads) == 0:
439
+ return
440
+ heads, index = find_pruneable_heads_and_indices(
441
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
442
+ )
443
+
444
+ # Prune linear layers
445
+ self.self.query = prune_linear_layer(self.self.query, index)
446
+ self.self.key = prune_linear_layer(self.self.key, index)
447
+ self.self.value = prune_linear_layer(self.self.value, index)
448
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
449
+
450
+ # Update hyper params and store pruned heads
451
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
452
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
453
+ self.pruned_heads = self.pruned_heads.union(heads)
454
+
455
+ def forward(
456
+ self,
457
+ hidden_states: torch.Tensor,
458
+ attention_mask: Optional[torch.FloatTensor] = None,
459
+ head_mask: Optional[torch.FloatTensor] = None,
460
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
461
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
462
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
463
+ output_attentions: Optional[bool] = False,
464
+ ) -> Tuple[torch.Tensor]:
465
+ self_outputs = self.self(
466
+ hidden_states,
467
+ attention_mask,
468
+ head_mask,
469
+ encoder_hidden_states,
470
+ encoder_attention_mask,
471
+ past_key_value,
472
+ output_attentions,
473
+ )
474
+ attention_output = self.output(self_outputs[0], hidden_states)
475
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
476
+ return outputs
477
+
478
+
479
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->Roberta->Camembert
480
+ class CamembertIntermediate(nn.Module):
481
+ def __init__(self, config):
482
+ super().__init__()
483
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
484
+ if isinstance(config.hidden_act, str):
485
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
486
+ else:
487
+ self.intermediate_act_fn = config.hidden_act
488
+
489
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
490
+ hidden_states = self.dense(hidden_states)
491
+ hidden_states = self.intermediate_act_fn(hidden_states)
492
+ return hidden_states
493
+
494
+
495
+ # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->Roberta->Camembert
496
+ class CamembertOutput(nn.Module):
497
+ def __init__(self, config):
498
+ super().__init__()
499
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
500
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
501
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
502
+
503
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
504
+ hidden_states = self.dense(hidden_states)
505
+ hidden_states = self.dropout(hidden_states)
506
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
507
+ return hidden_states
508
+
509
+
510
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaLayer with Roberta->Camembert
511
+ class CamembertLayer(nn.Module):
512
+ def __init__(self, config):
513
+ super().__init__()
514
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
515
+ self.seq_len_dim = 1
516
+ self.attention = CamembertAttention(config)
517
+ self.is_decoder = config.is_decoder
518
+ self.add_cross_attention = config.add_cross_attention
519
+ if self.add_cross_attention:
520
+ if not self.is_decoder:
521
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
522
+ self.crossattention = CamembertAttention(config, position_embedding_type="absolute")
523
+ self.intermediate = CamembertIntermediate(config)
524
+ self.output = CamembertOutput(config)
525
+
526
+ def forward(
527
+ self,
528
+ hidden_states: torch.Tensor,
529
+ attention_mask: Optional[torch.FloatTensor] = None,
530
+ head_mask: Optional[torch.FloatTensor] = None,
531
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
532
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
533
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
534
+ output_attentions: Optional[bool] = False,
535
+ ) -> Tuple[torch.Tensor]:
536
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
537
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
538
+ self_attention_outputs = self.attention(
539
+ hidden_states,
540
+ attention_mask,
541
+ head_mask,
542
+ output_attentions=output_attentions,
543
+ past_key_value=self_attn_past_key_value,
544
+ )
545
+ attention_output = self_attention_outputs[0]
546
+
547
+ # if decoder, the last output is tuple of self-attn cache
548
+ if self.is_decoder:
549
+ outputs = self_attention_outputs[1:-1]
550
+ present_key_value = self_attention_outputs[-1]
551
+ else:
552
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
553
+
554
+ cross_attn_present_key_value = None
555
+ if self.is_decoder and encoder_hidden_states is not None:
556
+ if not hasattr(self, "crossattention"):
557
+ raise ValueError(
558
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
559
+ " by setting `config.add_cross_attention=True`"
560
+ )
561
+
562
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
563
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
564
+ cross_attention_outputs = self.crossattention(
565
+ attention_output,
566
+ attention_mask,
567
+ head_mask,
568
+ encoder_hidden_states,
569
+ encoder_attention_mask,
570
+ cross_attn_past_key_value,
571
+ output_attentions,
572
+ )
573
+ attention_output = cross_attention_outputs[0]
574
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
575
+
576
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
577
+ cross_attn_present_key_value = cross_attention_outputs[-1]
578
+ present_key_value = present_key_value + cross_attn_present_key_value
579
+
580
+ layer_output = apply_chunking_to_forward(
581
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
582
+ )
583
+ outputs = (layer_output,) + outputs
584
+
585
+ # if decoder, return the attn key/values as the last output
586
+ if self.is_decoder:
587
+ outputs = outputs + (present_key_value,)
588
+
589
+ return outputs
590
+
591
+ def feed_forward_chunk(self, attention_output):
592
+ intermediate_output = self.intermediate(attention_output)
593
+ layer_output = self.output(intermediate_output, attention_output)
594
+ return layer_output
595
+
596
+
597
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaEncoder with Roberta->Camembert
598
+ class CamembertEncoder(nn.Module):
599
+ def __init__(self, config):
600
+ super().__init__()
601
+ self.config = config
602
+ self.layer = nn.ModuleList([CamembertLayer(config) for _ in range(config.num_hidden_layers)])
603
+ self.gradient_checkpointing = False
604
+
605
+ def forward(
606
+ self,
607
+ hidden_states: torch.Tensor,
608
+ attention_mask: Optional[torch.FloatTensor] = None,
609
+ head_mask: Optional[torch.FloatTensor] = None,
610
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
611
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
612
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
613
+ use_cache: Optional[bool] = None,
614
+ output_attentions: Optional[bool] = False,
615
+ output_hidden_states: Optional[bool] = False,
616
+ return_dict: Optional[bool] = True,
617
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
618
+ all_hidden_states = () if output_hidden_states else None
619
+ all_self_attentions = () if output_attentions else None
620
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
621
+
622
+ if self.gradient_checkpointing and self.training:
623
+ if use_cache:
624
+ logger.warning_once(
625
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
626
+ )
627
+ use_cache = False
628
+
629
+ next_decoder_cache = () if use_cache else None
630
+ for i, layer_module in enumerate(self.layer):
631
+ if output_hidden_states:
632
+ all_hidden_states = all_hidden_states + (hidden_states,)
633
+
634
+ layer_head_mask = head_mask[i] if head_mask is not None else None
635
+ past_key_value = past_key_values[i] if past_key_values is not None else None
636
+
637
+ if self.gradient_checkpointing and self.training:
638
+ layer_outputs = self._gradient_checkpointing_func(
639
+ layer_module.__call__,
640
+ hidden_states,
641
+ attention_mask,
642
+ layer_head_mask,
643
+ encoder_hidden_states,
644
+ encoder_attention_mask,
645
+ past_key_value,
646
+ output_attentions,
647
+ )
648
+ else:
649
+ layer_outputs = layer_module(
650
+ hidden_states,
651
+ attention_mask,
652
+ layer_head_mask,
653
+ encoder_hidden_states,
654
+ encoder_attention_mask,
655
+ past_key_value,
656
+ output_attentions,
657
+ )
658
+
659
+ hidden_states = layer_outputs[0]
660
+ if use_cache:
661
+ next_decoder_cache += (layer_outputs[-1],)
662
+ if output_attentions:
663
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
664
+ if self.config.add_cross_attention:
665
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
666
+
667
+ if output_hidden_states:
668
+ all_hidden_states = all_hidden_states + (hidden_states,)
669
+
670
+ if not return_dict:
671
+ return tuple(
672
+ v
673
+ for v in [
674
+ hidden_states,
675
+ next_decoder_cache,
676
+ all_hidden_states,
677
+ all_self_attentions,
678
+ all_cross_attentions,
679
+ ]
680
+ if v is not None
681
+ )
682
+ return BaseModelOutputWithPastAndCrossAttentions(
683
+ last_hidden_state=hidden_states,
684
+ past_key_values=next_decoder_cache,
685
+ hidden_states=all_hidden_states,
686
+ attentions=all_self_attentions,
687
+ cross_attentions=all_cross_attentions,
688
+ )
689
+
690
+
691
+ # Copied from transformers.models.bert.modeling_bert.BertPooler
692
+ class CamembertPooler(nn.Module):
693
+ def __init__(self, config):
694
+ super().__init__()
695
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
696
+ self.activation = nn.Tanh()
697
+
698
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
699
+ # We "pool" the model by simply taking the hidden state corresponding
700
+ # to the first token.
701
+ first_token_tensor = hidden_states[:, 0]
702
+ pooled_output = self.dense(first_token_tensor)
703
+ pooled_output = self.activation(pooled_output)
704
+ return pooled_output
705
+
706
+
707
+ class CamembertPreTrainedModel(PreTrainedModel):
708
+ """
709
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
710
+ models.
711
+ """
712
+
713
+ config_class = CamembertConfig
714
+ base_model_prefix = "roberta"
715
+ supports_gradient_checkpointing = True
716
+ _supports_sdpa = True
717
+
718
+ # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with BertLMPredictionHead->CamembertLMHead
719
+ def _init_weights(self, module):
720
+ """Initialize the weights"""
721
+ if isinstance(module, nn.Linear):
722
+ # Slightly different from the TF version which uses truncated_normal for initialization
723
+ # cf https://github.com/pytorch/pytorch/pull/5617
724
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
725
+ if module.bias is not None:
726
+ module.bias.data.zero_()
727
+ elif isinstance(module, nn.Embedding):
728
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
729
+ if module.padding_idx is not None:
730
+ module.weight.data[module.padding_idx].zero_()
731
+ elif isinstance(module, nn.LayerNorm):
732
+ module.bias.data.zero_()
733
+ module.weight.data.fill_(1.0)
734
+ elif isinstance(module, CamembertLMHead):
735
+ module.bias.data.zero_()
736
+
737
+
738
+ CAMEMBERT_INPUTS_DOCSTRING = r"""
739
+ Args:
740
+ input_ids (`torch.LongTensor` of shape `({0})`):
741
+ Indices of input sequence tokens in the vocabulary.
742
+
743
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
744
+ [`PreTrainedTokenizer.__call__`] for details.
745
+
746
+ [What are input IDs?](../glossary#input-ids)
747
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
748
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
749
+
750
+ - 1 for tokens that are **not masked**,
751
+ - 0 for tokens that are **masked**.
752
+
753
+ [What are attention masks?](../glossary#attention-mask)
754
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
755
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
756
+ 1]`:
757
+
758
+ - 0 corresponds to a *sentence A* token,
759
+ - 1 corresponds to a *sentence B* token.
760
+
761
+ [What are token type IDs?](../glossary#token-type-ids)
762
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
763
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
764
+ config.max_position_embeddings - 1]`.
765
+
766
+ [What are position IDs?](../glossary#position-ids)
767
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
768
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
769
+
770
+ - 1 indicates the head is **not masked**,
771
+ - 0 indicates the head is **masked**.
772
+
773
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
774
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
775
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
776
+ model's internal embedding lookup matrix.
777
+ output_attentions (`bool`, *optional*):
778
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
779
+ tensors for more detail.
780
+ output_hidden_states (`bool`, *optional*):
781
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
782
+ more detail.
783
+ return_dict (`bool`, *optional*):
784
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
785
+ """
786
+
787
+
788
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaClassificationHead with Roberta->Camembert
789
+ class CamembertClassificationHead(nn.Module):
790
+ """Head for sentence-level classification tasks."""
791
+
792
+ def __init__(self, config):
793
+ super().__init__()
794
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
795
+ classifier_dropout = (
796
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
797
+ )
798
+ self.dropout = nn.Dropout(classifier_dropout)
799
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
800
+
801
+ def forward(self, features, **kwargs):
802
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
803
+ x = self.dropout(x)
804
+ x = self.dense(x)
805
+ x = torch.tanh(x)
806
+ x = self.dropout(x)
807
+ x = self.out_proj(x)
808
+ return x
809
+
810
+
811
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaLMHead with Roberta->Camembert
812
+ class CamembertLMHead(nn.Module):
813
+ """Camembert Head for masked language modeling."""
814
+
815
+ def __init__(self, config):
816
+ super().__init__()
817
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
818
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
819
+
820
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
821
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
822
+ self.decoder.bias = self.bias
823
+
824
+ def forward(self, features, **kwargs):
825
+ x = self.dense(features)
826
+ x = gelu(x)
827
+ x = self.layer_norm(x)
828
+
829
+ # project back to size of vocabulary with bias
830
+ x = self.decoder(x)
831
+
832
+ return x
833
+
834
+ def _tie_weights(self):
835
+ # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
836
+ # For accelerate compatibility and to not break backward compatibility
837
+ if self.decoder.bias.device.type == "meta":
838
+ self.decoder.bias = self.bias
839
+ else:
840
+ self.bias = self.decoder.bias
841
+
842
+
843
+ @add_start_docstrings(
844
+ "The bare CamemBERT Model transformer outputting raw hidden-states without any specific head on top.",
845
+ CAMEMBERT_START_DOCSTRING,
846
+ )
847
+ class CamembertModel(CamembertPreTrainedModel):
848
+ """
849
+
850
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
851
+ cross-attention is added between the self-attention layers, following the architecture described in *Attention is
852
+ all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
853
+ Kaiser and Illia Polosukhin.
854
+
855
+ To behave as a decoder the model needs to be initialized with the `is_decoder` argument of the configuration set to
856
+ `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
857
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
858
+
859
+ .. _*Attention is all you need*: https://arxiv.org/abs/1706.03762
860
+
861
+ """
862
+
863
+ _no_split_modules = []
864
+
865
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaModel.__init__ with Roberta->Camembert
866
+ def __init__(self, config, add_pooling_layer=True):
867
+ super().__init__(config)
868
+ self.config = config
869
+
870
+ self.embeddings = CamembertEmbeddings(config)
871
+ self.encoder = CamembertEncoder(config)
872
+
873
+ self.pooler = CamembertPooler(config) if add_pooling_layer else None
874
+
875
+ self.attn_implementation = config._attn_implementation
876
+ self.position_embedding_type = config.position_embedding_type
877
+
878
+ # Initialize weights and apply final processing
879
+ self.post_init()
880
+
881
+ def get_input_embeddings(self):
882
+ return self.embeddings.word_embeddings
883
+
884
+ def set_input_embeddings(self, value):
885
+ self.embeddings.word_embeddings = value
886
+
887
+ def _prune_heads(self, heads_to_prune):
888
+ """
889
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
890
+ class PreTrainedModel
891
+ """
892
+ for layer, heads in heads_to_prune.items():
893
+ self.encoder.layer[layer].attention.prune_heads(heads)
894
+
895
+ @add_start_docstrings_to_model_forward(CAMEMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
896
+ @add_code_sample_docstrings(
897
+ checkpoint=_CHECKPOINT_FOR_DOC,
898
+ output_type=BaseModelOutputWithPoolingAndCrossAttentions,
899
+ config_class=_CONFIG_FOR_DOC,
900
+ )
901
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaModel.forward
902
+ def forward(
903
+ self,
904
+ input_ids: Optional[torch.Tensor] = None,
905
+ attention_mask: Optional[torch.Tensor] = None,
906
+ token_type_ids: Optional[torch.Tensor] = None,
907
+ position_ids: Optional[torch.Tensor] = None,
908
+ head_mask: Optional[torch.Tensor] = None,
909
+ inputs_embeds: Optional[torch.Tensor] = None,
910
+ encoder_hidden_states: Optional[torch.Tensor] = None,
911
+ encoder_attention_mask: Optional[torch.Tensor] = None,
912
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
913
+ use_cache: Optional[bool] = None,
914
+ output_attentions: Optional[bool] = None,
915
+ output_hidden_states: Optional[bool] = None,
916
+ return_dict: Optional[bool] = None,
917
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
918
+ r"""
919
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
920
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
921
+ the model is configured as a decoder.
922
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, target_length)`, *optional*):
923
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
924
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
925
+
926
+ - 1 for tokens that are **not masked**,
927
+ - 0 for tokens that are **masked**.
928
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
929
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
930
+
931
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
932
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
933
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
934
+ use_cache (`bool`, *optional*):
935
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
936
+ `past_key_values`).
937
+ """
938
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
939
+ output_hidden_states = (
940
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
941
+ )
942
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
943
+
944
+ if self.config.is_decoder:
945
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
946
+ else:
947
+ use_cache = False
948
+
949
+ if input_ids is not None and inputs_embeds is not None:
950
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
951
+ elif input_ids is not None:
952
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
953
+ input_shape = input_ids.size()
954
+ elif inputs_embeds is not None:
955
+ input_shape = inputs_embeds.size()[:-1]
956
+ else:
957
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
958
+
959
+ batch_size, seq_length = input_shape
960
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
961
+
962
+ # past_key_values_length
963
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
964
+
965
+ if token_type_ids is None:
966
+ if hasattr(self.embeddings, "token_type_ids"):
967
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
968
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
969
+ token_type_ids = buffered_token_type_ids_expanded
970
+ else:
971
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
972
+
973
+ embedding_output = self.embeddings(
974
+ input_ids=input_ids,
975
+ position_ids=position_ids,
976
+ token_type_ids=token_type_ids,
977
+ inputs_embeds=inputs_embeds,
978
+ past_key_values_length=past_key_values_length,
979
+ )
980
+
981
+ if attention_mask is None:
982
+ attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device)
983
+
984
+ use_sdpa_attention_masks = (
985
+ self.attn_implementation == "sdpa"
986
+ and self.position_embedding_type == "absolute"
987
+ and head_mask is None
988
+ and not output_attentions
989
+ )
990
+
991
+ # Expand the attention mask
992
+ if use_sdpa_attention_masks and attention_mask.dim() == 2:
993
+ # Expand the attention mask for SDPA.
994
+ # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
995
+ if self.config.is_decoder:
996
+ extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
997
+ attention_mask,
998
+ input_shape,
999
+ embedding_output,
1000
+ past_key_values_length,
1001
+ )
1002
+ else:
1003
+ extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
1004
+ attention_mask, embedding_output.dtype, tgt_len=seq_length
1005
+ )
1006
+ else:
1007
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
1008
+ # ourselves in which case we just need to make it broadcastable to all heads.
1009
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
1010
+
1011
+ # If a 2D or 3D attention mask is provided for the cross-attention
1012
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
1013
+ if self.config.is_decoder and encoder_hidden_states is not None:
1014
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
1015
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
1016
+ if encoder_attention_mask is None:
1017
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
1018
+
1019
+ if use_sdpa_attention_masks and encoder_attention_mask.dim() == 2:
1020
+ # Expand the attention mask for SDPA.
1021
+ # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
1022
+ encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
1023
+ encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length
1024
+ )
1025
+ else:
1026
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
1027
+ else:
1028
+ encoder_extended_attention_mask = None
1029
+
1030
+ # Prepare head mask if needed
1031
+ # 1.0 in head_mask indicate we keep the head
1032
+ # attention_probs has shape bsz x n_heads x N x N
1033
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1034
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1035
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1036
+
1037
+ encoder_outputs = self.encoder(
1038
+ embedding_output,
1039
+ attention_mask=extended_attention_mask,
1040
+ head_mask=head_mask,
1041
+ encoder_hidden_states=encoder_hidden_states,
1042
+ encoder_attention_mask=encoder_extended_attention_mask,
1043
+ past_key_values=past_key_values,
1044
+ use_cache=use_cache,
1045
+ output_attentions=output_attentions,
1046
+ output_hidden_states=output_hidden_states,
1047
+ return_dict=return_dict,
1048
+ )
1049
+ sequence_output = encoder_outputs[0]
1050
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1051
+
1052
+ if not return_dict:
1053
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
1054
+
1055
+ return BaseModelOutputWithPoolingAndCrossAttentions(
1056
+ last_hidden_state=sequence_output,
1057
+ pooler_output=pooled_output,
1058
+ past_key_values=encoder_outputs.past_key_values,
1059
+ hidden_states=encoder_outputs.hidden_states,
1060
+ attentions=encoder_outputs.attentions,
1061
+ cross_attentions=encoder_outputs.cross_attentions,
1062
+ )
1063
+
1064
+
1065
+ @add_start_docstrings(
1066
+ """CamemBERT Model with a `language modeling` head on top.""",
1067
+ CAMEMBERT_START_DOCSTRING,
1068
+ )
1069
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaForMaskedLM with Roberta->Camembert, ROBERTA->CAMEMBERT
1070
+ class CamembertForMaskedLM(CamembertPreTrainedModel):
1071
+ _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
1072
+
1073
+ def __init__(self, config):
1074
+ super().__init__(config)
1075
+
1076
+ if config.is_decoder:
1077
+ logger.warning(
1078
+ "If you want to use `CamembertForMaskedLM` make sure `config.is_decoder=False` for "
1079
+ "bi-directional self-attention."
1080
+ )
1081
+
1082
+ self.roberta = CamembertModel(config, add_pooling_layer=False)
1083
+ self.lm_head = CamembertLMHead(config)
1084
+
1085
+ # Initialize weights and apply final processing
1086
+ self.post_init()
1087
+
1088
+ def get_output_embeddings(self):
1089
+ return self.lm_head.decoder
1090
+
1091
+ def set_output_embeddings(self, new_embeddings):
1092
+ self.lm_head.decoder = new_embeddings
1093
+
1094
+ @add_start_docstrings_to_model_forward(CAMEMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1095
+ @add_code_sample_docstrings(
1096
+ checkpoint=_CHECKPOINT_FOR_DOC,
1097
+ output_type=MaskedLMOutput,
1098
+ config_class=_CONFIG_FOR_DOC,
1099
+ mask="<mask>",
1100
+ expected_output="' Paris'",
1101
+ expected_loss=0.1,
1102
+ )
1103
+ def forward(
1104
+ self,
1105
+ input_ids: Optional[torch.LongTensor] = None,
1106
+ attention_mask: Optional[torch.FloatTensor] = None,
1107
+ token_type_ids: Optional[torch.LongTensor] = None,
1108
+ position_ids: Optional[torch.LongTensor] = None,
1109
+ head_mask: Optional[torch.FloatTensor] = None,
1110
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1111
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1112
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1113
+ labels: Optional[torch.LongTensor] = None,
1114
+ output_attentions: Optional[bool] = None,
1115
+ output_hidden_states: Optional[bool] = None,
1116
+ return_dict: Optional[bool] = None,
1117
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
1118
+ r"""
1119
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1120
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1121
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1122
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1123
+ kwargs (`Dict[str, any]`, *optional*, defaults to `{}`):
1124
+ Used to hide legacy arguments that have been deprecated.
1125
+ """
1126
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1127
+
1128
+ outputs = self.roberta(
1129
+ input_ids,
1130
+ attention_mask=attention_mask,
1131
+ token_type_ids=token_type_ids,
1132
+ position_ids=position_ids,
1133
+ head_mask=head_mask,
1134
+ inputs_embeds=inputs_embeds,
1135
+ encoder_hidden_states=encoder_hidden_states,
1136
+ encoder_attention_mask=encoder_attention_mask,
1137
+ output_attentions=output_attentions,
1138
+ output_hidden_states=output_hidden_states,
1139
+ return_dict=return_dict,
1140
+ )
1141
+ sequence_output = outputs[0]
1142
+ prediction_scores = self.lm_head(sequence_output)
1143
+
1144
+ masked_lm_loss = None
1145
+ if labels is not None:
1146
+ # move labels to correct device to enable model parallelism
1147
+ labels = labels.to(prediction_scores.device)
1148
+ loss_fct = CrossEntropyLoss()
1149
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1150
+
1151
+ if not return_dict:
1152
+ output = (prediction_scores,) + outputs[2:]
1153
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1154
+
1155
+ return MaskedLMOutput(
1156
+ loss=masked_lm_loss,
1157
+ logits=prediction_scores,
1158
+ hidden_states=outputs.hidden_states,
1159
+ attentions=outputs.attentions,
1160
+ )
1161
+
1162
+
1163
+ @add_start_docstrings(
1164
+ """
1165
+ CamemBERT Model transformer with a sequence classification/regression head on top (a linear layer on top of the
1166
+ pooled output) e.g. for GLUE tasks.
1167
+ """,
1168
+ CAMEMBERT_START_DOCSTRING,
1169
+ )
1170
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaForSequenceClassification with Roberta->Camembert, ROBERTA->CAMEMBERT
1171
+ class CamembertForSequenceClassification(CamembertPreTrainedModel):
1172
+ def __init__(self, config):
1173
+ super().__init__(config)
1174
+ self.num_labels = config.num_labels
1175
+ self.config = config
1176
+
1177
+ self.roberta = CamembertModel(config, add_pooling_layer=False)
1178
+ self.classifier = CamembertClassificationHead(config)
1179
+
1180
+ # Initialize weights and apply final processing
1181
+ self.post_init()
1182
+
1183
+ @add_start_docstrings_to_model_forward(CAMEMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1184
+ @add_code_sample_docstrings(
1185
+ checkpoint="cardiffnlp/twitter-roberta-base-emotion",
1186
+ output_type=SequenceClassifierOutput,
1187
+ config_class=_CONFIG_FOR_DOC,
1188
+ expected_output="'optimism'",
1189
+ expected_loss=0.08,
1190
+ )
1191
+ def forward(
1192
+ self,
1193
+ input_ids: Optional[torch.LongTensor] = None,
1194
+ attention_mask: Optional[torch.FloatTensor] = None,
1195
+ token_type_ids: Optional[torch.LongTensor] = None,
1196
+ position_ids: Optional[torch.LongTensor] = None,
1197
+ head_mask: Optional[torch.FloatTensor] = None,
1198
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1199
+ labels: Optional[torch.LongTensor] = None,
1200
+ output_attentions: Optional[bool] = None,
1201
+ output_hidden_states: Optional[bool] = None,
1202
+ return_dict: Optional[bool] = None,
1203
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1204
+ r"""
1205
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1206
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1207
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1208
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1209
+ """
1210
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1211
+
1212
+ outputs = self.roberta(
1213
+ input_ids,
1214
+ attention_mask=attention_mask,
1215
+ token_type_ids=token_type_ids,
1216
+ position_ids=position_ids,
1217
+ head_mask=head_mask,
1218
+ inputs_embeds=inputs_embeds,
1219
+ output_attentions=output_attentions,
1220
+ output_hidden_states=output_hidden_states,
1221
+ return_dict=return_dict,
1222
+ )
1223
+ sequence_output = outputs[0]
1224
+ logits = self.classifier(sequence_output)
1225
+
1226
+ loss = None
1227
+ if labels is not None:
1228
+ # move labels to correct device to enable model parallelism
1229
+ labels = labels.to(logits.device)
1230
+ if self.config.problem_type is None:
1231
+ if self.num_labels == 1:
1232
+ self.config.problem_type = "regression"
1233
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1234
+ self.config.problem_type = "single_label_classification"
1235
+ else:
1236
+ self.config.problem_type = "multi_label_classification"
1237
+
1238
+ if self.config.problem_type == "regression":
1239
+ loss_fct = MSELoss()
1240
+ if self.num_labels == 1:
1241
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1242
+ else:
1243
+ loss = loss_fct(logits, labels)
1244
+ elif self.config.problem_type == "single_label_classification":
1245
+ loss_fct = CrossEntropyLoss()
1246
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1247
+ elif self.config.problem_type == "multi_label_classification":
1248
+ loss_fct = BCEWithLogitsLoss()
1249
+ loss = loss_fct(logits, labels)
1250
+
1251
+ if not return_dict:
1252
+ output = (logits,) + outputs[2:]
1253
+ return ((loss,) + output) if loss is not None else output
1254
+
1255
+ return SequenceClassifierOutput(
1256
+ loss=loss,
1257
+ logits=logits,
1258
+ hidden_states=outputs.hidden_states,
1259
+ attentions=outputs.attentions,
1260
+ )
1261
+
1262
+
1263
+ @add_start_docstrings(
1264
+ """
1265
+ CamemBERT Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
1266
+ softmax) e.g. for RocStories/SWAG tasks.
1267
+ """,
1268
+ CAMEMBERT_START_DOCSTRING,
1269
+ )
1270
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaForMultipleChoice with Roberta->Camembert, ROBERTA->CAMEMBERT
1271
+ class CamembertForMultipleChoice(CamembertPreTrainedModel):
1272
+ def __init__(self, config):
1273
+ super().__init__(config)
1274
+
1275
+ self.roberta = CamembertModel(config)
1276
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1277
+ self.classifier = nn.Linear(config.hidden_size, 1)
1278
+
1279
+ # Initialize weights and apply final processing
1280
+ self.post_init()
1281
+
1282
+ @add_start_docstrings_to_model_forward(
1283
+ CAMEMBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
1284
+ )
1285
+ @add_code_sample_docstrings(
1286
+ checkpoint=_CHECKPOINT_FOR_DOC,
1287
+ output_type=MultipleChoiceModelOutput,
1288
+ config_class=_CONFIG_FOR_DOC,
1289
+ )
1290
+ def forward(
1291
+ self,
1292
+ input_ids: Optional[torch.LongTensor] = None,
1293
+ token_type_ids: Optional[torch.LongTensor] = None,
1294
+ attention_mask: Optional[torch.FloatTensor] = None,
1295
+ labels: Optional[torch.LongTensor] = None,
1296
+ position_ids: Optional[torch.LongTensor] = None,
1297
+ head_mask: Optional[torch.FloatTensor] = None,
1298
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1299
+ output_attentions: Optional[bool] = None,
1300
+ output_hidden_states: Optional[bool] = None,
1301
+ return_dict: Optional[bool] = None,
1302
+ ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
1303
+ r"""
1304
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1305
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
1306
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
1307
+ `input_ids` above)
1308
+ """
1309
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1310
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1311
+
1312
+ flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1313
+ flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1314
+ flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1315
+ flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1316
+ flat_inputs_embeds = (
1317
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1318
+ if inputs_embeds is not None
1319
+ else None
1320
+ )
1321
+
1322
+ outputs = self.roberta(
1323
+ flat_input_ids,
1324
+ position_ids=flat_position_ids,
1325
+ token_type_ids=flat_token_type_ids,
1326
+ attention_mask=flat_attention_mask,
1327
+ head_mask=head_mask,
1328
+ inputs_embeds=flat_inputs_embeds,
1329
+ output_attentions=output_attentions,
1330
+ output_hidden_states=output_hidden_states,
1331
+ return_dict=return_dict,
1332
+ )
1333
+ pooled_output = outputs[1]
1334
+
1335
+ pooled_output = self.dropout(pooled_output)
1336
+ logits = self.classifier(pooled_output)
1337
+ reshaped_logits = logits.view(-1, num_choices)
1338
+
1339
+ loss = None
1340
+ if labels is not None:
1341
+ # move labels to correct device to enable model parallelism
1342
+ labels = labels.to(reshaped_logits.device)
1343
+ loss_fct = CrossEntropyLoss()
1344
+ loss = loss_fct(reshaped_logits, labels)
1345
+
1346
+ if not return_dict:
1347
+ output = (reshaped_logits,) + outputs[2:]
1348
+ return ((loss,) + output) if loss is not None else output
1349
+
1350
+ return MultipleChoiceModelOutput(
1351
+ loss=loss,
1352
+ logits=reshaped_logits,
1353
+ hidden_states=outputs.hidden_states,
1354
+ attentions=outputs.attentions,
1355
+ )
1356
+
1357
+
1358
+ @add_start_docstrings(
1359
+ """
1360
+ CamemBERT Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.
1361
+ for Named-Entity-Recognition (NER) tasks.
1362
+ """,
1363
+ CAMEMBERT_START_DOCSTRING,
1364
+ )
1365
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaForTokenClassification with Roberta->Camembert, ROBERTA->CAMEMBERT
1366
+ class CamembertForTokenClassification(CamembertPreTrainedModel):
1367
+ def __init__(self, config):
1368
+ super().__init__(config)
1369
+ self.num_labels = config.num_labels
1370
+
1371
+ self.roberta = CamembertModel(config, add_pooling_layer=False)
1372
+ classifier_dropout = (
1373
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1374
+ )
1375
+ self.dropout = nn.Dropout(classifier_dropout)
1376
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1377
+
1378
+ # Initialize weights and apply final processing
1379
+ self.post_init()
1380
+
1381
+ @add_start_docstrings_to_model_forward(CAMEMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1382
+ @add_code_sample_docstrings(
1383
+ checkpoint="Jean-Baptiste/roberta-large-ner-english",
1384
+ output_type=TokenClassifierOutput,
1385
+ config_class=_CONFIG_FOR_DOC,
1386
+ expected_output="['O', 'ORG', 'ORG', 'O', 'O', 'O', 'O', 'O', 'LOC', 'O', 'LOC', 'LOC']",
1387
+ expected_loss=0.01,
1388
+ )
1389
+ def forward(
1390
+ self,
1391
+ input_ids: Optional[torch.LongTensor] = None,
1392
+ attention_mask: Optional[torch.FloatTensor] = None,
1393
+ token_type_ids: Optional[torch.LongTensor] = None,
1394
+ position_ids: Optional[torch.LongTensor] = None,
1395
+ head_mask: Optional[torch.FloatTensor] = None,
1396
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1397
+ labels: Optional[torch.LongTensor] = None,
1398
+ output_attentions: Optional[bool] = None,
1399
+ output_hidden_states: Optional[bool] = None,
1400
+ return_dict: Optional[bool] = None,
1401
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1402
+ r"""
1403
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1404
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1405
+ """
1406
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1407
+
1408
+ outputs = self.roberta(
1409
+ input_ids,
1410
+ attention_mask=attention_mask,
1411
+ token_type_ids=token_type_ids,
1412
+ position_ids=position_ids,
1413
+ head_mask=head_mask,
1414
+ inputs_embeds=inputs_embeds,
1415
+ output_attentions=output_attentions,
1416
+ output_hidden_states=output_hidden_states,
1417
+ return_dict=return_dict,
1418
+ )
1419
+
1420
+ sequence_output = outputs[0]
1421
+
1422
+ sequence_output = self.dropout(sequence_output)
1423
+ logits = self.classifier(sequence_output)
1424
+
1425
+ loss = None
1426
+ if labels is not None:
1427
+ # move labels to correct device to enable model parallelism
1428
+ labels = labels.to(logits.device)
1429
+ loss_fct = CrossEntropyLoss()
1430
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1431
+
1432
+ if not return_dict:
1433
+ output = (logits,) + outputs[2:]
1434
+ return ((loss,) + output) if loss is not None else output
1435
+
1436
+ return TokenClassifierOutput(
1437
+ loss=loss,
1438
+ logits=logits,
1439
+ hidden_states=outputs.hidden_states,
1440
+ attentions=outputs.attentions,
1441
+ )
1442
+
1443
+
1444
+ @add_start_docstrings(
1445
+ """
1446
+ CamemBERT Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1447
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`
1448
+ """,
1449
+ CAMEMBERT_START_DOCSTRING,
1450
+ )
1451
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaForQuestionAnswering with Roberta->Camembert, ROBERTA->CAMEMBERT
1452
+ class CamembertForQuestionAnswering(CamembertPreTrainedModel):
1453
+ def __init__(self, config):
1454
+ super().__init__(config)
1455
+ self.num_labels = config.num_labels
1456
+
1457
+ self.roberta = CamembertModel(config, add_pooling_layer=False)
1458
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1459
+
1460
+ # Initialize weights and apply final processing
1461
+ self.post_init()
1462
+
1463
+ @add_start_docstrings_to_model_forward(CAMEMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1464
+ @add_code_sample_docstrings(
1465
+ checkpoint="deepset/roberta-base-squad2",
1466
+ output_type=QuestionAnsweringModelOutput,
1467
+ config_class=_CONFIG_FOR_DOC,
1468
+ expected_output="' puppet'",
1469
+ expected_loss=0.86,
1470
+ )
1471
+ def forward(
1472
+ self,
1473
+ input_ids: Optional[torch.LongTensor] = None,
1474
+ attention_mask: Optional[torch.FloatTensor] = None,
1475
+ token_type_ids: Optional[torch.LongTensor] = None,
1476
+ position_ids: Optional[torch.LongTensor] = None,
1477
+ head_mask: Optional[torch.FloatTensor] = None,
1478
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1479
+ start_positions: Optional[torch.LongTensor] = None,
1480
+ end_positions: Optional[torch.LongTensor] = None,
1481
+ output_attentions: Optional[bool] = None,
1482
+ output_hidden_states: Optional[bool] = None,
1483
+ return_dict: Optional[bool] = None,
1484
+ ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
1485
+ r"""
1486
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1487
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1488
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1489
+ are not taken into account for computing the loss.
1490
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1491
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1492
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1493
+ are not taken into account for computing the loss.
1494
+ """
1495
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1496
+
1497
+ outputs = self.roberta(
1498
+ input_ids,
1499
+ attention_mask=attention_mask,
1500
+ token_type_ids=token_type_ids,
1501
+ position_ids=position_ids,
1502
+ head_mask=head_mask,
1503
+ inputs_embeds=inputs_embeds,
1504
+ output_attentions=output_attentions,
1505
+ output_hidden_states=output_hidden_states,
1506
+ return_dict=return_dict,
1507
+ )
1508
+
1509
+ sequence_output = outputs[0]
1510
+
1511
+ logits = self.qa_outputs(sequence_output)
1512
+ start_logits, end_logits = logits.split(1, dim=-1)
1513
+ start_logits = start_logits.squeeze(-1).contiguous()
1514
+ end_logits = end_logits.squeeze(-1).contiguous()
1515
+
1516
+ total_loss = None
1517
+ if start_positions is not None and end_positions is not None:
1518
+ # If we are on multi-GPU, split add a dimension
1519
+ if len(start_positions.size()) > 1:
1520
+ start_positions = start_positions.squeeze(-1)
1521
+ if len(end_positions.size()) > 1:
1522
+ end_positions = end_positions.squeeze(-1)
1523
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1524
+ ignored_index = start_logits.size(1)
1525
+ start_positions = start_positions.clamp(0, ignored_index)
1526
+ end_positions = end_positions.clamp(0, ignored_index)
1527
+
1528
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1529
+ start_loss = loss_fct(start_logits, start_positions)
1530
+ end_loss = loss_fct(end_logits, end_positions)
1531
+ total_loss = (start_loss + end_loss) / 2
1532
+
1533
+ if not return_dict:
1534
+ output = (start_logits, end_logits) + outputs[2:]
1535
+ return ((total_loss,) + output) if total_loss is not None else output
1536
+
1537
+ return QuestionAnsweringModelOutput(
1538
+ loss=total_loss,
1539
+ start_logits=start_logits,
1540
+ end_logits=end_logits,
1541
+ hidden_states=outputs.hidden_states,
1542
+ attentions=outputs.attentions,
1543
+ )
1544
+
1545
+
1546
+ @add_start_docstrings(
1547
+ """CamemBERT Model with a `language modeling` head on top for CLM fine-tuning.""", CAMEMBERT_START_DOCSTRING
1548
+ )
1549
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM with Roberta->Camembert, ROBERTA->CAMEMBERT, FacebookAI/roberta-base->almanach/camembert-base
1550
+ class CamembertForCausalLM(CamembertPreTrainedModel, GenerationMixin):
1551
+ _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
1552
+
1553
+ def __init__(self, config):
1554
+ super().__init__(config)
1555
+
1556
+ if not config.is_decoder:
1557
+ logger.warning("If you want to use `CamembertLMHeadModel` as a standalone, add `is_decoder=True.`")
1558
+
1559
+ self.roberta = CamembertModel(config, add_pooling_layer=False)
1560
+ self.lm_head = CamembertLMHead(config)
1561
+
1562
+ # Initialize weights and apply final processing
1563
+ self.post_init()
1564
+
1565
+ def get_output_embeddings(self):
1566
+ return self.lm_head.decoder
1567
+
1568
+ def set_output_embeddings(self, new_embeddings):
1569
+ self.lm_head.decoder = new_embeddings
1570
+
1571
+ @add_start_docstrings_to_model_forward(CAMEMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1572
+ @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
1573
+ def forward(
1574
+ self,
1575
+ input_ids: Optional[torch.LongTensor] = None,
1576
+ attention_mask: Optional[torch.FloatTensor] = None,
1577
+ token_type_ids: Optional[torch.LongTensor] = None,
1578
+ position_ids: Optional[torch.LongTensor] = None,
1579
+ head_mask: Optional[torch.FloatTensor] = None,
1580
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1581
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1582
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1583
+ labels: Optional[torch.LongTensor] = None,
1584
+ past_key_values: Tuple[Tuple[torch.FloatTensor]] = None,
1585
+ use_cache: Optional[bool] = None,
1586
+ output_attentions: Optional[bool] = None,
1587
+ output_hidden_states: Optional[bool] = None,
1588
+ return_dict: Optional[bool] = None,
1589
+ **kwargs,
1590
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
1591
+ r"""
1592
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1593
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1594
+ the model is configured as a decoder.
1595
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1596
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1597
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
1598
+
1599
+ - 1 for tokens that are **not masked**,
1600
+ - 0 for tokens that are **masked**.
1601
+
1602
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1603
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1604
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
1605
+ ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1606
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1607
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1608
+
1609
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1610
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1611
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1612
+ use_cache (`bool`, *optional*):
1613
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1614
+ `past_key_values`).
1615
+
1616
+ Returns:
1617
+
1618
+ Example:
1619
+
1620
+ ```python
1621
+ >>> from transformers import AutoTokenizer, CamembertForCausalLM, AutoConfig
1622
+ >>> import torch
1623
+
1624
+ >>> tokenizer = AutoTokenizer.from_pretrained("almanach/camembert-base")
1625
+ >>> config = AutoConfig.from_pretrained("almanach/camembert-base")
1626
+ >>> config.is_decoder = True
1627
+ >>> model = CamembertForCausalLM.from_pretrained("almanach/camembert-base", config=config)
1628
+
1629
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1630
+ >>> outputs = model(**inputs)
1631
+
1632
+ >>> prediction_logits = outputs.logits
1633
+ ```"""
1634
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1635
+ if labels is not None:
1636
+ use_cache = False
1637
+
1638
+ outputs = self.roberta(
1639
+ input_ids,
1640
+ attention_mask=attention_mask,
1641
+ token_type_ids=token_type_ids,
1642
+ position_ids=position_ids,
1643
+ head_mask=head_mask,
1644
+ inputs_embeds=inputs_embeds,
1645
+ encoder_hidden_states=encoder_hidden_states,
1646
+ encoder_attention_mask=encoder_attention_mask,
1647
+ past_key_values=past_key_values,
1648
+ use_cache=use_cache,
1649
+ output_attentions=output_attentions,
1650
+ output_hidden_states=output_hidden_states,
1651
+ return_dict=return_dict,
1652
+ )
1653
+
1654
+ sequence_output = outputs[0]
1655
+ prediction_scores = self.lm_head(sequence_output)
1656
+
1657
+ lm_loss = None
1658
+ if labels is not None:
1659
+ # move labels to correct device to enable model parallelism
1660
+ labels = labels.to(prediction_scores.device)
1661
+ lm_loss = self.loss_function(
1662
+ prediction_scores,
1663
+ labels,
1664
+ vocab_size=self.config.vocab_size,
1665
+ **kwargs,
1666
+ )
1667
+
1668
+ if not return_dict:
1669
+ output = (prediction_scores,) + outputs[2:]
1670
+ return ((lm_loss,) + output) if lm_loss is not None else output
1671
+
1672
+ return CausalLMOutputWithCrossAttentions(
1673
+ loss=lm_loss,
1674
+ logits=prediction_scores,
1675
+ past_key_values=outputs.past_key_values,
1676
+ hidden_states=outputs.hidden_states,
1677
+ attentions=outputs.attentions,
1678
+ cross_attentions=outputs.cross_attentions,
1679
+ )
1680
+
1681
+ def _reorder_cache(self, past_key_values, beam_idx):
1682
+ reordered_past = ()
1683
+ for layer_past in past_key_values:
1684
+ reordered_past += (
1685
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1686
+ )
1687
+ return reordered_past
1688
+
1689
+
1690
+ # Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids
1691
+ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
1692
+ """
1693
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
1694
+ are ignored. This is modified from fairseq's `utils.make_positions`.
1695
+
1696
+ Args:
1697
+ x: torch.Tensor x:
1698
+
1699
+ Returns: torch.Tensor
1700
+ """
1701
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
1702
+ mask = input_ids.ne(padding_idx).int()
1703
+ incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
1704
+ return incremental_indices.long() + padding_idx
1705
+
1706
+
1707
+ __all__ = [
1708
+ "CamembertForCausalLM",
1709
+ "CamembertForMaskedLM",
1710
+ "CamembertForMultipleChoice",
1711
+ "CamembertForQuestionAnswering",
1712
+ "CamembertForSequenceClassification",
1713
+ "CamembertForTokenClassification",
1714
+ "CamembertModel",
1715
+ "CamembertPreTrainedModel",
1716
+ ]
docs/transformers/src/transformers/models/camembert/modeling_tf_camembert.py ADDED
@@ -0,0 +1,1801 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """TF 2.0 CamemBERT model."""
17
+
18
+ from __future__ import annotations
19
+
20
+ import math
21
+ import warnings
22
+ from typing import Optional, Tuple, Union
23
+
24
+ import numpy as np
25
+ import tensorflow as tf
26
+
27
+ from ...activations_tf import get_tf_activation
28
+ from ...modeling_tf_outputs import (
29
+ TFBaseModelOutputWithPastAndCrossAttentions,
30
+ TFBaseModelOutputWithPoolingAndCrossAttentions,
31
+ TFCausalLMOutputWithCrossAttentions,
32
+ TFMaskedLMOutput,
33
+ TFMultipleChoiceModelOutput,
34
+ TFQuestionAnsweringModelOutput,
35
+ TFSequenceClassifierOutput,
36
+ TFTokenClassifierOutput,
37
+ )
38
+ from ...modeling_tf_utils import (
39
+ TFCausalLanguageModelingLoss,
40
+ TFMaskedLanguageModelingLoss,
41
+ TFModelInputType,
42
+ TFMultipleChoiceLoss,
43
+ TFPreTrainedModel,
44
+ TFQuestionAnsweringLoss,
45
+ TFSequenceClassificationLoss,
46
+ TFTokenClassificationLoss,
47
+ get_initializer,
48
+ keras,
49
+ keras_serializable,
50
+ unpack_inputs,
51
+ )
52
+ from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
53
+ from ...utils import (
54
+ add_code_sample_docstrings,
55
+ add_start_docstrings,
56
+ add_start_docstrings_to_model_forward,
57
+ logging,
58
+ )
59
+ from .configuration_camembert import CamembertConfig
60
+
61
+
62
+ logger = logging.get_logger(__name__)
63
+
64
+ _CHECKPOINT_FOR_DOC = "almanach/camembert-base"
65
+ _CONFIG_FOR_DOC = "CamembertConfig"
66
+
67
+
68
+ CAMEMBERT_START_DOCSTRING = r"""
69
+
70
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
71
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
72
+ etc.)
73
+
74
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
75
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
76
+ behavior.
77
+
78
+ <Tip>
79
+
80
+ TensorFlow models and layers in `transformers` accept two formats as input:
81
+
82
+ - having all inputs as keyword arguments (like PyTorch models), or
83
+ - having all inputs as a list, tuple or dict in the first positional argument.
84
+
85
+ The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
86
+ and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
87
+ pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
88
+ format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
89
+ the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
90
+ positional argument:
91
+
92
+ - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
93
+ - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
94
+ `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
95
+ - a dictionary with one or several input Tensors associated to the input names given in the docstring:
96
+ `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
97
+
98
+ Note that when creating models and layers with
99
+ [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
100
+ about any of this, as you can just pass inputs like you would to any other Python function!
101
+
102
+ </Tip>
103
+
104
+ Parameters:
105
+ config ([`CamembertConfig`]): Model configuration class with all the parameters of the
106
+ model. Initializing with a config file does not load the weights associated with the model, only the
107
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
108
+ """
109
+
110
+ CAMEMBERT_INPUTS_DOCSTRING = r"""
111
+ Args:
112
+ input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`):
113
+ Indices of input sequence tokens in the vocabulary.
114
+
115
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
116
+ [`PreTrainedTokenizer.encode`] for details.
117
+
118
+ [What are input IDs?](../glossary#input-ids)
119
+ attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
120
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
121
+
122
+ - 1 for tokens that are **not masked**,
123
+ - 0 for tokens that are **masked**.
124
+
125
+ [What are attention masks?](../glossary#attention-mask)
126
+ token_type_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
127
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
128
+ 1]`:
129
+
130
+ - 0 corresponds to a *sentence A* token,
131
+ - 1 corresponds to a *sentence B* token.
132
+
133
+ [What are token type IDs?](../glossary#token-type-ids)
134
+ position_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
135
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
136
+ config.max_position_embeddings - 1]`.
137
+
138
+ [What are position IDs?](../glossary#position-ids)
139
+ head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
140
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
141
+
142
+ - 1 indicates the head is **not masked**,
143
+ - 0 indicates the head is **masked**.
144
+
145
+ inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*):
146
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
147
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
148
+ model's internal embedding lookup matrix.
149
+ output_attentions (`bool`, *optional*):
150
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
151
+ tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
152
+ config will be used instead.
153
+ output_hidden_states (`bool`, *optional*):
154
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
155
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
156
+ used instead.
157
+ return_dict (`bool`, *optional*):
158
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
159
+ eager mode, in graph mode the value will always be set to True.
160
+ training (`bool`, *optional*, defaults to `False`):
161
+ Whether or not to use the model in training mode (some modules like dropout modules have different
162
+ behaviors between training and evaluation).
163
+ """
164
+
165
+
166
+ # Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaEmbeddings
167
+ class TFCamembertEmbeddings(keras.layers.Layer):
168
+ """
169
+ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
170
+ """
171
+
172
+ def __init__(self, config, **kwargs):
173
+ super().__init__(**kwargs)
174
+
175
+ self.padding_idx = 1
176
+ self.config = config
177
+ self.hidden_size = config.hidden_size
178
+ self.max_position_embeddings = config.max_position_embeddings
179
+ self.initializer_range = config.initializer_range
180
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
181
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
182
+
183
+ def build(self, input_shape=None):
184
+ with tf.name_scope("word_embeddings"):
185
+ self.weight = self.add_weight(
186
+ name="weight",
187
+ shape=[self.config.vocab_size, self.hidden_size],
188
+ initializer=get_initializer(self.initializer_range),
189
+ )
190
+
191
+ with tf.name_scope("token_type_embeddings"):
192
+ self.token_type_embeddings = self.add_weight(
193
+ name="embeddings",
194
+ shape=[self.config.type_vocab_size, self.hidden_size],
195
+ initializer=get_initializer(self.initializer_range),
196
+ )
197
+
198
+ with tf.name_scope("position_embeddings"):
199
+ self.position_embeddings = self.add_weight(
200
+ name="embeddings",
201
+ shape=[self.max_position_embeddings, self.hidden_size],
202
+ initializer=get_initializer(self.initializer_range),
203
+ )
204
+
205
+ if self.built:
206
+ return
207
+ self.built = True
208
+ if getattr(self, "LayerNorm", None) is not None:
209
+ with tf.name_scope(self.LayerNorm.name):
210
+ self.LayerNorm.build([None, None, self.config.hidden_size])
211
+
212
+ def create_position_ids_from_input_ids(self, input_ids, past_key_values_length=0):
213
+ """
214
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
215
+ symbols are ignored. This is modified from fairseq's `utils.make_positions`.
216
+
217
+ Args:
218
+ input_ids: tf.Tensor
219
+ Returns: tf.Tensor
220
+ """
221
+ mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype)
222
+ incremental_indices = (tf.math.cumsum(mask, axis=1) + past_key_values_length) * mask
223
+
224
+ return incremental_indices + self.padding_idx
225
+
226
+ def call(
227
+ self,
228
+ input_ids=None,
229
+ position_ids=None,
230
+ token_type_ids=None,
231
+ inputs_embeds=None,
232
+ past_key_values_length=0,
233
+ training=False,
234
+ ):
235
+ """
236
+ Applies embedding based on inputs tensor.
237
+
238
+ Returns:
239
+ final_embeddings (`tf.Tensor`): output embedding tensor.
240
+ """
241
+ assert not (input_ids is None and inputs_embeds is None)
242
+
243
+ if input_ids is not None:
244
+ check_embeddings_within_bounds(input_ids, self.config.vocab_size)
245
+ inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
246
+
247
+ input_shape = shape_list(inputs_embeds)[:-1]
248
+
249
+ if token_type_ids is None:
250
+ token_type_ids = tf.fill(dims=input_shape, value=0)
251
+
252
+ if position_ids is None:
253
+ if input_ids is not None:
254
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
255
+ position_ids = self.create_position_ids_from_input_ids(
256
+ input_ids=input_ids, past_key_values_length=past_key_values_length
257
+ )
258
+ else:
259
+ position_ids = tf.expand_dims(
260
+ tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1), axis=0
261
+ )
262
+
263
+ position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
264
+ token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
265
+ final_embeddings = inputs_embeds + position_embeds + token_type_embeds
266
+ final_embeddings = self.LayerNorm(inputs=final_embeddings)
267
+ final_embeddings = self.dropout(inputs=final_embeddings, training=training)
268
+
269
+ return final_embeddings
270
+
271
+
272
+ # Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->Camembert
273
+ class TFCamembertPooler(keras.layers.Layer):
274
+ def __init__(self, config: CamembertConfig, **kwargs):
275
+ super().__init__(**kwargs)
276
+
277
+ self.dense = keras.layers.Dense(
278
+ units=config.hidden_size,
279
+ kernel_initializer=get_initializer(config.initializer_range),
280
+ activation="tanh",
281
+ name="dense",
282
+ )
283
+ self.config = config
284
+
285
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
286
+ # We "pool" the model by simply taking the hidden state corresponding
287
+ # to the first token.
288
+ first_token_tensor = hidden_states[:, 0]
289
+ pooled_output = self.dense(inputs=first_token_tensor)
290
+
291
+ return pooled_output
292
+
293
+ def build(self, input_shape=None):
294
+ if self.built:
295
+ return
296
+ self.built = True
297
+ if getattr(self, "dense", None) is not None:
298
+ with tf.name_scope(self.dense.name):
299
+ self.dense.build([None, None, self.config.hidden_size])
300
+
301
+
302
+ # Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention with Bert->Camembert
303
+ class TFCamembertSelfAttention(keras.layers.Layer):
304
+ def __init__(self, config: CamembertConfig, **kwargs):
305
+ super().__init__(**kwargs)
306
+
307
+ if config.hidden_size % config.num_attention_heads != 0:
308
+ raise ValueError(
309
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number "
310
+ f"of attention heads ({config.num_attention_heads})"
311
+ )
312
+
313
+ self.num_attention_heads = config.num_attention_heads
314
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
315
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
316
+ self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
317
+
318
+ self.query = keras.layers.Dense(
319
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
320
+ )
321
+ self.key = keras.layers.Dense(
322
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
323
+ )
324
+ self.value = keras.layers.Dense(
325
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
326
+ )
327
+ self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
328
+
329
+ self.is_decoder = config.is_decoder
330
+ self.config = config
331
+
332
+ def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
333
+ # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
334
+ tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
335
+
336
+ # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
337
+ return tf.transpose(tensor, perm=[0, 2, 1, 3])
338
+
339
+ def call(
340
+ self,
341
+ hidden_states: tf.Tensor,
342
+ attention_mask: tf.Tensor,
343
+ head_mask: tf.Tensor,
344
+ encoder_hidden_states: tf.Tensor,
345
+ encoder_attention_mask: tf.Tensor,
346
+ past_key_value: Tuple[tf.Tensor],
347
+ output_attentions: bool,
348
+ training: bool = False,
349
+ ) -> Tuple[tf.Tensor]:
350
+ batch_size = shape_list(hidden_states)[0]
351
+ mixed_query_layer = self.query(inputs=hidden_states)
352
+
353
+ # If this is instantiated as a cross-attention module, the keys
354
+ # and values come from an encoder; the attention mask needs to be
355
+ # such that the encoder's padding tokens are not attended to.
356
+ is_cross_attention = encoder_hidden_states is not None
357
+
358
+ if is_cross_attention and past_key_value is not None:
359
+ # reuse k,v, cross_attentions
360
+ key_layer = past_key_value[0]
361
+ value_layer = past_key_value[1]
362
+ attention_mask = encoder_attention_mask
363
+ elif is_cross_attention:
364
+ key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size)
365
+ value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size)
366
+ attention_mask = encoder_attention_mask
367
+ elif past_key_value is not None:
368
+ key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
369
+ value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
370
+ key_layer = tf.concat([past_key_value[0], key_layer], axis=2)
371
+ value_layer = tf.concat([past_key_value[1], value_layer], axis=2)
372
+ else:
373
+ key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
374
+ value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
375
+
376
+ query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
377
+
378
+ if self.is_decoder:
379
+ # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
380
+ # Further calls to cross_attention layer can then reuse all cross-attention
381
+ # key/value_states (first "if" case)
382
+ # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
383
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
384
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
385
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
386
+ past_key_value = (key_layer, value_layer)
387
+
388
+ # Take the dot product between "query" and "key" to get the raw attention scores.
389
+ # (batch size, num_heads, seq_len_q, seq_len_k)
390
+ attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
391
+ dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
392
+ attention_scores = tf.divide(attention_scores, dk)
393
+
394
+ if attention_mask is not None:
395
+ # Apply the attention mask is (precomputed for all layers in TFCamembertModel call() function)
396
+ attention_scores = tf.add(attention_scores, attention_mask)
397
+
398
+ # Normalize the attention scores to probabilities.
399
+ attention_probs = stable_softmax(logits=attention_scores, axis=-1)
400
+
401
+ # This is actually dropping out entire tokens to attend to, which might
402
+ # seem a bit unusual, but is taken from the original Transformer paper.
403
+ attention_probs = self.dropout(inputs=attention_probs, training=training)
404
+
405
+ # Mask heads if we want to
406
+ if head_mask is not None:
407
+ attention_probs = tf.multiply(attention_probs, head_mask)
408
+
409
+ attention_output = tf.matmul(attention_probs, value_layer)
410
+ attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
411
+
412
+ # (batch_size, seq_len_q, all_head_size)
413
+ attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
414
+ outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
415
+
416
+ if self.is_decoder:
417
+ outputs = outputs + (past_key_value,)
418
+ return outputs
419
+
420
+ def build(self, input_shape=None):
421
+ if self.built:
422
+ return
423
+ self.built = True
424
+ if getattr(self, "query", None) is not None:
425
+ with tf.name_scope(self.query.name):
426
+ self.query.build([None, None, self.config.hidden_size])
427
+ if getattr(self, "key", None) is not None:
428
+ with tf.name_scope(self.key.name):
429
+ self.key.build([None, None, self.config.hidden_size])
430
+ if getattr(self, "value", None) is not None:
431
+ with tf.name_scope(self.value.name):
432
+ self.value.build([None, None, self.config.hidden_size])
433
+
434
+
435
+ # Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->Camembert
436
+ class TFCamembertSelfOutput(keras.layers.Layer):
437
+ def __init__(self, config: CamembertConfig, **kwargs):
438
+ super().__init__(**kwargs)
439
+
440
+ self.dense = keras.layers.Dense(
441
+ units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
442
+ )
443
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
444
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
445
+ self.config = config
446
+
447
+ def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
448
+ hidden_states = self.dense(inputs=hidden_states)
449
+ hidden_states = self.dropout(inputs=hidden_states, training=training)
450
+ hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
451
+
452
+ return hidden_states
453
+
454
+ def build(self, input_shape=None):
455
+ if self.built:
456
+ return
457
+ self.built = True
458
+ if getattr(self, "dense", None) is not None:
459
+ with tf.name_scope(self.dense.name):
460
+ self.dense.build([None, None, self.config.hidden_size])
461
+ if getattr(self, "LayerNorm", None) is not None:
462
+ with tf.name_scope(self.LayerNorm.name):
463
+ self.LayerNorm.build([None, None, self.config.hidden_size])
464
+
465
+
466
+ # Copied from transformers.models.bert.modeling_tf_bert.TFBertAttention with Bert->Camembert
467
+ class TFCamembertAttention(keras.layers.Layer):
468
+ def __init__(self, config: CamembertConfig, **kwargs):
469
+ super().__init__(**kwargs)
470
+
471
+ self.self_attention = TFCamembertSelfAttention(config, name="self")
472
+ self.dense_output = TFCamembertSelfOutput(config, name="output")
473
+
474
+ def prune_heads(self, heads):
475
+ raise NotImplementedError
476
+
477
+ def call(
478
+ self,
479
+ input_tensor: tf.Tensor,
480
+ attention_mask: tf.Tensor,
481
+ head_mask: tf.Tensor,
482
+ encoder_hidden_states: tf.Tensor,
483
+ encoder_attention_mask: tf.Tensor,
484
+ past_key_value: Tuple[tf.Tensor],
485
+ output_attentions: bool,
486
+ training: bool = False,
487
+ ) -> Tuple[tf.Tensor]:
488
+ self_outputs = self.self_attention(
489
+ hidden_states=input_tensor,
490
+ attention_mask=attention_mask,
491
+ head_mask=head_mask,
492
+ encoder_hidden_states=encoder_hidden_states,
493
+ encoder_attention_mask=encoder_attention_mask,
494
+ past_key_value=past_key_value,
495
+ output_attentions=output_attentions,
496
+ training=training,
497
+ )
498
+ attention_output = self.dense_output(
499
+ hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
500
+ )
501
+ # add attentions (possibly with past_key_value) if we output them
502
+ outputs = (attention_output,) + self_outputs[1:]
503
+
504
+ return outputs
505
+
506
+ def build(self, input_shape=None):
507
+ if self.built:
508
+ return
509
+ self.built = True
510
+ if getattr(self, "self_attention", None) is not None:
511
+ with tf.name_scope(self.self_attention.name):
512
+ self.self_attention.build(None)
513
+ if getattr(self, "dense_output", None) is not None:
514
+ with tf.name_scope(self.dense_output.name):
515
+ self.dense_output.build(None)
516
+
517
+
518
+ # Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->Camembert
519
+ class TFCamembertIntermediate(keras.layers.Layer):
520
+ def __init__(self, config: CamembertConfig, **kwargs):
521
+ super().__init__(**kwargs)
522
+
523
+ self.dense = keras.layers.Dense(
524
+ units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
525
+ )
526
+
527
+ if isinstance(config.hidden_act, str):
528
+ self.intermediate_act_fn = get_tf_activation(config.hidden_act)
529
+ else:
530
+ self.intermediate_act_fn = config.hidden_act
531
+ self.config = config
532
+
533
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
534
+ hidden_states = self.dense(inputs=hidden_states)
535
+ hidden_states = self.intermediate_act_fn(hidden_states)
536
+
537
+ return hidden_states
538
+
539
+ def build(self, input_shape=None):
540
+ if self.built:
541
+ return
542
+ self.built = True
543
+ if getattr(self, "dense", None) is not None:
544
+ with tf.name_scope(self.dense.name):
545
+ self.dense.build([None, None, self.config.hidden_size])
546
+
547
+
548
+ # Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->Camembert
549
+ class TFCamembertOutput(keras.layers.Layer):
550
+ def __init__(self, config: CamembertConfig, **kwargs):
551
+ super().__init__(**kwargs)
552
+
553
+ self.dense = keras.layers.Dense(
554
+ units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
555
+ )
556
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
557
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
558
+ self.config = config
559
+
560
+ def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
561
+ hidden_states = self.dense(inputs=hidden_states)
562
+ hidden_states = self.dropout(inputs=hidden_states, training=training)
563
+ hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
564
+
565
+ return hidden_states
566
+
567
+ def build(self, input_shape=None):
568
+ if self.built:
569
+ return
570
+ self.built = True
571
+ if getattr(self, "dense", None) is not None:
572
+ with tf.name_scope(self.dense.name):
573
+ self.dense.build([None, None, self.config.intermediate_size])
574
+ if getattr(self, "LayerNorm", None) is not None:
575
+ with tf.name_scope(self.LayerNorm.name):
576
+ self.LayerNorm.build([None, None, self.config.hidden_size])
577
+
578
+
579
+ # Copied from transformers.models.bert.modeling_tf_bert.TFBertLayer with Bert->Camembert
580
+ class TFCamembertLayer(keras.layers.Layer):
581
+ def __init__(self, config: CamembertConfig, **kwargs):
582
+ super().__init__(**kwargs)
583
+
584
+ self.attention = TFCamembertAttention(config, name="attention")
585
+ self.is_decoder = config.is_decoder
586
+ self.add_cross_attention = config.add_cross_attention
587
+ if self.add_cross_attention:
588
+ if not self.is_decoder:
589
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
590
+ self.crossattention = TFCamembertAttention(config, name="crossattention")
591
+ self.intermediate = TFCamembertIntermediate(config, name="intermediate")
592
+ self.bert_output = TFCamembertOutput(config, name="output")
593
+
594
+ def call(
595
+ self,
596
+ hidden_states: tf.Tensor,
597
+ attention_mask: tf.Tensor,
598
+ head_mask: tf.Tensor,
599
+ encoder_hidden_states: tf.Tensor | None,
600
+ encoder_attention_mask: tf.Tensor | None,
601
+ past_key_value: Tuple[tf.Tensor] | None,
602
+ output_attentions: bool,
603
+ training: bool = False,
604
+ ) -> Tuple[tf.Tensor]:
605
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
606
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
607
+ self_attention_outputs = self.attention(
608
+ input_tensor=hidden_states,
609
+ attention_mask=attention_mask,
610
+ head_mask=head_mask,
611
+ encoder_hidden_states=None,
612
+ encoder_attention_mask=None,
613
+ past_key_value=self_attn_past_key_value,
614
+ output_attentions=output_attentions,
615
+ training=training,
616
+ )
617
+ attention_output = self_attention_outputs[0]
618
+
619
+ # if decoder, the last output is tuple of self-attn cache
620
+ if self.is_decoder:
621
+ outputs = self_attention_outputs[1:-1]
622
+ present_key_value = self_attention_outputs[-1]
623
+ else:
624
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
625
+
626
+ cross_attn_present_key_value = None
627
+ if self.is_decoder and encoder_hidden_states is not None:
628
+ if not hasattr(self, "crossattention"):
629
+ raise ValueError(
630
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
631
+ " by setting `config.add_cross_attention=True`"
632
+ )
633
+
634
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
635
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
636
+ cross_attention_outputs = self.crossattention(
637
+ input_tensor=attention_output,
638
+ attention_mask=attention_mask,
639
+ head_mask=head_mask,
640
+ encoder_hidden_states=encoder_hidden_states,
641
+ encoder_attention_mask=encoder_attention_mask,
642
+ past_key_value=cross_attn_past_key_value,
643
+ output_attentions=output_attentions,
644
+ training=training,
645
+ )
646
+ attention_output = cross_attention_outputs[0]
647
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
648
+
649
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
650
+ cross_attn_present_key_value = cross_attention_outputs[-1]
651
+ present_key_value = present_key_value + cross_attn_present_key_value
652
+
653
+ intermediate_output = self.intermediate(hidden_states=attention_output)
654
+ layer_output = self.bert_output(
655
+ hidden_states=intermediate_output, input_tensor=attention_output, training=training
656
+ )
657
+ outputs = (layer_output,) + outputs # add attentions if we output them
658
+
659
+ # if decoder, return the attn key/values as the last output
660
+ if self.is_decoder:
661
+ outputs = outputs + (present_key_value,)
662
+
663
+ return outputs
664
+
665
+ def build(self, input_shape=None):
666
+ if self.built:
667
+ return
668
+ self.built = True
669
+ if getattr(self, "attention", None) is not None:
670
+ with tf.name_scope(self.attention.name):
671
+ self.attention.build(None)
672
+ if getattr(self, "intermediate", None) is not None:
673
+ with tf.name_scope(self.intermediate.name):
674
+ self.intermediate.build(None)
675
+ if getattr(self, "bert_output", None) is not None:
676
+ with tf.name_scope(self.bert_output.name):
677
+ self.bert_output.build(None)
678
+ if getattr(self, "crossattention", None) is not None:
679
+ with tf.name_scope(self.crossattention.name):
680
+ self.crossattention.build(None)
681
+
682
+
683
+ # Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->Camembert
684
+ class TFCamembertEncoder(keras.layers.Layer):
685
+ def __init__(self, config: CamembertConfig, **kwargs):
686
+ super().__init__(**kwargs)
687
+ self.config = config
688
+ self.layer = [TFCamembertLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
689
+
690
+ def call(
691
+ self,
692
+ hidden_states: tf.Tensor,
693
+ attention_mask: tf.Tensor,
694
+ head_mask: tf.Tensor,
695
+ encoder_hidden_states: tf.Tensor | None,
696
+ encoder_attention_mask: tf.Tensor | None,
697
+ past_key_values: Tuple[Tuple[tf.Tensor]] | None,
698
+ use_cache: Optional[bool],
699
+ output_attentions: bool,
700
+ output_hidden_states: bool,
701
+ return_dict: bool,
702
+ training: bool = False,
703
+ ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
704
+ all_hidden_states = () if output_hidden_states else None
705
+ all_attentions = () if output_attentions else None
706
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
707
+
708
+ next_decoder_cache = () if use_cache else None
709
+ for i, layer_module in enumerate(self.layer):
710
+ if output_hidden_states:
711
+ all_hidden_states = all_hidden_states + (hidden_states,)
712
+
713
+ past_key_value = past_key_values[i] if past_key_values is not None else None
714
+
715
+ layer_outputs = layer_module(
716
+ hidden_states=hidden_states,
717
+ attention_mask=attention_mask,
718
+ head_mask=head_mask[i],
719
+ encoder_hidden_states=encoder_hidden_states,
720
+ encoder_attention_mask=encoder_attention_mask,
721
+ past_key_value=past_key_value,
722
+ output_attentions=output_attentions,
723
+ training=training,
724
+ )
725
+ hidden_states = layer_outputs[0]
726
+
727
+ if use_cache:
728
+ next_decoder_cache += (layer_outputs[-1],)
729
+
730
+ if output_attentions:
731
+ all_attentions = all_attentions + (layer_outputs[1],)
732
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
733
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
734
+
735
+ # Add last layer
736
+ if output_hidden_states:
737
+ all_hidden_states = all_hidden_states + (hidden_states,)
738
+
739
+ if not return_dict:
740
+ return tuple(
741
+ v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None
742
+ )
743
+
744
+ return TFBaseModelOutputWithPastAndCrossAttentions(
745
+ last_hidden_state=hidden_states,
746
+ past_key_values=next_decoder_cache,
747
+ hidden_states=all_hidden_states,
748
+ attentions=all_attentions,
749
+ cross_attentions=all_cross_attentions,
750
+ )
751
+
752
+ def build(self, input_shape=None):
753
+ if self.built:
754
+ return
755
+ self.built = True
756
+ if getattr(self, "layer", None) is not None:
757
+ for layer in self.layer:
758
+ with tf.name_scope(layer.name):
759
+ layer.build(None)
760
+
761
+
762
+ @keras_serializable
763
+ # Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaMainLayer with Roberta->Camembert
764
+ class TFCamembertMainLayer(keras.layers.Layer):
765
+ config_class = CamembertConfig
766
+
767
+ def __init__(self, config, add_pooling_layer=True, **kwargs):
768
+ super().__init__(**kwargs)
769
+
770
+ self.config = config
771
+ self.is_decoder = config.is_decoder
772
+
773
+ self.num_hidden_layers = config.num_hidden_layers
774
+ self.initializer_range = config.initializer_range
775
+ self.output_attentions = config.output_attentions
776
+ self.output_hidden_states = config.output_hidden_states
777
+ self.return_dict = config.use_return_dict
778
+ self.encoder = TFCamembertEncoder(config, name="encoder")
779
+ self.pooler = TFCamembertPooler(config, name="pooler") if add_pooling_layer else None
780
+ # The embeddings must be the last declaration in order to follow the weights order
781
+ self.embeddings = TFCamembertEmbeddings(config, name="embeddings")
782
+
783
+ # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.get_input_embeddings
784
+ def get_input_embeddings(self) -> keras.layers.Layer:
785
+ return self.embeddings
786
+
787
+ # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.set_input_embeddings
788
+ def set_input_embeddings(self, value: tf.Variable):
789
+ self.embeddings.weight = value
790
+ self.embeddings.vocab_size = shape_list(value)[0]
791
+
792
+ # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer._prune_heads
793
+ def _prune_heads(self, heads_to_prune):
794
+ """
795
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
796
+ class PreTrainedModel
797
+ """
798
+ raise NotImplementedError
799
+
800
+ @unpack_inputs
801
+ # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.call
802
+ def call(
803
+ self,
804
+ input_ids: TFModelInputType | None = None,
805
+ attention_mask: np.ndarray | tf.Tensor | None = None,
806
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
807
+ position_ids: np.ndarray | tf.Tensor | None = None,
808
+ head_mask: np.ndarray | tf.Tensor | None = None,
809
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
810
+ encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
811
+ encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
812
+ past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
813
+ use_cache: Optional[bool] = None,
814
+ output_attentions: Optional[bool] = None,
815
+ output_hidden_states: Optional[bool] = None,
816
+ return_dict: Optional[bool] = None,
817
+ training: bool = False,
818
+ ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
819
+ if not self.config.is_decoder:
820
+ use_cache = False
821
+
822
+ if input_ids is not None and inputs_embeds is not None:
823
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
824
+ elif input_ids is not None:
825
+ input_shape = shape_list(input_ids)
826
+ elif inputs_embeds is not None:
827
+ input_shape = shape_list(inputs_embeds)[:-1]
828
+ else:
829
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
830
+
831
+ batch_size, seq_length = input_shape
832
+
833
+ if past_key_values is None:
834
+ past_key_values_length = 0
835
+ past_key_values = [None] * len(self.encoder.layer)
836
+ else:
837
+ past_key_values_length = shape_list(past_key_values[0][0])[-2]
838
+
839
+ if attention_mask is None:
840
+ attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1)
841
+
842
+ if token_type_ids is None:
843
+ token_type_ids = tf.fill(dims=input_shape, value=0)
844
+
845
+ embedding_output = self.embeddings(
846
+ input_ids=input_ids,
847
+ position_ids=position_ids,
848
+ token_type_ids=token_type_ids,
849
+ inputs_embeds=inputs_embeds,
850
+ past_key_values_length=past_key_values_length,
851
+ training=training,
852
+ )
853
+
854
+ # We create a 3D attention mask from a 2D tensor mask.
855
+ # Sizes are [batch_size, 1, 1, to_seq_length]
856
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
857
+ # this attention mask is more simple than the triangular masking of causal attention
858
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
859
+ attention_mask_shape = shape_list(attention_mask)
860
+
861
+ mask_seq_length = seq_length + past_key_values_length
862
+ # Copied from `modeling_tf_t5.py`
863
+ # Provided a padding mask of dimensions [batch_size, mask_seq_length]
864
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
865
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
866
+ if self.is_decoder:
867
+ seq_ids = tf.range(mask_seq_length)
868
+ causal_mask = tf.less_equal(
869
+ tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)),
870
+ seq_ids[None, :, None],
871
+ )
872
+ causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype)
873
+ extended_attention_mask = causal_mask * attention_mask[:, None, :]
874
+ attention_mask_shape = shape_list(extended_attention_mask)
875
+ extended_attention_mask = tf.reshape(
876
+ extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2])
877
+ )
878
+ if past_key_values[0] is not None:
879
+ # attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length]
880
+ extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :]
881
+ else:
882
+ extended_attention_mask = tf.reshape(
883
+ attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1])
884
+ )
885
+
886
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
887
+ # masked positions, this operation will create a tensor which is 0.0 for
888
+ # positions we want to attend and -10000.0 for masked positions.
889
+ # Since we are adding it to the raw scores before the softmax, this is
890
+ # effectively the same as removing these entirely.
891
+ extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype)
892
+ one_cst = tf.constant(1.0, dtype=embedding_output.dtype)
893
+ ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
894
+ extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
895
+
896
+ # Copied from `modeling_tf_t5.py` with -1e9 -> -10000
897
+ if self.is_decoder and encoder_attention_mask is not None:
898
+ # If a 2D ou 3D attention mask is provided for the cross-attention
899
+ # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
900
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
901
+ encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype)
902
+ num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask))
903
+ if num_dims_encoder_attention_mask == 3:
904
+ encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
905
+ if num_dims_encoder_attention_mask == 2:
906
+ encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
907
+
908
+ # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
909
+ # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
910
+ # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask,
911
+ # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2)))
912
+
913
+ encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
914
+ else:
915
+ encoder_extended_attention_mask = None
916
+
917
+ # Prepare head mask if needed
918
+ # 1.0 in head_mask indicate we keep the head
919
+ # attention_probs has shape bsz x n_heads x N x N
920
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
921
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
922
+ if head_mask is not None:
923
+ raise NotImplementedError
924
+ else:
925
+ head_mask = [None] * self.config.num_hidden_layers
926
+
927
+ encoder_outputs = self.encoder(
928
+ hidden_states=embedding_output,
929
+ attention_mask=extended_attention_mask,
930
+ head_mask=head_mask,
931
+ encoder_hidden_states=encoder_hidden_states,
932
+ encoder_attention_mask=encoder_extended_attention_mask,
933
+ past_key_values=past_key_values,
934
+ use_cache=use_cache,
935
+ output_attentions=output_attentions,
936
+ output_hidden_states=output_hidden_states,
937
+ return_dict=return_dict,
938
+ training=training,
939
+ )
940
+
941
+ sequence_output = encoder_outputs[0]
942
+ pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None
943
+
944
+ if not return_dict:
945
+ return (
946
+ sequence_output,
947
+ pooled_output,
948
+ ) + encoder_outputs[1:]
949
+
950
+ return TFBaseModelOutputWithPoolingAndCrossAttentions(
951
+ last_hidden_state=sequence_output,
952
+ pooler_output=pooled_output,
953
+ past_key_values=encoder_outputs.past_key_values,
954
+ hidden_states=encoder_outputs.hidden_states,
955
+ attentions=encoder_outputs.attentions,
956
+ cross_attentions=encoder_outputs.cross_attentions,
957
+ )
958
+
959
+ def build(self, input_shape=None):
960
+ if self.built:
961
+ return
962
+ self.built = True
963
+ if getattr(self, "encoder", None) is not None:
964
+ with tf.name_scope(self.encoder.name):
965
+ self.encoder.build(None)
966
+ if getattr(self, "pooler", None) is not None:
967
+ with tf.name_scope(self.pooler.name):
968
+ self.pooler.build(None)
969
+ if getattr(self, "embeddings", None) is not None:
970
+ with tf.name_scope(self.embeddings.name):
971
+ self.embeddings.build(None)
972
+
973
+
974
+ class TFCamembertPreTrainedModel(TFPreTrainedModel):
975
+ """
976
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
977
+ models.
978
+ """
979
+
980
+ config_class = CamembertConfig
981
+ base_model_prefix = "roberta"
982
+
983
+
984
+ @add_start_docstrings(
985
+ "The bare CamemBERT Model transformer outputting raw hidden-states without any specific head on top.",
986
+ CAMEMBERT_START_DOCSTRING,
987
+ )
988
+ # Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaModel with Roberta->Camembert, ROBERTA->CAMEMBERT
989
+ class TFCamembertModel(TFCamembertPreTrainedModel):
990
+ def __init__(self, config, *inputs, **kwargs):
991
+ super().__init__(config, *inputs, **kwargs)
992
+ self.roberta = TFCamembertMainLayer(config, name="roberta")
993
+
994
+ @unpack_inputs
995
+ @add_start_docstrings_to_model_forward(CAMEMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
996
+ @add_code_sample_docstrings(
997
+ checkpoint=_CHECKPOINT_FOR_DOC,
998
+ output_type=TFBaseModelOutputWithPoolingAndCrossAttentions,
999
+ config_class=_CONFIG_FOR_DOC,
1000
+ )
1001
+ def call(
1002
+ self,
1003
+ input_ids: TFModelInputType | None = None,
1004
+ attention_mask: np.ndarray | tf.Tensor | None = None,
1005
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
1006
+ position_ids: np.ndarray | tf.Tensor | None = None,
1007
+ head_mask: np.ndarray | tf.Tensor | None = None,
1008
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
1009
+ encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
1010
+ encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
1011
+ past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
1012
+ use_cache: Optional[bool] = None,
1013
+ output_attentions: Optional[bool] = None,
1014
+ output_hidden_states: Optional[bool] = None,
1015
+ return_dict: Optional[bool] = None,
1016
+ training: Optional[bool] = False,
1017
+ ) -> Union[Tuple, TFBaseModelOutputWithPoolingAndCrossAttentions]:
1018
+ r"""
1019
+ encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1020
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1021
+ the model is configured as a decoder.
1022
+ encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1023
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1024
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
1025
+
1026
+ - 1 for tokens that are **not masked**,
1027
+ - 0 for tokens that are **masked**.
1028
+
1029
+ past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`)
1030
+ contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1031
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1032
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1033
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1034
+ use_cache (`bool`, *optional*, defaults to `True`):
1035
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1036
+ `past_key_values`). Set to `False` during training, `True` during generation
1037
+ """
1038
+ outputs = self.roberta(
1039
+ input_ids=input_ids,
1040
+ attention_mask=attention_mask,
1041
+ token_type_ids=token_type_ids,
1042
+ position_ids=position_ids,
1043
+ head_mask=head_mask,
1044
+ inputs_embeds=inputs_embeds,
1045
+ encoder_hidden_states=encoder_hidden_states,
1046
+ encoder_attention_mask=encoder_attention_mask,
1047
+ past_key_values=past_key_values,
1048
+ use_cache=use_cache,
1049
+ output_attentions=output_attentions,
1050
+ output_hidden_states=output_hidden_states,
1051
+ return_dict=return_dict,
1052
+ training=training,
1053
+ )
1054
+
1055
+ return outputs
1056
+
1057
+ def build(self, input_shape=None):
1058
+ if self.built:
1059
+ return
1060
+ self.built = True
1061
+ if getattr(self, "roberta", None) is not None:
1062
+ with tf.name_scope(self.roberta.name):
1063
+ self.roberta.build(None)
1064
+
1065
+
1066
+ # Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaLMHead with Roberta->Camembert
1067
+ class TFCamembertLMHead(keras.layers.Layer):
1068
+ """Camembert Head for masked language modeling."""
1069
+
1070
+ def __init__(self, config, input_embeddings, **kwargs):
1071
+ super().__init__(**kwargs)
1072
+
1073
+ self.config = config
1074
+ self.hidden_size = config.hidden_size
1075
+ self.dense = keras.layers.Dense(
1076
+ config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
1077
+ )
1078
+ self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
1079
+ self.act = get_tf_activation("gelu")
1080
+
1081
+ # The output weights are the same as the input embeddings, but there is
1082
+ # an output-only bias for each token.
1083
+ self.decoder = input_embeddings
1084
+
1085
+ def build(self, input_shape=None):
1086
+ self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
1087
+
1088
+ if self.built:
1089
+ return
1090
+ self.built = True
1091
+ if getattr(self, "dense", None) is not None:
1092
+ with tf.name_scope(self.dense.name):
1093
+ self.dense.build([None, None, self.config.hidden_size])
1094
+ if getattr(self, "layer_norm", None) is not None:
1095
+ with tf.name_scope(self.layer_norm.name):
1096
+ self.layer_norm.build([None, None, self.config.hidden_size])
1097
+
1098
+ def get_output_embeddings(self):
1099
+ return self.decoder
1100
+
1101
+ def set_output_embeddings(self, value):
1102
+ self.decoder.weight = value
1103
+ self.decoder.vocab_size = shape_list(value)[0]
1104
+
1105
+ def get_bias(self):
1106
+ return {"bias": self.bias}
1107
+
1108
+ def set_bias(self, value):
1109
+ self.bias = value["bias"]
1110
+ self.config.vocab_size = shape_list(value["bias"])[0]
1111
+
1112
+ def call(self, hidden_states):
1113
+ hidden_states = self.dense(hidden_states)
1114
+ hidden_states = self.act(hidden_states)
1115
+ hidden_states = self.layer_norm(hidden_states)
1116
+
1117
+ # project back to size of vocabulary with bias
1118
+ seq_length = shape_list(tensor=hidden_states)[1]
1119
+ hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.hidden_size])
1120
+ hidden_states = tf.matmul(a=hidden_states, b=self.decoder.weight, transpose_b=True)
1121
+ hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size])
1122
+ hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias)
1123
+
1124
+ return hidden_states
1125
+
1126
+
1127
+ @add_start_docstrings(
1128
+ """CamemBERT Model with a `language modeling` head on top.""",
1129
+ CAMEMBERT_START_DOCSTRING,
1130
+ )
1131
+ # Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForMaskedLM with Roberta->Camembert, ROBERTA->CAMEMBERT
1132
+ class TFCamembertForMaskedLM(TFCamembertPreTrainedModel, TFMaskedLanguageModelingLoss):
1133
+ # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
1134
+ _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head.decoder.weight"]
1135
+
1136
+ def __init__(self, config, *inputs, **kwargs):
1137
+ super().__init__(config, *inputs, **kwargs)
1138
+
1139
+ self.roberta = TFCamembertMainLayer(config, add_pooling_layer=False, name="roberta")
1140
+ self.lm_head = TFCamembertLMHead(config, self.roberta.embeddings, name="lm_head")
1141
+
1142
+ def get_lm_head(self):
1143
+ return self.lm_head
1144
+
1145
+ def get_prefix_bias_name(self):
1146
+ warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
1147
+ return self.name + "/" + self.lm_head.name
1148
+
1149
+ @unpack_inputs
1150
+ @add_start_docstrings_to_model_forward(CAMEMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1151
+ @add_code_sample_docstrings(
1152
+ checkpoint=_CHECKPOINT_FOR_DOC,
1153
+ output_type=TFMaskedLMOutput,
1154
+ config_class=_CONFIG_FOR_DOC,
1155
+ mask="<mask>",
1156
+ expected_output="' Paris'",
1157
+ expected_loss=0.1,
1158
+ )
1159
+ def call(
1160
+ self,
1161
+ input_ids: TFModelInputType | None = None,
1162
+ attention_mask: np.ndarray | tf.Tensor | None = None,
1163
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
1164
+ position_ids: np.ndarray | tf.Tensor | None = None,
1165
+ head_mask: np.ndarray | tf.Tensor | None = None,
1166
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
1167
+ output_attentions: Optional[bool] = None,
1168
+ output_hidden_states: Optional[bool] = None,
1169
+ return_dict: Optional[bool] = None,
1170
+ labels: np.ndarray | tf.Tensor | None = None,
1171
+ training: Optional[bool] = False,
1172
+ ) -> Union[TFMaskedLMOutput, Tuple[tf.Tensor]]:
1173
+ r"""
1174
+ labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1175
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1176
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1177
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1178
+ """
1179
+ outputs = self.roberta(
1180
+ input_ids,
1181
+ attention_mask=attention_mask,
1182
+ token_type_ids=token_type_ids,
1183
+ position_ids=position_ids,
1184
+ head_mask=head_mask,
1185
+ inputs_embeds=inputs_embeds,
1186
+ output_attentions=output_attentions,
1187
+ output_hidden_states=output_hidden_states,
1188
+ return_dict=return_dict,
1189
+ training=training,
1190
+ )
1191
+
1192
+ sequence_output = outputs[0]
1193
+ prediction_scores = self.lm_head(sequence_output)
1194
+
1195
+ loss = None if labels is None else self.hf_compute_loss(labels, prediction_scores)
1196
+
1197
+ if not return_dict:
1198
+ output = (prediction_scores,) + outputs[2:]
1199
+ return ((loss,) + output) if loss is not None else output
1200
+
1201
+ return TFMaskedLMOutput(
1202
+ loss=loss,
1203
+ logits=prediction_scores,
1204
+ hidden_states=outputs.hidden_states,
1205
+ attentions=outputs.attentions,
1206
+ )
1207
+
1208
+ def build(self, input_shape=None):
1209
+ if self.built:
1210
+ return
1211
+ self.built = True
1212
+ if getattr(self, "roberta", None) is not None:
1213
+ with tf.name_scope(self.roberta.name):
1214
+ self.roberta.build(None)
1215
+ if getattr(self, "lm_head", None) is not None:
1216
+ with tf.name_scope(self.lm_head.name):
1217
+ self.lm_head.build(None)
1218
+
1219
+
1220
+ # Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaClassificationHead
1221
+ class TFCamembertClassificationHead(keras.layers.Layer):
1222
+ """Head for sentence-level classification tasks."""
1223
+
1224
+ def __init__(self, config, **kwargs):
1225
+ super().__init__(**kwargs)
1226
+ self.dense = keras.layers.Dense(
1227
+ config.hidden_size,
1228
+ kernel_initializer=get_initializer(config.initializer_range),
1229
+ activation="tanh",
1230
+ name="dense",
1231
+ )
1232
+ classifier_dropout = (
1233
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1234
+ )
1235
+ self.dropout = keras.layers.Dropout(classifier_dropout)
1236
+ self.out_proj = keras.layers.Dense(
1237
+ config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj"
1238
+ )
1239
+ self.config = config
1240
+
1241
+ def call(self, features, training=False):
1242
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
1243
+ x = self.dropout(x, training=training)
1244
+ x = self.dense(x)
1245
+ x = self.dropout(x, training=training)
1246
+ x = self.out_proj(x)
1247
+ return x
1248
+
1249
+ def build(self, input_shape=None):
1250
+ if self.built:
1251
+ return
1252
+ self.built = True
1253
+ if getattr(self, "dense", None) is not None:
1254
+ with tf.name_scope(self.dense.name):
1255
+ self.dense.build([None, None, self.config.hidden_size])
1256
+ if getattr(self, "out_proj", None) is not None:
1257
+ with tf.name_scope(self.out_proj.name):
1258
+ self.out_proj.build([None, None, self.config.hidden_size])
1259
+
1260
+
1261
+ @add_start_docstrings(
1262
+ """
1263
+ CamemBERT Model transformer with a sequence classification/regression head on top (a linear layer on top of the
1264
+ pooled output) e.g. for GLUE tasks.
1265
+ """,
1266
+ CAMEMBERT_START_DOCSTRING,
1267
+ )
1268
+ # Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForSequenceClassification with Roberta->Camembert, ROBERTA->CAMEMBERT
1269
+ class TFCamembertForSequenceClassification(TFCamembertPreTrainedModel, TFSequenceClassificationLoss):
1270
+ # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
1271
+ _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head"]
1272
+
1273
+ def __init__(self, config, *inputs, **kwargs):
1274
+ super().__init__(config, *inputs, **kwargs)
1275
+ self.num_labels = config.num_labels
1276
+
1277
+ self.roberta = TFCamembertMainLayer(config, add_pooling_layer=False, name="roberta")
1278
+ self.classifier = TFCamembertClassificationHead(config, name="classifier")
1279
+
1280
+ @unpack_inputs
1281
+ @add_start_docstrings_to_model_forward(CAMEMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1282
+ @add_code_sample_docstrings(
1283
+ checkpoint="cardiffnlp/twitter-roberta-base-emotion",
1284
+ output_type=TFSequenceClassifierOutput,
1285
+ config_class=_CONFIG_FOR_DOC,
1286
+ expected_output="'optimism'",
1287
+ expected_loss=0.08,
1288
+ )
1289
+ def call(
1290
+ self,
1291
+ input_ids: TFModelInputType | None = None,
1292
+ attention_mask: np.ndarray | tf.Tensor | None = None,
1293
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
1294
+ position_ids: np.ndarray | tf.Tensor | None = None,
1295
+ head_mask: np.ndarray | tf.Tensor | None = None,
1296
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
1297
+ output_attentions: Optional[bool] = None,
1298
+ output_hidden_states: Optional[bool] = None,
1299
+ return_dict: Optional[bool] = None,
1300
+ labels: np.ndarray | tf.Tensor | None = None,
1301
+ training: Optional[bool] = False,
1302
+ ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
1303
+ r"""
1304
+ labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
1305
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1306
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1307
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1308
+ """
1309
+ outputs = self.roberta(
1310
+ input_ids,
1311
+ attention_mask=attention_mask,
1312
+ token_type_ids=token_type_ids,
1313
+ position_ids=position_ids,
1314
+ head_mask=head_mask,
1315
+ inputs_embeds=inputs_embeds,
1316
+ output_attentions=output_attentions,
1317
+ output_hidden_states=output_hidden_states,
1318
+ return_dict=return_dict,
1319
+ training=training,
1320
+ )
1321
+ sequence_output = outputs[0]
1322
+ logits = self.classifier(sequence_output, training=training)
1323
+
1324
+ loss = None if labels is None else self.hf_compute_loss(labels, logits)
1325
+
1326
+ if not return_dict:
1327
+ output = (logits,) + outputs[2:]
1328
+ return ((loss,) + output) if loss is not None else output
1329
+
1330
+ return TFSequenceClassifierOutput(
1331
+ loss=loss,
1332
+ logits=logits,
1333
+ hidden_states=outputs.hidden_states,
1334
+ attentions=outputs.attentions,
1335
+ )
1336
+
1337
+ def build(self, input_shape=None):
1338
+ if self.built:
1339
+ return
1340
+ self.built = True
1341
+ if getattr(self, "roberta", None) is not None:
1342
+ with tf.name_scope(self.roberta.name):
1343
+ self.roberta.build(None)
1344
+ if getattr(self, "classifier", None) is not None:
1345
+ with tf.name_scope(self.classifier.name):
1346
+ self.classifier.build(None)
1347
+
1348
+
1349
+ @add_start_docstrings(
1350
+ """
1351
+ CamemBERT Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.
1352
+ for Named-Entity-Recognition (NER) tasks.
1353
+ """,
1354
+ CAMEMBERT_START_DOCSTRING,
1355
+ )
1356
+ # Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForTokenClassification with Roberta->Camembert, ROBERTA->CAMEMBERT
1357
+ class TFCamembertForTokenClassification(TFCamembertPreTrainedModel, TFTokenClassificationLoss):
1358
+ # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
1359
+ _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head"]
1360
+ _keys_to_ignore_on_load_missing = [r"dropout"]
1361
+
1362
+ def __init__(self, config, *inputs, **kwargs):
1363
+ super().__init__(config, *inputs, **kwargs)
1364
+ self.num_labels = config.num_labels
1365
+
1366
+ self.roberta = TFCamembertMainLayer(config, add_pooling_layer=False, name="roberta")
1367
+ classifier_dropout = (
1368
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1369
+ )
1370
+ self.dropout = keras.layers.Dropout(classifier_dropout)
1371
+ self.classifier = keras.layers.Dense(
1372
+ config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
1373
+ )
1374
+ self.config = config
1375
+
1376
+ @unpack_inputs
1377
+ @add_start_docstrings_to_model_forward(CAMEMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1378
+ @add_code_sample_docstrings(
1379
+ checkpoint="ydshieh/roberta-large-ner-english",
1380
+ output_type=TFTokenClassifierOutput,
1381
+ config_class=_CONFIG_FOR_DOC,
1382
+ expected_output="['O', 'ORG', 'ORG', 'O', 'O', 'O', 'O', 'O', 'LOC', 'O', 'LOC', 'LOC']",
1383
+ expected_loss=0.01,
1384
+ )
1385
+ def call(
1386
+ self,
1387
+ input_ids: TFModelInputType | None = None,
1388
+ attention_mask: np.ndarray | tf.Tensor | None = None,
1389
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
1390
+ position_ids: np.ndarray | tf.Tensor | None = None,
1391
+ head_mask: np.ndarray | tf.Tensor | None = None,
1392
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
1393
+ output_attentions: Optional[bool] = None,
1394
+ output_hidden_states: Optional[bool] = None,
1395
+ return_dict: Optional[bool] = None,
1396
+ labels: np.ndarray | tf.Tensor | None = None,
1397
+ training: Optional[bool] = False,
1398
+ ) -> Union[TFTokenClassifierOutput, Tuple[tf.Tensor]]:
1399
+ r"""
1400
+ labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1401
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1402
+ """
1403
+ outputs = self.roberta(
1404
+ input_ids,
1405
+ attention_mask=attention_mask,
1406
+ token_type_ids=token_type_ids,
1407
+ position_ids=position_ids,
1408
+ head_mask=head_mask,
1409
+ inputs_embeds=inputs_embeds,
1410
+ output_attentions=output_attentions,
1411
+ output_hidden_states=output_hidden_states,
1412
+ return_dict=return_dict,
1413
+ training=training,
1414
+ )
1415
+ sequence_output = outputs[0]
1416
+
1417
+ sequence_output = self.dropout(sequence_output, training=training)
1418
+ logits = self.classifier(sequence_output)
1419
+
1420
+ loss = None if labels is None else self.hf_compute_loss(labels, logits)
1421
+
1422
+ if not return_dict:
1423
+ output = (logits,) + outputs[2:]
1424
+ return ((loss,) + output) if loss is not None else output
1425
+
1426
+ return TFTokenClassifierOutput(
1427
+ loss=loss,
1428
+ logits=logits,
1429
+ hidden_states=outputs.hidden_states,
1430
+ attentions=outputs.attentions,
1431
+ )
1432
+
1433
+ def build(self, input_shape=None):
1434
+ if self.built:
1435
+ return
1436
+ self.built = True
1437
+ if getattr(self, "roberta", None) is not None:
1438
+ with tf.name_scope(self.roberta.name):
1439
+ self.roberta.build(None)
1440
+ if getattr(self, "classifier", None) is not None:
1441
+ with tf.name_scope(self.classifier.name):
1442
+ self.classifier.build([None, None, self.config.hidden_size])
1443
+
1444
+
1445
+ @add_start_docstrings(
1446
+ """
1447
+ CamemBERT Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
1448
+ softmax) e.g. for RocStories/SWAG tasks.
1449
+ """,
1450
+ CAMEMBERT_START_DOCSTRING,
1451
+ )
1452
+ # Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForMultipleChoice with Roberta->Camembert, ROBERTA->CAMEMBERT
1453
+ class TFCamembertForMultipleChoice(TFCamembertPreTrainedModel, TFMultipleChoiceLoss):
1454
+ # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
1455
+ _keys_to_ignore_on_load_unexpected = [r"lm_head"]
1456
+ _keys_to_ignore_on_load_missing = [r"dropout"]
1457
+
1458
+ def __init__(self, config, *inputs, **kwargs):
1459
+ super().__init__(config, *inputs, **kwargs)
1460
+
1461
+ self.roberta = TFCamembertMainLayer(config, name="roberta")
1462
+ self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
1463
+ self.classifier = keras.layers.Dense(
1464
+ 1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
1465
+ )
1466
+ self.config = config
1467
+
1468
+ @unpack_inputs
1469
+ @add_start_docstrings_to_model_forward(
1470
+ CAMEMBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
1471
+ )
1472
+ @add_code_sample_docstrings(
1473
+ checkpoint=_CHECKPOINT_FOR_DOC,
1474
+ output_type=TFMultipleChoiceModelOutput,
1475
+ config_class=_CONFIG_FOR_DOC,
1476
+ )
1477
+ def call(
1478
+ self,
1479
+ input_ids: TFModelInputType | None = None,
1480
+ attention_mask: np.ndarray | tf.Tensor | None = None,
1481
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
1482
+ position_ids: np.ndarray | tf.Tensor | None = None,
1483
+ head_mask: np.ndarray | tf.Tensor | None = None,
1484
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
1485
+ output_attentions: Optional[bool] = None,
1486
+ output_hidden_states: Optional[bool] = None,
1487
+ return_dict: Optional[bool] = None,
1488
+ labels: np.ndarray | tf.Tensor | None = None,
1489
+ training: Optional[bool] = False,
1490
+ ) -> Union[TFMultipleChoiceModelOutput, Tuple[tf.Tensor]]:
1491
+ r"""
1492
+ labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
1493
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
1494
+ where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
1495
+ """
1496
+
1497
+ if input_ids is not None:
1498
+ num_choices = shape_list(input_ids)[1]
1499
+ seq_length = shape_list(input_ids)[2]
1500
+ else:
1501
+ num_choices = shape_list(inputs_embeds)[1]
1502
+ seq_length = shape_list(inputs_embeds)[2]
1503
+
1504
+ flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
1505
+ flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
1506
+ flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
1507
+ flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
1508
+ outputs = self.roberta(
1509
+ flat_input_ids,
1510
+ flat_attention_mask,
1511
+ flat_token_type_ids,
1512
+ flat_position_ids,
1513
+ head_mask,
1514
+ inputs_embeds,
1515
+ output_attentions,
1516
+ output_hidden_states,
1517
+ return_dict=return_dict,
1518
+ training=training,
1519
+ )
1520
+ pooled_output = outputs[1]
1521
+ pooled_output = self.dropout(pooled_output, training=training)
1522
+ logits = self.classifier(pooled_output)
1523
+ reshaped_logits = tf.reshape(logits, (-1, num_choices))
1524
+
1525
+ loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits)
1526
+
1527
+ if not return_dict:
1528
+ output = (reshaped_logits,) + outputs[2:]
1529
+ return ((loss,) + output) if loss is not None else output
1530
+
1531
+ return TFMultipleChoiceModelOutput(
1532
+ loss=loss,
1533
+ logits=reshaped_logits,
1534
+ hidden_states=outputs.hidden_states,
1535
+ attentions=outputs.attentions,
1536
+ )
1537
+
1538
+ def build(self, input_shape=None):
1539
+ if self.built:
1540
+ return
1541
+ self.built = True
1542
+ if getattr(self, "roberta", None) is not None:
1543
+ with tf.name_scope(self.roberta.name):
1544
+ self.roberta.build(None)
1545
+ if getattr(self, "classifier", None) is not None:
1546
+ with tf.name_scope(self.classifier.name):
1547
+ self.classifier.build([None, None, self.config.hidden_size])
1548
+
1549
+
1550
+ @add_start_docstrings(
1551
+ """
1552
+ CamemBERT Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1553
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1554
+ """,
1555
+ CAMEMBERT_START_DOCSTRING,
1556
+ )
1557
+ # Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForQuestionAnswering with Roberta->Camembert, ROBERTA->CAMEMBERT
1558
+ class TFCamembertForQuestionAnswering(TFCamembertPreTrainedModel, TFQuestionAnsweringLoss):
1559
+ # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
1560
+ _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head"]
1561
+
1562
+ def __init__(self, config, *inputs, **kwargs):
1563
+ super().__init__(config, *inputs, **kwargs)
1564
+ self.num_labels = config.num_labels
1565
+
1566
+ self.roberta = TFCamembertMainLayer(config, add_pooling_layer=False, name="roberta")
1567
+ self.qa_outputs = keras.layers.Dense(
1568
+ config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
1569
+ )
1570
+ self.config = config
1571
+
1572
+ @unpack_inputs
1573
+ @add_start_docstrings_to_model_forward(CAMEMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1574
+ @add_code_sample_docstrings(
1575
+ checkpoint="ydshieh/roberta-base-squad2",
1576
+ output_type=TFQuestionAnsweringModelOutput,
1577
+ config_class=_CONFIG_FOR_DOC,
1578
+ expected_output="' puppet'",
1579
+ expected_loss=0.86,
1580
+ )
1581
+ def call(
1582
+ self,
1583
+ input_ids: TFModelInputType | None = None,
1584
+ attention_mask: np.ndarray | tf.Tensor | None = None,
1585
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
1586
+ position_ids: np.ndarray | tf.Tensor | None = None,
1587
+ head_mask: np.ndarray | tf.Tensor | None = None,
1588
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
1589
+ output_attentions: Optional[bool] = None,
1590
+ output_hidden_states: Optional[bool] = None,
1591
+ return_dict: Optional[bool] = None,
1592
+ start_positions: np.ndarray | tf.Tensor | None = None,
1593
+ end_positions: np.ndarray | tf.Tensor | None = None,
1594
+ training: Optional[bool] = False,
1595
+ ) -> Union[TFQuestionAnsweringModelOutput, Tuple[tf.Tensor]]:
1596
+ r"""
1597
+ start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
1598
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1599
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1600
+ are not taken into account for computing the loss.
1601
+ end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
1602
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1603
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1604
+ are not taken into account for computing the loss.
1605
+ """
1606
+ outputs = self.roberta(
1607
+ input_ids,
1608
+ attention_mask=attention_mask,
1609
+ token_type_ids=token_type_ids,
1610
+ position_ids=position_ids,
1611
+ head_mask=head_mask,
1612
+ inputs_embeds=inputs_embeds,
1613
+ output_attentions=output_attentions,
1614
+ output_hidden_states=output_hidden_states,
1615
+ return_dict=return_dict,
1616
+ training=training,
1617
+ )
1618
+ sequence_output = outputs[0]
1619
+
1620
+ logits = self.qa_outputs(sequence_output)
1621
+ start_logits, end_logits = tf.split(logits, 2, axis=-1)
1622
+ start_logits = tf.squeeze(start_logits, axis=-1)
1623
+ end_logits = tf.squeeze(end_logits, axis=-1)
1624
+
1625
+ loss = None
1626
+ if start_positions is not None and end_positions is not None:
1627
+ labels = {"start_position": start_positions}
1628
+ labels["end_position"] = end_positions
1629
+ loss = self.hf_compute_loss(labels, (start_logits, end_logits))
1630
+
1631
+ if not return_dict:
1632
+ output = (start_logits, end_logits) + outputs[2:]
1633
+ return ((loss,) + output) if loss is not None else output
1634
+
1635
+ return TFQuestionAnsweringModelOutput(
1636
+ loss=loss,
1637
+ start_logits=start_logits,
1638
+ end_logits=end_logits,
1639
+ hidden_states=outputs.hidden_states,
1640
+ attentions=outputs.attentions,
1641
+ )
1642
+
1643
+ def build(self, input_shape=None):
1644
+ if self.built:
1645
+ return
1646
+ self.built = True
1647
+ if getattr(self, "roberta", None) is not None:
1648
+ with tf.name_scope(self.roberta.name):
1649
+ self.roberta.build(None)
1650
+ if getattr(self, "qa_outputs", None) is not None:
1651
+ with tf.name_scope(self.qa_outputs.name):
1652
+ self.qa_outputs.build([None, None, self.config.hidden_size])
1653
+
1654
+
1655
+ @add_start_docstrings(
1656
+ """CamemBERT Model with a `language modeling` head on top for CLM fine-tuning.""", CAMEMBERT_START_DOCSTRING
1657
+ )
1658
+ # Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForCausalLM with Roberta->Camembert, ROBERTA->CAMEMBERT
1659
+ class TFCamembertForCausalLM(TFCamembertPreTrainedModel, TFCausalLanguageModelingLoss):
1660
+ # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
1661
+ _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head.decoder.weight"]
1662
+
1663
+ def __init__(self, config: CamembertConfig, *inputs, **kwargs):
1664
+ super().__init__(config, *inputs, **kwargs)
1665
+
1666
+ if not config.is_decoder:
1667
+ logger.warning("If you want to use `TFCamembertLMHeadModel` as a standalone, add `is_decoder=True.`")
1668
+
1669
+ self.roberta = TFCamembertMainLayer(config, add_pooling_layer=False, name="roberta")
1670
+ self.lm_head = TFCamembertLMHead(config, input_embeddings=self.roberta.embeddings, name="lm_head")
1671
+
1672
+ def get_lm_head(self):
1673
+ return self.lm_head
1674
+
1675
+ def get_prefix_bias_name(self):
1676
+ warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
1677
+ return self.name + "/" + self.lm_head.name
1678
+
1679
+ # Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.prepare_inputs_for_generation
1680
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
1681
+ input_shape = input_ids.shape
1682
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1683
+ if attention_mask is None:
1684
+ attention_mask = tf.ones(input_shape)
1685
+
1686
+ # cut decoder_input_ids if past is used
1687
+ if past_key_values is not None:
1688
+ input_ids = input_ids[:, -1:]
1689
+
1690
+ return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
1691
+
1692
+ @unpack_inputs
1693
+ @add_start_docstrings_to_model_forward(CAMEMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1694
+ @add_code_sample_docstrings(
1695
+ checkpoint=_CHECKPOINT_FOR_DOC,
1696
+ output_type=TFCausalLMOutputWithCrossAttentions,
1697
+ config_class=_CONFIG_FOR_DOC,
1698
+ )
1699
+ def call(
1700
+ self,
1701
+ input_ids: TFModelInputType | None = None,
1702
+ attention_mask: np.ndarray | tf.Tensor | None = None,
1703
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
1704
+ position_ids: np.ndarray | tf.Tensor | None = None,
1705
+ head_mask: np.ndarray | tf.Tensor | None = None,
1706
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
1707
+ encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
1708
+ encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
1709
+ past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
1710
+ use_cache: Optional[bool] = None,
1711
+ output_attentions: Optional[bool] = None,
1712
+ output_hidden_states: Optional[bool] = None,
1713
+ return_dict: Optional[bool] = None,
1714
+ labels: np.ndarray | tf.Tensor | None = None,
1715
+ training: Optional[bool] = False,
1716
+ ) -> Union[TFCausalLMOutputWithCrossAttentions, Tuple[tf.Tensor]]:
1717
+ r"""
1718
+ encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1719
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1720
+ the model is configured as a decoder.
1721
+ encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1722
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1723
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
1724
+
1725
+ - 1 for tokens that are **not masked**,
1726
+ - 0 for tokens that are **masked**.
1727
+
1728
+ past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`)
1729
+ contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1730
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1731
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1732
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1733
+ use_cache (`bool`, *optional*, defaults to `True`):
1734
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1735
+ `past_key_values`). Set to `False` during training, `True` during generation
1736
+ labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
1737
+ Labels for computing the cross entropy classification loss. Indices should be in `[0, ...,
1738
+ config.vocab_size - 1]`.
1739
+ """
1740
+ outputs = self.roberta(
1741
+ input_ids=input_ids,
1742
+ attention_mask=attention_mask,
1743
+ token_type_ids=token_type_ids,
1744
+ position_ids=position_ids,
1745
+ head_mask=head_mask,
1746
+ inputs_embeds=inputs_embeds,
1747
+ encoder_hidden_states=encoder_hidden_states,
1748
+ encoder_attention_mask=encoder_attention_mask,
1749
+ past_key_values=past_key_values,
1750
+ use_cache=use_cache,
1751
+ output_attentions=output_attentions,
1752
+ output_hidden_states=output_hidden_states,
1753
+ return_dict=return_dict,
1754
+ training=training,
1755
+ )
1756
+
1757
+ sequence_output = outputs[0]
1758
+ logits = self.lm_head(hidden_states=sequence_output, training=training)
1759
+ loss = None
1760
+
1761
+ if labels is not None:
1762
+ # shift labels to the left and cut last logit token
1763
+ shifted_logits = logits[:, :-1]
1764
+ labels = labels[:, 1:]
1765
+ loss = self.hf_compute_loss(labels=labels, logits=shifted_logits)
1766
+
1767
+ if not return_dict:
1768
+ output = (logits,) + outputs[2:]
1769
+ return ((loss,) + output) if loss is not None else output
1770
+
1771
+ return TFCausalLMOutputWithCrossAttentions(
1772
+ loss=loss,
1773
+ logits=logits,
1774
+ past_key_values=outputs.past_key_values,
1775
+ hidden_states=outputs.hidden_states,
1776
+ attentions=outputs.attentions,
1777
+ cross_attentions=outputs.cross_attentions,
1778
+ )
1779
+
1780
+ def build(self, input_shape=None):
1781
+ if self.built:
1782
+ return
1783
+ self.built = True
1784
+ if getattr(self, "roberta", None) is not None:
1785
+ with tf.name_scope(self.roberta.name):
1786
+ self.roberta.build(None)
1787
+ if getattr(self, "lm_head", None) is not None:
1788
+ with tf.name_scope(self.lm_head.name):
1789
+ self.lm_head.build(None)
1790
+
1791
+
1792
+ __all__ = [
1793
+ "TFCamembertForCausalLM",
1794
+ "TFCamembertForMaskedLM",
1795
+ "TFCamembertForMultipleChoice",
1796
+ "TFCamembertForQuestionAnswering",
1797
+ "TFCamembertForSequenceClassification",
1798
+ "TFCamembertForTokenClassification",
1799
+ "TFCamembertModel",
1800
+ "TFCamembertPreTrainedModel",
1801
+ ]
docs/transformers/src/transformers/models/camembert/tokenization_camembert.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License
15
+ """Tokenization classes for Camembert model."""
16
+
17
+ import os
18
+ from shutil import copyfile
19
+ from typing import Any, Dict, List, Optional, Tuple
20
+
21
+ import sentencepiece as spm
22
+
23
+ from ...tokenization_utils import AddedToken, PreTrainedTokenizer
24
+ from ...utils import logging
25
+ from ...utils.import_utils import requires
26
+
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+ VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"}
31
+
32
+
33
+ SPIECE_UNDERLINE = "▁"
34
+
35
+
36
+ @requires(backends=("sentencepiece",))
37
+ class CamembertTokenizer(PreTrainedTokenizer):
38
+ """
39
+ Adapted from [`RobertaTokenizer`] and [`XLNetTokenizer`]. Construct a CamemBERT tokenizer. Based on
40
+ [SentencePiece](https://github.com/google/sentencepiece).
41
+
42
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
43
+ this superclass for more information regarding those methods.
44
+
45
+ Args:
46
+ vocab_file (`str`):
47
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
48
+ contains the vocabulary necessary to instantiate a tokenizer.
49
+ bos_token (`str`, *optional*, defaults to `"<s>"`):
50
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
51
+
52
+ <Tip>
53
+
54
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
55
+ sequence. The token used is the `cls_token`.
56
+
57
+ </Tip>
58
+
59
+ eos_token (`str`, *optional*, defaults to `"</s>"`):
60
+ The end of sequence token.
61
+
62
+ <Tip>
63
+
64
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
65
+ The token used is the `sep_token`.
66
+
67
+ </Tip>
68
+
69
+ sep_token (`str`, *optional*, defaults to `"</s>"`):
70
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
71
+ sequence classification or for a text and a question for question answering. It is also used as the last
72
+ token of a sequence built with special tokens.
73
+ cls_token (`str`, *optional*, defaults to `"<s>"`):
74
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
75
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
76
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
77
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
78
+ token instead.
79
+ pad_token (`str`, *optional*, defaults to `"<pad>"`):
80
+ The token used for padding, for example when batching sequences of different lengths.
81
+ mask_token (`str`, *optional*, defaults to `"<mask>"`):
82
+ The token used for masking values. This is the token used when training this model with masked language
83
+ modeling. This is the token which the model will try to predict.
84
+ additional_special_tokens (`List[str]`, *optional*, defaults to `['<s>NOTUSED', '</s>NOTUSED', '<unk>NOTUSED']`):
85
+ Additional special tokens used by the tokenizer.
86
+ sp_model_kwargs (`dict`, *optional*):
87
+ Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
88
+ SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
89
+ to set:
90
+
91
+ - `enable_sampling`: Enable subword regularization.
92
+ - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
93
+
94
+ - `nbest_size = {0,1}`: No sampling is performed.
95
+ - `nbest_size > 1`: samples from the nbest_size results.
96
+ - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
97
+ using forward-filtering-and-backward-sampling algorithm.
98
+
99
+ - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
100
+ BPE-dropout.
101
+
102
+ Attributes:
103
+ sp_model (`SentencePieceProcessor`):
104
+ The *SentencePiece* processor that is used for every conversion (string, tokens and IDs).
105
+ """
106
+
107
+ vocab_files_names = VOCAB_FILES_NAMES
108
+ model_input_names = ["input_ids", "attention_mask"]
109
+
110
+ def __init__(
111
+ self,
112
+ vocab_file,
113
+ bos_token="<s>",
114
+ eos_token="</s>",
115
+ sep_token="</s>",
116
+ cls_token="<s>",
117
+ unk_token="<unk>",
118
+ pad_token="<pad>",
119
+ mask_token="<mask>",
120
+ additional_special_tokens=["<s>NOTUSED", "</s>NOTUSED", "<unk>NOTUSED"],
121
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
122
+ **kwargs,
123
+ ) -> None:
124
+ # Mask token behave like a normal word, i.e. include the space before it
125
+ mask_token = (
126
+ AddedToken(mask_token, lstrip=True, rstrip=False, normalized=False, special=True)
127
+ if isinstance(mask_token, str)
128
+ else mask_token
129
+ )
130
+
131
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
132
+
133
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
134
+ self.sp_model.Load(str(vocab_file))
135
+ self.vocab_file = vocab_file
136
+
137
+ # HACK: These tokens were added by the author for an obscure reason as they were already part of the
138
+ # sentencepiece vocabulary (this is the case for <s> and </s> and <unk>).
139
+ # In this case it is recommended to properly set the tokens by hand.
140
+ self._added_tokens_decoder = {
141
+ 0: AddedToken("<s>NOTUSED", special=True),
142
+ 1: AddedToken(pad_token, special=True) if isinstance(pad_token, str) else pad_token,
143
+ 2: AddedToken("</s>NOTUSED", special=True),
144
+ 3: AddedToken(unk_token, special=True) if isinstance(unk_token, str) else unk_token,
145
+ 4: AddedToken("<unk>NOTUSED", special=True),
146
+ }
147
+
148
+ self.fairseq_offset = 4 # 3 tokens are newly added, but the offset starts from 4
149
+
150
+ # legacy: camemebert is a particular case were we have to make sure `"<unk>NOTUSED"` is here
151
+ if "added_tokens_decoder" in kwargs:
152
+ # this is the only class that requires this unfortunately.....
153
+ # the reason is that the fast version has a whole.
154
+ kwargs["added_tokens_decoder"].update(self._added_tokens_decoder)
155
+
156
+ super().__init__(
157
+ bos_token=bos_token,
158
+ eos_token=eos_token,
159
+ unk_token=unk_token,
160
+ sep_token=sep_token,
161
+ cls_token=cls_token,
162
+ pad_token=pad_token,
163
+ mask_token=mask_token,
164
+ additional_special_tokens=additional_special_tokens,
165
+ sp_model_kwargs=self.sp_model_kwargs,
166
+ **kwargs,
167
+ )
168
+
169
+ @property
170
+ def vocab_size(self):
171
+ # The length of the vocabulary without added tokens is len(self.sp_model) but the added tokens are added at the beginning.
172
+ return len(self.sp_model)
173
+
174
+ def get_vocab(self):
175
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size + self.fairseq_offset)}
176
+ vocab.update(self.added_tokens_encoder)
177
+ return vocab
178
+
179
+ def _tokenize(self, text: str) -> List[str]:
180
+ return self.sp_model.encode(text, out_type=str)
181
+
182
+ def _convert_token_to_id(self, token):
183
+ """Converts a token (str) in an id using the vocab."""
184
+ # specifi to camembert, both 3 and 4 point to the unk token.
185
+ if self.sp_model.PieceToId(token) == 0:
186
+ # Convert sentence piece unk token to fairseq unk token index
187
+ return self.unk_token_id
188
+ return self.fairseq_offset + self.sp_model.PieceToId(token)
189
+
190
+ def _convert_id_to_token(self, index):
191
+ """Converts an index (integer) in a token (str) using the vocab."""
192
+ return self.sp_model.IdToPiece(index - self.fairseq_offset)
193
+
194
+ def convert_tokens_to_string(self, tokens):
195
+ """Converts a sequence of tokens (string) in a single string."""
196
+ # TODO decode outputs do not match between fast and slow
197
+ current_sub_tokens = []
198
+ out_string = ""
199
+ prev_is_special = False
200
+ for token in tokens:
201
+ # make sure that special tokens are not decoded using sentencepiece model
202
+ if token in self.all_special_tokens:
203
+ if not prev_is_special:
204
+ out_string += " "
205
+ out_string += self.sp_model.decode(current_sub_tokens) + token
206
+ prev_is_special = True
207
+ current_sub_tokens = []
208
+ else:
209
+ current_sub_tokens.append(token)
210
+ prev_is_special = False
211
+ out_string += self.sp_model.decode(current_sub_tokens)
212
+ return out_string.strip()
213
+
214
+ def __getstate__(self):
215
+ state = self.__dict__.copy()
216
+ state["sp_model"] = None
217
+ return state
218
+
219
+ def __setstate__(self, d):
220
+ self.__dict__ = d
221
+
222
+ # for backward compatibility
223
+ if not hasattr(self, "sp_model_kwargs"):
224
+ self.sp_model_kwargs = {}
225
+
226
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
227
+ self.sp_model.Load(self.vocab_file)
228
+
229
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
230
+ if not os.path.isdir(save_directory):
231
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
232
+ return
233
+ out_vocab_file = os.path.join(
234
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
235
+ )
236
+
237
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
238
+ copyfile(self.vocab_file, out_vocab_file)
239
+ elif not os.path.isfile(self.vocab_file):
240
+ with open(out_vocab_file, "wb") as fi:
241
+ content_spiece_model = self.sp_model.serialized_model_proto()
242
+ fi.write(content_spiece_model)
243
+
244
+ return (out_vocab_file,)
245
+
246
+ def build_inputs_with_special_tokens(
247
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
248
+ ) -> List[int]:
249
+ """
250
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
251
+ adding special tokens. An CamemBERT sequence has the following format:
252
+
253
+ - single sequence: `<s> X </s>`
254
+ - pair of sequences: `<s> A </s></s> B </s>`
255
+
256
+ Args:
257
+ token_ids_0 (`List[int]`):
258
+ List of IDs to which the special tokens will be added.
259
+ token_ids_1 (`List[int]`, *optional*):
260
+ Optional second list of IDs for sequence pairs.
261
+
262
+ Returns:
263
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
264
+ """
265
+
266
+ if token_ids_1 is None:
267
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
268
+ cls = [self.cls_token_id]
269
+ sep = [self.sep_token_id]
270
+ return cls + token_ids_0 + sep + sep + token_ids_1 + sep
271
+
272
+ def get_special_tokens_mask(
273
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
274
+ ) -> List[int]:
275
+ """
276
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
277
+ special tokens using the tokenizer `prepare_for_model` method.
278
+
279
+ Args:
280
+ token_ids_0 (`List[int]`):
281
+ List of IDs.
282
+ token_ids_1 (`List[int]`, *optional*):
283
+ Optional second list of IDs for sequence pairs.
284
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
285
+ Whether or not the token list is already formatted with special tokens for the model.
286
+
287
+ Returns:
288
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
289
+ """
290
+ if already_has_special_tokens:
291
+ return super().get_special_tokens_mask(
292
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
293
+ )
294
+
295
+ if token_ids_1 is None:
296
+ return [1] + ([0] * len(token_ids_0)) + [1]
297
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
298
+
299
+ def create_token_type_ids_from_sequences(
300
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
301
+ ) -> List[int]:
302
+ """
303
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. CamemBERT, like
304
+ RoBERTa, does not make use of token type ids, therefore a list of zeros is returned.
305
+
306
+ Args:
307
+ token_ids_0 (`List[int]`):
308
+ List of IDs.
309
+ token_ids_1 (`List[int]`, *optional*):
310
+ Optional second list of IDs for sequence pairs.
311
+
312
+ Returns:
313
+ `List[int]`: List of zeros.
314
+ """
315
+ sep = [self.sep_token_id]
316
+ cls = [self.cls_token_id]
317
+
318
+ if token_ids_1 is None:
319
+ return len(cls + token_ids_0 + sep) * [0]
320
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
321
+
322
+
323
+ __all__ = ["CamembertTokenizer"]
docs/transformers/src/transformers/models/camembert/tokenization_camembert_fast.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License
15
+ """Fast tokenization classes for Camembert model."""
16
+
17
+ import os
18
+ from shutil import copyfile
19
+ from typing import List, Optional, Tuple
20
+
21
+ from ...tokenization_utils import AddedToken
22
+ from ...tokenization_utils_fast import PreTrainedTokenizerFast
23
+ from ...utils import is_sentencepiece_available, logging
24
+
25
+
26
+ if is_sentencepiece_available():
27
+ from .tokenization_camembert import CamembertTokenizer
28
+ else:
29
+ CamembertTokenizer = None
30
+
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+ VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"}
35
+
36
+
37
+ SPIECE_UNDERLINE = "▁"
38
+
39
+
40
+ class CamembertTokenizerFast(PreTrainedTokenizerFast):
41
+ """
42
+ Construct a "fast" CamemBERT tokenizer (backed by HuggingFace's *tokenizers* library). Adapted from
43
+ [`RobertaTokenizer`] and [`XLNetTokenizer`]. Based on
44
+ [BPE](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=BPE#models).
45
+
46
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
47
+ refer to this superclass for more information regarding those methods.
48
+
49
+ Args:
50
+ vocab_file (`str`):
51
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
52
+ contains the vocabulary necessary to instantiate a tokenizer.
53
+ bos_token (`str`, *optional*, defaults to `"<s>"`):
54
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
55
+
56
+ <Tip>
57
+
58
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
59
+ sequence. The token used is the `cls_token`.
60
+
61
+ </Tip>
62
+
63
+ eos_token (`str`, *optional*, defaults to `"</s>"`):
64
+ The end of sequence token.
65
+
66
+ <Tip>
67
+
68
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
69
+ The token used is the `sep_token`.
70
+
71
+ </Tip>
72
+
73
+ sep_token (`str`, *optional*, defaults to `"</s>"`):
74
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
75
+ sequence classification or for a text and a question for question answering. It is also used as the last
76
+ token of a sequence built with special tokens.
77
+ cls_token (`str`, *optional*, defaults to `"<s>"`):
78
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
79
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
80
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
81
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
82
+ token instead.
83
+ pad_token (`str`, *optional*, defaults to `"<pad>"`):
84
+ The token used for padding, for example when batching sequences of different lengths.
85
+ mask_token (`str`, *optional*, defaults to `"<mask>"`):
86
+ The token used for masking values. This is the token used when training this model with masked language
87
+ modeling. This is the token which the model will try to predict.
88
+ additional_special_tokens (`List[str]`, *optional*, defaults to `["<s>NOTUSED", "</s>NOTUSED"]`):
89
+ Additional special tokens used by the tokenizer.
90
+ """
91
+
92
+ vocab_files_names = VOCAB_FILES_NAMES
93
+ model_input_names = ["input_ids", "attention_mask"]
94
+ slow_tokenizer_class = CamembertTokenizer
95
+
96
+ def __init__(
97
+ self,
98
+ vocab_file=None,
99
+ tokenizer_file=None,
100
+ bos_token="<s>",
101
+ eos_token="</s>",
102
+ sep_token="</s>",
103
+ cls_token="<s>",
104
+ unk_token="<unk>",
105
+ pad_token="<pad>",
106
+ mask_token="<mask>",
107
+ additional_special_tokens=["<s>NOTUSED", "</s>NOTUSED", "<unk>NOTUSED"],
108
+ **kwargs,
109
+ ):
110
+ # Mask token behave like a normal word, i.e. include the space before it. Will have normalized = False
111
+ mask_token = AddedToken(mask_token, lstrip=True, special=True) if isinstance(mask_token, str) else mask_token
112
+ super().__init__(
113
+ vocab_file,
114
+ tokenizer_file=tokenizer_file,
115
+ bos_token=bos_token,
116
+ eos_token=eos_token,
117
+ sep_token=sep_token,
118
+ cls_token=cls_token,
119
+ unk_token=unk_token,
120
+ pad_token=pad_token,
121
+ mask_token=mask_token,
122
+ additional_special_tokens=additional_special_tokens,
123
+ **kwargs,
124
+ )
125
+
126
+ self.vocab_file = vocab_file
127
+
128
+ @property
129
+ def can_save_slow_tokenizer(self) -> bool:
130
+ return os.path.isfile(self.vocab_file) if self.vocab_file else False
131
+
132
+ def build_inputs_with_special_tokens(
133
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
134
+ ) -> List[int]:
135
+ """
136
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
137
+ adding special tokens. An CamemBERT sequence has the following format:
138
+
139
+ - single sequence: `<s> X </s>`
140
+ - pair of sequences: `<s> A </s></s> B </s>`
141
+
142
+ Args:
143
+ token_ids_0 (`List[int]`):
144
+ List of IDs to which the special tokens will be added.
145
+ token_ids_1 (`List[int]`, *optional*):
146
+ Optional second list of IDs for sequence pairs.
147
+
148
+ Returns:
149
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
150
+ """
151
+
152
+ if token_ids_1 is None:
153
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
154
+ cls = [self.cls_token_id]
155
+ sep = [self.sep_token_id]
156
+ return cls + token_ids_0 + sep + sep + token_ids_1 + sep
157
+
158
+ def create_token_type_ids_from_sequences(
159
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
160
+ ) -> List[int]:
161
+ """
162
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. CamemBERT, like
163
+ RoBERTa, does not make use of token type ids, therefore a list of zeros is returned.
164
+
165
+ Args:
166
+ token_ids_0 (`List[int]`):
167
+ List of IDs.
168
+ token_ids_1 (`List[int]`, *optional*):
169
+ Optional second list of IDs for sequence pairs.
170
+
171
+ Returns:
172
+ `List[int]`: List of zeros.
173
+ """
174
+ sep = [self.sep_token_id]
175
+ cls = [self.cls_token_id]
176
+
177
+ if token_ids_1 is None:
178
+ return len(cls + token_ids_0 + sep) * [0]
179
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
180
+
181
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
182
+ if not self.can_save_slow_tokenizer:
183
+ raise ValueError(
184
+ "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
185
+ "tokenizer."
186
+ )
187
+
188
+ if not os.path.isdir(save_directory):
189
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
190
+ return
191
+ out_vocab_file = os.path.join(
192
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
193
+ )
194
+
195
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
196
+ copyfile(self.vocab_file, out_vocab_file)
197
+
198
+ return (out_vocab_file,)
199
+
200
+
201
+ __all__ = ["CamembertTokenizerFast"]
docs/transformers/src/transformers/models/canine/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_canine import *
22
+ from .modeling_canine import *
23
+ from .tokenization_canine import *
24
+ else:
25
+ import sys
26
+
27
+ _file = globals()["__file__"]
28
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/src/transformers/models/canine/configuration_canine.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright Google AI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """CANINE model configuration"""
16
+
17
+ from ...configuration_utils import PretrainedConfig
18
+ from ...utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class CanineConfig(PretrainedConfig):
25
+ r"""
26
+ This is the configuration class to store the configuration of a [`CanineModel`]. It is used to instantiate an
27
+ CANINE model according to the specified arguments, defining the model architecture. Instantiating a configuration
28
+ with the defaults will yield a similar configuration to that of the CANINE
29
+ [google/canine-s](https://huggingface.co/google/canine-s) architecture.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+
35
+ Args:
36
+ hidden_size (`int`, *optional*, defaults to 768):
37
+ Dimension of the encoder layers and the pooler layer.
38
+ num_hidden_layers (`int`, *optional*, defaults to 12):
39
+ Number of hidden layers in the deep Transformer encoder.
40
+ num_attention_heads (`int`, *optional*, defaults to 12):
41
+ Number of attention heads for each attention layer in the Transformer encoders.
42
+ intermediate_size (`int`, *optional*, defaults to 3072):
43
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoders.
44
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
45
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
46
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
47
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
48
+ The dropout probability for all fully connected layers in the embeddings, encoders, and pooler.
49
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
50
+ The dropout ratio for the attention probabilities.
51
+ max_position_embeddings (`int`, *optional*, defaults to 16384):
52
+ The maximum sequence length that this model might ever be used with.
53
+ type_vocab_size (`int`, *optional*, defaults to 16):
54
+ The vocabulary size of the `token_type_ids` passed when calling [`CanineModel`].
55
+ initializer_range (`float`, *optional*, defaults to 0.02):
56
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
57
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
58
+ The epsilon used by the layer normalization layers.
59
+ pad_token_id (`int`, *optional*, defaults to 0):
60
+ Padding token id.
61
+ bos_token_id (`int`, *optional*, defaults to 57344):
62
+ Beginning of stream token id.
63
+ eos_token_id (`int`, *optional*, defaults to 57345):
64
+ End of stream token id.
65
+ downsampling_rate (`int`, *optional*, defaults to 4):
66
+ The rate at which to downsample the original character sequence length before applying the deep Transformer
67
+ encoder.
68
+ upsampling_kernel_size (`int`, *optional*, defaults to 4):
69
+ The kernel size (i.e. the number of characters in each window) of the convolutional projection layer when
70
+ projecting back from `hidden_size`*2 to `hidden_size`.
71
+ num_hash_functions (`int`, *optional*, defaults to 8):
72
+ The number of hash functions to use. Each hash function has its own embedding matrix.
73
+ num_hash_buckets (`int`, *optional*, defaults to 16384):
74
+ The number of hash buckets to use.
75
+ local_transformer_stride (`int`, *optional*, defaults to 128):
76
+ The stride of the local attention of the first shallow Transformer encoder. Defaults to 128 for good
77
+ TPU/XLA memory alignment.
78
+
79
+ Example:
80
+
81
+ ```python
82
+ >>> from transformers import CanineConfig, CanineModel
83
+
84
+ >>> # Initializing a CANINE google/canine-s style configuration
85
+ >>> configuration = CanineConfig()
86
+
87
+ >>> # Initializing a model (with random weights) from the google/canine-s style configuration
88
+ >>> model = CanineModel(configuration)
89
+
90
+ >>> # Accessing the model configuration
91
+ >>> configuration = model.config
92
+ ```"""
93
+
94
+ model_type = "canine"
95
+
96
+ def __init__(
97
+ self,
98
+ hidden_size=768,
99
+ num_hidden_layers=12,
100
+ num_attention_heads=12,
101
+ intermediate_size=3072,
102
+ hidden_act="gelu",
103
+ hidden_dropout_prob=0.1,
104
+ attention_probs_dropout_prob=0.1,
105
+ max_position_embeddings=16384,
106
+ type_vocab_size=16,
107
+ initializer_range=0.02,
108
+ layer_norm_eps=1e-12,
109
+ pad_token_id=0,
110
+ bos_token_id=0xE000,
111
+ eos_token_id=0xE001,
112
+ downsampling_rate=4,
113
+ upsampling_kernel_size=4,
114
+ num_hash_functions=8,
115
+ num_hash_buckets=16384,
116
+ local_transformer_stride=128, # Good TPU/XLA memory alignment.
117
+ **kwargs,
118
+ ):
119
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
120
+
121
+ self.max_position_embeddings = max_position_embeddings
122
+ self.hidden_size = hidden_size
123
+ self.num_hidden_layers = num_hidden_layers
124
+ self.num_attention_heads = num_attention_heads
125
+ self.intermediate_size = intermediate_size
126
+ self.hidden_act = hidden_act
127
+ self.hidden_dropout_prob = hidden_dropout_prob
128
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
129
+ self.initializer_range = initializer_range
130
+ self.type_vocab_size = type_vocab_size
131
+ self.layer_norm_eps = layer_norm_eps
132
+
133
+ # Character config:
134
+ self.downsampling_rate = downsampling_rate
135
+ self.upsampling_kernel_size = upsampling_kernel_size
136
+ self.num_hash_functions = num_hash_functions
137
+ self.num_hash_buckets = num_hash_buckets
138
+ self.local_transformer_stride = local_transformer_stride
139
+
140
+
141
+ __all__ = ["CanineConfig"]
docs/transformers/src/transformers/models/canine/convert_canine_original_tf_checkpoint_to_pytorch.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert CANINE checkpoint."""
16
+
17
+ import argparse
18
+
19
+ from transformers import CanineConfig, CanineModel, CanineTokenizer, load_tf_weights_in_canine
20
+ from transformers.utils import logging
21
+
22
+
23
+ logging.set_verbosity_info()
24
+
25
+
26
+ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, pytorch_dump_path):
27
+ # Initialize PyTorch model
28
+ config = CanineConfig()
29
+ model = CanineModel(config)
30
+ model.eval()
31
+
32
+ print(f"Building PyTorch model from configuration: {config}")
33
+
34
+ # Load weights from tf checkpoint
35
+ load_tf_weights_in_canine(model, config, tf_checkpoint_path)
36
+
37
+ # Save pytorch-model (weights and configuration)
38
+ print(f"Save PyTorch model to {pytorch_dump_path}")
39
+ model.save_pretrained(pytorch_dump_path)
40
+
41
+ # Save tokenizer files
42
+ tokenizer = CanineTokenizer()
43
+ print(f"Save tokenizer files to {pytorch_dump_path}")
44
+ tokenizer.save_pretrained(pytorch_dump_path)
45
+
46
+
47
+ if __name__ == "__main__":
48
+ parser = argparse.ArgumentParser()
49
+ # Required parameters
50
+ parser.add_argument(
51
+ "--tf_checkpoint_path",
52
+ default=None,
53
+ type=str,
54
+ required=True,
55
+ help="Path to the TensorFlow checkpoint. Should end with model.ckpt",
56
+ )
57
+ parser.add_argument(
58
+ "--pytorch_dump_path",
59
+ default=None,
60
+ type=str,
61
+ required=True,
62
+ help="Path to a folder where the PyTorch model will be placed.",
63
+ )
64
+ args = parser.parse_args()
65
+ convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.pytorch_dump_path)
docs/transformers/src/transformers/models/canine/modeling_canine.py ADDED
@@ -0,0 +1,1653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 Google AI The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch CANINE model."""
16
+
17
+ import copy
18
+ import math
19
+ import os
20
+ from dataclasses import dataclass
21
+ from typing import Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
+
28
+ from ...activations import ACT2FN
29
+ from ...modeling_outputs import (
30
+ BaseModelOutput,
31
+ ModelOutput,
32
+ MultipleChoiceModelOutput,
33
+ QuestionAnsweringModelOutput,
34
+ SequenceClassifierOutput,
35
+ TokenClassifierOutput,
36
+ )
37
+ from ...modeling_utils import PreTrainedModel
38
+ from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
39
+ from ...utils import (
40
+ add_code_sample_docstrings,
41
+ add_start_docstrings,
42
+ add_start_docstrings_to_model_forward,
43
+ logging,
44
+ replace_return_docstrings,
45
+ )
46
+ from .configuration_canine import CanineConfig
47
+
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+ _CHECKPOINT_FOR_DOC = "google/canine-s"
52
+ _CONFIG_FOR_DOC = "CanineConfig"
53
+
54
+
55
+ # Support up to 16 hash functions.
56
+ _PRIMES = [31, 43, 59, 61, 73, 97, 103, 113, 137, 149, 157, 173, 181, 193, 211, 223]
57
+
58
+
59
+ @dataclass
60
+ class CanineModelOutputWithPooling(ModelOutput):
61
+ """
62
+ Output type of [`CanineModel`]. Based on [`~modeling_outputs.BaseModelOutputWithPooling`], but with slightly
63
+ different `hidden_states` and `attentions`, as these also include the hidden states and attentions of the shallow
64
+ Transformer encoders.
65
+
66
+ Args:
67
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
68
+ Sequence of hidden-states at the output of the last layer of the model (i.e. the output of the final
69
+ shallow Transformer encoder).
70
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
71
+ Hidden-state of the first token of the sequence (classification token) at the last layer of the deep
72
+ Transformer encoder, further processed by a Linear layer and a Tanh activation function. The Linear layer
73
+ weights are trained from the next sentence prediction (classification) objective during pretraining.
74
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
75
+ Tuple of `torch.FloatTensor` (one for the input to each encoder + one for the output of each layer of each
76
+ encoder) of shape `(batch_size, sequence_length, hidden_size)` and `(batch_size, sequence_length //
77
+ config.downsampling_rate, hidden_size)`. Hidden-states of the model at the output of each layer plus the
78
+ initial input to each Transformer encoder. The hidden states of the shallow encoders have length
79
+ `sequence_length`, but the hidden states of the deep encoder have length `sequence_length` //
80
+ `config.downsampling_rate`.
81
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
82
+ Tuple of `torch.FloatTensor` (one for each layer) of the 3 Transformer encoders of shape `(batch_size,
83
+ num_heads, sequence_length, sequence_length)` and `(batch_size, num_heads, sequence_length //
84
+ config.downsampling_rate, sequence_length // config.downsampling_rate)`. Attentions weights after the
85
+ attention softmax, used to compute the weighted average in the self-attention heads.
86
+ """
87
+
88
+ last_hidden_state: Optional[torch.FloatTensor] = None
89
+ pooler_output: Optional[torch.FloatTensor] = None
90
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
91
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
92
+
93
+
94
+ def load_tf_weights_in_canine(model, config, tf_checkpoint_path):
95
+ """Load tf checkpoints in a pytorch model."""
96
+ try:
97
+ import re
98
+
99
+ import numpy as np
100
+ import tensorflow as tf
101
+ except ImportError:
102
+ logger.error(
103
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
104
+ "https://www.tensorflow.org/install/ for installation instructions."
105
+ )
106
+ raise
107
+ tf_path = os.path.abspath(tf_checkpoint_path)
108
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
109
+ # Load weights from TF model
110
+ init_vars = tf.train.list_variables(tf_path)
111
+ names = []
112
+ arrays = []
113
+ for name, shape in init_vars:
114
+ logger.info(f"Loading TF weight {name} with shape {shape}")
115
+ array = tf.train.load_variable(tf_path, name)
116
+ names.append(name)
117
+ arrays.append(array)
118
+
119
+ for name, array in zip(names, arrays):
120
+ name = name.split("/")
121
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
122
+ # which are not required for using pretrained model
123
+ # also discard the cls weights (which were used for the next sentence prediction pre-training task)
124
+ if any(
125
+ n
126
+ in [
127
+ "adam_v",
128
+ "adam_m",
129
+ "AdamWeightDecayOptimizer",
130
+ "AdamWeightDecayOptimizer_1",
131
+ "global_step",
132
+ "cls",
133
+ "autoregressive_decoder",
134
+ "char_output_weights",
135
+ ]
136
+ for n in name
137
+ ):
138
+ logger.info(f"Skipping {'/'.join(name)}")
139
+ continue
140
+ # if first scope name starts with "bert", change it to "encoder"
141
+ if name[0] == "bert":
142
+ name[0] = "encoder"
143
+ # remove "embeddings" middle name of HashBucketCodepointEmbedders
144
+ elif name[1] == "embeddings":
145
+ name.remove(name[1])
146
+ # rename segment_embeddings to token_type_embeddings
147
+ elif name[1] == "segment_embeddings":
148
+ name[1] = "token_type_embeddings"
149
+ # rename initial convolutional projection layer
150
+ elif name[1] == "initial_char_encoder":
151
+ name = ["chars_to_molecules"] + name[-2:]
152
+ # rename final convolutional projection layer
153
+ elif name[0] == "final_char_encoder" and name[1] in ["LayerNorm", "conv"]:
154
+ name = ["projection"] + name[1:]
155
+ pointer = model
156
+ for m_name in name:
157
+ if (re.fullmatch(r"[A-Za-z]+_\d+", m_name)) and "Embedder" not in m_name:
158
+ scope_names = re.split(r"_(\d+)", m_name)
159
+ else:
160
+ scope_names = [m_name]
161
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
162
+ pointer = getattr(pointer, "weight")
163
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
164
+ pointer = getattr(pointer, "bias")
165
+ elif scope_names[0] == "output_weights":
166
+ pointer = getattr(pointer, "weight")
167
+ else:
168
+ try:
169
+ pointer = getattr(pointer, scope_names[0])
170
+ except AttributeError:
171
+ logger.info(f"Skipping {'/'.join(name)}")
172
+ continue
173
+ if len(scope_names) >= 2:
174
+ num = int(scope_names[1])
175
+ pointer = pointer[num]
176
+ if m_name[-11:] == "_embeddings":
177
+ pointer = getattr(pointer, "weight")
178
+ elif m_name[-10:] in [f"Embedder_{i}" for i in range(8)]:
179
+ pointer = getattr(pointer, "weight")
180
+ elif m_name == "kernel":
181
+ array = np.transpose(array)
182
+
183
+ if pointer.shape != array.shape:
184
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
185
+
186
+ logger.info(f"Initialize PyTorch weight {name}")
187
+ pointer.data = torch.from_numpy(array)
188
+ return model
189
+
190
+
191
+ class CanineEmbeddings(nn.Module):
192
+ """Construct the character, position and token_type embeddings."""
193
+
194
+ def __init__(self, config):
195
+ super().__init__()
196
+
197
+ self.config = config
198
+
199
+ # character embeddings
200
+ shard_embedding_size = config.hidden_size // config.num_hash_functions
201
+ for i in range(config.num_hash_functions):
202
+ name = f"HashBucketCodepointEmbedder_{i}"
203
+ setattr(self, name, nn.Embedding(config.num_hash_buckets, shard_embedding_size))
204
+ self.char_position_embeddings = nn.Embedding(config.num_hash_buckets, config.hidden_size)
205
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
206
+
207
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
208
+ # any TensorFlow checkpoint file
209
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
210
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
211
+
212
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
213
+ self.register_buffer(
214
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
215
+ )
216
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
217
+
218
+ def _hash_bucket_tensors(self, input_ids, num_hashes: int, num_buckets: int):
219
+ """
220
+ Converts ids to hash bucket ids via multiple hashing.
221
+
222
+ Args:
223
+ input_ids: The codepoints or other IDs to be hashed.
224
+ num_hashes: The number of hash functions to use.
225
+ num_buckets: The number of hash buckets (i.e. embeddings in each table).
226
+
227
+ Returns:
228
+ A list of tensors, each of which is the hash bucket IDs from one hash function.
229
+ """
230
+ if num_hashes > len(_PRIMES):
231
+ raise ValueError(f"`num_hashes` must be <= {len(_PRIMES)}")
232
+
233
+ primes = _PRIMES[:num_hashes]
234
+
235
+ result_tensors = []
236
+ for prime in primes:
237
+ hashed = ((input_ids + 1) * prime) % num_buckets
238
+ result_tensors.append(hashed)
239
+ return result_tensors
240
+
241
+ def _embed_hash_buckets(self, input_ids, embedding_size: int, num_hashes: int, num_buckets: int):
242
+ """Converts IDs (e.g. codepoints) into embeddings via multiple hashing."""
243
+ if embedding_size % num_hashes != 0:
244
+ raise ValueError(f"Expected `embedding_size` ({embedding_size}) % `num_hashes` ({num_hashes}) == 0")
245
+
246
+ hash_bucket_tensors = self._hash_bucket_tensors(input_ids, num_hashes=num_hashes, num_buckets=num_buckets)
247
+ embedding_shards = []
248
+ for i, hash_bucket_ids in enumerate(hash_bucket_tensors):
249
+ name = f"HashBucketCodepointEmbedder_{i}"
250
+ shard_embeddings = getattr(self, name)(hash_bucket_ids)
251
+ embedding_shards.append(shard_embeddings)
252
+
253
+ return torch.cat(embedding_shards, dim=-1)
254
+
255
+ def forward(
256
+ self,
257
+ input_ids: Optional[torch.LongTensor] = None,
258
+ token_type_ids: Optional[torch.LongTensor] = None,
259
+ position_ids: Optional[torch.LongTensor] = None,
260
+ inputs_embeds: Optional[torch.FloatTensor] = None,
261
+ ) -> torch.FloatTensor:
262
+ if input_ids is not None:
263
+ input_shape = input_ids.size()
264
+ else:
265
+ input_shape = inputs_embeds.size()[:-1]
266
+
267
+ seq_length = input_shape[1]
268
+
269
+ if position_ids is None:
270
+ position_ids = self.position_ids[:, :seq_length]
271
+
272
+ if token_type_ids is None:
273
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
274
+
275
+ if inputs_embeds is None:
276
+ inputs_embeds = self._embed_hash_buckets(
277
+ input_ids, self.config.hidden_size, self.config.num_hash_functions, self.config.num_hash_buckets
278
+ )
279
+
280
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
281
+
282
+ embeddings = inputs_embeds + token_type_embeddings
283
+
284
+ if self.position_embedding_type == "absolute":
285
+ position_embeddings = self.char_position_embeddings(position_ids)
286
+ embeddings += position_embeddings
287
+ embeddings = self.LayerNorm(embeddings)
288
+ embeddings = self.dropout(embeddings)
289
+ return embeddings
290
+
291
+
292
+ class CharactersToMolecules(nn.Module):
293
+ """Convert character sequence to initial molecule sequence (i.e. downsample) using strided convolutions."""
294
+
295
+ def __init__(self, config):
296
+ super().__init__()
297
+
298
+ self.conv = nn.Conv1d(
299
+ in_channels=config.hidden_size,
300
+ out_channels=config.hidden_size,
301
+ kernel_size=config.downsampling_rate,
302
+ stride=config.downsampling_rate,
303
+ )
304
+ self.activation = ACT2FN[config.hidden_act]
305
+
306
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
307
+ # any TensorFlow checkpoint file
308
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
309
+
310
+ def forward(self, char_encoding: torch.Tensor) -> torch.Tensor:
311
+ # `cls_encoding`: [batch, 1, hidden_size]
312
+ cls_encoding = char_encoding[:, 0:1, :]
313
+
314
+ # char_encoding has shape [batch, char_seq, hidden_size]
315
+ # We transpose it to be [batch, hidden_size, char_seq]
316
+ char_encoding = torch.transpose(char_encoding, 1, 2)
317
+ downsampled = self.conv(char_encoding)
318
+ downsampled = torch.transpose(downsampled, 1, 2)
319
+ downsampled = self.activation(downsampled)
320
+
321
+ # Truncate the last molecule in order to reserve a position for [CLS].
322
+ # Often, the last position is never used (unless we completely fill the
323
+ # text buffer). This is important in order to maintain alignment on TPUs
324
+ # (i.e. a multiple of 128).
325
+ downsampled_truncated = downsampled[:, 0:-1, :]
326
+
327
+ # We also keep [CLS] as a separate sequence position since we always
328
+ # want to reserve a position (and the model capacity that goes along
329
+ # with that) in the deep BERT stack.
330
+ # `result`: [batch, molecule_seq, molecule_dim]
331
+ result = torch.cat([cls_encoding, downsampled_truncated], dim=1)
332
+
333
+ result = self.LayerNorm(result)
334
+
335
+ return result
336
+
337
+
338
+ class ConvProjection(nn.Module):
339
+ """
340
+ Project representations from hidden_size*2 back to hidden_size across a window of w = config.upsampling_kernel_size
341
+ characters.
342
+ """
343
+
344
+ def __init__(self, config):
345
+ super().__init__()
346
+ self.config = config
347
+ self.conv = nn.Conv1d(
348
+ in_channels=config.hidden_size * 2,
349
+ out_channels=config.hidden_size,
350
+ kernel_size=config.upsampling_kernel_size,
351
+ stride=1,
352
+ )
353
+ self.activation = ACT2FN[config.hidden_act]
354
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
355
+ # any TensorFlow checkpoint file
356
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
357
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
358
+
359
+ def forward(
360
+ self,
361
+ inputs: torch.Tensor,
362
+ final_seq_char_positions: Optional[torch.Tensor] = None,
363
+ ) -> torch.Tensor:
364
+ # inputs has shape [batch, mol_seq, molecule_hidden_size+char_hidden_final]
365
+ # we transpose it to be [batch, molecule_hidden_size+char_hidden_final, mol_seq]
366
+ inputs = torch.transpose(inputs, 1, 2)
367
+
368
+ # PyTorch < 1.9 does not support padding="same" (which is used in the original implementation),
369
+ # so we pad the tensor manually before passing it to the conv layer
370
+ # based on https://github.com/google-research/big_transfer/blob/49afe42338b62af9fbe18f0258197a33ee578a6b/bit_tf2/models.py#L36-L38
371
+ pad_total = self.config.upsampling_kernel_size - 1
372
+ pad_beg = pad_total // 2
373
+ pad_end = pad_total - pad_beg
374
+
375
+ pad = nn.ConstantPad1d((pad_beg, pad_end), 0)
376
+ # `result`: shape (batch_size, char_seq_len, hidden_size)
377
+ result = self.conv(pad(inputs))
378
+ result = torch.transpose(result, 1, 2)
379
+ result = self.activation(result)
380
+ result = self.LayerNorm(result)
381
+ result = self.dropout(result)
382
+ final_char_seq = result
383
+
384
+ if final_seq_char_positions is not None:
385
+ # Limit transformer query seq and attention mask to these character
386
+ # positions to greatly reduce the compute cost. Typically, this is just
387
+ # done for the MLM training task.
388
+ # TODO add support for MLM
389
+ raise NotImplementedError("CanineForMaskedLM is currently not supported")
390
+ else:
391
+ query_seq = final_char_seq
392
+
393
+ return query_seq
394
+
395
+
396
+ class CanineSelfAttention(nn.Module):
397
+ def __init__(self, config):
398
+ super().__init__()
399
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
400
+ raise ValueError(
401
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
402
+ f"heads ({config.num_attention_heads})"
403
+ )
404
+
405
+ self.num_attention_heads = config.num_attention_heads
406
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
407
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
408
+
409
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
410
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
411
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
412
+
413
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
414
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
415
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
416
+ self.max_position_embeddings = config.max_position_embeddings
417
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
418
+
419
+ def transpose_for_scores(self, x):
420
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
421
+ x = x.view(*new_x_shape)
422
+ return x.permute(0, 2, 1, 3)
423
+
424
+ def forward(
425
+ self,
426
+ from_tensor: torch.Tensor,
427
+ to_tensor: torch.Tensor,
428
+ attention_mask: Optional[torch.FloatTensor] = None,
429
+ head_mask: Optional[torch.FloatTensor] = None,
430
+ output_attentions: Optional[bool] = False,
431
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
432
+ mixed_query_layer = self.query(from_tensor)
433
+
434
+ # If this is instantiated as a cross-attention module, the keys
435
+ # and values come from an encoder; the attention mask needs to be
436
+ # such that the encoder's padding tokens are not attended to.
437
+
438
+ key_layer = self.transpose_for_scores(self.key(to_tensor))
439
+ value_layer = self.transpose_for_scores(self.value(to_tensor))
440
+
441
+ query_layer = self.transpose_for_scores(mixed_query_layer)
442
+
443
+ # Take the dot product between "query" and "key" to get the raw attention scores.
444
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
445
+
446
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
447
+ seq_length = from_tensor.size()[1]
448
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=from_tensor.device).view(-1, 1)
449
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=from_tensor.device).view(1, -1)
450
+ distance = position_ids_l - position_ids_r
451
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
452
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
453
+
454
+ if self.position_embedding_type == "relative_key":
455
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
456
+ attention_scores = attention_scores + relative_position_scores
457
+ elif self.position_embedding_type == "relative_key_query":
458
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
459
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
460
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
461
+
462
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
463
+ if attention_mask is not None:
464
+ if attention_mask.ndim == 3:
465
+ # if attention_mask is 3D, do the following:
466
+ attention_mask = torch.unsqueeze(attention_mask, dim=1)
467
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
468
+ # masked positions, this operation will create a tensor which is 0.0 for
469
+ # positions we want to attend and the dtype's smallest value for masked positions.
470
+ attention_mask = (1.0 - attention_mask.float()) * torch.finfo(attention_scores.dtype).min
471
+ # Apply the attention mask (precomputed for all layers in CanineModel forward() function)
472
+ attention_scores = attention_scores + attention_mask
473
+
474
+ # Normalize the attention scores to probabilities.
475
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
476
+
477
+ # This is actually dropping out entire tokens to attend to, which might
478
+ # seem a bit unusual, but is taken from the original Transformer paper.
479
+ attention_probs = self.dropout(attention_probs)
480
+
481
+ # Mask heads if we want to
482
+ if head_mask is not None:
483
+ attention_probs = attention_probs * head_mask
484
+
485
+ context_layer = torch.matmul(attention_probs, value_layer)
486
+
487
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
488
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
489
+ context_layer = context_layer.view(*new_context_layer_shape)
490
+
491
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
492
+
493
+ return outputs
494
+
495
+
496
+ class CanineSelfOutput(nn.Module):
497
+ def __init__(self, config):
498
+ super().__init__()
499
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
500
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
501
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
502
+
503
+ def forward(
504
+ self, hidden_states: Tuple[torch.FloatTensor], input_tensor: torch.FloatTensor
505
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
506
+ hidden_states = self.dense(hidden_states)
507
+ hidden_states = self.dropout(hidden_states)
508
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
509
+ return hidden_states
510
+
511
+
512
+ class CanineAttention(nn.Module):
513
+ """
514
+ Additional arguments related to local attention:
515
+
516
+ - **local** (`bool`, *optional*, defaults to `False`) -- Whether to apply local attention.
517
+ - **always_attend_to_first_position** (`bool`, *optional*, defaults to `False`) -- Should all blocks be able to
518
+ attend
519
+ to the `to_tensor`'s first position (e.g. a [CLS] position)? - **first_position_attends_to_all** (`bool`,
520
+ *optional*, defaults to `False`) -- Should the *from_tensor*'s first position be able to attend to all
521
+ positions within the *from_tensor*? - **attend_from_chunk_width** (`int`, *optional*, defaults to 128) -- The
522
+ width of each block-wise chunk in `from_tensor`. - **attend_from_chunk_stride** (`int`, *optional*, defaults to
523
+ 128) -- The number of elements to skip when moving to the next block in `from_tensor`. -
524
+ **attend_to_chunk_width** (`int`, *optional*, defaults to 128) -- The width of each block-wise chunk in
525
+ *to_tensor*. - **attend_to_chunk_stride** (`int`, *optional*, defaults to 128) -- The number of elements to
526
+ skip when moving to the next block in `to_tensor`.
527
+ """
528
+
529
+ def __init__(
530
+ self,
531
+ config,
532
+ local=False,
533
+ always_attend_to_first_position: bool = False,
534
+ first_position_attends_to_all: bool = False,
535
+ attend_from_chunk_width: int = 128,
536
+ attend_from_chunk_stride: int = 128,
537
+ attend_to_chunk_width: int = 128,
538
+ attend_to_chunk_stride: int = 128,
539
+ ):
540
+ super().__init__()
541
+ self.self = CanineSelfAttention(config)
542
+ self.output = CanineSelfOutput(config)
543
+ self.pruned_heads = set()
544
+
545
+ # additional arguments related to local attention
546
+ self.local = local
547
+ if attend_from_chunk_width < attend_from_chunk_stride:
548
+ raise ValueError(
549
+ "`attend_from_chunk_width` < `attend_from_chunk_stride` would cause sequence positions to get skipped."
550
+ )
551
+ if attend_to_chunk_width < attend_to_chunk_stride:
552
+ raise ValueError(
553
+ "`attend_to_chunk_width` < `attend_to_chunk_stride`would cause sequence positions to get skipped."
554
+ )
555
+ self.always_attend_to_first_position = always_attend_to_first_position
556
+ self.first_position_attends_to_all = first_position_attends_to_all
557
+ self.attend_from_chunk_width = attend_from_chunk_width
558
+ self.attend_from_chunk_stride = attend_from_chunk_stride
559
+ self.attend_to_chunk_width = attend_to_chunk_width
560
+ self.attend_to_chunk_stride = attend_to_chunk_stride
561
+
562
+ def prune_heads(self, heads):
563
+ if len(heads) == 0:
564
+ return
565
+ heads, index = find_pruneable_heads_and_indices(
566
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
567
+ )
568
+
569
+ # Prune linear layers
570
+ self.self.query = prune_linear_layer(self.self.query, index)
571
+ self.self.key = prune_linear_layer(self.self.key, index)
572
+ self.self.value = prune_linear_layer(self.self.value, index)
573
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
574
+
575
+ # Update hyper params and store pruned heads
576
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
577
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
578
+ self.pruned_heads = self.pruned_heads.union(heads)
579
+
580
+ def forward(
581
+ self,
582
+ hidden_states: Tuple[torch.FloatTensor],
583
+ attention_mask: Optional[torch.FloatTensor] = None,
584
+ head_mask: Optional[torch.FloatTensor] = None,
585
+ output_attentions: Optional[bool] = False,
586
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
587
+ if not self.local:
588
+ self_outputs = self.self(hidden_states, hidden_states, attention_mask, head_mask, output_attentions)
589
+ attention_output = self_outputs[0]
590
+ else:
591
+ from_seq_length = to_seq_length = hidden_states.shape[1]
592
+ from_tensor = to_tensor = hidden_states
593
+
594
+ # Create chunks (windows) that we will attend *from* and then concatenate them.
595
+ from_chunks = []
596
+ if self.first_position_attends_to_all:
597
+ from_chunks.append((0, 1))
598
+ # We must skip this first position so that our output sequence is the
599
+ # correct length (this matters in the *from* sequence only).
600
+ from_start = 1
601
+ else:
602
+ from_start = 0
603
+ for chunk_start in range(from_start, from_seq_length, self.attend_from_chunk_stride):
604
+ chunk_end = min(from_seq_length, chunk_start + self.attend_from_chunk_width)
605
+ from_chunks.append((chunk_start, chunk_end))
606
+
607
+ # Determine the chunks (windows) that will attend *to*.
608
+ to_chunks = []
609
+ if self.first_position_attends_to_all:
610
+ to_chunks.append((0, to_seq_length))
611
+ for chunk_start in range(0, to_seq_length, self.attend_to_chunk_stride):
612
+ chunk_end = min(to_seq_length, chunk_start + self.attend_to_chunk_width)
613
+ to_chunks.append((chunk_start, chunk_end))
614
+
615
+ if len(from_chunks) != len(to_chunks):
616
+ raise ValueError(
617
+ f"Expected to have same number of `from_chunks` ({from_chunks}) and "
618
+ f"`to_chunks` ({from_chunks}). Check strides."
619
+ )
620
+
621
+ # next, compute attention scores for each pair of windows and concatenate
622
+ attention_output_chunks = []
623
+ attention_probs_chunks = []
624
+ for (from_start, from_end), (to_start, to_end) in zip(from_chunks, to_chunks):
625
+ from_tensor_chunk = from_tensor[:, from_start:from_end, :]
626
+ to_tensor_chunk = to_tensor[:, to_start:to_end, :]
627
+ # `attention_mask`: <float>[batch_size, from_seq, to_seq]
628
+ # `attention_mask_chunk`: <float>[batch_size, from_seq_chunk, to_seq_chunk]
629
+ attention_mask_chunk = attention_mask[:, from_start:from_end, to_start:to_end]
630
+ if self.always_attend_to_first_position:
631
+ cls_attention_mask = attention_mask[:, from_start:from_end, 0:1]
632
+ attention_mask_chunk = torch.cat([cls_attention_mask, attention_mask_chunk], dim=2)
633
+
634
+ cls_position = to_tensor[:, 0:1, :]
635
+ to_tensor_chunk = torch.cat([cls_position, to_tensor_chunk], dim=1)
636
+
637
+ attention_outputs_chunk = self.self(
638
+ from_tensor_chunk, to_tensor_chunk, attention_mask_chunk, head_mask, output_attentions
639
+ )
640
+ attention_output_chunks.append(attention_outputs_chunk[0])
641
+ if output_attentions:
642
+ attention_probs_chunks.append(attention_outputs_chunk[1])
643
+
644
+ attention_output = torch.cat(attention_output_chunks, dim=1)
645
+
646
+ attention_output = self.output(attention_output, hidden_states)
647
+ outputs = (attention_output,)
648
+ if not self.local:
649
+ outputs = outputs + self_outputs[1:] # add attentions if we output them
650
+ else:
651
+ outputs = outputs + tuple(attention_probs_chunks) # add attentions if we output them
652
+ return outputs
653
+
654
+
655
+ class CanineIntermediate(nn.Module):
656
+ def __init__(self, config):
657
+ super().__init__()
658
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
659
+ if isinstance(config.hidden_act, str):
660
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
661
+ else:
662
+ self.intermediate_act_fn = config.hidden_act
663
+
664
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
665
+ hidden_states = self.dense(hidden_states)
666
+ hidden_states = self.intermediate_act_fn(hidden_states)
667
+ return hidden_states
668
+
669
+
670
+ class CanineOutput(nn.Module):
671
+ def __init__(self, config):
672
+ super().__init__()
673
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
674
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
675
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
676
+
677
+ def forward(self, hidden_states: Tuple[torch.FloatTensor], input_tensor: torch.FloatTensor) -> torch.FloatTensor:
678
+ hidden_states = self.dense(hidden_states)
679
+ hidden_states = self.dropout(hidden_states)
680
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
681
+ return hidden_states
682
+
683
+
684
+ class CanineLayer(nn.Module):
685
+ def __init__(
686
+ self,
687
+ config,
688
+ local,
689
+ always_attend_to_first_position,
690
+ first_position_attends_to_all,
691
+ attend_from_chunk_width,
692
+ attend_from_chunk_stride,
693
+ attend_to_chunk_width,
694
+ attend_to_chunk_stride,
695
+ ):
696
+ super().__init__()
697
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
698
+ self.seq_len_dim = 1
699
+ self.attention = CanineAttention(
700
+ config,
701
+ local,
702
+ always_attend_to_first_position,
703
+ first_position_attends_to_all,
704
+ attend_from_chunk_width,
705
+ attend_from_chunk_stride,
706
+ attend_to_chunk_width,
707
+ attend_to_chunk_stride,
708
+ )
709
+ self.intermediate = CanineIntermediate(config)
710
+ self.output = CanineOutput(config)
711
+
712
+ def forward(
713
+ self,
714
+ hidden_states: Tuple[torch.FloatTensor],
715
+ attention_mask: Optional[torch.FloatTensor] = None,
716
+ head_mask: Optional[torch.FloatTensor] = None,
717
+ output_attentions: Optional[bool] = False,
718
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
719
+ self_attention_outputs = self.attention(
720
+ hidden_states,
721
+ attention_mask,
722
+ head_mask,
723
+ output_attentions=output_attentions,
724
+ )
725
+ attention_output = self_attention_outputs[0]
726
+
727
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
728
+
729
+ layer_output = apply_chunking_to_forward(
730
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
731
+ )
732
+ outputs = (layer_output,) + outputs
733
+
734
+ return outputs
735
+
736
+ def feed_forward_chunk(self, attention_output):
737
+ intermediate_output = self.intermediate(attention_output)
738
+ layer_output = self.output(intermediate_output, attention_output)
739
+ return layer_output
740
+
741
+
742
+ class CanineEncoder(nn.Module):
743
+ def __init__(
744
+ self,
745
+ config,
746
+ local=False,
747
+ always_attend_to_first_position=False,
748
+ first_position_attends_to_all=False,
749
+ attend_from_chunk_width=128,
750
+ attend_from_chunk_stride=128,
751
+ attend_to_chunk_width=128,
752
+ attend_to_chunk_stride=128,
753
+ ):
754
+ super().__init__()
755
+ self.config = config
756
+ self.layer = nn.ModuleList(
757
+ [
758
+ CanineLayer(
759
+ config,
760
+ local,
761
+ always_attend_to_first_position,
762
+ first_position_attends_to_all,
763
+ attend_from_chunk_width,
764
+ attend_from_chunk_stride,
765
+ attend_to_chunk_width,
766
+ attend_to_chunk_stride,
767
+ )
768
+ for _ in range(config.num_hidden_layers)
769
+ ]
770
+ )
771
+ self.gradient_checkpointing = False
772
+
773
+ def forward(
774
+ self,
775
+ hidden_states: Tuple[torch.FloatTensor],
776
+ attention_mask: Optional[torch.FloatTensor] = None,
777
+ head_mask: Optional[torch.FloatTensor] = None,
778
+ output_attentions: Optional[bool] = False,
779
+ output_hidden_states: Optional[bool] = False,
780
+ return_dict: Optional[bool] = True,
781
+ ) -> Union[Tuple, BaseModelOutput]:
782
+ all_hidden_states = () if output_hidden_states else None
783
+ all_self_attentions = () if output_attentions else None
784
+
785
+ for i, layer_module in enumerate(self.layer):
786
+ if output_hidden_states:
787
+ all_hidden_states = all_hidden_states + (hidden_states,)
788
+
789
+ layer_head_mask = head_mask[i] if head_mask is not None else None
790
+
791
+ if self.gradient_checkpointing and self.training:
792
+ layer_outputs = self._gradient_checkpointing_func(
793
+ layer_module.__call__,
794
+ hidden_states,
795
+ attention_mask,
796
+ layer_head_mask,
797
+ output_attentions,
798
+ )
799
+ else:
800
+ layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
801
+
802
+ hidden_states = layer_outputs[0]
803
+ if output_attentions:
804
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
805
+
806
+ if output_hidden_states:
807
+ all_hidden_states = all_hidden_states + (hidden_states,)
808
+
809
+ if not return_dict:
810
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
811
+ return BaseModelOutput(
812
+ last_hidden_state=hidden_states,
813
+ hidden_states=all_hidden_states,
814
+ attentions=all_self_attentions,
815
+ )
816
+
817
+
818
+ class CaninePooler(nn.Module):
819
+ def __init__(self, config):
820
+ super().__init__()
821
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
822
+ self.activation = nn.Tanh()
823
+
824
+ def forward(self, hidden_states: Tuple[torch.FloatTensor]) -> torch.FloatTensor:
825
+ # We "pool" the model by simply taking the hidden state corresponding
826
+ # to the first token.
827
+ first_token_tensor = hidden_states[:, 0]
828
+ pooled_output = self.dense(first_token_tensor)
829
+ pooled_output = self.activation(pooled_output)
830
+ return pooled_output
831
+
832
+
833
+ class CaninePredictionHeadTransform(nn.Module):
834
+ def __init__(self, config):
835
+ super().__init__()
836
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
837
+ if isinstance(config.hidden_act, str):
838
+ self.transform_act_fn = ACT2FN[config.hidden_act]
839
+ else:
840
+ self.transform_act_fn = config.hidden_act
841
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
842
+
843
+ def forward(self, hidden_states: Tuple[torch.FloatTensor]) -> torch.FloatTensor:
844
+ hidden_states = self.dense(hidden_states)
845
+ hidden_states = self.transform_act_fn(hidden_states)
846
+ hidden_states = self.LayerNorm(hidden_states)
847
+ return hidden_states
848
+
849
+
850
+ class CanineLMPredictionHead(nn.Module):
851
+ def __init__(self, config):
852
+ super().__init__()
853
+ self.transform = CaninePredictionHeadTransform(config)
854
+
855
+ # The output weights are the same as the input embeddings, but there is
856
+ # an output-only bias for each token.
857
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
858
+
859
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
860
+
861
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
862
+ self.decoder.bias = self.bias
863
+
864
+ def forward(self, hidden_states: Tuple[torch.FloatTensor]) -> torch.FloatTensor:
865
+ hidden_states = self.transform(hidden_states)
866
+ hidden_states = self.decoder(hidden_states)
867
+ return hidden_states
868
+
869
+
870
+ class CanineOnlyMLMHead(nn.Module):
871
+ def __init__(self, config):
872
+ super().__init__()
873
+ self.predictions = CanineLMPredictionHead(config)
874
+
875
+ def forward(
876
+ self,
877
+ sequence_output: Tuple[torch.Tensor],
878
+ ) -> Tuple[torch.Tensor]:
879
+ prediction_scores = self.predictions(sequence_output)
880
+ return prediction_scores
881
+
882
+
883
+ class CaninePreTrainedModel(PreTrainedModel):
884
+ """
885
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
886
+ models.
887
+ """
888
+
889
+ config_class = CanineConfig
890
+ load_tf_weights = load_tf_weights_in_canine
891
+ base_model_prefix = "canine"
892
+ supports_gradient_checkpointing = True
893
+
894
+ def _init_weights(self, module):
895
+ """Initialize the weights"""
896
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
897
+ # Slightly different from the TF version which uses truncated_normal for initialization
898
+ # cf https://github.com/pytorch/pytorch/pull/5617
899
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
900
+ if module.bias is not None:
901
+ module.bias.data.zero_()
902
+ elif isinstance(module, nn.Embedding):
903
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
904
+ if module.padding_idx is not None:
905
+ module.weight.data[module.padding_idx].zero_()
906
+ elif isinstance(module, nn.LayerNorm):
907
+ module.bias.data.zero_()
908
+ module.weight.data.fill_(1.0)
909
+
910
+
911
+ CANINE_START_DOCSTRING = r"""
912
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
913
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
914
+ behavior.
915
+
916
+ Parameters:
917
+ config ([`CanineConfig`]): Model configuration class with all the parameters of the model.
918
+ Initializing with a config file does not load the weights associated with the model, only the
919
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
920
+ """
921
+
922
+ CANINE_INPUTS_DOCSTRING = r"""
923
+ Args:
924
+ input_ids (`torch.LongTensor` of shape `({0})`):
925
+ Indices of input sequence tokens in the vocabulary.
926
+
927
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
928
+ [`PreTrainedTokenizer.__call__`] for details.
929
+
930
+ [What are input IDs?](../glossary#input-ids)
931
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
932
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
933
+
934
+ - 1 for tokens that are **not masked**,
935
+ - 0 for tokens that are **masked**.
936
+
937
+ [What are attention masks?](../glossary#attention-mask)
938
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
939
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
940
+ 1]`:
941
+
942
+ - 0 corresponds to a *sentence A* token,
943
+ - 1 corresponds to a *sentence B* token.
944
+
945
+ [What are token type IDs?](../glossary#token-type-ids)
946
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
947
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
948
+ config.max_position_embeddings - 1]`.
949
+
950
+ [What are position IDs?](../glossary#position-ids)
951
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
952
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
953
+
954
+ - 1 indicates the head is **not masked**,
955
+ - 0 indicates the head is **masked**.
956
+
957
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
958
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
959
+ is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
960
+ model's internal embedding lookup matrix.
961
+ output_attentions (`bool`, *optional*):
962
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
963
+ tensors for more detail.
964
+ output_hidden_states (`bool`, *optional*):
965
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
966
+ more detail.
967
+ return_dict (`bool`, *optional*):
968
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
969
+ """
970
+
971
+
972
+ @add_start_docstrings(
973
+ "The bare CANINE Model transformer outputting raw hidden-states without any specific head on top.",
974
+ CANINE_START_DOCSTRING,
975
+ )
976
+ class CanineModel(CaninePreTrainedModel):
977
+ def __init__(self, config, add_pooling_layer=True):
978
+ super().__init__(config)
979
+ self.config = config
980
+ shallow_config = copy.deepcopy(config)
981
+ shallow_config.num_hidden_layers = 1
982
+
983
+ self.char_embeddings = CanineEmbeddings(config)
984
+ # shallow/low-dim transformer encoder to get a initial character encoding
985
+ self.initial_char_encoder = CanineEncoder(
986
+ shallow_config,
987
+ local=True,
988
+ always_attend_to_first_position=False,
989
+ first_position_attends_to_all=False,
990
+ attend_from_chunk_width=config.local_transformer_stride,
991
+ attend_from_chunk_stride=config.local_transformer_stride,
992
+ attend_to_chunk_width=config.local_transformer_stride,
993
+ attend_to_chunk_stride=config.local_transformer_stride,
994
+ )
995
+ self.chars_to_molecules = CharactersToMolecules(config)
996
+ # deep transformer encoder
997
+ self.encoder = CanineEncoder(config)
998
+ self.projection = ConvProjection(config)
999
+ # shallow/low-dim transformer encoder to get a final character encoding
1000
+ self.final_char_encoder = CanineEncoder(shallow_config)
1001
+
1002
+ self.pooler = CaninePooler(config) if add_pooling_layer else None
1003
+
1004
+ # Initialize weights and apply final processing
1005
+ self.post_init()
1006
+
1007
+ def _prune_heads(self, heads_to_prune):
1008
+ """
1009
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1010
+ class PreTrainedModel
1011
+ """
1012
+ for layer, heads in heads_to_prune.items():
1013
+ self.encoder.layer[layer].attention.prune_heads(heads)
1014
+
1015
+ def _create_3d_attention_mask_from_input_mask(self, from_tensor, to_mask):
1016
+ """
1017
+ Create 3D attention mask from a 2D tensor mask.
1018
+
1019
+ Args:
1020
+ from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...].
1021
+ to_mask: int32 Tensor of shape [batch_size, to_seq_length].
1022
+
1023
+ Returns:
1024
+ float Tensor of shape [batch_size, from_seq_length, to_seq_length].
1025
+ """
1026
+ batch_size, from_seq_length = from_tensor.shape[0], from_tensor.shape[1]
1027
+
1028
+ to_seq_length = to_mask.shape[1]
1029
+
1030
+ to_mask = torch.reshape(to_mask, (batch_size, 1, to_seq_length)).float()
1031
+
1032
+ # We don't assume that `from_tensor` is a mask (although it could be). We
1033
+ # don't actually care if we attend *from* padding tokens (only *to* padding)
1034
+ # tokens so we create a tensor of all ones.
1035
+ broadcast_ones = torch.ones(size=(batch_size, from_seq_length, 1), dtype=torch.float32, device=to_mask.device)
1036
+
1037
+ # Here we broadcast along two dimensions to create the mask.
1038
+ mask = broadcast_ones * to_mask
1039
+
1040
+ return mask
1041
+
1042
+ def _downsample_attention_mask(self, char_attention_mask: torch.Tensor, downsampling_rate: int):
1043
+ """Downsample 2D character attention mask to 2D molecule attention mask using MaxPool1d layer."""
1044
+
1045
+ # first, make char_attention_mask 3D by adding a channel dim
1046
+ batch_size, char_seq_len = char_attention_mask.shape
1047
+ poolable_char_mask = torch.reshape(char_attention_mask, (batch_size, 1, char_seq_len))
1048
+
1049
+ # next, apply MaxPool1d to get pooled_molecule_mask of shape (batch_size, 1, mol_seq_len)
1050
+ pooled_molecule_mask = torch.nn.MaxPool1d(kernel_size=downsampling_rate, stride=downsampling_rate)(
1051
+ poolable_char_mask.float()
1052
+ )
1053
+
1054
+ # finally, squeeze to get tensor of shape (batch_size, mol_seq_len)
1055
+ molecule_attention_mask = torch.squeeze(pooled_molecule_mask, dim=-1)
1056
+
1057
+ return molecule_attention_mask
1058
+
1059
+ def _repeat_molecules(self, molecules: torch.Tensor, char_seq_length: int) -> torch.Tensor:
1060
+ """Repeats molecules to make them the same length as the char sequence."""
1061
+
1062
+ rate = self.config.downsampling_rate
1063
+
1064
+ molecules_without_extra_cls = molecules[:, 1:, :]
1065
+ # `repeated`: [batch_size, almost_char_seq_len, molecule_hidden_size]
1066
+ repeated = torch.repeat_interleave(molecules_without_extra_cls, repeats=rate, dim=-2)
1067
+
1068
+ # So far, we've repeated the elements sufficient for any `char_seq_length`
1069
+ # that's a multiple of `downsampling_rate`. Now we account for the last
1070
+ # n elements (n < `downsampling_rate`), i.e. the remainder of floor
1071
+ # division. We do this by repeating the last molecule a few extra times.
1072
+ last_molecule = molecules[:, -1:, :]
1073
+ remainder_length = char_seq_length % rate
1074
+ remainder_repeated = torch.repeat_interleave(
1075
+ last_molecule,
1076
+ # +1 molecule to compensate for truncation.
1077
+ repeats=remainder_length + rate,
1078
+ dim=-2,
1079
+ )
1080
+
1081
+ # `repeated`: [batch_size, char_seq_len, molecule_hidden_size]
1082
+ return torch.cat([repeated, remainder_repeated], dim=-2)
1083
+
1084
+ @add_start_docstrings_to_model_forward(CANINE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1085
+ @add_code_sample_docstrings(
1086
+ checkpoint=_CHECKPOINT_FOR_DOC,
1087
+ output_type=CanineModelOutputWithPooling,
1088
+ config_class=_CONFIG_FOR_DOC,
1089
+ )
1090
+ def forward(
1091
+ self,
1092
+ input_ids: Optional[torch.LongTensor] = None,
1093
+ attention_mask: Optional[torch.FloatTensor] = None,
1094
+ token_type_ids: Optional[torch.LongTensor] = None,
1095
+ position_ids: Optional[torch.LongTensor] = None,
1096
+ head_mask: Optional[torch.FloatTensor] = None,
1097
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1098
+ output_attentions: Optional[bool] = None,
1099
+ output_hidden_states: Optional[bool] = None,
1100
+ return_dict: Optional[bool] = None,
1101
+ ) -> Union[Tuple, CanineModelOutputWithPooling]:
1102
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1103
+ output_hidden_states = (
1104
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1105
+ )
1106
+ all_hidden_states = () if output_hidden_states else None
1107
+ all_self_attentions = () if output_attentions else None
1108
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1109
+
1110
+ if input_ids is not None and inputs_embeds is not None:
1111
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1112
+ elif input_ids is not None:
1113
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
1114
+ input_shape = input_ids.size()
1115
+ elif inputs_embeds is not None:
1116
+ input_shape = inputs_embeds.size()[:-1]
1117
+ else:
1118
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1119
+
1120
+ batch_size, seq_length = input_shape
1121
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1122
+
1123
+ if attention_mask is None:
1124
+ attention_mask = torch.ones(((batch_size, seq_length)), device=device)
1125
+ if token_type_ids is None:
1126
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
1127
+
1128
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
1129
+ # ourselves in which case we just need to make it broadcastable to all heads.
1130
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
1131
+ molecule_attention_mask = self._downsample_attention_mask(
1132
+ attention_mask, downsampling_rate=self.config.downsampling_rate
1133
+ )
1134
+ extended_molecule_attention_mask: torch.Tensor = self.get_extended_attention_mask(
1135
+ molecule_attention_mask, (batch_size, molecule_attention_mask.shape[-1])
1136
+ )
1137
+
1138
+ # Prepare head mask if needed
1139
+ # 1.0 in head_mask indicate we keep the head
1140
+ # attention_probs has shape bsz x n_heads x N x N
1141
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1142
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1143
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1144
+
1145
+ # `input_char_embeddings`: shape (batch_size, char_seq, char_dim)
1146
+ input_char_embeddings = self.char_embeddings(
1147
+ input_ids=input_ids,
1148
+ position_ids=position_ids,
1149
+ token_type_ids=token_type_ids,
1150
+ inputs_embeds=inputs_embeds,
1151
+ )
1152
+
1153
+ # Contextualize character embeddings using shallow Transformer.
1154
+ # We use a 3D attention mask for the local attention.
1155
+ # `input_char_encoding`: shape (batch_size, char_seq_len, char_dim)
1156
+ char_attention_mask = self._create_3d_attention_mask_from_input_mask(
1157
+ input_ids if input_ids is not None else inputs_embeds, attention_mask
1158
+ )
1159
+ init_chars_encoder_outputs = self.initial_char_encoder(
1160
+ input_char_embeddings,
1161
+ attention_mask=char_attention_mask,
1162
+ output_attentions=output_attentions,
1163
+ output_hidden_states=output_hidden_states,
1164
+ )
1165
+ input_char_encoding = init_chars_encoder_outputs.last_hidden_state
1166
+
1167
+ # Downsample chars to molecules.
1168
+ # The following lines have dimensions: [batch, molecule_seq, molecule_dim].
1169
+ # In this transformation, we change the dimensionality from `char_dim` to
1170
+ # `molecule_dim`, but do *NOT* add a resnet connection. Instead, we rely on
1171
+ # the resnet connections (a) from the final char transformer stack back into
1172
+ # the original char transformer stack and (b) the resnet connections from
1173
+ # the final char transformer stack back into the deep BERT stack of
1174
+ # molecules.
1175
+ #
1176
+ # Empirically, it is critical to use a powerful enough transformation here:
1177
+ # mean pooling causes training to diverge with huge gradient norms in this
1178
+ # region of the model; using a convolution here resolves this issue. From
1179
+ # this, it seems that molecules and characters require a very different
1180
+ # feature space; intuitively, this makes sense.
1181
+ init_molecule_encoding = self.chars_to_molecules(input_char_encoding)
1182
+
1183
+ # Deep BERT encoder
1184
+ # `molecule_sequence_output`: shape (batch_size, mol_seq_len, mol_dim)
1185
+ encoder_outputs = self.encoder(
1186
+ init_molecule_encoding,
1187
+ attention_mask=extended_molecule_attention_mask,
1188
+ head_mask=head_mask,
1189
+ output_attentions=output_attentions,
1190
+ output_hidden_states=output_hidden_states,
1191
+ return_dict=return_dict,
1192
+ )
1193
+ molecule_sequence_output = encoder_outputs[0]
1194
+ pooled_output = self.pooler(molecule_sequence_output) if self.pooler is not None else None
1195
+
1196
+ # Upsample molecules back to characters.
1197
+ # `repeated_molecules`: shape (batch_size, char_seq_len, mol_hidden_size)
1198
+ repeated_molecules = self._repeat_molecules(molecule_sequence_output, char_seq_length=input_shape[-1])
1199
+
1200
+ # Concatenate representations (contextualized char embeddings and repeated molecules):
1201
+ # `concat`: shape [batch_size, char_seq_len, molecule_hidden_size+char_hidden_final]
1202
+ concat = torch.cat([input_char_encoding, repeated_molecules], dim=-1)
1203
+
1204
+ # Project representation dimension back to hidden_size
1205
+ # `sequence_output`: shape (batch_size, char_seq_len, hidden_size])
1206
+ sequence_output = self.projection(concat)
1207
+
1208
+ # Apply final shallow Transformer
1209
+ # `sequence_output`: shape (batch_size, char_seq_len, hidden_size])
1210
+ final_chars_encoder_outputs = self.final_char_encoder(
1211
+ sequence_output,
1212
+ attention_mask=extended_attention_mask,
1213
+ output_attentions=output_attentions,
1214
+ output_hidden_states=output_hidden_states,
1215
+ )
1216
+ sequence_output = final_chars_encoder_outputs.last_hidden_state
1217
+
1218
+ if output_hidden_states:
1219
+ deep_encoder_hidden_states = encoder_outputs.hidden_states if return_dict else encoder_outputs[1]
1220
+ all_hidden_states = (
1221
+ all_hidden_states
1222
+ + init_chars_encoder_outputs.hidden_states
1223
+ + deep_encoder_hidden_states
1224
+ + final_chars_encoder_outputs.hidden_states
1225
+ )
1226
+
1227
+ if output_attentions:
1228
+ deep_encoder_self_attentions = encoder_outputs.attentions if return_dict else encoder_outputs[-1]
1229
+ all_self_attentions = (
1230
+ all_self_attentions
1231
+ + init_chars_encoder_outputs.attentions
1232
+ + deep_encoder_self_attentions
1233
+ + final_chars_encoder_outputs.attentions
1234
+ )
1235
+
1236
+ if not return_dict:
1237
+ output = (sequence_output, pooled_output)
1238
+ output += tuple(v for v in [all_hidden_states, all_self_attentions] if v is not None)
1239
+ return output
1240
+
1241
+ return CanineModelOutputWithPooling(
1242
+ last_hidden_state=sequence_output,
1243
+ pooler_output=pooled_output,
1244
+ hidden_states=all_hidden_states,
1245
+ attentions=all_self_attentions,
1246
+ )
1247
+
1248
+
1249
+ @add_start_docstrings(
1250
+ """
1251
+ CANINE Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
1252
+ output) e.g. for GLUE tasks.
1253
+ """,
1254
+ CANINE_START_DOCSTRING,
1255
+ )
1256
+ class CanineForSequenceClassification(CaninePreTrainedModel):
1257
+ def __init__(self, config):
1258
+ super().__init__(config)
1259
+ self.num_labels = config.num_labels
1260
+
1261
+ self.canine = CanineModel(config)
1262
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1263
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1264
+
1265
+ # Initialize weights and apply final processing
1266
+ self.post_init()
1267
+
1268
+ @add_start_docstrings_to_model_forward(CANINE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1269
+ @add_code_sample_docstrings(
1270
+ checkpoint=_CHECKPOINT_FOR_DOC,
1271
+ output_type=SequenceClassifierOutput,
1272
+ config_class=_CONFIG_FOR_DOC,
1273
+ )
1274
+ def forward(
1275
+ self,
1276
+ input_ids: Optional[torch.LongTensor] = None,
1277
+ attention_mask: Optional[torch.FloatTensor] = None,
1278
+ token_type_ids: Optional[torch.LongTensor] = None,
1279
+ position_ids: Optional[torch.LongTensor] = None,
1280
+ head_mask: Optional[torch.FloatTensor] = None,
1281
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1282
+ labels: Optional[torch.LongTensor] = None,
1283
+ output_attentions: Optional[bool] = None,
1284
+ output_hidden_states: Optional[bool] = None,
1285
+ return_dict: Optional[bool] = None,
1286
+ ) -> Union[Tuple, SequenceClassifierOutput]:
1287
+ r"""
1288
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1289
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1290
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1291
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1292
+ """
1293
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1294
+
1295
+ outputs = self.canine(
1296
+ input_ids,
1297
+ attention_mask=attention_mask,
1298
+ token_type_ids=token_type_ids,
1299
+ position_ids=position_ids,
1300
+ head_mask=head_mask,
1301
+ inputs_embeds=inputs_embeds,
1302
+ output_attentions=output_attentions,
1303
+ output_hidden_states=output_hidden_states,
1304
+ return_dict=return_dict,
1305
+ )
1306
+
1307
+ pooled_output = outputs[1]
1308
+
1309
+ pooled_output = self.dropout(pooled_output)
1310
+ logits = self.classifier(pooled_output)
1311
+
1312
+ loss = None
1313
+ if labels is not None:
1314
+ if self.config.problem_type is None:
1315
+ if self.num_labels == 1:
1316
+ self.config.problem_type = "regression"
1317
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1318
+ self.config.problem_type = "single_label_classification"
1319
+ else:
1320
+ self.config.problem_type = "multi_label_classification"
1321
+
1322
+ if self.config.problem_type == "regression":
1323
+ loss_fct = MSELoss()
1324
+ if self.num_labels == 1:
1325
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1326
+ else:
1327
+ loss = loss_fct(logits, labels)
1328
+ elif self.config.problem_type == "single_label_classification":
1329
+ loss_fct = CrossEntropyLoss()
1330
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1331
+ elif self.config.problem_type == "multi_label_classification":
1332
+ loss_fct = BCEWithLogitsLoss()
1333
+ loss = loss_fct(logits, labels)
1334
+ if not return_dict:
1335
+ output = (logits,) + outputs[2:]
1336
+ return ((loss,) + output) if loss is not None else output
1337
+
1338
+ return SequenceClassifierOutput(
1339
+ loss=loss,
1340
+ logits=logits,
1341
+ hidden_states=outputs.hidden_states,
1342
+ attentions=outputs.attentions,
1343
+ )
1344
+
1345
+
1346
+ @add_start_docstrings(
1347
+ """
1348
+ CANINE Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
1349
+ softmax) e.g. for RocStories/SWAG tasks.
1350
+ """,
1351
+ CANINE_START_DOCSTRING,
1352
+ )
1353
+ class CanineForMultipleChoice(CaninePreTrainedModel):
1354
+ def __init__(self, config):
1355
+ super().__init__(config)
1356
+
1357
+ self.canine = CanineModel(config)
1358
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1359
+ self.classifier = nn.Linear(config.hidden_size, 1)
1360
+
1361
+ # Initialize weights and apply final processing
1362
+ self.post_init()
1363
+
1364
+ @add_start_docstrings_to_model_forward(CANINE_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
1365
+ @add_code_sample_docstrings(
1366
+ checkpoint=_CHECKPOINT_FOR_DOC,
1367
+ output_type=MultipleChoiceModelOutput,
1368
+ config_class=_CONFIG_FOR_DOC,
1369
+ )
1370
+ def forward(
1371
+ self,
1372
+ input_ids: Optional[torch.LongTensor] = None,
1373
+ attention_mask: Optional[torch.FloatTensor] = None,
1374
+ token_type_ids: Optional[torch.LongTensor] = None,
1375
+ position_ids: Optional[torch.LongTensor] = None,
1376
+ head_mask: Optional[torch.FloatTensor] = None,
1377
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1378
+ labels: Optional[torch.LongTensor] = None,
1379
+ output_attentions: Optional[bool] = None,
1380
+ output_hidden_states: Optional[bool] = None,
1381
+ return_dict: Optional[bool] = None,
1382
+ ) -> Union[Tuple, MultipleChoiceModelOutput]:
1383
+ r"""
1384
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1385
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
1386
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
1387
+ `input_ids` above)
1388
+ """
1389
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1390
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1391
+
1392
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1393
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1394
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1395
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1396
+ inputs_embeds = (
1397
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1398
+ if inputs_embeds is not None
1399
+ else None
1400
+ )
1401
+
1402
+ outputs = self.canine(
1403
+ input_ids,
1404
+ attention_mask=attention_mask,
1405
+ token_type_ids=token_type_ids,
1406
+ position_ids=position_ids,
1407
+ head_mask=head_mask,
1408
+ inputs_embeds=inputs_embeds,
1409
+ output_attentions=output_attentions,
1410
+ output_hidden_states=output_hidden_states,
1411
+ return_dict=return_dict,
1412
+ )
1413
+
1414
+ pooled_output = outputs[1]
1415
+
1416
+ pooled_output = self.dropout(pooled_output)
1417
+ logits = self.classifier(pooled_output)
1418
+ reshaped_logits = logits.view(-1, num_choices)
1419
+
1420
+ loss = None
1421
+ if labels is not None:
1422
+ loss_fct = CrossEntropyLoss()
1423
+ loss = loss_fct(reshaped_logits, labels)
1424
+
1425
+ if not return_dict:
1426
+ output = (reshaped_logits,) + outputs[2:]
1427
+ return ((loss,) + output) if loss is not None else output
1428
+
1429
+ return MultipleChoiceModelOutput(
1430
+ loss=loss,
1431
+ logits=reshaped_logits,
1432
+ hidden_states=outputs.hidden_states,
1433
+ attentions=outputs.attentions,
1434
+ )
1435
+
1436
+
1437
+ @add_start_docstrings(
1438
+ """
1439
+ CANINE Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1440
+ Named-Entity-Recognition (NER) tasks.
1441
+ """,
1442
+ CANINE_START_DOCSTRING,
1443
+ )
1444
+ class CanineForTokenClassification(CaninePreTrainedModel):
1445
+ def __init__(self, config):
1446
+ super().__init__(config)
1447
+ self.num_labels = config.num_labels
1448
+
1449
+ self.canine = CanineModel(config)
1450
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1451
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1452
+
1453
+ # Initialize weights and apply final processing
1454
+ self.post_init()
1455
+
1456
+ @add_start_docstrings_to_model_forward(CANINE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1457
+ @replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC)
1458
+ def forward(
1459
+ self,
1460
+ input_ids: Optional[torch.LongTensor] = None,
1461
+ attention_mask: Optional[torch.FloatTensor] = None,
1462
+ token_type_ids: Optional[torch.LongTensor] = None,
1463
+ position_ids: Optional[torch.LongTensor] = None,
1464
+ head_mask: Optional[torch.FloatTensor] = None,
1465
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1466
+ labels: Optional[torch.LongTensor] = None,
1467
+ output_attentions: Optional[bool] = None,
1468
+ output_hidden_states: Optional[bool] = None,
1469
+ return_dict: Optional[bool] = None,
1470
+ ) -> Union[Tuple, TokenClassifierOutput]:
1471
+ r"""
1472
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1473
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1474
+
1475
+ Returns:
1476
+
1477
+ Example:
1478
+
1479
+ ```python
1480
+ >>> from transformers import AutoTokenizer, CanineForTokenClassification
1481
+ >>> import torch
1482
+
1483
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/canine-s")
1484
+ >>> model = CanineForTokenClassification.from_pretrained("google/canine-s")
1485
+
1486
+ >>> inputs = tokenizer(
1487
+ ... "HuggingFace is a company based in Paris and New York", add_special_tokens=False, return_tensors="pt"
1488
+ ... )
1489
+
1490
+ >>> with torch.no_grad():
1491
+ ... logits = model(**inputs).logits
1492
+
1493
+ >>> predicted_token_class_ids = logits.argmax(-1)
1494
+
1495
+ >>> # Note that tokens are classified rather then input words which means that
1496
+ >>> # there might be more predicted token classes than words.
1497
+ >>> # Multiple token classes might account for the same word
1498
+ >>> predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids[0]]
1499
+ >>> predicted_tokens_classes # doctest: +SKIP
1500
+ ```
1501
+
1502
+ ```python
1503
+ >>> labels = predicted_token_class_ids
1504
+ >>> loss = model(**inputs, labels=labels).loss
1505
+ >>> round(loss.item(), 2) # doctest: +SKIP
1506
+ ```"""
1507
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1508
+
1509
+ outputs = self.canine(
1510
+ input_ids,
1511
+ attention_mask=attention_mask,
1512
+ token_type_ids=token_type_ids,
1513
+ position_ids=position_ids,
1514
+ head_mask=head_mask,
1515
+ inputs_embeds=inputs_embeds,
1516
+ output_attentions=output_attentions,
1517
+ output_hidden_states=output_hidden_states,
1518
+ return_dict=return_dict,
1519
+ )
1520
+
1521
+ sequence_output = outputs[0]
1522
+
1523
+ sequence_output = self.dropout(sequence_output)
1524
+ logits = self.classifier(sequence_output)
1525
+
1526
+ loss = None
1527
+ if labels is not None:
1528
+ loss_fct = CrossEntropyLoss()
1529
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1530
+
1531
+ if not return_dict:
1532
+ output = (logits,) + outputs[2:]
1533
+ return ((loss,) + output) if loss is not None else output
1534
+
1535
+ return TokenClassifierOutput(
1536
+ loss=loss,
1537
+ logits=logits,
1538
+ hidden_states=outputs.hidden_states,
1539
+ attentions=outputs.attentions,
1540
+ )
1541
+
1542
+
1543
+ @add_start_docstrings(
1544
+ """
1545
+ CANINE Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1546
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1547
+ """,
1548
+ CANINE_START_DOCSTRING,
1549
+ )
1550
+ class CanineForQuestionAnswering(CaninePreTrainedModel):
1551
+ def __init__(self, config):
1552
+ super().__init__(config)
1553
+ self.num_labels = config.num_labels
1554
+
1555
+ self.canine = CanineModel(config)
1556
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1557
+
1558
+ # Initialize weights and apply final processing
1559
+ self.post_init()
1560
+
1561
+ @add_start_docstrings_to_model_forward(CANINE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1562
+ @add_code_sample_docstrings(
1563
+ checkpoint="Splend1dchan/canine-c-squad",
1564
+ output_type=QuestionAnsweringModelOutput,
1565
+ config_class=_CONFIG_FOR_DOC,
1566
+ expected_output="'nice puppet'",
1567
+ expected_loss=8.81,
1568
+ )
1569
+ def forward(
1570
+ self,
1571
+ input_ids: Optional[torch.LongTensor] = None,
1572
+ attention_mask: Optional[torch.FloatTensor] = None,
1573
+ token_type_ids: Optional[torch.LongTensor] = None,
1574
+ position_ids: Optional[torch.LongTensor] = None,
1575
+ head_mask: Optional[torch.FloatTensor] = None,
1576
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1577
+ start_positions: Optional[torch.LongTensor] = None,
1578
+ end_positions: Optional[torch.LongTensor] = None,
1579
+ output_attentions: Optional[bool] = None,
1580
+ output_hidden_states: Optional[bool] = None,
1581
+ return_dict: Optional[bool] = None,
1582
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1583
+ r"""
1584
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1585
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1586
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1587
+ are not taken into account for computing the loss.
1588
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1589
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1590
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1591
+ are not taken into account for computing the loss.
1592
+ """
1593
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1594
+
1595
+ outputs = self.canine(
1596
+ input_ids,
1597
+ attention_mask=attention_mask,
1598
+ token_type_ids=token_type_ids,
1599
+ position_ids=position_ids,
1600
+ head_mask=head_mask,
1601
+ inputs_embeds=inputs_embeds,
1602
+ output_attentions=output_attentions,
1603
+ output_hidden_states=output_hidden_states,
1604
+ return_dict=return_dict,
1605
+ )
1606
+
1607
+ sequence_output = outputs[0]
1608
+
1609
+ logits = self.qa_outputs(sequence_output)
1610
+ start_logits, end_logits = logits.split(1, dim=-1)
1611
+ start_logits = start_logits.squeeze(-1)
1612
+ end_logits = end_logits.squeeze(-1)
1613
+
1614
+ total_loss = None
1615
+ if start_positions is not None and end_positions is not None:
1616
+ # If we are on multi-GPU, split add a dimension
1617
+ if len(start_positions.size()) > 1:
1618
+ start_positions = start_positions.squeeze(-1)
1619
+ if len(end_positions.size()) > 1:
1620
+ end_positions = end_positions.squeeze(-1)
1621
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1622
+ ignored_index = start_logits.size(1)
1623
+ start_positions.clamp_(0, ignored_index)
1624
+ end_positions.clamp_(0, ignored_index)
1625
+
1626
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1627
+ start_loss = loss_fct(start_logits, start_positions)
1628
+ end_loss = loss_fct(end_logits, end_positions)
1629
+ total_loss = (start_loss + end_loss) / 2
1630
+
1631
+ if not return_dict:
1632
+ output = (start_logits, end_logits) + outputs[2:]
1633
+ return ((total_loss,) + output) if total_loss is not None else output
1634
+
1635
+ return QuestionAnsweringModelOutput(
1636
+ loss=total_loss,
1637
+ start_logits=start_logits,
1638
+ end_logits=end_logits,
1639
+ hidden_states=outputs.hidden_states,
1640
+ attentions=outputs.attentions,
1641
+ )
1642
+
1643
+
1644
+ __all__ = [
1645
+ "CanineForMultipleChoice",
1646
+ "CanineForQuestionAnswering",
1647
+ "CanineForSequenceClassification",
1648
+ "CanineForTokenClassification",
1649
+ "CanineLayer",
1650
+ "CanineModel",
1651
+ "CaninePreTrainedModel",
1652
+ "load_tf_weights_in_canine",
1653
+ ]
docs/transformers/src/transformers/models/canine/tokenization_canine.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright Google AI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for CANINE."""
16
+
17
+ from typing import Dict, List, Optional
18
+
19
+ from ...tokenization_utils import AddedToken, PreTrainedTokenizer
20
+ from ...utils import logging
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ # Unicode defines 1,114,112 total “codepoints”
27
+ UNICODE_VOCAB_SIZE = 1114112
28
+
29
+ # Below: Constants defining canonical codepoints for special, pseudo-characters.
30
+ # Copied from https://github.com/google-research/language/blob/master/language/canine/special_codepoints.py
31
+ PAD = 0
32
+ CLS = 0xE000
33
+ SEP = 0xE001
34
+ BOS = 0xE002
35
+ MASK = 0xE003
36
+ RESERVED = 0xE004
37
+
38
+ # Maps special codepoints to human-readable names.
39
+ SPECIAL_CODEPOINTS: Dict[int, str] = {
40
+ # Special symbols are represented using codepoints values that are valid,
41
+ # but designated as "Private Use", meaning that they will never be assigned
42
+ # characters by the Unicode Consortium, and are thus safe for use here.
43
+ #
44
+ # NOTE: Do *NOT* add any sort of [UNK_CHAR] here. They are explicitly
45
+ # excluded and should fail with a hard error.
46
+ CLS: "[CLS]",
47
+ SEP: "[SEP]",
48
+ BOS: "[BOS]",
49
+ MASK: "[MASK]",
50
+ PAD: "[PAD]",
51
+ RESERVED: "[RESERVED]",
52
+ }
53
+
54
+ # Maps special codepoint human-readable names to their codepoint values.
55
+ SPECIAL_CODEPOINTS_BY_NAME: Dict[str, int] = {name: codepoint for codepoint, name in SPECIAL_CODEPOINTS.items()}
56
+
57
+
58
+ class CanineTokenizer(PreTrainedTokenizer):
59
+ r"""
60
+ Construct a CANINE tokenizer (i.e. a character splitter). It turns text into a sequence of characters, and then
61
+ converts each character into its Unicode code point.
62
+
63
+ [`CanineTokenizer`] inherits from [`PreTrainedTokenizer`].
64
+
65
+ Refer to superclass [`PreTrainedTokenizer`] for usage examples and documentation concerning parameters.
66
+
67
+ Args:
68
+ model_max_length (`int`, *optional*, defaults to 2048):
69
+ The maximum sentence length the model accepts.
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ bos_token=chr(CLS),
75
+ eos_token=chr(SEP),
76
+ sep_token=chr(SEP),
77
+ cls_token=chr(CLS),
78
+ pad_token=chr(PAD),
79
+ mask_token=chr(MASK),
80
+ add_prefix_space=False,
81
+ model_max_length=2048,
82
+ **kwargs,
83
+ ):
84
+ bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
85
+ eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
86
+ sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token
87
+ cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token
88
+ pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
89
+
90
+ # Mask token behave like a normal word, i.e. include the space before it
91
+ mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
92
+
93
+ # Creates a mapping for looking up the IDs of special symbols.
94
+ self._special_codepoints: Dict[str, int] = {}
95
+ for codepoint, name in SPECIAL_CODEPOINTS.items():
96
+ self._special_codepoints[name] = codepoint
97
+
98
+ # Creates a mapping for looking up the string forms of special symbol IDs.
99
+ self._special_codepoint_strings: Dict[int, str] = {
100
+ codepoint: name for name, codepoint in self._special_codepoints.items()
101
+ }
102
+
103
+ self._unicode_vocab_size = UNICODE_VOCAB_SIZE
104
+ self._num_special_tokens = len(self._special_codepoints)
105
+
106
+ super().__init__(
107
+ bos_token=bos_token,
108
+ eos_token=eos_token,
109
+ sep_token=sep_token,
110
+ cls_token=cls_token,
111
+ pad_token=pad_token,
112
+ mask_token=mask_token,
113
+ add_prefix_space=add_prefix_space,
114
+ model_max_length=model_max_length,
115
+ **kwargs,
116
+ )
117
+
118
+ @property
119
+ def vocab_size(self) -> int:
120
+ return self._unicode_vocab_size
121
+
122
+ def get_vocab(self):
123
+ vocab = {chr(i): i for i in range(self.vocab_size)}
124
+ vocab.update(self.added_tokens_encoder)
125
+ return vocab
126
+
127
+ def _tokenize(self, text: str) -> List[str]:
128
+ """Tokenize a string (i.e. perform character splitting)."""
129
+ return list(text)
130
+
131
+ def _convert_token_to_id(self, token: str) -> int:
132
+ """Converts a token (i.e. a Unicode character) in an id (i.e. its integer Unicode code point value)."""
133
+ try:
134
+ return ord(token)
135
+ except TypeError:
136
+ raise ValueError(f"invalid token: '{token}'")
137
+
138
+ def _convert_id_to_token(self, index: int) -> str:
139
+ """
140
+ Converts a Unicode code point (integer) in a token (str). In case it's a special code point, convert to
141
+ human-readable format.
142
+ """
143
+ try:
144
+ if index in SPECIAL_CODEPOINTS:
145
+ return SPECIAL_CODEPOINTS[index]
146
+ return chr(index)
147
+ except TypeError:
148
+ raise ValueError(f"invalid id: {index}")
149
+
150
+ def convert_tokens_to_string(self, tokens):
151
+ return "".join(tokens)
152
+
153
+ def build_inputs_with_special_tokens(
154
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
155
+ ) -> List[int]:
156
+ """
157
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
158
+ adding special tokens. A CANINE sequence has the following format:
159
+
160
+ - single sequence: `[CLS] X [SEP]`
161
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
162
+
163
+ Args:
164
+ token_ids_0 (`List[int]`):
165
+ List of IDs to which the special tokens will be added.
166
+ token_ids_1 (`List[int]`, *optional*):
167
+ Optional second list of IDs for sequence pairs.
168
+
169
+ Returns:
170
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
171
+ """
172
+ sep = [self.sep_token_id]
173
+ cls = [self.cls_token_id]
174
+
175
+ result = cls + token_ids_0 + sep
176
+ if token_ids_1 is not None:
177
+ result += token_ids_1 + sep
178
+ return result
179
+
180
+ def get_special_tokens_mask(
181
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
182
+ ) -> List[int]:
183
+ """
184
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
185
+ special tokens using the tokenizer `prepare_for_model` method.
186
+
187
+ Args:
188
+ token_ids_0 (`List[int]`):
189
+ List of IDs.
190
+ token_ids_1 (`List[int]`, *optional*):
191
+ Optional second list of IDs for sequence pairs.
192
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
193
+ Whether or not the token list is already formatted with special tokens for the model.
194
+
195
+ Returns:
196
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
197
+ """
198
+ if already_has_special_tokens:
199
+ return super().get_special_tokens_mask(
200
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
201
+ )
202
+
203
+ result = [1] + ([0] * len(token_ids_0)) + [1]
204
+ if token_ids_1 is not None:
205
+ result += ([0] * len(token_ids_1)) + [1]
206
+ return result
207
+
208
+ def create_token_type_ids_from_sequences(
209
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
210
+ ) -> List[int]:
211
+ """
212
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. A CANINE
213
+ sequence pair mask has the following format:
214
+
215
+ ```
216
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
217
+ | first sequence | second sequence |
218
+ ```
219
+
220
+ If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
221
+
222
+ Args:
223
+ token_ids_0 (`List[int]`):
224
+ List of IDs.
225
+ token_ids_1 (`List[int]`, *optional*):
226
+ Optional second list of IDs for sequence pairs.
227
+
228
+ Returns:
229
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
230
+ """
231
+ sep = [self.sep_token_id]
232
+ cls = [self.cls_token_id]
233
+
234
+ result = len(cls + token_ids_0 + sep) * [0]
235
+ if token_ids_1 is not None:
236
+ result += len(token_ids_1 + sep) * [1]
237
+ return result
238
+
239
+ # CanineTokenizer has no vocab file
240
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None):
241
+ return ()
242
+
243
+
244
+ __all__ = ["CanineTokenizer"]
docs/transformers/src/transformers/models/chameleon/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_chameleon import *
22
+ from .image_processing_chameleon import *
23
+ from .modeling_chameleon import *
24
+ from .processing_chameleon import *
25
+ else:
26
+ import sys
27
+
28
+ _file = globals()["__file__"]
29
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/src/transformers/models/chameleon/configuration_chameleon.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """chameleon model configuration"""
16
+
17
+ from typing import List
18
+
19
+ from ...configuration_utils import PretrainedConfig
20
+ from ...utils import logging
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ class ChameleonVQVAEConfig(PretrainedConfig):
27
+ r"""
28
+ This is the configuration class to store the configuration of a [`ChameleonVQModel`]. It is used to instantiate a
29
+ `ChameleonVQModel` according to the specified arguments, defining the model architecture.
30
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
31
+ documentation from [`PretrainedConfig`] for more information. Instantiating a
32
+ configuration with the defaults will yield a similar configuration to the VQModel of the
33
+ [meta/chameleon-7B](https://huggingface.co/meta/chameleon-7B).
34
+
35
+ Args:
36
+ embed_dim (`int`, *optional*, defaults to 256):
37
+ Dimensionality of each embedding vector.
38
+ num_embeddings (`int`, *optional*, defaults to 8192):
39
+ Number of codebook embeddings.
40
+ double_latent (`bool`, *optional*, defaults to `False`):
41
+ Whether to use double z channels.
42
+ latent_channels (`int`, *optional*, defaults to 256):
43
+ Number of channels for the latent space.
44
+ resolution (`int`, *optional*, defaults to 512):
45
+ Resolution of the input images.
46
+ in_channels (`int`, *optional*, defaults to 3):
47
+ Number of input channels.
48
+ base_channels (`int`, *optional*, defaults to 128):
49
+ Base channel count.
50
+ channel_multiplier (`List[int]`, *optional*, defaults to `[1, 1, 2, 2, 4]`):
51
+ Channel multipliers for each resolution.
52
+ num_res_blocks (`int`, *optional*, defaults to 2):
53
+ Number of residual blocks.
54
+ attn_resolutions (`List[int]`, *optional*):
55
+ Resolutions to apply attention.
56
+ dropout (`float`, *optional*, defaults to 0.0):
57
+ Dropout rate.
58
+ attn_type (`str`, *optional*, defaults to `"vanilla"`):
59
+ Attention type used in VQ-GAN encoder. Can be "vanilla" or None.
60
+ initializer_range (`float`, *optional*, defaults to 0.02):
61
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
62
+ """
63
+
64
+ model_type = "chameleon_vqgan"
65
+ base_config_key = "vq_config"
66
+
67
+ def __init__(
68
+ self,
69
+ embed_dim: int = 256,
70
+ num_embeddings: int = 8192,
71
+ double_latent: bool = False,
72
+ latent_channels: int = 256,
73
+ resolution: int = 512,
74
+ in_channels: int = 3,
75
+ base_channels: int = 128,
76
+ channel_multiplier: List[int] = [1, 1, 2, 2, 4],
77
+ num_res_blocks: int = 2,
78
+ attn_resolutions: List[int] = None,
79
+ dropout: float = 0.0,
80
+ attn_type: str = "vanilla",
81
+ initializer_range=0.02,
82
+ **kwargs,
83
+ ):
84
+ super().__init__(**kwargs)
85
+ self.embed_dim = embed_dim
86
+ self.num_embeddings = num_embeddings
87
+ self.double_latent = double_latent
88
+ self.latent_channels = latent_channels
89
+ self.resolution = resolution
90
+ self.in_channels = in_channels
91
+ self.base_channels = base_channels
92
+ self.channel_multiplier = channel_multiplier
93
+ self.num_res_blocks = num_res_blocks
94
+ self.attn_resolutions = attn_resolutions
95
+ self.dropout = dropout
96
+ self.attn_type = attn_type
97
+ self.initializer_range = initializer_range
98
+
99
+
100
+ class ChameleonConfig(PretrainedConfig):
101
+ r"""
102
+ This is the configuration class to store the configuration of a [`ChameleonModel`]. It is used to instantiate a
103
+ chameleon model according to the specified arguments, defining the model architecture. Instantiating a
104
+ configuration with the defaults will yield a similar configuration to that of the
105
+ [meta/chameleon-7B](https://huggingface.co/meta/chameleon-7B).
106
+
107
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
108
+ documentation from [`PretrainedConfig`] for more information.
109
+
110
+
111
+ Args:
112
+ vocab_size (`int`, *optional*, defaults to 65536):
113
+ Vocabulary size of the chameleon model. Defines the number of different tokens that can be represented by the
114
+ `inputs_ids` passed when calling [`ChameleonModel`]; this includes text and image tokens.
115
+ hidden_size (`int`, *optional*, defaults to 4096):
116
+ Dimension of the hidden representations.
117
+ intermediate_size (`int`, *optional*, defaults to 11008):
118
+ Dimension of the MLP representations.
119
+ num_hidden_layers (`int`, *optional*, defaults to 32):
120
+ Number of hidden layers in the Transformer decoder.
121
+ num_attention_heads (`int`, *optional*, defaults to 32):
122
+ Number of attention heads for each attention layer in the Transformer decoder.
123
+ num_key_value_heads (`int`, *optional*, defaults to 32):
124
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
125
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
126
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
127
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
128
+ by meanpooling all the original heads within that group. For more details checkout [this
129
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
130
+ `num_attention_heads`.
131
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
132
+ The non-linear activation function (function or string) in the decoder.
133
+ max_position_embeddings (`int`, *optional*, defaults to 4096):
134
+ The maximum sequence length that this model might ever be used with. Chameleon supports up to 4096 tokens.
135
+ initializer_range (`float`, *optional*, defaults to 0.02):
136
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
137
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
138
+ The epsilon used by the rms normalization layers.
139
+ use_cache (`bool`, *optional*, defaults to `True`):
140
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
141
+ relevant if `config.is_decoder=True`.
142
+ pad_token_id (`int`, *optional*):
143
+ Padding token id.
144
+ bos_token_id (`int`, *optional*, defaults to 1):
145
+ Beginning of stream token id.
146
+ eos_token_id (`int`, *optional*, defaults to 2):
147
+ End of stream token id.
148
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
149
+ Whether to tie weight embeddings
150
+ rope_theta (`float`, *optional*, defaults to 10000.0):
151
+ The base period of the RoPE embeddings.
152
+ rope_scaling (`Dict`, *optional*):
153
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
154
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
155
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
156
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
157
+ these scaling strategies behave:
158
+ https://www.reddit.com/r/Localchameleon/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
159
+ experimental feature, subject to breaking API changes in future versions.
160
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
161
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
162
+ attention_dropout (`float`, *optional*, defaults to 0.0):
163
+ The dropout ratio for the attention probabilities.
164
+ model_parallel_size (`int`, *optional*, defaults to 1):
165
+ Number of shards used when training the model. This will be used in qk layernorm because the original Chameleon inference
166
+ doesn't do reduction in those layers and each rank has its own biases.
167
+ swin_norm (`bool`, *optional*, defaults to `False`):
168
+ Use Swin Transformer normalization.
169
+ vq_config (`dict`, *optional*):
170
+ ChameleonVQConfig instance containing the configuration for the VQ-VAE model.
171
+ vocabulary_map (`dict`, *optional*):
172
+ A dictionary containing the vocabulary map from the tokenizer. Used to obtain tokens from the image inputs.
173
+ mlp_bias (`bool`, *optional*, defaults to `False`):
174
+ Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
175
+
176
+
177
+ ```python
178
+ >>> from transformers import ChameleonModel, ChameleonConfig
179
+
180
+ >>> # Initializing a chameleon chameleon-7b style configuration
181
+ >>> configuration = ChameleonConfig()
182
+
183
+ >>> # Initializing a model from the chameleon-7b style configuration
184
+ >>> model = ChameleonModel(configuration)
185
+
186
+ >>> # Accessing the model configuration
187
+ >>> configuration = model.config
188
+ ```"""
189
+
190
+ model_type = "chameleon"
191
+ sub_configs = {"vq_config": ChameleonVQVAEConfig}
192
+ keys_to_ignore_at_inference = ["past_key_values"]
193
+
194
+ def __init__(
195
+ self,
196
+ vocab_size=65536,
197
+ hidden_size=4096,
198
+ intermediate_size=11008,
199
+ num_hidden_layers=32,
200
+ num_attention_heads=32,
201
+ num_key_value_heads=32,
202
+ hidden_act="silu",
203
+ max_position_embeddings=4096,
204
+ initializer_range=0.02,
205
+ rms_norm_eps=1e-05,
206
+ use_cache=True,
207
+ pad_token_id=None,
208
+ bos_token_id=1,
209
+ eos_token_id=2,
210
+ tie_word_embeddings=False,
211
+ rope_theta=10000.0,
212
+ rope_scaling=None,
213
+ attention_bias=False,
214
+ attention_dropout=0.0,
215
+ model_parallel_size=1,
216
+ swin_norm=False,
217
+ vq_config=None,
218
+ vocabulary_map=None,
219
+ mlp_bias=False,
220
+ **kwargs,
221
+ ):
222
+ self.vocab_size = vocab_size
223
+ self.max_position_embeddings = max_position_embeddings
224
+ self.hidden_size = hidden_size
225
+ self.intermediate_size = intermediate_size
226
+ self.num_hidden_layers = num_hidden_layers
227
+ self.num_attention_heads = num_attention_heads
228
+ self.mlp_bias = mlp_bias
229
+
230
+ self.num_key_value_heads = num_key_value_heads
231
+ self.hidden_act = hidden_act
232
+ self.initializer_range = initializer_range
233
+ self.rms_norm_eps = rms_norm_eps
234
+ self.use_cache = use_cache
235
+ self.rope_theta = rope_theta
236
+ self.rope_scaling = rope_scaling
237
+ self._rope_scaling_validation()
238
+ self.attention_bias = attention_bias
239
+ self.attention_dropout = attention_dropout
240
+ self.model_parallel_size = model_parallel_size
241
+ self.swin_norm = swin_norm
242
+
243
+ if vq_config is None:
244
+ vq_config = {}
245
+ logger.info("vq_config is None. initializing the ChameleonVQConfig with default values.")
246
+
247
+ self.vq_config = ChameleonVQVAEConfig(**vq_config)
248
+
249
+ self.vocabulary_map = vocabulary_map
250
+
251
+ super().__init__(
252
+ pad_token_id=pad_token_id,
253
+ bos_token_id=bos_token_id,
254
+ eos_token_id=eos_token_id,
255
+ tie_word_embeddings=tie_word_embeddings,
256
+ **kwargs,
257
+ )
258
+
259
+ def _rope_scaling_validation(self):
260
+ """
261
+ Validate the `rope_scaling` configuration.
262
+ """
263
+ if self.rope_scaling is None:
264
+ return
265
+
266
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
267
+ raise ValueError(
268
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
269
+ f"got {self.rope_scaling}"
270
+ )
271
+ rope_scaling_type = self.rope_scaling.get("type", None)
272
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
273
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
274
+ raise ValueError(
275
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
276
+ )
277
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
278
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
279
+
280
+
281
+ __all__ = ["ChameleonConfig", "ChameleonVQVAEConfig"]
docs/transformers/src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import argparse
15
+ import gc
16
+ import json
17
+ import os
18
+
19
+ import requests
20
+ import torch
21
+ import yaml
22
+ from accelerate import init_empty_weights
23
+ from PIL import Image
24
+
25
+ from transformers import (
26
+ ChameleonConfig,
27
+ ChameleonForConditionalGeneration,
28
+ ChameleonImageProcessor,
29
+ ChameleonProcessor,
30
+ )
31
+
32
+
33
+ try:
34
+ from transformers import LlamaTokenizerFast
35
+ except ImportError:
36
+ raise ValueError(
37
+ "Chameleon conversion supports only FastTokenizer and LlamaTokenizerFast can't be imported! "
38
+ "Update your `tokenizers` library and re-run the tokenizer conversion."
39
+ )
40
+
41
+ """
42
+ Sample usage:
43
+
44
+ ```
45
+ python src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py \
46
+ --input_dir /path/to/downloaded/chameleon/weights --model_size 7B --output_dir /output/path
47
+ ```
48
+
49
+ Thereafter, models can be loaded via:
50
+
51
+ ```py
52
+ from transformers import ChameleonForConditionalGeneration, LlamaTokenizerFast
53
+
54
+ model = ChameleonForConditionalGeneration.from_pretrained("/output/path")
55
+ tokenizer = LlamaTokenizerFast.from_pretrained("/output/path")
56
+ ```
57
+
58
+ Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
59
+ come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
60
+ """
61
+
62
+ NUM_SHARDS = {
63
+ "7B": 1,
64
+ "30B": 4,
65
+ }
66
+
67
+ VOCAB_SIZE = 65536
68
+
69
+
70
+ def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
71
+ return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)
72
+
73
+
74
+ def read_json(path):
75
+ with open(path, "r") as f:
76
+ return json.load(f)
77
+
78
+
79
+ def write_json(text, path):
80
+ with open(path, "w") as f:
81
+ json.dump(text, f)
82
+
83
+
84
+ def write_model(model_path, input_base_path, model_size, chameleon_version=1):
85
+ os.makedirs(model_path, exist_ok=True)
86
+ input_model_path = os.path.join(input_base_path, "models", model_size.lower())
87
+ params_path = os.path.join(input_model_path, "params.json")
88
+ consolidate_params_path = os.path.join(input_model_path, "consolidate_params.json")
89
+
90
+ params = read_json(params_path)
91
+ if os.path.isfile(consolidate_params_path):
92
+ params = {**params, **read_json(consolidate_params_path)}
93
+ num_shards = NUM_SHARDS[model_size]
94
+ model_parallel_size = params["model_parallel_size"]
95
+ params = params.get("model", params)
96
+ n_layers = params["n_layers"]
97
+ n_heads = params["n_heads"]
98
+ n_heads_per_shard = n_heads // num_shards
99
+ dim = params["dim"]
100
+ dims_per_head = dim // n_heads
101
+ base = params.get("rope_theta", 10000.0)
102
+ swin_norm = params["swin_norm"]
103
+ if base > 10000.0:
104
+ max_position_embeddings = 16384
105
+ else:
106
+ # Depending on the Chameleon version, the default max_position_embeddings has different values.
107
+ if chameleon_version == 1:
108
+ max_position_embeddings = 4096
109
+ else:
110
+ raise NotImplementedError(
111
+ f"Version {chameleon_version} of chameleon is not supported yet. "
112
+ "Current supported versions of chameleon are [1]."
113
+ )
114
+
115
+ if params.get("n_kv_heads", None) is not None:
116
+ num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
117
+ num_local_key_value_heads = n_heads_per_shard // num_key_value_heads
118
+ key_value_dim = dim // num_key_value_heads
119
+ else: # compatibility with other checkpoints
120
+ num_key_value_heads = n_heads
121
+ num_local_key_value_heads = n_heads_per_shard
122
+ key_value_dim = dim
123
+
124
+ print(f"Fetching all parameters from the checkpoint at {input_model_path}.")
125
+ # Load weights
126
+ if num_shards == 1:
127
+ # Not sharded
128
+ # (The sharded implementation would also work, but this is simpler.)
129
+ loaded = None
130
+ for possible_name in ["consolidated.pth", "consolidated.00.pth"]:
131
+ possible_path = os.path.join(input_model_path, possible_name)
132
+ if os.path.exists(possible_path):
133
+ loaded = torch.load(possible_path, map_location="cpu", weights_only=True)
134
+ break
135
+ assert loaded is not None
136
+ else:
137
+ # Sharded
138
+ loaded = [
139
+ torch.load(
140
+ os.path.join(input_model_path, f"consolidated.{i:02d}.pth"), map_location="cpu", weights_only=True
141
+ )
142
+ for i in range(num_shards)
143
+ ]
144
+
145
+ # permute for sliced rotary
146
+ def permute(w, n_heads, dim1=dim, dim2=dim):
147
+ return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
148
+
149
+ # Load weights to the state dict
150
+ state_dict = {}
151
+ for layer_i in range(n_layers):
152
+ if num_shards == 1:
153
+ # Unsharded
154
+ state_dict.update(
155
+ {
156
+ f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
157
+ loaded[f"layers.{layer_i}.attention.wq.weight"], n_heads=n_heads
158
+ ),
159
+ f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
160
+ loaded[f"layers.{layer_i}.attention.wk.weight"],
161
+ n_heads=num_key_value_heads,
162
+ dim1=key_value_dim,
163
+ ),
164
+ f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"],
165
+ f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"],
166
+ f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"],
167
+ f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"],
168
+ f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"],
169
+ f"model.layers.{layer_i}.input_layernorm.weight": loaded[
170
+ f"layers.{layer_i}.attention_norm.weight"
171
+ ],
172
+ f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[
173
+ f"layers.{layer_i}.ffn_norm.weight"
174
+ ],
175
+ }
176
+ )
177
+ # qk_layernorm (see https://github.com/huggingface/transformers/pull/31534#issuecomment-2207354677)
178
+ state_dict[f"model.layers.{layer_i}.self_attn.q_norm.weight"] = (
179
+ loaded[f"layers.{layer_i}.attention.q_normalization.weight"]
180
+ .view(dims_per_head // 2, 2)
181
+ .t()
182
+ .reshape(1, -1)
183
+ .repeat_interleave(n_heads, 0)
184
+ )
185
+ state_dict[f"model.layers.{layer_i}.self_attn.q_norm.bias"] = (
186
+ loaded[f"layers.{layer_i}.attention.q_normalization.bias"]
187
+ .view(dims_per_head // 2, 2)
188
+ .t()
189
+ .reshape(1, -1)
190
+ .repeat_interleave(n_heads, 0)
191
+ )
192
+ state_dict[f"model.layers.{layer_i}.self_attn.k_norm.weight"] = (
193
+ loaded[f"layers.{layer_i}.attention.k_normalization.weight"]
194
+ .view(dims_per_head // 2, 2)
195
+ .t()
196
+ .reshape(1, -1)
197
+ .repeat_interleave(num_key_value_heads, 0)
198
+ )
199
+ state_dict[f"model.layers.{layer_i}.self_attn.k_norm.bias"] = (
200
+ loaded[f"layers.{layer_i}.attention.k_normalization.bias"]
201
+ .view(dims_per_head // 2, 2)
202
+ .t()
203
+ .reshape(1, -1)
204
+ .repeat_interleave(num_key_value_heads, 0)
205
+ )
206
+
207
+ else:
208
+ # Sharded
209
+ state_dict.update(
210
+ {
211
+ f"model.layers.{layer_i}.input_layernorm.weight": torch.stack(
212
+ [l[f"layers.{layer_i}.attention_norm.weight"] for l in loaded]
213
+ ).mean(dim=0),
214
+ f"model.layers.{layer_i}.post_attention_layernorm.weight": torch.stack(
215
+ [l[f"layers.{layer_i}.ffn_norm.weight"] for l in loaded]
216
+ ).mean(dim=0),
217
+ }
218
+ )
219
+ state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
220
+ torch.cat(
221
+ [
222
+ loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim)
223
+ for i in range(num_shards)
224
+ ],
225
+ dim=0,
226
+ ).reshape(dim, dim),
227
+ n_heads=n_heads,
228
+ )
229
+
230
+ state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
231
+ torch.cat(
232
+ [
233
+ loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(
234
+ num_local_key_value_heads, dims_per_head, dim
235
+ )
236
+ for i in range(num_shards)
237
+ ],
238
+ dim=0,
239
+ ).reshape(key_value_dim, dim),
240
+ n_heads=num_key_value_heads,
241
+ dim1=key_value_dim,
242
+ )
243
+
244
+ # qk_layernorm (see https://github.com/huggingface/transformers/pull/31534#issuecomment-2207354677)
245
+ state_dict[f"model.layers.{layer_i}.self_attn.q_norm.weight"] = (
246
+ torch.cat([l[f"layers.{layer_i}.attention.q_normalization.weight"].unsqueeze(0) for l in loaded])
247
+ .view(num_shards, dims_per_head // 2, 2)
248
+ .transpose(1, 2)
249
+ .reshape(num_shards, -1)
250
+ .repeat_interleave(n_heads // num_shards, 0)
251
+ )
252
+ state_dict[f"model.layers.{layer_i}.self_attn.q_norm.bias"] = (
253
+ torch.cat([l[f"layers.{layer_i}.attention.q_normalization.bias"].unsqueeze(0) for l in loaded])
254
+ .view(num_shards, dims_per_head // 2, 2)
255
+ .transpose(1, 2)
256
+ .reshape(num_shards, -1)
257
+ .repeat_interleave(n_heads // num_shards, 0)
258
+ )
259
+ state_dict[f"model.layers.{layer_i}.self_attn.k_norm.weight"] = (
260
+ torch.cat([l[f"layers.{layer_i}.attention.k_normalization.weight"].unsqueeze(0) for l in loaded])
261
+ .view(num_shards, dims_per_head // 2, 2)
262
+ .transpose(1, 2)
263
+ .reshape(num_shards, -1)
264
+ .repeat_interleave(num_key_value_heads // num_shards, 0)
265
+ )
266
+ state_dict[f"model.layers.{layer_i}.self_attn.k_norm.bias"] = (
267
+ torch.cat([l[f"layers.{layer_i}.attention.k_normalization.bias"].unsqueeze(0) for l in loaded])
268
+ .view(num_shards, dims_per_head // 2, 2)
269
+ .transpose(1, 2)
270
+ .reshape(num_shards, -1)
271
+ .repeat_interleave(num_key_value_heads // num_shards, 0)
272
+ )
273
+
274
+ state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
275
+ [
276
+ loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(
277
+ num_local_key_value_heads, dims_per_head, dim
278
+ )
279
+ for i in range(num_shards)
280
+ ],
281
+ dim=0,
282
+ ).reshape(key_value_dim, dim)
283
+
284
+ state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
285
+ [loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1
286
+ )
287
+ state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
288
+ [loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0
289
+ )
290
+ state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
291
+ [loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1
292
+ )
293
+ state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
294
+ [loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0
295
+ )
296
+
297
+ if num_shards == 1:
298
+ # Unsharded
299
+ state_dict.update(
300
+ {
301
+ "model.embed_tokens.weight": loaded["tok_embeddings.weight"],
302
+ "model.norm.weight": loaded["norm.weight"],
303
+ "lm_head.weight": loaded["output.weight"],
304
+ }
305
+ )
306
+ else:
307
+ state_dict.update(
308
+ {
309
+ "model.embed_tokens.weight": torch.cat(
310
+ [loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1
311
+ ),
312
+ "model.norm.weight": torch.stack([loaded[i]["norm.weight"] for i in range(num_shards)]).mean(dim=0),
313
+ "lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0),
314
+ }
315
+ )
316
+
317
+ # Load VQGAN weights
318
+ vqgan_path = os.path.join(input_base_path, "tokenizer/vqgan.ckpt")
319
+ vqgan_state_dict = torch.load(vqgan_path, map_location="cpu", weights_only=True)["state_dict"]
320
+ for k, v in vqgan_state_dict.items():
321
+ if "decoder" in k:
322
+ continue # we dont do image generation yet
323
+ state_dict[f"model.vqmodel.{k}"] = v
324
+
325
+ # Write configs
326
+ ffn_dim_multiplier = params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in params else 1
327
+ multiple_of = params["multiple_of"] if "multiple_of" in params else 256
328
+
329
+ with open(os.path.join(input_base_path, "tokenizer/text_tokenizer.json")) as tokenizer_file:
330
+ tokenizer_config = json.load(tokenizer_file)
331
+ vocabulary_map = tokenizer_config["model"]["vocab"]
332
+ vocabulary_map["<image>"] = vocabulary_map[
333
+ "<reserved08707>"
334
+ ] # use a reserved token instead of adding a new one
335
+ del vocabulary_map["<reserved08707>"]
336
+
337
+ for token in tokenizer_config["added_tokens"]:
338
+ if token["content"] == "<reserved08707>":
339
+ token["content"] = "<image>"
340
+
341
+ with open(os.path.join(input_base_path, "tokenizer/text_tokenizer_modified.json"), "w") as f:
342
+ json.dump(tokenizer_config, f) # save the new file to init tokenizer later
343
+
344
+ vq_keys_to_replace = [
345
+ ("ch", "base_channels"),
346
+ ("out_ch", "out_channels"),
347
+ ("n_embed", "num_embeddings"),
348
+ ("ch_mult", "channel_multiplier"),
349
+ ("double_z", "double_latent"),
350
+ ("z_channels", "latent_channels"),
351
+ ]
352
+ with open(os.path.join(input_base_path, "tokenizer/vqgan.yaml")) as vqgan_cfg_file:
353
+ vq_config = yaml.safe_load(vqgan_cfg_file)["model"]["params"]
354
+ vq_config.update(**vq_config["ddconfig"])
355
+ for old, new in vq_keys_to_replace:
356
+ vq_config[new] = vq_config[old]
357
+ del vq_config["ddconfig"]
358
+ del vq_config["ckpt_path"]
359
+ del vq_config["lossconfig"]
360
+
361
+ config = ChameleonConfig(
362
+ hidden_size=dim,
363
+ intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of),
364
+ num_attention_heads=params["n_heads"],
365
+ num_hidden_layers=params["n_layers"],
366
+ rms_norm_eps=params["norm_eps"],
367
+ num_key_value_heads=num_key_value_heads,
368
+ vocab_size=VOCAB_SIZE,
369
+ rope_theta=base,
370
+ max_position_embeddings=max_position_embeddings,
371
+ model_parallel_size=model_parallel_size,
372
+ swin_norm=swin_norm,
373
+ vq_config=vq_config,
374
+ vocabulary_map=vocabulary_map,
375
+ )
376
+ with init_empty_weights():
377
+ model = ChameleonForConditionalGeneration(config)
378
+
379
+ model.load_state_dict(state_dict, assign=True, strict=False)
380
+ model.save_pretrained(model_path, safe_serialization=True)
381
+
382
+ # Load and save the processor
383
+ tokenizer = LlamaTokenizerFast(
384
+ tokenizer_file=os.path.join(input_base_path, "tokenizer/text_tokenizer_modified.json"), legacy=False
385
+ )
386
+ tokenizer.sep_token_id = 8710 # assign <reserved08706> to sep so that we can append it after input text
387
+ tokenizer.pad_token_id = 1 # assing <pad> to special pad_token
388
+ image_processor = ChameleonImageProcessor()
389
+ processor = ChameleonProcessor(image_processor=image_processor, tokenizer=tokenizer)
390
+ processor.save_pretrained(model_path)
391
+
392
+ # Make space so we can load the model properly now.
393
+ del state_dict
394
+ del loaded
395
+ del vqgan_state_dict
396
+ gc.collect()
397
+
398
+ # Short inference on a few examples to check if generation makes sense
399
+ # taken from https://github.com/facebookresearch/chameleon/blob/7a72f40aa5f462965c8374f25257f55b65b25ff4/data/prompts_for_human_evaluations.jsonl
400
+ print("Loading the checkpoint in a Chameleon model...")
401
+ print("*" * 100)
402
+ model = ChameleonForConditionalGeneration.from_pretrained(
403
+ model_path, attn_implementation="eager", torch_dtype=torch.bfloat16, device_map="auto"
404
+ )
405
+ processor = ChameleonProcessor.from_pretrained(model_path)
406
+
407
+ prompt = "I'm very intrigued by this work of art:<image>Please tell me about the artist."
408
+ image = Image.open(
409
+ requests.get(
410
+ "https://uploads4.wikiart.org/images/paul-klee/death-for-the-idea-1915.jpg!Large.jpg", stream=True
411
+ ).raw
412
+ )
413
+ inputs = processor(prompt, images=image, return_tensors="pt").to(model.device, torch.bfloat16)
414
+ length = inputs.input_ids.shape[1]
415
+
416
+ out = model.generate(**inputs, max_new_tokens=40, do_sample=False)
417
+ generated_text = processor.batch_decode(out[:, length:], skip_special_tokens=True)[0]
418
+
419
+ print(f"Generation for single-image: {generated_text}")
420
+ print("*" * 100)
421
+
422
+ # Multi-image example
423
+ prompt = "I used to know a lot about constellations when I was younger, but as I grew older, I forgot most of what I knew. These are the only two constellations that I really remember now.<image><image>I would like for you to tell me about 3 more constellations and give me a little bit of history about the constellation."
424
+ image = Image.open(
425
+ requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw
426
+ )
427
+ image_2 = Image.open(
428
+ requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw
429
+ )
430
+
431
+ inputs = processor(prompt, images=[image, image_2], return_tensors="pt").to(model.device, dtype=torch.bfloat16)
432
+ length = inputs.input_ids.shape[1]
433
+ out = model.generate(**inputs, max_new_tokens=50, do_sample=False)
434
+ generated_text = processor.batch_decode(out[:, length:], skip_special_tokens=True)[0]
435
+
436
+ print(f"Generation for multi-image: {generated_text}")
437
+
438
+
439
+ def main():
440
+ parser = argparse.ArgumentParser()
441
+ parser.add_argument(
442
+ "--input_dir",
443
+ help="Location of Chameleon weights",
444
+ )
445
+ parser.add_argument(
446
+ "--model_size",
447
+ choices=["7B", "30B"],
448
+ help=""
449
+ " models correspond to the finetuned versions, and are specific to the Chameleon official release. For more details on Chameleon, checkout the original repo: https://github.com/facebookresearch/chameleon",
450
+ )
451
+ parser.add_argument(
452
+ "--output_dir",
453
+ help="Location to write HF model",
454
+ )
455
+ parser.add_argument(
456
+ "--test_inference",
457
+ action="store_true",
458
+ help="Whether to load the model for generation to test it's converted correctly.",
459
+ )
460
+ # Different Chameleon versions used different default values for max_position_embeddings, hence the need to be able to specify which version is being used.
461
+ parser.add_argument(
462
+ "--chameleon_version",
463
+ choices=[1],
464
+ default=1,
465
+ type=int,
466
+ help="Version of the Chameleon model to convert",
467
+ )
468
+ args = parser.parse_args()
469
+ write_model(
470
+ model_path=args.output_dir,
471
+ input_base_path=args.input_dir,
472
+ model_size=args.model_size,
473
+ chameleon_version=args.chameleon_version,
474
+ )
475
+
476
+
477
+ if __name__ == "__main__":
478
+ main()
docs/transformers/src/transformers/models/chameleon/image_processing_chameleon.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Image processor class for Chameleon."""
16
+
17
+ from typing import Dict, List, Optional, Union
18
+
19
+ import numpy as np
20
+
21
+ from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
22
+ from ...image_transforms import (
23
+ get_resize_output_image_size,
24
+ resize,
25
+ to_channel_dimension_format,
26
+ )
27
+ from ...image_utils import (
28
+ ChannelDimension,
29
+ ImageInput,
30
+ PILImageResampling,
31
+ infer_channel_dimension_format,
32
+ is_scaled_image,
33
+ make_flat_list_of_images,
34
+ to_numpy_array,
35
+ valid_images,
36
+ validate_preprocess_arguments,
37
+ )
38
+ from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+ if is_vision_available():
44
+ import PIL
45
+
46
+
47
+ class ChameleonImageProcessor(BaseImageProcessor):
48
+ r"""
49
+ Constructs a Chameleon image processor.
50
+
51
+ Args:
52
+ do_resize (`bool`, *optional*, defaults to `True`):
53
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
54
+ `do_resize` in the `preprocess` method.
55
+ size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 512}`):
56
+ Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with
57
+ the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess`
58
+ method.
59
+ resample (`PILImageResampling`, *optional*, defaults to 1):
60
+ Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
61
+ do_center_crop (`bool`, *optional*, defaults to `True`):
62
+ Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the
63
+ `preprocess` method.
64
+ crop_size (`Dict[str, int]` *optional*, defaults to {"height": 512, "width": 512}):
65
+ Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess`
66
+ method.
67
+ do_rescale (`bool`, *optional*, defaults to `True`):
68
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
69
+ the `preprocess` method.
70
+ rescale_factor (`int` or `float`, *optional*, defaults to 0.0078):
71
+ Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
72
+ method.
73
+ do_normalize (`bool`, *optional*, defaults to `True`):
74
+ Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
75
+ image_mean (`float` or `List[float]`, *optional*, defaults to `[1.0, 1.0, 1.0]`):
76
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
77
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
78
+ image_std (`float` or `List[float]`, *optional*, defaults to `[1.0, 1.0, 1.0]`):
79
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
80
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
81
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
82
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
83
+ Whether to convert the image to RGB.
84
+ """
85
+
86
+ model_input_names = ["pixel_values"]
87
+
88
+ def __init__(
89
+ self,
90
+ do_resize: bool = True,
91
+ size: Dict[str, int] = None,
92
+ resample: PILImageResampling = PIL.Image.LANCZOS,
93
+ do_center_crop: bool = True,
94
+ crop_size: Dict[str, int] = None,
95
+ do_rescale: bool = True,
96
+ rescale_factor: Union[int, float] = 0.0078,
97
+ do_normalize: bool = True,
98
+ image_mean: Optional[Union[float, List[float]]] = None,
99
+ image_std: Optional[Union[float, List[float]]] = None,
100
+ do_convert_rgb: bool = True,
101
+ **kwargs,
102
+ ) -> None:
103
+ super().__init__(**kwargs)
104
+ size = size if size is not None else {"shortest_edge": 512}
105
+ size = get_size_dict(size, default_to_square=False)
106
+ crop_size = crop_size if crop_size is not None else {"height": 512, "width": 512}
107
+ crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
108
+
109
+ self.do_resize = do_resize
110
+ self.size = size
111
+ self.resample = resample
112
+ self.do_center_crop = do_center_crop
113
+ self.crop_size = crop_size
114
+ self.do_rescale = do_rescale
115
+ self.rescale_factor = rescale_factor
116
+ self.do_normalize = do_normalize
117
+ self.image_mean = image_mean if image_mean is not None else [1.0, 1.0, 1.0]
118
+ self.image_std = image_std if image_std is not None else [1.0, 1.0, 1.0]
119
+ self.do_convert_rgb = do_convert_rgb
120
+
121
+ # Copied from transformers.models.clip.image_processing_clip.CLIPImageProcessor.resize
122
+ def resize(
123
+ self,
124
+ image: np.ndarray,
125
+ size: Dict[str, int],
126
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
127
+ data_format: Optional[Union[str, ChannelDimension]] = None,
128
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
129
+ **kwargs,
130
+ ) -> np.ndarray:
131
+ """
132
+ Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
133
+ resized to keep the input aspect ratio.
134
+
135
+ Args:
136
+ image (`np.ndarray`):
137
+ Image to resize.
138
+ size (`Dict[str, int]`):
139
+ Size of the output image.
140
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
141
+ Resampling filter to use when resiizing the image.
142
+ data_format (`str` or `ChannelDimension`, *optional*):
143
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
144
+ input_data_format (`ChannelDimension` or `str`, *optional*):
145
+ The channel dimension format of the input image. If not provided, it will be inferred.
146
+ """
147
+ default_to_square = True
148
+ if "shortest_edge" in size:
149
+ size = size["shortest_edge"]
150
+ default_to_square = False
151
+ elif "height" in size and "width" in size:
152
+ size = (size["height"], size["width"])
153
+ else:
154
+ raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")
155
+
156
+ output_size = get_resize_output_image_size(
157
+ image,
158
+ size=size,
159
+ default_to_square=default_to_square,
160
+ input_data_format=input_data_format,
161
+ )
162
+ return resize(
163
+ image,
164
+ size=output_size,
165
+ resample=resample,
166
+ data_format=data_format,
167
+ input_data_format=input_data_format,
168
+ **kwargs,
169
+ )
170
+
171
+ @filter_out_non_signature_kwargs()
172
+ def preprocess(
173
+ self,
174
+ images: ImageInput,
175
+ do_resize: Optional[bool] = None,
176
+ size: Dict[str, int] = None,
177
+ resample: PILImageResampling = None,
178
+ do_center_crop: Optional[bool] = None,
179
+ crop_size: Optional[int] = None,
180
+ do_rescale: Optional[bool] = None,
181
+ rescale_factor: Optional[float] = None,
182
+ do_normalize: Optional[bool] = None,
183
+ image_mean: Optional[Union[float, List[float]]] = None,
184
+ image_std: Optional[Union[float, List[float]]] = None,
185
+ do_convert_rgb: Optional[bool] = None,
186
+ return_tensors: Optional[Union[str, TensorType]] = None,
187
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
188
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
189
+ ) -> PIL.Image.Image:
190
+ """
191
+ Preprocess an image or batch of images.
192
+
193
+ Args:
194
+ images (`ImageInput`):
195
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
196
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
197
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
198
+ Whether to resize the image.
199
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
200
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
201
+ the longest edge resized to keep the input aspect ratio.
202
+ resample (`int`, *optional*, defaults to `self.resample`):
203
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
204
+ has an effect if `do_resize` is set to `True`.
205
+ do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
206
+ Whether to center crop the image.
207
+ crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
208
+ Size of the center crop. Only has an effect if `do_center_crop` is set to `True`.
209
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
210
+ Whether to rescale the image.
211
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
212
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
213
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
214
+ Whether to normalize the image.
215
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
216
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
217
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
218
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
219
+ `True`.
220
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
221
+ Whether to convert the image to RGB.
222
+ return_tensors (`str` or `TensorType`, *optional*):
223
+ The type of tensors to return. Can be one of:
224
+ - Unset: Return a list of `np.ndarray`.
225
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
226
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
227
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
228
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
229
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
230
+ The channel dimension format for the output image. Can be one of:
231
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
232
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
233
+ - Unset: Use the channel dimension format of the input image.
234
+ input_data_format (`ChannelDimension` or `str`, *optional*):
235
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
236
+ from the input image. Can be one of:
237
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
238
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
239
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
240
+ """
241
+ do_resize = do_resize if do_resize is not None else self.do_resize
242
+ size = size if size is not None else self.size
243
+ size = get_size_dict(size, param_name="size", default_to_square=False)
244
+ resample = resample if resample is not None else self.resample
245
+ do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
246
+ crop_size = crop_size if crop_size is not None else self.crop_size
247
+ crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True)
248
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
249
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
250
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
251
+ image_mean = image_mean if image_mean is not None else self.image_mean
252
+ image_std = image_std if image_std is not None else self.image_std
253
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
254
+
255
+ images = make_flat_list_of_images(images)
256
+
257
+ if not valid_images(images):
258
+ raise ValueError(
259
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
260
+ "torch.Tensor, tf.Tensor or jax.ndarray."
261
+ )
262
+
263
+ validate_preprocess_arguments(
264
+ do_rescale=do_rescale,
265
+ rescale_factor=rescale_factor,
266
+ do_normalize=do_normalize,
267
+ image_mean=image_mean,
268
+ image_std=image_std,
269
+ do_center_crop=do_center_crop,
270
+ crop_size=crop_size,
271
+ do_resize=do_resize,
272
+ size=size,
273
+ resample=resample,
274
+ )
275
+
276
+ if do_convert_rgb:
277
+ images = [self.blend_rgba(image) for image in images]
278
+
279
+ # All transformations expect numpy arrays.
280
+ images = [to_numpy_array(image) for image in images]
281
+
282
+ if do_rescale and is_scaled_image(images[0]):
283
+ logger.warning_once(
284
+ "It looks like you are trying to rescale already rescaled images. If the input"
285
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
286
+ )
287
+
288
+ if input_data_format is None:
289
+ # We assume that all images have the same channel dimension format.
290
+ input_data_format = infer_channel_dimension_format(images[0])
291
+ all_images = []
292
+ for image in images:
293
+ if do_resize:
294
+ image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
295
+
296
+ if do_center_crop:
297
+ image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)
298
+
299
+ if do_rescale:
300
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
301
+
302
+ if do_normalize:
303
+ image = self.normalize(
304
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
305
+ )
306
+
307
+ all_images.append(image)
308
+ images = [
309
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
310
+ for image in all_images
311
+ ]
312
+
313
+ data = {"pixel_values": images}
314
+ return BatchFeature(data=data, tensor_type=return_tensors)
315
+
316
+ def blend_rgba(self, image: ImageInput) -> ImageInput:
317
+ """
318
+ Convert image to RGB by blending the transparency layer if it's in RGBA format.
319
+ If image is not `PIL.Image`, it si simply returned without modifications.
320
+
321
+ Args:
322
+ image (`ImageInput`):
323
+ Image to convert.
324
+ """
325
+
326
+ if not isinstance(image, PIL.Image.Image):
327
+ return image
328
+ elif image.mode == "RGB":
329
+ return image
330
+
331
+ img_rgba = np.array(image.convert("RGBA"))
332
+
333
+ # If there is no transparency layer, simple convert and return.
334
+ if not (img_rgba[:, :, 3] < 255).any():
335
+ return image.convert("RGB")
336
+
337
+ # There is a transparency layer, blend it with a white background.
338
+ # Calculate the alpha proportion for blending.
339
+ alpha = img_rgba[:, :, 3] / 255.0
340
+ img_rgb = (1 - alpha[:, :, np.newaxis]) * 255 + alpha[:, :, np.newaxis] * img_rgba[:, :, :3]
341
+ return PIL.Image.fromarray(img_rgb.astype("uint8"), "RGB")
342
+
343
+
344
+ __all__ = ["ChameleonImageProcessor"]
docs/transformers/src/transformers/models/chameleon/modeling_chameleon.py ADDED
@@ -0,0 +1,1673 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch Chameleon model."""
16
+
17
+ import math
18
+ from functools import cached_property
19
+ from typing import Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+ from torch.nn import CrossEntropyLoss
26
+
27
+ from ...activations import ACT2FN
28
+ from ...cache_utils import Cache, DynamicCache, StaticCache
29
+ from ...generation import GenerationMixin
30
+ from ...modeling_attn_mask_utils import AttentionMaskConverter
31
+ from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask
32
+ from ...modeling_outputs import (
33
+ BaseModelOutputWithPast,
34
+ CausalLMOutputWithPast,
35
+ )
36
+ from ...modeling_utils import PreTrainedModel
37
+ from ...pytorch_utils import ALL_LAYERNORM_LAYERS
38
+ from ...utils import (
39
+ add_code_sample_docstrings,
40
+ add_start_docstrings,
41
+ add_start_docstrings_to_model_forward,
42
+ is_torch_flex_attn_available,
43
+ is_torchdynamo_compiling,
44
+ logging,
45
+ replace_return_docstrings,
46
+ )
47
+ from .configuration_chameleon import ChameleonConfig, ChameleonVQVAEConfig
48
+
49
+
50
+ if is_torch_flex_attn_available():
51
+ from torch.nn.attention.flex_attention import BlockMask
52
+
53
+ from ...integrations.flex_attention import make_flex_block_causal_mask
54
+
55
+
56
+ logger = logging.get_logger(__name__)
57
+
58
+ _CONFIG_FOR_DOC = "ChameleonConfig"
59
+ _CHECKPOINT_FOR_DOC = "meta/chameleon-7b"
60
+ _EXPECTED_OUTPUT_SHAPE = [1, 7, 4096]
61
+ _SEQ_CLASS_EXPECTED_LOSS = 1.03
62
+ _SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'"
63
+
64
+
65
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Chameleon
66
+ class ChameleonRMSNorm(nn.Module):
67
+ def __init__(self, hidden_size, eps=1e-6):
68
+ """
69
+ ChameleonRMSNorm is equivalent to T5LayerNorm
70
+ """
71
+ super().__init__()
72
+ self.weight = nn.Parameter(torch.ones(hidden_size))
73
+ self.variance_epsilon = eps
74
+
75
+ def forward(self, hidden_states):
76
+ input_dtype = hidden_states.dtype
77
+ hidden_states = hidden_states.to(torch.float32)
78
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
79
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
80
+ return self.weight * hidden_states.to(input_dtype)
81
+
82
+ def extra_repr(self):
83
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
84
+
85
+
86
+ ALL_LAYERNORM_LAYERS.append(ChameleonRMSNorm)
87
+
88
+
89
+ # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Chameleon
90
+ # TODO(joao): add me back asap :)
91
+ class ChameleonRotaryEmbedding(nn.Module):
92
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
93
+ super().__init__()
94
+ self.scaling_factor = scaling_factor
95
+ self.dim = dim
96
+ self.max_position_embeddings = max_position_embeddings
97
+ self.base = base
98
+ inv_freq = 1.0 / (
99
+ self.base
100
+ ** (torch.arange(0, self.dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / self.dim)
101
+ )
102
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
103
+ # For BC we register cos and sin cached
104
+ self.max_seq_len_cached = max_position_embeddings
105
+
106
+ @torch.no_grad()
107
+ def forward(self, x, position_ids):
108
+ # x: [bs, num_attention_heads, seq_len, head_size]
109
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
110
+ position_ids_expanded = position_ids[:, None, :].float()
111
+ # Force float32 since bfloat16 loses precision on long contexts
112
+ # See https://github.com/huggingface/transformers/pull/29285
113
+ device_type = x.device.type
114
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
115
+ with torch.autocast(device_type=device_type, enabled=False):
116
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
117
+ emb = torch.cat((freqs, freqs), dim=-1)
118
+ cos = emb.cos()
119
+ sin = emb.sin()
120
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
121
+
122
+
123
+ class ChameleonLinearScalingRotaryEmbedding(ChameleonRotaryEmbedding):
124
+ """ChameleonRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
125
+
126
+ def forward(self, x, position_ids):
127
+ # difference to the original RoPE: a scaling factor is aplied to the position ids
128
+ position_ids = position_ids.float() / self.scaling_factor
129
+ cos, sin = super().forward(x, position_ids)
130
+ return cos, sin
131
+
132
+
133
+ class ChameleonDynamicNTKScalingRotaryEmbedding(ChameleonRotaryEmbedding):
134
+ """ChameleonRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
135
+
136
+ def forward(self, x, position_ids):
137
+ # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
138
+ seq_len = torch.max(position_ids) + 1
139
+ if seq_len > self.max_position_embeddings:
140
+ base = self.base * (
141
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
142
+ ) ** (self.dim / (self.dim - 2))
143
+ inv_freq = 1.0 / (
144
+ base
145
+ ** (torch.arange(0, self.dim, 2, dtype=torch.int64).to(device=x.device, dtype=torch.float) / self.dim)
146
+ )
147
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
148
+
149
+ cos, sin = super().forward(x, position_ids)
150
+ return cos, sin
151
+
152
+
153
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
154
+ def rotate_half(x):
155
+ """Rotates half the hidden dims of the input."""
156
+ x1 = x[..., : x.shape[-1] // 2]
157
+ x2 = x[..., x.shape[-1] // 2 :]
158
+ return torch.cat((-x2, x1), dim=-1)
159
+
160
+
161
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
162
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
163
+ """Applies Rotary Position Embedding to the query and key tensors.
164
+
165
+ Args:
166
+ q (`torch.Tensor`): The query tensor.
167
+ k (`torch.Tensor`): The key tensor.
168
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
169
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
170
+ position_ids (`torch.Tensor`, *optional*):
171
+ Deprecated and unused.
172
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
173
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
174
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
175
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
176
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
177
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
178
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
179
+ Returns:
180
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
181
+ """
182
+ cos = cos.unsqueeze(unsqueeze_dim)
183
+ sin = sin.unsqueeze(unsqueeze_dim)
184
+ q_embed = (q * cos) + (rotate_half(q) * sin)
185
+ k_embed = (k * cos) + (rotate_half(k) * sin)
186
+ return q_embed, k_embed
187
+
188
+
189
+ # Copied from transformers.models.llama.modeling_llama.LlamaMLP with Llama->Chameleon
190
+ class ChameleonMLP(nn.Module):
191
+ def __init__(self, config):
192
+ super().__init__()
193
+ self.config = config
194
+ self.hidden_size = config.hidden_size
195
+ self.intermediate_size = config.intermediate_size
196
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
197
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
198
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
199
+ self.act_fn = ACT2FN[config.hidden_act]
200
+
201
+ # Ignore copy
202
+ def forward(self, x):
203
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
204
+ return down_proj
205
+
206
+
207
+ class ChameleonLayerNorm(nn.LayerNorm):
208
+ """
209
+ LayerNorm but computes stats only over the last dim because Chameleon applies gamma and beta
210
+ from each shard separately to each head, instead of reducing. We can apply each head's own
211
+ gamma/beta by repeat-interleaving weights from each shard, but the stats have to be computed
212
+ in the last dimension. This module applies gamma/beta manually to fulfill this requirement.
213
+ """
214
+
215
+ def __init__(self, hidden_size, *args, **kwargs):
216
+ super().__init__(hidden_size, *args, **kwargs)
217
+ self.normalized_shape = (hidden_size[-1],)
218
+
219
+ def forward(self, hidden_states):
220
+ hidden_states = F.layer_norm(hidden_states, self.normalized_shape, None, None, eps=1e-5)
221
+ hidden_states = hidden_states * self.weight + self.bias
222
+ return hidden_states
223
+
224
+
225
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
226
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
227
+ """
228
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
229
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
230
+ """
231
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
232
+ if n_rep == 1:
233
+ return hidden_states
234
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
235
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
236
+
237
+
238
+ class ChameleonAttention(nn.Module):
239
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
240
+
241
+ def __init__(self, config: ChameleonConfig, layer_idx: Optional[int] = None):
242
+ super().__init__()
243
+ self.config = config
244
+ self.layer_idx = layer_idx
245
+ if layer_idx is None:
246
+ logger.warning_once(
247
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
248
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
249
+ "when creating this class."
250
+ )
251
+
252
+ self.attention_dropout = config.attention_dropout
253
+ self.hidden_size = config.hidden_size
254
+ self.num_heads = config.num_attention_heads
255
+ self.head_dim = self.hidden_size // self.num_heads
256
+ self.num_key_value_heads = config.num_key_value_heads
257
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
258
+ self.max_position_embeddings = config.max_position_embeddings
259
+ self.rope_theta = config.rope_theta
260
+ self.is_causal = True
261
+ self.model_parallel_size = config.model_parallel_size
262
+
263
+ if (self.head_dim * self.num_heads) != self.hidden_size:
264
+ raise ValueError(
265
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
266
+ f" and `num_heads`: {self.num_heads})."
267
+ )
268
+
269
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
270
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
271
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
272
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
273
+ self.q_norm = ChameleonLayerNorm((self.num_heads, self.head_dim))
274
+ self.k_norm = ChameleonLayerNorm((self.num_key_value_heads, self.head_dim))
275
+ self._init_rope()
276
+
277
+ # copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->Chameleon
278
+ # TODO(joao): add me back asap :)
279
+ def _init_rope(self):
280
+ if self.config.rope_scaling is None:
281
+ self.rotary_emb = ChameleonRotaryEmbedding(
282
+ self.head_dim,
283
+ max_position_embeddings=self.max_position_embeddings,
284
+ base=self.rope_theta,
285
+ )
286
+ else:
287
+ scaling_type = self.config.rope_scaling["type"]
288
+ scaling_factor = self.config.rope_scaling["factor"]
289
+ if scaling_type == "linear":
290
+ self.rotary_emb = ChameleonLinearScalingRotaryEmbedding(
291
+ self.head_dim,
292
+ max_position_embeddings=self.max_position_embeddings,
293
+ scaling_factor=scaling_factor,
294
+ base=self.rope_theta,
295
+ )
296
+ elif scaling_type == "dynamic":
297
+ self.rotary_emb = ChameleonDynamicNTKScalingRotaryEmbedding(
298
+ self.head_dim,
299
+ max_position_embeddings=self.max_position_embeddings,
300
+ scaling_factor=scaling_factor,
301
+ base=self.rope_theta,
302
+ )
303
+ else:
304
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
305
+
306
+ def forward(
307
+ self,
308
+ hidden_states: torch.Tensor,
309
+ attention_mask: Optional[torch.Tensor] = None,
310
+ position_ids: Optional[torch.LongTensor] = None,
311
+ past_key_value: Optional[Cache] = None,
312
+ output_attentions: bool = False,
313
+ use_cache: bool = False,
314
+ cache_position: Optional[torch.LongTensor] = None,
315
+ **kwargs,
316
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
317
+ bsz, q_len, _ = hidden_states.size()
318
+
319
+ query_states = self.q_proj(hidden_states)
320
+ key_states = self.k_proj(hidden_states)
321
+ value_states = self.v_proj(hidden_states)
322
+
323
+ query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
324
+ query_states = self.q_norm(query_states)
325
+
326
+ key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim)
327
+ key_states = self.k_norm(key_states)
328
+
329
+ query_states = query_states.reshape(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
330
+ key_states = key_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
331
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
332
+
333
+ cos, sin = self.rotary_emb(value_states, position_ids)
334
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
335
+
336
+ if past_key_value is not None:
337
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
338
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
339
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
340
+
341
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
342
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
343
+
344
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
345
+
346
+ if attention_mask is not None: # no matter the length, we just slice it
347
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
348
+ attn_weights = attn_weights + causal_mask
349
+
350
+ # upcast attention to fp32
351
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
352
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
353
+ attn_output = torch.matmul(attn_weights, value_states)
354
+
355
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
356
+ raise ValueError(
357
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
358
+ f" {attn_output.size()}"
359
+ )
360
+
361
+ attn_output = attn_output.transpose(1, 2).contiguous()
362
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
363
+ attn_output = self.o_proj(attn_output)
364
+
365
+ if not output_attentions:
366
+ attn_weights = None
367
+
368
+ return attn_output, attn_weights, past_key_value
369
+
370
+
371
+ # NO LONGER EXIST copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Chameleon
372
+ # TODO(joao): add me back asap :)
373
+ class ChameleonFlashAttention2(ChameleonAttention):
374
+ """
375
+ Chameleon flash attention module. This module inherits from `ChameleonAttention` as the weights of the module stays
376
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
377
+ flash attention and deal with padding tokens in case the input contains any of them.
378
+ """
379
+
380
+ def __init__(self, *args, **kwargs):
381
+ super().__init__(*args, **kwargs)
382
+
383
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
384
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
385
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
386
+ self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
387
+
388
+ # Ignore copy
389
+ def forward(
390
+ self,
391
+ hidden_states: torch.Tensor,
392
+ attention_mask: Optional[torch.LongTensor] = None,
393
+ position_ids: Optional[torch.LongTensor] = None,
394
+ past_key_value: Optional[Cache] = None,
395
+ output_attentions: bool = False,
396
+ use_cache: bool = False,
397
+ cache_position: Optional[torch.LongTensor] = None,
398
+ **kwargs,
399
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
400
+ if isinstance(past_key_value, StaticCache):
401
+ raise ValueError(
402
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
403
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
404
+ )
405
+
406
+ output_attentions = False
407
+
408
+ bsz, q_len, _ = hidden_states.size()
409
+
410
+ query_states = self.q_proj(hidden_states)
411
+ key_states = self.k_proj(hidden_states)
412
+ value_states = self.v_proj(hidden_states)
413
+
414
+ query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
415
+ query_states = self.q_norm(query_states)
416
+
417
+ key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim)
418
+ key_states = self.k_norm(key_states)
419
+
420
+ # Flash attention requires the input to have the shape
421
+ # batch_size x seq_length x head_dim x hidden_dim
422
+ # therefore we just need to keep the original shape
423
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
424
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
425
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
426
+
427
+ cos, sin = self.rotary_emb(value_states, position_ids)
428
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
429
+
430
+ if past_key_value is not None:
431
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
432
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
433
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
434
+
435
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim].
436
+ # We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view.
437
+ query_states = query_states.transpose(1, 2)
438
+ key_states = key_states.transpose(1, 2)
439
+ value_states = value_states.transpose(1, 2)
440
+
441
+ dropout_rate = self.attention_dropout if self.training else 0.0
442
+
443
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
444
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
445
+ # cast them back in the correct dtype just to be sure everything works as expected.
446
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
447
+ # in fp32. (ChameleonRMSNorm handles it correctly)
448
+
449
+ input_dtype = query_states.dtype
450
+ if input_dtype == torch.float32:
451
+ if torch.is_autocast_enabled():
452
+ target_dtype = torch.get_autocast_gpu_dtype()
453
+ # Handle the case where the model is quantized
454
+ elif hasattr(self.config, "_pre_quantization_dtype"):
455
+ target_dtype = self.config._pre_quantization_dtype
456
+ else:
457
+ target_dtype = self.q_proj.weight.dtype
458
+
459
+ logger.warning_once(
460
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
461
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
462
+ f" {target_dtype}."
463
+ )
464
+
465
+ query_states = query_states.to(target_dtype)
466
+ key_states = key_states.to(target_dtype)
467
+ value_states = value_states.to(target_dtype)
468
+
469
+ attn_output = _flash_attention_forward(
470
+ query_states,
471
+ key_states,
472
+ value_states,
473
+ attention_mask,
474
+ q_len,
475
+ dropout=dropout_rate,
476
+ sliding_window=getattr(self, "sliding_window", None),
477
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
478
+ is_causal=self.is_causal,
479
+ )
480
+
481
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
482
+ attn_output = self.o_proj(attn_output)
483
+
484
+ if not output_attentions:
485
+ attn_weights = None
486
+
487
+ return attn_output, attn_weights, past_key_value
488
+
489
+
490
+ class ChameleonSdpaAttention(ChameleonAttention):
491
+ """
492
+ Chameleon attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
493
+ `ChameleonAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
494
+ SDPA API.
495
+ """
496
+
497
+ # Adapted from ChameleonAttention.forward
498
+ def forward(
499
+ self,
500
+ hidden_states: torch.Tensor,
501
+ attention_mask: Optional[torch.Tensor] = None,
502
+ position_ids: Optional[torch.LongTensor] = None,
503
+ past_key_value: Optional[Cache] = None,
504
+ output_attentions: bool = False,
505
+ use_cache: bool = False,
506
+ cache_position: Optional[torch.LongTensor] = None,
507
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
508
+ if output_attentions:
509
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
510
+ logger.warning_once(
511
+ "ChameleonModel is using ChameleonSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
512
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
513
+ )
514
+ return super().forward(
515
+ hidden_states=hidden_states,
516
+ attention_mask=attention_mask,
517
+ position_ids=position_ids,
518
+ past_key_value=past_key_value,
519
+ output_attentions=output_attentions,
520
+ use_cache=use_cache,
521
+ cache_position=cache_position,
522
+ )
523
+
524
+ bsz, q_len, _ = hidden_states.size()
525
+
526
+ query_states = self.q_proj(hidden_states)
527
+ key_states = self.k_proj(hidden_states)
528
+ value_states = self.v_proj(hidden_states)
529
+
530
+ query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
531
+ query_states = self.q_norm(query_states)
532
+
533
+ key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim)
534
+ key_states = self.k_norm(key_states)
535
+
536
+ query_states = query_states.reshape(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
537
+ key_states = key_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
538
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
539
+
540
+ cos, sin = self.rotary_emb(value_states, position_ids)
541
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
542
+
543
+ if past_key_value is not None:
544
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
545
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
546
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
547
+
548
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
549
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
550
+
551
+ causal_mask = attention_mask
552
+ if attention_mask is not None and cache_position is not None:
553
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
554
+
555
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
556
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
557
+ if query_states.device.type == "cuda" and causal_mask is not None:
558
+ query_states = query_states.contiguous()
559
+ key_states = key_states.contiguous()
560
+ value_states = value_states.contiguous()
561
+
562
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
563
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
564
+ is_causal = True if causal_mask is None and q_len > 1 else False
565
+
566
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
567
+ query_states,
568
+ key_states,
569
+ value_states,
570
+ attn_mask=causal_mask,
571
+ dropout_p=self.attention_dropout if self.training else 0.0,
572
+ is_causal=is_causal,
573
+ )
574
+
575
+ attn_output = attn_output.transpose(1, 2).contiguous()
576
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
577
+
578
+ attn_output = self.o_proj(attn_output)
579
+
580
+ return attn_output, None, past_key_value
581
+
582
+
583
+ CHAMELEON_ATTENTION_CLASSES = {
584
+ "eager": ChameleonAttention,
585
+ "flash_attention_2": ChameleonFlashAttention2,
586
+ "sdpa": ChameleonSdpaAttention,
587
+ }
588
+
589
+
590
+ # copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Chameleon, LLAMA->CHAMELEON
591
+ # TODO(joao): add me back asap :)
592
+ class ChameleonDecoderLayer(nn.Module):
593
+ def __init__(self, config: ChameleonConfig, layer_idx: int):
594
+ super().__init__()
595
+ self.hidden_size = config.hidden_size
596
+
597
+ self.self_attn = CHAMELEON_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
598
+
599
+ self.mlp = ChameleonMLP(config)
600
+ self.input_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
601
+ self.post_attention_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
602
+
603
+ def forward(
604
+ self,
605
+ hidden_states: torch.Tensor,
606
+ attention_mask: Optional[torch.Tensor] = None,
607
+ position_ids: Optional[torch.LongTensor] = None,
608
+ past_key_value: Optional[Cache] = None,
609
+ output_attentions: Optional[bool] = False,
610
+ use_cache: Optional[bool] = False,
611
+ cache_position: Optional[torch.LongTensor] = None,
612
+ **kwargs,
613
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
614
+ """
615
+ Args:
616
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
617
+ attention_mask (`torch.FloatTensor`, *optional*):
618
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
619
+ query_sequence_length, key_sequence_length)` if default attention is used.
620
+ output_attentions (`bool`, *optional*):
621
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
622
+ returned tensors for more detail.
623
+ use_cache (`bool`, *optional*):
624
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
625
+ (see `past_key_values`).
626
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
627
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
628
+ Indices depicting the position of the input sequence tokens in the sequence
629
+ kwargs (`dict`, *optional*):
630
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
631
+ into the model
632
+ """
633
+ residual = hidden_states
634
+
635
+ hidden_states = self.input_layernorm(hidden_states)
636
+
637
+ # Self Attention
638
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
639
+ hidden_states=hidden_states,
640
+ attention_mask=attention_mask,
641
+ position_ids=position_ids,
642
+ past_key_value=past_key_value,
643
+ output_attentions=output_attentions,
644
+ use_cache=use_cache,
645
+ cache_position=cache_position,
646
+ **kwargs,
647
+ )
648
+ hidden_states = residual + hidden_states
649
+
650
+ # Fully Connected
651
+ residual = hidden_states
652
+ hidden_states = self.post_attention_layernorm(hidden_states)
653
+ hidden_states = self.mlp(hidden_states)
654
+ hidden_states = residual + hidden_states
655
+
656
+ outputs = (hidden_states,)
657
+
658
+ if output_attentions:
659
+ outputs += (self_attn_weights,)
660
+
661
+ if use_cache:
662
+ outputs += (present_key_value,)
663
+
664
+ return outputs
665
+
666
+
667
+ class ChameleonSwinDecoderLayer(nn.Module):
668
+ def __init__(self, config: ChameleonConfig, layer_idx: int):
669
+ super().__init__()
670
+ self.hidden_size = config.hidden_size
671
+
672
+ self.self_attn = CHAMELEON_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
673
+
674
+ self.mlp = ChameleonMLP(config)
675
+ self.input_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
676
+ self.post_attention_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
677
+
678
+ def forward(
679
+ self,
680
+ hidden_states: torch.Tensor,
681
+ attention_mask: Optional[torch.Tensor] = None,
682
+ position_ids: Optional[torch.LongTensor] = None,
683
+ past_key_value: Optional[Cache] = None,
684
+ output_attentions: Optional[bool] = False,
685
+ use_cache: Optional[bool] = False,
686
+ cache_position: Optional[torch.LongTensor] = None,
687
+ **kwargs,
688
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
689
+ """
690
+ Args:
691
+ hidden_states (`torch.FloatTensor`):
692
+ input to the layer of shape `(batch, seq_len, embed_dim)`
693
+ attention_mask (`torch.FloatTensor`, *optional*):
694
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
695
+ query_sequence_length, key_sequence_length)` if default attention is used.
696
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
697
+ Indices of positions of each input sequence tokens in the position embeddings
698
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
699
+ output_attentions (`bool`, *optional*):
700
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
701
+ returned tensors for more detail.
702
+ use_cache (`bool`, *optional*):
703
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
704
+ (see `past_key_values`).
705
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
706
+ Indices depicting the position of the input sequence tokens in the sequence.
707
+ """
708
+
709
+ residual = hidden_states
710
+
711
+ # Self Attention
712
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
713
+ hidden_states=hidden_states,
714
+ attention_mask=attention_mask,
715
+ position_ids=position_ids,
716
+ past_key_value=past_key_value,
717
+ output_attentions=output_attentions,
718
+ use_cache=use_cache,
719
+ cache_position=cache_position,
720
+ **kwargs,
721
+ )
722
+ hidden_states = self.input_layernorm(hidden_states)
723
+ hidden_states = residual + hidden_states
724
+ # Fully Connected
725
+ residual = hidden_states
726
+ hidden_states = self.mlp(hidden_states)
727
+ hidden_states = self.post_attention_layernorm(hidden_states)
728
+ hidden_states = residual + hidden_states
729
+ outputs = (hidden_states,)
730
+
731
+ if output_attentions:
732
+ outputs += (self_attn_weights,)
733
+
734
+ if use_cache:
735
+ outputs += (present_key_value,)
736
+
737
+ return outputs
738
+
739
+
740
+ class ChameleonVQVAEVectorQuantizer(nn.Module):
741
+ """
742
+ A module for vector quantization using learned embedding vectors.
743
+
744
+ This module implements the quantization process similar to te one described in
745
+ the VQ-VAE (Vector Quantized Variational AutoEncoder) paper. It quantizes continuous
746
+ input vectors into discrete codebook vectors, which are learned during training.
747
+ Current implementation improves over previous ones by avoiding costly matrix multiplications
748
+ and allowing for post-hoc remapping of indices.
749
+ """
750
+
751
+ def __init__(self, config):
752
+ super().__init__()
753
+ self.num_embeddings = config.num_embeddings
754
+ self.embedding_dim = config.embed_dim
755
+ self.beta = getattr(config, "beta", 0.25)
756
+
757
+ self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
758
+
759
+ def forward(self, hidden_state: torch.Tensor):
760
+ hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
761
+ hidden_state_flattened = hidden_state.view(-1, self.embedding_dim)
762
+
763
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
764
+ distances = (
765
+ torch.sum(hidden_state_flattened**2, dim=1, keepdim=True)
766
+ + torch.sum(self.embedding.weight**2, dim=1)
767
+ - 2 * torch.einsum("bd,dn->bn", hidden_state_flattened, self.embedding.weight.transpose(0, 1))
768
+ )
769
+
770
+ min_encoding_indices = torch.argmin(distances, dim=1)
771
+ hidden_state_quant = self.embedding(min_encoding_indices).view(hidden_state.shape)
772
+
773
+ # compute loss for embedding
774
+ loss = torch.mean((hidden_state_quant.detach() - hidden_state) ** 2) + self.beta * torch.mean(
775
+ (hidden_state_quant - hidden_state.detach()) ** 2
776
+ )
777
+
778
+ # preserve gradients
779
+ hidden_state_quant = hidden_state + (hidden_state_quant - hidden_state).detach()
780
+
781
+ # reshape back to match original input shape
782
+ hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous()
783
+
784
+ return hidden_state_quant, loss, min_encoding_indices
785
+
786
+
787
+ class ChameleonVQVAEEncoderConvDownsample(nn.Module):
788
+ def __init__(self, in_channels):
789
+ super().__init__()
790
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
791
+
792
+ def forward(self, hidden_states):
793
+ # no asymmetric padding in torch conv, must do it ourselves
794
+ hidden_states = F.pad(hidden_states, pad=(0, 1, 0, 1), mode="constant", value=0)
795
+ hidden_states = self.conv(hidden_states)
796
+ return hidden_states
797
+
798
+
799
+ class ChameleonVQVAEEncoderResnetBlock(nn.Module):
800
+ def __init__(
801
+ self,
802
+ config,
803
+ in_channels,
804
+ out_channels=None,
805
+ conv_shortcut=False,
806
+ ):
807
+ super().__init__()
808
+ self.in_channels = in_channels
809
+ self.out_channels = in_channels if out_channels is None else out_channels
810
+ self.use_conv_shortcut = conv_shortcut
811
+
812
+ self.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
813
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
814
+ self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
815
+ self.dropout = torch.nn.Dropout(config.dropout)
816
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
817
+ if self.in_channels != self.out_channels:
818
+ if self.use_conv_shortcut:
819
+ self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
820
+ else:
821
+ self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
822
+
823
+ def forward(self, hidden_states):
824
+ residual = hidden_states
825
+ hidden_states = self.norm1(hidden_states)
826
+ hidden_states *= torch.sigmoid(hidden_states)
827
+ hidden_states = self.conv1(hidden_states)
828
+
829
+ hidden_states = self.norm2(hidden_states)
830
+ hidden_states *= torch.sigmoid(hidden_states)
831
+ hidden_states = self.dropout(hidden_states)
832
+ hidden_states = self.conv2(hidden_states)
833
+
834
+ if self.in_channels != self.out_channels:
835
+ if self.use_conv_shortcut:
836
+ residual = self.conv_shortcut(residual)
837
+ else:
838
+ residual = self.nin_shortcut(residual)
839
+
840
+ return residual + hidden_states
841
+
842
+
843
+ class ChameleonVQVAEEncoderAttnBlock(nn.Module):
844
+ def __init__(self, in_channels):
845
+ super().__init__()
846
+ self.in_channels = in_channels
847
+
848
+ self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
849
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
850
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
851
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
852
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
853
+
854
+ def forward(self, hidden_states):
855
+ residual = hidden_states
856
+ hidden_states = self.norm(hidden_states)
857
+ query_states = self.q(hidden_states)
858
+ key_states = self.k(hidden_states)
859
+ value_states = self.v(hidden_states)
860
+
861
+ # compute attention
862
+ batch_size, channels, height, width = query_states.shape
863
+ query_states = query_states.reshape(batch_size, channels, height * width).permute(0, 2, 1)
864
+ key_states = key_states.reshape(batch_size, channels, height * width)
865
+ attn_weights = torch.bmm(query_states, key_states)
866
+ attn_weights = attn_weights * (int(channels) ** (-0.5))
867
+ attn_weights = F.softmax(attn_weights, dim=2)
868
+
869
+ # attend to values
870
+ value_states = value_states.reshape(batch_size, channels, height * width)
871
+ attn_weights = attn_weights.permute(0, 2, 1)
872
+ attn_output = torch.bmm(value_states, attn_weights).reshape(batch_size, channels, height, width)
873
+
874
+ attn_output = self.proj_out(attn_output)
875
+ return residual + attn_output
876
+
877
+
878
+ class ChameleonVQVAEEncoder(nn.Module):
879
+ def __init__(self, config):
880
+ super().__init__()
881
+
882
+ self.num_resolutions = len(config.channel_multiplier)
883
+ self.num_res_blocks = config.num_res_blocks
884
+ base_channels = config.base_channels
885
+ resolution = config.resolution
886
+ in_channels = config.in_channels
887
+ double_latent = config.double_latent
888
+ latent_channels = config.latent_channels
889
+ channel_multiplier = config.channel_multiplier
890
+
891
+ self.conv_in = torch.nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1)
892
+
893
+ curr_res = resolution
894
+ in_channel_multiplier = (1,) + tuple(channel_multiplier)
895
+ self.in_channel_multiplier = in_channel_multiplier
896
+ self.down = nn.ModuleList()
897
+ for i_level in range(self.num_resolutions):
898
+ block = nn.ModuleList()
899
+ attn = nn.ModuleList()
900
+ block_in = base_channels * in_channel_multiplier[i_level]
901
+ block_out = base_channels * channel_multiplier[i_level]
902
+ for i_block in range(self.num_res_blocks):
903
+ block.append(
904
+ ChameleonVQVAEEncoderResnetBlock(
905
+ config=config,
906
+ in_channels=block_in,
907
+ out_channels=block_out,
908
+ )
909
+ )
910
+ block_in = block_out
911
+ if (
912
+ config.attn_resolutions is not None
913
+ and curr_res in config.attn_resolutions
914
+ and config.attn_type == "vanilla"
915
+ ):
916
+ attn.append(ChameleonVQVAEEncoderAttnBlock(block_in))
917
+
918
+ down = nn.Module()
919
+ down.block = block
920
+ down.attn = attn
921
+ if i_level != self.num_resolutions - 1:
922
+ down.downsample = ChameleonVQVAEEncoderConvDownsample(block_in)
923
+ curr_res = curr_res // 2
924
+ self.down.append(down)
925
+
926
+ self.mid = nn.Module()
927
+ self.mid.block_1 = ChameleonVQVAEEncoderResnetBlock(
928
+ config=config,
929
+ in_channels=block_in,
930
+ out_channels=block_in,
931
+ )
932
+ self.mid.attn_1 = ChameleonVQVAEEncoderAttnBlock(block_in) if config.attn_type == "vanilla" else nn.Identity()
933
+ self.mid.block_2 = ChameleonVQVAEEncoderResnetBlock(
934
+ config=config,
935
+ in_channels=block_in,
936
+ out_channels=block_in,
937
+ )
938
+
939
+ self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
940
+ self.conv_out = torch.nn.Conv2d(
941
+ block_in,
942
+ 2 * latent_channels if double_latent else latent_channels,
943
+ kernel_size=3,
944
+ stride=1,
945
+ padding=1,
946
+ )
947
+
948
+ def forward(self, pixel_values: torch.LongTensor):
949
+ # downsampling
950
+ hidden_states = [self.conv_in(pixel_values)]
951
+ for i_level in range(self.num_resolutions):
952
+ for i_block in range(self.num_res_blocks):
953
+ hidden_state = self.down[i_level].block[i_block](
954
+ hidden_states[-1],
955
+ )
956
+ if len(self.down[i_level].attn) > 0:
957
+ hidden_state = self.down[i_level].attn[i_block](hidden_state)
958
+ hidden_states.append(hidden_state)
959
+ if i_level != self.num_resolutions - 1:
960
+ hidden_states.append(self.down[i_level].downsample(hidden_states[-1]))
961
+
962
+ # middle
963
+ last_hidden_state = hidden_states[-1]
964
+ last_hidden_state = self.mid.block_1(last_hidden_state)
965
+ last_hidden_state = self.mid.attn_1(last_hidden_state)
966
+ last_hidden_state = self.mid.block_2(last_hidden_state)
967
+
968
+ # end
969
+ last_hidden_state = self.norm_out(last_hidden_state)
970
+ last_hidden_state *= torch.sigmoid(last_hidden_state)
971
+ last_hidden_state = self.conv_out(last_hidden_state)
972
+ return last_hidden_state
973
+
974
+
975
+ class ChameleonImageVocabularyMapping:
976
+ """
977
+ A class for mapping discrete image tokens from VQGAN to BPE tokens.
978
+ """
979
+
980
+ def __init__(self, vocab_map):
981
+ self.vocab_map = vocab_map
982
+ self.image_token_id = vocab_map.get("<image>")
983
+
984
+ @cached_property
985
+ def val2name(self):
986
+ return {v: k for k, v in self.vocab_map.items()}
987
+
988
+ @cached_property
989
+ def image_tokens(self):
990
+ return sorted([val for name, val in self.vocab_map.items() if name.startswith("IMGIMG")])
991
+
992
+ @cached_property
993
+ def bpe2img(self):
994
+ img_tkn_chr_mapping = {chr(ord("A") + i): str(i) for i in range(10)}
995
+
996
+ def remap(old_name: str) -> str:
997
+ return "".join(img_tkn_chr_mapping.get(c, c) for c in old_name[len("IMGIMG") : -1])
998
+
999
+ return {tok: int(remap(self.val2name[tok])) for tok in self.image_tokens}
1000
+
1001
+ @cached_property
1002
+ def img2bpe(self):
1003
+ return {v: k for k, v in self.bpe2img.items()}
1004
+
1005
+ @cached_property
1006
+ def bpe2img_search_tensors(self):
1007
+ return torch.tensor(sorted(self.bpe2img.keys())), torch.tensor(sorted(self.bpe2img.values()))
1008
+
1009
+ @cached_property
1010
+ def img2bpe_mapping_tensor(self):
1011
+ mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int)
1012
+ for k, v in self.img2bpe.items():
1013
+ mapping[k] = v
1014
+ return mapping
1015
+
1016
+ def convert_img2bpe(self, img_batch: torch.Tensor) -> torch.Tensor:
1017
+ device = img_batch.device
1018
+ img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")]
1019
+ return img_tokens.to(device)
1020
+
1021
+
1022
+ CHAMELEON_START_DOCSTRING = r"""
1023
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1024
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1025
+ etc.)
1026
+
1027
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1028
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1029
+ and behavior.
1030
+
1031
+ Parameters:
1032
+ config ([`ChameleonConfig`]):
1033
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
1034
+ load the weights associated with the model, only the configuration. Check out the
1035
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1036
+ """
1037
+
1038
+
1039
+ @add_start_docstrings(
1040
+ "The bare chameleon Model outputting raw hidden-states without any specific head on top.",
1041
+ CHAMELEON_START_DOCSTRING,
1042
+ )
1043
+ class ChameleonPreTrainedModel(PreTrainedModel):
1044
+ config_class = ChameleonConfig
1045
+ base_model_prefix = "model"
1046
+ supports_gradient_checkpointing = True
1047
+ _no_split_modules = ["ChameleonDecoderLayer", "ChameleonSwinDecoderLayer"]
1048
+ _skip_keys_device_placement = ["past_key_values", "causal_mask"]
1049
+ _supports_flash_attn_2 = True
1050
+ _supports_sdpa = True
1051
+ _supports_quantized_cache = True
1052
+ _supports_cache_class = True
1053
+ _supports_static_cache = True
1054
+ _supports_param_buffer_assignment = False
1055
+
1056
+ def _init_weights(self, module):
1057
+ std = self.config.initializer_range
1058
+
1059
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
1060
+ module.weight.data.normal_(mean=0.0, std=std)
1061
+ if module.bias is not None:
1062
+ module.bias.data.zero_()
1063
+ elif isinstance(module, (nn.GroupNorm, nn.LayerNorm)):
1064
+ module.bias.data.zero_()
1065
+ module.weight.data.fill_(1.0)
1066
+ elif isinstance(module, ChameleonRMSNorm):
1067
+ module.weight.data.fill_(1.0)
1068
+ elif isinstance(module, nn.Embedding):
1069
+ module.weight.data.normal_(mean=0.0, std=std)
1070
+ if module.padding_idx is not None:
1071
+ module.weight.data[module.padding_idx].zero_()
1072
+
1073
+
1074
+ CHAMELEON_VQ_START_DOCSTRING = r"""
1075
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1076
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1077
+ etc.)
1078
+
1079
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1080
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1081
+ and behavior.
1082
+
1083
+ Parameters:
1084
+ config ([`ChameleonVQVAEConfig`]):
1085
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
1086
+ load the weights associated with the model, only the configuration. Check out the
1087
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1088
+ """
1089
+
1090
+
1091
+ @add_start_docstrings(
1092
+ """The VQ-VAE model used in Chameleon for encoding/decoding images into discrete tokens.
1093
+ This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from
1094
+ [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv Taigman](https://arxiv.org/abs/2203.13131).
1095
+ """,
1096
+ CHAMELEON_VQ_START_DOCSTRING,
1097
+ )
1098
+ class ChameleonVQVAE(ChameleonPreTrainedModel):
1099
+ config_class = ChameleonVQVAEConfig
1100
+ _no_split_modules = ["ChameleonVQVAEVectorQuantizer"]
1101
+
1102
+ def __init__(self, config: ChameleonVQVAEConfig):
1103
+ super().__init__(config)
1104
+
1105
+ self.encoder = ChameleonVQVAEEncoder(config)
1106
+ self.quantize = ChameleonVQVAEVectorQuantizer(config)
1107
+ self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1)
1108
+ self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1)
1109
+ self.eval() # Chameleon's VQ model is frozen
1110
+
1111
+ def encode(self, pixel_values: torch.LongTensor):
1112
+ hidden_states = self.encoder(pixel_values)
1113
+ hidden_states = self.quant_conv(hidden_states)
1114
+ quant, emb_loss, indices = self.quantize(hidden_states)
1115
+ return quant, emb_loss, indices
1116
+
1117
+
1118
+ CHAMELEON_INPUTS_DOCSTRING = r"""
1119
+ Args:
1120
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1121
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1122
+ it.
1123
+
1124
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1125
+ [`PreTrainedTokenizer.__call__`] for details.
1126
+
1127
+ [What are input IDs?](../glossary#input-ids)
1128
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
1129
+ The tensors corresponding to the input images. Pixel values can be obtained using
1130
+ [`AutoImageProcessor`]. See [`ChameleonImageProcessor.__call__`] for details.
1131
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1132
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1133
+
1134
+ - 1 for tokens that are **not masked**,
1135
+ - 0 for tokens that are **masked**.
1136
+
1137
+ [What are attention masks?](../glossary#attention-mask)
1138
+
1139
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1140
+ [`PreTrainedTokenizer.__call__`] for details.
1141
+
1142
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
1143
+ `past_key_values`).
1144
+
1145
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1146
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1147
+ information on the default strategy.
1148
+
1149
+ - 1 indicates the head is **not masked**,
1150
+ - 0 indicates the head is **masked**.
1151
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1152
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1153
+ config.n_positions - 1]`.
1154
+
1155
+ [What are position IDs?](../glossary#position-ids)
1156
+ past_key_values (`Cache`, *optional*):
1157
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1158
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1159
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
1160
+
1161
+ Should always be a [`~cache_utils.Cache`] instance and the model will output the same cache instance.
1162
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1163
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1164
+ of shape `(batch_size, sequence_length)`.
1165
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1166
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1167
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1168
+ model's internal embedding lookup matrix.
1169
+ use_cache (`bool`, *optional*):
1170
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1171
+ `past_key_values`).
1172
+ output_attentions (`bool`, *optional*):
1173
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1174
+ tensors for more detail.
1175
+ output_hidden_states (`bool`, *optional*):
1176
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1177
+ more detail.
1178
+ return_dict (`bool`, *optional*):
1179
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1180
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
1181
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
1182
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
1183
+ the complete sequence length.
1184
+ """
1185
+
1186
+
1187
+ @add_start_docstrings(
1188
+ "The bare chameleon Model outputting raw hidden-states without any specific head on top.",
1189
+ CHAMELEON_START_DOCSTRING,
1190
+ )
1191
+ class ChameleonModel(ChameleonPreTrainedModel):
1192
+ """
1193
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`ChameleonDecoderLayer`]
1194
+
1195
+ Args:
1196
+ config: ChameleonConfig
1197
+ """
1198
+
1199
+ def __init__(self, config: ChameleonConfig):
1200
+ super().__init__(config)
1201
+ self.padding_idx = config.pad_token_id
1202
+ self.vocab_size = config.vocab_size
1203
+
1204
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1205
+ self.vocabulary_mapping = ChameleonImageVocabularyMapping(config.vocabulary_map)
1206
+ decoder_layer = ChameleonDecoderLayer if not self.config.swin_norm else ChameleonSwinDecoderLayer
1207
+ self.layers = nn.ModuleList(
1208
+ [decoder_layer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1209
+ )
1210
+ self.norm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1211
+ self.vqmodel = ChameleonVQVAE._from_config(config.vq_config)
1212
+ self.gradient_checkpointing = False
1213
+
1214
+ # Initialize weights and apply final processing
1215
+ self.post_init()
1216
+
1217
+ def get_input_embeddings(self):
1218
+ return self.embed_tokens
1219
+
1220
+ def set_input_embeddings(self, value):
1221
+ self.embed_tokens = value
1222
+
1223
+ def get_image_tokens(self, pixel_values: torch.FloatTensor):
1224
+ """
1225
+ Tokenizes images into discrete tokens with VQGAN module. Converts
1226
+ obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
1227
+ special tokens.
1228
+
1229
+ Args:
1230
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
1231
+ The tensors corresponding to the input images.
1232
+ """
1233
+ batch_size = pixel_values.shape[0]
1234
+ _, _, image_toks = self.vqmodel.encode(pixel_values)
1235
+ bpe_toks = self.vocabulary_mapping.convert_img2bpe(image_toks)
1236
+ bpe_toks = bpe_toks.view(batch_size, -1)
1237
+ return bpe_toks
1238
+
1239
+ @add_start_docstrings_to_model_forward(CHAMELEON_INPUTS_DOCSTRING)
1240
+ @add_code_sample_docstrings(
1241
+ checkpoint=_CHECKPOINT_FOR_DOC,
1242
+ output_type=BaseModelOutputWithPast,
1243
+ config_class=_CONFIG_FOR_DOC,
1244
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
1245
+ )
1246
+ def forward(
1247
+ self,
1248
+ input_ids: Optional[torch.LongTensor] = None,
1249
+ pixel_values: Optional[torch.FloatTensor] = None,
1250
+ attention_mask: Optional[torch.Tensor] = None,
1251
+ position_ids: Optional[torch.LongTensor] = None,
1252
+ past_key_values: Optional[Cache] = None,
1253
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1254
+ use_cache: Optional[bool] = None,
1255
+ output_attentions: Optional[bool] = None,
1256
+ output_hidden_states: Optional[bool] = None,
1257
+ return_dict: Optional[bool] = None,
1258
+ cache_position: Optional[torch.LongTensor] = None,
1259
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1260
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1261
+ output_hidden_states = (
1262
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1263
+ )
1264
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1265
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1266
+
1267
+ if self.gradient_checkpointing and self.training and use_cache:
1268
+ logger.warning_once(
1269
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
1270
+ )
1271
+ use_cache = False
1272
+
1273
+ if (input_ids is None) ^ (inputs_embeds is not None):
1274
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1275
+
1276
+ if pixel_values is not None and inputs_embeds is not None:
1277
+ raise ValueError(
1278
+ "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
1279
+ )
1280
+
1281
+ if pixel_values is not None:
1282
+ image_tokens = self.get_image_tokens(pixel_values)
1283
+ special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
1284
+ if not is_torchdynamo_compiling() and input_ids[special_image_mask].numel() != image_tokens.numel():
1285
+ n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum()
1286
+ n_image_features = image_tokens.shape[0] * image_tokens.shape[1]
1287
+ raise ValueError(
1288
+ f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}"
1289
+ )
1290
+ image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
1291
+ input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
1292
+
1293
+ if inputs_embeds is None:
1294
+ inputs_embeds = self.embed_tokens(input_ids)
1295
+
1296
+ # torch.jit.trace() doesn't support cache objects in the output
1297
+ if use_cache and past_key_values is None and not torch.jit.is_tracing():
1298
+ past_key_values = DynamicCache()
1299
+
1300
+ if cache_position is None:
1301
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1302
+ cache_position = torch.arange(
1303
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
1304
+ )
1305
+
1306
+ if position_ids is None:
1307
+ position_ids = cache_position.unsqueeze(0)
1308
+
1309
+ causal_mask = self._update_causal_mask(
1310
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
1311
+ )
1312
+
1313
+ # embed positions
1314
+ hidden_states = inputs_embeds
1315
+
1316
+ # decoder layers
1317
+ all_hidden_states = () if output_hidden_states else None
1318
+ all_self_attns = () if output_attentions else None
1319
+ next_decoder_cache = None
1320
+
1321
+ for decoder_layer in self.layers:
1322
+ if output_hidden_states:
1323
+ all_hidden_states += (hidden_states,)
1324
+
1325
+ if self.gradient_checkpointing and self.training:
1326
+ layer_outputs = self._gradient_checkpointing_func(
1327
+ decoder_layer.__call__,
1328
+ hidden_states,
1329
+ causal_mask,
1330
+ position_ids,
1331
+ past_key_values,
1332
+ output_attentions,
1333
+ use_cache,
1334
+ cache_position,
1335
+ )
1336
+ else:
1337
+ layer_outputs = decoder_layer(
1338
+ hidden_states,
1339
+ attention_mask=causal_mask,
1340
+ position_ids=position_ids,
1341
+ past_key_value=past_key_values,
1342
+ output_attentions=output_attentions,
1343
+ use_cache=use_cache,
1344
+ cache_position=cache_position,
1345
+ )
1346
+
1347
+ hidden_states = layer_outputs[0]
1348
+
1349
+ if use_cache:
1350
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1351
+
1352
+ if output_attentions:
1353
+ all_self_attns += (layer_outputs[1],)
1354
+
1355
+ hidden_states = self.norm(hidden_states)
1356
+
1357
+ # add hidden states from the last decoder layer
1358
+ if output_hidden_states:
1359
+ all_hidden_states += (hidden_states,)
1360
+
1361
+ next_cache = None
1362
+ if use_cache:
1363
+ next_cache = next_decoder_cache
1364
+
1365
+ if not return_dict:
1366
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1367
+
1368
+ return BaseModelOutputWithPast(
1369
+ last_hidden_state=hidden_states,
1370
+ past_key_values=next_cache,
1371
+ hidden_states=all_hidden_states,
1372
+ attentions=all_self_attns,
1373
+ )
1374
+
1375
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
1376
+ def _update_causal_mask(
1377
+ self,
1378
+ attention_mask: Union[torch.Tensor, "BlockMask"],
1379
+ input_tensor: torch.Tensor,
1380
+ cache_position: torch.Tensor,
1381
+ past_key_values: Cache,
1382
+ output_attentions: bool = False,
1383
+ ):
1384
+ if self.config._attn_implementation == "flash_attention_2":
1385
+ if attention_mask is not None and (attention_mask == 0.0).any():
1386
+ return attention_mask
1387
+ return None
1388
+ if self.config._attn_implementation == "flex_attention":
1389
+ if isinstance(attention_mask, torch.Tensor):
1390
+ attention_mask = make_flex_block_causal_mask(attention_mask)
1391
+ return attention_mask
1392
+
1393
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1394
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1395
+ # to infer the attention mask.
1396
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1397
+ using_static_cache = isinstance(past_key_values, StaticCache)
1398
+
1399
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1400
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
1401
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
1402
+ attention_mask,
1403
+ inputs_embeds=input_tensor,
1404
+ past_key_values_length=past_seen_tokens,
1405
+ is_training=self.training,
1406
+ ):
1407
+ return None
1408
+
1409
+ dtype, device = input_tensor.dtype, input_tensor.device
1410
+ sequence_length = input_tensor.shape[1]
1411
+ if using_static_cache:
1412
+ target_length = past_key_values.get_max_cache_shape()
1413
+ else:
1414
+ target_length = (
1415
+ attention_mask.shape[-1]
1416
+ if isinstance(attention_mask, torch.Tensor)
1417
+ else past_seen_tokens + sequence_length + 1
1418
+ )
1419
+
1420
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1421
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
1422
+ attention_mask,
1423
+ sequence_length=sequence_length,
1424
+ target_length=target_length,
1425
+ dtype=dtype,
1426
+ device=device,
1427
+ cache_position=cache_position,
1428
+ batch_size=input_tensor.shape[0],
1429
+ )
1430
+
1431
+ if (
1432
+ self.config._attn_implementation == "sdpa"
1433
+ and attention_mask is not None
1434
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
1435
+ and not output_attentions
1436
+ ):
1437
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1438
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1439
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1440
+ min_dtype = torch.finfo(dtype).min
1441
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1442
+
1443
+ return causal_mask
1444
+
1445
+ @staticmethod
1446
+ # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
1447
+ def _prepare_4d_causal_attention_mask_with_cache_position(
1448
+ attention_mask: torch.Tensor,
1449
+ sequence_length: int,
1450
+ target_length: int,
1451
+ dtype: torch.dtype,
1452
+ device: torch.device,
1453
+ cache_position: torch.Tensor,
1454
+ batch_size: int,
1455
+ **kwargs,
1456
+ ):
1457
+ """
1458
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
1459
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
1460
+
1461
+ Args:
1462
+ attention_mask (`torch.Tensor`):
1463
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
1464
+ `(batch_size, 1, query_length, key_value_length)`.
1465
+ sequence_length (`int`):
1466
+ The sequence length being processed.
1467
+ target_length (`int`):
1468
+ The target length: when generating with static cache, the mask should be as long as the static cache,
1469
+ to account for the 0 padding, the part of the cache that is not filled yet.
1470
+ dtype (`torch.dtype`):
1471
+ The dtype to use for the 4D attention mask.
1472
+ device (`torch.device`):
1473
+ The device to place the 4D attention mask on.
1474
+ cache_position (`torch.Tensor`):
1475
+ Indices depicting the position of the input sequence tokens in the sequence.
1476
+ batch_size (`torch.Tensor`):
1477
+ Batch size.
1478
+ """
1479
+ if attention_mask is not None and attention_mask.dim() == 4:
1480
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
1481
+ causal_mask = attention_mask
1482
+ else:
1483
+ min_dtype = torch.finfo(dtype).min
1484
+ causal_mask = torch.full(
1485
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
1486
+ )
1487
+ if sequence_length != 1:
1488
+ causal_mask = torch.triu(causal_mask, diagonal=1)
1489
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1490
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
1491
+ if attention_mask is not None:
1492
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1493
+ mask_length = attention_mask.shape[-1]
1494
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
1495
+ causal_mask.device
1496
+ )
1497
+ padding_mask = padding_mask == 0
1498
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
1499
+ padding_mask, min_dtype
1500
+ )
1501
+
1502
+ return causal_mask
1503
+
1504
+
1505
+ @add_start_docstrings(
1506
+ "Chameleon Model with a head on top used for outputting logits for next token prediction.",
1507
+ CHAMELEON_START_DOCSTRING,
1508
+ )
1509
+ class ChameleonForConditionalGeneration(ChameleonPreTrainedModel, GenerationMixin):
1510
+ _tied_weights_keys = ["lm_head.weight"]
1511
+
1512
+ def __init__(self, config):
1513
+ super().__init__(config)
1514
+ self.model = ChameleonModel(config)
1515
+ self.vocab_size = config.vocab_size
1516
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1517
+
1518
+ # Initialize weights and apply final processing
1519
+ self.post_init()
1520
+
1521
+ def get_input_embeddings(self):
1522
+ return self.model.embed_tokens
1523
+
1524
+ def set_input_embeddings(self, value):
1525
+ self.model.embed_tokens = value
1526
+
1527
+ def get_output_embeddings(self):
1528
+ return self.lm_head
1529
+
1530
+ def set_output_embeddings(self, new_embeddings):
1531
+ self.lm_head = new_embeddings
1532
+
1533
+ def set_decoder(self, decoder):
1534
+ self.model = decoder
1535
+
1536
+ def get_decoder(self):
1537
+ return self.model
1538
+
1539
+ @add_start_docstrings_to_model_forward(CHAMELEON_INPUTS_DOCSTRING)
1540
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1541
+ def forward(
1542
+ self,
1543
+ input_ids: Optional[torch.LongTensor] = None,
1544
+ pixel_values: Optional[torch.FloatTensor] = None,
1545
+ attention_mask: Optional[torch.Tensor] = None,
1546
+ position_ids: Optional[torch.LongTensor] = None,
1547
+ past_key_values: Optional[Cache] = None,
1548
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1549
+ labels: Optional[torch.LongTensor] = None,
1550
+ use_cache: Optional[bool] = None,
1551
+ output_attentions: Optional[bool] = None,
1552
+ output_hidden_states: Optional[bool] = None,
1553
+ return_dict: Optional[bool] = None,
1554
+ cache_position: Optional[torch.LongTensor] = None,
1555
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1556
+ r"""
1557
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1558
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1559
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1560
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1561
+
1562
+ Returns:
1563
+
1564
+ Example:
1565
+
1566
+ ```python
1567
+ >>> from transformers import ChameleonProcessor, ChameleonForConditionalGeneration
1568
+ >>> import torch
1569
+ >>> import requests
1570
+ >>> from PIL import Image
1571
+
1572
+ >>> model = ChameleonForConditionalGeneration.from_pretrained("facebook/chameleon-7b", torch_dtype=torch.bfloat16)
1573
+ >>> processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b")
1574
+
1575
+ >>> prompt = "I used to know a lot about constellations when I was younger, but as I grew older, I forgot most of what I knew. These are the only two constellations that I really remember now.<image><image>I would like for you to tell me about 3 more constellations and give me a little bit of history about the constellation."
1576
+ >>> image = Image.open(requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw)
1577
+ >>> image_2 = Image.open(requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw)
1578
+
1579
+ >>> inputs = processor(images=[image, image_2], text=prompt, return_tensors="pt").to(model.device, torch.bfloat16)
1580
+
1581
+ >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False)
1582
+ >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
1583
+ ```"""
1584
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1585
+ output_hidden_states = (
1586
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1587
+ )
1588
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1589
+
1590
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1591
+ outputs = self.model(
1592
+ input_ids=input_ids,
1593
+ pixel_values=pixel_values,
1594
+ attention_mask=attention_mask,
1595
+ position_ids=position_ids,
1596
+ past_key_values=past_key_values,
1597
+ inputs_embeds=inputs_embeds,
1598
+ use_cache=use_cache,
1599
+ output_attentions=output_attentions,
1600
+ output_hidden_states=output_hidden_states,
1601
+ return_dict=return_dict,
1602
+ cache_position=cache_position,
1603
+ )
1604
+
1605
+ hidden_states = outputs[0]
1606
+ logits = self.lm_head(hidden_states)
1607
+
1608
+ # Disallow image tokens which does not include special begin-image and end-image tokens
1609
+ image_tokens = self.model.vocabulary_mapping.image_tokens
1610
+ logits[:, :, image_tokens] = torch.finfo(logits.dtype).min
1611
+
1612
+ loss = None
1613
+ if labels is not None:
1614
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
1615
+ logits = logits.float()
1616
+ # Shift so that tokens < n predict n
1617
+ shift_logits = logits[..., :-1, :].contiguous()
1618
+ shift_labels = labels[..., 1:].contiguous()
1619
+ # Flatten the tokens
1620
+ loss_fct = CrossEntropyLoss()
1621
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1622
+ shift_labels = shift_labels.view(-1)
1623
+ # Enable model parallelism
1624
+ shift_labels = shift_labels.to(shift_logits.device)
1625
+ loss = loss_fct(shift_logits, shift_labels)
1626
+
1627
+ if not return_dict:
1628
+ output = (logits,) + outputs[1:]
1629
+ return (loss,) + output if loss is not None else output
1630
+
1631
+ return CausalLMOutputWithPast(
1632
+ loss=loss,
1633
+ logits=logits,
1634
+ past_key_values=outputs.past_key_values,
1635
+ hidden_states=outputs.hidden_states,
1636
+ attentions=outputs.attentions,
1637
+ )
1638
+
1639
+ def prepare_inputs_for_generation(
1640
+ self,
1641
+ input_ids,
1642
+ pixel_values=None,
1643
+ past_key_values=None,
1644
+ attention_mask=None,
1645
+ inputs_embeds=None,
1646
+ cache_position=None,
1647
+ position_ids=None,
1648
+ use_cache=True,
1649
+ **kwargs,
1650
+ ):
1651
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
1652
+
1653
+ model_inputs = super().prepare_inputs_for_generation(
1654
+ input_ids,
1655
+ pixel_values=pixel_values,
1656
+ past_key_values=past_key_values,
1657
+ attention_mask=attention_mask,
1658
+ inputs_embeds=inputs_embeds,
1659
+ cache_position=cache_position,
1660
+ position_ids=position_ids,
1661
+ use_cache=use_cache,
1662
+ **kwargs,
1663
+ )
1664
+
1665
+ if cache_position[0] != 0:
1666
+ # If we're in cached decoding stage, pixel values should be `None` because input ids do not contain special image token anymore
1667
+ # Otherwise we need pixel values to be passed to model
1668
+ model_inputs["pixel_values"] = None
1669
+
1670
+ return model_inputs
1671
+
1672
+
1673
+ __all__ = ["ChameleonForConditionalGeneration", "ChameleonModel", "ChameleonPreTrainedModel", "ChameleonVQVAE"]
docs/transformers/src/transformers/models/chameleon/processing_chameleon.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Processor class for Chameleon.
17
+ """
18
+
19
+ from typing import List, Optional, Union
20
+
21
+ from ...feature_extraction_utils import BatchFeature
22
+ from ...image_utils import ImageInput
23
+ from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack, _validate_images_text_input_order
24
+ from ...tokenization_utils_base import PreTokenizedInput, TextInput
25
+
26
+
27
+ class ChameleonTextKwargs(TextKwargs, total=False):
28
+ return_for_text_completion: bool
29
+
30
+
31
+ class ChameleonProcessorKwargs(ProcessingKwargs, total=False):
32
+ text_kwargs: ChameleonTextKwargs
33
+ _defaults = {
34
+ "text_kwargs": {
35
+ "padding": False,
36
+ "return_for_text_completion": False,
37
+ },
38
+ "common_kwargs": {
39
+ "return_tensors": "pt",
40
+ },
41
+ }
42
+
43
+
44
+ class ChameleonProcessor(ProcessorMixin):
45
+ r"""
46
+ Constructs a Chameleon processor which wraps a Chameleon image processor and a Chameleon tokenizer into a single
47
+ processor.
48
+
49
+ [`ChameleonProcessor`] offers all the functionalities of [`ChameleonImageProcessor`] and [`LlamaTokenizerFast`].
50
+ See the [`~ChameleonProcessor.__call__`] and [`~ChameleonProcessor.decode`] for more information.
51
+
52
+ Args:
53
+ image_processor ([`ChameleonImageProcessor`]):
54
+ The image processor is a required input.
55
+ tokenizer ([`LlamaTokenizerFast`]):
56
+ The tokenizer is a required input.
57
+ image_seq_length (`int`, *optional*, defaults to 1024):
58
+ Sequence length of one image embedding.
59
+ image_token (`str`, *optional*, defaults to `"<image>"`):
60
+ The special token used to indicate image in the text.
61
+ """
62
+
63
+ attributes = ["image_processor", "tokenizer"]
64
+ tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
65
+ valid_kwargs = ["image_seq_length", "image_token"]
66
+ image_processor_class = "ChameleonImageProcessor"
67
+
68
+ def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, image_token: str = "<image>"):
69
+ self.image_seq_length = image_seq_length
70
+ self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
71
+ self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
72
+ self.image_start_token = (
73
+ tokenizer.boi_token if hasattr(tokenizer, "boi_token") else "<racm3:break>"
74
+ ) # fixed tokens for start and end, so can hardcode
75
+ self.image_end_token = tokenizer.eoi_token if hasattr(tokenizer, "eoi_token") else "<eoss>"
76
+
77
+ super().__init__(image_processor, tokenizer)
78
+
79
+ def __call__(
80
+ self,
81
+ images: Optional[ImageInput] = None,
82
+ text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
83
+ audio=None,
84
+ videos=None,
85
+ **kwargs: Unpack[ChameleonProcessorKwargs],
86
+ ) -> BatchFeature:
87
+ """
88
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
89
+ and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
90
+ the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
91
+ CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
92
+ of the above two methods for more information.
93
+
94
+ Args:
95
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
96
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
97
+ tensor. Both channels-first and channels-last formats are supported.
98
+ text (`str`, `List[str]`, `List[List[str]]`):
99
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
100
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
101
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
102
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
103
+ If set, will return tensors of a particular framework. Acceptable values are:
104
+
105
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
106
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
107
+ - `'np'`: Return NumPy `np.ndarray` objects.
108
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
109
+
110
+ Returns:
111
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
112
+
113
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
114
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
115
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
116
+ `None`).
117
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
118
+ """
119
+ # check if images and text inputs are reversed for BC
120
+ images, text = _validate_images_text_input_order(images, text)
121
+ if isinstance(text, str):
122
+ text = [text]
123
+ elif not isinstance(text, list) and not isinstance(text[0], str):
124
+ raise TypeError("Invalid input text. Please provide a string, or a list of strings")
125
+ if text is None and images is None:
126
+ raise ValueError("You must provide either text or images")
127
+
128
+ output_kwargs = self._merge_kwargs(
129
+ ChameleonProcessorKwargs,
130
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
131
+ **kwargs,
132
+ )
133
+ return_for_text_completion = output_kwargs["text_kwargs"].pop("return_for_text_completion", False)
134
+
135
+ # Replace the image token with the expanded image token sequence
136
+ prompt_strings = []
137
+ one_img_tokens = self.image_start_token + (self.image_token * self.image_seq_length) + self.image_end_token
138
+ for sample in text:
139
+ sample = sample.replace(self.image_token, one_img_tokens)
140
+ if not return_for_text_completion:
141
+ sample += self.tokenizer.sep_token # special Chameleon treatment to add sep for chat mode
142
+ prompt_strings.append(sample)
143
+
144
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
145
+ data = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
146
+ self._check_special_mm_tokens(prompt_strings, data, modalities=["image"])
147
+
148
+ if images is not None:
149
+ data["pixel_values"] = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]
150
+
151
+ return BatchFeature(data=data, tensor_type=return_tensors)
152
+
153
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
154
+ def batch_decode(self, *args, **kwargs):
155
+ """
156
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
157
+ refer to the docstring of this method for more information.
158
+ """
159
+ return self.tokenizer.batch_decode(*args, **kwargs)
160
+
161
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
162
+ def decode(self, *args, **kwargs):
163
+ """
164
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
165
+ the docstring of this method for more information.
166
+ """
167
+ return self.tokenizer.decode(*args, **kwargs)
168
+
169
+ @property
170
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
171
+ def model_input_names(self):
172
+ tokenizer_input_names = self.tokenizer.model_input_names
173
+ image_processor_input_names = self.image_processor.model_input_names
174
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
175
+
176
+
177
+ __all__ = ["ChameleonProcessor"]
docs/transformers/src/transformers/models/chinese_clip/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_chinese_clip import *
22
+ from .feature_extraction_chinese_clip import *
23
+ from .image_processing_chinese_clip import *
24
+ from .image_processing_chinese_clip_fast import *
25
+ from .modeling_chinese_clip import *
26
+ from .processing_chinese_clip import *
27
+ else:
28
+ import sys
29
+
30
+ _file = globals()["__file__"]
31
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/src/transformers/models/chinese_clip/configuration_chinese_clip.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The OFA-Sys Team Authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Chinese-CLIP model configuration"""
16
+
17
+ from collections import OrderedDict
18
+ from typing import TYPE_CHECKING, Any, Mapping, Optional
19
+
20
+
21
+ if TYPE_CHECKING:
22
+ from ...processing_utils import ProcessorMixin
23
+ from ...utils import TensorType
24
+
25
+ from ...configuration_utils import PretrainedConfig
26
+ from ...onnx import OnnxConfig
27
+ from ...utils import logging
28
+
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ class ChineseCLIPTextConfig(PretrainedConfig):
34
+ r"""
35
+ This is the configuration class to store the configuration of a [`ChineseCLIPModel`]. It is used to instantiate a
36
+ Chinese CLIP model according to the specified arguments, defining the model architecture. Instantiating a
37
+ configuration with the defaults will yield a similar configuration to that of the Chinese CLIP
38
+ [OFA-Sys/chinese-clip-vit-base-patch16](https:
39
+ //huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16) architecture.
40
+
41
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
42
+ documentation from [`PretrainedConfig`] for more information.
43
+
44
+
45
+ Args:
46
+ vocab_size (`int`, *optional*, defaults to 30522):
47
+ Vocabulary size of the CHINESE_CLIP model. Defines the number of different tokens that can be represented
48
+ by the `inputs_ids` passed when calling [`ChineseCLIPModel`].
49
+ hidden_size (`int`, *optional*, defaults to 768):
50
+ Dimensionality of the encoder layers and the pooler layer.
51
+ num_hidden_layers (`int`, *optional*, defaults to 12):
52
+ Number of hidden layers in the Transformer encoder.
53
+ num_attention_heads (`int`, *optional*, defaults to 12):
54
+ Number of attention heads for each attention layer in the Transformer encoder.
55
+ intermediate_size (`int`, *optional*, defaults to 3072):
56
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
57
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
58
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
59
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
60
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
61
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
62
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
63
+ The dropout ratio for the attention probabilities.
64
+ max_position_embeddings (`int`, *optional*, defaults to 512):
65
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
66
+ just in case (e.g., 512 or 1024 or 2048).
67
+ type_vocab_size (`int`, *optional*, defaults to 2):
68
+ The vocabulary size of the `token_type_ids` passed when calling [`ChineseCLIPModel`].
69
+ initializer_range (`float`, *optional*, defaults to 0.02):
70
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
71
+ initializer_factor (`float`, *optional*, defaults to 1.0):
72
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
73
+ testing).
74
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
75
+ The epsilon used by the layer normalization layers.
76
+ pad_token_id (`int`, *optional*, defaults to 0):
77
+ Padding token id.
78
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
79
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
80
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
81
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
82
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
83
+ with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
84
+ use_cache (`bool`, *optional*, defaults to `True`):
85
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
86
+ relevant if `config.is_decoder=True`.
87
+
88
+ Example:
89
+
90
+ ```python
91
+ >>> from transformers import ChineseCLIPTextConfig, ChineseCLIPTextModel
92
+
93
+ >>> # Initializing a ChineseCLIPTextConfig with OFA-Sys/chinese-clip-vit-base-patch16 style configuration
94
+ >>> configuration = ChineseCLIPTextConfig()
95
+
96
+ >>> # Initializing a ChineseCLIPTextModel (with random weights) from the OFA-Sys/chinese-clip-vit-base-patch16 style configuration
97
+ >>> model = ChineseCLIPTextModel(configuration)
98
+
99
+ >>> # Accessing the model configuration
100
+ >>> configuration = model.config
101
+ ```"""
102
+
103
+ model_type = "chinese_clip_text_model"
104
+ base_config_key = "text_config"
105
+
106
+ def __init__(
107
+ self,
108
+ vocab_size=30522,
109
+ hidden_size=768,
110
+ num_hidden_layers=12,
111
+ num_attention_heads=12,
112
+ intermediate_size=3072,
113
+ hidden_act="gelu",
114
+ hidden_dropout_prob=0.1,
115
+ attention_probs_dropout_prob=0.1,
116
+ max_position_embeddings=512,
117
+ type_vocab_size=2,
118
+ initializer_range=0.02,
119
+ initializer_factor=1.0,
120
+ layer_norm_eps=1e-12,
121
+ pad_token_id=0,
122
+ position_embedding_type="absolute",
123
+ use_cache=True,
124
+ **kwargs,
125
+ ):
126
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
127
+
128
+ self.vocab_size = vocab_size
129
+ self.hidden_size = hidden_size
130
+ self.num_hidden_layers = num_hidden_layers
131
+ self.num_attention_heads = num_attention_heads
132
+ self.hidden_act = hidden_act
133
+ self.intermediate_size = intermediate_size
134
+ self.hidden_dropout_prob = hidden_dropout_prob
135
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
136
+ self.max_position_embeddings = max_position_embeddings
137
+ self.type_vocab_size = type_vocab_size
138
+ self.initializer_range = initializer_range
139
+ self.initializer_factor = initializer_factor
140
+ self.layer_norm_eps = layer_norm_eps
141
+ self.position_embedding_type = position_embedding_type
142
+ self.use_cache = use_cache
143
+
144
+
145
+ class ChineseCLIPVisionConfig(PretrainedConfig):
146
+ r"""
147
+ This is the configuration class to store the configuration of a [`ChineseCLIPModel`]. It is used to instantiate an
148
+ ChineseCLIP model according to the specified arguments, defining the model architecture. Instantiating a
149
+ configuration with the defaults will yield a similar configuration to that of the ChineseCLIP
150
+ [OFA-Sys/chinese-clip-vit-base-patch16](https://huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16) architecture.
151
+
152
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
153
+ documentation from [`PretrainedConfig`] for more information.
154
+
155
+
156
+ Args:
157
+ hidden_size (`int`, *optional*, defaults to 768):
158
+ Dimensionality of the encoder layers and the pooler layer.
159
+ intermediate_size (`int`, *optional*, defaults to 3072):
160
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
161
+ projection_dim (`int`, *optional*, defaults to 512):
162
+ Dimensionality of text and vision projection layers.
163
+ num_hidden_layers (`int`, *optional*, defaults to 12):
164
+ Number of hidden layers in the Transformer encoder.
165
+ num_attention_heads (`int`, *optional*, defaults to 12):
166
+ Number of attention heads for each attention layer in the Transformer encoder.
167
+ num_channels (`int`, *optional*, defaults to 3):
168
+ The number of input channels.
169
+ image_size (`int`, *optional*, defaults to 224):
170
+ The size (resolution) of each image.
171
+ patch_size (`int`, *optional*, defaults to 32):
172
+ The size (resolution) of each patch.
173
+ hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
174
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
175
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
176
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
177
+ The epsilon used by the layer normalization layers.
178
+ attention_dropout (`float`, *optional*, defaults to 0.0):
179
+ The dropout ratio for the attention probabilities.
180
+ initializer_range (`float`, *optional*, defaults to 0.02):
181
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
182
+ initializer_factor (`float`, *optional*, defaults to 1.0):
183
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
184
+ testing).
185
+ Example:
186
+ ```python
187
+ >>> from transformers import ChineseCLIPVisionConfig, ChineseCLIPVisionModel
188
+
189
+ >>> # Initializing a ChineseCLIPVisionConfig with OFA-Sys/chinese-clip-vit-base-patch16 style configuration
190
+ >>> configuration = ChineseCLIPVisionConfig()
191
+
192
+ >>> # Initializing a ChineseCLIPVisionModel (with random weights) from the OFA-Sys/chinese-clip-vit-base-patch16 style configuration
193
+ >>> model = ChineseCLIPVisionModel(configuration)
194
+
195
+ >>> # Accessing the model configuration
196
+ >>> configuration = model.config
197
+ ```"""
198
+
199
+ model_type = "chinese_clip_vision_model"
200
+ base_config_key = "vision_config"
201
+
202
+ def __init__(
203
+ self,
204
+ hidden_size=768,
205
+ intermediate_size=3072,
206
+ projection_dim=512,
207
+ num_hidden_layers=12,
208
+ num_attention_heads=12,
209
+ num_channels=3,
210
+ image_size=224,
211
+ patch_size=32,
212
+ hidden_act="quick_gelu",
213
+ layer_norm_eps=1e-5,
214
+ attention_dropout=0.0,
215
+ initializer_range=0.02,
216
+ initializer_factor=1.0,
217
+ **kwargs,
218
+ ):
219
+ super().__init__(**kwargs)
220
+
221
+ self.hidden_size = hidden_size
222
+ self.intermediate_size = intermediate_size
223
+ self.projection_dim = projection_dim
224
+ self.num_hidden_layers = num_hidden_layers
225
+ self.num_attention_heads = num_attention_heads
226
+ self.num_channels = num_channels
227
+ self.patch_size = patch_size
228
+ self.image_size = image_size
229
+ self.initializer_range = initializer_range
230
+ self.initializer_factor = initializer_factor
231
+ self.attention_dropout = attention_dropout
232
+ self.layer_norm_eps = layer_norm_eps
233
+ self.hidden_act = hidden_act
234
+
235
+
236
+ class ChineseCLIPConfig(PretrainedConfig):
237
+ r"""
238
+ [`ChineseCLIPConfig`] is the configuration class to store the configuration of a [`ChineseCLIPModel`]. It is used
239
+ to instantiate Chinese-CLIP model according to the specified arguments, defining the text model and vision model
240
+ configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the
241
+ Chinese-CLIP [OFA-Sys/chinese-clip-vit-base-patch16](https://huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16)
242
+ architecture.
243
+
244
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
245
+ documentation from [`PretrainedConfig`] for more information.
246
+
247
+ Args:
248
+ text_config (`dict`, *optional*):
249
+ Dictionary of configuration options used to initialize [`ChineseCLIPTextConfig`].
250
+ vision_config (`dict`, *optional*):
251
+ Dictionary of configuration options used to initialize [`ChineseCLIPVisionConfig`].
252
+ projection_dim (`int`, *optional*, defaults to 512):
253
+ Dimensionality of text and vision projection layers.
254
+ logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
255
+ The initial value of the *logit_scale* parameter. Default is used as per the original ChineseCLIP
256
+ implementation.
257
+ kwargs (*optional*):
258
+ Dictionary of keyword arguments.
259
+
260
+ Example:
261
+
262
+ ```python
263
+ >>> from transformers import ChineseCLIPConfig, ChineseCLIPModel
264
+
265
+ >>> # Initializing a ChineseCLIPConfig with OFA-Sys/chinese-clip-vit-base-patch16 style configuration
266
+ >>> configuration = ChineseCLIPConfig()
267
+
268
+ >>> # Initializing a ChineseCLIPModel (with random weights) from the OFA-Sys/chinese-clip-vit-base-patch16 style configuration
269
+ >>> model = ChineseCLIPModel(configuration)
270
+
271
+ >>> # Accessing the model configuration
272
+ >>> configuration = model.config
273
+
274
+ >>> # We can also initialize a ChineseCLIPConfig from a ChineseCLIPTextConfig and a ChineseCLIPVisionConfig
275
+
276
+ >>> # Initializing a ChineseCLIPTextConfig and ChineseCLIPVisionConfig configuration
277
+ >>> config_text = ChineseCLIPTextConfig()
278
+ >>> config_vision = ChineseCLIPVisionConfig()
279
+
280
+ >>> config = ChineseCLIPConfig.from_text_vision_configs(config_text, config_vision)
281
+ ```"""
282
+
283
+ model_type = "chinese_clip"
284
+ sub_configs = {"text_config": ChineseCLIPTextConfig, "vision_config": ChineseCLIPVisionConfig}
285
+
286
+ def __init__(
287
+ self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs
288
+ ):
289
+ # If `_config_dict` exist, we use them for the backward compatibility.
290
+ # We pop out these 2 attributes before calling `super().__init__` to avoid them being saved (which causes a lot
291
+ # of confusion!).
292
+ text_config_dict = kwargs.pop("text_config_dict", None)
293
+ vision_config_dict = kwargs.pop("vision_config_dict", None)
294
+
295
+ super().__init__(**kwargs)
296
+
297
+ # Instead of simply assigning `[text|vision]_config_dict` to `[text|vision]_config`, we use the values in
298
+ # `[text|vision]_config_dict` to update the values in `[text|vision]_config`. The values should be same in most
299
+ # cases, but we don't want to break anything regarding `_config_dict` that existed before commit `8827e1b2`.
300
+ if text_config_dict is not None:
301
+ if text_config is None:
302
+ text_config = {}
303
+
304
+ # This is the complete result when using `text_config_dict`.
305
+ _text_config_dict = ChineseCLIPTextConfig(**text_config_dict).to_dict()
306
+
307
+ # Give a warning if the values exist in both `_text_config_dict` and `text_config` but being different.
308
+ for key, value in _text_config_dict.items():
309
+ if key in text_config and value != text_config[key] and key not in ["transformers_version"]:
310
+ # If specified in `text_config_dict`
311
+ if key in text_config_dict:
312
+ message = (
313
+ f"`{key}` is found in both `text_config_dict` and `text_config` but with different values. "
314
+ f'The value `text_config_dict["{key}"]` will be used instead.'
315
+ )
316
+ # If inferred from default argument values (just to be super careful)
317
+ else:
318
+ message = (
319
+ f"`text_config_dict` is provided which will be used to initialize `ChineseCLIPTextConfig`. "
320
+ f'The value `text_config["{key}"]` will be overridden.'
321
+ )
322
+ logger.info(message)
323
+
324
+ # Update all values in `text_config` with the ones in `_text_config_dict`.
325
+ text_config.update(_text_config_dict)
326
+
327
+ if vision_config_dict is not None:
328
+ if vision_config is None:
329
+ vision_config = {}
330
+
331
+ # This is the complete result when using `vision_config_dict`.
332
+ _vision_config_dict = ChineseCLIPVisionConfig(**vision_config_dict).to_dict()
333
+ # convert keys to string instead of integer
334
+ if "id2label" in _vision_config_dict:
335
+ _vision_config_dict["id2label"] = {
336
+ str(key): value for key, value in _vision_config_dict["id2label"].items()
337
+ }
338
+
339
+ # Give a warning if the values exist in both `_vision_config_dict` and `vision_config` but being different.
340
+ for key, value in _vision_config_dict.items():
341
+ if key in vision_config and value != vision_config[key] and key not in ["transformers_version"]:
342
+ # If specified in `vision_config_dict`
343
+ if key in vision_config_dict:
344
+ message = (
345
+ f"`{key}` is found in both `vision_config_dict` and `vision_config` but with different "
346
+ f'values. The value `vision_config_dict["{key}"]` will be used instead.'
347
+ )
348
+ # If inferred from default argument values (just to be super careful)
349
+ else:
350
+ message = (
351
+ f"`vision_config_dict` is provided which will be used to initialize "
352
+ f'`ChineseCLIPVisionConfig`. The value `vision_config["{key}"]` will be overridden.'
353
+ )
354
+ logger.info(message)
355
+
356
+ # Update all values in `vision_config` with the ones in `_vision_config_dict`.
357
+ vision_config.update(_vision_config_dict)
358
+
359
+ if text_config is None:
360
+ text_config = {}
361
+ logger.info("`text_config` is `None`. Initializing the `ChineseCLIPTextConfig` with default values.")
362
+
363
+ if vision_config is None:
364
+ vision_config = {}
365
+ logger.info("`vision_config` is `None`. initializing the `ChineseCLIPVisionConfig` with default values.")
366
+
367
+ self.text_config = ChineseCLIPTextConfig(**text_config)
368
+ self.vision_config = ChineseCLIPVisionConfig(**vision_config)
369
+
370
+ self.projection_dim = projection_dim
371
+ self.logit_scale_init_value = logit_scale_init_value
372
+ self.initializer_factor = 1.0
373
+ self.initializer_range = 0.02
374
+
375
+ @classmethod
376
+ def from_text_vision_configs(
377
+ cls, text_config: ChineseCLIPTextConfig, vision_config: ChineseCLIPVisionConfig, **kwargs
378
+ ):
379
+ r"""
380
+ Instantiate a [`ChineseCLIPConfig`] (or a derived class) from Chinese-CLIP text model configuration and
381
+ Chinese-CLIP vision model configuration. Returns:
382
+ [`ChineseCLIPConfig`]: An instance of a configuration object
383
+ """
384
+
385
+ return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
386
+
387
+
388
+ class ChineseCLIPOnnxConfig(OnnxConfig):
389
+ @property
390
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
391
+ return OrderedDict(
392
+ [
393
+ ("input_ids", {0: "batch", 1: "sequence"}),
394
+ ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
395
+ ("attention_mask", {0: "batch", 1: "sequence"}),
396
+ ]
397
+ )
398
+
399
+ @property
400
+ def outputs(self) -> Mapping[str, Mapping[int, str]]:
401
+ return OrderedDict(
402
+ [
403
+ ("logits_per_image", {0: "batch"}),
404
+ ("logits_per_text", {0: "batch"}),
405
+ ("text_embeds", {0: "batch"}),
406
+ ("image_embeds", {0: "batch"}),
407
+ ]
408
+ )
409
+
410
+ @property
411
+ def atol_for_validation(self) -> float:
412
+ return 1e-4
413
+
414
+ def generate_dummy_inputs(
415
+ self,
416
+ processor: "ProcessorMixin",
417
+ batch_size: int = -1,
418
+ seq_length: int = -1,
419
+ framework: Optional["TensorType"] = None,
420
+ ) -> Mapping[str, Any]:
421
+ text_input_dict = super().generate_dummy_inputs(
422
+ processor.tokenizer, batch_size=batch_size, seq_length=seq_length, framework=framework
423
+ )
424
+ image_input_dict = super().generate_dummy_inputs(
425
+ processor.image_processor, batch_size=batch_size, framework=framework
426
+ )
427
+ return {**text_input_dict, **image_input_dict}
428
+
429
+ @property
430
+ def default_onnx_opset(self) -> int:
431
+ return 14
432
+
433
+
434
+ __all__ = ["ChineseCLIPConfig", "ChineseCLIPOnnxConfig", "ChineseCLIPTextConfig", "ChineseCLIPVisionConfig"]
docs/transformers/src/transformers/models/chinese_clip/convert_chinese_clip_original_pytorch_to_hf.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The OFA-Sys Team Authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import argparse
17
+
18
+ import torch
19
+
20
+ from transformers import ChineseCLIPConfig, ChineseCLIPModel
21
+
22
+
23
+ def copy_attn_layer(hf_attn_layer, pt_weights, prefix):
24
+ q_proj, k_proj, v_proj = pt_weights[f"{prefix}.in_proj_weight"].chunk(3, dim=0)
25
+ q_proj_bias, k_proj_bias, v_proj_bias = pt_weights[f"{prefix}.in_proj_bias"].chunk(3, dim=0)
26
+
27
+ out_proj_weights = pt_weights[f"{prefix}.out_proj.weight"]
28
+ out_proj_bias = pt_weights[f"{prefix}.out_proj.bias"]
29
+
30
+ hf_attn_layer.q_proj.weight.data = q_proj
31
+ hf_attn_layer.q_proj.bias.data = q_proj_bias
32
+
33
+ hf_attn_layer.k_proj.weight.data = k_proj
34
+ hf_attn_layer.k_proj.bias.data = k_proj_bias
35
+
36
+ hf_attn_layer.v_proj.weight.data = v_proj
37
+ hf_attn_layer.v_proj.bias.data = v_proj_bias
38
+
39
+ hf_attn_layer.out_proj.weight.data = out_proj_weights
40
+ hf_attn_layer.out_proj.bias.data = out_proj_bias
41
+
42
+
43
+ def copy_mlp(hf_mlp, pt_weights, prefix):
44
+ copy_linear(hf_mlp.fc1, pt_weights, f"{prefix}.c_fc")
45
+ copy_linear(hf_mlp.fc2, pt_weights, f"{prefix}.c_proj")
46
+
47
+
48
+ def copy_linear(hf_linear, pt_weights, prefix):
49
+ hf_linear.weight.data = pt_weights[f"{prefix}.weight"].data
50
+ hf_linear.bias.data = pt_weights[f"{prefix}.bias"].data
51
+
52
+
53
+ def copy_layer(hf_layer, pt_weights, prefix):
54
+ # copy layer norms
55
+ copy_linear(hf_layer.layer_norm1, pt_weights, f"{prefix}.ln_1")
56
+ copy_linear(hf_layer.layer_norm2, pt_weights, f"{prefix}.ln_2")
57
+
58
+ # copy MLP
59
+ copy_mlp(hf_layer.mlp, pt_weights, f"{prefix}.mlp")
60
+
61
+ # copy attn
62
+ copy_attn_layer(hf_layer.self_attn, pt_weights, f"{prefix}.attn")
63
+
64
+
65
+ def copy_layers(hf_layers, pt_weights, prefix):
66
+ for layer_id, hf_layer in enumerate(hf_layers):
67
+ copy_layer(hf_layer, pt_weights, f"{prefix}.{layer_id}")
68
+
69
+
70
+ def copy_text_model_and_projection(hf_model, pt_weights):
71
+ # copy projection
72
+ hf_model.text_projection.weight.data = pt_weights["text_projection"].data.T
73
+
74
+ # copy text encoder
75
+ for name, param in hf_model.text_model.named_parameters():
76
+ param.data = pt_weights[f"bert.{name}"].data
77
+
78
+
79
+ def copy_vision_model_and_projection(hf_model, pt_weights):
80
+ # copy projection
81
+ hf_model.visual_projection.weight.data = pt_weights["visual.proj"].data.T
82
+
83
+ # copy layer norms
84
+ copy_linear(hf_model.vision_model.pre_layrnorm, pt_weights, "visual.ln_pre")
85
+ copy_linear(hf_model.vision_model.post_layernorm, pt_weights, "visual.ln_post")
86
+
87
+ # copy embeddings
88
+ hf_model.vision_model.embeddings.patch_embedding.weight.data = pt_weights["visual.conv1.weight"].data
89
+ hf_model.vision_model.embeddings.class_embedding.data = pt_weights["visual.class_embedding"].data
90
+ hf_model.vision_model.embeddings.position_embedding.weight.data = pt_weights["visual.positional_embedding"].data
91
+
92
+ # copy encoder
93
+ copy_layers(hf_model.vision_model.encoder.layers, pt_weights, "visual.transformer.resblocks")
94
+
95
+
96
+ @torch.no_grad()
97
+ def convert_chinese_clip_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None):
98
+ """
99
+ Copy/paste/tweak model's weights to transformers design.
100
+ """
101
+
102
+ assert config_path is not None, "Please specify the ChineseCLIP model config of the corresponding model size."
103
+ config = ChineseCLIPConfig.from_pretrained(config_path)
104
+
105
+ hf_model = ChineseCLIPModel(config).eval()
106
+
107
+ pt_weights = torch.load(checkpoint_path, map_location="cpu", weights_only=True)["state_dict"]
108
+ pt_weights = {(name[7:] if name.startswith("module.") else name): value for name, value in pt_weights.items()}
109
+
110
+ copy_text_model_and_projection(hf_model, pt_weights)
111
+ copy_vision_model_and_projection(hf_model, pt_weights)
112
+ hf_model.logit_scale.data = pt_weights["logit_scale"].data
113
+
114
+ hf_model.save_pretrained(pytorch_dump_folder_path)
115
+
116
+
117
+ if __name__ == "__main__":
118
+ parser = argparse.ArgumentParser()
119
+ parser.add_argument(
120
+ "--pytorch_dump_folder_path",
121
+ default=None,
122
+ type=str,
123
+ help="Path to the output folder storing converted hf PyTorch model.",
124
+ )
125
+ parser.add_argument(
126
+ "--checkpoint_path", default=None, type=str, help="Path to original github format ChineseCLIP checkpoint."
127
+ )
128
+ parser.add_argument(
129
+ "--config_path", default=None, required=True, type=str, help="Path to hf config.json of model to convert."
130
+ )
131
+ args = parser.parse_args()
132
+
133
+ convert_chinese_clip_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path)
134
+ print("The conversion is finished!")
docs/transformers/src/transformers/models/chinese_clip/feature_extraction_chinese_clip.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The OFA-Sys Team Authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Feature extractor class for Chinese-CLIP."""
16
+
17
+ import warnings
18
+
19
+ from ...utils import logging
20
+ from ...utils.import_utils import requires
21
+ from .image_processing_chinese_clip import ChineseCLIPImageProcessor
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ @requires(backends=("vision",))
28
+ class ChineseCLIPFeatureExtractor(ChineseCLIPImageProcessor):
29
+ def __init__(self, *args, **kwargs) -> None:
30
+ warnings.warn(
31
+ "The class ChineseCLIPFeatureExtractor is deprecated and will be removed in version 5 of Transformers."
32
+ " Please use ChineseCLIPImageProcessor instead.",
33
+ FutureWarning,
34
+ )
35
+ super().__init__(*args, **kwargs)
36
+
37
+
38
+ __all__ = ["ChineseCLIPFeatureExtractor"]
docs/transformers/src/transformers/models/chinese_clip/image_processing_chinese_clip.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The OFA-Sys Team Authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Image processor class for Chinese-CLIP."""
16
+
17
+ from typing import Dict, List, Optional, Union
18
+
19
+ import numpy as np
20
+
21
+ from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
22
+ from ...image_transforms import (
23
+ convert_to_rgb,
24
+ get_resize_output_image_size,
25
+ resize,
26
+ to_channel_dimension_format,
27
+ )
28
+ from ...image_utils import (
29
+ OPENAI_CLIP_MEAN,
30
+ OPENAI_CLIP_STD,
31
+ ChannelDimension,
32
+ ImageInput,
33
+ PILImageResampling,
34
+ infer_channel_dimension_format,
35
+ is_scaled_image,
36
+ make_list_of_images,
37
+ to_numpy_array,
38
+ valid_images,
39
+ validate_preprocess_arguments,
40
+ )
41
+ from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
42
+
43
+
44
+ if is_vision_available():
45
+ import PIL
46
+
47
+
48
+ from ...utils.import_utils import requires
49
+
50
+
51
+ logger = logging.get_logger(__name__)
52
+
53
+
54
+ @requires(backends=("vision",))
55
+ class ChineseCLIPImageProcessor(BaseImageProcessor):
56
+ r"""
57
+ Constructs a Chinese-CLIP image processor.
58
+
59
+ Args:
60
+ do_resize (`bool`, *optional*, defaults to `True`):
61
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
62
+ `do_resize` in the `preprocess` method.
63
+ size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
64
+ Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with
65
+ the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess`
66
+ method.
67
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
68
+ Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
69
+ do_center_crop (`bool`, *optional*, defaults to `True`):
70
+ Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the
71
+ `preprocess` method.
72
+ crop_size (`Dict[str, int]` *optional*, defaults to 224):
73
+ Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess`
74
+ method.
75
+ do_rescale (`bool`, *optional*, defaults to `True`):
76
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
77
+ the `preprocess` method.
78
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
79
+ Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
80
+ method.
81
+ do_normalize (`bool`, *optional*, defaults to `True`):
82
+ Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
83
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
84
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
85
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
86
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
87
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
88
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
89
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
90
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
91
+ Whether to convert the image to RGB.
92
+ """
93
+
94
+ model_input_names = ["pixel_values"]
95
+
96
+ def __init__(
97
+ self,
98
+ do_resize: bool = True,
99
+ size: Dict[str, int] = None,
100
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
101
+ do_center_crop: bool = True,
102
+ crop_size: Dict[str, int] = None,
103
+ do_rescale: bool = True,
104
+ rescale_factor: Union[int, float] = 1 / 255,
105
+ do_normalize: bool = True,
106
+ image_mean: Optional[Union[float, List[float]]] = None,
107
+ image_std: Optional[Union[float, List[float]]] = None,
108
+ do_convert_rgb: bool = True,
109
+ **kwargs,
110
+ ) -> None:
111
+ super().__init__(**kwargs)
112
+ size = size if size is not None else {"shortest_edge": 224}
113
+ size = get_size_dict(size, default_to_square=False)
114
+ crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
115
+ crop_size = get_size_dict(crop_size)
116
+
117
+ self.do_resize = do_resize
118
+ self.size = size
119
+ self.resample = resample
120
+ self.do_center_crop = do_center_crop
121
+ self.crop_size = crop_size
122
+ self.do_rescale = do_rescale
123
+ self.rescale_factor = rescale_factor
124
+ self.do_normalize = do_normalize
125
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
126
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
127
+ self.do_convert_rgb = do_convert_rgb
128
+
129
+ def resize(
130
+ self,
131
+ image: np.ndarray,
132
+ size: Dict[str, int],
133
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
134
+ data_format: Optional[Union[str, ChannelDimension]] = None,
135
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
136
+ **kwargs,
137
+ ) -> np.ndarray:
138
+ """
139
+ Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
140
+ resized to keep the input aspect ratio.
141
+
142
+ Args:
143
+ image (`np.ndarray`):
144
+ Image to resize.
145
+ size (`Dict[str, int]`):
146
+ Size of the output image.
147
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
148
+ Resampling filter to use when resiizing the image.
149
+ data_format (`str` or `ChannelDimension`, *optional*):
150
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
151
+ input_data_format (`ChannelDimension` or `str`, *optional*):
152
+ The channel dimension format of the input image. If not provided, it will be inferred from the input
153
+ image.
154
+ """
155
+ size = get_size_dict(size, default_to_square=False)
156
+ output_size = get_resize_output_image_size(
157
+ image, size=(size["height"], size["width"]), default_to_square=False, input_data_format=input_data_format
158
+ )
159
+ return resize(
160
+ image,
161
+ size=output_size,
162
+ resample=resample,
163
+ data_format=data_format,
164
+ input_data_format=input_data_format,
165
+ **kwargs,
166
+ )
167
+
168
+ @filter_out_non_signature_kwargs()
169
+ def preprocess(
170
+ self,
171
+ images: ImageInput,
172
+ do_resize: Optional[bool] = None,
173
+ size: Dict[str, int] = None,
174
+ resample: PILImageResampling = None,
175
+ do_center_crop: Optional[bool] = None,
176
+ crop_size: Optional[int] = None,
177
+ do_rescale: Optional[bool] = None,
178
+ rescale_factor: Optional[float] = None,
179
+ do_normalize: Optional[bool] = None,
180
+ image_mean: Optional[Union[float, List[float]]] = None,
181
+ image_std: Optional[Union[float, List[float]]] = None,
182
+ do_convert_rgb: Optional[bool] = None,
183
+ return_tensors: Optional[Union[str, TensorType]] = None,
184
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
185
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
186
+ ) -> PIL.Image.Image:
187
+ """
188
+ Preprocess an image or batch of images.
189
+
190
+ Args:
191
+ images (`ImageInput`):
192
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
193
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
194
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
195
+ Whether to resize the image.
196
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
197
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
198
+ the longest edge resized to keep the input aspect ratio.
199
+ resample (`int`, *optional*, defaults to `self.resample`):
200
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
201
+ has an effect if `do_resize` is set to `True`.
202
+ do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
203
+ Whether to center crop the image.
204
+ crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
205
+ Size of the center crop. Only has an effect if `do_center_crop` is set to `True`.
206
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
207
+ Whether to rescale the image.
208
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
209
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
210
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
211
+ Whether to normalize the image.
212
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
213
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
214
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
215
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
216
+ `True`.
217
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
218
+ Whether to convert the image to RGB.
219
+ return_tensors (`str` or `TensorType`, *optional*):
220
+ The type of tensors to return. Can be one of:
221
+ - Unset: Return a list of `np.ndarray`.
222
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
223
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
224
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
225
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
226
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
227
+ The channel dimension format for the output image. Can be one of:
228
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
229
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
230
+ - Unset: Use the channel dimension format of the input image.
231
+ input_data_format (`ChannelDimension` or `str`, *optional*):
232
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
233
+ from the input image. Can be one of:
234
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
235
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
236
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
237
+ """
238
+
239
+ do_resize = do_resize if do_resize is not None else self.do_resize
240
+ size = size if size is not None else self.size
241
+ size = get_size_dict(size, default_to_square=False)
242
+ resample = resample if resample is not None else self.resample
243
+ do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
244
+ crop_size = crop_size if crop_size is not None else self.crop_size
245
+ crop_size = get_size_dict(crop_size)
246
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
247
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
248
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
249
+ image_mean = image_mean if image_mean is not None else self.image_mean
250
+ image_std = image_std if image_std is not None else self.image_std
251
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
252
+
253
+ images = make_list_of_images(images)
254
+
255
+ if not valid_images(images):
256
+ raise ValueError(
257
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
258
+ "torch.Tensor, tf.Tensor or jax.ndarray."
259
+ )
260
+ validate_preprocess_arguments(
261
+ do_rescale=do_rescale,
262
+ rescale_factor=rescale_factor,
263
+ do_normalize=do_normalize,
264
+ image_mean=image_mean,
265
+ image_std=image_std,
266
+ do_center_crop=do_center_crop,
267
+ crop_size=crop_size,
268
+ do_resize=do_resize,
269
+ size=size,
270
+ resample=resample,
271
+ )
272
+ if do_convert_rgb:
273
+ images = [convert_to_rgb(image) for image in images]
274
+
275
+ # All transformations expect numpy arrays.
276
+ images = [to_numpy_array(image) for image in images]
277
+
278
+ if do_rescale and is_scaled_image(images[0]):
279
+ logger.warning_once(
280
+ "It looks like you are trying to rescale already rescaled images. If the input"
281
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
282
+ )
283
+
284
+ if input_data_format is None:
285
+ # We assume that all images have the same channel dimension format.
286
+ input_data_format = infer_channel_dimension_format(images[0])
287
+
288
+ all_images = []
289
+ for image in images:
290
+ if do_resize:
291
+ image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
292
+
293
+ if do_center_crop:
294
+ image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)
295
+
296
+ if do_rescale:
297
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
298
+
299
+ if do_normalize:
300
+ image = self.normalize(
301
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
302
+ )
303
+
304
+ all_images.append(image)
305
+ images = [
306
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
307
+ for image in all_images
308
+ ]
309
+
310
+ data = {"pixel_values": images}
311
+ return BatchFeature(data=data, tensor_type=return_tensors)
312
+
313
+
314
+ __all__ = ["ChineseCLIPImageProcessor"]
docs/transformers/src/transformers/models/chinese_clip/image_processing_chinese_clip_fast.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The OFA-Sys Team Authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Fast Image processor class for Chinese-CLIP."""
16
+
17
+ from ...image_processing_utils_fast import BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, BaseImageProcessorFast
18
+ from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling
19
+ from ...utils import add_start_docstrings
20
+
21
+
22
+ @add_start_docstrings(
23
+ "Constructs a fast ChineseCLIP image processor.",
24
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
25
+ )
26
+ class ChineseCLIPImageProcessorFast(BaseImageProcessorFast):
27
+ resample = PILImageResampling.BICUBIC
28
+ image_mean = OPENAI_CLIP_MEAN
29
+ image_std = OPENAI_CLIP_STD
30
+ size = {"shortest_edge": 224}
31
+ default_to_square = False
32
+ crop_size = {"height": 224, "width": 224}
33
+ do_resize = True
34
+ do_center_crop = True
35
+ do_rescale = True
36
+ do_normalize = True
37
+ do_convert_rgb = True
38
+
39
+
40
+ __all__ = ["ChineseCLIPImageProcessorFast"]
docs/transformers/src/transformers/models/chinese_clip/modeling_chinese_clip.py ADDED
@@ -0,0 +1,1630 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The OFA-Sys Team Authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch Chinese-CLIP model."""
16
+
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Any, List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+
25
+ from ...activations import ACT2FN
26
+ from ...modeling_outputs import (
27
+ BaseModelOutput,
28
+ BaseModelOutputWithPastAndCrossAttentions,
29
+ BaseModelOutputWithPooling,
30
+ BaseModelOutputWithPoolingAndCrossAttentions,
31
+ )
32
+ from ...modeling_utils import PreTrainedModel
33
+ from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
34
+ from ...utils import (
35
+ ModelOutput,
36
+ add_code_sample_docstrings,
37
+ add_start_docstrings,
38
+ add_start_docstrings_to_model_forward,
39
+ logging,
40
+ replace_return_docstrings,
41
+ torch_int,
42
+ )
43
+ from .configuration_chinese_clip import ChineseCLIPConfig, ChineseCLIPTextConfig, ChineseCLIPVisionConfig
44
+
45
+
46
+ logger = logging.get_logger(__name__)
47
+
48
+ _CHECKPOINT_FOR_DOC = "OFA-Sys/chinese-clip-vit-base-patch16"
49
+ _CONFIG_FOR_DOC = "ChineseCLIPConfig"
50
+
51
+
52
+ # https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
53
+ # Copied from transformers.models.clip.modeling_clip.contrastive_loss
54
+ def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
55
+ return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
56
+
57
+
58
+ def chinese_clip_loss(similarity: torch.Tensor) -> torch.Tensor:
59
+ caption_loss = contrastive_loss(similarity)
60
+ image_loss = contrastive_loss(similarity.t())
61
+ return (caption_loss + image_loss) / 2.0
62
+
63
+
64
+ @dataclass
65
+ class ChineseCLIPOutput(ModelOutput):
66
+ """
67
+ Args:
68
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
69
+ Contrastive loss for image-text similarity.
70
+ logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
71
+ The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
72
+ similarity scores.
73
+ logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
74
+ The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
75
+ similarity scores.
76
+ text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
77
+ The text embeddings obtained by applying the projection layer to the pooled output of
78
+ [`ChineseCLIPTextModel`].
79
+ image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
80
+ The image embeddings obtained by applying the projection layer to the pooled output of
81
+ [`ChineseCLIPVisionModel`].
82
+ text_model_output(`BaseModelOutputWithPoolingAndCrossAttentions`):
83
+ The output of the [`ChineseCLIPTextModel`].
84
+ vision_model_output(`BaseModelOutputWithPoolingAndCrossAttentions`):
85
+ The output of the [`ChineseCLIPVisionModel`].
86
+ """
87
+
88
+ loss: Optional[torch.FloatTensor] = None
89
+ logits_per_image: Optional[torch.FloatTensor] = None
90
+ logits_per_text: Optional[torch.FloatTensor] = None
91
+ text_embeds: Optional[torch.FloatTensor] = None
92
+ image_embeds: Optional[torch.FloatTensor] = None
93
+ text_model_output: BaseModelOutputWithPoolingAndCrossAttentions = None
94
+ vision_model_output: BaseModelOutputWithPoolingAndCrossAttentions = None
95
+
96
+ def to_tuple(self) -> Tuple[Any]:
97
+ return tuple(
98
+ self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
99
+ for k in self.keys()
100
+ )
101
+
102
+
103
+ # Copied from transformers.models.bert.modeling_bert.BertEmbeddings with Bert->ChineseCLIPText
104
+ class ChineseCLIPTextEmbeddings(nn.Module):
105
+ """Construct the embeddings from word, position and token_type embeddings."""
106
+
107
+ def __init__(self, config):
108
+ super().__init__()
109
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
110
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
111
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
112
+
113
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
114
+ # any TensorFlow checkpoint file
115
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
116
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
117
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
118
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
119
+ self.register_buffer(
120
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
121
+ )
122
+ self.register_buffer(
123
+ "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
124
+ )
125
+
126
+ def forward(
127
+ self,
128
+ input_ids: Optional[torch.LongTensor] = None,
129
+ token_type_ids: Optional[torch.LongTensor] = None,
130
+ position_ids: Optional[torch.LongTensor] = None,
131
+ inputs_embeds: Optional[torch.FloatTensor] = None,
132
+ past_key_values_length: int = 0,
133
+ ) -> torch.Tensor:
134
+ if input_ids is not None:
135
+ input_shape = input_ids.size()
136
+ else:
137
+ input_shape = inputs_embeds.size()[:-1]
138
+
139
+ seq_length = input_shape[1]
140
+
141
+ if position_ids is None:
142
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
143
+
144
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
145
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
146
+ # issue #5664
147
+ if token_type_ids is None:
148
+ if hasattr(self, "token_type_ids"):
149
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
150
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
151
+ token_type_ids = buffered_token_type_ids_expanded
152
+ else:
153
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
154
+
155
+ if inputs_embeds is None:
156
+ inputs_embeds = self.word_embeddings(input_ids)
157
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
158
+
159
+ embeddings = inputs_embeds + token_type_embeddings
160
+ if self.position_embedding_type == "absolute":
161
+ position_embeddings = self.position_embeddings(position_ids)
162
+ embeddings += position_embeddings
163
+ embeddings = self.LayerNorm(embeddings)
164
+ embeddings = self.dropout(embeddings)
165
+ return embeddings
166
+
167
+
168
+ # Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->ChineseCLIP
169
+ class ChineseCLIPVisionEmbeddings(nn.Module):
170
+ def __init__(self, config: ChineseCLIPVisionConfig):
171
+ super().__init__()
172
+ self.config = config
173
+ self.embed_dim = config.hidden_size
174
+ self.image_size = config.image_size
175
+ self.patch_size = config.patch_size
176
+
177
+ self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
178
+
179
+ self.patch_embedding = nn.Conv2d(
180
+ in_channels=config.num_channels,
181
+ out_channels=self.embed_dim,
182
+ kernel_size=self.patch_size,
183
+ stride=self.patch_size,
184
+ bias=False,
185
+ )
186
+
187
+ self.num_patches = (self.image_size // self.patch_size) ** 2
188
+ self.num_positions = self.num_patches + 1
189
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
190
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
191
+
192
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
193
+ """
194
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
195
+ images. This method is also adapted to support torch.jit tracing.
196
+
197
+ Adapted from:
198
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
199
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
200
+ """
201
+
202
+ num_patches = embeddings.shape[1] - 1
203
+ position_embedding = self.position_embedding.weight.unsqueeze(0)
204
+ num_positions = position_embedding.shape[1] - 1
205
+
206
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
207
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
208
+ return self.position_embedding(self.position_ids)
209
+
210
+ class_pos_embed = position_embedding[:, :1]
211
+ patch_pos_embed = position_embedding[:, 1:]
212
+
213
+ dim = embeddings.shape[-1]
214
+
215
+ new_height = height // self.patch_size
216
+ new_width = width // self.patch_size
217
+
218
+ sqrt_num_positions = torch_int(num_positions**0.5)
219
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
220
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
221
+
222
+ patch_pos_embed = nn.functional.interpolate(
223
+ patch_pos_embed,
224
+ size=(new_height, new_width),
225
+ mode="bicubic",
226
+ align_corners=False,
227
+ )
228
+
229
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
230
+
231
+ return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
232
+
233
+ def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
234
+ batch_size, _, height, width = pixel_values.shape
235
+ if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
236
+ raise ValueError(
237
+ f"Input image size ({height}*{width}) doesn't match model ({self.image_size}*{self.image_size})."
238
+ )
239
+ target_dtype = self.patch_embedding.weight.dtype
240
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
241
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
242
+
243
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
244
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
245
+ if interpolate_pos_encoding:
246
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
247
+ else:
248
+ embeddings = embeddings + self.position_embedding(self.position_ids)
249
+ return embeddings
250
+
251
+
252
+ # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->ChineseCLIPText
253
+ class ChineseCLIPTextSelfAttention(nn.Module):
254
+ def __init__(self, config, position_embedding_type=None):
255
+ super().__init__()
256
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
257
+ raise ValueError(
258
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
259
+ f"heads ({config.num_attention_heads})"
260
+ )
261
+
262
+ self.num_attention_heads = config.num_attention_heads
263
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
264
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
265
+
266
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
267
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
268
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
269
+
270
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
271
+ self.position_embedding_type = position_embedding_type or getattr(
272
+ config, "position_embedding_type", "absolute"
273
+ )
274
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
275
+ self.max_position_embeddings = config.max_position_embeddings
276
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
277
+
278
+ self.is_decoder = config.is_decoder
279
+
280
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
281
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
282
+ x = x.view(new_x_shape)
283
+ return x.permute(0, 2, 1, 3)
284
+
285
+ def forward(
286
+ self,
287
+ hidden_states: torch.Tensor,
288
+ attention_mask: Optional[torch.FloatTensor] = None,
289
+ head_mask: Optional[torch.FloatTensor] = None,
290
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
291
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
292
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
293
+ output_attentions: Optional[bool] = False,
294
+ ) -> Tuple[torch.Tensor]:
295
+ mixed_query_layer = self.query(hidden_states)
296
+
297
+ # If this is instantiated as a cross-attention module, the keys
298
+ # and values come from an encoder; the attention mask needs to be
299
+ # such that the encoder's padding tokens are not attended to.
300
+ is_cross_attention = encoder_hidden_states is not None
301
+
302
+ if is_cross_attention and past_key_value is not None:
303
+ # reuse k,v, cross_attentions
304
+ key_layer = past_key_value[0]
305
+ value_layer = past_key_value[1]
306
+ attention_mask = encoder_attention_mask
307
+ elif is_cross_attention:
308
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
309
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
310
+ attention_mask = encoder_attention_mask
311
+ elif past_key_value is not None:
312
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
313
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
314
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
315
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
316
+ else:
317
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
318
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
319
+
320
+ query_layer = self.transpose_for_scores(mixed_query_layer)
321
+
322
+ use_cache = past_key_value is not None
323
+ if self.is_decoder:
324
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
325
+ # Further calls to cross_attention layer can then reuse all cross-attention
326
+ # key/value_states (first "if" case)
327
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
328
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
329
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
330
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
331
+ past_key_value = (key_layer, value_layer)
332
+
333
+ # Take the dot product between "query" and "key" to get the raw attention scores.
334
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
335
+
336
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
337
+ query_length, key_length = query_layer.shape[2], key_layer.shape[2]
338
+ if use_cache:
339
+ position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
340
+ -1, 1
341
+ )
342
+ else:
343
+ position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
344
+ position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
345
+ distance = position_ids_l - position_ids_r
346
+
347
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
348
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
349
+
350
+ if self.position_embedding_type == "relative_key":
351
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
352
+ attention_scores = attention_scores + relative_position_scores
353
+ elif self.position_embedding_type == "relative_key_query":
354
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
355
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
356
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
357
+
358
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
359
+ if attention_mask is not None:
360
+ # Apply the attention mask is (precomputed for all layers in ChineseCLIPTextModel forward() function)
361
+ attention_scores = attention_scores + attention_mask
362
+
363
+ # Normalize the attention scores to probabilities.
364
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
365
+
366
+ # This is actually dropping out entire tokens to attend to, which might
367
+ # seem a bit unusual, but is taken from the original Transformer paper.
368
+ attention_probs = self.dropout(attention_probs)
369
+
370
+ # Mask heads if we want to
371
+ if head_mask is not None:
372
+ attention_probs = attention_probs * head_mask
373
+
374
+ context_layer = torch.matmul(attention_probs, value_layer)
375
+
376
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
377
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
378
+ context_layer = context_layer.view(new_context_layer_shape)
379
+
380
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
381
+
382
+ if self.is_decoder:
383
+ outputs = outputs + (past_key_value,)
384
+ return outputs
385
+
386
+
387
+ # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->ChineseCLIPText
388
+ class ChineseCLIPTextSelfOutput(nn.Module):
389
+ def __init__(self, config):
390
+ super().__init__()
391
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
392
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
393
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
394
+
395
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
396
+ hidden_states = self.dense(hidden_states)
397
+ hidden_states = self.dropout(hidden_states)
398
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
399
+ return hidden_states
400
+
401
+
402
+ CHINESE_CLIP_TEXT_SELF_ATTENTION_CLASSES = {
403
+ "eager": ChineseCLIPTextSelfAttention,
404
+ }
405
+
406
+
407
+ # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->ChineseCLIPText,BERT->CHINESE_CLIP_TEXT
408
+ class ChineseCLIPTextAttention(nn.Module):
409
+ def __init__(self, config, position_embedding_type=None):
410
+ super().__init__()
411
+ self.self = CHINESE_CLIP_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation](
412
+ config, position_embedding_type=position_embedding_type
413
+ )
414
+ self.output = ChineseCLIPTextSelfOutput(config)
415
+ self.pruned_heads = set()
416
+
417
+ def prune_heads(self, heads):
418
+ if len(heads) == 0:
419
+ return
420
+ heads, index = find_pruneable_heads_and_indices(
421
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
422
+ )
423
+
424
+ # Prune linear layers
425
+ self.self.query = prune_linear_layer(self.self.query, index)
426
+ self.self.key = prune_linear_layer(self.self.key, index)
427
+ self.self.value = prune_linear_layer(self.self.value, index)
428
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
429
+
430
+ # Update hyper params and store pruned heads
431
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
432
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
433
+ self.pruned_heads = self.pruned_heads.union(heads)
434
+
435
+ def forward(
436
+ self,
437
+ hidden_states: torch.Tensor,
438
+ attention_mask: Optional[torch.FloatTensor] = None,
439
+ head_mask: Optional[torch.FloatTensor] = None,
440
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
441
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
442
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
443
+ output_attentions: Optional[bool] = False,
444
+ ) -> Tuple[torch.Tensor]:
445
+ self_outputs = self.self(
446
+ hidden_states,
447
+ attention_mask,
448
+ head_mask,
449
+ encoder_hidden_states,
450
+ encoder_attention_mask,
451
+ past_key_value,
452
+ output_attentions,
453
+ )
454
+ attention_output = self.output(self_outputs[0], hidden_states)
455
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
456
+ return outputs
457
+
458
+
459
+ class ChineseCLIPVisionAttention(nn.Module):
460
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
461
+
462
+ def __init__(self, config):
463
+ super().__init__()
464
+ self.config = config
465
+ self.embed_dim = config.hidden_size
466
+ self.num_heads = config.num_attention_heads
467
+ self.head_dim = self.embed_dim // self.num_heads
468
+ if self.head_dim * self.num_heads != self.embed_dim:
469
+ raise ValueError(
470
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
471
+ f" {self.num_heads})."
472
+ )
473
+ self.scale = self.head_dim**-0.5
474
+ self.dropout = config.attention_dropout
475
+
476
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
477
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
478
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
479
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
480
+
481
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
482
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
483
+
484
+ def forward(
485
+ self,
486
+ hidden_states: torch.Tensor,
487
+ output_attentions: Optional[bool] = False,
488
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
489
+ """Input shape: Batch x Time x Channel"""
490
+
491
+ bsz, tgt_len, embed_dim = hidden_states.size()
492
+
493
+ # get query proj
494
+ query_states = self.q_proj(hidden_states) * self.scale
495
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
496
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
497
+
498
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
499
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
500
+ key_states = key_states.view(*proj_shape)
501
+ value_states = value_states.view(*proj_shape)
502
+
503
+ src_len = key_states.size(1)
504
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
505
+
506
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
507
+ raise ValueError(
508
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
509
+ f" {attn_weights.size()}"
510
+ )
511
+
512
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
513
+
514
+ if output_attentions:
515
+ # this operation is a bit akward, but it's required to
516
+ # make sure that attn_weights keeps its gradient.
517
+ # In order to do so, attn_weights have to reshaped
518
+ # twice and have to be reused in the following
519
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
520
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
521
+ else:
522
+ attn_weights_reshaped = None
523
+
524
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
525
+
526
+ attn_output = torch.bmm(attn_probs, value_states)
527
+
528
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
529
+ raise ValueError(
530
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
531
+ f" {attn_output.size()}"
532
+ )
533
+
534
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
535
+ attn_output = attn_output.transpose(1, 2)
536
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
537
+
538
+ attn_output = self.out_proj(attn_output)
539
+
540
+ return attn_output, attn_weights_reshaped
541
+
542
+
543
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->ChineseCLIPText
544
+ class ChineseCLIPTextIntermediate(nn.Module):
545
+ def __init__(self, config):
546
+ super().__init__()
547
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
548
+ if isinstance(config.hidden_act, str):
549
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
550
+ else:
551
+ self.intermediate_act_fn = config.hidden_act
552
+
553
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
554
+ hidden_states = self.dense(hidden_states)
555
+ hidden_states = self.intermediate_act_fn(hidden_states)
556
+ return hidden_states
557
+
558
+
559
+ # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->ChineseCLIPText
560
+ class ChineseCLIPTextOutput(nn.Module):
561
+ def __init__(self, config):
562
+ super().__init__()
563
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
564
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
565
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
566
+
567
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
568
+ hidden_states = self.dense(hidden_states)
569
+ hidden_states = self.dropout(hidden_states)
570
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
571
+ return hidden_states
572
+
573
+
574
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->ChineseCLIPVision
575
+ class ChineseCLIPVisionMLP(nn.Module):
576
+ def __init__(self, config):
577
+ super().__init__()
578
+ self.config = config
579
+ self.activation_fn = ACT2FN[config.hidden_act]
580
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
581
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
582
+
583
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
584
+ hidden_states = self.fc1(hidden_states)
585
+ hidden_states = self.activation_fn(hidden_states)
586
+ hidden_states = self.fc2(hidden_states)
587
+ return hidden_states
588
+
589
+
590
+ # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->ChineseCLIPText
591
+ class ChineseCLIPTextLayer(nn.Module):
592
+ def __init__(self, config):
593
+ super().__init__()
594
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
595
+ self.seq_len_dim = 1
596
+ self.attention = ChineseCLIPTextAttention(config)
597
+ self.is_decoder = config.is_decoder
598
+ self.add_cross_attention = config.add_cross_attention
599
+ if self.add_cross_attention:
600
+ if not self.is_decoder:
601
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
602
+ self.crossattention = ChineseCLIPTextAttention(config, position_embedding_type="absolute")
603
+ self.intermediate = ChineseCLIPTextIntermediate(config)
604
+ self.output = ChineseCLIPTextOutput(config)
605
+
606
+ def forward(
607
+ self,
608
+ hidden_states: torch.Tensor,
609
+ attention_mask: Optional[torch.FloatTensor] = None,
610
+ head_mask: Optional[torch.FloatTensor] = None,
611
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
612
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
613
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
614
+ output_attentions: Optional[bool] = False,
615
+ ) -> Tuple[torch.Tensor]:
616
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
617
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
618
+ self_attention_outputs = self.attention(
619
+ hidden_states,
620
+ attention_mask,
621
+ head_mask,
622
+ output_attentions=output_attentions,
623
+ past_key_value=self_attn_past_key_value,
624
+ )
625
+ attention_output = self_attention_outputs[0]
626
+
627
+ # if decoder, the last output is tuple of self-attn cache
628
+ if self.is_decoder:
629
+ outputs = self_attention_outputs[1:-1]
630
+ present_key_value = self_attention_outputs[-1]
631
+ else:
632
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
633
+
634
+ cross_attn_present_key_value = None
635
+ if self.is_decoder and encoder_hidden_states is not None:
636
+ if not hasattr(self, "crossattention"):
637
+ raise ValueError(
638
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
639
+ " by setting `config.add_cross_attention=True`"
640
+ )
641
+
642
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
643
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
644
+ cross_attention_outputs = self.crossattention(
645
+ attention_output,
646
+ attention_mask,
647
+ head_mask,
648
+ encoder_hidden_states,
649
+ encoder_attention_mask,
650
+ cross_attn_past_key_value,
651
+ output_attentions,
652
+ )
653
+ attention_output = cross_attention_outputs[0]
654
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
655
+
656
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
657
+ cross_attn_present_key_value = cross_attention_outputs[-1]
658
+ present_key_value = present_key_value + cross_attn_present_key_value
659
+
660
+ layer_output = apply_chunking_to_forward(
661
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
662
+ )
663
+ outputs = (layer_output,) + outputs
664
+
665
+ # if decoder, return the attn key/values as the last output
666
+ if self.is_decoder:
667
+ outputs = outputs + (present_key_value,)
668
+
669
+ return outputs
670
+
671
+ def feed_forward_chunk(self, attention_output):
672
+ intermediate_output = self.intermediate(attention_output)
673
+ layer_output = self.output(intermediate_output, attention_output)
674
+ return layer_output
675
+
676
+
677
+ class ChineseCLIPVisionLayer(nn.Module):
678
+ def __init__(self, config: ChineseCLIPConfig):
679
+ super().__init__()
680
+ self.embed_dim = config.hidden_size
681
+ self.self_attn = ChineseCLIPVisionAttention(config)
682
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
683
+ self.mlp = ChineseCLIPVisionMLP(config)
684
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
685
+
686
+ def forward(
687
+ self,
688
+ hidden_states: torch.Tensor,
689
+ output_attentions: Optional[bool] = False,
690
+ ) -> Tuple[torch.FloatTensor]:
691
+ """
692
+ Args:
693
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
694
+ output_attentions (`bool`, *optional*):
695
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
696
+ returned tensors for more detail.
697
+ """
698
+ residual = hidden_states
699
+
700
+ hidden_states = self.layer_norm1(hidden_states)
701
+ hidden_states, attn_weights = self.self_attn(
702
+ hidden_states=hidden_states,
703
+ output_attentions=output_attentions,
704
+ )
705
+ hidden_states = residual + hidden_states
706
+
707
+ residual = hidden_states
708
+ hidden_states = self.layer_norm2(hidden_states)
709
+ hidden_states = self.mlp(hidden_states)
710
+ hidden_states = residual + hidden_states
711
+
712
+ outputs = (hidden_states,)
713
+
714
+ if output_attentions:
715
+ outputs += (attn_weights,)
716
+
717
+ return outputs
718
+
719
+
720
+ # Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->ChineseCLIPText
721
+ class ChineseCLIPTextPooler(nn.Module):
722
+ def __init__(self, config):
723
+ super().__init__()
724
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
725
+ self.activation = nn.Tanh()
726
+
727
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
728
+ # We "pool" the model by simply taking the hidden state corresponding
729
+ # to the first token.
730
+ first_token_tensor = hidden_states[:, 0]
731
+ pooled_output = self.dense(first_token_tensor)
732
+ pooled_output = self.activation(pooled_output)
733
+ return pooled_output
734
+
735
+
736
+ class ChineseCLIPPreTrainedModel(PreTrainedModel):
737
+ """
738
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
739
+ models.
740
+ """
741
+
742
+ config_class = ChineseCLIPConfig
743
+ base_model_prefix = "chinese_clip"
744
+ supports_gradient_checkpointing = True
745
+
746
+ def _init_weights(self, module):
747
+ """Initialize the weights"""
748
+ factor = self.config.initializer_factor
749
+ if isinstance(module, ChineseCLIPVisionEmbeddings):
750
+ factor = self.config.initializer_factor
751
+ nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
752
+ nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
753
+ nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
754
+ elif isinstance(module, ChineseCLIPTextEmbeddings):
755
+ nn.init.normal_(module.word_embeddings.weight, mean=0.0, std=self.config.initializer_range)
756
+ nn.init.normal_(module.position_embeddings.weight, mean=0.0, std=self.config.initializer_range)
757
+ nn.init.normal_(module.token_type_embeddings.weight, mean=0.0, std=self.config.initializer_range)
758
+ for embedding in [module.word_embeddings, module.position_embeddings, module.token_type_embeddings]:
759
+ if embedding.padding_idx is not None:
760
+ embedding.weight.data[embedding.padding_idx].zero_()
761
+ elif isinstance(module, ChineseCLIPVisionAttention):
762
+ factor = self.config.initializer_factor
763
+ in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
764
+ out_proj_std = (module.embed_dim**-0.5) * factor
765
+ nn.init.normal_(module.q_proj.weight, std=in_proj_std)
766
+ nn.init.normal_(module.k_proj.weight, std=in_proj_std)
767
+ nn.init.normal_(module.v_proj.weight, std=in_proj_std)
768
+ nn.init.normal_(module.out_proj.weight, std=out_proj_std)
769
+ elif isinstance(module, ChineseCLIPVisionMLP):
770
+ factor = self.config.initializer_factor
771
+ in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
772
+ fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
773
+ nn.init.normal_(module.fc1.weight, std=fc_std)
774
+ nn.init.normal_(module.fc2.weight, std=in_proj_std)
775
+ elif isinstance(module, ChineseCLIPModel):
776
+ nn.init.normal_(
777
+ module.text_projection.weight,
778
+ std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
779
+ )
780
+ nn.init.normal_(
781
+ module.visual_projection.weight,
782
+ std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
783
+ )
784
+
785
+ if isinstance(module, nn.LayerNorm):
786
+ module.bias.data.zero_()
787
+ module.weight.data.fill_(1.0)
788
+ if isinstance(module, nn.Linear):
789
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
790
+ if module.bias is not None:
791
+ module.bias.data.zero_()
792
+
793
+
794
+ CHINESE_CLIP_START_DOCSTRING = r"""
795
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
796
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
797
+ behavior.
798
+
799
+ Parameters:
800
+ config ([`ChineseCLIPConfig`]): Model configuration class with all the parameters of the model.
801
+ Initializing with a config file does not load the weights associated with the model, only the
802
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
803
+ """
804
+
805
+ CHINESE_CLIP_TEXT_INPUTS_DOCSTRING = r"""
806
+ Args:
807
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
808
+ Indices of input sequence tokens in the vocabulary.
809
+
810
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
811
+ [`PreTrainedTokenizer.__call__`] for details.
812
+
813
+ [What are input IDs?](../glossary#input-ids)
814
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
815
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
816
+
817
+ - 1 for tokens that are **not masked**,
818
+ - 0 for tokens that are **masked**.
819
+
820
+ [What are attention masks?](../glossary#attention-mask)
821
+ token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
822
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
823
+ 1]`:
824
+
825
+ - 0 corresponds to a *sentence A* token,
826
+ - 1 corresponds to a *sentence B* token.
827
+
828
+ [What are token type IDs?](../glossary#token-type-ids)
829
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
830
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
831
+ config.max_position_embeddings - 1]`.
832
+
833
+ [What are position IDs?](../glossary#position-ids)
834
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
835
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
836
+
837
+ - 1 indicates the head is **not masked**,
838
+ - 0 indicates the head is **masked**.
839
+
840
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
841
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
842
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
843
+ model's internal embedding lookup matrix.
844
+ output_attentions (`bool`, *optional*):
845
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
846
+ tensors for more detail.
847
+ output_hidden_states (`bool`, *optional*):
848
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
849
+ more detail.
850
+ interpolate_pos_encoding (`bool`, *optional*, defaults `False`):
851
+ Whether to interpolate the pre-trained position encodings.
852
+ return_dict (`bool`, *optional*):
853
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
854
+ """
855
+
856
+ CHINESE_CLIP_VISION_INPUTS_DOCSTRING = r"""
857
+ Args:
858
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
859
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
860
+ [`AutoImageProcessor`]. See [`ChineseCLIPImageProcessor.__call__`] for details.
861
+ output_attentions (`bool`, *optional*):
862
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
863
+ tensors for more detail.
864
+ output_hidden_states (`bool`, *optional*):
865
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
866
+ more detail.
867
+ interpolate_pos_encoding (`bool`, *optional*, defaults `False`):
868
+ Whether to interpolate the pre-trained position encodings.
869
+ return_dict (`bool`, *optional*):
870
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
871
+ """
872
+
873
+ CHINESE_CLIP_INPUTS_DOCSTRING = r"""
874
+ Args:
875
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
876
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
877
+ it.
878
+
879
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
880
+ [`PreTrainedTokenizer.__call__`] for details.
881
+
882
+ [What are input IDs?](../glossary#input-ids)
883
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
884
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
885
+
886
+ - 1 for tokens that are **not masked**,
887
+ - 0 for tokens that are **masked**.
888
+
889
+ [What are attention masks?](../glossary#attention-mask)
890
+ token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
891
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
892
+ 1]`:
893
+
894
+ - 0 corresponds to a *sentence A* token,
895
+ - 1 corresponds to a *sentence B* token.
896
+
897
+ [What are token type IDs?](../glossary#token-type-ids)
898
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
899
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
900
+ config.max_position_embeddings - 1]`.
901
+
902
+ [What are position IDs?](../glossary#position-ids)
903
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
904
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
905
+ [`AutoImageProcessor`]. See [`ChineseCLIPImageProcessor.__call__`] for details.
906
+ return_loss (`bool`, *optional*):
907
+ Whether or not to return the contrastive loss.
908
+ output_attentions (`bool`, *optional*):
909
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
910
+ tensors for more detail.
911
+ output_hidden_states (`bool`, *optional*):
912
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
913
+ more detail.
914
+ return_dict (`bool`, *optional*):
915
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
916
+ """
917
+
918
+
919
+ # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->ChineseCLIPText
920
+ class ChineseCLIPTextEncoder(nn.Module):
921
+ def __init__(self, config):
922
+ super().__init__()
923
+ self.config = config
924
+ self.layer = nn.ModuleList([ChineseCLIPTextLayer(config) for _ in range(config.num_hidden_layers)])
925
+ self.gradient_checkpointing = False
926
+
927
+ def forward(
928
+ self,
929
+ hidden_states: torch.Tensor,
930
+ attention_mask: Optional[torch.FloatTensor] = None,
931
+ head_mask: Optional[torch.FloatTensor] = None,
932
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
933
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
934
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
935
+ use_cache: Optional[bool] = None,
936
+ output_attentions: Optional[bool] = False,
937
+ output_hidden_states: Optional[bool] = False,
938
+ return_dict: Optional[bool] = True,
939
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
940
+ all_hidden_states = () if output_hidden_states else None
941
+ all_self_attentions = () if output_attentions else None
942
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
943
+
944
+ if self.gradient_checkpointing and self.training:
945
+ if use_cache:
946
+ logger.warning_once(
947
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
948
+ )
949
+ use_cache = False
950
+
951
+ next_decoder_cache = () if use_cache else None
952
+ for i, layer_module in enumerate(self.layer):
953
+ if output_hidden_states:
954
+ all_hidden_states = all_hidden_states + (hidden_states,)
955
+
956
+ layer_head_mask = head_mask[i] if head_mask is not None else None
957
+ past_key_value = past_key_values[i] if past_key_values is not None else None
958
+
959
+ if self.gradient_checkpointing and self.training:
960
+ layer_outputs = self._gradient_checkpointing_func(
961
+ layer_module.__call__,
962
+ hidden_states,
963
+ attention_mask,
964
+ layer_head_mask,
965
+ encoder_hidden_states,
966
+ encoder_attention_mask,
967
+ past_key_value,
968
+ output_attentions,
969
+ )
970
+ else:
971
+ layer_outputs = layer_module(
972
+ hidden_states,
973
+ attention_mask,
974
+ layer_head_mask,
975
+ encoder_hidden_states,
976
+ encoder_attention_mask,
977
+ past_key_value,
978
+ output_attentions,
979
+ )
980
+
981
+ hidden_states = layer_outputs[0]
982
+ if use_cache:
983
+ next_decoder_cache += (layer_outputs[-1],)
984
+ if output_attentions:
985
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
986
+ if self.config.add_cross_attention:
987
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
988
+
989
+ if output_hidden_states:
990
+ all_hidden_states = all_hidden_states + (hidden_states,)
991
+
992
+ if not return_dict:
993
+ return tuple(
994
+ v
995
+ for v in [
996
+ hidden_states,
997
+ next_decoder_cache,
998
+ all_hidden_states,
999
+ all_self_attentions,
1000
+ all_cross_attentions,
1001
+ ]
1002
+ if v is not None
1003
+ )
1004
+ return BaseModelOutputWithPastAndCrossAttentions(
1005
+ last_hidden_state=hidden_states,
1006
+ past_key_values=next_decoder_cache,
1007
+ hidden_states=all_hidden_states,
1008
+ attentions=all_self_attentions,
1009
+ cross_attentions=all_cross_attentions,
1010
+ )
1011
+
1012
+
1013
+ class ChineseCLIPVisionEncoder(nn.Module):
1014
+ """
1015
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
1016
+ [`ChineseCLIPVisionEncoderLayer`].
1017
+
1018
+ Args:
1019
+ config: ChineseCLIPConfig
1020
+ """
1021
+
1022
+ def __init__(self, config: ChineseCLIPConfig):
1023
+ super().__init__()
1024
+ self.config = config
1025
+ self.layers = nn.ModuleList([ChineseCLIPVisionLayer(config) for _ in range(config.num_hidden_layers)])
1026
+ self.gradient_checkpointing = False
1027
+
1028
+ def forward(
1029
+ self,
1030
+ inputs_embeds,
1031
+ output_attentions: Optional[bool] = None,
1032
+ output_hidden_states: Optional[bool] = None,
1033
+ return_dict: Optional[bool] = None,
1034
+ ) -> Union[Tuple, BaseModelOutput]:
1035
+ r"""
1036
+ Args:
1037
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
1038
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
1039
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
1040
+ than the model's internal embedding lookup matrix.
1041
+ output_attentions (`bool`, *optional*):
1042
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1043
+ returned tensors for more detail.
1044
+ output_hidden_states (`bool`, *optional*):
1045
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1046
+ for more detail.
1047
+ return_dict (`bool`, *optional*):
1048
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1049
+ """
1050
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1051
+ output_hidden_states = (
1052
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1053
+ )
1054
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1055
+
1056
+ encoder_states = () if output_hidden_states else None
1057
+ all_attentions = () if output_attentions else None
1058
+
1059
+ hidden_states = inputs_embeds
1060
+ for idx, encoder_layer in enumerate(self.layers):
1061
+ if output_hidden_states:
1062
+ encoder_states = encoder_states + (hidden_states,)
1063
+ if self.gradient_checkpointing and self.training:
1064
+ layer_outputs = self._gradient_checkpointing_func(
1065
+ encoder_layer.__call__,
1066
+ hidden_states,
1067
+ output_attentions,
1068
+ )
1069
+ else:
1070
+ layer_outputs = encoder_layer(
1071
+ hidden_states,
1072
+ output_attentions=output_attentions,
1073
+ )
1074
+
1075
+ hidden_states = layer_outputs[0]
1076
+
1077
+ if output_attentions:
1078
+ all_attentions = all_attentions + (layer_outputs[1],)
1079
+
1080
+ if output_hidden_states:
1081
+ encoder_states = encoder_states + (hidden_states,)
1082
+
1083
+ if not return_dict:
1084
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
1085
+ return BaseModelOutput(
1086
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
1087
+ )
1088
+
1089
+
1090
+ class ChineseCLIPVisionTransformer(nn.Module):
1091
+ def __init__(self, config: ChineseCLIPVisionConfig):
1092
+ super().__init__()
1093
+ self.config = config
1094
+ embed_dim = config.hidden_size
1095
+
1096
+ self.embeddings = ChineseCLIPVisionEmbeddings(config)
1097
+ self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
1098
+ self.encoder = ChineseCLIPVisionEncoder(config)
1099
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
1100
+
1101
+ @add_start_docstrings_to_model_forward(CHINESE_CLIP_VISION_INPUTS_DOCSTRING)
1102
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=ChineseCLIPVisionConfig)
1103
+ def forward(
1104
+ self,
1105
+ pixel_values: Optional[torch.FloatTensor] = None,
1106
+ output_attentions: Optional[bool] = None,
1107
+ output_hidden_states: Optional[bool] = None,
1108
+ interpolate_pos_encoding: bool = False,
1109
+ return_dict: Optional[bool] = None,
1110
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
1111
+ r"""
1112
+ Returns:
1113
+ """
1114
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1115
+ output_hidden_states = (
1116
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1117
+ )
1118
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1119
+
1120
+ if pixel_values is None:
1121
+ raise ValueError("You have to specify pixel_values")
1122
+
1123
+ hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
1124
+ hidden_states = self.pre_layrnorm(hidden_states)
1125
+
1126
+ encoder_outputs = self.encoder(
1127
+ inputs_embeds=hidden_states,
1128
+ output_attentions=output_attentions,
1129
+ output_hidden_states=output_hidden_states,
1130
+ return_dict=return_dict,
1131
+ )
1132
+
1133
+ last_hidden_state = encoder_outputs[0]
1134
+ pooled_output = last_hidden_state[:, 0, :]
1135
+ pooled_output = self.post_layernorm(pooled_output)
1136
+
1137
+ if not return_dict:
1138
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
1139
+
1140
+ return BaseModelOutputWithPooling(
1141
+ last_hidden_state=last_hidden_state,
1142
+ pooler_output=pooled_output,
1143
+ hidden_states=encoder_outputs.hidden_states,
1144
+ attentions=encoder_outputs.attentions,
1145
+ )
1146
+
1147
+
1148
+ @add_start_docstrings(
1149
+ "The text model from CHINESE_CLIP without any head or projection on top.",
1150
+ CHINESE_CLIP_START_DOCSTRING,
1151
+ )
1152
+ class ChineseCLIPTextModel(ChineseCLIPPreTrainedModel):
1153
+ """
1154
+
1155
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
1156
+ cross-attention is added between the self-attention layers, following the architecture described in [Attention is
1157
+ all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
1158
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
1159
+
1160
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
1161
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
1162
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
1163
+ """
1164
+
1165
+ config_class = ChineseCLIPTextConfig
1166
+ _no_split_modules = ["ChineseCLIPTextEmbeddings"]
1167
+
1168
+ def __init__(self, config, add_pooling_layer=True):
1169
+ super().__init__(config)
1170
+ self.config = config
1171
+
1172
+ self.embeddings = ChineseCLIPTextEmbeddings(config)
1173
+ self.encoder = ChineseCLIPTextEncoder(config)
1174
+
1175
+ self.pooler = ChineseCLIPTextPooler(config) if add_pooling_layer else None
1176
+
1177
+ # Initialize weights and apply final processing
1178
+ self.post_init()
1179
+
1180
+ def get_input_embeddings(self):
1181
+ return self.embeddings.word_embeddings
1182
+
1183
+ def set_input_embeddings(self, value):
1184
+ self.embeddings.word_embeddings = value
1185
+
1186
+ def _prune_heads(self, heads_to_prune):
1187
+ """
1188
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1189
+ class PreTrainedModel
1190
+ """
1191
+ for layer, heads in heads_to_prune.items():
1192
+ self.encoder.layer[layer].attention.prune_heads(heads)
1193
+
1194
+ @add_start_docstrings_to_model_forward(CHINESE_CLIP_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1195
+ @add_code_sample_docstrings(
1196
+ checkpoint=_CHECKPOINT_FOR_DOC,
1197
+ output_type=BaseModelOutputWithPoolingAndCrossAttentions,
1198
+ config_class=_CONFIG_FOR_DOC,
1199
+ )
1200
+ def forward(
1201
+ self,
1202
+ input_ids: Optional[torch.Tensor] = None,
1203
+ attention_mask: Optional[torch.Tensor] = None,
1204
+ token_type_ids: Optional[torch.Tensor] = None,
1205
+ position_ids: Optional[torch.Tensor] = None,
1206
+ head_mask: Optional[torch.Tensor] = None,
1207
+ inputs_embeds: Optional[torch.Tensor] = None,
1208
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1209
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1210
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1211
+ use_cache: Optional[bool] = None,
1212
+ output_attentions: Optional[bool] = None,
1213
+ output_hidden_states: Optional[bool] = None,
1214
+ return_dict: Optional[bool] = None,
1215
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
1216
+ r"""
1217
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1218
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1219
+ the model is configured as a decoder.
1220
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1221
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1222
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
1223
+
1224
+ - 1 for tokens that are **not masked**,
1225
+ - 0 for tokens that are **masked**.
1226
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1227
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1228
+
1229
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1230
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1231
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1232
+ use_cache (`bool`, *optional*):
1233
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1234
+ `past_key_values`).
1235
+ """
1236
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1237
+ output_hidden_states = (
1238
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1239
+ )
1240
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1241
+
1242
+ if self.config.is_decoder:
1243
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1244
+ else:
1245
+ use_cache = False
1246
+
1247
+ if input_ids is not None and inputs_embeds is not None:
1248
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1249
+ elif input_ids is not None:
1250
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
1251
+ input_shape = input_ids.size()
1252
+ elif inputs_embeds is not None:
1253
+ input_shape = inputs_embeds.size()[:-1]
1254
+ else:
1255
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1256
+
1257
+ batch_size, seq_length = input_shape
1258
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1259
+
1260
+ # past_key_values_length
1261
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
1262
+
1263
+ if attention_mask is None:
1264
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
1265
+
1266
+ if token_type_ids is None:
1267
+ if hasattr(self.embeddings, "token_type_ids"):
1268
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
1269
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
1270
+ token_type_ids = buffered_token_type_ids_expanded
1271
+ else:
1272
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
1273
+
1274
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
1275
+ # ourselves in which case we just need to make it broadcastable to all heads.
1276
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
1277
+
1278
+ # If a 2D or 3D attention mask is provided for the cross-attention
1279
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
1280
+ if self.config.is_decoder and encoder_hidden_states is not None:
1281
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
1282
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
1283
+ if encoder_attention_mask is None:
1284
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
1285
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
1286
+ else:
1287
+ encoder_extended_attention_mask = None
1288
+
1289
+ # Prepare head mask if needed
1290
+ # 1.0 in head_mask indicate we keep the head
1291
+ # attention_probs has shape bsz x n_heads x N x N
1292
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1293
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1294
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1295
+
1296
+ embedding_output = self.embeddings(
1297
+ input_ids=input_ids,
1298
+ position_ids=position_ids,
1299
+ token_type_ids=token_type_ids,
1300
+ inputs_embeds=inputs_embeds,
1301
+ past_key_values_length=past_key_values_length,
1302
+ )
1303
+ encoder_outputs = self.encoder(
1304
+ embedding_output,
1305
+ attention_mask=extended_attention_mask,
1306
+ head_mask=head_mask,
1307
+ encoder_hidden_states=encoder_hidden_states,
1308
+ encoder_attention_mask=encoder_extended_attention_mask,
1309
+ past_key_values=past_key_values,
1310
+ use_cache=use_cache,
1311
+ output_attentions=output_attentions,
1312
+ output_hidden_states=output_hidden_states,
1313
+ return_dict=return_dict,
1314
+ )
1315
+ sequence_output = encoder_outputs[0]
1316
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1317
+
1318
+ if not return_dict:
1319
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
1320
+
1321
+ return BaseModelOutputWithPoolingAndCrossAttentions(
1322
+ last_hidden_state=sequence_output,
1323
+ pooler_output=pooled_output,
1324
+ past_key_values=encoder_outputs.past_key_values,
1325
+ hidden_states=encoder_outputs.hidden_states,
1326
+ attentions=encoder_outputs.attentions,
1327
+ cross_attentions=encoder_outputs.cross_attentions,
1328
+ )
1329
+
1330
+
1331
+ @add_start_docstrings(
1332
+ """The vision model from CHINESE_CLIP without any head or projection on top.""",
1333
+ CHINESE_CLIP_START_DOCSTRING,
1334
+ )
1335
+ class ChineseCLIPVisionModel(ChineseCLIPPreTrainedModel):
1336
+ config_class = ChineseCLIPVisionConfig
1337
+ main_input_name = "pixel_values"
1338
+ _no_split_modules = ["ChineseCLIPVisionEmbeddings", "ChineseCLIPVisionAttention"]
1339
+
1340
+ def __init__(self, config: ChineseCLIPVisionConfig):
1341
+ super().__init__(config)
1342
+ self.vision_model = ChineseCLIPVisionTransformer(config)
1343
+ # Initialize weights and apply final processing
1344
+ self.post_init()
1345
+
1346
+ def get_input_embeddings(self) -> nn.Module:
1347
+ return self.vision_model.embeddings.patch_embedding
1348
+
1349
+ @add_start_docstrings_to_model_forward(CHINESE_CLIP_VISION_INPUTS_DOCSTRING)
1350
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=ChineseCLIPVisionConfig)
1351
+ def forward(
1352
+ self,
1353
+ pixel_values: Optional[torch.FloatTensor] = None,
1354
+ output_attentions: Optional[bool] = None,
1355
+ output_hidden_states: Optional[bool] = None,
1356
+ interpolate_pos_encoding: bool = False,
1357
+ return_dict: Optional[bool] = None,
1358
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
1359
+ r"""
1360
+ Returns:
1361
+
1362
+ Examples:
1363
+
1364
+ ```python
1365
+ >>> from PIL import Image
1366
+ >>> import requests
1367
+ >>> from transformers import CLIPProcessor, ChineseCLIPVisionModel
1368
+
1369
+ >>> model = ChineseCLIPVisionModel.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16")
1370
+ >>> processor = CLIPProcessor.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16")
1371
+
1372
+ >>> url = "https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/pokemon.jpeg"
1373
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1374
+
1375
+ >>> inputs = processor(images=image, return_tensors="pt")
1376
+
1377
+ >>> outputs = model(**inputs)
1378
+ >>> last_hidden_state = outputs.last_hidden_state
1379
+ >>> pooled_output = outputs.pooler_output # pooled CLS states
1380
+ ```"""
1381
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1382
+
1383
+ return self.vision_model(
1384
+ pixel_values=pixel_values,
1385
+ output_attentions=output_attentions,
1386
+ output_hidden_states=output_hidden_states,
1387
+ interpolate_pos_encoding=interpolate_pos_encoding,
1388
+ return_dict=return_dict,
1389
+ )
1390
+
1391
+
1392
+ @add_start_docstrings(CHINESE_CLIP_START_DOCSTRING)
1393
+ class ChineseCLIPModel(ChineseCLIPPreTrainedModel):
1394
+ config_class = ChineseCLIPConfig
1395
+
1396
+ def __init__(self, config: ChineseCLIPConfig):
1397
+ super().__init__(config)
1398
+
1399
+ if not isinstance(config.text_config, ChineseCLIPTextConfig):
1400
+ raise TypeError(
1401
+ "config.text_config is expected to be of type ChineseCLIPTextConfig but is of type"
1402
+ f" {type(config.text_config)}."
1403
+ )
1404
+
1405
+ if not isinstance(config.vision_config, ChineseCLIPVisionConfig):
1406
+ raise TypeError(
1407
+ "config.vision_config is expected to be of type ChineseCLIPVisionConfig but is of type"
1408
+ f" {type(config.vision_config)}."
1409
+ )
1410
+
1411
+ text_config = config.text_config
1412
+ vision_config = config.vision_config
1413
+
1414
+ self.projection_dim = config.projection_dim
1415
+ self.text_embed_dim = text_config.hidden_size
1416
+ self.vision_embed_dim = vision_config.hidden_size
1417
+
1418
+ self.text_model = ChineseCLIPTextModel(text_config, add_pooling_layer=False)
1419
+ self.vision_model = ChineseCLIPVisionTransformer(vision_config)
1420
+
1421
+ self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
1422
+ self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
1423
+ self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
1424
+
1425
+ # Initialize weights and apply final processing
1426
+ self.post_init()
1427
+
1428
+ @add_start_docstrings_to_model_forward(CHINESE_CLIP_TEXT_INPUTS_DOCSTRING)
1429
+ def get_text_features(
1430
+ self,
1431
+ input_ids: Optional[torch.Tensor] = None,
1432
+ attention_mask: Optional[torch.Tensor] = None,
1433
+ token_type_ids: Optional[torch.Tensor] = None,
1434
+ position_ids: Optional[torch.Tensor] = None,
1435
+ output_attentions: Optional[bool] = None,
1436
+ output_hidden_states: Optional[bool] = None,
1437
+ return_dict: Optional[bool] = None,
1438
+ ) -> torch.FloatTensor:
1439
+ r"""
1440
+ Returns:
1441
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
1442
+ applying the projection layer to the final [CLS] hidden state of Text-Transformer.
1443
+
1444
+ Examples:
1445
+
1446
+ ```python
1447
+ >>> from transformers import AutoTokenizer, ChineseCLIPModel
1448
+
1449
+ >>> model = ChineseCLIPModel.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16")
1450
+ >>> tokenizer = AutoTokenizer.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16")
1451
+
1452
+ >>> inputs = tokenizer(["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"], padding=True, return_tensors="pt")
1453
+ >>> text_features = model.get_text_features(**inputs)
1454
+ >>> text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)
1455
+ ```"""
1456
+ # Use CHINESE_CLIP model's config for some fields (if specified) instead of those of vision & text components.
1457
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1458
+ output_hidden_states = (
1459
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1460
+ )
1461
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1462
+
1463
+ text_outputs = self.text_model(
1464
+ input_ids=input_ids,
1465
+ attention_mask=attention_mask,
1466
+ token_type_ids=token_type_ids,
1467
+ position_ids=position_ids,
1468
+ output_attentions=output_attentions,
1469
+ output_hidden_states=output_hidden_states,
1470
+ return_dict=return_dict,
1471
+ )
1472
+
1473
+ pooled_output = text_outputs[0][:, 0, :]
1474
+ text_features = self.text_projection(pooled_output)
1475
+
1476
+ return text_features
1477
+
1478
+ @add_start_docstrings_to_model_forward(CHINESE_CLIP_VISION_INPUTS_DOCSTRING)
1479
+ def get_image_features(
1480
+ self,
1481
+ pixel_values: Optional[torch.FloatTensor] = None,
1482
+ output_attentions: Optional[bool] = None,
1483
+ output_hidden_states: Optional[bool] = None,
1484
+ interpolate_pos_encoding: bool = False,
1485
+ return_dict: Optional[bool] = None,
1486
+ ) -> torch.FloatTensor:
1487
+ r"""
1488
+ Returns:
1489
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
1490
+ applying the projection layer to the final [CLS] hidden state of Vision-Transformer.
1491
+
1492
+ Examples:
1493
+
1494
+ ```python
1495
+ >>> from PIL import Image
1496
+ >>> import requests
1497
+ >>> from transformers import AutoProcessor, ChineseCLIPModel
1498
+
1499
+ >>> model = ChineseCLIPModel.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16")
1500
+ >>> processor = AutoProcessor.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16")
1501
+
1502
+ >>> url = "https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/pokemon.jpeg"
1503
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1504
+
1505
+ >>> inputs = processor(images=image, return_tensors="pt")
1506
+
1507
+ >>> image_features = model.get_image_features(**inputs)
1508
+ >>> image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
1509
+ ```"""
1510
+ # Use CHINESE_CLIP model's config for some fields (if specified) instead of those of vision & text components.
1511
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1512
+ output_hidden_states = (
1513
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1514
+ )
1515
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1516
+
1517
+ vision_outputs = self.vision_model(
1518
+ pixel_values=pixel_values,
1519
+ output_attentions=output_attentions,
1520
+ output_hidden_states=output_hidden_states,
1521
+ interpolate_pos_encoding=interpolate_pos_encoding,
1522
+ return_dict=return_dict,
1523
+ )
1524
+
1525
+ pooled_output = vision_outputs[1] # pooled_output
1526
+ image_features = self.visual_projection(pooled_output)
1527
+
1528
+ return image_features
1529
+
1530
+ @add_start_docstrings_to_model_forward(CHINESE_CLIP_INPUTS_DOCSTRING)
1531
+ @replace_return_docstrings(output_type=ChineseCLIPOutput, config_class=ChineseCLIPConfig)
1532
+ def forward(
1533
+ self,
1534
+ input_ids: Optional[torch.LongTensor] = None,
1535
+ pixel_values: Optional[torch.FloatTensor] = None,
1536
+ attention_mask: Optional[torch.Tensor] = None,
1537
+ token_type_ids: Optional[torch.Tensor] = None,
1538
+ position_ids: Optional[torch.LongTensor] = None,
1539
+ return_loss: Optional[bool] = None,
1540
+ output_attentions: Optional[bool] = None,
1541
+ output_hidden_states: Optional[bool] = None,
1542
+ interpolate_pos_encoding: bool = False,
1543
+ return_dict: Optional[bool] = None,
1544
+ ) -> Union[Tuple, ChineseCLIPOutput]:
1545
+ r"""
1546
+ Returns:
1547
+
1548
+ Examples:
1549
+
1550
+ ```python
1551
+ >>> from PIL import Image
1552
+ >>> import requests
1553
+ >>> from transformers import AutoProcessor, ChineseCLIPModel
1554
+
1555
+ >>> model = ChineseCLIPModel.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16")
1556
+ >>> processor = AutoProcessor.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16")
1557
+
1558
+ >>> url = "https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/pokemon.jpeg"
1559
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1560
+
1561
+ >>> inputs = processor(text=["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"], images=image, return_tensors="pt", padding=True)
1562
+
1563
+ >>> outputs = model(**inputs)
1564
+ >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
1565
+ >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
1566
+ ```"""
1567
+ # Use CHINESE_CLIP model's config for some fields (if specified) instead of those of vision & text components.
1568
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1569
+ output_hidden_states = (
1570
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1571
+ )
1572
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1573
+
1574
+ vision_outputs = self.vision_model(
1575
+ pixel_values=pixel_values,
1576
+ output_attentions=output_attentions,
1577
+ output_hidden_states=output_hidden_states,
1578
+ interpolate_pos_encoding=interpolate_pos_encoding,
1579
+ return_dict=return_dict,
1580
+ )
1581
+
1582
+ text_outputs = self.text_model(
1583
+ input_ids=input_ids,
1584
+ attention_mask=attention_mask,
1585
+ token_type_ids=token_type_ids,
1586
+ position_ids=position_ids,
1587
+ output_attentions=output_attentions,
1588
+ output_hidden_states=output_hidden_states,
1589
+ return_dict=return_dict,
1590
+ )
1591
+
1592
+ image_embeds = vision_outputs[1]
1593
+ image_embeds = self.visual_projection(image_embeds)
1594
+
1595
+ text_embeds = text_outputs[0][:, 0, :]
1596
+ text_embeds = self.text_projection(text_embeds)
1597
+
1598
+ # normalized features
1599
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
1600
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
1601
+
1602
+ # cosine similarity as logits
1603
+ logit_scale = self.logit_scale.exp()
1604
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
1605
+ logits_per_image = logits_per_text.t()
1606
+
1607
+ loss = None
1608
+ if return_loss:
1609
+ loss = chinese_clip_loss(logits_per_text)
1610
+
1611
+ if not return_dict:
1612
+ # fix the None pooled_output of text_outputs to conform with dict_output
1613
+ pooled_output = text_outputs[1]
1614
+ if pooled_output is None:
1615
+ text_outputs = (text_outputs[0],) + text_outputs[2:]
1616
+ output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
1617
+ return ((loss,) + output) if loss is not None else output
1618
+
1619
+ return ChineseCLIPOutput(
1620
+ loss=loss,
1621
+ logits_per_image=logits_per_image,
1622
+ logits_per_text=logits_per_text,
1623
+ text_embeds=text_embeds,
1624
+ image_embeds=image_embeds,
1625
+ text_model_output=text_outputs,
1626
+ vision_model_output=vision_outputs,
1627
+ )
1628
+
1629
+
1630
+ __all__ = ["ChineseCLIPModel", "ChineseCLIPPreTrainedModel", "ChineseCLIPTextModel", "ChineseCLIPVisionModel"]
docs/transformers/src/transformers/models/chinese_clip/processing_chinese_clip.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The OFA-Sys Team Authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Image/Text processor class for Chinese-CLIP
17
+ """
18
+
19
+ import warnings
20
+ from typing import List, Union
21
+
22
+ from ...image_utils import ImageInput
23
+ from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
24
+ from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
25
+
26
+
27
+ class ChineseClipProcessorKwargs(ProcessingKwargs, total=False):
28
+ _defaults = {}
29
+
30
+
31
+ class ChineseCLIPProcessor(ProcessorMixin):
32
+ r"""
33
+ Constructs a Chinese-CLIP processor which wraps a Chinese-CLIP image processor and a Chinese-CLIP tokenizer into a
34
+ single processor.
35
+
36
+ [`ChineseCLIPProcessor`] offers all the functionalities of [`ChineseCLIPImageProcessor`] and [`BertTokenizerFast`].
37
+ See the [`~ChineseCLIPProcessor.__call__`] and [`~ChineseCLIPProcessor.decode`] for more information.
38
+
39
+ Args:
40
+ image_processor ([`ChineseCLIPImageProcessor`], *optional*):
41
+ The image processor is a required input.
42
+ tokenizer ([`BertTokenizerFast`], *optional*):
43
+ The tokenizer is a required input.
44
+ """
45
+
46
+ attributes = ["image_processor", "tokenizer"]
47
+ image_processor_class = ("ChineseCLIPImageProcessor", "ChineseCLIPImageProcessorFast")
48
+ tokenizer_class = ("BertTokenizer", "BertTokenizerFast")
49
+
50
+ def __init__(self, image_processor=None, tokenizer=None, **kwargs):
51
+ feature_extractor = None
52
+ if "feature_extractor" in kwargs:
53
+ warnings.warn(
54
+ "The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`"
55
+ " instead.",
56
+ FutureWarning,
57
+ )
58
+ feature_extractor = kwargs.pop("feature_extractor")
59
+
60
+ image_processor = image_processor if image_processor is not None else feature_extractor
61
+ if image_processor is None:
62
+ raise ValueError("You need to specify an `image_processor`.")
63
+ if tokenizer is None:
64
+ raise ValueError("You need to specify a `tokenizer`.")
65
+
66
+ super().__init__(image_processor, tokenizer)
67
+ self.current_processor = self.image_processor
68
+
69
+ def __call__(
70
+ self,
71
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
72
+ images: ImageInput = None,
73
+ audio=None,
74
+ videos=None,
75
+ **kwargs: Unpack[ChineseClipProcessorKwargs],
76
+ ) -> BatchEncoding:
77
+ """
78
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
79
+ and `kwargs` arguments to BertTokenizerFast's [`~BertTokenizerFast.__call__`] if `text` is not `None` to encode
80
+ the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
81
+ CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
82
+ of the above two methods for more information.
83
+
84
+ Args:
85
+ text (`str`, `List[str]`, `List[List[str]]`):
86
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
87
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
88
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
89
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
90
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
91
+ tensor. Both channels-first and channels-last formats are supported.
92
+
93
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
94
+ If set, will return tensors of a particular framework. Acceptable values are:
95
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
96
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
97
+ - `'np'`: Return NumPy `np.ndarray` objects.
98
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
99
+ Returns:
100
+ [`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
101
+
102
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
103
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
104
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
105
+ `None`).
106
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
107
+ """
108
+
109
+ if text is None and images is None:
110
+ raise ValueError("You have to specify either text or images. Both cannot be none.")
111
+ output_kwargs = self._merge_kwargs(
112
+ ChineseClipProcessorKwargs,
113
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
114
+ **kwargs,
115
+ )
116
+
117
+ if text is not None:
118
+ encoding = self.tokenizer(text, **output_kwargs["text_kwargs"])
119
+ if images is not None:
120
+ image_features = self.image_processor(images, **output_kwargs["images_kwargs"])
121
+
122
+ # BC for explicit return_tensors
123
+ if "return_tensors" in output_kwargs["common_kwargs"]:
124
+ return_tensors = output_kwargs["common_kwargs"].pop("return_tensors", None)
125
+
126
+ if text is not None and images is not None:
127
+ encoding["pixel_values"] = image_features.pixel_values
128
+ return encoding
129
+ elif text is not None:
130
+ return encoding
131
+ else:
132
+ return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)
133
+
134
+ def batch_decode(self, *args, **kwargs):
135
+ """
136
+ This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
137
+ refer to the docstring of this method for more information.
138
+ """
139
+ return self.tokenizer.batch_decode(*args, **kwargs)
140
+
141
+ def decode(self, *args, **kwargs):
142
+ """
143
+ This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
144
+ the docstring of this method for more information.
145
+ """
146
+ return self.tokenizer.decode(*args, **kwargs)
147
+
148
+ @property
149
+ def model_input_names(self):
150
+ tokenizer_input_names = self.tokenizer.model_input_names
151
+ image_processor_input_names = self.image_processor.model_input_names
152
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
153
+
154
+ @property
155
+ def feature_extractor_class(self):
156
+ warnings.warn(
157
+ "`feature_extractor_class` is deprecated and will be removed in v5. Use `image_processor_class` instead.",
158
+ FutureWarning,
159
+ )
160
+ return self.image_processor_class
161
+
162
+
163
+ __all__ = ["ChineseCLIPProcessor"]
docs/transformers/src/transformers/models/clap/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_clap import *
22
+ from .feature_extraction_clap import *
23
+ from .modeling_clap import *
24
+ from .processing_clap import *
25
+ else:
26
+ import sys
27
+
28
+ _file = globals()["__file__"]
29
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/src/transformers/models/clap/configuration_clap.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """CLAP model configuration"""
16
+
17
+ from ...configuration_utils import PretrainedConfig
18
+ from ...utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class ClapTextConfig(PretrainedConfig):
25
+ r"""
26
+ This is the configuration class to store the configuration of a [`ClapTextModel`]. It is used to instantiate a CLAP
27
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
28
+ defaults will yield a similar configuration to that of the CLAP
29
+ [calp-hsat-fused](https://huggingface.co/laion/clap-hsat-fused) architecture.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+
35
+ Args:
36
+ vocab_size (`int`, *optional*, defaults to 30522):
37
+ Vocabulary size of the CLAP model. Defines the number of different tokens that can be represented by the
38
+ `inputs_ids` passed when calling [`ClapTextModel`].
39
+ hidden_size (`int`, *optional*, defaults to 768):
40
+ Dimensionality of the encoder layers and the pooler layer.
41
+ num_hidden_layers (`int`, *optional*, defaults to 12):
42
+ Number of hidden layers in the Transformer encoder.
43
+ num_attention_heads (`int`, *optional*, defaults to 12):
44
+ Number of attention heads for each attention layer in the Transformer encoder.
45
+ intermediate_size (`int`, *optional*, defaults to 3072):
46
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
47
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"relu"`):
48
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"relu"`,
49
+ `"relu"`, `"silu"` and `"relu_new"` are supported.
50
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
51
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
52
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
53
+ The dropout ratio for the attention probabilities.
54
+ max_position_embeddings (`int`, *optional*, defaults to 512):
55
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
56
+ just in case (e.g., 512 or 1024 or 2048).
57
+ type_vocab_size (`int`, *optional*, defaults to 2):
58
+ The vocabulary size of the `token_type_ids` passed when calling [`ClapTextModel`].
59
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
60
+ The epsilon used by the layer normalization layers.
61
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
62
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
63
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
64
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
65
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
66
+ with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
67
+ is_decoder (`bool`, *optional*, defaults to `False`):
68
+ Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
69
+ use_cache (`bool`, *optional*, defaults to `True`):
70
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
71
+ relevant if `config.is_decoder=True`.
72
+ projection_hidden_act (`str`, *optional*, defaults to `"relu"`):
73
+ The non-linear activation function (function or string) in the projection layer. If string, `"gelu"`,
74
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
75
+ projection_dim (`int`, *optional*, defaults to 512)
76
+ Dimension of the projection head of the `ClapTextModelWithProjection`.
77
+
78
+ Examples:
79
+
80
+ ```python
81
+ >>> from transformers import ClapTextConfig, ClapTextModel
82
+
83
+ >>> # Initializing a CLAP text configuration
84
+ >>> configuration = ClapTextConfig()
85
+
86
+ >>> # Initializing a model (with random weights) from the configuration
87
+ >>> model = ClapTextModel(configuration)
88
+
89
+ >>> # Accessing the model configuration
90
+ >>> configuration = model.config
91
+ ```"""
92
+
93
+ model_type = "clap_text_model"
94
+ base_config_key = "text_config"
95
+
96
+ def __init__(
97
+ self,
98
+ vocab_size=50265,
99
+ hidden_size=768,
100
+ num_hidden_layers=12,
101
+ num_attention_heads=12,
102
+ intermediate_size=3072,
103
+ hidden_act="gelu",
104
+ hidden_dropout_prob=0.1,
105
+ attention_probs_dropout_prob=0.1,
106
+ max_position_embeddings=514,
107
+ type_vocab_size=1,
108
+ initializer_factor=1.0,
109
+ layer_norm_eps=1e-12,
110
+ projection_dim=512,
111
+ pad_token_id=1,
112
+ bos_token_id=0,
113
+ eos_token_id=2,
114
+ position_embedding_type="absolute",
115
+ use_cache=True,
116
+ projection_hidden_act="relu",
117
+ **kwargs,
118
+ ):
119
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
120
+
121
+ self.vocab_size = vocab_size
122
+ self.hidden_size = hidden_size
123
+ self.num_hidden_layers = num_hidden_layers
124
+ self.num_attention_heads = num_attention_heads
125
+ self.hidden_act = hidden_act
126
+ self.intermediate_size = intermediate_size
127
+ self.hidden_dropout_prob = hidden_dropout_prob
128
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
129
+ self.max_position_embeddings = max_position_embeddings
130
+ self.type_vocab_size = type_vocab_size
131
+ self.initializer_factor = initializer_factor
132
+ self.layer_norm_eps = layer_norm_eps
133
+ self.position_embedding_type = position_embedding_type
134
+ self.use_cache = use_cache
135
+ self.projection_hidden_act = projection_hidden_act
136
+ self.projection_dim = projection_dim
137
+
138
+
139
+ class ClapAudioConfig(PretrainedConfig):
140
+ r"""
141
+ This is the configuration class to store the configuration of a [`ClapAudioModel`]. It is used to instantiate a
142
+ CLAP audio encoder according to the specified arguments, defining the model architecture. Instantiating a
143
+ configuration with the defaults will yield a similar configuration to that of the audio encoder of the CLAP
144
+ [laion/clap-htsat-fused](https://huggingface.co/laion/clap-htsat-fused) architecture.
145
+
146
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
147
+ documentation from [`PretrainedConfig`] for more information.
148
+
149
+ Args:
150
+ window_size (`int`, *optional*, defaults to 8):
151
+ Image size of the spectrogram
152
+ num_mel_bins (`int`, *optional*, defaults to 64):
153
+ Number of mel features used per frames. Should correspond to the value used in the `ClapProcessor` class.
154
+ spec_size (`int`, *optional*, defaults to 256):
155
+ Desired input size of the spectrogram that the model supports. It can be different from the output of the
156
+ `ClapFeatureExtractor`, in which case the input features will be resized. Corresponds to the `image_size`
157
+ of the audio models.
158
+ hidden_act (`str`, *optional*, defaults to `"gelu"`):
159
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
160
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
161
+ patch_size (`int`, *optional*, defaults to 4):
162
+ Patch size for the audio spectrogram
163
+ patch_stride (`list`, *optional*, defaults to `[4, 4]`):
164
+ Patch stride for the audio spectrogram
165
+ num_classes (`int`, *optional*, defaults to 527):
166
+ Number of classes used for the head training
167
+ hidden_size (`int`, *optional*, defaults to 768):
168
+ Hidden size of the output of the audio encoder. Correspond to the dimension of the penultimate layer's
169
+ output,which is sent to the projection MLP layer.
170
+ projection_dim (`int`, *optional*, defaults to 512):
171
+ Hidden size of the projection layer.
172
+ depths (`list`, *optional*, defaults to `[2, 2, 6, 2]`):
173
+ Depths used for the Swin Layers of the audio model
174
+ num_attention_heads (`list`, *optional*, defaults to `[4, 8, 16, 32]`):
175
+ Number of attention heads used for the Swin Layers of the audio model
176
+ enable_fusion (`bool`, *optional*, defaults to `False`):
177
+ Whether or not to enable patch fusion. This is the main contribution of the authors, and should give the
178
+ best results.
179
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
180
+ The dropout probability for all fully connected layers in the encoder.
181
+ fusion_type (`[type]`, *optional*):
182
+ Fusion type used for the patch fusion.
183
+ patch_embed_input_channels (`int`, *optional*, defaults to 1):
184
+ Number of channels used for the input spectrogram
185
+ flatten_patch_embeds (`bool`, *optional*, defaults to `True`):
186
+ Whether or not to flatten the patch embeddings
187
+ patch_embeds_hidden_size (`int`, *optional*, defaults to 96):
188
+ Hidden size of the patch embeddings. It is used as the number of output channels.
189
+ enable_patch_layer_norm (`bool`, *optional*, defaults to `True`):
190
+ Whether or not to enable layer normalization for the patch embeddings
191
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
192
+ Drop path rate for the patch fusion
193
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
194
+ The dropout ratio for the attention probabilities.
195
+ qkv_bias (`bool`, *optional*, defaults to `True`):
196
+ Whether or not to add a bias to the query, key, value projections.
197
+ mlp_ratio (`float`, *optional*, defaults to 4.0):
198
+ Ratio of the mlp hidden dim to embedding dim.
199
+ aff_block_r (`int`, *optional*, defaults to 4):
200
+ downsize_ratio used in the AudioFF block
201
+ num_hidden_layers (`int`, *optional*, defaults to 4):
202
+ Number of hidden layers in the Transformer encoder.
203
+ projection_hidden_act (`str`, *optional*, defaults to `"relu"`):
204
+ The non-linear activation function (function or string) in the projection layer. If string, `"gelu"`,
205
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
206
+ layer_norm_eps (`[type]`, *optional*, defaults to 1e-05):
207
+ The epsilon used by the layer normalization layers.
208
+ initializer_factor (`float`, *optional*, defaults to 1.0):
209
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
210
+ testing).
211
+
212
+ Example:
213
+
214
+ ```python
215
+ >>> from transformers import ClapAudioConfig, ClapAudioModel
216
+
217
+ >>> # Initializing a ClapAudioConfig with laion/clap-htsat-fused style configuration
218
+ >>> configuration = ClapAudioConfig()
219
+
220
+ >>> # Initializing a ClapAudioModel (with random weights) from the laion/clap-htsat-fused style configuration
221
+ >>> model = ClapAudioModel(configuration)
222
+
223
+ >>> # Accessing the model configuration
224
+ >>> configuration = model.config
225
+ ```"""
226
+
227
+ model_type = "clap_audio_model"
228
+ base_config_key = "audio_config"
229
+
230
+ def __init__(
231
+ self,
232
+ window_size=8,
233
+ num_mel_bins=64,
234
+ spec_size=256,
235
+ hidden_act="gelu",
236
+ patch_size=4,
237
+ patch_stride=[4, 4],
238
+ num_classes=527,
239
+ hidden_size=768,
240
+ projection_dim=512,
241
+ depths=[2, 2, 6, 2],
242
+ num_attention_heads=[4, 8, 16, 32],
243
+ enable_fusion=False,
244
+ hidden_dropout_prob=0.1,
245
+ fusion_type=None,
246
+ patch_embed_input_channels=1,
247
+ flatten_patch_embeds=True,
248
+ patch_embeds_hidden_size=96,
249
+ enable_patch_layer_norm=True,
250
+ drop_path_rate=0.0,
251
+ attention_probs_dropout_prob=0.0,
252
+ qkv_bias=True,
253
+ mlp_ratio=4.0,
254
+ aff_block_r=4,
255
+ num_hidden_layers=4,
256
+ projection_hidden_act="relu",
257
+ layer_norm_eps=1e-5,
258
+ initializer_factor=1.0,
259
+ **kwargs,
260
+ ):
261
+ super().__init__(**kwargs)
262
+ self.window_size = window_size
263
+ self.num_mel_bins = num_mel_bins
264
+ self.spec_size = spec_size
265
+ self.patch_size = patch_size
266
+ self.patch_stride = patch_stride
267
+ self.num_classes = num_classes
268
+ self.hidden_size = hidden_size
269
+ self.depths = depths
270
+ self.num_hidden_layers = num_hidden_layers
271
+ self.num_attention_heads = num_attention_heads
272
+ self.window_size = window_size
273
+ self.enable_fusion = enable_fusion
274
+ self.fusion_type = fusion_type
275
+ self.hidden_act = hidden_act
276
+ self.hidden_dropout_prob = hidden_dropout_prob
277
+ self.projection_dim = projection_dim
278
+ self.flatten_patch_embeds = flatten_patch_embeds
279
+ self.patch_embeds_hidden_size = patch_embeds_hidden_size
280
+ self.enable_patch_layer_norm = enable_patch_layer_norm
281
+ self.drop_path_rate = drop_path_rate
282
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
283
+ self.qkv_bias = qkv_bias
284
+ self.mlp_ratio = mlp_ratio
285
+ self.patch_embed_input_channels = patch_embed_input_channels
286
+ self.aff_block_r = aff_block_r
287
+ self.layer_norm_eps = layer_norm_eps
288
+ self.initializer_factor = initializer_factor
289
+ self.projection_hidden_act = projection_hidden_act
290
+
291
+
292
+ class ClapConfig(PretrainedConfig):
293
+ r"""
294
+ [`ClapConfig`] is the configuration class to store the configuration of a [`ClapModel`]. It is used to instantiate
295
+ a CLAP model according to the specified arguments, defining the text model and audio model configs. Instantiating a
296
+ configuration with the defaults will yield a similar configuration to that of the CLAP
297
+ [laion/clap-htsat-fused](https://huggingface.co/laion/clap-htsat-fused) architecture.
298
+
299
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
300
+ documentation from [`PretrainedConfig`] for more information.
301
+
302
+ Args:
303
+ text_config (`dict`, *optional*):
304
+ Dictionary of configuration options used to initialize [`ClapTextConfig`].
305
+ audio_config (`dict`, *optional*):
306
+ Dictionary of configuration options used to initialize [`ClapAudioConfig`].
307
+ logit_scale_init_value (`float`, *optional*, defaults to 14.29):
308
+ The initial value of the *logit_scale* parameter. Default is used as per the original CLAP implementation.
309
+ projection_dim (`int`, *optional*, defaults to 512):
310
+ Dimensionality of text and audio projection layers.
311
+ projection_hidden_act (`str`, *optional*, defaults to `"relu"`):
312
+ Activation function for the projection layers.
313
+ initializer_factor (`float`, *optional*, defaults to 1.0):
314
+ Factor to scale the initialization of the model weights.
315
+ kwargs (*optional*):
316
+ Dictionary of keyword arguments.
317
+
318
+ Example:
319
+
320
+ ```python
321
+ >>> from transformers import ClapConfig, ClapModel
322
+
323
+ >>> # Initializing a ClapConfig with laion-ai/base style configuration
324
+ >>> configuration = ClapConfig()
325
+
326
+ >>> # Initializing a ClapModel (with random weights) from the laion-ai/base style configuration
327
+ >>> model = ClapModel(configuration)
328
+
329
+ >>> # Accessing the model configuration
330
+ >>> configuration = model.config
331
+
332
+ >>> # We can also initialize a ClapConfig from a ClapTextConfig and a ClapAudioConfig
333
+ >>> from transformers import ClapTextConfig, ClapAudioConfig
334
+
335
+ >>> # Initializing a ClapText and ClapAudioConfig configuration
336
+ >>> config_text = ClapTextConfig()
337
+ >>> config_audio = ClapAudioConfig()
338
+
339
+ >>> config = ClapConfig.from_text_audio_configs(config_text, config_audio)
340
+ ```"""
341
+
342
+ model_type = "clap"
343
+ sub_configs = {"text_config": ClapTextConfig, "audio_config": ClapAudioConfig}
344
+
345
+ def __init__(
346
+ self,
347
+ text_config=None,
348
+ audio_config=None,
349
+ logit_scale_init_value=(1 / 0.07),
350
+ projection_dim=512,
351
+ projection_hidden_act="relu",
352
+ initializer_factor=1.0,
353
+ **kwargs,
354
+ ):
355
+ super().__init__(**kwargs)
356
+
357
+ if text_config is None:
358
+ text_config = {}
359
+ logger.info("text_config is None. Initializing the ClapTextConfig with default values.")
360
+
361
+ if audio_config is None:
362
+ audio_config = {}
363
+ logger.info("audio_config is None. initializing the ClapAudioConfig with default values.")
364
+
365
+ self.text_config = ClapTextConfig(**text_config)
366
+ self.audio_config = ClapAudioConfig(**audio_config)
367
+ self.text_config.projection_dim = projection_dim
368
+ self.audio_config.projection_dim = projection_dim
369
+
370
+ self.text_config.projection_hidden_act = projection_hidden_act
371
+ self.audio_config.projection_hidden_act = projection_hidden_act
372
+
373
+ self.projection_dim = projection_dim
374
+ self.projection_hidden_act = projection_hidden_act
375
+ self.hidden_size = self.text_config.hidden_size
376
+
377
+ self.logit_scale_init_value = logit_scale_init_value
378
+ self.initializer_factor = initializer_factor
379
+ self.num_hidden_layers = self.text_config.num_hidden_layers + len(self.audio_config.depths)
380
+
381
+ @classmethod
382
+ def from_text_audio_configs(cls, text_config: ClapTextConfig, audio_config: ClapAudioConfig, **kwargs):
383
+ r"""
384
+ Instantiate a [`ClapConfig`] (or a derived class) from clap text model configuration and clap audio model
385
+ configuration.
386
+
387
+ Returns:
388
+ [`ClapConfig`]: An instance of a configuration object
389
+ """
390
+
391
+ return cls(text_config=text_config.to_dict(), audio_config=audio_config.to_dict(), **kwargs)
392
+
393
+
394
+ __all__ = ["ClapAudioConfig", "ClapConfig", "ClapTextConfig"]