cephcyn's picture
Update handler.py
a872063 verified
from typing import Dict, List, Any
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
# Need to set HF_TOKEN on the endpoint creation process for this to work
model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
class EndpointHandler:
def __init__(self, path=""):
# create inference pipeline
self.pipeline = pipeline(
"text-generation",
model=model_name,
model_kwargs={"torch_dtype": torch.bfloat16},
device_map="auto",
)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
input args:
data: a dict with elements...
inputs: List[List[Dict[str, str]]] or List[str] , inputs to batch-process in conversational format
parameters: Any , parameters to be passed into model
outputs:
list of {'generated_text': str} type outputs
"""
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)
# pass inputs with all kwargs in data
if parameters is not None:
predictions = self.pipeline(inputs, **parameters)
else:
predictions = self.pipeline(inputs)
# postprocess the prediction
results = []
for e in predictions:
e_turn = e[0]["generated_text"][-1]
results.append({
'next_chat_turn': e_turn,
'next_chat_text': e_turn['content'],
})
return results