|
|
from typing import Any, Optional, Union |
|
|
|
|
|
from transformers.configuration_utils import PretrainedConfig |
|
|
from transformers import Qwen3Config |
|
|
|
|
|
|
|
|
class StepRoboticsVisionEncoderConfig(PretrainedConfig): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
width=1536, |
|
|
layers=47, |
|
|
heads=16, |
|
|
num_channels=3, |
|
|
image_size=728, |
|
|
mlp_ratio = 8960/1536, |
|
|
patch_size=14, |
|
|
hidden_act="quick_gelu", |
|
|
layer_norm_eps=1e-5, |
|
|
ues_cls_token=False, |
|
|
use_ln_pre=True, |
|
|
use_ln_post=False, |
|
|
use_abs_posemb=True, |
|
|
use_rope2d=True, |
|
|
ls_init_value=0.1, |
|
|
**kwargs, |
|
|
): |
|
|
self.width = width |
|
|
self.layers = layers |
|
|
self.heads = heads |
|
|
self.num_channels = num_channels |
|
|
self.patch_size = patch_size |
|
|
self.image_size = image_size |
|
|
self.mlp_ratio = mlp_ratio |
|
|
self.layer_norm_eps = layer_norm_eps |
|
|
self.hidden_act = hidden_act |
|
|
self.ues_cls_token = ues_cls_token |
|
|
self.use_ln_pre = use_ln_pre |
|
|
self.ls_init_value = ls_init_value |
|
|
self.use_ln_post = use_ln_post |
|
|
self.use_abs_posemb = use_abs_posemb |
|
|
self.use_rope2d = use_rope2d |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
|
|
|
|
class StepRoboticsConfig(PretrainedConfig): |
|
|
model_type = "step_robotics" |
|
|
architectures = ["StepVLForConditionalGeneration"] |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vision_config: Optional[Union[dict, StepRoboticsVisionEncoderConfig]] = None, |
|
|
text_config: Optional[Union[dict, Qwen3Config]] = None, |
|
|
understand_projector_stride: int = 2, |
|
|
projector_bias: bool = False, |
|
|
image_token_id: int = 151679, |
|
|
**kwargs, |
|
|
) -> None: |
|
|
if vision_config is None: |
|
|
vision_config = StepRoboticsVisionEncoderConfig() |
|
|
elif isinstance(vision_config, dict): |
|
|
vision_config = StepRoboticsVisionEncoderConfig(**vision_config) |
|
|
self.vision_config = vision_config |
|
|
|
|
|
if text_config is None: |
|
|
text_config = Qwen3Config() |
|
|
elif isinstance(text_config, dict): |
|
|
text_config = Qwen3Config(**text_config) |
|
|
self.text_config = text_config |
|
|
|
|
|
self.understand_projector_stride = understand_projector_stride |
|
|
self.projector_bias = projector_bias |
|
|
self.hidden_size = text_config.hidden_size |
|
|
self.image_token_id = image_token_id |
|
|
|
|
|
super().__init__(**kwargs) |
|
|
|