File size: 11,463 Bytes
a3050b3
5fa3f26
a3050b3
5fa3f26
 
 
 
 
 
a3050b3
 
5fa3f26
 
 
 
 
 
a3050b3
5fa3f26
 
 
 
 
 
 
 
 
 
a3050b3
 
5fa3f26
 
 
 
 
 
 
a3050b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5fa3f26
 
 
 
 
 
 
 
 
a3050b3
 
 
 
 
 
5fa3f26
a3050b3
5fa3f26
 
a3050b3
5fa3f26
 
 
a3050b3
 
5fa3f26
 
 
 
a3050b3
 
5fa3f26
 
 
 
a3050b3
 
 
5fa3f26
a3050b3
 
 
 
 
 
 
5fa3f26
 
 
 
 
 
 
 
a3050b3
 
 
 
5fa3f26
a3050b3
 
 
5fa3f26
 
 
 
a3050b3
 
 
 
5fa3f26
a3050b3
 
5fa3f26
a3050b3
 
 
 
 
 
5fa3f26
a3050b3
5fa3f26
a3050b3
 
 
 
 
 
 
5fa3f26
a3050b3
 
 
5fa3f26
 
 
a3050b3
5fa3f26
 
a3050b3
 
 
5fa3f26
 
 
 
a3050b3
5fa3f26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a3050b3
 
 
 
5fa3f26
 
a3050b3
5fa3f26
a3050b3
 
5fa3f26
a3050b3
 
5fa3f26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a3050b3
5fa3f26
 
 
 
 
a3050b3
 
 
 
 
5fa3f26
 
 
 
 
 
 
a3050b3
 
5fa3f26
a3050b3
5fa3f26
a3050b3
5fa3f26
 
a3050b3
 
5fa3f26
a3050b3
 
 
5fa3f26
a3050b3
 
 
5fa3f26
a3050b3
 
 
 
5fa3f26
a3050b3
 
 
 
 
 
 
 
 
 
 
5fa3f26
 
 
 
a3050b3
 
 
5fa3f26
a3050b3
 
 
5fa3f26
a3050b3
 
5fa3f26
 
a3050b3
 
 
5fa3f26
 
 
a3050b3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
"""
benchmark_loader.py β€” Load and sample from the Phase B benchmark prompt pool.

The pool (benchmark_pool.yaml) contains 17 categories Γ— 4 chains Γ— 3 layers
= 204 prompts, with 8 conceptual threads woven through in a near-uniform
bipartite (each thread spans 7-10 categories). Subversion is in
priority_categories so it's force-included in every per-run sample.

This loader:

  1. Loads the pool YAML
  2. Builds (Q1, Q2, Q3) chain triples β€” each Q2 and Q3 references the
     same parent Q1 by `parent` field (Q3 is a sibling to Q2 under Q1,
     not a Q1β†’Q2β†’Q3 lineage)
  3. Samples N chains per run with stratification discipline:
     - Force-include 1 chain from each priority category (subversion)
     - At most 2 chains per category (prevents category dominance)
     - At least 3 threads with 2+ representatives (cross-category co-firing)
     - Multi-complexity coverage across all 3 layers
  4. Returns interleaved Q1/Q2/Q3 turn sequence:
       turns 0..N-1:    Q1s (one per sampled chain, in chain order)
       turns N..2N-1:   matching Q2s (same chain order)
       turns 2N..3N-1:  matching Q3s (same chain order)
  5. Returns same-cat pair indices for the heatmap math. Phase A semantics
     preserved: pairs are Q1↔Q2 only `[(i, i+N) for i in range(N)]`.
     Q3 turns contribute to substrate but aren't part of the strict same-
     cat-reselect calculation. Future work (Option B) can add Q1↔Q3 and
     Q2↔Q3 pairings.

# ---- Changelog ----
# [2026-05-10] Claude Opus 4.7 β€” Phase A loader (Q1/Q2 pairs, 10 cats)
# [2026-05-11] Claude Opus 4.7 β€” Phase B loader (Q1/Q2/Q3 chains, 17 cats,
#              priority_categories). Function renamed sample_pairs β†’
#              sample_chains. Returns 24-turn interleave (3 layers Γ— 8
#              chains). Subversion is forced in every sample to give
#              substrate consistent expectation-subverting content
#              exposure for the surprise-axis hypothesis test.
# -------------------
"""

from __future__ import annotations

