Tilelli-llm / tests /test_kit_smoke.py
TilelliLab's picture
Mirror small files (code, paper, results)
f86dc09 verified
Raw
History Blame Contribute Delete
2.46 kB
"""Three smoke tests so users can verify the install before the reproduce
scripts.
Run with:
pip install -e ".[test]"
pytest -q tests/
"""
import sys
from pathlib import Path
KIT_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(KIT_ROOT / "src"))
def test_kit_imports_cleanly():
"""All public-facing modules import without error."""
from tilelli.core import TilelliLiteLM, PATHWAY_NAMES_LITE # noqa: F401
from tilelli.core.ternary_linear import TernaryLinear # noqa: F401
from tilelli.core.ternary_conv import TernaryCausalConv1d # noqa: F401
from tilelli.core.sparse_attention import SparseCausalAttention # noqa: F401
from tilelli.distillery.tokenize import ByteTokenizer
from tilelli.eval.metacog_probe import load_bridge # noqa: F401
tok = ByteTokenizer()
enc = tok.encode("hello world")
out = tok.decode(enc.tolist())
assert out == "hello world", f"tokenizer roundtrip broke: {out!r}"
def test_bundled_checkpoint_loads():
"""The bundled v4 checkpoint loads, has 10M params, and produces a forward."""
import torch
from tilelli.eval.metacog_probe import load_bridge
ckpt = KIT_ROOT / "checkpoints" / "tilelli_chat_v4.pt"
assert ckpt.exists(), f"bundled ckpt missing at {ckpt}"
model, abstain, tok = load_bridge(str(ckpt))
n_params = sum(p.numel() for p in model.parameters())
assert 9_500_000 <= n_params <= 11_000_000, (
f"param count {n_params} outside 9.5–11M; ckpt may be wrong file"
)
ids = tok.encode("USER: hi\nTILELLI:").long().unsqueeze(0)
with torch.no_grad():
out = model(ids)
assert out.ndim == 3 and out.shape[-1] == 256, (
f"unexpected forward shape: {tuple(out.shape)}"
)
def test_one_generation_step_runs():
"""Greedy generation runs end-to-end without crashing."""
import torch
from tilelli.eval.metacog_probe import load_bridge
ckpt = KIT_ROOT / "checkpoints" / "tilelli_chat_v4.pt"
model, _abstain, tok = load_bridge(str(ckpt))
ids = tok.encode("USER: hello\nTILELLI:").long().unsqueeze(0)
with torch.no_grad():
full, generated, confs = model.generate_with_cache(
ids, n_new_tokens=8, stop_ids=(10, 0)
)
assert full.size(1) >= ids.size(1), "no tokens generated"
assert len(generated) >= 1, "no generated_id_list entries"
assert all(0.0 <= c <= 1.0 for c in confs), "confidence values out of range"