Spaces:
Sleeping
Sleeping
| """ | |
| FastAPI application for the Data Validation Environment. | |
| Uses a STATEFUL single environment instance so that /reset and /step share state. | |
| Responses use the standard OpenEnv format: {observation, reward, done}. | |
| """ | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from typing import Any, Dict, Optional | |
| from env.models import DataCleanAction, DataCleanObservation | |
| from env.environment import DataValidationEnvironment | |
| # ββ Pydantic request / response models matching OpenEnv wire format ββββββββββ | |
| class ResetRequest(BaseModel): | |
| class Config: | |
| extra = "allow" | |
| task_name: Optional[str] = None | |
| seed: Optional[int] = 42 | |
| episode_id: Optional[str] = None | |
| class StepRequest(BaseModel): | |
| class Config: | |
| extra = "allow" | |
| action: Dict[str, Any] | |
| class EnvResponse(BaseModel): | |
| observation: Dict[str, Any] | |
| reward: Optional[float] = None | |
| done: bool = False | |
| # ββ Shared environment instance (stateful across requests) βββββββββββββββββββ | |
| env = DataValidationEnvironment() | |
| # ββ FastAPI app ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| app = FastAPI( | |
| title="OpenEnv Environment HTTP API", | |
| version="1.0.0", | |
| ) | |
| def _serialize_observation(obs: DataCleanObservation) -> EnvResponse: | |
| """Convert observation to OpenEnv standard response format.""" | |
| obs_dict = obs.model_dump(exclude={"reward", "done", "metadata"}) | |
| return EnvResponse( | |
| observation=obs_dict, | |
| reward=obs.reward, | |
| done=obs.done, | |
| ) | |
| # ββ Endpoints ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def health(): | |
| return {"status": "healthy"} | |
| def metadata(): | |
| return { | |
| "name": "data_validation_env", | |
| "description": "An RL environment for training agents to clean and validate structured data.", | |
| } | |
| def schema(): | |
| return { | |
| "action": DataCleanAction.model_json_schema(), | |
| "observation": DataCleanObservation.model_json_schema(), | |
| "state": {}, | |
| } | |
| def state(): | |
| s = env.state | |
| return s.model_dump() if hasattr(s, "model_dump") else {"episode_id": None, "step_count": 0} | |
| def reset(request: ResetRequest = ResetRequest()): | |
| obs = env.reset( | |
| task_name=request.task_name, | |
| seed=request.seed if request.seed is not None else 42, | |
| episode_id=request.episode_id, | |
| ) | |
| return _serialize_observation(obs) | |
| def step(request: StepRequest): | |
| try: | |
| action = DataCleanAction.model_validate(request.action) | |
| except Exception as e: | |
| raise HTTPException(status_code=422, detail=str(e)) | |
| obs = env.step(action) | |
| return _serialize_observation(obs) | |
| # ββ Entry point ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(host: str = "0.0.0.0", port: int = 8000): | |
| import uvicorn | |
| uvicorn.run(app, host=host, port=port) | |
| if __name__ == "__main__": | |
| main() | |