Spaces:
Sleeping
Sleeping
File size: 5,064 Bytes
3c665d2 55f54ec 3c665d2 55f54ec 3c665d2 55f54ec 3c665d2 55f54ec 3c665d2 719c147 3c665d2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 | """
OpenEnv spec routes.
POST /env/reset β Observation
POST /env/step β {observation: Observation, reward: RewardInfo}
GET /env/state β current episode state dict
GET /env/tasks β list of task metadata
GET /env/info β env metadata
"""
from __future__ import annotations
import json
import sys
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from typing import Optional
from env.sql_env import get_env, Observation, Action, RewardInfo
from env.tasks import get_all_tasks
router = APIRouter()
def _log(tag: str, payload: dict) -> None:
"""Emit a single structured log line to stdout: [TAG] <json>"""
print(f"[{tag}] {json.dumps(payload)}", flush=True)
# βββ Request Models βββββββββββββββββββββββββββββββββββββββββββββββ
class ResetRequest(BaseModel):
task_id: str = "simple_queries"
question_id: Optional[str] = None
class StepRequest(BaseModel):
repair_action: str = "generate"
custom_sql: Optional[str] = None
# βββ Routes βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@router.post("/reset", response_model=Observation)
async def env_reset(req: ResetRequest):
"""Reset the environment to start a new episode."""
env = get_env()
if req.question_id:
obs = env.reset_with_question(req.task_id, req.question_id)
else:
obs = env.reset(req.task_id)
_log("START", {
"task_id": obs.task_id,
"task_difficulty": obs.task_difficulty,
"question": obs.question,
"max_attempts": obs.max_attempts,
})
return obs
@router.post("/step")
async def env_step(req: StepRequest):
"""Execute one step in the current episode."""
env = get_env()
try:
action = Action(
repair_action=req.repair_action,
custom_sql=req.custom_sql,
)
obs, reward = await env.step(action)
_log("STEP", {
"attempt": obs.attempt_number,
"action": req.repair_action,
"sql": obs.current_sql or "",
"error": obs.error_message,
"error_class": obs.error_class,
"reward": round(reward.value, 4),
"success": reward.success,
"done": reward.done,
})
if reward.done:
ep = env._episode
_log("END", {
"success": reward.success,
"attempts": obs.attempt_number,
"total_reward": round(
sum(s.reward for s in ep.steps) if ep and ep.steps else reward.value, 4
),
})
return {
"observation": obs.model_dump(),
"reward": reward.model_dump(),
}
except RuntimeError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get("/state")
async def env_state():
"""Get the current episode state."""
env = get_env()
return env.state()
@router.get("/tasks")
async def list_tasks():
"""List all available tasks with metadata."""
tasks = get_all_tasks()
return [
{
"id": t.id,
"name": t.name,
"difficulty": t.difficulty,
"description": t.description,
"question_count": len(t.questions),
"questions": [
{
"id": q.id,
"question": q.question,
"hint_tables": q.hint_tables,
}
for q in t.questions
],
}
for t in tasks
]
@router.get("/info")
async def env_info():
"""Return environment metadata (matches openenv.yaml spec)."""
return {
"name": "sql-agent-openenv",
"version": "1.0.0",
"description": "SQL generation and repair environment with RL-driven repair strategy selection.",
"action_space": {
"type": "discrete",
"actions": [
"generate",
"rewrite_full",
"fix_column",
"fix_table",
"add_groupby",
"rewrite_cte",
"fix_syntax",
"change_dialect",
"relax_filter",
],
},
"observation_space": {
"type": "dict",
"fields": [
"question",
"schema_info",
"current_sql",
"error_message",
"error_class",
"attempt_number",
"max_attempts",
"task_id",
"task_difficulty",
],
},
"reward_range": [0.05, 0.95],
"max_steps": 5,
"tasks": ["simple_queries", "join_queries", "complex_queries"],
"rl_algorithm": "LinUCB (contextual bandit)",
"feature_dim": 20,
"num_actions": 8,
}
|