Add torch.compile, CUDA graph, and compiled momentum [skip-build]
Browse files- Newton-Schulz: per-shape torch.compile caching + CUDA graph replay
- Batched momentum: separately compiled nesterov/non-nesterov functions
- Batched Newton-Schulz for MoE experts (bmm/baddbmm)
- Triton matmul_transpose cleanup
- Inline uneven shard handling, remove small_param_numel_threshold
- Raise dynamo recompile_limit for test suite
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
test/conftest.py
CHANGED
|
@@ -9,6 +9,11 @@ from transformers import AutoModelForCausalLM
|
|
| 9 |
logger = logging.getLogger(__name__)
|
| 10 |
logging.basicConfig(level=logging.INFO)
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
SEED = 0xdeadbeef
|
| 13 |
|
| 14 |
|
|
|
|
| 9 |
logger = logging.getLogger(__name__)
|
| 10 |
logging.basicConfig(level=logging.INFO)
|
| 11 |
|
| 12 |
+
# Raise dynamo recompile limit so that compiled momentum (batch_pre_ortho)
|
| 13 |
+
# does not fall back to eager mode when the test suite runs 30+ model
|
| 14 |
+
# configurations with different tensor shapes in a single process.
|
| 15 |
+
torch._dynamo.config.recompile_limit = 64
|
| 16 |
+
|
| 17 |
SEED = 0xdeadbeef
|
| 18 |
|
| 19 |
|
torch-ext/optimizer/core.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import logging
|
| 2 |
import math
|
| 3 |
from dataclasses import dataclass
|
|
|
|
| 4 |
|
| 5 |
import torch
|
| 6 |
-
import torch.distributed as dist
|
| 7 |
from torch.distributed import ProcessGroup
|
| 8 |
from torch.distributed.tensor import DTensor
|
| 9 |
|
|
@@ -31,26 +31,71 @@ class _muon_state:
|
|
| 31 |
qk_clip_state: torch.Tensor | None = None
|
| 32 |
|
| 33 |
|
| 34 |
-
def
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
Args:
|
| 38 |
-
optimizer_state: The optimizer's state dict (self.state in Muon).
|
| 39 |
-
p: Parameter tensor.
|
| 40 |
-
g: Gradient tensor.
|
| 41 |
-
group: Parameter group dict.
|
| 42 |
-
momentum: Momentum coefficient.
|
| 43 |
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
"""
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
|
| 56 |
def update_p(p, u, lr, adjusted_lr, weight_decay):
|
|
@@ -63,14 +108,13 @@ def update_p(p, u, lr, adjusted_lr, weight_decay):
|
|
| 63 |
adjusted_lr: Size-adjusted learning rate.
|
| 64 |
weight_decay: Weight decay coefficient.
|
| 65 |
"""
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
else
|
| 72 |
-
|
| 73 |
-
p.add_(u, alpha=-adjusted_lr)
|
| 74 |
|
| 75 |
|
| 76 |
def adjust_lr_for_muon(lr, param_shape):
|
|
@@ -147,7 +191,7 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
|
|
| 147 |
is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys)
|
| 148 |
|
| 149 |
muon_params, muon_names = [], []
|
| 150 |
-
non_muon_params = []
|
| 151 |
|
| 152 |
for n, p in model.named_parameters():
|
| 153 |
if not p.requires_grad:
|
|
@@ -157,6 +201,10 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
|
|
| 157 |
muon_names.append(n)
|
| 158 |
else:
|
| 159 |
non_muon_params.append(p)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
return [
|
| 162 |
{
|
|
|
|
| 1 |
import logging
|
| 2 |
import math
|
| 3 |
from dataclasses import dataclass
|
| 4 |
+
from typing import List
|
| 5 |
|
| 6 |
import torch
|
|
|
|
| 7 |
from torch.distributed import ProcessGroup
|
| 8 |
from torch.distributed.tensor import DTensor
|
| 9 |
|
|
|
|
| 31 |
qk_clip_state: torch.Tensor | None = None
|
| 32 |
|
| 33 |
|
| 34 |
+
def _batch_momentum(
|
| 35 |
+
grads: List[torch.Tensor],
|
| 36 |
+
momentum_bufs: List[torch.Tensor],
|
| 37 |
+
momentum: torch.Tensor,
|
| 38 |
+
) -> None:
|
| 39 |
+
"""Batched momentum update (no nesterov)."""
|
| 40 |
+
torch._foreach_mul_(momentum_bufs, momentum)
|
| 41 |
+
torch._foreach_add_(momentum_bufs, grads)
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
+
def _batch_momentum_nesterov(
|
| 45 |
+
grads: List[torch.Tensor],
|
| 46 |
+
momentum_bufs: List[torch.Tensor],
|
| 47 |
+
momentum: torch.Tensor,
|
| 48 |
+
) -> None:
|
| 49 |
+
"""Batched momentum update with nesterov correction."""
|
| 50 |
+
torch._foreach_mul_(momentum_bufs, momentum)
|
| 51 |
+
torch._foreach_add_(momentum_bufs, grads)
|
| 52 |
+
nesterov_terms = torch._foreach_mul(momentum_bufs, momentum)
|
| 53 |
+
torch._foreach_add_(grads, nesterov_terms)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
_compiled_momentum: dict[bool, callable] = {}
|
| 57 |
+
_use_momentum_compile = True
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def set_momentum_compile(enabled: bool):
|
| 61 |
+
"""Toggle torch.compile for batched momentum."""
|
| 62 |
+
global _use_momentum_compile
|
| 63 |
+
_use_momentum_compile = enabled
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def batch_pre_ortho(
|
| 67 |
+
grads: List[torch.Tensor],
|
| 68 |
+
momentum_bufs: List[torch.Tensor],
|
| 69 |
+
momentum: torch.Tensor,
|
| 70 |
+
nesterov: bool,
|
| 71 |
+
) -> None:
|
| 72 |
+
"""Batched momentum update on lists of plain tensors.
|
| 73 |
+
|
| 74 |
+
Mirrors dion's ``muon_update_pre_orthogonalize``.
|
| 75 |
+
Inputs must be plain CUDA tensors (not DTensor).
|
| 76 |
+
Modifies ``momentum_bufs`` and (for nesterov) ``grads`` in-place.
|
| 77 |
+
|
| 78 |
+
When compile is enabled, uses separately compiled functions for
|
| 79 |
+
nesterov=True/False to avoid graph breaks from the branch.
|
| 80 |
"""
|
| 81 |
+
fn = _batch_momentum_nesterov if nesterov else _batch_momentum
|
| 82 |
+
if _use_momentum_compile:
|
| 83 |
+
if nesterov not in _compiled_momentum:
|
| 84 |
+
_compiled_momentum[nesterov] = torch.compile(fn)
|
| 85 |
+
fn = _compiled_momentum[nesterov]
|
| 86 |
+
fn(grads, momentum_bufs, momentum)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay):
|
| 90 |
+
"""Weight-decay + update on plain tensors.
|
| 91 |
+
|
| 92 |
+
Not compiled: per-param @torch.compile caused ~0.25ms TorchDynamo cache
|
| 93 |
+
lookup per call × 256+ params = massive overhead. The pipeline path uses
|
| 94 |
+
batched _foreach_* ops instead; this function remains for base() and
|
| 95 |
+
distributed_muon().
|
| 96 |
+
"""
|
| 97 |
+
p_data.mul_(1 - lr * weight_decay)
|
| 98 |
+
p_data.add_(u_data, alpha=-adjusted_lr)
|
| 99 |
|
| 100 |
|
| 101 |
def update_p(p, u, lr, adjusted_lr, weight_decay):
|
|
|
|
| 108 |
adjusted_lr: Size-adjusted learning rate.
|
| 109 |
weight_decay: Weight decay coefficient.
|
| 110 |
"""
|
| 111 |
+
# Unwrap Parameter -> underlying data tensor.
|
| 112 |
+
p_data = p.data if isinstance(p, torch.nn.Parameter) else p
|
| 113 |
+
# Unwrap DTensor -> local CUDA tensor for compiled kernel.
|
| 114 |
+
if isinstance(p_data, DTensor):
|
| 115 |
+
p_data = p_data._local_tensor
|
| 116 |
+
u_data = u._local_tensor if isinstance(u, DTensor) else u
|
| 117 |
+
_update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay)
|
|
|
|
| 118 |
|
| 119 |
|
| 120 |
def adjust_lr_for_muon(lr, param_shape):
|
|
|
|
| 191 |
is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys)
|
| 192 |
|
| 193 |
muon_params, muon_names = [], []
|
| 194 |
+
non_muon_params, non_muon_names = [], []
|
| 195 |
|
| 196 |
for n, p in model.named_parameters():
|
| 197 |
if not p.requires_grad:
|
|
|
|
| 201 |
muon_names.append(n)
|
| 202 |
else:
|
| 203 |
non_muon_params.append(p)
|
| 204 |
+
non_muon_names.append(n)
|
| 205 |
+
|
| 206 |
+
logger.info("[param_groups] expert_keys=%s, Muon=%d, AdamW=%d",
|
| 207 |
+
expert_keys, len(muon_names), len(non_muon_names))
|
| 208 |
|
| 209 |
return [
|
| 210 |
{
|
torch-ext/optimizer/distributed/utils.py
CHANGED
|
@@ -72,12 +72,6 @@ def get_slices_of_dtensor(
|
|
| 72 |
else:
|
| 73 |
curr_size = target.size()[shard_dim]
|
| 74 |
|
| 75 |
-
if curr_size % num_chunks != 0:
|
| 76 |
-
raise NotImplementedError(
|
| 77 |
-
f"Dimension size {curr_size} is not divisible "
|
| 78 |
-
f"by number of ranks {num_chunks} for shard "
|
| 79 |
-
f"placement on dim {shard_dim}. (shape: {target.shape})")
|
| 80 |
-
|
| 81 |
# Compute indices for this level of sharding
|
| 82 |
if isinstance(placement, _StridedShard):
|
| 83 |
_shard_size, offsets = _StridedShard.local_shard_size_and_offset(
|
|
|
|
| 72 |
else:
|
| 73 |
curr_size = target.size()[shard_dim]
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
# Compute indices for this level of sharding
|
| 76 |
if isinstance(placement, _StridedShard):
|
| 77 |
_shard_size, offsets = _StridedShard.local_shard_size_and_offset(
|
torch-ext/optimizer/matmul_transpose_triton.py
CHANGED
|
@@ -43,6 +43,7 @@ def get_autotune_config():
|
|
| 43 |
@triton.autotune(
|
| 44 |
configs=get_autotune_config(),
|
| 45 |
key=['M', 'K'],
|
|
|
|
| 46 |
)
|
| 47 |
@triton.jit
|
| 48 |
def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
|
|
@@ -102,16 +103,10 @@ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
|
|
| 102 |
tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask)
|
| 103 |
|
| 104 |
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
assert d_in.dtype == d_out.dtype, "Inputs must have the same data type"
|
| 110 |
-
assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor"
|
| 111 |
-
assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor"
|
| 112 |
-
assert d_in.size(0) == d_out.size(0) == d_out.size(0), \
|
| 113 |
-
"First dimension of `d_in` must match first and second dimension of `d_out`"
|
| 114 |
-
|
| 115 |
d_in = d_in.contiguous()
|
| 116 |
M, K = d_in.shape
|
| 117 |
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
|
|
@@ -119,3 +114,9 @@ def matmul_transpose_assign(d_in, d_out):
|
|
| 119 |
with torch.cuda.device(d_in.device.index):
|
| 120 |
mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
|
| 121 |
d_out.stride(0), d_out.stride(1))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
@triton.autotune(
|
| 44 |
configs=get_autotune_config(),
|
| 45 |
key=['M', 'K'],
|
| 46 |
+
restore_value=['y'],
|
| 47 |
)
|
| 48 |
@triton.jit
|
| 49 |
def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
|
|
|
|
| 103 |
tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask)
|
| 104 |
|
| 105 |
|
| 106 |
+
@torch.library.custom_op("muon::matmul_transpose_assign",
|
| 107 |
+
mutates_args=("d_out", ))
|
| 108 |
+
def matmul_transpose_assign(d_in: torch.Tensor, d_out: torch.Tensor) -> None:
|
| 109 |
+
"""Compute d_out = d_in @ d_in.T using an optimized Triton kernel."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
d_in = d_in.contiguous()
|
| 111 |
M, K = d_in.shape
|
| 112 |
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
|
|
|
|
| 114 |
with torch.cuda.device(d_in.device.index):
|
| 115 |
mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
|
| 116 |
d_out.stride(0), d_out.stride(1))
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
@matmul_transpose_assign.register_fake
|
| 120 |
+
def _(d_in: torch.Tensor, d_out: torch.Tensor) -> None:
|
| 121 |
+
"""FakeTensor impl: d_out is already allocated, mutation is declared."""
|
| 122 |
+
pass
|
torch-ext/optimizer/newton_schulz.py
CHANGED
|
@@ -162,3 +162,75 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
| 162 |
X = X.T
|
| 163 |
|
| 164 |
return X
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
X = X.T
|
| 163 |
|
| 164 |
return X
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
@torch.no_grad()
|
| 168 |
+
def _zeropower_via_newtonschulz5_batched(G, steps):
|
| 169 |
+
"""Batched polar factor computation for 3D (E, out, in) tensors.
|
| 170 |
+
|
| 171 |
+
Same algorithm as ``_zeropower_via_newtonschulz5`` but uses
|
| 172 |
+
``torch.bmm`` / ``torch.baddbmm`` instead of the 2D Triton kernel,
|
| 173 |
+
processing all E expert matrices in a single batched call.
|
| 174 |
+
"""
|
| 175 |
+
assert len(G.shape) == 3
|
| 176 |
+
assert G.dtype == COMM_DTYPE
|
| 177 |
+
X = G
|
| 178 |
+
|
| 179 |
+
if G.size(1) > G.size(2):
|
| 180 |
+
X = X.transpose(-2, -1)
|
| 181 |
+
|
| 182 |
+
# Per-expert Frobenius norm.
|
| 183 |
+
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
| 184 |
+
|
| 185 |
+
hs = _coeffs_list[:steps] + list(
|
| 186 |
+
repeat(_coeffs_list[-1], steps - len(_coeffs_list)))
|
| 187 |
+
for a, b, c in hs:
|
| 188 |
+
buf1 = torch.bmm(X, X.transpose(-2, -1))
|
| 189 |
+
buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
|
| 190 |
+
buf1.mul_(b).add_(buf2, alpha=c)
|
| 191 |
+
X = torch.baddbmm(X, buf1, X, alpha=1.0, beta=a)
|
| 192 |
+
|
| 193 |
+
if G.size(1) > G.size(2):
|
| 194 |
+
X = X.transpose(-2, -1)
|
| 195 |
+
|
| 196 |
+
return X
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
_ns_per_shape: dict[tuple[int, ...], callable] = {}
|
| 200 |
+
_use_compile = True
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def set_ns_compile(enabled: bool):
|
| 204 |
+
"""Toggle torch.compile for Newton-Schulz iteration."""
|
| 205 |
+
global _use_compile
|
| 206 |
+
_use_compile = enabled
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def zeropower_via_newtonschulz5(G, steps=5):
|
| 210 |
+
if not _use_compile:
|
| 211 |
+
return _zeropower_via_newtonschulz5(G, steps)
|
| 212 |
+
key = G.shape
|
| 213 |
+
if key not in _ns_per_shape:
|
| 214 |
+
_ns_per_shape[key] = torch.compile(_zeropower_via_newtonschulz5,
|
| 215 |
+
options={
|
| 216 |
+
"triton.cudagraphs": True,
|
| 217 |
+
"shape_padding": False
|
| 218 |
+
})
|
| 219 |
+
torch.compiler.cudagraph_mark_step_begin()
|
| 220 |
+
return _ns_per_shape[key](G, steps).clone()
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def zeropower_via_newtonschulz5_batched(G, steps=5):
|
| 224 |
+
"""Compile-cached batched Newton-Schulz for 3D expert tensors."""
|
| 225 |
+
if not _use_compile:
|
| 226 |
+
return _zeropower_via_newtonschulz5_batched(G, steps)
|
| 227 |
+
key = G.shape
|
| 228 |
+
if key not in _ns_per_shape:
|
| 229 |
+
_ns_per_shape[key] = torch.compile(
|
| 230 |
+
_zeropower_via_newtonschulz5_batched,
|
| 231 |
+
options={
|
| 232 |
+
"triton.cudagraphs": True,
|
| 233 |
+
"shape_padding": False
|
| 234 |
+
})
|
| 235 |
+
torch.compiler.cudagraph_mark_step_begin()
|
| 236 |
+
return _ns_per_shape[key](G, steps).clone()
|
torch-ext/optimizer/qk_clip.py
CHANGED
|
@@ -102,23 +102,27 @@ def compute_scales(p, qk_clip_state):
|
|
| 102 |
threshold = qk_clip_state.threshold
|
| 103 |
logit = qk_clip_state.logit
|
| 104 |
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
scaling = 0
|
| 108 |
-
|
| 109 |
for logit_idx, head_idx in enumerate(indices):
|
| 110 |
v_ele = float(logit[logit_idx])
|
| 111 |
if v_ele > threshold:
|
| 112 |
new_scale = math.sqrt(threshold / v_ele)
|
| 113 |
-
if new_scale <
|
| 114 |
-
|
| 115 |
logger.info(
|
| 116 |
f"[{kind}] Head {head_idx} exceeded threshold "
|
| 117 |
f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
|
| 118 |
)
|
| 119 |
-
scaling += 1
|
| 120 |
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
|
| 124 |
def qk_clip(p, scales, head_dim):
|
|
|
|
| 102 |
threshold = qk_clip_state.threshold
|
| 103 |
logit = qk_clip_state.logit
|
| 104 |
|
| 105 |
+
# Check if any head exceeds threshold before allocating.
|
| 106 |
+
head_scales = {}
|
|
|
|
|
|
|
| 107 |
for logit_idx, head_idx in enumerate(indices):
|
| 108 |
v_ele = float(logit[logit_idx])
|
| 109 |
if v_ele > threshold:
|
| 110 |
new_scale = math.sqrt(threshold / v_ele)
|
| 111 |
+
if head_idx not in head_scales or new_scale < head_scales[head_idx]:
|
| 112 |
+
head_scales[head_idx] = new_scale
|
| 113 |
logger.info(
|
| 114 |
f"[{kind}] Head {head_idx} exceeded threshold "
|
| 115 |
f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
|
| 116 |
)
|
|
|
|
| 117 |
|
| 118 |
+
if not head_scales:
|
| 119 |
+
return None
|
| 120 |
+
|
| 121 |
+
H_global = p.shape[0] // head_dim
|
| 122 |
+
scales_full = torch.ones(H_global, device=p.data.device)
|
| 123 |
+
for head_idx, scale in head_scales.items():
|
| 124 |
+
scales_full[head_idx] = scale
|
| 125 |
+
return scales_full
|
| 126 |
|
| 127 |
|
| 128 |
def qk_clip(p, scales, head_dim):
|