| from typing import Dict, List, Any | |
| from optimum.intel import OVModelForSeq2SeqLM | |
| from transformers import AutoTokenizer | |
| INSTRUCTION = "rewrite: " | |
| generation_config = { | |
| "max_new_tokens": 16, | |
| "use_cache": True, | |
| "temperature": 0.6, | |
| "do_sample": True, | |
| "top_p": 0.95, | |
| } | |
| class EndpointHandler: | |
| def __init__(self, path="."): | |
| # Preload all the elements you are going to need at inference. | |
| # pseudo: | |
| self.model = OVModelForSeq2SeqLM.from_pretrained( | |
| path, use_cache=True, use_io_binding=False | |
| ) | |
| self.tokenizer = AutoTokenizer.from_pretrained(path, use_fast=True) | |
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: | |
| """ | |
| data args: | |
| inputs (:obj: `str` | `PIL.Image` | `np.array`) | |
| kwargs | |
| Return: | |
| A :obj:`list` | `dict`: will be serialized and returned | |
| """ | |
| inputs = data.pop("inputs", data) | |
| parameters = data.pop("parameters", generation_config) | |
| inputs = self.tokenizer( | |
| ["{} {}".format(INSTRUCTION, inputs)], | |
| padding=False, | |
| return_tensors="pt", | |
| max_length=20, | |
| truncation=True, | |
| ) | |
| outputs = self.model.generate(**inputs, **parameters) | |
| return self.tokenizer.batch_decode(outputs, skip_special_tokens=True) | |