| """WebSocket client for the SimMart environment.""" |
|
|
| from typing import Dict |
|
|
| from openenv.core.env_client import EnvClient |
| from openenv.core.client_types import StepResult |
|
|
| from .models import SimMartAction, SimMartObservation, SimMartState |
|
|
|
|
| class SimMartEnv( |
| EnvClient[SimMartAction, SimMartObservation, SimMartState] |
| ): |
| """ |
| Async/sync client for SimMart. |
| |
| Example (sync): |
| >>> with SimMartEnv(base_url="http://localhost:7860").sync() as env: |
| ... result = env.reset(seed=42) |
| ... result = env.step(SimMartAction(action_type="decide")) |
| """ |
|
|
| def _step_payload(self, action: SimMartAction) -> Dict: |
| return { |
| "action_type": action.action_type, |
| "content": action.content, |
| } |
|
|
| def _parse_result(self, payload: Dict) -> StepResult[SimMartObservation]: |
| obs = payload.get("observation", {}) |
| observation = SimMartObservation( |
| output=obs.get("output", ""), |
| task_description=obs.get("task_description", ""), |
| step_number=obs.get("step_number", 0), |
| day_of_quarter=obs.get("day_of_quarter", 0), |
| week_of_quarter=obs.get("week_of_quarter", 0), |
| message=obs.get("message", ""), |
| done=payload.get("done", False), |
| reward=payload.get("reward"), |
| ) |
| return StepResult( |
| observation=observation, |
| reward=payload.get("reward"), |
| done=payload.get("done", False), |
| ) |
|
|
| def _parse_state(self, payload: Dict) -> SimMartState: |
| return SimMartState( |
| episode_id=payload.get("episode_id", ""), |
| step_count=payload.get("step_count", 0), |
| day=payload.get("day", 0), |
| week=payload.get("week", 0), |
| rng_seed=payload.get("rng_seed", 0), |
| ) |
|
|