AntiAtropos / models.py
PranavKK1201
graph modelling
b5e5650
from enum import Enum
from typing import Annotated, Literal, Optional
from pydantic import BaseModel, Field
class EnvironmentMode(str, Enum):
SIMULATED = "simulated"
HYBRID = "hybrid"
LIVE = "live"
AWS = "aws"
# ---------------------------------------------------------------------------
# SRE Action Schema (Control Plane)
# ---------------------------------------------------------------------------
class ActionType(str, Enum):
NO_OP = "NO_OP"
SCALE_UP = "SCALE_UP"
SCALE_DOWN = "SCALE_DOWN"
REROUTE_TRAFFIC = "REROUTE_TRAFFIC"
SHED_LOAD = "SHED_LOAD"
class SREAction(BaseModel):
"""
Management action issued by the SRE agent.
* SCALE_UP: Increment capacity on target_node_id by parameter (1-5 units).
* SCALE_DOWN: Decrement capacity on target_node_id by parameter (1-5 units).
* REROUTE_TRAFFIC: Shift 'parameter' [0, 1] of incoming traffic AWAY from
target_node_id and redistribute across healthy peers.
* SHED_LOAD: Drop 'parameter' [0, 1] of incoming traffic targeting target_node_id for 1 tick.
"""
action_type: ActionType
target_node_id: str
parameter: float = Field(default=0.0, ge=0.0, le=10.0)
# ---------------------------------------------------------------------------
# Observation Schema (Data Plane)
# ---------------------------------------------------------------------------
class NodeStatus(str, Enum):
HEALTHY = "HEALTHY"
DEGRADED = "DEGRADED"
FAILED = "FAILED"
class NodeObservation(BaseModel):
"""Telemetry for a single service instance (node)."""
node_id: str
status: NodeStatus
is_vip: bool = False
# All numerical telemetry is normalized to [0, 1] for RL stability.
queue_depth: float = Field(
default=0.0,
ge=0.0,
le=1.0,
description=(
"Normalized queue depth [0.0, 1.0]. Represents the % of theoretical max queue."
),
)
latency_ms: float = Field(
default=0.0,
ge=0.0,
le=1.0,
description="Normalized processing latency [0.0, 1.0] relative to 1000ms SLA limit.",
)
incoming_request_rate: float = Field(
default=0.0,
ge=0.0,
le=1.0,
description="Normalized incoming request rate [0.0, 1.0] for this node (requests per tick).",
)
cpu_utilization: float = Field(
default=0.0,
ge=0.0,
le=1.0,
description="Estimated CPU load [0.0, 1.0].",
)
importance_weight: float = Field(
default=1.0,
ge=0.0,
description="Business criticality weight. VIP nodes have higher impact on scoring.",
)
capacity: float = Field(
default=0.0,
ge=0.0,
description="Current capacity units provisioned for this node (0-5).",
)
pending_capacity: float = Field(
default=0.0,
ge=0.0,
description="Capacity units being booted (will be live after boot delay).",
)
queue_delta: float = Field(
default=0.0,
ge=-1.0,
le=1.0,
description="Normalized queue depth change from previous tick (-1 to +1).",
)
sla_proximity: float = Field(
default=0.0,
ge=0.0,
le=1.0,
description="How close this node is to SLA violation (0=safe, 1=violating).",
)
outflow_rate: float = Field(
default=0.0,
ge=0.0,
le=1.0,
description="Normalised rate of requests dispatched downstream [0, 1].",
)
upstream_nodes: list[str] = Field(default_factory=list)
downstream_nodes: list[str] = Field(default_factory=list)
upstream_pressure: float = Field(
default=0.0,
ge=0.0,
le=1.0,
description="Mean queue depth of upstream parent nodes (normalised).",
)
node_reward: float = Field(
default=0.0,
description="Per-node reward contribution for credit assignment.",
)
# Episode interaction fields (handled by framework)
done: bool = False
reward: float = 0.0
class ClusterObservation(BaseModel):
"""System-wide telemetry representing the 'dashboard' for the agent."""
cluster_id: str
task_id: str
step: int
max_steps: int
mode: EnvironmentMode = EnvironmentMode.SIMULATED
active_nodes: int = Field(ge=0, le=10)
average_latency_ms: float = Field(
default=0.0,
ge=0.0,
le=1.0,
description="Cluster-wide average latency (normalized [0.0, 1.0]).",
)
error_rate: float = Field(
default=0.0,
ge=0.0,
le=1.0,
description="Cluster-wide fraction of dropped/failed requests [0.0, 1.0].",
)
total_queue_backlog: float = Field(
default=0.0,
ge=0.0,
le=1.0,
description=(
"Normalized sum of queue_depth across all nodes [0.0, 1.0]."
),
)
current_cost_per_hour: float = Field(
default=0.0,
ge=0.0,
description="Infrastructure cost in USD/hr based on provisioned capacity.",
)
lyapunov_energy: float = Field(
default=0.0,
description="Stability metric (Sum of squares of queue depths). Low is good.",
)
sla_violations: int = Field(
default=0,
description="Cumulative count of SLA violations this episode.",
)
invalid_action_count: int = Field(
default=0,
description="Number of forbidden actions (e.g. SHED_LOAD on critical nodes).",
)
vip_failure_count: int = Field(
default=0,
description="Number of failed VIP nodes in the current observation.",
)
# New fields for Prometheus/Kubernetes integration
metric_timestamp: float = 0.0
data_freshness_ms: int = 0
action_ack_status: str = "success"
action_id: str = ""
executor_latency_ms: float = Field(default=0.0, ge=0.0)
executor_error_code: str = ""
raw_reward: float = 0.0
normalized_reward: float = Field(default=0.0, ge=0.0, le=1.0)
reward_scale_version: str = "sigmoid-v1"
# Reward components breakdown
reward_drift: float = Field(
default=0.0,
description="Lyapunov drift component of the reward.",
)
reward_cost: float = Field(
default=0.0,
description="Infrastructure cost component of the reward.",
)
reward_sla: float = Field(
default=0.0,
description="SLA penalty component of the reward.",
)
reward_barrier: float = Field(
default=0.0,
description="Barrier function penalty component of the reward.",
)
choke_level: float = 0.0
nodes: list[NodeObservation]
# Episode interaction fields (handled by framework)
done: bool = False
reward: float = 0.0