File size: 3,125 Bytes
f89b1ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""Typed clients for the DataOpsEnv environment."""

from typing import Optional

import requests

from openenv.core.client_types import StepResult
from openenv.core.env_client import EnvClient

from models import DataOpsAction, DataOpsObservation, DataOpsState


class DataOpsEnv(EnvClient[DataOpsAction, DataOpsObservation, DataOpsState]):
    """Native OpenEnv WebSocket client for persistent sessions."""

    def _step_payload(self, action: DataOpsAction) -> dict:
        return action.model_dump()

    def _parse_result(self, payload: dict) -> StepResult[DataOpsObservation]:
        observation = DataOpsObservation(
            **payload.get("observation", {}),
            reward=payload.get("reward"),
            done=payload.get("done", False),
        )
        return StepResult(
            observation=observation,
            reward=payload.get("reward"),
            done=payload.get("done", False),
        )

    def _parse_state(self, payload: dict) -> DataOpsState:
        return DataOpsState(**payload)


class DataOpsEnvClient:
    """Compatibility HTTP client for the validator-facing REST API."""

    def __init__(
        self, base_url: str = "http://127.0.0.1:7860", timeout: float = 30.0
    ) -> None:
        self.base_url = base_url.rstrip("/")
        self.timeout = timeout
        self._session = requests.Session()

    @staticmethod
    def _parse_observation(payload: dict) -> DataOpsObservation:
        observation_payload = dict(payload.get("observation", {}))
        if "reward" in payload:
            observation_payload["reward"] = payload["reward"]
        if "done" in payload:
            observation_payload["done"] = payload["done"]
        return DataOpsObservation(**observation_payload)

    def reset(
        self, task_id: str = "task_1_easy_anomaly", seed: Optional[int] = None,
    ) -> DataOpsObservation:
        resp = self._session.post(
            f"{self.base_url}/reset",
            params={"task_id": task_id},
            json={"seed": seed},
            timeout=self.timeout,
        )
        resp.raise_for_status()
        return self._parse_observation(resp.json())

    def step(self, action: DataOpsAction) -> DataOpsObservation:
        resp = self._session.post(
            f"{self.base_url}/step",
            json={"action": action.model_dump()},
            timeout=self.timeout,
        )
        resp.raise_for_status()
        return self._parse_observation(resp.json())

    def state(self) -> DataOpsState:
        resp = self._session.get(f"{self.base_url}/state", timeout=self.timeout)
        resp.raise_for_status()
        return DataOpsState(**resp.json())

    def grade(self, task_id: Optional[str] = None) -> dict:
        url = f"{self.base_url}/grader/{task_id}" if task_id else f"{self.base_url}/grader"
        resp = self._session.get(url, timeout=self.timeout)
        resp.raise_for_status()
        return resp.json()

    def close(self) -> None:
        self._session.close()

    def __enter__(self) -> "DataOpsEnvClient":
        return self

    def __exit__(self, *args: object) -> None:
        self.close()