d-s-b commited on
Commit
7c170e5
·
verified ·
1 Parent(s): 8a8aa76

adding a handler file

Browse files
Files changed (1) hide show
  1. handler.py +53 -0
handler.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # handler.py
2
+ from typing import Dict, List, Any
3
+ import torch
4
+
5
+ class EndpointHandler:
6
+ def __init__(self, path=""):
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
9
+ self.model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16, device_map="auto")
10
+
11
+ self.inference_prompt_style = """Below is an instruction that describes a task, paired with an input that provides further context.
12
+ Write a response that appropriately completes the request.
13
+ Create an engaging and educational story that combines whimsical elements with real-world facts.
14
+
15
+ ### Instruction:
16
+ You are a creative storyteller who specializes in writing whimsical children's stories that incorporate educational facts about the real world.
17
+ Please create a story based on the following prompt.
18
+
19
+ ### Prompt:
20
+ {}
21
+
22
+ ### Response:
23
+ <think>
24
+ """
25
+
26
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
27
+ question = data.pop("inputs", data)
28
+ parameters = data.pop("parameters", {})
29
+
30
+ # Set default parameters
31
+ max_new_tokens = parameters.get("max_new_tokens", 1200)
32
+
33
+ # Format prompt
34
+ prompt = self.inference_prompt_style.format(question) + self.tokenizer.eos_token
35
+
36
+ # Tokenize
37
+ inputs = self.tokenizer([prompt], return_tensors="pt")
38
+
39
+ # Generate
40
+ outputs = self.model.generate(
41
+ input_ids=inputs.input_ids,
42
+ attention_mask=inputs.attention_mask,
43
+ max_new_tokens=max_new_tokens,
44
+ eos_token_id=self.tokenizer.eos_token_id,
45
+ use_cache=True,
46
+ **parameters
47
+ )
48
+
49
+ # Decode and extract response
50
+ response = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
51
+ result = response[0].split("### Response:")[1]
52
+
53
+ return [{"generated_text": result}]