File size: 2,459 Bytes
f86dc09 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 | """Three smoke tests so users can verify the install before the reproduce
scripts.
Run with:
pip install -e ".[test]"
pytest -q tests/
"""
import sys
from pathlib import Path
KIT_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(KIT_ROOT / "src"))
def test_kit_imports_cleanly():
"""All public-facing modules import without error."""
from tilelli.core import TilelliLiteLM, PATHWAY_NAMES_LITE # noqa: F401
from tilelli.core.ternary_linear import TernaryLinear # noqa: F401
from tilelli.core.ternary_conv import TernaryCausalConv1d # noqa: F401
from tilelli.core.sparse_attention import SparseCausalAttention # noqa: F401
from tilelli.distillery.tokenize import ByteTokenizer
from tilelli.eval.metacog_probe import load_bridge # noqa: F401
tok = ByteTokenizer()
enc = tok.encode("hello world")
out = tok.decode(enc.tolist())
assert out == "hello world", f"tokenizer roundtrip broke: {out!r}"
def test_bundled_checkpoint_loads():
"""The bundled v4 checkpoint loads, has 10M params, and produces a forward."""
import torch
from tilelli.eval.metacog_probe import load_bridge
ckpt = KIT_ROOT / "checkpoints" / "tilelli_chat_v4.pt"
assert ckpt.exists(), f"bundled ckpt missing at {ckpt}"
model, abstain, tok = load_bridge(str(ckpt))
n_params = sum(p.numel() for p in model.parameters())
assert 9_500_000 <= n_params <= 11_000_000, (
f"param count {n_params} outside 9.5–11M; ckpt may be wrong file"
)
ids = tok.encode("USER: hi\nTILELLI:").long().unsqueeze(0)
with torch.no_grad():
out = model(ids)
assert out.ndim == 3 and out.shape[-1] == 256, (
f"unexpected forward shape: {tuple(out.shape)}"
)
def test_one_generation_step_runs():
"""Greedy generation runs end-to-end without crashing."""
import torch
from tilelli.eval.metacog_probe import load_bridge
ckpt = KIT_ROOT / "checkpoints" / "tilelli_chat_v4.pt"
model, _abstain, tok = load_bridge(str(ckpt))
ids = tok.encode("USER: hello\nTILELLI:").long().unsqueeze(0)
with torch.no_grad():
full, generated, confs = model.generate_with_cache(
ids, n_new_tokens=8, stop_ids=(10, 0)
)
assert full.size(1) >= ids.size(1), "no tokens generated"
assert len(generated) >= 1, "no generated_id_list entries"
assert all(0.0 <= c <= 1.0 for c in confs), "confidence values out of range"
|