robot-folding / compute_cld.py
pepijn223's picture
pepijn223 HF Staff
Improve DAgger explainer, add conclusion and expand references
f0f3d44 unverified
"""
Compute CLD (Compact Letter Display) letters and verify Beta posteriors
for all experiments using the TRI STEP sequential testing framework.
Parameters:
- global_confidence_level = 0.90 (Ξ±=0.10)
- n_max = 50
- shuffle = False (each rollout is independent)
- 11 experiments β†’ C(11,2) = 55 pairwise comparisons β†’ Bonferroni correction
"""
import json
import numpy as np
from scipy import stats
from sequentialized_barnard_tests.tools.plotting import compare_success_and_get_cld
EXPERIMENTS = {
"1.1 Ο€0": {"total": [8, 20], "L1": [8, 10], "L2": [0, 10]},
"1.2 Ο€0.5": {"total": [4, 20], "L1": [4, 10], "L2": [0, 10]},
"1.3 Relative": {"total": [7, 20], "L1": [7, 10], "L2": [0, 10]},
"1.4 RABC low": {"total": [3, 20], "L1": [3, 10], "L2": [0, 10]},
"1.5 RABC high": {"total": [0, 20], "L1": [0, 10], "L2": [0, 10]},
"1.7 Rel+RABC": {"total": [8, 20], "L1": [8, 10], "L2": [0, 10]},
"2.1 HQ": {"total": [8, 20], "L1": [7, 10], "L2": [1, 10]},
"2.2 HQ+RABC+Rel": {"total": [15, 20], "L1": [10, 10], "L2": [5, 10]},
"2.3 HQ+mirror": {"total": [1, 20], "L1": [0, 10], "L2": [1, 10]},
"2.4 HQ chunk45": {"total": [4, 20], "L1": [4, 10], "L2": [0, 10]},
"2.5 HQ+RABC+Relβ˜…": {"total": [18, 20], "L1": [10, 10], "L2": [8, 10]},
}
# What the HTML files use (percentages) β€” for cross-checking the round-trip
HTML_RAW_PCT = {
"1.1 Ο€0": {"total": 40, "l1": 80, "l2": 0},
"1.2 Ο€0.5": {"total": 20, "l1": 40, "l2": 0},
"1.3 Relative": {"total": 35, "l1": 70, "l2": 0},
"1.4 RABC low": {"total": 15, "l1": 30, "l2": 0},
"1.5 RABC high": {"total": 0, "l1": 0, "l2": 0},
"1.7 Rel+RABC": {"total": 40, "l1": 80, "l2": 0},
"2.1 HQ": {"total": 40, "l1": 70, "l2": 10},
"2.2 HQ+RABC+Rel": {"total": 75, "l1": 100, "l2": 50},
"2.3 HQ+mirror": {"total": 5, "l1": 0, "l2": 10},
"2.4 HQ chunk45": {"total": 20, "l1": 40, "l2": 0},
"2.5 HQ+RABC+Relβ˜…": {"total": 90, "l1": 100, "l2": 80},
}
HTML_N = {"total": 20, "l1": 10, "l2": 10}
GLOBAL_CONFIDENCE = 0.90
N_MAX = 50
SHUFFLE = False
model_names = list(EXPERIMENTS.keys())
def draw_samples_from_beta_posterior(
success_array: np.ndarray,
rng: np.random.Generator,
num_samples: int = 10000,
alpha_prior: float = 1,
beta_prior: float = 1,
) -> np.ndarray:
"""TRI's exact function from their notebook."""
n_trials = len(success_array)
n_successes = np.sum(success_array)
n_failures = n_trials - n_successes
posterior = stats.beta(alpha_prior + n_successes, beta_prior + n_failures)
return posterior.rvs(num_samples, random_state=rng)
# ── 1. CLD letters ──────────────────────────────────────────────────────────
for level in ["total", "L1", "L2"]:
print(f"\n{'='*60}")
print(f" CLD β€” LEVEL: {level}")
print(f"{'='*60}")
success_arrays = []
for name in model_names:
k, n = EXPERIMENTS[name][level]
arr = np.array([True] * k + [False] * (n - k))
success_arrays.append(arr)
cld_dict = compare_success_and_get_cld(
model_names,
success_arrays,
GLOBAL_CONFIDENCE,
N_MAX,
SHUFFLE,
verbose=True,
)
print(f"\nJSON for HTML embed:")
json_obj = {name: cld_dict[name] for name in model_names}
print(json.dumps(json_obj, ensure_ascii=False))
# ── 2. Verify Beta posteriors ────────────────────────────────────────────────
print(f"\n\n{'#'*70}")
print(f" POSTERIOR VERIFICATION")
print(f" Prior: Beta(1,1) (uniform). Posterior: Beta(1+k, 1+n-k)")
print(f"{'#'*70}")
level_map = {"total": "total", "L1": "l1", "L2": "l2"}
rng = np.random.default_rng(42)
all_ok = True
for level in ["total", "L1", "L2"]:
html_key = level_map[level]
print(f"\n{'─'*60}")
print(f" Level: {level} (HTML key: '{html_key}', n={HTML_N[html_key]})")
print(f"{'─'*60}")
print(f" {'Experiment':<24s} {'k/n':>6s} {'Ξ±':>4s} {'Ξ²':>4s} {'Mean':>7s} {'90% CI':>16s} {'HTML%β†’k':>8s} {'Match':>5s}")
for name in model_names:
k, n = EXPERIMENTS[name][level]
alpha_post = 1 + k
beta_post = 1 + (n - k)
dist = stats.beta(alpha_post, beta_post)
mean = dist.mean()
ci_lo, ci_hi = dist.ppf(0.05), dist.ppf(0.95)
# Verify HTML percentage round-trip
html_pct = HTML_RAW_PCT[name][html_key]
html_n = HTML_N[html_key]
html_k = round(html_pct / 100 * html_n)
match = html_k == k
if not match:
all_ok = False
print(
f" {name:<24s} {k:>2d}/{n:<2d} {alpha_post:>4d} {beta_post:>4d} "
f"{mean*100:>6.1f}% [{ci_lo*100:>5.1f}% – {ci_hi*100:>5.1f}%] "
f"{html_pct}%β†’{html_k:>2d} {'βœ“' if match else 'βœ— MISMATCH'}"
)
# Also run TRI's draw_samples_from_beta_posterior for a spot-check
spot_name = "2.5 HQ+RABC+Relβ˜…"
k_spot, n_spot = EXPERIMENTS[spot_name][level]
arr_spot = np.array([True] * k_spot + [False] * (n_spot - k_spot))
samples = draw_samples_from_beta_posterior(arr_spot, rng, num_samples=100_000)
print(f"\n Spot-check ({spot_name}, {level}):")
print(f" TRI samples: mean={np.mean(samples)*100:.1f}%, std={np.std(samples)*100:.1f}%")
alpha_s, beta_s = 1 + k_spot, 1 + (n_spot - k_spot)
analytic = stats.beta(alpha_s, beta_s)
print(f" Analytic: mean={analytic.mean()*100:.1f}%, std={analytic.std()*100:.1f}%")
print(f" HTML params: Beta({alpha_s}, {beta_s})")
print(f"\n{'='*60}")
if all_ok:
print(" ALL POSTERIORS VERIFIED βœ“")
print(" - Beta(1+k, 1+n-k) with uniform prior Beta(1,1)")
print(" - HTML percentage→k round-trip: all match")
print(" - Matches TRI draw_samples_from_beta_posterior()")
else:
print(" βœ— SOME MISMATCHES FOUND β€” see above")
print(f"{'='*60}")