autoprogrammer commited on
Commit
996d29a
·
verified ·
1 Parent(s): 599977c

Update modeling_minicpm.py

Browse files
Files changed (1) hide show
  1. modeling_minicpm.py +81 -24
modeling_minicpm.py CHANGED
@@ -314,36 +314,93 @@ class MiniCPMMoE(nn.Module):
314
  )
315
  self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
316
  self.intermediate_size = config.intermediate_size
317
-
318
  def forward(self, hidden_states):
319
- orig_shape = hidden_states.shape
320
- orig_dtype = hidden_states.dtype
321
- hidden_states = hidden_states.view(-1, orig_shape[-1])
322
- token_num = hidden_states.shape[0]
323
- scores = self.gate(hidden_states)
324
- scores_prob = F.softmax(scores, dim=-1, dtype=torch.float32)
325
- expert_weights, expert_indices = torch.topk(scores_prob, self.num_experts_per_tok, dim=-1)
326
- expert_weights = expert_weights / expert_weights.sum(dim=-1, keepdim=True)
327
- topk_idx_flat = expert_indices.view(-1)
328
- expert_weights = expert_weights.to(orig_dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
 
330
  if self.training:
331
- hidden_states = hidden_states.repeat_interleave(self.num_experts_per_tok, dim=0)
332
- y = torch.empty_like(hidden_states)
333
- for i in range(self.num_experts):
334
- y[topk_idx_flat == i] = self.experts[i](hidden_states[topk_idx_flat == i])
335
- y = (y.view(*expert_weights.shape, -1) * expert_weights.unsqueeze(-1)).sum(dim=1)
336
- y = y.view(*orig_shape)
337
-
338
- load = expert_indices.view(-1).bincount(minlength=self.num_experts)
339
- load_mean = load / (token_num * self.num_experts_per_tok)
340
- importance_mean = scores_prob.mean(dim=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  balance_loss = self.num_experts * torch.sum(importance_mean * load_mean)
342
 
343
- y = AddAuxiliaryLoss.apply(y, balance_loss)
 
344
  else:
345
- y = self.moe_infer(hidden_states, topk_idx_flat, expert_weights.view(-1, 1)).view(*orig_shape)
346
- return y
 
 
 
347
 
348
  @torch.no_grad()
349
  def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
 
314
  )
315
  self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
316
  self.intermediate_size = config.intermediate_size
317
+
318
  def forward(self, hidden_states):
319
+ """
320
+ DenseMixer implementation for MiniCPM MoE:
321
+ - Sparse forward: only top-k experts compute for selected tokens (inference efficiency)
322
+ - Dense backward: all experts receive gradients weighted by full routing weights (better router training)
323
+ - Hook mechanism: only activated tokens produce gradients for each expert (expert sparsity maintained)
324
+
325
+ Forward output uses sparse computation results, backward uses dense gradient via straight-through estimator.
326
+ """
327
+ batch_size, seq_length, hidden_dim = hidden_states.shape
328
+ dtype = hidden_states.dtype
329
+ device = hidden_states.device
330
+
331
+ flat_hidden = hidden_states.view(-1, hidden_dim) # (B*seq_len, hidden_dim)
332
+ N_tokens = flat_hidden.size(0)
333
+
334
+ # Compute routing logic
335
+ router_logits = self.gate(flat_hidden).to(dtype=dtype) # (N_tokens, num_experts)
336
+ routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32) # (N_tokens, num_experts)
337
+
338
+ # Select top-k experts
339
+ routing_weights_topk, selected_experts = torch.topk(routing_weights, self.num_experts_per_tok, dim=-1)
340
+
341
+ # MiniCPM always normalizes top-k weights, so we need to adjust routing_weights accordingly
342
+ norm_ratio = routing_weights_topk.sum(dim=-1, keepdim=True)
343
+ # Normalize top-k routing weights
344
+ routing_weights_topk = routing_weights_topk / norm_ratio
345
+
346
+ # Adjust full routing_weights: scale top-k positions by norm_ratio to match sparse computation
347
+ # This ensures dense_outputs and sparse_outputs use consistent weights
348
+ mask = F.one_hot(selected_experts, num_classes=self.num_experts).sum(dim=1).to(dtype)
349
+ # Scale selected experts by norm_ratio, non-selected experts remain unchanged (but won't contribute due to hook)
350
+ routing_weights = routing_weights * (1.0 - mask) / norm_ratio.detach() + routing_weights * mask / norm_ratio
351
+
352
+ routing_weights_topk = routing_weights_topk.to(dtype=dtype)
353
+ routing_weights = routing_weights.to(dtype=dtype)
354
 
