Charlie81 commited on
Commit
353cce5
·
1 Parent(s): c83cd65

dataclass config

Browse files
Files changed (1) hide show
  1. myolmoe/modeling_myolmoe.py +64 -93
myolmoe/modeling_myolmoe.py CHANGED
@@ -18,102 +18,73 @@ from transformers.utils import logging
18
  from transformers.configuration_utils import PretrainedConfig
19
  from transformers.modeling_rope_utils import rope_config_validation
20
 
 
 
 
 
 
21
  class MyOlmoeConfig(PretrainedConfig):
22
- r"""
23
- This is the configuration class to store the configuration of a [`OlmoeModel`].
24
- [Previous docstring remains the same...]
25
-
26
- Args:
27
- [Previous args remain the same...]
28
- small_expert_intermediate_ratio (`float`, *optional*, defaults to 0.5):
29
- Ratio of intermediate size for small experts compared to regular experts.
30
- small_expert_count (`int`, *optional*, defaults to 64):
31
- Frequency of small experts - every Nth expert will be small.
32
- small_expert_sparsity_coef (`float`, *optional*, defaults to 0.1):
33
- Coefficient for small expert load balancing loss.
34
  """
35
- model_type = "myolmoe"
36
- keys_to_ignore_at_inference = ["past_key_values"]
37
-
38
- def __init__(
39
- self,
40
- vocab_size=50304,
41
- hidden_size=2048,
42
- intermediate_size=2048,
43
- num_hidden_layers=16,
44
- num_attention_heads=16,
45
- num_key_value_heads=None,
46
- hidden_act="silu",
47
- max_position_embeddings=4096,
48
- initializer_range=0.02,
49
- rms_norm_eps=1e-05,
50
- use_cache=True,
51
- pad_token_id=1,
52
- bos_token_id=None,
53
- eos_token_id=50279,
54
- tie_word_embeddings=False,
55
- rope_theta=10000.0,
56
- rope_scaling=None,
57
- attention_bias=False,
58
- attention_dropout=0.0,
59
- clip_qkv=None,
60
- num_experts_per_tok=8,
61
- num_experts=64,
62
- output_router_logits=False,
63
- router_aux_loss_coef=0.01,
64
- norm_topk_prob=False,
65
- small_expert_intermediate_ratio=64,
66
- small_expert_count=64,
67
- small_expert_sparsity_coef=0.1,
68
- small_expert_strategy="constant", # increment
69
- max_small_expert_count=64,
70
- **kwargs,
71
- ):
72
- self.vocab_size = vocab_size
73
- self.max_position_embeddings = max_position_embeddings
74
- self.hidden_size = hidden_size
75
- self.intermediate_size = intermediate_size
76
- self.num_hidden_layers = num_hidden_layers
77
- self.num_attention_heads = num_attention_heads
78
-
79
- # for backward compatibility
80
- if num_key_value_heads is None:
81
- num_key_value_heads = num_attention_heads
82
-
83
- self.num_key_value_heads = num_key_value_heads
84
- self.hidden_act = hidden_act
85
- self.initializer_range = initializer_range
86
- self.rms_norm_eps = rms_norm_eps
87
- self.use_cache = use_cache
88
- self.rope_theta = rope_theta
89
- self.rope_scaling = rope_scaling
90
- self.attention_bias = attention_bias
91
- self.attention_dropout = attention_dropout
92
- self.clip_qkv = clip_qkv
93
- self.num_experts_per_tok = num_experts_per_tok
94
- self.num_experts = num_experts
95
- self.output_router_logits = output_router_logits
96
- self.router_aux_loss_coef = router_aux_loss_coef
97
- self.norm_topk_prob = norm_topk_prob
98
-
99
- # Small expert parameters
100
- self.small_expert_intermediate_ratio = small_expert_intermediate_ratio
101
- self.small_expert_count = small_expert_count
102
- self.small_expert_sparsity_coef = small_expert_sparsity_coef
103
- self.small_expert_strategy = small_expert_strategy
104
- self.max_small_expert_count = max_small_expert_count
105
-
106
- # Validate the correctness of rotary position embeddings parameters
107
- if self.rope_scaling is not None and "type" in self.rope_scaling:
108
- self.rope_scaling["rope_type"] = self.rope_scaling["type"]
109
- rope_config_validation(self)
110
-
111
  super().__init__(
112
- pad_token_id=pad_token_id,
113
- bos_token_id=bos_token_id,
114
- eos_token_id=eos_token_id,
115
- tie_word_embeddings=tie_word_embeddings,
116
- **kwargs,
117
  )
118
 
119
  logger = logging.get_logger(__name__)
 
18
  from transformers.configuration_utils import PretrainedConfig
19
  from transformers.modeling_rope_utils import rope_config_validation
20
 
21
+ from dataclasses import dataclass, field
22
+ from typing import Optional, List, Any
23
+ from transformers import PretrainedConfig
24
+
25
+ @dataclass
26
  class MyOlmoeConfig(PretrainedConfig):
 
 
 
 
 
 
 
 
 
 
 
 
27
  """
28
+ Configuration class for MyOlmoe model.
29
+ """
30
+ model_type: str = "myolmoe"
31
+
32
+ # Core model parameters
33
+ vocab_size: int = 50304
34
+ hidden_size: int = 2048
35
+ intermediate_size: int = 1024
36
+ num_hidden_layers: int = 16
37
+ num_attention_heads: int = 16
38
+ num_key_value_heads: int = 16
39
+ max_position_embeddings: int = 4096
40
+
41
+ # Expert parameters
42
+ num_experts: int = 64
43
+ num_experts_per_tok: int = 2
44
+ num_small_experts: int = 0
45
+ small_expert_count: int = 64
46
+ small_expert_intermediate_ratio: int = 16
47
+ small_expert_intermediate_size: int = 0
48
+ small_expert_sparsity_coef: float = 0.1
49
+ small_expert_strategy: str = "constant"
50
+ max_small_expert_count: int = 64
51
+
52
+ # Attention parameters
53
+ attention_bias: bool = False
54
+ attention_dropout: float = 0.0
55
+ clip_qkv: Optional[float] = None
56
+
57
+ # Normalization and activation
58
+ hidden_act: str = "silu"
59
+ rms_norm_eps: float = 1e-05
60
+ norm_topk_prob: bool = False
61
+
62
+ # Router parameters
63
+ router_aux_loss_coef: float = 0.01
64
+ output_router_logits: bool = False
65
+
66
+ # Training parameters
67
+ initializer_range: float = 0.02
68
+ tie_word_embeddings: bool = False
69
+ use_cache: bool = True
70
+
71
+ # RoPE parameters
72
+ rope_theta: float = 10000.0
73
+ rope_scaling: Optional[dict] = None
74
+
75
+ # Token IDs
76
+ pad_token_id: int = 1
77
+ eos_token_id: int = 50279
78
+
79
+ # Model architecture
80
+ architectures: List[str] = field(default_factory=lambda: ["MyOlmoeForCausalLM"])
81
+
82
+ def __post_init__(self):
83
+ """Post-initialization to ensure compatibility with PretrainedConfig."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  super().__init__(
85
+ pad_token_id=self.pad_token_id,
86
+ eos_token_id=self.eos_token_id,
87
+ **{k: v for k, v in self.__dict__.items() if not k.startswith('_')}
 
 
88
  )
89
 
90
  logger = logging.get_logger(__name__)