oraculumai's picture
examples: fix path resolution for Swift runner
de84de8 verified
#!/usr/bin/env python3
import argparse
import json
import subprocess
import tempfile
from collections.abc import Mapping
from pathlib import Path
from transformers import AutoTokenizer
def resolve_mask_id(tok) -> int:
candidates = [
tok.convert_tokens_to_ids("<|mdm_mask|>"),
getattr(tok, "gmask_token_id", None),
getattr(tok, "mask_token_id", None),
tok.convert_tokens_to_ids("[gMASK]"),
tok.convert_tokens_to_ids("[MASK]"),
tok.convert_tokens_to_ids("<mask>"),
]
for cid in candidates:
if cid is not None and cid != tok.unk_token_id:
return int(cid)
raise RuntimeError("Could not determine mask token id for tokenizer")
def main() -> None:
root = Path(__file__).resolve().parents[1]
script_dir = Path(__file__).resolve().parent
package_path = root / "llada_8b_instruct_seq192.mlpackage"
compiled_path = root / "llada_8b_instruct_seq192.mlmodelc"
if package_path.exists():
default_model = package_path
elif compiled_path.exists():
default_model = compiled_path
else:
# Local workspace fallback used during development.
default_model = root / "models" / "compiled_v2" / "llada_8b_instruct_seq192.mlmodelc"
swift_script = script_dir / "llada_diffuse.swift"
if not swift_script.exists():
swift_script = root / "scripts" / "llada_diffuse.swift"
p = argparse.ArgumentParser(description="Run LLaDA CoreML diffusion loop from a text prompt")
p.add_argument("prompt", help="User prompt text")
p.add_argument("--model", default=str(default_model), help="Path to compiled .mlmodelc or .mlpackage")
p.add_argument("--tokenizer", default="GSAI-ML/LLaDA-8B-Instruct", help="HF tokenizer repo")
p.add_argument("--seq-len", type=int, default=192, help="Fixed model sequence length")
p.add_argument("--max-new-tokens", type=int, default=64, help="Number of tokens to generate")
p.add_argument("--steps", type=int, default=24, help="Diffusion denoise steps")
p.add_argument(
"--compute-units",
default="all",
choices=["all", "cpuOnly", "cpuAndGPU", "cpuAndNeuralEngine", "cpu", "cpugpu", "cpune"],
help="CoreML compute units",
)
p.add_argument("--no-chat-template", action="store_true", help="Disable tokenizer chat template")
p.add_argument("--json", action="store_true", help="Print full JSON result")
args = p.parse_args()
tok = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True)
if not args.no_chat_template and getattr(tok, "chat_template", None):
chat_tokens = tok.apply_chat_template(
[{"role": "user", "content": args.prompt}],
tokenize=True,
add_generation_prompt=True,
)
if isinstance(chat_tokens, Mapping):
prompt_ids = chat_tokens["input_ids"]
else:
prompt_ids = chat_tokens
else:
prompt_ids = tok(args.prompt, add_special_tokens=True, return_attention_mask=False)["input_ids"]
if len(prompt_ids) >= args.seq_len:
prompt_ids = prompt_ids[: args.seq_len - 1]
mask_id = resolve_mask_id(tok)
eos_id = tok.eos_token_id if tok.eos_token_id is not None else 0
pad_id = tok.pad_token_id if tok.pad_token_id is not None else eos_id
payload = {
"prompt": args.prompt,
"prompt_ids": [int(x) for x in prompt_ids],
"seq_len": int(args.seq_len),
"max_new_tokens": int(args.max_new_tokens),
"steps": int(args.steps),
"mask_token_id": int(mask_id),
"eos_token_id": int(eos_id),
"pad_token_id": int(pad_id),
"compute_units": args.compute_units,
}
with tempfile.TemporaryDirectory(prefix="llada_run_") as td:
td_path = Path(td)
input_path = td_path / "input.json"
output_path = td_path / "output.json"
input_path.write_text(json.dumps(payload), encoding="utf-8")
cmd = [
"swift",
str(swift_script),
"--model",
str(args.model),
"--input",
str(input_path),
"--output",
str(output_path),
]
subprocess.run(cmd, check=True)
out = json.loads(output_path.read_text(encoding="utf-8"))
generated_ids = out["generated_ids"]
generated_ids_untrimmed = out["generated_ids_untrimmed"]
decoded_generated = tok.decode(generated_ids, skip_special_tokens=True)
decoded_generated_with_specials = tok.decode(generated_ids_untrimmed, skip_special_tokens=False)
prompt_decoded = tok.decode(prompt_ids, skip_special_tokens=False)
out["decoded_prompt"] = prompt_decoded
out["decoded_generated"] = decoded_generated
out["decoded_generated_with_specials"] = decoded_generated_with_specials
if args.json:
print(json.dumps(out, indent=2))
return
print("Prompt:")
print(args.prompt)
print()
print("Decoded Prompt Tokens:")
print(prompt_decoded)
print()
print("Generated Text:")
print(decoded_generated if decoded_generated else "(empty)")
print()
print("Generated Text (with specials):")
print(decoded_generated_with_specials)
print()
print(
f"Timing: load={out['load_seconds']:.2f}s predict_total={out['total_predict_seconds']:.2f}s loop={out['loop_seconds']:.2f}s"
)
print(f"Tokens: prompt={out['prompt_len']} total={out['total_len']} generated={len(generated_ids)}")
if __name__ == "__main__":
main()