shuffled-linear-regression / experiments.py
abidlabs's picture
abidlabs HF Staff
Add shuffled-regression inconsistency demo + scripts, proof, paper, figures
652b4df verified
Raw
History Blame Contribute Delete
9.68 kB
"""
Simulations confirming the generalized Theorem 1 (see PROOF.md).
Run: python experiments.py
Outputs figures to ./figures and prints a numeric report.
Experiments
-----------
1. Loss convergence: empirical sorted loss L_n(w) -> population ell(w),
pointwise over random w, as n grows (validates Sec. 5).
2. d=1 formula: LS estimate -> Eq. (4) across sigma_E and mu_X sweeps.
3. Invariants (d>1): mean-match and variance-amplification invariants
converge to their predicted targets as n grows.
4. Norm amplification: ||w_hat||^2_Sigma - ||w0||^2_Sigma -> sigma_E^2,
independent of d.
"""
from __future__ import annotations
import os
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import shuffled_ls as sl
FIGDIR = os.path.join(os.path.dirname(__file__), "figures")
os.makedirs(FIGDIR, exist_ok=True)
C = {"emp": "#1f77b4", "thy": "#d62728", "grid": "#cccccc"}
def banner(title):
print("\n" + "=" * 70 + f"\n{title}\n" + "=" * 70)
# --------------------------------------------------------------------------- #
# Experiment 1: empirical loss -> population loss, pointwise
# --------------------------------------------------------------------------- #
def exp_loss_convergence(seed=1):
banner("Experiment 1: L_n(w) -> ell(w) pointwise (d=4, general Sigma_X)")
rng = np.random.default_rng(seed)
d = 4
w0 = rng.standard_normal(d)
mu_X = rng.standard_normal(d)
Sigma_X = sl.random_spd(d, rng)
sigma_E = 0.8
# 200 random probe weights, fixed across n
n_probes = 200
W = rng.standard_normal((n_probes, d)) * 1.5
ell = np.array([sl.population_loss(w, w0, mu_X, Sigma_X, sigma_E) for w in W])
ns = [200, 1000, 5000, 50000]
fig, axes = plt.subplots(1, len(ns), figsize=(4 * len(ns), 4), sharex=True, sharey=True)
for ax, n in zip(axes, ns):
x, y = sl.make_data(n, w0, mu_X, Sigma_X, sigma_E, rng)
ys = np.sort(y)
Ln = np.array([sl.ls_loss(w, x, ys) for w in W])
rmse = np.sqrt(np.mean((Ln - ell) ** 2))
lim = [0, max(ell.max(), Ln.max()) * 1.05]
ax.plot(lim, lim, "--", color=C["thy"], lw=1.5, label="y = x")
ax.scatter(ell, Ln, s=14, alpha=0.6, color=C["emp"])
ax.set_title(f"n = {n:,}\nRMSE = {rmse:.4f}")
ax.set_xlabel(r"population loss $\ell(w)$")
ax.grid(True, color=C["grid"], lw=0.5)
ax.set_aspect("equal", "box")
print(f" n={n:6d} RMSE(L_n, ell) = {rmse:.5f}")
axes[0].set_ylabel(r"empirical loss $L_n(w)$")
axes[0].legend(loc="upper left")
fig.suptitle(r"Empirical sorted loss converges to the population loss $\ell(w)$ (d=4)",
fontweight="bold")
fig.tight_layout()
fig.savefig(os.path.join(FIGDIR, "fig1_loss_convergence.png"), dpi=130)
plt.close(fig)
# --------------------------------------------------------------------------- #
# Experiment 2: d=1 estimator matches Eq. (4)
# --------------------------------------------------------------------------- #
def exp_d1_formula(seed=2):
banner("Experiment 2: d=1 LS estimate -> Eq. (4)")
rng = np.random.default_rng(seed)
n = 200000
n_trials = 12
fig, axes = plt.subplots(1, 2, figsize=(12, 4.6))
# (a) sweep sigma_E with fixed mu_X, sigma_X, w0
w0, mu_X, sigma_X = 1.5, 1.0, 1.0
sigmaEs = np.linspace(0.0, 3.0, 13)
emp_mean, emp_std, pred = [], [], []
for sE in sigmaEs:
ws = []
for _ in range(n_trials):
x, y = sl.make_data(n, [w0], [mu_X], [[sigma_X ** 2]], sE, rng)
ws.append(sl.fit_ls_1d(x, y))
emp_mean.append(np.mean(ws)); emp_std.append(np.std(ws))
pred.append(sl.theorem1_limit_1d(w0, mu_X, sigma_X, sE))
emp_mean, emp_std, pred = map(np.array, (emp_mean, emp_std, pred))
ax = axes[0]
ax.errorbar(sigmaEs, emp_mean, yerr=emp_std, fmt="o", color=C["emp"],
capsize=3, label="empirical $\\hat w_{LS}$")
ax.plot(sigmaEs, pred, "-", color=C["thy"], lw=2, label="Eq. (4)")
ax.axhline(w0, ls=":", color="gray", label="$w_0$")
ax.set_xlabel(r"noise std $\sigma_E$"); ax.set_ylabel(r"$\hat w_{LS}$")
ax.set_title(r"Sweep $\sigma_E$ ($w_0=1.5,\ \mu_X=1,\ \sigma_X=1$)")
ax.legend(); ax.grid(True, color=C["grid"], lw=0.5)
print(" (a) sigma_E sweep: max|emp-pred| =", f"{np.max(np.abs(emp_mean-pred)):.4f}")
# (b) sweep mu_X with fixed sigma_E
w0, sigma_X, sigma_E = 1.0, 1.0, 1.0
muXs = np.linspace(0.1, 3.0, 13)
emp_mean, emp_std, pred = [], [], []
for mX in muXs:
ws = []
for _ in range(n_trials):
x, y = sl.make_data(n, [w0], [mX], [[sigma_X ** 2]], sigma_E, rng)
ws.append(sl.fit_ls_1d(x, y))
emp_mean.append(np.mean(ws)); emp_std.append(np.std(ws))
pred.append(sl.theorem1_limit_1d(w0, mX, sigma_X, sigma_E))
emp_mean, emp_std, pred = map(np.array, (emp_mean, emp_std, pred))
ax = axes[1]
ax.errorbar(muXs, emp_mean, yerr=emp_std, fmt="o", color=C["emp"],
capsize=3, label="empirical $\\hat w_{LS}$")
ax.plot(muXs, pred, "-", color=C["thy"], lw=2, label="Eq. (4)")
ax.axhline(w0, ls=":", color="gray", label="$w_0$")
ax.set_xlabel(r"feature mean $\mu_X$"); ax.set_ylabel(r"$\hat w_{LS}$")
ax.set_title(r"Sweep $\mu_X$ ($w_0=1,\ \sigma_X=1,\ \sigma_E=1$)")
ax.legend(); ax.grid(True, color=C["grid"], lw=0.5)
print(" (b) mu_X sweep: max|emp-pred| =", f"{np.max(np.abs(emp_mean-pred)):.4f}")
fig.suptitle(r"$d=1$: shuffled-LS estimate matches Theorem 1, Eq. (4) (amplification bias)",
fontweight="bold")
fig.tight_layout()
fig.savefig(os.path.join(FIGDIR, "fig2_d1_formula.png"), dpi=130)
plt.close(fig)
# Also confirm convergence in n at one setting
print(" convergence in n (w0=1.5, muX=1, sigX=1, sigE=1.5):")
pred1 = sl.theorem1_limit_1d(1.5, 1.0, 1.0, 1.5)
for n in [1000, 10000, 100000, 1000000]:
ws = [sl.fit_ls_1d(*sl.make_data(n, [1.5], [1.0], [[1.0]], 1.5, rng))
for _ in range(8)]
print(f" n={n:8d} mean what={np.mean(ws):.4f} pred={pred1:.4f}"
f" |err|={abs(np.mean(ws)-pred1):.4f}")
# --------------------------------------------------------------------------- #
# Experiment 3 + 4: invariants and norm amplification for d>1
# --------------------------------------------------------------------------- #
def exp_invariants(seed=3):
banner("Experiment 3+4: moment-matching invariants & norm amplification (d>1)")
rng = np.random.default_rng(seed)
dims = [2, 3, 5]
sigma_E = 0.9
ns = [500, 2000, 8000, 30000, 100000]
# store per-dim convergence of the two invariant *errors*
results = {}
for d in dims:
w0 = rng.standard_normal(d)
mu_X = rng.standard_normal(d)
Sigma_X = sl.random_spd(d, rng)
mT, vT = sl.target_invariants(w0, mu_X, Sigma_X, sigma_E)
w0_norm2 = vT - sigma_E ** 2
mean_err, var_err, infl = [], [], []
for n in ns:
mse_m, mse_v, infl_n = [], [], []
for _ in range(5):
x, y = sl.make_data(n, w0, mu_X, Sigma_X, sigma_E, rng)
w_hat, _ = sl.fit_ls(x, y, n_starts=8, w0_hint=w0, rng=rng)
mS, vS = sl.invariants(w_hat, mu_X, Sigma_X)
mse_m.append(abs(mS - mT))
mse_v.append(abs(vS - vT))
infl_n.append(vS - w0_norm2)
mean_err.append(np.mean(mse_m))
var_err.append(np.mean(mse_v))
infl.append(np.mean(infl_n))
print(f" d={d} n={n:6d} |mean-match err|={mean_err[-1]:.4f}"
f" |var err|={var_err[-1]:.4f} inflation={infl[-1]:.4f}"
f" (target sigma_E^2={sigma_E**2:.3f})")
results[d] = dict(mean_err=np.array(mean_err), var_err=np.array(var_err),
infl=np.array(infl))
# Figure 3: invariant errors -> 0
fig, axes = plt.subplots(1, 2, figsize=(12, 4.6))
for d in dims:
axes[0].loglog(ns, results[d]["mean_err"], "o-", label=f"d={d}")
axes[1].loglog(ns, results[d]["var_err"], "o-", label=f"d={d}")
axes[0].set_title(r"Mean-match invariant error $|\hat w^\top\mu_X - w_0^\top\mu_X|$")
axes[1].set_title(r"Variance invariant error $|\hat w^\top\Sigma_X\hat w - (w_0^\top\Sigma_X w_0+\sigma_E^2)|$")
for ax in axes:
ax.set_xlabel("n"); ax.set_ylabel("absolute error")
ax.grid(True, which="both", color=C["grid"], lw=0.5); ax.legend()
fig.suptitle(r"$d>1$: LS estimate satisfies both moment-matching invariants as $n\to\infty$",
fontweight="bold")
fig.tight_layout()
fig.savefig(os.path.join(FIGDIR, "fig3_invariants.png"), dpi=130)
plt.close(fig)
# Figure 4: norm inflation -> sigma_E^2, independent of d
fig, ax = plt.subplots(figsize=(7, 4.8))
for d in dims:
ax.semilogx(ns, results[d]["infl"], "o-", label=f"d={d}")
ax.axhline(sigma_E ** 2, ls="--", color=C["thy"], lw=2,
label=r"prediction $\sigma_E^2$")
ax.set_xlabel("n")
ax.set_ylabel(r"$\|\hat w\|^2_{\Sigma_X} - \|w_0\|^2_{\Sigma_X}$")
ax.set_title(r"Norm amplification equals $\sigma_E^2$, independent of $d$")
ax.grid(True, which="both", color=C["grid"], lw=0.5); ax.legend()
fig.tight_layout()
fig.savefig(os.path.join(FIGDIR, "fig4_amplification.png"), dpi=130)
plt.close(fig)
if __name__ == "__main__":
exp_loss_convergence()
exp_d1_formula()
exp_invariants()
print("\nAll figures written to:", FIGDIR)