Spaces:
Sleeping
Sleeping
File size: 3,525 Bytes
dce68a7 c22bf49 dce68a7 c22bf49 dce68a7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | from __future__ import annotations
import argparse
from typing import Any, Literal
import uvicorn
from fastapi import Body, FastAPI
from fastapi.responses import RedirectResponse, Response
from pydantic import BaseModel
from models import Action, Observation
from .environment import DataCleaningEnv
TASKS = ["basic_cleaning", "moderate_cleaning", "full_pipeline"]
ENV_NAME = "data_cleaning_env"
ENV_DESCRIPTION = (
"RL environment for interactive tabular data cleaning and preparation. "
"Agents must fix missing values, duplicates, dtype issues, category inconsistencies, "
"and derived-feature requirements."
)
app = FastAPI(title="Data Cleaning OpenEnv", version="1.0.0")
ENV = DataCleaningEnv()
class ResetRequest(BaseModel):
task_name: Literal["basic_cleaning", "moderate_cleaning", "full_pipeline"] = "basic_cleaning"
def _metadata() -> dict[str, Any]:
return {
"name": ENV_NAME,
"description": ENV_DESCRIPTION,
"version": "1.0.0",
"tasks": TASKS,
"mode": "simulation",
}
@app.get("/")
def root() -> dict[str, Any]:
payload = _metadata()
payload["status"] = "ok"
return payload
@app.get("/web", include_in_schema=False)
def web_root() -> RedirectResponse:
return RedirectResponse(url="/", status_code=307)
@app.get("/web/", include_in_schema=False)
def web_root_slash() -> RedirectResponse:
return RedirectResponse(url="/", status_code=307)
@app.get("/favicon.ico", include_in_schema=False)
def favicon() -> Response:
return Response(status_code=204)
@app.get("/health")
def health() -> dict[str, str]:
return {"status": "healthy"}
@app.get("/metadata")
def metadata() -> dict[str, Any]:
return _metadata()
@app.get("/tasks")
def list_tasks() -> dict[str, list[str]]:
return {"tasks": TASKS}
@app.get("/schema")
def schema() -> dict[str, Any]:
observation_schema = Observation.model_json_schema()
return {
"action": Action.model_json_schema(),
"observation": observation_schema,
"state": observation_schema,
}
@app.post("/mcp")
def mcp(payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]:
return {
"jsonrpc": "2.0",
"id": payload.get("id"),
"error": {
"code": -32601,
"message": "MCP methods are not implemented for this benchmark.",
},
}
@app.post("/reset")
def reset(request: ResetRequest | None = None) -> dict[str, Any]:
effective_request = request or ResetRequest()
ENV.task_name = effective_request.task_name
observation = ENV.reset()
return observation.model_dump()
@app.post("/step")
def step(action: Action) -> dict[str, Any]:
observation, reward, done, info = ENV.step(action)
return {
"observation": observation.model_dump(),
"reward": reward,
"done": done,
"info": info,
}
@app.get("/state")
def state() -> dict[str, Any]:
if not ENV.dataset:
ENV.reset()
return ENV.state().model_dump()
def main(host: str | None = None, port: int | None = None) -> None:
if host is None or port is None:
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=7860)
args = parser.parse_args()
host = args.host if host is None else host
port = args.port if port is None else port
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
main()
|