rawcell commited on
Commit
cca395f
·
verified ·
1 Parent(s): 44cfc7c

Add handler.py for Inference Endpoints

Browse files
Files changed (1) hide show
  1. handler.py +46 -0
handler.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+
5
+ class EndpointHandler:
6
+ def __init__(self, path=""):
7
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
8
+ self.model = AutoModelForCausalLM.from_pretrained(
9
+ path,
10
+ torch_dtype=torch.bfloat16,
11
+ device_map="auto"
12
+ )
13
+
14
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
15
+ inputs = data.pop("inputs", data)
16
+ parameters = data.pop("parameters", {})
17
+
18
+ if isinstance(inputs, list) and len(inputs) > 0 and isinstance(inputs[0], dict):
19
+ text = self.tokenizer.apply_chat_template(
20
+ inputs,
21
+ tokenize=False,
22
+ add_generation_prompt=True
23
+ )
24
+ else:
25
+ text = inputs
26
+
27
+ encoded = self.tokenizer(text, return_tensors="pt").to(self.model.device)
28
+
29
+ gen_kwargs = {
30
+ "max_new_tokens": parameters.get("max_new_tokens", 512),
31
+ "temperature": parameters.get("temperature", 0.7),
32
+ "top_p": parameters.get("top_p", 0.9),
33
+ "do_sample": parameters.get("do_sample", True),
34
+ }
35
+
36
+ with torch.no_grad():
37
+ outputs = self.model.generate(**encoded, **gen_kwargs)
38
+
39
+ decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
40
+
41
+ if "<|im_start|>assistant" in decoded:
42
+ decoded = decoded.split("<|im_start|>assistant")[-1].strip()
43
+ if decoded.endswith("<|im_end|>"):
44
+ decoded = decoded[:-10].strip()
45
+
46
+ return [{"generated_text": decoded}]