NuWave / nuwave /benchmark_loader.py
Executor-Tyrant-Framework's picture
Sync from GitHub: 2d1f218e331d6687f61ad13796b3478380b66891
5fa3f26 verified
"""
benchmark_loader.py — Load and sample from the Phase B benchmark prompt pool.
The pool (benchmark_pool.yaml) contains 17 categories × 4 chains × 3 layers
= 204 prompts, with 8 conceptual threads woven through in a near-uniform
bipartite (each thread spans 7-10 categories). Subversion is in
priority_categories so it's force-included in every per-run sample.
This loader:
1. Loads the pool YAML
2. Builds (Q1, Q2, Q3) chain triples — each Q2 and Q3 references the
same parent Q1 by `parent` field (Q3 is a sibling to Q2 under Q1,
not a Q1→Q2→Q3 lineage)
3. Samples N chains per run with stratification discipline:
- Force-include 1 chain from each priority category (subversion)
- At most 2 chains per category (prevents category dominance)
- At least 3 threads with 2+ representatives (cross-category co-firing)
- Multi-complexity coverage across all 3 layers
4. Returns interleaved Q1/Q2/Q3 turn sequence:
turns 0..N-1: Q1s (one per sampled chain, in chain order)
turns N..2N-1: matching Q2s (same chain order)
turns 2N..3N-1: matching Q3s (same chain order)
5. Returns same-cat pair indices for the heatmap math. Phase A semantics
preserved: pairs are Q1↔Q2 only `[(i, i+N) for i in range(N)]`.
Q3 turns contribute to substrate but aren't part of the strict same-
cat-reselect calculation. Future work (Option B) can add Q1↔Q3 and
Q2↔Q3 pairings.
# ---- Changelog ----
# [2026-05-10] Claude Opus 4.7 — Phase A loader (Q1/Q2 pairs, 10 cats)
# [2026-05-11] Claude Opus 4.7 — Phase B loader (Q1/Q2/Q3 chains, 17 cats,
# priority_categories). Function renamed sample_pairs →
# sample_chains. Returns 24-turn interleave (3 layers × 8
# chains). Subversion is forced in every sample to give
# substrate consistent expectation-subverting content
# exposure for the surprise-axis hypothesis test.
# -------------------
"""
from __future__ import annotations
import os
import random
from collections import Counter
from typing import Any, Dict, List, Optional, Tuple
import yaml
_DEFAULT_POOL_PATH = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"benchmark_pool.yaml",
)
def load_pool(path: Optional[str] = None) -> Dict[str, Any]:
"""Load the benchmark pool YAML.
Returns a dict with keys:
threads: list of thread names (8 entries)
complexity_levels: list of complexity tags (6 entries)
priority_categories: list of categories that must appear in every
per-run sample (typically just ["subversion"])
q1_layer: list of 68 Q1 dicts (id, category, thread,
complexity, text)
q2_layer: list of 68 Q2 dicts (adds: parent → Q1 id)
q3_layer: list of 68 Q3 dicts (parent → Q1 id; Q3 is
sibling to Q2 under Q1)
"""
p = path or _DEFAULT_POOL_PATH
with open(p) as f:
return yaml.safe_load(f)
def _build_chains(
pool: Dict[str, Any],
) -> List[Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]]:
"""Build (Q1, Q2, Q3) chain triples from the pool.
Each Q2's and Q3's `parent` field references its Q1's `id`. Chains
without a complete (Q1, Q2, Q3) triple are skipped; Phase B
discipline guarantees full triples but defensive code stays.
"""
q2_by_parent = {q["parent"]: q for q in pool.get("q2_layer", [])}
q3_by_parent = {q["parent"]: q for q in pool.get("q3_layer", [])}
chains: List[
Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]
] = []
for q1 in pool.get("q1_layer", []):
q2 = q2_by_parent.get(q1["id"])
q3 = q3_by_parent.get(q1["id"])
if q2 is not None and q3 is not None:
chains.append((q1, q2, q3))
return chains
def _validate_sample(
sample: List[Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]],
max_per_category: int = 2,
min_threads_with_dups: int = 3,
min_distinct_complexity_levels: int = 3,
) -> bool:
"""Stratification discipline check.
Returns True if the sample respects:
- max `max_per_category` chains per category (default 2)
- at least `min_threads_with_dups` threads with 2+ instances (3)
- at least `min_distinct_complexity_levels` distinct complexity
tags across the sample's combined Q1+Q2+Q3 levels (3)
Counted across (Q1, Q2, Q3) chain triples. Each chain contributes
its (single) thread once and contributes 3 complexity tags (one per
layer).
"""
if not sample:
return False
cats = Counter(q1["category"] for q1, _q2, _q3 in sample)
if any(count > max_per_category for count in cats.values()):
return False
threads = Counter(q1["thread"] for q1, _q2, _q3 in sample)
threads_with_dups = sum(
1 for count in threads.values() if count >= 2
)
if threads_with_dups < min_threads_with_dups:
return False
complexities: set = set()
for q1, q2, q3 in sample:
complexities.add(q1["complexity"])
complexities.add(q2["complexity"])
complexities.add(q3["complexity"])
if len(complexities) < min_distinct_complexity_levels:
return False
return True
def sample_chains(
pool: Optional[Dict[str, Any]] = None,
n_chains: int = 8,
seed: Optional[int] = None,
max_attempts: int = 200,
) -> Tuple[
List[Tuple[str, str]],
List[Tuple[int, int]],
List[Dict[str, Any]],
]:
"""Sample `n_chains` chains with stratification + priority discipline.
Args:
pool: Pre-loaded pool dict. If None, loads from default path.
n_chains: Total number of (Q1, Q2, Q3) chains to sample. Each
chain contributes 3 turns, so total turns = 3 * n_chains.
Phase B default 8 chains → 24 turns/run.
seed: RNG seed for reproducibility. None = nondeterministic.
max_attempts: Rejection-sampling retry budget on the non-priority
portion of the sample.
Returns:
interleaved_questions: list of (category, prompt_text) tuples,
3*n_chains entries. Turn structure:
0..n-1: Q1s
n..2n-1: Q2s (matching, same order)
2n..3n-1: Q3s (matching, same order)
same_cat_pairs: list of (q1_turn_idx, q2_turn_idx) tuples,
n_chains entries. Phase A semantics:
always [(i, i+n_chains) for i in range(n)].
Q3 turns aren't paired here (Option A from
2026-05-11; future Option B can add Q1↔Q3
and Q2↔Q3 pairs).
sample_meta: list of n_chains dicts with q1_id, q2_id,
q3_id, category, thread, q1_complexity,
q2_complexity, q3_complexity.
Priority categories (from pool["priority_categories"]) are force-
included: one chain from each priority category is pre-selected
before rejection sampling fills the remaining slots from the non-
priority pool. Stratification is checked on the COMBINED final
sample, so the forced chain's thread/complexity contribute to the
constraint accounting.
"""
if pool is None:
pool = load_pool()
chains = _build_chains(pool)
if len(chains) < n_chains:
raise ValueError(
f"Pool has {len(chains)} chains, cannot sample {n_chains}"
)
priority_cats: List[str] = pool.get("priority_categories", []) or []
rng = random.Random(seed)
# Step 1 — Pre-select forced chains from priority categories
forced: List[
Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]
] = []
for cat in priority_cats:
cat_chains = [c for c in chains if c[0]["category"] == cat]
if cat_chains:
forced.append(rng.choice(cat_chains))
# Step 2 — Fill remaining slots from non-priority chains via
# rejection sampling against the COMBINED (forced + sampled) total
n_remaining = n_chains - len(forced)
if n_remaining < 0:
raise ValueError(
f"More priority categories ({len(forced)}) than n_chains "
f"({n_chains}); reduce priority list or raise n_chains"
)
forced_ids = {c[0]["id"] for c in forced}
non_priority = [c for c in chains if c[0]["id"] not in forced_ids]
selected: Optional[
List[Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]]
] = None
for _attempt in range(max_attempts):
if n_remaining > 0:
candidate_remaining = rng.sample(non_priority, n_remaining)
else:
candidate_remaining = []
candidate = forced + candidate_remaining
if _validate_sample(candidate):
selected = candidate
break
if selected is None:
# Fallback — accept partial constraint satisfaction rather than
# aborting. Forced chains still included; remaining slots filled
# by best-effort random draw.
if n_remaining > 0:
selected = forced + rng.sample(non_priority, n_remaining)
else:
selected = list(forced)
interleaved: List[Tuple[str, str]] = []
for q1, _q2, _q3 in selected:
interleaved.append((q1["category"], q1["text"]))
for _q1, q2, _q3 in selected:
interleaved.append((q2["category"], q2["text"]))
for _q1, _q2, q3 in selected:
interleaved.append((q3["category"], q3["text"]))
same_cat_pairs: List[Tuple[int, int]] = [
(i, i + n_chains) for i in range(n_chains)
]
sample_meta: List[Dict[str, Any]] = []
for q1, q2, q3 in selected:
sample_meta.append({
"q1_id": q1["id"],
"q2_id": q2["id"],
"q3_id": q3["id"],
"category": q1["category"],
"thread": q1["thread"],
"q1_complexity": q1["complexity"],
"q2_complexity": q2["complexity"],
"q3_complexity": q3["complexity"],
})
return interleaved, same_cat_pairs, sample_meta
def describe_sample(
sample_meta: List[Dict[str, Any]],
) -> Dict[str, Any]:
"""Produce a small structured summary of a sample for logging.
Used by the benchmark to surface in JSON output what was actually
sampled this run — useful for correlating per-run substrate
behavior with which threads / categories / complexity registers
were exercised, and for confirming priority_categories are
being respected.
"""
cats = Counter(m["category"] for m in sample_meta)
threads = Counter(m["thread"] for m in sample_meta)
complexities: Counter = Counter()
for m in sample_meta:
complexities[m["q1_complexity"]] += 1
complexities[m["q2_complexity"]] += 1
complexities[m["q3_complexity"]] += 1
return {
"n_chains": len(sample_meta),
"n_turns": 3 * len(sample_meta),
"categories_sampled": dict(cats),
"threads_sampled": dict(threads),
"complexity_distribution": dict(complexities),
"chain_ids": [
(m["q1_id"], m["q2_id"], m["q3_id"]) for m in sample_meta
],
}