| """HF Inference Endpoint handler that runs Gemma 4 31B with MTP on/off. |
| |
| Loads the 31B target and the 0.5B assistant drafter at init, then per-request |
| flips speculative decoding via the `use_mtp` parameter — same hardware, |
| same weights, only difference is the drafter being passed to .generate(). |
| """ |
| import time |
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| TARGET_ID = "google/gemma-4-31B-it" |
| ASSISTANT_ID = "google/gemma-4-31B-it-assistant" |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path: str = ""): |
| |
| |
| self.tokenizer = AutoTokenizer.from_pretrained(TARGET_ID) |
| self.target = AutoModelForCausalLM.from_pretrained( |
| TARGET_ID, |
| dtype=torch.bfloat16, |
| device_map="auto", |
| attn_implementation="sdpa", |
| ) |
| self.target.eval() |
| self.assistant = AutoModelForCausalLM.from_pretrained( |
| ASSISTANT_ID, |
| dtype=torch.bfloat16, |
| device_map="auto", |
| attn_implementation="sdpa", |
| ) |
| self.assistant.eval() |
| |
| warm = self.tokenizer("hello", return_tensors="pt").to(self.target.device) |
| with torch.no_grad(): |
| _ = self.target.generate(**warm, max_new_tokens=4, do_sample=False) |
| _ = self.target.generate( |
| **warm, |
| max_new_tokens=4, |
| do_sample=True, |
| temperature=0.7, |
| assistant_model=self.assistant, |
| ) |
|
|
| def __call__(self, data: dict): |
| prompt = data.get("inputs", "") |
| params = data.get("parameters", {}) or {} |
| max_new_tokens = int(params.get("max_new_tokens", 200)) |
| use_mtp = bool(params.get("use_mtp", True)) |
| do_sample = bool(params.get("do_sample", True)) |
| temperature = float(params.get("temperature", 0.7)) |
|
|
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.target.device) |
| prompt_tokens = inputs["input_ids"].shape[1] |
|
|
| gen_kwargs = dict( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| do_sample=do_sample, |
| ) |
| if do_sample: |
| gen_kwargs["temperature"] = temperature |
| if use_mtp: |
| gen_kwargs["assistant_model"] = self.assistant |
|
|
| if torch.cuda.is_available(): |
| torch.cuda.synchronize() |
| t0 = time.perf_counter() |
| with torch.no_grad(): |
| out = self.target.generate(**gen_kwargs) |
| if torch.cuda.is_available(): |
| torch.cuda.synchronize() |
| elapsed = time.perf_counter() - t0 |
|
|
| gen_tokens = out.shape[1] - prompt_tokens |
| text = self.tokenizer.decode( |
| out[0][prompt_tokens:], skip_special_tokens=True |
| ) |
| return [{ |
| "generated_text": text, |
| "prompt_tokens": prompt_tokens, |
| "generated_tokens": int(gen_tokens), |
| "elapsed_seconds": round(elapsed, 3), |
| "tokens_per_second": round(gen_tokens / elapsed, 2) if elapsed > 0 else 0.0, |
| "use_mtp": use_mtp, |
| "do_sample": do_sample, |
| "temperature": temperature if do_sample else None, |
| }] |
|
|