File size: 4,153 Bytes
19e6a5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
#!/usr/bin/env python3
from __future__ import annotations

import argparse
import ctypes
import os
import sys
from contextlib import contextmanager
from pathlib import Path


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Run inference with the local LFM2-350M MNN export."
    )
    parser.add_argument(
        "prompt",
        nargs="?",
        help="User prompt. If omitted, the script reads from stdin.",
    )
    parser.add_argument(
        "--config",
        default="config.json",
        help="Path to the exported MNN config file. Defaults to config.json next to this script.",
    )
    parser.add_argument(
        "--system",
        default="",
        help="Optional system prompt inserted ahead of the user prompt.",
    )
    parser.add_argument(
        "--stream",
        action="store_true",
        help="Stream tokens to stdout while generating.",
    )
    parser.add_argument(
        "--raw-prompt",
        action="store_true",
        help="Treat the provided prompt as a fully formatted raw model prompt.",
    )
    parser.add_argument(
        "--tmp-path",
        default="tmp",
        help="Temporary directory passed to the MNN runtime.",
    )
    parser.add_argument(
        "--show-stats",
        action="store_true",
        help="Print prompt and generation stats to stderr after inference.",
    )
    return parser.parse_args()


def resolve_path(base_dir: Path, value: str) -> Path:
    path = Path(value)
    if path.is_absolute():
        return path
    return base_dir / path


def read_prompt(args: argparse.Namespace) -> str:
    if args.prompt is not None:
        return args.prompt
    if not sys.stdin.isatty():
        prompt = sys.stdin.read()
        if prompt:
            return prompt
    raise SystemExit("Provide a prompt argument or pipe prompt text on stdin.")


def build_prompt(user_prompt: str, system_prompt: str) -> str:
    parts = ["<|startoftext|>"]
    if system_prompt:
        parts.append(f"<|im_start|>system\n{system_prompt.rstrip()}\n<|im_end|>\n")
    parts.append(f"<|im_start|>user\n{user_prompt.rstrip()}\n<|im_end|>\n<|im_start|>assistant\n")
    return "".join(parts)


@contextmanager
def suppress_native_stdout(enabled: bool):
    if not enabled:
        yield
        return

    sys.stdout.flush()
    libc = ctypes.CDLL(None)
    libc.fflush(None)
    stdout_fd = sys.stdout.fileno()
    saved_stdout_fd = os.dup(stdout_fd)

    try:
        with open(os.devnull, "w", encoding="utf-8") as devnull:
            os.dup2(devnull.fileno(), stdout_fd)
            yield
    finally:
        libc.fflush(None)
        os.dup2(saved_stdout_fd, stdout_fd)
        os.close(saved_stdout_fd)


def main() -> int:
    args = parse_args()
    base_dir = Path(__file__).resolve().parent
    config_path = resolve_path(base_dir, args.config)
    tmp_path = resolve_path(base_dir, args.tmp_path)
    tmp_path.mkdir(parents=True, exist_ok=True)

    prompt = read_prompt(args)
    formatted_prompt = prompt if args.raw_prompt else build_prompt(prompt, args.system)

    with suppress_native_stdout(not args.stream):
        import MNN.llm as mnn_llm

        model = mnn_llm.create(str(config_path))
        model.set_config({"tmp_path": str(tmp_path), "use_template": False})
        model.load()

        if model.context.status != mnn_llm.LlmStatus.RUNNING:
            raise RuntimeError(f"Model failed to load correctly: {model.context.status}")

        result = model.response(formatted_prompt, args.stream)

    if not args.stream:
        sys.stdout.write(result)
        if result and not result.endswith("\n"):
            sys.stdout.write("\n")

    if args.show_stats:
        context = model.context
        print(
            (
                f"prompt_len={context.prompt_len} "
                f"gen_seq_len={context.gen_seq_len} "
                f"prefill_us={context.prefill_us} "
                f"decode_us={context.decode_us} "
                f"status={context.status}"
            ),
            file=sys.stderr,
        )

    return 0


if __name__ == "__main__":
    raise SystemExit(main())