File size: 4,153 Bytes
19e6a5e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | #!/usr/bin/env python3
from __future__ import annotations
import argparse
import ctypes
import os
import sys
from contextlib import contextmanager
from pathlib import Path
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run inference with the local LFM2-350M MNN export."
)
parser.add_argument(
"prompt",
nargs="?",
help="User prompt. If omitted, the script reads from stdin.",
)
parser.add_argument(
"--config",
default="config.json",
help="Path to the exported MNN config file. Defaults to config.json next to this script.",
)
parser.add_argument(
"--system",
default="",
help="Optional system prompt inserted ahead of the user prompt.",
)
parser.add_argument(
"--stream",
action="store_true",
help="Stream tokens to stdout while generating.",
)
parser.add_argument(
"--raw-prompt",
action="store_true",
help="Treat the provided prompt as a fully formatted raw model prompt.",
)
parser.add_argument(
"--tmp-path",
default="tmp",
help="Temporary directory passed to the MNN runtime.",
)
parser.add_argument(
"--show-stats",
action="store_true",
help="Print prompt and generation stats to stderr after inference.",
)
return parser.parse_args()
def resolve_path(base_dir: Path, value: str) -> Path:
path = Path(value)
if path.is_absolute():
return path
return base_dir / path
def read_prompt(args: argparse.Namespace) -> str:
if args.prompt is not None:
return args.prompt
if not sys.stdin.isatty():
prompt = sys.stdin.read()
if prompt:
return prompt
raise SystemExit("Provide a prompt argument or pipe prompt text on stdin.")
def build_prompt(user_prompt: str, system_prompt: str) -> str:
parts = ["<|startoftext|>"]
if system_prompt:
parts.append(f"<|im_start|>system\n{system_prompt.rstrip()}\n<|im_end|>\n")
parts.append(f"<|im_start|>user\n{user_prompt.rstrip()}\n<|im_end|>\n<|im_start|>assistant\n")
return "".join(parts)
@contextmanager
def suppress_native_stdout(enabled: bool):
if not enabled:
yield
return
sys.stdout.flush()
libc = ctypes.CDLL(None)
libc.fflush(None)
stdout_fd = sys.stdout.fileno()
saved_stdout_fd = os.dup(stdout_fd)
try:
with open(os.devnull, "w", encoding="utf-8") as devnull:
os.dup2(devnull.fileno(), stdout_fd)
yield
finally:
libc.fflush(None)
os.dup2(saved_stdout_fd, stdout_fd)
os.close(saved_stdout_fd)
def main() -> int:
args = parse_args()
base_dir = Path(__file__).resolve().parent
config_path = resolve_path(base_dir, args.config)
tmp_path = resolve_path(base_dir, args.tmp_path)
tmp_path.mkdir(parents=True, exist_ok=True)
prompt = read_prompt(args)
formatted_prompt = prompt if args.raw_prompt else build_prompt(prompt, args.system)
with suppress_native_stdout(not args.stream):
import MNN.llm as mnn_llm
model = mnn_llm.create(str(config_path))
model.set_config({"tmp_path": str(tmp_path), "use_template": False})
model.load()
if model.context.status != mnn_llm.LlmStatus.RUNNING:
raise RuntimeError(f"Model failed to load correctly: {model.context.status}")
result = model.response(formatted_prompt, args.stream)
if not args.stream:
sys.stdout.write(result)
if result and not result.endswith("\n"):
sys.stdout.write("\n")
if args.show_stats:
context = model.context
print(
(
f"prompt_len={context.prompt_len} "
f"gen_seq_len={context.gen_seq_len} "
f"prefill_us={context.prefill_us} "
f"decode_us={context.decode_us} "
f"status={context.status}"
),
file=sys.stderr,
)
return 0
if __name__ == "__main__":
raise SystemExit(main()) |