OzTianlu commited on
Commit
d94ae83
·
verified ·
1 Parent(s): 5f4d724

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +100 -61
handler.py CHANGED
@@ -1,87 +1,126 @@
1
  # handler.py
2
- from typing import Any, Dict, List
 
 
 
3
  import torch
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
5
 
6
  Json = Dict[str, Any]
 
 
 
 
 
 
 
 
 
 
7
 
8
  class EndpointHandler:
9
  """
10
- Minimal custom handler for Hugging Face Inference Endpoints.
11
-
12
- Implements __init__() to load the model/tokenizer,
13
- and __call__() to handle inference requests.
 
14
  """
15
 
16
  def __init__(self, model_dir: str):
17
- """
18
- Called once on endpoint startup.
19
-
20
- Args:
21
- model_dir (str): Local path where the model repo was downloaded.
22
- """
23
- # Load tokenizer and model
24
- # Set trust_remote_code=True if the model repo has custom code
 
 
 
25
  self.tokenizer = AutoTokenizer.from_pretrained(
26
  model_dir,
27
- trust_remote_code=True, # allow custom code in repo
28
  use_fast=True,
29
  )
30
 
 
 
 
 
31
  self.model = AutoModelForCausalLM.from_pretrained(
32
  model_dir,
33
  trust_remote_code=True,
 
 
34
  )
35
 
36
- # Put model in eval mode
 
 
37
  self.model.eval()
38
 
39
  @torch.inference_mode()
40
- def __call__(self, data: Json) -> List[Json]:
41
- """
42
- Called for each inference request.
43
-
44
- Args:
45
- data (dict): {"inputs": str or list[str], "parameters": {...}}
46
-
47
- Returns:
48
- List[dict]: list of output dicts (each must be serializable).
49
- """
50
- # Parse incoming prompt(s)
51
  inputs = data.get("inputs", "")
52
  params = data.get("parameters", {}) or {}
53
 
54
- # Tokenize
55
- enc = self.tokenizer(
56
- inputs,
57
- return_tensors="pt",
58
- padding=True,
59
- )
60
-
61
- input_ids = enc["input_ids"]
62
- attention_mask = enc["attention_mask"]
63
-
64
- # Move tensors to model device
65
- device = next(self.model.parameters()).device
66
- input_ids = input_ids.to(device)
67
- attention_mask = attention_mask.to(device)
68
-
69
- # Generation parameters (optional overrides)
70
- max_new_tokens = int(params.get("max_new_tokens", 128))
71
- temperature = float(params.get("temperature", 1.0))
72
-
73
- # Run generation
74
- output_ids = self.model.generate(
75
- input_ids,
76
- attention_mask=attention_mask,
77
- max_new_tokens=max_new_tokens,
78
- temperature=temperature,
79
- )
80
-
81
- # Decode to text
82
- outputs = []
83
- for seq in output_ids:
84
- text = self.tokenizer.decode(seq, skip_special_tokens=True)
85
- outputs.append({"generated_text": text})
86
-
87
- return outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # handler.py
2
+ from __future__ import annotations
3
+
4
+ from typing import Any, Dict, List, Union
5
+
6
  import torch
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+
9
 
10
  Json = Dict[str, Any]
11
+ Messages = List[Dict[str, str]] # [{"role":"user|assistant|system", "content":"..."}]
12
+
13
+
14
+ def _is_messages(x: Any) -> bool:
15
+ return (
16
+ isinstance(x, list)
17
+ and len(x) > 0
18
+ and all(isinstance(m, dict) and "role" in m and "content" in m for m in x)
19
+ )
20
+
21
 
22
  class EndpointHandler:
23
  """
24
+ Hugging Face Inference Endpoints custom handler.
25
+ Expects:
26
+ - request body is a dict
27
+ - always contains `inputs`
28
+ - may contain `parameters` for generation
29
  """
30
 
31
  def __init__(self, model_dir: str):
