import json import logging from pathlib import Path import random import re import sqlite3 import time import uuid from typing import Protocol from .reward import compute_step_reward from .verifier import verify_answer # ``chart_intent`` is gradio-free + dependency-light (stdlib + pydantic only), so # importing it here does NOT pull gradio/torch/trl/transformers into the env. It # is the SINGLE strip site for the ``​```chart {…}```​`` block (display AND # scoring): ``_handle_answer`` strips the block from the model's ANSWER value # before ``verify_answer`` so a prose+block answer still matches its gold value. try: from .chart_intent import strip_chart_block from .sql_ident import is_valid_identifier, quote_ident except ImportError: # pragma: no cover - flat-layout / direct-run fallback from chart_intent import strip_chart_block # type: ignore[no-redef] from sql_ident import is_valid_identifier, quote_ident # type: ignore[no-redef] class ModelTokenizer(Protocol): """Minimal tokenizer contract the environment relies on. Replaces OpenEnv's ModelTokenizer interface. Any object exposing ``apply_chat_template`` (HuggingFace tokenizers, MockTokenizer, the training adapter's stub) satisfies it. """ def apply_chat_template(self, messages: list[dict[str, str]], **kwargs) -> str: ... try: from sql_env.models import ( EpisodeContext, QuestionRecord, SQLAction, SQLObservation, SQLState, ) except ImportError: # Fallback for Docker where PYTHONPATH=/app/env from models import ( # type: ignore[no-redef] EpisodeContext, QuestionRecord, SQLAction, SQLObservation, SQLState, ) logger = logging.getLogger(__name__) _TABLE_FROM_JOIN_PATTERN = re.compile( r"\b(?:FROM|JOIN)\s+([A-Za-z_][A-Za-z0-9_]*)", re.IGNORECASE ) _FIRST_KEYWORD_PATTERN = re.compile(r"^[\s\n\r\t]*(\w+)") def resolve_db_path(db_dir: str | Path, db_id: str) -> Path | None: """Resolve the existing ``.sqlite`` file for ``db_id`` under ``db_dir``. The SINGLE source of the db-id -> file resolution AND its path-traversal defense, shared by ``SQLEnvironment._open_db`` (episode setup) and ``agent_loop._resolve_db_path`` (read-only re-exec). Tries the two layout candidates — ``//.sqlite`` then ``/.sqlite`` — and returns the first that BOTH exists AND passes the containment guard (``.resolve()`` + ``db_root in candidate.parents``), so a ``db_id`` like ``../escape`` can never resolve a file outside ``db_dir``. Returns ``None`` when no contained candidate exists. Identifier-charset validation (``^[A-Za-z0-9_]+$``) is the CALLER's responsibility — ``_open_db`` enforces it and raises its own error; the containment guard here is defense-in-depth that stands on its own. """ root = Path(db_dir) db_root = root.resolve() candidates = [ (root / db_id / f"{db_id}.sqlite").resolve(), (root / f"{db_id}.sqlite").resolve(), ] for candidate in candidates: if candidate.exists() and db_root in candidate.parents: return candidate return None class HarnessError(RuntimeError): """A broken *episode setup* — never a model failure. Raised by ``reset()`` when the environment itself cannot produce a valid episode: the gold SQL errors or times out, the database file is missing, or the gold answer is empty/degenerate. Such episodes must be excluded from training (they would poison the gradient), not scored as model mistakes. The ``reason`` is a stable key (``db_missing`` | ``gold_sql_error`` | ``gold_empty``) so failures can be counted and triaged. """ def __init__(self, reason: str, detail: str = "", question_id: str | None = None): self.reason = reason self.detail = detail self.question_id = question_id suffix = f" (question {question_id})" if question_id else "" super().__init__(f"[{reason}] {detail}{suffix}") def is_degenerate_gold(rows: list[tuple], gold_answer: str) -> bool: """True if a gold result is unusable as a training target. Empty result sets and single all-NULL rows (e.g. MAX()/MIN() over an empty set) give an empty/ambiguous gold answer a model can fluke-match (answering "" or "None"), so they are treated as harness/data failures rather than legitimate episodes. The all-NULL check is applied regardless of column count so single- and multi-column degenerate rows are handled consistently. """ if not rows: return True if len(rows) == 1 and all(value is None for value in rows[0]): return True return not gold_answer.strip() class SQLEnvironment: """SQLEnv implementation with a structured SQL action loop. Runs in-process (TRL training calls reset()/step() directly). Formerly an OpenEnv ``Environment`` subclass; the base class only stored an optional transform, which this environment never used, so it is now standalone. """ def __init__( self, questions_path: str, db_dir: str, tokenizer: ModelTokenizer, step_budget: int = 15, ): if not hasattr(tokenizer, "apply_chat_template"): raise ValueError("Tokenizer must have 'apply_chat_template' method") if step_budget <= 0: raise ValueError("step_budget must be a positive integer") questions_file = Path(questions_path) database_dir = Path(db_dir) if not questions_file.exists(): raise FileNotFoundError(f"Questions file not found: {questions_file}") if not database_dir.exists() or not database_dir.is_dir(): raise FileNotFoundError(f"Database directory not found: {database_dir}") self.tokenizer = tokenizer self.questions_path = questions_file self.db_dir = database_dir self.step_budget = step_budget self.questions = self._load_questions(str(questions_file)) if not self.questions: raise ValueError("Questions file contains no questions") self._episode: EpisodeContext | None = None self._last_result = "" self._last_error = "" self._last_reward: float | None = None self._last_query_truncated = False self._state = SQLState() def _extract_tables_from_sql(self, sql: str) -> list[str]: """Extract table names from basic FROM/JOIN clauses.""" tables: list[str] = [] for match in _TABLE_FROM_JOIN_PATTERN.findall(sql): if match not in tables: tables.append(match) return tables def _load_questions(self, path: str) -> list[QuestionRecord]: """Load Spider questions JSON into QuestionRecord instances.""" questions_path = Path(path) if not questions_path.exists(): raise FileNotFoundError(f"Questions file not found: {questions_path}") try: with questions_path.open("r", encoding="utf-8") as handle: payload = json.load(handle) except json.JSONDecodeError as exc: raise ValueError( f"Invalid questions JSON format: {questions_path}" ) from exc if not isinstance(payload, list): raise ValueError("Questions JSON must be an array of records") question_records: list[QuestionRecord] = [] for idx, item in enumerate(payload): if not isinstance(item, dict): raise ValueError(f"Question at index {idx} must be an object") # Support both raw Spider format and curated format question_text = item.get("question_text") or item.get("question") db_name = item.get("database_name") or item.get("db_id") gold_sql = item.get("gold_sql") or item.get("query") if not isinstance(question_text, str) or not question_text.strip(): raise ValueError( f"Question at index {idx} missing non-empty 'question'" ) if not isinstance(db_name, str) or not db_name.strip(): raise ValueError(f"Question at index {idx} missing non-empty 'db_id'") if not isinstance(gold_sql, str) or not gold_sql.strip(): raise ValueError(f"Question at index {idx} missing non-empty 'query'") normalized_db_name = db_name.strip() if not is_valid_identifier(normalized_db_name): raise ValueError( f"Question at index {idx} has invalid db_id '{normalized_db_name}'" ) gold_answer = item.get("gold_answer", "") if not isinstance(gold_answer, str): gold_answer = str(gold_answer) question_records.append( QuestionRecord( question_id=item.get("question_id", f"q-{idx}"), question_text=question_text, database_name=normalized_db_name, gold_sql=gold_sql, gold_answer=gold_answer, answer_type=item.get("answer_type", "string"), difficulty=item.get("difficulty", "medium"), tables_involved=item.get("tables_involved") or self._extract_tables_from_sql(gold_sql), ) ) return question_records def _open_db(self, db_name: str) -> sqlite3.Connection: """Open a read-only SQLite connection for the requested database.""" normalized_db_name = db_name.strip() if not is_valid_identifier(normalized_db_name): raise ValueError(f"Invalid database name: '{db_name}'") db_path = resolve_db_path(self.db_dir, normalized_db_name) if db_path is None: raise FileNotFoundError( f"Database '{normalized_db_name}' not found in {self.db_dir}" ) uri = f"file:{db_path}?mode=ro" return sqlite3.connect(uri, uri=True) def _format_gold_answer(self, rows: list[tuple]) -> str: """Convert SQL rows into a stable string answer for episode comparison.""" if not rows: return "" if len(rows) == 1 and len(rows[0]) == 1: return str(rows[0][0]) return "\n".join(" | ".join(str(value) for value in row) for row in rows) def _execute_gold_sql( self, connection: sqlite3.Connection, sql: str, timeout_s: float = 5.0, ) -> list[tuple]: """Execute gold SQL with read-only/SELECT-only timeout protections.""" sql_stripped = sql.strip() if not sql_stripped: raise ValueError("SQL query cannot be empty") first_keyword_match = _FIRST_KEYWORD_PATTERN.match(sql_stripped) first_keyword = ( first_keyword_match.group(1).upper() if first_keyword_match else "" ) if first_keyword not in ("SELECT", "WITH"): raise ValueError(f"Only SELECT queries are allowed. Got: {first_keyword}") deadline = time.monotonic() + timeout_s def _progress_callback() -> int: return 1 if time.monotonic() > deadline else 0 connection.set_progress_handler(_progress_callback, 1000) try: cursor = connection.cursor() cursor.execute(sql_stripped) return cursor.fetchall() except sqlite3.OperationalError as exc: if "interrupted" in str(exc).lower(): raise sqlite3.OperationalError( f"Query timed out after {timeout_s:.1f} seconds" ) from exc raise finally: connection.set_progress_handler(None, 0) def reset( self, *, seed: int | None = None, episode_id: str | None = None, question_index: int | None = None, **kwargs, ) -> SQLObservation: """Reset episode context and return the initial rich observation. ``question_index`` (optional) selects a SPECIFIC question deterministically instead of the random ``seed`` draw — the full-set eval protocol iterates every question exactly once with it, eliminating the sampling-with- replacement noise that made N=50 evals unreliable. """ del kwargs if self._episode is not None: self._episode.db_connection.close() if question_index is not None: question = self.questions[question_index % len(self.questions)] else: chooser = random.Random(seed) if seed is not None else random question = chooser.choice(self.questions) # --- Harness guardrail: fail fast on a broken episode setup --------- # A broken gold answer is a HARNESS failure, not a model failure. We # raise HarnessError so the caller can count + exclude it instead of # silently training on an empty/degenerate gold (which poisons the # gradient). All failure paths close the connection (no leaks). try: connection = self._open_db(question.database_name) except (FileNotFoundError, ValueError) as exc: raise HarnessError("db_missing", str(exc), question.question_id) from exc try: gold_rows = self._execute_gold_sql(connection, question.gold_sql) except Exception as exc: connection.close() raise HarnessError( "gold_sql_error", str(exc), question.question_id ) from exc gold_answer = self._format_gold_answer(gold_rows) if is_degenerate_gold(gold_rows, gold_answer): connection.close() raise HarnessError( "gold_empty", "gold SQL returned an empty or degenerate result set", question.question_id, ) question_for_episode = QuestionRecord( question_id=question.question_id, question_text=question.question_text, database_name=question.database_name, gold_sql=question.gold_sql, gold_answer=gold_answer, answer_type=question.answer_type, difficulty=question.difficulty, tables_involved=list(question.tables_involved), ) resolved_episode_id = episode_id or str(uuid.uuid4()) self._episode = EpisodeContext( episode_id=resolved_episode_id, db_connection=connection, question_record=question_for_episode, step_count=0, budget=self.step_budget, done=False, gold_answer=gold_answer, gold_rows=gold_rows, ) self._state.episode_id = resolved_episode_id self._state.step_count = 0 self._state.current_action_type = "QUERY" self._state.history_messages = [] self._last_result = "" self._last_error = "" self._last_reward = None self._last_query_truncated = False return self._build_observation() def begin_episode( self, db_id: str, question: str, *, episode_id: str | None = None, gold: QuestionRecord | None = None, ) -> SQLObservation: """Start a NON-GOLD episode for a user question against ``db_id``. The seam the demo loop (``agent_loop.run_agent_turn``) uses INSTEAD of ``reset()`` (which is gold-coupled: random gold question + gold SQL + ``verify_answer`` scoring). Opens the DB read-only via the existing ``_open_db``, constructs an ``EpisodeContext`` with ``gold_answer=None`` and ``gold_rows=[]`` (no scoring target), and returns the initial observation via the existing ``_build_observation()``. A non-gold ANSWER terminates without scoring (see ``_handle_answer``). ``reset()``, ``step()``, and the gold scoring path are unchanged. Args: db_id: database to open (resolved by ``_open_db``). question: the user's plain-English question (no gold answer exists). episode_id: optional explicit id (a uuid is generated otherwise). gold: reserved for a future gold-seeded variant; ``None`` (the default) is the non-gold path. When supplied it carries the scoring target. Returns: the initial ``SQLObservation`` for the started episode. Raises: FileNotFoundError / ValueError: db_id missing or invalid (from ``_open_db``) — a setup error, surfaced to the caller. """ if self._episode is not None: self._episode.db_connection.close() connection = self._open_db(db_id) question_for_episode = QuestionRecord( question_id=episode_id or "user-question", question_text=question, database_name=db_id.strip(), gold_sql=gold.gold_sql if gold is not None else "", gold_answer=gold.gold_answer if gold is not None else "", answer_type=gold.answer_type if gold is not None else "string", difficulty=gold.difficulty if gold is not None else "medium", tables_involved=list(gold.tables_involved) if gold is not None else [], ) resolved_episode_id = episode_id or str(uuid.uuid4()) self._episode = EpisodeContext( episode_id=resolved_episode_id, db_connection=connection, question_record=question_for_episode, step_count=0, budget=self.step_budget, done=False, gold_answer=gold.gold_answer if gold is not None else None, gold_rows=[], ) self._state.episode_id = resolved_episode_id self._state.step_count = 0 self._state.current_action_type = "QUERY" self._state.history_messages = [] self._last_result = "" self._last_error = "" self._last_reward = None self._last_query_truncated = False return self._build_observation() def _get_table_names(self, connection: sqlite3.Connection) -> list[str]: """Return user-visible table names for the active SQLite database.""" cursor = connection.cursor() cursor.execute( """ SELECT name FROM sqlite_master WHERE type = 'table' AND name NOT LIKE 'sqlite_%' ORDER BY name """ ) return [str(row[0]) for row in cursor.fetchall()] def _resolve_table_name(self, table_name: str) -> tuple[str | None, list[str]]: """Resolve requested table name against active DB tables.""" if self._episode is None: return None, [] available_tables = self._get_table_names(self._episode.db_connection) lookup = {table.lower(): table for table in available_tables} resolved = lookup.get(table_name.strip().lower()) return resolved, available_tables def _format_rows(self, rows: list[tuple]) -> str: """Format SQL rows as readable text.""" if not rows: return "No rows returned." lines = [ f"{idx}. {' | '.join(str(value) for value in row)}" for idx, row in enumerate(rows, start=1) ] return "\n".join(lines) def _execute_sql(self, sql: str, timeout_s: float = 5.0) -> list[tuple]: """Execute SQL in sandbox: SELECT-only, single statement, timeout, truncation.""" if self._episode is None: raise RuntimeError("No active episode. Call reset() before step().") sql_stripped = sql.strip() if not sql_stripped: raise ValueError("SQL query cannot be empty") first_keyword_match = _FIRST_KEYWORD_PATTERN.match(sql_stripped) first_keyword = ( first_keyword_match.group(1).upper() if first_keyword_match else "" ) if first_keyword not in ("SELECT", "WITH"): raise ValueError(f"Only SELECT queries are allowed. Got: {first_keyword}") single_statement_sql = sql_stripped.rstrip(";").strip() if ";" in single_statement_sql: raise ValueError("Only a single SELECT statement is allowed") deadline = time.monotonic() + timeout_s def _progress_callback() -> int: return 1 if time.monotonic() > deadline else 0 connection = self._episode.db_connection connection.set_progress_handler(_progress_callback, 1000) self._last_query_truncated = False try: cursor = connection.cursor() cursor.execute(sql_stripped) rows = cursor.fetchmany(21) if len(rows) > 20: self._last_query_truncated = True rows = rows[:20] return rows except sqlite3.OperationalError as exc: if "interrupted" in str(exc).lower(): raise sqlite3.OperationalError( f"Query timed out after {timeout_s:.1f} seconds" ) from exc raise finally: connection.set_progress_handler(None, 0) def _handle_describe(self, table_name: str) -> str: """Return table schema and row count.""" if self._episode is None: raise RuntimeError("No active episode. Call reset() before step().") requested = table_name.strip() if not requested: raise ValueError("Argument cannot be empty for DESCRIBE") resolved_table, available_tables = self._resolve_table_name(requested) if resolved_table is None: available = ", ".join(available_tables) if available_tables else "none" raise ValueError( f"Table '{requested}' not found. Available tables: {available}" ) quoted_table = quote_ident(resolved_table) cursor = self._episode.db_connection.cursor() cursor.execute(f"PRAGMA table_info({quoted_table})") columns = cursor.fetchall() if not columns: raise ValueError(f"Table '{resolved_table}' has no visible columns") cursor.execute(f"SELECT COUNT(*) FROM {quoted_table}") row_count = int(cursor.fetchone()[0]) self._episode.described_tables.add(resolved_table) lines = [f"Table '{resolved_table}' columns:"] for _, col_name, col_type, _, _, _ in columns: normalized_type = str(col_type).strip() or "UNKNOWN" lines.append(f"- {col_name}: {normalized_type}") lines.append(f"Row count: {row_count}") return "\n".join(lines) def _handle_sample(self, table_name: str, limit: int = 5) -> str: """Return sample rows from a table.""" if self._episode is None: raise RuntimeError("No active episode. Call reset() before step().") requested = table_name.strip() if not requested: raise ValueError("Argument cannot be empty for SAMPLE") resolved_table, available_tables = self._resolve_table_name(requested) if resolved_table is None: available = ", ".join(available_tables) if available_tables else "none" raise ValueError( f"Table '{requested}' not found. Available tables: {available}" ) quoted_table = quote_ident(resolved_table) bounded_limit = max(1, min(limit, 20)) rows = self._execute_sql(f"SELECT * FROM {quoted_table} LIMIT {bounded_limit}") return f"Sample from '{resolved_table}':\n{self._format_rows(rows)}" def _handle_query(self, sql: str) -> tuple[str, list[tuple]]: """Execute query and return formatted output with raw result rows.""" sql_text = sql.strip() if not sql_text: raise ValueError("Argument cannot be empty for QUERY") rows = self._execute_sql(sql_text, timeout_s=5.0) output = self._format_rows(rows) if self._last_query_truncated: output = f"{output}\n... (truncated to 20 rows)" return output, rows def _handle_answer(self, value: str) -> tuple[bool, float]: """Compare submitted answer against episode gold answer. Non-gold episodes (``gold_answer is None``, started via ``begin_episode`` for a user question that has no scoring target) skip ``verify_answer`` entirely: the episode terminates with no score. The gold path (``gold_answer is not None``) is byte-identical to before. """ if self._episode is None: raise RuntimeError("No active episode. Call reset() before step().") if self._episode.gold_answer is None: # Non-gold (user) question — no gold target, never score it. self._episode.done = True return False, 0.0 # F005/C1: strip any ``​```chart {…}```​`` block so SCORING sees the clean # prose answer (a real model emitting prose + a chart block would otherwise # fail gold comparison — ``verify_answer`` only unwraps a fence when the # WHOLE string is one fenced block). GATED on the ``​```chart``​`` marker # (same guard as ``verifier.verify_answer``): ``strip_chart_block``'s # orphan-fence scrub eats a bare closing ``​```​`` so calling it on a legit # non-chart fenced answer (e.g. a ``​```sql … ```​`` block) would unbalance # the fence. ``verify_answer`` re-strips self-sufficiently downstream, so # this is defense-in-depth: a no-op for block-free answers and only fires # for actual chart blocks. if "```chart" in value.lower(): value = strip_chart_block(value) is_correct = verify_answer( predicted=value, gold=self._episode.gold_answer or "", answer_type=self._episode.question_record.answer_type, gold_rows=self._episode.gold_rows, ) self._episode.done = True return is_correct, 1.0 if is_correct else 0.0 def step( self, action: SQLAction, *, timeout_s: float = 30, **kwargs, ) -> SQLObservation: """Dispatch one structured action and return updated observation.""" del timeout_s del kwargs if self._episode is None: self._last_result = "" self._last_error = "No active episode. Call reset() before step()." self._last_reward = None return self._build_observation() if self._episode.done: return self._build_observation() action_type = str(action.action_type).strip().upper() argument = str(action.argument) self._state.current_action_type = action_type or "QUERY" self._last_result = "" self._last_error = "" self._last_reward = None reward_rows: list[tuple] | None = [] reward_sql = "" def _consume_invalid_step(error_text: str) -> SQLObservation: self._last_error = error_text self._episode.step_count += 1 self._episode.budget = max(0, self._episode.budget - 1) self._episode.action_log.append(f"{action_type} -> ERROR: {error_text}") if self._episode.budget == 0: self._episode.done = True self._last_reward = 0.0 self._state.step_count = self._episode.step_count return self._build_observation() valid_action_types = {"DESCRIBE", "SAMPLE", "QUERY", "ANSWER"} if action_type not in valid_action_types: return _consume_invalid_step( f"Unknown action type '{action.action_type}'. " "Valid types: DESCRIBE, SAMPLE, QUERY, ANSWER" ) argument_stripped = argument.strip() if not argument_stripped: return _consume_invalid_step(f"Argument cannot be empty for {action_type}") try: if action_type == "DESCRIBE": self._last_result = self._handle_describe(argument_stripped) elif action_type == "SAMPLE": self._last_result = self._handle_sample(argument_stripped) elif action_type == "QUERY": reward_sql = argument_stripped self._last_result, reward_rows = self._handle_query(argument_stripped) else: # ANSWER always terminates the episode (_handle_answer sets # done=True), so we return early without decrementing budget. is_correct, reward = self._handle_answer(argument_stripped) verdict = "correct" if is_correct else "incorrect" self._last_result = f"Answer submitted: {verdict}." self._last_reward = reward self._episode.step_count += 1 self._episode.action_log.append( f"ANSWER {argument_stripped} -> {verdict}" ) self._state.step_count = self._episode.step_count return self._build_observation() except ValueError as exc: self._last_error = str(exc) except sqlite3.Error as exc: self._last_error = f"SQL error: {exc}" self._episode.step_count += 1 self._episode.budget = max(0, self._episode.budget - 1) self._state.step_count = self._episode.step_count if self._episode.budget > 0: self._last_reward = compute_step_reward( ctx=self._episode, action_type=action_type, sql=reward_sql, rows=reward_rows, error=self._last_error or None, ) if self._last_error: self._episode.action_log.append( f"{action_type} -> ERROR: {self._last_error}" ) else: preview = self._last_result.splitlines()[0] if self._last_result else "ok" self._episode.action_log.append(f"{action_type} -> {preview}") if self._episode.budget == 0: self._episode.done = True if self._last_reward is None: self._last_reward = 0.0 return self._build_observation() def _build_observation(self) -> SQLObservation: """Construct a rich observation from the current episode context.""" if self._episode is None: observation = SQLObservation( question="", schema_info="", result=self._last_result, error=self._last_error, step_count=0, budget_remaining=0, action_history=[], done=False, reward=self._last_reward, ) else: table_names = self._get_table_names(self._episode.db_connection) known_tables = set(table_names) schema_lines = ["Available tables:", *[f"- {name}" for name in table_names]] if self._episode.described_tables: schema_lines.append("") schema_lines.append("Described tables:") for table_name in sorted(self._episode.described_tables): if table_name not in known_tables: schema_lines.append( f"- {table_name}: unavailable (not in active schema)" ) continue cursor = self._episode.db_connection.cursor() cursor.execute(f"PRAGMA table_info({quote_ident(table_name)})") columns = cursor.fetchall() if not columns: schema_lines.append(f"- {table_name}: no columns available") continue column_summary = ", ".join( f"{str(column[1])} {str(column[2]) or 'UNKNOWN'}" for column in columns ) schema_lines.append(f"- {table_name}: {column_summary}") observation = SQLObservation( question=self._episode.question_record.question_text, schema_info="\n".join(schema_lines), result=self._last_result, error=self._last_error, step_count=self._episode.step_count, budget_remaining=self._episode.budget, action_history=list(self._episode.action_log), done=self._episode.done, reward=self._last_reward, ) return observation @property def state(self) -> SQLState: """Get current exposed state metadata.""" return self._state def message_to_action(self, message: dict[str, str]) -> SQLAction: """Convert free-form messages into structured SQLAction values.""" if "role" not in message: raise ValueError("Message must contain a 'role' key") if "content" not in message: raise ValueError("Message must contain a 'content' key") if message["content"] is None: raise ValueError("Message content cannot be None") content = str(message["content"]) parsed = content.strip() action_type = "QUERY" argument = content if message["role"].lower() == "user" and parsed: prefix, separator, remainder = parsed.partition(" ") normalized_prefix = prefix.upper() if normalized_prefix in {"DESCRIBE", "SAMPLE", "QUERY", "ANSWER"}: action_type = normalized_prefix if separator: argument = remainder else: argument = "" self._state.current_action_type = action_type self._state.history_messages.append(message) return SQLAction(action_type=action_type, argument=argument)