sql-agent-openenv / backend /api /openenv.py
ar9avg's picture
Nuclear clamp: every reward source in the codebase now returns (0.05, 0.95)
719c147
"""
OpenEnv spec routes.
POST /env/reset β†’ Observation
POST /env/step β†’ {observation: Observation, reward: RewardInfo}
GET /env/state β†’ current episode state dict
GET /env/tasks β†’ list of task metadata
GET /env/info β†’ env metadata
"""
from __future__ import annotations
import json
import sys
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from typing import Optional
from env.sql_env import get_env, Observation, Action, RewardInfo
from env.tasks import get_all_tasks
router = APIRouter()
def _log(tag: str, payload: dict) -> None:
"""Emit a single structured log line to stdout: [TAG] <json>"""
print(f"[{tag}] {json.dumps(payload)}", flush=True)
# ─── Request Models ───────────────────────────────────────────────
class ResetRequest(BaseModel):
task_id: str = "simple_queries"
question_id: Optional[str] = None
class StepRequest(BaseModel):
repair_action: str = "generate"
custom_sql: Optional[str] = None
# ─── Routes ───────────────────────────────────────────────────────
@router.post("/reset", response_model=Observation)
async def env_reset(req: ResetRequest):
"""Reset the environment to start a new episode."""
env = get_env()
if req.question_id:
obs = env.reset_with_question(req.task_id, req.question_id)
else:
obs = env.reset(req.task_id)
_log("START", {
"task_id": obs.task_id,
"task_difficulty": obs.task_difficulty,
"question": obs.question,
"max_attempts": obs.max_attempts,
})
return obs
@router.post("/step")
async def env_step(req: StepRequest):
"""Execute one step in the current episode."""
env = get_env()
try:
action = Action(
repair_action=req.repair_action,
custom_sql=req.custom_sql,
)
obs, reward = await env.step(action)
_log("STEP", {
"attempt": obs.attempt_number,
"action": req.repair_action,
"sql": obs.current_sql or "",
"error": obs.error_message,
"error_class": obs.error_class,
"reward": round(reward.value, 4),
"success": reward.success,
"done": reward.done,
})
if reward.done:
ep = env._episode
_log("END", {
"success": reward.success,
"attempts": obs.attempt_number,
"total_reward": round(
sum(s.reward for s in ep.steps) if ep and ep.steps else reward.value, 4
),
})
return {
"observation": obs.model_dump(),
"reward": reward.model_dump(),
}
except RuntimeError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get("/state")
async def env_state():
"""Get the current episode state."""
env = get_env()
return env.state()
@router.get("/tasks")
async def list_tasks():
"""List all available tasks with metadata."""
tasks = get_all_tasks()
return [
{
"id": t.id,
"name": t.name,
"difficulty": t.difficulty,
"description": t.description,
"question_count": len(t.questions),
"questions": [
{
"id": q.id,
"question": q.question,
"hint_tables": q.hint_tables,
}
for q in t.questions
],
}
for t in tasks
]
@router.get("/info")
async def env_info():
"""Return environment metadata (matches openenv.yaml spec)."""
return {
"name": "sql-agent-openenv",
"version": "1.0.0",
"description": "SQL generation and repair environment with RL-driven repair strategy selection.",
"action_space": {
"type": "discrete",
"actions": [
"generate",
"rewrite_full",
"fix_column",
"fix_table",
"add_groupby",
"rewrite_cte",
"fix_syntax",
"change_dialect",
"relax_filter",
],
},
"observation_space": {
"type": "dict",
"fields": [
"question",
"schema_info",
"current_sql",
"error_message",
"error_class",
"attempt_number",
"max_attempts",
"task_id",
"task_difficulty",
],
},
"reward_range": [0.05, 0.95],
"max_steps": 5,
"tasks": ["simple_queries", "join_queries", "complex_queries"],
"rl_algorithm": "LinUCB (contextual bandit)",
"feature_dim": 20,
"num_actions": 8,
}