drbh
commited on
Commit
·
cf66620
1
Parent(s):
733f7f4
feat: align outputs and support backward method
Browse files- compare_example.py +17 -0
- gpt_oss_backward.py +181 -0
- gpt_oss_match.py +124 -0
- torch-ext/yamoe/layers.py +183 -57
compare_example.py
CHANGED
|
@@ -130,6 +130,10 @@ model.down_proj_bias.data = down_proj_bias
|
|
| 130 |
model = model.cuda()
|
| 131 |
model.eval()
|
| 132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
torch.cuda.synchronize()
|
| 134 |
torch.cuda.reset_peak_memory_stats()
|
| 135 |
start = time.perf_counter()
|
|
@@ -151,6 +155,19 @@ ref_output_reshaped = ref_output.view(kernel_output.shape)
|
|
| 151 |
# Test yamoe_ref implementation
|
| 152 |
expert_capacity = batch_seq * top_k // num_experts * 2 # Generous capacity
|
| 153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
torch.cuda.synchronize()
|
| 155 |
torch.cuda.reset_peak_memory_stats()
|
| 156 |
start = time.perf_counter()
|
|
|
|
| 130 |
model = model.cuda()
|
| 131 |
model.eval()
|
| 132 |
|
| 133 |
+
# Warmup
|
| 134 |
+
for _ in range(5):
|
| 135 |
+
_ = model(hidden_states, router_indices, routing_weights)
|
| 136 |
+
|
| 137 |
torch.cuda.synchronize()
|
| 138 |
torch.cuda.reset_peak_memory_stats()
|
| 139 |
start = time.perf_counter()
|
|
|
|
| 155 |
# Test yamoe_ref implementation
|
| 156 |
expert_capacity = batch_seq * top_k // num_experts * 2 # Generous capacity
|
| 157 |
|
| 158 |
+
# Warmup
|
| 159 |
+
for _ in range(5):
|
| 160 |
+
_ = binned_experts_ref(
|
| 161 |
+
hidden_states,
|
| 162 |
+
router_indices,
|
| 163 |
+
routing_weights,
|
| 164 |
+
gate_up_proj,
|
| 165 |
+
gate_up_proj_bias,
|
| 166 |
+
down_proj,
|
| 167 |
+
down_proj_bias,
|
| 168 |
+
expert_capacity,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
torch.cuda.synchronize()
|
| 172 |
torch.cuda.reset_peak_memory_stats()
|
| 173 |
start = time.perf_counter()
|
gpt_oss_backward.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# /// script
|
| 2 |
+
# requires-python = "==3.10"
|
| 3 |
+
# dependencies = ["torch==2.7.0", "triton", "numpy", "kernels"]
|
| 4 |
+
# [tool.uv.sources]
|
| 5 |
+
# kernels = { git = "https://github.com/huggingface/kernels.git" }
|
| 6 |
+
# ///
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import sys
|
| 10 |
+
import time
|
| 11 |
+
from kernels import get_kernel, get_local_kernel
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
load_method = 2 # 1: sym, 2: local, 3: hf
|
| 15 |
+
|
| 16 |
+
if load_method == 1:
|
| 17 |
+
sys.path.insert(0, "./torch-ext")
|
| 18 |
+
import yamoe
|
| 19 |
+
elif load_method == 2:
|
| 20 |
+
yamoe = get_local_kernel(Path("result"), "yamoe")
|
| 21 |
+
elif load_method == 3:
|
| 22 |
+
yamoe = get_kernel("drbh/yamoe", revision="v0.1.0")
|
| 23 |
+
|
| 24 |
+
torch.manual_seed(42)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def zero_grads(m):
|
| 28 |
+
for p in m.parameters():
|
| 29 |
+
if p.grad is not None:
|
| 30 |
+
p.grad = None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def benchmark_backward(model, x, tag: str, iters: int = 10, warmup: int = 10):
|
| 34 |
+
x_local = x.detach().clone().requires_grad_(True)
|
| 35 |
+
x_local.retain_grad()
|
| 36 |
+
|
| 37 |
+
# Warmup
|
| 38 |
+
for _ in range(warmup):
|
| 39 |
+
out = model(x_local)
|
| 40 |
+
out = out[0] if isinstance(out, tuple) else out
|
| 41 |
+
loss = out.mean()
|
| 42 |
+
zero_grads(model)
|
| 43 |
+
if x_local.grad is not None:
|
| 44 |
+
x_local.grad = None
|
| 45 |
+
loss.backward()
|
| 46 |
+
|
| 47 |
+
# Benchmark
|
| 48 |
+
torch.cuda.reset_peak_memory_stats()
|
| 49 |
+
torch.cuda.synchronize()
|
| 50 |
+
start = time.perf_counter()
|
| 51 |
+
for _ in range(iters):
|
| 52 |
+
out = model(x_local)
|
| 53 |
+
out = out[0] if isinstance(out, tuple) else out
|
| 54 |
+
loss = out.mean()
|
| 55 |
+
zero_grads(model)
|
| 56 |
+
if x_local.grad is not None:
|
| 57 |
+
x_local.grad = None
|
| 58 |
+
loss.backward()
|
| 59 |
+
torch.cuda.synchronize()
|
| 60 |
+
bwd_ms = (time.perf_counter() - start) * 1e3 / iters
|
| 61 |
+
peak_mem = torch.cuda.max_memory_allocated() / 1024**3 # Convert to GB
|
| 62 |
+
|
| 63 |
+
print(f"[{tag}] backward: {bwd_ms:.2f} ms | peak mem: {peak_mem:.2f} GB")
|
| 64 |
+
return bwd_ms
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def print_grad_norms(m, x, tag):
|
| 68 |
+
print(f"\n[{tag}] Gradient norms:")
|
| 69 |
+
xg = x.grad.norm().item() if x.grad is not None else 0.0
|
| 70 |
+
print(f" input grad: {xg:.6f}")
|
| 71 |
+
for name, p in m.named_parameters():
|
| 72 |
+
if p.grad is None:
|
| 73 |
+
print(f" {name}: None")
|
| 74 |
+
else:
|
| 75 |
+
print(f" {name}: {p.grad.norm().item():.6f}")
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def main():
|
| 79 |
+
ref_moe_cls = yamoe.vendored.gpt_oss_mlp.GptOssMLP
|
| 80 |
+
new_moe_cls = yamoe.Yamoe
|
| 81 |
+
|
| 82 |
+
batch_size, seq_len, hidden_dim = 4, 1024, 2880
|
| 83 |
+
num_experts, top_k = 8, 2
|
| 84 |
+
|
| 85 |
+
config = type("Config", (), {})()
|
| 86 |
+
config.hidden_size = hidden_dim
|
| 87 |
+
config.intermediate_size = hidden_dim
|
| 88 |
+
config.num_local_experts = num_experts
|
| 89 |
+
config.num_experts_per_tok = top_k
|
| 90 |
+
ref_moe = ref_moe_cls(config)
|
| 91 |
+
|
| 92 |
+
print(ref_moe)
|
| 93 |
+
|
| 94 |
+
for p in ref_moe.parameters():
|
| 95 |
+
if p.dim() > 1:
|
| 96 |
+
torch.nn.init.xavier_uniform_(p)
|
| 97 |
+
else:
|
| 98 |
+
torch.nn.init.zeros_(p)
|
| 99 |
+
|
| 100 |
+
x = torch.randn(batch_size, seq_len, hidden_dim, device="cuda")
|
| 101 |
+
ref_moe = ref_moe.cuda()
|
| 102 |
+
|
| 103 |
+
# Test reference implementation backward
|
| 104 |
+
print("\nReference Implementation Backward")
|
| 105 |
+
|
| 106 |
+
# Small warmup
|
| 107 |
+
print(" Warming up...")
|
| 108 |
+
x_warmup = x.detach().requires_grad_(True)
|
| 109 |
+
for _ in range(3):
|
| 110 |
+
out = ref_moe(x_warmup)
|
| 111 |
+
out = out[0] if isinstance(out, tuple) else out
|
| 112 |
+
loss = out.mean()
|
| 113 |
+
zero_grads(ref_moe)
|
| 114 |
+
if x_warmup.grad is not None:
|
| 115 |
+
x_warmup.grad = None
|
| 116 |
+
loss.backward()
|
| 117 |
+
torch.cuda.synchronize()
|
| 118 |
+
|
| 119 |
+
# Run once to get gradient info
|
| 120 |
+
x_ref = x.detach().requires_grad_(True)
|
| 121 |
+
x_ref.retain_grad()
|
| 122 |
+
ref_output = ref_moe(x_ref)
|
| 123 |
+
out = ref_output[0] if isinstance(ref_output, tuple) else ref_output
|
| 124 |
+
print(f" Input shape: {x_ref.shape}")
|
| 125 |
+
print(f" Output shape: {out.shape}")
|
| 126 |
+
print(
|
| 127 |
+
f" Output mean: {out.mean():.6f}, std: {out.std():.6f}, norm: {out.norm():.6f}"
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
loss = out.mean()
|
| 131 |
+
zero_grads(ref_moe)
|
| 132 |
+
loss.backward()
|
| 133 |
+
print_grad_norms(ref_moe, x_ref, "reference")
|
| 134 |
+
|
| 135 |
+
benchmark_backward(ref_moe, x, tag="reference", warmup=10, iters=20)
|
| 136 |
+
|
| 137 |
+
# Switch to YAMOE-backed backward
|
| 138 |
+
print("\nYAMOE-backed Implementation Backward")
|
| 139 |
+
ref_moe.forward = new_moe_cls.forward.__get__(ref_moe)
|
| 140 |
+
ref_moe._routing_weights_buffer = None
|
| 141 |
+
ref_moe._batch_indices_buffer = None
|
| 142 |
+
ref_moe._last_batch_seq = None
|
| 143 |
+
ref_moe._last_num_experts = None
|
| 144 |
+
ref_moe.enable_router_grads = True
|
| 145 |
+
ref_moe.num_experts = num_experts
|
| 146 |
+
ref_moe.top_k = top_k
|
| 147 |
+
|
| 148 |
+
# Small warmup
|
| 149 |
+
print(" Warming up...")
|
| 150 |
+
x_warmup = x.detach().requires_grad_(True)
|
| 151 |
+
for _ in range(3):
|
| 152 |
+
out = ref_moe(x_warmup)
|
| 153 |
+
out = out[0] if isinstance(out, tuple) else out
|
| 154 |
+
loss = out.mean()
|
| 155 |
+
zero_grads(ref_moe)
|
| 156 |
+
if x_warmup.grad is not None:
|
| 157 |
+
x_warmup.grad = None
|
| 158 |
+
loss.backward()
|
| 159 |
+
torch.cuda.synchronize()
|
| 160 |
+
|
| 161 |
+
# Run once to get gradient info
|
| 162 |
+
x_cuda = x.detach().requires_grad_(True)
|
| 163 |
+
x_cuda.retain_grad()
|
| 164 |
+
cuda_output = ref_moe(x_cuda)
|
| 165 |
+
out = cuda_output[0] if isinstance(cuda_output, tuple) else cuda_output
|
| 166 |
+
print(f" Input shape: {x_cuda.shape}")
|
| 167 |
+
print(f" Output shape: {out.shape}")
|
| 168 |
+
print(
|
| 169 |
+
f" Output mean: {out.mean():.6f}, std: {out.std():.6f}, norm: {out.norm():.6f}"
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
loss = out.mean()
|
| 173 |
+
zero_grads(ref_moe)
|
| 174 |
+
loss.backward()
|
| 175 |
+
print_grad_norms(ref_moe, x_cuda, "yamoe-backed")
|
| 176 |
+
|
| 177 |
+
benchmark_backward(ref_moe, x, tag="yamoe-backed", warmup=10, iters=20)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
if __name__ == "__main__":
|
| 181 |
+
main()
|
gpt_oss_match.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# /// script
|
| 2 |
+
# requires-python = "==3.10"
|
| 3 |
+
# dependencies = ["torch==2.7.0", "triton", "numpy", "kernels"]
|
| 4 |
+
# [tool.uv.sources]
|
| 5 |
+
# kernels = { git = "https://github.com/huggingface/kernels.git" }
|
| 6 |
+
# ///
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import sys
|
| 10 |
+
import time
|
| 11 |
+
from kernels import get_kernel, get_local_kernel
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
load_method = 2 # 1: sym, 2: local, 3: hf
|
| 15 |
+
|
| 16 |
+
if load_method == 1:
|
| 17 |
+
sys.path.insert(0, "./torch-ext")
|
| 18 |
+
import yamoe
|
| 19 |
+
elif load_method == 2:
|
| 20 |
+
yamoe = get_local_kernel(Path("result"), "yamoe")
|
| 21 |
+
elif load_method == 3:
|
| 22 |
+
yamoe = get_kernel("drbh/yamoe", revision="v0.1.0")
|
| 23 |
+
|
| 24 |
+
torch.manual_seed(42)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def benchmark_forward(model, x, tag: str, iters: int = 10, warmup: int = 10):
|
| 28 |
+
x_local = x.detach().clone().requires_grad_(False)
|
| 29 |
+
|
| 30 |
+
for _ in range(warmup):
|
| 31 |
+
out = model(x_local)
|
| 32 |
+
out = out[0] if isinstance(out, tuple) else out
|
| 33 |
+
|
| 34 |
+
torch.cuda.reset_peak_memory_stats()
|
| 35 |
+
torch.cuda.synchronize()
|
| 36 |
+
start = time.perf_counter()
|
| 37 |
+
for _ in range(iters):
|
| 38 |
+
out = model(x_local)
|
| 39 |
+
out = out[0] if isinstance(out, tuple) else out
|
| 40 |
+
torch.cuda.synchronize()
|
| 41 |
+
fwd_ms = (time.perf_counter() - start) * 1e3 / iters
|
| 42 |
+
peak_mem = torch.cuda.max_memory_allocated() / 1024**3 # Convert to GB
|
| 43 |
+
|
| 44 |
+
print(f"[{tag}] fwd: {fwd_ms:.2f} ms | peak mem: {peak_mem:.2f} GB")
|
| 45 |
+
return fwd_ms
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def main():
|
| 49 |
+
ref_moe_cls = yamoe.vendored.gpt_oss_mlp.GptOssMLP
|
| 50 |
+
new_moe_cls = yamoe.Yamoe
|
| 51 |
+
|
| 52 |
+
batch_size, seq_len, hidden_dim = 4, 1024, 2880
|
| 53 |
+
num_experts, top_k = 8, 2
|
| 54 |
+
|
| 55 |
+
config = type("Config", (), {})()
|
| 56 |
+
config.hidden_size = hidden_dim
|
| 57 |
+
config.intermediate_size = hidden_dim
|
| 58 |
+
config.num_local_experts = num_experts
|
| 59 |
+
config.num_experts_per_tok = top_k
|
| 60 |
+
ref_moe = ref_moe_cls(config)
|
| 61 |
+
|
| 62 |
+
print(ref_moe)
|
| 63 |
+
|
| 64 |
+
for p in ref_moe.parameters():
|
| 65 |
+
if p.dim() > 1:
|
| 66 |
+
torch.nn.init.xavier_uniform_(p)
|
| 67 |
+
else:
|
| 68 |
+
torch.nn.init.zeros_(p)
|
| 69 |
+
|
| 70 |
+
x = torch.randn(batch_size, seq_len, hidden_dim, device="cuda")
|
| 71 |
+
ref_moe = ref_moe.cuda()
|
| 72 |
+
ref_moe = ref_moe.eval()
|
| 73 |
+
|
| 74 |
+
# Test reference implementation
|
| 75 |
+
print("\nReference Implementation")
|
| 76 |
+
|
| 77 |
+
# Small warmup
|
| 78 |
+
print(" Warming up...")
|
| 79 |
+
for _ in range(3):
|
| 80 |
+
_ = ref_moe(x)
|
| 81 |
+
torch.cuda.synchronize()
|
| 82 |
+
|
| 83 |
+
x_ref = x.detach().requires_grad_(False)
|
| 84 |
+
ref_output = ref_moe(x_ref)
|
| 85 |
+
out = ref_output[0] if isinstance(ref_output, tuple) else ref_output
|
| 86 |
+
print(f" Input shape: {x_ref.shape}")
|
| 87 |
+
print(f" Output shape: {out.shape}")
|
| 88 |
+
print(
|
| 89 |
+
f" Output mean: {out.mean():.6f}, std: {out.std():.6f}, norm: {out.norm():.6f}"
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
benchmark_forward(ref_moe, x, tag="reference", warmup=10, iters=20)
|
| 93 |
+
|
| 94 |
+
# Switch to YAMOE-backed forward
|
| 95 |
+
print("\nYAMOE-backed Implementation")
|
| 96 |
+
ref_moe.forward = new_moe_cls.forward.__get__(ref_moe)
|
| 97 |
+
ref_moe._routing_weights_buffer = None
|
| 98 |
+
ref_moe._batch_indices_buffer = None
|
| 99 |
+
ref_moe._last_batch_seq = None
|
| 100 |
+
ref_moe._last_num_experts = None
|
| 101 |
+
ref_moe.enable_router_grads = False
|
| 102 |
+
ref_moe.num_experts = num_experts
|
| 103 |
+
ref_moe.top_k = top_k
|
| 104 |
+
|
| 105 |
+
# Small warmup
|
| 106 |
+
print(" Warming up...")
|
| 107 |
+
for _ in range(3):
|
| 108 |
+
_ = ref_moe(x)
|
| 109 |
+
torch.cuda.synchronize()
|
| 110 |
+
|
| 111 |
+
x_cuda = x.detach().requires_grad_(False)
|
| 112 |
+
cuda_output = ref_moe(x_cuda)
|
| 113 |
+
out = cuda_output[0] if isinstance(cuda_output, tuple) else cuda_output
|
| 114 |
+
print(f" Input shape: {x_cuda.shape}")
|
| 115 |
+
print(f" Output shape: {out.shape}")
|
| 116 |
+
print(
|
| 117 |
+
f" Output mean: {out.mean():.6f}, std: {out.std():.6f}, norm: {out.norm():.6f}"
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
benchmark_forward(ref_moe, x, tag="yamoe-backed", warmup=10, iters=20)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
if __name__ == "__main__":
|
| 124 |
+
main()
|
torch-ext/yamoe/layers.py
CHANGED
|
@@ -1,95 +1,190 @@
|
|
| 1 |
import torch
|
|
|
|
| 2 |
from ._ops import ops
|
| 3 |
|
| 4 |
|
| 5 |
-
class
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
| 17 |
|
| 18 |
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
batch_size, seq_len, hidden_dim = hidden_states.shape
|
| 20 |
batch_seq = batch_size * seq_len
|
| 21 |
|
| 22 |
num_experts = getattr(self, "num_experts", 128)
|
| 23 |
top_k = getattr(self, "top_k", 4)
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
# Route tokens to experts
|
| 26 |
x_flat = hidden_states.view(-1, hidden_dim)
|
| 27 |
logits = torch.nn.functional.linear(
|
| 28 |
x_flat, self.router.weight, self.router.bias
|
| 29 |
)
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
# Compute top-k
|
| 32 |
if top_k == 1:
|
| 33 |
-
|
| 34 |
else:
|
| 35 |
-
|
| 36 |
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
# Create router scores
|
| 40 |
-
router_scores = (
|
| 41 |
-
|
| 42 |
-
.scatter_(1, router_indices, routing_weights)
|
| 43 |
-
.transpose(0, 1)
|
| 44 |
)
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
self.
|
| 50 |
-
or self._last_batch_seq != batch_seq
|
| 51 |
-
or self._last_num_experts != num_experts
|
| 52 |
-
or self._routing_weights_buffer.device != routing_weights.device
|
| 53 |
-
):
|
| 54 |
-
self._routing_weights_buffer = torch.zeros(
|
| 55 |
-
batch_seq,
|
| 56 |
-
num_experts,
|
| 57 |
-
device=routing_weights.device,
|
| 58 |
-
dtype=routing_weights.dtype,
|
| 59 |
-
)
|
| 60 |
-
self._batch_indices_buffer = (
|
| 61 |
-
torch.arange(batch_seq, device=routing_weights.device)
|
| 62 |
-
.unsqueeze(1)
|
| 63 |
-
.expand(-1, top_k)
|
| 64 |
-
)
|
| 65 |
-
self._last_batch_seq = batch_seq
|
| 66 |
-
self._last_num_experts = num_experts
|
| 67 |
-
else:
|
| 68 |
-
self._routing_weights_buffer.zero_()
|
| 69 |
|
| 70 |
-
|
| 71 |
-
flat_indices = router_indices.view(batch_seq, top_k)
|
| 72 |
-
flat_weights = routing_weights.view(batch_seq, top_k)
|
| 73 |
-
self._routing_weights_buffer[self._batch_indices_buffer, flat_indices] = (
|
| 74 |
-
flat_weights
|
| 75 |
-
)
|
| 76 |
-
|
| 77 |
-
# FIX: Use the correct expert projections
|
| 78 |
-
gate_up = self.experts.gate_up_proj[:, :, : hidden_dim * top_k].contiguous()
|
| 79 |
-
gate_up_bias = self.experts.gate_up_proj_bias[
|
| 80 |
-
:, : hidden_dim * top_k
|
| 81 |
-
].contiguous()
|
| 82 |
|
|
|
|
|
|
|
|
|
|
| 83 |
down_proj = self.experts.down_proj[:, :hidden_dim, :].contiguous()
|
| 84 |
-
|
| 85 |
expert_capacity = batch_seq * top_k // num_experts * 2
|
| 86 |
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
hidden_states.view(-1, hidden_dim),
|
| 91 |
router_indices,
|
| 92 |
-
|
| 93 |
gate_up,
|
| 94 |
gate_up_bias,
|
| 95 |
down_proj,
|
|
@@ -97,8 +192,39 @@ class Yamoe(torch.nn.Module):
|
|
| 97 |
expert_capacity,
|
| 98 |
num_experts,
|
| 99 |
top_k,
|
|
|
|
| 100 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
# Reshape output back to [B, S, H]
|
| 103 |
output = output.view(batch_size, seq_len, hidden_dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
return output, router_scores
|
|
|
|
| 1 |
import torch
|
| 2 |
+
import time
|
| 3 |
from ._ops import ops
|
| 4 |
|
| 5 |
|
| 6 |
+
class _ExpertsFn(torch.autograd.Function):
|
| 7 |
+
@staticmethod
|
| 8 |
+
def forward(
|
| 9 |
+
ctx,
|
| 10 |
+
hidden_flat: torch.Tensor,
|
| 11 |
+
router_indices: torch.Tensor,
|
| 12 |
+
routing_weights_dense: torch.Tensor,
|
| 13 |
+
gate_up_proj: torch.Tensor,
|
| 14 |
+
gate_up_proj_bias: torch.Tensor,
|
| 15 |
+
down_proj: torch.Tensor,
|
| 16 |
+
down_proj_bias: torch.Tensor,
|
| 17 |
+
expert_capacity: int,
|
| 18 |
+
num_experts: int,
|
| 19 |
+
top_k: int,
|
| 20 |
+
enable_router_grads: bool,
|
| 21 |
+
):
|
| 22 |
+
out = ops.experts(
|
| 23 |
+
hidden_flat,
|
| 24 |
+
router_indices,
|
| 25 |
+
routing_weights_dense,
|
| 26 |
+
gate_up_proj,
|
| 27 |
+
gate_up_proj_bias,
|
| 28 |
+
down_proj,
|
| 29 |
+
down_proj_bias,
|
| 30 |
+
expert_capacity,
|
| 31 |
+
num_experts,
|
| 32 |
+
top_k,
|
| 33 |
+
)
|
| 34 |
+
ctx.expert_capacity = expert_capacity
|
| 35 |
+
ctx.num_experts = num_experts
|
| 36 |
+
ctx.top_k = top_k
|
| 37 |
+
ctx.enable_router_grads = bool(enable_router_grads)
|
| 38 |
+
ctx.save_for_backward(
|
| 39 |
+
hidden_flat,
|
| 40 |
+
router_indices,
|
| 41 |
+
routing_weights_dense,
|
| 42 |
+
gate_up_proj,
|
| 43 |
+
gate_up_proj_bias,
|
| 44 |
+
down_proj,
|
| 45 |
+
down_proj_bias,
|
| 46 |
+
)
|
| 47 |
+
return out
|
| 48 |
+
|
| 49 |
+
@staticmethod
|
| 50 |
+
def backward(ctx, grad_output: torch.Tensor):
|
| 51 |
+
(
|
| 52 |
+
hidden_flat,
|
| 53 |
+
router_indices,
|
| 54 |
+
routing_weights_dense,
|
| 55 |
+
gate_up_proj,
|
| 56 |
+
gate_up_proj_bias,
|
| 57 |
+
down_proj,
|
| 58 |
+
down_proj_bias,
|
| 59 |
+
) = ctx.saved_tensors
|
| 60 |
+
|
| 61 |
+
(
|
| 62 |
+
grad_hidden_flat,
|
| 63 |
+
grad_routing_weights,
|
| 64 |
+
grad_gate_up_proj,
|
| 65 |
+
grad_gate_up_proj_bias,
|
| 66 |
+
grad_down_proj,
|
| 67 |
+
grad_down_proj_bias,
|
| 68 |
+
) = ops.experts_backward(
|
| 69 |
+
grad_output,
|
| 70 |
+
hidden_flat,
|
| 71 |
+
router_indices,
|
| 72 |
+
routing_weights_dense,
|
| 73 |
+
gate_up_proj,
|
| 74 |
+
gate_up_proj_bias,
|
| 75 |
+
down_proj,
|
| 76 |
+
down_proj_bias,
|
| 77 |
+
ctx.expert_capacity,
|
| 78 |
+
ctx.num_experts,
|
| 79 |
+
ctx.top_k,
|
| 80 |
+
)
|
| 81 |
|
| 82 |
+
# Return grad for dense routing; autograd handles scatter->softmax->linear
|
| 83 |
+
grad_routing_weights_dense = (
|
| 84 |
+
grad_routing_weights if ctx.enable_router_grads else None
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# Return gradients per input (None for non-differentiable)
|
| 88 |
+
return (
|
| 89 |
+
grad_hidden_flat, # hidden_flat
|
| 90 |
+
None, # router_indices
|
| 91 |
+
grad_routing_weights_dense, # routing_weights_dense
|
| 92 |
+
grad_gate_up_proj,
|
| 93 |
+
grad_gate_up_proj_bias,
|
| 94 |
+
grad_down_proj,
|
| 95 |
+
grad_down_proj_bias,
|
| 96 |
+
None, # expert_capacity
|
| 97 |
+
None, # num_experts
|
| 98 |
+
None, # top_k
|
| 99 |
+
None, # enable_router_grads
|
| 100 |
+
)
|
| 101 |
|
| 102 |
+
|
| 103 |
+
class Yamoe(torch.nn.Module):
|
| 104 |
+
can_torch_compile: bool = False
|
| 105 |
+
|
| 106 |
+
_routing_weights_buffer: torch.Tensor = None
|
| 107 |
+
_batch_indices_buffer: torch.Tensor = None
|
| 108 |
+
_last_batch_seq: int = None
|
| 109 |
+
_last_num_experts: int = None
|
| 110 |
+
enable_router_grads: bool = True
|
| 111 |
|
| 112 |
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 113 |
+
# Initialize runtime attrs when this forward is bound onto a different module (e.g., GptOssMLP)
|
| 114 |
+
if not hasattr(self, "enable_router_grads"):
|
| 115 |
+
# self.enable_router_grads = True
|
| 116 |
+
self.enable_router_grads = False
|
| 117 |
+
if not hasattr(self, "_routing_weights_buffer"):
|
| 118 |
+
self._routing_weights_buffer = None
|
| 119 |
+
self._batch_indices_buffer = None
|
| 120 |
+
self._last_batch_seq = None
|
| 121 |
+
self._last_num_experts = None
|
| 122 |
+
self._timing_enabled = False
|
| 123 |
+
self._timing_stats = {}
|
| 124 |
+
|
| 125 |
batch_size, seq_len, hidden_dim = hidden_states.shape
|
| 126 |
batch_seq = batch_size * seq_len
|
| 127 |
|
| 128 |
num_experts = getattr(self, "num_experts", 128)
|
| 129 |
top_k = getattr(self, "top_k", 4)
|
| 130 |
|
| 131 |
+
# Enable timing if requested
|
| 132 |
+
timing = getattr(self, "_timing_enabled", False)
|
| 133 |
+
|
| 134 |
+
if timing:
|
| 135 |
+
torch.cuda.synchronize()
|
| 136 |
+
t0 = time.perf_counter()
|
| 137 |
+
|
| 138 |
# Route tokens to experts
|
| 139 |
x_flat = hidden_states.view(-1, hidden_dim)
|
| 140 |
logits = torch.nn.functional.linear(
|
| 141 |
x_flat, self.router.weight, self.router.bias
|
| 142 |
)
|
| 143 |
|
| 144 |
+
if timing:
|
| 145 |
+
torch.cuda.synchronize()
|
| 146 |
+
t1 = time.perf_counter()
|
| 147 |
+
self._timing_stats["router"] = (t1 - t0) * 1000
|
| 148 |
+
|
| 149 |
# Compute top-k
|
| 150 |
if top_k == 1:
|
| 151 |
+
routing_logits_topk, router_indices = logits.max(dim=-1, keepdim=True)
|
| 152 |
else:
|
| 153 |
+
routing_logits_topk, router_indices = torch.topk(logits, top_k, dim=-1)
|
| 154 |
|
| 155 |
+
# Match reference path exactly: use F.softmax with explicit dtype
|
| 156 |
+
routing_weights_topk = torch.nn.functional.softmax(
|
| 157 |
+
routing_logits_topk, dim=-1, dtype=routing_logits_topk.dtype
|
| 158 |
+
)
|
| 159 |
|
| 160 |
# Create router scores
|
| 161 |
+
router_scores = torch.zeros_like(logits).scatter_(
|
| 162 |
+
1, router_indices, routing_weights_topk
|
|
|
|
|
|
|
| 163 |
)
|
| 164 |
|
| 165 |
+
if timing:
|
| 166 |
+
torch.cuda.synchronize()
|
| 167 |
+
t2 = time.perf_counter()
|
| 168 |
+
self._timing_stats["topk_softmax"] = (t2 - t1) * 1000
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
+
dense_routing = router_scores # [B*S, E]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
|
| 172 |
+
# Kernel expects shapes: [E, H, 2H] and [E, H, H]
|
| 173 |
+
gate_up = self.experts.gate_up_proj[:, :, : 2 * hidden_dim].contiguous()
|
| 174 |
+
gate_up_bias = self.experts.gate_up_proj_bias[:, : 2 * hidden_dim].contiguous()
|
| 175 |
down_proj = self.experts.down_proj[:, :hidden_dim, :].contiguous()
|
|
|
|
| 176 |
expert_capacity = batch_seq * top_k // num_experts * 2
|
| 177 |
|
| 178 |
+
if timing:
|
| 179 |
+
torch.cuda.synchronize()
|
| 180 |
+
t3 = time.perf_counter()
|
| 181 |
+
|
| 182 |
+
# Compute expert output with custom backward
|
| 183 |
+
if self.enable_router_grads:
|
| 184 |
+
output = _ExpertsFn.apply(
|
| 185 |
hidden_states.view(-1, hidden_dim),
|
| 186 |
router_indices,
|
| 187 |
+
dense_routing,
|
| 188 |
gate_up,
|
| 189 |
gate_up_bias,
|
| 190 |
down_proj,
|
|
|
|
| 192 |
expert_capacity,
|
| 193 |
num_experts,
|
| 194 |
top_k,
|
| 195 |
+
self.enable_router_grads,
|
| 196 |
)
|
| 197 |
+
else:
|
| 198 |
+
with torch.no_grad():
|
| 199 |
+
output = ops.experts(
|
| 200 |
+
hidden_states.view(-1, hidden_dim),
|
| 201 |
+
router_indices,
|
| 202 |
+
dense_routing,
|
| 203 |
+
gate_up,
|
| 204 |
+
gate_up_bias,
|
| 205 |
+
down_proj,
|
| 206 |
+
self.experts.down_proj_bias,
|
| 207 |
+
expert_capacity,
|
| 208 |
+
num_experts,
|
| 209 |
+
top_k,
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
if timing:
|
| 213 |
+
torch.cuda.synchronize()
|
| 214 |
+
t4 = time.perf_counter()
|
| 215 |
+
self._timing_stats["experts_kernel"] = (t4 - t3) * 1000
|
| 216 |
|
| 217 |
# Reshape output back to [B, S, H]
|
| 218 |
output = output.view(batch_size, seq_len, hidden_dim)
|
| 219 |
+
|
| 220 |
+
if timing:
|
| 221 |
+
torch.cuda.synchronize()
|
| 222 |
+
t5 = time.perf_counter()
|
| 223 |
+
self._timing_stats["total"] = (t5 - t0) * 1000
|
| 224 |
+
print(f"\n[Yamoe.forward timing in ms]")
|
| 225 |
+
print(f" Router linear: {self._timing_stats['router']:.3f}")
|
| 226 |
+
print(f" TopK + Softmax: {self._timing_stats['topk_softmax']:.3f}")
|
| 227 |
+
print(f" Experts kernel: {self._timing_stats['experts_kernel']:.3f}")
|
| 228 |
+
print(f" Total: {self._timing_stats['total']:.3f}")
|
| 229 |
+
|
| 230 |
return output, router_scores
|