| from typing import Dict, List, Any | |
| import torch | |
| from transformers import pipeline, set_seed | |
| class EndpointHandler: | |
| def __init__(self, path=""): | |
| self.pipeline = pipeline( | |
| "text-generation", | |
| model="openai-community/gpt2", | |
| device_map='auto', | |
| #trust_remote_code=True, | |
| model_kwargs={ | |
| "load_in_4bit": True | |
| }, | |
| # batch_size=1, | |
| ) | |
| # model.generation_config | |
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: | |
| """ | |
| data args: | |
| inputs (:obj: `str`) | |
| parameters (:obj: `Dict`) | |
| Return: | |
| A :obj:`list` | `dict`: will be serialized and returned | |
| """ | |
| # get inputs | |
| inputs = data.pop("inputs", "") | |
| # get additional date field | |
| params = data.pop("parameters", ()) | |
| if not params: | |
| params = dict() | |
| set_seed(42) | |
| # run normal prediction | |
| generation = self.pipeline(inputs, **params) | |
| # **generate_kwargs https://huggingface.co/docs/transformers/generation_strategies#customize-text-generation, | |
| # https://huggingface.co/docs/transformers/generation_strategies#customize-text-generation | |
| return generation | |
| # Returns | |
| # A list or a list of list of dict | |
| # Returns one of the following dictionaries (cannot return a combination of both generated_text and generated_token_ids): | |
| # generated_text (str, present when return_text=True) — The generated text. | |
| # generated_token_ids (torch.Tensor or tf.Tensor, present when return_tensors=True) — The token ids of the generated text. |