from typing import Dict, List, Any import torch from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline # Need to set HF_TOKEN on the endpoint creation process for this to work model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct" class EndpointHandler: def __init__(self, path=""): # create inference pipeline self.pipeline = pipeline( "text-generation", model=model_name, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto", ) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ input args: data: a dict with elements... inputs: List[List[Dict[str, str]]] or List[str] , inputs to batch-process in conversational format parameters: Any , parameters to be passed into model outputs: list of {'generated_text': str} type outputs """ inputs = data.pop("inputs", data) parameters = data.pop("parameters", None) # pass inputs with all kwargs in data if parameters is not None: predictions = self.pipeline(inputs, **parameters) else: predictions = self.pipeline(inputs) # postprocess the prediction results = [] for e in predictions: e_turn = e[0]["generated_text"][-1] results.append({ 'next_chat_turn': e_turn, 'next_chat_text': e_turn['content'], }) return results