developerabu commited on
Commit
19e6a5e
·
verified ·
1 Parent(s): 5acef5b

Upload run_inference.py

Browse files
Files changed (1) hide show
  1. run_inference.py +145 -0
run_inference.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import ctypes
6
+ import os
7
+ import sys
8
+ from contextlib import contextmanager
9
+ from pathlib import Path
10
+
11
+
12
+ def parse_args() -> argparse.Namespace:
13
+ parser = argparse.ArgumentParser(
14
+ description="Run inference with the local LFM2-350M MNN export."
15
+ )
16
+ parser.add_argument(
17
+ "prompt",
18
+ nargs="?",
19
+ help="User prompt. If omitted, the script reads from stdin.",
20
+ )
21
+ parser.add_argument(
22
+ "--config",
23
+ default="config.json",
24
+ help="Path to the exported MNN config file. Defaults to config.json next to this script.",
25
+ )
26
+ parser.add_argument(
27
+ "--system",
28
+ default="",
29
+ help="Optional system prompt inserted ahead of the user prompt.",
30
+ )
31
+ parser.add_argument(
32
+ "--stream",
33
+ action="store_true",
34
+ help="Stream tokens to stdout while generating.",
35
+ )
36
+ parser.add_argument(
37
+ "--raw-prompt",
38
+ action="store_true",
39
+ help="Treat the provided prompt as a fully formatted raw model prompt.",
40
+ )
41
+ parser.add_argument(
42
+ "--tmp-path",
43
+ default="tmp",
44
+ help="Temporary directory passed to the MNN runtime.",
45
+ )
46
+ parser.add_argument(
47
+ "--show-stats",
48
+ action="store_true",
49
+ help="Print prompt and generation stats to stderr after inference.",
50
+ )
51
+ return parser.parse_args()
52
+
53
+
54
+ def resolve_path(base_dir: Path, value: str) -> Path:
55
+ path = Path(value)
56
+ if path.is_absolute():
57
+ return path
58
+ return base_dir / path
59
+
60
+
61
+ def read_prompt(args: argparse.Namespace) -> str:
62
+ if args.prompt is not None:
63
+ return args.prompt
64
+ if not sys.stdin.isatty():
65
+ prompt = sys.stdin.read()
66
+ if prompt:
67
+ return prompt
68
+ raise SystemExit("Provide a prompt argument or pipe prompt text on stdin.")
69
+
70
+
71
+ def build_prompt(user_prompt: str, system_prompt: str) -> str:
72
+ parts = ["<|startoftext|>"]
73
+ if system_prompt:
74
+ parts.append(f"<|im_start|>system\n{system_prompt.rstrip()}\n<|im_end|>\n")
75
+ parts.append(f"<|im_start|>user\n{user_prompt.rstrip()}\n<|im_end|>\n<|im_start|>assistant\n")
76
+ return "".join(parts)
77
+
78
+
79
+ @contextmanager
80
+ def suppress_native_stdout(enabled: bool):
81
+ if not enabled:
82
+ yield
83
+ return
84
+
85
+ sys.stdout.flush()
86
+ libc = ctypes.CDLL(None)
87
+ libc.fflush(None)
88
+ stdout_fd = sys.stdout.fileno()
89
+ saved_stdout_fd = os.dup(stdout_fd)
90
+
91
+ try:
92
+ with open(os.devnull, "w", encoding="utf-8") as devnull:
93
+ os.dup2(devnull.fileno(), stdout_fd)
94
+ yield
95
+ finally:
96
+ libc.fflush(None)
97
+ os.dup2(saved_stdout_fd, stdout_fd)
98
+ os.close(saved_stdout_fd)
99
+
100
+
101
+ def main() -> int:
102
+ args = parse_args()
103
+ base_dir = Path(__file__).resolve().parent
104
+ config_path = resolve_path(base_dir, args.config)
105
+ tmp_path = resolve_path(base_dir, args.tmp_path)
106
+ tmp_path.mkdir(parents=True, exist_ok=True)
107
+
108
+ prompt = read_prompt(args)
109
+ formatted_prompt = prompt if args.raw_prompt else build_prompt(prompt, args.system)
110
+
111
+ with suppress_native_stdout(not args.stream):
112
+ import MNN.llm as mnn_llm
113
+
114
+ model = mnn_llm.create(str(config_path))
115
+ model.set_config({"tmp_path": str(tmp_path), "use_template": False})
116
+ model.load()
117
+
118
+ if model.context.status != mnn_llm.LlmStatus.RUNNING:
119
+ raise RuntimeError(f"Model failed to load correctly: {model.context.status}")
120
+
121
+ result = model.response(formatted_prompt, args.stream)
122
+
123
+ if not args.stream:
124
+ sys.stdout.write(result)
125
+ if result and not result.endswith("\n"):
126
+ sys.stdout.write("\n")
127
+
128
+ if args.show_stats:
129
+ context = model.context
130
+ print(
131
+ (
132
+ f"prompt_len={context.prompt_len} "
133
+ f"gen_seq_len={context.gen_seq_len} "
134
+ f"prefill_us={context.prefill_us} "
135
+ f"decode_us={context.decode_us} "
136
+ f"status={context.status}"
137
+ ),
138
+ file=sys.stderr,
139
+ )
140
+
141
+ return 0
142
+
143
+
144
+ if __name__ == "__main__":
145
+ raise SystemExit(main())