#!/usr/bin/env python3 from __future__ import annotations import argparse import ctypes import os import sys from contextlib import contextmanager from pathlib import Path def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Run inference with the local LFM2-350M MNN export." ) parser.add_argument( "prompt", nargs="?", help="User prompt. If omitted, the script reads from stdin.", ) parser.add_argument( "--config", default="config.json", help="Path to the exported MNN config file. Defaults to config.json next to this script.", ) parser.add_argument( "--system", default="", help="Optional system prompt inserted ahead of the user prompt.", ) parser.add_argument( "--stream", action="store_true", help="Stream tokens to stdout while generating.", ) parser.add_argument( "--raw-prompt", action="store_true", help="Treat the provided prompt as a fully formatted raw model prompt.", ) parser.add_argument( "--tmp-path", default="tmp", help="Temporary directory passed to the MNN runtime.", ) parser.add_argument( "--show-stats", action="store_true", help="Print prompt and generation stats to stderr after inference.", ) return parser.parse_args() def resolve_path(base_dir: Path, value: str) -> Path: path = Path(value) if path.is_absolute(): return path return base_dir / path def read_prompt(args: argparse.Namespace) -> str: if args.prompt is not None: return args.prompt if not sys.stdin.isatty(): prompt = sys.stdin.read() if prompt: return prompt raise SystemExit("Provide a prompt argument or pipe prompt text on stdin.") def build_prompt(user_prompt: str, system_prompt: str) -> str: parts = ["<|startoftext|>"] if system_prompt: parts.append(f"<|im_start|>system\n{system_prompt.rstrip()}\n<|im_end|>\n") parts.append(f"<|im_start|>user\n{user_prompt.rstrip()}\n<|im_end|>\n<|im_start|>assistant\n") return "".join(parts) @contextmanager def suppress_native_stdout(enabled: bool): if not enabled: yield return sys.stdout.flush() libc = ctypes.CDLL(None) libc.fflush(None) stdout_fd = sys.stdout.fileno() saved_stdout_fd = os.dup(stdout_fd) try: with open(os.devnull, "w", encoding="utf-8") as devnull: os.dup2(devnull.fileno(), stdout_fd) yield finally: libc.fflush(None) os.dup2(saved_stdout_fd, stdout_fd) os.close(saved_stdout_fd) def main() -> int: args = parse_args() base_dir = Path(__file__).resolve().parent config_path = resolve_path(base_dir, args.config) tmp_path = resolve_path(base_dir, args.tmp_path) tmp_path.mkdir(parents=True, exist_ok=True) prompt = read_prompt(args) formatted_prompt = prompt if args.raw_prompt else build_prompt(prompt, args.system) with suppress_native_stdout(not args.stream): import MNN.llm as mnn_llm model = mnn_llm.create(str(config_path)) model.set_config({"tmp_path": str(tmp_path), "use_template": False}) model.load() if model.context.status != mnn_llm.LlmStatus.RUNNING: raise RuntimeError(f"Model failed to load correctly: {model.context.status}") result = model.response(formatted_prompt, args.stream) if not args.stream: sys.stdout.write(result) if result and not result.endswith("\n"): sys.stdout.write("\n") if args.show_stats: context = model.context print( ( f"prompt_len={context.prompt_len} " f"gen_seq_len={context.gen_seq_len} " f"prefill_us={context.prefill_us} " f"decode_us={context.decode_us} " f"status={context.status}" ), file=sys.stderr, ) return 0 if __name__ == "__main__": raise SystemExit(main())