File size: 34,091 Bytes
656f91e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
import json
import logging
from pathlib import Path
import random
import re
import sqlite3
import time
import uuid
from typing import Protocol

from .reward import compute_step_reward
from .verifier import verify_answer

# ``chart_intent`` is gradio-free + dependency-light (stdlib + pydantic only), so
# importing it here does NOT pull gradio/torch/trl/transformers into the env. It
# is the SINGLE strip site for the ``​```chart {…}```​`` block (display AND
# scoring): ``_handle_answer`` strips the block from the model's ANSWER value
# before ``verify_answer`` so a prose+block answer still matches its gold value.
try:
    from .chart_intent import strip_chart_block
    from .sql_ident import is_valid_identifier, quote_ident
except ImportError:  # pragma: no cover - flat-layout / direct-run fallback
    from chart_intent import strip_chart_block  # type: ignore[no-redef]
    from sql_ident import is_valid_identifier, quote_ident  # type: ignore[no-redef]


class ModelTokenizer(Protocol):
    """Minimal tokenizer contract the environment relies on.

    Replaces OpenEnv's ModelTokenizer interface. Any object exposing
    ``apply_chat_template`` (HuggingFace tokenizers, MockTokenizer, the
    training adapter's stub) satisfies it.
    """

    def apply_chat_template(self, messages: list[dict[str, str]], **kwargs) -> str: ...


try:
    from sql_env.models import (
        EpisodeContext,
        QuestionRecord,
        SQLAction,
        SQLObservation,
        SQLState,
    )
except ImportError:
    # Fallback for Docker where PYTHONPATH=/app/env
    from models import (  # type: ignore[no-redef]
        EpisodeContext,
        QuestionRecord,
        SQLAction,
        SQLObservation,
        SQLState,
    )

logger = logging.getLogger(__name__)

_TABLE_FROM_JOIN_PATTERN = re.compile(
    r"\b(?:FROM|JOIN)\s+([A-Za-z_][A-Za-z0-9_]*)", re.IGNORECASE
)
_FIRST_KEYWORD_PATTERN = re.compile(r"^[\s\n\r\t]*(\w+)")


def resolve_db_path(db_dir: str | Path, db_id: str) -> Path | None:
    """Resolve the existing ``.sqlite`` file for ``db_id`` under ``db_dir``.

    The SINGLE source of the db-id -> file resolution AND its path-traversal
    defense, shared by ``SQLEnvironment._open_db`` (episode setup) and
    ``agent_loop._resolve_db_path`` (read-only re-exec). Tries the two layout
    candidates β€” ``<root>/<id>/<id>.sqlite`` then ``<root>/<id>.sqlite`` β€” and
    returns the first that BOTH exists AND passes the containment guard
    (``.resolve()`` + ``db_root in candidate.parents``), so a ``db_id`` like
    ``../escape`` can never resolve a file outside ``db_dir``. Returns ``None``
    when no contained candidate exists.

    Identifier-charset validation (``^[A-Za-z0-9_]+$``) is the CALLER's
    responsibility β€” ``_open_db`` enforces it and raises its own error; the
    containment guard here is defense-in-depth that stands on its own.
    """
    root = Path(db_dir)
    db_root = root.resolve()
    candidates = [
        (root / db_id / f"{db_id}.sqlite").resolve(),
        (root / f"{db_id}.sqlite").resolve(),
    ]
    for candidate in candidates:
        if candidate.exists() and db_root in candidate.parents:
            return candidate
    return None


class HarnessError(RuntimeError):
    """A broken *episode setup* β€” never a model failure.

    Raised by ``reset()`` when the environment itself cannot produce a valid
    episode: the gold SQL errors or times out, the database file is missing, or
    the gold answer is empty/degenerate. Such episodes must be excluded from
    training (they would poison the gradient), not scored as model mistakes.
    The ``reason`` is a stable key (``db_missing`` | ``gold_sql_error`` |
    ``gold_empty``) so failures can be counted and triaged.
    """

    def __init__(self, reason: str, detail: str = "", question_id: str | None = None):
        self.reason = reason
        self.detail = detail
        self.question_id = question_id
        suffix = f" (question {question_id})" if question_id else ""
        super().__init__(f"[{reason}] {detail}{suffix}")


