analyst-buddy / server /tooling.py
hjerpe's picture
F006/F008: serve Qwen models + model switcher (vanilla-first)
656f91e verified
Raw
History Blame Contribute Delete
6.46 kB
"""Shared SQL agent tool-call contract (parser, tool defs, system prompt).
Single source of truth for the four agent-loop primitives so the evaluated
loop (``evaluation.model_policy.ModelPolicy``) and the future serving app stay
byte-identical to the RL-trained Hermes ``<tool_call>{"name","arguments"}``
format. Lives in the lightweight ``server`` package and pulls NO heavy training
deps (``trl``/``torch``/``transformers``) — ``SQLEnvTRL`` is resolved lazily
inside ``get_tool_definitions`` only when needed for introspection.
"""
import inspect
import json
import re
try:
from sql_env.models import SQLAction
except ImportError: # pragma: no cover - flat-layout / Docker fallback
from models import SQLAction # type: ignore[no-redef]
_SYSTEM_PROMPT = (
"You answer questions about a SQL database. "
"Use ONLY the provided tools.\n\n"
"Strategy:\n"
"1. Call describe(table_name=...) to see columns. Before a query that "
"JOINs tables, describe EACH table you plan to join so you use the real "
"column and key names.\n"
"2. Call query(sql=...) to run SELECT queries. If a query errors (e.g. "
"'no such column'), describe the relevant table again and fix the column "
"or join key — do NOT re-send the same failing query.\n"
"3. Call answer(value=...) to submit your final answer\n\n"
"Answer format: submit ONLY the data values from your query result.\n"
"- Single value: 42 or ford\n"
"- Multiple values: alice, bob, charlie\n"
"- Table rows: col1 | col2 (one row per line)\n"
"- No results: []\n\n"
"IMPORTANT: Call only ONE tool at a time, then read the "
"response before deciding what to do next."
)
def get_system_prompt(*, enable_thinking: bool = False) -> str:
"""Return the SQL exploration system prompt.
Parameters
----------
enable_thinking
When False (default), prepends ``/no_think`` to disable
Qwen3 thinking mode. When True, returns prompt as-is.
Returns
-------
str
Deterministic prompt text describing tool-calling strategy.
"""
if enable_thinking:
return _SYSTEM_PROMPT
return "/no_think\n" + _SYSTEM_PROMPT
def get_tool_definitions(env_cls: type | None = None) -> list[dict]:
"""Extract tool definitions from an environment class via introspection.
Inspects public methods (excluding reset and dunder) to build the
same JSON schema that TRL generates for environment_factory. This
guarantees SFT and GRPO see identical tool definitions.
"""
if env_cls is None:
try:
from sql_env.training.trl_adapter import SQLEnvTRL # noqa: PLC0415
except ImportError: # pragma: no cover - flat-layout fallback
from training.trl_adapter import SQLEnvTRL # type: ignore[no-redef] # noqa: PLC0415
env_cls = SQLEnvTRL
_SKIP = {"reset", "reward"}
tools = []
for name, method in inspect.getmembers(env_cls, predicate=inspect.isfunction):
if name.startswith("_") or name in _SKIP:
continue
sig = inspect.signature(method)
doc = inspect.getdoc(method) or ""
# Split docstring into description and Args/Returns sections
lines = doc.split("\n")
description = lines[0].strip() if lines else name
# Parse Args section for parameter descriptions
param_descriptions: dict[str, str] = {}
return_description = ""
section = ""
for line in lines[1:]:
stripped = line.strip()
if stripped.lower().startswith("args:"):
section = "args"
continue
if stripped.lower().startswith("returns:"):
section = "returns"
continue
if section == "args" and ":" in stripped:
param_name, param_desc = stripped.split(":", 1)
param_descriptions[param_name.strip()] = param_desc.strip()
if section == "returns" and stripped:
return_description = stripped
# Build parameters schema from signature
properties = {}
required = []
for param_name, param in sig.parameters.items():
if param_name == "self":
continue
properties[param_name] = {
"type": "string",
"description": param_descriptions.get(
param_name, f"{param_name} parameter."
),
}
if param.default is inspect.Parameter.empty:
required.append(param_name)
tool = {
"type": "function",
"function": {
"name": name,
"description": description,
"parameters": {
"type": "object",
"properties": properties,
"required": required,
},
},
}
if return_description:
tool["function"]["return"] = {
"type": "string",
"description": return_description,
}
tools.append(tool)
# Sort by name for deterministic ordering
tools.sort(key=lambda t: t["function"]["name"])
return tools
# Maps the tool name the model emits (lowercased method name, per
# trl_adapter.get_tool_definitions) -> (SQLAction action_type, argument key).
_TOOL_TO_ACTION: dict[str, tuple[str, str]] = {
"describe": ("DESCRIBE", "table_name"),
"sample": ("SAMPLE", "table_name"),
"query": ("QUERY", "sql"),
"answer": ("ANSWER", "value"),
}
_TOOL_CALL_RE = re.compile(r"<tool_call>\s*(\{.*?\})\s*</tool_call>", re.DOTALL)
def parse_action(completion: str) -> SQLAction | None:
"""Extract the first <tool_call> JSON and map it to an SQLAction."""
match = _TOOL_CALL_RE.search(completion)
if not match:
return None
try:
payload = json.loads(match.group(1))
except json.JSONDecodeError:
return None
name = str(payload.get("name", "")).lower()
mapping = _TOOL_TO_ACTION.get(name)
if mapping is None:
return None
action_type, arg_key = mapping
args = payload.get("arguments", {}) or {}
argument = args.get(arg_key)
if argument is None and len(args) == 1:
argument = next(iter(args.values()))
if argument is None:
return None
return SQLAction(action_type=action_type, argument=str(argument))