32
+ self.model_dir = model_dir
33
+
34
+ # Pick dtype/device
35
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
36
+ if self.device == "cuda":
37
+ # bfloat16 is usually safe on A100/H100; if your instance doesn't support bf16, change to float16
38
+ self.dtype = torch.bfloat16
39
+ else:
40
+ self.dtype = torch.float32
41
+
42
+ # IMPORTANT: trust_remote_code=True because repo contains AsteriskForCausalLM.py + auto_map
43
  self.tokenizer = AutoTokenizer.from_pretrained(
44
  model_dir,
45
+ trust_remote_code=True,
46
  use_fast=True,
47
  )
48
 
49
+ # Make sure pad token exists (your config uses pad_token_id=2 which equals eos_token_id in many llama-like models)
50
+ if self.tokenizer.pad_token_id is None and self.tokenizer.eos_token_id is not None:
51
+ self.tokenizer.pad_token = self.tokenizer.eos_token
52
+
53
  self.model = AutoModelForCausalLM.from_pretrained(
54
  model_dir,
55
  trust_remote_code=True,
56
+ torch_dtype=self.dtype,
57
+ device_map="auto" if self.device == "cuda" else None,
58
  )
59
 
60
+ if self.device != "cuda":
61
+ self.model.to(self.device)
62
+
63
  self.model.eval()
64
 
65
  @torch.inference_mode()
66
+ def __call__(self, data: Json) -> Union[Json, List[Json]]:
 
 
 
 
 
 
 
 
 
 
67
  inputs = data.get("inputs", "")
68
  params = data.get("parameters", {}) or {}
69
 
70
+ # Generation defaults (can be overridden via `parameters`)
71
+ max_new_tokens = int(params.get("max_new_tokens", 256))
72
+ temperature = float(params.get("temperature", 0.7))
73
+ top_p = float(params.get("top_p", 0.95))
74
+ top_k = int(params.get("top_k", 0))
75
+ repetition_penalty = float(params.get("repetition_penalty", 1.0))
76
+
77
+ do_sample = bool(params.get("do_sample", temperature > 0))
78
+ num_beams = int(params.get("num_beams", 1))
79
+
80
+ def _one(item: Any) -> Json:
81
+ # Accept:
82
+ # 1) string prompt
83
+ # 2) messages list: [{"role":"user","content":"..."}]
84
+ # 3) dict {"messages":[...]} (common chat style)
85
+ if isinstance(item, dict) and "messages" in item:
86
+ item = item["messages"]
87
+
88
+ if _is_messages(item):
89
+ # Chat template path exists in repo; tokenizer.apply_chat_template will use it if configured
90
+ input_ids = self.tokenizer.apply_chat_template(
91
+ item,
92
+ return_tensors="pt",
93
+ add_generation_prompt=True,
94
+ )
95
+ else:
96
+ if not isinstance(item, str):
97
+ item = str(item)
98
+ enc = self.tokenizer(item, return_tensors="pt")
99
+ input_ids = enc["input_ids"]
100
+
101
+ input_ids = input_ids.to(self.model.device)
102
+ input_len = input_ids.shape[-1]
103
+
104
+ gen_ids = self.model.generate(
105
+ input_ids=input_ids,
106
+ max_new_tokens=max_new_tokens,
107
+ do_sample=do_sample,
108
+ temperature=temperature if do_sample else None,
109
+ top_p=top_p if do_sample else None,
110
+ top_k=top_k if do_sample and top_k > 0 else None,
111
+ num_beams=num_beams,
112
+ repetition_penalty=repetition_penalty,
113
+ pad_token_id=self.tokenizer.pad_token_id,
114
+ eos_token_id=self.tokenizer.eos_token_id,
115
+ )
116
+
117
+ # Only return newly generated tokens
118
+ new_tokens = gen_ids[0, input_len:]
119
+ text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
120
+ return {"generated_text": text}
121
+
122
+ # Batch support
123
+ if isinstance(inputs, list) and not _is_messages(inputs):
124
+ return [_one(x) for x in inputs]
125
+ else:
126
+ return _one(inputs)