File size: 6,290 Bytes
087391f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 | """
Diagnostic: what exactly breaks in parallel root-finding?
Test 1: Pure parallel Laguerre (no Aberth, no clamp, no damp)
Test 2: Parallel Laguerre + Aberth
Test 3: Sequential Laguerre + deflation (baseline)
Prints per-iteration convergence to identify exactly where it goes wrong.
"""
import math, torch
torch.backends.cuda.matmul.allow_tf32 = False
torch.set_float32_matmul_precision('highest')
dev = torch.device('cuda')
B = 512; N = 6
torch.manual_seed(42)
A = (lambda R: (R+R.mT)/2)(torch.randn(B, N, N, device=dev))
rv, rV = torch.linalg.eigh(A)
# FL Phase 1 β get characteristic polynomial
sc = (torch.linalg.norm(A.reshape(B,-1), dim=-1) / math.sqrt(N)).clamp(min=1e-12)
As = A / sc[:, None, None]; Ad = As.double()
I_d = torch.eye(N, device=dev, dtype=torch.float64).unsqueeze(0).expand(B,-1,-1)
c = torch.zeros(B, N+1, device=dev, dtype=torch.float64); c[:, N] = 1.0
Mk = torch.zeros(B, N, N, device=dev, dtype=torch.float64)
for k in range(1, N+1):
Mk = torch.bmm(Ad, Mk) + c[:, N-k+1, None, None] * I_d
c[:, N-k] = -(Ad * Mk).sum((-2,-1)) / k
# True roots (scaled)
true_roots = (rv / sc.unsqueeze(-1)).double().sort(dim=-1).values
# Init from diagonal
z_init = Ad.diagonal(dim1=-2, dim2=-1).sort(dim=-1).values
pert = torch.linspace(-1e-3, 1e-3, N, device=dev, dtype=torch.float64).unsqueeze(0)
z_init = z_init + pert
def horner_pd(c, z):
"""Evaluate p(z), p'(z), p''(z)/2 via Horner. c: [B,n+1], z: [B,n]"""
B, n_roots = z.shape
n = c.shape[1] - 1
pv = c[:, n:n+1].expand(B, n_roots)
dp = torch.zeros_like(pv)
d2 = torch.zeros_like(pv)
for j in range(n-1, -1, -1):
d2 = d2 * z + dp
dp = dp * z + pv
pv = pv * z + c[:, j:j+1]
return pv, dp, d2
def laguerre_step(c, z, n):
pv, dp, d2 = horner_pd(c, z)
ok = pv.abs() > 1e-30
ps = torch.where(ok, pv, torch.ones_like(pv))
G = torch.where(ok, dp / ps, torch.zeros_like(dp))
H = G * G - torch.where(ok, 2.0 * d2 / ps, torch.zeros_like(d2))
disc = ((n-1.0) * (n * H - G * G)).clamp(min=0.0)
sq = torch.sqrt(disc)
gp = G + sq; gm = G - sq
den = torch.where(gp.abs() >= gm.abs(), gp, gm)
return torch.where(den.abs() > 1e-20, float(n) / den, torch.zeros_like(den))
mask_eye = torch.eye(N, device=dev, dtype=torch.bool).unsqueeze(0)
def aberth_correction(z):
diffs = z.unsqueeze(-1) - z.unsqueeze(-2)
diffs_safe = diffs.masked_fill(mask_eye, 1.0)
return (1.0 / diffs_safe).masked_fill(mask_eye, 0.0).sum(-1)
def report(label, z, iteration):
err = (z.sort(dim=-1).values - true_roots).abs().max().item()
# Check for duplicates: min gap between sorted roots
zs = z.sort(dim=-1).values
min_gap = (zs[:, 1:] - zs[:, :-1]).min().item()
# p(z) residual
pv, _, _ = horner_pd(c, z)
p_res = pv.abs().max().item()
print(f" {label:>5} it={iteration:>2} max_err={err:.2e} min_gap={min_gap:.2e} |p(z)|={p_res:.2e}")
print("="*78)
print(" Diagnostic: Parallel Root-Finding")
print("="*78)
print(f" B={B} N={N}")
print(f" True eigenvalue range: [{true_roots.min().item():.3f}, {true_roots.max().item():.3f}]")
print(f" Diagonal init range: [{z_init.min().item():.3f}, {z_init.max().item():.3f}]")
# βββ Test 1: Pure parallel Laguerre (no Aberth) βββ
print(f"\n --- Test 1: Pure Laguerre (no Aberth) ---")
z = z_init.clone()
for it in range(20):
step = laguerre_step(c, z, N)
z = z - step
if it < 5 or it % 5 == 4:
report("PurL", z, it)
# βββ Test 2: Laguerre + Aberth (full strength) βββ
print(f"\n --- Test 2: Laguerre + Aberth (full) ---")
z = z_init.clone()
for it in range(20):
step = laguerre_step(c, z, N)
corr = aberth_correction(z)
denom = 1.0 - step * corr
denom_safe = torch.where(denom.abs() > 1e-20, denom, torch.ones_like(denom))
full_step = torch.where(denom.abs() > 1e-20, step / denom_safe, step)
z = z - full_step
if it < 5 or it % 5 == 4:
report("LA-F", z, it)
# βββ Test 3: Laguerre + weak Aberth (0.1Γ correction) βββ
print(f"\n --- Test 3: Laguerre + weak Aberth (0.1x) ---")
z = z_init.clone()
for it in range(20):
step = laguerre_step(c, z, N)
corr = aberth_correction(z)
denom = 1.0 - 0.1 * step * corr
denom_safe = torch.where(denom.abs() > 1e-20, denom, torch.ones_like(denom))
full_step = torch.where(denom.abs() > 1e-20, step / denom_safe, step)
z = z - full_step
if it < 5 or it % 5 == 4:
report("LA.1", z, it)
# βββ Test 4: Pure Laguerre + post-sort each iteration βββ
print(f"\n --- Test 4: Pure Laguerre + re-sort ---")
z = z_init.clone()
for it in range(20):
step = laguerre_step(c, z, N)
z = z - step
z = z.sort(dim=-1).values # keep sorted
if it < 5 or it % 5 == 4:
report("PL+S", z, it)
# βββ Test 5: Laguerre + Aberth + damped ramp βββ
print(f"\n --- Test 5: Laguerre + Aberth damped (0.1 β 1.0) ---")
z = z_init.clone()
for it in range(20):
step = laguerre_step(c, z, N)
corr = aberth_correction(z)
alpha = min(1.0, 0.1 + 0.1 * it)
denom = 1.0 - alpha * step * corr
denom_safe = torch.where(denom.abs() > 1e-20, denom, torch.ones_like(denom))
full_step = torch.where(denom.abs() > 1e-20, step / denom_safe, step)
z = z - full_step
z = z.sort(dim=-1).values
if it < 5 or it % 5 == 4:
report("LADa", z, it)
# βββ Test 6: Newton + Aberth (original Aberth-Ehrlich) βββ
print(f"\n --- Test 6: Newton + Aberth ---")
z = z_init.clone()
for it in range(20):
pv, dp, _ = horner_pd(c, z)
ok = dp.abs() > 1e-30
w = torch.where(ok, pv / dp, torch.zeros_like(pv))
corr = aberth_correction(z)
denom = 1.0 - w * corr
denom_safe = torch.where(denom.abs() > 1e-20, denom, torch.ones_like(denom))
full_step = torch.where(denom.abs() > 1e-20, w / denom_safe, w)
z = z - full_step
if it < 5 or it % 5 == 4:
report("NwAb", z, it)
# βββ Test 7: Pure Newton (no Aberth) βββ
print(f"\n --- Test 7: Pure Newton ---")
z = z_init.clone()
for it in range(20):
pv, dp, _ = horner_pd(c, z)
ok = dp.abs() > 1e-30
w = torch.where(ok, pv / dp, torch.zeros_like(pv))
z = z - w
if it < 5 or it % 5 == 4:
report("PurN", z, it)
print("="*78) |