Merge pull request #17 from MotifTechnologies/optimal-ns-coefficients
Browse filesReplace hardcoded NS coefficients with analytically optimal ones [ski…
- torch-ext/optimizer/newton_schulz.py +134 -20
torch-ext/optimizer/newton_schulz.py
CHANGED
|
@@ -1,3 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
|
| 3 |
from .matmul_transpose_triton import matmul_transpose_assign
|
|
@@ -6,21 +10,134 @@ COMM_DTYPE = torch.bfloat16
|
|
| 6 |
DEFAULT_CHUNK_SIZE_RATIO = 4
|
| 7 |
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
@torch.no_grad()
|
| 14 |
-
# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon
|
| 15 |
def _zeropower_via_newtonschulz5(G, steps):
|
| 16 |
"""
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
"""
|
| 25 |
assert len(G.shape) == 2
|
| 26 |
assert G.dtype == COMM_DTYPE
|
|
@@ -28,18 +145,14 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
| 28 |
|
| 29 |
if G.size(0) > G.size(1):
|
| 30 |
X = X.T
|
| 31 |
-
|
| 32 |
X = X / (X.norm() + 1e-7)
|
|
|
|
|
|
|
| 33 |
buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
|
| 34 |
buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
|
| 35 |
# Perform the NS iterations
|
| 36 |
-
for a, b, c in
|
| 37 |
-
(4.0848, -6.8946, 2.9270),
|
| 38 |
-
(3.9505, -6.3029, 2.6377),
|
| 39 |
-
(3.7418, -5.5913, 2.3037),
|
| 40 |
-
(2.8769, -3.1427, 1.2046),
|
| 41 |
-
(2.8366, -3.0525, 1.2012),
|
| 42 |
-
]:
|
| 43 |
matmul_transpose_assign(X, buf1)
|
| 44 |
matmul_transpose_assign(buf1, buf2)
|
| 45 |
buf1.mul_(b).add_(buf2, alpha=c)
|
|
@@ -47,4 +160,5 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
| 47 |
|
| 48 |
if G.size(0) > G.size(1):
|
| 49 |
X = X.T
|
|
|
|
| 50 |
return X
|
|
|
|
| 1 |
+
from itertools import repeat
|
| 2 |
+
from math import inf, sqrt
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
import torch
|
| 6 |
|
| 7 |
from .matmul_transpose_triton import matmul_transpose_assign
|
|
|
|
| 10 |
DEFAULT_CHUNK_SIZE_RATIO = 4
|
| 11 |
|
| 12 |
|
| 13 |
+
def _optimal_quintic(l, u, max_iter=1000):
|
| 14 |
+
"""
|
| 15 |
+
Use the simplified Remez algorithm to find the optimal odd quintic approximant
|
| 16 |
+
to the constant function x -> 1 over the interval [l, u].
|
| 17 |
+
|
| 18 |
+
Returns (a, b, c) for p(x) = ax + bx^3 + cx^5 that minimizes the maximum
|
| 19 |
+
approximation error max_{x in [l,u]} |p(x) - 1|. Iterates by updating the
|
| 20 |
+
two interior equioscillation nodes q, r until convergence. Returns the
|
| 21 |
+
closed-form equioscillating solution when l ≈ u.
|
| 22 |
+
|
| 23 |
+
Raises ValueError if any intermediate value (a, b, c, E, q, r) is non-finite
|
| 24 |
+
(NaN or inf). Raises RuntimeError if convergence is not reached within
|
| 25 |
+
max_iter iterations.
|
| 26 |
+
"""
|
| 27 |
+
assert 0 <= l <= u
|
| 28 |
+
if 1 - 5e-6 <= l / u:
|
| 29 |
+
return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5)
|
| 30 |
+
q = (3 * l + u) / 4
|
| 31 |
+
r = (l + 3 * u) / 4
|
| 32 |
+
E = inf
|
| 33 |
+
for _ in range(max_iter):
|
| 34 |
+
old_E = E
|
| 35 |
+
LHS = np.array([
|
| 36 |
+
[l, l**3, l**5, 1],
|
| 37 |
+
[q, q**3, q**5, -1],
|
| 38 |
+
[r, r**3, r**5, 1],
|
| 39 |
+
[u, u**3, u**5, -1],
|
| 40 |
+
])
|
| 41 |
+
a, b, c, E = np.linalg.solve(LHS, np.ones(4))
|
| 42 |
+
if not np.all(np.isfinite([a, b, c, E])):
|
| 43 |
+
raise ValueError(f"_optimal_quintic: non-finite solve result "
|
| 44 |
+
f"a={a}, b={b}, c={c}, E={E}")
|
| 45 |
+
q, r = np.sqrt(
|
| 46 |
+
(-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) /
|
| 47 |
+
(10 * c))
|
| 48 |
+
if not np.all(np.isfinite([q, r])):
|
| 49 |
+
raise ValueError(
|
| 50 |
+
f"_optimal_quintic: non-finite node update q={q}, r={r}")
|
| 51 |
+
if abs(old_E - E) <= 1e-15:
|
| 52 |
+
break
|
| 53 |
+
else:
|
| 54 |
+
raise RuntimeError(
|
| 55 |
+
f"_optimal_quintic: did not converge after {max_iter} iterations")
|
| 56 |
+
return float(a), float(b), float(c)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0):
|
| 60 |
+
"""
|
| 61 |
+
Compute the Polar Express coefficient series for `num_iters` quintic iterations.
|
| 62 |
+
|
| 63 |
+
Builds a sequence of per-step optimal odd quintic coefficients (a, b, c) that
|
| 64 |
+
compose to map singular values from [l, 1] toward 1. At each step:
|
| 65 |
+
1. Solves `_optimal_quintic` on [max(l, cushion*u), u]. The `cushion`
|
| 66 |
+
prevents near-zero singular values from stalling by raising the effective
|
| 67 |
+
lower bound; if it is active (cushion*u > l), the coefficients are
|
| 68 |
+
rescaled so that p(l) and p(u) are centered around 1 w.r.t. the true [l, u].
|
| 69 |
+
2. Deflates the coefficients by (1 + safety_factor_eps)^degree for all but the
|
| 70 |
+
last iteration, providing numerical headroom at the cost of a slightly slower
|
| 71 |
+
final convergence step.
|
| 72 |
+
3. Advances the interval: l <- p(l), u <- 2 - p(l) (by symmetry of p around 1).
|
| 73 |
+
|
| 74 |
+
Returns a list of (a, b, c) tuples, one per iteration.
|
| 75 |
+
|
| 76 |
+
Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and
|
| 77 |
+
Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932
|
| 78 |
+
"""
|
| 79 |
+
u = 1
|
| 80 |
+
assert 0 <= l <= u
|
| 81 |
+
safety_factor = 1 + safety_factor_eps
|
| 82 |
+
coefficients = []
|
| 83 |
+
for iter in range(num_iters):
|
| 84 |
+
a, b, c = _optimal_quintic(max(l, cushion * u), u)
|
| 85 |
+
if cushion * u > l:
|
| 86 |
+
pl = a * l + b * l**3 + c * l**5
|
| 87 |
+
pu = a * u + b * u**3 + c * u**5
|
| 88 |
+
rescaler = 2 / (pl + pu)
|
| 89 |
+
a *= rescaler
|
| 90 |
+
b *= rescaler
|
| 91 |
+
c *= rescaler
|
| 92 |
+
if iter < num_iters - 1:
|
| 93 |
+
a /= safety_factor
|
| 94 |
+
b /= safety_factor**3
|
| 95 |
+
c /= safety_factor**5
|
| 96 |
+
coefficients.append((a, b, c))
|
| 97 |
+
l = a * l + b * l**3 + c * l**5
|
| 98 |
+
u = 2 - l
|
| 99 |
+
return coefficients
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# Precomputed Polar Express coefficients (a, b, c) for 10 quintic Newton-Schulz
|
| 103 |
+
# iterations. Each tuple is the minimax-optimal (Remez/equioscillation) odd quintic
|
| 104 |
+
# approximant to x->1 over the current singular-value interval, computed once at
|
| 105 |
+
# import time and reused across all optimizer steps.
|
| 106 |
+
#
|
| 107 |
+
# Contrast with the former hardcoded NS coefficients (5 fixed tuples):
|
| 108 |
+
# - Former: empirically tuned to maximize slope at zero; did not converge
|
| 109 |
+
# singular values to 1, yielding US'V^T with S' ~ Uniform(0.5, 1.5) instead
|
| 110 |
+
# of the true polar factor UV^T.
|
| 111 |
+
# - Polar Express: analytically optimal per step, adapting to the shrinking
|
| 112 |
+
# singular-value interval [l, u] as iterations progress; converges all
|
| 113 |
+
# singular values to 1, producing the exact polar factor UV^T.
|
| 114 |
+
_coeffs_list = _optimal_composition(l=1e-3,
|
| 115 |
+
num_iters=10,
|
| 116 |
+
safety_factor_eps=1e-2,
|
| 117 |
+
cushion=0.02)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# This code is adapted from:
|
| 121 |
+
# KellerJordan/Muon (https://github.com/KellerJordan/Muon/blob/master/muon.py)
|
| 122 |
+
# NoahAmsel/PolarExpress (https://github.com/NoahAmsel/PolarExpress)
|
| 123 |
+
# matmul_transpose_assign kernel from nil0x9/flash-muon (https://github.com/nil0x9/flash-muon)
|
| 124 |
@torch.no_grad()
|
|
|
|
| 125 |
def _zeropower_via_newtonschulz5(G, steps):
|
| 126 |
"""
|
| 127 |
+
Compute the polar factor of G via the Polar Express method.
|
| 128 |
+
|
| 129 |
+
Applies `steps` quintic iterations X <- aX + bX^3 + cX^5, where (a, b, c)
|
| 130 |
+
are the Polar Express coefficients from `_coeffs_list`. Each step is the
|
| 131 |
+
optimal odd quintic approximant to x -> 1 over the current singular-value
|
| 132 |
+
interval, minimizing the maximum approximation error (Remez / minimax criterion).
|
| 133 |
+
The composition maps singular values from [l, 1] to near 1, producing the
|
| 134 |
+
polar factor (orthogonal factor in the polar decomposition G = UP).
|
| 135 |
+
|
| 136 |
+
`_coeffs_list` is precomputed for 10 iterations (l=1e-3, safety_factor_eps=1e-2,
|
| 137 |
+
cushion=0.02). If `steps` exceeds 10, the final coefficient set is repeated.
|
| 138 |
+
|
| 139 |
+
Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and
|
| 140 |
+
Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932
|
| 141 |
"""
|
| 142 |
assert len(G.shape) == 2
|
| 143 |
assert G.dtype == COMM_DTYPE
|
|
|
|
| 145 |
|
| 146 |
if G.size(0) > G.size(1):
|
| 147 |
X = X.T
|
| 148 |
+
|
| 149 |
X = X / (X.norm() + 1e-7)
|
| 150 |
+
hs = _coeffs_list[:steps] + list(
|
| 151 |
+
repeat(_coeffs_list[-1], steps - len(_coeffs_list)))
|
| 152 |
buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
|
| 153 |
buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
|
| 154 |
# Perform the NS iterations
|
| 155 |
+
for a, b, c in hs:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
matmul_transpose_assign(X, buf1)
|
| 157 |
matmul_transpose_assign(buf1, buf2)
|
| 158 |
buf1.mul_(b).add_(buf2, alpha=c)
|
|
|
|
| 160 |
|
| 161 |
if G.size(0) > G.size(1):
|
| 162 |
X = X.T
|
| 163 |
+
|
| 164 |
return X
|