paultltc commited on
Commit
57e35bd
·
verified ·
1 Parent(s): 2d5ce86

Update modeling_modernvbert.py

Browse files
Files changed (1) hide show
  1. modeling_modernvbert.py +23 -20
modeling_modernvbert.py CHANGED
@@ -202,11 +202,8 @@ class ModernVBertPreTrainedModel(PreTrainedModel):
202
  config_class = ModernVBertConfig
203
  base_model_prefix = "model"
204
  supports_gradient_checkpointing = True
205
- _no_split_modules = ["ModernVBertDecoderLayer"]
206
- _skip_keys_device_placement = "past_key_values"
207
  _supports_flash_attn_2 = True
208
  _supports_sdpa = True
209
- _supports_cache_class = True
210
 
211
  def _init_weights(self, module):
212
  std = getattr(self.config, "initializer_range", 0.02)
@@ -221,39 +218,44 @@ class ModernVBertPreTrainedModel(PreTrainedModel):
221
 
222
 
223
  class ModernVBertModel(ModernVBertPreTrainedModel):
224
- def __init__(self, config: ModernVBertConfig, **kwargs):
225
  super().__init__(config)
226
- self.vision_model = ModernVBertModel.init_vision_model(config, **kwargs)
227
  self.connector = ModernVBertConnector(config)
228
- self.text_model = ModernVBertModel.init_language_model(config, **kwargs)
229
  self.image_seq_len = int(
230
  ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2)
231
  )
232
  self.image_token_id = config.image_token_id
233
  self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
 
 
 
234
  self.post_init()
235
 
236
  @staticmethod
237
- def init_vision_model(config: ModernVBertConfig, **kwargs):
238
  vision_model_config = AutoConfig.from_pretrained(
239
  config.vision_config.vision_model_name,
240
  _attn_implementation=config._attn_implementation,
241
- dtype=config.torch_dtype,
242
- **kwargs,
243
  )
244
- vision_model = AutoModel.from_config(vision_model_config, trust_remote_code=True, **kwargs)
 
 
 
245
  return getattr(vision_model, "vision_model", vision_model)
246
 
247
  @staticmethod
248
- def init_language_model(config: ModernVBertConfig, **kwargs):
249
  text_model_config = AutoConfig.from_pretrained(
250
  config.text_config.text_model_name,
251
  _attn_implementation=config._attn_implementation,
252
- dtype=config.torch_dtype,
253
  trust_remote_code=True,
254
- **kwargs,
255
  )
256
- text_model = AutoModel.from_config(text_model_config, trust_remote_code=True, **kwargs)
 
 
 
257
  embed_layer = DecoupledEmbedding(
258
  num_embeddings=text_model_config.vocab_size,
259
  num_additional_embeddings=config.additional_vocab_size,
@@ -376,10 +378,10 @@ class ModernVBertModel(ModernVBertPreTrainedModel):
376
  )
377
 
378
  class ModernVBertLMHead(nn.Module):
379
- def __init__(self, config, **kwargs):
380
  super().__init__()
381
- pretrained_config = AutoConfig.from_pretrained(config.text_config.text_model_name, trust_remote_code=True, **kwargs)
382
- pretrained_model = AutoModelForMaskedLM.from_config(pretrained_config, trust_remote_code=True, **kwargs)
383
  self.head = pretrained_model.head
384
  self.decoder = pretrained_model.decoder
385
 
@@ -388,16 +390,17 @@ class ModernVBertLMHead(nn.Module):
388
 
389
 
390
  class ModernVBertForMaskedLM(ModernVBertPreTrainedModel):
391
- def __init__(self, config, **kwargs):
392
  super().__init__(config)
393
  self.image_token_id = config.image_token_id
394
  self.in_features = config.hidden_size
395
  self.out_additional_features = config.additional_vocab_size
396
  self.vocab_size = config.vocab_size
397
- self.model = ModernVBertModel(config, **kwargs)
398
- self.lm_head = ModernVBertLMHead(config, **kwargs)
399
  if self.out_additional_features > 0:
400
  self.additional_fc = nn.Linear(self.in_features, self.out_additional_features, bias=False)
 
401
  self.post_init()
402
 
403
  def forward(
 
202
  config_class = ModernVBertConfig
203
  base_model_prefix = "model"
204
  supports_gradient_checkpointing = True
 
 
205
  _supports_flash_attn_2 = True
206
  _supports_sdpa = True
 
207
 
208
  def _init_weights(self, module):
209
  std = getattr(self.config, "initializer_range", 0.02)
 
218
 
219
 
220
  class ModernVBertModel(ModernVBertPreTrainedModel):
221
+ def __init__(self, config: ModernVBertConfig):
222
  super().__init__(config)
223
+ self.vision_model = ModernVBertModel.init_vision_model(config)
224
  self.connector = ModernVBertConnector(config)
225
+ self.text_model = ModernVBertModel.init_language_model(config)
226
  self.image_seq_len = int(
227
  ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2)
228
  )
229
  self.image_token_id = config.image_token_id
230
  self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
231
+ # set the correct dtype for vision and text models
232
+ self.vision_model.to(self.dtype)
233
+ self.text_model.to(self.dtype)
234
  self.post_init()
235
 
236
  @staticmethod
237
+ def init_vision_model(config: ModernVBertConfig):
238
  vision_model_config = AutoConfig.from_pretrained(
239
  config.vision_config.vision_model_name,
240
  _attn_implementation=config._attn_implementation,
 
 
241
  )
242
+ vision_model = AutoModel.from_config(
243
+ vision_model_config,
244
+ trust_remote_code=True,
245
+ )
246
  return getattr(vision_model, "vision_model", vision_model)
247
 
248
  @staticmethod
249
+ def init_language_model(config: ModernVBertConfig):
250
  text_model_config = AutoConfig.from_pretrained(
251
  config.text_config.text_model_name,
252
  _attn_implementation=config._attn_implementation,
 
253
  trust_remote_code=True,
 
254
  )
255
+ text_model = AutoModel.from_config(
256
+ text_model_config,
257
+ trust_remote_code=True
258
+ )
259
  embed_layer = DecoupledEmbedding(
260
  num_embeddings=text_model_config.vocab_size,
261
  num_additional_embeddings=config.additional_vocab_size,
 
378
  )
379
 
380
  class ModernVBertLMHead(nn.Module):
381
+ def __init__(self, config):
382
  super().__init__()
383
+ pretrained_config = AutoConfig.from_pretrained(config.text_config.text_model_name, trust_remote_code=True)
384
+ pretrained_model = AutoModelForMaskedLM.from_config(pretrained_config, trust_remote_code=True)
385
  self.head = pretrained_model.head
386
  self.decoder = pretrained_model.decoder
387
 
 
390
 
391
 
392
  class ModernVBertForMaskedLM(ModernVBertPreTrainedModel):
393
+ def __init__(self, config):
394
  super().__init__(config)
395
  self.image_token_id = config.image_token_id
396
  self.in_features = config.hidden_size
397
  self.out_additional_features = config.additional_vocab_size
398
  self.vocab_size = config.vocab_size
399
+ self.model = ModernVBertModel(config)
400
+ self.lm_head = ModernVBertLMHead(config)
401
  if self.out_additional_features > 0:
402
  self.additional_fc = nn.Linear(self.in_features, self.out_additional_features, bias=False)
403
+ self.lm_head.to(self.dtype)
404
  self.post_init()
405
 
406
  def forward(