""" app.py - FastAPI server for SQL Repair Clinic OpenEnv environment. Endpoints: GET / health check GET /info environment metadata POST /reset start a new episode POST /step submit an action GET /state get current state (no side-effects) Run locally: uvicorn server.app:app --host 0.0.0.0 --port 7860 """ from __future__ import annotations import logging from typing import Any, Dict, Optional from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from environment import SQLRepairEnv from models import ( EnvironmentState, ResetRequest, SQLAction, SQLObservation, StepResponse, ) from tasks import VALID_TASKS # ----------------------------------------------------------------------------- logging.basicConfig(level=logging.INFO) logger = logging.getLogger("sql_repair_clinic") app = FastAPI( title="SQL Repair Clinic - OpenEnv", description=( "An RL environment where agents learn to repair and write SQL queries. " "Three difficulty levels: fix_syntax, fix_logic, write_analytical." ), version="1.0.0", ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # Global environment instance (single-user / evaluation mode) _env = SQLRepairEnv() # ----------------------------------------------------------------------------- # Routes # ----------------------------------------------------------------------------- @app.get("/", tags=["Health"]) def health() -> Dict[str, Any]: """Health check - returns 200 and basic info.""" return { "status": "ok", "environment": "sql-repair-clinic", "version": "1.0.0", "valid_tasks": VALID_TASKS, } @app.get("/info", tags=["Meta"]) def info() -> Dict[str, Any]: """Return environment metadata (OpenEnv spec).""" return { "name": "sql-repair-clinic", "version": "1.0.0", "description": ( "SQL Query Repair Clinic: the agent must fix or write SQL queries " "to match a ground-truth result set. Three tasks with easy->medium->hard " "difficulty progression." ), "action_space": { "type": "text", "schema": {"query": "string - a valid SQL SELECT statement"}, }, "observation_space": { "task_name": "string", "difficulty": "string (easy|medium|hard)", "task_description": "string", "schema_info": "string (DDL + sample rows)", "initial_broken_query": "string", "last_submitted_query": "string", "error_message": "string | null", "result_preview": "list[dict] | null (up to 5 rows)", "step_count": "integer", "max_steps": "integer", "last_reward": "float [0.0, 1.0]", "hint": "string | null (shown after 3+ failed attempts)", }, "reward_range": [0.0, 1.0], "tasks": VALID_TASKS, } @app.post("/reset", response_model=SQLObservation, tags=["OpenEnv"]) def reset(body: Optional[ResetRequest] = None) -> SQLObservation: """ Reset the environment and start a new episode. Body (optional JSON): { "task": "fix_syntax" } # or fix_logic / write_analytical """ task = (body.task if body else None) or "fix_syntax" session_id = (body.session_id if body else None) if task not in VALID_TASKS: raise HTTPException( status_code=422, detail=f"Unknown task '{task}'. Valid: {VALID_TASKS}", ) logger.info("reset task=%s session=%s", task, session_id) obs = _env.reset(task=task, session_id=session_id) return obs @app.post("/step", response_model=StepResponse, tags=["OpenEnv"]) def step(action: SQLAction) -> StepResponse: """ Submit an SQL query action and receive a graded observation. Body: { "query": "SELECT name, salary FROM employees WHERE ..." } """ obs, reward, done, info = _env.step(action) logger.info( "step task=%s step=%d reward=%.3f done=%s", info.get("task", "?"), info.get("step", 0), reward, done, ) return StepResponse(observation=obs, reward=reward, done=done, info=info) @app.get("/state", response_model=EnvironmentState, tags=["OpenEnv"]) def state() -> EnvironmentState: """Return the current environment state without side-effects.""" return _env.state() def main() -> None: """Console-script entry point for OpenEnv validation.""" import uvicorn uvicorn.run("server.app:app", host="0.0.0.0", port=7860, reload=False) if __name__ == "__main__": main()