File size: 2,695 Bytes
9e64e71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""Deterministic oracle policy for upper-bound evaluation baselines."""

from __future__ import annotations

try:
    from ..models import QuestionRecord, SQLAction, SQLObservation
except ImportError:
    try:
        from models import QuestionRecord, SQLAction, SQLObservation  # type: ignore[no-redef]
    except ImportError:
        from sql_env.models import QuestionRecord, SQLAction, SQLObservation  # type: ignore[no-redef]


class OraclePolicy:
    """Play deterministic optimal actions using question gold data."""

    def __init__(self, questions: list[QuestionRecord]) -> None:
        self._question_lookup: dict[str, QuestionRecord] = {
            question.question_text: question for question in questions
        }
        self._current_question: QuestionRecord | None = None
        self._tables_to_describe: list[str] = []
        self._gold_sql_sent = False

    def select_action(self, observation: SQLObservation) -> SQLAction:
        """Select the next deterministic oracle action."""
        if self._needs_episode_reset(observation):
            self._start_episode(observation.question)

        if self._current_question is None:
            return SQLAction(action_type="ANSWER", argument="")

        answer_value = self._gold_answer()
        if observation.budget_remaining <= 1:
            return SQLAction(action_type="ANSWER", argument=answer_value)

        if self._tables_to_describe:
            table_name = self._tables_to_describe.pop(0)
            return SQLAction(action_type="DESCRIBE", argument=table_name)

        if not self._gold_sql_sent:
            self._gold_sql_sent = True
            return SQLAction(action_type="QUERY", argument=self._gold_sql())

        return SQLAction(action_type="ANSWER", argument=answer_value)

    def _needs_episode_reset(self, observation: SQLObservation) -> bool:
        if self._current_question is None:
            return True
        if observation.step_count == 0:
            return True
        return observation.question != self._current_question.question_text

    def _start_episode(self, question_text: str) -> None:
        self._current_question = self._question_lookup.get(question_text)
        self._tables_to_describe = []
        self._gold_sql_sent = False
        if self._current_question is not None:
            self._tables_to_describe = list(self._current_question.tables_involved)

    def _gold_sql(self) -> str:
        if self._current_question is None:
            return ""
        return self._current_question.gold_sql

    def _gold_answer(self) -> str:
        if self._current_question is None:
            return ""
        return self._current_question.gold_answer