Spaces:
Sleeping
Sleeping
| """ | |
| app.py - FastAPI server for SQL Repair Clinic OpenEnv environment. | |
| Endpoints: | |
| GET / health check | |
| GET /info environment metadata | |
| POST /reset start a new episode | |
| POST /step submit an action | |
| GET /state get current state (no side-effects) | |
| Run locally: | |
| uvicorn server.app:app --host 0.0.0.0 --port 7860 | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from typing import Any, Dict, Optional | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from environment import SQLRepairEnv | |
| from models import ( | |
| EnvironmentState, | |
| ResetRequest, | |
| SQLAction, | |
| SQLObservation, | |
| StepResponse, | |
| ) | |
| from tasks import VALID_TASKS | |
| # ----------------------------------------------------------------------------- | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("sql_repair_clinic") | |
| app = FastAPI( | |
| title="SQL Repair Clinic - OpenEnv", | |
| description=( | |
| "An RL environment where agents learn to repair and write SQL queries. " | |
| "Three difficulty levels: fix_syntax, fix_logic, write_analytical." | |
| ), | |
| version="1.0.0", | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global environment instance (single-user / evaluation mode) | |
| _env = SQLRepairEnv() | |
| # ----------------------------------------------------------------------------- | |
| # Routes | |
| # ----------------------------------------------------------------------------- | |
| def health() -> Dict[str, Any]: | |
| """Health check - returns 200 and basic info.""" | |
| return { | |
| "status": "ok", | |
| "environment": "sql-repair-clinic", | |
| "version": "1.0.0", | |
| "valid_tasks": VALID_TASKS, | |
| } | |
| def info() -> Dict[str, Any]: | |
| """Return environment metadata (OpenEnv spec).""" | |
| return { | |
| "name": "sql-repair-clinic", | |
| "version": "1.0.0", | |
| "description": ( | |
| "SQL Query Repair Clinic: the agent must fix or write SQL queries " | |
| "to match a ground-truth result set. Three tasks with easy->medium->hard " | |
| "difficulty progression." | |
| ), | |
| "action_space": { | |
| "type": "text", | |
| "schema": {"query": "string - a valid SQL SELECT statement"}, | |
| }, | |
| "observation_space": { | |
| "task_name": "string", | |
| "difficulty": "string (easy|medium|hard)", | |
| "task_description": "string", | |
| "schema_info": "string (DDL + sample rows)", | |
| "initial_broken_query": "string", | |
| "last_submitted_query": "string", | |
| "error_message": "string | null", | |
| "result_preview": "list[dict] | null (up to 5 rows)", | |
| "step_count": "integer", | |
| "max_steps": "integer", | |
| "last_reward": "float [0.0, 1.0]", | |
| "hint": "string | null (shown after 3+ failed attempts)", | |
| }, | |
| "reward_range": [0.0, 1.0], | |
| "tasks": VALID_TASKS, | |
| } | |
| def reset(body: Optional[ResetRequest] = None) -> SQLObservation: | |
| """ | |
| Reset the environment and start a new episode. | |
| Body (optional JSON): | |
| { "task": "fix_syntax" } # or fix_logic / write_analytical | |
| """ | |
| task = (body.task if body else None) or "fix_syntax" | |
| session_id = (body.session_id if body else None) | |
| if task not in VALID_TASKS: | |
| raise HTTPException( | |
| status_code=422, | |
| detail=f"Unknown task '{task}'. Valid: {VALID_TASKS}", | |
| ) | |
| logger.info("reset task=%s session=%s", task, session_id) | |
| obs = _env.reset(task=task, session_id=session_id) | |
| return obs | |
| def step(action: SQLAction) -> StepResponse: | |
| """ | |
| Submit an SQL query action and receive a graded observation. | |
| Body: | |
| { "query": "SELECT name, salary FROM employees WHERE ..." } | |
| """ | |
| obs, reward, done, info = _env.step(action) | |
| logger.info( | |
| "step task=%s step=%d reward=%.3f done=%s", | |
| info.get("task", "?"), info.get("step", 0), reward, done, | |
| ) | |
| return StepResponse(observation=obs, reward=reward, done=done, info=info) | |
| def state() -> EnvironmentState: | |
| """Return the current environment state without side-effects.""" | |
| return _env.state() | |
| def main() -> None: | |
| """Console-script entry point for OpenEnv validation.""" | |
| import uvicorn | |
| uvicorn.run("server.app:app", host="0.0.0.0", port=7860, reload=False) | |
| if __name__ == "__main__": | |
| main() | |