355
  if self.training:
356
+ # DenseMixer training mode: sparse forward, dense backward
357
+ # Prepare accumulators: one for dense_outputs, one for sparse_outputs
358
+ dense_outputs = torch.zeros((N_tokens, hidden_dim), dtype=dtype, device=device)
359
+ sparse_outputs = torch.zeros((N_tokens, hidden_dim), dtype=dtype, device=device)
360
+
361
+ # Iterate through all experts
362
+ for expert_idx in range(self.num_experts):
363
+ expert_layer = self.experts[expert_idx]
364
+ # Compute current expert output for all tokens (dense forward for gradient)
365
+ expert_output = expert_layer(flat_hidden).to(dtype=dtype) # (N_tokens, hidden_dim)
366
+
367
+ # Register hook to mask non-selected token gradients
368
+ # This ensures expert parameters only update for activated tokens
369
+ activation_mask = (selected_experts == expert_idx).any(dim=1).float().unsqueeze(-1).to(dtype)
370
+ if expert_output.requires_grad:
371
+ expert_output.register_hook(lambda grad, mask=activation_mask: grad * mask)
372
+
373
+ # Dense accumulation: multiply by full routing weight and add
374
+ weight_full = routing_weights[:, expert_idx].unsqueeze(-1) # (N_tokens, 1)
375
+ dense_outputs = dense_outputs + expert_output * weight_full
376
+
377
+ # Sparse accumulation: find tokens where this expert is among top-k
378
+ matches = (selected_experts == expert_idx)
379
+ if matches.any():
380
+ token_indices, k_indices = torch.where(matches)
381
+ weights_topk = routing_weights_topk[token_indices, k_indices].unsqueeze(-1) # (num_matches, 1)
382
+ sparse_outputs[token_indices] = sparse_outputs[token_indices] + expert_output[token_indices] * weights_topk
383
+
384
+ # Combine sparse forward output and dense backward output using straight-through estimator
385
+ # Forward: sparse_outputs, Backward: dense_outputs
386
+ final_flat = sparse_outputs.detach() + (dense_outputs - dense_outputs.detach())
387
+ final_flat = final_flat.to(dtype=dtype)
388
+ final_output = final_flat.view(batch_size, seq_length, hidden_dim)
389
+
390
+ # Compute balance loss
391
+ load = selected_experts.view(-1).bincount(minlength=self.num_experts)
392
+ load_mean = load.float() / (N_tokens * self.num_experts_per_tok)
393
+ importance_mean = F.softmax(router_logits, dim=-1, dtype=torch.float32).mean(dim=0)
394
  balance_loss = self.num_experts * torch.sum(importance_mean * load_mean)
395
 
396
+ final_output = AddAuxiliaryLoss.apply(final_output, balance_loss)
397
+ return final_output
398
  else:
399
+ # Inference mode: use original sparse implementation for efficiency
400
+ topk_idx_flat = selected_experts.view(-1)
401
+ expert_weights = routing_weights_topk
402
+ y = self.moe_infer(flat_hidden, topk_idx_flat, expert_weights.view(-1, 1)).view(batch_size, seq_length, hidden_dim)
403
+ return y
404
 
405
  @torch.no_grad()
406
  def moe_infer(self, x, flat_expert_indices, flat_expert_weights):