Step-3.7-Flash / configuration_step3p7.py
danielhanchen's picture
Upload folder using huggingface_hub
6c0b16e verified
from typing import Any, Optional, Sequence, Union
from transformers.configuration_utils import PretrainedConfig
class StepRoboticsVisionEncoderConfig(PretrainedConfig):
model_type = "perception_encoder"
def __init__(
self,
width=1536,
layers=47,
heads=16,
num_channels=3,
image_size=728,
mlp_ratio = 8960/1536,
patch_size=14,
hidden_act="quick_gelu",
layer_norm_eps=1e-5,
ues_cls_token=False,
use_cls_token: Optional[bool] = None,
use_ln_pre=True,
use_ln_post=False,
use_abs_posemb=True,
use_rope2d=True,
ls_init_value=0.1,
**kwargs,
):
self.width = width
self.layers = layers
self.heads = heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.mlp_ratio = mlp_ratio
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
if use_cls_token is None:
use_cls_token = ues_cls_token
self.ues_cls_token = use_cls_token
self.use_cls_token = use_cls_token
self.use_ln_pre = use_ln_pre
self.ls_init_value = ls_init_value
self.use_ln_post = use_ln_post
self.use_abs_posemb = use_abs_posemb
self.use_rope2d = use_rope2d
super().__init__(**kwargs)
class Step3p7TextConfig(PretrainedConfig):
model_type = "step3p5"
architectures = ["Step3p5ForCausalLM"]
def __init__(
self,
hidden_size: int = 4096,
intermediate_size: int = 11264,
num_attention_heads: int = 64,
num_attention_groups: int = 8,
num_hidden_layers: int = 45,
max_seq_len: int = 128000,
vocab_size: int = 128815,
rms_norm_eps: float = 1e-5,
moe_intermediate_size: int = 1280,
moe_num_experts: int = 288,
moe_top_k: int = 8,
rope_theta: float = 10000,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 128000,
share_expert_dims: int = 1280,
share_expert_dim: Optional[int] = None,
head_dim: int = 128,
norm_expert_weight: bool = True,
layer_types: list[str] = None,
sliding_window: Optional[int] = None,
pad_token_id: int = 1,
attention_dropout: float = 0.0,
use_head_wise_attn_gate: bool = False,
use_moe_router_bias: bool = False,
moe_router_activation: str = "softmax",
moe_router_scaling_factor: float = 1.0,
need_fp32_gate: bool = False,
attention_other_setting: Optional[dict[str, Any]] = None,
swiglu_limits: Optional[list[Optional[float]]] = None,
swiglu_limits_shared: Optional[list[Optional[float]]] = None,
use_rope_layers: Optional[list[bool]] = None,
yarn_only_types: Optional[list[str]] = None,
moe_layers_enum: tuple[int] = (3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
35, 36, 37, 38, 39, 40, 41, 42, 43, 44),
**kwargs,
) -> None:
torch_dtype = kwargs.get("torch_dtype")
trim_layer_types = _normalize_per_layer_values(layer_types,
num_hidden_layers)
if isinstance(rope_scaling, dict):
rope_scaling = dict(rope_scaling)
if share_expert_dim is None:
share_expert_dim = share_expert_dims
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.num_attention_groups = num_attention_groups
self.num_hidden_layers = num_hidden_layers
self.max_seq_len = max_seq_len
self.vocab_size = vocab_size
self.rms_norm_eps = rms_norm_eps
self.moe_intermediate_size = moe_intermediate_size
self.moe_num_experts = moe_num_experts
self.moe_top_k = moe_top_k
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.max_position_embeddings = max_position_embeddings
self.share_expert_dim = share_expert_dim
self.head_dim = head_dim
self.norm_expert_weight = norm_expert_weight
self.moe_layers_enum = moe_layers_enum
self.layer_types = trim_layer_types
self.sliding_window = sliding_window
self.pad_token_id = pad_token_id
self.attention_dropout = attention_dropout
self.use_head_wise_attn_gate = use_head_wise_attn_gate
self.use_moe_router_bias = use_moe_router_bias
self.moe_router_activation = moe_router_activation
self.moe_router_scaling_factor = moe_router_scaling_factor
self.need_fp32_gate = need_fp32_gate
self.attention_other_setting = attention_other_setting
self.swiglu_limits = swiglu_limits
self.swiglu_limits_shared = swiglu_limits_shared
self.use_rope_layers = use_rope_layers
self.yarn_only_types = yarn_only_types
super().__init__(**kwargs)
if torch_dtype is not None:
self.torch_dtype = torch_dtype
self.layer_types = layer_types
def to_dict(self):
output = super().to_dict()
torch_dtype = getattr(self, "torch_dtype", None)
if torch_dtype is not None:
output["torch_dtype"] = torch_dtype
return output
def _normalize_per_layer_values(
values: Optional[Sequence[Any]],
num_hidden_layers: int,
) -> Optional[list[Any]]:
if values is None:
return None
normalized = list(values)
if not normalized:
return normalized
if len(normalized) < num_hidden_layers:
normalized.extend([normalized[-1]] *
(num_hidden_layers - len(normalized)))
# Some checkpoints keep MTP/spec layer entries after the decoder layers.
# This config only builds num_hidden_layers decoder layers, and HF strict
# validation requires per-layer fields to match that decoder count.
return normalized[:num_hidden_layers]
class Step3p7Config(PretrainedConfig):
# This loader is a compatibility shim for original Step VL checkpoints
# whose top-level config model_type is `step3p7`.
model_type = "step3p7"
def __init__(
self,
vision_config: Optional[Union[dict, StepRoboticsVisionEncoderConfig]] = None,
text_config: Optional[Union[dict, Step3p7TextConfig]] = None,
understand_projector_stride: int = 2,
projector_bias: bool = False,
image_token_id: int = 151679,
**kwargs,
) -> None:
shared_rope_scaling = kwargs.get("rope_scaling")
if isinstance(shared_rope_scaling, dict):
shared_rope_scaling = dict(shared_rope_scaling)
if vision_config is None:
vision_config = StepRoboticsVisionEncoderConfig()
elif isinstance(vision_config, dict):
vision_config = StepRoboticsVisionEncoderConfig(**vision_config)
self.vision_config = vision_config
if text_config is None:
text_config = Step3p7TextConfig(rope_scaling=shared_rope_scaling)
elif isinstance(text_config, dict):
text_config = dict(text_config)
if shared_rope_scaling is not None and "rope_scaling" not in text_config:
text_config["rope_scaling"] = shared_rope_scaling
text_config = Step3p7TextConfig(**text_config)
elif shared_rope_scaling is not None and text_config.rope_scaling is None:
text_config.rope_scaling = dict(shared_rope_scaling)
self.text_config = text_config
rope_scaling = kwargs.get("rope_scaling")
if isinstance(rope_scaling, dict):
kwargs["rope_scaling"] = dict(rope_scaling)
self.understand_projector_stride = understand_projector_stride
self.projector_bias = projector_bias
self.hidden_size = text_config.hidden_size
self.max_position_embeddings = text_config.max_position_embeddings
self.image_token_id = image_token_id
# Help Auto classes find the correct implementation when saving/loading.
super().__init__(**kwargs)