File size: 3,166 Bytes
a39d8ef
 
 
 
 
 
1a16689
a39d8ef
 
 
 
1a16689
a39d8ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a16689
a39d8ef
 
 
 
 
 
1a16689
a39d8ef
 
1a16689
 
a39d8ef
 
 
 
 
 
 
 
 
1a16689
 
 
 
 
a39d8ef
 
 
1a16689
 
 
 
a39d8ef
1a16689
a39d8ef
 
 
1a45976
1a16689
 
 
 
 
 
 
 
1a45976
ed2e608
 
a39d8ef
 
 
 
 
 
 
 
 
 
ae4bbb1
1a45976
ed2e608
a39d8ef
 
 
 
ae4bbb1
a39d8ef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import httpx
import json
import os
from typing import Any, Dict, Optional
from dataclasses import dataclass


@dataclass
class NL2SQLAction:
    query: str


@dataclass
class NL2SQLObservation:
    question: str
    schema_context: str
    task_name: str
    last_query: str
    last_result: list
    last_error: Optional[str]
    result_columns: list
    step: int
    max_steps: int
    done: bool
    reward: float
    score: float


@dataclass
class StepResult:
    observation: NL2SQLObservation
    reward: float
    done: bool


class NL2SQLEnv:
    def __init__(self, base_url: str = "http://localhost:8000"):
        self.base_url = base_url.rstrip("/")
        self.client = httpx.AsyncClient(base_url=self.base_url, timeout=120.0)

    async def __aenter__(self):
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        await self.client.aclose()

    async def reset(self) -> StepResult:
        task_name = os.getenv("NL2SQL_DEFAULT_TASK", "simple-filter")
        # Send task_name both ways — some openenv-core versions read from body,
        # some from the action wrapper. Belt-and-suspenders.
        payload = {"task_name": task_name}
        resp = await self.client.post("/reset", json=payload)
        resp.raise_for_status()
        return self._parse_result(resp.json())

    async def step(self, action: NL2SQLAction) -> StepResult:
        # CRITICAL FIX: The server's action_cls=NL2SQLAction expects the payload
        # wrapped in {"action": {"query": ...}} per OpenEnv protocol.
        # Sending {"query": ...} at the top level bypasses action parsing → 0 reward.
        payload = {"action": {"query": action.query}}
        resp = await self.client.post("/step", json=payload)
        resp.raise_for_status()
        return self._parse_result(resp.json())

    def _parse_result(self, payload: Dict[str, Any]) -> StepResult:
        obs_data = payload.get("observation", payload)

        # Extract reward — check top-level payload first (OpenEnv puts it there),
        # then fall back to nested observation dict.
        raw_reward = payload.get("reward")
        if raw_reward is None:
            raw_reward = obs_data.get("reward")
        safe_reward = float(raw_reward) if raw_reward is not None else 0.0

        safe_score = float(obs_data.get("score") or 0.0)
        safe_done = bool(payload.get("done") or obs_data.get("done") or False)

        obs = NL2SQLObservation(
            question=obs_data.get("question", ""),
            schema_context=obs_data.get("schema_context", ""),
            task_name=obs_data.get("task_name", ""),
            last_query=obs_data.get("last_query", ""),
            last_result=obs_data.get("last_result", []),
            last_error=obs_data.get("last_error"),
            result_columns=obs_data.get("result_columns", []),
            step=obs_data.get("step", 0),
            max_steps=obs_data.get("max_steps", 5),
            done=safe_done,
            reward=safe_reward,
            score=safe_score,
        )
        return StepResult(
            observation=obs,
            reward=safe_reward,
            done=safe_done,
        )