| import os |
| from typing import Any, Dict |
|
|
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| BASE_MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct" |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path: str = "") -> None: |
| token = ( |
| os.environ.get("HF_TOKEN") |
| or os.environ.get("HUGGING_FACE_HUB_TOKEN") |
| or os.environ.get("HUGGINGFACE_HUB_TOKEN") |
| ) |
| if not token: |
| raise RuntimeError( |
| "HF_TOKEN is not set. Add it as a secret on the Inference Endpoint " |
| "so the handler can download the gated meta-llama/Meta-Llama-3.1-8B-Instruct weights." |
| ) |
|
|
| tokenizer_source = path or BASE_MODEL |
| self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_source) |
| self.model = AutoModelForCausalLM.from_pretrained( |
| BASE_MODEL, |
| token=token, |
| device_map="auto", |
| torch_dtype=torch.bfloat16, |
| ) |
| self.model.eval() |
|
|
| if self.tokenizer.pad_token_id is None: |
| self.tokenizer.pad_token_id = self.tokenizer.eos_token_id |
|
|
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| inputs_payload = data.get("inputs", data) |
| messages = ( |
| inputs_payload.get("messages") |
| if isinstance(inputs_payload, dict) |
| else None |
| ) or data.get("messages") |
|
|
| if not messages: |
| raise ValueError( |
| "Request payload must include a 'messages' list, e.g. " |
| '{"inputs": {"messages": [{"role": "user", "content": "hi"}]}}.' |
| ) |
|
|
| parameters: Dict[str, Any] = data.get("parameters") or {} |
| max_new_tokens = int(parameters.get("max_new_tokens", 256)) |
| do_sample = bool(parameters.get("do_sample", False)) |
| temperature = float(parameters.get("temperature", 0.7)) |
| top_p = float(parameters.get("top_p", 0.9)) |
|
|
| inputs = self.tokenizer.apply_chat_template( |
| messages, |
| add_generation_prompt=True, |
| tokenize=True, |
| return_dict=True, |
| return_tensors="pt", |
| ).to(self.model.device) |
|
|
| generate_kwargs: Dict[str, Any] = { |
| "max_new_tokens": max_new_tokens, |
| "do_sample": do_sample, |
| "pad_token_id": self.tokenizer.pad_token_id, |
| "eos_token_id": self.tokenizer.eos_token_id, |
| } |
| if do_sample: |
| generate_kwargs["temperature"] = temperature |
| generate_kwargs["top_p"] = top_p |
|
|
| with torch.inference_mode(): |
| outputs = self.model.generate(**inputs, **generate_kwargs) |
|
|
| prompt_len = inputs["input_ids"].shape[-1] |
| decoded = self.tokenizer.decode( |
| outputs[0][prompt_len:], |
| skip_special_tokens=True, |
| ) |
|
|
| return {"generated_text": decoded} |
|
|