File size: 5,169 Bytes
f86dc09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#!/usr/bin/env python3
"""Generic text generator — works with both bundled checkpoints.

Auto-routes between the two architectures based on the checkpoint config:

    python infer.py                          # default: chat with v4 (FP32 chat-SFT'd, deployed)
    python infer.py --ckpt checkpoints/tilelli_pretrain_v1_ternary.pt --prompt "Once upon a time"

For v4 (the deployed chat model), the prompt is wrapped as `USER: ... TILELLI:` automatically
unless you pass --raw. For pretrain checkpoints there's no chat format, so the prompt is
used verbatim.
"""
import argparse
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent / "src"))

import torch

from tilelli.utils import safe_load_checkpoint
from tilelli.distillery.tokenize import ByteTokenizer


def _strip_prefix(state_dict, prefix):
    return {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)}


def load_model(ckpt_path: str):
    """Inspect the checkpoint config and instantiate the right model class."""
    ckpt = safe_load_checkpoint(ckpt_path, trusted=True)
    cfg = ckpt.get("base_model_cfg") or ckpt.get("model_cfg") or ckpt.get("config") or {}
    raw = ckpt.get("model", ckpt)

    builder = cfg.get("builder", "tilelli_lite")
    if builder == "tilelli_lite" or "abstain.weight" in raw or "abstain.bias" in raw:
        # Lite 3-pathway — the deployed chat v4 lives here
        from tilelli.core.tilelli_lite import TilelliLiteLM
        model = TilelliLiteLM(
            vocab_size=cfg.get("vocab_size", 256),
            d_model=cfg.get("d_model", 256),
            n_layers=cfg.get("n_layers", 8),
            n_heads=cfg.get("n_heads", 8),
            top_k=cfg.get("top_k", 16),
            ffn_expand=cfg.get("dense_expand", 4),
            max_seq_len=cfg.get("max_seq_len", 256),
            quantize=cfg.get("quantize", False),
        )
        base = {
            k.replace("base.", "", 1): v
            for k, v in raw.items()
            if not k.startswith("abstain.")
        }
        model.load_state_dict(base, strict=False)
        kind = "lite"
    else:
        # Parent multi-pathway (TilelliLM) — the ternary pretrain lives here
        from tilelli.core.tilelli_lm import TilelliLM
        model = TilelliLM(
            vocab_size=cfg.get("vocab_size", 256),
            d_model=cfg.get("d_model", 512),
            n_layers=cfg.get("n_layers", 7),
            d_head=cfg.get("d_head", 64),
            top_k=cfg.get("top_k", 8),
            pathways=cfg.get("pathways", 5),
            max_seq_len=cfg.get("max_seq_len", 256),
            quantize=cfg.get("quantize", True),
            n_banks=cfg.get("n_banks", 1),
            per_row=cfg.get("per_row", False),
            hadamard=cfg.get("hadamard", False),
            lsq=cfg.get("lsq", False),
            dense_expand=cfg.get("dense_expand", 2),
            fp_attention=cfg.get("fp_attention", False),
        )
        model.load_state_dict(raw, strict=False)
        kind = "parent"

    model.eval()
    return model, cfg, kind


@torch.no_grad()
def generate(model, prompt_ids: torch.Tensor, n_new: int = 120, stop_ids=(10, 0)) -> torch.Tensor:
    """Generic greedy generation that works for both architectures."""
    if hasattr(model, "generate_with_cache"):
        full, _, _ = model.generate_with_cache(prompt_ids, n_new_tokens=n_new, stop_ids=stop_ids)
        return full
    if hasattr(model, "generate"):
        return model.generate(prompt_ids, n_new_tokens=n_new)
    # Fall back to a slow loop
    ids = prompt_ids
    max_ctx = getattr(model, "max_seq_len", 256)
    for _ in range(n_new):
        window = ids[:, -max_ctx:]
        logits = model(window)
        if logits.ndim == 3:
            logits = logits[:, -1, :]
        nxt = logits.argmax(dim=-1, keepdim=True)
        ids = torch.cat([ids, nxt], dim=1)
        if int(nxt) in stop_ids:
            break
    return ids


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--ckpt", default="checkpoints/tilelli_chat_v4.pt",
                    help="Checkpoint to load. Default = the FP32 chat-SFT'd v4 (deployed).")
    ap.add_argument("--prompt", default=None,
                    help="Text to continue. For v4 it gets wrapped as USER:/TILELLI:.")
    ap.add_argument("--raw", action="store_true",
                    help="Skip the USER:/TILELLI: wrapping (treat prompt as continuation seed).")
    ap.add_argument("--max-new", type=int, default=120)
    args = ap.parse_args()

    tok = ByteTokenizer()
    model, cfg, kind = load_model(args.ckpt)
    n_params = sum(p.numel() for p in model.parameters())
    print(
        f"[infer] {args.ckpt}",
        f"({kind}, {n_params/1e6:.2f}M params, quantize={cfg.get('quantize')})",
        file=sys.stderr,
    )

    prompt = args.prompt or ("Hello, who are you?" if kind == "lite" else "Once upon a time")
    if kind == "lite" and not args.raw:
        prompt = f"USER: {prompt}\nTILELLI:"

    ids = tok.encode(prompt).long().unsqueeze(0)
    out = generate(model, ids, n_new=args.max_new)
    text = tok.decode(out[0].tolist())
    print(text)


if __name__ == "__main__":
    main()