def is_degenerate_gold(rows: list[tuple], gold_answer: str) -> bool:
    """True if a gold result is unusable as a training target.

    Empty result sets and single all-NULL rows (e.g. MAX()/MIN() over an empty
    set) give an empty/ambiguous gold answer a model can fluke-match (answering
    "" or "None"), so they are treated as harness/data failures rather than
    legitimate episodes. The all-NULL check is applied regardless of column
    count so single- and multi-column degenerate rows are handled consistently.
    """
    if not rows:
        return True
    if len(rows) == 1 and all(value is None for value in rows[0]):
        return True
    return not gold_answer.strip()


class SQLEnvironment:
    """SQLEnv implementation with a structured SQL action loop.

    Runs in-process (TRL training calls reset()/step() directly). Formerly an
    OpenEnv ``Environment`` subclass; the base class only stored an optional
    transform, which this environment never used, so it is now standalone.
    """

    def __init__(
        self,
        questions_path: str,
        db_dir: str,
        tokenizer: ModelTokenizer,
        step_budget: int = 15,
    ):
        if not hasattr(tokenizer, "apply_chat_template"):
            raise ValueError("Tokenizer must have 'apply_chat_template' method")
        if step_budget <= 0:
            raise ValueError("step_budget must be a positive integer")

        questions_file = Path(questions_path)
        database_dir = Path(db_dir)
        if not questions_file.exists():
            raise FileNotFoundError(f"Questions file not found: {questions_file}")
        if not database_dir.exists() or not database_dir.is_dir():
            raise FileNotFoundError(f"Database directory not found: {database_dir}")

        self.tokenizer = tokenizer
        self.questions_path = questions_file
        self.db_dir = database_dir
        self.step_budget = step_budget
        self.questions = self._load_questions(str(questions_file))

        if not self.questions:
            raise ValueError("Questions file contains no questions")

        self._episode: EpisodeContext | None = None
        self._last_result = ""
        self._last_error = ""
        self._last_reward: float | None = None
        self._last_query_truncated = False

        self._state = SQLState()

    def _extract_tables_from_sql(self, sql: str) -> list[str]:
        """Extract table names from basic FROM/JOIN clauses."""
        tables: list[str] = []
        for match in _TABLE_FROM_JOIN_PATTERN.findall(sql):
            if match not in tables:
                tables.append(match)
        return tables

    def _load_questions(self, path: str) -> list[QuestionRecord]:
        """Load Spider questions JSON into QuestionRecord instances."""
        questions_path = Path(path)
        if not questions_path.exists():
            raise FileNotFoundError(f"Questions file not found: {questions_path}")

        try:
            with questions_path.open("r", encoding="utf-8") as handle:
                payload = json.load(handle)
        except json.JSONDecodeError as exc:
            raise ValueError(
                f"Invalid questions JSON format: {questions_path}"
            ) from exc

        if not isinstance(payload, list):
            raise ValueError("Questions JSON must be an array of records")

        question_records: list[QuestionRecord] = []
        for idx, item in enumerate(payload):
            if not isinstance(item, dict):
                raise ValueError(f"Question at index {idx} must be an object")

            # Support both raw Spider format and curated format
            question_text = item.get("question_text") or item.get("question")
            db_name = item.get("database_name") or item.get("db_id")
            gold_sql = item.get("gold_sql") or item.get("query")

            if not isinstance(question_text, str) or not question_text.strip():
                raise ValueError(
                    f"Question at index {idx} missing non-empty 'question'"
                )
            if not isinstance(db_name, str) or not db_name.strip():
                raise ValueError(f"Question at index {idx} missing non-empty 'db_id'")
            if not isinstance(gold_sql, str) or not gold_sql.strip():
                raise ValueError(f"Question at index {idx} missing non-empty 'query'")

            normalized_db_name = db_name.strip()
            if not is_valid_identifier(normalized_db_name):
                raise ValueError(
                    f"Question at index {idx} has invalid db_id '{normalized_db_name}'"
                )

            gold_answer = item.get("gold_answer", "")
            if not isinstance(gold_answer, str):
                gold_answer = str(gold_answer)

            question_records.append(
                QuestionRecord(
                    question_id=item.get("question_id", f"q-{idx}"),
                    question_text=question_text,
                    database_name=normalized_db_name,
                    gold_sql=gold_sql,
                    gold_answer=gold_answer,
                    answer_type=item.get("answer_type", "string"),
                    difficulty=item.get("difficulty", "medium"),
                    tables_involved=item.get("tables_involved")
                    or self._extract_tables_from_sql(gold_sql),
                )
            )

        return question_records

    def _open_db(self, db_name: str) -> sqlite3.Connection:
        """Open a read-only SQLite connection for the requested database."""
        normalized_db_name = db_name.strip()
        if not is_valid_identifier(normalized_db_name):
            raise ValueError(f"Invalid database name: '{db_name}'")

        db_path = resolve_db_path(self.db_dir, normalized_db_name)
        if db_path is None:
            raise FileNotFoundError(
                f"Database '{normalized_db_name}' not found in {self.db_dir}"
            )

        uri = f"file:{db_path}?mode=ro"
        return sqlite3.connect(uri, uri=True)

    def _format_gold_answer(self, rows: list[tuple]) -> str:
        """Convert SQL rows into a stable string answer for episode comparison."""
        if not rows:
            return ""
        if len(rows) == 1 and len(rows[0]) == 1:
            return str(rows[0][0])
        return "\n".join(" | ".join(str(value) for value in row) for row in rows)

    def _execute_gold_sql(
        self,
        connection: sqlite3.Connection,
        sql: str,
        timeout_s: float = 5.0,
    ) -> list[tuple]:
        """Execute gold SQL with read-only/SELECT-only timeout protections."""
        sql_stripped = sql.strip()
        if not sql_stripped:
            raise ValueError("SQL query cannot be empty")

        first_keyword_match = _FIRST_KEYWORD_PATTERN.match(sql_stripped)
        first_keyword = (
            first_keyword_match.group(1).upper() if first_keyword_match else ""
        )
        if first_keyword not in ("SELECT", "WITH"):
            raise ValueError(f"Only SELECT queries are allowed. Got: {first_keyword}")

        deadline = time.monotonic() + timeout_s

        def _progress_callback() -> int:
            return 1 if time.monotonic() > deadline else 0

        connection.set_progress_handler(_progress_callback, 1000)
        try:
            cursor = connection.cursor()
            cursor.execute(sql_stripped)
            return cursor.fetchall()
        except sqlite3.OperationalError as exc:
            if "interrupted" in str(exc).lower():
                raise sqlite3.OperationalError(
                    f"Query timed out after {timeout_s:.1f} seconds"
                ) from exc
            raise
        finally:
            connection.set_progress_handler(None, 0)

    def reset(
        self,
        *,
        seed: int | None = None,
        episode_id: str | None = None,
        question_index: int | None = None,
        **kwargs,
    ) -> SQLObservation:
        """Reset episode context and return the initial rich observation.

        ``question_index`` (optional) selects a SPECIFIC question deterministically
        instead of the random ``seed`` draw β€” the full-set eval protocol iterates
        every question exactly once with it, eliminating the sampling-with-
        replacement noise that made N=50 evals unreliable.
        """
        del kwargs

        if self._episode is not None:
            self._episode.db_connection.close()

        if question_index is not None:
            question = self.questions[question_index % len(self.questions)]
        else:
            chooser = random.Random(seed) if seed is not None else random
            question = chooser.choice(self.questions)

        # --- Harness guardrail: fail fast on a broken episode setup ---------
        # A broken gold answer is a HARNESS failure, not a model failure. We
        # raise HarnessError so the caller can count + exclude it instead of
        # silently training on an empty/degenerate gold (which poisons the
        # gradient). All failure paths close the connection (no leaks).
        try:
            connection = self._open_db(question.database_name)
        except (FileNotFoundError, ValueError) as exc:
            raise HarnessError("db_missing", str(exc), question.question_id) from exc

        try:
            gold_rows = self._execute_gold_sql(connection, question.gold_sql)
        except Exception as exc:
            connection.close()
            raise HarnessError(
                "gold_sql_error", str(exc), question.question_id
            ) from exc

        gold_answer = self._format_gold_answer(gold_rows)
        if is_degenerate_gold(gold_rows, gold_answer):
            connection.close()
            raise HarnessError(
                "gold_empty",
                "gold SQL returned an empty or degenerate result set",
                question.question_id,
            )
        question_for_episode = QuestionRecord(
            question_id=question.question_id,
            question_text=question.question_text,
            database_name=question.database_name,
            gold_sql=question.gold_sql,
            gold_answer=gold_answer,
            answer_type=question.answer_type,
            difficulty=question.difficulty,
            tables_involved=list(question.tables_involved),
        )

        resolved_episode_id = episode_id or str(uuid.uuid4())
        self._episode = EpisodeContext(
            episode_id=resolved_episode_id,
            db_connection=connection,
            question_record=question_for_episode,
            step_count=0,
            budget=self.step_budget,
            done=False,
            gold_answer=gold_answer,
            gold_rows=gold_rows,
        )

        self._state.episode_id = resolved_episode_id
        self._state.step_count = 0
        self._state.current_action_type = "QUERY"
        self._state.history_messages = []

        self._last_result = ""
        self._last_error = ""
        self._last_reward = None
        self._last_query_truncated = False

        return self._build_observation()

    def begin_episode(
        self,
        db_id: str,
        question: str,
        *,
        episode_id: str | None = None,
        gold: QuestionRecord | None = None,
    ) -> SQLObservation:
        """Start a NON-GOLD episode for a user question against ``db_id``.

        The seam the demo loop (``agent_loop.run_agent_turn``) uses INSTEAD of
        ``reset()`` (which is gold-coupled: random gold question + gold SQL +
        ``verify_answer`` scoring). Opens the DB read-only via the existing
        ``_open_db``, constructs an ``EpisodeContext`` with ``gold_answer=None``
        and ``gold_rows=[]`` (no scoring target), and returns the initial
        observation via the existing ``_build_observation()``. A non-gold ANSWER
        terminates without scoring (see ``_handle_answer``).

        ``reset()``, ``step()``, and the gold scoring path are unchanged.

        Args:
            db_id: database to open (resolved by ``_open_db``).
            question: the user's plain-English question (no gold answer exists).
            episode_id: optional explicit id (a uuid is generated otherwise).
            gold: reserved for a future gold-seeded variant; ``None`` (the default)
                is the non-gold path. When supplied it carries the scoring target.

        Returns:
            the initial ``SQLObservation`` for the started episode.

        Raises:
            FileNotFoundError / ValueError: db_id missing or invalid (from
                ``_open_db``) β€” a setup error, surfaced to the caller.
        """
        if self._episode is not None:
            self._episode.db_connection.close()

        connection = self._open_db(db_id)

        question_for_episode = QuestionRecord(
            question_id=episode_id or "user-question",
            question_text=question,
            database_name=db_id.strip(),
            gold_sql=gold.gold_sql if gold is not None else "",
            gold_answer=gold.gold_answer if gold is not None else "",
            answer_type=gold.answer_type if gold is not None else "string",
            difficulty=gold.difficulty if gold is not None else "medium",
            tables_involved=list(gold.tables_involved) if gold is not None else [],
        )

        resolved_episode_id = episode_id or str(uuid.uuid4())
        self._episode = EpisodeContext(
            episode_id=resolved_episode_id,
            db_connection=connection,
            question_record=question_for_episode,
            step_count=0,
            budget=self.step_budget,
            done=False,
            gold_answer=gold.gold_answer if gold is not None else None,
            gold_rows=[],
        )

        self._state.episode_id = resolved_episode_id
        self._state.step_count = 0
        self._state.current_action_type = "QUERY"
        self._state.history_messages = []

        self._last_result = ""
        self._last_error = ""
        self._last_reward = None
        self._last_query_truncated = False

        return self._build_observation()

    def _get_table_names(self, connection: sqlite3.Connection) -> list[str]:
        """Return user-visible table names for the active SQLite database."""
        cursor = connection.cursor()
        cursor.execute(
            """
            SELECT name
            FROM sqlite_master
            WHERE type = 'table' AND name NOT LIKE 'sqlite_%'
            ORDER BY name
            """
        )
        return [str(row[0]) for row in cursor.fetchall()]

    def _resolve_table_name(self, table_name: str) -> tuple[str | None, list[str]]:
        """Resolve requested table name against active DB tables."""
        if self._episode is None:
            return None, []
        available_tables = self._get_table_names(self._episode.db_connection)
        lookup = {table.lower(): table for table in available_tables}
        resolved = lookup.get(table_name.strip().lower())
        return resolved, available_tables

    def _format_rows(self, rows: list[tuple]) -> str:
        """Format SQL rows as readable text."""
        if not rows:
            return "No rows returned."
        lines = [
            f"{idx}. {' | '.join(str(value) for value in row)}"
            for idx, row in enumerate(rows, start=1)
        ]
        return "\n".join(lines)

    def _execute_sql(self, sql: str, timeout_s: float = 5.0) -> list[tuple]:
        """Execute SQL in sandbox: SELECT-only, single statement, timeout, truncation."""
        if self._episode is None:
            raise RuntimeError("No active episode. Call reset() before step().")

        sql_stripped = sql.strip()
        if not sql_stripped:
            raise ValueError("SQL query cannot be empty")

        first_keyword_match = _FIRST_KEYWORD_PATTERN.match(sql_stripped)
        first_keyword = (
            first_keyword_match.group(1).upper() if first_keyword_match else ""
        )
        if first_keyword not in ("SELECT", "WITH"):
            raise ValueError(f"Only SELECT queries are allowed. Got: {first_keyword}")

        single_statement_sql = sql_stripped.rstrip(";").strip()
        if ";" in single_statement_sql:
            raise ValueError("Only a single SELECT statement is allowed")

        deadline = time.monotonic() + timeout_s

        def _progress_callback() -> int:
            return 1 if time.monotonic() > deadline else 0

        connection = self._episode.db_connection
        connection.set_progress_handler(_progress_callback, 1000)

        self._last_query_truncated = False
        try:
            cursor = connection.cursor()
            cursor.execute(sql_stripped)
            rows = cursor.fetchmany(21)
            if len(rows) > 20:
                self._last_query_truncated = True
                rows = rows[:20]
            return rows
        except sqlite3.OperationalError as exc:
            if "interrupted" in str(exc).lower():
                raise sqlite3.OperationalError(
                    f"Query timed out after {timeout_s:.1f} seconds"
                ) from exc
            raise
        finally:
            connection.set_progress_handler(None, 0)

    def _handle_describe(self, table_name: str) -> str:
        """Return table schema and row count."""
        if self._episode is None:
            raise RuntimeError("No active episode. Call reset() before step().")

        requested = table_name.strip()
        if not requested:
            raise ValueError("Argument cannot be empty for DESCRIBE")

        resolved_table, available_tables = self._resolve_table_name(requested)
        if resolved_table is None:
            available = ", ".join(available_tables) if available_tables else "none"
            raise ValueError(
                f"Table '{requested}' not found. Available tables: {available}"
            )

        quoted_table = quote_ident(resolved_table)
        cursor = self._episode.db_connection.cursor()
        cursor.execute(f"PRAGMA table_info({quoted_table})")
        columns = cursor.fetchall()
        if not columns:
            raise ValueError(f"Table '{resolved_table}' has no visible columns")

        cursor.execute(f"SELECT COUNT(*) FROM {quoted_table}")
        row_count = int(cursor.fetchone()[0])
        self._episode.described_tables.add(resolved_table)

        lines = [f"Table '{resolved_table}' columns:"]
        for _, col_name, col_type, _, _, _ in columns:
            normalized_type = str(col_type).strip() or "UNKNOWN"
            lines.append(f"- {col_name}: {normalized_type}")
        lines.append(f"Row count: {row_count}")
        return "\n".join(lines)

    def _handle_sample(self, table_name: str, limit: int = 5) -> str:
        """Return sample rows from a table."""
        if self._episode is None:
            raise RuntimeError("No active episode. Call reset() before step().")

        requested = table_name.strip()
        if not requested:
            raise ValueError("Argument cannot be empty for SAMPLE")

        resolved_table, available_tables = self._resolve_table_name(requested)
        if resolved_table is None:
            available = ", ".join(available_tables) if available_tables else "none"
            raise ValueError(
                f"Table '{requested}' not found. Available tables: {available}"
            )

        quoted_table = quote_ident(resolved_table)
        bounded_limit = max(1, min(limit, 20))
        rows = self._execute_sql(f"SELECT * FROM {quoted_table} LIMIT {bounded_limit}")
        return f"Sample from '{resolved_table}':\n{self._format_rows(rows)}"

    def _handle_query(self, sql: str) -> tuple[str, list[tuple]]:
        """Execute query and return formatted output with raw result rows."""
        sql_text = sql.strip()
        if not sql_text:
            raise ValueError("Argument cannot be empty for QUERY")

        rows = self._execute_sql(sql_text, timeout_s=5.0)
        output = self._format_rows(rows)
        if self._last_query_truncated:
            output = f"{output}\n... (truncated to 20 rows)"
        return output, rows

    def _handle_answer(self, value: str) -> tuple[bool, float]:
        """Compare submitted answer against episode gold answer.

        Non-gold episodes (``gold_answer is None``, started via ``begin_episode``
        for a user question that has no scoring target) skip ``verify_answer``
        entirely: the episode terminates with no score. The gold path
        (``gold_answer is not None``) is byte-identical to before.
        """
        if self._episode is None:
            raise RuntimeError("No active episode. Call reset() before step().")

        if self._episode.gold_answer is None:
            # Non-gold (user) question β€” no gold target, never score it.
            self._episode.done = True
            return False, 0.0

        # F005/C1: strip any ``​```chart {…}```​`` block so SCORING sees the clean
        # prose answer (a real model emitting prose + a chart block would otherwise
        # fail gold comparison β€” ``verify_answer`` only unwraps a fence when the
        # WHOLE string is one fenced block). GATED on the ``​```chart``​`` marker
        # (same guard as ``verifier.verify_answer``): ``strip_chart_block``'s
        # orphan-fence scrub eats a bare closing ``​```​`` so calling it on a legit
        # non-chart fenced answer (e.g. a ``​```sql … ```​`` block) would unbalance
        # the fence. ``verify_answer`` re-strips self-sufficiently downstream, so
        # this is defense-in-depth: a no-op for block-free answers and only fires
        # for actual chart blocks.
        if "```chart" in value.lower():
            value = strip_chart_block(value)

        is_correct = verify_answer(
            predicted=value,
            gold=self._episode.gold_answer or "",
            answer_type=self._episode.question_record.answer_type,
            gold_rows=self._episode.gold_rows,
        )
        self._episode.done = True
        return is_correct, 1.0 if is_correct else 0.0

    def step(
        self,
        action: SQLAction,
        *,
        timeout_s: float = 30,
        **kwargs,
    ) -> SQLObservation:
        """Dispatch one structured action and return updated observation."""
        del timeout_s
        del kwargs

        if self._episode is None:
            self._last_result = ""
            self._last_error = "No active episode. Call reset() before step()."
            self._last_reward = None
            return self._build_observation()

        if self._episode.done:
            return self._build_observation()

        action_type = str(action.action_type).strip().upper()
        argument = str(action.argument)

        self._state.current_action_type = action_type or "QUERY"
        self._last_result = ""
        self._last_error = ""
        self._last_reward = None
        reward_rows: list[tuple] | None = []
        reward_sql = ""

        def _consume_invalid_step(error_text: str) -> SQLObservation:
            self._last_error = error_text
            self._episode.step_count += 1
            self._episode.budget = max(0, self._episode.budget - 1)
            self._episode.action_log.append(f"{action_type} -> ERROR: {error_text}")
            if self._episode.budget == 0:
                self._episode.done = True
                self._last_reward = 0.0
            self._state.step_count = self._episode.step_count
            return self._build_observation()

        valid_action_types = {"DESCRIBE", "SAMPLE", "QUERY", "ANSWER"}
        if action_type not in valid_action_types:
            return _consume_invalid_step(
                f"Unknown action type '{action.action_type}'. "
                "Valid types: DESCRIBE, SAMPLE, QUERY, ANSWER"
            )

        argument_stripped = argument.strip()
        if not argument_stripped:
            return _consume_invalid_step(f"Argument cannot be empty for {action_type}")

        try:
            if action_type == "DESCRIBE":
                self._last_result = self._handle_describe(argument_stripped)
            elif action_type == "SAMPLE":
                self._last_result = self._handle_sample(argument_stripped)
            elif action_type == "QUERY":
                reward_sql = argument_stripped
                self._last_result, reward_rows = self._handle_query(argument_stripped)
            else:
                # ANSWER always terminates the episode (_handle_answer sets
                # done=True), so we return early without decrementing budget.
                is_correct, reward = self._handle_answer(argument_stripped)
                verdict = "correct" if is_correct else "incorrect"
                self._last_result = f"Answer submitted: {verdict}."
                self._last_reward = reward
                self._episode.step_count += 1
                self._episode.action_log.append(
                    f"ANSWER {argument_stripped} -> {verdict}"
                )
                self._state.step_count = self._episode.step_count
                return self._build_observation()

        except ValueError as exc:
            self._last_error = str(exc)
        except sqlite3.Error as exc:
            self._last_error = f"SQL error: {exc}"

        self._episode.step_count += 1
        self._episode.budget = max(0, self._episode.budget - 1)
        self._state.step_count = self._episode.step_count

        if self._episode.budget > 0:
            self._last_reward = compute_step_reward(
                ctx=self._episode,
                action_type=action_type,
                sql=reward_sql,
                rows=reward_rows,
                error=self._last_error or None,
            )

        if self._last_error:
            self._episode.action_log.append(
                f"{action_type} -> ERROR: {self._last_error}"
            )
        else:
            preview = self._last_result.splitlines()[0] if self._last_result else "ok"
            self._episode.action_log.append(f"{action_type} -> {preview}")

        if self._episode.budget == 0:
            self._episode.done = True
            if self._last_reward is None:
                self._last_reward = 0.0

        return self._build_observation()

    def _build_observation(self) -> SQLObservation:
        """Construct a rich observation from the current episode context."""
        if self._episode is None:
            observation = SQLObservation(
                question="",
                schema_info="",
                result=self._last_result,
                error=self._last_error,
                step_count=0,
                budget_remaining=0,
                action_history=[],
                done=False,
                reward=self._last_reward,
            )
        else:
            table_names = self._get_table_names(self._episode.db_connection)
            known_tables = set(table_names)
            schema_lines = ["Available tables:", *[f"- {name}" for name in table_names]]

            if self._episode.described_tables:
                schema_lines.append("")
                schema_lines.append("Described tables:")
                for table_name in sorted(self._episode.described_tables):
                    if table_name not in known_tables:
                        schema_lines.append(
                            f"- {table_name}: unavailable (not in active schema)"
                        )
                        continue
                    cursor = self._episode.db_connection.cursor()
                    cursor.execute(f"PRAGMA table_info({quote_ident(table_name)})")
                    columns = cursor.fetchall()
                    if not columns:
                        schema_lines.append(f"- {table_name}: no columns available")
                        continue
                    column_summary = ", ".join(
                        f"{str(column[1])} {str(column[2]) or 'UNKNOWN'}"
                        for column in columns
                    )
                    schema_lines.append(f"- {table_name}: {column_summary}")

            observation = SQLObservation(
                question=self._episode.question_record.question_text,
                schema_info="\n".join(schema_lines),
                result=self._last_result,
                error=self._last_error,
                step_count=self._episode.step_count,
                budget_remaining=self._episode.budget,
                action_history=list(self._episode.action_log),
                done=self._episode.done,
                reward=self._last_reward,
            )

        return observation

    @property
    def state(self) -> SQLState:
        """Get current exposed state metadata."""
        return self._state

    def message_to_action(self, message: dict[str, str]) -> SQLAction:
        """Convert free-form messages into structured SQLAction values."""
        if "role" not in message:
            raise ValueError("Message must contain a 'role' key")
        if "content" not in message:
            raise ValueError("Message must contain a 'content' key")
        if message["content"] is None:
            raise ValueError("Message content cannot be None")

        content = str(message["content"])
        parsed = content.strip()

        action_type = "QUERY"
        argument = content

        if message["role"].lower() == "user" and parsed:
            prefix, separator, remainder = parsed.partition(" ")
            normalized_prefix = prefix.upper()
            if normalized_prefix in {"DESCRIBE", "SAMPLE", "QUERY", "ANSWER"}:
                action_type = normalized_prefix
                if separator:
                    argument = remainder
                else:
                    argument = ""

        self._state.current_action_type = action_type
        self._state.history_messages.append(message)

        return SQLAction(action_type=action_type, argument=argument)