Spaces:
Running on Zero
Running on Zero
| """Shared SQL agent tool-call contract (parser, tool defs, system prompt). | |
| Single source of truth for the four agent-loop primitives so the evaluated | |
| loop (``evaluation.model_policy.ModelPolicy``) and the future serving app stay | |
| byte-identical to the RL-trained Hermes ``<tool_call>{"name","arguments"}`` | |
| format. Lives in the lightweight ``server`` package and pulls NO heavy training | |
| deps (``trl``/``torch``/``transformers``) — ``SQLEnvTRL`` is resolved lazily | |
| inside ``get_tool_definitions`` only when needed for introspection. | |
| """ | |
| import inspect | |
| import json | |
| import re | |
| try: | |
| from sql_env.models import SQLAction | |
| except ImportError: # pragma: no cover - flat-layout / Docker fallback | |
| from models import SQLAction # type: ignore[no-redef] | |
| _SYSTEM_PROMPT = ( | |
| "You answer questions about a SQL database. " | |
| "Use ONLY the provided tools.\n\n" | |
| "Strategy:\n" | |
| "1. Call describe(table_name=...) to see columns. Before a query that " | |
| "JOINs tables, describe EACH table you plan to join so you use the real " | |
| "column and key names.\n" | |
| "2. Call query(sql=...) to run SELECT queries. If a query errors (e.g. " | |
| "'no such column'), describe the relevant table again and fix the column " | |
| "or join key — do NOT re-send the same failing query.\n" | |
| "3. Call answer(value=...) to submit your final answer\n\n" | |
| "Answer format: submit ONLY the data values from your query result.\n" | |
| "- Single value: 42 or ford\n" | |
| "- Multiple values: alice, bob, charlie\n" | |
| "- Table rows: col1 | col2 (one row per line)\n" | |
| "- No results: []\n\n" | |
| "IMPORTANT: Call only ONE tool at a time, then read the " | |
| "response before deciding what to do next." | |
| ) | |
| def get_system_prompt(*, enable_thinking: bool = False) -> str: | |
| """Return the SQL exploration system prompt. | |
| Parameters | |
| ---------- | |
| enable_thinking | |
| When False (default), prepends ``/no_think`` to disable | |
| Qwen3 thinking mode. When True, returns prompt as-is. | |
| Returns | |
| ------- | |
| str | |
| Deterministic prompt text describing tool-calling strategy. | |
| """ | |
| if enable_thinking: | |
| return _SYSTEM_PROMPT | |
| return "/no_think\n" + _SYSTEM_PROMPT | |
| def get_tool_definitions(env_cls: type | None = None) -> list[dict]: | |
| """Extract tool definitions from an environment class via introspection. | |
| Inspects public methods (excluding reset and dunder) to build the | |
| same JSON schema that TRL generates for environment_factory. This | |
| guarantees SFT and GRPO see identical tool definitions. | |
| """ | |
| if env_cls is None: | |
| try: | |
| from sql_env.training.trl_adapter import SQLEnvTRL # noqa: PLC0415 | |
| except ImportError: # pragma: no cover - flat-layout fallback | |
| from training.trl_adapter import SQLEnvTRL # type: ignore[no-redef] # noqa: PLC0415 | |
| env_cls = SQLEnvTRL | |
| _SKIP = {"reset", "reward"} | |
| tools = [] | |
| for name, method in inspect.getmembers(env_cls, predicate=inspect.isfunction): | |
| if name.startswith("_") or name in _SKIP: | |
| continue | |
| sig = inspect.signature(method) | |
| doc = inspect.getdoc(method) or "" | |
| # Split docstring into description and Args/Returns sections | |
| lines = doc.split("\n") | |
| description = lines[0].strip() if lines else name | |
| # Parse Args section for parameter descriptions | |
| param_descriptions: dict[str, str] = {} | |
| return_description = "" | |
| section = "" | |
| for line in lines[1:]: | |
| stripped = line.strip() | |
| if stripped.lower().startswith("args:"): | |
| section = "args" | |
| continue | |
| if stripped.lower().startswith("returns:"): | |
| section = "returns" | |
| continue | |
| if section == "args" and ":" in stripped: | |
| param_name, param_desc = stripped.split(":", 1) | |
| param_descriptions[param_name.strip()] = param_desc.strip() | |
| if section == "returns" and stripped: | |
| return_description = stripped | |
| # Build parameters schema from signature | |
| properties = {} | |
| required = [] | |
| for param_name, param in sig.parameters.items(): | |
| if param_name == "self": | |
| continue | |
| properties[param_name] = { | |
| "type": "string", | |
| "description": param_descriptions.get( | |
| param_name, f"{param_name} parameter." | |
| ), | |
| } | |
| if param.default is inspect.Parameter.empty: | |
| required.append(param_name) | |
| tool = { | |
| "type": "function", | |
| "function": { | |
| "name": name, | |
| "description": description, | |
| "parameters": { | |
| "type": "object", | |
| "properties": properties, | |
| "required": required, | |
| }, | |
| }, | |
| } | |
| if return_description: | |
| tool["function"]["return"] = { | |
| "type": "string", | |
| "description": return_description, | |
| } | |
| tools.append(tool) | |
| # Sort by name for deterministic ordering | |
| tools.sort(key=lambda t: t["function"]["name"]) | |
| return tools | |
| # Maps the tool name the model emits (lowercased method name, per | |
| # trl_adapter.get_tool_definitions) -> (SQLAction action_type, argument key). | |
| _TOOL_TO_ACTION: dict[str, tuple[str, str]] = { | |
| "describe": ("DESCRIBE", "table_name"), | |
| "sample": ("SAMPLE", "table_name"), | |
| "query": ("QUERY", "sql"), | |
| "answer": ("ANSWER", "value"), | |
| } | |
| _TOOL_CALL_RE = re.compile(r"<tool_call>\s*(\{.*?\})\s*</tool_call>", re.DOTALL) | |
| def parse_action(completion: str) -> SQLAction | None: | |
| """Extract the first <tool_call> JSON and map it to an SQLAction.""" | |
| match = _TOOL_CALL_RE.search(completion) | |
| if not match: | |
| return None | |
| try: | |
| payload = json.loads(match.group(1)) | |
| except json.JSONDecodeError: | |
| return None | |
| name = str(payload.get("name", "")).lower() | |
| mapping = _TOOL_TO_ACTION.get(name) | |
| if mapping is None: | |
| return None | |
| action_type, arg_key = mapping | |
| args = payload.get("arguments", {}) or {} | |
| argument = args.get(arg_key) | |
| if argument is None and len(args) == 1: | |
| argument = next(iter(args.values())) | |
| if argument is None: | |
| return None | |
| return SQLAction(action_type=action_type, argument=str(argument)) | |