| from __future__ import annotations |
|
|
| import argparse |
| from typing import Any, Literal |
|
|
| import uvicorn |
| from fastapi import Body, FastAPI |
| from fastapi.responses import RedirectResponse, Response |
| from pydantic import BaseModel |
|
|
| from models import Action, Observation |
|
|
| from .environment import DataCleaningEnv |
|
|
| TASKS = ["basic_cleaning", "moderate_cleaning", "full_pipeline"] |
| ENV_NAME = "data_cleaning_env" |
| ENV_DESCRIPTION = ( |
| "RL environment for interactive tabular data cleaning and preparation. " |
| "Agents must fix missing values, duplicates, dtype issues, category inconsistencies, " |
| "and derived-feature requirements." |
| ) |
|
|
| app = FastAPI(title="Data Cleaning OpenEnv", version="1.0.0") |
| ENV = DataCleaningEnv() |
|
|
|
|
| class ResetRequest(BaseModel): |
| task_name: Literal["basic_cleaning", "moderate_cleaning", "full_pipeline"] = "basic_cleaning" |
|
|
|
|
| def _metadata() -> dict[str, Any]: |
| return { |
| "name": ENV_NAME, |
| "description": ENV_DESCRIPTION, |
| "version": "1.0.0", |
| "tasks": TASKS, |
| "mode": "simulation", |
| } |
|
|
|
|
| @app.get("/") |
| def root() -> dict[str, Any]: |
| payload = _metadata() |
| payload["status"] = "ok" |
| return payload |
|
|
|
|
| @app.get("/web", include_in_schema=False) |
| def web_root() -> RedirectResponse: |
| return RedirectResponse(url="/", status_code=307) |
|
|
|
|
| @app.get("/web/", include_in_schema=False) |
| def web_root_slash() -> RedirectResponse: |
| return RedirectResponse(url="/", status_code=307) |
|
|
|
|
| @app.get("/favicon.ico", include_in_schema=False) |
| def favicon() -> Response: |
| return Response(status_code=204) |
|
|
|
|
| @app.get("/health") |
| def health() -> dict[str, str]: |
| return {"status": "healthy"} |
|
|
|
|
| @app.get("/metadata") |
| def metadata() -> dict[str, Any]: |
| return _metadata() |
|
|
|
|
| @app.get("/tasks") |
| def list_tasks() -> dict[str, list[str]]: |
| return {"tasks": TASKS} |
|
|
|
|
| @app.get("/schema") |
| def schema() -> dict[str, Any]: |
| observation_schema = Observation.model_json_schema() |
| return { |
| "action": Action.model_json_schema(), |
| "observation": observation_schema, |
| "state": observation_schema, |
| } |
|
|
|
|
| @app.post("/mcp") |
| def mcp(payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]: |
| return { |
| "jsonrpc": "2.0", |
| "id": payload.get("id"), |
| "error": { |
| "code": -32601, |
| "message": "MCP methods are not implemented for this benchmark.", |
| }, |
| } |
|
|
|
|
| @app.post("/reset") |
| def reset(request: ResetRequest | None = None) -> dict[str, Any]: |
| effective_request = request or ResetRequest() |
| ENV.task_name = effective_request.task_name |
| observation = ENV.reset() |
| return observation.model_dump() |
|
|
|
|
| @app.post("/step") |
| def step(action: Action) -> dict[str, Any]: |
| observation, reward, done, info = ENV.step(action) |
| return { |
| "observation": observation.model_dump(), |
| "reward": reward, |
| "done": done, |
| "info": info, |
| } |
|
|
|
|
| @app.get("/state") |
| def state() -> dict[str, Any]: |
| if not ENV.dataset: |
| ENV.reset() |
| return ENV.state().model_dump() |
|
|
|
|
| def main(host: str | None = None, port: int | None = None) -> None: |
| if host is None or port is None: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--host", default="0.0.0.0") |
| parser.add_argument("--port", type=int, default=7860) |
| args = parser.parse_args() |
| host = args.host if host is None else host |
| port = args.port if port is None else port |
| uvicorn.run(app, host=host, port=port) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|