| """Three smoke tests so users can verify the install before the reproduce |
| scripts. |
| |
| Run with: |
| pip install -e ".[test]" |
| pytest -q tests/ |
| """ |
| import sys |
| from pathlib import Path |
|
|
| KIT_ROOT = Path(__file__).resolve().parent.parent |
| sys.path.insert(0, str(KIT_ROOT / "src")) |
|
|
|
|
| def test_kit_imports_cleanly(): |
| """All public-facing modules import without error.""" |
| from tilelli.core import TilelliLiteLM, PATHWAY_NAMES_LITE |
| from tilelli.core.ternary_linear import TernaryLinear |
| from tilelli.core.ternary_conv import TernaryCausalConv1d |
| from tilelli.core.sparse_attention import SparseCausalAttention |
| from tilelli.distillery.tokenize import ByteTokenizer |
| from tilelli.eval.metacog_probe import load_bridge |
|
|
| tok = ByteTokenizer() |
| enc = tok.encode("hello world") |
| out = tok.decode(enc.tolist()) |
| assert out == "hello world", f"tokenizer roundtrip broke: {out!r}" |
|
|
|
|
| def test_bundled_checkpoint_loads(): |
| """The bundled v4 checkpoint loads, has 10M params, and produces a forward.""" |
| import torch |
| from tilelli.eval.metacog_probe import load_bridge |
|
|
| ckpt = KIT_ROOT / "checkpoints" / "tilelli_chat_v4.pt" |
| assert ckpt.exists(), f"bundled ckpt missing at {ckpt}" |
|
|
| model, abstain, tok = load_bridge(str(ckpt)) |
| n_params = sum(p.numel() for p in model.parameters()) |
| assert 9_500_000 <= n_params <= 11_000_000, ( |
| f"param count {n_params} outside 9.5–11M; ckpt may be wrong file" |
| ) |
|
|
| ids = tok.encode("USER: hi\nTILELLI:").long().unsqueeze(0) |
| with torch.no_grad(): |
| out = model(ids) |
| assert out.ndim == 3 and out.shape[-1] == 256, ( |
| f"unexpected forward shape: {tuple(out.shape)}" |
| ) |
|
|
|
|
| def test_one_generation_step_runs(): |
| """Greedy generation runs end-to-end without crashing.""" |
| import torch |
| from tilelli.eval.metacog_probe import load_bridge |
|
|
| ckpt = KIT_ROOT / "checkpoints" / "tilelli_chat_v4.pt" |
| model, _abstain, tok = load_bridge(str(ckpt)) |
| ids = tok.encode("USER: hello\nTILELLI:").long().unsqueeze(0) |
|
|
| with torch.no_grad(): |
| full, generated, confs = model.generate_with_cache( |
| ids, n_new_tokens=8, stop_ids=(10, 0) |
| ) |
|
|
| assert full.size(1) >= ids.size(1), "no tokens generated" |
| assert len(generated) >= 1, "no generated_id_list entries" |
| assert all(0.0 <= c <= 1.0 for c in confs), "confidence values out of range" |
|
|