Spaces:
Sleeping
Sleeping
File size: 3,473 Bytes
842577f 6b66cfc 68118d3 842577f 6b66cfc 68118d3 842577f 6b66cfc 68118d3 6b66cfc 68118d3 842577f 6b66cfc 842577f 6b66cfc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 | """
FastAPI application for the Data Validation Environment.
Uses a STATEFUL single environment instance so that /reset and /step share state.
Responses use the standard OpenEnv format: {observation, reward, done}.
"""
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Any, Dict, Optional
from env.models import DataCleanAction, DataCleanObservation
from env.environment import DataValidationEnvironment
# ββ Pydantic request / response models matching OpenEnv wire format ββββββββββ
class ResetRequest(BaseModel):
class Config:
extra = "allow"
task_name: Optional[str] = None
seed: Optional[int] = 42
episode_id: Optional[str] = None
class StepRequest(BaseModel):
class Config:
extra = "allow"
action: Dict[str, Any]
class EnvResponse(BaseModel):
observation: Dict[str, Any]
reward: Optional[float] = None
done: bool = False
# ββ Shared environment instance (stateful across requests) βββββββββββββββββββ
env = DataValidationEnvironment()
# ββ FastAPI app ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
app = FastAPI(
title="OpenEnv Environment HTTP API",
version="1.0.0",
)
def _serialize_observation(obs: DataCleanObservation) -> EnvResponse:
"""Convert observation to OpenEnv standard response format."""
obs_dict = obs.model_dump(exclude={"reward", "done", "metadata"})
return EnvResponse(
observation=obs_dict,
reward=obs.reward,
done=obs.done,
)
# ββ Endpoints ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@app.get("/health")
def health():
return {"status": "healthy"}
@app.get("/metadata")
def metadata():
return {
"name": "data_validation_env",
"description": "An RL environment for training agents to clean and validate structured data.",
}
@app.get("/schema")
def schema():
return {
"action": DataCleanAction.model_json_schema(),
"observation": DataCleanObservation.model_json_schema(),
"state": {},
}
@app.get("/state")
def state():
s = env.state
return s.model_dump() if hasattr(s, "model_dump") else {"episode_id": None, "step_count": 0}
@app.post("/reset")
def reset(request: ResetRequest = ResetRequest()):
obs = env.reset(
task_name=request.task_name,
seed=request.seed if request.seed is not None else 42,
episode_id=request.episode_id,
)
return _serialize_observation(obs)
@app.post("/step")
def step(request: StepRequest):
try:
action = DataCleanAction.model_validate(request.action)
except Exception as e:
raise HTTPException(status_code=422, detail=str(e))
obs = env.step(action)
return _serialize_observation(obs)
# ββ Entry point ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def main(host: str = "0.0.0.0", port: int = 8000):
import uvicorn
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
main()
|