File size: 818 Bytes
812f8d3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 | import typing
import transformers
class EndpointHandler():
def __init__(self, path: str = ""):
self.pipeline = transformers.pipeline("text-generation", model=path)
self.pipeline.model.load_adapter(path)
def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.List[typing.Dict[str, typing.Any]]:
"""
data args:
inputs (:obj: `str`)
max_new_tokens (:obj: `int`)
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
return self.pipeline(
(
self.pipeline
.tokenizer
.apply_chat_template(data["inputs"], tokenize=False)
),
max_new_tokens=data["max_new_tokens"],
return_full_text=False
) |