Spaces:
Configuration error
Configuration error
File size: 6,274 Bytes
c34e7cc d1600e6 c34e7cc f51115b c34e7cc f51115b c34e7cc f51115b c34e7cc f51115b c34e7cc f51115b c34e7cc f51115b c34e7cc f51115b c34e7cc f51115b c34e7cc f51115b c34e7cc f51115b c34e7cc f51115b c34e7cc facabc7 d58554d c34e7cc d58554d c34e7cc f51115b 2ba6413 c34e7cc d58554d c34e7cc d58554d c34e7cc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 | """
Pydantic v2 typed Action and Observation models for the
Distributed Infrastructure Management Environment.
"""
from typing import Dict, List, Literal, Optional
from openenv.core.env_server.types import Action, Observation, State
from pydantic import Field, model_validator
class InfraAction(Action):
"""
Action the LLM agent can take to manage the distributed system.
Supported action types:
- restart_node: Bring a failed node back online (2-step delay, 5-step cooldown).
- reroute_traffic: Shift a fraction of load between two nodes.
- scale_up: Add a temporary capacity node for 10 steps (costs 1 cloud budget unit).
- throttle: Reduce incoming request acceptance rate.
- query_logs: Investigate a node with telemetry dropout (partial observability).
- no_op: Take no action (passive observation step).
Optionally, ``raw_command`` can be set to a kubectl/AWS CLI string
which takes priority and is parsed into structured fields automatically.
"""
action_type: Literal[
"restart_node", "reroute_traffic", "scale_up", "throttle", "query_logs", "no_op"
] = Field(description="The management action to perform.")
target: Optional[int] = Field(
default=None,
description="Target node index (used by restart_node, query_logs).",
)
from_node: Optional[int] = Field(
default=None,
description="Source node index (used by reroute_traffic).",
)
to_node: Optional[int] = Field(
default=None,
description="Destination node index (used by reroute_traffic).",
)
rate: Optional[float] = Field(
default=None,
ge=0.0,
le=1.0,
description="Throttle rate in [0, 1] (used by throttle). 1.0 = accept all, 0.0 = reject all.",
)
raw_command: Optional[str] = Field(
default=None,
description=(
"Raw kubectl/AWS CLI command string. When set, the environment "
"parses this into a structured action automatically. Takes "
"priority over other fields."
),
)
@model_validator(mode="after")
def validate_action_params(self) -> "InfraAction":
# Skip validation when raw_command is provided — it gets parsed later
if self.raw_command:
return self
if self.action_type == "restart_node" and self.target is None:
raise ValueError("restart_node requires 'target' node index.")
if self.action_type == "reroute_traffic":
if self.from_node is None or self.to_node is None:
raise ValueError(
"reroute_traffic requires both 'from_node' and 'to_node'."
)
if self.action_type == "throttle" and self.rate is None:
raise ValueError("throttle requires 'rate' parameter.")
if self.action_type == "query_logs" and self.target is None:
raise ValueError("query_logs requires 'target' node index.")
return self
class InfraObservation(Observation):
"""
Observation returned to the LLM agent at each step.
Contains the observable state of the distributed system plus
anti-hacking and partial-observability metadata.
"""
cpu_loads: List[float] = Field(
description=(
"CPU utilization [0.0, 1.0] for each node. "
"A value of -1.0 indicates telemetry dropout (timeout)."
)
)
queue_lengths: List[int] = Field(
description="Number of pending requests per node. -1 indicates telemetry dropout."
)
failed_nodes: List[int] = Field(
description="Indices of nodes currently in failed state."
)
latency_ms: float = Field(
description="Rolling average end-to-end latency in milliseconds."
)
request_rate: float = Field(
description="Incoming requests per second into the system."
)
mem_utilizations: List[float] = Field(
default_factory=list,
description="Memory utilization [0.0, 1.0] per node (same ordering as cpu_loads).",
)
io_wait: float = Field(
default=0.0,
description="Database disk I/O wait / saturation proxy in [0.0, 1.0].",
)
p99_latency: float = Field(
default=0.0,
description="P99 tail latency in milliseconds.",
)
error_budget: float = Field(
default=100.0,
description="Remaining error budget token bucket for throttling actions.",
)
# --- ML-friendly normalized features ---
request_rate_norm: float = Field(
default=0.0,
description="request_rate normalized to [0,1] (divide by 5000.0, clipped).",
)
p99_latency_norm: float = Field(
default=0.0,
description="p99_latency normalized to [0,1] (divide by 1000.0, clipped).",
)
step: int = Field(description="Current step within the episode.")
task_hint: str = Field(
description="Natural language description of the current task objective."
)
task_score: float = Field(default=0.01, description="Current grader score")
# --- Partial observability ---
telemetry_status: Dict[int, str] = Field(
default_factory=dict,
description="Per-node telemetry status: 'ok' or 'timeout'.",
)
# --- Anti-hacking sandbox ---
action_errors: List[str] = Field(
default_factory=list,
description=(
"Errors from the last action (e.g. InsufficientFunds, "
"CooldownActive, ParseError)."
),
)
cloud_budget: int = Field(
default=10,
description="Remaining cloud budget units for scale_up.",
)
# --- Prometheus-style telemetry ---
prometheus_metrics: List[Dict] = Field(
default_factory=list,
description=(
"Prometheus-style structured metrics. Each entry is a dict with "
"'metric', 'labels', 'value', 'timestamp' keys."
),
)
class InfraState(State):
"""
Internal environment state extending the base OpenEnv State.
"""
task_id: Optional[str] = Field(default=None, description="Current task identifier.")
task_score: float = Field(
default=0.01, description="Current task grader score in (0.0, 1.0) strictly."
)
|