OzTianlu commited on
Commit
6198884
·
verified ·
1 Parent(s): 5e08673

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +126 -0
handler.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # handler.py
2
+ from __future__ import annotations
3
+
4
+ from typing import Any, Dict, List, Union
5
+
6
+ import torch
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+
9
+
10
+ Json = Dict[str, Any]
11
+ Messages = List[Dict[str, str]] # [{"role":"user|assistant|system", "content":"..."}]
12
+
13
+
14
+ def _is_messages(x: Any) -> bool:
15
+ return (
16
+ isinstance(x, list)
17
+ and len(x) > 0
18
+ and all(isinstance(m, dict) and "role" in m and "content" in m for m in x)
19
+ )
20
+
21
+
22
+ class EndpointHandler:
23
+ """
24
+ Hugging Face Inference Endpoints custom handler.
25
+ Expects:
26
+ - request body is a dict
27
+ - always contains `inputs`
28
+ - may contain `parameters` for generation
29
+ """
30
+
31
+ def __init__(self, model_dir: str):
32
+ self.model_dir = model_dir
33
+
34
+ # Pick dtype/device
35
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
36
+ if self.device == "cuda":
37
+ # bfloat16 is usually safe on A100/H100; if your instance doesn't support bf16, change to float16
38
+ self.dtype = torch.bfloat16
39
+ else:
40
+ self.dtype = torch.float32
41
+
42
+ # IMPORTANT: trust_remote_code=True because repo contains AsteriskForCausalLM.py + auto_map
43
+ self.tokenizer = AutoTokenizer.from_pretrained(
44
+ model_dir,
45
+ trust_remote_code=True,
46
+ use_fast=True,
47
+ )
48
+
49
+ # Make sure pad token exists (your config uses pad_token_id=2 which equals eos_token_id in many llama-like models)
50
+ if self.tokenizer.pad_token_id is None and self.tokenizer.eos_token_id is not None:
51
+ self.tokenizer.pad_token = self.tokenizer.eos_token
52
+
53
+ self.model = AutoModelForCausalLM.from_pretrained(
54
+ model_dir,
55
+ trust_remote_code=True,
56
+ torch_dtype=self.dtype,
57
+ device_map="auto" if self.device == "cuda" else None,
58
+ )
59
+
60
+ if self.device != "cuda":
61
+ self.model.to(self.device)
62
+
63
+ self.model.eval()
64
+
65
+ @torch.inference_mode()
66
+ def __call__(self, data: Json) -> Union[Json, List[Json]]:
67
+ inputs = data.get("inputs", "")
68
+ params = data.get("parameters", {}) or {}
69
+
70
+ # Generation defaults (can be overridden via `parameters`)
71
+ max_new_tokens = int(params.get("max_new_tokens", 256))
72
+ temperature = float(params.get("temperature", 0.7))
73
+ top_p = float(params.get("top_p", 0.95))
74
+ top_k = int(params.get("top_k", 0))
75
+ repetition_penalty = float(params.get("repetition_penalty", 1.0))
76
+
77
+ do_sample = bool(params.get("do_sample", temperature > 0))
78
+ num_beams = int(params.get("num_beams", 1))
79
+
80
+ def _one(item: Any) -> Json:
81
+ # Accept:
82
+ # 1) string prompt
83
+ # 2) messages list: [{"role":"user","content":"..."}]
84
+ # 3) dict {"messages":[...]} (common chat style)
85
+ if isinstance(item, dict) and "messages" in item:
86
+ item = item["messages"]
87
+
88
+ if _is_messages(item):
89
+ # Chat template path exists in repo; tokenizer.apply_chat_template will use it if configured
90
+ input_ids = self.tokenizer.apply_chat_template(
91
+ item,
92
+ return_tensors="pt",
93
+ add_generation_prompt=True,
94
+ )
95
+ else:
96
+ if not isinstance(item, str):
97
+ item = str(item)
98
+ enc = self.tokenizer(item, return_tensors="pt")
99
+ input_ids = enc["input_ids"]
100
+
101
+ input_ids = input_ids.to(self.model.device)
102
+ input_len = input_ids.shape[-1]
103
+
104
+ gen_ids = self.model.generate(
105
+ input_ids=input_ids,
106
+ max_new_tokens=max_new_tokens,
107
+ do_sample=do_sample,
108
+ temperature=temperature if do_sample else None,
109
+ top_p=top_p if do_sample else None,
110
+ top_k=top_k if do_sample and top_k > 0 else None,
111
+ num_beams=num_beams,
112
+ repetition_penalty=repetition_penalty,
113
+ pad_token_id=self.tokenizer.pad_token_id,
114
+ eos_token_id=self.tokenizer.eos_token_id,
115
+ )
116
+
117
+ # Only return newly generated tokens
118
+ new_tokens = gen_ids[0, input_len:]
119
+ text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
120
+ return {"generated_text": text}
121
+
122
+ # Batch support
123
+ if isinstance(inputs, list) and not _is_messages(inputs):
124
+ return [_one(x) for x in inputs]
125
+ else:
126
+ return _one(inputs)