reka-edge-2603 / configuration_yasa2.py
donovanOng92's picture
upload
7d24555 verified
"""Yasa2 model configuration."""
from typing import Any, Dict, Optional, Union
from transformers import PretrainedConfig
try:
from transformers.modeling_rope_utils import (
rope_config_validation as _rope_validate,
)
except Exception:
_rope_validate = None
class ConvNextConfig(PretrainedConfig):
"""Configuration for ConvNeXt vision backbones used by Yasa2."""
model_type = "convnext"
def __init__(
self,
num_channels: int = 3,
patch_size: int = 4,
hidden_sizes: list[int] | None = None,
hidden_size: Optional[int] = None,
depths: list[int] | None = None,
hidden_act: str = "gelu",
initializer_range: float = 0.02,
layer_norm_eps: float = 1e-12,
layer_scale_init_value: float = 0, # Different default for ConvNeXt V2
drop_path_rate: float = 0.0,
image_size: int = 512,
use_grn: bool = True, # Different default for ConvNeXt V2
**kwargs: Any,
):
"""Initialize ConvNeXt vision configuration.
Args:
num_channels: Number of input image channels.
patch_size: Patch size for the stem.
hidden_sizes: Channel sizes per stage.
hidden_size: Optional single hidden size (ignored if provided).
depths: Number of blocks per stage.
hidden_act: Activation function name.
initializer_range: Weight init range.
layer_norm_eps: Layer norm epsilon.
layer_scale_init_value: Layer scale init value.
drop_path_rate: Stochastic depth rate.
image_size: Base input image size.
use_grn: Whether to use GRN (ConvNeXt V2).
**kwargs: Passed to PretrainedConfig.
"""
super().__init__(**kwargs)
self.num_channels: int = num_channels
self.patch_size: int = patch_size
self.hidden_sizes: list[int] = (
[96, 192, 384, 768] if hidden_sizes is None else hidden_sizes
)
self.depths: list[int] = [3, 3, 9, 3] if depths is None else depths
self.hidden_size: int = self.hidden_sizes[-1]
self.hidden_act: str = hidden_act
self.initializer_range: float = initializer_range
self.layer_norm_eps: float = layer_norm_eps
self.layer_scale_init_value: float = layer_scale_init_value
self.drop_path_rate: float = drop_path_rate
self.image_size: int = image_size
self.use_grn: bool = use_grn
@staticmethod
def convnextv2_base() -> "ConvNextConfig":
"""Return ConvNeXt V2 Base config.
Returns:
ConvNextConfig: Preconfigured base model kwargs.
"""
return ConvNextConfig(
hidden_sizes=[128, 256, 512, 1024],
depths=[3, 3, 27, 3],
use_grn=True,
)
@staticmethod
def convnextv2_large() -> "ConvNextConfig":
"""Return ConvNeXt V2 Large config.
Returns:
ConvNextConfig: Preconfigured large-scale model kwargs.
"""
return ConvNextConfig(
hidden_sizes=[192, 384, 768, 1536],
depths=[3, 3, 27, 3],
use_grn=True,
)
@staticmethod
def convnextv2_huge() -> "ConvNextConfig":
"""Return ConvNeXt V2 Huge config.
Returns:
ConvNextConfig: Preconfigured huge-scale model kwargs.
"""
return ConvNextConfig(
hidden_sizes=[352, 704, 1408, 2816],
depths=[3, 3, 27, 3],
use_grn=True,
)
class YasaConfig(PretrainedConfig):
"""Configuration for the Yasa language model block."""
model_type = "yasa_model"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `YasaModel`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__(
self,
vocab_size: Optional[int] = 100352,
hidden_size: Optional[int] = 4096,
intermediate_size: Optional[int] = 10880,
num_hidden_layers: Optional[int] = 32,
num_attention_heads: Optional[int] = 32,
num_key_value_heads: Optional[int] = 8,
hidden_act: Optional[str] = "silu",
max_position_embeddings: Optional[int] = 8192,
initializer_range: Optional[float] = 0.02,
rms_norm_eps: Optional[int] = 1e-05,
use_cache: Optional[bool] = True,
pad_token_id: Optional[int] = 100257,
bos_token_id: Optional[int] = 100257,
eos_token_id: Optional[int] = 100257,
pretraining_tp: Optional[int] = 1,
tie_word_embeddings: Optional[bool] = False,
rope_theta: Optional[float] = 10000.0,
rope_scaling: Optional[dict] = None,
attention_bias: Optional[bool] = False,
attention_dropout: Optional[float] = 0.0,
mlp_bias: Optional[bool] = False,
head_dim: Optional[int] = None,
**kwargs: Any,
):
"""Initialize the Yasa language model configuration.
Args:
vocab_size: Vocabulary size for the tokenizer.
hidden_size: Transformer hidden size.
intermediate_size: MLP intermediate size.
num_hidden_layers: Number of Transformer layers.
num_attention_heads: Number of attention heads.
num_key_value_heads: Number of key/value heads for GQA/MQA.
hidden_act: Activation function name.
max_position_embeddings: Maximum supported sequence length.
initializer_range: Weight init range.
rms_norm_eps: RMS norm epsilon.
use_cache: Whether to return KV cache.
pad_token_id: Padding token id.
bos_token_id: Beginning-of-sequence token id.
eos_token_id: End-of-sequence token id.
pretraining_tp: Tensor parallel shards used during pretraining.
tie_word_embeddings: Whether to tie input/output embeddings.
rope_theta: Rotary embedding base.
rope_scaling: Rotary embedding scaling configuration.
attention_bias: Whether attention layers use bias.
attention_dropout: Dropout rate for attention probabilities.
mlp_bias: Whether MLP layers use bias.
head_dim: Per-head dimension override.
**kwargs: Passed to PretrainedConfig.
Returns:
None
"""
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias
self.head_dim = (
head_dim
if head_dim is not None
else self.hidden_size // self.num_attention_heads
)
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, copy it it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
if hasattr(self, "standardize_rope_params"):
self.standardize_rope_params()
if hasattr(self, "validate_rope"):
self.validate_rope()
elif _rope_validate is not None:
_rope_validate(self)
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
class Yasa2Config(PretrainedConfig):
"""Top-level configuration for the Yasa2 multimodal model."""
model_type = "yasa2"
is_composition = True
sub_configs = {"vision_config": ConvNextConfig, "text_config": YasaConfig}
def __init__(
self,
vision_config: Union[
Dict[str, Any], ConvNextConfig
] = ConvNextConfig(),
text_config: Union[Dict[str, Any], YasaConfig] = YasaConfig(),
num_query_tokens: int = 64,
vision_pooling: str = "adaptive_avg",
use_vision_pos_embed: bool = True,
apply_patch_attention_mask: bool = True,
image_token_id: int = 100278,
label_ignore_index: int = 100257,
**kwargs: Any,
):
"""Initialize Yasa2 multimodal configuration.
Args:
vision_config: Vision backbone config.
text_config: Text model config.
num_query_tokens: Number of query tokens.
vision_pooling: Vision pooling strategy.
use_vision_pos_embed: Whether to use vision pos embed.
apply_patch_attention_mask: Whether to apply patch attention masks to vision features.
image_token_id: Token ID for image content tokens.
label_ignore_index: Label ignore index for loss (defaults to pad_token_id).
**kwargs: Passed to PretrainedConfig.
"""
# Drop None auto_map keys to keep JSON serialization with sort_keys stable.
auto_map = kwargs.get("auto_map")
if isinstance(auto_map, dict):
kwargs["auto_map"] = {
k: v for k, v in auto_map.items() if k is not None
}
super().__init__(**kwargs)
if isinstance(vision_config, ConvNextConfig):
self.vision_config = vision_config
elif vision_config is not None:
vision_config = dict(vision_config)
vision_config.pop("hidden_size", None)
self.vision_config = ConvNextConfig(**vision_config)
if isinstance(text_config, YasaConfig):
self.text_config = text_config
elif text_config is not None:
self.text_config = YasaConfig(**text_config)
self.num_query_tokens: int = num_query_tokens
self.vision_pooling: str = vision_pooling
self.use_vision_pos_embed: bool = use_vision_pos_embed
self.apply_patch_attention_mask: bool = apply_patch_attention_mask
self.image_token_id: int = image_token_id
self.label_ignore_index: int = label_ignore_index
def to_dict(self) -> Dict[str, Any]:
"""Serialize configuration to a Python dict.
Returns:
Dict[str, Any]: Configuration mapping ready for JSON serialization.
"""
output = super().to_dict()
output["vision_config"] = self.vision_config.to_dict()
output["text_config"] = self.text_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
@property
def num_attention_heads(self) -> int:
"""Return the attention head count from the text config.
Returns:
int: Number of attention heads.
"""
return self.text_config.num_attention_heads
@property
def num_key_value_heads(self) -> int:
"""Return the key/value head count from the text config.
Returns:
int: Number of key/value heads.
"""
return self.text_config.num_key_value_heads
@property
def num_hidden_layers(self):
"""Return the number of hidden layers from the text config.
Returns:
int: Number of Transformer layers.
"""
return self.text_config.num_hidden_layers
@property
def hidden_size(self):
"""Return the hidden size from the text config.
Returns:
int: Transformer hidden dimension.
"""
return self.text_config.hidden_size
@property
def vocab_size(self):
"""Return the vocabulary size from the text config.
Returns:
int: Vocabulary size.
"""
return self.text_config.vocab_size
@property
def max_position_embeddings(self):
"""Return the max position embeddings from the text config.
Returns:
int: Maximum sequence length supported.
"""
return self.text_config.max_position_embeddings
@property
def initializer_range(self):
"""Return the initializer range from the text config.
Returns:
float: Value used for weight initialization scaling.
"""
return self.text_config.initializer_range
ConvNextConfig.register_for_auto_class()
YasaConfig.register_for_auto_class()
Yasa2Config.register_for_auto_class()