#!/usr/bin/env python3 """MNN LLM Inference & Benchmark script for LFM2-350M model.""" import sys import os import time import argparse import MNN.llm as llm def run_inference(model, prompt, stream=False): """Run a single inference and return the response + timing context.""" model.reset() response = model.response(prompt, stream) if stream: output = "" for chunk in response: print(chunk, end="", flush=True) output += chunk print() return output return response def benchmark(model, prompts, warmup=1, runs=3): """Benchmark prefill and decode performance across multiple prompts.""" print("=" * 60) print("BENCHMARK") print("=" * 60) # Warmup print(f"\nWarmup ({warmup} run(s))...") for i in range(warmup): model.reset() model.response(prompts[0], False) results = [] for idx, prompt in enumerate(prompts): prompt_results = [] for run in range(runs): model.reset() t0 = time.perf_counter() response = model.response(prompt, False) t1 = time.perf_counter() wall_time = t1 - t0 ctx = model.context ctx.refresh() prompt_tokens = ctx.prompt_len gen_tokens = ctx.gen_seq_len prefill_us = ctx.prefill_us decode_us = ctx.decode_us prefill_s = prefill_us / 1e6 if prefill_us else 0 decode_s = decode_us / 1e6 if decode_us else 0 prefill_tps = prompt_tokens / prefill_s if prefill_s > 0 else 0 decode_tps = gen_tokens / decode_s if decode_s > 0 else 0 prompt_results.append({ "prompt_tokens": prompt_tokens, "gen_tokens": gen_tokens, "wall_time": wall_time, "prefill_s": prefill_s, "decode_s": decode_s, "prefill_tps": prefill_tps, "decode_tps": decode_tps, "response": response, }) results.append(prompt_results) # Print per-prompt summary avg_prefill_tps = sum(r["prefill_tps"] for r in prompt_results) / runs avg_decode_tps = sum(r["decode_tps"] for r in prompt_results) / runs avg_wall = sum(r["wall_time"] for r in prompt_results) / runs prompt_tokens = prompt_results[0]["prompt_tokens"] avg_gen = sum(r["gen_tokens"] for r in prompt_results) / runs print(f"\nPrompt {idx + 1}: \"{prompt[:60]}{'...' if len(prompt) > 60 else ''}\"") print(f" Prompt tokens : {prompt_tokens}") print(f" Avg gen tokens : {avg_gen:.1f}") print(f" Avg wall time : {avg_wall:.3f} s") print(f" Avg prefill : {avg_prefill_tps:.1f} tok/s") print(f" Avg decode : {avg_decode_tps:.1f} tok/s") # Overall summary all_runs = [r for pr in results for r in pr] overall_prefill = sum(r["prefill_tps"] for r in all_runs) / len(all_runs) overall_decode = sum(r["decode_tps"] for r in all_runs) / len(all_runs) print("\n" + "=" * 60) print(f"Overall avg prefill : {overall_prefill:.1f} tok/s") print(f"Overall avg decode : {overall_decode:.1f} tok/s") print("=" * 60) return results def main(): parser = argparse.ArgumentParser(description="MNN LLM Inference & Benchmark") parser.add_argument("--config", default="config.json", help="Path to MNN config.json (default: config.json)") parser.add_argument("--prompt", default=None, help="Single prompt for inference") parser.add_argument("--stream", action="store_true", help="Stream output tokens") parser.add_argument("--benchmark", action="store_true", help="Run benchmark suite") parser.add_argument("--warmup", type=int, default=1, help="Warmup runs for benchmark (default: 1)") parser.add_argument("--runs", type=int, default=3, help="Benchmark runs per prompt (default: 3)") parser.add_argument("--backend", default=None, choices=["cpu", "metal"], help="Override backend type") parser.add_argument("--threads", type=int, default=None, help="Override thread count") parser.add_argument("--max-tokens", type=int, default=128, help="Max tokens to generate (default: 128)") args = parser.parse_args() model_dir = os.path.dirname(os.path.abspath(args.config)) config_path = os.path.abspath(args.config) print(f"Loading model from: {config_path}") model = llm.create(config_path) if args.backend: model.set_config({"backend_type": args.backend}) if args.threads: model.set_config({"thread_num": args.threads}) model.set_config({"max_new_tokens": args.max_tokens}) model.load() print("Model loaded.\n") if args.benchmark: bench_prompts = [ "Hello!", "What is the capital of France?", "Explain quantum computing in simple terms.", "Write a short poem about the ocean.", "List 5 programming languages and their main use cases.", ] benchmark(model, bench_prompts, warmup=args.warmup, runs=args.runs) elif args.prompt: print(f"Prompt: {args.prompt}\n") response = run_inference(model, args.prompt, stream=args.stream) if not args.stream: print(f"Response:\n{response}") ctx = model.context ctx.refresh() print(f"\n--- Stats ---") print(f"Prompt tokens : {ctx.prompt_len}") print(f"Gen tokens : {ctx.gen_seq_len}") prefill_s = ctx.prefill_us / 1e6 if ctx.prefill_us else 0 decode_s = ctx.decode_us / 1e6 if ctx.decode_us else 0 if prefill_s > 0: print(f"Prefill : {ctx.prompt_len / prefill_s:.1f} tok/s ({prefill_s:.3f}s)") if decode_s > 0: print(f"Decode : {ctx.gen_seq_len / decode_s:.1f} tok/s ({decode_s:.3f}s)") else: # Interactive mode print("Interactive mode (type 'quit' to exit)\n") while True: try: user_input = input("You: ").strip() except (EOFError, KeyboardInterrupt): print("\nBye!") break if user_input.lower() in ("quit", "exit"): break if not user_input: continue response = run_inference(model, user_input, stream=True) ctx = model.context ctx.refresh() prefill_s = ctx.prefill_us / 1e6 if ctx.prefill_us else 0 decode_s = ctx.decode_us / 1e6 if ctx.decode_us else 0 if decode_s > 0: print(f" [{ctx.gen_seq_len} tokens, {ctx.gen_seq_len / decode_s:.1f} tok/s]") if __name__ == "__main__": main()