| from typing import Any, Optional
|
|
|
| from transformers.models.qwen3 import Qwen3Config
|
|
|
|
|
| class DockGenConfig(Qwen3Config):
|
| model_type = "dockgen"
|
| keys_to_ignore_at_inference = ["past_key_values"]
|
|
|
|
|
| base_model_tp_plan = {
|
| "layers.*.self_attn.q_proj": "colwise",
|
| "layers.*.self_attn.k_proj": "colwise",
|
| "layers.*.self_attn.v_proj": "colwise",
|
| "layers.*.self_attn.o_proj": "rowwise",
|
| "layers.*.mlp.gate_proj": "colwise",
|
| "layers.*.mlp.up_proj": "colwise",
|
| "layers.*.mlp.down_proj": "rowwise",
|
| }
|
| base_model_pp_plan = {
|
| "embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
| "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
| "norm": (["hidden_states"], ["hidden_states"]),
|
| }
|
|
|
| def __init__(
|
| self,
|
| prot_embedding_dim: int = 1024,
|
| mm_token_id: int = 151655,
|
| vocab_size: int = 151936,
|
| hidden_size: int = 4096,
|
| intermediate_size: int = 22016,
|
| num_hidden_layers: int = 32,
|
| num_attention_heads: int = 32,
|
| num_key_value_heads: int = 32,
|
| head_dim: int = 128,
|
| hidden_act: str = "silu",
|
| max_position_embeddings: int = 32768,
|
| initializer_range: float = 0.02,
|
| rms_norm_eps: float = 1e-6,
|
| use_cache: bool = True,
|
| tie_word_embeddings: bool = True,
|
| rope_theta: float = 10000.0,
|
| rope_scaling: Optional[float] = None,
|
| attention_bias: bool = False,
|
| use_sliding_window: bool = False,
|
| sliding_window: int = 4096,
|
| max_window_layers: int = 28,
|
| layer_types: Optional[str] = None,
|
| attention_dropout: float = 0.0,
|
| **kwargs: Any,
|
| ):
|
| self.prot_embedding_dim = prot_embedding_dim
|
| self.mm_token_id = mm_token_id
|
| super().__init__(
|
| vocab_size=vocab_size,
|
| hidden_size=hidden_size,
|
| intermediate_size=intermediate_size,
|
| num_hidden_layers=num_hidden_layers,
|
| num_attention_heads=num_attention_heads,
|
| num_key_value_heads=num_key_value_heads,
|
| head_dim=head_dim,
|
| hidden_act=hidden_act,
|
| max_position_embeddings=max_position_embeddings,
|
| initializer_range=initializer_range,
|
| rms_norm_eps=rms_norm_eps,
|
| use_cache=use_cache,
|
| tie_word_embeddings=tie_word_embeddings,
|
| rope_theta=rope_theta,
|
| rope_scaling=rope_scaling,
|
| attention_bias=attention_bias,
|
| use_sliding_window=use_sliding_window,
|
| sliding_window=sliding_window,
|
| max_window_layers=max_window_layers,
|
| layer_types=layer_types,
|
| attention_dropout=attention_dropout,
|
| **kwargs,
|
| )
|
|
|
| @classmethod
|
| def from_qwen3_config(
|
| cls,
|
| qwen3_config: Qwen3Config,
|
| prot_embedding_dim: int = 1024,
|
| mm_token_id: int = 151655,
|
| **kwargs: Any,
|
| ) -> "DockGenConfig":
|
| """Create a DockGenConfig from a Qwen3Config."""
|
| return cls(
|
| prot_embedding_dim=prot_embedding_dim,
|
| mm_token_id=mm_token_id,
|
| vocab_size=qwen3_config.vocab_size,
|
| hidden_size=qwen3_config.hidden_size,
|
| intermediate_size=qwen3_config.intermediate_size,
|
| num_hidden_layers=qwen3_config.num_hidden_layers,
|
| num_attention_heads=qwen3_config.num_attention_heads,
|
| num_key_value_heads=qwen3_config.num_key_value_heads,
|
| head_dim=qwen3_config.head_dim,
|
| hidden_act=qwen3_config.hidden_act,
|
| max_position_embeddings=qwen3_config.max_position_embeddings,
|
| initializer_range=qwen3_config.initializer_range,
|
| rms_norm_eps=qwen3_config.rms_norm_eps,
|
| use_cache=qwen3_config.use_cache,
|
| tie_word_embeddings=qwen3_config.tie_word_embeddings,
|
| rope_theta=qwen3_config.rope_theta,
|
| rope_scaling=qwen3_config.rope_scaling,
|
| attention_bias=qwen3_config.attention_bias,
|
| use_sliding_window=qwen3_config.use_sliding_window,
|
| sliding_window=qwen3_config.sliding_window,
|
| max_window_layers=qwen3_config.max_window_layers,
|
| layer_types=qwen3_config.layer_types,
|
| attention_dropout=qwen3_config.attention_dropout,
|
| **kwargs,
|
| )
|
|
|