|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Dict, Any, List, Tuple |
|
|
import os |
|
|
import json |
|
|
import re |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
from jsonschema import Draft202012Validator |
|
|
|
|
|
|
|
|
SYS_INSTR = ( |
|
|
"You're a tool-calling assistant. " |
|
|
"Return ONLY valid JSON for your answer, with this exact shape:\n" |
|
|
"{\"tool_calls\": [{\"name\": \"<function_name>\", \"arguments\": {<key>: <value>, ...}}, ...]}\n" |
|
|
"No prose. No explanations. JSON only." |
|
|
) |
|
|
|
|
|
|
|
|
def build_schema_from_tools(tools: List[dict]) -> dict: |
|
|
""" |
|
|
Build a strict JSON Schema that allows either: |
|
|
- { "tool_calls": [ { "name": <tool>, "arguments": <schema-per-tool> }, ... ] } |
|
|
- { "function_call": { "name": <tool>, "arguments": <schema-per-tool> } } |
|
|
- { "final_answer": <string> } |
|
|
Tools can be provided either as: |
|
|
{"function": {"name": "...", "parameters": {...}}} OR {"name": "...", "parameters": {...}} |
|
|
""" |
|
|
from copy import deepcopy |
|
|
tool_variants, defs = [], {} |
|
|
for t in tools or []: |
|
|
f = t.get("function", t) if isinstance(t, dict) else {} |
|
|
name = f.get("name") or f.get("api_call") or f.get("api_name") |
|
|
if not isinstance(name, str) or not name: |
|
|
continue |
|
|
|
|
|
params = f.get("parameters") or {"type": "object", "properties": {}, "additionalProperties": True} |
|
|
|
|
|
if isinstance(params, list): |
|
|
props = {} |
|
|
for p in params: |
|
|
if isinstance(p, dict) and "name" in p: |
|
|
nm = p["name"] |
|
|
pd = {k: v for k, v in p.items() if k != "name"} |
|
|
props[nm] = pd |
|
|
if props: |
|
|
params = {"type": "object", "properties": props} |
|
|
|
|
|
defs[f"{name}_args"] = deepcopy(params) |
|
|
tool_variants.append({ |
|
|
"type": "object", |
|
|
"properties": { |
|
|
"name": {"const": name}, |
|
|
"arguments": {"$ref": f"#/$defs/{name}_args"} |
|
|
}, |
|
|
"required": ["name", "arguments"], |
|
|
"additionalProperties": False |
|
|
}) |
|
|
|
|
|
return { |
|
|
"$schema": "https://json-schema.org/draft/2020-12/schema", |
|
|
"oneOf": [ |
|
|
{ |
|
|
"type": "object", |
|
|
"properties": { |
|
|
"tool_calls": { |
|
|
"type": "array", |
|
|
"minItems": 1, |
|
|
"items": {"oneOf": tool_variants} |
|
|
} |
|
|
}, |
|
|
"required": ["tool_calls"], |
|
|
"additionalProperties": False |
|
|
}, |
|
|
{ |
|
|
"type": "object", |
|
|
"properties": { |
|
|
"function_call": {"oneOf": tool_variants} |
|
|
}, |
|
|
"required": ["function_call"], |
|
|
"additionalProperties": False |
|
|
}, |
|
|
{ |
|
|
"type": "object", |
|
|
"properties": { |
|
|
"final_answer": {"type": "string", "minLength": 1} |
|
|
}, |
|
|
"required": ["final_answer"], |
|
|
"additionalProperties": False |
|
|
} |
|
|
], |
|
|
"$defs": defs |
|
|
} |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path: str = ""): |
|
|
""" |
|
|
If the repo contains a merged model, MODEL_ID should point to it. |
|
|
If you use adapter-only repos, modify __init__ to load base + adapter. |
|
|
""" |
|
|
model_id = path or os.getenv("MODEL_ID", ".") |
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) |
|
|
if self.tokenizer.pad_token is None: |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
self.tokenizer.padding_side = "right" |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
try: |
|
|
dtype = torch.bfloat16 if torch.cuda.get_device_capability(0)[0] >= 8 else torch.float16 |
|
|
except Exception: |
|
|
dtype = torch.float16 |
|
|
else: |
|
|
dtype = torch.float32 |
|
|
|
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
model_id, torch_dtype=dtype, device_map="auto" |
|
|
) |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
for obj in (self.model.config, self.model.generation_config): |
|
|
obj.pad_token_id = self.tokenizer.pad_token_id |
|
|
obj.eos_token_id = self.tokenizer.eos_token_id |
|
|
obj.bos_token_id = self.tokenizer.bos_token_id |
|
|
|
|
|
|
|
|
def _flat_tool(self, t: dict) -> Tuple[str, dict, List[str]]: |
|
|
f = t.get("function", t) if isinstance(t, dict) else {} |
|
|
name = f.get("name") or f.get("api_call") or f.get("api_name") or "" |
|
|
params = f.get("parameters") or {} |
|
|
prop_names: List[str] = [] |
|
|
if isinstance(params, dict): |
|
|
props = params.get("properties") |
|
|
if isinstance(props, dict): |
|
|
prop_names = list(props.keys())[:12] |
|
|
elif isinstance(props, list): |
|
|
prop_names = [p.get("name", "") for p in props if isinstance(p, dict)][:12] |
|
|
return name, params, prop_names |
|
|
|
|
|
def _render_tools_signature(self, tools: List[dict]) -> str: |
|
|
lines = [] |
|
|
for t in tools[:12]: |
|
|
name, _, pnames = self._flat_tool(t) |
|
|
if not name: |
|
|
continue |
|
|
lines.append(f"- {name}({', '.join(pnames)})" if pnames else f"- {name}()") |
|
|
return "\n".join(lines) if lines else "- (tools omitted)" |
|
|
|
|
|
def _encode_messages(self, user_text: str, tools: List[dict]): |
|
|
sig = self._render_tools_signature(tools) |
|
|
prompt = ( |
|
|
"<|system|>\n" + SYS_INSTR + "\n\n" |
|
|
"<|tools|>\n" + sig + "\n\n" |
|
|
"<|user|>\n" + user_text + "\n\n" |
|
|
"<|assistant|>\n" |
|
|
) |
|
|
toks = self.tokenizer(prompt, return_tensors="pt") |
|
|
return toks["input_ids"].to(self.model.device) |
|
|
|
|
|
|
|
|
def _unpack(self, data: Dict[str, Any]): |
|
|
""" |
|
|
Accept both: |
|
|
- {"inputs": {"messages": [...], "tools": [...], "parameters": {...}}} |
|
|
- {"messages": [...], "tools": [...], "parameters": {...}} |
|
|
- {"text": "..."} as a minimal fallback |
|
|
""" |
|
|
body = data.get("inputs", data) |
|
|
params = data.get("parameters") or (body.get("parameters") if isinstance(body, dict) else {}) or {} |
|
|
|
|
|
messages = None |
|
|
tools = None |
|
|
if isinstance(body, dict): |
|
|
messages = body.get("messages") |
|
|
tools = body.get("tools") or body.get("functions") |
|
|
if messages is None: |
|
|
messages = data.get("messages") |
|
|
if tools is None: |
|
|
tools = data.get("tools") or data.get("functions") or [] |
|
|
|
|
|
if not messages: |
|
|
raw = body if isinstance(body, str) else data.get("text", "") |
|
|
messages = [{"role": "user", "content": str(raw)}] |
|
|
|
|
|
temperature = float(params.get("temperature", data.get("temperature", 0.0))) |
|
|
max_new = int(params.get("max_new_tokens", data.get("max_new_tokens", 192))) |
|
|
top_p = float(params.get("top_p", data.get("top_p", 1.0))) |
|
|
|
|
|
|
|
|
user_text = "" |
|
|
for m in reversed(messages): |
|
|
if m.get("role") == "user": |
|
|
user_text = m.get("content", "") |
|
|
break |
|
|
|
|
|
return user_text, tools, temperature, max_new, top_p |
|
|
|
|
|
|
|
|
def _apply_guard(self, tools: List[dict], raw_text: str): |
|
|
try: |
|
|
obj = json.loads(raw_text) |
|
|
except Exception: |
|
|
|
|
|
return {"final_answer": raw_text.strip()} |
|
|
|
|
|
|
|
|
schema = build_schema_from_tools(tools) |
|
|
_ = [e.message for e in Draft202012Validator(schema).iter_errors(obj)] |
|
|
|
|
|
return obj |
|
|
|
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
user_text, tools, temperature, max_new, top_p = self._unpack(data) |
|
|
input_ids = self._encode_messages(user_text, tools) |
|
|
|
|
|
gen_kwargs = dict( |
|
|
input_ids=input_ids, |
|
|
max_new_tokens=max_new, |
|
|
eos_token_id=self.tokenizer.eos_token_id, |
|
|
pad_token_id=self.tokenizer.pad_token_id, |
|
|
) |
|
|
if temperature > 0: |
|
|
gen_kwargs.update(do_sample=True, temperature=temperature, top_p=top_p) |
|
|
else: |
|
|
gen_kwargs.update(do_sample=False) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
out = self.model.generate(**gen_kwargs) |
|
|
|
|
|
raw = self.tokenizer.decode(out[0][input_ids.shape[-1]:], skip_special_tokens=True).strip() |
|
|
envelope = self._apply_guard(tools, raw) |
|
|
return {"generated_text": raw, "envelope": envelope} |
|
|
|