Kernels
dongseokmotif commited on
Commit
aff01db
·
1 Parent(s): abaa449

Apply pre-commit formatting (yapf, isort) [skip-build]

Browse files
Files changed (1) hide show
  1. torch-ext/optimizer/newton_schulz.py +26 -15
torch-ext/optimizer/newton_schulz.py CHANGED
@@ -1,13 +1,15 @@
1
- import torch
2
  from itertools import repeat
3
  from math import inf, sqrt
 
4
  import numpy as np
 
5
 
6
  from .matmul_transpose_triton import matmul_transpose_assign
7
 
8
  COMM_DTYPE = torch.bfloat16
9
  DEFAULT_CHUNK_SIZE_RATIO = 4
10
 
 
11
  def _optimal_quintic(l, u):
12
  """
13
  Use the simplified Remez algorithm to find the optimal odd quintic approximant
@@ -20,9 +22,9 @@ def _optimal_quintic(l, u):
20
  """
21
  assert 0 <= l <= u
22
  if 1 - 5e-6 <= l / u:
23
- return (15/8)/u, (-10/8)/(u**3), (3/8)/(u**5)
24
- q = (3*l + u) / 4
25
- r = (l + 3*u) / 4
26
  E, old_E = inf, None
27
  while not old_E or abs(old_E - E) > 1e-15:
28
  old_E = E
@@ -33,8 +35,9 @@ def _optimal_quintic(l, u):
33
  [u, u**3, u**5, -1],
34
  ])
35
  a, b, c, E = np.linalg.solve(LHS, np.ones(4))
36
- q, r = np.sqrt((-3*b + np.array([-1, 1]) *
37
- sqrt(9*b**2 - 20*a*c)) / (10*c))
 
38
  return float(a), float(b), float(c)
39
 
40
 
@@ -63,16 +66,20 @@ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0):
63
  safety_factor = 1 + safety_factor_eps
64
  coefficients = []
65
  for iter in range(num_iters):
66
- a, b, c = _optimal_quintic(max(l, cushion*u), u)
67
- if cushion*u > l:
68
- pl = a*l + b*l**3 + c*l**5
69
- pu = a*u + b*u**3 + c*u**5
70
- rescaler = 2/(pl + pu)
71
- a *= rescaler; b *= rescaler; c *= rescaler
 
 
72
  if iter < num_iters - 1:
73
- a /= safety_factor; b /= safety_factor**3; c /= safety_factor**5
 
 
74
  coefficients.append((a, b, c))
75
- l = a*l + b*l**3 + c*l**5
76
  u = 2 - l
77
  return coefficients
78
 
@@ -89,7 +96,11 @@ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0):
89
  # - Polar Express: analytically optimal per step, adapting to the shrinking
90
  # singular-value interval [l, u] as iterations progress; converges all
91
  # singular values to 1, producing the exact polar factor UV^T.
92
- _coeffs_list = _optimal_composition(l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02)
 
 
 
 
93
 
94
  # This code is adapted from:
95
  # KellerJordan/Muon (https://github.com/KellerJordan/Muon/blob/master/muon.py)
 
 
1
  from itertools import repeat
2
  from math import inf, sqrt
3
+
4
  import numpy as np
5
+ import torch
6
 
7
  from .matmul_transpose_triton import matmul_transpose_assign
8
 
9
  COMM_DTYPE = torch.bfloat16
10
  DEFAULT_CHUNK_SIZE_RATIO = 4
11
 
12
+
13
  def _optimal_quintic(l, u):
14
  """
15
  Use the simplified Remez algorithm to find the optimal odd quintic approximant
 
22
  """
23
  assert 0 <= l <= u
24
  if 1 - 5e-6 <= l / u:
25
+ return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5)
26
+ q = (3 * l + u) / 4
27
+ r = (l + 3 * u) / 4
28
  E, old_E = inf, None
29
  while not old_E or abs(old_E - E) > 1e-15:
30
  old_E = E
 
35
  [u, u**3, u**5, -1],
36
  ])
37
  a, b, c, E = np.linalg.solve(LHS, np.ones(4))
38
+ q, r = np.sqrt(
39
+ (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) /
40
+ (10 * c))
41
  return float(a), float(b), float(c)
42
 
43
 
 
66
  safety_factor = 1 + safety_factor_eps
67
  coefficients = []
68
  for iter in range(num_iters):
69
+ a, b, c = _optimal_quintic(max(l, cushion * u), u)
70
+ if cushion * u > l:
71
+ pl = a * l + b * l**3 + c * l**5
72
+ pu = a * u + b * u**3 + c * u**5
73
+ rescaler = 2 / (pl + pu)
74
+ a *= rescaler
75
+ b *= rescaler
76
+ c *= rescaler
77
  if iter < num_iters - 1:
78
+ a /= safety_factor
79
+ b /= safety_factor**3
80
+ c /= safety_factor**5
81
  coefficients.append((a, b, c))
82
+ l = a * l + b * l**3 + c * l**5
83
  u = 2 - l
84
  return coefficients
85
 
 
96
  # - Polar Express: analytically optimal per step, adapting to the shrinking
97
  # singular-value interval [l, u] as iterations progress; converges all
98
  # singular values to 1, producing the exact polar factor UV^T.
99
+ _coeffs_list = _optimal_composition(l=1e-3,
100
+ num_iters=10,
101
+ safety_factor_eps=1e-2,
102
+ cushion=0.02)
103
+
104
 
105
  # This code is adapted from:
106
  # KellerJordan/Muon (https://github.com/KellerJordan/Muon/blob/master/muon.py)