|
|
"""MarkupDM configuration""" |
|
|
|
|
|
from transformers import AutoConfig, PretrainedConfig |
|
|
from transformers.utils import logging |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
class MarkupDMConfig(PretrainedConfig): |
|
|
model_type = "markupdm" |
|
|
is_composition = True |
|
|
has_no_defaults_at_init = True |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vocab_size: int = 49156, |
|
|
image_size: int = 256, |
|
|
image_pos_size: int = 4, |
|
|
image_pos_sigma: float = 10.0, |
|
|
image_loss_weight: float = 1.0, |
|
|
freeze_text_embeddings: bool = False, |
|
|
**kwargs, |
|
|
) -> None: |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
if "text_model" not in kwargs or "vision_model" not in kwargs: |
|
|
raise ValueError( |
|
|
f"A configuraton of type {self.model_type} cannot be" |
|
|
"instantiated because not both `text_model` and `vision_model`" |
|
|
f"sub-configurations are passed, but only {kwargs}" |
|
|
) |
|
|
|
|
|
self.vocab_size = vocab_size |
|
|
self.image_size = image_size |
|
|
self.image_pos_size = image_pos_size |
|
|
self.image_pos_sigma = image_pos_sigma |
|
|
self.loss_type = "WeightedCausalLMLoss" |
|
|
self.image_loss_weight = image_loss_weight |
|
|
self.freeze_text_embeddings = freeze_text_embeddings |
|
|
|
|
|
text_config = kwargs.pop("text_model") |
|
|
vision_config = kwargs.pop("vision_model") |
|
|
|
|
|
if isinstance(text_config, PretrainedConfig): |
|
|
self.text_model = text_config |
|
|
else: |
|
|
path = text_config.pop("_name_or_path") |
|
|
self.text_model = AutoConfig.from_pretrained( |
|
|
path, |
|
|
**text_config, |
|
|
) |
|
|
|
|
|
if isinstance(vision_config, PretrainedConfig): |
|
|
self.vision_model = vision_config |
|
|
else: |
|
|
path = vision_config.pop("_name_or_path") |
|
|
self.vision_model = AutoConfig.from_pretrained( |
|
|
path, |
|
|
trust_remote_code=True, |
|
|
**vision_config, |
|
|
) |
|
|
|
|
|
|
|
|
self.initializer_range = self.text_model.initializer_range |
|
|
self.num_hidden_layers = self.text_model.num_hidden_layers |
|
|
self.is_decoder = True |
|
|
|