test-Llama / handler.py
ecmendez25's picture
Bundle Llama 3.1 tokenizer locally; speed up cold start
802eb23
import os
from typing import Any, Dict
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
BASE_MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct"
class EndpointHandler:
def __init__(self, path: str = "") -> None:
token = (
os.environ.get("HF_TOKEN")
or os.environ.get("HUGGING_FACE_HUB_TOKEN")
or os.environ.get("HUGGINGFACE_HUB_TOKEN")
)
if not token:
raise RuntimeError(
"HF_TOKEN is not set. Add it as a secret on the Inference Endpoint "
"so the handler can download the gated meta-llama/Meta-Llama-3.1-8B-Instruct weights."
)
tokenizer_source = path or BASE_MODEL
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_source)
self.model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
token=token,
device_map="auto",
torch_dtype=torch.bfloat16,
)
self.model.eval()
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
inputs_payload = data.get("inputs", data)
messages = (
inputs_payload.get("messages")
if isinstance(inputs_payload, dict)
else None
) or data.get("messages")
if not messages:
raise ValueError(
"Request payload must include a 'messages' list, e.g. "
'{"inputs": {"messages": [{"role": "user", "content": "hi"}]}}.'
)
parameters: Dict[str, Any] = data.get("parameters") or {}
max_new_tokens = int(parameters.get("max_new_tokens", 256))
do_sample = bool(parameters.get("do_sample", False))
temperature = float(parameters.get("temperature", 0.7))
top_p = float(parameters.get("top_p", 0.9))
inputs = self.tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(self.model.device)
generate_kwargs: Dict[str, Any] = {
"max_new_tokens": max_new_tokens,
"do_sample": do_sample,
"pad_token_id": self.tokenizer.pad_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
}
if do_sample:
generate_kwargs["temperature"] = temperature
generate_kwargs["top_p"] = top_p
with torch.inference_mode():
outputs = self.model.generate(**inputs, **generate_kwargs)
prompt_len = inputs["input_ids"].shape[-1]
decoded = self.tokenizer.decode(
outputs[0][prompt_len:],
skip_special_tokens=True,
)
return {"generated_text": decoded}