| |
| import argparse |
| import json |
| import subprocess |
| import tempfile |
| from collections.abc import Mapping |
| from pathlib import Path |
|
|
| from transformers import AutoTokenizer |
|
|
|
|
| def resolve_mask_id(tok) -> int: |
| candidates = [ |
| tok.convert_tokens_to_ids("<|mdm_mask|>"), |
| getattr(tok, "gmask_token_id", None), |
| getattr(tok, "mask_token_id", None), |
| tok.convert_tokens_to_ids("[gMASK]"), |
| tok.convert_tokens_to_ids("[MASK]"), |
| tok.convert_tokens_to_ids("<mask>"), |
| ] |
| for cid in candidates: |
| if cid is not None and cid != tok.unk_token_id: |
| return int(cid) |
| raise RuntimeError("Could not determine mask token id for tokenizer") |
|
|
|
|
| def main() -> None: |
| root = Path(__file__).resolve().parents[1] |
| script_dir = Path(__file__).resolve().parent |
| package_path = root / "llada_8b_instruct_seq192.mlpackage" |
| compiled_path = root / "llada_8b_instruct_seq192.mlmodelc" |
| if package_path.exists(): |
| default_model = package_path |
| elif compiled_path.exists(): |
| default_model = compiled_path |
| else: |
| |
| default_model = root / "models" / "compiled_v2" / "llada_8b_instruct_seq192.mlmodelc" |
| swift_script = script_dir / "llada_diffuse.swift" |
| if not swift_script.exists(): |
| swift_script = root / "scripts" / "llada_diffuse.swift" |
|
|
| p = argparse.ArgumentParser(description="Run LLaDA CoreML diffusion loop from a text prompt") |
| p.add_argument("prompt", help="User prompt text") |
| p.add_argument("--model", default=str(default_model), help="Path to compiled .mlmodelc or .mlpackage") |
| p.add_argument("--tokenizer", default="GSAI-ML/LLaDA-8B-Instruct", help="HF tokenizer repo") |
| p.add_argument("--seq-len", type=int, default=192, help="Fixed model sequence length") |
| p.add_argument("--max-new-tokens", type=int, default=64, help="Number of tokens to generate") |
| p.add_argument("--steps", type=int, default=24, help="Diffusion denoise steps") |
| p.add_argument( |
| "--compute-units", |
| default="all", |
| choices=["all", "cpuOnly", "cpuAndGPU", "cpuAndNeuralEngine", "cpu", "cpugpu", "cpune"], |
| help="CoreML compute units", |
| ) |
| p.add_argument("--no-chat-template", action="store_true", help="Disable tokenizer chat template") |
| p.add_argument("--json", action="store_true", help="Print full JSON result") |
| args = p.parse_args() |
|
|
| tok = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True) |
|
|
| if not args.no_chat_template and getattr(tok, "chat_template", None): |
| chat_tokens = tok.apply_chat_template( |
| [{"role": "user", "content": args.prompt}], |
| tokenize=True, |
| add_generation_prompt=True, |
| ) |
| if isinstance(chat_tokens, Mapping): |
| prompt_ids = chat_tokens["input_ids"] |
| else: |
| prompt_ids = chat_tokens |
| else: |
| prompt_ids = tok(args.prompt, add_special_tokens=True, return_attention_mask=False)["input_ids"] |
|
|
| if len(prompt_ids) >= args.seq_len: |
| prompt_ids = prompt_ids[: args.seq_len - 1] |
|
|
| mask_id = resolve_mask_id(tok) |
| eos_id = tok.eos_token_id if tok.eos_token_id is not None else 0 |
| pad_id = tok.pad_token_id if tok.pad_token_id is not None else eos_id |
|
|
| payload = { |
| "prompt": args.prompt, |
| "prompt_ids": [int(x) for x in prompt_ids], |
| "seq_len": int(args.seq_len), |
| "max_new_tokens": int(args.max_new_tokens), |
| "steps": int(args.steps), |
| "mask_token_id": int(mask_id), |
| "eos_token_id": int(eos_id), |
| "pad_token_id": int(pad_id), |
| "compute_units": args.compute_units, |
| } |
|
|
| with tempfile.TemporaryDirectory(prefix="llada_run_") as td: |
| td_path = Path(td) |
| input_path = td_path / "input.json" |
| output_path = td_path / "output.json" |
| input_path.write_text(json.dumps(payload), encoding="utf-8") |
|
|
| cmd = [ |
| "swift", |
| str(swift_script), |
| "--model", |
| str(args.model), |
| "--input", |
| str(input_path), |
| "--output", |
| str(output_path), |
| ] |
| subprocess.run(cmd, check=True) |
| out = json.loads(output_path.read_text(encoding="utf-8")) |
|
|
| generated_ids = out["generated_ids"] |
| generated_ids_untrimmed = out["generated_ids_untrimmed"] |
|
|
| decoded_generated = tok.decode(generated_ids, skip_special_tokens=True) |
| decoded_generated_with_specials = tok.decode(generated_ids_untrimmed, skip_special_tokens=False) |
| prompt_decoded = tok.decode(prompt_ids, skip_special_tokens=False) |
|
|
| out["decoded_prompt"] = prompt_decoded |
| out["decoded_generated"] = decoded_generated |
| out["decoded_generated_with_specials"] = decoded_generated_with_specials |
|
|
| if args.json: |
| print(json.dumps(out, indent=2)) |
| return |
|
|
| print("Prompt:") |
| print(args.prompt) |
| print() |
| print("Decoded Prompt Tokens:") |
| print(prompt_decoded) |
| print() |
| print("Generated Text:") |
| print(decoded_generated if decoded_generated else "(empty)") |
| print() |
| print("Generated Text (with specials):") |
| print(decoded_generated_with_specials) |
| print() |
| print( |
| f"Timing: load={out['load_seconds']:.2f}s predict_total={out['total_predict_seconds']:.2f}s loop={out['loop_seconds']:.2f}s" |
| ) |
| print(f"Tokens: prompt={out['prompt_len']} total={out['total_len']} generated={len(generated_ids)}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|