Kernels
dongseokmotif commited on
Commit
206b280
·
1 Parent(s): aff01db

Add max_iter cap and non-finite checks to _optimal_quintic [skip-build]

Browse files

Adds a max_iter parameter (default 1000) to replace the open-ended while
loop. Raises ValueError if the linear solve or node update produces
non-finite values, and RuntimeError if convergence is not reached within
max_iter iterations.

torch-ext/optimizer/newton_schulz.py CHANGED
@@ -10,7 +10,7 @@ COMM_DTYPE = torch.bfloat16
10
  DEFAULT_CHUNK_SIZE_RATIO = 4
11
 
12
 
13
- def _optimal_quintic(l, u):
14
  """
15
  Use the simplified Remez algorithm to find the optimal odd quintic approximant
16
  to the constant function x -> 1 over the interval [l, u].
@@ -19,14 +19,18 @@ def _optimal_quintic(l, u):
19
  approximation error max_{x in [l,u]} |p(x) - 1|. Iterates by updating the
20
  two interior equioscillation nodes q, r until convergence. Returns the
21
  closed-form equioscillating solution when l ≈ u.
 
 
 
 
22
  """
23
  assert 0 <= l <= u
24
  if 1 - 5e-6 <= l / u:
25
  return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5)
26
  q = (3 * l + u) / 4
27
  r = (l + 3 * u) / 4
28
- E, old_E = inf, None
29
- while not old_E or abs(old_E - E) > 1e-15:
30
  old_E = E
31
  LHS = np.array([
32
  [l, l**3, l**5, 1],
@@ -35,9 +39,21 @@ def _optimal_quintic(l, u):
35
  [u, u**3, u**5, -1],
36
  ])
37
  a, b, c, E = np.linalg.solve(LHS, np.ones(4))
 
 
 
 
38
  q, r = np.sqrt(
39
  (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) /
40
  (10 * c))
 
 
 
 
 
 
 
 
41
  return float(a), float(b), float(c)
42
 
43
 
 
10
  DEFAULT_CHUNK_SIZE_RATIO = 4
11
 
12
 
13
+ def _optimal_quintic(l, u, max_iter=1000):
14
  """
15
  Use the simplified Remez algorithm to find the optimal odd quintic approximant
16
  to the constant function x -> 1 over the interval [l, u].
 
19
  approximation error max_{x in [l,u]} |p(x) - 1|. Iterates by updating the
20
  two interior equioscillation nodes q, r until convergence. Returns the
21
  closed-form equioscillating solution when l ≈ u.
22
+
23
+ Raises ValueError if any intermediate value (a, b, c, E, q, r) is non-finite
24
+ (NaN or inf). Raises RuntimeError if convergence is not reached within
25
+ max_iter iterations.
26
  """
27
  assert 0 <= l <= u
28
  if 1 - 5e-6 <= l / u:
29
  return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5)
30
  q = (3 * l + u) / 4
31
  r = (l + 3 * u) / 4
32
+ E = inf
33
+ for _ in range(max_iter):
34
  old_E = E
35
  LHS = np.array([
36
  [l, l**3, l**5, 1],
 
39
  [u, u**3, u**5, -1],
40
  ])
41
  a, b, c, E = np.linalg.solve(LHS, np.ones(4))
42
+ if not np.all(np.isfinite([a, b, c, E])):
43
+ raise ValueError(
44
+ f"_optimal_quintic: non-finite solve result "
45
+ f"a={a}, b={b}, c={c}, E={E}")
46
  q, r = np.sqrt(
47
  (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) /
48
  (10 * c))
49
+ if not np.all(np.isfinite([q, r])):
50
+ raise ValueError(
51
+ f"_optimal_quintic: non-finite node update q={q}, r={r}")
52
+ if abs(old_E - E) <= 1e-15:
53
+ break
54
+ else:
55
+ raise RuntimeError(
56
+ f"_optimal_quintic: did not converge after {max_iter} iterations")
57
  return float(a), float(b), float(c)
58
 
59