Falcon-Twig-7B / handler.py
younissk's picture
Update handler.py
359a930 verified
# handler.py — Falcon H1 7B tool-calling for Hugging Face Inference Endpoints
# This handler builds the exact prompt format used in training and returns:
# {
# "generated_text": "<raw model output>",
# "envelope": { "tool_calls": [...]} | {"function_call": {...}} | {"final_answer": "..."}
# }
from typing import Dict, Any, List, Tuple
import os
import json
import re
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from jsonschema import Draft202012Validator
# ---------- Prompt instruction (match training) ----------
SYS_INSTR = (
"You're a tool-calling assistant. "
"Return ONLY valid JSON for your answer, with this exact shape:\n"
"{\"tool_calls\": [{\"name\": \"<function_name>\", \"arguments\": {<key>: <value>, ...}}, ...]}\n"
"No prose. No explanations. JSON only."
)
# ---------- Schema builder (accepts both tool shapes) ----------
def build_schema_from_tools(tools: List[dict]) -> dict:
"""
Build a strict JSON Schema that allows either:
- { "tool_calls": [ { "name": <tool>, "arguments": <schema-per-tool> }, ... ] }
- { "function_call": { "name": <tool>, "arguments": <schema-per-tool> } }
- { "final_answer": <string> }
Tools can be provided either as:
{"function": {"name": "...", "parameters": {...}}} OR {"name": "...", "parameters": {...}}
"""
from copy import deepcopy
tool_variants, defs = [], {}
for t in tools or []:
f = t.get("function", t) if isinstance(t, dict) else {}
name = f.get("name") or f.get("api_call") or f.get("api_name")
if not isinstance(name, str) or not name:
continue
params = f.get("parameters") or {"type": "object", "properties": {}, "additionalProperties": True}
# Normalize list-of-params to object.properties form
if isinstance(params, list):
props = {}
for p in params:
if isinstance(p, dict) and "name" in p:
nm = p["name"]
pd = {k: v for k, v in p.items() if k != "name"}
props[nm] = pd
if props:
params = {"type": "object", "properties": props}
defs[f"{name}_args"] = deepcopy(params)
tool_variants.append({
"type": "object",
"properties": {
"name": {"const": name},
"arguments": {"$ref": f"#/$defs/{name}_args"}
},
"required": ["name", "arguments"],
"additionalProperties": False
})
return {
"$schema": "https://json-schema.org/draft/2020-12/schema",
"oneOf": [
{
"type": "object",
"properties": {
"tool_calls": {
"type": "array",
"minItems": 1,
"items": {"oneOf": tool_variants}
}
},
"required": ["tool_calls"],
"additionalProperties": False
},
{
"type": "object",
"properties": {
"function_call": {"oneOf": tool_variants}
},
"required": ["function_call"],
"additionalProperties": False
},
{
"type": "object",
"properties": {
"final_answer": {"type": "string", "minLength": 1}
},
"required": ["final_answer"],
"additionalProperties": False
}
],
"$defs": defs
}
# ---------- Main handler ----------
class EndpointHandler:
def __init__(self, path: str = ""):
"""
If the repo contains a merged model, MODEL_ID should point to it.
If you use adapter-only repos, modify __init__ to load base + adapter.
"""
model_id = path or os.getenv("MODEL_ID", ".")
# Tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Match training (we trained with right padding)
self.tokenizer.padding_side = "right"
# Choose dtype
if torch.cuda.is_available():
try:
dtype = torch.bfloat16 if torch.cuda.get_device_capability(0)[0] >= 8 else torch.float16
except Exception:
dtype = torch.float16
else:
dtype = torch.float32
# Model
self.model = AutoModelForCausalLM.from_pretrained(
model_id, torch_dtype=dtype, device_map="auto"
)
self.model.eval()
# Keep special tokens consistent
for obj in (self.model.config, self.model.generation_config):
obj.pad_token_id = self.tokenizer.pad_token_id
obj.eos_token_id = self.tokenizer.eos_token_id
obj.bos_token_id = self.tokenizer.bos_token_id
# ---- tools signature (exact format used in training) ----
def _flat_tool(self, t: dict) -> Tuple[str, dict, List[str]]:
f = t.get("function", t) if isinstance(t, dict) else {}
name = f.get("name") or f.get("api_call") or f.get("api_name") or ""
params = f.get("parameters") or {}
prop_names: List[str] = []
if isinstance(params, dict):
props = params.get("properties")
if isinstance(props, dict):
prop_names = list(props.keys())[:12]
elif isinstance(props, list):
prop_names = [p.get("name", "") for p in props if isinstance(p, dict)][:12]
return name, params, prop_names
def _render_tools_signature(self, tools: List[dict]) -> str:
lines = []
for t in tools[:12]:
name, _, pnames = self._flat_tool(t)
if not name:
continue
lines.append(f"- {name}({', '.join(pnames)})" if pnames else f"- {name}()")
return "\n".join(lines) if lines else "- (tools omitted)"
def _encode_messages(self, user_text: str, tools: List[dict]):
sig = self._render_tools_signature(tools)
prompt = (
"<|system|>\n" + SYS_INSTR + "\n\n"
"<|tools|>\n" + sig + "\n\n"
"<|user|>\n" + user_text + "\n\n"
"<|assistant|>\n"
)
toks = self.tokenizer(prompt, return_tensors="pt")
return toks["input_ids"].to(self.model.device)
# ---- request parsing / params ----
def _unpack(self, data: Dict[str, Any]):
"""
Accept both:
- {"inputs": {"messages": [...], "tools": [...], "parameters": {...}}}
- {"messages": [...], "tools": [...], "parameters": {...}}
- {"text": "..."} as a minimal fallback
"""
body = data.get("inputs", data)
params = data.get("parameters") or (body.get("parameters") if isinstance(body, dict) else {}) or {}
messages = None
tools = None
if isinstance(body, dict):
messages = body.get("messages")
tools = body.get("tools") or body.get("functions")
if messages is None:
messages = data.get("messages")
if tools is None:
tools = data.get("tools") or data.get("functions") or []
if not messages:
raw = body if isinstance(body, str) else data.get("text", "")
messages = [{"role": "user", "content": str(raw)}]
temperature = float(params.get("temperature", data.get("temperature", 0.0)))
max_new = int(params.get("max_new_tokens", data.get("max_new_tokens", 192)))
top_p = float(params.get("top_p", data.get("top_p", 1.0)))
# last user message text
user_text = ""
for m in reversed(messages):
if m.get("role") == "user":
user_text = m.get("content", "")
break
return user_text, tools, temperature, max_new, top_p
# ---- best-effort validation (no canonicalization) ----
def _apply_guard(self, tools: List[dict], raw_text: str):
try:
obj = json.loads(raw_text)
except Exception:
# Model did not emit JSON → wrap as final answer
return {"final_answer": raw_text.strip()}
# Validate against per-request schema (non-blocking)
schema = build_schema_from_tools(tools)
_ = [e.message for e in Draft202012Validator(schema).iter_errors(obj)]
# We return the object regardless of validation outcome (best effort).
return obj
# ---- entrypoint ----
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
user_text, tools, temperature, max_new, top_p = self._unpack(data)
input_ids = self._encode_messages(user_text, tools)
gen_kwargs = dict(
input_ids=input_ids,
max_new_tokens=max_new,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
)
if temperature > 0:
gen_kwargs.update(do_sample=True, temperature=temperature, top_p=top_p)
else:
gen_kwargs.update(do_sample=False)
with torch.inference_mode():
out = self.model.generate(**gen_kwargs)
raw = self.tokenizer.decode(out[0][input_ids.shape[-1]:], skip_special_tokens=True).strip()
envelope = self._apply_guard(tools, raw)
return {"generated_text": raw, "envelope": envelope}