File size: 5,169 Bytes
f86dc09 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | #!/usr/bin/env python3
"""Generic text generator — works with both bundled checkpoints.
Auto-routes between the two architectures based on the checkpoint config:
python infer.py # default: chat with v4 (FP32 chat-SFT'd, deployed)
python infer.py --ckpt checkpoints/tilelli_pretrain_v1_ternary.pt --prompt "Once upon a time"
For v4 (the deployed chat model), the prompt is wrapped as `USER: ... TILELLI:` automatically
unless you pass --raw. For pretrain checkpoints there's no chat format, so the prompt is
used verbatim.
"""
import argparse
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent / "src"))
import torch
from tilelli.utils import safe_load_checkpoint
from tilelli.distillery.tokenize import ByteTokenizer
def _strip_prefix(state_dict, prefix):
return {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)}
def load_model(ckpt_path: str):
"""Inspect the checkpoint config and instantiate the right model class."""
ckpt = safe_load_checkpoint(ckpt_path, trusted=True)
cfg = ckpt.get("base_model_cfg") or ckpt.get("model_cfg") or ckpt.get("config") or {}
raw = ckpt.get("model", ckpt)
builder = cfg.get("builder", "tilelli_lite")
if builder == "tilelli_lite" or "abstain.weight" in raw or "abstain.bias" in raw:
# Lite 3-pathway — the deployed chat v4 lives here
from tilelli.core.tilelli_lite import TilelliLiteLM
model = TilelliLiteLM(
vocab_size=cfg.get("vocab_size", 256),
d_model=cfg.get("d_model", 256),
n_layers=cfg.get("n_layers", 8),
n_heads=cfg.get("n_heads", 8),
top_k=cfg.get("top_k", 16),
ffn_expand=cfg.get("dense_expand", 4),
max_seq_len=cfg.get("max_seq_len", 256),
quantize=cfg.get("quantize", False),
)
base = {
k.replace("base.", "", 1): v
for k, v in raw.items()
if not k.startswith("abstain.")
}
model.load_state_dict(base, strict=False)
kind = "lite"
else:
# Parent multi-pathway (TilelliLM) — the ternary pretrain lives here
from tilelli.core.tilelli_lm import TilelliLM
model = TilelliLM(
vocab_size=cfg.get("vocab_size", 256),
d_model=cfg.get("d_model", 512),
n_layers=cfg.get("n_layers", 7),
d_head=cfg.get("d_head", 64),
top_k=cfg.get("top_k", 8),
pathways=cfg.get("pathways", 5),
max_seq_len=cfg.get("max_seq_len", 256),
quantize=cfg.get("quantize", True),
n_banks=cfg.get("n_banks", 1),
per_row=cfg.get("per_row", False),
hadamard=cfg.get("hadamard", False),
lsq=cfg.get("lsq", False),
dense_expand=cfg.get("dense_expand", 2),
fp_attention=cfg.get("fp_attention", False),
)
model.load_state_dict(raw, strict=False)
kind = "parent"
model.eval()
return model, cfg, kind
@torch.no_grad()
def generate(model, prompt_ids: torch.Tensor, n_new: int = 120, stop_ids=(10, 0)) -> torch.Tensor:
"""Generic greedy generation that works for both architectures."""
if hasattr(model, "generate_with_cache"):
full, _, _ = model.generate_with_cache(prompt_ids, n_new_tokens=n_new, stop_ids=stop_ids)
return full
if hasattr(model, "generate"):
return model.generate(prompt_ids, n_new_tokens=n_new)
# Fall back to a slow loop
ids = prompt_ids
max_ctx = getattr(model, "max_seq_len", 256)
for _ in range(n_new):
window = ids[:, -max_ctx:]
logits = model(window)
if logits.ndim == 3:
logits = logits[:, -1, :]
nxt = logits.argmax(dim=-1, keepdim=True)
ids = torch.cat([ids, nxt], dim=1)
if int(nxt) in stop_ids:
break
return ids
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--ckpt", default="checkpoints/tilelli_chat_v4.pt",
help="Checkpoint to load. Default = the FP32 chat-SFT'd v4 (deployed).")
ap.add_argument("--prompt", default=None,
help="Text to continue. For v4 it gets wrapped as USER:/TILELLI:.")
ap.add_argument("--raw", action="store_true",
help="Skip the USER:/TILELLI: wrapping (treat prompt as continuation seed).")
ap.add_argument("--max-new", type=int, default=120)
args = ap.parse_args()
tok = ByteTokenizer()
model, cfg, kind = load_model(args.ckpt)
n_params = sum(p.numel() for p in model.parameters())
print(
f"[infer] {args.ckpt}",
f"({kind}, {n_params/1e6:.2f}M params, quantize={cfg.get('quantize')})",
file=sys.stderr,
)
prompt = args.prompt or ("Hello, who are you?" if kind == "lite" else "Once upon a time")
if kind == "lite" and not args.raw:
prompt = f"USER: {prompt}\nTILELLI:"
ids = tok.encode(prompt).long().unsqueeze(0)
out = generate(model, ids, n_new=args.max_new)
text = tok.decode(out[0].tolist())
print(text)
if __name__ == "__main__":
main()
|