File size: 1,928 Bytes
7c170e5 828e52c 7c170e5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 | # handler.py
from typing import Dict, List, Any
import torch
class EndpointHandler:
def __init__(self, path=""):
from transformers import AutoModelForCausalLM, AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16, device_map="auto")
self.inference_prompt_style = """Below is an instruction that describes a task, paired with an input that provides further context.
Write a response that appropriately completes the request.
Create an engaging and educational story that combines whimsical elements with real-world facts.
### Instruction:
You are a creative storyteller who specializes in writing whimsical children's stories that incorporate educational facts about the real world.
Please create a story based on the following prompt.
### Prompt:
{}
### Response:
<think>
"""
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
question = data.pop("inputs", data)
parameters = data.pop("parameters", {})
# Set default parameters
max_new_tokens = parameters.get("max_new_tokens", 12000)
# Format prompt
prompt = self.inference_prompt_style.format(question) + self.tokenizer.eos_token
# Tokenize
inputs = self.tokenizer([prompt], return_tensors="pt")
# Generate
outputs = self.model.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=max_new_tokens,
eos_token_id=self.tokenizer.eos_token_id,
use_cache=True,
**parameters
)
# Decode and extract response
response = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
result = response[0].split("### Response:")[1]
return [{"generated_text": result}] |