autoprogrammer commited on
Commit
5aa8b74
·
verified ·
1 Parent(s): 827f943

Update modeling_densebackward_olmoe0125.py

Browse files
Files changed (1) hide show
  1. modeling_densebackward_olmoe0125.py +6 -13
modeling_densebackward_olmoe0125.py CHANGED
@@ -23,9 +23,9 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
23
  final_output: Tensor, shape (batch_size, sequence_length, hidden_dim)
24
  router_logits: Tensor, shape (batch_size * sequence_length, num_experts)
25
  """
26
- def forward(self, hidden_states: torch.Tensor):
27
  """
28
- forward_partscale_fixep_norm_dtch_tau
29
  """
30
  batch_size, seq_length, hidden_dim = hidden_states.shape
31
  dtype = hidden_states.dtype
@@ -37,22 +37,17 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
37
  # Compute routing logic
38
  router_logits = self.gate(flat_hidden).to(dtype=dtype) # (B*L, num_experts)
39
  routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) # (B*L, num_experts)
40
- routing_weights_tau = F.softmax(router_logits / 1.1, dim=1, dtype=torch.float) # (B*L, num_experts)
41
 
42
  # Select top-k experts
43
  routing_weights_topk, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
44
- routing_weights_topk_tau, selected_experts_tau = torch.topk(routing_weights_tau, self.top_k, dim=-1)
45
  if self.norm_topk_prob:
46
  norm_ratio = routing_weights_topk.sum(dim=-1, keepdim=True)
47
  # Normalize top-k routing weights
48
  routing_weights_topk = routing_weights_topk / norm_ratio
49
  # Only scale the selected top-k positions in routing_weights
50
- mask = F.one_hot(selected_experts_tau, num_classes=self.num_experts).sum(dim=1).to(dtype)
51
- routing_weights_topk_tau = routing_weights_tau * mask
52
- norm_ratio_dense = routing_weights_topk_tau.sum(dim=-1, keepdim=True)
53
  # ------------------------------------Choose Section-----------------------------------------------
54
  # current --> partscale_fix_expert implementation
55
- routing_weights_tau = routing_weights_tau * (1.0 - mask) / norm_ratio_dense.detach() + routing_weights_topk_tau / norm_ratio_dense
56
  routing_weights = routing_weights * (1.0 - mask) / norm_ratio.detach() + routing_weights * mask / norm_ratio
57
 
58
  # should be --> the gated implemenation, by comment out the line above and uncomment the two lines below
@@ -63,8 +58,7 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
63
  routing_weights_topk = routing_weights_topk.to(dtype=dtype)
64
 
65
  # Convert full routing_weights to consistent dtype for dense accumulation
66
- # routing_weights = routing_weights.to(dtype=dtype)
67
- routing_weights_tau = routing_weights_tau.to(dtype=dtype)
68
 
69
  # Prepare accumulators: one for dense_outputs, one for sparse_outputs
70
  dense_outputs = torch.zeros((N_tokens, hidden_dim), dtype=dtype, device=device)
@@ -77,14 +71,13 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
77
  expert_layer = self.experts[expert_idx]
78
  # Compute current expert output for all tokens
79
  expert_output = expert_layer(flat_hidden).to(dtype=dtype) # (N_tokens, hidden_dim)
80
- activation_mask = (selected_experts_tau == expert_idx).any(dim=1).float().unsqueeze(-1).to(dtype)
81
  if expert_output.requires_grad:
82
  expert_output.register_hook(lambda grad, mask=activation_mask: grad * mask)
83
  expert_output = expert_output.to(dtype=dtype)
84
  # Dense accumulation: multiply by full routing weight and add
85
- weight_full_tau = routing_weights_tau[:, expert_idx].unsqueeze(-1) # (N_tokens, 1)
86
  weight_full = routing_weights[:, expert_idx].unsqueeze(-1) # (N_tokens, 1)
87
- dense_outputs = dense_outputs + expert_output * (weight_full_tau-weight_full_tau.detach()) + expert_output * weight_full.detach()
88
 
89
  # Sparse accumulation: find tokens where this expert is among top_k
90
  # matches: Boolean mask where selected_experts == expert_idx → shape (N_tokens, top_k)
 
23
  final_output: Tensor, shape (batch_size, sequence_length, hidden_dim)
24
  router_logits: Tensor, shape (batch_size * sequence_length, num_experts)
25
  """
26
+ def forward_partscale_fixep_norm_dtch(self, hidden_states: torch.Tensor):
27
  """
28
+ forward_partscale_fixep_norm_dtch
29
  """
30
  batch_size, seq_length, hidden_dim = hidden_states.shape
31
  dtype = hidden_states.dtype
 
37
  # Compute routing logic
38
  router_logits = self.gate(flat_hidden).to(dtype=dtype) # (B*L, num_experts)
39
  routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) # (B*L, num_experts)
 
40
 
41
  # Select top-k experts
42
  routing_weights_topk, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
 
43
  if self.norm_topk_prob:
44
  norm_ratio = routing_weights_topk.sum(dim=-1, keepdim=True)
45
  # Normalize top-k routing weights
46
  routing_weights_topk = routing_weights_topk / norm_ratio
47
  # Only scale the selected top-k positions in routing_weights
48
+ mask = F.one_hot(selected_experts, num_classes=self.num_experts).sum(dim=1).to(dtype)
 
 
49
  # ------------------------------------Choose Section-----------------------------------------------
50
  # current --> partscale_fix_expert implementation
 
51
  routing_weights = routing_weights * (1.0 - mask) / norm_ratio.detach() + routing_weights * mask / norm_ratio
52
 
53
  # should be --> the gated implemenation, by comment out the line above and uncomment the two lines below
 
58
  routing_weights_topk = routing_weights_topk.to(dtype=dtype)
59
 
60
  # Convert full routing_weights to consistent dtype for dense accumulation
61
+ routing_weights = routing_weights.to(dtype=dtype)
 
62
 
63
  # Prepare accumulators: one for dense_outputs, one for sparse_outputs
64
  dense_outputs = torch.zeros((N_tokens, hidden_dim), dtype=dtype, device=device)
 
71
  expert_layer = self.experts[expert_idx]
72
  # Compute current expert output for all tokens
73
  expert_output = expert_layer(flat_hidden).to(dtype=dtype) # (N_tokens, hidden_dim)
74
+ activation_mask = (selected_experts == expert_idx).any(dim=1).float().unsqueeze(-1).to(dtype)
75
  if expert_output.requires_grad:
76
  expert_output.register_hook(lambda grad, mask=activation_mask: grad * mask)
77
  expert_output = expert_output.to(dtype=dtype)
78
  # Dense accumulation: multiply by full routing weight and add
 
79
  weight_full = routing_weights[:, expert_idx].unsqueeze(-1) # (N_tokens, 1)
80
+ dense_outputs = dense_outputs + expert_output * weight_full
81
 
82
  # Sparse accumulation: find tokens where this expert is among top_k
83
  # matches: Boolean mask where selected_experts == expert_idx → shape (N_tokens, top_k)