younissk commited on
Commit
9319bba
·
verified ·
1 Parent(s): 50b28cd

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +65 -32
handler.py CHANGED
@@ -57,28 +57,34 @@ def tools_to_system_text(tools: List[dict], max_tools=12, max_props=6) -> str:
57
  return "\n".join(lines)
58
 
59
  def build_schema_from_tools(tools: List[dict]) -> dict:
60
- # strict per-request schema (name + args schema)
61
  from copy import deepcopy
62
  tool_variants, defs = [], {}
63
- for t in tools:
64
- f = t.get("function", {})
65
- name = f.get("name")
66
- params = f.get("parameters") or {"type":"object","properties":{},"additionalProperties":True}
67
- if not isinstance(name, str):
68
  continue
 
 
 
 
 
 
 
 
 
69
  defs[f"{name}_args"] = deepcopy(params)
70
  tool_variants.append({
71
- "type": "object",
72
- "properties": {"name": {"const": name}, "arguments": {"$ref": f"#/$defs/{name}_args"}},
73
- "required": ["name","arguments"],
74
  "additionalProperties": False
75
  })
 
76
  return {
77
- "$schema": "https://json-schema.org/draft/2020-12/schema",
78
- "oneOf": [
79
- {"type":"object","properties":{"function_call":{"oneOf": tool_variants}},
80
- "required":["function_call"],"additionalProperties": False},
81
- {"type":"object","properties":{"tool_calls":{"type":"array","minItems":1,"items":{"oneOf": tool_variants}}},
82
  "required":["tool_calls"],"additionalProperties": False},
83
  {"type":"object","properties":{"final_answer":{"type":"string","minLength":1}},
84
  "required":["final_answer"],"additionalProperties": False}
@@ -86,6 +92,7 @@ def build_schema_from_tools(tools: List[dict]) -> dict:
86
  "$defs": defs
87
  }
88
 
 
89
  class EndpointHandler:
90
  def __init__(self, path: str = ""):
91
  model_id = path or os.getenv("MODEL_ID", ".")
@@ -166,22 +173,48 @@ class EndpointHandler:
166
 
167
  return messages, tools, temperature, max_new, top_p
168
 
169
- def _encode_messages(self, msgs: List[dict]):
170
- # Try chat template; fallback to a simple role-tagged prompt
171
- try:
172
- return self.tokenizer.apply_chat_template(
173
- msgs, add_generation_prompt=True, return_tensors="pt"
174
- ).to(self.model.device)
175
- except Exception:
176
- lines = []
177
- for m in msgs:
178
- role = m.get("role", "user")
179
- content = m.get("content", "")
180
- lines.append(f"{role}: {content}")
181
- lines.append("assistant:")
182
- prompt_text = "\n".join(lines)
183
- toks = self.tokenizer(prompt_text, return_tensors="pt")
184
- return toks["input_ids"].to(self.model.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
187
  messages, tools, temperature, max_new, top_p = self._unpack(data)
@@ -192,12 +225,12 @@ class EndpointHandler:
192
 
193
  # Remember last user text for the guard’s heuristics
194
  user_text = ""
195
- for m in reversed(msgs):
196
  if m.get("role") == "user":
197
  user_text = m.get("content", "")
198
  break
199
 
200
- input_ids = self._encode_messages(msgs)
201
 
202
  gen_kwargs = dict(
203
  input_ids=input_ids,
 
57
  return "\n".join(lines)
58
 
59
  def build_schema_from_tools(tools: List[dict]) -> dict:
 
60
  from copy import deepcopy
61
  tool_variants, defs = [], {}
62
+ for t in tools or []:
63
+ f = t.get("function", t) if isinstance(t, dict) else {}
64
+ name = f.get("name") or f.get("api_call") or f.get("api_name")
65
+ if not isinstance(name, str) or not name:
 
66
  continue
67
+ params = f.get("parameters") or {"type":"object","properties":{},"additionalProperties":True}
68
+ if isinstance(params, list): # allow list form
69
+ props = {}
70
+ for p in params:
71
+ if isinstance(p, dict) and "name" in p:
72
+ nm = p["name"]; pd = {k:v for k,v in p.items() if k!="name"}
73
+ props[nm] = pd
74
+ if props:
75
+ params = {"type":"object","properties":props}
76
  defs[f"{name}_args"] = deepcopy(params)
77
  tool_variants.append({
78
+ "type":"object",
79
+ "properties":{"name":{"const":name},"arguments":{"$ref": f"#/$defs/{name}_args"}},
80
+ "required":["name","arguments"],
81
  "additionalProperties": False
82
  })
83
+
84
  return {
85
+ "$schema":"https://json-schema.org/draft/2020-12/schema",
86
+ "oneOf":[
87
+ {"type":"object","properties":{"tool_calls":{"type":"array","minItems":1,"items":{"oneOf":tool_variants}}},
 
 
88
  "required":["tool_calls"],"additionalProperties": False},
89
  {"type":"object","properties":{"final_answer":{"type":"string","minLength":1}},
90
  "required":["final_answer"],"additionalProperties": False}
 
92
  "$defs": defs
93
  }
94
 
95
+
96
  class EndpointHandler:
97
  def __init__(self, path: str = ""):
98
  model_id = path or os.getenv("MODEL_ID", ".")
 
173
 
174
  return messages, tools, temperature, max_new, top_p
175
 
176
+ SYS_INSTR = (
177
+ "You're a tool-calling assistant. "
178
+ "Return ONLY valid JSON for your answer, with this exact shape:\n"
179
+ "{\"tool_calls\": [{\"name\": \"<function_name>\", \"arguments\": {<key>: <value>, ...}}, ...]}\n"
180
+ "No prose. No explanations. JSON only."
181
+ )
182
+
183
+ def _flat_tool(self, t: dict):
184
+ """Accept both {'function':{...}} and flat {'name':...,'parameters':...}."""
185
+ f = t.get("function", t) if isinstance(t, dict) else {}
186
+ name = f.get("name") or f.get("api_call") or f.get("api_name") or ""
187
+ params = f.get("parameters") or {}
188
+ # Normalize parameter names for signature display
189
+ prop_names = []
190
+ if isinstance(params, dict):
191
+ props = params.get("properties")
192
+ if isinstance(props, dict):
193
+ prop_names = list(props.keys())[:12]
194
+ elif isinstance(props, list):
195
+ prop_names = [p.get("name","") for p in props if isinstance(p,dict)][:12]
196
+ return name, params, prop_names
197
+
198
+ def _render_tools_signature(self, tools: List[dict]) -> str:
199
+ lines = []
200
+ for t in tools[:12]:
201
+ name, _, pnames = self._flat_tool(t)
202
+ if not name:
203
+ continue
204
+ lines.append(f"- {name}({', '.join(pnames)})" if pnames else f"- {name}()")
205
+ return "\n".join(lines) if lines else "- (tools omitted)"
206
+
207
+ def _encode_messages(self, user_text: str, tools: List[dict]):
208
+ # Build the exact same prompt you used for training
209
+ sig = self._render_tools_signature(tools)
210
+ prompt = (
211
+ "<|system|>\n" + SYS_INSTR + "\n\n"
212
+ "<|tools|>\n" + sig + "\n\n"
213
+ "<|user|>\n" + user_text + "\n\n"
214
+ "<|assistant|>\n"
215
+ )
216
+ toks = self.tokenizer(prompt, return_tensors="pt")
217
+ return toks["input_ids"].to(self.model.device)
218
 
219
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
220
  messages, tools, temperature, max_new, top_p = self._unpack(data)
 
225
 
226
  # Remember last user text for the guard’s heuristics
227
  user_text = ""
228
+ for m in reversed(messages):
229
  if m.get("role") == "user":
230
  user_text = m.get("content", "")
231
  break
232
 
233
+ input_ids = self._encode_messages(user_text, tools)
234
 
235
  gen_kwargs = dict(
236
  input_ids=input_ids,