""" FastAPI server for the Data Cleaning RL Environment. Uses openenv's create_app() to get the standard /reset, /step, /state, /health, /schema, /ws endpoints automatically. Adds custom hackathon endpoints: /tasks, /grader, /baseline. """ from __future__ import annotations import os from fastapi import FastAPI, HTTPException from pydantic import BaseModel from openenv.core.env_server.http_server import create_app from data_cleaning_env.action_registry import build_action_schema from data_cleaning_env.models import CleaningAction, Observation from data_cleaning_env.server.environment import DataCleaningEnvironment # --------------- # Configuration # --------------- SERVER_PORT: int = int(os.environ.get("SERVER_PORT", "8000")) # --------------- # Create the app via openenv — gives us /reset, /step, /state, /health, /ws # --------------- app: FastAPI = create_app( DataCleaningEnvironment, # factory callable CleaningAction, # Action class Observation, # Observation class env_name="data-cleaning-env", max_concurrent_envs=4, ) # --------------- # Custom hackathon endpoints (mounted on the same app) # --------------- class GraderRequest(BaseModel): episode_id: str @app.get("/tasks", summary="List tasks and action schema") async def tasks(): return { "tasks": [ { "id": "easy", "description": "Fix missing values in the Iris dataset (15% numeric missing).", "dataset": "iris (OpenML ID 61)", "max_steps": 20, "noise_types": ["missing_values"], }, { "id": "medium", "description": "Fix missing values, type errors, and duplicates in Adult Income (2k sample).", "dataset": "adult (OpenML ID 1590, 2k sample)", "max_steps": 40, "noise_types": ["missing_values", "type_errors", "duplicates"], }, { "id": "hard", "description": "Fix missing values, type errors, duplicates, outliers, and schema violations in Credit-G.", "dataset": "credit-g (OpenML ID 31)", "max_steps": 60, "noise_types": ["missing_values", "type_errors", "duplicates", "outliers", "schema_violations"], }, { "id": "expert", "description": "Clean ML metadata: mislabels, corrupted paths, inconsistent labels, duplicates, invalid dimensions.", "dataset": "ml-metadata (synthetic, 2000 rows)", "max_steps": 80, "noise_types": [ "mislabels", "corrupted_paths", "inconsistent_labels", "duplicates", "missing_values", "invalid_dimensions", "format_inconsistency", ], }, ], "action_schema": build_action_schema(), } @app.post("/grader", summary="Grade a completed episode") async def grader(req: GraderRequest): """ Compute the grader score for an episode. Requires a REST-based workflow where the environment instance is long-lived. For WebSocket-based GenericEnvClient usage, grading happens via the reward in the observation. """ # The openenv server creates fresh instances per session, so this endpoint # is only useful in our REST workflow. Return a generic message if not found. raise HTTPException( status_code=501, detail="Grading is computed via observation rewards in the openenv protocol. " "Use the per-step reward and done signal to evaluate agent performance.", ) @app.post("/baseline", summary="Run the baseline heuristic agent") async def baseline(): from data_cleaning_env.baseline import run_baseline_all scores = run_baseline_all() return {"baseline_scores": scores} def main(host: str = "0.0.0.0", port: int = 8000) -> None: import uvicorn uvicorn.run(app, host=host, port=port) if __name__ == "__main__": main()