Jackmin108 commited on
Commit
1fb9eb6
·
1 Parent(s): ad420a7

use tt moe

Browse files
Files changed (3) hide show
  1. config.json +5 -0
  2. configuration_glm4_moe.py +2 -2
  3. modeling_glm4_moe.py +29 -15
config.json CHANGED
@@ -2,6 +2,11 @@
2
  "architectures": [
3
  "Glm4MoeForCausalLM"
4
  ],
 
 
 
 
 
5
  "attention_bias": true,
6
  "attention_dropout": 0.0,
7
  "pad_token_id": 151329,
 
2
  "architectures": [
3
  "Glm4MoeForCausalLM"
4
  ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_glm4_moe.Glm4MoeConfig",
7
+ "AutoModelForCausalLM": "modeling_glm4_moe.Glm4MoeForCausalLM",
8
+ "AutoModel": "modeling_glm4_moe.Glm4MoeModel"
9
+ },
10
  "attention_bias": true,
11
  "attention_dropout": 0.0,
12
  "pad_token_id": 151329,
configuration_glm4_moe.py CHANGED
@@ -19,8 +19,8 @@
19
  # See the License for the specific language governing permissions and
20
  # limitations under the License.
21
 
22
- from ...configuration_utils import PretrainedConfig
23
- from ...modeling_rope_utils import rope_config_validation
24
 
25
 
26
  class Glm4MoeConfig(PretrainedConfig):
 
19
  # See the License for the specific language governing permissions and
20
  # limitations under the License.
21
 
22
+ from transformers.configuration_utils import PretrainedConfig
23
+ from transformers.modeling_rope_utils import rope_config_validation
24
 
25
 
26
  class Glm4MoeConfig(PretrainedConfig):
modeling_glm4_moe.py CHANGED
@@ -25,22 +25,24 @@ import torch
25
  import torch.nn.functional as F
26
  from torch import nn
27
 
28
- from ...activations import ACT2FN
29
- from ...cache_utils import Cache, DynamicCache
30
- from ...generation import GenerationMixin
31
- from ...integrations import use_kernel_forward_from_hub
32
- from ...masking_utils import create_causal_mask
33
- from ...modeling_flash_attention_utils import FlashAttentionKwargs
34
- from ...modeling_layers import GradientCheckpointingLayer
35
- from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
36
- from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
37
- from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
38
- from ...processing_utils import Unpack
39
- from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
40
- from ...utils.deprecation import deprecate_kwarg
41
- from ...utils.generic import check_model_inputs
42
  from .configuration_glm4_moe import Glm4MoeConfig
43
 
 
 
44
 
45
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
46
  """
@@ -354,8 +356,20 @@ class Glm4MoeDecoderLayer(GradientCheckpointingLayer):
354
 
355
  self.self_attn = Glm4MoeAttention(config=config, layer_idx=layer_idx)
356
 
 
 
 
 
 
 
 
 
 
 
 
 
357
  if layer_idx >= config.first_k_dense_replace:
358
- self.mlp = Glm4MoeMoE(config)
359
  else:
360
  self.mlp = Glm4MoeMLP(config)
361
 
 
25
  import torch.nn.functional as F
26
  from torch import nn
27
 
28
+ from transformers.activations import ACT2FN
29
+ from transformers.cache_utils import Cache, DynamicCache
30
+ from transformers.generation import GenerationMixin
31
+ from transformers.integrations import use_kernel_forward_from_hub
32
+ from transformers.masking_utils import create_causal_mask
33
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
34
+ from transformers.modeling_layers import GradientCheckpointingLayer
35
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
36
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
37
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
38
+ from transformers.processing_utils import Unpack
39
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
40
+ from transformers.utils.deprecation import deprecate_kwarg
41
+ from transformers.utils.generic import check_model_inputs
42
  from .configuration_glm4_moe import Glm4MoeConfig
43
 
44
+ from torchtitan.models.moe import MoE, MoEArgs
45
+
46
 
47
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
48
  """
 
356
 
357
  self.self_attn = Glm4MoeAttention(config=config, layer_idx=layer_idx)
358
 
359
+ moe_args = MoEArgs(
360
+ num_experts=config.n_routed_experts,
361
+ num_shared_experts=config.n_shared_experts,
362
+ score_func="sigmoid",
363
+ route_norm=config.norm_topk_prob,
364
+ route_scale=config.routed_scaling_factor,
365
+ score_before_experts=False,
366
+ top_k=config.num_experts_per_tok,
367
+ use_grouped_mm=True,
368
+ load_balance_coeff=1e-3,
369
+ )
370
+
371
  if layer_idx >= config.first_k_dense_replace:
372
+ self.mlp = MoE(moe_args, dim=config.hidden_size, hidden_dim=config.moe_intermediate_size)
373
  else:
374
  self.mlp = Glm4MoeMLP(config)
375