| """Yasa2 model configuration.""" |
|
|
| from typing import Any, Dict, Optional, Union |
|
|
| from transformers import PretrainedConfig |
| try: |
| from transformers.modeling_rope_utils import ( |
| rope_config_validation as _rope_validate, |
| ) |
| except Exception: |
| _rope_validate = None |
|
|
|
|
| class ConvNextConfig(PretrainedConfig): |
| """Configuration for ConvNeXt vision backbones used by Yasa2.""" |
|
|
| model_type = "convnext" |
|
|
| def __init__( |
| self, |
| num_channels: int = 3, |
| patch_size: int = 4, |
| hidden_sizes: list[int] | None = None, |
| hidden_size: Optional[int] = None, |
| depths: list[int] | None = None, |
| hidden_act: str = "gelu", |
| initializer_range: float = 0.02, |
| layer_norm_eps: float = 1e-12, |
| layer_scale_init_value: float = 0, |
| drop_path_rate: float = 0.0, |
| image_size: int = 512, |
| use_grn: bool = True, |
| **kwargs: Any, |
| ): |
| """Initialize ConvNeXt vision configuration. |
| |
| Args: |
| num_channels: Number of input image channels. |
| patch_size: Patch size for the stem. |
| hidden_sizes: Channel sizes per stage. |
| hidden_size: Optional single hidden size (ignored if provided). |
| depths: Number of blocks per stage. |
| hidden_act: Activation function name. |
| initializer_range: Weight init range. |
| layer_norm_eps: Layer norm epsilon. |
| layer_scale_init_value: Layer scale init value. |
| drop_path_rate: Stochastic depth rate. |
| image_size: Base input image size. |
| use_grn: Whether to use GRN (ConvNeXt V2). |
| **kwargs: Passed to PretrainedConfig. |
| """ |
| super().__init__(**kwargs) |
| self.num_channels: int = num_channels |
| self.patch_size: int = patch_size |
| self.hidden_sizes: list[int] = ( |
| [96, 192, 384, 768] if hidden_sizes is None else hidden_sizes |
| ) |
| self.depths: list[int] = [3, 3, 9, 3] if depths is None else depths |
| self.hidden_size: int = self.hidden_sizes[-1] |
| self.hidden_act: str = hidden_act |
| self.initializer_range: float = initializer_range |
| self.layer_norm_eps: float = layer_norm_eps |
| self.layer_scale_init_value: float = layer_scale_init_value |
| self.drop_path_rate: float = drop_path_rate |
| self.image_size: int = image_size |
| self.use_grn: bool = use_grn |
|
|
| @staticmethod |
| def convnextv2_base() -> "ConvNextConfig": |
| """Return ConvNeXt V2 Base config. |
| |
| Returns: |
| ConvNextConfig: Preconfigured base model kwargs. |
| """ |
| return ConvNextConfig( |
| hidden_sizes=[128, 256, 512, 1024], |
| depths=[3, 3, 27, 3], |
| use_grn=True, |
| ) |
|
|
| @staticmethod |
| def convnextv2_large() -> "ConvNextConfig": |
| """Return ConvNeXt V2 Large config. |
| |
| Returns: |
| ConvNextConfig: Preconfigured large-scale model kwargs. |
| """ |
| return ConvNextConfig( |
| hidden_sizes=[192, 384, 768, 1536], |
| depths=[3, 3, 27, 3], |
| use_grn=True, |
| ) |
|
|
| @staticmethod |
| def convnextv2_huge() -> "ConvNextConfig": |
| """Return ConvNeXt V2 Huge config. |
| |
| Returns: |
| ConvNextConfig: Preconfigured huge-scale model kwargs. |
| """ |
| return ConvNextConfig( |
| hidden_sizes=[352, 704, 1408, 2816], |
| depths=[3, 3, 27, 3], |
| use_grn=True, |
| ) |
|
|
|
|
| class YasaConfig(PretrainedConfig): |
| """Configuration for the Yasa language model block.""" |
|
|
| model_type = "yasa_model" |
| keys_to_ignore_at_inference = ["past_key_values"] |
| |
| base_model_tp_plan = { |
| "layers.*.self_attn.q_proj": "colwise", |
| "layers.*.self_attn.k_proj": "colwise", |
| "layers.*.self_attn.v_proj": "colwise", |
| "layers.*.self_attn.o_proj": "rowwise", |
| "layers.*.mlp.gate_proj": "colwise", |
| "layers.*.mlp.up_proj": "colwise", |
| "layers.*.mlp.down_proj": "rowwise", |
| } |
| base_model_pp_plan = { |
| "embed_tokens": (["input_ids"], ["inputs_embeds"]), |
| "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), |
| "norm": (["hidden_states"], ["hidden_states"]), |
| } |
|
|
| def __init__( |
| self, |
| vocab_size: Optional[int] = 100352, |
| hidden_size: Optional[int] = 4096, |
| intermediate_size: Optional[int] = 10880, |
| num_hidden_layers: Optional[int] = 32, |
| num_attention_heads: Optional[int] = 32, |
| num_key_value_heads: Optional[int] = 8, |
| hidden_act: Optional[str] = "silu", |
| max_position_embeddings: Optional[int] = 8192, |
| initializer_range: Optional[float] = 0.02, |
| rms_norm_eps: Optional[int] = 1e-05, |
| use_cache: Optional[bool] = True, |
| pad_token_id: Optional[int] = 100257, |
| bos_token_id: Optional[int] = 100257, |
| eos_token_id: Optional[int] = 100257, |
| pretraining_tp: Optional[int] = 1, |
| tie_word_embeddings: Optional[bool] = False, |
| rope_theta: Optional[float] = 10000.0, |
| rope_scaling: Optional[dict] = None, |
| attention_bias: Optional[bool] = False, |
| attention_dropout: Optional[float] = 0.0, |
| mlp_bias: Optional[bool] = False, |
| head_dim: Optional[int] = None, |
| **kwargs: Any, |
| ): |
| """Initialize the Yasa language model configuration. |
| |
| Args: |
| vocab_size: Vocabulary size for the tokenizer. |
| hidden_size: Transformer hidden size. |
| intermediate_size: MLP intermediate size. |
| num_hidden_layers: Number of Transformer layers. |
| num_attention_heads: Number of attention heads. |
| num_key_value_heads: Number of key/value heads for GQA/MQA. |
| hidden_act: Activation function name. |
| max_position_embeddings: Maximum supported sequence length. |
| initializer_range: Weight init range. |
| rms_norm_eps: RMS norm epsilon. |
| use_cache: Whether to return KV cache. |
| pad_token_id: Padding token id. |
| bos_token_id: Beginning-of-sequence token id. |
| eos_token_id: End-of-sequence token id. |
| pretraining_tp: Tensor parallel shards used during pretraining. |
| tie_word_embeddings: Whether to tie input/output embeddings. |
| rope_theta: Rotary embedding base. |
| rope_scaling: Rotary embedding scaling configuration. |
| attention_bias: Whether attention layers use bias. |
| attention_dropout: Dropout rate for attention probabilities. |
| mlp_bias: Whether MLP layers use bias. |
| head_dim: Per-head dimension override. |
| **kwargs: Passed to PretrainedConfig. |
| |
| Returns: |
| None |
| """ |
| self.vocab_size = vocab_size |
| self.max_position_embeddings = max_position_embeddings |
| self.hidden_size = hidden_size |
| self.intermediate_size = intermediate_size |
| self.num_hidden_layers = num_hidden_layers |
| self.num_attention_heads = num_attention_heads |
|
|
| |
| if num_key_value_heads is None: |
| num_key_value_heads = num_attention_heads |
|
|
| self.num_key_value_heads = num_key_value_heads |
| self.hidden_act = hidden_act |
| self.initializer_range = initializer_range |
| self.rms_norm_eps = rms_norm_eps |
| self.pretraining_tp = pretraining_tp |
| self.use_cache = use_cache |
| self.rope_theta = rope_theta |
| self.rope_scaling = rope_scaling |
| self.attention_bias = attention_bias |
| self.attention_dropout = attention_dropout |
| self.mlp_bias = mlp_bias |
| self.head_dim = ( |
| head_dim |
| if head_dim is not None |
| else self.hidden_size // self.num_attention_heads |
| ) |
| |
| |
| if self.rope_scaling is not None and "type" in self.rope_scaling: |
| self.rope_scaling["rope_type"] = self.rope_scaling["type"] |
| if hasattr(self, "standardize_rope_params"): |
| self.standardize_rope_params() |
| if hasattr(self, "validate_rope"): |
| self.validate_rope() |
| elif _rope_validate is not None: |
| _rope_validate(self) |
|
|
| super().__init__( |
| pad_token_id=pad_token_id, |
| bos_token_id=bos_token_id, |
| eos_token_id=eos_token_id, |
| tie_word_embeddings=tie_word_embeddings, |
| **kwargs, |
| ) |
|
|
|
|
| class Yasa2Config(PretrainedConfig): |
| """Top-level configuration for the Yasa2 multimodal model.""" |
|
|
| model_type = "yasa2" |
| is_composition = True |
| sub_configs = {"vision_config": ConvNextConfig, "text_config": YasaConfig} |
|
|
| def __init__( |
| self, |
| vision_config: Union[ |
| Dict[str, Any], ConvNextConfig |
| ] = ConvNextConfig(), |
| text_config: Union[Dict[str, Any], YasaConfig] = YasaConfig(), |
| num_query_tokens: int = 64, |
| vision_pooling: str = "adaptive_avg", |
| use_vision_pos_embed: bool = True, |
| apply_patch_attention_mask: bool = True, |
| image_token_id: int = 100278, |
| label_ignore_index: int = 100257, |
| **kwargs: Any, |
| ): |
| """Initialize Yasa2 multimodal configuration. |
| |
| Args: |
| vision_config: Vision backbone config. |
| text_config: Text model config. |
| num_query_tokens: Number of query tokens. |
| vision_pooling: Vision pooling strategy. |
| use_vision_pos_embed: Whether to use vision pos embed. |
| apply_patch_attention_mask: Whether to apply patch attention masks to vision features. |
| image_token_id: Token ID for image content tokens. |
| label_ignore_index: Label ignore index for loss (defaults to pad_token_id). |
| **kwargs: Passed to PretrainedConfig. |
| """ |
| |
| auto_map = kwargs.get("auto_map") |
| if isinstance(auto_map, dict): |
| kwargs["auto_map"] = { |
| k: v for k, v in auto_map.items() if k is not None |
| } |
| super().__init__(**kwargs) |
|
|
| if isinstance(vision_config, ConvNextConfig): |
| self.vision_config = vision_config |
| elif vision_config is not None: |
| vision_config = dict(vision_config) |
| vision_config.pop("hidden_size", None) |
| self.vision_config = ConvNextConfig(**vision_config) |
|
|
| if isinstance(text_config, YasaConfig): |
| self.text_config = text_config |
| elif text_config is not None: |
| self.text_config = YasaConfig(**text_config) |
|
|
| self.num_query_tokens: int = num_query_tokens |
| self.vision_pooling: str = vision_pooling |
| self.use_vision_pos_embed: bool = use_vision_pos_embed |
| self.apply_patch_attention_mask: bool = apply_patch_attention_mask |
| self.image_token_id: int = image_token_id |
| self.label_ignore_index: int = label_ignore_index |
|
|
| def to_dict(self) -> Dict[str, Any]: |
| """Serialize configuration to a Python dict. |
| |
| Returns: |
| Dict[str, Any]: Configuration mapping ready for JSON serialization. |
| """ |
| output = super().to_dict() |
| output["vision_config"] = self.vision_config.to_dict() |
| output["text_config"] = self.text_config.to_dict() |
| output["model_type"] = self.__class__.model_type |
| return output |
|
|
| @property |
| def num_attention_heads(self) -> int: |
| """Return the attention head count from the text config. |
| |
| Returns: |
| int: Number of attention heads. |
| """ |
| return self.text_config.num_attention_heads |
|
|
| @property |
| def num_key_value_heads(self) -> int: |
| """Return the key/value head count from the text config. |
| |
| Returns: |
| int: Number of key/value heads. |
| """ |
| return self.text_config.num_key_value_heads |
|
|
| @property |
| def num_hidden_layers(self): |
| """Return the number of hidden layers from the text config. |
| |
| Returns: |
| int: Number of Transformer layers. |
| """ |
| return self.text_config.num_hidden_layers |
|
|
| @property |
| def hidden_size(self): |
| """Return the hidden size from the text config. |
| |
| Returns: |
| int: Transformer hidden dimension. |
| """ |
| return self.text_config.hidden_size |
|
|
| @property |
| def vocab_size(self): |
| """Return the vocabulary size from the text config. |
| |
| Returns: |
| int: Vocabulary size. |
| """ |
| return self.text_config.vocab_size |
|
|
| @property |
| def max_position_embeddings(self): |
| """Return the max position embeddings from the text config. |
| |
| Returns: |
| int: Maximum sequence length supported. |
| """ |
| return self.text_config.max_position_embeddings |
|
|
| @property |
| def initializer_range(self): |
| """Return the initializer range from the text config. |
| |
| Returns: |
| float: Value used for weight initialization scaling. |
| """ |
| return self.text_config.initializer_range |
|
|
|
|
| ConvNextConfig.register_for_auto_class() |
| YasaConfig.register_for_auto_class() |
| Yasa2Config.register_for_auto_class() |
|
|