""" OpenEnv spec routes. POST /env/reset → Observation POST /env/step → {observation: Observation, reward: RewardInfo} GET /env/state → current episode state dict GET /env/tasks → list of task metadata GET /env/info → env metadata """ from __future__ import annotations from fastapi import APIRouter, HTTPException from pydantic import BaseModel from typing import Optional from env.sql_env import get_env, Observation, Action, RewardInfo from env.tasks import get_all_tasks router = APIRouter() # ─── Request Models ─────────────────────────────────────────────── class ResetRequest(BaseModel): task_id: str = "simple_queries" question_id: Optional[str] = None class StepRequest(BaseModel): repair_action: str = "generate" custom_sql: Optional[str] = None # ─── Routes ─────────────────────────────────────────────────────── @router.post("/reset", response_model=Observation) async def env_reset(req: ResetRequest): """Reset the environment to start a new episode.""" env = get_env() if req.question_id: obs = env.reset_with_question(req.task_id, req.question_id) else: obs = env.reset(req.task_id) return obs @router.post("/step") async def env_step(req: StepRequest): """Execute one step in the current episode.""" env = get_env() try: action = Action( repair_action=req.repair_action, custom_sql=req.custom_sql, ) obs, reward = await env.step(action) return { "observation": obs.model_dump(), "reward": reward.model_dump(), } except RuntimeError as e: raise HTTPException(status_code=400, detail=str(e)) @router.get("/state") async def env_state(): """Get the current episode state.""" env = get_env() return env.state() @router.get("/tasks") async def list_tasks(): """List all available tasks with metadata.""" tasks = get_all_tasks() return [ { "id": t.id, "name": t.name, "difficulty": t.difficulty, "description": t.description, "question_count": len(t.questions), "questions": [ { "id": q.id, "question": q.question, "hint_tables": q.hint_tables, } for q in t.questions ], } for t in tasks ] @router.get("/info") async def env_info(): """Return environment metadata (matches openenv.yaml spec).""" return { "name": "sql-agent-openenv", "version": "1.0.0", "description": "SQL generation and repair environment with RL-driven repair strategy selection.", "action_space": { "type": "discrete", "actions": [ "generate", "rewrite_full", "fix_column", "fix_table", "add_groupby", "rewrite_cte", "fix_syntax", "change_dialect", "relax_filter", ], }, "observation_space": { "type": "dict", "fields": [ "question", "schema_info", "current_sql", "error_message", "error_class", "attempt_number", "max_attempts", "task_id", "task_difficulty", ], }, "reward_range": [-1.5, 1.5], "max_steps": 5, "tasks": ["simple_queries", "join_queries", "complex_queries"], "rl_algorithm": "LinUCB (contextual bandit)", "feature_dim": 20, "num_actions": 8, }