# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """AntiAtropos Environment Client.""" from typing import Dict from openenv.core import EnvClient from openenv.core.client_types import StepResult from openenv.core.env_server.types import State try: from .models import SREAction, ClusterObservation, NodeObservation, NodeStatus except ImportError: from models import SREAction, ClusterObservation, NodeObservation, NodeStatus # type: ignore class AntiAtroposEnv( EnvClient[SREAction, ClusterObservation, State] ): """ Client for the AntiAtropos Environment. This client maintains a persistent WebSocket connection to the environment server, enabling efficient multi-step interactions with lower latency. Each client instance has its own dedicated environment session on the server. Example: >>> # Connect to a running server >>> with AntiAtroposEnv(base_url="http://localhost:8000") as client: ... result = client.reset() ... print(result.observation.average_latency_ms) ... ... action = SREAction(action_type="SCALE_UP", target_node_id="node-0", parameter=2.0) ... result = client.step(action) ... print(result.observation.lyapunov_energy) Example with Docker: >>> # Automatically start container and connect >>> client = AntiAtroposEnv.from_docker_image("AntiAtropos-env:latest") >>> try: ... result = client.reset() ... result = client.step(SREAction(action_type="NO_OP")) ... finally: ... client.close() """ def _step_payload(self, action: SREAction) -> Dict: """ Convert SREAction to JSON payload for step message. Args: action: SREAction instance Returns: Dictionary representation suitable for JSON encoding """ return { "action_type": action.action_type.value, "target_node_id": action.target_node_id, "parameter": action.parameter, } def _parse_result(self, payload: Dict) -> StepResult[ClusterObservation]: """ Parse server response into StepResult[ClusterObservation]. Args: payload: JSON response data from server Returns: StepResult with ClusterObservation """ obs_data = payload.get("observation", {}) # Parse per-node list into NodeObservation objects raw_nodes = obs_data.get("nodes", []) node_obs = [ NodeObservation( node_id=n.get("node_id", ""), status=NodeStatus(n.get("status", NodeStatus.HEALTHY)), is_vip=n.get("is_vip", False), queue_depth=n.get("queue_depth", 0), latency_ms=n.get("latency_ms", 0.0), incoming_request_rate=n.get("incoming_request_rate", 0.0), cpu_utilization=n.get("cpu_utilization", 0.0), importance_weight=n.get("importance_weight", 1.0), capacity=n.get("capacity", 0.0), pending_capacity=n.get("pending_capacity", 0.0), queue_delta=n.get("queue_delta", 0.0), sla_proximity=n.get("sla_proximity", 0.0), outflow_rate=n.get("outflow_rate", 0.0), upstream_nodes=n.get("upstream_nodes", []), downstream_nodes=n.get("downstream_nodes", []), upstream_pressure=n.get("upstream_pressure", 0.0), node_reward=n.get("node_reward", 0.0), done=n.get("done", False), reward=n.get("reward", 0.0), ) for n in raw_nodes ] observation = ClusterObservation( cluster_id=obs_data.get("cluster_id", ""), task_id=obs_data.get("task_id", "task-1"), mode=obs_data.get("mode", "simulated"), active_nodes=obs_data.get("active_nodes", 0), average_latency_ms=obs_data.get("average_latency_ms", 0.0), error_rate=obs_data.get("error_rate", 0.0), total_queue_backlog=obs_data.get("total_queue_backlog", 0), current_cost_per_hour=obs_data.get("current_cost_per_hour", 0.0), lyapunov_energy=obs_data.get("lyapunov_energy", 0.0), nodes=node_obs, step=obs_data.get("step", 0), max_steps=obs_data.get("max_steps", 100), sla_violations=obs_data.get("sla_violations", 0), invalid_action_count=obs_data.get("invalid_action_count", 0), vip_failure_count=obs_data.get("vip_failure_count", 0), metric_timestamp=obs_data.get("metric_timestamp", 0.0), data_freshness_ms=obs_data.get("data_freshness_ms", 0), action_ack_status=obs_data.get("action_ack_status", "success"), action_id=obs_data.get("action_id", ""), executor_latency_ms=obs_data.get("executor_latency_ms", 0.0), executor_error_code=obs_data.get("executor_error_code", ""), raw_reward=obs_data.get("raw_reward", 0.0), normalized_reward=obs_data.get("normalized_reward", 0.0), reward_scale_version=obs_data.get("reward_scale_version", "sigmoid-v1"), reward_drift=obs_data.get("reward_drift", 0.0), reward_cost=obs_data.get("reward_cost", 0.0), reward_sla=obs_data.get("reward_sla", 0.0), reward_barrier=obs_data.get("reward_barrier", 0.0), choke_level=obs_data.get("choke_level", 0.0), done=payload.get("done", False), reward=payload.get("reward", 0.0), ) return StepResult( observation=observation, reward=payload.get("reward", 0.0), done=payload.get("done", False), ) def _parse_state(self, payload: Dict) -> State: """ Parse server response into State object. Args: payload: JSON response from state request Returns: State object with episode_id and step_count """ return State( episode_id=payload.get("episode_id"), step_count=payload.get("step_count", 0), )