Commit ·
206b280
1
Parent(s): aff01db
Add max_iter cap and non-finite checks to _optimal_quintic [skip-build]
Browse filesAdds a max_iter parameter (default 1000) to replace the open-ended while
loop. Raises ValueError if the linear solve or node update produces
non-finite values, and RuntimeError if convergence is not reached within
max_iter iterations.
torch-ext/optimizer/newton_schulz.py
CHANGED
|
@@ -10,7 +10,7 @@ COMM_DTYPE = torch.bfloat16
|
|
| 10 |
DEFAULT_CHUNK_SIZE_RATIO = 4
|
| 11 |
|
| 12 |
|
| 13 |
-
def _optimal_quintic(l, u):
|
| 14 |
"""
|
| 15 |
Use the simplified Remez algorithm to find the optimal odd quintic approximant
|
| 16 |
to the constant function x -> 1 over the interval [l, u].
|
|
@@ -19,14 +19,18 @@ def _optimal_quintic(l, u):
|
|
| 19 |
approximation error max_{x in [l,u]} |p(x) - 1|. Iterates by updating the
|
| 20 |
two interior equioscillation nodes q, r until convergence. Returns the
|
| 21 |
closed-form equioscillating solution when l ≈ u.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
"""
|
| 23 |
assert 0 <= l <= u
|
| 24 |
if 1 - 5e-6 <= l / u:
|
| 25 |
return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5)
|
| 26 |
q = (3 * l + u) / 4
|
| 27 |
r = (l + 3 * u) / 4
|
| 28 |
-
E
|
| 29 |
-
|
| 30 |
old_E = E
|
| 31 |
LHS = np.array([
|
| 32 |
[l, l**3, l**5, 1],
|
|
@@ -35,9 +39,21 @@ def _optimal_quintic(l, u):
|
|
| 35 |
[u, u**3, u**5, -1],
|
| 36 |
])
|
| 37 |
a, b, c, E = np.linalg.solve(LHS, np.ones(4))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
q, r = np.sqrt(
|
| 39 |
(-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) /
|
| 40 |
(10 * c))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
return float(a), float(b), float(c)
|
| 42 |
|
| 43 |
|
|
|
|
| 10 |
DEFAULT_CHUNK_SIZE_RATIO = 4
|
| 11 |
|
| 12 |
|
| 13 |
+
def _optimal_quintic(l, u, max_iter=1000):
|
| 14 |
"""
|
| 15 |
Use the simplified Remez algorithm to find the optimal odd quintic approximant
|
| 16 |
to the constant function x -> 1 over the interval [l, u].
|
|
|
|
| 19 |
approximation error max_{x in [l,u]} |p(x) - 1|. Iterates by updating the
|
| 20 |
two interior equioscillation nodes q, r until convergence. Returns the
|
| 21 |
closed-form equioscillating solution when l ≈ u.
|
| 22 |
+
|
| 23 |
+
Raises ValueError if any intermediate value (a, b, c, E, q, r) is non-finite
|
| 24 |
+
(NaN or inf). Raises RuntimeError if convergence is not reached within
|
| 25 |
+
max_iter iterations.
|
| 26 |
"""
|
| 27 |
assert 0 <= l <= u
|
| 28 |
if 1 - 5e-6 <= l / u:
|
| 29 |
return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5)
|
| 30 |
q = (3 * l + u) / 4
|
| 31 |
r = (l + 3 * u) / 4
|
| 32 |
+
E = inf
|
| 33 |
+
for _ in range(max_iter):
|
| 34 |
old_E = E
|
| 35 |
LHS = np.array([
|
| 36 |
[l, l**3, l**5, 1],
|
|
|
|
| 39 |
[u, u**3, u**5, -1],
|
| 40 |
])
|
| 41 |
a, b, c, E = np.linalg.solve(LHS, np.ones(4))
|
| 42 |
+
if not np.all(np.isfinite([a, b, c, E])):
|
| 43 |
+
raise ValueError(
|
| 44 |
+
f"_optimal_quintic: non-finite solve result "
|
| 45 |
+
f"a={a}, b={b}, c={c}, E={E}")
|
| 46 |
q, r = np.sqrt(
|
| 47 |
(-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) /
|
| 48 |
(10 * c))
|
| 49 |
+
if not np.all(np.isfinite([q, r])):
|
| 50 |
+
raise ValueError(
|
| 51 |
+
f"_optimal_quintic: non-finite node update q={q}, r={r}")
|
| 52 |
+
if abs(old_E - E) <= 1e-15:
|
| 53 |
+
break
|
| 54 |
+
else:
|
| 55 |
+
raise RuntimeError(
|
| 56 |
+
f"_optimal_quintic: did not converge after {max_iter} iterations")
|
| 57 |
return float(a), float(b), float(c)
|
| 58 |
|
| 59 |
|