OzTianlu commited on
Commit
ee6d50b
·
verified ·
1 Parent(s): d2efc5f

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +39 -12
handler.py CHANGED
@@ -22,10 +22,23 @@ def _is_messages(x: Any) -> bool:
22
  class EndpointHandler:
23
  """
24
  Hugging Face Inference Endpoints custom handler.
25
- Expects:
26
- - request body is a dict
27
- - always contains `inputs`
28
- - may contain `parameters` for generation
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  """
30
 
31
  def __init__(self, model_dir: str):
@@ -73,6 +86,7 @@ class EndpointHandler:
73
  top_p = float(params.get("top_p", 0.95))
74
  top_k = int(params.get("top_k", 0))
75
  repetition_penalty = float(params.get("repetition_penalty", 1.0))
 
76
 
77
  do_sample = bool(params.get("do_sample", temperature > 0))
78
  num_beams = int(params.get("num_beams", 1))
@@ -87,11 +101,21 @@ class EndpointHandler:
87
 
88
  if _is_messages(item):
89
  # Chat template path exists in repo; tokenizer.apply_chat_template will use it if configured
90
- input_ids = self.tokenizer.apply_chat_template(
91
- item,
92
- return_tensors="pt",
93
- add_generation_prompt=True,
94
- )
 
 
 
 
 
 
 
 
 
 
95
  else:
96
  if not isinstance(item, str):
97
  item = str(item)
@@ -114,9 +138,12 @@ class EndpointHandler:
114
  eos_token_id=self.tokenizer.eos_token_id,
115
  )
116
 
117
- # Only return newly generated tokens
118
- new_tokens = gen_ids[0, input_len:]
119
- text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
 
 
 
120
  return {"generated_text": text}
121
 
122
  # Batch support
 
22
  class EndpointHandler:
23
  """
24
  Hugging Face Inference Endpoints custom handler.
25
+
26
+ Supports both text and chat formats:
27
+
28
+ Text format:
29
+ {"inputs": "Hello, how are you?"}
30
+
31
+ Chat format (recommended):
32
+ {"inputs": [{"role": "user", "content": "Hello!"}]}
33
+ or
34
+ {"inputs": {"messages": [{"role": "user", "content": "Hello!"}]}}
35
+
36
+ Parameters:
37
+ - max_new_tokens (default: 256): Max tokens to generate
38
+ - temperature (default: 0.7): Sampling temperature
39
+ - top_p (default: 0.95): Nucleus sampling
40
+ - repetition_penalty (default: 1.0): Penalize repetitions
41
+ - return_full_text (default: False): If True, return full conversation; if False, only new tokens
42
  """
43
 
44
  def __init__(self, model_dir: str):
 
86
  top_p = float(params.get("top_p", 0.95))
87
  top_k = int(params.get("top_k", 0))
88
  repetition_penalty = float(params.get("repetition_penalty", 1.0))
89
+ return_full_text = bool(params.get("return_full_text", False))
90
 
91
  do_sample = bool(params.get("do_sample", temperature > 0))
92
  num_beams = int(params.get("num_beams", 1))
 
101
 
102
  if _is_messages(item):
103
  # Chat template path exists in repo; tokenizer.apply_chat_template will use it if configured
104
+ try:
105
+ # Use tokenize=False to get the formatted string first
106
+ prompt = self.tokenizer.apply_chat_template(
107
+ item,
108
+ tokenize=False,
109
+ add_generation_prompt=True,
110
+ )
111
+ # Then tokenize it separately to avoid unpacking issues
112
+ enc = self.tokenizer(prompt, return_tensors="pt")
113
+ input_ids = enc["input_ids"]
114
+ except Exception:
115
+ # Fallback: if chat template fails, use the last user message
116
+ last_user_msg = next((m["content"] for m in reversed(item) if m.get("role") == "user"), "")
117
+ enc = self.tokenizer(last_user_msg, return_tensors="pt")
118
+ input_ids = enc["input_ids"]
119
  else:
120
  if not isinstance(item, str):
121
  item = str(item)
 
138
  eos_token_id=self.tokenizer.eos_token_id,
139
  )
140
 
141
+ # Return newly generated tokens by default, or full text if requested
142
+ if return_full_text:
143
+ text = self.tokenizer.decode(gen_ids[0], skip_special_tokens=True)
144
+ else:
145
+ new_tokens = gen_ids[0, input_len:]
146
+ text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
147
  return {"generated_text": text}
148
 
149
  # Batch support