""" Simulations confirming the generalized Theorem 1 (see PROOF.md). Run: python experiments.py Outputs figures to ./figures and prints a numeric report. Experiments ----------- 1. Loss convergence: empirical sorted loss L_n(w) -> population ell(w), pointwise over random w, as n grows (validates Sec. 5). 2. d=1 formula: LS estimate -> Eq. (4) across sigma_E and mu_X sweeps. 3. Invariants (d>1): mean-match and variance-amplification invariants converge to their predicted targets as n grows. 4. Norm amplification: ||w_hat||^2_Sigma - ||w0||^2_Sigma -> sigma_E^2, independent of d. """ from __future__ import annotations import os import numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import shuffled_ls as sl FIGDIR = os.path.join(os.path.dirname(__file__), "figures") os.makedirs(FIGDIR, exist_ok=True) C = {"emp": "#1f77b4", "thy": "#d62728", "grid": "#cccccc"} def banner(title): print("\n" + "=" * 70 + f"\n{title}\n" + "=" * 70) # --------------------------------------------------------------------------- # # Experiment 1: empirical loss -> population loss, pointwise # --------------------------------------------------------------------------- # def exp_loss_convergence(seed=1): banner("Experiment 1: L_n(w) -> ell(w) pointwise (d=4, general Sigma_X)") rng = np.random.default_rng(seed) d = 4 w0 = rng.standard_normal(d) mu_X = rng.standard_normal(d) Sigma_X = sl.random_spd(d, rng) sigma_E = 0.8 # 200 random probe weights, fixed across n n_probes = 200 W = rng.standard_normal((n_probes, d)) * 1.5 ell = np.array([sl.population_loss(w, w0, mu_X, Sigma_X, sigma_E) for w in W]) ns = [200, 1000, 5000, 50000] fig, axes = plt.subplots(1, len(ns), figsize=(4 * len(ns), 4), sharex=True, sharey=True) for ax, n in zip(axes, ns): x, y = sl.make_data(n, w0, mu_X, Sigma_X, sigma_E, rng) ys = np.sort(y) Ln = np.array([sl.ls_loss(w, x, ys) for w in W]) rmse = np.sqrt(np.mean((Ln - ell) ** 2)) lim = [0, max(ell.max(), Ln.max()) * 1.05] ax.plot(lim, lim, "--", color=C["thy"], lw=1.5, label="y = x") ax.scatter(ell, Ln, s=14, alpha=0.6, color=C["emp"]) ax.set_title(f"n = {n:,}\nRMSE = {rmse:.4f}") ax.set_xlabel(r"population loss $\ell(w)$") ax.grid(True, color=C["grid"], lw=0.5) ax.set_aspect("equal", "box") print(f" n={n:6d} RMSE(L_n, ell) = {rmse:.5f}") axes[0].set_ylabel(r"empirical loss $L_n(w)$") axes[0].legend(loc="upper left") fig.suptitle(r"Empirical sorted loss converges to the population loss $\ell(w)$ (d=4)", fontweight="bold") fig.tight_layout() fig.savefig(os.path.join(FIGDIR, "fig1_loss_convergence.png"), dpi=130) plt.close(fig) # --------------------------------------------------------------------------- # # Experiment 2: d=1 estimator matches Eq. (4) # --------------------------------------------------------------------------- # def exp_d1_formula(seed=2): banner("Experiment 2: d=1 LS estimate -> Eq. (4)") rng = np.random.default_rng(seed) n = 200000 n_trials = 12 fig, axes = plt.subplots(1, 2, figsize=(12, 4.6)) # (a) sweep sigma_E with fixed mu_X, sigma_X, w0 w0, mu_X, sigma_X = 1.5, 1.0, 1.0 sigmaEs = np.linspace(0.0, 3.0, 13) emp_mean, emp_std, pred = [], [], [] for sE in sigmaEs: ws = [] for _ in range(n_trials): x, y = sl.make_data(n, [w0], [mu_X], [[sigma_X ** 2]], sE, rng) ws.append(sl.fit_ls_1d(x, y)) emp_mean.append(np.mean(ws)); emp_std.append(np.std(ws)) pred.append(sl.theorem1_limit_1d(w0, mu_X, sigma_X, sE)) emp_mean, emp_std, pred = map(np.array, (emp_mean, emp_std, pred)) ax = axes[0] ax.errorbar(sigmaEs, emp_mean, yerr=emp_std, fmt="o", color=C["emp"], capsize=3, label="empirical $\\hat w_{LS}$") ax.plot(sigmaEs, pred, "-", color=C["thy"], lw=2, label="Eq. (4)") ax.axhline(w0, ls=":", color="gray", label="$w_0$") ax.set_xlabel(r"noise std $\sigma_E$"); ax.set_ylabel(r"$\hat w_{LS}$") ax.set_title(r"Sweep $\sigma_E$ ($w_0=1.5,\ \mu_X=1,\ \sigma_X=1$)") ax.legend(); ax.grid(True, color=C["grid"], lw=0.5) print(" (a) sigma_E sweep: max|emp-pred| =", f"{np.max(np.abs(emp_mean-pred)):.4f}") # (b) sweep mu_X with fixed sigma_E w0, sigma_X, sigma_E = 1.0, 1.0, 1.0 muXs = np.linspace(0.1, 3.0, 13) emp_mean, emp_std, pred = [], [], [] for mX in muXs: ws = [] for _ in range(n_trials): x, y = sl.make_data(n, [w0], [mX], [[sigma_X ** 2]], sigma_E, rng) ws.append(sl.fit_ls_1d(x, y)) emp_mean.append(np.mean(ws)); emp_std.append(np.std(ws)) pred.append(sl.theorem1_limit_1d(w0, mX, sigma_X, sigma_E)) emp_mean, emp_std, pred = map(np.array, (emp_mean, emp_std, pred)) ax = axes[1] ax.errorbar(muXs, emp_mean, yerr=emp_std, fmt="o", color=C["emp"], capsize=3, label="empirical $\\hat w_{LS}$") ax.plot(muXs, pred, "-", color=C["thy"], lw=2, label="Eq. (4)") ax.axhline(w0, ls=":", color="gray", label="$w_0$") ax.set_xlabel(r"feature mean $\mu_X$"); ax.set_ylabel(r"$\hat w_{LS}$") ax.set_title(r"Sweep $\mu_X$ ($w_0=1,\ \sigma_X=1,\ \sigma_E=1$)") ax.legend(); ax.grid(True, color=C["grid"], lw=0.5) print(" (b) mu_X sweep: max|emp-pred| =", f"{np.max(np.abs(emp_mean-pred)):.4f}") fig.suptitle(r"$d=1$: shuffled-LS estimate matches Theorem 1, Eq. (4) (amplification bias)", fontweight="bold") fig.tight_layout() fig.savefig(os.path.join(FIGDIR, "fig2_d1_formula.png"), dpi=130) plt.close(fig) # Also confirm convergence in n at one setting print(" convergence in n (w0=1.5, muX=1, sigX=1, sigE=1.5):") pred1 = sl.theorem1_limit_1d(1.5, 1.0, 1.0, 1.5) for n in [1000, 10000, 100000, 1000000]: ws = [sl.fit_ls_1d(*sl.make_data(n, [1.5], [1.0], [[1.0]], 1.5, rng)) for _ in range(8)] print(f" n={n:8d} mean what={np.mean(ws):.4f} pred={pred1:.4f}" f" |err|={abs(np.mean(ws)-pred1):.4f}") # --------------------------------------------------------------------------- # # Experiment 3 + 4: invariants and norm amplification for d>1 # --------------------------------------------------------------------------- # def exp_invariants(seed=3): banner("Experiment 3+4: moment-matching invariants & norm amplification (d>1)") rng = np.random.default_rng(seed) dims = [2, 3, 5] sigma_E = 0.9 ns = [500, 2000, 8000, 30000, 100000] # store per-dim convergence of the two invariant *errors* results = {} for d in dims: w0 = rng.standard_normal(d) mu_X = rng.standard_normal(d) Sigma_X = sl.random_spd(d, rng) mT, vT = sl.target_invariants(w0, mu_X, Sigma_X, sigma_E) w0_norm2 = vT - sigma_E ** 2 mean_err, var_err, infl = [], [], [] for n in ns: mse_m, mse_v, infl_n = [], [], [] for _ in range(5): x, y = sl.make_data(n, w0, mu_X, Sigma_X, sigma_E, rng) w_hat, _ = sl.fit_ls(x, y, n_starts=8, w0_hint=w0, rng=rng) mS, vS = sl.invariants(w_hat, mu_X, Sigma_X) mse_m.append(abs(mS - mT)) mse_v.append(abs(vS - vT)) infl_n.append(vS - w0_norm2) mean_err.append(np.mean(mse_m)) var_err.append(np.mean(mse_v)) infl.append(np.mean(infl_n)) print(f" d={d} n={n:6d} |mean-match err|={mean_err[-1]:.4f}" f" |var err|={var_err[-1]:.4f} inflation={infl[-1]:.4f}" f" (target sigma_E^2={sigma_E**2:.3f})") results[d] = dict(mean_err=np.array(mean_err), var_err=np.array(var_err), infl=np.array(infl)) # Figure 3: invariant errors -> 0 fig, axes = plt.subplots(1, 2, figsize=(12, 4.6)) for d in dims: axes[0].loglog(ns, results[d]["mean_err"], "o-", label=f"d={d}") axes[1].loglog(ns, results[d]["var_err"], "o-", label=f"d={d}") axes[0].set_title(r"Mean-match invariant error $|\hat w^\top\mu_X - w_0^\top\mu_X|$") axes[1].set_title(r"Variance invariant error $|\hat w^\top\Sigma_X\hat w - (w_0^\top\Sigma_X w_0+\sigma_E^2)|$") for ax in axes: ax.set_xlabel("n"); ax.set_ylabel("absolute error") ax.grid(True, which="both", color=C["grid"], lw=0.5); ax.legend() fig.suptitle(r"$d>1$: LS estimate satisfies both moment-matching invariants as $n\to\infty$", fontweight="bold") fig.tight_layout() fig.savefig(os.path.join(FIGDIR, "fig3_invariants.png"), dpi=130) plt.close(fig) # Figure 4: norm inflation -> sigma_E^2, independent of d fig, ax = plt.subplots(figsize=(7, 4.8)) for d in dims: ax.semilogx(ns, results[d]["infl"], "o-", label=f"d={d}") ax.axhline(sigma_E ** 2, ls="--", color=C["thy"], lw=2, label=r"prediction $\sigma_E^2$") ax.set_xlabel("n") ax.set_ylabel(r"$\|\hat w\|^2_{\Sigma_X} - \|w_0\|^2_{\Sigma_X}$") ax.set_title(r"Norm amplification equals $\sigma_E^2$, independent of $d$") ax.grid(True, which="both", color=C["grid"], lw=0.5); ax.legend() fig.tight_layout() fig.savefig(os.path.join(FIGDIR, "fig4_amplification.png"), dpi=130) plt.close(fig) if __name__ == "__main__": exp_loss_convergence() exp_d1_formula() exp_invariants() print("\nAll figures written to:", FIGDIR)