import os
import random
from collections import Counter
from typing import Any, Dict, List, Optional, Tuple

import yaml


_DEFAULT_POOL_PATH = os.path.join(
    os.path.dirname(os.path.abspath(__file__)),
    "benchmark_pool.yaml",
)


def load_pool(path: Optional[str] = None) -> Dict[str, Any]:
    """Load the benchmark pool YAML.

    Returns a dict with keys:
        threads:             list of thread names (8 entries)
        complexity_levels:   list of complexity tags (6 entries)
        priority_categories: list of categories that must appear in every
                             per-run sample (typically just ["subversion"])
        q1_layer:            list of 68 Q1 dicts (id, category, thread,
                             complexity, text)
        q2_layer:            list of 68 Q2 dicts (adds: parent β†’ Q1 id)
        q3_layer:            list of 68 Q3 dicts (parent β†’ Q1 id; Q3 is
                             sibling to Q2 under Q1)
    """
    p = path or _DEFAULT_POOL_PATH
    with open(p) as f:
        return yaml.safe_load(f)


def _build_chains(
    pool: Dict[str, Any],
) -> List[Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]]:
    """Build (Q1, Q2, Q3) chain triples from the pool.

    Each Q2's and Q3's `parent` field references its Q1's `id`. Chains
    without a complete (Q1, Q2, Q3) triple are skipped; Phase B
    discipline guarantees full triples but defensive code stays.
    """
    q2_by_parent = {q["parent"]: q for q in pool.get("q2_layer", [])}
    q3_by_parent = {q["parent"]: q for q in pool.get("q3_layer", [])}
    chains: List[
        Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]
    ] = []
    for q1 in pool.get("q1_layer", []):
        q2 = q2_by_parent.get(q1["id"])
        q3 = q3_by_parent.get(q1["id"])
        if q2 is not None and q3 is not None:
            chains.append((q1, q2, q3))
    return chains


def _validate_sample(
    sample: List[Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]],
    max_per_category: int = 2,
    min_threads_with_dups: int = 3,
    min_distinct_complexity_levels: int = 3,
) -> bool:
    """Stratification discipline check.

    Returns True if the sample respects:
        - max `max_per_category` chains per category (default 2)
        - at least `min_threads_with_dups` threads with 2+ instances (3)
        - at least `min_distinct_complexity_levels` distinct complexity
          tags across the sample's combined Q1+Q2+Q3 levels (3)

    Counted across (Q1, Q2, Q3) chain triples. Each chain contributes
    its (single) thread once and contributes 3 complexity tags (one per
    layer).
    """
    if not sample:
        return False

    cats = Counter(q1["category"] for q1, _q2, _q3 in sample)
    if any(count > max_per_category for count in cats.values()):
        return False

    threads = Counter(q1["thread"] for q1, _q2, _q3 in sample)
    threads_with_dups = sum(
        1 for count in threads.values() if count >= 2
    )
    if threads_with_dups < min_threads_with_dups:
        return False

    complexities: set = set()
    for q1, q2, q3 in sample:
        complexities.add(q1["complexity"])
        complexities.add(q2["complexity"])
        complexities.add(q3["complexity"])
    if len(complexities) < min_distinct_complexity_levels:
        return False

    return True


