"""Shared SQL agent tool-call contract (parser, tool defs, system prompt). Single source of truth for the four agent-loop primitives so the evaluated loop (``evaluation.model_policy.ModelPolicy``) and the future serving app stay byte-identical to the RL-trained Hermes ``{"name","arguments"}`` format. Lives in the lightweight ``server`` package and pulls NO heavy training deps (``trl``/``torch``/``transformers``) — ``SQLEnvTRL`` is resolved lazily inside ``get_tool_definitions`` only when needed for introspection. """ import inspect import json import re try: from sql_env.models import SQLAction except ImportError: # pragma: no cover - flat-layout / Docker fallback from models import SQLAction # type: ignore[no-redef] _SYSTEM_PROMPT = ( "You answer questions about a SQL database. " "Use ONLY the provided tools.\n\n" "Strategy:\n" "1. Call describe(table_name=...) to see columns. Before a query that " "JOINs tables, describe EACH table you plan to join so you use the real " "column and key names.\n" "2. Call query(sql=...) to run SELECT queries. If a query errors (e.g. " "'no such column'), describe the relevant table again and fix the column " "or join key — do NOT re-send the same failing query.\n" "3. Call answer(value=...) to submit your final answer\n\n" "Answer format: submit ONLY the data values from your query result.\n" "- Single value: 42 or ford\n" "- Multiple values: alice, bob, charlie\n" "- Table rows: col1 | col2 (one row per line)\n" "- No results: []\n\n" "IMPORTANT: Call only ONE tool at a time, then read the " "response before deciding what to do next." ) def get_system_prompt(*, enable_thinking: bool = False) -> str: """Return the SQL exploration system prompt. Parameters ---------- enable_thinking When False (default), prepends ``/no_think`` to disable Qwen3 thinking mode. When True, returns prompt as-is. Returns ------- str Deterministic prompt text describing tool-calling strategy. """ if enable_thinking: return _SYSTEM_PROMPT return "/no_think\n" + _SYSTEM_PROMPT def get_tool_definitions(env_cls: type | None = None) -> list[dict]: """Extract tool definitions from an environment class via introspection. Inspects public methods (excluding reset and dunder) to build the same JSON schema that TRL generates for environment_factory. This guarantees SFT and GRPO see identical tool definitions. """ if env_cls is None: try: from sql_env.training.trl_adapter import SQLEnvTRL # noqa: PLC0415 except ImportError: # pragma: no cover - flat-layout fallback from training.trl_adapter import SQLEnvTRL # type: ignore[no-redef] # noqa: PLC0415 env_cls = SQLEnvTRL _SKIP = {"reset", "reward"} tools = [] for name, method in inspect.getmembers(env_cls, predicate=inspect.isfunction): if name.startswith("_") or name in _SKIP: continue sig = inspect.signature(method) doc = inspect.getdoc(method) or "" # Split docstring into description and Args/Returns sections lines = doc.split("\n") description = lines[0].strip() if lines else name # Parse Args section for parameter descriptions param_descriptions: dict[str, str] = {} return_description = "" section = "" for line in lines[1:]: stripped = line.strip() if stripped.lower().startswith("args:"): section = "args" continue if stripped.lower().startswith("returns:"): section = "returns" continue if section == "args" and ":" in stripped: param_name, param_desc = stripped.split(":", 1) param_descriptions[param_name.strip()] = param_desc.strip() if section == "returns" and stripped: return_description = stripped # Build parameters schema from signature properties = {} required = [] for param_name, param in sig.parameters.items(): if param_name == "self": continue properties[param_name] = { "type": "string", "description": param_descriptions.get( param_name, f"{param_name} parameter." ), } if param.default is inspect.Parameter.empty: required.append(param_name) tool = { "type": "function", "function": { "name": name, "description": description, "parameters": { "type": "object", "properties": properties, "required": required, }, }, } if return_description: tool["function"]["return"] = { "type": "string", "description": return_description, } tools.append(tool) # Sort by name for deterministic ordering tools.sort(key=lambda t: t["function"]["name"]) return tools # Maps the tool name the model emits (lowercased method name, per # trl_adapter.get_tool_definitions) -> (SQLAction action_type, argument key). _TOOL_TO_ACTION: dict[str, tuple[str, str]] = { "describe": ("DESCRIBE", "table_name"), "sample": ("SAMPLE", "table_name"), "query": ("QUERY", "sql"), "answer": ("ANSWER", "value"), } _TOOL_CALL_RE = re.compile(r"\s*(\{.*?\})\s*", re.DOTALL) def parse_action(completion: str) -> SQLAction | None: """Extract the first JSON and map it to an SQLAction.""" match = _TOOL_CALL_RE.search(completion) if not match: return None try: payload = json.loads(match.group(1)) except json.JSONDecodeError: return None name = str(payload.get("name", "")).lower() mapping = _TOOL_TO_ACTION.get(name) if mapping is None: return None action_type, arg_key = mapping args = payload.get("arguments", {}) or {} argument = args.get(arg_key) if argument is None and len(args) == 1: argument = next(iter(args.values())) if argument is None: return None return SQLAction(action_type=action_type, argument=str(argument))