| | |
| | from typing import Dict, List, Any |
| | import torch |
| |
|
| | class EndpointHandler: |
| | def __init__(self, path=""): |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| | self.tokenizer = AutoTokenizer.from_pretrained(path) |
| | self.model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16, device_map="auto") |
| | |
| | self.inference_prompt_style = """Below is an instruction that describes a task, paired with an input that provides further context. |
| | Write a response that appropriately completes the request. |
| | Create an engaging and educational story that combines whimsical elements with real-world facts. |
| | |
| | ### Instruction: |
| | You are a creative storyteller who specializes in writing whimsical children's stories that incorporate educational facts about the real world. |
| | Please create a story based on the following prompt. |
| | |
| | ### Prompt: |
| | {} |
| | |
| | ### Response: |
| | <think> |
| | """ |
| | |
| | def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| | question = data.pop("inputs", data) |
| | parameters = data.pop("parameters", {}) |
| | |
| | |
| | max_new_tokens = parameters.get("max_new_tokens", 12000) |
| | |
| | |
| | prompt = self.inference_prompt_style.format(question) + self.tokenizer.eos_token |
| | |
| | |
| | inputs = self.tokenizer([prompt], return_tensors="pt") |
| | |
| | |
| | outputs = self.model.generate( |
| | input_ids=inputs.input_ids, |
| | attention_mask=inputs.attention_mask, |
| | max_new_tokens=max_new_tokens, |
| | eos_token_id=self.tokenizer.eos_token_id, |
| | use_cache=True, |
| | **parameters |
| | ) |
| | |
| | |
| | response = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) |
| | result = response[0].split("### Response:")[1] |
| | |
| | return [{"generated_text": result}] |