simon-muenker's picture
Create handler.py
812f8d3 verified
raw
history blame contribute delete
818 Bytes
import typing
import transformers
class EndpointHandler():
def __init__(self, path: str = ""):
self.pipeline = transformers.pipeline("text-generation", model=path)
self.pipeline.model.load_adapter(path)
def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.List[typing.Dict[str, typing.Any]]:
"""
data args:
inputs (:obj: `str`)
max_new_tokens (:obj: `int`)
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
return self.pipeline(
(
self.pipeline
.tokenizer
.apply_chat_template(data["inputs"], tokenize=False)
),
max_new_tokens=data["max_new_tokens"],
return_full_text=False
)