File size: 7,006 Bytes
c1e1bf3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 | #!/usr/bin/env python3
"""MNN LLM Inference & Benchmark script for LFM2-350M model."""
import sys
import os
import time
import argparse
import MNN.llm as llm
def run_inference(model, prompt, stream=False):
"""Run a single inference and return the response + timing context."""
model.reset()
response = model.response(prompt, stream)
if stream:
output = ""
for chunk in response:
print(chunk, end="", flush=True)
output += chunk
print()
return output
return response
def benchmark(model, prompts, warmup=1, runs=3):
"""Benchmark prefill and decode performance across multiple prompts."""
print("=" * 60)
print("BENCHMARK")
print("=" * 60)
# Warmup
print(f"\nWarmup ({warmup} run(s))...")
for i in range(warmup):
model.reset()
model.response(prompts[0], False)
results = []
for idx, prompt in enumerate(prompts):
prompt_results = []
for run in range(runs):
model.reset()
t0 = time.perf_counter()
response = model.response(prompt, False)
t1 = time.perf_counter()
wall_time = t1 - t0
ctx = model.context
ctx.refresh()
prompt_tokens = ctx.prompt_len
gen_tokens = ctx.gen_seq_len
prefill_us = ctx.prefill_us
decode_us = ctx.decode_us
prefill_s = prefill_us / 1e6 if prefill_us else 0
decode_s = decode_us / 1e6 if decode_us else 0
prefill_tps = prompt_tokens / prefill_s if prefill_s > 0 else 0
decode_tps = gen_tokens / decode_s if decode_s > 0 else 0
prompt_results.append({
"prompt_tokens": prompt_tokens,
"gen_tokens": gen_tokens,
"wall_time": wall_time,
"prefill_s": prefill_s,
"decode_s": decode_s,
"prefill_tps": prefill_tps,
"decode_tps": decode_tps,
"response": response,
})
results.append(prompt_results)
# Print per-prompt summary
avg_prefill_tps = sum(r["prefill_tps"] for r in prompt_results) / runs
avg_decode_tps = sum(r["decode_tps"] for r in prompt_results) / runs
avg_wall = sum(r["wall_time"] for r in prompt_results) / runs
prompt_tokens = prompt_results[0]["prompt_tokens"]
avg_gen = sum(r["gen_tokens"] for r in prompt_results) / runs
print(f"\nPrompt {idx + 1}: \"{prompt[:60]}{'...' if len(prompt) > 60 else ''}\"")
print(f" Prompt tokens : {prompt_tokens}")
print(f" Avg gen tokens : {avg_gen:.1f}")
print(f" Avg wall time : {avg_wall:.3f} s")
print(f" Avg prefill : {avg_prefill_tps:.1f} tok/s")
print(f" Avg decode : {avg_decode_tps:.1f} tok/s")
# Overall summary
all_runs = [r for pr in results for r in pr]
overall_prefill = sum(r["prefill_tps"] for r in all_runs) / len(all_runs)
overall_decode = sum(r["decode_tps"] for r in all_runs) / len(all_runs)
print("\n" + "=" * 60)
print(f"Overall avg prefill : {overall_prefill:.1f} tok/s")
print(f"Overall avg decode : {overall_decode:.1f} tok/s")
print("=" * 60)
return results
def main():
parser = argparse.ArgumentParser(description="MNN LLM Inference & Benchmark")
parser.add_argument("--config", default="config.json",
help="Path to MNN config.json (default: config.json)")
parser.add_argument("--prompt", default=None,
help="Single prompt for inference")
parser.add_argument("--stream", action="store_true",
help="Stream output tokens")
parser.add_argument("--benchmark", action="store_true",
help="Run benchmark suite")
parser.add_argument("--warmup", type=int, default=1,
help="Warmup runs for benchmark (default: 1)")
parser.add_argument("--runs", type=int, default=3,
help="Benchmark runs per prompt (default: 3)")
parser.add_argument("--backend", default=None,
choices=["cpu", "metal"],
help="Override backend type")
parser.add_argument("--threads", type=int, default=None,
help="Override thread count")
parser.add_argument("--max-tokens", type=int, default=128,
help="Max tokens to generate (default: 128)")
args = parser.parse_args()
model_dir = os.path.dirname(os.path.abspath(args.config))
config_path = os.path.abspath(args.config)
print(f"Loading model from: {config_path}")
model = llm.create(config_path)
if args.backend:
model.set_config({"backend_type": args.backend})
if args.threads:
model.set_config({"thread_num": args.threads})
model.set_config({"max_new_tokens": args.max_tokens})
model.load()
print("Model loaded.\n")
if args.benchmark:
bench_prompts = [
"Hello!",
"What is the capital of France?",
"Explain quantum computing in simple terms.",
"Write a short poem about the ocean.",
"List 5 programming languages and their main use cases.",
]
benchmark(model, bench_prompts, warmup=args.warmup, runs=args.runs)
elif args.prompt:
print(f"Prompt: {args.prompt}\n")
response = run_inference(model, args.prompt, stream=args.stream)
if not args.stream:
print(f"Response:\n{response}")
ctx = model.context
ctx.refresh()
print(f"\n--- Stats ---")
print(f"Prompt tokens : {ctx.prompt_len}")
print(f"Gen tokens : {ctx.gen_seq_len}")
prefill_s = ctx.prefill_us / 1e6 if ctx.prefill_us else 0
decode_s = ctx.decode_us / 1e6 if ctx.decode_us else 0
if prefill_s > 0:
print(f"Prefill : {ctx.prompt_len / prefill_s:.1f} tok/s ({prefill_s:.3f}s)")
if decode_s > 0:
print(f"Decode : {ctx.gen_seq_len / decode_s:.1f} tok/s ({decode_s:.3f}s)")
else:
# Interactive mode
print("Interactive mode (type 'quit' to exit)\n")
while True:
try:
user_input = input("You: ").strip()
except (EOFError, KeyboardInterrupt):
print("\nBye!")
break
if user_input.lower() in ("quit", "exit"):
break
if not user_input:
continue
response = run_inference(model, user_input, stream=True)
ctx = model.context
ctx.refresh()
prefill_s = ctx.prefill_us / 1e6 if ctx.prefill_us else 0
decode_s = ctx.decode_us / 1e6 if ctx.decode_us else 0
if decode_s > 0:
print(f" [{ctx.gen_seq_len} tokens, {ctx.gen_seq_len / decode_s:.1f} tok/s]")
if __name__ == "__main__":
main()
|