Spaces:
Running
Running
File size: 3,166 Bytes
a39d8ef 1a16689 a39d8ef 1a16689 a39d8ef 1a16689 a39d8ef 1a16689 a39d8ef 1a16689 a39d8ef 1a16689 a39d8ef 1a16689 a39d8ef 1a16689 a39d8ef 1a45976 1a16689 1a45976 ed2e608 a39d8ef ae4bbb1 1a45976 ed2e608 a39d8ef ae4bbb1 a39d8ef | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 | import httpx
import json
import os
from typing import Any, Dict, Optional
from dataclasses import dataclass
@dataclass
class NL2SQLAction:
query: str
@dataclass
class NL2SQLObservation:
question: str
schema_context: str
task_name: str
last_query: str
last_result: list
last_error: Optional[str]
result_columns: list
step: int
max_steps: int
done: bool
reward: float
score: float
@dataclass
class StepResult:
observation: NL2SQLObservation
reward: float
done: bool
class NL2SQLEnv:
def __init__(self, base_url: str = "http://localhost:8000"):
self.base_url = base_url.rstrip("/")
self.client = httpx.AsyncClient(base_url=self.base_url, timeout=120.0)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.client.aclose()
async def reset(self) -> StepResult:
task_name = os.getenv("NL2SQL_DEFAULT_TASK", "simple-filter")
# Send task_name both ways — some openenv-core versions read from body,
# some from the action wrapper. Belt-and-suspenders.
payload = {"task_name": task_name}
resp = await self.client.post("/reset", json=payload)
resp.raise_for_status()
return self._parse_result(resp.json())
async def step(self, action: NL2SQLAction) -> StepResult:
# CRITICAL FIX: The server's action_cls=NL2SQLAction expects the payload
# wrapped in {"action": {"query": ...}} per OpenEnv protocol.
# Sending {"query": ...} at the top level bypasses action parsing → 0 reward.
payload = {"action": {"query": action.query}}
resp = await self.client.post("/step", json=payload)
resp.raise_for_status()
return self._parse_result(resp.json())
def _parse_result(self, payload: Dict[str, Any]) -> StepResult:
obs_data = payload.get("observation", payload)
# Extract reward — check top-level payload first (OpenEnv puts it there),
# then fall back to nested observation dict.
raw_reward = payload.get("reward")
if raw_reward is None:
raw_reward = obs_data.get("reward")
safe_reward = float(raw_reward) if raw_reward is not None else 0.0
safe_score = float(obs_data.get("score") or 0.0)
safe_done = bool(payload.get("done") or obs_data.get("done") or False)
obs = NL2SQLObservation(
question=obs_data.get("question", ""),
schema_context=obs_data.get("schema_context", ""),
task_name=obs_data.get("task_name", ""),
last_query=obs_data.get("last_query", ""),
last_result=obs_data.get("last_result", []),
last_error=obs_data.get("last_error"),
result_columns=obs_data.get("result_columns", []),
step=obs_data.get("step", 0),
max_steps=obs_data.get("max_steps", 5),
done=safe_done,
reward=safe_reward,
score=safe_score,
)
return StepResult(
observation=obs,
reward=safe_reward,
done=safe_done,
) |