File size: 1,928 Bytes
7c170e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
828e52c
7c170e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# handler.py
from typing import Dict, List, Any
import torch

class EndpointHandler:
   def __init__(self, path=""):
       from transformers import AutoModelForCausalLM, AutoTokenizer
       self.tokenizer = AutoTokenizer.from_pretrained(path)
       self.model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16, device_map="auto")
       
       self.inference_prompt_style = """Below is an instruction that describes a task, paired with an input that provides further context. 
Write a response that appropriately completes the request. 
Create an engaging and educational story that combines whimsical elements with real-world facts.

### Instruction:
You are a creative storyteller who specializes in writing whimsical children's stories that incorporate educational facts about the real world.
Please create a story based on the following prompt.

### Prompt:
{}

### Response:
<think>
"""
   
   def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
       question = data.pop("inputs", data)
       parameters = data.pop("parameters", {})
       
       # Set default parameters
       max_new_tokens = parameters.get("max_new_tokens", 12000)
       
       # Format prompt
       prompt = self.inference_prompt_style.format(question) + self.tokenizer.eos_token
       
       # Tokenize
       inputs = self.tokenizer([prompt], return_tensors="pt")
       
       # Generate
       outputs = self.model.generate(
           input_ids=inputs.input_ids,
           attention_mask=inputs.attention_mask,
           max_new_tokens=max_new_tokens,
           eos_token_id=self.tokenizer.eos_token_id,
           use_cache=True,
           **parameters
       )
       
       # Decode and extract response
       response = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
       result = response[0].split("### Response:")[1]
       
       return [{"generated_text": result}]