File size: 6,415 Bytes
dd24a31
 
 
 
 
 
 
 
 
 
 
 
 
 
cf2697b
 
 
 
dd24a31
 
 
654c8c7
dd24a31
 
 
 
 
 
 
 
 
 
 
 
654c8c7
dd24a31
654c8c7
 
 
dd24a31
 
 
 
 
 
654c8c7
dd24a31
 
 
 
654c8c7
dd24a31
654c8c7
dd24a31
 
654c8c7
dd24a31
 
 
 
 
654c8c7
 
 
dd24a31
 
654c8c7
dd24a31
654c8c7
dd24a31
 
 
 
 
654c8c7
dd24a31
 
654c8c7
 
 
 
 
 
 
4b5c463
654c8c7
 
 
 
4b5c463
630f735
 
 
 
 
 
 
 
 
654c8c7
 
 
 
 
 
 
 
 
77ede9e
654c8c7
 
 
 
 
 
 
 
 
 
4b5c463
 
77ede9e
 
 
630f735
 
 
6ad7bd8
 
 
630f735
 
 
 
 
dd24a31
654c8c7
dd24a31
 
 
 
654c8c7
dd24a31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""AntiAtropos Environment Client."""

from typing import Dict

from openenv.core import EnvClient
from openenv.core.client_types import StepResult
from openenv.core.env_server.types import State

try:
    from .models import SREAction, ClusterObservation, NodeObservation, NodeStatus
except ImportError:
    from models import SREAction, ClusterObservation, NodeObservation, NodeStatus  # type: ignore


class AntiAtroposEnv(
    EnvClient[SREAction, ClusterObservation, State]
):
    """
    Client for the AntiAtropos Environment.

    This client maintains a persistent WebSocket connection to the environment server,
    enabling efficient multi-step interactions with lower latency.
    Each client instance has its own dedicated environment session on the server.

    Example:
        >>> # Connect to a running server
        >>> with AntiAtroposEnv(base_url="http://localhost:8000") as client:
        ...     result = client.reset()
        ...     print(result.observation.average_latency_ms)
        ...
        ...     action = SREAction(action_type="SCALE_UP", target_node_id="node-0", parameter=2.0)
        ...     result = client.step(action)
        ...     print(result.observation.lyapunov_energy)

    Example with Docker:
        >>> # Automatically start container and connect
        >>> client = AntiAtroposEnv.from_docker_image("AntiAtropos-env:latest")
        >>> try:
        ...     result = client.reset()
        ...     result = client.step(SREAction(action_type="NO_OP"))
        ... finally:
        ...     client.close()
    """

    def _step_payload(self, action: SREAction) -> Dict:
        """
        Convert SREAction to JSON payload for step message.

        Args:
            action: SREAction instance

        Returns:
            Dictionary representation suitable for JSON encoding
        """
        return {
            "action_type": action.action_type.value,
            "target_node_id": action.target_node_id,
            "parameter": action.parameter,
        }

    def _parse_result(self, payload: Dict) -> StepResult[ClusterObservation]:
        """
        Parse server response into StepResult[ClusterObservation].

        Args:
            payload: JSON response data from server

        Returns:
            StepResult with ClusterObservation
        """
        obs_data = payload.get("observation", {})

        # Parse per-node list into NodeObservation objects
        raw_nodes = obs_data.get("nodes", [])
        node_obs = [
            NodeObservation(
                node_id=n.get("node_id", ""),
                status=NodeStatus(n.get("status", NodeStatus.HEALTHY)),
                is_vip=n.get("is_vip", False),
                queue_depth=n.get("queue_depth", 0),
                latency_ms=n.get("latency_ms", 0.0),
                incoming_request_rate=n.get("incoming_request_rate", 0.0),
                cpu_utilization=n.get("cpu_utilization", 0.0),
                importance_weight=n.get("importance_weight", 1.0),
                capacity=n.get("capacity", 0.0),
                pending_capacity=n.get("pending_capacity", 0.0),
                queue_delta=n.get("queue_delta", 0.0),
                sla_proximity=n.get("sla_proximity", 0.0),
                outflow_rate=n.get("outflow_rate", 0.0),
                upstream_nodes=n.get("upstream_nodes", []),
                downstream_nodes=n.get("downstream_nodes", []),
                upstream_pressure=n.get("upstream_pressure", 0.0),
                node_reward=n.get("node_reward", 0.0),
                done=n.get("done", False),
                reward=n.get("reward", 0.0),
            )
            for n in raw_nodes
        ]

        observation = ClusterObservation(
            cluster_id=obs_data.get("cluster_id", ""),
            task_id=obs_data.get("task_id", "task-1"),
            mode=obs_data.get("mode", "simulated"),
            active_nodes=obs_data.get("active_nodes", 0),
            average_latency_ms=obs_data.get("average_latency_ms", 0.0),
            error_rate=obs_data.get("error_rate", 0.0),
            total_queue_backlog=obs_data.get("total_queue_backlog", 0),
            current_cost_per_hour=obs_data.get("current_cost_per_hour", 0.0),
            lyapunov_energy=obs_data.get("lyapunov_energy", 0.0),
            nodes=node_obs,
            step=obs_data.get("step", 0),
            max_steps=obs_data.get("max_steps", 100),
            sla_violations=obs_data.get("sla_violations", 0),
            invalid_action_count=obs_data.get("invalid_action_count", 0),
            vip_failure_count=obs_data.get("vip_failure_count", 0),
            metric_timestamp=obs_data.get("metric_timestamp", 0.0),
            data_freshness_ms=obs_data.get("data_freshness_ms", 0),
            action_ack_status=obs_data.get("action_ack_status", "success"),
            action_id=obs_data.get("action_id", ""),
            executor_latency_ms=obs_data.get("executor_latency_ms", 0.0),
            executor_error_code=obs_data.get("executor_error_code", ""),
            raw_reward=obs_data.get("raw_reward", 0.0),
            normalized_reward=obs_data.get("normalized_reward", 0.0),
            reward_scale_version=obs_data.get("reward_scale_version", "sigmoid-v1"),
            reward_drift=obs_data.get("reward_drift", 0.0),
            reward_cost=obs_data.get("reward_cost", 0.0),
            reward_sla=obs_data.get("reward_sla", 0.0),
            reward_barrier=obs_data.get("reward_barrier", 0.0),
            choke_level=obs_data.get("choke_level", 0.0),
            done=payload.get("done", False),
            reward=payload.get("reward", 0.0),
        )

        return StepResult(
            observation=observation,
            reward=payload.get("reward", 0.0),
            done=payload.get("done", False),
        )

    def _parse_state(self, payload: Dict) -> State:
        """
        Parse server response into State object.

        Args:
            payload: JSON response from state request

        Returns:
            State object with episode_id and step_count
        """
        return State(
            episode_id=payload.get("episode_id"),
            step_count=payload.get("step_count", 0),
        )