File size: 3,033 Bytes
025033b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""Text generation CLI for MLX LLaDA2.0-Uni backbone.

Usage:
    python generate.py --prompt "Hello world" [--gen-length 64] [--steps-per-block 16]
"""
import argparse
import json
import time
from pathlib import Path

import mlx.core as mx
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer

from llada2.generate import generate_text
from llada2.model import LLaDA2Config, LLaDA2Model
from llada2.weights import load_weights_into_model


def apply_chat_template(tokenizer, prompt: str) -> str:
    """Tokenize using LLaDA2 chat template."""
    messages = [{"role": "user", "content": prompt}]
    return tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--prompt", type=str, required=True)
    parser.add_argument("--gen-length", type=int, default=128)
    parser.add_argument("--block-length", type=int, default=32)
    parser.add_argument("--steps-per-block", type=int, default=16)
    parser.add_argument("--temperature", type=float, default=0.0)
    parser.add_argument("--threshold", type=float, default=0.95)
    parser.add_argument("--repo-id", default="inclusionAI/LLaDA2.0-Uni")
    args = parser.parse_args()

    print(f"[gen] fetching model files…")
    snap = snapshot_download(
        args.repo_id,
        allow_patterns=[
            "model-*.safetensors", "model.safetensors.index.json",
            "config.json", "tokenizer*", "special_tokens_map.json",
        ],
    )
    snap = Path(snap)

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(str(snap), trust_remote_code=True)

    # Build model
    cfg_data = json.loads((snap / "config.json").read_text())
    config = LLaDA2Config.from_hf(cfg_data)
    model = LLaDA2Model(config)

    # Load weights
    print(f"[gen] loading weights from {snap}")
    t0 = time.time()
    load_weights_into_model(model, snap, dtype=mx.bfloat16)
    print(f"[gen] weights loaded in {time.time()-t0:.1f}s")

    # Tokenize prompt
    prompt_text = apply_chat_template(tokenizer, args.prompt)
    print(f"\n[gen] prompt (chat-templated):\n{prompt_text!r}\n")
    prompt_ids = tokenizer(prompt_text, return_tensors="np").input_ids
    prompt_ids = mx.array(prompt_ids, dtype=mx.int32)
    print(f"[gen] prompt token count: {prompt_ids.shape[1]}")

    # Generate
    t0 = time.time()
    out_ids = generate_text(
        model, prompt_ids,
        gen_length=args.gen_length,
        block_length=args.block_length,
        steps_per_block=args.steps_per_block,
        temperature=args.temperature,
        threshold=args.threshold,
        mask_token_id=config.mask_token_id,
        eos_token_id=config.eos_token_id,
    )
    mx.eval(out_ids)
    dt = time.time() - t0

    gen_ids = out_ids[0, prompt_ids.shape[1]:].tolist()
    text = tokenizer.decode(gen_ids, skip_special_tokens=False)
    print(f"\n[gen] ==== GENERATED ({dt:.1f}s) ====\n{text}")


if __name__ == "__main__":
    main()