"""HF Inference Endpoint handler that runs Gemma 4 31B with MTP on/off. Loads the 31B target and the 0.5B assistant drafter at init, then per-request flips speculative decoding via the `use_mtp` parameter — same hardware, same weights, only difference is the drafter being passed to .generate(). """ import time import torch from transformers import AutoModelForCausalLM, AutoTokenizer TARGET_ID = "google/gemma-4-31B-it" ASSISTANT_ID = "google/gemma-4-31B-it-assistant" class EndpointHandler: def __init__(self, path: str = ""): # path is the repo this handler.py was deployed in — we ignore it and # pull the canonical Gemma 4 weights directly from HF. self.tokenizer = AutoTokenizer.from_pretrained(TARGET_ID) self.target = AutoModelForCausalLM.from_pretrained( TARGET_ID, dtype=torch.bfloat16, device_map="auto", attn_implementation="sdpa", ) self.target.eval() self.assistant = AutoModelForCausalLM.from_pretrained( ASSISTANT_ID, dtype=torch.bfloat16, device_map="auto", attn_implementation="sdpa", ) self.assistant.eval() # one-shot warmup so the first real request isn't paying for kernel autotuning warm = self.tokenizer("hello", return_tensors="pt").to(self.target.device) with torch.no_grad(): _ = self.target.generate(**warm, max_new_tokens=4, do_sample=False) _ = self.target.generate( **warm, max_new_tokens=4, do_sample=True, temperature=0.7, assistant_model=self.assistant, ) def __call__(self, data: dict): prompt = data.get("inputs", "") params = data.get("parameters", {}) or {} max_new_tokens = int(params.get("max_new_tokens", 200)) use_mtp = bool(params.get("use_mtp", True)) do_sample = bool(params.get("do_sample", True)) temperature = float(params.get("temperature", 0.7)) inputs = self.tokenizer(prompt, return_tensors="pt").to(self.target.device) prompt_tokens = inputs["input_ids"].shape[1] gen_kwargs = dict( **inputs, max_new_tokens=max_new_tokens, do_sample=do_sample, ) if do_sample: gen_kwargs["temperature"] = temperature if use_mtp: gen_kwargs["assistant_model"] = self.assistant if torch.cuda.is_available(): torch.cuda.synchronize() t0 = time.perf_counter() with torch.no_grad(): out = self.target.generate(**gen_kwargs) if torch.cuda.is_available(): torch.cuda.synchronize() elapsed = time.perf_counter() - t0 gen_tokens = out.shape[1] - prompt_tokens text = self.tokenizer.decode( out[0][prompt_tokens:], skip_special_tokens=True ) return [{ "generated_text": text, "prompt_tokens": prompt_tokens, "generated_tokens": int(gen_tokens), "elapsed_seconds": round(elapsed, 3), "tokens_per_second": round(gen_tokens / elapsed, 2) if elapsed > 0 else 0.0, "use_mtp": use_mtp, "do_sample": do_sample, "temperature": temperature if do_sample else None, }]