Spaces:
Sleeping
Sleeping
| """ | |
| OpenEnv spec routes. | |
| POST /env/reset β Observation | |
| POST /env/step β {observation: Observation, reward: RewardInfo} | |
| GET /env/state β current episode state dict | |
| GET /env/tasks β list of task metadata | |
| GET /env/info β env metadata | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import sys | |
| from fastapi import APIRouter, HTTPException | |
| from pydantic import BaseModel | |
| from typing import Optional | |
| from env.sql_env import get_env, Observation, Action, RewardInfo | |
| from env.tasks import get_all_tasks | |
| router = APIRouter() | |
| def _log(tag: str, payload: dict) -> None: | |
| """Emit a single structured log line to stdout: [TAG] <json>""" | |
| print(f"[{tag}] {json.dumps(payload)}", flush=True) | |
| # βββ Request Models βββββββββββββββββββββββββββββββββββββββββββββββ | |
| class ResetRequest(BaseModel): | |
| task_id: str = "simple_queries" | |
| question_id: Optional[str] = None | |
| class StepRequest(BaseModel): | |
| repair_action: str = "generate" | |
| custom_sql: Optional[str] = None | |
| # βββ Routes βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def env_reset(req: ResetRequest): | |
| """Reset the environment to start a new episode.""" | |
| env = get_env() | |
| if req.question_id: | |
| obs = env.reset_with_question(req.task_id, req.question_id) | |
| else: | |
| obs = env.reset(req.task_id) | |
| _log("START", { | |
| "task_id": obs.task_id, | |
| "task_difficulty": obs.task_difficulty, | |
| "question": obs.question, | |
| "max_attempts": obs.max_attempts, | |
| }) | |
| return obs | |
| async def env_step(req: StepRequest): | |
| """Execute one step in the current episode.""" | |
| env = get_env() | |
| try: | |
| action = Action( | |
| repair_action=req.repair_action, | |
| custom_sql=req.custom_sql, | |
| ) | |
| obs, reward = await env.step(action) | |
| _log("STEP", { | |
| "attempt": obs.attempt_number, | |
| "action": req.repair_action, | |
| "sql": obs.current_sql or "", | |
| "error": obs.error_message, | |
| "error_class": obs.error_class, | |
| "reward": round(reward.value, 4), | |
| "success": reward.success, | |
| "done": reward.done, | |
| }) | |
| if reward.done: | |
| ep = env._episode | |
| _log("END", { | |
| "success": reward.success, | |
| "attempts": obs.attempt_number, | |
| "total_reward": round( | |
| sum(s.reward for s in ep.steps) if ep and ep.steps else reward.value, 4 | |
| ), | |
| }) | |
| return { | |
| "observation": obs.model_dump(), | |
| "reward": reward.model_dump(), | |
| } | |
| except RuntimeError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| async def env_state(): | |
| """Get the current episode state.""" | |
| env = get_env() | |
| return env.state() | |
| async def list_tasks(): | |
| """List all available tasks with metadata.""" | |
| tasks = get_all_tasks() | |
| return [ | |
| { | |
| "id": t.id, | |
| "name": t.name, | |
| "difficulty": t.difficulty, | |
| "description": t.description, | |
| "question_count": len(t.questions), | |
| "questions": [ | |
| { | |
| "id": q.id, | |
| "question": q.question, | |
| "hint_tables": q.hint_tables, | |
| } | |
| for q in t.questions | |
| ], | |
| } | |
| for t in tasks | |
| ] | |
| async def env_info(): | |
| """Return environment metadata (matches openenv.yaml spec).""" | |
| return { | |
| "name": "sql-agent-openenv", | |
| "version": "1.0.0", | |
| "description": "SQL generation and repair environment with RL-driven repair strategy selection.", | |
| "action_space": { | |
| "type": "discrete", | |
| "actions": [ | |
| "generate", | |
| "rewrite_full", | |
| "fix_column", | |
| "fix_table", | |
| "add_groupby", | |
| "rewrite_cte", | |
| "fix_syntax", | |
| "change_dialect", | |
| "relax_filter", | |
| ], | |
| }, | |
| "observation_space": { | |
| "type": "dict", | |
| "fields": [ | |
| "question", | |
| "schema_info", | |
| "current_sql", | |
| "error_message", | |
| "error_class", | |
| "attempt_number", | |
| "max_attempts", | |
| "task_id", | |
| "task_difficulty", | |
| ], | |
| }, | |
| "reward_range": [0.05, 0.95], | |
| "max_steps": 5, | |
| "tasks": ["simple_queries", "join_queries", "complex_queries"], | |
| "rl_algorithm": "LinUCB (contextual bandit)", | |
| "feature_dim": 20, | |
| "num_actions": 8, | |
| } | |