File size: 3,473 Bytes
842577f
 
6b66cfc
68118d3
 
842577f
6b66cfc
68118d3
 
 
842577f
 
6b66cfc
 
68118d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b66cfc
 
 
68118d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
842577f
6b66cfc
842577f
6b66cfc
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""
FastAPI application for the Data Validation Environment.

Uses a STATEFUL single environment instance so that /reset and /step share state.
Responses use the standard OpenEnv format: {observation, reward, done}.
"""

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Any, Dict, Optional

from env.models import DataCleanAction, DataCleanObservation
from env.environment import DataValidationEnvironment


# ── Pydantic request / response models matching OpenEnv wire format ──────────

class ResetRequest(BaseModel):
    class Config:
        extra = "allow"
    task_name: Optional[str] = None
    seed: Optional[int] = 42
    episode_id: Optional[str] = None


class StepRequest(BaseModel):
    class Config:
        extra = "allow"
    action: Dict[str, Any]


class EnvResponse(BaseModel):
    observation: Dict[str, Any]
    reward: Optional[float] = None
    done: bool = False


# ── Shared environment instance (stateful across requests) ───────────────────

env = DataValidationEnvironment()


# ── FastAPI app ──────────────────────────────────────────────────────────────

app = FastAPI(
    title="OpenEnv Environment HTTP API",
    version="1.0.0",
)


def _serialize_observation(obs: DataCleanObservation) -> EnvResponse:
    """Convert observation to OpenEnv standard response format."""
    obs_dict = obs.model_dump(exclude={"reward", "done", "metadata"})
    return EnvResponse(
        observation=obs_dict,
        reward=obs.reward,
        done=obs.done,
    )


# ── Endpoints ────────────────────────────────────────────────────────────────

@app.get("/health")
def health():
    return {"status": "healthy"}


@app.get("/metadata")
def metadata():
    return {
        "name": "data_validation_env",
        "description": "An RL environment for training agents to clean and validate structured data.",
    }


@app.get("/schema")
def schema():
    return {
        "action": DataCleanAction.model_json_schema(),
        "observation": DataCleanObservation.model_json_schema(),
        "state": {},
    }


@app.get("/state")
def state():
    s = env.state
    return s.model_dump() if hasattr(s, "model_dump") else {"episode_id": None, "step_count": 0}


@app.post("/reset")
def reset(request: ResetRequest = ResetRequest()):
    obs = env.reset(
        task_name=request.task_name,
        seed=request.seed if request.seed is not None else 42,
        episode_id=request.episode_id,
    )
    return _serialize_observation(obs)


@app.post("/step")
def step(request: StepRequest):
    try:
        action = DataCleanAction.model_validate(request.action)
    except Exception as e:
        raise HTTPException(status_code=422, detail=str(e))
    obs = env.step(action)
    return _serialize_observation(obs)


# ── Entry point ──────────────────────────────────────────────────────────────

def main(host: str = "0.0.0.0", port: int = 8000):
    import uvicorn
    uvicorn.run(app, host=host, port=port)


if __name__ == "__main__":
    main()