File size: 9,555 Bytes
359a930
 
 
 
 
 
 
 
 
 
 
 
b433043
 
 
359a930
 
 
 
 
 
 
b433043
359a930
b433043
359a930
 
 
 
 
 
 
 
b433043
 
9319bba
 
 
 
b433043
359a930
 
 
 
9319bba
 
 
359a930
 
9319bba
 
359a930
 
b433043
 
359a930
 
 
 
 
 
b433043
 
9319bba
b433043
359a930
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b433043
 
 
 
359a930
b433043
 
359a930
 
 
 
b433043
359a930
 
b433043
 
 
359a930
 
b433043
359a930
b433043
 
 
 
 
 
 
 
359a930
b433043
 
 
 
 
359a930
 
 
 
 
b433043
359a930
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b433043
359a930
3e6e18c
 
359a930
 
 
 
 
 
 
 
3e6e18c
 
 
 
 
 
 
 
 
 
b433043
3e6e18c
 
 
 
 
 
 
359a930
b433043
9319bba
b433043
 
3e6e18c
b433043
359a930
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9319bba
3e6e18c
 
 
 
 
359a930
3e6e18c
 
 
 
 
b433043
 
3e6e18c
b433043
3e6e18c
359a930
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
# handler.py — Falcon H1 7B tool-calling for Hugging Face Inference Endpoints
# This handler builds the exact prompt format used in training and returns:
# {
#   "generated_text": "<raw model output>",
#   "envelope": { "tool_calls": [...]} | {"function_call": {...}} | {"final_answer": "..."}
# }

from typing import Dict, Any, List, Tuple
import os
import json
import re
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from jsonschema import Draft202012Validator

# ---------- Prompt instruction (match training) ----------
SYS_INSTR = (
    "You're a tool-calling assistant. "
    "Return ONLY valid JSON for your answer, with this exact shape:\n"
    "{\"tool_calls\": [{\"name\": \"<function_name>\", \"arguments\": {<key>: <value>, ...}}, ...]}\n"
    "No prose. No explanations. JSON only."
)

# ---------- Schema builder (accepts both tool shapes) ----------
def build_schema_from_tools(tools: List[dict]) -> dict:
    """
    Build a strict JSON Schema that allows either:
      - { "tool_calls": [ { "name": <tool>, "arguments": <schema-per-tool> }, ... ] }
      - { "function_call": { "name": <tool>, "arguments": <schema-per-tool> } }
      - { "final_answer": <string> }
    Tools can be provided either as:
      {"function": {"name": "...", "parameters": {...}}}  OR  {"name": "...", "parameters": {...}}
    """
    from copy import deepcopy
    tool_variants, defs = [], {}
    for t in tools or []:
        f = t.get("function", t) if isinstance(t, dict) else {}
        name = f.get("name") or f.get("api_call") or f.get("api_name")
        if not isinstance(name, str) or not name:
            continue

        params = f.get("parameters") or {"type": "object", "properties": {}, "additionalProperties": True}
        # Normalize list-of-params to object.properties form
        if isinstance(params, list):
            props = {}
            for p in params:
                if isinstance(p, dict) and "name" in p:
                    nm = p["name"]
                    pd = {k: v for k, v in p.items() if k != "name"}
                    props[nm] = pd
            if props:
                params = {"type": "object", "properties": props}

        defs[f"{name}_args"] = deepcopy(params)
        tool_variants.append({
            "type": "object",
            "properties": {
                "name": {"const": name},
                "arguments": {"$ref": f"#/$defs/{name}_args"}
            },
            "required": ["name", "arguments"],
            "additionalProperties": False
        })

    return {
        "$schema": "https://json-schema.org/draft/2020-12/schema",
        "oneOf": [
            {
                "type": "object",
                "properties": {
                    "tool_calls": {
                        "type": "array",
                        "minItems": 1,
                        "items": {"oneOf": tool_variants}
                    }
                },
                "required": ["tool_calls"],
                "additionalProperties": False
            },
            {
                "type": "object",
                "properties": {
                    "function_call": {"oneOf": tool_variants}
                },
                "required": ["function_call"],
                "additionalProperties": False
            },
            {
                "type": "object",
                "properties": {
                    "final_answer": {"type": "string", "minLength": 1}
                },
                "required": ["final_answer"],
                "additionalProperties": False
            }
        ],
        "$defs": defs
    }

