# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """Aws Rl Env Environment Client.""" from typing import Any, Dict, Optional from openenv.core import EnvClient from openenv.core.client_types import StepResult from models import ( AwsRlAction, AwsRlObservation, EpisodeID, StepCount, AwsRlState, Task, ) class AwsRlEnv(EnvClient[AwsRlAction, AwsRlObservation, AwsRlState]): """ Client for the Aws Rl Env Environment. This client maintains a persistent WebSocket connection to the environment server, enabling efficient multi-step interactions with lower latency. Each client instance has its own dedicated environment session on the server. Example: >>> with AwsRlEnv(base_url="http://localhost:8000") as client: ... result = client.reset() ... print(result.observation.command_output) ... ... result = client.step(AwsRlAction(command="aws s3 ls")) ... print(result.observation.command_output) Example with Docker: >>> client = AwsRlEnv.from_docker_image("aws_rl_env-env:latest") >>> try: ... result = client.reset() ... result = client.step(AwsRlAction(command="aws s3 ls")) ... finally: ... client.close() """ async def reset( self, task: Optional[Task] = None, **kwargs: Any, ) -> StepResult[AwsRlObservation]: """Reset the environment. Pass a `Task` object to force that exact task (trainer mode) — the full task is serialised to the server so the env never has to look it up through its own curriculum. Without `task`, the server's local curriculum picks the next task. """ if task is not None: kwargs["task"] = task.model_dump() return await super().reset(**kwargs) def _step_payload(self, action: AwsRlAction) -> Dict: """Convert AwsRlAction to JSON payload for step message.""" return {"command": action.command} def _parse_result(self, payload: Dict) -> StepResult[AwsRlObservation]: """Parse server response into StepResult[AwsRlObservation].""" obs_data = payload.get("observation", {}) observation = AwsRlObservation( episode_id=EpisodeID(obs_data.get("episode_id", "")), step_count=StepCount(obs_data.get("step_count", 0)), command_success=obs_data.get("command_success", False), command_output=obs_data.get("command_output", ""), error=obs_data.get("error", ""), task=obs_data.get("task"), task_achieved=obs_data.get("task_achieved", False), done=payload.get("done", False), reward=payload.get("reward", 0.0), ) return StepResult( observation=observation, reward=payload.get("reward", 0.0), done=payload.get("done", False), ) def _parse_state(self, payload: Dict) -> AwsRlState: """Parse server response into AwsRlState object.""" from models import TrackerState, Task tracker_data = payload.get("tracker", {}) task_data = payload.get("current_task") return AwsRlState( episode_id=payload.get("episode_id"), step_count=payload.get("step_count", 0), current_task=Task(**task_data) if task_data else None, tracker=TrackerState(**tracker_data) if tracker_data else TrackerState(), infra_state=payload.get("infra_state", {}), chaos_occurred=payload.get("chaos_occurred", False), current_tier=payload.get("current_tier", "warmup"), )