File size: 8,456 Bytes
b43d8da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
"""Cell 19 — Final evaluation harness (post-training LoRA).

Implements ``docs/modules/evaluation.md`` §2.1, §3.1, §3.3 (paired-difference),
§3.5 (drift-detection latency aggregation), §3.8, §5 ``EpisodeSetLeakError``.

Hard rules (evaluation.md §3.1, §6.1, §6.3):
- Same 50 episodes as baseline (paired); ``EpisodeSetLeakError`` raised on
  mismatch.
- Bootstrap CI seed for paired-difference is ``20260428`` (evaluation.md §2.4).
- Wall-clock budget 20 minutes — same ceiling as baseline.
- No LLM-as-judge; static AST scan via ``_NO_LLM_JUDGE_FORBIDDEN_IMPORTS``.

Heavy imports (``torch``) are deferred so this module imports cleanly on
CPU-only CI. The training-eval delegate is injected (see step_18).
"""

from __future__ import annotations

import time
from dataclasses import replace
from pathlib import Path
from typing import TYPE_CHECKING, Any

from cells.step_18_eval_baseline import (
    BUDGET_RUN_EVAL_SECONDS,
    DEFAULT_N_BOOT,
    DEFAULT_PAIRED_BOOTSTRAP_SEED,
    DriftDetectionLatency,
    EvalBudgetExceededError,
    EvalReport,
    EvaluationError,
    PerLanguageReport,
    TrainingEvalCallable,
    _check_catalogue_hashes,
    _episode_ids_from_breakdown,
    _validate_briefs_first_50,
    run_eval,
)

if TYPE_CHECKING:  # pragma: no cover - typing only
    from collections.abc import Callable, Sequence


__all__ = [
    "BUDGET_RUN_EVAL_SECONDS",
    "DEFAULT_PAIRED_BOOTSTRAP_SEED",
    "DriftDetectionLatency",
    "EpisodeSetLeakError",
    "EvalBudgetExceededError",
    "EvalReport",
    "PerLanguageReport",
    "assert_paired_episode_sets",
    "eval_final",
    "paired_difference_ci",
]


# ---------------------------------------------------------------------------
# Errors — evaluation.md §5
# ---------------------------------------------------------------------------


class EpisodeSetLeakError(EvaluationError):
    """Baseline ``episode_ids`` ≠ final ``episode_ids`` — paired-comparison invariant violated."""


# ---------------------------------------------------------------------------
# Paired-difference CI — evaluation.md §2.4
# ---------------------------------------------------------------------------


def paired_difference_ci(
    baseline_samples: tuple[float, ...],
    final_samples: tuple[float, ...],
    n_boot: int = DEFAULT_N_BOOT,
    rng_seed: int = DEFAULT_PAIRED_BOOTSTRAP_SEED,
) -> tuple[float, float, float]:
    """Bootstrap 95% CI on ``mean(final - baseline)`` — index-paired.

    evaluation.md §2.4: lengths must match (raises ``EpisodeSetLeakError``).
    Edge cases mirror :func:`bootstrap_ci`: empty → all-NaN; single → triple.
    """
    if len(baseline_samples) != len(final_samples):
        raise EpisodeSetLeakError(
            f"paired-comparison invariant: len(baseline)={len(baseline_samples)} "
            f"!= len(final)={len(final_samples)}",
        )
    n = len(baseline_samples)
    if n == 0:
        nan = float("nan")
        return nan, nan, nan
    diffs = tuple(f - b for b, f in zip(baseline_samples, final_samples, strict=True))
    mean = sum(diffs) / n
    if n == 1:
        return mean, mean, mean
    if all(d == diffs[0] for d in diffs):
        return mean, mean, mean

    import numpy as np

    rng = np.random.default_rng(rng_seed)
    arr = np.asarray(diffs, dtype=np.float64)
    idx = rng.integers(0, n, size=(n_boot, n))
    means = arr[idx].mean(axis=1)
    lo = float(np.percentile(means, 2.5))
    hi = float(np.percentile(means, 97.5))
    return float(mean), lo, hi


# ---------------------------------------------------------------------------
# Episode-set leak guard — evaluation.md §3.1
# ---------------------------------------------------------------------------


def assert_paired_episode_sets(baseline: EvalReport, final: EvalReport) -> None:
    """Raise ``EpisodeSetLeakError`` iff ``episode_ids`` tuples differ."""
    base_ids = _episode_ids_from_breakdown(baseline)
    final_ids = _episode_ids_from_breakdown(final)
    if base_ids != final_ids:
        raise EpisodeSetLeakError(
            "paired-comparison invariant violated — baseline.episode_ids != final.episode_ids; "
            "operator must re-run baseline against the current val split.",
        )


