File size: 4,811 Bytes
e51ccda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import pytest

torch = pytest.importorskip("torch")

from sgjm.eval.checkpoint import load_checkpoint
from sgjm.eval.metrics import (
    BaselineEvalMetrics,
    SGJMEvalMetrics,
    compare,
    evaluate_baseline,
    evaluate_sgjm,
)
from sgjm.training.config import TrainingConfig
from sgjm.training.data import ByteDataset, synthetic_corpus
from sgjm.training.torch_backend.baseline import BaselineLM
from sgjm.training.torch_backend.model import SGJM
from sgjm.training.torch_backend.trainer import train


def _smoke_cfg(arch: str, ckpt_dir) -> TrainingConfig:
    cfg = TrainingConfig.smoke()
    cfg.arch = arch
    cfg.checkpoint_dir = str(ckpt_dir)
    return cfg


def test_baseline_model_param_count_matches_sgjm_total():
    cfg = TrainingConfig.sgjm_25m()
    sgjm = SGJM(cfg.model)
    baseline = BaselineLM(cfg.model)
    sgjm_n = sgjm.num_parameters()
    base_n = baseline.num_parameters()
    # Should land within +/- 10% so the comparison is fair.
    assert 0.9 * sgjm_n <= base_n <= 1.1 * sgjm_n, (
        f"baseline {base_n/1e6:.2f}M vs sgjm {sgjm_n/1e6:.2f}M too different"
    )


def test_evaluate_sgjm_returns_finite_metrics():
    cfg = TrainingConfig.smoke()
    model = SGJM(cfg.model)
    corpus = synthetic_corpus(4096, seed=0)
    ds = ByteDataset(corpus, cfg.optim.seq_len)
    metrics = evaluate_sgjm(
        model,
        cfg,
        ds,
        n_batches=2,
        n_distractors=4,
        n_merge_pairs=128,
        drafts_per_step=2,
        device="cpu",
    )
    assert 0.0 <= metrics.branch_acceptance_rate <= 1.0
    assert 0.0 <= metrics.jepa_top1_acc <= 1.0
    assert metrics.compute_per_accepted_token > 0
    # merge_precision_js can legitimately be NaN when too few pairs land
    # inside the merge radius; everything else must be finite.
    nan_ok = {"merge_precision_js", "random_pair_js"}
    for k, v in metrics.to_dict().items():
        if isinstance(v, float) and k not in nan_ok:
            assert v == v, f"{k} is NaN"


def test_evaluate_baseline_returns_finite_metrics():
    cfg = TrainingConfig.smoke()
    model = BaselineLM(cfg.model)
    corpus = synthetic_corpus(4096, seed=0)
    ds = ByteDataset(corpus, cfg.optim.seq_len)
    metrics = evaluate_baseline(model, cfg, ds, n_batches=2, device="cpu")
    assert metrics.token_nll > 0
    assert metrics.compute_per_token > 0


def test_compare_constructs_report():
    sgjm = SGJMEvalMetrics(
        n_tokens=100, n_positions=50, token_nll=1.5, token_ppl=4.5,
        branch_acceptance_rate=0.6, jepa_top1_acc=0.4, jepa_chance_top1=0.1,
        merge_precision_js=0.01, random_pair_js=0.05,
        merge_precision_advantage=5.0,
        compute_per_accepted_token=1e6,
    )
    baseline = BaselineEvalMetrics(
        n_tokens=100, token_nll=1.5, token_ppl=4.5, compute_per_token=2e6,
    )
    report = compare(sgjm, baseline)
    assert report.gate_passed
    assert report.compute_advantage > 1.0


def test_compare_fails_gate_on_high_nll():
    sgjm = SGJMEvalMetrics(
        n_tokens=100, n_positions=50, token_nll=3.0, token_ppl=20.0,
        branch_acceptance_rate=0.6, jepa_top1_acc=0.4, jepa_chance_top1=0.1,
        merge_precision_js=0.01, random_pair_js=0.05,
        merge_precision_advantage=5.0,
        compute_per_accepted_token=1e6,
    )
    baseline = BaselineEvalMetrics(
        n_tokens=100, token_nll=1.5, token_ppl=4.5, compute_per_token=2e6,
    )
    report = compare(sgjm, baseline)
    assert not report.gate_passed
    assert any("nll" in r for r in report.gate_reasons)


def test_end_to_end_train_then_eval(tmp_path):
    sgjm_dir = tmp_path / "sgjm"
    base_dir = tmp_path / "baseline"
    sgjm_cfg = _smoke_cfg("sgjm", sgjm_dir)
    base_cfg = _smoke_cfg("baseline", base_dir)

    sgjm_result = train(sgjm_cfg, backend="cpu")
    base_result = train(base_cfg, backend="cpu")
    assert sgjm_result.checkpoint_path is not None
    assert base_result.checkpoint_path is not None

    sgjm_loaded = load_checkpoint(sgjm_result.checkpoint_path, device="cpu")
    base_loaded = load_checkpoint(base_result.checkpoint_path, device="cpu")
    assert sgjm_loaded.arch == "sgjm"
    assert base_loaded.arch == "baseline"

    corpus = synthetic_corpus(2048, seed=0)
    ds = ByteDataset(corpus, sgjm_cfg.optim.seq_len)
    sgjm_metrics = evaluate_sgjm(
        sgjm_loaded.model, sgjm_cfg, ds, n_batches=2,
        n_distractors=4, n_merge_pairs=64, drafts_per_step=2, device="cpu",
    )
    base_metrics = evaluate_baseline(
        base_loaded.model, base_cfg, ds, n_batches=2, device="cpu",
    )
    report = compare(sgjm_metrics, base_metrics)
    assert isinstance(report.gate_passed, bool)
    # Untrained smoke models won't pass the gate, just verify the report shape
    assert report.sgjm.token_nll > 0
    assert report.baseline.token_nll > 0