younissk commited on
Commit
3e6e18c
·
verified ·
1 Parent(s): 630b0c1

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +70 -31
handler.py CHANGED
@@ -135,48 +135,87 @@ class EndpointHandler:
135
  # best-effort: return canonicalized even if schema still complains
136
  return obj
137
 
138
- def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
139
- # Accept either {"messages":[...], "tools":[...]} or {"inputs": "..."}
140
- messages = data.get("messages")
141
- tools = data.get("tools") or data.get("functions") or []
142
- temperature = float(data.get("temperature", 0.0))
143
- max_new = int(data.get("max_new_tokens", 192))
144
-
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  if not messages:
146
- text = data.get("inputs") or data.get("text") or ""
147
- messages = [{"role": "user", "content": text}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
- # Build a fresh system message from tools; prepend to any user/system provided
150
- sys_text = tools_to_system_text(tools) if isinstance(tools, list) and tools else None
151
- msgs = []
152
- if sys_text:
153
- msgs.append({"role": "system", "content": sys_text})
154
- msgs.extend(messages)
155
 
156
- # The last user content is used by the guard to infer heuristics (e.g., "top 5")
157
  user_text = ""
158
- for m in msgs:
159
  if m.get("role") == "user":
160
  user_text = m.get("content", "")
 
161
 
162
- prompt = self.tokenizer.apply_chat_template(
163
- msgs, add_generation_prompt=True, return_tensors="pt"
164
- ).to(self.model.device)
 
 
 
 
 
 
 
 
165
 
166
  with torch.inference_mode():
167
- out = self.model.generate(
168
- input_ids=prompt,
169
- max_new_tokens=max_new,
170
- do_sample=temperature > 0,
171
- temperature=temperature if temperature > 0 else None,
172
- eos_token_id=self.tokenizer.eos_token_id,
173
- )
174
- raw = self.tokenizer.decode(out[0][prompt.shape[-1]:], skip_special_tokens=True).strip()
175
 
 
176
  guarded = self._apply_guard(user_text, tools, raw)
177
 
178
- # Return both for convenience
179
  return {
180
- "generated_text": raw, # string (for quick cURL)
181
- "envelope": guarded # dict with {"tool_calls":[...]} | {"function_call":...} | {"final_answer":...}
182
  }
 
135
  # best-effort: return canonicalized even if schema still complains
136
  return obj
137
 
138
+ def _unpack(self, data: Dict[str, Any]):
139
+ """Normalize payload coming from IE:
140
+ - accept top-level or inputs-nested messages/tools
141
+ - accept parameters both top-level and nested
142
+ """
143
+ body = data.get("inputs", data) # if no "inputs", body == data
144
+ params = data.get("parameters") or {}
145
+ # pull messages/tools from body if dict
146
+ messages = None
147
+ tools = None
148
+ if isinstance(body, dict):
149
+ messages = body.get("messages")
150
+ tools = body.get("tools") or body.get("functions")
151
+ # allow top-level fallbacks
152
+ if messages is None:
153
+ messages = data.get("messages")
154
+ if tools is None:
155
+ tools = data.get("tools") or data.get("functions") or []
156
+
157
+ # if still no messages, treat body as raw text
158
  if not messages:
159
+ raw = body if isinstance(body, str) else data.get("text", "")
160
+ messages = [{"role": "user", "content": str(raw)}]
161
+
162
+ # generation params (support both locations)
163
+ temperature = float(params.get("temperature", data.get("temperature", 0.0)))
164
+ max_new = int(params.get("max_new_tokens", data.get("max_new_tokens", 192)))
165
+ top_p = float(params.get("top_p", data.get("top_p", 1.0)))
166
+
167
+ return messages, tools, temperature, max_new, top_p
168
+
169
+ def _encode_messages(self, msgs: List[dict]):
170
+ # Try chat template; fallback to a simple role-tagged prompt
171
+ try:
172
+ return self.tokenizer.apply_chat_template(
173
+ msgs, add_generation_prompt=True, return_tensors="pt"
174
+ ).to(self.model.device)
175
+ except Exception:
176
+ lines = []
177
+ for m in msgs:
178
+ role = m.get("role", "user")
179
+ content = m.get("content", "")
180
+ lines.append(f"{role}: {content}")
181
+ lines.append("assistant:")
182
+ prompt_text = "\n".join(lines)
183
+ toks = self.tokenizer(prompt_text, return_tensors="pt")
184
+ return toks["input_ids"].to(self.model.device)
185
+
186
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
187
+ messages, tools, temperature, max_new, top_p = self._unpack(data)
188
 
189
+ # Build a system message from tools; prepend to conversation
190
+ sys_text = tools_to_system_text(tools) if tools else None
191
+ msgs = [{"role": "system", "content": sys_text}] + messages if sys_text else messages
 
 
 
192
 
193
+ # Remember last user text for the guard’s heuristics
194
  user_text = ""
195
+ for m in reversed(msgs):
196
  if m.get("role") == "user":
197
  user_text = m.get("content", "")
198
+ break
199
 
200
+ input_ids = self._encode_messages(msgs)
201
+
202
+ gen_kwargs = dict(
203
+ input_ids=input_ids,
204
+ max_new_tokens=max_new,
205
+ eos_token_id=self.tokenizer.eos_token_id,
206
+ )
207
+ if temperature > 0:
208
+ gen_kwargs.update(do_sample=True, temperature=temperature, top_p=top_p)
209
+ else:
210
+ gen_kwargs.update(do_sample=False)
211
 
212
  with torch.inference_mode():
213
+ out = self.model.generate(**gen_kwargs)
 
 
 
 
 
 
 
214
 
215
+ raw = self.tokenizer.decode(out[0][input_ids.shape[-1]:], skip_special_tokens=True).strip()
216
  guarded = self._apply_guard(user_text, tools, raw)
217
 
 
218
  return {
219
+ "generated_text": raw,
220
+ "envelope": guarded
221
  }