| |
| """MNN LLM Inference & Benchmark script for LFM2-350M model.""" |
|
|
| import sys |
| import os |
| import time |
| import argparse |
|
|
| import MNN.llm as llm |
|
|
|
|
| def run_inference(model, prompt, stream=False): |
| """Run a single inference and return the response + timing context.""" |
| model.reset() |
| response = model.response(prompt, stream) |
| if stream: |
| output = "" |
| for chunk in response: |
| print(chunk, end="", flush=True) |
| output += chunk |
| print() |
| return output |
| return response |
|
|
|
|
| def benchmark(model, prompts, warmup=1, runs=3): |
| """Benchmark prefill and decode performance across multiple prompts.""" |
| print("=" * 60) |
| print("BENCHMARK") |
| print("=" * 60) |
|
|
| |
| print(f"\nWarmup ({warmup} run(s))...") |
| for i in range(warmup): |
| model.reset() |
| model.response(prompts[0], False) |
|
|
| results = [] |
| for idx, prompt in enumerate(prompts): |
| prompt_results = [] |
| for run in range(runs): |
| model.reset() |
| t0 = time.perf_counter() |
| response = model.response(prompt, False) |
| t1 = time.perf_counter() |
| wall_time = t1 - t0 |
|
|
| ctx = model.context |
| ctx.refresh() |
|
|
| prompt_tokens = ctx.prompt_len |
| gen_tokens = ctx.gen_seq_len |
| prefill_us = ctx.prefill_us |
| decode_us = ctx.decode_us |
|
|
| prefill_s = prefill_us / 1e6 if prefill_us else 0 |
| decode_s = decode_us / 1e6 if decode_us else 0 |
|
|
| prefill_tps = prompt_tokens / prefill_s if prefill_s > 0 else 0 |
| decode_tps = gen_tokens / decode_s if decode_s > 0 else 0 |
|
|
| prompt_results.append({ |
| "prompt_tokens": prompt_tokens, |
| "gen_tokens": gen_tokens, |
| "wall_time": wall_time, |
| "prefill_s": prefill_s, |
| "decode_s": decode_s, |
| "prefill_tps": prefill_tps, |
| "decode_tps": decode_tps, |
| "response": response, |
| }) |
|
|
| results.append(prompt_results) |
|
|
| |
| avg_prefill_tps = sum(r["prefill_tps"] for r in prompt_results) / runs |
| avg_decode_tps = sum(r["decode_tps"] for r in prompt_results) / runs |
| avg_wall = sum(r["wall_time"] for r in prompt_results) / runs |
| prompt_tokens = prompt_results[0]["prompt_tokens"] |
| avg_gen = sum(r["gen_tokens"] for r in prompt_results) / runs |
|
|
| print(f"\nPrompt {idx + 1}: \"{prompt[:60]}{'...' if len(prompt) > 60 else ''}\"") |
| print(f" Prompt tokens : {prompt_tokens}") |
| print(f" Avg gen tokens : {avg_gen:.1f}") |
| print(f" Avg wall time : {avg_wall:.3f} s") |
| print(f" Avg prefill : {avg_prefill_tps:.1f} tok/s") |
| print(f" Avg decode : {avg_decode_tps:.1f} tok/s") |
|
|
| |
| all_runs = [r for pr in results for r in pr] |
| overall_prefill = sum(r["prefill_tps"] for r in all_runs) / len(all_runs) |
| overall_decode = sum(r["decode_tps"] for r in all_runs) / len(all_runs) |
| print("\n" + "=" * 60) |
| print(f"Overall avg prefill : {overall_prefill:.1f} tok/s") |
| print(f"Overall avg decode : {overall_decode:.1f} tok/s") |
| print("=" * 60) |
|
|
| return results |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="MNN LLM Inference & Benchmark") |
| parser.add_argument("--config", default="config.json", |
| help="Path to MNN config.json (default: config.json)") |
| parser.add_argument("--prompt", default=None, |
| help="Single prompt for inference") |
| parser.add_argument("--stream", action="store_true", |
| help="Stream output tokens") |
| parser.add_argument("--benchmark", action="store_true", |
| help="Run benchmark suite") |
| parser.add_argument("--warmup", type=int, default=1, |
| help="Warmup runs for benchmark (default: 1)") |
| parser.add_argument("--runs", type=int, default=3, |
| help="Benchmark runs per prompt (default: 3)") |
| parser.add_argument("--backend", default=None, |
| choices=["cpu", "metal"], |
| help="Override backend type") |
| parser.add_argument("--threads", type=int, default=None, |
| help="Override thread count") |
| parser.add_argument("--max-tokens", type=int, default=128, |
| help="Max tokens to generate (default: 128)") |
| args = parser.parse_args() |
|
|
| model_dir = os.path.dirname(os.path.abspath(args.config)) |
| config_path = os.path.abspath(args.config) |
|
|
| print(f"Loading model from: {config_path}") |
| model = llm.create(config_path) |
|
|
| if args.backend: |
| model.set_config({"backend_type": args.backend}) |
| if args.threads: |
| model.set_config({"thread_num": args.threads}) |
| model.set_config({"max_new_tokens": args.max_tokens}) |
|
|
| model.load() |
| print("Model loaded.\n") |
|
|
| if args.benchmark: |
| bench_prompts = [ |
| "Hello!", |
| "What is the capital of France?", |
| "Explain quantum computing in simple terms.", |
| "Write a short poem about the ocean.", |
| "List 5 programming languages and their main use cases.", |
| ] |
| benchmark(model, bench_prompts, warmup=args.warmup, runs=args.runs) |
| elif args.prompt: |
| print(f"Prompt: {args.prompt}\n") |
| response = run_inference(model, args.prompt, stream=args.stream) |
| if not args.stream: |
| print(f"Response:\n{response}") |
|
|
| ctx = model.context |
| ctx.refresh() |
| print(f"\n--- Stats ---") |
| print(f"Prompt tokens : {ctx.prompt_len}") |
| print(f"Gen tokens : {ctx.gen_seq_len}") |
| prefill_s = ctx.prefill_us / 1e6 if ctx.prefill_us else 0 |
| decode_s = ctx.decode_us / 1e6 if ctx.decode_us else 0 |
| if prefill_s > 0: |
| print(f"Prefill : {ctx.prompt_len / prefill_s:.1f} tok/s ({prefill_s:.3f}s)") |
| if decode_s > 0: |
| print(f"Decode : {ctx.gen_seq_len / decode_s:.1f} tok/s ({decode_s:.3f}s)") |
| else: |
| |
| print("Interactive mode (type 'quit' to exit)\n") |
| while True: |
| try: |
| user_input = input("You: ").strip() |
| except (EOFError, KeyboardInterrupt): |
| print("\nBye!") |
| break |
| if user_input.lower() in ("quit", "exit"): |
| break |
| if not user_input: |
| continue |
| response = run_inference(model, user_input, stream=True) |
| ctx = model.context |
| ctx.refresh() |
| prefill_s = ctx.prefill_us / 1e6 if ctx.prefill_us else 0 |
| decode_s = ctx.decode_us / 1e6 if ctx.decode_us else 0 |
| if decode_s > 0: |
| print(f" [{ctx.gen_seq_len} tokens, {ctx.gen_seq_len / decode_s:.1f} tok/s]") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|