OzTianlu commited on
Commit
d2efc5f
·
verified ·
1 Parent(s): 6ce2ca7

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +12 -39
handler.py CHANGED
@@ -22,23 +22,10 @@ def _is_messages(x: Any) -> bool:
22
  class EndpointHandler:
23
  """
24
  Hugging Face Inference Endpoints custom handler.
25
-
26
- Supports both text and chat formats:
27
-
28
- Text format:
29
- {"inputs": "Hello, how are you?"}
30
-
31
- Chat format (recommended):
32
- {"inputs": [{"role": "user", "content": "Hello!"}]}
33
- or
34
- {"inputs": {"messages": [{"role": "user", "content": "Hello!"}]}}
35
-
36
- Parameters:
37
- - max_new_tokens (default: 256): Max tokens to generate
38
- - temperature (default: 0.7): Sampling temperature
39
- - top_p (default: 0.95): Nucleus sampling
40
- - repetition_penalty (default: 1.0): Penalize repetitions
41
- - return_full_text (default: False): If True, return full conversation; if False, only new tokens
42
  """
43
 
44
  def __init__(self, model_dir: str):
@@ -86,7 +73,6 @@ class EndpointHandler:
86
  top_p = float(params.get("top_p", 0.95))
87
  top_k = int(params.get("top_k", 0))
88
  repetition_penalty = float(params.get("repetition_penalty", 1.0))
89
- return_full_text = bool(params.get("return_full_text", False))
90
 
91
  do_sample = bool(params.get("do_sample", temperature > 0))
92
  num_beams = int(params.get("num_beams", 1))
@@ -101,21 +87,11 @@ class EndpointHandler:
101
 
102
  if _is_messages(item):
103
  # Chat template path exists in repo; tokenizer.apply_chat_template will use it if configured
104
- try:
105
- # Use tokenize=False to get the formatted string first
106
- prompt = self.tokenizer.apply_chat_template(
107
- item,
108
- tokenize=False,
109
- add_generation_prompt=True,
110
- )
111
- # Then tokenize it separately to avoid unpacking issues
112
- enc = self.tokenizer(prompt, return_tensors="pt")
113
- input_ids = enc["input_ids"]
114
- except Exception:
115
- # Fallback: if chat template fails, use the last user message
116
- last_user_msg = next((m["content"] for m in reversed(item) if m.get("role") == "user"), "")
117
- enc = self.tokenizer(last_user_msg, return_tensors="pt")
118
- input_ids = enc["input_ids"]
119
  else:
120
  if not isinstance(item, str):
121
  item = str(item)
@@ -138,12 +114,9 @@ class EndpointHandler:
138
  eos_token_id=self.tokenizer.eos_token_id,
139
  )
140
 
141
- # Return newly generated tokens by default, or full text if requested
142
- if return_full_text:
143
- text = self.tokenizer.decode(gen_ids[0], skip_special_tokens=True)
144
- else:
145
- new_tokens = gen_ids[0, input_len:]
146
- text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
147
  return {"generated_text": text}
148
 
149
  # Batch support
 
22
  class EndpointHandler:
23
  """
24
  Hugging Face Inference Endpoints custom handler.
25
+ Expects:
26
+ - request body is a dict
27
+ - always contains `inputs`
28
+ - may contain `parameters` for generation
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  """
30
 
31
  def __init__(self, model_dir: str):
 
73
  top_p = float(params.get("top_p", 0.95))
74
  top_k = int(params.get("top_k", 0))
75
  repetition_penalty = float(params.get("repetition_penalty", 1.0))
 
76
 
77
  do_sample = bool(params.get("do_sample", temperature > 0))
78
  num_beams = int(params.get("num_beams", 1))
 
87
 
88
  if _is_messages(item):
89
  # Chat template path exists in repo; tokenizer.apply_chat_template will use it if configured
90
+ input_ids = self.tokenizer.apply_chat_template(
91
+ item,
92
+ return_tensors="pt",
93
+ add_generation_prompt=True,
94
+ )
 
 
 
 
 
 
 
 
 
 
95
  else:
96
  if not isinstance(item, str):
97
  item = str(item)
 
114
  eos_token_id=self.tokenizer.eos_token_id,
115
  )
116
 
117
+ # Only return newly generated tokens
118
+ new_tokens = gen_ids[0, input_len:]
119
+ text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
 
 
 
120
  return {"generated_text": text}
121
 
122
  # Batch support