markupdm / configuration_markupdm.py
ktrk115's picture
Update configuration_markupdm.py
6d6c47a verified
"""MarkupDM configuration"""
from transformers import AutoConfig, PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class MarkupDMConfig(PretrainedConfig):
model_type = "markupdm"
is_composition = True
has_no_defaults_at_init = True
def __init__(
self,
vocab_size: int = 49156,
image_size: int = 256,
image_pos_size: int = 4,
image_pos_sigma: float = 10.0,
image_loss_weight: float = 1.0,
freeze_text_embeddings: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
if "text_model" not in kwargs or "vision_model" not in kwargs:
raise ValueError(
f"A configuraton of type {self.model_type} cannot be"
"instantiated because not both `text_model` and `vision_model`"
f"sub-configurations are passed, but only {kwargs}"
)
self.vocab_size = vocab_size
self.image_size = image_size
self.image_pos_size = image_pos_size
self.image_pos_sigma = image_pos_sigma
self.loss_type = "WeightedCausalLMLoss"
self.image_loss_weight = image_loss_weight
self.freeze_text_embeddings = freeze_text_embeddings
text_config = kwargs.pop("text_model")
vision_config = kwargs.pop("vision_model")
if isinstance(text_config, PretrainedConfig):
self.text_model = text_config
else:
path = text_config.pop("_name_or_path")
self.text_model = AutoConfig.from_pretrained(
path,
**text_config,
)
if isinstance(vision_config, PretrainedConfig):
self.vision_model = vision_config
else:
path = vision_config.pop("_name_or_path")
self.vision_model = AutoConfig.from_pretrained(
path,
trust_remote_code=True,
**vision_config,
)
# Update config
self.initializer_range = self.text_model.initializer_range
self.num_hidden_layers = self.text_model.num_hidden_layers
self.is_decoder = True