gemma4-mtp-bench / bench.py
Blongo1's picture
add MTP-on-vs-off bench handler
aa94965 verified
#!/usr/bin/env python3
"""Hit the deployed endpoint with same prompt twice — MTP off, MTP on.
Set ENDPOINT_URL and HF_TOKEN env vars before running.
"""
import os, json, sys
import urllib.request
URL = os.environ.get("ENDPOINT_URL")
TOKEN = os.environ.get("HF_TOKEN")
if not URL or not TOKEN:
print("set ENDPOINT_URL and HF_TOKEN env vars first", file=sys.stderr)
sys.exit(1)
PROMPT = (
"Write a Python function `decompose_modal(sentence)` that takes a legal "
"sentence and returns (modal_word, sentence_without_modal). Handle: shall, "
"must, may, will, should. Return (None, sentence) if no modal found. "
"Include 5 example calls."
)
def call(use_mtp: bool):
payload = {
"inputs": PROMPT,
"parameters": {
"max_new_tokens": 300,
"use_mtp": use_mtp,
"do_sample": True,
"temperature": 0.7,
},
}
req = urllib.request.Request(
URL,
data=json.dumps(payload).encode(),
headers={
"Authorization": f"Bearer {TOKEN}",
"Content-Type": "application/json",
},
)
with urllib.request.urlopen(req, timeout=600) as r:
return json.load(r)
def fmt(label, result):
r = result[0] if isinstance(result, list) else result
print(f"--- {label} ---")
print(f" use_mtp: {r.get('use_mtp')}")
print(f" generated_tokens: {r.get('generated_tokens')}")
print(f" elapsed_seconds: {r.get('elapsed_seconds')}")
print(f" tokens_per_second: {r.get('tokens_per_second')}")
print(f" text excerpt: {(r.get('generated_text') or '')[:160]!r}")
print()
if __name__ == "__main__":
print("warming...")
call(use_mtp=False) # warm both paths once before timing
call(use_mtp=True)
print("\n=== TIMED RUNS ===\n")
for run in range(3):
print(f"### run {run + 1} ###")
fmt("MTP OFF", call(use_mtp=False))
fmt("MTP ON ", call(use_mtp=True))