def sample_chains(
    pool: Optional[Dict[str, Any]] = None,
    n_chains: int = 8,
    seed: Optional[int] = None,
    max_attempts: int = 200,
) -> Tuple[
    List[Tuple[str, str]],
    List[Tuple[int, int]],
    List[Dict[str, Any]],
]:
    """Sample `n_chains` chains with stratification + priority discipline.

    Args:
        pool: Pre-loaded pool dict. If None, loads from default path.
        n_chains: Total number of (Q1, Q2, Q3) chains to sample. Each
                  chain contributes 3 turns, so total turns = 3 * n_chains.
                  Phase B default 8 chains β†’ 24 turns/run.
        seed: RNG seed for reproducibility. None = nondeterministic.
        max_attempts: Rejection-sampling retry budget on the non-priority
                      portion of the sample.

    Returns:
        interleaved_questions: list of (category, prompt_text) tuples,
                               3*n_chains entries. Turn structure:
                                  0..n-1:    Q1s
                                  n..2n-1:   Q2s (matching, same order)
                                  2n..3n-1:  Q3s (matching, same order)
        same_cat_pairs:        list of (q1_turn_idx, q2_turn_idx) tuples,
                               n_chains entries. Phase A semantics:
                               always [(i, i+n_chains) for i in range(n)].
                               Q3 turns aren't paired here (Option A from
                               2026-05-11; future Option B can add Q1↔Q3
                               and Q2↔Q3 pairs).
        sample_meta:           list of n_chains dicts with q1_id, q2_id,
                               q3_id, category, thread, q1_complexity,
                               q2_complexity, q3_complexity.

    Priority categories (from pool["priority_categories"]) are force-
    included: one chain from each priority category is pre-selected
    before rejection sampling fills the remaining slots from the non-
    priority pool. Stratification is checked on the COMBINED final
    sample, so the forced chain's thread/complexity contribute to the
    constraint accounting.
    """
    if pool is None:
        pool = load_pool()

    chains = _build_chains(pool)
    if len(chains) < n_chains:
        raise ValueError(
            f"Pool has {len(chains)} chains, cannot sample {n_chains}"
        )

    priority_cats: List[str] = pool.get("priority_categories", []) or []
    rng = random.Random(seed)

    # Step 1 β€” Pre-select forced chains from priority categories
    forced: List[
        Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]
    ] = []
    for cat in priority_cats:
        cat_chains = [c for c in chains if c[0]["category"] == cat]
        if cat_chains:
            forced.append(rng.choice(cat_chains))

    # Step 2 β€” Fill remaining slots from non-priority chains via
    # rejection sampling against the COMBINED (forced + sampled) total
    n_remaining = n_chains - len(forced)
    if n_remaining < 0:
        raise ValueError(
            f"More priority categories ({len(forced)}) than n_chains "
            f"({n_chains}); reduce priority list or raise n_chains"
        )
    forced_ids = {c[0]["id"] for c in forced}
    non_priority = [c for c in chains if c[0]["id"] not in forced_ids]

    selected: Optional[
        List[Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]]
    ] = None
    for _attempt in range(max_attempts):
        if n_remaining > 0:
            candidate_remaining = rng.sample(non_priority, n_remaining)
        else:
            candidate_remaining = []
        candidate = forced + candidate_remaining
        if _validate_sample(candidate):
            selected = candidate
            break

    if selected is None:
        # Fallback β€” accept partial constraint satisfaction rather than
        # aborting. Forced chains still included; remaining slots filled
        # by best-effort random draw.
        if n_remaining > 0:
            selected = forced + rng.sample(non_priority, n_remaining)
        else:
            selected = list(forced)

    interleaved: List[Tuple[str, str]] = []
    for q1, _q2, _q3 in selected:
        interleaved.append((q1["category"], q1["text"]))
    for _q1, q2, _q3 in selected:
        interleaved.append((q2["category"], q2["text"]))
    for _q1, _q2, q3 in selected:
        interleaved.append((q3["category"], q3["text"]))

    same_cat_pairs: List[Tuple[int, int]] = [
        (i, i + n_chains) for i in range(n_chains)
    ]

    sample_meta: List[Dict[str, Any]] = []
    for q1, q2, q3 in selected:
        sample_meta.append({
            "q1_id": q1["id"],
            "q2_id": q2["id"],
            "q3_id": q3["id"],
            "category": q1["category"],
            "thread": q1["thread"],
            "q1_complexity": q1["complexity"],
            "q2_complexity": q2["complexity"],
            "q3_complexity": q3["complexity"],
        })

    return interleaved, same_cat_pairs, sample_meta


def describe_sample(
    sample_meta: List[Dict[str, Any]],
) -> Dict[str, Any]:
    """Produce a small structured summary of a sample for logging.

    Used by the benchmark to surface in JSON output what was actually
    sampled this run β€” useful for correlating per-run substrate
    behavior with which threads / categories / complexity registers
    were exercised, and for confirming priority_categories are
    being respected.
    """
    cats = Counter(m["category"] for m in sample_meta)
    threads = Counter(m["thread"] for m in sample_meta)
    complexities: Counter = Counter()
    for m in sample_meta:
        complexities[m["q1_complexity"]] += 1
        complexities[m["q2_complexity"]] += 1
        complexities[m["q3_complexity"]] += 1

    return {
        "n_chains": len(sample_meta),
        "n_turns": 3 * len(sample_meta),
        "categories_sampled": dict(cats),
        "threads_sampled": dict(threads),
        "complexity_distribution": dict(complexities),
        "chain_ids": [
            (m["q1_id"], m["q2_id"], m["q3_id"]) for m in sample_meta
        ],
    }