LFM2-350M-Extract-MNN / inference.py
cyberfly
Add LFM2-350M MNN model files and inference script
c1e1bf3
#!/usr/bin/env python3
"""MNN LLM Inference & Benchmark script for LFM2-350M model."""
import sys
import os
import time
import argparse
import MNN.llm as llm
def run_inference(model, prompt, stream=False):
"""Run a single inference and return the response + timing context."""
model.reset()
response = model.response(prompt, stream)
if stream:
output = ""
for chunk in response:
print(chunk, end="", flush=True)
output += chunk
print()
return output
return response
def benchmark(model, prompts, warmup=1, runs=3):
"""Benchmark prefill and decode performance across multiple prompts."""
print("=" * 60)
print("BENCHMARK")
print("=" * 60)
# Warmup
print(f"\nWarmup ({warmup} run(s))...")
for i in range(warmup):
model.reset()
model.response(prompts[0], False)
results = []
for idx, prompt in enumerate(prompts):
prompt_results = []
for run in range(runs):
model.reset()
t0 = time.perf_counter()
response = model.response(prompt, False)
t1 = time.perf_counter()
wall_time = t1 - t0
ctx = model.context
ctx.refresh()
prompt_tokens = ctx.prompt_len
gen_tokens = ctx.gen_seq_len
prefill_us = ctx.prefill_us
decode_us = ctx.decode_us
prefill_s = prefill_us / 1e6 if prefill_us else 0
decode_s = decode_us / 1e6 if decode_us else 0
prefill_tps = prompt_tokens / prefill_s if prefill_s > 0 else 0
decode_tps = gen_tokens / decode_s if decode_s > 0 else 0
prompt_results.append({
"prompt_tokens": prompt_tokens,
"gen_tokens": gen_tokens,
"wall_time": wall_time,
"prefill_s": prefill_s,
"decode_s": decode_s,
"prefill_tps": prefill_tps,
"decode_tps": decode_tps,
"response": response,
})
results.append(prompt_results)
# Print per-prompt summary
avg_prefill_tps = sum(r["prefill_tps"] for r in prompt_results) / runs
avg_decode_tps = sum(r["decode_tps"] for r in prompt_results) / runs
avg_wall = sum(r["wall_time"] for r in prompt_results) / runs
prompt_tokens = prompt_results[0]["prompt_tokens"]
avg_gen = sum(r["gen_tokens"] for r in prompt_results) / runs
print(f"\nPrompt {idx + 1}: \"{prompt[:60]}{'...' if len(prompt) > 60 else ''}\"")
print(f" Prompt tokens : {prompt_tokens}")
print(f" Avg gen tokens : {avg_gen:.1f}")
print(f" Avg wall time : {avg_wall:.3f} s")
print(f" Avg prefill : {avg_prefill_tps:.1f} tok/s")
print(f" Avg decode : {avg_decode_tps:.1f} tok/s")
# Overall summary
all_runs = [r for pr in results for r in pr]
overall_prefill = sum(r["prefill_tps"] for r in all_runs) / len(all_runs)
overall_decode = sum(r["decode_tps"] for r in all_runs) / len(all_runs)
print("\n" + "=" * 60)
print(f"Overall avg prefill : {overall_prefill:.1f} tok/s")
print(f"Overall avg decode : {overall_decode:.1f} tok/s")
print("=" * 60)
return results
def main():
parser = argparse.ArgumentParser(description="MNN LLM Inference & Benchmark")
parser.add_argument("--config", default="config.json",
help="Path to MNN config.json (default: config.json)")
parser.add_argument("--prompt", default=None,
help="Single prompt for inference")
parser.add_argument("--stream", action="store_true",
help="Stream output tokens")
parser.add_argument("--benchmark", action="store_true",
help="Run benchmark suite")
parser.add_argument("--warmup", type=int, default=1,
help="Warmup runs for benchmark (default: 1)")
parser.add_argument("--runs", type=int, default=3,
help="Benchmark runs per prompt (default: 3)")
parser.add_argument("--backend", default=None,
choices=["cpu", "metal"],
help="Override backend type")
parser.add_argument("--threads", type=int, default=None,
help="Override thread count")
parser.add_argument("--max-tokens", type=int, default=128,
help="Max tokens to generate (default: 128)")
args = parser.parse_args()
model_dir = os.path.dirname(os.path.abspath(args.config))
config_path = os.path.abspath(args.config)
print(f"Loading model from: {config_path}")
model = llm.create(config_path)
if args.backend:
model.set_config({"backend_type": args.backend})
if args.threads:
model.set_config({"thread_num": args.threads})
model.set_config({"max_new_tokens": args.max_tokens})
model.load()
print("Model loaded.\n")
if args.benchmark:
bench_prompts = [
"Hello!",
"What is the capital of France?",
"Explain quantum computing in simple terms.",
"Write a short poem about the ocean.",
"List 5 programming languages and their main use cases.",
]
benchmark(model, bench_prompts, warmup=args.warmup, runs=args.runs)
elif args.prompt:
print(f"Prompt: {args.prompt}\n")
response = run_inference(model, args.prompt, stream=args.stream)
if not args.stream:
print(f"Response:\n{response}")
ctx = model.context
ctx.refresh()
print(f"\n--- Stats ---")
print(f"Prompt tokens : {ctx.prompt_len}")
print(f"Gen tokens : {ctx.gen_seq_len}")
prefill_s = ctx.prefill_us / 1e6 if ctx.prefill_us else 0
decode_s = ctx.decode_us / 1e6 if ctx.decode_us else 0
if prefill_s > 0:
print(f"Prefill : {ctx.prompt_len / prefill_s:.1f} tok/s ({prefill_s:.3f}s)")
if decode_s > 0:
print(f"Decode : {ctx.gen_seq_len / decode_s:.1f} tok/s ({decode_s:.3f}s)")
else:
# Interactive mode
print("Interactive mode (type 'quit' to exit)\n")
while True:
try:
user_input = input("You: ").strip()
except (EOFError, KeyboardInterrupt):
print("\nBye!")
break
if user_input.lower() in ("quit", "exit"):
break
if not user_input:
continue
response = run_inference(model, user_input, stream=True)
ctx = model.context
ctx.refresh()
prefill_s = ctx.prefill_us / 1e6 if ctx.prefill_us else 0
decode_s = ctx.decode_us / 1e6 if ctx.decode_us else 0
if decode_s > 0:
print(f" [{ctx.gen_seq_len} tokens, {ctx.gen_seq_len / decode_s:.1f} tok/s]")
if __name__ == "__main__":
main()