File size: 3,395 Bytes
aa94965 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 | """HF Inference Endpoint handler that runs Gemma 4 31B with MTP on/off.
Loads the 31B target and the 0.5B assistant drafter at init, then per-request
flips speculative decoding via the `use_mtp` parameter — same hardware,
same weights, only difference is the drafter being passed to .generate().
"""
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
TARGET_ID = "google/gemma-4-31B-it"
ASSISTANT_ID = "google/gemma-4-31B-it-assistant"
class EndpointHandler:
def __init__(self, path: str = ""):
# path is the repo this handler.py was deployed in — we ignore it and
# pull the canonical Gemma 4 weights directly from HF.
self.tokenizer = AutoTokenizer.from_pretrained(TARGET_ID)
self.target = AutoModelForCausalLM.from_pretrained(
TARGET_ID,
dtype=torch.bfloat16,
device_map="auto",
attn_implementation="sdpa",
)
self.target.eval()
self.assistant = AutoModelForCausalLM.from_pretrained(
ASSISTANT_ID,
dtype=torch.bfloat16,
device_map="auto",
attn_implementation="sdpa",
)
self.assistant.eval()
# one-shot warmup so the first real request isn't paying for kernel autotuning
warm = self.tokenizer("hello", return_tensors="pt").to(self.target.device)
with torch.no_grad():
_ = self.target.generate(**warm, max_new_tokens=4, do_sample=False)
_ = self.target.generate(
**warm,
max_new_tokens=4,
do_sample=True,
temperature=0.7,
assistant_model=self.assistant,
)
def __call__(self, data: dict):
prompt = data.get("inputs", "")
params = data.get("parameters", {}) or {}
max_new_tokens = int(params.get("max_new_tokens", 200))
use_mtp = bool(params.get("use_mtp", True))
do_sample = bool(params.get("do_sample", True))
temperature = float(params.get("temperature", 0.7))
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.target.device)
prompt_tokens = inputs["input_ids"].shape[1]
gen_kwargs = dict(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
)
if do_sample:
gen_kwargs["temperature"] = temperature
if use_mtp:
gen_kwargs["assistant_model"] = self.assistant
if torch.cuda.is_available():
torch.cuda.synchronize()
t0 = time.perf_counter()
with torch.no_grad():
out = self.target.generate(**gen_kwargs)
if torch.cuda.is_available():
torch.cuda.synchronize()
elapsed = time.perf_counter() - t0
gen_tokens = out.shape[1] - prompt_tokens
text = self.tokenizer.decode(
out[0][prompt_tokens:], skip_special_tokens=True
)
return [{
"generated_text": text,
"prompt_tokens": prompt_tokens,
"generated_tokens": int(gen_tokens),
"elapsed_seconds": round(elapsed, 3),
"tokens_per_second": round(gen_tokens / elapsed, 2) if elapsed > 0 else 0.0,
"use_mtp": use_mtp,
"do_sample": do_sample,
"temperature": temperature if do_sample else None,
}]
|