Commit ·
aff01db
1
Parent(s): abaa449
Apply pre-commit formatting (yapf, isort) [skip-build]
Browse files
torch-ext/optimizer/newton_schulz.py
CHANGED
|
@@ -1,13 +1,15 @@
|
|
| 1 |
-
import torch
|
| 2 |
from itertools import repeat
|
| 3 |
from math import inf, sqrt
|
|
|
|
| 4 |
import numpy as np
|
|
|
|
| 5 |
|
| 6 |
from .matmul_transpose_triton import matmul_transpose_assign
|
| 7 |
|
| 8 |
COMM_DTYPE = torch.bfloat16
|
| 9 |
DEFAULT_CHUNK_SIZE_RATIO = 4
|
| 10 |
|
|
|
|
| 11 |
def _optimal_quintic(l, u):
|
| 12 |
"""
|
| 13 |
Use the simplified Remez algorithm to find the optimal odd quintic approximant
|
|
@@ -20,9 +22,9 @@ def _optimal_quintic(l, u):
|
|
| 20 |
"""
|
| 21 |
assert 0 <= l <= u
|
| 22 |
if 1 - 5e-6 <= l / u:
|
| 23 |
-
return (15/8)/u, (-10/8)/(u**3), (3/8)/(u**5)
|
| 24 |
-
q = (3*l + u) / 4
|
| 25 |
-
r = (l + 3*u) / 4
|
| 26 |
E, old_E = inf, None
|
| 27 |
while not old_E or abs(old_E - E) > 1e-15:
|
| 28 |
old_E = E
|
|
@@ -33,8 +35,9 @@ def _optimal_quintic(l, u):
|
|
| 33 |
[u, u**3, u**5, -1],
|
| 34 |
])
|
| 35 |
a, b, c, E = np.linalg.solve(LHS, np.ones(4))
|
| 36 |
-
q, r = np.sqrt(
|
| 37 |
-
|
|
|
|
| 38 |
return float(a), float(b), float(c)
|
| 39 |
|
| 40 |
|
|
@@ -63,16 +66,20 @@ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0):
|
|
| 63 |
safety_factor = 1 + safety_factor_eps
|
| 64 |
coefficients = []
|
| 65 |
for iter in range(num_iters):
|
| 66 |
-
a, b, c = _optimal_quintic(max(l, cushion*u), u)
|
| 67 |
-
if cushion*u > l:
|
| 68 |
-
pl = a*l + b*l**3 + c*l**5
|
| 69 |
-
pu = a*u + b*u**3 + c*u**5
|
| 70 |
-
rescaler = 2/(pl + pu)
|
| 71 |
-
a *= rescaler
|
|
|
|
|
|
|
| 72 |
if iter < num_iters - 1:
|
| 73 |
-
a /= safety_factor
|
|
|
|
|
|
|
| 74 |
coefficients.append((a, b, c))
|
| 75 |
-
l = a*l + b*l**3 + c*l**5
|
| 76 |
u = 2 - l
|
| 77 |
return coefficients
|
| 78 |
|
|
@@ -89,7 +96,11 @@ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0):
|
|
| 89 |
# - Polar Express: analytically optimal per step, adapting to the shrinking
|
| 90 |
# singular-value interval [l, u] as iterations progress; converges all
|
| 91 |
# singular values to 1, producing the exact polar factor UV^T.
|
| 92 |
-
_coeffs_list = _optimal_composition(l=1e-3,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
# This code is adapted from:
|
| 95 |
# KellerJordan/Muon (https://github.com/KellerJordan/Muon/blob/master/muon.py)
|
|
|
|
|
|
|
| 1 |
from itertools import repeat
|
| 2 |
from math import inf, sqrt
|
| 3 |
+
|
| 4 |
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
|
| 7 |
from .matmul_transpose_triton import matmul_transpose_assign
|
| 8 |
|
| 9 |
COMM_DTYPE = torch.bfloat16
|
| 10 |
DEFAULT_CHUNK_SIZE_RATIO = 4
|
| 11 |
|
| 12 |
+
|
| 13 |
def _optimal_quintic(l, u):
|
| 14 |
"""
|
| 15 |
Use the simplified Remez algorithm to find the optimal odd quintic approximant
|
|
|
|
| 22 |
"""
|
| 23 |
assert 0 <= l <= u
|
| 24 |
if 1 - 5e-6 <= l / u:
|
| 25 |
+
return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5)
|
| 26 |
+
q = (3 * l + u) / 4
|
| 27 |
+
r = (l + 3 * u) / 4
|
| 28 |
E, old_E = inf, None
|
| 29 |
while not old_E or abs(old_E - E) > 1e-15:
|
| 30 |
old_E = E
|
|
|
|
| 35 |
[u, u**3, u**5, -1],
|
| 36 |
])
|
| 37 |
a, b, c, E = np.linalg.solve(LHS, np.ones(4))
|
| 38 |
+
q, r = np.sqrt(
|
| 39 |
+
(-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) /
|
| 40 |
+
(10 * c))
|
| 41 |
return float(a), float(b), float(c)
|
| 42 |
|
| 43 |
|
|
|
|
| 66 |
safety_factor = 1 + safety_factor_eps
|
| 67 |
coefficients = []
|
| 68 |
for iter in range(num_iters):
|
| 69 |
+
a, b, c = _optimal_quintic(max(l, cushion * u), u)
|
| 70 |
+
if cushion * u > l:
|
| 71 |
+
pl = a * l + b * l**3 + c * l**5
|
| 72 |
+
pu = a * u + b * u**3 + c * u**5
|
| 73 |
+
rescaler = 2 / (pl + pu)
|
| 74 |
+
a *= rescaler
|
| 75 |
+
b *= rescaler
|
| 76 |
+
c *= rescaler
|
| 77 |
if iter < num_iters - 1:
|
| 78 |
+
a /= safety_factor
|
| 79 |
+
b /= safety_factor**3
|
| 80 |
+
c /= safety_factor**5
|
| 81 |
coefficients.append((a, b, c))
|
| 82 |
+
l = a * l + b * l**3 + c * l**5
|
| 83 |
u = 2 - l
|
| 84 |
return coefficients
|
| 85 |
|
|
|
|
| 96 |
# - Polar Express: analytically optimal per step, adapting to the shrinking
|
| 97 |
# singular-value interval [l, u] as iterations progress; converges all
|
| 98 |
# singular values to 1, producing the exact polar factor UV^T.
|
| 99 |
+
_coeffs_list = _optimal_composition(l=1e-3,
|
| 100 |
+
num_iters=10,
|
| 101 |
+
safety_factor_eps=1e-2,
|
| 102 |
+
cushion=0.02)
|
| 103 |
+
|
| 104 |
|
| 105 |
# This code is adapted from:
|
| 106 |
# KellerJordan/Muon (https://github.com/KellerJordan/Muon/blob/master/muon.py)
|