paultltc commited on
Commit
6ce858b
·
1 Parent(s): 8699dec

clean modeling + fix config double loading

Browse files
config.json CHANGED
@@ -1,12 +1,12 @@
1
  {
2
  "additional_vocab_size": 40,
3
  "architectures": [
4
- "VBertForMaskedLM"
5
  ],
6
  "auto_map": {
7
- "AutoConfig": "configuration_vbert.VBertConfig",
8
- "AutoModel": "modeling_vbert.VBertModel",
9
- "AutoModelForMaskedLM": "modeling_vbert.VBertForMaskedLM"
10
  },
11
  "freeze_config": {
12
  "freeze_lm_head": true,
@@ -27,7 +27,6 @@
27
  "hidden_size": 768,
28
  "intermediate_size": 1152,
29
  "mlp_bias": false,
30
- "model_type": "vbert",
31
  "num_hidden_layers": 22,
32
  "text_model_name": "jhu-clsp/ettin-encoder-150m",
33
  "vocab_size": 50368
@@ -41,7 +40,6 @@
41
  "embed_dim": 768,
42
  "image_size": 512,
43
  "intermediate_size": 3072,
44
- "model_type": "vbert",
45
  "num_hidden_layers": 12,
46
  "patch_size": 16,
47
  "vision_model_name": "google/siglip2-base-patch16-512"
 
1
  {
2
  "additional_vocab_size": 40,
3
  "architectures": [
4
+ "ModernVBertForMaskedLM"
5
  ],
6
  "auto_map": {
7
+ "AutoConfig": "configuration_modernvbert.ModernVBertConfig",
8
+ "AutoModel": "modeling_modernvbert.ModernVBertModel",
9
+ "AutoModelForMaskedLM": "modeling_modernvbert.ModernVBertForMaskedLM"
10
  },
11
  "freeze_config": {
12
  "freeze_lm_head": true,
 
27
  "hidden_size": 768,
28
  "intermediate_size": 1152,
29
  "mlp_bias": false,
 
30
  "num_hidden_layers": 22,
31
  "text_model_name": "jhu-clsp/ettin-encoder-150m",
32
  "vocab_size": 50368
 
40
  "embed_dim": 768,
41
  "image_size": 512,
42
  "intermediate_size": 3072,
 
43
  "num_hidden_layers": 12,
44
  "patch_size": 16,
45
  "vision_model_name": "google/siglip2-base-patch16-512"
configuration_modernvbert.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ from typing import Any, Dict, Union
4
+
5
+ from transformers import AutoConfig
6
+ from transformers.configuration_utils import PretrainedConfig
7
+ from transformers.utils import logging
8
+
9
+ logger = logging.get_logger(__name__)
10
+
11
+ DEFAULT_TEXT_MODEL_NAME = "jhu-clsp/ettin-encoder-150m"
12
+ DEFAULT_VISION_MODEL_NAME = "google/siglip2-base-patch16-512"
13
+
14
+ def collect_arg_in_candidates(config, candidates, default=None) -> Any:
15
+ """Gets the first available argument in a config given a list of candidate names."""
16
+ for c in candidates:
17
+ if hasattr(config, c):
18
+ return getattr(config, c)
19
+ elif c in config:
20
+ return config[c]
21
+ if default is not None:
22
+ return default
23
+ raise ValueError(
24
+ f"No matching arguments found in candidates. Candidates: {candidates}, Config: {config}"
25
+ )
26
+
27
+ class ModernVBertTextConfig(PretrainedConfig):
28
+ r"""
29
+ This is the configuration class to store the configuration of a [`ModernBERT`]. It is used to instantiate an ModernBERT
30
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
31
+ defaults will yield a similar configuration to that of the [jhu-clsp/ettin-encoder-150m](https://huggingface.co/jhu-clsp/ettin-encoder-150m) architecture.
32
+
33
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
34
+ documentation from [`PretrainedConfig`] for more information.
35
+ """
36
+ model_type = "modernvbert_text"
37
+
38
+ def __init__(
39
+ self,
40
+ text_model_name=DEFAULT_TEXT_MODEL_NAME,
41
+ hidden_size=768,
42
+ num_hidden_layers=22,
43
+ intermediate_size=1152,
44
+ mlp_bias=False,
45
+ vocab_size=50368,
46
+ **kwargs,
47
+ ):
48
+ super().__init__(
49
+ text_model_name=text_model_name,
50
+ hidden_size=hidden_size,
51
+ num_hidden_layers=num_hidden_layers,
52
+ intermediate_size=intermediate_size,
53
+ mlp_bias=mlp_bias,
54
+ vocab_size=vocab_size,
55
+ **kwargs,
56
+ )
57
+
58
+ @classmethod
59
+ def from_base_model(
60
+ cls,
61
+ text_model_name=DEFAULT_TEXT_MODEL_NAME,
62
+ **kwargs,
63
+ ):
64
+ text_config = AutoConfig.from_pretrained(text_model_name, trust_remote_code=True)
65
+ if hasattr(text_config, "text_config"):
66
+ text_config = text_config.text_config
67
+
68
+ hidden_size = collect_arg_in_candidates(text_config, ["hidden_size", "embed_dim"])
69
+ num_hidden_layers = collect_arg_in_candidates(text_config, ["num_hidden_layers", "num_hidden_blocks"])
70
+ intermediate_size = collect_arg_in_candidates(text_config, ["intermediate_size", "mlp_dim"])
71
+ mlp_bias = collect_arg_in_candidates(text_config, ["mlp_bias", "mlp_hidden_bias"], default=False)
72
+ vocab_size = collect_arg_in_candidates(text_config, ["vocab_size"])
73
+
74
+ return cls(
75
+ text_model_name=text_model_name,
76
+ hidden_size=hidden_size,
77
+ num_hidden_layers=num_hidden_layers,
78
+ intermediate_size=intermediate_size,
79
+ mlp_bias=mlp_bias,
80
+ vocab_size=vocab_size,
81
+ **kwargs,
82
+ )
83
+
84
+ class ModernVBertVisionConfig(PretrainedConfig):
85
+ r"""
86
+ This is the configuration class to store the configuration of a [`SigLIP`]. It is used to instantiate the vision encoder part of the ModernVBERT
87
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
88
+ defaults will yield a similar configuration to that of the SigLIP.
89
+
90
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
91
+ documentation from [`PretrainedConfig`] for more information.
92
+ """
93
+ model_type = "modernvbert_vision"
94
+
95
+ attribute_map = {
96
+ "hidden_size": "embed_dim",
97
+ }
98
+
99
+ def __init__(
100
+ self,
101
+ vision_model_name=DEFAULT_VISION_MODEL_NAME,
102
+ embed_dim=768,
103
+ image_size=512,
104
+ patch_size=16,
105
+ num_hidden_layers=12,
106
+ intermediate_size=3072,
107
+ **kwargs,
108
+ ):
109
+ super().__init__(
110
+ vision_model_name=vision_model_name,
111
+ embed_dim=embed_dim,
112
+ image_size=image_size,
113
+ patch_size=patch_size,
114
+ num_hidden_layers=num_hidden_layers,
115
+ intermediate_size=intermediate_size,
116
+ **kwargs,
117
+ )
118
+
119
+ @classmethod
120
+ def from_base_model(
121
+ cls,
122
+ vision_model_name=DEFAULT_VISION_MODEL_NAME,
123
+ **kwargs,
124
+ ):
125
+ vision_config = AutoConfig.from_pretrained(vision_model_name, trust_remote_code=True)
126
+ if hasattr(vision_config, "vision_config"):
127
+ vision_config = vision_config.vision_config
128
+
129
+ embed_dim = collect_arg_in_candidates(vision_config, ["embed_dim", "hidden_size"])
130
+ image_size = collect_arg_in_candidates(vision_config, ["image_size", "img_size"])
131
+ patch_size = collect_arg_in_candidates(vision_config, ["patch_size"])
132
+ num_hidden_layers = collect_arg_in_candidates(vision_config, ["num_hidden_layers", "num_hidden_blocks"])
133
+ intermediate_size = collect_arg_in_candidates(vision_config, ["intermediate_size", "mlp_dim"])
134
+
135
+ return cls(
136
+ vision_model_name=vision_model_name,
137
+ embed_dim=embed_dim,
138
+ image_size=image_size,
139
+ patch_size=patch_size,
140
+ num_hidden_layers=num_hidden_layers,
141
+ intermediate_size=intermediate_size,
142
+ **kwargs,
143
+ )
144
+
145
+
146
+ class ModernVBertConfig(PretrainedConfig):
147
+ r"""
148
+ This is the configuration class to store the configuration of a `ModernVBert` model. It is used to
149
+ instantiate a ModernVBert model according to the specified arguments and defines the model architecture.
150
+
151
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs.
152
+ See the documentation for [`PretrainedConfig`] for more details.
153
+
154
+ Args:
155
+ text_config (`PretrainedConfig` or `dict`, optional):
156
+ Custom text config or a dict with a `text_model_name` key for the text encoder. If `None`, the
157
+ default text backbone defined by `DEFAULT_TEXT_MODEL_NAME` is used.
158
+ vision_config (`PretrainedConfig` or `dict`, optional):
159
+ Custom vision config or a dict with a `vision_model_name` key for the vision encoder. If `None`, the
160
+ default vision backbone defined by `DEFAULT_VISION_MODEL_NAME` is used.
161
+ image_token_id (`int`, optional, defaults to 128257):
162
+ Token id reserved for image tokens inserted into the text stream.
163
+ vocab_size (`int`, optional, defaults to 128256):
164
+ Vocabulary size used by the text embeddings.
165
+ use_cache (`bool`, optional, defaults to `True`):
166
+ Whether to cache key/value tensors for attention (relevant for decoder architectures).
167
+ tie_word_embeddings (`bool`, optional, defaults to `False`):
168
+ Whether to tie input token embeddings and output token embeddings.
169
+ pixel_shuffle_factor (`int`, optional, defaults to 4):
170
+ Scale factor used by any pixel-shuffle / upsampling operations in the vision head.
171
+ additional_vocab_size (`int`, optional, defaults to 0):
172
+ Number of extra tokens appended to the base vocabulary (useful for adapters / special tokens).
173
+ pad_token_id (`int`, optional):
174
+ Padding token id.
175
+ initializer_range (`float`, optional, defaults to 0.02):
176
+ Stddev used for weight initialization.
177
+ freeze_config (`Any`, optional):
178
+ Optional config describing which submodules to freeze during training.
179
+ use_resampler (`bool`, optional, defaults to `False`):
180
+ Whether to enable an additional resampler on visual features.
181
+ neftune_noise_alpha (`float`, optional, defaults to 0.0):
182
+ Alpha parameter for neftune noise injection.
183
+
184
+ Example:
185
+ ```python
186
+ >>> from modernvbert import ModernVBertConfig
187
+ >>> # Initializing configuration
188
+ >>> configuration = ModernVBertConfig()
189
+ >>> # Initializing a model from the configuration (model class is implemented in
190
+ >>> # `modernvbert.modeling_modernvbert`)
191
+ >>> # from modernvbert import ModernVBertModel
192
+ >>> # model = ModernVBertModel(configuration)
193
+ >>> # Accessing the model configuration
194
+ >>> # cfg = model.config
195
+ ```"""
196
+
197
+ model_type = "modernvbert"
198
+ is_composition = True
199
+
200
+ def __init__(
201
+ self,
202
+ text_config: Union[PretrainedConfig, Dict[str, Any]] = None,
203
+ vision_config: Union[PretrainedConfig, Dict[str, Any]] = None,
204
+ image_token_id: int = 128_257,
205
+ vocab_size=50368,
206
+ use_cache=True,
207
+ tie_word_embeddings=False,
208
+ freeze_config=None,
209
+ pad_token_id=None,
210
+ initializer_range=0.02,
211
+ pixel_shuffle_factor=4,
212
+ use_resampler=False,
213
+ additional_vocab_size=0,
214
+ neftune_noise_alpha=0.0,
215
+ **kwargs,
216
+ ):
217
+ self.image_token_id = image_token_id
218
+ self.use_cache = use_cache
219
+ self.tie_word_embeddings = tie_word_embeddings
220
+ self.scale_factor = pixel_shuffle_factor
221
+ self.additional_vocab_size = additional_vocab_size
222
+
223
+ if text_config is None:
224
+ base_text_config = AutoConfig.from_pretrained(DEFAULT_TEXT_MODEL_NAME, trust_remote_code=True)
225
+ text_config = ModernVBertTextConfig(base_text_config)
226
+ elif isinstance(text_config, dict):
227
+ text_config = ModernVBertTextConfig.from_dict(text_config)
228
+ self.text_config = text_config
229
+
230
+ if vision_config is None:
231
+ base_vision_config = AutoConfig.from_pretrained(DEFAULT_VISION_MODEL_NAME, trust_remote_code=True)
232
+ vision_config = ModernVBertVisionConfig(base_vision_config)
233
+ elif isinstance(vision_config, dict):
234
+ vision_config = ModernVBertVisionConfig.from_dict(vision_config)
235
+ self.vision_config = vision_config
236
+
237
+ self.freeze_config = freeze_config
238
+ self.pixel_shuffle_factor = pixel_shuffle_factor
239
+ self.use_resampler = use_resampler
240
+ self.neftune_noise_alpha = neftune_noise_alpha
241
+ self.initializer_range = initializer_range
242
+
243
+ hidden_size = kwargs.pop("hidden_size", self.text_config.hidden_size)
244
+
245
+ super().__init__(
246
+ **kwargs,
247
+ pad_token_id=pad_token_id,
248
+ tie_word_embeddings=tie_word_embeddings,
249
+ vocab_size=vocab_size,
250
+ hidden_size=hidden_size,
251
+ )
252
+
253
+ def to_dict(self):
254
+ output = copy.deepcopy(self.__dict__)
255
+ output["model_type"] = self.__class__.model_type
256
+ output["vision_config"] = self.vision_config.to_dict()
257
+ output["text_config"] = self.text_config.to_dict()
258
+ return output
259
+
260
+ @classmethod
261
+ def from_pretrained_models(
262
+ cls,
263
+ text_model_name: Union[str, os.PathLike],
264
+ vision_model_name: Union[str, os.PathLike],
265
+ **kwargs,
266
+ ) -> "PretrainedConfig":
267
+ text_model_config = ModernVBertTextConfig.from_base_model(text_model_name)
268
+ vision_model_config = ModernVBertVisionConfig.from_base_model(vision_model_name)
269
+ return cls(
270
+ text_config=text_model_config,
271
+ vision_config=vision_model_config,
272
+ **kwargs,
273
+ )import copy
274
+ import os
275
+ from typing import Any, Dict, Union
276
+
277
+ from transformers import AutoConfig
278
+ from transformers.configuration_utils import PretrainedConfig
279
+ from transformers.utils import logging
280
+
281
+ logger = logging.get_logger(__name__)
282
+
283
+ DEFAULT_TEXT_MODEL_NAME = "jhu-clsp/ettin-encoder-150m"
284
+ DEFAULT_VISION_MODEL_NAME = "google/siglip2-base-patch16-512"
285
+
286
+ def collect_arg_in_candidates(config, candidates, default=None) -> Any:
287
+ """Gets the first available argument in a config given a list of candidate names."""
288
+ for c in candidates:
289
+ if hasattr(config, c):
290
+ return getattr(config, c)
291
+ elif c in config:
292
+ return config[c]
293
+ if default is not None:
294
+ return default
295
+ raise ValueError(
296
+ f"No matching arguments found in candidates. Candidates: {candidates}, Config: {config}"
297
+ )
298
+
299
+ class ModernVBertTextConfig(PretrainedConfig):
300
+ r"""
301
+ This is the configuration class to store the configuration of a [`ModernBERT`]. It is used to instantiate an ModernBERT
302
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
303
+ defaults will yield a similar configuration to that of the [jhu-clsp/ettin-encoder-150m](https://huggingface.co/jhu-clsp/ettin-encoder-150m) architecture.
304
+
305
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
306
+ documentation from [`PretrainedConfig`] for more information.
307
+ """
308
+ model_type = "modernvbert_text"
309
+
310
+ def __init__(
311
+ self,
312
+ text_model_name=DEFAULT_TEXT_MODEL_NAME,
313
+ hidden_size=768,
314
+ num_hidden_layers=22,
315
+ intermediate_size=1152,
316
+ mlp_bias=False,
317
+ vocab_size=50368,
318
+ **kwargs,
319
+ ):
320
+ super().__init__(
321
+ text_model_name=text_model_name,
322
+ hidden_size=hidden_size,
323
+ num_hidden_layers=num_hidden_layers,
324
+ intermediate_size=intermediate_size,
325
+ mlp_bias=mlp_bias,
326
+ vocab_size=vocab_size,
327
+ **kwargs,
328
+ )
329
+
330
+ @classmethod
331
+ def from_base_model(
332
+ cls,
333
+ text_model_name=DEFAULT_TEXT_MODEL_NAME,
334
+ **kwargs,
335
+ ):
336
+ text_config = AutoConfig.from_pretrained(text_model_name, trust_remote_code=True)
337
+ if hasattr(text_config, "text_config"):
338
+ text_config = text_config.text_config
339
+
340
+ hidden_size = collect_arg_in_candidates(text_config, ["hidden_size", "embed_dim"])
341
+ num_hidden_layers = collect_arg_in_candidates(text_config, ["num_hidden_layers", "num_hidden_blocks"])
342
+ intermediate_size = collect_arg_in_candidates(text_config, ["intermediate_size", "mlp_dim"])
343
+ mlp_bias = collect_arg_in_candidates(text_config, ["mlp_bias", "mlp_hidden_bias"], default=False)
344
+ vocab_size = collect_arg_in_candidates(text_config, ["vocab_size"])
345
+
346
+ return cls(
347
+ text_model_name=text_model_name,
348
+ hidden_size=hidden_size,
349
+ num_hidden_layers=num_hidden_layers,
350
+ intermediate_size=intermediate_size,
351
+ mlp_bias=mlp_bias,
352
+ vocab_size=vocab_size,
353
+ **kwargs,
354
+ )
355
+
356
+ class ModernVBertVisionConfig(PretrainedConfig):
357
+ r"""
358
+ This is the configuration class to store the configuration of a [`SigLIP`]. It is used to instantiate the vision encoder part of the ModernVBERT
359
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
360
+ defaults will yield a similar configuration to that of the SigLIP.
361
+
362
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
363
+ documentation from [`PretrainedConfig`] for more information.
364
+ """
365
+ model_type = "modernvbert_vision"
366
+
367
+ attribute_map = {
368
+ "hidden_size": "embed_dim",
369
+ }
370
+
371
+ def __init__(
372
+ self,
373
+ vision_model_name=DEFAULT_VISION_MODEL_NAME,
374
+ embed_dim=768,
375
+ image_size=512,
376
+ patch_size=16,
377
+ num_hidden_layers=12,
378
+ intermediate_size=3072,
379
+ **kwargs,
380
+ ):
381
+ super().__init__(
382
+ vision_model_name=vision_model_name,
383
+ embed_dim=embed_dim,
384
+ image_size=image_size,
385
+ patch_size=patch_size,
386
+ num_hidden_layers=num_hidden_layers,
387
+ intermediate_size=intermediate_size,
388
+ **kwargs,
389
+ )
390
+
391
+ @classmethod
392
+ def from_base_model(
393
+ cls,
394
+ vision_model_name=DEFAULT_VISION_MODEL_NAME,
395
+ **kwargs,
396
+ ):
397
+ vision_config = AutoConfig.from_pretrained(vision_model_name, trust_remote_code=True)
398
+ if hasattr(vision_config, "vision_config"):
399
+ vision_config = vision_config.vision_config
400
+
401
+ embed_dim = collect_arg_in_candidates(vision_config, ["embed_dim", "hidden_size"])
402
+ image_size = collect_arg_in_candidates(vision_config, ["image_size", "img_size"])
403
+ patch_size = collect_arg_in_candidates(vision_config, ["patch_size"])
404
+ num_hidden_layers = collect_arg_in_candidates(vision_config, ["num_hidden_layers", "num_hidden_blocks"])
405
+ intermediate_size = collect_arg_in_candidates(vision_config, ["intermediate_size", "mlp_dim"])
406
+
407
+ return cls(
408
+ vision_model_name=vision_model_name,
409
+ embed_dim=embed_dim,
410
+ image_size=image_size,
411
+ patch_size=patch_size,
412
+ num_hidden_layers=num_hidden_layers,
413
+ intermediate_size=intermediate_size,
414
+ **kwargs,
415
+ )
416
+
417
+
418
+ class ModernVBertConfig(PretrainedConfig):
419
+ r"""
420
+ This is the configuration class to store the configuration of a `ModernVBert` model. It is used to
421
+ instantiate a ModernVBert model according to the specified arguments and defines the model architecture.
422
+
423
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs.
424
+ See the documentation for [`PretrainedConfig`] for more details.
425
+
426
+ Args:
427
+ text_config (`PretrainedConfig` or `dict`, optional):
428
+ Custom text config or a dict with a `text_model_name` key for the text encoder. If `None`, the
429
+ default text backbone defined by `DEFAULT_TEXT_MODEL_NAME` is used.
430
+ vision_config (`PretrainedConfig` or `dict`, optional):
431
+ Custom vision config or a dict with a `vision_model_name` key for the vision encoder. If `None`, the
432
+ default vision backbone defined by `DEFAULT_VISION_MODEL_NAME` is used.
433
+ image_token_id (`int`, optional, defaults to 128257):
434
+ Token id reserved for image tokens inserted into the text stream.
435
+ vocab_size (`int`, optional, defaults to 128256):
436
+ Vocabulary size used by the text embeddings.
437
+ use_cache (`bool`, optional, defaults to `True`):
438
+ Whether to cache key/value tensors for attention (relevant for decoder architectures).
439
+ tie_word_embeddings (`bool`, optional, defaults to `False`):
440
+ Whether to tie input token embeddings and output token embeddings.
441
+ pixel_shuffle_factor (`int`, optional, defaults to 4):
442
+ Scale factor used by any pixel-shuffle / upsampling operations in the vision head.
443
+ additional_vocab_size (`int`, optional, defaults to 0):
444
+ Number of extra tokens appended to the base vocabulary (useful for adapters / special tokens).
445
+ pad_token_id (`int`, optional):
446
+ Padding token id.
447
+ initializer_range (`float`, optional, defaults to 0.02):
448
+ Stddev used for weight initialization.
449
+ freeze_config (`Any`, optional):
450
+ Optional config describing which submodules to freeze during training.
451
+ use_resampler (`bool`, optional, defaults to `False`):
452
+ Whether to enable an additional resampler on visual features.
453
+ neftune_noise_alpha (`float`, optional, defaults to 0.0):
454
+ Alpha parameter for neftune noise injection.
455
+
456
+ Example:
457
+ ```python
458
+ >>> from modernvbert import ModernVBertConfig
459
+ >>> # Initializing configuration
460
+ >>> configuration = ModernVBertConfig()
461
+ >>> # Initializing a model from the configuration (model class is implemented in
462
+ >>> # `modernvbert.modeling_modernvbert`)
463
+ >>> # from modernvbert import ModernVBertModel
464
+ >>> # model = ModernVBertModel(configuration)
465
+ >>> # Accessing the model configuration
466
+ >>> # cfg = model.config
467
+ ```"""
468
+
469
+ model_type = "modernvbert"
470
+ is_composition = True
471
+
472
+ def __init__(
473
+ self,
474
+ text_config: Union[PretrainedConfig, Dict[str, Any]] = None,
475
+ vision_config: Union[PretrainedConfig, Dict[str, Any]] = None,
476
+ image_token_id: int = 128_257,
477
+ vocab_size=50368,
478
+ use_cache=True,
479
+ tie_word_embeddings=False,
480
+ freeze_config=None,
481
+ pad_token_id=None,
482
+ initializer_range=0.02,
483
+ pixel_shuffle_factor=4,
484
+ use_resampler=False,
485
+ additional_vocab_size=0,
486
+ neftune_noise_alpha=0.0,
487
+ **kwargs,
488
+ ):
489
+ self.image_token_id = image_token_id
490
+ self.use_cache = use_cache
491
+ self.tie_word_embeddings = tie_word_embeddings
492
+ self.scale_factor = pixel_shuffle_factor
493
+ self.additional_vocab_size = additional_vocab_size
494
+
495
+ if text_config is None:
496
+ base_text_config = AutoConfig.from_pretrained(DEFAULT_TEXT_MODEL_NAME, trust_remote_code=True)
497
+ text_config = ModernVBertTextConfig(base_text_config)
498
+ elif isinstance(text_config, dict):
499
+ text_config = ModernVBertTextConfig.from_dict(text_config)
500
+ self.text_config = text_config
501
+
502
+ if vision_config is None:
503
+ base_vision_config = AutoConfig.from_pretrained(DEFAULT_VISION_MODEL_NAME, trust_remote_code=True)
504
+ vision_config = ModernVBertVisionConfig(base_vision_config)
505
+ elif isinstance(vision_config, dict):
506
+ vision_config = ModernVBertVisionConfig.from_dict(vision_config)
507
+ self.vision_config = vision_config
508
+
509
+ self.freeze_config = freeze_config
510
+ self.pixel_shuffle_factor = pixel_shuffle_factor
511
+ self.use_resampler = use_resampler
512
+ self.neftune_noise_alpha = neftune_noise_alpha
513
+ self.initializer_range = initializer_range
514
+
515
+ hidden_size = kwargs.pop("hidden_size", self.text_config.hidden_size)
516
+
517
+ super().__init__(
518
+ **kwargs,
519
+ pad_token_id=pad_token_id,
520
+ tie_word_embeddings=tie_word_embeddings,
521
+ vocab_size=vocab_size,
522
+ hidden_size=hidden_size,
523
+ )
524
+
525
+ def to_dict(self):
526
+ output = copy.deepcopy(self.__dict__)
527
+ output["model_type"] = self.__class__.model_type
528
+ output["vision_config"] = self.vision_config.to_dict()
529
+ output["text_config"] = self.text_config.to_dict()
530
+ return output
531
+
532
+ @classmethod
533
+ def from_pretrained_models(
534
+ cls,
535
+ text_model_name: Union[str, os.PathLike],
536
+ vision_model_name: Union[str, os.PathLike],
537
+ **kwargs,
538
+ ) -> "PretrainedConfig":
539
+ text_model_config = ModernVBertTextConfig.from_base_model(text_model_name)
540
+ vision_model_config = ModernVBertVisionConfig.from_base_model(vision_model_name)
541
+ return cls(
542
+ text_config=text_model_config,
543
+ vision_config=vision_model_config,
544
+ **kwargs,
545
+ )
configuration_vbert.py DELETED
@@ -1,233 +0,0 @@
1
- import copy
2
- import os
3
-
4
- from typing import Union, Any, Dict
5
-
6
- from transformers.configuration_utils import PretrainedConfig
7
- from transformers.utils import logging
8
- from transformers import CONFIG_MAPPING, AutoConfig
9
-
10
- logger = logging.get_logger(__name__)
11
-
12
- def collect_arg_in_candidates(config, candidates, default = None) -> Any:
13
- """ Gets the argument in a config given a list of candidates """
14
- for c in candidates:
15
- if hasattr(config, c):
16
- return getattr(config, c)
17
- elif c in config:
18
- return config[c]
19
- if default is not None:
20
- return default
21
- raise ValueError("No matching arguments found in candidates. Candidates: {}, Config: {}".format(candidates, config))
22
-
23
- class VBertTextConfig(PretrainedConfig):
24
- r"""
25
- This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
26
- model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
27
- defaults will yield a similar configuration to that of the LLaMA-7B.
28
-
29
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
30
- documentation from [`PretrainedConfig`] for more information.
31
-
32
- Args:
33
- embed_dim (`int`, *optional*, defaults to 1152):
34
- Dimensionality of the encoder layers and the pooler layer. (elsewhere referred to as `embed_dim`)
35
- image_size (`int`, *optional*, defaults to 384):
36
- The size (resolution) of each image.
37
- """
38
- model_type = "vbert"
39
-
40
- def __init__(
41
- self,
42
- # Case for when vllama3 is from the hub with no vision_model_name
43
- text_model_name="EuroBERT/EuroBERT-210m",
44
- **kwargs,
45
- ):
46
- self.text_model_name = text_model_name
47
- text_config = AutoConfig.from_pretrained(text_model_name, trust_remote_code=True)
48
- if hasattr(text_config, "text_config"):
49
- text_config = text_config.text_config
50
-
51
- self.hidden_size = collect_arg_in_candidates(text_config, ["hidden_size", "embed_dim"])
52
- self.num_hidden_layers = collect_arg_in_candidates(text_config, ["num_hidden_layers", "num_hidden_blocks"])
53
- self.intermediate_size = collect_arg_in_candidates(text_config, ["intermediate_size", "mlp_dim"])
54
- self.mlp_bias = collect_arg_in_candidates(text_config, ["mlp_bias", "mlp_hidden_bias"], default = False)
55
- self.vocab_size = collect_arg_in_candidates(text_config, ["vocab_size"])
56
-
57
- super().__init__(text_model_name=text_model_name, **kwargs)
58
-
59
- class VBertVisionConfig(PretrainedConfig):
60
- r"""
61
- This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
62
- model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
63
- defaults will yield a similar configuration to that of the LLaMA-7B.
64
-
65
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
66
- documentation from [`PretrainedConfig`] for more information.
67
-
68
- Args:
69
- embed_dim (`int`, *optional*, defaults to 1152):
70
- Dimensionality of the encoder layers and the pooler layer. (elsewhere referred to as `embed_dim`)
71
- image_size (`int`, *optional*, defaults to 384):
72
- The size (resolution) of each image.
73
- """
74
- model_type = "vbert"
75
- attribute_map = {
76
- "hidden_size": "embed_dim",
77
- }
78
-
79
- def __init__(
80
- self,
81
- # Case for when vllama3 is from the hub with no vision_model_name
82
- vision_model_name="google/siglip2-base-patch16-512",
83
- **kwargs,
84
- ):
85
- self.vision_model_name = vision_model_name
86
- vision_config = AutoConfig.from_pretrained(vision_model_name, trust_remote_code=True)
87
- if hasattr(vision_config, "vision_config"):
88
- vision_config = vision_config.vision_config
89
-
90
- self.embed_dim = collect_arg_in_candidates(vision_config, ["embed_dim", "hidden_size"])
91
- self.image_size = collect_arg_in_candidates(vision_config, ["image_size", "img_size"])
92
- self.patch_size = collect_arg_in_candidates(vision_config, ["patch_size"])
93
- self.num_hidden_layers = collect_arg_in_candidates(vision_config, ["num_hidden_layers", "num_hidden_blocks"])
94
- self.intermediate_size = collect_arg_in_candidates(vision_config, ["intermediate_size", "mlp_dim"])
95
-
96
- super().__init__(vision_model_name=vision_model_name, **kwargs)
97
-
98
- class VBertConfig(PretrainedConfig):
99
- r"""
100
- This is the configuration class to store the configuration of a [`SmolVLMModel`]. It is used to instantiate a
101
- SmolVLM model according to the specified arguments, defining the model architecture. Instantiating a
102
- configuration with the defaults will yield a similar configuration to that of the model of the SmolVLM
103
- [HuggingFaceTB/SmolVLM2-2.2B-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM2-2.2B-Instruct) architecture.
104
-
105
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
106
- documentation from [`PretrainedConfig`] for more information.
107
-
108
- Args:
109
- use_cache (`bool`, *optional*, defaults to `True`):
110
- Whether or not the model should cache the key/value pairs of the attention mechanism. Only
111
- relevant if `config.is_decoder=True`.
112
- image_token_id (`int`, *optional*, defaults to 128257):
113
- The id of the "image" token.
114
- tie_word_embeddings (`bool`, *optional*, defaults to `False`):
115
- Whether or not to tie the word embeddings with the token embeddings.
116
- vision_config (`IdeficsVisionConfig` or `dict`, *optional*, defaults to `IdeficsVisionConfig`):
117
- Custom vision config or dict for the vision tower
118
- text_config (`PretrainedConfig` or `dict`, *optional*, defaults to `LlamaConfig`):
119
- Custom text config or dict for the text model
120
- scale_factor (`int`, *optional*, defaults to 2):
121
- The scale factor for the image encoder.
122
- pad_token_id (`int`, *optional*, defaults to 128002):
123
- The id of the padding token.
124
-
125
- Example:
126
- ```python
127
- >>> from transformers import SmolVLMModel, SmolVLMConfig
128
- >>> # Initializing configuration
129
- >>> configuration = SmolVLMConfig()
130
- >>> # Initializing a model from the configuration
131
- >>> model = SmolVLMModel(configuration)
132
- >>> # Accessing the model configuration
133
- >>> configuration = model.config
134
- ```"""
135
-
136
- model_type = "vbert"
137
- is_composition = True
138
- # sub_configs = {"text_config": VBertTextConfig, "vision_config": VBertVisionConfig}
139
-
140
- DEFAULT_TEXT_MODEL_NAME = "EuroBERT/EuroBERT-210m"
141
- DEFAULT_VISION_MODEL_NAME = "google/siglip2-base-patch16-512"
142
-
143
- def __init__(
144
- self,
145
- text_config: Union[PretrainedConfig, Dict[str, Any]] = None,
146
- vision_config: Union[PretrainedConfig, Dict[str, Any]] = None,
147
- image_token_id: int = 128_257,
148
- vocab_size=128_256,
149
- use_cache = True,
150
- tie_word_embeddings = False,
151
- freeze_config = None,
152
- pad_token_id = None,
153
- initializer_range = 0.02,
154
- pixel_shuffle_factor = 4,
155
- use_resampler = False,
156
- additional_vocab_size = 0,
157
- neftune_noise_alpha = 0.0,
158
- **kwargs,
159
- ):
160
- self.image_token_id = image_token_id
161
- self.use_cache = use_cache
162
- self.tie_word_embeddings = tie_word_embeddings
163
- self.scale_factor = pixel_shuffle_factor
164
- self.additional_vocab_size = additional_vocab_size
165
-
166
- if text_config is None:
167
- text_config = AutoConfig.from_pretrained(self.DEFAULT_TEXT_MODEL_NAME, trust_remote_code=True)
168
- elif isinstance(text_config, dict):
169
- text_config = VBertTextConfig(text_config["text_model_name"])
170
- self.text_config = text_config
171
-
172
- if vision_config is None:
173
- vision_config = AutoConfig.from_pretrained(self.DEFAULT_VISION_MODEL_NAME, trust_remote_code=True)
174
- elif isinstance(vision_config, dict):
175
- vision_config = VBertVisionConfig(vision_config["vision_model_name"])
176
- self.vision_config = vision_config
177
-
178
- self.freeze_config = freeze_config
179
-
180
- # Pixel shuffle factor
181
- self.pixel_shuffle_factor = pixel_shuffle_factor
182
- self.use_resampler = use_resampler
183
-
184
- self.neftune_noise_alpha = neftune_noise_alpha
185
-
186
- self.initializer_range = initializer_range
187
-
188
- hidden_size = kwargs.pop("hidden_size", self.text_config.hidden_size)
189
-
190
- super().__init__(
191
- **kwargs,
192
- pad_token_id=pad_token_id,
193
- tie_word_embeddings=tie_word_embeddings,
194
- vocab_size=vocab_size,
195
- hidden_size=hidden_size,
196
- )
197
-
198
- def to_dict(self):
199
- """
200
- Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
201
- Returns:
202
- `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
203
- """
204
- output = copy.deepcopy(self.__dict__)
205
-
206
- output["model_type"] = self.__class__.model_type
207
- output["vision_config"] = self.vision_config.to_dict()
208
- output["text_config"] = self.text_config.to_dict()
209
- # output["freeze_config"] = self.freeze_config.to_dict()
210
-
211
- return output
212
-
213
- # @classmethod
214
- # def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
215
- # outputs = super(VBertConfig, cls).from_pretrained(pretrained_model_name_or_path, **kwargs)
216
- # return outputs
217
-
218
- @classmethod
219
- def from_pretrained_models(
220
- cls,
221
- text_model_name: Union[str, os.PathLike],
222
- vision_model_name: Union[str, os.PathLike],
223
- **kwargs
224
- ) -> "PretrainedConfig":
225
- # text_model_config = AutoConfig.from_pretrained(text_model_name, trust_remote_code=True)
226
- # vision_model_config = AutoConfig.from_pretrained(vision_model_name, trust_remote_code=True)
227
- text_model_config = VBertTextConfig(text_model_name)
228
- vision_model_config = VBertVisionConfig(vision_model_name)
229
- return cls(
230
- text_config=text_model_config,
231
- vision_config=vision_model_config,
232
- **kwargs
233
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling_vbert.py → modeling_modernvbert.py RENAMED
@@ -1,25 +1,15 @@
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  from torch.nn import CrossEntropyLoss
5
- from typing import Optional, Tuple, Union, List
6
-
7
- from transformers.cache_utils import DynamicCache
8
-
9
- from .configuration_vbert import VBertConfig
10
-
11
- from transformers import AutoModel, AutoConfig, AutoModelForMaskedLM, PreTrainedModel
12
  from transformers.modeling_outputs import BaseModelOutput
13
  from transformers.models.bert.modeling_bert import BaseModelOutputWithPoolingAndCrossAttentions, MaskedLMOutput
14
 
15
- from typing import List, Optional, Tuple, Union
16
-
17
- import torch
18
- import torch.utils.checkpoint
19
-
20
- from dataclasses import dataclass
21
-
22
- from transformers import logging
23
 
24
  logger = logging.get_logger(__name__)
25
 
@@ -51,6 +41,7 @@ class DecoupledEmbedding(nn.Embedding):
51
  """
52
  if padding_idx is not None and padding_idx > num_embeddings:
53
  raise ValueError(f"padding_idx must be within num_embeddings. Got {padding_idx} and {num_embeddings}")
 
54
  super().__init__(
55
  num_embeddings=num_embeddings,
56
  embedding_dim=embedding_dim,
@@ -60,7 +51,6 @@ class DecoupledEmbedding(nn.Embedding):
60
  **kwargs,
61
  )
62
  self.num_embeddings = num_embeddings
63
- self.padding_idx = padding_idx
64
  self.num_additional_embeddings = num_additional_embeddings
65
  self.partially_freeze = partially_freeze
66
 
@@ -69,7 +59,7 @@ class DecoupledEmbedding(nn.Embedding):
69
 
70
  if self.num_additional_embeddings > 0:
71
  self.additional_embedding = nn.Embedding(
72
- num_embeddings=self.num_additional_embeddings,
73
  embedding_dim=embedding_dim,
74
  device=device,
75
  dtype=dtype,
@@ -97,9 +87,8 @@ class DecoupledEmbedding(nn.Embedding):
97
 
98
  """
99
  if self.num_additional_embeddings == 0:
100
- return self.additional_embedding(input_ids)
101
 
102
- # Clone so that we don't modify the original input_ids later on
103
  input_ids = input_ids.clone()
104
  additional_vocab_indices = torch.where(input_ids >= self.num_embeddings)
105
  input_ids_additional_vocab = input_ids[additional_vocab_indices]
@@ -108,37 +97,19 @@ class DecoupledEmbedding(nn.Embedding):
108
  # for successful lookup replace input_ids with 0, the results of these will be discarded anyway
109
  input_ids[additional_vocab_indices] = 0
110
  full_vector = F.embedding(input_ids, self.weight)
111
-
112
- # overwrite the records with high indices
113
- full_vector[additional_vocab_indices] = additional_embeddings
114
-
115
  return full_vector
116
 
117
- def extra_repr(self) -> str:
118
- return "num_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format(
119
- self.num_embeddings,
120
- self.num_additional_embeddings,
121
- self.embedding_dim,
122
- self.partially_freeze,
123
- )
124
-
125
  @dataclass
126
- class VBertBaseModelOutput(BaseModelOutput):
127
  """
128
- Base class for SmolVLM model's outputs that may also contain a past key/values (to speed up sequential decoding).
129
  Args:
130
  last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
131
  Sequence of hidden-states at the output of the last layer of the model.
132
  If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
133
  hidden_size)` is output.
134
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
135
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
136
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
137
- `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
138
- encoder_sequence_length, embed_size_per_head)`.
139
- Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
140
- `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
141
- input) to speed up sequential decoding.
142
  hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
143
  Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
144
  one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
@@ -153,16 +124,16 @@ class VBertBaseModelOutput(BaseModelOutput):
153
  sequence_length, hidden_size)`.
154
  image_hidden_states of the model produced by the vision encoder
155
  """
156
-
157
  last_hidden_state: torch.FloatTensor = None
158
  hidden_states: Optional[Tuple[torch.FloatTensor]] = None
159
  attentions: Optional[Tuple[torch.FloatTensor]] = None
160
  image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
161
 
 
162
  @dataclass
163
- class VBertMaskedLMOutput(MaskedLMOutput):
164
  """
165
- Base class for SmolVLM model's outputs that may also contain a past key/values (to speed up sequential decoding).
166
  Args:
167
  loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
168
  Masked language modeling (MLM) loss.
@@ -188,7 +159,9 @@ class VBertMaskedLMOutput(MaskedLMOutput):
188
  attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
189
  image_hidden_states: Optional[torch.FloatTensor] = None
190
 
191
- class VBertSimpleMLP(nn.Module):
 
 
192
  def __init__(self, input_size, output_size):
193
  super().__init__()
194
  self.proj = nn.Linear(input_size, output_size, bias=False)
@@ -196,13 +169,18 @@ class VBertSimpleMLP(nn.Module):
196
  def forward(self, x):
197
  return self.proj(x)
198
 
199
- class VBertConnector(nn.Module):
 
 
 
 
 
200
  def __init__(self, config):
201
  super().__init__()
202
  self.scale_factor = config.pixel_shuffle_factor
203
- self.modality_projection = VBertSimpleMLP(
204
  input_size=config.vision_config.hidden_size * (config.scale_factor**2),
205
- output_size=config.text_config.hidden_size
206
  )
207
 
208
  def pixel_shuffle(self, x, scale_factor):
@@ -213,36 +191,25 @@ class VBertConnector(nn.Module):
213
  x = x.permute(0, 2, 1, 3)
214
  x = x.reshape(bsz, int(width / scale_factor), int(height / scale_factor), embed_dim * (scale_factor**2))
215
  x = x.permute(0, 2, 1, 3)
216
- x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))
217
- return x
218
 
219
  def forward(self, image_hidden_states):
220
  image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor)
221
- image_hidden_states = self.modality_projection(image_hidden_states)
222
- return image_hidden_states
223
 
224
- class VBertPreTrainedModel(PreTrainedModel):
225
- config_class = VBertConfig
226
  base_model_prefix = "model"
227
  supports_gradient_checkpointing = True
228
- _no_split_modules = ["VBertDecoderLayer"]
229
  _skip_keys_device_placement = "past_key_values"
230
  _supports_flash_attn_2 = True
231
  _supports_sdpa = True
232
  _supports_cache_class = True
233
 
234
  def _init_weights(self, module):
235
- """Initialize the weights."""
236
-
237
- std = (
238
- self.config.initializer_range
239
- if hasattr(self.config, "initializer_range")
240
- else self.config.text_config.initializer_range
241
- )
242
-
243
- if hasattr(module, "class_embedding"):
244
- module.class_embedding.data.normal_(mean=0.0, std=std)
245
-
246
  if isinstance(module, (nn.Linear, nn.Conv2d)):
247
  module.weight.data.normal_(mean=0.0, std=std)
248
  if module.bias is not None:
@@ -252,53 +219,41 @@ class VBertPreTrainedModel(PreTrainedModel):
252
  if module.padding_idx is not None:
253
  module.weight.data[module.padding_idx].zero_()
254
 
255
- class VBertModel(VBertPreTrainedModel):
256
- """
257
- A subclass of Idefics3Model. We do *not* remove or block the call to inputs_merger
258
- in forward. Instead, we override inputs_merger here with custom logic.
259
- """
260
 
261
- def __init__(self, config: VBertConfig, **kwargs):
 
262
  super().__init__(config)
263
-
264
- self.vision_model = VBertModel.init_vision_model(config, **kwargs)
265
- self.connector = VBertConnector(config)
266
- self.text_model = VBertModel.init_language_model(config, **kwargs)
267
-
268
  self.image_seq_len = int(
269
  ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2)
270
  )
271
- self.image_token_id = self.config.image_token_id
272
-
273
  self.post_init()
274
 
275
  @staticmethod
276
- def init_vision_model(config: VBertConfig, **kwargs):
277
  vision_model_config = AutoConfig.from_pretrained(
278
  config.vision_config.vision_model_name,
279
- trust_remote_code=True,
 
280
  **kwargs,
281
  )
282
-
283
  vision_model = AutoModel.from_config(vision_model_config, trust_remote_code=True, **kwargs)
284
-
285
- if hasattr(vision_model, "vision_model"):
286
- # If the model has a vision_model attribute, it means it's a wrapper around another model
287
- vision_model = vision_model.vision_model
288
-
289
- return vision_model
290
 
291
  @staticmethod
292
- def init_language_model(config: VBertConfig, **kwargs):
293
  text_model_config = AutoConfig.from_pretrained(
294
  config.text_config.text_model_name,
 
 
295
  trust_remote_code=True,
296
  **kwargs,
297
  )
298
-
299
  text_model = AutoModel.from_config(text_model_config, trust_remote_code=True, **kwargs)
300
- # extractor = regex_lookup(language_model_name, language_model_name2model)
301
-
302
  embed_layer = DecoupledEmbedding(
303
  num_embeddings=text_model_config.vocab_size,
304
  num_additional_embeddings=config.additional_vocab_size,
@@ -306,11 +261,9 @@ class VBertModel(VBertPreTrainedModel):
306
  partially_freeze=config.freeze_config["freeze_text_layers"],
307
  padding_idx=config.pad_token_id,
308
  )
309
-
310
  text_model.set_input_embeddings(embed_layer)
311
-
312
  return text_model
313
-
314
  def enable_input_require_grads(self):
315
  """
316
  Enables the gradients for the input embeddings.
@@ -337,20 +290,15 @@ class VBertModel(VBertPreTrainedModel):
337
  make_inputs_require_grads
338
  )
339
 
340
- def disable_input_require_grads(self):
341
- self._text_require_grads_hook.remove()
342
- self._vision_require_grads_hook.remove()
343
-
344
  def get_input_embeddings(self):
345
  return self.text_model.get_input_embeddings()
346
 
347
  def set_input_embeddings(self, value):
348
  self.text_model.set_input_embeddings(value)
349
 
350
- def inputs_merger(
351
- self, input_ids: torch.LongTensor, inputs_embeds: torch.Tensor, image_hidden_states: torch.Tensor
352
- ):
353
- """
354
  This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM.
355
  The merging happens as follows:
356
  - The text token sequence is: `tok_1 tok_2 tok_3 <fake_token_around_image> <image> <image> ... <image> <fake_token_around_image> tok_4`.
@@ -359,135 +307,57 @@ class VBertModel(VBertPreTrainedModel):
359
  - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
360
  - To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
361
  """
362
- _, patch_size, _ = image_hidden_states.shape
363
 
 
364
  image_mask = input_ids == self.image_token_id
365
  num_image_tokens = image_mask.sum(dim=1)
366
  if not torch.all(num_image_tokens % patch_size == 0):
367
- raise ValueError("At least one sample has <image> tokens not divisible by patch_size.")
368
-
369
  blocks_per_sample = num_image_tokens // patch_size
370
-
371
  offsets = torch.nn.functional.pad(blocks_per_sample.cumsum(dim=0), (1, 0), value=0)
372
  block_offset = offsets[:-1]
373
  row_cum = image_mask.cumsum(dim=-1)
374
  chunk_idx = (row_cum - 1) // patch_size
375
  local_idx = (row_cum - 1) % patch_size
376
  block_idx = block_offset.unsqueeze(1) + chunk_idx
377
-
378
  image_embeds = torch.zeros_like(inputs_embeds)
379
  image_embeds[image_mask] = image_hidden_states[block_idx[image_mask], local_idx[image_mask], :]
380
-
381
- merged_embeds = torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds)
382
- return merged_embeds
383
 
384
  def forward(
385
  self,
386
  input_ids: torch.LongTensor = None,
387
  attention_mask: Optional[torch.Tensor] = None,
388
  position_ids: Optional[torch.LongTensor] = None,
389
- past_key_values: Optional[List[torch.FloatTensor]] = None,
390
  inputs_embeds: Optional[torch.FloatTensor] = None,
391
  pixel_values: Optional[torch.FloatTensor] = None,
392
  pixel_attention_mask: Optional[torch.BoolTensor] = None,
393
  image_hidden_states: Optional[torch.FloatTensor] = None,
394
- use_cache: Optional[bool] = None,
395
  output_attentions: Optional[bool] = None,
396
  output_hidden_states: Optional[bool] = None,
397
  return_dict: Optional[bool] = None,
398
- cache_position: Optional[torch.LongTensor] = None,
399
- ) -> Union[Tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
400
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
401
  output_hidden_states = (
402
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
403
  )
404
- use_cache = use_cache if use_cache is not None else self.config.use_cache
405
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
406
-
407
- if self.training and self.text_model.gradient_checkpointing and use_cache:
408
- logger.warning_once(
409
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
410
- )
411
- use_cache = False
412
-
413
- # retrieve input_ids and inputs_embeds
414
- if input_ids is not None:
415
- batch_size, seq_length = input_ids.shape
416
- elif inputs_embeds is not None:
417
- batch_size, seq_length, _ = inputs_embeds.shape
418
- else:
419
- raise ValueError("You have to specify either input_ids or inputs_embeds")
420
-
421
- past_seen_tokens = 0
422
- if use_cache:
423
- if past_key_values is None:
424
- past_key_values = DynamicCache()
425
- past_seen_tokens = past_key_values.get_seq_length()
426
-
427
- if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0:
428
- raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.")
429
-
430
  if inputs_embeds is None:
431
  inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(input_ids.device)
432
-
433
- # START VISUAL INPUTS INTEGRATION
434
- if pixel_values is not None and image_hidden_states is not None:
435
- raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
436
- elif pixel_values is not None:
437
- batch_size, num_images, num_channels, height, width = pixel_values.shape
438
- pixel_values = pixel_values
439
  pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
440
-
441
- # Remove padding images - padding images are full 0.
442
  nb_values_per_image = pixel_values.shape[1:].numel()
443
  real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
444
-
445
  if not any(real_images_inds):
446
- # no images, leave one empty image.
447
  real_images_inds[0] = True
448
-
449
  pixel_values = pixel_values[real_images_inds].contiguous()
450
-
451
- # Handle the vision attention mask
452
- if pixel_attention_mask is None:
453
- pixel_attention_mask = torch.ones(
454
- size=[pixel_values.shape[i] for i in (0, 2, 3)],
455
- dtype=torch.bool,
456
- device=pixel_values.device,
457
- )
458
- else:
459
- # Remove padding images from the mask
460
- pixel_attention_mask = pixel_attention_mask.view(
461
- batch_size * num_images, *pixel_attention_mask.shape[2:]
462
- )
463
- pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
464
-
465
- # patch_size = self.config.vision_config.patch_size
466
- # patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
467
- # patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
468
- # patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
469
-
470
- # Get sequence from the vision encoder
471
- image_hidden_states = self.vision_model(
472
- pixel_values=pixel_values,
473
- # patch_attention_mask=patch_attention_mask,
474
- ).last_hidden_state
475
-
476
- # Modality projection & resampling
477
  image_hidden_states = self.connector(image_hidden_states)
478
-
479
  elif image_hidden_states is not None:
480
  image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
481
-
482
  if inputs_embeds is not None and image_hidden_states is not None:
483
- # When we embed, we don't want to replace the potential image_token_id that we generated by images
484
- # that simply don't exist
485
- inputs_embeds = self.inputs_merger(
486
- input_ids=input_ids,
487
- inputs_embeds=inputs_embeds,
488
- image_hidden_states=image_hidden_states,
489
- )
490
-
491
  outputs = self.text_model(
492
  inputs_embeds=inputs_embeds,
493
  attention_mask=attention_mask,
@@ -495,138 +365,88 @@ class VBertModel(VBertPreTrainedModel):
495
  output_attentions=output_attentions,
496
  output_hidden_states=output_hidden_states,
497
  return_dict=return_dict,
498
- # past_key_values=past_key_values,
499
- # use_cache=use_cache,
500
- # cache_position=cache_position,
501
  )
502
-
503
  if not return_dict:
504
  return tuple(v for v in [*outputs, image_hidden_states] if v is not None)
505
-
506
- return VBertBaseModelOutput(
507
  last_hidden_state=outputs.last_hidden_state,
508
  hidden_states=outputs.hidden_states,
509
  attentions=outputs.attentions,
510
  image_hidden_states=image_hidden_states,
511
  )
512
 
513
- class VBertLMHead(nn.Module):
514
  def __init__(self, config, **kwargs):
515
  super().__init__()
516
- pretrained_config = AutoConfig.from_pretrained(
517
- config.text_config.text_model_name,
518
- trust_remote_code=True,
519
- **kwargs,
520
- )
521
  pretrained_model = AutoModelForMaskedLM.from_config(pretrained_config, trust_remote_code=True, **kwargs)
522
-
523
  self.head = pretrained_model.head
524
  self.decoder = pretrained_model.decoder
525
 
526
  def forward(self, hidden_states):
527
- hidden_states = self.head(hidden_states)
528
- hidden_states = self.decoder(hidden_states)
529
- return hidden_states
530
 
531
- class VBertForMaskedLM(VBertPreTrainedModel):
532
- # _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
533
 
 
534
  def __init__(self, config, **kwargs):
535
  super().__init__(config)
536
-
537
  self.image_token_id = config.image_token_id
538
  self.in_features = config.hidden_size
539
  self.out_additional_features = config.additional_vocab_size
540
  self.vocab_size = config.vocab_size
541
-
542
- if config.is_decoder:
543
- logger.warning(
544
- "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
545
- "bi-directional self-attention."
546
- )
547
-
548
- self.model = VBertModel(config, **kwargs)
549
- self.lm_head = VBertLMHead(config, **kwargs)
550
  if self.out_additional_features > 0:
551
- self.additional_fc = nn.Linear(
552
- in_features=self.in_features,
553
- out_features=self.out_additional_features,
554
- bias=False,
555
- )
556
-
557
- # Initialize weights and apply final processing
558
  self.post_init()
559
 
560
  def forward(
561
- self,
562
- input_ids: torch.LongTensor = None,
563
- attention_mask: Optional[torch.Tensor] = None,
564
- position_ids: Optional[torch.LongTensor] = None,
565
- past_key_values: Optional[List[torch.FloatTensor]] = None,
566
- inputs_embeds: Optional[torch.FloatTensor] = None,
567
- pixel_values: Optional[torch.FloatTensor] = None,
568
- pixel_attention_mask: Optional[torch.BoolTensor] = None,
569
- image_hidden_states: Optional[torch.FloatTensor] = None,
570
- labels: Optional[torch.LongTensor] = None,
571
- use_cache: Optional[bool] = None,
572
- output_attentions: Optional[bool] = None,
573
- output_hidden_states: Optional[bool] = None,
574
- return_dict: Optional[bool] = None,
575
- ) -> Union[Tuple, VBertMaskedLMOutput]:
576
- r"""
577
- Args:
578
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
579
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
580
- config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `Idefics3ForConditionalGeneration`).
581
- Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only
582
- computed for the tokens with labels in `[0, ..., config.vocab_size]`.
583
- ```"""
584
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
585
  output_hidden_states = (
586
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
587
  )
588
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
589
 
590
-
591
- # Pass the inputs to VBertModel
592
  outputs = self.model(
593
  input_ids=input_ids,
594
  attention_mask=attention_mask,
595
  position_ids=position_ids,
596
- past_key_values=past_key_values,
597
  inputs_embeds=inputs_embeds,
598
  pixel_values=pixel_values,
599
  pixel_attention_mask=pixel_attention_mask,
600
  image_hidden_states=image_hidden_states,
601
- use_cache=use_cache,
602
  output_attentions=output_attentions,
603
  output_hidden_states=output_hidden_states,
604
  return_dict=return_dict,
605
  )
606
-
607
- # Pass the outputs to the MLM head
608
  hidden_states = outputs[0]
609
-
610
  logits = self.lm_head(hidden_states)
611
  if self.out_additional_features > 0:
612
  proj_states = self.lm_head.head(hidden_states)
613
  additional_features = self.additional_fc(proj_states)
614
  logits = torch.cat((logits, additional_features), -1)
615
- logits = logits.float()
616
-
617
- masked_lm_loss = None
618
  if labels is not None:
619
- # print the ratio of not ignored tokens
620
- loss_fct = CrossEntropyLoss()
621
- masked_lm_loss = loss_fct(logits.view(-1, self.vocab_size + self.out_additional_features), labels.view(-1))
622
-
623
  if not return_dict:
624
  output = (logits,) + outputs[2:]
625
- return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
626
-
627
- return VBertMaskedLMOutput(
628
- loss=masked_lm_loss,
629
- logits=logits,
630
  hidden_states=outputs.hidden_states,
631
  attentions=outputs.attentions,
632
  image_hidden_states=outputs.image_hidden_states,
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional, Tuple, Union
3
+
4
  import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
  from torch.nn import CrossEntropyLoss
8
+ from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM, PreTrainedModel, logging
 
 
 
 
 
 
9
  from transformers.modeling_outputs import BaseModelOutput
10
  from transformers.models.bert.modeling_bert import BaseModelOutputWithPoolingAndCrossAttentions, MaskedLMOutput
11
 
12
+ from .configuration_modernvbert import ModernVBertConfig
 
 
 
 
 
 
 
13
 
14
  logger = logging.get_logger(__name__)
15
 
 
41
  """
42
  if padding_idx is not None and padding_idx > num_embeddings:
43
  raise ValueError(f"padding_idx must be within num_embeddings. Got {padding_idx} and {num_embeddings}")
44
+
45
  super().__init__(
46
  num_embeddings=num_embeddings,
47
  embedding_dim=embedding_dim,
 
51
  **kwargs,
52
  )
53
  self.num_embeddings = num_embeddings
 
54
  self.num_additional_embeddings = num_additional_embeddings
55
  self.partially_freeze = partially_freeze
56
 
 
59
 
60
  if self.num_additional_embeddings > 0:
61
  self.additional_embedding = nn.Embedding(
62
+ num_embeddings=num_additional_embeddings,
63
  embedding_dim=embedding_dim,
64
  device=device,
65
  dtype=dtype,
 
87
 
88
  """
89
  if self.num_additional_embeddings == 0:
90
+ return super().forward(input_ids)
91
 
 
92
  input_ids = input_ids.clone()
93
  additional_vocab_indices = torch.where(input_ids >= self.num_embeddings)
94
  input_ids_additional_vocab = input_ids[additional_vocab_indices]
 
97
  # for successful lookup replace input_ids with 0, the results of these will be discarded anyway
98
  input_ids[additional_vocab_indices] = 0
99
  full_vector = F.embedding(input_ids, self.weight)
100
+ full_vector[additional_vocab_indices] = additional_embeddings # overwrite the records with high indices
 
 
 
101
  return full_vector
102
 
103
+
 
 
 
 
 
 
 
104
  @dataclass
105
+ class ModernVBertBaseModelOutput(BaseModelOutput):
106
  """
107
+ Base class for ModernVBERT model's outputs that may also contain a past key/values (to speed up sequential decoding).
108
  Args:
109
  last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
110
  Sequence of hidden-states at the output of the last layer of the model.
111
  If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
112
  hidden_size)` is output.
 
 
 
 
 
 
 
 
113
  hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
114
  Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
115
  one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
 
124
  sequence_length, hidden_size)`.
125
  image_hidden_states of the model produced by the vision encoder
126
  """
 
127
  last_hidden_state: torch.FloatTensor = None
128
  hidden_states: Optional[Tuple[torch.FloatTensor]] = None
129
  attentions: Optional[Tuple[torch.FloatTensor]] = None
130
  image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
131
 
132
+
133
  @dataclass
134
+ class ModernVBertMaskedLMOutput(MaskedLMOutput):
135
  """
136
+ Base class for ModernVBERT model's outputs that may also contain a past key/values (to speed up sequential decoding).
137
  Args:
138
  loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
139
  Masked language modeling (MLM) loss.
 
159
  attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
160
  image_hidden_states: Optional[torch.FloatTensor] = None
161
 
162
+
163
+ class ModernVBertSimpleMLP(nn.Module):
164
+ """A simple linear projection layer to project the vision hidden states to the text hidden states."""
165
  def __init__(self, input_size, output_size):
166
  super().__init__()
167
  self.proj = nn.Linear(input_size, output_size, bias=False)
 
169
  def forward(self, x):
170
  return self.proj(x)
171
 
172
+
173
+ class ModernVBertConnector(nn.Module):
174
+ """
175
+ Connector module for ModernVBERT. It performs a pixel shuffle operation followed by a linear projection to match the text model's hidden size.
176
+ Based on https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html
177
+ """
178
  def __init__(self, config):
179
  super().__init__()
180
  self.scale_factor = config.pixel_shuffle_factor
181
+ self.modality_projection = ModernVBertSimpleMLP(
182
  input_size=config.vision_config.hidden_size * (config.scale_factor**2),
183
+ output_size=config.text_config.hidden_size,
184
  )
185
 
186
  def pixel_shuffle(self, x, scale_factor):
 
191
  x = x.permute(0, 2, 1, 3)
192
  x = x.reshape(bsz, int(width / scale_factor), int(height / scale_factor), embed_dim * (scale_factor**2))
193
  x = x.permute(0, 2, 1, 3)
194
+ return x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))
 
195
 
196
  def forward(self, image_hidden_states):
197
  image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor)
198
+ return self.modality_projection(image_hidden_states)
199
+
200
 
201
+ class ModernVBertPreTrainedModel(PreTrainedModel):
202
+ config_class = ModernVBertConfig
203
  base_model_prefix = "model"
204
  supports_gradient_checkpointing = True
205
+ _no_split_modules = ["ModernVBertDecoderLayer"]
206
  _skip_keys_device_placement = "past_key_values"
207
  _supports_flash_attn_2 = True
208
  _supports_sdpa = True
209
  _supports_cache_class = True
210
 
211
  def _init_weights(self, module):
212
+ std = getattr(self.config, "initializer_range", 0.02)
 
 
 
 
 
 
 
 
 
 
213
  if isinstance(module, (nn.Linear, nn.Conv2d)):
214
  module.weight.data.normal_(mean=0.0, std=std)
215
  if module.bias is not None:
 
219
  if module.padding_idx is not None:
220
  module.weight.data[module.padding_idx].zero_()
221
 
 
 
 
 
 
222
 
223
+ class ModernVBertModel(ModernVBertPreTrainedModel):
224
+ def __init__(self, config: ModernVBertConfig, **kwargs):
225
  super().__init__(config)
226
+ self.vision_model = ModernVBertModel.init_vision_model(config, **kwargs)
227
+ self.connector = ModernVBertConnector(config)
228
+ self.text_model = ModernVBertModel.init_language_model(config, **kwargs)
 
 
229
  self.image_seq_len = int(
230
  ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2)
231
  )
232
+ self.image_token_id = config.image_token_id
233
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
234
  self.post_init()
235
 
236
  @staticmethod
237
+ def init_vision_model(config: ModernVBertConfig, **kwargs):
238
  vision_model_config = AutoConfig.from_pretrained(
239
  config.vision_config.vision_model_name,
240
+ _attn_implementation=config._attn_implementation,
241
+ dtype=config.torch_dtype,
242
  **kwargs,
243
  )
 
244
  vision_model = AutoModel.from_config(vision_model_config, trust_remote_code=True, **kwargs)
245
+ return getattr(vision_model, "vision_model", vision_model)
 
 
 
 
 
246
 
247
  @staticmethod
248
+ def init_language_model(config: ModernVBertConfig, **kwargs):
249
  text_model_config = AutoConfig.from_pretrained(
250
  config.text_config.text_model_name,
251
+ _attn_implementation=config._attn_implementation,
252
+ dtype=config.torch_dtype,
253
  trust_remote_code=True,
254
  **kwargs,
255
  )
 
256
  text_model = AutoModel.from_config(text_model_config, trust_remote_code=True, **kwargs)
 
 
257
  embed_layer = DecoupledEmbedding(
258
  num_embeddings=text_model_config.vocab_size,
259
  num_additional_embeddings=config.additional_vocab_size,
 
261
  partially_freeze=config.freeze_config["freeze_text_layers"],
262
  padding_idx=config.pad_token_id,
263
  )
 
264
  text_model.set_input_embeddings(embed_layer)
 
265
  return text_model
266
+
267
  def enable_input_require_grads(self):
268
  """
269
  Enables the gradients for the input embeddings.
 
290
  make_inputs_require_grads
291
  )
292
 
 
 
 
 
293
  def get_input_embeddings(self):
294
  return self.text_model.get_input_embeddings()
295
 
296
  def set_input_embeddings(self, value):
297
  self.text_model.set_input_embeddings(value)
298
 
299
+ def inputs_merger(self, input_ids, inputs_embeds, image_hidden_states):
300
+ """Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/smolvlm/modeling_smolvlm.py
301
+
 
302
  This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM.
303
  The merging happens as follows:
304
  - The text token sequence is: `tok_1 tok_2 tok_3 <fake_token_around_image> <image> <image> ... <image> <fake_token_around_image> tok_4`.
 
307
  - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
308
  - To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
309
  """
 
310
 
311
+ _, patch_size, _ = image_hidden_states.shape
312
  image_mask = input_ids == self.image_token_id
313
  num_image_tokens = image_mask.sum(dim=1)
314
  if not torch.all(num_image_tokens % patch_size == 0):
315
+ raise ValueError("Number of <image> tokens not divisible by patch_size.")
 
316
  blocks_per_sample = num_image_tokens // patch_size
 
317
  offsets = torch.nn.functional.pad(blocks_per_sample.cumsum(dim=0), (1, 0), value=0)
318
  block_offset = offsets[:-1]
319
  row_cum = image_mask.cumsum(dim=-1)
320
  chunk_idx = (row_cum - 1) // patch_size
321
  local_idx = (row_cum - 1) % patch_size
322
  block_idx = block_offset.unsqueeze(1) + chunk_idx
 
323
  image_embeds = torch.zeros_like(inputs_embeds)
324
  image_embeds[image_mask] = image_hidden_states[block_idx[image_mask], local_idx[image_mask], :]
325
+ return torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds)
 
 
326
 
327
  def forward(
328
  self,
329
  input_ids: torch.LongTensor = None,
330
  attention_mask: Optional[torch.Tensor] = None,
331
  position_ids: Optional[torch.LongTensor] = None,
 
332
  inputs_embeds: Optional[torch.FloatTensor] = None,
333
  pixel_values: Optional[torch.FloatTensor] = None,
334
  pixel_attention_mask: Optional[torch.BoolTensor] = None,
335
  image_hidden_states: Optional[torch.FloatTensor] = None,
 
336
  output_attentions: Optional[bool] = None,
337
  output_hidden_states: Optional[bool] = None,
338
  return_dict: Optional[bool] = None,
339
+ ) -> Union[Tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
 
340
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
341
  output_hidden_states = (
342
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
343
  )
 
344
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  if inputs_embeds is None:
346
  inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(input_ids.device)
347
+ if pixel_values is not None:
348
+ batch_size, num_images, _, _, _ = pixel_values.shape
 
 
 
 
 
349
  pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
 
 
350
  nb_values_per_image = pixel_values.shape[1:].numel()
351
  real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
 
352
  if not any(real_images_inds):
 
353
  real_images_inds[0] = True
 
354
  pixel_values = pixel_values[real_images_inds].contiguous()
355
+ image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
  image_hidden_states = self.connector(image_hidden_states)
 
357
  elif image_hidden_states is not None:
358
  image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
 
359
  if inputs_embeds is not None and image_hidden_states is not None:
360
+ inputs_embeds = self.inputs_merger(input_ids, inputs_embeds, image_hidden_states)
 
 
 
 
 
 
 
361
  outputs = self.text_model(
362
  inputs_embeds=inputs_embeds,
363
  attention_mask=attention_mask,
 
365
  output_attentions=output_attentions,
366
  output_hidden_states=output_hidden_states,
367
  return_dict=return_dict,
 
 
 
368
  )
 
369
  if not return_dict:
370
  return tuple(v for v in [*outputs, image_hidden_states] if v is not None)
371
+ return ModernVBertBaseModelOutput(
 
372
  last_hidden_state=outputs.last_hidden_state,
373
  hidden_states=outputs.hidden_states,
374
  attentions=outputs.attentions,
375
  image_hidden_states=image_hidden_states,
376
  )
377
 
378
+ class ModernVBertLMHead(nn.Module):
379
  def __init__(self, config, **kwargs):
380
  super().__init__()
381
+ pretrained_config = AutoConfig.from_pretrained(config.text_config.text_model_name, trust_remote_code=True, **kwargs)
 
 
 
 
382
  pretrained_model = AutoModelForMaskedLM.from_config(pretrained_config, trust_remote_code=True, **kwargs)
 
383
  self.head = pretrained_model.head
384
  self.decoder = pretrained_model.decoder
385
 
386
  def forward(self, hidden_states):
387
+ return self.decoder(self.head(hidden_states))
 
 
388
 
 
 
389
 
390
+ class ModernVBertForMaskedLM(ModernVBertPreTrainedModel):
391
  def __init__(self, config, **kwargs):
392
  super().__init__(config)
 
393
  self.image_token_id = config.image_token_id
394
  self.in_features = config.hidden_size
395
  self.out_additional_features = config.additional_vocab_size
396
  self.vocab_size = config.vocab_size
397
+ self.model = ModernVBertModel(config, **kwargs)
398
+ self.lm_head = ModernVBertLMHead(config, **kwargs)
 
 
 
 
 
 
 
399
  if self.out_additional_features > 0:
400
+ self.additional_fc = nn.Linear(self.in_features, self.out_additional_features, bias=False)
 
 
 
 
 
 
401
  self.post_init()
402
 
403
  def forward(
404
+ self,
405
+ input_ids: torch.LongTensor = None,
406
+ attention_mask: Optional[torch.Tensor] = None,
407
+ position_ids: Optional[torch.LongTensor] = None,
408
+ inputs_embeds: Optional[torch.FloatTensor] = None,
409
+ pixel_values: Optional[torch.FloatTensor] = None,
410
+ pixel_attention_mask: Optional[torch.BoolTensor] = None,
411
+ image_hidden_states: Optional[torch.FloatTensor] = None,
412
+ output_attentions: Optional[bool] = None,
413
+ output_hidden_states: Optional[bool] = None,
414
+ return_dict: Optional[bool] = None,
415
+ labels: Optional[torch.LongTensor] = None,
416
+ ) -> Union[Tuple, ModernVBertMaskedLMOutput]:
 
 
 
 
 
 
 
 
 
 
417
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
418
  output_hidden_states = (
419
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
420
  )
421
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
422
 
 
 
423
  outputs = self.model(
424
  input_ids=input_ids,
425
  attention_mask=attention_mask,
426
  position_ids=position_ids,
 
427
  inputs_embeds=inputs_embeds,
428
  pixel_values=pixel_values,
429
  pixel_attention_mask=pixel_attention_mask,
430
  image_hidden_states=image_hidden_states,
 
431
  output_attentions=output_attentions,
432
  output_hidden_states=output_hidden_states,
433
  return_dict=return_dict,
434
  )
 
 
435
  hidden_states = outputs[0]
 
436
  logits = self.lm_head(hidden_states)
437
  if self.out_additional_features > 0:
438
  proj_states = self.lm_head.head(hidden_states)
439
  additional_features = self.additional_fc(proj_states)
440
  logits = torch.cat((logits, additional_features), -1)
441
+ loss = None
 
 
442
  if labels is not None:
443
+ loss = CrossEntropyLoss()(logits.view(-1, self.vocab_size + self.out_additional_features), labels.view(-1))
 
 
 
444
  if not return_dict:
445
  output = (logits,) + outputs[2:]
446
+ return ((loss,) + output) if loss is not None else output
447
+ return ModernVBertMaskedLMOutput(
448
+ loss=loss,
449
+ logits=logits.float(),
 
450
  hidden_states=outputs.hidden_states,
451
  attentions=outputs.attentions,
452
  image_hidden_states=outputs.image_hidden_states,