""" Pydantic v2 typed Action and Observation models for the Distributed Infrastructure Management Environment. """ from typing import Dict, List, Literal, Optional from openenv.core.env_server.types import Action, Observation, State from pydantic import Field, model_validator class InfraAction(Action): """ Action the LLM agent can take to manage the distributed system. Supported action types: - restart_node: Bring a failed node back online (2-step delay, 5-step cooldown). - reroute_traffic: Shift a fraction of load between two nodes. - scale_up: Add a temporary capacity node for 10 steps (costs 1 cloud budget unit). - throttle: Reduce incoming request acceptance rate. - query_logs: Investigate a node with telemetry dropout (partial observability). - no_op: Take no action (passive observation step). Optionally, ``raw_command`` can be set to a kubectl/AWS CLI string which takes priority and is parsed into structured fields automatically. """ action_type: Literal[ "restart_node", "reroute_traffic", "scale_up", "throttle", "query_logs", "no_op" ] = Field(description="The management action to perform.") target: Optional[int] = Field( default=None, description="Target node index (used by restart_node, query_logs).", ) from_node: Optional[int] = Field( default=None, description="Source node index (used by reroute_traffic).", ) to_node: Optional[int] = Field( default=None, description="Destination node index (used by reroute_traffic).", ) rate: Optional[float] = Field( default=None, ge=0.0, le=1.0, description="Throttle rate in [0, 1] (used by throttle). 1.0 = accept all, 0.0 = reject all.", ) raw_command: Optional[str] = Field( default=None, description=( "Raw kubectl/AWS CLI command string. When set, the environment " "parses this into a structured action automatically. Takes " "priority over other fields." ), ) @model_validator(mode="after") def validate_action_params(self) -> "InfraAction": # Skip validation when raw_command is provided — it gets parsed later if self.raw_command: return self if self.action_type == "restart_node" and self.target is None: raise ValueError("restart_node requires 'target' node index.") if self.action_type == "reroute_traffic": if self.from_node is None or self.to_node is None: raise ValueError( "reroute_traffic requires both 'from_node' and 'to_node'." ) if self.action_type == "throttle" and self.rate is None: raise ValueError("throttle requires 'rate' parameter.") if self.action_type == "query_logs" and self.target is None: raise ValueError("query_logs requires 'target' node index.") return self class InfraObservation(Observation): """ Observation returned to the LLM agent at each step. Contains the observable state of the distributed system plus anti-hacking and partial-observability metadata. """ cpu_loads: List[float] = Field( description=( "CPU utilization [0.0, 1.0] for each node. " "A value of -1.0 indicates telemetry dropout (timeout)." ) ) queue_lengths: List[int] = Field( description="Number of pending requests per node. -1 indicates telemetry dropout." ) failed_nodes: List[int] = Field( description="Indices of nodes currently in failed state." ) latency_ms: float = Field( description="Rolling average end-to-end latency in milliseconds." ) request_rate: float = Field( description="Incoming requests per second into the system." ) mem_utilizations: List[float] = Field( default_factory=list, description="Memory utilization [0.0, 1.0] per node (same ordering as cpu_loads).", ) io_wait: float = Field( default=0.0, description="Database disk I/O wait / saturation proxy in [0.0, 1.0].", ) p99_latency: float = Field( default=0.0, description="P99 tail latency in milliseconds.", ) error_budget: float = Field( default=100.0, description="Remaining error budget token bucket for throttling actions.", ) # --- ML-friendly normalized features --- request_rate_norm: float = Field( default=0.0, description="request_rate normalized to [0,1] (divide by 5000.0, clipped).", ) p99_latency_norm: float = Field( default=0.0, description="p99_latency normalized to [0,1] (divide by 1000.0, clipped).", ) step: int = Field(description="Current step within the episode.") task_hint: str = Field( description="Natural language description of the current task objective." ) task_score: float = Field(default=0.01, description="Current grader score") # --- Partial observability --- telemetry_status: Dict[int, str] = Field( default_factory=dict, description="Per-node telemetry status: 'ok' or 'timeout'.", ) # --- Anti-hacking sandbox --- action_errors: List[str] = Field( default_factory=list, description=( "Errors from the last action (e.g. InsufficientFunds, " "CooldownActive, ParseError)." ), ) cloud_budget: int = Field( default=10, description="Remaining cloud budget units for scale_up.", ) # --- Prometheus-style telemetry --- prometheus_metrics: List[Dict] = Field( default_factory=list, description=( "Prometheus-style structured metrics. Each entry is a dict with " "'metric', 'labels', 'value', 'timestamp' keys." ), ) class InfraState(State): """ Internal environment state extending the base OpenEnv State. """ task_id: Optional[str] = Field(default=None, description="Current task identifier.") task_score: float = Field( default=0.01, description="Current task grader score in (0.0, 1.0) strictly." )