File size: 3,395 Bytes
aa94965
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""HF Inference Endpoint handler that runs Gemma 4 31B with MTP on/off.

Loads the 31B target and the 0.5B assistant drafter at init, then per-request
flips speculative decoding via the `use_mtp` parameter — same hardware,
same weights, only difference is the drafter being passed to .generate().
"""
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

TARGET_ID = "google/gemma-4-31B-it"
ASSISTANT_ID = "google/gemma-4-31B-it-assistant"


class EndpointHandler:
    def __init__(self, path: str = ""):
        # path is the repo this handler.py was deployed in — we ignore it and
        # pull the canonical Gemma 4 weights directly from HF.
        self.tokenizer = AutoTokenizer.from_pretrained(TARGET_ID)
        self.target = AutoModelForCausalLM.from_pretrained(
            TARGET_ID,
            dtype=torch.bfloat16,
            device_map="auto",
            attn_implementation="sdpa",
        )
        self.target.eval()
        self.assistant = AutoModelForCausalLM.from_pretrained(
            ASSISTANT_ID,
            dtype=torch.bfloat16,
            device_map="auto",
            attn_implementation="sdpa",
        )
        self.assistant.eval()
        # one-shot warmup so the first real request isn't paying for kernel autotuning
        warm = self.tokenizer("hello", return_tensors="pt").to(self.target.device)
        with torch.no_grad():
            _ = self.target.generate(**warm, max_new_tokens=4, do_sample=False)
            _ = self.target.generate(
                **warm,
                max_new_tokens=4,
                do_sample=True,
                temperature=0.7,
                assistant_model=self.assistant,
            )

    def __call__(self, data: dict):
        prompt = data.get("inputs", "")
        params = data.get("parameters", {}) or {}
        max_new_tokens = int(params.get("max_new_tokens", 200))
        use_mtp = bool(params.get("use_mtp", True))
        do_sample = bool(params.get("do_sample", True))
        temperature = float(params.get("temperature", 0.7))

        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.target.device)
        prompt_tokens = inputs["input_ids"].shape[1]

        gen_kwargs = dict(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=do_sample,
        )
        if do_sample:
            gen_kwargs["temperature"] = temperature
        if use_mtp:
            gen_kwargs["assistant_model"] = self.assistant

        if torch.cuda.is_available():
            torch.cuda.synchronize()
        t0 = time.perf_counter()
        with torch.no_grad():
            out = self.target.generate(**gen_kwargs)
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        elapsed = time.perf_counter() - t0

        gen_tokens = out.shape[1] - prompt_tokens
        text = self.tokenizer.decode(
            out[0][prompt_tokens:], skip_special_tokens=True
        )
        return [{
            "generated_text": text,
            "prompt_tokens": prompt_tokens,
            "generated_tokens": int(gen_tokens),
            "elapsed_seconds": round(elapsed, 3),
            "tokens_per_second": round(gen_tokens / elapsed, 2) if elapsed > 0 else 0.0,
            "use_mtp": use_mtp,
            "do_sample": do_sample,
            "temperature": temperature if do_sample else None,
        }]