| | import warnings |
| | import torch |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| | from typing import Any, Dict |
| |
|
| | class EndpointHandler: |
| | def __init__(self, path='', torch_dtype=torch.bfloat16, trust_remote_code=True): |
| |
|
| | self.model = AutoModelForCausalLM.from_pretrained( |
| | path, |
| | torch_dtype=torch_dtype, |
| | trust_remote_code=trust_remote_code |
| | ) |
| | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") |
| | |
| | if tokenizer.pad_token_id is None: |
| | warnings.warn( |
| | "pad_token_id is not set for the tokenizer. Using eos_token_id as pad_token_id." |
| | ) |
| | tokenizer.pad_token = tokenizer.eos_token |
| | |
| | tokenizer.padding_side = "right" |
| | self.tokenizer = tokenizer |
| |
|
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | self.model.eval() |
| | self.model.to(device=self.device, dtype=torch_dtype) |
| |
|
| | self.generate_kwargs = { |
| | "temperature": 0.01, |
| | "top_p": 0.92, |
| | "top_k": 0, |
| | "max_new_tokens": 512, |
| | "use_cache": True, |
| | "do_sample": True, |
| | "eos_token_id": self.tokenizer.eos_token_id, |
| | "pad_token_id": self.tokenizer.pad_token_id, |
| | "repetition_penalty": 1.0 |
| | } |
| |
|
| | def format_instruction(self, instruction): |
| | return PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction) |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: |
| | |
| | inputs = data.pop("inputs", data) |
| | parameters = data.pop("parameters", None) |
| |
|
| | |
| | s = PROMPT_FOR_GENERATION_FORMAT.format(instruction=inputs) |
| | input_ids = self.tokenizer(s, return_tensors="pt").input_ids.to(self.device) |
| | gkw = {**self.generate_kwargs, **parameters} |
| | |
| | with torch.no_grad(): |
| | output_ids = self.model.generate(input_ids, **gkw) |
| | |
| | new_tokens = output_ids[0, len(input_ids[0]) :] |
| | output_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True) |
| | return [{"generated_text": output_text}] |