Kernels
dongseokmotif Claude Sonnet 4.6 commited on
Commit
abaa449
·
1 Parent(s): 573242f

Add comment explaining _coeffs_list and Polar Express vs former NS [skip-build]

Browse files

Documents what _coeffs_list is (precomputed Polar Express coefficients, minimax-optimal
via Remez/equioscillation), and contrasts with the former hardcoded NS coefficients:
former produced US'V^T with scattered singular values; Polar Express converges to
the exact polar factor UV^T. Also removes unused loguru import.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

torch-ext/optimizer/newton_schulz.py CHANGED
@@ -4,7 +4,6 @@ from math import inf, sqrt
4
  import numpy as np
5
 
6
  from .matmul_transpose_triton import matmul_transpose_assign
7
- from loguru import logger
8
 
9
  COMM_DTYPE = torch.bfloat16
10
  DEFAULT_CHUNK_SIZE_RATIO = 4
@@ -78,6 +77,18 @@ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0):
78
  return coefficients
79
 
80
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  _coeffs_list = _optimal_composition(l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02)
82
 
83
  # This code is adapted from:
 
4
  import numpy as np
5
 
6
  from .matmul_transpose_triton import matmul_transpose_assign
 
7
 
8
  COMM_DTYPE = torch.bfloat16
9
  DEFAULT_CHUNK_SIZE_RATIO = 4
 
77
  return coefficients
78
 
79
 
80
+ # Precomputed Polar Express coefficients (a, b, c) for 10 quintic Newton-Schulz
81
+ # iterations. Each tuple is the minimax-optimal (Remez/equioscillation) odd quintic
82
+ # approximant to x->1 over the current singular-value interval, computed once at
83
+ # import time and reused across all optimizer steps.
84
+ #
85
+ # Contrast with the former hardcoded NS coefficients (5 fixed tuples):
86
+ # - Former: empirically tuned to maximize slope at zero; did not converge
87
+ # singular values to 1, yielding US'V^T with S' ~ Uniform(0.5, 1.5) instead
88
+ # of the true polar factor UV^T.
89
+ # - Polar Express: analytically optimal per step, adapting to the shrinking
90
+ # singular-value interval [l, u] as iterations progress; converges all
91
+ # singular values to 1, producing the exact polar factor UV^T.
92
  _coeffs_list = _optimal_composition(l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02)
93
 
94
  # This code is adapted from: