27M LinWeizheDragon commited on
Commit
2242717
·
0 Parent(s):

Duplicate from LinWeizheDragon/PreFLMR_ViT-G

Browse files

Co-authored-by: Weizhe Lin <LinWeizheDragon@users.noreply.huggingface.co>

.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ license: mit
4
+ language:
5
+ - en
6
+ tags:
7
+ - retrieval
8
+ - multi-modal
9
+ - knowledge-based visual question answering
10
+ - FLMR
11
+ - PreFLMR
12
+ ---
13
+
14
+ # PreFLMR model card
15
+
16
+ PreFLMR is an open-source model for multimodal knowledge retrieval. It is a transformer-based model that uses a combination of text and image inputs to retrieve relevant documents from a large corpus.
17
+
18
+ ## Model Details
19
+
20
+ ### Model Description
21
+
22
+ - **Model type:** FLMRModelForRetrieval
23
+ - **Language(s) (NLP):** English
24
+ - **License:** MIT License
25
+
26
+ ### Paper and resources for more detail
27
+
28
+ - **Blog Post for quick overview:** https://www.jinghong-chen.net/preflmr-sota-open-sourced-multi/
29
+ - **Paper:** https://arxiv.org/abs/2402.08327
30
+ - **Gradio Demo:** https://u60544-b8d4-53eaa55d.westx.seetacloud.com:8443/
31
+ - **Repository:** https://github.com/LinWeizheDragon/FLMR
32
+ - **Project Page:** https://preflmr.github.io/
33
+
34
+ ## Uses
35
+
36
+ ### Direct Use
37
+
38
+ This model can be used directly to retrieve documents from a large corpus using a combination of text and image input queries. The retrieval usage can be found in the [official implementation](https://github.com/LinWeizheDragon/FLMR).
39
+
40
+ ### Downstream Use
41
+
42
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
43
+
44
+ This model can be used combined with language models to create a retrieval-augmented language model. The use for Knowledge-based VQA can be found in [RAVQA](https://github.com/linweizhedragon/retrieval-augmented-visual-question-answering)
45
+
46
+ ## How to Get Started with the Model
47
+
48
+ For details of training, indexing, and performing retrieval, please refer to [here](https://github.com/LinWeizheDragon/FLMR).
49
+
50
+ ## Training datasets
51
+ The model is pre-trained on three types of tasks with a total of nine datasets:
52
+ 1. Image to Text retrieval: WIT, KVQA, and CC3M
53
+ 2. Question to Text retrieval: MSMARCO
54
+ 3. Image & Question to Text retrieval: LLaVA, OVEN, OKVQA, Infoseek and E-VQA
55
+
56
+ These datasets were converted to retrieval format. For details on the dataset split and conversion process, please refer to the paper [PreFLMR: Scaling Up Fine-Grained Late-Interaction Multi-modal Retrievers](https://arxiv.org/abs/2402.08327). We will release the proprocessed datasets soon.
57
+
58
+
59
+ ## Evaluation datasets
60
+ We evaluate our models on WIT, LLaVA, OVEN, KVQA, IGLUE (subset of WIT), Infoseek, E-VQA, OKVQA and MSMARCO.
61
+ | Model | Vision Encoder | Text Encoder | Checkpoint Name | No. Param. | WIT | LLaVA | OVEN | KVQA | IGLUE | Infoseek | E-VQA | OKVQA | MSMARCO |
62
+ |---------|----------------|--------------|-------------------------------------------------------------|-------|-------|--------|-------|-------|-------|----------|-------|--------|-------|
63
+ | PreFLMR | ViT-B | Base-v2 | [LinWeizheDragon/PreFLMR_ViT-B](https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-B) | 327M | 41.7 | 67.2 | 46.3 | 28.6 | 57.3 | 48.8 | 67.9 | 66.1 | 79.5 |
64
+ | PreFLMR | ViT-L | Base-v2 | [LinWeizheDragon/PreFLMR_ViT-L](https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L) | 543M | 60.5 | 71.8 | 59.8 | 43.6 | 69.2 | 57.9 | 70.8 | 68.5 | 78.7 |
65
+ | PreFLMR | ViT-G | Base-v2 | [LinWeizheDragon/PreFLMR_ViT-G](https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-G) | 2.1B | 61.5 | 72.4 | 63.4 | 42.1 |71.5 | 59.6 | 73.1 | 68.6 | 78.6 |
66
+
67
+ For the evaluation metrics, WIT uses Recall@10, IGLUE uses Recall@1, and all the rest datasets use Recall@5.
68
+
69
+
70
+ ## Citation
71
+
72
+ **BibTeX:**
73
+ ```
74
+ @article{Lin_Mei_Chen_Byrne_2024,
75
+ title={PreFLMR: Scaling Up Fine-Grained Late-Interaction Multi-modal Retrievers},
76
+ url={http://arxiv.org/abs/2402.08327},
77
+ number={arXiv:2402.08327},
78
+ publisher={arXiv},
79
+ author={Lin, Weizhe and Mei, Jingbiao and Chen, Jinghong and Byrne, Bill},
80
+ year={2024}}
81
+ ```
config.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "./PreFLMR_ViT-G",
3
+ "architectures": [
4
+ "FLMRModelForRetrieval"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_flmr.FLMRConfig",
8
+ "AutoModel": "modeling_flmr.FLMRModelForRetrieval"
9
+ },
10
+ "context_concat_output_from_text_encoder": true,
11
+ "context_concat_output_from_vision_encoder": false,
12
+ "dim": 128,
13
+ "initializer_range": 0.02,
14
+ "load_cpu_extension": false,
15
+ "mapping_network_prefix_length": 32,
16
+ "mask_instruction_token": ":",
17
+ "mask_punctuation": true,
18
+ "model_type": "flmr",
19
+ "query_concat_output_from_text_encoder": true,
20
+ "query_concat_output_from_vision_encoder": true,
21
+ "separate_query_and_context_text_encoder": true,
22
+ "separate_query_and_context_vision_encoder": false,
23
+ "text_config": {
24
+ "architectures": [
25
+ "BertForMaskedLM"
26
+ ],
27
+ "gradient_checkpointing": false,
28
+ "model_type": "flmr_text_model",
29
+ "use_cache": true
30
+ },
31
+ "torch_dtype": "float32",
32
+ "transformer_mapping_config_base": "bert-base-uncased",
33
+ "transformer_mapping_cross_attention_length": 32,
34
+ "transformer_mapping_num_hidden_layers": 1,
35
+ "transformers_version": "4.37.2",
36
+ "use_transformer_mapping_network": true,
37
+ "use_vision_encoder": true,
38
+ "vision_config": {
39
+ "dropout": 0.0,
40
+ "hidden_act": "gelu",
41
+ "hidden_size": 1664,
42
+ "intermediate_size": 8192,
43
+ "model_type": "flmr_vision_model",
44
+ "num_attention_heads": 16,
45
+ "num_hidden_layers": 48,
46
+ "patch_size": 14,
47
+ "projection_dim": 1280
48
+ },
49
+ "vision_model_version": "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
50
+ }
configuration_flmr.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2010, FLMR authors, The Hugging Face Team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ FLMR model configuration"""
16
+
17
+ import os
18
+ from typing import Union
19
+
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.utils import logging
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ FLMR_PRETRAINED_CONFIG_ARCHIVE_MAP = {
27
+ "LinWeizheDragon/PreFLMR_ViT-L": "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/config.json",
28
+ "LinWeizheDragon/FLMR": "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/config.json",
29
+ }
30
+
31
+
32
+ # Modified from transformers.models.clip.configuration_clip.CLIPVisionConfig with CLIP -> FLMR
33
+ class FLMRVisionConfig(PretrainedConfig):
34
+ r"""
35
+ This is the configuration class to store the configuration of a [`FLMRVisionModel`]. It is used to instantiate a
36
+ FLMR vision encoder according to the specified arguments, defining the model architecture. Instantiating a
37
+ configuration with the defaults will yield a similar configuration to that of the vision encoder of the FLMR
38
+ [openai/flmr-vit-base-patch32](https://huggingface.co/openai/flmr-vit-base-patch32) architecture.
39
+
40
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
41
+ documentation from [`PretrainedConfig`] for more information.
42
+
43
+ Args:
44
+ hidden_size (`int`, *optional*, defaults to 768):
45
+ Dimensionality of the encoder layers and the pooler layer.
46
+ intermediate_size (`int`, *optional*, defaults to 3072):
47
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
48
+ projection_dim (`int`, *optional*, defaults to 512):
49
+ Dimentionality of text and vision projection layers.
50
+ num_hidden_layers (`int`, *optional*, defaults to 12):
51
+ Number of hidden layers in the Transformer encoder.
52
+ num_attention_heads (`int`, *optional*, defaults to 12):
53
+ Number of attention heads for each attention layer in the Transformer encoder.
54
+ num_channels (`int`, *optional*, defaults to 3):
55
+ The number of input channels.
56
+ image_size (`int`, *optional*, defaults to 224):
57
+ The size (resolution) of each image.
58
+ patch_size (`int`, *optional*, defaults to 32):
59
+ The size (resolution) of each patch.
60
+ hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
61
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
62
+ `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
63
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
64
+ The epsilon used by the layer normalization layers.
65
+ attention_dropout (`float`, *optional*, defaults to 0.0):
66
+ The dropout ratio for the attention probabilities.
67
+ initializer_range (`float`, *optional*, defaults to 0.02):
68
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
69
+ initializer_factor (`float`, *optional*, defaults to 1.0):
70
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
71
+ testing).
72
+
73
+ Example:
74
+
75
+ ```python
76
+ >>> from transformers import FLMRVisionConfig, FLMRVisionModel
77
+
78
+ >>> # Initializing a FLMRVisionConfig with LinWeizheDragon/FLMR style configuration
79
+ >>> configuration = FLMRVisionConfig()
80
+
81
+ >>> # Initializing a FLMRVisionModel (with random weights) from the LinWeizheDragon/FLMR style configuration
82
+ >>> model = FLMRVisionModel(configuration)
83
+
84
+ >>> # Accessing the model configuration
85
+ >>> configuration = model.config
86
+ ```"""
87
+
88
+ model_type = "flmr_vision_model"
89
+
90
+ def __init__(
91
+ self,
92
+ hidden_size=768,
93
+ intermediate_size=3072,
94
+ projection_dim=512,
95
+ num_hidden_layers=12,
96
+ num_attention_heads=12,
97
+ num_channels=3,
98
+ image_size=224,
99
+ patch_size=32,
100
+ hidden_act="quick_gelu",
101
+ layer_norm_eps=1e-5,
102
+ attention_dropout=0.0,
103
+ initializer_range=0.02,
104
+ initializer_factor=1.0,
105
+ **kwargs,
106
+ ):
107
+ super().__init__(**kwargs)
108
+
109
+ self.hidden_size = hidden_size
110
+ self.intermediate_size = intermediate_size
111
+ self.projection_dim = projection_dim
112
+ self.num_hidden_layers = num_hidden_layers
113
+ self.num_attention_heads = num_attention_heads
114
+ self.num_channels = num_channels
115
+ self.patch_size = patch_size
116
+ self.image_size = image_size
117
+ self.initializer_range = initializer_range
118
+ self.initializer_factor = initializer_factor
119
+ self.attention_dropout = attention_dropout
120
+ self.layer_norm_eps = layer_norm_eps
121
+ self.hidden_act = hidden_act
122
+
123
+ @classmethod
124
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
125
+ cls._set_token_in_kwargs(kwargs)
126
+
127
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
128
+
129
+ # get the vision config dict if we are loading from a CLIPConfig
130
+ if config_dict.get("model_type") == "clip":
131
+ config_dict = config_dict["vision_config"]
132
+
133
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
134
+ logger.warning(
135
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
136
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
137
+ )
138
+
139
+ return cls.from_dict(config_dict, **kwargs)
140
+
141
+
142
+ # Modified from transformers.models.dpr.configuration_dpr.DPRConfig with DPR -> FLMR
143
+ class FLMRTextConfig(PretrainedConfig):
144
+ r"""
145
+ [`FLMRTextConfig`] is the configuration class to store the configuration of a *FLMRTextModel*.
146
+
147
+ This is the configuration class to store the configuration of a [`FLMRTextModel`]. It is used to instantiate the components of the FLMR model according to the specified arguments,
148
+ defining the model component architectures. Instantiating a configuration with the defaults will yield a similar
149
+ configuration to that of the DPRContextEncoder
150
+ [facebook/dpr-ctx_encoder-single-nq-base](https://huggingface.co/facebook/dpr-ctx_encoder-single-nq-base)
151
+ architecture.
152
+
153
+ This class is a subclass of [`BertConfig`]. Please check the superclass for the documentation of all kwargs.
154
+
155
+ Args:
156
+ vocab_size (`int`, *optional*, defaults to 30522):
157
+ Vocabulary size of the FLMR model. Defines the different tokens that can be represented by the *inputs_ids*
158
+ passed to the forward method of [`BertModel`].
159
+ hidden_size (`int`, *optional*, defaults to 768):
160
+ Dimensionality of the encoder layers and the pooler layer.
161
+ num_hidden_layers (`int`, *optional*, defaults to 12):
162
+ Number of hidden layers in the Transformer encoder.
163
+ num_attention_heads (`int`, *optional*, defaults to 12):
164
+ Number of attention heads for each attention layer in the Transformer encoder.
165
+ intermediate_size (`int`, *optional*, defaults to 3072):
166
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
167
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
168
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
169
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
170
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
171
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
172
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
173
+ The dropout ratio for the attention probabilities.
174
+ max_position_embeddings (`int`, *optional*, defaults to 512):
175
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
176
+ just in case (e.g., 512 or 1024 or 2048).
177
+ type_vocab_size (`int`, *optional*, defaults to 2):
178
+ The vocabulary size of the *token_type_ids* passed into [`BertModel`].
179
+ initializer_range (`float`, *optional*, defaults to 0.02):
180
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
181
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
182
+ The epsilon used by the layer normalization layers.
183
+ pad_token_id (`int`, *optional*, defaults to 0):
184
+ Padding token id.
185
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
186
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
187
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
188
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
189
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
190
+ with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
191
+ projection_dim (`int`, *optional*, defaults to 0):
192
+ Dimension of the projection for the context and question encoders. If it is set to zero (default), then no
193
+ projection is done.
194
+
195
+ Example:
196
+
197
+ ```python
198
+ >>> from transformers import FLMRTextConfig, FLMRTextModel
199
+
200
+ >>> # Initializing a FLMR LinWeizheDragon/FLMR style configuration
201
+ >>> configuration = FLMRTextConfig()
202
+
203
+ >>> # Initializing a model (with random weights) from the LinWeizheDragon/FLMR style configuration
204
+ >>> model = FLMRTextModel(configuration)
205
+
206
+ >>> # Accessing the model configuration
207
+ >>> configuration = model.config
208
+ ```"""
209
+
210
+ model_type = "flmr_text_model"
211
+
212
+ def __init__(
213
+ self,
214
+ vocab_size=30522,
215
+ hidden_size=768,
216
+ num_hidden_layers=12,
217
+ num_attention_heads=12,
218
+ intermediate_size=3072,
219
+ hidden_act="gelu",
220
+ hidden_dropout_prob=0.1,
221
+ attention_probs_dropout_prob=0.1,
222
+ max_position_embeddings=512,
223
+ type_vocab_size=2,
224
+ initializer_range=0.02,
225
+ layer_norm_eps=1e-12,
226
+ pad_token_id=0,
227
+ position_embedding_type="absolute",
228
+ projection_dim: int = 0,
229
+ **kwargs,
230
+ ):
231
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
232
+
233
+ self.vocab_size = vocab_size
234
+ self.hidden_size = hidden_size
235
+ self.num_hidden_layers = num_hidden_layers
236
+ self.num_attention_heads = num_attention_heads
237
+ self.hidden_act = hidden_act
238
+ self.intermediate_size = intermediate_size
239
+ self.hidden_dropout_prob = hidden_dropout_prob
240
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
241
+ self.max_position_embeddings = max_position_embeddings
242
+ self.type_vocab_size = type_vocab_size
243
+ self.initializer_range = initializer_range
244
+ self.layer_norm_eps = layer_norm_eps
245
+ self.projection_dim = projection_dim
246
+ self.position_embedding_type = position_embedding_type
247
+
248
+
249
+ class FLMRConfig(PretrainedConfig):
250
+ r"""
251
+ [`FLMRConfig`] is the configuration class to store the configuration of a *FLMRModelForRetrieval*.
252
+ This is the configuration class to store the configuration of a [`FLMRModelForRetrieval`]. It is used to instantiate the components of the FLMR model according to the specified arguments,
253
+ defining the model component architectures. Instantiating a configuration with the defaults will yield a similar
254
+ configuration to that of the FLMR
255
+ [LinWeizheDragon/PreFLMR_ViT-G](https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-G)
256
+ architecture.
257
+
258
+ Args:
259
+ vision_config (`FLMRVisionConfig`, *optional*):
260
+ Configuration for the vision encoder.
261
+ text_config (`FLMRTextConfig`, *optional*):
262
+ Configuration for the text encoder.
263
+ mask_punctuation (`bool`, *optional*, defaults to `True`):
264
+ Whether to mask punctuation tokens in the input.
265
+ mapping_network_prefix_length (`int`, *optional*, defaults to 32):
266
+ The output length of the linear mapping network.
267
+ dim (`int`, *optional*, defaults to 128):
268
+ The late-interaction dimension of the model. The output of the text encoder, vision encoder, transformer mapping network should all be projected to this dimension for late-interaction scoring.
269
+ use_vision_encoder (`bool`, *optional*, defaults to `True`):
270
+ Whether to load the vision encoder. When no vision encoder is loaded, `image_features` should be used in the forward pass rather than `pixel_values`.
271
+ initializer_range (`float`, *optional*, defaults to 0.02):
272
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
273
+ separate_query_and_context_text_encoder (`bool`, *optional*, defaults to `False`):
274
+ Whether to use separate text encoders for query and context.
275
+ separate_query_and_context_vision_encoder (`bool`, *optional*, defaults to `False`):
276
+ Whether to use separate vision encoders for query and context.
277
+ query_concat_output_from_vision_encoder (`bool`, *optional*, defaults to `True`):
278
+ Whether to concatenate the output from the vision encoder to the output from the text encoder for the query.
279
+ query_concat_output_from_text_encoder (`bool`, *optional*, defaults to `True`):
280
+ Whether to concatenate the output from the text encoder to the output from the vision encoder for the query.
281
+ context_concat_output_from_vision_encoder (`bool`, *optional*, defaults to `False`):
282
+ Whether to concatenate the output from the vision encoder to the output from the text encoder for the context.
283
+ context_concat_output_from_text_encoder (`bool`, *optional*, defaults to `True`):
284
+ Whether to concatenate the output from the text encoder to the output from the vision encoder for the context.
285
+ use_transformer_mapping_network (`bool`, *optional*, defaults to `False`):
286
+ Whether to add a transformer mapping network to map the features from the vision encoder to the embedding space. This option is used in PreFLMR.
287
+ transformer_mapping_config_base (`str`, *optional*):
288
+ The base configuration for the transformer mapping network. This option is used in PreFLMR. An example of this argument is `bert-base-uncased`.
289
+ transformer_mapping_num_hidden_layers (`int`, *optional*):
290
+ The number of hidden layers in the transformer mapping network. This option is used in PreFLMR.
291
+ load_cpu_extension (`bool`, *optional*, defaults to `False`):
292
+ Whether to load the CPU extension. Only set this to `True` if a CPU is used in training and inference. In any case, GPU is recommended for training and inference.
293
+ mask_instruction_token (`str`, *optional*):
294
+ The token that indicates the end of the input instruction. All tokens before this token (the first one in a sequence) will be masked. This option is used in PreFLMR.
295
+ transformer_mapping_cross_attention_length (`int`, *optional*, defaults to 32):
296
+ The length of the cross attention in the transformer mapping network. This option is used in PreFLMR.
297
+ vision_model_version (`str`, *optional*, defaults to `"openai/clip-vit-base-patch32"`):
298
+ The version of the vision model being used in this FLMR model.
299
+ This option is used in performing retrieval only. Though it does not affect the model architecture, it is highly recommended to set this argument so that it properly reflects the version of the vision model being used in the FLMR model. This arugment will be saved in the model configuration, and it can be read by the indexing engine. The indexing engine will use this argument to initialize an image processor, which can process the input image files. Find more details under `examples/research_projects/flmr-retrieval`.
300
+
301
+ Example:
302
+
303
+ ```python
304
+ >>> from transformers import FLMRConfig, FLMRModelForRetrieval
305
+
306
+ >>> # Initializing a FLMR LinWeizheDragon/FLMR style configuration
307
+ >>> configuration = FLMRConfig()
308
+
309
+ >>> # Initializing a model (with random weights) from the FLMR style configuration
310
+ >>> model = FLMRModelForRetrieval(configuration)
311
+
312
+ >>> # Accessing the model configuration
313
+ >>> configuration = model.config
314
+ ```"""
315
+
316
+ model_type = "flmr"
317
+
318
+ def __init__(
319
+ self,
320
+ vision_config: FLMRVisionConfig = None,
321
+ text_config: FLMRTextConfig = None,
322
+ mask_punctuation: bool = True,
323
+ mapping_network_prefix_length: int = 32,
324
+ dim: int = 128,
325
+ use_vision_encoder: bool = True,
326
+ initializer_range: float = 0.02,
327
+ separate_query_and_context_text_encoder: bool = False,
328
+ separate_query_and_context_vision_encoder: bool = False,
329
+ query_concat_output_from_vision_encoder: bool = True,
330
+ query_concat_output_from_text_encoder: bool = True,
331
+ context_concat_output_from_vision_encoder: bool = False,
332
+ context_concat_output_from_text_encoder: bool = True,
333
+ use_transformer_mapping_network: bool = False,
334
+ transformer_mapping_config_base: str = None,
335
+ transformer_mapping_num_hidden_layers: int = None,
336
+ load_cpu_extension: bool = False,
337
+ mask_instruction_token: str = None,
338
+ transformer_mapping_cross_attention_length: int = 32,
339
+ vision_model_version: str = "openai/clip-vit-base-patch32",
340
+ **kwargs,
341
+ ):
342
+ super().__init__(**kwargs)
343
+
344
+ if vision_config is None:
345
+ vision_config = {}
346
+ if text_config is None:
347
+ text_config = {}
348
+
349
+ if not isinstance(vision_config, FLMRVisionConfig):
350
+ vision_config = FLMRVisionConfig(**vision_config)
351
+ if not isinstance(text_config, FLMRTextConfig):
352
+ text_config = FLMRTextConfig(**text_config)
353
+
354
+ self.vision_config = vision_config
355
+ self.text_config = text_config
356
+ self.dim = dim
357
+ self.initializer_range = initializer_range
358
+ self.mask_punctuation = mask_punctuation
359
+ self.mapping_network_prefix_length = mapping_network_prefix_length
360
+ self.use_vision_encoder = use_vision_encoder
361
+ self.separate_query_and_context_text_encoder = separate_query_and_context_text_encoder
362
+ self.separate_query_and_context_vision_encoder = separate_query_and_context_vision_encoder
363
+ self.query_concat_output_from_vision_encoder = query_concat_output_from_vision_encoder
364
+ self.query_concat_output_from_text_encoder = query_concat_output_from_text_encoder
365
+ self.context_concat_output_from_vision_encoder = context_concat_output_from_vision_encoder
366
+ self.context_concat_output_from_text_encoder = context_concat_output_from_text_encoder
367
+ self.use_transformer_mapping_network = use_transformer_mapping_network
368
+ self.transformer_mapping_config_base = transformer_mapping_config_base
369
+ self.transformer_mapping_num_hidden_layers = transformer_mapping_num_hidden_layers
370
+ self.load_cpu_extension = load_cpu_extension
371
+ self.mask_instruction_token = mask_instruction_token
372
+ self.transformer_mapping_cross_attention_length = transformer_mapping_cross_attention_length
373
+ self.vision_model_version = vision_model_version
374
+
375
+ @classmethod
376
+ def from_text_vision_configs(cls, text_config: FLMRTextConfig, vision_config: FLMRVisionConfig, **kwargs):
377
+ r"""
378
+ Instantiate a [`FLMRConfig`] (or a derived class) from FLMR text model configuration and FLMR vision model
379
+ configuration.
380
+
381
+ Returns:
382
+ [`FLMRConfig`]: An instance of a configuration object
383
+ """
384
+
385
+ return cls(text_config=text_config, vision_config=vision_config, **kwargs)
context_tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": {
3
+ "content": "[CLS]",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "mask_token": {
10
+ "content": "[MASK]",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "[PAD]",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "sep_token": {
24
+ "content": "[SEP]",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "unk_token": {
31
+ "content": "[UNK]",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ }
37
+ }
context_tokenizer/tokenization_flmr.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team, The Hugging Face Team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for FLMR."""
16
+
17
+
18
+ from typing import List, Optional, Union
19
+
20
+ from transformers.utils import TensorType, logging
21
+ from transformers.models.bert.tokenization_bert import BertTokenizer
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer_config.json"}
27
+
28
+ CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
29
+ "vocab_file": {
30
+ "LinWeizheDragon/PreFLMR_ViT-L": (
31
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/context_tokenizer/vocab.txt"
32
+ ),
33
+ "LinWeizheDragon/FLMR": (
34
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/context_tokenizer/vocab.txt"
35
+ ),
36
+ },
37
+ "tokenizer_file": {
38
+ "LinWeizheDragon/PreFLMR_ViT-L": (
39
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/context_tokenizer/tokenizer_config.json"
40
+ ),
41
+ "LinWeizheDragon/FLMR": (
42
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/context_tokenizer/tokenizer_config.json"
43
+ ),
44
+ },
45
+ }
46
+ QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
47
+ "vocab_file": {
48
+ "LinWeizheDragon/PreFLMR_ViT-L": (
49
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/query_tokenizer/vocab.txt"
50
+ ),
51
+ "LinWeizheDragon/FLMR": ("https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/query_tokenizer/vocab.txt"),
52
+ },
53
+ "tokenizer_file": {
54
+ "LinWeizheDragon/PreFLMR_ViT-L": (
55
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/query_tokenizer/tokenizer_config.json"
56
+ ),
57
+ "LinWeizheDragon/FLMR": (
58
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/query_tokenizer/tokenizer_config.json"
59
+ ),
60
+ },
61
+ }
62
+
63
+
64
+ CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
65
+ "LinWeizheDragon/PreFLMR_ViT-L": 512,
66
+ "LinWeizheDragon/FLMR": 512,
67
+ }
68
+ QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
69
+ "LinWeizheDragon/PreFLMR_ViT-L": 512,
70
+ "LinWeizheDragon/FLMR": 512,
71
+ }
72
+
73
+
74
+ CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
75
+ "LinWeizheDragon/PreFLMR_ViT-L": {"do_lower_case": True},
76
+ "LinWeizheDragon/FLMR": {"do_lower_case": True},
77
+ }
78
+ QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
79
+ "LinWeizheDragon/PreFLMR_ViT-L": {"do_lower_case": True},
80
+ "LinWeizheDragon/FLMR": {"do_lower_case": True},
81
+ }
82
+
83
+
84
+ # Modified from colbert.modeling.tokenization
85
+ class FLMRContextEncoderTokenizer(BertTokenizer):
86
+ r"""
87
+ Construct a FLMRContextEncoder tokenizer.
88
+
89
+ [`FLMRContextEncoderTokenizer`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation
90
+ splitting and wordpiece.
91
+
92
+ Refer to superclass [`BertTokenizer`] for usage examples and documentation concerning parameters.
93
+ """
94
+
95
+ vocab_files_names = VOCAB_FILES_NAMES
96
+ pretrained_vocab_files_map = CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP
97
+ max_model_input_sizes = CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
98
+ pretrained_init_configuration = CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION
99
+
100
+ def __init__(
101
+ self,
102
+ doc_maxlen: Optional[int] = 512,
103
+ **kwargs,
104
+ ):
105
+ super().__init__(
106
+ doc_maxlen=doc_maxlen,
107
+ **kwargs,
108
+ )
109
+
110
+ self.doc_maxlen = doc_maxlen
111
+ self.D_marker_token, self.D_marker_token_id = "[D]", self.convert_tokens_to_ids("[unused1]")
112
+
113
+ def __call__(
114
+ self,
115
+ text: List[str],
116
+ padding: Optional[Union[str, bool]] = "max_length",
117
+ truncation: Optional[Union[bool, str]] = "longest_first",
118
+ max_length: Optional[int] = 512,
119
+ return_tensors: Optional[Union[str, TensorType]] = "pt",
120
+ **kwargs,
121
+ ):
122
+ # add placehold for the [D] marker
123
+ text = [". " + x for x in text]
124
+
125
+ if max_length > self.doc_maxlen:
126
+ # can not exceed the pre-set length
127
+ max_length = self.doc_maxlen
128
+
129
+ encoding = super().__call__(
130
+ text,
131
+ padding=padding,
132
+ truncation=truncation,
133
+ return_tensors=return_tensors,
134
+ max_length=max_length,
135
+ **kwargs,
136
+ )
137
+
138
+ ids, mask = encoding["input_ids"], encoding["attention_mask"]
139
+
140
+ # postprocess for the [D] marker
141
+ ids[:, 1] = self.D_marker_token_id
142
+
143
+ # if bsize:
144
+ # # This bsize function is used in the original ColBERT codebase to split inputs into multiple batches
145
+ # if image_features is not None:
146
+ # ids, mask, image_features, reverse_indices = _sort_by_length(ids, mask, bsize, image_features=image_features)
147
+ # batches = _split_into_batches(ids, mask, bsize, image_features=image_features)
148
+ # else:
149
+ # ids, mask, reverse_indices = _sort_by_length(ids, mask, bsize)
150
+ # batches = _split_into_batches(ids, mask, bsize)
151
+
152
+ # return batches, reverse_indices
153
+
154
+ encoding["input_ids"] = ids
155
+ encoding["attention_mask"] = mask
156
+
157
+ return encoding
158
+
159
+
160
+ # Modified from colbert.modeling.tokenization
161
+ class FLMRQueryEncoderTokenizer(BertTokenizer):
162
+ r"""
163
+ Constructs a FLMRQueryEncoder tokenizer.
164
+
165
+ [`FLMRQueryEncoder`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation
166
+ splitting and wordpiece.
167
+
168
+ Refer to superclass [`BertTokenizer`] for usage examples and documentation concerning parameters.
169
+ """
170
+
171
+ vocab_files_names = VOCAB_FILES_NAMES
172
+ pretrained_vocab_files_map = QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP
173
+ max_model_input_sizes = QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
174
+ pretrained_init_configuration = QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION
175
+
176
+ def __init__(
177
+ self,
178
+ *args,
179
+ query_maxlen: Optional[int] = 32,
180
+ attend_to_mask_tokens: Optional[bool] = False,
181
+ **kwargs,
182
+ ):
183
+ super().__init__(
184
+ *args,
185
+ query_maxlen=query_maxlen,
186
+ attend_to_mask_tokens=attend_to_mask_tokens,
187
+ **kwargs,
188
+ )
189
+
190
+ self.query_maxlen = query_maxlen
191
+ self.background_maxlen = 512 - self.query_maxlen + 1 # FIXME: Make this configurable
192
+ self.attend_to_mask_tokens = attend_to_mask_tokens
193
+
194
+ self.Q_marker_token, self.Q_marker_token_id = "[Q]", self.convert_tokens_to_ids("[unused0]")
195
+
196
+ def __call__(
197
+ self,
198
+ text: Union[str, List[str]],
199
+ padding: Optional[Union[str, bool]] = "max_length",
200
+ truncation: Optional[Union[bool, str]] = True,
201
+ max_length: Optional[int] = None,
202
+ return_tensors: Optional[Union[str, TensorType]] = "pt",
203
+ **kwargs,
204
+ ):
205
+ if isinstance(text, str):
206
+ # convert to list if input is a single string
207
+ text = [text]
208
+
209
+ # add placehold for the [Q] marker
210
+ text = [". " + x for x in text]
211
+
212
+ if max_length is not None:
213
+ # use user specified max_length
214
+ pass
215
+ else:
216
+ # use default max length
217
+ max_length = self.query_maxlen
218
+
219
+ encoding = super().__call__(
220
+ text,
221
+ padding=padding,
222
+ truncation=truncation,
223
+ return_tensors=return_tensors,
224
+ max_length=max_length,
225
+ **kwargs,
226
+ )
227
+
228
+ ids, mask = encoding["input_ids"], encoding["attention_mask"]
229
+
230
+ # postprocess for the [Q] marker and the [MASK] augmentation
231
+ ids[:, 1] = self.Q_marker_token_id
232
+ ids[ids == self.pad_token_id] = self.mask_token_id
233
+
234
+ if self.attend_to_mask_tokens:
235
+ # When attend_to_mask_tokens is True, we want to attend to the [MASK] tokens
236
+ mask[ids == self.mask_token_id] = 1
237
+ assert mask.sum().item() == mask.size(0) * mask.size(1), mask
238
+
239
+ return {"input_ids": ids, "attention_mask": mask}
context_tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "auto_map": {
45
+ "AutoTokenizer": [
46
+ "tokenization_flmr.FLMRContextEncoderTokenizer",
47
+ null
48
+ ]
49
+ },
50
+ "clean_up_tokenization_spaces": true,
51
+ "cls_token": "[CLS]",
52
+ "do_basic_tokenize": true,
53
+ "do_lower_case": true,
54
+ "doc_maxlen": 512,
55
+ "mask_token": "[MASK]",
56
+ "model_max_length": 1000000000000000019884624838656,
57
+ "never_split": null,
58
+ "pad_token": "[PAD]",
59
+ "sep_token": "[SEP]",
60
+ "strip_accents": null,
61
+ "tokenize_chinese_chars": true,
62
+ "tokenizer_class": "FLMRContextEncoderTokenizer",
63
+ "unk_token": "[UNK]"
64
+ }
context_tokenizer/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
flmr_utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file contains utility functions for the FLMR model. Some of these functions are adapted from the original ColBERT codebase.
3
+ """
4
+
5
+ import torch
6
+ import torch.distributed as dist
7
+
8
+
9
+ def get_rank():
10
+ return dist.get_rank()
11
+
12
+
13
+ def get_world_size():
14
+ return dist.get_world_size()
15
+
16
+
17
+ def get_default_group():
18
+ return dist.group.WORLD
19
+
20
+
21
+ # TODO: The masking below might also be applicable in the kNN part
22
+ def colbert_score_reduce(scores_padded, D_mask):
23
+ # print('D_mask', D_mask.shape, D_mask)
24
+ D_padding = ~D_mask.view(scores_padded.size(0), scores_padded.size(1)).bool()
25
+ # print('D_padding', D_padding.shape, D_padding)
26
+ # print(D_padding[0].tolist())
27
+ scores_padded[D_padding] = -9999
28
+ scores = scores_padded.max(1).values
29
+
30
+ return scores.sum(-1)
31
+
32
+
33
+ def colbert_score(Q, D_padded, D_mask, use_gpu=False):
34
+ """
35
+ Supply sizes Q = (1 | num_docs, *, dim) and D = (num_docs, *, dim).
36
+ If Q.size(0) is 1, the matrix will be compared with all passages.
37
+ Otherwise, each query matrix will be compared against the *aligned* passage.
38
+
39
+ EVENTUALLY: Consider masking with -inf for the maxsim (or enforcing a ReLU).
40
+ """
41
+ if use_gpu:
42
+ Q, D_padded, D_mask = Q.cuda(), D_padded.cuda(), D_mask.cuda()
43
+ assert Q.dim() == 3, Q.size()
44
+ assert D_padded.dim() == 3, D_padded.size()
45
+ assert Q.size(0) in [1, D_padded.size(0)]
46
+
47
+ scores = D_padded @ Q.to(dtype=D_padded.dtype).permute(0, 2, 1)
48
+
49
+ return colbert_score_reduce(scores, D_mask)
50
+
51
+
52
+ def _sort_by_length(ids, mask, bsize, *args):
53
+ if ids.size(0) <= bsize:
54
+ return ids, mask, torch.arange(ids.size(0))
55
+
56
+ indices = mask.sum(-1).sort().indices
57
+ reverse_indices = indices.sort().indices
58
+
59
+ return_array = [ids[indices], mask[indices]]
60
+ for arg in args:
61
+ if isinstance(arg, torch.Tensor):
62
+ return_array.append(arg[indices])
63
+ else:
64
+ # arg is a list, and we want to sort the list according to indices
65
+ return_array.append([arg[i] for i in indices])
66
+
67
+ return *return_array, reverse_indices
68
+
69
+
70
+ def _split_into_batches(ids, mask, bsize, *args):
71
+ batches = []
72
+ for offset in range(0, ids.size(0), bsize):
73
+ batch = [ids[offset : offset + bsize], mask[offset : offset + bsize]]
74
+ for arg in args:
75
+ batch.append(arg[offset : offset + bsize])
76
+ batches.append(batch)
77
+ return batches
model-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b10b9f681dcffbb367d8f2d7d74f14974f4b42759306cbffc6e5cb2ff0d8a76b
3
+ size 4959695808
model-00002-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cec398666827c116ac250eb745785aeeeb7b72da572992fbd160a8c3cf4a18ce
3
+ size 3378768360
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_flmr.py ADDED
@@ -0,0 +1,1504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 FLMR Authors, The Hugging Face Team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch FLMR model for Knowledge-intensive Visual Question Answering."""
16
+
17
+
18
+ import copy
19
+ import os
20
+ import pathlib
21
+ import string
22
+ from dataclasses import dataclass
23
+ from typing import Optional, Tuple, Union
24
+
25
+ import torch
26
+ import torch.distributed as dist
27
+ from torch import Tensor, nn
28
+ from torch.utils.cpp_extension import load
29
+
30
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
31
+ from transformers.modeling_utils import PreTrainedModel
32
+ from transformers.utils import (
33
+ ModelOutput,
34
+ add_start_docstrings,
35
+ add_start_docstrings_to_model_forward,
36
+ logging,
37
+ replace_return_docstrings,
38
+ )
39
+ from transformers.models.bert.modeling_bert import BertModel
40
+ from transformers.models.clip import CLIPVisionModel
41
+ from .configuration_flmr import FLMRConfig, FLMRTextConfig, FLMRVisionConfig
42
+ from .tokenization_flmr import FLMRQueryEncoderTokenizer, FLMRContextEncoderTokenizer
43
+ from .tokenization_flmr_fast import FLMRQueryEncoderTokenizerFast, FLMRContextEncoderTokenizerFast
44
+ from .flmr_utils import (
45
+ colbert_score,
46
+ colbert_score_reduce,
47
+ get_rank,
48
+ get_world_size,
49
+ )
50
+
51
+
52
+ logger = logging.get_logger(__name__)
53
+
54
+ _CONFIG_FOR_DOC = "FLMRConfig"
55
+ _CHECKPOINT_FOR_DOC = "LinWeizheDragon/PreFLMR_ViT-L"
56
+
57
+
58
+ FLMR_PRETRAINED_MODEL_ARCHIVE_LIST = [
59
+ "LinWeizheDragon/PreFLMR_ViT-L",
60
+ "LinWeizheDragon/FLMR",
61
+ # See all FLMR models at https://huggingface.co/models?filter=flmr
62
+ ]
63
+
64
+
65
+ ##########
66
+ # Outputs
67
+ ##########
68
+
69
+
70
+ @dataclass
71
+ class FLMRContextEncoderOutput(ModelOutput):
72
+ """
73
+ Class for outputs of the `doc()` function of [`FLMRModelForRetrieval`].
74
+
75
+ Args:
76
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, embeddings_size)`):
77
+ The FLMR encoder outputs the *pooler_output* that corresponds to the embedding of the first token of the context representation.
78
+ This output can be used to embed questions for nearest neighbors queries with query embeddings.
79
+ late_interaction_output (`torch.FloatTensor` of shape `(batch_size, context_embedding_length, embeddings_size)`):
80
+ The FLMR encoder outputs the *late_interaction_output* that corresponds to the question representation. The embeddings of all tokens are included for late interaction retrieval.
81
+ This output is to be used to embed contexts for late-interaction retrieval with query embeddings.
82
+ context_mask (`torch.FloatTensor` of shape `(batch_size, context_embedding_length)`):
83
+ The FLMR encoder outputs the *context_mask* that corresponds to the mask of the context representation.
84
+ text_encoder_attentions (`Tuple[torch.FloatTensor]`, *optional*):
85
+ Tuple of elements containing the attention weights of the text encoder's layers. Each element is a
86
+ tensor of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
87
+ text_encoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
88
+ Tuple of elements containing the hidden states of the text encoder at each layer plus the initial embedding
89
+ outputs. Each tensor has a shape of `(batch_size, sequence_length, hidden_size)`.
90
+ vision_encoder_attentions (`Tuple[torch.FloatTensor]`, *optional*):
91
+ Tuple of elements containing the attention weights of the vision encoder's layers. Each element is a
92
+ tensor of shape `(batch_size, num_heads, vision_sequence_length, vision_sequence_length)`.
93
+ vision_encoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
94
+ Tuple of elements containing the hidden states of the vision encoder at each layer plus the initial embedding
95
+ outputs. Each tensor has a shape of `(batch_size, vision_sequence_length, hidden_size)`.
96
+ transformer_mapping_network_attentions (`Tuple[torch.FloatTensor]`, *optional*):
97
+ Tuple of elements containing the attention weights of the transformer mapping network's layers. Each element
98
+ is a tensor of shape `(batch_size, num_heads, mapping_sequence_length, mapping_sequence_length)`.
99
+ transformer_mapping_network_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
100
+ Tuple of elements containing the hidden states of the transformer mapping network at each layer plus the
101
+ initial embedding outputs. Each tensor has a shape of `(batch_size, mapping_sequence_length, hidden_size)`.
102
+ """
103
+
104
+ pooler_output: torch.FloatTensor
105
+ late_interaction_output: torch.FloatTensor = None
106
+ context_mask: torch.FloatTensor = None
107
+ text_encoder_attentions: Optional[Tuple[Tensor]] = None
108
+ text_encoder_hidden_states: Optional[Tuple[Tensor]] = None
109
+ vision_encoder_attentions: Optional[Tuple[Tensor]] = None
110
+ vision_encoder_hidden_states: Optional[Tuple[Tensor]] = None
111
+ transformer_mapping_network_attentions: Optional[Tuple[Tensor]] = None
112
+ transformer_mapping_network_hidden_states: Optional[Tuple[Tensor]] = None
113
+
114
+
115
+ @dataclass
116
+ class FLMRQueryEncoderOutput(ModelOutput):
117
+ """
118
+ Class for outputs of the `query()` function of [`FLMRModelForRetrieval.query()`].
119
+
120
+ Args:
121
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, embeddings_size)`):
122
+ The FLMR encoder outputs the *pooler_output* that corresponds to the embedding of the first token of the query representation.
123
+ This output can be used to embed questions for nearest neighbors queries with context embeddings.
124
+ late_interaction_output (`torch.FloatTensor` of shape `(batch_size, query_embedding_length, embeddings_size)`):
125
+ The FLMR encoder outputs the *late_interaction_output* that corresponds to the question representation. The embeddings of all tokens are included for late interaction retrieval.
126
+ This output is to be used to embed questions for late-interaction retrieval with context embeddings.
127
+ text_encoder_attentions (`Tuple[torch.FloatTensor]`, *optional*):
128
+ Tuple of elements containing the attention weights of the text encoder's layers. Each element is a
129
+ tensor of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
130
+ text_encoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
131
+ Tuple of elements containing the hidden states of the text encoder at each layer plus the initial embedding
132
+ outputs. Each tensor has a shape of `(batch_size, sequence_length, hidden_size)`.
133
+ vision_encoder_attentions (`Tuple[torch.FloatTensor]`, *optional*):
134
+ Tuple of elements containing the attention weights of the vision encoder's layers. Each element is a
135
+ tensor of shape `(batch_size, num_heads, vision_sequence_length, vision_sequence_length)`.
136
+ vision_encoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
137
+ Tuple of elements containing the hidden states of the vision encoder at each layer plus the initial embedding
138
+ outputs. Each tensor has a shape of `(batch_size, vision_sequence_length, hidden_size)`.
139
+ transformer_mapping_network_attentions (`Tuple[torch.FloatTensor]`, *optional*):
140
+ Tuple of elements containing the attention weights of the transformer mapping network's layers. Each element
141
+ is a tensor of shape `(batch_size, num_heads, mapping_sequence_length, mapping_sequence_length)`.
142
+ transformer_mapping_network_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
143
+ Tuple of elements containing the hidden states of the transformer mapping network at each layer plus the
144
+ initial embedding outputs. Each tensor has a shape of `(batch_size, mapping_sequence_length, hidden_size)`.
145
+ """
146
+
147
+ pooler_output: torch.FloatTensor
148
+ late_interaction_output: torch.FloatTensor = None
149
+ text_encoder_attentions: Optional[Tuple[Tensor]] = None
150
+ text_encoder_hidden_states: Optional[Tuple[Tensor]] = None
151
+ vision_encoder_attentions: Optional[Tuple[Tensor]] = None
152
+ vision_encoder_hidden_states: Optional[Tuple[Tensor]] = None
153
+ transformer_mapping_network_attentions: Optional[Tuple[Tensor]] = None
154
+ transformer_mapping_network_hidden_states: Optional[Tuple[Tensor]] = None
155
+
156
+
157
+ @dataclass
158
+ class FLMRModelForRetrievalOutput(ModelOutput):
159
+ """
160
+ Class for outputs of [`FLMRModelForRetrieval.query()`].
161
+
162
+ Args:
163
+ loss (`torch.FloatTensor`):
164
+ contrastive loss of the input queries and positive and negative examples. This output is to be used in model training.
165
+ scores (`torch.FloatTensor` of shape `(batch_size, num_positive_examples + num_negative_examples)`):
166
+ The FLMR model outputs the *scores* that corresponds to the late-interaction scores of the input query and context. Each query is associated with `num_positive_examples` positive examples and `num_negative_examples` negative examples, and the scores are the late-interaction scores of the query and these examples.
167
+ in_batch_negative_loss (`torch.FloatTensor` of shape `(batch_size, query_embedding_length, embeddings_size)`):
168
+ The FLMR model outputs the *in_batch_negative_loss* which computes contrastive loss that includes in-batch negatives. For each positive example, all other examples in the batch except itself are considered negative examples in computing the contrastive loss. This improves ultimate performance in practice. This output is to be used in model training.
169
+ query_late_interaction_output (`torch.FloatTensor` of shape `(batch_size, query_embedding_length, embeddings_size)`):
170
+ The FLMR model outputs the *query_late_interaction_output* that corresponds to the late-interaction representations of the input query.
171
+ context_late_interaction_output (`torch.FloatTensor` of shape `(batch_size, context_embedding_length, embeddings_size)`):
172
+ The FLMR model outputs the *context_late_interaction_output* that corresponds to the late-interaction representations of the input context.
173
+ query_attentions (`Tuple[Tuple[Tensor]]`, *optional*):
174
+ Tuple of elements containing the attention weights of the query's layers. There are three sub-tuples in this tuple, corresponding to the attentions of the text encoder, vision encoder, and transformer mapping network. Each element in the sub-tuple is a tensor of shape `(batch_size, num_heads, sequence_length, sequence_length)`, with `sequence_length` being the sequence length in the corresponding encoder.
175
+ query_hidden_states (`Tuple[Tuple[Tensor]]`, *optional*):
176
+ Tuple of elements containing the hidden states of the query's layers. There are three sub-tuples in this tuple, corresponding to the hidden states of the text encoder, vision encoder, and transformer mapping network. Each element in the sub-tuple is a tensor of shape `(batch_size, sequence_length, hidden_size)`, with `sequence_length` being the sequence length in the corresponding encoder.
177
+ context_attentions (`Tuple[Tuple[Tensor]]`, *optional*):
178
+ Tuple of elements containing the attention weights of the context's layers. There are three sub-tuples in this tuple, corresponding to the attentions of the text encoder, vision encoder, and transformer mapping network. Each element in the sub-tuple is a tensor of shape `(batch_size, num_heads, sequence_length, sequence_length)`, with `sequence_length` being the sequence length in the corresponding encoder.
179
+ context_hidden_states (`Tuple[Tuple[Tensor]]`, *optional*):
180
+ Tuple of elements containing the hidden states of the context's layers. There are three sub-tuples in this tuple, corresponding to the hidden states of the text encoder, vision encoder, and transformer mapping network. Each element in the sub-tuple is a tensor of shape `(batch_size, sequence_length, hidden_size)`, with `sequence_length` being the sequence length in the corresponding encoder.
181
+ """
182
+
183
+ loss: torch.FloatTensor
184
+ scores: torch.FloatTensor = None
185
+ in_batch_negative_loss: torch.FloatTensor = None
186
+ query_late_interaction_output: torch.FloatTensor = None
187
+ context_late_interaction_output: torch.FloatTensor = None
188
+ query_attentions: Optional[Tuple[Tuple[Tensor]]] = None
189
+ query_hidden_states: Optional[Tuple[Tuple[Tensor]]] = None
190
+ context_attentions: Optional[Tuple[Tuple[Tensor]]] = None
191
+ context_hidden_states: Optional[Tuple[Tuple[Tensor]]] = None
192
+
193
+
194
+ class FLMRPreTrainedModel(PreTrainedModel):
195
+ def _init_weights(self, module):
196
+ """Initialize the weights"""
197
+ if isinstance(module, nn.Linear):
198
+ # Slightly different from the TF version which uses truncated_normal for initialization
199
+ # cf https://github.com/pytorch/pytorch/pull/5617
200
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
201
+ if module.bias is not None:
202
+ module.bias.data.zero_()
203
+ elif isinstance(module, nn.Embedding):
204
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
205
+ if module.padding_idx is not None:
206
+ module.weight.data[module.padding_idx].zero_()
207
+ elif isinstance(module, nn.LayerNorm):
208
+ module.bias.data.zero_()
209
+ module.weight.data.fill_(1.0)
210
+
211
+
212
+ ##################
213
+ # PreTrainedModel
214
+ ##################
215
+
216
+
217
+ class FLMRPretrainedModelForRetrieval(FLMRPreTrainedModel):
218
+ """
219
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
220
+ models.
221
+ """
222
+
223
+ config_class = FLMRConfig
224
+ load_tf_weights = None
225
+ base_model_prefix = "flmr"
226
+
227
+
228
+ ###############
229
+ # Actual Models
230
+ ###############
231
+
232
+
233
+ FLMR_START_DOCSTRING = r"""
234
+
235
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
236
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
237
+ etc.)
238
+
239
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
240
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
241
+ and behavior.
242
+
243
+ Parameters:
244
+ config ([`FLMRConfig`]): Model configuration class with all the parameters of the model.
245
+ Initializing with a config file does not load the weights associated with the model, only the
246
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
247
+ query_tokenizer ([`FLMRQueryEncoderTokenizer`], *optional*): The tokenizer used for tokenizing the query.
248
+ The query tokenizer can be initialized with `FLMRQueryEncoderTokenizer.from_pretrained(pretrained_model_name_or_path)`.
249
+ context_tokenizer ([`FLMRContextEncoderTokenizer`], *optional*): The tokenizer used for tokenizing the context.
250
+ The context tokenizer can be initialized with `FLMRContextEncoderTokenizer.from_pretrained(pretrained_model_name_or_path)`.
251
+ """
252
+
253
+
254
+ FLMR_MODEL_INPUTS_DOCSTRING = r"""
255
+ Args:
256
+ query_input_ids (`torch.LongTensor` of shape `(batch_size, query_length)`):
257
+ Indices of input query tokens in the vocabulary. To match pretraining, FLMR input sequence should be
258
+ formatted with [CLS] and Q marker tokens as follows:
259
+ [CLS] [unused0] using the provided image, obtain documents that address the subsequent question : what is the capital of france? [SEP] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] ...
260
+
261
+ FLMR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
262
+ rather than the left.
263
+
264
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
265
+ [`PreTrainedTokenizer.__call__`] for details.
266
+
267
+ [What are input IDs?](../glossary#input-ids)
268
+ query_attention_mask (`torch.FloatTensor` of shape `(batch_size, query_length)`, *optional*):
269
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
270
+
271
+ - 1 for tokens that are **not masked**,
272
+ - 0 for tokens that are **masked**.
273
+
274
+ [What are attention masks?](../glossary#attention-mask)
275
+ query_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
276
+ Pixel values. Pixel values can be obtained using
277
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
278
+ query_image_features (`torch.FloatTensor` of shape `(batch_size, vision_encoder_hidden_size)`, *optional*):
279
+ Image features are required when `query_pixel_values` is not provided. In this case, vision encoder outputs are pre-extracted to speed up training and inference by skipping the vision encoder forward pass and the extract image features are directly given to the FLMR model. Image features can be obtained
280
+ using [`CLIPVisionModel`]. See [`CLIPVisionModel.__call__`] for details.
281
+ context_input_ids (`torch.LongTensor` of shape `(batch_size * (1 + num_negative_examples), context_length)`):
282
+ Indices of input context tokens in the vocabulary. To match pretraining, FLMR input sequence should be
283
+ formatted with [CLS] and D marker tokens as follows:
284
+ [CLS] [unused1] paris is the capital of france. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] ...
285
+
286
+ FLMR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
287
+ rather than the left.
288
+
289
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
290
+ [`PreTrainedTokenizer.__call__`] for details.
291
+
292
+ [What are input IDs?](../glossary#input-ids)
293
+
294
+ The input batch size of this tensor is `batch_size * (1 + num_negative_examples)`. Check the following argument `num_negative_examples` for details.
295
+
296
+ context_attention_mask (`torch.FloatTensor` of shape `(batch_size * (1 + num_negative_examples), context_length)`, *optional*):
297
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
298
+
299
+ - 1 for tokens that are **not masked**,
300
+ - 0 for tokens that are **masked**.
301
+
302
+ [What are attention masks?](../glossary#attention-mask)
303
+
304
+ The input batch size of this tensor is `batch_size * (1 + num_negative_examples)`. Check the following argument `num_negative_examples` for details.
305
+ context_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
306
+ Pixel values. Pixel values can be obtained using
307
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
308
+ context_image_features (`torch.FloatTensor` of shape `(batch_size, vision_encoder_hidden_size)`, *optional*):
309
+ Image features are required when `context_pixel_values` is not provided. In this case, vision encoder outputs are pre-extracted to speed up training and inference by skipping the vision encoder forward pass and the extract image features are directly given to the FLMR model. Image features can be obtained
310
+ using [`CLIPVisionModel`]. See [`CLIPVisionModel.__call__`] for details.
311
+ use_in_batch_negatives (`bool`, *optional*):
312
+ Whether or not to use in-batch negatives. If `True`, the contrastive loss includes in-batch negatives. For each positive example, all other examples in the batch except itself are considered negative examples in computing the contrastive loss. This improves ultimate performance in practice. This input is to be used in model training.
313
+ in_batch_negatives_from_all_gpus (`bool`, *optional*):
314
+ Whether or not to use in-batch negatives from all GPUs. If `True`, the contrastive loss includes in-batch negatives from all GPUs. This input is to be used in model training.
315
+ num_negative_examples (`int`, *optional*):
316
+ The number of negative examples in the batch. For example, if `num_negative_examples` is 4, the batch size of `context_input_ids` and `context_attention_mask` is `batch_size * 5`.
317
+ query_concat_output_from_vision_encoder (`bool`, *optional*):
318
+ Whether or not to concatenate the output from the vision encoder to the final query late-interaction representations. If `True`, the output from the vision encoder is concatenated to the query representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models.
319
+ query_concat_output_from_text_encoder (`bool`, *optional*):
320
+ Whether or not to concatenate the output from the text encoder to the final query late-interaction representations. If `True`, the output from the text encoder is concatenated to the query representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models.
321
+
322
+ This argument can be set to `False` when performing mapping network pretraining as in FLMR and PreFLMR, in which case the output from the text encoder is not concatenated to the final query representations.
323
+ context_concat_output_from_vision_encoder (`bool`, *optional*):
324
+ Whether or not to concatenate the output from the vision encoder to the final context late-interaction representations. If `True`, the output from the vision encoder is concatenated to the context representations. When using a pretrained model, this will be read from the model configuration. It should be set to `False` for FLMR and PreFLMR -style models since the context vision encoder is not used.
325
+
326
+ This can be set to `True` to additionally encode the context images with the vision encoder when context images are provided.
327
+ context_concat_output_from_text_encoder (`bool`, *optional*):
328
+ Whether or not to concatenate the output from the text encoder to the final context late-interaction representations. If `True`, the output from the text encoder is concatenated to the context representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models.
329
+ return_dict (`bool`, *optional*):
330
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
331
+ output_attentions (`bool`, *optional*):
332
+ Whether or not to return the attentions tensors of all attention layers. See `*_attentions` under returned
333
+ tensors for more detail.
334
+ output_hidden_states (`bool`, *optional*):
335
+ Whether or not to return the hidden states of all layers. See `*_hidden_states` under returned tensors for more detail.
336
+ """
337
+
338
+
339
+ FLMR_MODEL_QUERY_INPUTS_DOCSTRING = r"""
340
+ Args:
341
+ input_ids (`torch.LongTensor` of shape `(batch_size, query_length)`):
342
+ Indices of input query tokens in the vocabulary. To match pretraining, FLMR input sequence should be
343
+ formatted with [CLS] and Q marker tokens as follows:
344
+ [CLS] [unused0] using the provided image, obtain documents that address the subsequent question : what is the capital of france? [SEP] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] ...
345
+
346
+ FLMR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
347
+ rather than the left.
348
+
349
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
350
+ [`PreTrainedTokenizer.__call__`] for details.
351
+
352
+ [What are input IDs?](../glossary#input-ids)
353
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, query_length)`, *optional*):
354
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
355
+
356
+ - 1 for tokens that are **not masked**,
357
+ - 0 for tokens that are **masked**.
358
+
359
+ [What are attention masks?](../glossary#attention-mask)
360
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
361
+ Pixel values. Pixel values can be obtained using
362
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
363
+ image_features (`torch.FloatTensor` of shape `(batch_size, vision_encoder_hidden_size)`, *optional*):
364
+ Image features are required when `pixel_values` is not provided. In this case, vision encoder outputs are pre-extracted to speed up training and inference by skipping the vision encoder forward pass and the extract image features are directly given to the FLMR model. Image features can be obtained
365
+ using [`CLIPVisionModel`]. See [`CLIPVisionModel.__call__`] for details.
366
+ concat_output_from_vision_encoder (`bool`, *optional*):
367
+ Whether or not to concatenate the output from the vision encoder to the final query late-interaction representations. If `True`, the output from the vision encoder is concatenated to the query representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models.
368
+ concat_output_from_text_encoder (`bool`, *optional*):
369
+ Whether or not to concatenate the output from the text encoder to the final query late-interaction representations. If `True`, the output from the text encoder is concatenated to the query representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models.
370
+
371
+ This argument can be set to `False` when performing mapping network pretraining as in FLMR and PreFLMR, in which case the output from the text encoder is not concatenated to the final query representations.
372
+ """
373
+
374
+
375
+ FLMR_MODEL_CONTEXT_INPUTS_DOCSTRING = r"""
376
+ Args:
377
+ input_ids (`torch.LongTensor` of shape `(batch_size * (1 + num_negative_examples), context_length)`):
378
+ Indices of input context tokens in the vocabulary. To match pretraining, FLMR input sequence should be
379
+ formatted with [CLS] and D marker tokens as follows:
380
+ [CLS] [unused1] paris is the capital of france. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] ...
381
+
382
+ FLMR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
383
+ rather than the left.
384
+
385
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
386
+ [`PreTrainedTokenizer.__call__`] for details.
387
+
388
+ [What are input IDs?](../glossary#input-ids)
389
+
390
+ The input batch size of this tensor is `batch_size * (1 + num_negative_examples)`. Check the following argument `num_negative_examples` for details.
391
+ attention_mask (`torch.FloatTensor` of shape `(batch_size * (1 + num_negative_examples), context_length)`, *optional*):
392
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
393
+
394
+ - 1 for tokens that are **not masked**,
395
+ - 0 for tokens that are **masked**.
396
+
397
+ [What are attention masks?](../glossary#attention-mask)
398
+
399
+ The input batch size of this tensor is `batch_size * (1 + num_negative_examples)`. Check the following argument `num_negative_examples` for details.
400
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
401
+ Pixel values. Pixel values can be obtained using
402
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
403
+ image_features (`torch.FloatTensor` of shape `(batch_size, vision_encoder_hidden_size)`, *optional*):
404
+ Image features are required when `pixel_values` is not provided. In this case, vision encoder outputs are pre-extracted to speed up training and inference by skipping the vision encoder forward pass and the extract image features are directly given to the FLMR model. Image features can be obtained
405
+ using [`CLIPVisionModel`]. See [`CLIPVisionModel
406
+ .__call__`] for details.
407
+ concat_output_from_vision_encoder (`bool`, *optional*):
408
+ Whether or not to concatenate the output from the vision encoder to the final context late-interaction representations. If `True`, the output from the vision encoder is concatenated to the context representations. When using a pretrained model, this will be read from the model configuration. It should be set to `False` for FLMR and PreFLMR -style models since the context vision encoder is not used.
409
+
410
+ This can be set to `True` to additionally encode the context images with the vision encoder when context images are provided.
411
+ concat_output_from_text_encoder (`bool`, *optional*):
412
+ Whether or not to concatenate the output from the text encoder to the final context late-interaction representations. If `True`, the output from the text encoder is concatenated to the context representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models.
413
+ keep_dims (`bool`, *optional*):
414
+ Whether or not to keep the dimensions of the output. If `True`, the output is returned with the same dimensions as the input. If `False`, the output is returned with the batch size of the input and the context length. This input is to be used in model training.
415
+ return_mask (`bool`, *optional*):
416
+ Whether or not to return the mask of the context representation. If `True`, the mask of the context representation is returned. This input is to be used in model training.
417
+ """
418
+
419
+
420
+ FLMR_TEXT_ENCODERS_START_DOCSTRING = r"""
421
+
422
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
423
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
424
+ etc.)
425
+
426
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
427
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
428
+ and behavior.
429
+
430
+ Parameters:
431
+ config ([`FLMRTextConfig`]): Model configuration class with all the parameters of the model.
432
+ Initializing with a config file does not load the weights associated with the model, only the
433
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
434
+ """
435
+
436
+
437
+ # Modified from transformers.models.dpr.modeling_dpr with DPR -> FLMR
438
+ FLMR_TEXT_ENCODERS_INPUTS_DOCSTRING = r"""
439
+ Args:
440
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
441
+ Indices of input sequence tokens in the vocabulary. To match pretraining, FLMR input sequence should be
442
+ formatted with [CLS] and [SEP] tokens as follows:
443
+
444
+ (a) For sequence pairs (for a pair title+text for example):
445
+
446
+ ```
447
+ tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
448
+ token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
449
+ ```
450
+
451
+ (b) For single sequences (for a question for example):
452
+
453
+ ```
454
+ tokens: [CLS] the dog is hairy . [SEP]
455
+ token_type_ids: 0 0 0 0 0 0 0
456
+ ```
457
+
458
+ FLMR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
459
+ rather than the left.
460
+
461
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
462
+ [`PreTrainedTokenizer.__call__`] for details.
463
+
464
+ [What are input IDs?](../glossary#input-ids)
465
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
466
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
467
+
468
+ - 1 for tokens that are **not masked**,
469
+ - 0 for tokens that are **masked**.
470
+
471
+ [What are attention masks?](../glossary#attention-mask)
472
+ token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
473
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
474
+ 1]`:
475
+
476
+ - 0 corresponds to a *sentence A* token,
477
+ - 1 corresponds to a *sentence B* token.
478
+
479
+ [What are token type IDs?](../glossary#token-type-ids)
480
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
481
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
482
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
483
+ model's internal embedding lookup matrix.
484
+ output_attentions (`bool`, *optional*):
485
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
486
+ tensors for more detail.
487
+ output_hidden_states (`bool`, *optional*):
488
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
489
+ more detail.
490
+ return_dict (`bool`, *optional*):
491
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
492
+ """
493
+
494
+ FLMR_VISION_ENCODERS_START_DOCSTRING = r"""
495
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
496
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
497
+ etc.)
498
+
499
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
500
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
501
+ and behavior.
502
+
503
+ Parameters:
504
+ config ([`FLMRVisionConfig`]): Model configuration class with all the parameters of the model.
505
+ Initializing with a config file does not load the weights associated with the model, only the
506
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
507
+ """
508
+
509
+ # Modified from transformers.models.clip.modeling_clip with CLIP -> FLMR
510
+ FLMR_VISION_ENCODERS_INPUTS_DOCSTRING = r"""
511
+ Args:
512
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
513
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
514
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
515
+ output_attentions (`bool`, *optional*):
516
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
517
+ tensors for more detail.
518
+ output_hidden_states (`bool`, *optional*):
519
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
520
+ more detail.
521
+ return_dict (`bool`, *optional*):
522
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
523
+ """
524
+
525
+
526
+ class FLMRMultiLayerPerceptron(nn.Module):
527
+ """
528
+ A simple multi-layer perceptron with an activation function. This can be used as the mapping network in the FLMR model.
529
+ """
530
+
531
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
532
+ return self.model(x)
533
+
534
+ def __init__(self, sizes, bias=True, act=nn.Tanh):
535
+ super(FLMRMultiLayerPerceptron, self).__init__()
536
+ layers = []
537
+ for i in range(len(sizes) - 1):
538
+ layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
539
+ if i < len(sizes) - 2:
540
+ layers.append(act())
541
+ self.model = nn.Sequential(*layers)
542
+
543
+
544
+ @add_start_docstrings(
545
+ "The bare FLMR model that can be used to generate late-interaction embeddings for both multi-modal queries and documents. ",
546
+ FLMR_START_DOCSTRING,
547
+ )
548
+ class FLMRModelForRetrieval(FLMRPretrainedModelForRetrieval):
549
+ _keys_to_ignore_on_load_unexpected = [r"cls"]
550
+ main_input_name = "query_input_ids"
551
+ _tied_weights_keys = [] # Added dynamically at initialization depending on the architecture
552
+
553
+ def __init__(self, config: FLMRConfig, query_tokenizer=None, context_tokenizer=None):
554
+ super().__init__(config)
555
+ self.config = config
556
+ self.vision_model_version = config.vision_model_version
557
+
558
+ self.context_text_encoder = FLMRTextModel(config.text_config)
559
+ self.context_text_encoder_linear = nn.Linear(config.text_config.hidden_size, config.dim, bias=False)
560
+
561
+ self.query_tokenizer = query_tokenizer
562
+ self.context_tokenizer = context_tokenizer
563
+
564
+ if self.query_tokenizer is None:
565
+ logger.warning(
566
+ "query_tokenizer is not provided. A tokenizer is initialized from `bert-base-uncased`. Please pass in an FLMRQueryEncoderTokenizer instance if you need to extend the vocabulary beyond the existing ones in the bert tokenizer."
567
+ )
568
+ from transformers import FLMRQueryEncoderTokenizer
569
+
570
+ # initialize a FLMRQueryEncoderTokenizer
571
+ self.query_tokenizer = FLMRQueryEncoderTokenizer.from_pretrained("bert-base-uncased")
572
+
573
+ if self.context_tokenizer is None:
574
+ logger.warning(
575
+ "context_tokenizer is not provided. A tokenizer is initialized from `bert-base-uncased`. Please pass in an FLMRContextEncoderTokenizer instance if you need to extend the vocabulary beyond the existing ones in the bert tokenizer."
576
+ )
577
+ from transformers import FLMRContextEncoderTokenizer
578
+
579
+ # initialize a FLMRContextEncoderTokenizer
580
+ self.context_tokenizer = FLMRContextEncoderTokenizer.from_pretrained("bert-base-uncased")
581
+
582
+ self.mapping_network_prefix_length = self.config.mapping_network_prefix_length
583
+ self.vision_encoder_embedding_size = self.config.vision_config.hidden_size
584
+ self.text_encoder_embedding_size = self.config.text_config.hidden_size
585
+ self.late_interaction_embedding_size = self.config.dim
586
+
587
+ if self.config.use_vision_encoder:
588
+ self.context_vision_projection = FLMRMultiLayerPerceptron(
589
+ (
590
+ self.vision_encoder_embedding_size,
591
+ (self.late_interaction_embedding_size * self.mapping_network_prefix_length) // 2,
592
+ self.late_interaction_embedding_size * self.mapping_network_prefix_length,
593
+ )
594
+ )
595
+
596
+ if self.config.use_vision_encoder:
597
+ self.context_vision_encoder = FLMRVisionModel(config.vision_config)
598
+
599
+ if self.config.use_transformer_mapping_network:
600
+ # This is a PreFLMR style model
601
+ transformer_mapping_config_base = self.config.transformer_mapping_config_base
602
+ try:
603
+ from transformers import BertConfig
604
+ from transformers.models.bert.modeling_bert import BertEncoder
605
+ except Exception as e:
606
+ raise ImportError(f"Failed to import BertConfig and BertEncoder from transformers. {e}")
607
+
608
+ transformer_mapping_config = BertConfig.from_pretrained(transformer_mapping_config_base)
609
+
610
+ assert (
611
+ self.config.text_config.hidden_size == transformer_mapping_config.hidden_size
612
+ ), f"hidden_size {self.config.text_config.hidden_size} != transformer_mapping_config.hidden_size {transformer_mapping_config.hidden_size}. To use cross attention, the dimensions must match."
613
+ # shallow transformer
614
+ transformer_mapping_config.num_hidden_layers = self.config.transformer_mapping_num_hidden_layers
615
+ # add cross attention
616
+ transformer_mapping_config.is_decoder = True
617
+ transformer_mapping_config.add_cross_attention = True
618
+
619
+ # The linear layer from vision encoder to transformer input
620
+ self.transformer_mapping_input_linear = nn.Linear(
621
+ self.vision_encoder_embedding_size, transformer_mapping_config.hidden_size
622
+ )
623
+
624
+ # The transformer encoder
625
+ self.transformer_mapping_network = BertEncoder(transformer_mapping_config)
626
+
627
+ # The linear layer from transformer output to FLMR dim
628
+ self.transformer_mapping_output_linear = nn.Linear(
629
+ transformer_mapping_config.hidden_size, self.late_interaction_embedding_size
630
+ )
631
+
632
+ if self.config.separate_query_and_context_text_encoder:
633
+ self.query_text_encoder = copy.deepcopy(self.context_text_encoder)
634
+ self.query_text_encoder_linear = copy.deepcopy(self.context_text_encoder_linear)
635
+ else:
636
+ self.query_text_encoder = self.context_text_encoder
637
+ self.query_text_encoder_linear = self.context_text_encoder_linear
638
+ self._tied_weights_keys += ["context_text_encoder", "context_text_encoder_linear"]
639
+
640
+ if self.config.use_vision_encoder:
641
+ if self.config.separate_query_and_context_vision_encoder:
642
+ self.query_vision_encoder = copy.deepcopy(self.context_vision_encoder)
643
+ self.query_vision_projection = copy.deepcopy(self.context_vision_projection)
644
+ else:
645
+ self.query_vision_encoder = self.context_vision_encoder
646
+ self.query_vision_projection = self.context_vision_projection
647
+ self._tied_weights_keys += ["context_vision_encoder", "context_vision_projection"]
648
+
649
+ if self.config.load_cpu_extension:
650
+ try:
651
+ FLMRModelForRetrieval.try_load_torch_extensions()
652
+ except Exception as e:
653
+ raise(f"Unable to load `segmented_maxsim.cpp`. hf-hub does not download this file automatically. Please download it manually from `https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/blob/main/segmented_maxsim.cpp` and put it under the same folder as the model file.\n {e}")
654
+
655
+ if self.config.mask_punctuation:
656
+ self.skiplist = {
657
+ w: True
658
+ for symbol in string.punctuation
659
+ for w in [symbol, self.context_tokenizer.encode(symbol, add_special_tokens=False)[0]]
660
+ }
661
+
662
+ if self.config.mask_instruction_token is not None:
663
+ self.mask_instruction = True
664
+ # obtain the token id of the instruction token
665
+ self.instruction_token_id = self.query_tokenizer.encode(
666
+ self.config.mask_instruction_token, add_special_tokens=False
667
+ )[0]
668
+ else:
669
+ self.mask_instruction = False
670
+
671
+ self.loss_fn = torch.nn.CrossEntropyLoss()
672
+
673
+ # Initialize weights and apply final processing
674
+ self.post_init()
675
+
676
+ @property
677
+ def use_gpu(self):
678
+ return self.device.type == "cuda"
679
+
680
+ @classmethod
681
+ def from_pretrained(self, name_or_path, **kwargs):
682
+ obj = super().from_pretrained(name_or_path, **kwargs)
683
+ return obj
684
+
685
+ @classmethod
686
+ def try_load_torch_extensions(cls):
687
+ if hasattr(cls, "loaded_extensions"):
688
+ return
689
+
690
+ logger.info(
691
+ "Loading segmented_maxsim_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)..."
692
+ )
693
+ segmented_maxsim_cpp = load(
694
+ name="segmented_maxsim_cpp",
695
+ sources=[
696
+ os.path.join(pathlib.Path(__file__).parent.resolve(), "segmented_maxsim.cpp"),
697
+ ],
698
+ extra_cflags=["-O3"],
699
+ verbose=os.getenv("COLBERT_LOAD_TORCH_EXTENSION_VERBOSE", "False") == "True",
700
+ )
701
+ cls.segmented_maxsim = segmented_maxsim_cpp.segmented_maxsim_cpp
702
+
703
+ cls.loaded_extensions = True
704
+
705
+ def query_mask(self, input_ids, skiplist):
706
+ if not self.mask_instruction:
707
+ return self.mask(input_ids, skiplist)
708
+
709
+ # find the position of end of instruction in input_ids
710
+ # mask the tokens before the position
711
+ sep_id = self.instruction_token_id
712
+ sep_positions = torch.argmax((input_ids == sep_id).int(), dim=1).tolist()
713
+ # if any of the positions is lower than 1, set to 1
714
+ for i, x in enumerate(sep_positions):
715
+ if x < 1:
716
+ sep_positions[i] = 1
717
+ logger.error(f"can not find the separator in the input_ids: {input_ids[i].tolist()}")
718
+ mask = [
719
+ [
720
+ (x not in skiplist) and (x != 0) and (index > sep_positions[seq_index] or index < 2)
721
+ for index, x in enumerate(d)
722
+ ]
723
+ for seq_index, d in enumerate(input_ids.cpu().tolist())
724
+ ]
725
+ return mask
726
+
727
+ @add_start_docstrings_to_model_forward(FLMR_MODEL_INPUTS_DOCSTRING)
728
+ @replace_return_docstrings(output_type=FLMRModelForRetrievalOutput, config_class=_CONFIG_FOR_DOC)
729
+ def forward(
730
+ self,
731
+ query_input_ids: Optional[torch.Tensor] = None,
732
+ query_attention_mask: Optional[torch.Tensor] = None,
733
+ query_pixel_values: Optional[torch.Tensor] = None,
734
+ query_image_features: Optional[torch.Tensor] = None,
735
+ context_input_ids: Optional[torch.Tensor] = None,
736
+ context_attention_mask: Optional[torch.Tensor] = None,
737
+ context_pixel_values: Optional[torch.Tensor] = None,
738
+ context_image_features: Optional[torch.Tensor] = None,
739
+ use_in_batch_negatives: bool = True,
740
+ in_batch_negatives_from_all_gpus: bool = False,
741
+ num_negative_examples: int = 1,
742
+ query_concat_output_from_vision_encoder: Optional[bool] = None,
743
+ query_concat_output_from_text_encoder: Optional[bool] = None,
744
+ context_concat_output_from_vision_encoder: Optional[bool] = None,
745
+ context_concat_output_from_text_encoder: Optional[bool] = None,
746
+ return_dict: bool = None,
747
+ output_attentions: bool = None,
748
+ output_hidden_states: bool = None,
749
+ ) -> Union[FLMRModelForRetrievalOutput, Tuple[Tensor, ...]]:
750
+ r"""
751
+ Return:
752
+
753
+ Examples:
754
+
755
+ ```python
756
+ >>> import torch
757
+ >>> from transformers import FLMRQueryEncoderTokenizer, FLMRContextEncoderTokenizer, FLMRModelForRetrieval, AutoImageProcessor
758
+
759
+ >>> checkpoint_path = "LinWeizheDragon/PreFLMR_ViT-L"
760
+ >>> image_processor_name = "openai/clip-vit-large-patch14"
761
+ >>> query_tokenizer = FLMRQueryEncoderTokenizer.from_pretrained(checkpoint_path, subfolder="query_tokenizer")
762
+ >>> context_tokenizer = FLMRContextEncoderTokenizer.from_pretrained(checkpoint_path, subfolder="context_tokenizer")
763
+
764
+ >>> model = FLMRModelForRetrieval.from_pretrained(checkpoint_path,
765
+ query_tokenizer=query_tokenizer,
766
+ context_tokenizer=context_tokenizer,
767
+ )
768
+ >>> image_processor = AutoImageProcessor.from_pretrained(image_processor_name)
769
+
770
+ >>> Q_encoding = query_tokenizer(["Using the provided image, obtain documents that address the subsequent question: What is the capital of France?", "Extract documents linked to the question provided in conjunction with the image: What is the capital of China?"])
771
+ >>> D_encoding = context_tokenizer(["Paris is the capital of France.", "Beijing is the capital of China.",
772
+ "Paris is the capital of France.", "Beijing is the capital of China."])
773
+ >>> Q_pixel_values = torch.zeros(2, 3, 224, 224)
774
+ >>> inputs = dict(
775
+ query_input_ids=Q_encoding['input_ids'],
776
+ query_attention_mask=Q_encoding['attention_mask'],
777
+ query_pixel_values=Q_pixel_values,
778
+ context_input_ids=D_encoding['input_ids'],
779
+ context_attention_mask=D_encoding['attention_mask'],
780
+ use_in_batch_negatives=True,
781
+ )
782
+
783
+ >>> model.forward(**inputs)
784
+ FLMRModelForRetrievalOutput(loss=tensor(4.5000, device='cuda:0', dtype=torch.float16,
785
+ grad_fn=<NllLossBackward0>), scores=tensor([[44.2188, 40.6562],
786
+ [39.4375, 48.4062]], device='cuda:0', dtype=torch.float16,
787
+ grad_fn=<ViewBackward0>), in_batch_negative_loss=tensor(5.1994, device='cuda:0', grad_fn=<NllLossBackward0>), query_late_interaction_output=tensor(...), context_late_interaction_output=tensor(...)
788
+ ```
789
+ """
790
+
791
+ if query_concat_output_from_vision_encoder is None:
792
+ query_concat_output_from_vision_encoder = self.config.query_concat_output_from_vision_encoder
793
+
794
+ if query_concat_output_from_text_encoder is None:
795
+ query_concat_output_from_text_encoder = self.config.query_concat_output_from_text_encoder
796
+
797
+ if context_concat_output_from_vision_encoder is None:
798
+ context_concat_output_from_vision_encoder = self.config.context_concat_output_from_vision_encoder
799
+
800
+ if context_concat_output_from_text_encoder is None:
801
+ context_concat_output_from_text_encoder = self.config.context_concat_output_from_text_encoder
802
+
803
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
804
+ output_hidden_states = (
805
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
806
+ )
807
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
808
+
809
+ query_outputs = self.query(
810
+ input_ids=query_input_ids,
811
+ attention_mask=query_attention_mask,
812
+ pixel_values=query_pixel_values,
813
+ image_features=query_image_features,
814
+ concat_output_from_vision_encoder=query_concat_output_from_vision_encoder,
815
+ concat_output_from_text_encoder=query_concat_output_from_text_encoder,
816
+ output_attentions=output_attentions,
817
+ output_hidden_states=output_hidden_states,
818
+ )
819
+ Q = query_outputs.late_interaction_output
820
+
821
+ context_outputs = self.doc(
822
+ input_ids=context_input_ids,
823
+ attention_mask=context_attention_mask,
824
+ pixel_values=context_pixel_values,
825
+ image_features=context_image_features,
826
+ concat_output_from_vision_encoder=context_concat_output_from_vision_encoder,
827
+ concat_output_from_text_encoder=context_concat_output_from_text_encoder,
828
+ keep_dims=True,
829
+ return_mask=True,
830
+ output_attentions=output_attentions,
831
+ output_hidden_states=output_hidden_states,
832
+ )
833
+ D, D_mask = context_outputs.late_interaction_output, context_outputs.context_mask
834
+
835
+ # Gather tensors from other GPUs
836
+ if in_batch_negatives_from_all_gpus:
837
+ Q, D, D_mask = self.gather_tensors_from_other_gpus(Q, D, D_mask)
838
+ # Repeat each query encoding for every corresponding document.
839
+ Q_duplicated = Q.repeat_interleave(num_negative_examples + 1, dim=0).contiguous()
840
+
841
+ scores = self.score(Q_duplicated, D, D_mask)
842
+
843
+ # Use contrastive learning
844
+ batch_size = query_input_ids.shape[0]
845
+ scores = scores.view(-1, num_negative_examples + 1)
846
+ labels = torch.zeros(batch_size, dtype=torch.long, device=self.device)
847
+ loss = self.loss_fn(scores, labels)
848
+
849
+ if use_in_batch_negatives:
850
+ ib_loss = self.compute_ib_loss_new(Q, D, D_mask)
851
+ else:
852
+ ib_loss = None
853
+
854
+ if output_attentions:
855
+ query_attentions = (
856
+ query_outputs.text_encoder_attentions if query_outputs.text_encoder_attentions is not None else None,
857
+ query_outputs.vision_encoder_attentions
858
+ if query_outputs.vision_encoder_attentions is not None
859
+ else None,
860
+ query_outputs.transformer_mapping_network_attentions
861
+ if query_outputs.transformer_mapping_network_attentions is not None
862
+ else None,
863
+ )
864
+ context_attentions = (
865
+ context_outputs.text_encoder_attentions
866
+ if context_outputs.text_encoder_attentions is not None
867
+ else None,
868
+ context_outputs.vision_encoder_attentions
869
+ if context_outputs.vision_encoder_attentions is not None
870
+ else None,
871
+ context_outputs.transformer_mapping_network_attentions
872
+ if context_outputs.transformer_mapping_network_attentions is not None
873
+ else None,
874
+ )
875
+ else:
876
+ query_attentions = None
877
+ context_attentions = None
878
+
879
+ if output_hidden_states:
880
+ query_hidden_states = (
881
+ query_outputs.text_encoder_hidden_states
882
+ if query_outputs.text_encoder_hidden_states is not None
883
+ else None,
884
+ query_outputs.vision_encoder_hidden_states
885
+ if query_outputs.vision_encoder_hidden_states is not None
886
+ else None,
887
+ query_outputs.transformer_mapping_network_hidden_states
888
+ if query_outputs.transformer_mapping_network_hidden_states is not None
889
+ else None,
890
+ )
891
+ context_hidden_states = (
892
+ context_outputs.text_encoder_hidden_states
893
+ if context_outputs.text_encoder_hidden_states is not None
894
+ else None,
895
+ context_outputs.vision_encoder_hidden_states
896
+ if context_outputs.vision_encoder_hidden_states is not None
897
+ else None,
898
+ context_outputs.transformer_mapping_network_hidden_states
899
+ if context_outputs.transformer_mapping_network_hidden_states is not None
900
+ else None,
901
+ )
902
+ else:
903
+ query_hidden_states = None
904
+ context_hidden_states = None
905
+
906
+ if not return_dict:
907
+ if output_attentions and output_hidden_states:
908
+ return (
909
+ loss,
910
+ scores,
911
+ ib_loss,
912
+ query_outputs.late_interaction_output,
913
+ context_outputs.late_interaction_output,
914
+ query_attentions,
915
+ query_hidden_states,
916
+ context_attentions,
917
+ context_hidden_states,
918
+ )
919
+ elif output_attentions:
920
+ return (
921
+ loss,
922
+ scores,
923
+ ib_loss,
924
+ query_outputs.late_interaction_output,
925
+ context_outputs.late_interaction_output,
926
+ query_attentions,
927
+ context_attentions,
928
+ )
929
+ elif output_hidden_states:
930
+ return (
931
+ loss,
932
+ scores,
933
+ ib_loss,
934
+ query_outputs.late_interaction_output,
935
+ context_outputs.late_interaction_output,
936
+ query_hidden_states,
937
+ context_hidden_states,
938
+ )
939
+ else:
940
+ return (
941
+ loss,
942
+ scores,
943
+ ib_loss,
944
+ query_outputs.late_interaction_output,
945
+ context_outputs.late_interaction_output,
946
+ )
947
+
948
+ return FLMRModelForRetrievalOutput(
949
+ loss=loss,
950
+ scores=scores,
951
+ in_batch_negative_loss=ib_loss,
952
+ query_late_interaction_output=query_outputs.late_interaction_output,
953
+ context_late_interaction_output=context_outputs.late_interaction_output,
954
+ query_attentions=query_attentions if output_attentions else None,
955
+ query_hidden_states=query_hidden_states if output_hidden_states else None,
956
+ context_attentions=context_attentions if output_attentions else None,
957
+ context_hidden_states=context_hidden_states if output_hidden_states else None,
958
+ )
959
+
960
+ def compute_ib_loss_new(self, Q: torch.Tensor, D: torch.Tensor, D_mask: torch.Tensor) -> torch.Tensor:
961
+ # Q: batch_size x q_len x dim
962
+ # D: batch_size*n_docs x i_len x dim
963
+ # D_mask: batch_size*n_docs x i_len x dim
964
+ # 1 x batch_size*n_docs x i_len x dim matmul batch_size x 1 x q_len x dim
965
+ # = batch_size x batch_size*n_docs x i_len x q_len
966
+
967
+ scores = (D.float().unsqueeze(0) @ Q.float().permute(0, 2, 1).unsqueeze(1)).flatten(
968
+ 0, 1
969
+ ) # query-major unsqueeze
970
+ scores = colbert_score_reduce(scores, D_mask.repeat(Q.size(0), 1, 1))
971
+
972
+ in_batch_scores = scores.reshape(Q.size(0), -1)
973
+
974
+ batch_size = Q.shape[0]
975
+ batch_size_with_pos_and_neg = D.shape[0]
976
+ num_pos_and_neg = batch_size_with_pos_and_neg // batch_size
977
+
978
+ # batch_size x dim matmul dim x (num_pos+num_neg)*batch_size
979
+ # --> batch_size x (num_pos+num_neg)*batch_size
980
+ in_batch_labels = torch.zeros(batch_size, batch_size_with_pos_and_neg).to(scores.device)
981
+ step = num_pos_and_neg
982
+ for i in range(batch_size):
983
+ in_batch_labels[i, step * i] = 1
984
+ # print('in_batch_labels', in_batch_labels)
985
+ in_batch_labels = torch.argmax(in_batch_labels, dim=1)
986
+ # print('in_batch_labels', in_batch_labels)
987
+
988
+ loss = self.loss_fn(in_batch_scores, in_batch_labels)
989
+
990
+ return loss
991
+
992
+ def gather_tensors_from_other_gpus(self, query_embeddings, item_embeddings, item_mask):
993
+ # print("get rank", get_rank())
994
+ # print("get world size", get_world_size())
995
+ # Gather embeddings from other GPUs
996
+ n_nodes = get_world_size()
997
+ if n_nodes == 1:
998
+ return query_embeddings, item_embeddings, item_mask
999
+ # Create placeholder to hold embeddings passed from other ranks
1000
+ global_query_embeddings_placeholder = [
1001
+ torch.zeros(*query_embeddings.shape, dtype=query_embeddings.dtype).to(query_embeddings.device)
1002
+ for _ in range(n_nodes)
1003
+ ]
1004
+ global_item_embeddings_placeholder = [
1005
+ torch.zeros(*item_embeddings.shape, dtype=item_embeddings.dtype).to(item_embeddings.device)
1006
+ for _ in range(n_nodes)
1007
+ ]
1008
+ global_item_mask_placeholder = [
1009
+ torch.zeros(*item_mask.shape, dtype=item_mask.dtype).to(item_mask.device) for _ in range(n_nodes)
1010
+ ]
1011
+ dist.all_gather(global_query_embeddings_placeholder, query_embeddings.detach())
1012
+ dist.all_gather(global_item_embeddings_placeholder, item_embeddings.detach())
1013
+ dist.all_gather(global_item_mask_placeholder, item_mask.detach())
1014
+
1015
+ global_query_embeddings = []
1016
+ global_item_embeddings = []
1017
+ global_item_mask = []
1018
+ # print(f"rank {get_rank()} global_query_embeddings", global_query_embeddings)
1019
+ # print(f"rank {get_rank()} global_item_embeddings", global_item_embeddings)
1020
+ # input()
1021
+ current_rank = get_rank()
1022
+ for rank_index, remote_q_embeddings in enumerate(global_query_embeddings_placeholder):
1023
+ # We append the embeddings from other GPUs if this embedding does not require gradients
1024
+ if rank_index != current_rank:
1025
+ global_query_embeddings.append(remote_q_embeddings)
1026
+ else:
1027
+ global_query_embeddings.append(query_embeddings)
1028
+
1029
+ for rank_index, remote_item_embeddings in enumerate(global_item_embeddings_placeholder):
1030
+ # We append the embeddings from other GPUs if this embedding does not require gradients
1031
+ if rank_index != current_rank:
1032
+ global_item_embeddings.append(remote_item_embeddings)
1033
+ else:
1034
+ global_item_embeddings.append(item_embeddings)
1035
+
1036
+ for rank_index, remote_item_mask in enumerate(global_item_mask_placeholder):
1037
+ # We append the embeddings from other GPUs if this embedding does not require gradients
1038
+ if rank_index != current_rank:
1039
+ global_item_mask.append(remote_item_mask)
1040
+ else:
1041
+ global_item_mask.append(item_mask)
1042
+
1043
+ # Replace the previous variables with gathered tensors
1044
+ query_embeddings = torch.cat(global_query_embeddings)
1045
+ item_embeddings = torch.cat(global_item_embeddings)
1046
+ item_mask = torch.cat(global_item_mask)
1047
+
1048
+ return query_embeddings, item_embeddings, item_mask
1049
+
1050
+ @add_start_docstrings_to_model_forward(FLMR_MODEL_QUERY_INPUTS_DOCSTRING)
1051
+ @replace_return_docstrings(output_type=FLMRQueryEncoderOutput, config_class=_CONFIG_FOR_DOC)
1052
+ def query(
1053
+ self,
1054
+ input_ids: torch.Tensor,
1055
+ attention_mask: torch.Tensor,
1056
+ pixel_values: Optional[torch.Tensor] = None,
1057
+ image_features: Optional[torch.Tensor] = None,
1058
+ concat_output_from_vision_encoder: Optional[bool] = None,
1059
+ concat_output_from_text_encoder: Optional[bool] = None,
1060
+ output_attentions: Optional[bool] = None,
1061
+ output_hidden_states: Optional[bool] = None,
1062
+ ):
1063
+ r"""
1064
+ Returns:
1065
+
1066
+ """
1067
+
1068
+ if concat_output_from_vision_encoder is None:
1069
+ concat_output_from_vision_encoder = self.config.query_concat_output_from_vision_encoder
1070
+
1071
+ if concat_output_from_text_encoder is None:
1072
+ concat_output_from_text_encoder = self.config.query_concat_output_from_text_encoder
1073
+
1074
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1075
+ output_hidden_states = (
1076
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1077
+ )
1078
+
1079
+ input_modality = []
1080
+ if pixel_values is not None or image_features is not None:
1081
+ input_modality.append("image")
1082
+ if input_ids is not None and attention_mask is not None:
1083
+ input_modality.append("text")
1084
+
1085
+ text_encoder_outputs = None
1086
+ vision_encoder_outputs = None
1087
+ transformer_mapping_outputs = None
1088
+
1089
+ if "image" in input_modality:
1090
+ assert (
1091
+ pixel_values is not None or image_features is not None
1092
+ ), "pixel_values or image_features must be provided if image modality is used"
1093
+ assert (
1094
+ pixel_values is None or image_features is None
1095
+ ), "pixel_values and image_features cannot be provided at the same time"
1096
+
1097
+ if "text" in input_modality:
1098
+ assert (
1099
+ input_ids is not None and attention_mask is not None
1100
+ ), "input_ids and attention_mask must be provided if text modality is used"
1101
+ # Forward the text encoder
1102
+ input_ids, attention_mask = input_ids.to(self.device), attention_mask.to(self.device)
1103
+ text_encoder_outputs = self.query_text_encoder(input_ids, attention_mask=attention_mask)
1104
+ text_encoder_hidden_states = text_encoder_outputs[0]
1105
+ text_embeddings = self.query_text_encoder_linear(text_encoder_hidden_states)
1106
+ mask = torch.tensor(self.query_mask(input_ids, skiplist=[]), device=self.device).unsqueeze(2).float()
1107
+
1108
+ text_embeddings = text_embeddings * mask
1109
+
1110
+ if "image" in input_modality:
1111
+ if pixel_values is not None:
1112
+ batch_size = pixel_values.shape[0]
1113
+ # Forward the vision encoder
1114
+ pixel_values = pixel_values.to(self.device)
1115
+ if len(pixel_values.shape) == 5:
1116
+ # Multiple ROIs are provided
1117
+ # merge the first two dimensions
1118
+ pixel_values = pixel_values.reshape(
1119
+ -1, pixel_values.shape[2], pixel_values.shape[3], pixel_values.shape[4]
1120
+ )
1121
+ vision_encoder_outputs = self.query_vision_encoder(pixel_values, output_hidden_states=True)
1122
+ vision_embeddings = vision_encoder_outputs.last_hidden_state[:, 0]
1123
+
1124
+ if image_features is not None:
1125
+ batch_size = image_features.shape[0]
1126
+ vision_embeddings = image_features.to(self.device)
1127
+
1128
+ # Forward the vision projection / mapping network
1129
+ vision_embeddings = self.query_vision_projection(vision_embeddings)
1130
+ vision_embeddings = vision_embeddings.view(batch_size, -1, self.late_interaction_embedding_size)
1131
+
1132
+ if self.config.use_transformer_mapping_network:
1133
+ # select the second last layer
1134
+ vision_second_last_layer_hidden_states = vision_encoder_outputs.hidden_states[-2][:, 1:]
1135
+ # transformer_mapping
1136
+ transformer_mapping_input_features = self.transformer_mapping_input_linear(
1137
+ vision_second_last_layer_hidden_states
1138
+ )
1139
+
1140
+ # Cross attention only attends to the first 32 tokens
1141
+ encoder_mask = torch.ones_like(mask).to(mask.device, dtype=mask.dtype)
1142
+ cross_attention_length = self.config.transformer_mapping_cross_attention_length
1143
+ if text_encoder_hidden_states.shape[1] > cross_attention_length:
1144
+ text_encoder_hidden_states = text_encoder_hidden_states[:, :cross_attention_length]
1145
+ encoder_mask = encoder_mask[:, :cross_attention_length]
1146
+
1147
+ # Obtain cross attention mask
1148
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_mask.squeeze(-1))
1149
+ # Pass through the transformer mapping
1150
+ transformer_mapping_outputs = self.transformer_mapping_network(
1151
+ transformer_mapping_input_features,
1152
+ encoder_hidden_states=text_encoder_hidden_states,
1153
+ encoder_attention_mask=encoder_extended_attention_mask,
1154
+ )
1155
+ transformer_mapping_output_features = transformer_mapping_outputs.last_hidden_state
1156
+ # Convert the dimension to FLMR dim
1157
+ transformer_mapping_output_features = self.transformer_mapping_output_linear(
1158
+ transformer_mapping_output_features
1159
+ )
1160
+ # Merge with the vision embeddings
1161
+ vision_embeddings = torch.cat([vision_embeddings, transformer_mapping_output_features], dim=1)
1162
+
1163
+ if concat_output_from_vision_encoder and concat_output_from_text_encoder:
1164
+ Q = torch.cat([text_embeddings, vision_embeddings], dim=1)
1165
+ elif concat_output_from_vision_encoder:
1166
+ Q = vision_embeddings
1167
+ elif concat_output_from_text_encoder:
1168
+ Q = text_embeddings
1169
+
1170
+ vision_encoder_attentions = (
1171
+ vision_encoder_outputs.attentions
1172
+ if vision_encoder_outputs is not None
1173
+ and hasattr(vision_encoder_outputs, "attentions")
1174
+ and output_attentions
1175
+ else None
1176
+ )
1177
+ vision_encoder_hidden_states = (
1178
+ vision_encoder_outputs.hidden_states
1179
+ if vision_encoder_outputs is not None
1180
+ and hasattr(vision_encoder_outputs, "hidden_states")
1181
+ and output_hidden_states
1182
+ else None
1183
+ )
1184
+ text_encoder_attentions = (
1185
+ text_encoder_outputs.attentions
1186
+ if text_encoder_outputs is not None and hasattr(text_encoder_outputs, "attentions") and output_attentions
1187
+ else None
1188
+ )
1189
+ text_encoder_hidden_states = (
1190
+ text_encoder_outputs.hidden_states
1191
+ if text_encoder_outputs is not None
1192
+ and hasattr(text_encoder_outputs, "hidden_states")
1193
+ and output_hidden_states
1194
+ else None
1195
+ )
1196
+ transformer_mapping_network_attentions = (
1197
+ transformer_mapping_outputs.attentions
1198
+ if transformer_mapping_outputs is not None
1199
+ and hasattr(transformer_mapping_outputs, "attentions")
1200
+ and output_attentions
1201
+ else None
1202
+ )
1203
+ transformer_mapping_network_hidden_states = (
1204
+ transformer_mapping_outputs.hidden_states
1205
+ if transformer_mapping_outputs is not None
1206
+ and hasattr(transformer_mapping_outputs, "hidden_states")
1207
+ and output_hidden_states
1208
+ else None
1209
+ )
1210
+
1211
+ return FLMRQueryEncoderOutput(
1212
+ pooler_output=Q[:, 0, :],
1213
+ late_interaction_output=torch.nn.functional.normalize(Q, p=2, dim=2),
1214
+ vision_encoder_attentions=vision_encoder_attentions,
1215
+ vision_encoder_hidden_states=vision_encoder_hidden_states,
1216
+ text_encoder_attentions=text_encoder_attentions,
1217
+ text_encoder_hidden_states=text_encoder_hidden_states,
1218
+ transformer_mapping_network_attentions=transformer_mapping_network_attentions,
1219
+ transformer_mapping_network_hidden_states=transformer_mapping_network_hidden_states,
1220
+ )
1221
+
1222
+ @add_start_docstrings_to_model_forward(FLMR_MODEL_CONTEXT_INPUTS_DOCSTRING)
1223
+ @replace_return_docstrings(output_type=FLMRContextEncoderOutput, config_class=_CONFIG_FOR_DOC)
1224
+ def doc(
1225
+ self,
1226
+ input_ids: torch.Tensor,
1227
+ attention_mask: torch.Tensor,
1228
+ pixel_values: Optional[torch.Tensor] = None,
1229
+ image_features: Optional[torch.Tensor] = None,
1230
+ concat_output_from_vision_encoder: Optional[bool] = None,
1231
+ concat_output_from_text_encoder: Optional[bool] = None,
1232
+ keep_dims: Optional[bool] = True,
1233
+ return_mask: Optional[bool] = True,
1234
+ output_attentions: Optional[bool] = None,
1235
+ output_hidden_states: Optional[bool] = None,
1236
+ ):
1237
+ r"""
1238
+ Returns:
1239
+
1240
+ """
1241
+ assert keep_dims in [True, False]
1242
+
1243
+ if concat_output_from_vision_encoder is None:
1244
+ concat_output_from_vision_encoder = self.config.context_concat_output_from_vision_encoder
1245
+
1246
+ if concat_output_from_text_encoder is None:
1247
+ concat_output_from_text_encoder = self.config.context_concat_output_from_text_encoder
1248
+
1249
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1250
+ output_hidden_states = (
1251
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1252
+ )
1253
+
1254
+ input_modality = []
1255
+ if pixel_values is not None or image_features is not None:
1256
+ input_modality.append("image")
1257
+ if input_ids is not None and attention_mask is not None:
1258
+ input_modality.append("text")
1259
+
1260
+ text_encoder_outputs = None
1261
+ vision_encoder_outputs = None
1262
+
1263
+ if "image" in input_modality:
1264
+ assert (
1265
+ pixel_values is not None or image_features is not None
1266
+ ), "pixel_values or image_features must be provided if image modality is used"
1267
+ assert (
1268
+ pixel_values is None or image_features is None
1269
+ ), "pixel_values and image_features cannot be provided at the same time"
1270
+
1271
+ if "text" in input_modality:
1272
+ assert (
1273
+ input_ids is not None and attention_mask is not None
1274
+ ), "input_ids and attention_mask must be provided if text modality is used"
1275
+ # Forward the text encoder
1276
+ input_ids, attention_mask = input_ids.to(self.device), attention_mask.to(self.device)
1277
+ text_encoder_outputs = self.context_text_encoder(input_ids, attention_mask=attention_mask)
1278
+ text_embeddings = text_encoder_outputs[0]
1279
+ text_embeddings = self.context_text_encoder_linear(text_embeddings)
1280
+
1281
+ mask = torch.tensor(self.mask(input_ids, skiplist=self.skiplist), device=self.device).unsqueeze(2).float()
1282
+ text_embeddings = text_embeddings * mask
1283
+
1284
+ if "image" in input_modality:
1285
+ if pixel_values is not None:
1286
+ # Forward the vision encoder
1287
+ pixel_values = pixel_values.to(self.device)
1288
+ vision_encoder_outputs = self.context_vision_encoder(pixel_values)
1289
+ vision_embeddings = vision_encoder_outputs.last_hidden_state[:, 0]
1290
+
1291
+ if image_features is not None:
1292
+ vision_embeddings = image_features.to(self.device)
1293
+
1294
+ batch_size = vision_embeddings.shape[0]
1295
+
1296
+ # Forward the vision projection / mapping network
1297
+ vision_embeddings = self.context_vision_projection(vision_embeddings)
1298
+ vision_embeddings = vision_embeddings.view(
1299
+ -1, self.mapping_network_prefix_length, self.late_interaction_embedding_size
1300
+ )
1301
+
1302
+ image_mask = torch.ones(batch_size, vision_embeddings.shape[1], 1).to(self.device)
1303
+
1304
+ if concat_output_from_vision_encoder and concat_output_from_text_encoder:
1305
+ # Note: vision embeddings must be in the front since the ColBERT engine only indexes embeddings up to number of 1's in the mask
1306
+ # TODO: fix the engine to support masks with discontinuous 0 and 1.
1307
+ D = torch.cat([vision_embeddings, text_embeddings], dim=1)
1308
+ # concatenate the mask
1309
+ mask = torch.cat([image_mask, mask], dim=1)
1310
+ elif concat_output_from_vision_encoder:
1311
+ D = vision_embeddings
1312
+ mask = image_mask
1313
+ elif concat_output_from_text_encoder:
1314
+ D = text_embeddings
1315
+ mask = mask
1316
+
1317
+ D = torch.nn.functional.normalize(D, p=2, dim=2)
1318
+
1319
+ if self.use_gpu:
1320
+ D = D.half()
1321
+
1322
+ if keep_dims is False:
1323
+ D, mask = D.cpu(), mask.bool().cpu().squeeze(-1)
1324
+ D = [d[mask[idx]] for idx, d in enumerate(D)]
1325
+
1326
+ vision_encoder_attentions = (
1327
+ vision_encoder_outputs.attentions
1328
+ if vision_encoder_outputs is not None
1329
+ and hasattr(vision_encoder_outputs, "attentions")
1330
+ and output_attentions
1331
+ else None
1332
+ )
1333
+ vision_encoder_hidden_states = (
1334
+ vision_encoder_outputs.hidden_states
1335
+ if vision_encoder_outputs is not None
1336
+ and hasattr(vision_encoder_outputs, "hidden_states")
1337
+ and output_hidden_states
1338
+ else None
1339
+ )
1340
+ text_encoder_attentions = (
1341
+ text_encoder_outputs.attentions
1342
+ if text_encoder_outputs is not None and hasattr(text_encoder_outputs, "attentions") and output_attentions
1343
+ else None
1344
+ )
1345
+ text_encoder_hidden_states = (
1346
+ text_encoder_outputs.hidden_states
1347
+ if text_encoder_outputs is not None
1348
+ and hasattr(text_encoder_outputs, "hidden_states")
1349
+ and output_hidden_states
1350
+ else None
1351
+ )
1352
+
1353
+ return FLMRContextEncoderOutput(
1354
+ pooler_output=D[:, 0, :],
1355
+ late_interaction_output=D,
1356
+ context_mask=mask.bool() if return_mask else None,
1357
+ vision_encoder_attentions=vision_encoder_attentions,
1358
+ vision_encoder_hidden_states=vision_encoder_hidden_states,
1359
+ text_encoder_attentions=text_encoder_attentions,
1360
+ text_encoder_hidden_states=text_encoder_hidden_states,
1361
+ )
1362
+
1363
+ def score(self, Q, D_padded, D_mask):
1364
+ # assert self.colbert_config.similarity == 'cosine'
1365
+ # if self.colbert_config.similarity == 'l2':
1366
+ # assert self.colbert_config.interaction == 'colbert'
1367
+ # return (-1.0 * ((Q.unsqueeze(2) - D_padded.unsqueeze(1))**2).sum(-1)).max(-1).values.sum(-1)
1368
+ return colbert_score(Q, D_padded, D_mask, use_gpu=self.use_gpu)
1369
+
1370
+ def mask(self, input_ids, skiplist):
1371
+ mask = [[(x not in skiplist) and (x != 0) for x in d] for d in input_ids.cpu().tolist()]
1372
+ return mask
1373
+
1374
+
1375
+ @add_start_docstrings(
1376
+ "The bare FLMR text encoder that can be used to generate late-interaction embeddings for texts in queries and contexts. This model is based on a `BertModel`. It can be used like a `BertModel` model for encoding text.",
1377
+ FLMR_TEXT_ENCODERS_START_DOCSTRING,
1378
+ )
1379
+ class FLMRTextModel(FLMRPreTrainedModel):
1380
+ base_model_prefix = "bert_model"
1381
+ config_class = FLMRTextConfig
1382
+
1383
+ def __init__(self, config: FLMRTextConfig, *args, **kwargs):
1384
+ super().__init__(config)
1385
+ self.bert_model = BertModel(config, add_pooling_layer=True)
1386
+ if self.bert_model.config.hidden_size <= 0:
1387
+ raise ValueError("Encoder hidden_size can't be zero")
1388
+ self.projection_dim = config.projection_dim
1389
+ if self.projection_dim > 0:
1390
+ self.encode_proj = nn.Linear(self.bert_model.config.hidden_size, config.projection_dim)
1391
+ # Initialize weights and apply final processing
1392
+ self.post_init()
1393
+
1394
+ @add_start_docstrings_to_model_forward(FLMR_TEXT_ENCODERS_INPUTS_DOCSTRING)
1395
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=FLMRTextConfig)
1396
+ def forward(
1397
+ self,
1398
+ input_ids: Optional[Tensor] = None,
1399
+ attention_mask: Optional[Tensor] = None,
1400
+ token_type_ids: Optional[Tensor] = None,
1401
+ inputs_embeds: Optional[Tensor] = None,
1402
+ output_attentions: bool = None,
1403
+ output_hidden_states: bool = None,
1404
+ return_dict: bool = None,
1405
+ ) -> Union[BaseModelOutputWithPooling, Tuple[Tensor, ...]]:
1406
+ r"""
1407
+ Returns:
1408
+
1409
+ """
1410
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1411
+ output_hidden_states = (
1412
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1413
+ )
1414
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1415
+
1416
+ outputs = self.bert_model(
1417
+ input_ids=input_ids,
1418
+ attention_mask=attention_mask,
1419
+ token_type_ids=token_type_ids,
1420
+ inputs_embeds=inputs_embeds,
1421
+ output_attentions=output_attentions,
1422
+ output_hidden_states=output_hidden_states,
1423
+ return_dict=return_dict,
1424
+ )
1425
+ sequence_output = outputs[0]
1426
+ pooled_output = sequence_output[:, 0, :]
1427
+
1428
+ if self.projection_dim > 0:
1429
+ pooled_output = self.encode_proj(pooled_output)
1430
+
1431
+ if not return_dict:
1432
+ return (sequence_output, pooled_output) + outputs[2:]
1433
+
1434
+ return BaseModelOutputWithPooling(
1435
+ last_hidden_state=sequence_output,
1436
+ pooler_output=pooled_output,
1437
+ hidden_states=outputs.hidden_states,
1438
+ attentions=outputs.attentions,
1439
+ )
1440
+
1441
+ @property
1442
+ def embeddings_size(self) -> int:
1443
+ if self.projection_dim > 0:
1444
+ return self.encode_proj.out_features
1445
+ return self.bert_model.config.hidden_size
1446
+
1447
+
1448
+ @add_start_docstrings(
1449
+ "The bare FLMR vision encoder that can be used to generate late-interaction embeddings for images in queries and contexts. This model is based on a `CLIPVisionModel`. It can be used like a `CLIPVisionModel` model for encoding images.",
1450
+ FLMR_VISION_ENCODERS_START_DOCSTRING,
1451
+ )
1452
+ class FLMRVisionModel(FLMRPreTrainedModel):
1453
+ base_model_prefix = "vision_model"
1454
+ config_class = FLMRVisionConfig
1455
+ main_input_name = "pixel_values"
1456
+ _no_split_modules = ["CLIPEncoderLayer"]
1457
+
1458
+ def __init__(self, config: FLMRVisionConfig):
1459
+ super().__init__(config)
1460
+ self.vision_model = CLIPVisionModel(config)
1461
+ self.post_init()
1462
+
1463
+ def get_input_embeddings(self) -> nn.Module:
1464
+ return self.vision_model.vision_model.embeddings.patch_embedding
1465
+
1466
+ @add_start_docstrings_to_model_forward(FLMR_VISION_ENCODERS_INPUTS_DOCSTRING)
1467
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=FLMRVisionConfig)
1468
+ def forward(
1469
+ self,
1470
+ pixel_values: Optional[torch.FloatTensor] = None,
1471
+ output_attentions: Optional[bool] = None,
1472
+ output_hidden_states: Optional[bool] = None,
1473
+ return_dict: Optional[bool] = None,
1474
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
1475
+ r"""
1476
+ Returns:
1477
+
1478
+ Examples:
1479
+
1480
+ ```python
1481
+ >>> from PIL import Image
1482
+ >>> import requests
1483
+ >>> from transformers import AutoProcessor, FLMRVisionModel
1484
+
1485
+ >>> model = FLMRVisionModel.from_pretrained("openai/clip-vit-base-patch32")
1486
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
1487
+
1488
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1489
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1490
+
1491
+ >>> inputs = processor(images=image, return_tensors="pt")
1492
+
1493
+ >>> outputs = model(**inputs)
1494
+ >>> last_hidden_state = outputs.last_hidden_state
1495
+ >>> pooled_output = outputs.pooler_output # pooled CLS states
1496
+ ```"""
1497
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1498
+
1499
+ return self.vision_model(
1500
+ pixel_values=pixel_values,
1501
+ output_attentions=output_attentions,
1502
+ output_hidden_states=output_hidden_states,
1503
+ return_dict=return_dict,
1504
+ )
query_tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": {
3
+ "content": "[CLS]",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "mask_token": {
10
+ "content": "[MASK]",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "[PAD]",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "sep_token": {
24
+ "content": "[SEP]",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "unk_token": {
31
+ "content": "[UNK]",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ }
37
+ }
query_tokenizer/tokenization_flmr.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team, The Hugging Face Team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for FLMR."""
16
+
17
+
18
+ from typing import List, Optional, Union
19
+
20
+ from transformers.utils import TensorType, logging
21
+ from transformers.models.bert.tokenization_bert import BertTokenizer
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer_config.json"}
27
+
28
+ CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
29
+ "vocab_file": {
30
+ "LinWeizheDragon/PreFLMR_ViT-L": (
31
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/context_tokenizer/vocab.txt"
32
+ ),
33
+ "LinWeizheDragon/FLMR": (
34
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/context_tokenizer/vocab.txt"
35
+ ),
36
+ },
37
+ "tokenizer_file": {
38
+ "LinWeizheDragon/PreFLMR_ViT-L": (
39
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/context_tokenizer/tokenizer_config.json"
40
+ ),
41
+ "LinWeizheDragon/FLMR": (
42
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/context_tokenizer/tokenizer_config.json"
43
+ ),
44
+ },
45
+ }
46
+ QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
47
+ "vocab_file": {
48
+ "LinWeizheDragon/PreFLMR_ViT-L": (
49
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/query_tokenizer/vocab.txt"
50
+ ),
51
+ "LinWeizheDragon/FLMR": ("https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/query_tokenizer/vocab.txt"),
52
+ },
53
+ "tokenizer_file": {
54
+ "LinWeizheDragon/PreFLMR_ViT-L": (
55
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/query_tokenizer/tokenizer_config.json"
56
+ ),
57
+ "LinWeizheDragon/FLMR": (
58
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/query_tokenizer/tokenizer_config.json"
59
+ ),
60
+ },
61
+ }
62
+
63
+
64
+ CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
65
+ "LinWeizheDragon/PreFLMR_ViT-L": 512,
66
+ "LinWeizheDragon/FLMR": 512,
67
+ }
68
+ QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
69
+ "LinWeizheDragon/PreFLMR_ViT-L": 512,
70
+ "LinWeizheDragon/FLMR": 512,
71
+ }
72
+
73
+
74
+ CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
75
+ "LinWeizheDragon/PreFLMR_ViT-L": {"do_lower_case": True},
76
+ "LinWeizheDragon/FLMR": {"do_lower_case": True},
77
+ }
78
+ QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
79
+ "LinWeizheDragon/PreFLMR_ViT-L": {"do_lower_case": True},
80
+ "LinWeizheDragon/FLMR": {"do_lower_case": True},
81
+ }
82
+
83
+
84
+ # Modified from colbert.modeling.tokenization
85
+ class FLMRContextEncoderTokenizer(BertTokenizer):
86
+ r"""
87
+ Construct a FLMRContextEncoder tokenizer.
88
+
89
+ [`FLMRContextEncoderTokenizer`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation
90
+ splitting and wordpiece.
91
+
92
+ Refer to superclass [`BertTokenizer`] for usage examples and documentation concerning parameters.
93
+ """
94
+
95
+ vocab_files_names = VOCAB_FILES_NAMES
96
+ pretrained_vocab_files_map = CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP
97
+ max_model_input_sizes = CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
98
+ pretrained_init_configuration = CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION
99
+
100
+ def __init__(
101
+ self,
102
+ doc_maxlen: Optional[int] = 512,
103
+ **kwargs,
104
+ ):
105
+ super().__init__(
106
+ doc_maxlen=doc_maxlen,
107
+ **kwargs,
108
+ )
109
+
110
+ self.doc_maxlen = doc_maxlen
111
+ self.D_marker_token, self.D_marker_token_id = "[D]", self.convert_tokens_to_ids("[unused1]")
112
+
113
+ def __call__(
114
+ self,
115
+ text: List[str],
116
+ padding: Optional[Union[str, bool]] = "max_length",
117
+ truncation: Optional[Union[bool, str]] = "longest_first",
118
+ max_length: Optional[int] = 512,
119
+ return_tensors: Optional[Union[str, TensorType]] = "pt",
120
+ **kwargs,
121
+ ):
122
+ # add placehold for the [D] marker
123
+ text = [". " + x for x in text]
124
+
125
+ if max_length > self.doc_maxlen:
126
+ # can not exceed the pre-set length
127
+ max_length = self.doc_maxlen
128
+
129
+ encoding = super().__call__(
130
+ text,
131
+ padding=padding,
132
+ truncation=truncation,
133
+ return_tensors=return_tensors,
134
+ max_length=max_length,
135
+ **kwargs,
136
+ )
137
+
138
+ ids, mask = encoding["input_ids"], encoding["attention_mask"]
139
+
140
+ # postprocess for the [D] marker
141
+ ids[:, 1] = self.D_marker_token_id
142
+
143
+ # if bsize:
144
+ # # This bsize function is used in the original ColBERT codebase to split inputs into multiple batches
145
+ # if image_features is not None:
146
+ # ids, mask, image_features, reverse_indices = _sort_by_length(ids, mask, bsize, image_features=image_features)
147
+ # batches = _split_into_batches(ids, mask, bsize, image_features=image_features)
148
+ # else:
149
+ # ids, mask, reverse_indices = _sort_by_length(ids, mask, bsize)
150
+ # batches = _split_into_batches(ids, mask, bsize)
151
+
152
+ # return batches, reverse_indices
153
+
154
+ encoding["input_ids"] = ids
155
+ encoding["attention_mask"] = mask
156
+
157
+ return encoding
158
+
159
+
160
+ # Modified from colbert.modeling.tokenization
161
+ class FLMRQueryEncoderTokenizer(BertTokenizer):
162
+ r"""
163
+ Constructs a FLMRQueryEncoder tokenizer.
164
+
165
+ [`FLMRQueryEncoder`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation
166
+ splitting and wordpiece.
167
+
168
+ Refer to superclass [`BertTokenizer`] for usage examples and documentation concerning parameters.
169
+ """
170
+
171
+ vocab_files_names = VOCAB_FILES_NAMES
172
+ pretrained_vocab_files_map = QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP
173
+ max_model_input_sizes = QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
174
+ pretrained_init_configuration = QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION
175
+
176
+ def __init__(
177
+ self,
178
+ *args,
179
+ query_maxlen: Optional[int] = 32,
180
+ attend_to_mask_tokens: Optional[bool] = False,
181
+ **kwargs,
182
+ ):
183
+ super().__init__(
184
+ *args,
185
+ query_maxlen=query_maxlen,
186
+ attend_to_mask_tokens=attend_to_mask_tokens,
187
+ **kwargs,
188
+ )
189
+
190
+ self.query_maxlen = query_maxlen
191
+ self.background_maxlen = 512 - self.query_maxlen + 1 # FIXME: Make this configurable
192
+ self.attend_to_mask_tokens = attend_to_mask_tokens
193
+
194
+ self.Q_marker_token, self.Q_marker_token_id = "[Q]", self.convert_tokens_to_ids("[unused0]")
195
+
196
+ def __call__(
197
+ self,
198
+ text: Union[str, List[str]],
199
+ padding: Optional[Union[str, bool]] = "max_length",
200
+ truncation: Optional[Union[bool, str]] = True,
201
+ max_length: Optional[int] = None,
202
+ return_tensors: Optional[Union[str, TensorType]] = "pt",
203
+ **kwargs,
204
+ ):
205
+ if isinstance(text, str):
206
+ # convert to list if input is a single string
207
+ text = [text]
208
+
209
+ # add placehold for the [Q] marker
210
+ text = [". " + x for x in text]
211
+
212
+ if max_length is not None:
213
+ # use user specified max_length
214
+ pass
215
+ else:
216
+ # use default max length
217
+ max_length = self.query_maxlen
218
+
219
+ encoding = super().__call__(
220
+ text,
221
+ padding=padding,
222
+ truncation=truncation,
223
+ return_tensors=return_tensors,
224
+ max_length=max_length,
225
+ **kwargs,
226
+ )
227
+
228
+ ids, mask = encoding["input_ids"], encoding["attention_mask"]
229
+
230
+ # postprocess for the [Q] marker and the [MASK] augmentation
231
+ ids[:, 1] = self.Q_marker_token_id
232
+ ids[ids == self.pad_token_id] = self.mask_token_id
233
+
234
+ if self.attend_to_mask_tokens:
235
+ # When attend_to_mask_tokens is True, we want to attend to the [MASK] tokens
236
+ mask[ids == self.mask_token_id] = 1
237
+ assert mask.sum().item() == mask.size(0) * mask.size(1), mask
238
+
239
+ return {"input_ids": ids, "attention_mask": mask}
query_tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "attend_to_mask_tokens": false,
45
+ "auto_map": {
46
+ "AutoTokenizer": [
47
+ "tokenization_flmr.FLMRQueryEncoderTokenizer",
48
+ null
49
+ ]
50
+ },
51
+ "clean_up_tokenization_spaces": true,
52
+ "cls_token": "[CLS]",
53
+ "do_basic_tokenize": true,
54
+ "do_lower_case": true,
55
+ "mask_token": "[MASK]",
56
+ "model_max_length": 1000000000000000019884624838656,
57
+ "never_split": null,
58
+ "pad_token": "[PAD]",
59
+ "query_maxlen": 32,
60
+ "sep_token": "[SEP]",
61
+ "strip_accents": null,
62
+ "tokenize_chinese_chars": true,
63
+ "tokenizer_class": "FLMRQueryEncoderTokenizer",
64
+ "unk_token": "[UNK]"
65
+ }
query_tokenizer/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
segmented_maxsim.cpp ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <pthread.h>
2
+ #include <torch/extension.h>
3
+
4
+ #include <algorithm>
5
+ #include <numeric>
6
+
7
+ typedef struct {
8
+ int tid;
9
+ int nthreads;
10
+
11
+ int ndocs;
12
+ int ndoc_vectors;
13
+ int nquery_vectors;
14
+
15
+ int64_t* lengths;
16
+ float* scores;
17
+ int64_t* offsets;
18
+
19
+ float* max_scores;
20
+ } max_args_t;
21
+
22
+ void* max(void* args) {
23
+ max_args_t* max_args = (max_args_t*)args;
24
+
25
+ int ndocs_per_thread =
26
+ std::ceil(((float)max_args->ndocs) / max_args->nthreads);
27
+ int start = max_args->tid * ndocs_per_thread;
28
+ int end = std::min((max_args->tid + 1) * ndocs_per_thread, max_args->ndocs);
29
+
30
+ auto max_scores_offset =
31
+ max_args->max_scores + (start * max_args->nquery_vectors);
32
+ auto scores_offset =
33
+ max_args->scores + (max_args->offsets[start] * max_args->nquery_vectors);
34
+
35
+ for (int i = start; i < end; i++) {
36
+ for (int j = 0; j < max_args->lengths[i]; j++) {
37
+ std::transform(max_scores_offset,
38
+ max_scores_offset + max_args->nquery_vectors,
39
+ scores_offset, max_scores_offset,
40
+ [](float a, float b) { return std::max(a, b); });
41
+ scores_offset += max_args->nquery_vectors;
42
+ }
43
+ max_scores_offset += max_args->nquery_vectors;
44
+ }
45
+
46
+ return NULL;
47
+ }
48
+
49
+ torch::Tensor segmented_maxsim(const torch::Tensor scores,
50
+ const torch::Tensor lengths) {
51
+ auto lengths_a = lengths.data_ptr<int64_t>();
52
+ auto scores_a = scores.data_ptr<float>();
53
+ auto ndocs = lengths.size(0);
54
+ auto ndoc_vectors = scores.size(0);
55
+ auto nquery_vectors = scores.size(1);
56
+ auto nthreads = at::get_num_threads();
57
+
58
+ torch::Tensor max_scores =
59
+ torch::zeros({ndocs, nquery_vectors}, scores.options());
60
+
61
+ int64_t offsets[ndocs + 1];
62
+ offsets[0] = 0;
63
+ std::partial_sum(lengths_a, lengths_a + ndocs, offsets + 1);
64
+
65
+ pthread_t threads[nthreads];
66
+ max_args_t args[nthreads];
67
+
68
+ for (int i = 0; i < nthreads; i++) {
69
+ args[i].tid = i;
70
+ args[i].nthreads = nthreads;
71
+
72
+ args[i].ndocs = ndocs;
73
+ args[i].ndoc_vectors = ndoc_vectors;
74
+ args[i].nquery_vectors = nquery_vectors;
75
+
76
+ args[i].lengths = lengths_a;
77
+ args[i].scores = scores_a;
78
+ args[i].offsets = offsets;
79
+
80
+ args[i].max_scores = max_scores.data_ptr<float>();
81
+
82
+ int rc = pthread_create(&threads[i], NULL, max, (void*)&args[i]);
83
+ if (rc) {
84
+ fprintf(stderr, "Unable to create thread %d: %d\n", i, rc);
85
+ }
86
+ }
87
+
88
+ for (int i = 0; i < nthreads; i++) {
89
+ pthread_join(threads[i], NULL);
90
+ }
91
+
92
+ return max_scores.sum(1);
93
+ }
94
+
95
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
96
+ m.def("segmented_maxsim_cpp", &segmented_maxsim, "Segmented MaxSim");
97
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": {
3
+ "content": "[CLS]",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "mask_token": {
10
+ "content": "[MASK]",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "[PAD]",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "sep_token": {
24
+ "content": "[SEP]",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "unk_token": {
31
+ "content": "[UNK]",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ }
37
+ }
tokenization_flmr.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team, The Hugging Face Team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for FLMR."""
16
+
17
+
18
+ from typing import List, Optional, Union
19
+
20
+ from transformers.utils import TensorType, logging
21
+ from transformers.models.bert.tokenization_bert import BertTokenizer
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer_config.json"}
27
+
28
+ CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
29
+ "vocab_file": {
30
+ "LinWeizheDragon/PreFLMR_ViT-L": (
31
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/context_tokenizer/vocab.txt"
32
+ ),
33
+ "LinWeizheDragon/FLMR": (
34
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/context_tokenizer/vocab.txt"
35
+ ),
36
+ },
37
+ "tokenizer_file": {
38
+ "LinWeizheDragon/PreFLMR_ViT-L": (
39
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/context_tokenizer/tokenizer_config.json"
40
+ ),
41
+ "LinWeizheDragon/FLMR": (
42
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/context_tokenizer/tokenizer_config.json"
43
+ ),
44
+ },
45
+ }
46
+ QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
47
+ "vocab_file": {
48
+ "LinWeizheDragon/PreFLMR_ViT-L": (
49
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/query_tokenizer/vocab.txt"
50
+ ),
51
+ "LinWeizheDragon/FLMR": ("https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/query_tokenizer/vocab.txt"),
52
+ },
53
+ "tokenizer_file": {
54
+ "LinWeizheDragon/PreFLMR_ViT-L": (
55
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/query_tokenizer/tokenizer_config.json"
56
+ ),
57
+ "LinWeizheDragon/FLMR": (
58
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/query_tokenizer/tokenizer_config.json"
59
+ ),
60
+ },
61
+ }
62
+
63
+
64
+ CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
65
+ "LinWeizheDragon/PreFLMR_ViT-L": 512,
66
+ "LinWeizheDragon/FLMR": 512,
67
+ }
68
+ QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
69
+ "LinWeizheDragon/PreFLMR_ViT-L": 512,
70
+ "LinWeizheDragon/FLMR": 512,
71
+ }
72
+
73
+
74
+ CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
75
+ "LinWeizheDragon/PreFLMR_ViT-L": {"do_lower_case": True},
76
+ "LinWeizheDragon/FLMR": {"do_lower_case": True},
77
+ }
78
+ QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
79
+ "LinWeizheDragon/PreFLMR_ViT-L": {"do_lower_case": True},
80
+ "LinWeizheDragon/FLMR": {"do_lower_case": True},
81
+ }
82
+
83
+
84
+ # Modified from colbert.modeling.tokenization
85
+ class FLMRContextEncoderTokenizer(BertTokenizer):
86
+ r"""
87
+ Construct a FLMRContextEncoder tokenizer.
88
+
89
+ [`FLMRContextEncoderTokenizer`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation
90
+ splitting and wordpiece.
91
+
92
+ Refer to superclass [`BertTokenizer`] for usage examples and documentation concerning parameters.
93
+ """
94
+
95
+ vocab_files_names = VOCAB_FILES_NAMES
96
+ pretrained_vocab_files_map = CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP
97
+ max_model_input_sizes = CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
98
+ pretrained_init_configuration = CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION
99
+
100
+ def __init__(
101
+ self,
102
+ doc_maxlen: Optional[int] = 512,
103
+ **kwargs,
104
+ ):
105
+ super().__init__(
106
+ doc_maxlen=doc_maxlen,
107
+ **kwargs,
108
+ )
109
+
110
+ self.doc_maxlen = doc_maxlen
111
+ self.D_marker_token, self.D_marker_token_id = "[D]", self.convert_tokens_to_ids("[unused1]")
112
+
113
+ def __call__(
114
+ self,
115
+ text: List[str],
116
+ padding: Optional[Union[str, bool]] = "max_length",
117
+ truncation: Optional[Union[bool, str]] = "longest_first",
118
+ max_length: Optional[int] = 512,
119
+ return_tensors: Optional[Union[str, TensorType]] = "pt",
120
+ **kwargs,
121
+ ):
122
+ # add placehold for the [D] marker
123
+ text = [". " + x for x in text]
124
+
125
+ if max_length > self.doc_maxlen:
126
+ # can not exceed the pre-set length
127
+ max_length = self.doc_maxlen
128
+
129
+ encoding = super().__call__(
130
+ text,
131
+ padding=padding,
132
+ truncation=truncation,
133
+ return_tensors=return_tensors,
134
+ max_length=max_length,
135
+ **kwargs,
136
+ )
137
+
138
+ ids, mask = encoding["input_ids"], encoding["attention_mask"]
139
+
140
+ # postprocess for the [D] marker
141
+ ids[:, 1] = self.D_marker_token_id
142
+
143
+ # if bsize:
144
+ # # This bsize function is used in the original ColBERT codebase to split inputs into multiple batches
145
+ # if image_features is not None:
146
+ # ids, mask, image_features, reverse_indices = _sort_by_length(ids, mask, bsize, image_features=image_features)
147
+ # batches = _split_into_batches(ids, mask, bsize, image_features=image_features)
148
+ # else:
149
+ # ids, mask, reverse_indices = _sort_by_length(ids, mask, bsize)
150
+ # batches = _split_into_batches(ids, mask, bsize)
151
+
152
+ # return batches, reverse_indices
153
+
154
+ encoding["input_ids"] = ids
155
+ encoding["attention_mask"] = mask
156
+
157
+ return encoding
158
+
159
+
160
+ # Modified from colbert.modeling.tokenization
161
+ class FLMRQueryEncoderTokenizer(BertTokenizer):
162
+ r"""
163
+ Constructs a FLMRQueryEncoder tokenizer.
164
+
165
+ [`FLMRQueryEncoder`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation
166
+ splitting and wordpiece.
167
+
168
+ Refer to superclass [`BertTokenizer`] for usage examples and documentation concerning parameters.
169
+ """
170
+
171
+ vocab_files_names = VOCAB_FILES_NAMES
172
+ pretrained_vocab_files_map = QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP
173
+ max_model_input_sizes = QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
174
+ pretrained_init_configuration = QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION
175
+
176
+ def __init__(
177
+ self,
178
+ *args,
179
+ query_maxlen: Optional[int] = 32,
180
+ attend_to_mask_tokens: Optional[bool] = False,
181
+ **kwargs,
182
+ ):
183
+ super().__init__(
184
+ *args,
185
+ query_maxlen=query_maxlen,
186
+ attend_to_mask_tokens=attend_to_mask_tokens,
187
+ **kwargs,
188
+ )
189
+
190
+ self.query_maxlen = query_maxlen
191
+ self.background_maxlen = 512 - self.query_maxlen + 1 # FIXME: Make this configurable
192
+ self.attend_to_mask_tokens = attend_to_mask_tokens
193
+
194
+ self.Q_marker_token, self.Q_marker_token_id = "[Q]", self.convert_tokens_to_ids("[unused0]")
195
+
196
+ def __call__(
197
+ self,
198
+ text: Union[str, List[str]],
199
+ padding: Optional[Union[str, bool]] = "max_length",
200
+ truncation: Optional[Union[bool, str]] = True,
201
+ max_length: Optional[int] = None,
202
+ return_tensors: Optional[Union[str, TensorType]] = "pt",
203
+ **kwargs,
204
+ ):
205
+ if isinstance(text, str):
206
+ # convert to list if input is a single string
207
+ text = [text]
208
+
209
+ # add placehold for the [Q] marker
210
+ text = [". " + x for x in text]
211
+
212
+ if max_length is not None:
213
+ # use user specified max_length
214
+ pass
215
+ else:
216
+ # use default max length
217
+ max_length = self.query_maxlen
218
+
219
+ encoding = super().__call__(
220
+ text,
221
+ padding=padding,
222
+ truncation=truncation,
223
+ return_tensors=return_tensors,
224
+ max_length=max_length,
225
+ **kwargs,
226
+ )
227
+
228
+ ids, mask = encoding["input_ids"], encoding["attention_mask"]
229
+
230
+ # postprocess for the [Q] marker and the [MASK] augmentation
231
+ ids[:, 1] = self.Q_marker_token_id
232
+ ids[ids == self.pad_token_id] = self.mask_token_id
233
+
234
+ if self.attend_to_mask_tokens:
235
+ # When attend_to_mask_tokens is True, we want to attend to the [MASK] tokens
236
+ mask[ids == self.mask_token_id] = 1
237
+ assert mask.sum().item() == mask.size(0) * mask.size(1), mask
238
+
239
+ return {"input_ids": ids, "attention_mask": mask}
tokenization_flmr_fast.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team, The Hugging Face Team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for FLMR."""
16
+
17
+
18
+ from transformers.utils import logging
19
+ from transformers.models.bert.tokenization_bert_fast import BertTokenizerFast
20
+ from .tokenization_flmr import FLMRContextEncoderTokenizer, FLMRQueryEncoderTokenizer
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer_config.json"}
26
+
27
+ CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
28
+ "vocab_file": {
29
+ "LinWeizheDragon/PreFLMR_ViT-L": (
30
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/context_tokenizer/vocab.txt"
31
+ ),
32
+ "LinWeizheDragon/FLMR": (
33
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/context_tokenizer/vocab.txt"
34
+ ),
35
+ },
36
+ "tokenizer_file": {
37
+ "LinWeizheDragon/PreFLMR_ViT-L": (
38
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/context_tokenizer/tokenizer_config.json"
39
+ ),
40
+ "LinWeizheDragon/FLMR": (
41
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/context_tokenizer/tokenizer_config.json"
42
+ ),
43
+ },
44
+ }
45
+ QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
46
+ "vocab_file": {
47
+ "LinWeizheDragon/PreFLMR_ViT-L": (
48
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/query_tokenizer/vocab.txt"
49
+ ),
50
+ "LinWeizheDragon/FLMR": ("https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/query_tokenizer/vocab.txt"),
51
+ },
52
+ "tokenizer_file": {
53
+ "LinWeizheDragon/PreFLMR_ViT-L": (
54
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/query_tokenizer/tokenizer_config.json"
55
+ ),
56
+ "LinWeizheDragon/FLMR": (
57
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/query_tokenizer/tokenizer_config.json"
58
+ ),
59
+ },
60
+ }
61
+
62
+
63
+ CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
64
+ "LinWeizheDragon/PreFLMR_ViT-L": 512,
65
+ "LinWeizheDragon/FLMR": 512,
66
+ }
67
+ QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
68
+ "LinWeizheDragon/PreFLMR_ViT-L": 512,
69
+ "LinWeizheDragon/FLMR": 512,
70
+ }
71
+
72
+
73
+ CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
74
+ "LinWeizheDragon/PreFLMR_ViT-L": {"do_lower_case": True},
75
+ "LinWeizheDragon/FLMR": {"do_lower_case": True},
76
+ }
77
+ QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
78
+ "LinWeizheDragon/PreFLMR_ViT-L": {"do_lower_case": True},
79
+ "LinWeizheDragon/FLMR": {"do_lower_case": True},
80
+ }
81
+
82
+
83
+ class FLMRContextEncoderTokenizerFast(BertTokenizerFast):
84
+ r"""
85
+ Construct a "fast" FLMRContextEncoder tokenizer (backed by HuggingFace's *tokenizers* library).
86
+
87
+ [`FLMRContextEncoderTokenizerFast`] is identical to [`BertTokenizerFast`] and runs end-to-end tokenization:
88
+ punctuation splitting and wordpiece.
89
+
90
+ Refer to superclass [`BertTokenizerFast`] for usage examples and documentation concerning parameters.
91
+ """
92
+
93
+ vocab_files_names = VOCAB_FILES_NAMES
94
+ pretrained_vocab_files_map = CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP
95
+ max_model_input_sizes = CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
96
+ pretrained_init_configuration = CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION
97
+ slow_tokenizer_class = FLMRContextEncoderTokenizer
98
+
99
+
100
+ class FLMRQueryEncoderTokenizerFast(BertTokenizerFast):
101
+ r"""
102
+ Constructs a "fast" FLMRQueryEncoderTokenizer tokenizer (backed by HuggingFace's *tokenizers* library).
103
+
104
+ [`FLMRQueryEncoderTokenizerFast`] is identical to [`BertTokenizerFast`] and runs end-to-end tokenization:
105
+ punctuation splitting and wordpiece.
106
+
107
+ Refer to superclass [`BertTokenizerFast`] for usage examples and documentation concerning parameters.
108
+ """
109
+
110
+ vocab_files_names = VOCAB_FILES_NAMES
111
+ pretrained_vocab_files_map = QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP
112
+ max_model_input_sizes = QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
113
+ pretrained_init_configuration = QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION
114
+ slow_tokenizer_class = FLMRQueryEncoderTokenizer
tokenizer_config.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": true,
45
+ "cls_token": "[CLS]",
46
+ "do_basic_tokenize": true,
47
+ "do_lower_case": true,
48
+ "doc_maxlen": 512,
49
+ "mask_token": "[MASK]",
50
+ "model_max_length": 1000000000000000019884624838656,
51
+ "never_split": null,
52
+ "pad_token": "[PAD]",
53
+ "sep_token": "[SEP]",
54
+ "strip_accents": null,
55
+ "tokenize_chinese_chars": true,
56
+ "tokenizer_class": "FLMRContextEncoderTokenizer",
57
+ "unk_token": "[UNK]"
58
+ }
vocab.txt ADDED
The diff for this file is too large to render. See raw diff