Spaces:
Sleeping
Sleeping
| # app.py — FastAPI server for the Data Cleaning OpenEnv environment | |
| from fastapi import FastAPI, HTTPException | |
| from typing import Dict | |
| import uvicorn | |
| import sys | |
| import os | |
| # Add parent directory to sys.path to import data_cleaning_env | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from data_cleaning_env import ( | |
| DataCleaningEnvironment, CleaningAction, | |
| DatasetObservation, EnvironmentState | |
| ) | |
| app = FastAPI( | |
| title="Data Cleaning OpenEnv", | |
| description="RL environment for data cleaning tasks", | |
| version="1.0.0" | |
| ) | |
| # One env instance per task | |
| envs: Dict[int, DataCleaningEnvironment] = { | |
| 1: DataCleaningEnvironment(task_id=1), | |
| 2: DataCleaningEnvironment(task_id=2), | |
| 3: DataCleaningEnvironment(task_id=3), | |
| } | |
| def root(): | |
| return {"status": "ok", "env": "data-cleaning", "version": "1.0.0"} | |
| def health(): | |
| return {"status": "healthy"} | |
| def list_tasks(): | |
| return {"tasks": [ | |
| {"id": 1, "difficulty": "easy", "name": "Remove null values"}, | |
| {"id": 2, "difficulty": "medium", "name": "Fix date formats"}, | |
| {"id": 3, "difficulty": "hard", "name": "Remove outliers"}, | |
| ]} | |
| def reset(task_id: int = 1): | |
| if task_id not in envs: | |
| raise HTTPException(400, detail=f"task_id must be 1, 2, or 3. Got {task_id}") | |
| obs = envs[task_id].reset() | |
| return obs.model_dump() | |
| def step(action: CleaningAction): | |
| env = envs.get(action.task_id) | |
| if env is None: | |
| raise HTTPException(400, detail="task_id must be 1, 2, or 3") | |
| obs, reward, done, info = env.step(action) | |
| return { | |
| "observation": obs.model_dump(), | |
| "reward": reward, | |
| "done": done, | |
| "info": info | |
| } | |
| def state(task_id: int = 1): | |
| env = envs.get(task_id) | |
| if env is None: | |
| raise HTTPException(400, detail="task_id must be 1, 2, or 3") | |
| return env.state().model_dump() | |
| def main(): | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |
| if __name__ == "__main__": | |
| main() | |