Spaces:
Running on Zero
Running on Zero
| import json | |
| import logging | |
| from pathlib import Path | |
| import random | |
| import re | |
| import sqlite3 | |
| import time | |
| import uuid | |
| from typing import Protocol | |
| from .reward import compute_step_reward | |
| from .verifier import verify_answer | |
| # ``chart_intent`` is gradio-free + dependency-light (stdlib + pydantic only), so | |
| # importing it here does NOT pull gradio/torch/trl/transformers into the env. It | |
| # is the SINGLE strip site for the `````chart {…}````` block (display AND | |
| # scoring): ``_handle_answer`` strips the block from the model's ANSWER value | |
| # before ``verify_answer`` so a prose+block answer still matches its gold value. | |
| try: | |
| from .chart_intent import strip_chart_block | |
| from .sql_ident import is_valid_identifier, quote_ident | |
| except ImportError: # pragma: no cover - flat-layout / direct-run fallback | |
| from chart_intent import strip_chart_block # type: ignore[no-redef] | |
| from sql_ident import is_valid_identifier, quote_ident # type: ignore[no-redef] | |
| class ModelTokenizer(Protocol): | |
| """Minimal tokenizer contract the environment relies on. | |
| Replaces OpenEnv's ModelTokenizer interface. Any object exposing | |
| ``apply_chat_template`` (HuggingFace tokenizers, MockTokenizer, the | |
| training adapter's stub) satisfies it. | |
| """ | |
| def apply_chat_template(self, messages: list[dict[str, str]], **kwargs) -> str: ... | |
| try: | |
| from sql_env.models import ( | |
| EpisodeContext, | |
| QuestionRecord, | |
| SQLAction, | |
| SQLObservation, | |
| SQLState, | |
| ) | |
| except ImportError: | |
| # Fallback for Docker where PYTHONPATH=/app/env | |
| from models import ( # type: ignore[no-redef] | |
| EpisodeContext, | |
| QuestionRecord, | |
| SQLAction, | |
| SQLObservation, | |
| SQLState, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| _TABLE_FROM_JOIN_PATTERN = re.compile( | |
| r"\b(?:FROM|JOIN)\s+([A-Za-z_][A-Za-z0-9_]*)", re.IGNORECASE | |
| ) | |
| _FIRST_KEYWORD_PATTERN = re.compile(r"^[\s\n\r\t]*(\w+)") | |
| def resolve_db_path(db_dir: str | Path, db_id: str) -> Path | None: | |
| """Resolve the existing ``.sqlite`` file for ``db_id`` under ``db_dir``. | |
| The SINGLE source of the db-id -> file resolution AND its path-traversal | |
| defense, shared by ``SQLEnvironment._open_db`` (episode setup) and | |
| ``agent_loop._resolve_db_path`` (read-only re-exec). Tries the two layout | |
| candidates — ``<root>/<id>/<id>.sqlite`` then ``<root>/<id>.sqlite`` — and | |
| returns the first that BOTH exists AND passes the containment guard | |
| (``.resolve()`` + ``db_root in candidate.parents``), so a ``db_id`` like | |
| ``../escape`` can never resolve a file outside ``db_dir``. Returns ``None`` | |
| when no contained candidate exists. | |
| Identifier-charset validation (``^[A-Za-z0-9_]+$``) is the CALLER's | |
| responsibility — ``_open_db`` enforces it and raises its own error; the | |
| containment guard here is defense-in-depth that stands on its own. | |
| """ | |
| root = Path(db_dir) | |
| db_root = root.resolve() | |
| candidates = [ | |
| (root / db_id / f"{db_id}.sqlite").resolve(), | |
| (root / f"{db_id}.sqlite").resolve(), | |
| ] | |
| for candidate in candidates: | |
| if candidate.exists() and db_root in candidate.parents: | |
| return candidate | |
| return None | |
| class HarnessError(RuntimeError): | |
| """A broken *episode setup* — never a model failure. | |
| Raised by ``reset()`` when the environment itself cannot produce a valid | |
| episode: the gold SQL errors or times out, the database file is missing, or | |
| the gold answer is empty/degenerate. Such episodes must be excluded from | |
| training (they would poison the gradient), not scored as model mistakes. | |
| The ``reason`` is a stable key (``db_missing`` | ``gold_sql_error`` | | |
| ``gold_empty``) so failures can be counted and triaged. | |
| """ | |
| def __init__(self, reason: str, detail: str = "", question_id: str | None = None): | |
| self.reason = reason | |
| self.detail = detail | |
| self.question_id = question_id | |
| suffix = f" (question {question_id})" if question_id else "" | |
| super().__init__(f"[{reason}] {detail}{suffix}") | |
| def is_degenerate_gold(rows: list[tuple], gold_answer: str) -> bool: | |
| """True if a gold result is unusable as a training target. | |
| Empty result sets and single all-NULL rows (e.g. MAX()/MIN() over an empty | |
| set) give an empty/ambiguous gold answer a model can fluke-match (answering | |
| "" or "None"), so they are treated as harness/data failures rather than | |
| legitimate episodes. The all-NULL check is applied regardless of column | |
| count so single- and multi-column degenerate rows are handled consistently. | |
| """ | |
| if not rows: | |
| return True | |
| if len(rows) == 1 and all(value is None for value in rows[0]): | |
| return True | |
| return not gold_answer.strip() | |
| class SQLEnvironment: | |
| """SQLEnv implementation with a structured SQL action loop. | |
| Runs in-process (TRL training calls reset()/step() directly). Formerly an | |
| OpenEnv ``Environment`` subclass; the base class only stored an optional | |
| transform, which this environment never used, so it is now standalone. | |
| """ | |
| def __init__( | |
| self, | |
| questions_path: str, | |
| db_dir: str, | |
| tokenizer: ModelTokenizer, | |
| step_budget: int = 15, | |
| ): | |
| if not hasattr(tokenizer, "apply_chat_template"): | |
| raise ValueError("Tokenizer must have 'apply_chat_template' method") | |
| if step_budget <= 0: | |
| raise ValueError("step_budget must be a positive integer") | |
| questions_file = Path(questions_path) | |
| database_dir = Path(db_dir) | |
| if not questions_file.exists(): | |
| raise FileNotFoundError(f"Questions file not found: {questions_file}") | |
| if not database_dir.exists() or not database_dir.is_dir(): | |
| raise FileNotFoundError(f"Database directory not found: {database_dir}") | |
| self.tokenizer = tokenizer | |
| self.questions_path = questions_file | |
| self.db_dir = database_dir | |
| self.step_budget = step_budget | |
| self.questions = self._load_questions(str(questions_file)) | |
| if not self.questions: | |
| raise ValueError("Questions file contains no questions") | |
| self._episode: EpisodeContext | None = None | |
| self._last_result = "" | |
| self._last_error = "" | |
| self._last_reward: float | None = None | |
| self._last_query_truncated = False | |
| self._state = SQLState() | |
| def _extract_tables_from_sql(self, sql: str) -> list[str]: | |
| """Extract table names from basic FROM/JOIN clauses.""" | |
| tables: list[str] = [] | |
| for match in _TABLE_FROM_JOIN_PATTERN.findall(sql): | |
| if match not in tables: | |
| tables.append(match) | |
| return tables | |
| def _load_questions(self, path: str) -> list[QuestionRecord]: | |
| """Load Spider questions JSON into QuestionRecord instances.""" | |
| questions_path = Path(path) | |
| if not questions_path.exists(): | |
| raise FileNotFoundError(f"Questions file not found: {questions_path}") | |
| try: | |
| with questions_path.open("r", encoding="utf-8") as handle: | |
| payload = json.load(handle) | |
| except json.JSONDecodeError as exc: | |
| raise ValueError( | |
| f"Invalid questions JSON format: {questions_path}" | |
| ) from exc | |
| if not isinstance(payload, list): | |
| raise ValueError("Questions JSON must be an array of records") | |
| question_records: list[QuestionRecord] = [] | |
| for idx, item in enumerate(payload): | |
| if not isinstance(item, dict): | |
| raise ValueError(f"Question at index {idx} must be an object") | |
| # Support both raw Spider format and curated format | |
| question_text = item.get("question_text") or item.get("question") | |
| db_name = item.get("database_name") or item.get("db_id") | |
| gold_sql = item.get("gold_sql") or item.get("query") | |
| if not isinstance(question_text, str) or not question_text.strip(): | |
| raise ValueError( | |
| f"Question at index {idx} missing non-empty 'question'" | |
| ) | |
| if not isinstance(db_name, str) or not db_name.strip(): | |
| raise ValueError(f"Question at index {idx} missing non-empty 'db_id'") | |
| if not isinstance(gold_sql, str) or not gold_sql.strip(): | |
| raise ValueError(f"Question at index {idx} missing non-empty 'query'") | |
| normalized_db_name = db_name.strip() | |
| if not is_valid_identifier(normalized_db_name): | |
| raise ValueError( | |
| f"Question at index {idx} has invalid db_id '{normalized_db_name}'" | |
| ) | |
| gold_answer = item.get("gold_answer", "") | |
| if not isinstance(gold_answer, str): | |
| gold_answer = str(gold_answer) | |
| question_records.append( | |
| QuestionRecord( | |
| question_id=item.get("question_id", f"q-{idx}"), | |
| question_text=question_text, | |
| database_name=normalized_db_name, | |
| gold_sql=gold_sql, | |
| gold_answer=gold_answer, | |
| answer_type=item.get("answer_type", "string"), | |
| difficulty=item.get("difficulty", "medium"), | |
| tables_involved=item.get("tables_involved") | |
| or self._extract_tables_from_sql(gold_sql), | |
| ) | |
| ) | |
| return question_records | |
| def _open_db(self, db_name: str) -> sqlite3.Connection: | |
| """Open a read-only SQLite connection for the requested database.""" | |
| normalized_db_name = db_name.strip() | |
| if not is_valid_identifier(normalized_db_name): | |
| raise ValueError(f"Invalid database name: '{db_name}'") | |
| db_path = resolve_db_path(self.db_dir, normalized_db_name) | |
| if db_path is None: | |
| raise FileNotFoundError( | |
| f"Database '{normalized_db_name}' not found in {self.db_dir}" | |
| ) | |
| uri = f"file:{db_path}?mode=ro" | |
| return sqlite3.connect(uri, uri=True) | |
| def _format_gold_answer(self, rows: list[tuple]) -> str: | |
| """Convert SQL rows into a stable string answer for episode comparison.""" | |
| if not rows: | |
| return "" | |
| if len(rows) == 1 and len(rows[0]) == 1: | |
| return str(rows[0][0]) | |
| return "\n".join(" | ".join(str(value) for value in row) for row in rows) | |
| def _execute_gold_sql( | |
| self, | |
| connection: sqlite3.Connection, | |
| sql: str, | |
| timeout_s: float = 5.0, | |
| ) -> list[tuple]: | |
| """Execute gold SQL with read-only/SELECT-only timeout protections.""" | |
| sql_stripped = sql.strip() | |
| if not sql_stripped: | |
| raise ValueError("SQL query cannot be empty") | |
| first_keyword_match = _FIRST_KEYWORD_PATTERN.match(sql_stripped) | |
| first_keyword = ( | |
| first_keyword_match.group(1).upper() if first_keyword_match else "" | |
| ) | |
| if first_keyword not in ("SELECT", "WITH"): | |
| raise ValueError(f"Only SELECT queries are allowed. Got: {first_keyword}") | |
| deadline = time.monotonic() + timeout_s | |
| def _progress_callback() -> int: | |
| return 1 if time.monotonic() > deadline else 0 | |
| connection.set_progress_handler(_progress_callback, 1000) | |
| try: | |
| cursor = connection.cursor() | |
| cursor.execute(sql_stripped) | |
| return cursor.fetchall() | |
| except sqlite3.OperationalError as exc: | |
| if "interrupted" in str(exc).lower(): | |
| raise sqlite3.OperationalError( | |
| f"Query timed out after {timeout_s:.1f} seconds" | |
| ) from exc | |
| raise | |
| finally: | |
| connection.set_progress_handler(None, 0) | |
| def reset( | |
| self, | |
| *, | |
| seed: int | None = None, | |
| episode_id: str | None = None, | |
| question_index: int | None = None, | |
| **kwargs, | |
| ) -> SQLObservation: | |
| """Reset episode context and return the initial rich observation. | |
| ``question_index`` (optional) selects a SPECIFIC question deterministically | |
| instead of the random ``seed`` draw — the full-set eval protocol iterates | |
| every question exactly once with it, eliminating the sampling-with- | |
| replacement noise that made N=50 evals unreliable. | |
| """ | |
| del kwargs | |
| if self._episode is not None: | |
| self._episode.db_connection.close() | |
| if question_index is not None: | |
| question = self.questions[question_index % len(self.questions)] | |
| else: | |
| chooser = random.Random(seed) if seed is not None else random | |
| question = chooser.choice(self.questions) | |
| # --- Harness guardrail: fail fast on a broken episode setup --------- | |
| # A broken gold answer is a HARNESS failure, not a model failure. We | |
| # raise HarnessError so the caller can count + exclude it instead of | |
| # silently training on an empty/degenerate gold (which poisons the | |
| # gradient). All failure paths close the connection (no leaks). | |
| try: | |
| connection = self._open_db(question.database_name) | |
| except (FileNotFoundError, ValueError) as exc: | |
| raise HarnessError("db_missing", str(exc), question.question_id) from exc | |
| try: | |
| gold_rows = self._execute_gold_sql(connection, question.gold_sql) | |
| except Exception as exc: | |
| connection.close() | |
| raise HarnessError( | |
| "gold_sql_error", str(exc), question.question_id | |
| ) from exc | |
| gold_answer = self._format_gold_answer(gold_rows) | |
| if is_degenerate_gold(gold_rows, gold_answer): | |
| connection.close() | |
| raise HarnessError( | |
| "gold_empty", | |
| "gold SQL returned an empty or degenerate result set", | |
| question.question_id, | |
| ) | |
| question_for_episode = QuestionRecord( | |
| question_id=question.question_id, | |
| question_text=question.question_text, | |
| database_name=question.database_name, | |
| gold_sql=question.gold_sql, | |
| gold_answer=gold_answer, | |
| answer_type=question.answer_type, | |
| difficulty=question.difficulty, | |
| tables_involved=list(question.tables_involved), | |
| ) | |
| resolved_episode_id = episode_id or str(uuid.uuid4()) | |
| self._episode = EpisodeContext( | |
| episode_id=resolved_episode_id, | |
| db_connection=connection, | |
| question_record=question_for_episode, | |
| step_count=0, | |
| budget=self.step_budget, | |
| done=False, | |
| gold_answer=gold_answer, | |
| gold_rows=gold_rows, | |
| ) | |
| self._state.episode_id = resolved_episode_id | |
| self._state.step_count = 0 | |
| self._state.current_action_type = "QUERY" | |
| self._state.history_messages = [] | |
| self._last_result = "" | |
| self._last_error = "" | |
| self._last_reward = None | |
| self._last_query_truncated = False | |
| return self._build_observation() | |
| def begin_episode( | |
| self, | |
| db_id: str, | |
| question: str, | |
| *, | |
| episode_id: str | None = None, | |
| gold: QuestionRecord | None = None, | |
| ) -> SQLObservation: | |
| """Start a NON-GOLD episode for a user question against ``db_id``. | |
| The seam the demo loop (``agent_loop.run_agent_turn``) uses INSTEAD of | |
| ``reset()`` (which is gold-coupled: random gold question + gold SQL + | |
| ``verify_answer`` scoring). Opens the DB read-only via the existing | |
| ``_open_db``, constructs an ``EpisodeContext`` with ``gold_answer=None`` | |
| and ``gold_rows=[]`` (no scoring target), and returns the initial | |
| observation via the existing ``_build_observation()``. A non-gold ANSWER | |
| terminates without scoring (see ``_handle_answer``). | |
| ``reset()``, ``step()``, and the gold scoring path are unchanged. | |
| Args: | |
| db_id: database to open (resolved by ``_open_db``). | |
| question: the user's plain-English question (no gold answer exists). | |
| episode_id: optional explicit id (a uuid is generated otherwise). | |
| gold: reserved for a future gold-seeded variant; ``None`` (the default) | |
| is the non-gold path. When supplied it carries the scoring target. | |
| Returns: | |
| the initial ``SQLObservation`` for the started episode. | |
| Raises: | |
| FileNotFoundError / ValueError: db_id missing or invalid (from | |
| ``_open_db``) — a setup error, surfaced to the caller. | |
| """ | |
| if self._episode is not None: | |
| self._episode.db_connection.close() | |
| connection = self._open_db(db_id) | |
| question_for_episode = QuestionRecord( | |
| question_id=episode_id or "user-question", | |
| question_text=question, | |
| database_name=db_id.strip(), | |
| gold_sql=gold.gold_sql if gold is not None else "", | |
| gold_answer=gold.gold_answer if gold is not None else "", | |
| answer_type=gold.answer_type if gold is not None else "string", | |
| difficulty=gold.difficulty if gold is not None else "medium", | |
| tables_involved=list(gold.tables_involved) if gold is not None else [], | |
| ) | |
| resolved_episode_id = episode_id or str(uuid.uuid4()) | |
| self._episode = EpisodeContext( | |
| episode_id=resolved_episode_id, | |
| db_connection=connection, | |
| question_record=question_for_episode, | |
| step_count=0, | |
| budget=self.step_budget, | |
| done=False, | |
| gold_answer=gold.gold_answer if gold is not None else None, | |
| gold_rows=[], | |
| ) | |
| self._state.episode_id = resolved_episode_id | |
| self._state.step_count = 0 | |
| self._state.current_action_type = "QUERY" | |
| self._state.history_messages = [] | |
| self._last_result = "" | |
| self._last_error = "" | |
| self._last_reward = None | |
| self._last_query_truncated = False | |
| return self._build_observation() | |
| def _get_table_names(self, connection: sqlite3.Connection) -> list[str]: | |
| """Return user-visible table names for the active SQLite database.""" | |
| cursor = connection.cursor() | |
| cursor.execute( | |
| """ | |
| SELECT name | |
| FROM sqlite_master | |
| WHERE type = 'table' AND name NOT LIKE 'sqlite_%' | |
| ORDER BY name | |
| """ | |
| ) | |
| return [str(row[0]) for row in cursor.fetchall()] | |
| def _resolve_table_name(self, table_name: str) -> tuple[str | None, list[str]]: | |
| """Resolve requested table name against active DB tables.""" | |
| if self._episode is None: | |
| return None, [] | |
| available_tables = self._get_table_names(self._episode.db_connection) | |
| lookup = {table.lower(): table for table in available_tables} | |
| resolved = lookup.get(table_name.strip().lower()) | |
| return resolved, available_tables | |
| def _format_rows(self, rows: list[tuple]) -> str: | |
| """Format SQL rows as readable text.""" | |
| if not rows: | |
| return "No rows returned." | |
| lines = [ | |
| f"{idx}. {' | '.join(str(value) for value in row)}" | |
| for idx, row in enumerate(rows, start=1) | |
| ] | |
| return "\n".join(lines) | |
| def _execute_sql(self, sql: str, timeout_s: float = 5.0) -> list[tuple]: | |
| """Execute SQL in sandbox: SELECT-only, single statement, timeout, truncation.""" | |
| if self._episode is None: | |
| raise RuntimeError("No active episode. Call reset() before step().") | |
| sql_stripped = sql.strip() | |
| if not sql_stripped: | |
| raise ValueError("SQL query cannot be empty") | |
| first_keyword_match = _FIRST_KEYWORD_PATTERN.match(sql_stripped) | |
| first_keyword = ( | |
| first_keyword_match.group(1).upper() if first_keyword_match else "" | |
| ) | |
| if first_keyword not in ("SELECT", "WITH"): | |
| raise ValueError(f"Only SELECT queries are allowed. Got: {first_keyword}") | |
| single_statement_sql = sql_stripped.rstrip(";").strip() | |
| if ";" in single_statement_sql: | |
| raise ValueError("Only a single SELECT statement is allowed") | |
| deadline = time.monotonic() + timeout_s | |
| def _progress_callback() -> int: | |
| return 1 if time.monotonic() > deadline else 0 | |
| connection = self._episode.db_connection | |
| connection.set_progress_handler(_progress_callback, 1000) | |
| self._last_query_truncated = False | |
| try: | |
| cursor = connection.cursor() | |
| cursor.execute(sql_stripped) | |
| rows = cursor.fetchmany(21) | |
| if len(rows) > 20: | |
| self._last_query_truncated = True | |
| rows = rows[:20] | |
| return rows | |
| except sqlite3.OperationalError as exc: | |
| if "interrupted" in str(exc).lower(): | |
| raise sqlite3.OperationalError( | |
| f"Query timed out after {timeout_s:.1f} seconds" | |
| ) from exc | |
| raise | |
| finally: | |
| connection.set_progress_handler(None, 0) | |
| def _handle_describe(self, table_name: str) -> str: | |
| """Return table schema and row count.""" | |
| if self._episode is None: | |
| raise RuntimeError("No active episode. Call reset() before step().") | |
| requested = table_name.strip() | |
| if not requested: | |
| raise ValueError("Argument cannot be empty for DESCRIBE") | |
| resolved_table, available_tables = self._resolve_table_name(requested) | |
| if resolved_table is None: | |
| available = ", ".join(available_tables) if available_tables else "none" | |
| raise ValueError( | |
| f"Table '{requested}' not found. Available tables: {available}" | |
| ) | |
| quoted_table = quote_ident(resolved_table) | |
| cursor = self._episode.db_connection.cursor() | |
| cursor.execute(f"PRAGMA table_info({quoted_table})") | |
| columns = cursor.fetchall() | |
| if not columns: | |
| raise ValueError(f"Table '{resolved_table}' has no visible columns") | |
| cursor.execute(f"SELECT COUNT(*) FROM {quoted_table}") | |
| row_count = int(cursor.fetchone()[0]) | |
| self._episode.described_tables.add(resolved_table) | |
| lines = [f"Table '{resolved_table}' columns:"] | |
| for _, col_name, col_type, _, _, _ in columns: | |
| normalized_type = str(col_type).strip() or "UNKNOWN" | |
| lines.append(f"- {col_name}: {normalized_type}") | |
| lines.append(f"Row count: {row_count}") | |
| return "\n".join(lines) | |
| def _handle_sample(self, table_name: str, limit: int = 5) -> str: | |
| """Return sample rows from a table.""" | |
| if self._episode is None: | |
| raise RuntimeError("No active episode. Call reset() before step().") | |
| requested = table_name.strip() | |
| if not requested: | |
| raise ValueError("Argument cannot be empty for SAMPLE") | |
| resolved_table, available_tables = self._resolve_table_name(requested) | |
| if resolved_table is None: | |
| available = ", ".join(available_tables) if available_tables else "none" | |
| raise ValueError( | |
| f"Table '{requested}' not found. Available tables: {available}" | |
| ) | |
| quoted_table = quote_ident(resolved_table) | |
| bounded_limit = max(1, min(limit, 20)) | |
| rows = self._execute_sql(f"SELECT * FROM {quoted_table} LIMIT {bounded_limit}") | |
| return f"Sample from '{resolved_table}':\n{self._format_rows(rows)}" | |
| def _handle_query(self, sql: str) -> tuple[str, list[tuple]]: | |
| """Execute query and return formatted output with raw result rows.""" | |
| sql_text = sql.strip() | |
| if not sql_text: | |
| raise ValueError("Argument cannot be empty for QUERY") | |
| rows = self._execute_sql(sql_text, timeout_s=5.0) | |
| output = self._format_rows(rows) | |
| if self._last_query_truncated: | |
| output = f"{output}\n... (truncated to 20 rows)" | |
| return output, rows | |
| def _handle_answer(self, value: str) -> tuple[bool, float]: | |
| """Compare submitted answer against episode gold answer. | |
| Non-gold episodes (``gold_answer is None``, started via ``begin_episode`` | |
| for a user question that has no scoring target) skip ``verify_answer`` | |
| entirely: the episode terminates with no score. The gold path | |
| (``gold_answer is not None``) is byte-identical to before. | |
| """ | |
| if self._episode is None: | |
| raise RuntimeError("No active episode. Call reset() before step().") | |
| if self._episode.gold_answer is None: | |
| # Non-gold (user) question — no gold target, never score it. | |
| self._episode.done = True | |
| return False, 0.0 | |
| # F005/C1: strip any `````chart {…}````` block so SCORING sees the clean | |
| # prose answer (a real model emitting prose + a chart block would otherwise | |
| # fail gold comparison — ``verify_answer`` only unwraps a fence when the | |
| # WHOLE string is one fenced block). GATED on the `````chart```` marker | |
| # (same guard as ``verifier.verify_answer``): ``strip_chart_block``'s | |
| # orphan-fence scrub eats a bare closing ``````` so calling it on a legit | |
| # non-chart fenced answer (e.g. a `````sql … ````` block) would unbalance | |
| # the fence. ``verify_answer`` re-strips self-sufficiently downstream, so | |
| # this is defense-in-depth: a no-op for block-free answers and only fires | |
| # for actual chart blocks. | |
| if "```chart" in value.lower(): | |
| value = strip_chart_block(value) | |
| is_correct = verify_answer( | |
| predicted=value, | |
| gold=self._episode.gold_answer or "", | |
| answer_type=self._episode.question_record.answer_type, | |
| gold_rows=self._episode.gold_rows, | |
| ) | |
| self._episode.done = True | |
| return is_correct, 1.0 if is_correct else 0.0 | |
| def step( | |
| self, | |
| action: SQLAction, | |
| *, | |
| timeout_s: float = 30, | |
| **kwargs, | |
| ) -> SQLObservation: | |
| """Dispatch one structured action and return updated observation.""" | |
| del timeout_s | |
| del kwargs | |
| if self._episode is None: | |
| self._last_result = "" | |
| self._last_error = "No active episode. Call reset() before step()." | |
| self._last_reward = None | |
| return self._build_observation() | |
| if self._episode.done: | |
| return self._build_observation() | |
| action_type = str(action.action_type).strip().upper() | |
| argument = str(action.argument) | |
| self._state.current_action_type = action_type or "QUERY" | |
| self._last_result = "" | |
| self._last_error = "" | |
| self._last_reward = None | |
| reward_rows: list[tuple] | None = [] | |
| reward_sql = "" | |
| def _consume_invalid_step(error_text: str) -> SQLObservation: | |
| self._last_error = error_text | |
| self._episode.step_count += 1 | |
| self._episode.budget = max(0, self._episode.budget - 1) | |
| self._episode.action_log.append(f"{action_type} -> ERROR: {error_text}") | |
| if self._episode.budget == 0: | |
| self._episode.done = True | |
| self._last_reward = 0.0 | |
| self._state.step_count = self._episode.step_count | |
| return self._build_observation() | |
| valid_action_types = {"DESCRIBE", "SAMPLE", "QUERY", "ANSWER"} | |
| if action_type not in valid_action_types: | |
| return _consume_invalid_step( | |
| f"Unknown action type '{action.action_type}'. " | |
| "Valid types: DESCRIBE, SAMPLE, QUERY, ANSWER" | |
| ) | |
| argument_stripped = argument.strip() | |
| if not argument_stripped: | |
| return _consume_invalid_step(f"Argument cannot be empty for {action_type}") | |
| try: | |
| if action_type == "DESCRIBE": | |
| self._last_result = self._handle_describe(argument_stripped) | |
| elif action_type == "SAMPLE": | |
| self._last_result = self._handle_sample(argument_stripped) | |
| elif action_type == "QUERY": | |
| reward_sql = argument_stripped | |
| self._last_result, reward_rows = self._handle_query(argument_stripped) | |
| else: | |
| # ANSWER always terminates the episode (_handle_answer sets | |
| # done=True), so we return early without decrementing budget. | |
| is_correct, reward = self._handle_answer(argument_stripped) | |
| verdict = "correct" if is_correct else "incorrect" | |
| self._last_result = f"Answer submitted: {verdict}." | |
| self._last_reward = reward | |
| self._episode.step_count += 1 | |
| self._episode.action_log.append( | |
| f"ANSWER {argument_stripped} -> {verdict}" | |
| ) | |
| self._state.step_count = self._episode.step_count | |
| return self._build_observation() | |
| except ValueError as exc: | |
| self._last_error = str(exc) | |
| except sqlite3.Error as exc: | |
| self._last_error = f"SQL error: {exc}" | |
| self._episode.step_count += 1 | |
| self._episode.budget = max(0, self._episode.budget - 1) | |
| self._state.step_count = self._episode.step_count | |
| if self._episode.budget > 0: | |
| self._last_reward = compute_step_reward( | |
| ctx=self._episode, | |
| action_type=action_type, | |
| sql=reward_sql, | |
| rows=reward_rows, | |
| error=self._last_error or None, | |
| ) | |
| if self._last_error: | |
| self._episode.action_log.append( | |
| f"{action_type} -> ERROR: {self._last_error}" | |
| ) | |
| else: | |
| preview = self._last_result.splitlines()[0] if self._last_result else "ok" | |
| self._episode.action_log.append(f"{action_type} -> {preview}") | |
| if self._episode.budget == 0: | |
| self._episode.done = True | |
| if self._last_reward is None: | |
| self._last_reward = 0.0 | |
| return self._build_observation() | |
| def _build_observation(self) -> SQLObservation: | |
| """Construct a rich observation from the current episode context.""" | |
| if self._episode is None: | |
| observation = SQLObservation( | |
| question="", | |
| schema_info="", | |
| result=self._last_result, | |
| error=self._last_error, | |
| step_count=0, | |
| budget_remaining=0, | |
| action_history=[], | |
| done=False, | |
| reward=self._last_reward, | |
| ) | |
| else: | |
| table_names = self._get_table_names(self._episode.db_connection) | |
| known_tables = set(table_names) | |
| schema_lines = ["Available tables:", *[f"- {name}" for name in table_names]] | |
| if self._episode.described_tables: | |
| schema_lines.append("") | |
| schema_lines.append("Described tables:") | |
| for table_name in sorted(self._episode.described_tables): | |
| if table_name not in known_tables: | |
| schema_lines.append( | |
| f"- {table_name}: unavailable (not in active schema)" | |
| ) | |
| continue | |
| cursor = self._episode.db_connection.cursor() | |
| cursor.execute(f"PRAGMA table_info({quote_ident(table_name)})") | |
| columns = cursor.fetchall() | |
| if not columns: | |
| schema_lines.append(f"- {table_name}: no columns available") | |
| continue | |
| column_summary = ", ".join( | |
| f"{str(column[1])} {str(column[2]) or 'UNKNOWN'}" | |
| for column in columns | |
| ) | |
| schema_lines.append(f"- {table_name}: {column_summary}") | |
| observation = SQLObservation( | |
| question=self._episode.question_record.question_text, | |
| schema_info="\n".join(schema_lines), | |
| result=self._last_result, | |
| error=self._last_error, | |
| step_count=self._episode.step_count, | |
| budget_remaining=self._episode.budget, | |
| action_history=list(self._episode.action_log), | |
| done=self._episode.done, | |
| reward=self._last_reward, | |
| ) | |
| return observation | |
| def state(self) -> SQLState: | |
| """Get current exposed state metadata.""" | |
| return self._state | |
| def message_to_action(self, message: dict[str, str]) -> SQLAction: | |
| """Convert free-form messages into structured SQLAction values.""" | |
| if "role" not in message: | |
| raise ValueError("Message must contain a 'role' key") | |
| if "content" not in message: | |
| raise ValueError("Message must contain a 'content' key") | |
| if message["content"] is None: | |
| raise ValueError("Message content cannot be None") | |
| content = str(message["content"]) | |
| parsed = content.strip() | |
| action_type = "QUERY" | |
| argument = content | |
| if message["role"].lower() == "user" and parsed: | |
| prefix, separator, remainder = parsed.partition(" ") | |
| normalized_prefix = prefix.upper() | |
| if normalized_prefix in {"DESCRIBE", "SAMPLE", "QUERY", "ANSWER"}: | |
| action_type = normalized_prefix | |
| if separator: | |
| argument = remainder | |
| else: | |
| argument = "" | |
| self._state.current_action_type = action_type | |
| self._state.history_messages.append(message) | |
| return SQLAction(action_type=action_type, argument=argument) | |