import os from typing import Any, Dict import torch from transformers import AutoModelForCausalLM, AutoTokenizer BASE_MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct" class EndpointHandler: def __init__(self, path: str = "") -> None: token = ( os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") ) if not token: raise RuntimeError( "HF_TOKEN is not set. Add it as a secret on the Inference Endpoint " "so the handler can download the gated meta-llama/Meta-Llama-3.1-8B-Instruct weights." ) tokenizer_source = path or BASE_MODEL self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_source) self.model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, token=token, device_map="auto", torch_dtype=torch.bfloat16, ) self.model.eval() if self.tokenizer.pad_token_id is None: self.tokenizer.pad_token_id = self.tokenizer.eos_token_id def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: inputs_payload = data.get("inputs", data) messages = ( inputs_payload.get("messages") if isinstance(inputs_payload, dict) else None ) or data.get("messages") if not messages: raise ValueError( "Request payload must include a 'messages' list, e.g. " '{"inputs": {"messages": [{"role": "user", "content": "hi"}]}}.' ) parameters: Dict[str, Any] = data.get("parameters") or {} max_new_tokens = int(parameters.get("max_new_tokens", 256)) do_sample = bool(parameters.get("do_sample", False)) temperature = float(parameters.get("temperature", 0.7)) top_p = float(parameters.get("top_p", 0.9)) inputs = self.tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(self.model.device) generate_kwargs: Dict[str, Any] = { "max_new_tokens": max_new_tokens, "do_sample": do_sample, "pad_token_id": self.tokenizer.pad_token_id, "eos_token_id": self.tokenizer.eos_token_id, } if do_sample: generate_kwargs["temperature"] = temperature generate_kwargs["top_p"] = top_p with torch.inference_mode(): outputs = self.model.generate(**inputs, **generate_kwargs) prompt_len = inputs["input_ids"].shape[-1] decoded = self.tokenizer.decode( outputs[0][prompt_len:], skip_special_tokens=True, ) return {"generated_text": decoded}