Spaces:
Sleeping
Sleeping
| """ | |
| Simulations confirming the generalized Theorem 1 (see PROOF.md). | |
| Run: python experiments.py | |
| Outputs figures to ./figures and prints a numeric report. | |
| Experiments | |
| ----------- | |
| 1. Loss convergence: empirical sorted loss L_n(w) -> population ell(w), | |
| pointwise over random w, as n grows (validates Sec. 5). | |
| 2. d=1 formula: LS estimate -> Eq. (4) across sigma_E and mu_X sweeps. | |
| 3. Invariants (d>1): mean-match and variance-amplification invariants | |
| converge to their predicted targets as n grows. | |
| 4. Norm amplification: ||w_hat||^2_Sigma - ||w0||^2_Sigma -> sigma_E^2, | |
| independent of d. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import numpy as np | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import shuffled_ls as sl | |
| FIGDIR = os.path.join(os.path.dirname(__file__), "figures") | |
| os.makedirs(FIGDIR, exist_ok=True) | |
| C = {"emp": "#1f77b4", "thy": "#d62728", "grid": "#cccccc"} | |
| def banner(title): | |
| print("\n" + "=" * 70 + f"\n{title}\n" + "=" * 70) | |
| # --------------------------------------------------------------------------- # | |
| # Experiment 1: empirical loss -> population loss, pointwise | |
| # --------------------------------------------------------------------------- # | |
| def exp_loss_convergence(seed=1): | |
| banner("Experiment 1: L_n(w) -> ell(w) pointwise (d=4, general Sigma_X)") | |
| rng = np.random.default_rng(seed) | |
| d = 4 | |
| w0 = rng.standard_normal(d) | |
| mu_X = rng.standard_normal(d) | |
| Sigma_X = sl.random_spd(d, rng) | |
| sigma_E = 0.8 | |
| # 200 random probe weights, fixed across n | |
| n_probes = 200 | |
| W = rng.standard_normal((n_probes, d)) * 1.5 | |
| ell = np.array([sl.population_loss(w, w0, mu_X, Sigma_X, sigma_E) for w in W]) | |
| ns = [200, 1000, 5000, 50000] | |
| fig, axes = plt.subplots(1, len(ns), figsize=(4 * len(ns), 4), sharex=True, sharey=True) | |
| for ax, n in zip(axes, ns): | |
| x, y = sl.make_data(n, w0, mu_X, Sigma_X, sigma_E, rng) | |
| ys = np.sort(y) | |
| Ln = np.array([sl.ls_loss(w, x, ys) for w in W]) | |
| rmse = np.sqrt(np.mean((Ln - ell) ** 2)) | |
| lim = [0, max(ell.max(), Ln.max()) * 1.05] | |
| ax.plot(lim, lim, "--", color=C["thy"], lw=1.5, label="y = x") | |
| ax.scatter(ell, Ln, s=14, alpha=0.6, color=C["emp"]) | |
| ax.set_title(f"n = {n:,}\nRMSE = {rmse:.4f}") | |
| ax.set_xlabel(r"population loss $\ell(w)$") | |
| ax.grid(True, color=C["grid"], lw=0.5) | |
| ax.set_aspect("equal", "box") | |
| print(f" n={n:6d} RMSE(L_n, ell) = {rmse:.5f}") | |
| axes[0].set_ylabel(r"empirical loss $L_n(w)$") | |
| axes[0].legend(loc="upper left") | |
| fig.suptitle(r"Empirical sorted loss converges to the population loss $\ell(w)$ (d=4)", | |
| fontweight="bold") | |
| fig.tight_layout() | |
| fig.savefig(os.path.join(FIGDIR, "fig1_loss_convergence.png"), dpi=130) | |
| plt.close(fig) | |
| # --------------------------------------------------------------------------- # | |
| # Experiment 2: d=1 estimator matches Eq. (4) | |
| # --------------------------------------------------------------------------- # | |
| def exp_d1_formula(seed=2): | |
| banner("Experiment 2: d=1 LS estimate -> Eq. (4)") | |
| rng = np.random.default_rng(seed) | |
| n = 200000 | |
| n_trials = 12 | |
| fig, axes = plt.subplots(1, 2, figsize=(12, 4.6)) | |
| # (a) sweep sigma_E with fixed mu_X, sigma_X, w0 | |
| w0, mu_X, sigma_X = 1.5, 1.0, 1.0 | |
| sigmaEs = np.linspace(0.0, 3.0, 13) | |
| emp_mean, emp_std, pred = [], [], [] | |
| for sE in sigmaEs: | |
| ws = [] | |
| for _ in range(n_trials): | |
| x, y = sl.make_data(n, [w0], [mu_X], [[sigma_X ** 2]], sE, rng) | |
| ws.append(sl.fit_ls_1d(x, y)) | |
| emp_mean.append(np.mean(ws)); emp_std.append(np.std(ws)) | |
| pred.append(sl.theorem1_limit_1d(w0, mu_X, sigma_X, sE)) | |
| emp_mean, emp_std, pred = map(np.array, (emp_mean, emp_std, pred)) | |
| ax = axes[0] | |
| ax.errorbar(sigmaEs, emp_mean, yerr=emp_std, fmt="o", color=C["emp"], | |
| capsize=3, label="empirical $\\hat w_{LS}$") | |
| ax.plot(sigmaEs, pred, "-", color=C["thy"], lw=2, label="Eq. (4)") | |
| ax.axhline(w0, ls=":", color="gray", label="$w_0$") | |
| ax.set_xlabel(r"noise std $\sigma_E$"); ax.set_ylabel(r"$\hat w_{LS}$") | |
| ax.set_title(r"Sweep $\sigma_E$ ($w_0=1.5,\ \mu_X=1,\ \sigma_X=1$)") | |
| ax.legend(); ax.grid(True, color=C["grid"], lw=0.5) | |
| print(" (a) sigma_E sweep: max|emp-pred| =", f"{np.max(np.abs(emp_mean-pred)):.4f}") | |
| # (b) sweep mu_X with fixed sigma_E | |
| w0, sigma_X, sigma_E = 1.0, 1.0, 1.0 | |
| muXs = np.linspace(0.1, 3.0, 13) | |
| emp_mean, emp_std, pred = [], [], [] | |
| for mX in muXs: | |
| ws = [] | |
| for _ in range(n_trials): | |
| x, y = sl.make_data(n, [w0], [mX], [[sigma_X ** 2]], sigma_E, rng) | |
| ws.append(sl.fit_ls_1d(x, y)) | |
| emp_mean.append(np.mean(ws)); emp_std.append(np.std(ws)) | |
| pred.append(sl.theorem1_limit_1d(w0, mX, sigma_X, sigma_E)) | |
| emp_mean, emp_std, pred = map(np.array, (emp_mean, emp_std, pred)) | |
| ax = axes[1] | |
| ax.errorbar(muXs, emp_mean, yerr=emp_std, fmt="o", color=C["emp"], | |
| capsize=3, label="empirical $\\hat w_{LS}$") | |
| ax.plot(muXs, pred, "-", color=C["thy"], lw=2, label="Eq. (4)") | |
| ax.axhline(w0, ls=":", color="gray", label="$w_0$") | |
| ax.set_xlabel(r"feature mean $\mu_X$"); ax.set_ylabel(r"$\hat w_{LS}$") | |
| ax.set_title(r"Sweep $\mu_X$ ($w_0=1,\ \sigma_X=1,\ \sigma_E=1$)") | |
| ax.legend(); ax.grid(True, color=C["grid"], lw=0.5) | |
| print(" (b) mu_X sweep: max|emp-pred| =", f"{np.max(np.abs(emp_mean-pred)):.4f}") | |
| fig.suptitle(r"$d=1$: shuffled-LS estimate matches Theorem 1, Eq. (4) (amplification bias)", | |
| fontweight="bold") | |
| fig.tight_layout() | |
| fig.savefig(os.path.join(FIGDIR, "fig2_d1_formula.png"), dpi=130) | |
| plt.close(fig) | |
| # Also confirm convergence in n at one setting | |
| print(" convergence in n (w0=1.5, muX=1, sigX=1, sigE=1.5):") | |
| pred1 = sl.theorem1_limit_1d(1.5, 1.0, 1.0, 1.5) | |
| for n in [1000, 10000, 100000, 1000000]: | |
| ws = [sl.fit_ls_1d(*sl.make_data(n, [1.5], [1.0], [[1.0]], 1.5, rng)) | |
| for _ in range(8)] | |
| print(f" n={n:8d} mean what={np.mean(ws):.4f} pred={pred1:.4f}" | |
| f" |err|={abs(np.mean(ws)-pred1):.4f}") | |
| # --------------------------------------------------------------------------- # | |
| # Experiment 3 + 4: invariants and norm amplification for d>1 | |
| # --------------------------------------------------------------------------- # | |
| def exp_invariants(seed=3): | |
| banner("Experiment 3+4: moment-matching invariants & norm amplification (d>1)") | |
| rng = np.random.default_rng(seed) | |
| dims = [2, 3, 5] | |
| sigma_E = 0.9 | |
| ns = [500, 2000, 8000, 30000, 100000] | |
| # store per-dim convergence of the two invariant *errors* | |
| results = {} | |
| for d in dims: | |
| w0 = rng.standard_normal(d) | |
| mu_X = rng.standard_normal(d) | |
| Sigma_X = sl.random_spd(d, rng) | |
| mT, vT = sl.target_invariants(w0, mu_X, Sigma_X, sigma_E) | |
| w0_norm2 = vT - sigma_E ** 2 | |
| mean_err, var_err, infl = [], [], [] | |
| for n in ns: | |
| mse_m, mse_v, infl_n = [], [], [] | |
| for _ in range(5): | |
| x, y = sl.make_data(n, w0, mu_X, Sigma_X, sigma_E, rng) | |
| w_hat, _ = sl.fit_ls(x, y, n_starts=8, w0_hint=w0, rng=rng) | |
| mS, vS = sl.invariants(w_hat, mu_X, Sigma_X) | |
| mse_m.append(abs(mS - mT)) | |
| mse_v.append(abs(vS - vT)) | |
| infl_n.append(vS - w0_norm2) | |
| mean_err.append(np.mean(mse_m)) | |
| var_err.append(np.mean(mse_v)) | |
| infl.append(np.mean(infl_n)) | |
| print(f" d={d} n={n:6d} |mean-match err|={mean_err[-1]:.4f}" | |
| f" |var err|={var_err[-1]:.4f} inflation={infl[-1]:.4f}" | |
| f" (target sigma_E^2={sigma_E**2:.3f})") | |
| results[d] = dict(mean_err=np.array(mean_err), var_err=np.array(var_err), | |
| infl=np.array(infl)) | |
| # Figure 3: invariant errors -> 0 | |
| fig, axes = plt.subplots(1, 2, figsize=(12, 4.6)) | |
| for d in dims: | |
| axes[0].loglog(ns, results[d]["mean_err"], "o-", label=f"d={d}") | |
| axes[1].loglog(ns, results[d]["var_err"], "o-", label=f"d={d}") | |
| axes[0].set_title(r"Mean-match invariant error $|\hat w^\top\mu_X - w_0^\top\mu_X|$") | |
| axes[1].set_title(r"Variance invariant error $|\hat w^\top\Sigma_X\hat w - (w_0^\top\Sigma_X w_0+\sigma_E^2)|$") | |
| for ax in axes: | |
| ax.set_xlabel("n"); ax.set_ylabel("absolute error") | |
| ax.grid(True, which="both", color=C["grid"], lw=0.5); ax.legend() | |
| fig.suptitle(r"$d>1$: LS estimate satisfies both moment-matching invariants as $n\to\infty$", | |
| fontweight="bold") | |
| fig.tight_layout() | |
| fig.savefig(os.path.join(FIGDIR, "fig3_invariants.png"), dpi=130) | |
| plt.close(fig) | |
| # Figure 4: norm inflation -> sigma_E^2, independent of d | |
| fig, ax = plt.subplots(figsize=(7, 4.8)) | |
| for d in dims: | |
| ax.semilogx(ns, results[d]["infl"], "o-", label=f"d={d}") | |
| ax.axhline(sigma_E ** 2, ls="--", color=C["thy"], lw=2, | |
| label=r"prediction $\sigma_E^2$") | |
| ax.set_xlabel("n") | |
| ax.set_ylabel(r"$\|\hat w\|^2_{\Sigma_X} - \|w_0\|^2_{\Sigma_X}$") | |
| ax.set_title(r"Norm amplification equals $\sigma_E^2$, independent of $d$") | |
| ax.grid(True, which="both", color=C["grid"], lw=0.5); ax.legend() | |
| fig.tight_layout() | |
| fig.savefig(os.path.join(FIGDIR, "fig4_amplification.png"), dpi=130) | |
| plt.close(fig) | |
| if __name__ == "__main__": | |
| exp_loss_convergence() | |
| exp_d1_formula() | |
| exp_invariants() | |
| print("\nAll figures written to:", FIGDIR) | |