drbh commited on
Commit
cf66620
·
1 Parent(s): 733f7f4

feat: align outputs and support backward method

Browse files
compare_example.py CHANGED
@@ -130,6 +130,10 @@ model.down_proj_bias.data = down_proj_bias
130
  model = model.cuda()
131
  model.eval()
132
 
 
 
 
 
133
  torch.cuda.synchronize()
134
  torch.cuda.reset_peak_memory_stats()
135
  start = time.perf_counter()
@@ -151,6 +155,19 @@ ref_output_reshaped = ref_output.view(kernel_output.shape)
151
  # Test yamoe_ref implementation
152
  expert_capacity = batch_seq * top_k // num_experts * 2 # Generous capacity
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  torch.cuda.synchronize()
155
  torch.cuda.reset_peak_memory_stats()
156
  start = time.perf_counter()
 
130
  model = model.cuda()
131
  model.eval()
132
 
133
+ # Warmup
134
+ for _ in range(5):
135
+ _ = model(hidden_states, router_indices, routing_weights)
136
+
137
  torch.cuda.synchronize()
138
  torch.cuda.reset_peak_memory_stats()
139
  start = time.perf_counter()
 
155
  # Test yamoe_ref implementation
156
  expert_capacity = batch_seq * top_k // num_experts * 2 # Generous capacity
157
 
158
+ # Warmup
159
+ for _ in range(5):
160
+ _ = binned_experts_ref(
161
+ hidden_states,
162
+ router_indices,
163
+ routing_weights,
164
+ gate_up_proj,
165
+ gate_up_proj_bias,
166
+ down_proj,
167
+ down_proj_bias,
168
+ expert_capacity,
169
+ )
170
+
171
  torch.cuda.synchronize()
172
  torch.cuda.reset_peak_memory_stats()
173
  start = time.perf_counter()
