| |
| from __future__ import annotations |
|
|
| import argparse |
| import ctypes |
| import os |
| import sys |
| from contextlib import contextmanager |
| from pathlib import Path |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description="Run inference with the local LFM2-350M MNN export." |
| ) |
| parser.add_argument( |
| "prompt", |
| nargs="?", |
| help="User prompt. If omitted, the script reads from stdin.", |
| ) |
| parser.add_argument( |
| "--config", |
| default="config.json", |
| help="Path to the exported MNN config file. Defaults to config.json next to this script.", |
| ) |
| parser.add_argument( |
| "--system", |
| default="", |
| help="Optional system prompt inserted ahead of the user prompt.", |
| ) |
| parser.add_argument( |
| "--stream", |
| action="store_true", |
| help="Stream tokens to stdout while generating.", |
| ) |
| parser.add_argument( |
| "--raw-prompt", |
| action="store_true", |
| help="Treat the provided prompt as a fully formatted raw model prompt.", |
| ) |
| parser.add_argument( |
| "--tmp-path", |
| default="tmp", |
| help="Temporary directory passed to the MNN runtime.", |
| ) |
| parser.add_argument( |
| "--show-stats", |
| action="store_true", |
| help="Print prompt and generation stats to stderr after inference.", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def resolve_path(base_dir: Path, value: str) -> Path: |
| path = Path(value) |
| if path.is_absolute(): |
| return path |
| return base_dir / path |
|
|
|
|
| def read_prompt(args: argparse.Namespace) -> str: |
| if args.prompt is not None: |
| return args.prompt |
| if not sys.stdin.isatty(): |
| prompt = sys.stdin.read() |
| if prompt: |
| return prompt |
| raise SystemExit("Provide a prompt argument or pipe prompt text on stdin.") |
|
|
|
|
| def build_prompt(user_prompt: str, system_prompt: str) -> str: |
| parts = ["<|startoftext|>"] |
| if system_prompt: |
| parts.append(f"<|im_start|>system\n{system_prompt.rstrip()}\n<|im_end|>\n") |
| parts.append(f"<|im_start|>user\n{user_prompt.rstrip()}\n<|im_end|>\n<|im_start|>assistant\n") |
| return "".join(parts) |
|
|
|
|
| @contextmanager |
| def suppress_native_stdout(enabled: bool): |
| if not enabled: |
| yield |
| return |
|
|
| sys.stdout.flush() |
| libc = ctypes.CDLL(None) |
| libc.fflush(None) |
| stdout_fd = sys.stdout.fileno() |
| saved_stdout_fd = os.dup(stdout_fd) |
|
|
| try: |
| with open(os.devnull, "w", encoding="utf-8") as devnull: |
| os.dup2(devnull.fileno(), stdout_fd) |
| yield |
| finally: |
| libc.fflush(None) |
| os.dup2(saved_stdout_fd, stdout_fd) |
| os.close(saved_stdout_fd) |
|
|
|
|
| def main() -> int: |
| args = parse_args() |
| base_dir = Path(__file__).resolve().parent |
| config_path = resolve_path(base_dir, args.config) |
| tmp_path = resolve_path(base_dir, args.tmp_path) |
| tmp_path.mkdir(parents=True, exist_ok=True) |
|
|
| prompt = read_prompt(args) |
| formatted_prompt = prompt if args.raw_prompt else build_prompt(prompt, args.system) |
|
|
| with suppress_native_stdout(not args.stream): |
| import MNN.llm as mnn_llm |
|
|
| model = mnn_llm.create(str(config_path)) |
| model.set_config({"tmp_path": str(tmp_path), "use_template": False}) |
| model.load() |
|
|
| if model.context.status != mnn_llm.LlmStatus.RUNNING: |
| raise RuntimeError(f"Model failed to load correctly: {model.context.status}") |
|
|
| result = model.response(formatted_prompt, args.stream) |
|
|
| if not args.stream: |
| sys.stdout.write(result) |
| if result and not result.endswith("\n"): |
| sys.stdout.write("\n") |
|
|
| if args.show_stats: |
| context = model.context |
| print( |
| ( |
| f"prompt_len={context.prompt_len} " |
| f"gen_seq_len={context.gen_seq_len} " |
| f"prefill_us={context.prefill_us} " |
| f"decode_us={context.decode_us} " |
| f"status={context.status}" |
| ), |
| file=sys.stderr, |
| ) |
|
|
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| raise SystemExit(main()) |