File size: 2,695 Bytes
9e64e71 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 | """Deterministic oracle policy for upper-bound evaluation baselines."""
from __future__ import annotations
try:
from ..models import QuestionRecord, SQLAction, SQLObservation
except ImportError:
try:
from models import QuestionRecord, SQLAction, SQLObservation # type: ignore[no-redef]
except ImportError:
from sql_env.models import QuestionRecord, SQLAction, SQLObservation # type: ignore[no-redef]
class OraclePolicy:
"""Play deterministic optimal actions using question gold data."""
def __init__(self, questions: list[QuestionRecord]) -> None:
self._question_lookup: dict[str, QuestionRecord] = {
question.question_text: question for question in questions
}
self._current_question: QuestionRecord | None = None
self._tables_to_describe: list[str] = []
self._gold_sql_sent = False
def select_action(self, observation: SQLObservation) -> SQLAction:
"""Select the next deterministic oracle action."""
if self._needs_episode_reset(observation):
self._start_episode(observation.question)
if self._current_question is None:
return SQLAction(action_type="ANSWER", argument="")
answer_value = self._gold_answer()
if observation.budget_remaining <= 1:
return SQLAction(action_type="ANSWER", argument=answer_value)
if self._tables_to_describe:
table_name = self._tables_to_describe.pop(0)
return SQLAction(action_type="DESCRIBE", argument=table_name)
if not self._gold_sql_sent:
self._gold_sql_sent = True
return SQLAction(action_type="QUERY", argument=self._gold_sql())
return SQLAction(action_type="ANSWER", argument=answer_value)
def _needs_episode_reset(self, observation: SQLObservation) -> bool:
if self._current_question is None:
return True
if observation.step_count == 0:
return True
return observation.question != self._current_question.question_text
def _start_episode(self, question_text: str) -> None:
self._current_question = self._question_lookup.get(question_text)
self._tables_to_describe = []
self._gold_sql_sent = False
if self._current_question is not None:
self._tables_to_describe = list(self._current_question.tables_involved)
def _gold_sql(self) -> str:
if self._current_question is None:
return ""
return self._current_question.gold_sql
def _gold_answer(self) -> str:
if self._current_question is None:
return ""
return self._current_question.gold_answer
|