| from typing import Dict, List, Any, Optional | |
| import transformers | |
| import torch | |
| MAX_TOKENS=1024 | |
| class EndpointHandler(object): | |
| def __init__(self, path=''): | |
| self.pipeline: transformers.Pipeline = transformers.pipeline( | |
| "text-generation", | |
| model="humane-intelligence/gemma2-9b-cpt-sealionv3-instruct-endpoint", | |
| model_kwargs={"torch_dtype": torch.bfloat16 }, | |
| device_map="auto", | |
| ) | |
| def __call__(self, data: Dict[str, Any]) -> List[List[Dict[str, float]]]: | |
| """ | |
| :param data: | |
| inputs: message format | |
| parameters: parameters for the pipeline | |
| :return: | |
| """ | |
| print(f"data: {data}") | |
| inputs = data.pop("inputs") | |
| parameters: Optional[Dict] = data.pop("parameters", None) | |
| if parameters is not None: | |
| outputs = self.pipeline( | |
| inputs, | |
| **parameters | |
| ) | |
| else: | |
| outputs = self.pipeline(inputs, max_new_tokens=MAX_TOKENS) | |
| return outputs | |