Update modeling_modernvbert.py
Browse files- modeling_modernvbert.py +23 -20
modeling_modernvbert.py
CHANGED
|
@@ -202,11 +202,8 @@ class ModernVBertPreTrainedModel(PreTrainedModel):
|
|
| 202 |
config_class = ModernVBertConfig
|
| 203 |
base_model_prefix = "model"
|
| 204 |
supports_gradient_checkpointing = True
|
| 205 |
-
_no_split_modules = ["ModernVBertDecoderLayer"]
|
| 206 |
-
_skip_keys_device_placement = "past_key_values"
|
| 207 |
_supports_flash_attn_2 = True
|
| 208 |
_supports_sdpa = True
|
| 209 |
-
_supports_cache_class = True
|
| 210 |
|
| 211 |
def _init_weights(self, module):
|
| 212 |
std = getattr(self.config, "initializer_range", 0.02)
|
|
@@ -221,39 +218,44 @@ class ModernVBertPreTrainedModel(PreTrainedModel):
|
|
| 221 |
|
| 222 |
|
| 223 |
class ModernVBertModel(ModernVBertPreTrainedModel):
|
| 224 |
-
def __init__(self, config: ModernVBertConfig
|
| 225 |
super().__init__(config)
|
| 226 |
-
self.vision_model = ModernVBertModel.init_vision_model(config
|
| 227 |
self.connector = ModernVBertConnector(config)
|
| 228 |
-
self.text_model = ModernVBertModel.init_language_model(config
|
| 229 |
self.image_seq_len = int(
|
| 230 |
((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2)
|
| 231 |
)
|
| 232 |
self.image_token_id = config.image_token_id
|
| 233 |
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
|
|
|
|
|
|
|
|
|
| 234 |
self.post_init()
|
| 235 |
|
| 236 |
@staticmethod
|
| 237 |
-
def init_vision_model(config: ModernVBertConfig
|
| 238 |
vision_model_config = AutoConfig.from_pretrained(
|
| 239 |
config.vision_config.vision_model_name,
|
| 240 |
_attn_implementation=config._attn_implementation,
|
| 241 |
-
dtype=config.torch_dtype,
|
| 242 |
-
**kwargs,
|
| 243 |
)
|
| 244 |
-
vision_model = AutoModel.from_config(
|
|
|
|
|
|
|
|
|
|
| 245 |
return getattr(vision_model, "vision_model", vision_model)
|
| 246 |
|
| 247 |
@staticmethod
|
| 248 |
-
def init_language_model(config: ModernVBertConfig
|
| 249 |
text_model_config = AutoConfig.from_pretrained(
|
| 250 |
config.text_config.text_model_name,
|
| 251 |
_attn_implementation=config._attn_implementation,
|
| 252 |
-
dtype=config.torch_dtype,
|
| 253 |
trust_remote_code=True,
|
| 254 |
-
**kwargs,
|
| 255 |
)
|
| 256 |
-
text_model = AutoModel.from_config(
|
|
|
|
|
|
|
|
|
|
| 257 |
embed_layer = DecoupledEmbedding(
|
| 258 |
num_embeddings=text_model_config.vocab_size,
|
| 259 |
num_additional_embeddings=config.additional_vocab_size,
|
|
@@ -376,10 +378,10 @@ class ModernVBertModel(ModernVBertPreTrainedModel):
|
|
| 376 |
)
|
| 377 |
|
| 378 |
class ModernVBertLMHead(nn.Module):
|
| 379 |
-
def __init__(self, config
|
| 380 |
super().__init__()
|
| 381 |
-
pretrained_config = AutoConfig.from_pretrained(config.text_config.text_model_name, trust_remote_code=True
|
| 382 |
-
pretrained_model = AutoModelForMaskedLM.from_config(pretrained_config, trust_remote_code=True
|
| 383 |
self.head = pretrained_model.head
|
| 384 |
self.decoder = pretrained_model.decoder
|
| 385 |
|
|
@@ -388,16 +390,17 @@ class ModernVBertLMHead(nn.Module):
|
|
| 388 |
|
| 389 |
|
| 390 |
class ModernVBertForMaskedLM(ModernVBertPreTrainedModel):
|
| 391 |
-
def __init__(self, config
|
| 392 |
super().__init__(config)
|
| 393 |
self.image_token_id = config.image_token_id
|
| 394 |
self.in_features = config.hidden_size
|
| 395 |
self.out_additional_features = config.additional_vocab_size
|
| 396 |
self.vocab_size = config.vocab_size
|
| 397 |
-
self.model = ModernVBertModel(config
|
| 398 |
-
self.lm_head = ModernVBertLMHead(config
|
| 399 |
if self.out_additional_features > 0:
|
| 400 |
self.additional_fc = nn.Linear(self.in_features, self.out_additional_features, bias=False)
|
|
|
|
| 401 |
self.post_init()
|
| 402 |
|
| 403 |
def forward(
|
|
|
|
| 202 |
config_class = ModernVBertConfig
|
| 203 |
base_model_prefix = "model"
|
| 204 |
supports_gradient_checkpointing = True
|
|
|
|
|
|
|
| 205 |
_supports_flash_attn_2 = True
|
| 206 |
_supports_sdpa = True
|
|
|
|
| 207 |
|
| 208 |
def _init_weights(self, module):
|
| 209 |
std = getattr(self.config, "initializer_range", 0.02)
|
|
|
|
| 218 |
|
| 219 |
|
| 220 |
class ModernVBertModel(ModernVBertPreTrainedModel):
|
| 221 |
+
def __init__(self, config: ModernVBertConfig):
|
| 222 |
super().__init__(config)
|
| 223 |
+
self.vision_model = ModernVBertModel.init_vision_model(config)
|
| 224 |
self.connector = ModernVBertConnector(config)
|
| 225 |
+
self.text_model = ModernVBertModel.init_language_model(config)
|
| 226 |
self.image_seq_len = int(
|
| 227 |
((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2)
|
| 228 |
)
|
| 229 |
self.image_token_id = config.image_token_id
|
| 230 |
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
| 231 |
+
# set the correct dtype for vision and text models
|
| 232 |
+
self.vision_model.to(self.dtype)
|
| 233 |
+
self.text_model.to(self.dtype)
|
| 234 |
self.post_init()
|
| 235 |
|
| 236 |
@staticmethod
|
| 237 |
+
def init_vision_model(config: ModernVBertConfig):
|
| 238 |
vision_model_config = AutoConfig.from_pretrained(
|
| 239 |
config.vision_config.vision_model_name,
|
| 240 |
_attn_implementation=config._attn_implementation,
|
|
|
|
|
|
|
| 241 |
)
|
| 242 |
+
vision_model = AutoModel.from_config(
|
| 243 |
+
vision_model_config,
|
| 244 |
+
trust_remote_code=True,
|
| 245 |
+
)
|
| 246 |
return getattr(vision_model, "vision_model", vision_model)
|
| 247 |
|
| 248 |
@staticmethod
|
| 249 |
+
def init_language_model(config: ModernVBertConfig):
|
| 250 |
text_model_config = AutoConfig.from_pretrained(
|
| 251 |
config.text_config.text_model_name,
|
| 252 |
_attn_implementation=config._attn_implementation,
|
|
|
|
| 253 |
trust_remote_code=True,
|
|
|
|
| 254 |
)
|
| 255 |
+
text_model = AutoModel.from_config(
|
| 256 |
+
text_model_config,
|
| 257 |
+
trust_remote_code=True
|
| 258 |
+
)
|
| 259 |
embed_layer = DecoupledEmbedding(
|
| 260 |
num_embeddings=text_model_config.vocab_size,
|
| 261 |
num_additional_embeddings=config.additional_vocab_size,
|
|
|
|
| 378 |
)
|
| 379 |
|
| 380 |
class ModernVBertLMHead(nn.Module):
|
| 381 |
+
def __init__(self, config):
|
| 382 |
super().__init__()
|
| 383 |
+
pretrained_config = AutoConfig.from_pretrained(config.text_config.text_model_name, trust_remote_code=True)
|
| 384 |
+
pretrained_model = AutoModelForMaskedLM.from_config(pretrained_config, trust_remote_code=True)
|
| 385 |
self.head = pretrained_model.head
|
| 386 |
self.decoder = pretrained_model.decoder
|
| 387 |
|
|
|
|
| 390 |
|
| 391 |
|
| 392 |
class ModernVBertForMaskedLM(ModernVBertPreTrainedModel):
|
| 393 |
+
def __init__(self, config):
|
| 394 |
super().__init__(config)
|
| 395 |
self.image_token_id = config.image_token_id
|
| 396 |
self.in_features = config.hidden_size
|
| 397 |
self.out_additional_features = config.additional_vocab_size
|
| 398 |
self.vocab_size = config.vocab_size
|
| 399 |
+
self.model = ModernVBertModel(config)
|
| 400 |
+
self.lm_head = ModernVBertLMHead(config)
|
| 401 |
if self.out_additional_features > 0:
|
| 402 |
self.additional_fc = nn.Linear(self.in_features, self.out_additional_features, bias=False)
|
| 403 |
+
self.lm_head.to(self.dtype)
|
| 404 |
self.post_init()
|
| 405 |
|
| 406 |
def forward(
|