File size: 2,884 Bytes
a39d8ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""
nl2sql-bench/models.py
======================
Typed contracts for the NL2SQL-Bench OpenEnv environment.

Action  : The SQL query the agent submits.
Observation : What the agent sees after each step.
State   : Episode-level metadata (for state() endpoint).
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional

from openenv.core.env_server import Action, Observation, State


# ---------------------------------------------------------------------------
# Action
# ---------------------------------------------------------------------------

class NL2SQLAction(Action):
    """A single SQL query submitted by the agent."""
    query: str = ""


# ---------------------------------------------------------------------------
# Observation
# ---------------------------------------------------------------------------


class NL2SQLObservation(Observation):
    """
    Everything the agent needs to reason about and iterate its SQL query.

    Fields
    ------
    question        : The natural-language question to answer.
    schema_context  : Relevant table/column descriptions as a string block.
    task_name       : Identifier of the current task (easy / medium / hard).
    last_query      : The SQL the agent submitted on the last step (empty on reset).
    last_result     : Up to 10 rows returned by the last query (list of dicts).
    last_error      : SQLite error string if the query failed, else None.
    result_columns  : Column names of last_result rows.
    step            : Current step number (1-indexed).
    max_steps       : Maximum steps allowed per episode.
    done            : True when the episode is over (success or step exhausted).
    reward          : Reward for the most recent action (None on reset).
    score           : Normalised cumulative score so far [0.0, 1.0].
    """
    question: str = ""
    schema_context: str = ""
    task_name: str = ""
    last_query: str = ""
    last_result: List[Dict[str, Any]] = field(default_factory=list)
    last_error: Optional[str] = None
    result_columns: List[str] = field(default_factory=list)
    step: int = 0
    max_steps: int = 5
    done: bool = False
    reward: Optional[float] = None
    score: float = 0.0


# ---------------------------------------------------------------------------
# State
# ---------------------------------------------------------------------------

class NL2SQLState(State):
    """Episode-level state (returned by the /state endpoint)."""
    episode_id: Optional[str] = None
    step_count: int = 0
    task_name: str = ""
    task_difficulty: str = ""        # easy | medium | hard
    question: str = ""
    best_reward: float = 0.0         # highest reward seen this episode
    cumulative_reward: float = 0.0
    solved: bool = False             # True if exact match was achieved