Update handler.py
Browse files- handler.py +65 -32
handler.py
CHANGED
|
@@ -57,28 +57,34 @@ def tools_to_system_text(tools: List[dict], max_tools=12, max_props=6) -> str:
|
|
| 57 |
return "\n".join(lines)
|
| 58 |
|
| 59 |
def build_schema_from_tools(tools: List[dict]) -> dict:
|
| 60 |
-
# strict per-request schema (name + args schema)
|
| 61 |
from copy import deepcopy
|
| 62 |
tool_variants, defs = [], {}
|
| 63 |
-
for t in tools:
|
| 64 |
-
f = t.get("function", {}
|
| 65 |
-
name = f.get("name")
|
| 66 |
-
|
| 67 |
-
if not isinstance(name, str):
|
| 68 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
defs[f"{name}_args"] = deepcopy(params)
|
| 70 |
tool_variants.append({
|
| 71 |
-
"type":
|
| 72 |
-
"properties":
|
| 73 |
-
"required":
|
| 74 |
"additionalProperties": False
|
| 75 |
})
|
|
|
|
| 76 |
return {
|
| 77 |
-
"$schema":
|
| 78 |
-
"oneOf":
|
| 79 |
-
{"type":"object","properties":{"
|
| 80 |
-
"required":["function_call"],"additionalProperties": False},
|
| 81 |
-
{"type":"object","properties":{"tool_calls":{"type":"array","minItems":1,"items":{"oneOf": tool_variants}}},
|
| 82 |
"required":["tool_calls"],"additionalProperties": False},
|
| 83 |
{"type":"object","properties":{"final_answer":{"type":"string","minLength":1}},
|
| 84 |
"required":["final_answer"],"additionalProperties": False}
|
|
@@ -86,6 +92,7 @@ def build_schema_from_tools(tools: List[dict]) -> dict:
|
|
| 86 |
"$defs": defs
|
| 87 |
}
|
| 88 |
|
|
|
|
| 89 |
class EndpointHandler:
|
| 90 |
def __init__(self, path: str = ""):
|
| 91 |
model_id = path or os.getenv("MODEL_ID", ".")
|
|
@@ -166,22 +173,48 @@ class EndpointHandler:
|
|
| 166 |
|
| 167 |
return messages, tools, temperature, max_new, top_p
|
| 168 |
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
|
| 186 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
| 187 |
messages, tools, temperature, max_new, top_p = self._unpack(data)
|
|
@@ -192,12 +225,12 @@ class EndpointHandler:
|
|
| 192 |
|
| 193 |
# Remember last user text for the guard’s heuristics
|
| 194 |
user_text = ""
|
| 195 |
-
for m in reversed(
|
| 196 |
if m.get("role") == "user":
|
| 197 |
user_text = m.get("content", "")
|
| 198 |
break
|
| 199 |
|
| 200 |
-
input_ids = self._encode_messages(
|
| 201 |
|
| 202 |
gen_kwargs = dict(
|
| 203 |
input_ids=input_ids,
|
|
|
|
| 57 |
return "\n".join(lines)
|
| 58 |
|
| 59 |
def build_schema_from_tools(tools: List[dict]) -> dict:
|
|
|
|
| 60 |
from copy import deepcopy
|
| 61 |
tool_variants, defs = [], {}
|
| 62 |
+
for t in tools or []:
|
| 63 |
+
f = t.get("function", t) if isinstance(t, dict) else {}
|
| 64 |
+
name = f.get("name") or f.get("api_call") or f.get("api_name")
|
| 65 |
+
if not isinstance(name, str) or not name:
|
|
|
|
| 66 |
continue
|
| 67 |
+
params = f.get("parameters") or {"type":"object","properties":{},"additionalProperties":True}
|
| 68 |
+
if isinstance(params, list): # allow list form
|
| 69 |
+
props = {}
|
| 70 |
+
for p in params:
|
| 71 |
+
if isinstance(p, dict) and "name" in p:
|
| 72 |
+
nm = p["name"]; pd = {k:v for k,v in p.items() if k!="name"}
|
| 73 |
+
props[nm] = pd
|
| 74 |
+
if props:
|
| 75 |
+
params = {"type":"object","properties":props}
|
| 76 |
defs[f"{name}_args"] = deepcopy(params)
|
| 77 |
tool_variants.append({
|
| 78 |
+
"type":"object",
|
| 79 |
+
"properties":{"name":{"const":name},"arguments":{"$ref": f"#/$defs/{name}_args"}},
|
| 80 |
+
"required":["name","arguments"],
|
| 81 |
"additionalProperties": False
|
| 82 |
})
|
| 83 |
+
|
| 84 |
return {
|
| 85 |
+
"$schema":"https://json-schema.org/draft/2020-12/schema",
|
| 86 |
+
"oneOf":[
|
| 87 |
+
{"type":"object","properties":{"tool_calls":{"type":"array","minItems":1,"items":{"oneOf":tool_variants}}},
|
|
|
|
|
|
|
| 88 |
"required":["tool_calls"],"additionalProperties": False},
|
| 89 |
{"type":"object","properties":{"final_answer":{"type":"string","minLength":1}},
|
| 90 |
"required":["final_answer"],"additionalProperties": False}
|
|
|
|
| 92 |
"$defs": defs
|
| 93 |
}
|
| 94 |
|
| 95 |
+
|
| 96 |
class EndpointHandler:
|
| 97 |
def __init__(self, path: str = ""):
|
| 98 |
model_id = path or os.getenv("MODEL_ID", ".")
|
|
|
|
| 173 |
|
| 174 |
return messages, tools, temperature, max_new, top_p
|
| 175 |
|
| 176 |
+
SYS_INSTR = (
|
| 177 |
+
"You're a tool-calling assistant. "
|
| 178 |
+
"Return ONLY valid JSON for your answer, with this exact shape:\n"
|
| 179 |
+
"{\"tool_calls\": [{\"name\": \"<function_name>\", \"arguments\": {<key>: <value>, ...}}, ...]}\n"
|
| 180 |
+
"No prose. No explanations. JSON only."
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
def _flat_tool(self, t: dict):
|
| 184 |
+
"""Accept both {'function':{...}} and flat {'name':...,'parameters':...}."""
|
| 185 |
+
f = t.get("function", t) if isinstance(t, dict) else {}
|
| 186 |
+
name = f.get("name") or f.get("api_call") or f.get("api_name") or ""
|
| 187 |
+
params = f.get("parameters") or {}
|
| 188 |
+
# Normalize parameter names for signature display
|
| 189 |
+
prop_names = []
|
| 190 |
+
if isinstance(params, dict):
|
| 191 |
+
props = params.get("properties")
|
| 192 |
+
if isinstance(props, dict):
|
| 193 |
+
prop_names = list(props.keys())[:12]
|
| 194 |
+
elif isinstance(props, list):
|
| 195 |
+
prop_names = [p.get("name","") for p in props if isinstance(p,dict)][:12]
|
| 196 |
+
return name, params, prop_names
|
| 197 |
+
|
| 198 |
+
def _render_tools_signature(self, tools: List[dict]) -> str:
|
| 199 |
+
lines = []
|
| 200 |
+
for t in tools[:12]:
|
| 201 |
+
name, _, pnames = self._flat_tool(t)
|
| 202 |
+
if not name:
|
| 203 |
+
continue
|
| 204 |
+
lines.append(f"- {name}({', '.join(pnames)})" if pnames else f"- {name}()")
|
| 205 |
+
return "\n".join(lines) if lines else "- (tools omitted)"
|
| 206 |
+
|
| 207 |
+
def _encode_messages(self, user_text: str, tools: List[dict]):
|
| 208 |
+
# Build the exact same prompt you used for training
|
| 209 |
+
sig = self._render_tools_signature(tools)
|
| 210 |
+
prompt = (
|
| 211 |
+
"<|system|>\n" + SYS_INSTR + "\n\n"
|
| 212 |
+
"<|tools|>\n" + sig + "\n\n"
|
| 213 |
+
"<|user|>\n" + user_text + "\n\n"
|
| 214 |
+
"<|assistant|>\n"
|
| 215 |
+
)
|
| 216 |
+
toks = self.tokenizer(prompt, return_tensors="pt")
|
| 217 |
+
return toks["input_ids"].to(self.model.device)
|
| 218 |
|
| 219 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
| 220 |
messages, tools, temperature, max_new, top_p = self._unpack(data)
|
|
|
|
| 225 |
|
| 226 |
# Remember last user text for the guard’s heuristics
|
| 227 |
user_text = ""
|
| 228 |
+
for m in reversed(messages):
|
| 229 |
if m.get("role") == "user":
|
| 230 |
user_text = m.get("content", "")
|
| 231 |
break
|
| 232 |
|
| 233 |
+
input_ids = self._encode_messages(user_text, tools)
|
| 234 |
|
| 235 |
gen_kwargs = dict(
|
| 236 |
input_ids=input_ids,
|