gpt_oss_backward.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = "==3.10"
3
+ # dependencies = ["torch==2.7.0", "triton", "numpy", "kernels"]
4
+ # [tool.uv.sources]
5
+ # kernels = { git = "https://github.com/huggingface/kernels.git" }
6
+ # ///
7
+
8
+ import torch
9
+ import sys
10
+ import time
11
+ from kernels import get_kernel, get_local_kernel
12
+ from pathlib import Path
13
+
14
+ load_method = 2 # 1: sym, 2: local, 3: hf
15
+
16
+ if load_method == 1:
17
+ sys.path.insert(0, "./torch-ext")
18
+ import yamoe
19
+ elif load_method == 2:
20
+ yamoe = get_local_kernel(Path("result"), "yamoe")
21
+ elif load_method == 3:
22
+ yamoe = get_kernel("drbh/yamoe", revision="v0.1.0")
23
+
24
+ torch.manual_seed(42)
25
+
26
+
27
+ def zero_grads(m):
28
+ for p in m.parameters():
29
+ if p.grad is not None:
30
+ p.grad = None
31
+
32
+
33
+ def benchmark_backward(model, x, tag: str, iters: int = 10, warmup: int = 10):
34
+ x_local = x.detach().clone().requires_grad_(True)
35
+ x_local.retain_grad()
36
+
37
+ # Warmup
38
+ for _ in range(warmup):
39
+ out = model(x_local)
40
+ out = out[0] if isinstance(out, tuple) else out
41
+ loss = out.mean()
42
+ zero_grads(model)
43
+ if x_local.grad is not None:
44
+ x_local.grad = None
45
+ loss.backward()
46
+
47
+ # Benchmark
48
+ torch.cuda.reset_peak_memory_stats()
49
+ torch.cuda.synchronize()
50
+ start = time.perf_counter()
51
+ for _ in range(iters):
52
+ out = model(x_local)
53
+ out = out[0] if isinstance(out, tuple) else out
54
+ loss = out.mean()
55
+ zero_grads(model)
56
+ if x_local.grad is not None:
57
+ x_local.grad = None
58
+ loss.backward()
59
+ torch.cuda.synchronize()
60
+ bwd_ms = (time.perf_counter() - start) * 1e3 / iters
61
+ peak_mem = torch.cuda.max_memory_allocated() / 1024**3 # Convert to GB
62
+
63
+ print(f"[{tag}] backward: {bwd_ms:.2f} ms | peak mem: {peak_mem:.2f} GB")
64
+ return bwd_ms
65
+
66
+
67
+ def print_grad_norms(m, x, tag):
68
+ print(f"\n[{tag}] Gradient norms:")
69
+ xg = x.grad.norm().item() if x.grad is not None else 0.0
70
+ print(f" input grad: {xg:.6f}")
71
+ for name, p in m.named_parameters():
72
+ if p.grad is None:
73
+ print(f" {name}: None")
74
+ else:
75
+ print(f" {name}: {p.grad.norm().item():.6f}")
76
+
77
+
78
+ def main():
79
+ ref_moe_cls = yamoe.vendored.gpt_oss_mlp.GptOssMLP
80
+ new_moe_cls = yamoe.Yamoe
81
+
82
+ batch_size, seq_len, hidden_dim = 4, 1024, 2880
83
+ num_experts, top_k = 8, 2
84
+
85
+ config = type("Config", (), {})()
86
+ config.hidden_size = hidden_dim
87
+ config.intermediate_size = hidden_dim
88
+ config.num_local_experts = num_experts
89
+ config.num_experts_per_tok = top_k
90
+ ref_moe = ref_moe_cls(config)
91
+
92
+ print(ref_moe)
93
+
94
+ for p in ref_moe.parameters():
95
+ if p.dim() > 1:
96
+ torch.nn.init.xavier_uniform_(p)
97
+ else:
98
+ torch.nn.init.zeros_(p)
99
+
100
+ x = torch.randn(batch_size, seq_len, hidden_dim, device="cuda")
101
+ ref_moe = ref_moe.cuda()
102
+
103
+ # Test reference implementation backward
104
+ print("\nReference Implementation Backward")
105
+
106
+ # Small warmup
107
+ print(" Warming up...")
108
+ x_warmup = x.detach().requires_grad_(True)
109
+ for _ in range(3):
110
+ out = ref_moe(x_warmup)
111
+ out = out[0] if isinstance(out, tuple) else out
112
+ loss = out.mean()
113
+ zero_grads(ref_moe)
114
+ if x_warmup.grad is not None:
115
+ x_warmup.grad = None
116
+ loss.backward()
117
+ torch.cuda.synchronize()
118
+
119
+ # Run once to get gradient info
120
+ x_ref = x.detach().requires_grad_(True)
121
+ x_ref.retain_grad()
122
+ ref_output = ref_moe(x_ref)
123
+ out = ref_output[0] if isinstance(ref_output, tuple) else ref_output
124
+ print(f" Input shape: {x_ref.shape}")
125
+ print(f" Output shape: {out.shape}")
126
+ print(
127
+ f" Output mean: {out.mean():.6f}, std: {out.std():.6f}, norm: {out.norm():.6f}"
128
+ )
129
+
130
+ loss = out.mean()
131
+ zero_grads(ref_moe)
132
+ loss.backward()
133
+ print_grad_norms(ref_moe, x_ref, "reference")
134
+
135
+ benchmark_backward(ref_moe, x, tag="reference", warmup=10, iters=20)
136
+
137
+ # Switch to YAMOE-backed backward
138
+ print("\nYAMOE-backed Implementation Backward")
139
+ ref_moe.forward = new_moe_cls.forward.__get__(ref_moe)
140
+ ref_moe._routing_weights_buffer = None
141
+ ref_moe._batch_indices_buffer = None
142
+ ref_moe._last_batch_seq = None
143
+ ref_moe._last_num_experts = None
144
+ ref_moe.enable_router_grads = True
145
+ ref_moe.num_experts = num_experts
146
+ ref_moe.top_k = top_k
147
+
148
+ # Small warmup
149
+ print(" Warming up...")
150
+ x_warmup = x.detach().requires_grad_(True)
151
+ for _ in range(3):
152
+ out = ref_moe(x_warmup)
153
+ out = out[0] if isinstance(out, tuple) else out
154
+ loss = out.mean()
155
+ zero_grads(ref_moe)
156
+ if x_warmup.grad is not None:
157
+ x_warmup.grad = None
158
+ loss.backward()
159
+ torch.cuda.synchronize()
160
+
161
+ # Run once to get gradient info
162
+ x_cuda = x.detach().requires_grad_(True)
163
+ x_cuda.retain_grad()
164
+ cuda_output = ref_moe(x_cuda)
165
+ out = cuda_output[0] if isinstance(cuda_output, tuple) else cuda_output
166
+ print(f" Input shape: {x_cuda.shape}")
167
+ print(f" Output shape: {out.shape}")
168
+ print(
169
+ f" Output mean: {out.mean():.6f}, std: {out.std():.6f}, norm: {out.norm():.6f}"
170
+ )
171
+
172
+ loss = out.mean()
173
+ zero_grads(ref_moe)
174
+ loss.backward()
175
+ print_grad_norms(ref_moe, x_cuda, "yamoe-backed")
176
+
177
+ benchmark_backward(ref_moe, x, tag="yamoe-backed", warmup=10, iters=20)
178
+
179
+
180
+ if __name__ == "__main__":
181
+ main()
gpt_oss_match.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = "==3.10"
3
+ # dependencies = ["torch==2.7.0", "triton", "numpy", "kernels"]
4
+ # [tool.uv.sources]
5
+ # kernels = { git = "https://github.com/huggingface/kernels.git" }
6
+ # ///
7
+
8
+ import torch
9
+ import sys
10
+ import time
11
+ from kernels import get_kernel, get_local_kernel
12
+ from pathlib import Path
13
+
14
+ load_method = 2 # 1: sym, 2: local, 3: hf
15
+
16
+ if load_method == 1:
17
+ sys.path.insert(0, "./torch-ext")
18
+ import yamoe
19
+ elif load_method == 2:
20
+ yamoe = get_local_kernel(Path("result"), "yamoe")
21
+ elif load_method == 3:
22
+ yamoe = get_kernel("drbh/yamoe", revision="v0.1.0")
23
+
24
+ torch.manual_seed(42)
25
+
26
+
27
+ def benchmark_forward(model, x, tag: str, iters: int = 10, warmup: int = 10):
28
+ x_local = x.detach().clone().requires_grad_(False)
29
+
30
+ for _ in range(warmup):
31
+ out = model(x_local)
32
+ out = out[0] if isinstance(out, tuple) else out
33
+
34
+ torch.cuda.reset_peak_memory_stats()
35
+ torch.cuda.synchronize()
36
+ start = time.perf_counter()
37
+ for _ in range(iters):
38
+ out = model(x_local)
39
+ out = out[0] if isinstance(out, tuple) else out
40
+ torch.cuda.synchronize()
41
+ fwd_ms = (time.perf_counter() - start) * 1e3 / iters
42
+ peak_mem = torch.cuda.max_memory_allocated() / 1024**3 # Convert to GB
43
+
44
+ print(f"[{tag}] fwd: {fwd_ms:.2f} ms | peak mem: {peak_mem:.2f} GB")
45
+ return fwd_ms
46
+
47
+
48
+ def main():
49
+ ref_moe_cls = yamoe.vendored.gpt_oss_mlp.GptOssMLP
50
+ new_moe_cls = yamoe.Yamoe
51
+
52
+ batch_size, seq_len, hidden_dim = 4, 1024, 2880
53
+ num_experts, top_k = 8, 2
54
+
55
+ config = type("Config", (), {})()
56
+ config.hidden_size = hidden_dim
57
+ config.intermediate_size = hidden_dim
58
+ config.num_local_experts = num_experts
59
+ config.num_experts_per_tok = top_k
60
+ ref_moe = ref_moe_cls(config)
61
+
62
+ print(ref_moe)
63
+
64
+ for p in ref_moe.parameters():
65
+ if p.dim() > 1:
66
+ torch.nn.init.xavier_uniform_(p)
67
+ else:
68
+ torch.nn.init.zeros_(p)
69
+
70
+ x = torch.randn(batch_size, seq_len, hidden_dim, device="cuda")
71
+ ref_moe = ref_moe.cuda()
72
+ ref_moe = ref_moe.eval()
73
+
74
+ # Test reference implementation
75
+ print("\nReference Implementation")
76
+
77
+ # Small warmup
78
+ print(" Warming up...")
79
+ for _ in range(3):
80
+ _ = ref_moe(x)
81
+ torch.cuda.synchronize()
82
+
83
+ x_ref = x.detach().requires_grad_(False)
84
+ ref_output = ref_moe(x_ref)
85
+ out = ref_output[0] if isinstance(ref_output, tuple) else ref_output
86
+ print(f" Input shape: {x_ref.shape}")
87
+ print(f" Output shape: {out.shape}")
88
+ print(
89
+ f" Output mean: {out.mean():.6f}, std: {out.std():.6f}, norm: {out.norm():.6f}"
90
+ )
91
+
92
+ benchmark_forward(ref_moe, x, tag="reference", warmup=10, iters=20)
93
+
94
+ # Switch to YAMOE-backed forward
95
+ print("\nYAMOE-backed Implementation")
96
+ ref_moe.forward = new_moe_cls.forward.__get__(ref_moe)
97
+ ref_moe._routing_weights_buffer = None
98
+ ref_moe._batch_indices_buffer = None
99
+ ref_moe._last_batch_seq = None
100
+ ref_moe._last_num_experts = None
101
+ ref_moe.enable_router_grads = False
102
+ ref_moe.num_experts = num_experts
103
+ ref_moe.top_k = top_k
104
+
105
+ # Small warmup
106
+ print(" Warming up...")
107
+ for _ in range(3):
108
+ _ = ref_moe(x)
109
+ torch.cuda.synchronize()
110
+
111
+ x_cuda = x.detach().requires_grad_(False)
112
+ cuda_output = ref_moe(x_cuda)
113
+ out = cuda_output[0] if isinstance(cuda_output, tuple) else cuda_output
114
+ print(f" Input shape: {x_cuda.shape}")
115
+ print(f" Output shape: {out.shape}")
116
+ print(
117
+ f" Output mean: {out.mean():.6f}, std: {out.std():.6f}, norm: {out.norm():.6f}"
118
+ )
119
+
120
+ benchmark_forward(ref_moe, x, tag="yamoe-backed", warmup=10, iters=20)
121
+
122
+
123
+ if __name__ == "__main__":
124
+ main()
torch-ext/yamoe/layers.py CHANGED
@@ -1,95 +1,190 @@
1
  import torch
 
