yashmarathe's picture
refactor: full openenv protocol compliance
1a55ff4
"""
FastAPI server for the Data Cleaning RL Environment.
Uses openenv's create_app() to get the standard /reset, /step, /state,
/health, /schema, /ws endpoints automatically. Adds custom hackathon
endpoints: /tasks, /grader, /baseline.
"""
from __future__ import annotations
import os
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from openenv.core.env_server.http_server import create_app
from data_cleaning_env.action_registry import build_action_schema
from data_cleaning_env.models import CleaningAction, Observation
from data_cleaning_env.server.environment import DataCleaningEnvironment
# ---------------
# Configuration
# ---------------
SERVER_PORT: int = int(os.environ.get("SERVER_PORT", "8000"))
# ---------------
# Create the app via openenv — gives us /reset, /step, /state, /health, /ws
# ---------------
app: FastAPI = create_app(
DataCleaningEnvironment, # factory callable
CleaningAction, # Action class
Observation, # Observation class
env_name="data-cleaning-env",
max_concurrent_envs=4,
)
# ---------------
# Custom hackathon endpoints (mounted on the same app)
# ---------------
class GraderRequest(BaseModel):
episode_id: str
@app.get("/tasks", summary="List tasks and action schema")
async def tasks():
return {
"tasks": [
{
"id": "easy",
"description": "Fix missing values in the Iris dataset (15% numeric missing).",
"dataset": "iris (OpenML ID 61)",
"max_steps": 20,
"noise_types": ["missing_values"],
},
{
"id": "medium",
"description": "Fix missing values, type errors, and duplicates in Adult Income (2k sample).",
"dataset": "adult (OpenML ID 1590, 2k sample)",
"max_steps": 40,
"noise_types": ["missing_values", "type_errors", "duplicates"],
},
{
"id": "hard",
"description": "Fix missing values, type errors, duplicates, outliers, and schema violations in Credit-G.",
"dataset": "credit-g (OpenML ID 31)",
"max_steps": 60,
"noise_types": ["missing_values", "type_errors", "duplicates", "outliers", "schema_violations"],
},
{
"id": "expert",
"description": "Clean ML metadata: mislabels, corrupted paths, inconsistent labels, duplicates, invalid dimensions.",
"dataset": "ml-metadata (synthetic, 2000 rows)",
"max_steps": 80,
"noise_types": [
"mislabels", "corrupted_paths", "inconsistent_labels",
"duplicates", "missing_values", "invalid_dimensions", "format_inconsistency",
],
},
],
"action_schema": build_action_schema(),
}
@app.post("/grader", summary="Grade a completed episode")
async def grader(req: GraderRequest):
"""
Compute the grader score for an episode. Requires a REST-based workflow
where the environment instance is long-lived. For WebSocket-based
GenericEnvClient usage, grading happens via the reward in the observation.
"""
# The openenv server creates fresh instances per session, so this endpoint
# is only useful in our REST workflow. Return a generic message if not found.
raise HTTPException(
status_code=501,
detail="Grading is computed via observation rewards in the openenv protocol. "
"Use the per-step reward and done signal to evaluate agent performance.",
)
@app.post("/baseline", summary="Run the baseline heuristic agent")
async def baseline():
from data_cleaning_env.baseline import run_baseline_all
scores = run_baseline_all()
return {"baseline_scores": scores}
def main(host: str = "0.0.0.0", port: int = 8000) -> None:
import uvicorn
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
main()