"""Three smoke tests so users can verify the install before the reproduce scripts. Run with: pip install -e ".[test]" pytest -q tests/ """ import sys from pathlib import Path KIT_ROOT = Path(__file__).resolve().parent.parent sys.path.insert(0, str(KIT_ROOT / "src")) def test_kit_imports_cleanly(): """All public-facing modules import without error.""" from tilelli.core import TilelliLiteLM, PATHWAY_NAMES_LITE # noqa: F401 from tilelli.core.ternary_linear import TernaryLinear # noqa: F401 from tilelli.core.ternary_conv import TernaryCausalConv1d # noqa: F401 from tilelli.core.sparse_attention import SparseCausalAttention # noqa: F401 from tilelli.distillery.tokenize import ByteTokenizer from tilelli.eval.metacog_probe import load_bridge # noqa: F401 tok = ByteTokenizer() enc = tok.encode("hello world") out = tok.decode(enc.tolist()) assert out == "hello world", f"tokenizer roundtrip broke: {out!r}" def test_bundled_checkpoint_loads(): """The bundled v4 checkpoint loads, has 10M params, and produces a forward.""" import torch from tilelli.eval.metacog_probe import load_bridge ckpt = KIT_ROOT / "checkpoints" / "tilelli_chat_v4.pt" assert ckpt.exists(), f"bundled ckpt missing at {ckpt}" model, abstain, tok = load_bridge(str(ckpt)) n_params = sum(p.numel() for p in model.parameters()) assert 9_500_000 <= n_params <= 11_000_000, ( f"param count {n_params} outside 9.5–11M; ckpt may be wrong file" ) ids = tok.encode("USER: hi\nTILELLI:").long().unsqueeze(0) with torch.no_grad(): out = model(ids) assert out.ndim == 3 and out.shape[-1] == 256, ( f"unexpected forward shape: {tuple(out.shape)}" ) def test_one_generation_step_runs(): """Greedy generation runs end-to-end without crashing.""" import torch from tilelli.eval.metacog_probe import load_bridge ckpt = KIT_ROOT / "checkpoints" / "tilelli_chat_v4.pt" model, _abstain, tok = load_bridge(str(ckpt)) ids = tok.encode("USER: hello\nTILELLI:").long().unsqueeze(0) with torch.no_grad(): full, generated, confs = model.generate_with_cache( ids, n_new_tokens=8, stop_ids=(10, 0) ) assert full.size(1) >= ids.size(1), "no tokens generated" assert len(generated) >= 1, "no generated_id_list entries" assert all(0.0 <= c <= 1.0 for c in confs), "confidence values out of range"