Update modeling_densebackward_olmoe0125.py
Browse files
modeling_densebackward_olmoe0125.py
CHANGED
|
@@ -23,9 +23,9 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
|
|
| 23 |
final_output: Tensor, shape (batch_size, sequence_length, hidden_dim)
|
| 24 |
router_logits: Tensor, shape (batch_size * sequence_length, num_experts)
|
| 25 |
"""
|
| 26 |
-
def
|
| 27 |
"""
|
| 28 |
-
|
| 29 |
"""
|
| 30 |
batch_size, seq_length, hidden_dim = hidden_states.shape
|
| 31 |
dtype = hidden_states.dtype
|
|
@@ -37,22 +37,17 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
|
|
| 37 |
# Compute routing logic
|
| 38 |
router_logits = self.gate(flat_hidden).to(dtype=dtype) # (B*L, num_experts)
|
| 39 |
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) # (B*L, num_experts)
|
| 40 |
-
routing_weights_tau = F.softmax(router_logits / 1.1, dim=1, dtype=torch.float) # (B*L, num_experts)
|
| 41 |
|
| 42 |
# Select top-k experts
|
| 43 |
routing_weights_topk, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
| 44 |
-
routing_weights_topk_tau, selected_experts_tau = torch.topk(routing_weights_tau, self.top_k, dim=-1)
|
| 45 |
if self.norm_topk_prob:
|
| 46 |
norm_ratio = routing_weights_topk.sum(dim=-1, keepdim=True)
|
| 47 |
# Normalize top-k routing weights
|
| 48 |
routing_weights_topk = routing_weights_topk / norm_ratio
|
| 49 |
# Only scale the selected top-k positions in routing_weights
|
| 50 |
-
mask = F.one_hot(
|
| 51 |
-
routing_weights_topk_tau = routing_weights_tau * mask
|
| 52 |
-
norm_ratio_dense = routing_weights_topk_tau.sum(dim=-1, keepdim=True)
|
| 53 |
# ------------------------------------Choose Section-----------------------------------------------
|
| 54 |
# current --> partscale_fix_expert implementation
|
| 55 |
-
routing_weights_tau = routing_weights_tau * (1.0 - mask) / norm_ratio_dense.detach() + routing_weights_topk_tau / norm_ratio_dense
|
| 56 |
routing_weights = routing_weights * (1.0 - mask) / norm_ratio.detach() + routing_weights * mask / norm_ratio
|
| 57 |
|
| 58 |
# should be --> the gated implemenation, by comment out the line above and uncomment the two lines below
|
|
@@ -63,8 +58,7 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
|
|
| 63 |
routing_weights_topk = routing_weights_topk.to(dtype=dtype)
|
| 64 |
|
| 65 |
# Convert full routing_weights to consistent dtype for dense accumulation
|
| 66 |
-
|
| 67 |
-
routing_weights_tau = routing_weights_tau.to(dtype=dtype)
|
| 68 |
|
| 69 |
# Prepare accumulators: one for dense_outputs, one for sparse_outputs
|
| 70 |
dense_outputs = torch.zeros((N_tokens, hidden_dim), dtype=dtype, device=device)
|
|
@@ -77,14 +71,13 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
|
|
| 77 |
expert_layer = self.experts[expert_idx]
|
| 78 |
# Compute current expert output for all tokens
|
| 79 |
expert_output = expert_layer(flat_hidden).to(dtype=dtype) # (N_tokens, hidden_dim)
|
| 80 |
-
activation_mask = (
|
| 81 |
if expert_output.requires_grad:
|
| 82 |
expert_output.register_hook(lambda grad, mask=activation_mask: grad * mask)
|
| 83 |
expert_output = expert_output.to(dtype=dtype)
|
| 84 |
# Dense accumulation: multiply by full routing weight and add
|
| 85 |
-
weight_full_tau = routing_weights_tau[:, expert_idx].unsqueeze(-1) # (N_tokens, 1)
|
| 86 |
weight_full = routing_weights[:, expert_idx].unsqueeze(-1) # (N_tokens, 1)
|
| 87 |
-
dense_outputs = dense_outputs + expert_output *
|
| 88 |
|
| 89 |
# Sparse accumulation: find tokens where this expert is among top_k
|
| 90 |
# matches: Boolean mask where selected_experts == expert_idx → shape (N_tokens, top_k)
|
|
|
|
| 23 |
final_output: Tensor, shape (batch_size, sequence_length, hidden_dim)
|
| 24 |
router_logits: Tensor, shape (batch_size * sequence_length, num_experts)
|
| 25 |
"""
|
| 26 |
+
def forward_partscale_fixep_norm_dtch(self, hidden_states: torch.Tensor):
|
| 27 |
"""
|
| 28 |
+
forward_partscale_fixep_norm_dtch
|
| 29 |
"""
|
| 30 |
batch_size, seq_length, hidden_dim = hidden_states.shape
|
| 31 |
dtype = hidden_states.dtype
|
|
|
|
| 37 |
# Compute routing logic
|
| 38 |
router_logits = self.gate(flat_hidden).to(dtype=dtype) # (B*L, num_experts)
|
| 39 |
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) # (B*L, num_experts)
|
|
|
|
| 40 |
|
| 41 |
# Select top-k experts
|
| 42 |
routing_weights_topk, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
|
|
|
| 43 |
if self.norm_topk_prob:
|
| 44 |
norm_ratio = routing_weights_topk.sum(dim=-1, keepdim=True)
|
| 45 |
# Normalize top-k routing weights
|
| 46 |
routing_weights_topk = routing_weights_topk / norm_ratio
|
| 47 |
# Only scale the selected top-k positions in routing_weights
|
| 48 |
+
mask = F.one_hot(selected_experts, num_classes=self.num_experts).sum(dim=1).to(dtype)
|
|
|
|
|
|
|
| 49 |
# ------------------------------------Choose Section-----------------------------------------------
|
| 50 |
# current --> partscale_fix_expert implementation
|
|
|
|
| 51 |
routing_weights = routing_weights * (1.0 - mask) / norm_ratio.detach() + routing_weights * mask / norm_ratio
|
| 52 |
|
| 53 |
# should be --> the gated implemenation, by comment out the line above and uncomment the two lines below
|
|
|
|
| 58 |
routing_weights_topk = routing_weights_topk.to(dtype=dtype)
|
| 59 |
|
| 60 |
# Convert full routing_weights to consistent dtype for dense accumulation
|
| 61 |
+
routing_weights = routing_weights.to(dtype=dtype)
|
|
|
|
| 62 |
|
| 63 |
# Prepare accumulators: one for dense_outputs, one for sparse_outputs
|
| 64 |
dense_outputs = torch.zeros((N_tokens, hidden_dim), dtype=dtype, device=device)
|
|
|
|
| 71 |
expert_layer = self.experts[expert_idx]
|
| 72 |
# Compute current expert output for all tokens
|
| 73 |
expert_output = expert_layer(flat_hidden).to(dtype=dtype) # (N_tokens, hidden_dim)
|
| 74 |
+
activation_mask = (selected_experts == expert_idx).any(dim=1).float().unsqueeze(-1).to(dtype)
|
| 75 |
if expert_output.requires_grad:
|
| 76 |
expert_output.register_hook(lambda grad, mask=activation_mask: grad * mask)
|
| 77 |
expert_output = expert_output.to(dtype=dtype)
|
| 78 |
# Dense accumulation: multiply by full routing weight and add
|
|
|
|
| 79 |
weight_full = routing_weights[:, expert_idx].unsqueeze(-1) # (N_tokens, 1)
|
| 80 |
+
dense_outputs = dense_outputs + expert_output * weight_full
|
| 81 |
|
| 82 |
# Sparse accumulation: find tokens where this expert is among top_k
|
| 83 |
# matches: Boolean mask where selected_experts == expert_idx → shape (N_tokens, top_k)
|