2
  from ._ops import ops
3
 
4
 
5
- class Yamoe(torch.nn.Module):
6
- """Yamoe MoE layer with routing and expert computation"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- can_torch_compile: bool = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- def __init__(self):
11
- super().__init__()
12
- # Pre-allocate buffers to avoid repeated allocations
13
- self._routing_weights_buffer = None
14
- self._batch_indices_buffer = None
15
- self._last_batch_seq = None
16
- self._last_num_experts = None
 
 
17
 
18
  def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
 
 
 
 
 
 
 
 
 
 
 
 
19
  batch_size, seq_len, hidden_dim = hidden_states.shape
20
  batch_seq = batch_size * seq_len
21
 
22
  num_experts = getattr(self, "num_experts", 128)
23
  top_k = getattr(self, "top_k", 4)
24
 
 
 
 
 
 
 
 
25
  # Route tokens to experts
26
  x_flat = hidden_states.view(-1, hidden_dim)
27
  logits = torch.nn.functional.linear(
28
  x_flat, self.router.weight, self.router.bias
29
  )
30
 
 
 
 
 
 
31
  # Compute top-k
32
  if top_k == 1:
33
- routing_weights, router_indices = logits.max(dim=-1, keepdim=True)
34
  else:
35
- routing_weights, router_indices = torch.topk(logits, top_k, dim=-1)
36
 
37
- routing_weights = routing_weights.softmax(dim=-1)
 
 
 
38
 
39
  # Create router scores
40
- router_scores = (
41
- torch.zeros_like(logits)
42
- .scatter_(1, router_indices, routing_weights)
43
- .transpose(0, 1)
44
  )
45
 
46
- # Convert routing_weights to sparse format [batch_seq, num_experts]
47
- # Reuse buffer if possible to reduce allocations
48
- if (
49
- self._routing_weights_buffer is None
50
- or self._last_batch_seq != batch_seq
51
- or self._last_num_experts != num_experts
52
- or self._routing_weights_buffer.device != routing_weights.device
53
- ):
54
- self._routing_weights_buffer = torch.zeros(
55
- batch_seq,
56
- num_experts,
57
- device=routing_weights.device,
58
- dtype=routing_weights.dtype,
59
- )
60
- self._batch_indices_buffer = (
61
- torch.arange(batch_seq, device=routing_weights.device)
62
- .unsqueeze(1)
63
- .expand(-1, top_k)
64
- )
65
- self._last_batch_seq = batch_seq
66
- self._last_num_experts = num_experts
67
- else:
68
- self._routing_weights_buffer.zero_()
69
 
70
- # Fill sparse routing weights
71
- flat_indices = router_indices.view(batch_seq, top_k)
72
- flat_weights = routing_weights.view(batch_seq, top_k)
73
- self._routing_weights_buffer[self._batch_indices_buffer, flat_indices] = (
74
- flat_weights
75
- )
76
-
77
- # FIX: Use the correct expert projections
78
- gate_up = self.experts.gate_up_proj[:, :, : hidden_dim * top_k].contiguous()
79
- gate_up_bias = self.experts.gate_up_proj_bias[
80
- :, : hidden_dim * top_k
81
- ].contiguous()
82
 
 
 
 
83
  down_proj = self.experts.down_proj[:, :hidden_dim, :].contiguous()
84
-
85
  expert_capacity = batch_seq * top_k // num_experts * 2
86
 
87
- with torch.no_grad():
88
- # Compute expert output
89
- output = ops.experts(
 
 
 
 
90
  hidden_states.view(-1, hidden_dim),
91
  router_indices,
92
- self._routing_weights_buffer,
93
  gate_up,
94
  gate_up_bias,
95
  down_proj,
@@ -97,8 +192,39 @@ class Yamoe(torch.nn.Module):
97
  expert_capacity,
98
  num_experts,
99
  top_k,
 
100
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  # Reshape output back to [B, S, H]
103
  output = output.view(batch_size, seq_len, hidden_dim)
 
 
 
 
 
 
 
 
 
 
 
104
  return output, router_scores
 
1
  import torch
2
+ import time
3
  from ._ops import ops
4
 
5
 
6
+ class _ExpertsFn(torch.autograd.Function):
7
+ @staticmethod
8
+ def forward(
9
+ ctx,
10
+ hidden_flat: torch.Tensor,
11
+ router_indices: torch.Tensor,
12
+ routing_weights_dense: torch.Tensor,
13
+ gate_up_proj: torch.Tensor,
14
+ gate_up_proj_bias: torch.Tensor,
15
+ down_proj: torch.Tensor,
16
+ down_proj_bias: torch.Tensor,
17
+ expert_capacity: int,
18
+ num_experts: int,
19
+ top_k: int,
20
+ enable_router_grads: bool,
21
+ ):
22
+ out = ops.experts(
23
+ hidden_flat,
24
+ router_indices,
25
+ routing_weights_dense,
26
+ gate_up_proj,
27
+ gate_up_proj_bias,
28
+ down_proj,
29
+ down_proj_bias,
30
+ expert_capacity,
31
+ num_experts,
32
+ top_k,
33
+ )
34
+ ctx.expert_capacity = expert_capacity
35
+ ctx.num_experts = num_experts
36
+ ctx.top_k = top_k
37
+ ctx.enable_router_grads = bool(enable_router_grads)
38
+ ctx.save_for_backward(
39
+ hidden_flat,
40
+ router_indices,
41
+ routing_weights_dense,
42
+ gate_up_proj,
43
+ gate_up_proj_bias,
44
+ down_proj,
45
+ down_proj_bias,
46
+ )
47
+ return out
48
+
49
+ @staticmethod
50
+ def backward(ctx, grad_output: torch.Tensor):
51
+ (
52
+ hidden_flat,
53
+ router_indices,
54
+ routing_weights_dense,
55
+ gate_up_proj,
56
+ gate_up_proj_bias,
57
+ down_proj,
58
+ down_proj_bias,
59
+ ) = ctx.saved_tensors
60
+
61
+ (
62
+ grad_hidden_flat,
63
+ grad_routing_weights,
64
+ grad_gate_up_proj,
65
+ grad_gate_up_proj_bias,
66
+ grad_down_proj,
67
+ grad_down_proj_bias,
68
+ ) = ops.experts_backward(
69
+ grad_output,
70
+ hidden_flat,
71
+ router_indices,
72
+ routing_weights_dense,
73
+ gate_up_proj,
74
+ gate_up_proj_bias,
75
+ down_proj,
76
+ down_proj_bias,
77
+ ctx.expert_capacity,
78
+ ctx.num_experts,
79
+ ctx.top_k,
80
+ )
81
 
82
+ # Return grad for dense routing; autograd handles scatter->softmax->linear
83
+ grad_routing_weights_dense = (
84
+ grad_routing_weights if ctx.enable_router_grads else None
85
+ )
86
+
87
+ # Return gradients per input (None for non-differentiable)
88
+ return (
89
+ grad_hidden_flat, # hidden_flat
90
+ None, # router_indices
91
+ grad_routing_weights_dense, # routing_weights_dense
92
+ grad_gate_up_proj,
93
+ grad_gate_up_proj_bias,
94
+ grad_down_proj,
95
+ grad_down_proj_bias,
96
+ None, # expert_capacity
97
+ None, # num_experts
98
+ None, # top_k
99
+ None, # enable_router_grads
100
+ )
101
 
102
+
103
+ class Yamoe(torch.nn.Module):
104
+ can_torch_compile: bool = False
105
+
106
+ _routing_weights_buffer: torch.Tensor = None
107
+ _batch_indices_buffer: torch.Tensor = None
108
+ _last_batch_seq: int = None
109
+ _last_num_experts: int = None
110
+ enable_router_grads: bool = True
111
 
112
  def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
113
+ # Initialize runtime attrs when this forward is bound onto a different module (e.g., GptOssMLP)
114
+ if not hasattr(self, "enable_router_grads"):
115
+ # self.enable_router_grads = True
116
+ self.enable_router_grads = False
117
+ if not hasattr(self, "_routing_weights_buffer"):
118
+ self._routing_weights_buffer = None
119
+ self._batch_indices_buffer = None
120
+ self._last_batch_seq = None
121
+ self._last_num_experts = None
122
+ self._timing_enabled = False
123
+ self._timing_stats = {}
124
+
125
  batch_size, seq_len, hidden_dim = hidden_states.shape
126
  batch_seq = batch_size * seq_len
127
 
128
  num_experts = getattr(self, "num_experts", 128)
129
  top_k = getattr(self, "top_k", 4)
130
 
131
+ # Enable timing if requested
132
+ timing = getattr(self, "_timing_enabled", False)
133
+
134
+ if timing:
135
+ torch.cuda.synchronize()
136
+ t0 = time.perf_counter()
137
+
138
  # Route tokens to experts
139
  x_flat = hidden_states.view(-1, hidden_dim)
140
  logits = torch.nn.functional.linear(
141
  x_flat, self.router.weight, self.router.bias
142
  )
143
 
144
+ if timing:
145
+ torch.cuda.synchronize()
146
+ t1 = time.perf_counter()
147
+ self._timing_stats["router"] = (t1 - t0) * 1000
148
+
149
  # Compute top-k
150
  if top_k == 1:
151
+ routing_logits_topk, router_indices = logits.max(dim=-1, keepdim=True)
152
  else:
153
+ routing_logits_topk, router_indices = torch.topk(logits, top_k, dim=-1)
154
 
155
+ # Match reference path exactly: use F.softmax with explicit dtype
156
+ routing_weights_topk = torch.nn.functional.softmax(
157
+ routing_logits_topk, dim=-1, dtype=routing_logits_topk.dtype
158
+ )
159
 
160
  # Create router scores
161
+ router_scores = torch.zeros_like(logits).scatter_(
162
+ 1, router_indices, routing_weights_topk
 
 
163
  )
164
 
165
+ if timing:
166
+ torch.cuda.synchronize()
167
+ t2 = time.perf_counter()
168
+ self._timing_stats["topk_softmax"] = (t2 - t1) * 1000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
+ dense_routing = router_scores # [B*S, E]
 
 
 
 
 
 
 
 
 
 
 
171
 
172
+ # Kernel expects shapes: [E, H, 2H] and [E, H, H]
173
+ gate_up = self.experts.gate_up_proj[:, :, : 2 * hidden_dim].contiguous()
174
+ gate_up_bias = self.experts.gate_up_proj_bias[:, : 2 * hidden_dim].contiguous()
175
  down_proj = self.experts.down_proj[:, :hidden_dim, :].contiguous()
 
176
  expert_capacity = batch_seq * top_k // num_experts * 2
177
 
178
+ if timing:
179
+ torch.cuda.synchronize()
180
+ t3 = time.perf_counter()
181
+
182
+ # Compute expert output with custom backward
183
+ if self.enable_router_grads:
184
+ output = _ExpertsFn.apply(
185
  hidden_states.view(-1, hidden_dim),
186
  router_indices,
187
+ dense_routing,
188
  gate_up,
189
  gate_up_bias,
190
  down_proj,
 
192
  expert_capacity,
193
  num_experts,
194
  top_k,
195
+ self.enable_router_grads,
196
  )
197
+ else:
198
+ with torch.no_grad():
199
+ output = ops.experts(
200
+ hidden_states.view(-1, hidden_dim),
201
+ router_indices,
202
+ dense_routing,
203
+ gate_up,
204
+ gate_up_bias,
205
+ down_proj,
206
+ self.experts.down_proj_bias,
207
+ expert_capacity,
208
+ num_experts,
209
+ top_k,
210
+ )
211
+
212
+ if timing:
213
+ torch.cuda.synchronize()
214
+ t4 = time.perf_counter()
215
+ self._timing_stats["experts_kernel"] = (t4 - t3) * 1000
216
 
217
  # Reshape output back to [B, S, H]
218
  output = output.view(batch_size, seq_len, hidden_dim)
219
+
220
+ if timing:
221
+ torch.cuda.synchronize()
222
+ t5 = time.perf_counter()
223
+ self._timing_stats["total"] = (t5 - t0) * 1000
224
+ print(f"\n[Yamoe.forward timing in ms]")
225
+ print(f" Router linear: {self._timing_stats['router']:.3f}")
226
+ print(f" TopK + Softmax: {self._timing_stats['topk_softmax']:.3f}")
227
+ print(f" Experts kernel: {self._timing_stats['experts_kernel']:.3f}")
228
+ print(f" Total: {self._timing_stats['total']:.3f}")
229
+
230
  return output, router_scores