Update modeling_minicpm.py
Browse files- modeling_minicpm.py +81 -24
modeling_minicpm.py
CHANGED
|
@@ -314,36 +314,93 @@ class MiniCPMMoE(nn.Module):
|
|
| 314 |
)
|
| 315 |
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
|
| 316 |
self.intermediate_size = config.intermediate_size
|
| 317 |
-
|
| 318 |
def forward(self, hidden_states):
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
|
| 330 |
if self.training:
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
balance_loss = self.num_experts * torch.sum(importance_mean * load_mean)
|
| 342 |
|
| 343 |
-
|
|
|
|
| 344 |
else:
|
| 345 |
-
|
| 346 |
-
|
|
|
|
|
|
|
|
|
|
| 347 |
|
| 348 |
@torch.no_grad()
|
| 349 |
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
|
|
|
| 314 |
)
|
| 315 |
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
|
| 316 |
self.intermediate_size = config.intermediate_size
|
| 317 |
+
|
| 318 |
def forward(self, hidden_states):
|
| 319 |
+
"""
|
| 320 |
+
DenseMixer implementation for MiniCPM MoE:
|
| 321 |
+
- Sparse forward: only top-k experts compute for selected tokens (inference efficiency)
|
| 322 |
+
- Dense backward: all experts receive gradients weighted by full routing weights (better router training)
|
| 323 |
+
- Hook mechanism: only activated tokens produce gradients for each expert (expert sparsity maintained)
|
| 324 |
+
|
| 325 |
+
Forward output uses sparse computation results, backward uses dense gradient via straight-through estimator.
|
| 326 |
+
"""
|
| 327 |
+
batch_size, seq_length, hidden_dim = hidden_states.shape
|
| 328 |
+
dtype = hidden_states.dtype
|
| 329 |
+
device = hidden_states.device
|
| 330 |
+
|
| 331 |
+
flat_hidden = hidden_states.view(-1, hidden_dim) # (B*seq_len, hidden_dim)
|
| 332 |
+
N_tokens = flat_hidden.size(0)
|
| 333 |
+
|
| 334 |
+
# Compute routing logic
|
| 335 |
+
router_logits = self.gate(flat_hidden).to(dtype=dtype) # (N_tokens, num_experts)
|
| 336 |
+
routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32) # (N_tokens, num_experts)
|
| 337 |
+
|
| 338 |
+
# Select top-k experts
|
| 339 |
+
routing_weights_topk, selected_experts = torch.topk(routing_weights, self.num_experts_per_tok, dim=-1)
|
| 340 |
+
|
| 341 |
+
# MiniCPM always normalizes top-k weights, so we need to adjust routing_weights accordingly
|
| 342 |
+
norm_ratio = routing_weights_topk.sum(dim=-1, keepdim=True)
|
| 343 |
+
# Normalize top-k routing weights
|
| 344 |
+
routing_weights_topk = routing_weights_topk / norm_ratio
|
| 345 |
+
|
| 346 |
+
# Adjust full routing_weights: scale top-k positions by norm_ratio to match sparse computation
|
| 347 |
+
# This ensures dense_outputs and sparse_outputs use consistent weights
|
| 348 |
+
mask = F.one_hot(selected_experts, num_classes=self.num_experts).sum(dim=1).to(dtype)
|
| 349 |
+
# Scale selected experts by norm_ratio, non-selected experts remain unchanged (but won't contribute due to hook)
|
| 350 |
+
routing_weights = routing_weights * (1.0 - mask) / norm_ratio.detach() + routing_weights * mask / norm_ratio
|
| 351 |
+
|
| 352 |
+
routing_weights_topk = routing_weights_topk.to(dtype=dtype)
|
| 353 |
+
routing_weights = routing_weights.to(dtype=dtype)
|
| 354 |
|
| 355 |
if self.training:
|
| 356 |
+
# DenseMixer training mode: sparse forward, dense backward
|
| 357 |
+
# Prepare accumulators: one for dense_outputs, one for sparse_outputs
|
| 358 |
+
dense_outputs = torch.zeros((N_tokens, hidden_dim), dtype=dtype, device=device)
|
| 359 |
+
sparse_outputs = torch.zeros((N_tokens, hidden_dim), dtype=dtype, device=device)
|
| 360 |
+
|
| 361 |
+
# Iterate through all experts
|
| 362 |
+
for expert_idx in range(self.num_experts):
|
| 363 |
+
expert_layer = self.experts[expert_idx]
|
| 364 |
+
# Compute current expert output for all tokens (dense forward for gradient)
|
| 365 |
+
expert_output = expert_layer(flat_hidden).to(dtype=dtype) # (N_tokens, hidden_dim)
|
| 366 |
+
|
| 367 |
+
# Register hook to mask non-selected token gradients
|
| 368 |
+
# This ensures expert parameters only update for activated tokens
|
| 369 |
+
activation_mask = (selected_experts == expert_idx).any(dim=1).float().unsqueeze(-1).to(dtype)
|
| 370 |
+
if expert_output.requires_grad:
|
| 371 |
+
expert_output.register_hook(lambda grad, mask=activation_mask: grad * mask)
|
| 372 |
+
|
| 373 |
+
# Dense accumulation: multiply by full routing weight and add
|
| 374 |
+
weight_full = routing_weights[:, expert_idx].unsqueeze(-1) # (N_tokens, 1)
|
| 375 |
+
dense_outputs = dense_outputs + expert_output * weight_full
|
| 376 |
+
|
| 377 |
+
# Sparse accumulation: find tokens where this expert is among top-k
|
| 378 |
+
matches = (selected_experts == expert_idx)
|
| 379 |
+
if matches.any():
|
| 380 |
+
token_indices, k_indices = torch.where(matches)
|
| 381 |
+
weights_topk = routing_weights_topk[token_indices, k_indices].unsqueeze(-1) # (num_matches, 1)
|
| 382 |
+
sparse_outputs[token_indices] = sparse_outputs[token_indices] + expert_output[token_indices] * weights_topk
|
| 383 |
+
|
| 384 |
+
# Combine sparse forward output and dense backward output using straight-through estimator
|
| 385 |
+
# Forward: sparse_outputs, Backward: dense_outputs
|
| 386 |
+
final_flat = sparse_outputs.detach() + (dense_outputs - dense_outputs.detach())
|
| 387 |
+
final_flat = final_flat.to(dtype=dtype)
|
| 388 |
+
final_output = final_flat.view(batch_size, seq_length, hidden_dim)
|
| 389 |
+
|
| 390 |
+
# Compute balance loss
|
| 391 |
+
load = selected_experts.view(-1).bincount(minlength=self.num_experts)
|
| 392 |
+
load_mean = load.float() / (N_tokens * self.num_experts_per_tok)
|
| 393 |
+
importance_mean = F.softmax(router_logits, dim=-1, dtype=torch.float32).mean(dim=0)
|
| 394 |
balance_loss = self.num_experts * torch.sum(importance_mean * load_mean)
|
| 395 |
|
| 396 |
+
final_output = AddAuxiliaryLoss.apply(final_output, balance_loss)
|
| 397 |
+
return final_output
|
| 398 |
else:
|
| 399 |
+
# Inference mode: use original sparse implementation for efficiency
|
| 400 |
+
topk_idx_flat = selected_experts.view(-1)
|
| 401 |
+
expert_weights = routing_weights_topk
|
| 402 |
+
y = self.moe_infer(flat_hidden, topk_idx_flat, expert_weights.view(-1, 1)).view(batch_size, seq_length, hidden_dim)
|
| 403 |
+
return y
|
| 404 |
|
| 405 |
@torch.no_grad()
|
| 406 |
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|