Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """AntiAtropos Environment Client.""" | |
| from typing import Dict | |
| from openenv.core import EnvClient | |
| from openenv.core.client_types import StepResult | |
| from openenv.core.env_server.types import State | |
| try: | |
| from .models import SREAction, ClusterObservation, NodeObservation, NodeStatus | |
| except ImportError: | |
| from models import SREAction, ClusterObservation, NodeObservation, NodeStatus # type: ignore | |
| class AntiAtroposEnv( | |
| EnvClient[SREAction, ClusterObservation, State] | |
| ): | |
| """ | |
| Client for the AntiAtropos Environment. | |
| This client maintains a persistent WebSocket connection to the environment server, | |
| enabling efficient multi-step interactions with lower latency. | |
| Each client instance has its own dedicated environment session on the server. | |
| Example: | |
| >>> # Connect to a running server | |
| >>> with AntiAtroposEnv(base_url="http://localhost:8000") as client: | |
| ... result = client.reset() | |
| ... print(result.observation.average_latency_ms) | |
| ... | |
| ... action = SREAction(action_type="SCALE_UP", target_node_id="node-0", parameter=2.0) | |
| ... result = client.step(action) | |
| ... print(result.observation.lyapunov_energy) | |
| Example with Docker: | |
| >>> # Automatically start container and connect | |
| >>> client = AntiAtroposEnv.from_docker_image("AntiAtropos-env:latest") | |
| >>> try: | |
| ... result = client.reset() | |
| ... result = client.step(SREAction(action_type="NO_OP")) | |
| ... finally: | |
| ... client.close() | |
| """ | |
| def _step_payload(self, action: SREAction) -> Dict: | |
| """ | |
| Convert SREAction to JSON payload for step message. | |
| Args: | |
| action: SREAction instance | |
| Returns: | |
| Dictionary representation suitable for JSON encoding | |
| """ | |
| return { | |
| "action_type": action.action_type.value, | |
| "target_node_id": action.target_node_id, | |
| "parameter": action.parameter, | |
| } | |
| def _parse_result(self, payload: Dict) -> StepResult[ClusterObservation]: | |
| """ | |
| Parse server response into StepResult[ClusterObservation]. | |
| Args: | |
| payload: JSON response data from server | |
| Returns: | |
| StepResult with ClusterObservation | |
| """ | |
| obs_data = payload.get("observation", {}) | |
| # Parse per-node list into NodeObservation objects | |
| raw_nodes = obs_data.get("nodes", []) | |
| node_obs = [ | |
| NodeObservation( | |
| node_id=n.get("node_id", ""), | |
| status=NodeStatus(n.get("status", NodeStatus.HEALTHY)), | |
| is_vip=n.get("is_vip", False), | |
| queue_depth=n.get("queue_depth", 0), | |
| latency_ms=n.get("latency_ms", 0.0), | |
| incoming_request_rate=n.get("incoming_request_rate", 0.0), | |
| cpu_utilization=n.get("cpu_utilization", 0.0), | |
| importance_weight=n.get("importance_weight", 1.0), | |
| capacity=n.get("capacity", 0.0), | |
| pending_capacity=n.get("pending_capacity", 0.0), | |
| queue_delta=n.get("queue_delta", 0.0), | |
| sla_proximity=n.get("sla_proximity", 0.0), | |
| outflow_rate=n.get("outflow_rate", 0.0), | |
| upstream_nodes=n.get("upstream_nodes", []), | |
| downstream_nodes=n.get("downstream_nodes", []), | |
| upstream_pressure=n.get("upstream_pressure", 0.0), | |
| node_reward=n.get("node_reward", 0.0), | |
| done=n.get("done", False), | |
| reward=n.get("reward", 0.0), | |
| ) | |
| for n in raw_nodes | |
| ] | |
| observation = ClusterObservation( | |
| cluster_id=obs_data.get("cluster_id", ""), | |
| task_id=obs_data.get("task_id", "task-1"), | |
| mode=obs_data.get("mode", "simulated"), | |
| active_nodes=obs_data.get("active_nodes", 0), | |
| average_latency_ms=obs_data.get("average_latency_ms", 0.0), | |
| error_rate=obs_data.get("error_rate", 0.0), | |
| total_queue_backlog=obs_data.get("total_queue_backlog", 0), | |
| current_cost_per_hour=obs_data.get("current_cost_per_hour", 0.0), | |
| lyapunov_energy=obs_data.get("lyapunov_energy", 0.0), | |
| nodes=node_obs, | |
| step=obs_data.get("step", 0), | |
| max_steps=obs_data.get("max_steps", 100), | |
| sla_violations=obs_data.get("sla_violations", 0), | |
| invalid_action_count=obs_data.get("invalid_action_count", 0), | |
| vip_failure_count=obs_data.get("vip_failure_count", 0), | |
| metric_timestamp=obs_data.get("metric_timestamp", 0.0), | |
| data_freshness_ms=obs_data.get("data_freshness_ms", 0), | |
| action_ack_status=obs_data.get("action_ack_status", "success"), | |
| action_id=obs_data.get("action_id", ""), | |
| executor_latency_ms=obs_data.get("executor_latency_ms", 0.0), | |
| executor_error_code=obs_data.get("executor_error_code", ""), | |
| raw_reward=obs_data.get("raw_reward", 0.0), | |
| normalized_reward=obs_data.get("normalized_reward", 0.0), | |
| reward_scale_version=obs_data.get("reward_scale_version", "sigmoid-v1"), | |
| reward_drift=obs_data.get("reward_drift", 0.0), | |
| reward_cost=obs_data.get("reward_cost", 0.0), | |
| reward_sla=obs_data.get("reward_sla", 0.0), | |
| reward_barrier=obs_data.get("reward_barrier", 0.0), | |
| choke_level=obs_data.get("choke_level", 0.0), | |
| done=payload.get("done", False), | |
| reward=payload.get("reward", 0.0), | |
| ) | |
| return StepResult( | |
| observation=observation, | |
| reward=payload.get("reward", 0.0), | |
| done=payload.get("done", False), | |
| ) | |
| def _parse_state(self, payload: Dict) -> State: | |
| """ | |
| Parse server response into State object. | |
| Args: | |
| payload: JSON response from state request | |
| Returns: | |
| State object with episode_id and step_count | |
| """ | |
| return State( | |
| episode_id=payload.get("episode_id"), | |
| step_count=payload.get("step_count", 0), | |
| ) | |