katuni4ka commited on
Commit
d8419e7
·
verified ·
1 Parent(s): 91af5d5

Update modeling_mistral4.py

Browse files
Files changed (1) hide show
  1. modeling_mistral4.py +8 -4
modeling_mistral4.py CHANGED
@@ -116,9 +116,12 @@ class Mistral4MoE(nn.Module):
116
  self.config = config
117
  self.experts = Mistral4NaiveMoe(config)
118
  self.gate = Mistral4TopkRouter(config)
119
- self.shared_experts = Mistral4MLP(
120
- config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
121
- )
 
 
 
122
  self.n_routed_experts = config.n_routed_experts
123
  self.n_group = config.n_group
124
  self.topk_group = config.topk_group
@@ -155,7 +158,8 @@ class Mistral4MoE(nn.Module):
155
  topk_indices, topk_weights = self.route_tokens_to_experts(router_logits)
156
  hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
157
  hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape)
158
- hidden_states = hidden_states + self.shared_experts(residuals)
 
159
  return hidden_states
160
 
161
 
 
116
  self.config = config
117
  self.experts = Mistral4NaiveMoe(config)
118
  self.gate = Mistral4TopkRouter(config)
119
+ if config.n_shared_experts > 0:
120
+ self.shared_experts = Mistral4MLP(
121
+ config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
122
+ )
123
+ else:
124
+ self.shared_experts = None
125
  self.n_routed_experts = config.n_routed_experts
126
  self.n_group = config.n_group
127
  self.topk_group = config.topk_group
 
158
  topk_indices, topk_weights = self.route_tokens_to_experts(router_logits)
159
  hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
160
  hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape)
161
+ if self.shared_experts is not None:
162
+ hidden_states = hidden_states + self.shared_experts(residuals)
163
  return hidden_states
164
 
165