File size: 6,415 Bytes
dd24a31 cf2697b dd24a31 654c8c7 dd24a31 654c8c7 dd24a31 654c8c7 dd24a31 654c8c7 dd24a31 654c8c7 dd24a31 654c8c7 dd24a31 654c8c7 dd24a31 654c8c7 dd24a31 654c8c7 dd24a31 654c8c7 dd24a31 654c8c7 dd24a31 654c8c7 4b5c463 654c8c7 4b5c463 630f735 654c8c7 77ede9e 654c8c7 4b5c463 77ede9e 630f735 6ad7bd8 630f735 dd24a31 654c8c7 dd24a31 654c8c7 dd24a31 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""AntiAtropos Environment Client."""
from typing import Dict
from openenv.core import EnvClient
from openenv.core.client_types import StepResult
from openenv.core.env_server.types import State
try:
from .models import SREAction, ClusterObservation, NodeObservation, NodeStatus
except ImportError:
from models import SREAction, ClusterObservation, NodeObservation, NodeStatus # type: ignore
class AntiAtroposEnv(
EnvClient[SREAction, ClusterObservation, State]
):
"""
Client for the AntiAtropos Environment.
This client maintains a persistent WebSocket connection to the environment server,
enabling efficient multi-step interactions with lower latency.
Each client instance has its own dedicated environment session on the server.
Example:
>>> # Connect to a running server
>>> with AntiAtroposEnv(base_url="http://localhost:8000") as client:
... result = client.reset()
... print(result.observation.average_latency_ms)
...
... action = SREAction(action_type="SCALE_UP", target_node_id="node-0", parameter=2.0)
... result = client.step(action)
... print(result.observation.lyapunov_energy)
Example with Docker:
>>> # Automatically start container and connect
>>> client = AntiAtroposEnv.from_docker_image("AntiAtropos-env:latest")
>>> try:
... result = client.reset()
... result = client.step(SREAction(action_type="NO_OP"))
... finally:
... client.close()
"""
def _step_payload(self, action: SREAction) -> Dict:
"""
Convert SREAction to JSON payload for step message.
Args:
action: SREAction instance
Returns:
Dictionary representation suitable for JSON encoding
"""
return {
"action_type": action.action_type.value,
"target_node_id": action.target_node_id,
"parameter": action.parameter,
}
def _parse_result(self, payload: Dict) -> StepResult[ClusterObservation]:
"""
Parse server response into StepResult[ClusterObservation].
Args:
payload: JSON response data from server
Returns:
StepResult with ClusterObservation
"""
obs_data = payload.get("observation", {})
# Parse per-node list into NodeObservation objects
raw_nodes = obs_data.get("nodes", [])
node_obs = [
NodeObservation(
node_id=n.get("node_id", ""),
status=NodeStatus(n.get("status", NodeStatus.HEALTHY)),
is_vip=n.get("is_vip", False),
queue_depth=n.get("queue_depth", 0),
latency_ms=n.get("latency_ms", 0.0),
incoming_request_rate=n.get("incoming_request_rate", 0.0),
cpu_utilization=n.get("cpu_utilization", 0.0),
importance_weight=n.get("importance_weight", 1.0),
capacity=n.get("capacity", 0.0),
pending_capacity=n.get("pending_capacity", 0.0),
queue_delta=n.get("queue_delta", 0.0),
sla_proximity=n.get("sla_proximity", 0.0),
outflow_rate=n.get("outflow_rate", 0.0),
upstream_nodes=n.get("upstream_nodes", []),
downstream_nodes=n.get("downstream_nodes", []),
upstream_pressure=n.get("upstream_pressure", 0.0),
node_reward=n.get("node_reward", 0.0),
done=n.get("done", False),
reward=n.get("reward", 0.0),
)
for n in raw_nodes
]
observation = ClusterObservation(
cluster_id=obs_data.get("cluster_id", ""),
task_id=obs_data.get("task_id", "task-1"),
mode=obs_data.get("mode", "simulated"),
active_nodes=obs_data.get("active_nodes", 0),
average_latency_ms=obs_data.get("average_latency_ms", 0.0),
error_rate=obs_data.get("error_rate", 0.0),
total_queue_backlog=obs_data.get("total_queue_backlog", 0),
current_cost_per_hour=obs_data.get("current_cost_per_hour", 0.0),
lyapunov_energy=obs_data.get("lyapunov_energy", 0.0),
nodes=node_obs,
step=obs_data.get("step", 0),
max_steps=obs_data.get("max_steps", 100),
sla_violations=obs_data.get("sla_violations", 0),
invalid_action_count=obs_data.get("invalid_action_count", 0),
vip_failure_count=obs_data.get("vip_failure_count", 0),
metric_timestamp=obs_data.get("metric_timestamp", 0.0),
data_freshness_ms=obs_data.get("data_freshness_ms", 0),
action_ack_status=obs_data.get("action_ack_status", "success"),
action_id=obs_data.get("action_id", ""),
executor_latency_ms=obs_data.get("executor_latency_ms", 0.0),
executor_error_code=obs_data.get("executor_error_code", ""),
raw_reward=obs_data.get("raw_reward", 0.0),
normalized_reward=obs_data.get("normalized_reward", 0.0),
reward_scale_version=obs_data.get("reward_scale_version", "sigmoid-v1"),
reward_drift=obs_data.get("reward_drift", 0.0),
reward_cost=obs_data.get("reward_cost", 0.0),
reward_sla=obs_data.get("reward_sla", 0.0),
reward_barrier=obs_data.get("reward_barrier", 0.0),
choke_level=obs_data.get("choke_level", 0.0),
done=payload.get("done", False),
reward=payload.get("reward", 0.0),
)
return StepResult(
observation=observation,
reward=payload.get("reward", 0.0),
done=payload.get("done", False),
)
def _parse_state(self, payload: Dict) -> State:
"""
Parse server response into State object.
Args:
payload: JSON response from state request
Returns:
State object with episode_id and step_count
"""
return State(
episode_id=payload.get("episode_id"),
step_count=payload.get("step_count", 0),
)
|