# ---------------------------------------------------------------------------
# Drift-detection-latency point extraction — evaluation.md §3.5
# ---------------------------------------------------------------------------


def _final_latency_point(report: EvalReport) -> tuple[float, float]:
    """Return ``(p50, p95)`` from the report's drift-detection latency."""
    lat = report.drift_detection_latency
    # Stage-3 takes precedence (final stage); falls back to stage-2 if Stage-3 NaN.
    p50 = lat.stage3_median
    p95 = lat.stage3_p95
    return float(p50), float(p95)


# ---------------------------------------------------------------------------
# Final-eval entry point — evaluation.md §2.2 ``eval_final.py``
# ---------------------------------------------------------------------------


def eval_final(
    checkpoint: Path,
    episodes: int = 50,
    *,
    baseline: EvalReport,
    training_eval: TrainingEvalCallable,
    briefs: Sequence[Any],
    catalogue_hashes: dict[str, str] | None = None,
    budget_seconds: int = BUDGET_RUN_EVAL_SECONDS,
    monotonic: Callable[[], float] | None = None,
) -> EvalReport:
    """Run the trained LoRA against the SAME 50 paired episodes used by baseline.

    evaluation.md §2.1, §3.1: rejects mismatched checkpoints; verifies catalogue
    hashes; computes paired-difference CIs and stores them under
    ``EvalReport.breakdown['paired_ci']``.
    """
    if not isinstance(checkpoint, Path):
        raise EvaluationError(
            f"checkpoint must be pathlib.Path; got {type(checkpoint).__name__}",
        )
    if episodes != 50:
        raise EvaluationError(
            f"eval_final expects episodes=50 (paired contract); got {episodes}",
        )

    selected = _validate_briefs_first_50(briefs)
    if catalogue_hashes is not None:
        _check_catalogue_hashes(selected, catalogue_hashes)

    # Pre-flight: episode_ids match baseline before launching rollout.
    expected_ids = tuple(row.episode_id for row in selected)
    base_ids = _episode_ids_from_breakdown(baseline)
    if base_ids and base_ids != expected_ids:
        raise EpisodeSetLeakError(
            "paired-comparison invariant violated at entry — baseline.episode_ids "
            "do not match val/briefs.jsonl[0:50]; re-run baseline first.",
        )

    clock = monotonic if monotonic is not None else time.monotonic
    started = clock()

    final_report = run_eval(
        checkpoint,
        episodes,
        training_eval=training_eval,
        briefs=briefs,
        catalogue_hashes=catalogue_hashes,
        budget_seconds=budget_seconds,
        monotonic=clock,
    )
    elapsed = clock() - started
    if elapsed > budget_seconds:
        raise EvalBudgetExceededError(
            f"eval_final wall-clock {elapsed:.1f}s exceeded {budget_seconds}s",
        )

    assert_paired_episode_sets(baseline, final_report)

    # Compute paired-difference CIs (evaluation.md §3.3).
    paired_ci = _build_paired_ci_block(baseline, final_report)
    breakdown = dict(final_report.breakdown)
    breakdown["paired_ci"] = paired_ci
    return replace(final_report, breakdown=breakdown)


def _build_paired_ci_block(
    baseline: EvalReport,
    final: EvalReport,
) -> dict[str, tuple[float, float, float]]:
    """Construct the ``breakdown['paired_ci']`` block for the blog narrative."""
    out: dict[str, tuple[float, float, float]] = {}
    base_samples: dict[str, tuple[float, ...]] = baseline.breakdown.get("samples", {})
    final_samples: dict[str, tuple[float, ...]] = final.breakdown.get("samples", {})
    for key in ("reward", "r1", "r2", "r3", "r4", "r5"):
        if key in base_samples and key in final_samples:
            out[key] = paired_difference_ci(
                tuple(base_samples[key]),
                tuple(final_samples[key]),
            )

    # Drift-latency delta — final p50 minus baseline p50 (lower is better).
    base_p50, _ = _final_latency_point(baseline)
    final_p50, _ = _final_latency_point(final)
    if not (base_p50 != base_p50 or final_p50 != final_p50):  # neither NaN
        delta = final_p50 - base_p50
        out["drift_latency_p50"] = (delta, delta, delta)
    return out