# ---------- Main handler ----------
class EndpointHandler:
    def __init__(self, path: str = ""):
        """
        If the repo contains a merged model, MODEL_ID should point to it.
        If you use adapter-only repos, modify __init__ to load base + adapter.
        """
        model_id = path or os.getenv("MODEL_ID", ".")

        # Tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        # Match training (we trained with right padding)
        self.tokenizer.padding_side = "right"

        # Choose dtype
        if torch.cuda.is_available():
            try:
                dtype = torch.bfloat16 if torch.cuda.get_device_capability(0)[0] >= 8 else torch.float16
            except Exception:
                dtype = torch.float16
        else:
            dtype = torch.float32

        # Model
        self.model = AutoModelForCausalLM.from_pretrained(
            model_id, torch_dtype=dtype, device_map="auto"
        )
        self.model.eval()

        # Keep special tokens consistent
        for obj in (self.model.config, self.model.generation_config):
            obj.pad_token_id = self.tokenizer.pad_token_id
            obj.eos_token_id = self.tokenizer.eos_token_id
            obj.bos_token_id = self.tokenizer.bos_token_id

    # ---- tools signature (exact format used in training) ----
    def _flat_tool(self, t: dict) -> Tuple[str, dict, List[str]]:
        f = t.get("function", t) if isinstance(t, dict) else {}
        name = f.get("name") or f.get("api_call") or f.get("api_name") or ""
        params = f.get("parameters") or {}
        prop_names: List[str] = []
        if isinstance(params, dict):
            props = params.get("properties")
            if isinstance(props, dict):
                prop_names = list(props.keys())[:12]
            elif isinstance(props, list):
                prop_names = [p.get("name", "") for p in props if isinstance(p, dict)][:12]
        return name, params, prop_names

    def _render_tools_signature(self, tools: List[dict]) -> str:
        lines = []
        for t in tools[:12]:
            name, _, pnames = self._flat_tool(t)
            if not name:
                continue
            lines.append(f"- {name}({', '.join(pnames)})" if pnames else f"- {name}()")
        return "\n".join(lines) if lines else "- (tools omitted)"

    def _encode_messages(self, user_text: str, tools: List[dict]):
        sig = self._render_tools_signature(tools)
        prompt = (
            "<|system|>\n" + SYS_INSTR + "\n\n"
            "<|tools|>\n" + sig + "\n\n"
            "<|user|>\n" + user_text + "\n\n"
            "<|assistant|>\n"
        )
        toks = self.tokenizer(prompt, return_tensors="pt")
        return toks["input_ids"].to(self.model.device)

    # ---- request parsing / params ----
    def _unpack(self, data: Dict[str, Any]):
        """
        Accept both:
          - {"inputs": {"messages": [...], "tools": [...], "parameters": {...}}}
          - {"messages": [...], "tools": [...], "parameters": {...}}
          - {"text": "..."} as a minimal fallback
        """
        body = data.get("inputs", data)
        params = data.get("parameters") or (body.get("parameters") if isinstance(body, dict) else {}) or {}

        messages = None
        tools = None
        if isinstance(body, dict):
            messages = body.get("messages")
            tools = body.get("tools") or body.get("functions")
        if messages is None:
            messages = data.get("messages")
        if tools is None:
            tools = data.get("tools") or data.get("functions") or []

        if not messages:
            raw = body if isinstance(body, str) else data.get("text", "")
            messages = [{"role": "user", "content": str(raw)}]

        temperature = float(params.get("temperature", data.get("temperature", 0.0)))
        max_new = int(params.get("max_new_tokens", data.get("max_new_tokens", 192)))
        top_p = float(params.get("top_p", data.get("top_p", 1.0)))

        # last user message text
        user_text = ""
        for m in reversed(messages):
            if m.get("role") == "user":
                user_text = m.get("content", "")
                break

        return user_text, tools, temperature, max_new, top_p

    # ---- best-effort validation (no canonicalization) ----
    def _apply_guard(self, tools: List[dict], raw_text: str):
        try:
            obj = json.loads(raw_text)
        except Exception:
            # Model did not emit JSON → wrap as final answer
            return {"final_answer": raw_text.strip()}

        # Validate against per-request schema (non-blocking)
        schema = build_schema_from_tools(tools)
        _ = [e.message for e in Draft202012Validator(schema).iter_errors(obj)]
        # We return the object regardless of validation outcome (best effort).
        return obj

    # ---- entrypoint ----
    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        user_text, tools, temperature, max_new, top_p = self._unpack(data)
        input_ids = self._encode_messages(user_text, tools)

        gen_kwargs = dict(
            input_ids=input_ids,
            max_new_tokens=max_new,
            eos_token_id=self.tokenizer.eos_token_id,
            pad_token_id=self.tokenizer.pad_token_id,
        )
        if temperature > 0:
            gen_kwargs.update(do_sample=True, temperature=temperature, top_p=top_p)
        else:
            gen_kwargs.update(do_sample=False)

        with torch.inference_mode():
            out = self.model.generate(**gen_kwargs)

        raw = self.tokenizer.decode(out[0][input_ids.shape[-1]:], skip_special_tokens=True).strip()
        envelope = self._apply_guard(tools, raw)
        return {"generated_text": raw, "envelope": envelope}