File size: 6,274 Bytes
c34e7cc
 
 
 
 
d1600e6
c34e7cc
 
 
 
 
 
 
 
 
 
f51115b
c34e7cc
f51115b
c34e7cc
f51115b
c34e7cc
f51115b
 
 
c34e7cc
 
 
f51115b
c34e7cc
 
 
 
f51115b
c34e7cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f51115b
 
 
 
 
 
 
 
c34e7cc
 
 
f51115b
 
 
c34e7cc
 
 
 
 
 
 
 
 
f51115b
 
c34e7cc
 
 
 
 
 
 
f51115b
 
c34e7cc
 
 
f51115b
 
 
 
 
 
 
c34e7cc
 
 
 
 
 
 
 
 
 
facabc7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d58554d
c34e7cc
 
 
d58554d
c34e7cc
f51115b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ba6413
 
 
 
 
 
 
 
 
c34e7cc
 
 
 
 
 
d58554d
c34e7cc
d58554d
c34e7cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
"""
Pydantic v2 typed Action and Observation models for the
Distributed Infrastructure Management Environment.
"""

from typing import Dict, List, Literal, Optional

from openenv.core.env_server.types import Action, Observation, State
from pydantic import Field, model_validator


class InfraAction(Action):
    """
    Action the LLM agent can take to manage the distributed system.

    Supported action types:
        - restart_node: Bring a failed node back online (2-step delay, 5-step cooldown).
        - reroute_traffic: Shift a fraction of load between two nodes.
        - scale_up: Add a temporary capacity node for 10 steps (costs 1 cloud budget unit).
        - throttle: Reduce incoming request acceptance rate.
        - query_logs: Investigate a node with telemetry dropout (partial observability).
        - no_op: Take no action (passive observation step).

    Optionally, ``raw_command`` can be set to a kubectl/AWS CLI string
    which takes priority and is parsed into structured fields automatically.
    """

    action_type: Literal[
        "restart_node", "reroute_traffic", "scale_up", "throttle", "query_logs", "no_op"
    ] = Field(description="The management action to perform.")

    target: Optional[int] = Field(
        default=None,
        description="Target node index (used by restart_node, query_logs).",
    )
    from_node: Optional[int] = Field(
        default=None,
        description="Source node index (used by reroute_traffic).",
    )
    to_node: Optional[int] = Field(
        default=None,
        description="Destination node index (used by reroute_traffic).",
    )
    rate: Optional[float] = Field(
        default=None,
        ge=0.0,
        le=1.0,
        description="Throttle rate in [0, 1] (used by throttle). 1.0 = accept all, 0.0 = reject all.",
    )
    raw_command: Optional[str] = Field(
        default=None,
        description=(
            "Raw kubectl/AWS CLI command string. When set, the environment "
            "parses this into a structured action automatically. Takes "
            "priority over other fields."
        ),
    )

    @model_validator(mode="after")
    def validate_action_params(self) -> "InfraAction":
        # Skip validation when raw_command is provided — it gets parsed later
        if self.raw_command:
            return self
        if self.action_type == "restart_node" and self.target is None:
            raise ValueError("restart_node requires 'target' node index.")
        if self.action_type == "reroute_traffic":
            if self.from_node is None or self.to_node is None:
                raise ValueError(
                    "reroute_traffic requires both 'from_node' and 'to_node'."
                )
        if self.action_type == "throttle" and self.rate is None:
            raise ValueError("throttle requires 'rate' parameter.")
        if self.action_type == "query_logs" and self.target is None:
            raise ValueError("query_logs requires 'target' node index.")
        return self


class InfraObservation(Observation):
    """
    Observation returned to the LLM agent at each step.

    Contains the observable state of the distributed system plus
    anti-hacking and partial-observability metadata.
    """

    cpu_loads: List[float] = Field(
        description=(
            "CPU utilization [0.0, 1.0] for each node. "
            "A value of -1.0 indicates telemetry dropout (timeout)."
        )
    )
    queue_lengths: List[int] = Field(
        description="Number of pending requests per node. -1 indicates telemetry dropout."
    )
    failed_nodes: List[int] = Field(
        description="Indices of nodes currently in failed state."
    )
    latency_ms: float = Field(
        description="Rolling average end-to-end latency in milliseconds."
    )
    request_rate: float = Field(
        description="Incoming requests per second into the system."
    )
    mem_utilizations: List[float] = Field(
        default_factory=list,
        description="Memory utilization [0.0, 1.0] per node (same ordering as cpu_loads).",
    )
    io_wait: float = Field(
        default=0.0,
        description="Database disk I/O wait / saturation proxy in [0.0, 1.0].",
    )
    p99_latency: float = Field(
        default=0.0,
        description="P99 tail latency in milliseconds.",
    )
    error_budget: float = Field(
        default=100.0,
        description="Remaining error budget token bucket for throttling actions.",
    )

    # --- ML-friendly normalized features ---
    request_rate_norm: float = Field(
        default=0.0,
        description="request_rate normalized to [0,1] (divide by 5000.0, clipped).",
    )
    p99_latency_norm: float = Field(
        default=0.0,
        description="p99_latency normalized to [0,1] (divide by 1000.0, clipped).",
    )
    step: int = Field(description="Current step within the episode.")
    task_hint: str = Field(
        description="Natural language description of the current task objective."
    )
    task_score: float = Field(default=0.01, description="Current grader score")

    # --- Partial observability ---
    telemetry_status: Dict[int, str] = Field(
        default_factory=dict,
        description="Per-node telemetry status: 'ok' or 'timeout'.",
    )

    # --- Anti-hacking sandbox ---
    action_errors: List[str] = Field(
        default_factory=list,
        description=(
            "Errors from the last action (e.g. InsufficientFunds, "
            "CooldownActive, ParseError)."
        ),
    )
    cloud_budget: int = Field(
        default=10,
        description="Remaining cloud budget units for scale_up.",
    )

    # --- Prometheus-style telemetry ---
    prometheus_metrics: List[Dict] = Field(
        default_factory=list,
        description=(
            "Prometheus-style structured metrics. Each entry is a dict with "
            "'metric', 'labels', 'value', 'timestamp' keys."
        ),
    )


class InfraState(State):
    """
    Internal environment state extending the base OpenEnv State.
    """

    task_id: Optional[str] = Field(default=None, description="Current task identifier.")
    task_score: float = Field(
        default=0.01, description="Current task grader score in (0.0, 1.0) strictly."
    )