paultltc commited on
Commit
4804073
·
1 Parent(s): 6ce858b

clean modeling + fix config double loading

Browse files
Files changed (1) hide show
  1. configuration_modernvbert.py +0 -272
configuration_modernvbert.py CHANGED
@@ -197,278 +197,6 @@ class ModernVBertConfig(PretrainedConfig):
197
  model_type = "modernvbert"
198
  is_composition = True
199
 
200
- def __init__(
201
- self,
202
- text_config: Union[PretrainedConfig, Dict[str, Any]] = None,
203
- vision_config: Union[PretrainedConfig, Dict[str, Any]] = None,
204
- image_token_id: int = 128_257,
205
- vocab_size=50368,
206
- use_cache=True,
207
- tie_word_embeddings=False,
208
- freeze_config=None,
209
- pad_token_id=None,
210
- initializer_range=0.02,
211
- pixel_shuffle_factor=4,
212
- use_resampler=False,
213
- additional_vocab_size=0,
214
- neftune_noise_alpha=0.0,
215
- **kwargs,
216
- ):
217
- self.image_token_id = image_token_id
218
- self.use_cache = use_cache
219
- self.tie_word_embeddings = tie_word_embeddings
220
- self.scale_factor = pixel_shuffle_factor
221
- self.additional_vocab_size = additional_vocab_size
222
-
223
- if text_config is None:
224
- base_text_config = AutoConfig.from_pretrained(DEFAULT_TEXT_MODEL_NAME, trust_remote_code=True)
225
- text_config = ModernVBertTextConfig(base_text_config)
226
- elif isinstance(text_config, dict):
227
- text_config = ModernVBertTextConfig.from_dict(text_config)
228
- self.text_config = text_config
229
-
230
- if vision_config is None:
231
- base_vision_config = AutoConfig.from_pretrained(DEFAULT_VISION_MODEL_NAME, trust_remote_code=True)
232
- vision_config = ModernVBertVisionConfig(base_vision_config)
233
- elif isinstance(vision_config, dict):
234
- vision_config = ModernVBertVisionConfig.from_dict(vision_config)
235
- self.vision_config = vision_config
236
-
237
- self.freeze_config = freeze_config
238
- self.pixel_shuffle_factor = pixel_shuffle_factor
239
- self.use_resampler = use_resampler
240
- self.neftune_noise_alpha = neftune_noise_alpha
241
- self.initializer_range = initializer_range
242
-
243
- hidden_size = kwargs.pop("hidden_size", self.text_config.hidden_size)
244
-
245
- super().__init__(
246
- **kwargs,
247
- pad_token_id=pad_token_id,
248
- tie_word_embeddings=tie_word_embeddings,
249
- vocab_size=vocab_size,
250
- hidden_size=hidden_size,
251
- )
252
-
253
- def to_dict(self):
254
- output = copy.deepcopy(self.__dict__)
255
- output["model_type"] = self.__class__.model_type
256
- output["vision_config"] = self.vision_config.to_dict()
257
- output["text_config"] = self.text_config.to_dict()
258
- return output
259
-
260
- @classmethod
261
- def from_pretrained_models(
262
- cls,
263
- text_model_name: Union[str, os.PathLike],
264
- vision_model_name: Union[str, os.PathLike],
265
- **kwargs,
266
- ) -> "PretrainedConfig":
267
- text_model_config = ModernVBertTextConfig.from_base_model(text_model_name)
268
- vision_model_config = ModernVBertVisionConfig.from_base_model(vision_model_name)
269
- return cls(
270
- text_config=text_model_config,
271
- vision_config=vision_model_config,
272
- **kwargs,
273
- )import copy
274
- import os
275
- from typing import Any, Dict, Union
276
-
277
- from transformers import AutoConfig
278
- from transformers.configuration_utils import PretrainedConfig
279
- from transformers.utils import logging
280
-
281
- logger = logging.get_logger(__name__)
282
-
283
- DEFAULT_TEXT_MODEL_NAME = "jhu-clsp/ettin-encoder-150m"
284
- DEFAULT_VISION_MODEL_NAME = "google/siglip2-base-patch16-512"
285
-
286
- def collect_arg_in_candidates(config, candidates, default=None) -> Any:
287
- """Gets the first available argument in a config given a list of candidate names."""
288
- for c in candidates:
289
- if hasattr(config, c):
290
- return getattr(config, c)
291
- elif c in config:
292
- return config[c]
293
- if default is not None:
294
- return default
295
- raise ValueError(
296
- f"No matching arguments found in candidates. Candidates: {candidates}, Config: {config}"
297
- )
298
-
299
- class ModernVBertTextConfig(PretrainedConfig):
300
- r"""
301
- This is the configuration class to store the configuration of a [`ModernBERT`]. It is used to instantiate an ModernBERT
302
- model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
303
- defaults will yield a similar configuration to that of the [jhu-clsp/ettin-encoder-150m](https://huggingface.co/jhu-clsp/ettin-encoder-150m) architecture.
304
-
305
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
306
- documentation from [`PretrainedConfig`] for more information.
307
- """
308
- model_type = "modernvbert_text"
309
-
310
- def __init__(
311
- self,
312
- text_model_name=DEFAULT_TEXT_MODEL_NAME,
313
- hidden_size=768,
314
- num_hidden_layers=22,
315
- intermediate_size=1152,
316
- mlp_bias=False,
317
- vocab_size=50368,
318
- **kwargs,
319
- ):
320
- super().__init__(
321
- text_model_name=text_model_name,
322
- hidden_size=hidden_size,
323
- num_hidden_layers=num_hidden_layers,
324
- intermediate_size=intermediate_size,
325
- mlp_bias=mlp_bias,
326
- vocab_size=vocab_size,
327
- **kwargs,
328
- )
329
-
330
- @classmethod
331
- def from_base_model(
332
- cls,
333
- text_model_name=DEFAULT_TEXT_MODEL_NAME,
334
- **kwargs,
335
- ):
336
- text_config = AutoConfig.from_pretrained(text_model_name, trust_remote_code=True)
337
- if hasattr(text_config, "text_config"):
338
- text_config = text_config.text_config
339
-
340
- hidden_size = collect_arg_in_candidates(text_config, ["hidden_size", "embed_dim"])
341
- num_hidden_layers = collect_arg_in_candidates(text_config, ["num_hidden_layers", "num_hidden_blocks"])
342
- intermediate_size = collect_arg_in_candidates(text_config, ["intermediate_size", "mlp_dim"])
343
- mlp_bias = collect_arg_in_candidates(text_config, ["mlp_bias", "mlp_hidden_bias"], default=False)
344
- vocab_size = collect_arg_in_candidates(text_config, ["vocab_size"])
345
-
346
- return cls(
347
- text_model_name=text_model_name,
348
- hidden_size=hidden_size,
349
- num_hidden_layers=num_hidden_layers,
350
- intermediate_size=intermediate_size,
351
- mlp_bias=mlp_bias,
352
- vocab_size=vocab_size,
353
- **kwargs,
354
- )
355
-
356
- class ModernVBertVisionConfig(PretrainedConfig):
357
- r"""
358
- This is the configuration class to store the configuration of a [`SigLIP`]. It is used to instantiate the vision encoder part of the ModernVBERT
359
- model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
360
- defaults will yield a similar configuration to that of the SigLIP.
361
-
362
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
363
- documentation from [`PretrainedConfig`] for more information.
364
- """
365
- model_type = "modernvbert_vision"
366
-
367
- attribute_map = {
368
- "hidden_size": "embed_dim",
369
- }
370
-
371
- def __init__(
372
- self,
373
- vision_model_name=DEFAULT_VISION_MODEL_NAME,
374
- embed_dim=768,
375
- image_size=512,
376
- patch_size=16,
377
- num_hidden_layers=12,
378
- intermediate_size=3072,
379
- **kwargs,
380
- ):
381
- super().__init__(
382
- vision_model_name=vision_model_name,
383
- embed_dim=embed_dim,
384
- image_size=image_size,
385
- patch_size=patch_size,
386
- num_hidden_layers=num_hidden_layers,
387
- intermediate_size=intermediate_size,
388
- **kwargs,
389
- )
390
-
391
- @classmethod
392
- def from_base_model(
393
- cls,
394
- vision_model_name=DEFAULT_VISION_MODEL_NAME,
395
- **kwargs,
396
- ):
397
- vision_config = AutoConfig.from_pretrained(vision_model_name, trust_remote_code=True)
398
- if hasattr(vision_config, "vision_config"):
399
- vision_config = vision_config.vision_config
400
-
401
- embed_dim = collect_arg_in_candidates(vision_config, ["embed_dim", "hidden_size"])
402
- image_size = collect_arg_in_candidates(vision_config, ["image_size", "img_size"])
403
- patch_size = collect_arg_in_candidates(vision_config, ["patch_size"])
404
- num_hidden_layers = collect_arg_in_candidates(vision_config, ["num_hidden_layers", "num_hidden_blocks"])
405
- intermediate_size = collect_arg_in_candidates(vision_config, ["intermediate_size", "mlp_dim"])
406
-
407
- return cls(
408
- vision_model_name=vision_model_name,
409
- embed_dim=embed_dim,
410
- image_size=image_size,
411
- patch_size=patch_size,
412
- num_hidden_layers=num_hidden_layers,
413
- intermediate_size=intermediate_size,
414
- **kwargs,
415
- )
416
-
417
-
418
- class ModernVBertConfig(PretrainedConfig):
419
- r"""
420
- This is the configuration class to store the configuration of a `ModernVBert` model. It is used to
421
- instantiate a ModernVBert model according to the specified arguments and defines the model architecture.
422
-
423
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs.
424
- See the documentation for [`PretrainedConfig`] for more details.
425
-
426
- Args:
427
- text_config (`PretrainedConfig` or `dict`, optional):
428
- Custom text config or a dict with a `text_model_name` key for the text encoder. If `None`, the
429
- default text backbone defined by `DEFAULT_TEXT_MODEL_NAME` is used.
430
- vision_config (`PretrainedConfig` or `dict`, optional):
431
- Custom vision config or a dict with a `vision_model_name` key for the vision encoder. If `None`, the
432
- default vision backbone defined by `DEFAULT_VISION_MODEL_NAME` is used.
433
- image_token_id (`int`, optional, defaults to 128257):
434
- Token id reserved for image tokens inserted into the text stream.
435
- vocab_size (`int`, optional, defaults to 128256):
436
- Vocabulary size used by the text embeddings.
437
- use_cache (`bool`, optional, defaults to `True`):
438
- Whether to cache key/value tensors for attention (relevant for decoder architectures).
439
- tie_word_embeddings (`bool`, optional, defaults to `False`):
440
- Whether to tie input token embeddings and output token embeddings.
441
- pixel_shuffle_factor (`int`, optional, defaults to 4):
442
- Scale factor used by any pixel-shuffle / upsampling operations in the vision head.
443
- additional_vocab_size (`int`, optional, defaults to 0):
444
- Number of extra tokens appended to the base vocabulary (useful for adapters / special tokens).
445
- pad_token_id (`int`, optional):
446
- Padding token id.
447
- initializer_range (`float`, optional, defaults to 0.02):
448
- Stddev used for weight initialization.
449
- freeze_config (`Any`, optional):
450
- Optional config describing which submodules to freeze during training.
451
- use_resampler (`bool`, optional, defaults to `False`):
452
- Whether to enable an additional resampler on visual features.
453
- neftune_noise_alpha (`float`, optional, defaults to 0.0):
454
- Alpha parameter for neftune noise injection.
455
-
456
- Example:
457
- ```python
458
- >>> from modernvbert import ModernVBertConfig
459
- >>> # Initializing configuration
460
- >>> configuration = ModernVBertConfig()
461
- >>> # Initializing a model from the configuration (model class is implemented in
462
- >>> # `modernvbert.modeling_modernvbert`)
463
- >>> # from modernvbert import ModernVBertModel
464
- >>> # model = ModernVBertModel(configuration)
465
- >>> # Accessing the model configuration
466
- >>> # cfg = model.config
467
- ```"""
468
-
469
- model_type = "modernvbert"
470
- is_composition = True
471
-
472
  def __init__(
473
  self,
474
  text_config: Union[PretrainedConfig, Dict[str, Any]] = None,
 
197
  model_type = "modernvbert"
198
  is_composition = True
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  def __init__(
201
  self,
202
  text_config: Union[PretrainedConfig, Dict[str, Any]] = None,