analyst-buddy / server /sql_environment.py
hjerpe's picture
F006/F008: serve Qwen models + model switcher (vanilla-first)
656f91e verified
Raw
History Blame Contribute Delete
34.1 kB
import json
import logging
from pathlib import Path
import random
import re
import sqlite3
import time
import uuid
from typing import Protocol
from .reward import compute_step_reward
from .verifier import verify_answer
# ``chart_intent`` is gradio-free + dependency-light (stdlib + pydantic only), so
# importing it here does NOT pull gradio/torch/trl/transformers into the env. It
# is the SINGLE strip site for the ``​```chart {…}```​`` block (display AND
# scoring): ``_handle_answer`` strips the block from the model's ANSWER value
# before ``verify_answer`` so a prose+block answer still matches its gold value.
try:
from .chart_intent import strip_chart_block
from .sql_ident import is_valid_identifier, quote_ident
except ImportError: # pragma: no cover - flat-layout / direct-run fallback
from chart_intent import strip_chart_block # type: ignore[no-redef]
from sql_ident import is_valid_identifier, quote_ident # type: ignore[no-redef]
class ModelTokenizer(Protocol):
"""Minimal tokenizer contract the environment relies on.
Replaces OpenEnv's ModelTokenizer interface. Any object exposing
``apply_chat_template`` (HuggingFace tokenizers, MockTokenizer, the
training adapter's stub) satisfies it.
"""
def apply_chat_template(self, messages: list[dict[str, str]], **kwargs) -> str: ...
try:
from sql_env.models import (
EpisodeContext,
QuestionRecord,
SQLAction,
SQLObservation,
SQLState,
)
except ImportError:
# Fallback for Docker where PYTHONPATH=/app/env
from models import ( # type: ignore[no-redef]
EpisodeContext,
QuestionRecord,
SQLAction,
SQLObservation,
SQLState,
)
logger = logging.getLogger(__name__)
_TABLE_FROM_JOIN_PATTERN = re.compile(
r"\b(?:FROM|JOIN)\s+([A-Za-z_][A-Za-z0-9_]*)", re.IGNORECASE
)
_FIRST_KEYWORD_PATTERN = re.compile(r"^[\s\n\r\t]*(\w+)")
def resolve_db_path(db_dir: str | Path, db_id: str) -> Path | None:
"""Resolve the existing ``.sqlite`` file for ``db_id`` under ``db_dir``.
The SINGLE source of the db-id -> file resolution AND its path-traversal
defense, shared by ``SQLEnvironment._open_db`` (episode setup) and
``agent_loop._resolve_db_path`` (read-only re-exec). Tries the two layout
candidates — ``<root>/<id>/<id>.sqlite`` then ``<root>/<id>.sqlite`` — and
returns the first that BOTH exists AND passes the containment guard
(``.resolve()`` + ``db_root in candidate.parents``), so a ``db_id`` like
``../escape`` can never resolve a file outside ``db_dir``. Returns ``None``
when no contained candidate exists.
Identifier-charset validation (``^[A-Za-z0-9_]+$``) is the CALLER's
responsibility — ``_open_db`` enforces it and raises its own error; the
containment guard here is defense-in-depth that stands on its own.
"""
root = Path(db_dir)
db_root = root.resolve()
candidates = [
(root / db_id / f"{db_id}.sqlite").resolve(),
(root / f"{db_id}.sqlite").resolve(),
]
for candidate in candidates:
if candidate.exists() and db_root in candidate.parents:
return candidate
return None
class HarnessError(RuntimeError):
"""A broken *episode setup* — never a model failure.
Raised by ``reset()`` when the environment itself cannot produce a valid
episode: the gold SQL errors or times out, the database file is missing, or
the gold answer is empty/degenerate. Such episodes must be excluded from
training (they would poison the gradient), not scored as model mistakes.
The ``reason`` is a stable key (``db_missing`` | ``gold_sql_error`` |
``gold_empty``) so failures can be counted and triaged.
"""
def __init__(self, reason: str, detail: str = "", question_id: str | None = None):
self.reason = reason
self.detail = detail
self.question_id = question_id
suffix = f" (question {question_id})" if question_id else ""
super().__init__(f"[{reason}] {detail}{suffix}")
def is_degenerate_gold(rows: list[tuple], gold_answer: str) -> bool:
"""True if a gold result is unusable as a training target.
Empty result sets and single all-NULL rows (e.g. MAX()/MIN() over an empty
set) give an empty/ambiguous gold answer a model can fluke-match (answering
"" or "None"), so they are treated as harness/data failures rather than
legitimate episodes. The all-NULL check is applied regardless of column
count so single- and multi-column degenerate rows are handled consistently.
"""
if not rows:
return True
if len(rows) == 1 and all(value is None for value in rows[0]):
return True
return not gold_answer.strip()
class SQLEnvironment:
"""SQLEnv implementation with a structured SQL action loop.
Runs in-process (TRL training calls reset()/step() directly). Formerly an
OpenEnv ``Environment`` subclass; the base class only stored an optional
transform, which this environment never used, so it is now standalone.
"""
def __init__(
self,
questions_path: str,
db_dir: str,
tokenizer: ModelTokenizer,
step_budget: int = 15,
):
if not hasattr(tokenizer, "apply_chat_template"):
raise ValueError("Tokenizer must have 'apply_chat_template' method")
if step_budget <= 0:
raise ValueError("step_budget must be a positive integer")
questions_file = Path(questions_path)
database_dir = Path(db_dir)
if not questions_file.exists():
raise FileNotFoundError(f"Questions file not found: {questions_file}")
if not database_dir.exists() or not database_dir.is_dir():
raise FileNotFoundError(f"Database directory not found: {database_dir}")
self.tokenizer = tokenizer
self.questions_path = questions_file
self.db_dir = database_dir
self.step_budget = step_budget
self.questions = self._load_questions(str(questions_file))
if not self.questions:
raise ValueError("Questions file contains no questions")
self._episode: EpisodeContext | None = None
self._last_result = ""
self._last_error = ""
self._last_reward: float | None = None
self._last_query_truncated = False
self._state = SQLState()
def _extract_tables_from_sql(self, sql: str) -> list[str]:
"""Extract table names from basic FROM/JOIN clauses."""
tables: list[str] = []
for match in _TABLE_FROM_JOIN_PATTERN.findall(sql):
if match not in tables:
tables.append(match)
return tables
def _load_questions(self, path: str) -> list[QuestionRecord]:
"""Load Spider questions JSON into QuestionRecord instances."""
questions_path = Path(path)
if not questions_path.exists():
raise FileNotFoundError(f"Questions file not found: {questions_path}")
try:
with questions_path.open("r", encoding="utf-8") as handle:
payload = json.load(handle)
except json.JSONDecodeError as exc:
raise ValueError(
f"Invalid questions JSON format: {questions_path}"
) from exc
if not isinstance(payload, list):
raise ValueError("Questions JSON must be an array of records")
question_records: list[QuestionRecord] = []
for idx, item in enumerate(payload):
if not isinstance(item, dict):
raise ValueError(f"Question at index {idx} must be an object")
# Support both raw Spider format and curated format
question_text = item.get("question_text") or item.get("question")
db_name = item.get("database_name") or item.get("db_id")
gold_sql = item.get("gold_sql") or item.get("query")
if not isinstance(question_text, str) or not question_text.strip():
raise ValueError(
f"Question at index {idx} missing non-empty 'question'"
)
if not isinstance(db_name, str) or not db_name.strip():
raise ValueError(f"Question at index {idx} missing non-empty 'db_id'")
if not isinstance(gold_sql, str) or not gold_sql.strip():
raise ValueError(f"Question at index {idx} missing non-empty 'query'")
normalized_db_name = db_name.strip()
if not is_valid_identifier(normalized_db_name):
raise ValueError(
f"Question at index {idx} has invalid db_id '{normalized_db_name}'"
)
gold_answer = item.get("gold_answer", "")
if not isinstance(gold_answer, str):
gold_answer = str(gold_answer)
question_records.append(
QuestionRecord(
question_id=item.get("question_id", f"q-{idx}"),
question_text=question_text,
database_name=normalized_db_name,
gold_sql=gold_sql,
gold_answer=gold_answer,
answer_type=item.get("answer_type", "string"),
difficulty=item.get("difficulty", "medium"),
tables_involved=item.get("tables_involved")
or self._extract_tables_from_sql(gold_sql),
)
)
return question_records
def _open_db(self, db_name: str) -> sqlite3.Connection:
"""Open a read-only SQLite connection for the requested database."""
normalized_db_name = db_name.strip()
if not is_valid_identifier(normalized_db_name):
raise ValueError(f"Invalid database name: '{db_name}'")
db_path = resolve_db_path(self.db_dir, normalized_db_name)
if db_path is None:
raise FileNotFoundError(
f"Database '{normalized_db_name}' not found in {self.db_dir}"
)
uri = f"file:{db_path}?mode=ro"
return sqlite3.connect(uri, uri=True)
def _format_gold_answer(self, rows: list[tuple]) -> str:
"""Convert SQL rows into a stable string answer for episode comparison."""
if not rows:
return ""
if len(rows) == 1 and len(rows[0]) == 1:
return str(rows[0][0])
return "\n".join(" | ".join(str(value) for value in row) for row in rows)
def _execute_gold_sql(
self,
connection: sqlite3.Connection,
sql: str,
timeout_s: float = 5.0,
) -> list[tuple]:
"""Execute gold SQL with read-only/SELECT-only timeout protections."""
sql_stripped = sql.strip()
if not sql_stripped:
raise ValueError("SQL query cannot be empty")
first_keyword_match = _FIRST_KEYWORD_PATTERN.match(sql_stripped)
first_keyword = (
first_keyword_match.group(1).upper() if first_keyword_match else ""
)
if first_keyword not in ("SELECT", "WITH"):
raise ValueError(f"Only SELECT queries are allowed. Got: {first_keyword}")
deadline = time.monotonic() + timeout_s
def _progress_callback() -> int:
return 1 if time.monotonic() > deadline else 0
connection.set_progress_handler(_progress_callback, 1000)
try:
cursor = connection.cursor()
cursor.execute(sql_stripped)
return cursor.fetchall()
except sqlite3.OperationalError as exc:
if "interrupted" in str(exc).lower():
raise sqlite3.OperationalError(
f"Query timed out after {timeout_s:.1f} seconds"
) from exc
raise
finally:
connection.set_progress_handler(None, 0)
def reset(
self,
*,
seed: int | None = None,
episode_id: str | None = None,
question_index: int | None = None,
**kwargs,
) -> SQLObservation:
"""Reset episode context and return the initial rich observation.
``question_index`` (optional) selects a SPECIFIC question deterministically
instead of the random ``seed`` draw — the full-set eval protocol iterates
every question exactly once with it, eliminating the sampling-with-
replacement noise that made N=50 evals unreliable.
"""
del kwargs
if self._episode is not None:
self._episode.db_connection.close()
if question_index is not None:
question = self.questions[question_index % len(self.questions)]
else:
chooser = random.Random(seed) if seed is not None else random
question = chooser.choice(self.questions)
# --- Harness guardrail: fail fast on a broken episode setup ---------
# A broken gold answer is a HARNESS failure, not a model failure. We
# raise HarnessError so the caller can count + exclude it instead of
# silently training on an empty/degenerate gold (which poisons the
# gradient). All failure paths close the connection (no leaks).
try:
connection = self._open_db(question.database_name)
except (FileNotFoundError, ValueError) as exc:
raise HarnessError("db_missing", str(exc), question.question_id) from exc
try:
gold_rows = self._execute_gold_sql(connection, question.gold_sql)
except Exception as exc:
connection.close()
raise HarnessError(
"gold_sql_error", str(exc), question.question_id
) from exc
gold_answer = self._format_gold_answer(gold_rows)
if is_degenerate_gold(gold_rows, gold_answer):
connection.close()
raise HarnessError(
"gold_empty",
"gold SQL returned an empty or degenerate result set",
question.question_id,
)
question_for_episode = QuestionRecord(
question_id=question.question_id,
question_text=question.question_text,
database_name=question.database_name,
gold_sql=question.gold_sql,
gold_answer=gold_answer,
answer_type=question.answer_type,
difficulty=question.difficulty,
tables_involved=list(question.tables_involved),
)
resolved_episode_id = episode_id or str(uuid.uuid4())
self._episode = EpisodeContext(
episode_id=resolved_episode_id,
db_connection=connection,
question_record=question_for_episode,
step_count=0,
budget=self.step_budget,
done=False,
gold_answer=gold_answer,
gold_rows=gold_rows,
)
self._state.episode_id = resolved_episode_id
self._state.step_count = 0
self._state.current_action_type = "QUERY"
self._state.history_messages = []
self._last_result = ""
self._last_error = ""
self._last_reward = None
self._last_query_truncated = False
return self._build_observation()
def begin_episode(
self,
db_id: str,
question: str,
*,
episode_id: str | None = None,
gold: QuestionRecord | None = None,
) -> SQLObservation:
"""Start a NON-GOLD episode for a user question against ``db_id``.
The seam the demo loop (``agent_loop.run_agent_turn``) uses INSTEAD of
``reset()`` (which is gold-coupled: random gold question + gold SQL +
``verify_answer`` scoring). Opens the DB read-only via the existing
``_open_db``, constructs an ``EpisodeContext`` with ``gold_answer=None``
and ``gold_rows=[]`` (no scoring target), and returns the initial
observation via the existing ``_build_observation()``. A non-gold ANSWER
terminates without scoring (see ``_handle_answer``).
``reset()``, ``step()``, and the gold scoring path are unchanged.
Args:
db_id: database to open (resolved by ``_open_db``).
question: the user's plain-English question (no gold answer exists).
episode_id: optional explicit id (a uuid is generated otherwise).
gold: reserved for a future gold-seeded variant; ``None`` (the default)
is the non-gold path. When supplied it carries the scoring target.
Returns:
the initial ``SQLObservation`` for the started episode.
Raises:
FileNotFoundError / ValueError: db_id missing or invalid (from
``_open_db``) — a setup error, surfaced to the caller.
"""
if self._episode is not None:
self._episode.db_connection.close()
connection = self._open_db(db_id)
question_for_episode = QuestionRecord(
question_id=episode_id or "user-question",
question_text=question,
database_name=db_id.strip(),
gold_sql=gold.gold_sql if gold is not None else "",
gold_answer=gold.gold_answer if gold is not None else "",
answer_type=gold.answer_type if gold is not None else "string",
difficulty=gold.difficulty if gold is not None else "medium",
tables_involved=list(gold.tables_involved) if gold is not None else [],
)
resolved_episode_id = episode_id or str(uuid.uuid4())
self._episode = EpisodeContext(
episode_id=resolved_episode_id,
db_connection=connection,
question_record=question_for_episode,
step_count=0,
budget=self.step_budget,
done=False,
gold_answer=gold.gold_answer if gold is not None else None,
gold_rows=[],
)
self._state.episode_id = resolved_episode_id
self._state.step_count = 0
self._state.current_action_type = "QUERY"
self._state.history_messages = []
self._last_result = ""
self._last_error = ""
self._last_reward = None
self._last_query_truncated = False
return self._build_observation()
def _get_table_names(self, connection: sqlite3.Connection) -> list[str]:
"""Return user-visible table names for the active SQLite database."""
cursor = connection.cursor()
cursor.execute(
"""
SELECT name
FROM sqlite_master
WHERE type = 'table' AND name NOT LIKE 'sqlite_%'
ORDER BY name
"""
)
return [str(row[0]) for row in cursor.fetchall()]
def _resolve_table_name(self, table_name: str) -> tuple[str | None, list[str]]:
"""Resolve requested table name against active DB tables."""
if self._episode is None:
return None, []
available_tables = self._get_table_names(self._episode.db_connection)
lookup = {table.lower(): table for table in available_tables}
resolved = lookup.get(table_name.strip().lower())
return resolved, available_tables
def _format_rows(self, rows: list[tuple]) -> str:
"""Format SQL rows as readable text."""
if not rows:
return "No rows returned."
lines = [
f"{idx}. {' | '.join(str(value) for value in row)}"
for idx, row in enumerate(rows, start=1)
]
return "\n".join(lines)
def _execute_sql(self, sql: str, timeout_s: float = 5.0) -> list[tuple]:
"""Execute SQL in sandbox: SELECT-only, single statement, timeout, truncation."""
if self._episode is None:
raise RuntimeError("No active episode. Call reset() before step().")
sql_stripped = sql.strip()
if not sql_stripped:
raise ValueError("SQL query cannot be empty")
first_keyword_match = _FIRST_KEYWORD_PATTERN.match(sql_stripped)
first_keyword = (
first_keyword_match.group(1).upper() if first_keyword_match else ""
)
if first_keyword not in ("SELECT", "WITH"):
raise ValueError(f"Only SELECT queries are allowed. Got: {first_keyword}")
single_statement_sql = sql_stripped.rstrip(";").strip()
if ";" in single_statement_sql:
raise ValueError("Only a single SELECT statement is allowed")
deadline = time.monotonic() + timeout_s
def _progress_callback() -> int:
return 1 if time.monotonic() > deadline else 0
connection = self._episode.db_connection
connection.set_progress_handler(_progress_callback, 1000)
self._last_query_truncated = False
try:
cursor = connection.cursor()
cursor.execute(sql_stripped)
rows = cursor.fetchmany(21)
if len(rows) > 20:
self._last_query_truncated = True
rows = rows[:20]
return rows
except sqlite3.OperationalError as exc:
if "interrupted" in str(exc).lower():
raise sqlite3.OperationalError(
f"Query timed out after {timeout_s:.1f} seconds"
) from exc
raise
finally:
connection.set_progress_handler(None, 0)
def _handle_describe(self, table_name: str) -> str:
"""Return table schema and row count."""
if self._episode is None:
raise RuntimeError("No active episode. Call reset() before step().")
requested = table_name.strip()
if not requested:
raise ValueError("Argument cannot be empty for DESCRIBE")
resolved_table, available_tables = self._resolve_table_name(requested)
if resolved_table is None:
available = ", ".join(available_tables) if available_tables else "none"
raise ValueError(
f"Table '{requested}' not found. Available tables: {available}"
)
quoted_table = quote_ident(resolved_table)
cursor = self._episode.db_connection.cursor()
cursor.execute(f"PRAGMA table_info({quoted_table})")
columns = cursor.fetchall()
if not columns:
raise ValueError(f"Table '{resolved_table}' has no visible columns")
cursor.execute(f"SELECT COUNT(*) FROM {quoted_table}")
row_count = int(cursor.fetchone()[0])
self._episode.described_tables.add(resolved_table)
lines = [f"Table '{resolved_table}' columns:"]
for _, col_name, col_type, _, _, _ in columns:
normalized_type = str(col_type).strip() or "UNKNOWN"
lines.append(f"- {col_name}: {normalized_type}")
lines.append(f"Row count: {row_count}")
return "\n".join(lines)
def _handle_sample(self, table_name: str, limit: int = 5) -> str:
"""Return sample rows from a table."""
if self._episode is None:
raise RuntimeError("No active episode. Call reset() before step().")
requested = table_name.strip()
if not requested:
raise ValueError("Argument cannot be empty for SAMPLE")
resolved_table, available_tables = self._resolve_table_name(requested)
if resolved_table is None:
available = ", ".join(available_tables) if available_tables else "none"
raise ValueError(
f"Table '{requested}' not found. Available tables: {available}"
)
quoted_table = quote_ident(resolved_table)
bounded_limit = max(1, min(limit, 20))
rows = self._execute_sql(f"SELECT * FROM {quoted_table} LIMIT {bounded_limit}")
return f"Sample from '{resolved_table}':\n{self._format_rows(rows)}"
def _handle_query(self, sql: str) -> tuple[str, list[tuple]]:
"""Execute query and return formatted output with raw result rows."""
sql_text = sql.strip()
if not sql_text:
raise ValueError("Argument cannot be empty for QUERY")
rows = self._execute_sql(sql_text, timeout_s=5.0)
output = self._format_rows(rows)
if self._last_query_truncated:
output = f"{output}\n... (truncated to 20 rows)"
return output, rows
def _handle_answer(self, value: str) -> tuple[bool, float]:
"""Compare submitted answer against episode gold answer.
Non-gold episodes (``gold_answer is None``, started via ``begin_episode``
for a user question that has no scoring target) skip ``verify_answer``
entirely: the episode terminates with no score. The gold path
(``gold_answer is not None``) is byte-identical to before.
"""
if self._episode is None:
raise RuntimeError("No active episode. Call reset() before step().")
if self._episode.gold_answer is None:
# Non-gold (user) question — no gold target, never score it.
self._episode.done = True
return False, 0.0
# F005/C1: strip any ``​```chart {…}```​`` block so SCORING sees the clean
# prose answer (a real model emitting prose + a chart block would otherwise
# fail gold comparison — ``verify_answer`` only unwraps a fence when the
# WHOLE string is one fenced block). GATED on the ``​```chart``​`` marker
# (same guard as ``verifier.verify_answer``): ``strip_chart_block``'s
# orphan-fence scrub eats a bare closing ``​```​`` so calling it on a legit
# non-chart fenced answer (e.g. a ``​```sql … ```​`` block) would unbalance
# the fence. ``verify_answer`` re-strips self-sufficiently downstream, so
# this is defense-in-depth: a no-op for block-free answers and only fires
# for actual chart blocks.
if "```chart" in value.lower():
value = strip_chart_block(value)
is_correct = verify_answer(
predicted=value,
gold=self._episode.gold_answer or "",
answer_type=self._episode.question_record.answer_type,
gold_rows=self._episode.gold_rows,
)
self._episode.done = True
return is_correct, 1.0 if is_correct else 0.0
def step(
self,
action: SQLAction,
*,
timeout_s: float = 30,
**kwargs,
) -> SQLObservation:
"""Dispatch one structured action and return updated observation."""
del timeout_s
del kwargs
if self._episode is None:
self._last_result = ""
self._last_error = "No active episode. Call reset() before step()."
self._last_reward = None
return self._build_observation()
if self._episode.done:
return self._build_observation()
action_type = str(action.action_type).strip().upper()
argument = str(action.argument)
self._state.current_action_type = action_type or "QUERY"
self._last_result = ""
self._last_error = ""
self._last_reward = None
reward_rows: list[tuple] | None = []
reward_sql = ""
def _consume_invalid_step(error_text: str) -> SQLObservation:
self._last_error = error_text
self._episode.step_count += 1
self._episode.budget = max(0, self._episode.budget - 1)
self._episode.action_log.append(f"{action_type} -> ERROR: {error_text}")
if self._episode.budget == 0:
self._episode.done = True
self._last_reward = 0.0
self._state.step_count = self._episode.step_count
return self._build_observation()
valid_action_types = {"DESCRIBE", "SAMPLE", "QUERY", "ANSWER"}
if action_type not in valid_action_types:
return _consume_invalid_step(
f"Unknown action type '{action.action_type}'. "
"Valid types: DESCRIBE, SAMPLE, QUERY, ANSWER"
)
argument_stripped = argument.strip()
if not argument_stripped:
return _consume_invalid_step(f"Argument cannot be empty for {action_type}")
try:
if action_type == "DESCRIBE":
self._last_result = self._handle_describe(argument_stripped)
elif action_type == "SAMPLE":
self._last_result = self._handle_sample(argument_stripped)
elif action_type == "QUERY":
reward_sql = argument_stripped
self._last_result, reward_rows = self._handle_query(argument_stripped)
else:
# ANSWER always terminates the episode (_handle_answer sets
# done=True), so we return early without decrementing budget.
is_correct, reward = self._handle_answer(argument_stripped)
verdict = "correct" if is_correct else "incorrect"
self._last_result = f"Answer submitted: {verdict}."
self._last_reward = reward
self._episode.step_count += 1
self._episode.action_log.append(
f"ANSWER {argument_stripped} -> {verdict}"
)
self._state.step_count = self._episode.step_count
return self._build_observation()
except ValueError as exc:
self._last_error = str(exc)
except sqlite3.Error as exc:
self._last_error = f"SQL error: {exc}"
self._episode.step_count += 1
self._episode.budget = max(0, self._episode.budget - 1)
self._state.step_count = self._episode.step_count
if self._episode.budget > 0:
self._last_reward = compute_step_reward(
ctx=self._episode,
action_type=action_type,
sql=reward_sql,
rows=reward_rows,
error=self._last_error or None,
)
if self._last_error:
self._episode.action_log.append(
f"{action_type} -> ERROR: {self._last_error}"
)
else:
preview = self._last_result.splitlines()[0] if self._last_result else "ok"
self._episode.action_log.append(f"{action_type} -> {preview}")
if self._episode.budget == 0:
self._episode.done = True
if self._last_reward is None:
self._last_reward = 0.0
return self._build_observation()
def _build_observation(self) -> SQLObservation:
"""Construct a rich observation from the current episode context."""
if self._episode is None:
observation = SQLObservation(
question="",
schema_info="",
result=self._last_result,
error=self._last_error,
step_count=0,
budget_remaining=0,
action_history=[],
done=False,
reward=self._last_reward,
)
else:
table_names = self._get_table_names(self._episode.db_connection)
known_tables = set(table_names)
schema_lines = ["Available tables:", *[f"- {name}" for name in table_names]]
if self._episode.described_tables:
schema_lines.append("")
schema_lines.append("Described tables:")
for table_name in sorted(self._episode.described_tables):
if table_name not in known_tables:
schema_lines.append(
f"- {table_name}: unavailable (not in active schema)"
)
continue
cursor = self._episode.db_connection.cursor()
cursor.execute(f"PRAGMA table_info({quote_ident(table_name)})")
columns = cursor.fetchall()
if not columns:
schema_lines.append(f"- {table_name}: no columns available")
continue
column_summary = ", ".join(
f"{str(column[1])} {str(column[2]) or 'UNKNOWN'}"
for column in columns
)
schema_lines.append(f"- {table_name}: {column_summary}")
observation = SQLObservation(
question=self._episode.question_record.question_text,
schema_info="\n".join(schema_lines),
result=self._last_result,
error=self._last_error,
step_count=self._episode.step_count,
budget_remaining=self._episode.budget,
action_history=list(self._episode.action_log),
done=self._episode.done,
reward=self._last_reward,
)
return observation
@property
def state(self) -> SQLState:
"""Get current exposed state metadata."""
return self._state
def message_to_action(self, message: dict[str, str]) -> SQLAction:
"""Convert free-form messages into structured SQLAction values."""
if "role" not in message:
raise ValueError("Message must contain a 'role' key")
if "content" not in message:
raise ValueError("Message must contain a 'content' key")
if message["content"] is None:
raise ValueError("Message content cannot be None")
content = str(message["content"])
parsed = content.strip()
action_type = "QUERY"
argument = content
if message["role"].lower() == "user" and parsed:
prefix, separator, remainder = parsed.partition(" ")
normalized_prefix = prefix.upper()
if normalized_prefix in {"DESCRIBE", "SAMPLE", "QUERY", "ANSWER"}:
action_type = normalized_prefix
if separator:
argument = remainder
else:
argument = ""
self._state.current_action_type = action_type
self._state.history_messages.append(message)
return SQLAction(action_type=action_type, argument=argument)