AKGW580's picture
Fix OpenEnv multi-mode validation requirements
bfe1e62
"""
app.py - FastAPI server for SQL Repair Clinic OpenEnv environment.
Endpoints:
GET / health check
GET /info environment metadata
POST /reset start a new episode
POST /step submit an action
GET /state get current state (no side-effects)
Run locally:
uvicorn server.app:app --host 0.0.0.0 --port 7860
"""
from __future__ import annotations
import logging
from typing import Any, Dict, Optional
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from environment import SQLRepairEnv
from models import (
EnvironmentState,
ResetRequest,
SQLAction,
SQLObservation,
StepResponse,
)
from tasks import VALID_TASKS
# -----------------------------------------------------------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("sql_repair_clinic")
app = FastAPI(
title="SQL Repair Clinic - OpenEnv",
description=(
"An RL environment where agents learn to repair and write SQL queries. "
"Three difficulty levels: fix_syntax, fix_logic, write_analytical."
),
version="1.0.0",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Global environment instance (single-user / evaluation mode)
_env = SQLRepairEnv()
# -----------------------------------------------------------------------------
# Routes
# -----------------------------------------------------------------------------
@app.get("/", tags=["Health"])
def health() -> Dict[str, Any]:
"""Health check - returns 200 and basic info."""
return {
"status": "ok",
"environment": "sql-repair-clinic",
"version": "1.0.0",
"valid_tasks": VALID_TASKS,
}
@app.get("/info", tags=["Meta"])
def info() -> Dict[str, Any]:
"""Return environment metadata (OpenEnv spec)."""
return {
"name": "sql-repair-clinic",
"version": "1.0.0",
"description": (
"SQL Query Repair Clinic: the agent must fix or write SQL queries "
"to match a ground-truth result set. Three tasks with easy->medium->hard "
"difficulty progression."
),
"action_space": {
"type": "text",
"schema": {"query": "string - a valid SQL SELECT statement"},
},
"observation_space": {
"task_name": "string",
"difficulty": "string (easy|medium|hard)",
"task_description": "string",
"schema_info": "string (DDL + sample rows)",
"initial_broken_query": "string",
"last_submitted_query": "string",
"error_message": "string | null",
"result_preview": "list[dict] | null (up to 5 rows)",
"step_count": "integer",
"max_steps": "integer",
"last_reward": "float [0.0, 1.0]",
"hint": "string | null (shown after 3+ failed attempts)",
},
"reward_range": [0.0, 1.0],
"tasks": VALID_TASKS,
}
@app.post("/reset", response_model=SQLObservation, tags=["OpenEnv"])
def reset(body: Optional[ResetRequest] = None) -> SQLObservation:
"""
Reset the environment and start a new episode.
Body (optional JSON):
{ "task": "fix_syntax" } # or fix_logic / write_analytical
"""
task = (body.task if body else None) or "fix_syntax"
session_id = (body.session_id if body else None)
if task not in VALID_TASKS:
raise HTTPException(
status_code=422,
detail=f"Unknown task '{task}'. Valid: {VALID_TASKS}",
)
logger.info("reset task=%s session=%s", task, session_id)
obs = _env.reset(task=task, session_id=session_id)
return obs
@app.post("/step", response_model=StepResponse, tags=["OpenEnv"])
def step(action: SQLAction) -> StepResponse:
"""
Submit an SQL query action and receive a graded observation.
Body:
{ "query": "SELECT name, salary FROM employees WHERE ..." }
"""
obs, reward, done, info = _env.step(action)
logger.info(
"step task=%s step=%d reward=%.3f done=%s",
info.get("task", "?"), info.get("step", 0), reward, done,
)
return StepResponse(observation=obs, reward=reward, done=done, info=info)
@app.get("/state", response_model=EnvironmentState, tags=["OpenEnv"])
def state() -> EnvironmentState:
"""Return the current environment state without side-effects."""
return _env.state()
def main() -> None:
"""Console-script entry point for OpenEnv validation."""
import uvicorn
uvicorn.run("server.app:app", host="0.0.0.0", port=7860, reload=False)
if __name__ == "__main__":
main()