File size: 2,459 Bytes
f86dc09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""Three smoke tests so users can verify the install before the reproduce
scripts.

Run with:
    pip install -e ".[test]"
    pytest -q tests/
"""
import sys
from pathlib import Path

KIT_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(KIT_ROOT / "src"))


def test_kit_imports_cleanly():
    """All public-facing modules import without error."""
    from tilelli.core import TilelliLiteLM, PATHWAY_NAMES_LITE  # noqa: F401
    from tilelli.core.ternary_linear import TernaryLinear  # noqa: F401
    from tilelli.core.ternary_conv import TernaryCausalConv1d  # noqa: F401
    from tilelli.core.sparse_attention import SparseCausalAttention  # noqa: F401
    from tilelli.distillery.tokenize import ByteTokenizer
    from tilelli.eval.metacog_probe import load_bridge  # noqa: F401

    tok = ByteTokenizer()
    enc = tok.encode("hello world")
    out = tok.decode(enc.tolist())
    assert out == "hello world", f"tokenizer roundtrip broke: {out!r}"


def test_bundled_checkpoint_loads():
    """The bundled v4 checkpoint loads, has 10M params, and produces a forward."""
    import torch
    from tilelli.eval.metacog_probe import load_bridge

    ckpt = KIT_ROOT / "checkpoints" / "tilelli_chat_v4.pt"
    assert ckpt.exists(), f"bundled ckpt missing at {ckpt}"

    model, abstain, tok = load_bridge(str(ckpt))
    n_params = sum(p.numel() for p in model.parameters())
    assert 9_500_000 <= n_params <= 11_000_000, (
        f"param count {n_params} outside 9.5–11M; ckpt may be wrong file"
    )

    ids = tok.encode("USER: hi\nTILELLI:").long().unsqueeze(0)
    with torch.no_grad():
        out = model(ids)
    assert out.ndim == 3 and out.shape[-1] == 256, (
        f"unexpected forward shape: {tuple(out.shape)}"
    )


def test_one_generation_step_runs():
    """Greedy generation runs end-to-end without crashing."""
    import torch
    from tilelli.eval.metacog_probe import load_bridge

    ckpt = KIT_ROOT / "checkpoints" / "tilelli_chat_v4.pt"
    model, _abstain, tok = load_bridge(str(ckpt))
    ids = tok.encode("USER: hello\nTILELLI:").long().unsqueeze(0)

    with torch.no_grad():
        full, generated, confs = model.generate_with_cache(
            ids, n_new_tokens=8, stop_ids=(10, 0)
        )

    assert full.size(1) >= ids.size(1), "no tokens generated"
    assert len(generated) >= 1, "no generated_id_list entries"
    assert all(0.0 <= c <= 1.0 for c in confs), "confidence values out of range"