MrShadowBlade commited on
Commit
57c06cb
·
1 Parent(s): c8daa82

Implement Kubernetes action classes and execution logic

Browse files
server/actions/__init__.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .scale_action import ScaleAction
2
+ from .patch_action import PatchAction
3
+ from .delete_pod_action import DeletePodAction
4
+ from .rollout_action import RolloutRestartAction
5
+ from .hpa_action import SetHPAAction
6
+ from .drain_action import DrainNodeAction
7
+ from .describe_action import DescribeAction
8
+ from typing import Union, Any, Dict, Literal
9
+
10
+ KubeAction = Union[
11
+ ScaleAction,
12
+ PatchAction,
13
+ DeletePodAction,
14
+ RolloutRestartAction,
15
+ SetHPAAction,
16
+ DrainNodeAction,
17
+ DescribeAction
18
+ ]
19
+
20
+ ActionType = Literal["scale", "patch", "delete_pod", "rollout_restart", "set_hpa", "drain_node", "describe"]
21
+
22
+
23
+ def parse_action(data: Dict[str, Any]) -> KubeAction:
24
+ if not isinstance(data, dict):
25
+ raise ValueError(f"Expected dict, got {type(data)}")
26
+
27
+ action_type = data.get("action_type")
28
+ if not action_type:
29
+ raise ValueError("Missing 'action_type' field")
30
+
31
+ action_map = {
32
+ "scale": ScaleAction,
33
+ "patch": PatchAction,
34
+ "delete_pod": DeletePodAction,
35
+ "rollout_restart": RolloutRestartAction,
36
+ "set_hpa": SetHPAAction,
37
+ "drain_node": DrainNodeAction,
38
+ "describe": DescribeAction,
39
+ }
40
+
41
+ action_class = action_map.get(action_type)
42
+ if not action_class:
43
+ raise ValueError(f"Unknown action_type: {action_type}")
44
+
45
+ return action_class(**data)
46
+
47
+
48
+ __all__ = [
49
+ "ScaleAction",
50
+ "PatchAction",
51
+ "DeletePodAction",
52
+ "RolloutRestartAction",
53
+ "SetHPAAction",
54
+ "DrainNodeAction",
55
+ "DescribeAction",
56
+ "KubeAction",
57
+ "parse_action",
58
+ ]
server/actions/delete_pod_action.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import Literal
3
+
4
+
5
+ class DeletePodAction(BaseModel):
6
+ action_type: Literal["delete_pod"] = "delete_pod"
7
+ pod_name: str = Field(..., description="Exact name of the pod to delete")
server/actions/describe_action.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import Literal
3
+
4
+
5
+ class DescribeAction(BaseModel):
6
+ action_type: Literal["describe"] = "describe"
7
+ resource_type: Literal["deployment", "pod", "node", "service", "configmap"] = Field(..., description="Resource type to inspect")
8
+ name: str = Field(..., description="Resource name to inspect")
server/actions/drain_action.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import Literal
3
+
4
+
5
+ class DrainNodeAction(BaseModel):
6
+ action_type: Literal["drain_node"] = "drain_node"
7
+ node_name: str = Field(..., description="Node to cordon and drain")
server/actions/hpa_action.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field, field_validator
2
+ from typing import Literal
3
+
4
+
5
+ class SetHPAAction(BaseModel):
6
+ action_type: Literal["set_hpa"] = "set_hpa"
7
+ deployment: str = Field(..., description="Target deployment name")
8
+ min_replicas: int = Field(..., ge=1, le=20, description="Minimum replicas")
9
+ max_replicas: int = Field(..., ge=1, le=20, description="Maximum replicas")
10
+ cpu_target_percent: int = Field(..., ge=10, le=90, description="Target CPU percentage")
11
+
12
+ @field_validator("max_replicas")
13
+ @classmethod
14
+ def max_must_be_gte_min(cls, v, info):
15
+ if "min_replicas" in info.data and v < info.data["min_replicas"]:
16
+ raise ValueError("max_replicas must be >= min_replicas")
17
+ return v
server/actions/patch_action.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import Literal, Dict, Any
3
+
4
+
5
+ class PatchAction(BaseModel):
6
+ action_type: Literal["patch"] = "patch"
7
+ resource_type: Literal["deployment", "configmap", "service"] = Field(..., description="One of: deployment, configmap, service")
8
+ name: str = Field(..., description="Resource name")
9
+ patch: Dict[str, Any] = Field(..., description="Fields to update (partial patch)")
server/actions/rollout_action.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import Literal
3
+
4
+
5
+ class RolloutRestartAction(BaseModel):
6
+ action_type: Literal["rollout_restart"] = "rollout_restart"
7
+ deployment: str = Field(..., description="Deployment to restart all pods for")
server/actions/scale_action.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import Literal
3
+
4
+
5
+ class ScaleAction(BaseModel):
6
+ action_type: Literal["scale"] = "scale"
7
+ deployment: str = Field(..., description="Name of the deployment to scale")
8
+ replicas: int = Field(..., ge=1, le=20, description="Target replica count (1-20)")
server/conditions/__init__.py ADDED
File without changes
server/executor.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from typing import Any, Dict, Optional
3
+ from server.actions import (
4
+ KubeAction,
5
+ ScaleAction,
6
+ DeletePodAction,
7
+ PatchAction,
8
+ RolloutRestartAction,
9
+ SetHPAAction,
10
+ DrainNodeAction,
11
+ DescribeAction,
12
+ )
13
+ from server.models import ClusterObservation
14
+
15
+
16
+ class ExecutionResult(BaseModel):
17
+ observation: ClusterObservation
18
+ action_applied: str
19
+ tick_advanced: bool
20
+ describe_detail: Optional[Dict[str, Any]] = None
21
+
22
+
23
+ def execute(action: KubeAction, world) -> ExecutionResult:
24
+ if isinstance(action, ScaleAction):
25
+ return _execute_scale(action, world)
26
+ elif isinstance(action, DeletePodAction):
27
+ return _execute_delete_pod(action, world)
28
+ elif isinstance(action, PatchAction):
29
+ return _execute_patch(action, world)
30
+ elif isinstance(action, RolloutRestartAction):
31
+ return _execute_rollout_restart(action, world)
32
+ elif isinstance(action, SetHPAAction):
33
+ return _execute_set_hpa(action, world)
34
+ elif isinstance(action, DrainNodeAction):
35
+ return _execute_drain_node(action, world)
36
+ elif isinstance(action, DescribeAction):
37
+ return _execute_describe(action, world)
38
+ else:
39
+ raise ValueError(f"Unknown action type: {type(action)}")
40
+
41
+
42
+ def _execute_scale(action: ScaleAction, world) -> ExecutionResult:
43
+ world.scale(action.deployment, action.replicas)
44
+ world.tick()
45
+ return ExecutionResult(
46
+ observation=world.get_observation(),
47
+ action_applied=f"Scaled '{action.deployment}' to {action.replicas} replicas",
48
+ tick_advanced=True
49
+ )
50
+
51
+
52
+ def _execute_delete_pod(action: DeletePodAction, world) -> ExecutionResult:
53
+ world.delete_pod(action.pod_name)
54
+ world.tick()
55
+ return ExecutionResult(
56
+ observation=world.get_observation(),
57
+ action_applied=f"Deleted pod '{action.pod_name}'",
58
+ tick_advanced=True
59
+ )
60
+
61
+
62
+ def _execute_patch(action: PatchAction, world) -> ExecutionResult:
63
+ world.apply_patch(action.resource_type, action.name, action.patch)
64
+ world.tick()
65
+ return ExecutionResult(
66
+ observation=world.get_observation(),
67
+ action_applied=f"Patched {action.resource_type} '{action.name}'",
68
+ tick_advanced=True
69
+ )
70
+
71
+
72
+ def _execute_rollout_restart(action: RolloutRestartAction, world) -> ExecutionResult:
73
+ world.rollout_restart(action.deployment)
74
+ world.tick()
75
+ return ExecutionResult(
76
+ observation=world.get_observation(),
77
+ action_applied=f"Rollout restarted '{action.deployment}'",
78
+ tick_advanced=True
79
+ )
80
+
81
+
82
+ def _execute_set_hpa(action: SetHPAAction, world) -> ExecutionResult:
83
+ world.set_hpa(
84
+ action.deployment,
85
+ action.min_replicas,
86
+ action.max_replicas,
87
+ action.cpu_target_percent
88
+ )
89
+ world.tick()
90
+ return ExecutionResult(
91
+ observation=world.get_observation(),
92
+ action_applied=f"Set HPA for '{action.deployment}': {action.min_replicas}-{action.max_replicas} replicas, {action.cpu_target_percent}% CPU",
93
+ tick_advanced=True
94
+ )
95
+
96
+
97
+ def _execute_drain_node(action: DrainNodeAction, world) -> ExecutionResult:
98
+ world.drain_node(action.node_name)
99
+ world.tick()
100
+ return ExecutionResult(
101
+ observation=world.get_observation(),
102
+ action_applied=f"Drained node '{action.node_name}'",
103
+ tick_advanced=True
104
+ )
105
+
106
+
107
+ def _execute_describe(action: DescribeAction, world) -> ExecutionResult:
108
+ detail = world.describe(action.resource_type, action.name)
109
+ obs = world.get_observation()
110
+ return ExecutionResult(
111
+ observation=obs,
112
+ action_applied=f"Described {action.resource_type} '{action.name}'",
113
+ tick_advanced=False,
114
+ describe_detail=detail
115
+ )
server/graders/__init__.py ADDED
File without changes
server/models.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List, Dict, Any, Literal
2
+ from pydantic import BaseModel, Field
3
+
4
+
5
+ class PodStatus(BaseModel):
6
+ name: str
7
+ namespace: str = "default"
8
+ status: Literal["Running", "Pending", "CrashLoopBackOff", "OOMKilled", "Terminating", "Unknown"]
9
+ node: Optional[str] = None
10
+ restarts: int = 0
11
+ cpu_usage: float = 0.0
12
+ mem_usage: float = 0.0
13
+ container_image: str = "nginx:1.21"
14
+ env_vars: Dict[str, str] = Field(default_factory=dict)
15
+ resources: Dict[str, Dict[str, str]] = Field(default_factory=lambda: {"limits": {}, "requests": {}})
16
+
17
+
18
+ class NodeStatus(BaseModel):
19
+ name: str
20
+ status: Literal["Ready", "NotReady", "SchedulingDisabled"] = "Ready"
21
+ cpu_capacity: float = 4.0
22
+ mem_capacity: float = 8192.0
23
+ cpu_usage: float = 0.0
24
+ mem_usage: float = 0.0
25
+ pods: List[str] = Field(default_factory=list)
26
+
27
+
28
+ class DeploymentStatus(BaseModel):
29
+ name: str
30
+ namespace: str = "default"
31
+ desired_replicas: int = 1
32
+ available_replicas: int = 1
33
+ image: str = "nginx:1.21"
34
+ env_vars: List[Dict[str, str]] = Field(default_factory=list)
35
+ resources: Dict[str, Dict[str, str]] = Field(default_factory=lambda: {"limits": {}, "requests": {}})
36
+ hpa: Optional[Dict[str, Any]] = None
37
+
38
+
39
+ class ServiceStatus(BaseModel):
40
+ name: str
41
+ namespace: str = "default"
42
+ service_type: str = "ClusterIP"
43
+ selector: Dict[str, str] = Field(default_factory=dict)
44
+ ports: List[Dict[str, Any]] = Field(default_factory=lambda: [{"port": 80, "targetPort": 80}])
45
+ external_ip: Optional[str] = None
46
+ error_rate: float = 0.0
47
+ latency_p95: float = 0.0
48
+
49
+
50
+ class ConfigMapStatus(BaseModel):
51
+ name: str
52
+ namespace: str = "default"
53
+ data: Dict[str, str] = Field(default_factory=dict)
54
+
55
+
56
+ class HPAStatus(BaseModel):
57
+ name: str
58
+ namespace: str = "default"
59
+ target_deployment: str
60
+ min_replicas: int = 1
61
+ max_replicas: int = 10
62
+ cpu_target_percent: int = 80
63
+ current_replicas: int = 1
64
+
65
+
66
+ class ClusterEvent(BaseModel):
67
+ message: str
68
+ reason: str
69
+ type: Literal["Normal", "Warning"] = "Normal"
70
+ involved_object: str = ""
71
+ first_timestamp: Optional[str] = None
72
+ count: int = 1
73
+
74
+
75
+ class ClusterObservation(BaseModel):
76
+ nodes: List[NodeStatus] = Field(default_factory=list)
77
+ pods: List[PodStatus] = Field(default_factory=list)
78
+ deployments: List[DeploymentStatus] = Field(default_factory=list)
79
+ services: List[ServiceStatus] = Field(default_factory=list)
80
+ configmaps: List[ConfigMapStatus] = Field(default_factory=list)
81
+ hpa: List[HPAStatus] = Field(default_factory=list)
82
+ events: List[ClusterEvent] = Field(default_factory=list)
83
+ step: int = 0
84
+ objective: str = ""
server/tasks/__init__.py ADDED
File without changes
server/validator.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Dict, Any, List
2
+ from server.actions import (
3
+ KubeAction,
4
+ ScaleAction,
5
+ DeletePodAction,
6
+ PatchAction,
7
+ RolloutRestartAction,
8
+ SetHPAAction,
9
+ DrainNodeAction,
10
+ DescribeAction,
11
+ )
12
+
13
+
14
+ def validate(action: KubeAction, world_state: Dict[str, Any]) -> Optional[str]:
15
+ if isinstance(action, ScaleAction):
16
+ return _validate_scale(action, world_state)
17
+ elif isinstance(action, DeletePodAction):
18
+ return _validate_delete_pod(action, world_state)
19
+ elif isinstance(action, PatchAction):
20
+ return _validate_patch(action, world_state)
21
+ elif isinstance(action, RolloutRestartAction):
22
+ return _validate_rollout_restart(action, world_state)
23
+ elif isinstance(action, SetHPAAction):
24
+ return _validate_set_hpa(action, world_state)
25
+ elif isinstance(action, DrainNodeAction):
26
+ return _validate_drain_node(action, world_state)
27
+ elif isinstance(action, DescribeAction):
28
+ return _validate_describe(action, world_state)
29
+ return None
30
+
31
+
32
+ def _validate_scale(action: ScaleAction, world_state: Dict[str, Any]) -> Optional[str]:
33
+ deployments = world_state.get("deployments", [])
34
+ deployment_names = [d.get("name") for d in deployments]
35
+
36
+ if action.deployment not in deployment_names:
37
+ return f"Deployment '{action.deployment}' not found. Available: {deployment_names}"
38
+
39
+ if action.replicas < 1 or action.replicas > 20:
40
+ return f"Replica count must be between 1 and 20, got {action.replicas}"
41
+
42
+ return None
43
+
44
+
45
+ def _validate_delete_pod(action: DeletePodAction, world_state: Dict[str, Any]) -> Optional[str]:
46
+ pods = world_state.get("pods", [])
47
+ pod_names = [p.get("name") for p in pods]
48
+
49
+ if action.pod_name not in pod_names:
50
+ return f"Pod '{action.pod_name}' not found in cluster. Available: {pod_names}"
51
+
52
+ pod = next((p for p in pods if p.get("name") == action.pod_name), None)
53
+ if pod and pod.get("status") == "Terminating":
54
+ return f"Pod '{action.pod_name}' is already terminating"
55
+
56
+ return None
57
+
58
+
59
+ def _validate_patch(action: PatchAction, world_state: Dict[str, Any]) -> Optional[str]:
60
+ resource_type = action.resource_type
61
+ name = action.name
62
+
63
+ if resource_type == "deployment":
64
+ deployments = world_state.get("deployments", [])
65
+ deployment_names = [d.get("name") for d in deployments]
66
+ if name not in deployment_names:
67
+ return f"Deployment '{name}' not found. Available: {deployment_names}"
68
+
69
+ elif resource_type == "configmap":
70
+ configmaps = world_state.get("configmaps", [])
71
+ configmap_names = [c.get("name") for c in configmaps]
72
+ if name not in configmap_names:
73
+ return f"ConfigMap '{name}' not found. Available: {configmap_names}"
74
+
75
+ elif resource_type == "service":
76
+ services = world_state.get("services", [])
77
+ service_names = [s.get("name") for s in services]
78
+ if name not in service_names:
79
+ return f"Service '{name}' not found. Available: {service_names}"
80
+
81
+ else:
82
+ return f"Invalid resource_type: {resource_type}. Must be one of: deployment, configmap, service"
83
+
84
+ return None
85
+
86
+
87
+ def _validate_rollout_restart(action: RolloutRestartAction, world_state: Dict[str, Any]) -> Optional[str]:
88
+ deployments = world_state.get("deployments", [])
89
+ deployment_names = [d.get("name") for d in deployments]
90
+
91
+ if action.deployment not in deployment_names:
92
+ return f"Deployment '{action.deployment}' not found. Available: {deployment_names}"
93
+
94
+ return None
95
+
96
+
97
+ def _validate_set_hpa(action: SetHPAAction, world_state: Dict[str, Any]) -> Optional[str]:
98
+ deployments = world_state.get("deployments", [])
99
+ deployment_names = [d.get("name") for d in deployments]
100
+
101
+ if action.deployment not in deployment_names:
102
+ return f"Deployment '{action.deployment}' not found. Available: {deployment_names}"
103
+
104
+ if action.max_replicas < action.min_replicas:
105
+ return f"max_replicas ({action.max_replicas}) must be >= min_replicas ({action.min_replicas})"
106
+
107
+ if action.cpu_target_percent < 10 or action.cpu_target_percent > 90:
108
+ return f"cpu_target_percent must be between 10 and 90, got {action.cpu_target_percent}"
109
+
110
+ return None
111
+
112
+
113
+ def _validate_drain_node(action: DrainNodeAction, world_state: Dict[str, Any]) -> Optional[str]:
114
+ nodes = world_state.get("nodes", [])
115
+ node_names = [n.get("name") for n in nodes]
116
+
117
+ if action.node_name not in node_names:
118
+ return f"Node '{action.node_name}' not found. Available: {node_names}"
119
+
120
+ node = next((n for n in nodes if n.get("name") == action.node_name), None)
121
+ if node and node.get("status") == "SchedulingDisabled":
122
+ return f"Node '{action.node_name}' is already drained (SchedulingDisabled)"
123
+
124
+ ready_nodes = [n for n in nodes if n.get("status") == "Ready"]
125
+ if len(ready_nodes) <= 1 and node and node.get("status") == "Ready":
126
+ return "Cannot drain last healthy node — cluster would lose all capacity"
127
+
128
+ return None
129
+
130
+
131
+ def _validate_describe(action: DescribeAction, world_state: Dict[str, Any]) -> Optional[str]:
132
+ resource_type = action.resource_type
133
+ name = action.name
134
+
135
+ if resource_type == "deployment":
136
+ deployments = world_state.get("deployments", [])
137
+ deployment_names = [d.get("name") for d in deployments]
138
+ if name not in deployment_names:
139
+ return f"Deployment '{name}' not found. Available: {deployment_names}"
140
+
141
+ elif resource_type == "pod":
142
+ pods = world_state.get("pods", [])
143
+ pod_names = [p.get("name") for p in pods]
144
+ if name not in pod_names:
145
+ return f"Pod '{name}' not found. Available: {pod_names}"
146
+
147
+ elif resource_type == "node":
148
+ nodes = world_state.get("nodes", [])
149
+ node_names = [n.get("name") for n in nodes]
150
+ if name not in node_names:
151
+ return f"Node '{name}' not found. Available: {node_names}"
152
+
153
+ elif resource_type == "service":
154
+ services = world_state.get("services", [])
155
+ service_names = [s.get("name") for s in services]
156
+ if name not in service_names:
157
+ return f"Service '{name}' not found. Available: {service_names}"
158
+
159
+ elif resource_type == "configmap":
160
+ configmaps = world_state.get("configmaps", [])
161
+ configmap_names = [c.get("name") for c in configmaps]
162
+ if name not in configmap_names:
163
+ return f"ConfigMap '{name}' not found. Available: {configmap_names}"
164
+
165
+ else:
166
+ return f"Invalid resource_type: {resource_type}. Must be one of: deployment, pod, node, service, configmap"
167
+
168
+ return None
server/worker.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import List, Callable, Any, Optional, Dict
3
+
4
+
5
+ @dataclass
6
+ class StepRecord:
7
+ step: int
8
+ action_applied: str
9
+ reward: float
10
+ done: bool
11
+ error: Optional[str] = None
12
+
13
+
14
+ @dataclass
15
+ class EpisodeResult:
16
+ task_id: str
17
+ steps_taken: int
18
+ rewards: List[float]
19
+ success: bool
20
+ history: List[StepRecord] = field(default_factory=list)
21
+
22
+ @property
23
+ def total_reward(self) -> float:
24
+ return sum(self.rewards)
25
+
26
+
27
+ class Worker:
28
+ def run_episode(
29
+ self,
30
+ task_id: str,
31
+ world: Any,
32
+ get_action: Callable[[Any], Any],
33
+ max_steps: int,
34
+ grader: Any
35
+ ) -> EpisodeResult:
36
+ obs = world.reset(task=task_id)
37
+ history: List[StepRecord] = []
38
+ rewards: List[float] = []
39
+ done = False
40
+
41
+ for step in range(1, max_steps + 1):
42
+ action = get_action(obs)
43
+
44
+ error = None
45
+ from server.validator import validate
46
+ validation_error = validate(action, world.get_raw_state())
47
+
48
+ if validation_error:
49
+ history.append(StepRecord(
50
+ step=step,
51
+ action_applied="invalid_action",
52
+ reward=0.0,
53
+ done=False,
54
+ error=validation_error
55
+ ))
56
+ rewards.append(0.0)
57
+ continue
58
+
59
+ from server.executor import execute
60
+ result = execute(action, world)
61
+
62
+ reward = grader.grade(world.get_raw_state(), step, max_steps)
63
+ done = grader.is_done(world.get_raw_state())
64
+
65
+ history.append(StepRecord(
66
+ step=step,
67
+ action_applied=result.action_applied,
68
+ reward=reward,
69
+ done=done,
70
+ error=None
71
+ ))
72
+ rewards.append(reward)
73
+ obs = result.observation
74
+
75
+ if done:
76
+ break
77
+
78
+ return EpisodeResult(
79
+ task_id=task_id,
80
+ steps_taken=len(history),
81
+ rewards=rewards,
82
+ success=done,
83
+ history=history
84
+ )
tests/__init__.py ADDED
File without changes
tests/test_actions.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from pydantic import ValidationError
3
+ from server.actions import (
4
+ ScaleAction,
5
+ PatchAction,
6
+ DeletePodAction,
7
+ RolloutRestartAction,
8
+ SetHPAAction,
9
+ DrainNodeAction,
10
+ DescribeAction,
11
+ parse_action,
12
+ )
13
+
14
+
15
+ class TestScaleAction:
16
+ def test_valid_scale_action(self):
17
+ action = ScaleAction(action_type="scale", deployment="frontend", replicas=3)
18
+ assert action.deployment == "frontend"
19
+ assert action.replicas == 3
20
+
21
+ def test_scale_action_rejects_zero_replicas(self):
22
+ with pytest.raises(ValidationError):
23
+ ScaleAction(action_type="scale", deployment="frontend", replicas=0)
24
+
25
+ def test_scale_action_rejects_negative_replicas(self):
26
+ with pytest.raises(ValidationError):
27
+ ScaleAction(action_type="scale", deployment="frontend", replicas=-1)
28
+
29
+ def test_scale_action_rejects_too_many_replicas(self):
30
+ with pytest.raises(ValidationError):
31
+ ScaleAction(action_type="scale", deployment="frontend", replicas=21)
32
+
33
+ def test_scale_action_accepts_boundary_values(self):
34
+ action_min = ScaleAction(action_type="scale", deployment="frontend", replicas=1)
35
+ action_max = ScaleAction(action_type="scale", deployment="frontend", replicas=20)
36
+ assert action_min.replicas == 1
37
+ assert action_max.replicas == 20
38
+
39
+
40
+ class TestPatchAction:
41
+ def test_valid_patch_action(self):
42
+ action = PatchAction(
43
+ action_type="patch",
44
+ resource_type="deployment",
45
+ name="frontend",
46
+ patch={"env": [{"name": "DB_HOST", "value": "db.prod.internal"}]}
47
+ )
48
+ assert action.resource_type == "deployment"
49
+ assert action.name == "frontend"
50
+
51
+ def test_patch_action_rejects_invalid_resource_type(self):
52
+ with pytest.raises(ValidationError):
53
+ PatchAction(
54
+ action_type="patch",
55
+ resource_type="invalid",
56
+ name="frontend",
57
+ patch={}
58
+ )
59
+
60
+
61
+ class TestDeletePodAction:
62
+ def test_valid_delete_pod_action(self):
63
+ action = DeletePodAction(action_type="delete_pod", pod_name="frontend-7d9f-xkp2")
64
+ assert action.pod_name == "frontend-7d9f-xkp2"
65
+
66
+
67
+ class TestRolloutRestartAction:
68
+ def test_valid_rollout_restart_action(self):
69
+ action = RolloutRestartAction(action_type="rollout_restart", deployment="frontend")
70
+ assert action.deployment == "frontend"
71
+
72
+
73
+ class TestSetHPAAction:
74
+ def test_valid_hpa_action(self):
75
+ action = SetHPAAction(
76
+ action_type="set_hpa",
77
+ deployment="api",
78
+ min_replicas=2,
79
+ max_replicas=10,
80
+ cpu_target_percent=70
81
+ )
82
+ assert action.deployment == "api"
83
+ assert action.min_replicas == 2
84
+ assert action.max_replicas == 10
85
+
86
+ def test_hpa_action_rejects_max_less_than_min(self):
87
+ with pytest.raises(ValidationError):
88
+ SetHPAAction(
89
+ action_type="set_hpa",
90
+ deployment="api",
91
+ min_replicas=5,
92
+ max_replicas=2,
93
+ cpu_target_percent=60
94
+ )
95
+
96
+ def test_hpa_action_rejects_invalid_cpu_target(self):
97
+ with pytest.raises(ValidationError):
98
+ SetHPAAction(
99
+ action_type="set_hpa",
100
+ deployment="api",
101
+ min_replicas=1,
102
+ max_replicas=10,
103
+ cpu_target_percent=5
104
+ )
105
+
106
+ def test_hpa_action_accepts_boundary_cpu_target(self):
107
+ action_min = SetHPAAction(
108
+ action_type="set_hpa",
109
+ deployment="api",
110
+ min_replicas=1,
111
+ max_replicas=10,
112
+ cpu_target_percent=10
113
+ )
114
+ action_max = SetHPAAction(
115
+ action_type="set_hpa",
116
+ deployment="api",
117
+ min_replicas=1,
118
+ max_replicas=10,
119
+ cpu_target_percent=90
120
+ )
121
+ assert action_min.cpu_target_percent == 10
122
+ assert action_max.cpu_target_percent == 90
123
+
124
+
125
+ class TestDrainNodeAction:
126
+ def test_valid_drain_node_action(self):
127
+ action = DrainNodeAction(action_type="drain_node", node_name="node-1")
128
+ assert action.node_name == "node-1"
129
+
130
+
131
+ class TestDescribeAction:
132
+ def test_valid_describe_action(self):
133
+ action = DescribeAction(
134
+ action_type="describe",
135
+ resource_type="deployment",
136
+ name="frontend"
137
+ )
138
+ assert action.resource_type == "deployment"
139
+ assert action.name == "frontend"
140
+
141
+ def test_describe_action_rejects_invalid_resource_type(self):
142
+ with pytest.raises(ValidationError):
143
+ DescribeAction(
144
+ action_type="describe",
145
+ resource_type="invalid",
146
+ name="frontend"
147
+ )
148
+
149
+
150
+ class TestParseAction:
151
+ def test_parse_scale_action(self):
152
+ raw = {"action_type": "scale", "deployment": "frontend", "replicas": 3}
153
+ action = parse_action(raw)
154
+ assert isinstance(action, ScaleAction)
155
+ assert action.deployment == "frontend"
156
+ assert action.replicas == 3
157
+
158
+ def test_parse_delete_pod_action(self):
159
+ raw = {"action_type": "delete_pod", "pod_name": "frontend-7d9f-xkp2"}
160
+ action = parse_action(raw)
161
+ assert isinstance(action, DeletePodAction)
162
+ assert action.pod_name == "frontend-7d9f-xkp2"
163
+
164
+ def test_parse_patch_action(self):
165
+ raw = {
166
+ "action_type": "patch",
167
+ "resource_type": "deployment",
168
+ "name": "frontend",
169
+ "patch": {"env": [{"name": "DB_HOST", "value": "db.prod.internal"}]}
170
+ }
171
+ action = parse_action(raw)
172
+ assert isinstance(action, PatchAction)
173
+ assert action.name == "frontend"
174
+
175
+ def test_parse_rollout_restart_action(self):
176
+ raw = {"action_type": "rollout_restart", "deployment": "frontend"}
177
+ action = parse_action(raw)
178
+ assert isinstance(action, RolloutRestartAction)
179
+ assert action.deployment == "frontend"
180
+
181
+ def test_parse_hpa_action(self):
182
+ raw = {
183
+ "action_type": "set_hpa",
184
+ "deployment": "api",
185
+ "min_replicas": 2,
186
+ "max_replicas": 10,
187
+ "cpu_target_percent": 70
188
+ }
189
+ action = parse_action(raw)
190
+ assert isinstance(action, SetHPAAction)
191
+ assert action.deployment == "api"
192
+
193
+ def test_parse_drain_node_action(self):
194
+ raw = {"action_type": "drain_node", "node_name": "node-1"}
195
+ action = parse_action(raw)
196
+ assert isinstance(action, DrainNodeAction)
197
+ assert action.node_name == "node-1"
198
+
199
+ def test_parse_describe_action(self):
200
+ raw = {"action_type": "describe", "resource_type": "deployment", "name": "frontend"}
201
+ action = parse_action(raw)
202
+ assert isinstance(action, DescribeAction)
203
+ assert action.name == "frontend"
204
+
205
+ def test_parse_unknown_action_type(self):
206
+ with pytest.raises(ValueError, match="Unknown action_type"):
207
+ parse_action({"action_type": "unknown_action"})
208
+
209
+ def test_parse_missing_action_type(self):
210
+ with pytest.raises(ValueError, match="Missing 'action_type'"):
211
+ parse_action({"deployment": "frontend"})
tests/test_executor.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from unittest.mock import MagicMock, call
3
+ from server.actions import (
4
+ ScaleAction,
5
+ DeletePodAction,
6
+ PatchAction,
7
+ RolloutRestartAction,
8
+ SetHPAAction,
9
+ DrainNodeAction,
10
+ DescribeAction,
11
+ )
12
+ from server.executor import execute
13
+ from server.models import ClusterObservation
14
+
15
+
16
+ class MockWorld:
17
+ def __init__(self):
18
+ self.scale_called_with = None
19
+ self.delete_pod_called_with = None
20
+ self.apply_patch_called_with = None
21
+ self.rollout_restart_called_with = None
22
+ self.set_hpa_called_with = None
23
+ self.drain_node_called_with = None
24
+ self.describe_called_with = None
25
+ self.tick_called = False
26
+ self._observation = ClusterObservation(nodes=[], pods=[], deployments=[], services=[], configmaps=[], hpa=[], events=[], step=0, objective="")
27
+ self._raw_state = {"nodes": [], "pods": [], "deployments": [], "services": [], "configmaps": []}
28
+
29
+ def scale(self, deployment, replicas):
30
+ self.scale_called_with = (deployment, replicas)
31
+
32
+ def delete_pod(self, pod_name):
33
+ self.delete_pod_called_with = pod_name
34
+
35
+ def apply_patch(self, resource_type, name, patch):
36
+ self.apply_patch_called_with = (resource_type, name, patch)
37
+
38
+ def rollout_restart(self, deployment):
39
+ self.rollout_restart_called_with = deployment
40
+
41
+ def set_hpa(self, deployment, min_replicas, max_replicas, cpu_target_percent):
42
+ self.set_hpa_called_with = (deployment, min_replicas, max_replicas, cpu_target_percent)
43
+
44
+ def drain_node(self, node_name):
45
+ self.drain_node_called_with = node_name
46
+
47
+ def describe(self, resource_type, name):
48
+ self.describe_called_with = (resource_type, name)
49
+ return {"type": resource_type, "name": name, "detail": "mock detail"}
50
+
51
+ def tick(self):
52
+ self.tick_called = True
53
+
54
+ def get_observation(self):
55
+ return self._observation
56
+
57
+ def get_raw_state(self):
58
+ return self._raw_state
59
+
60
+
61
+ class TestExecutorScale:
62
+ def test_scale_calls_world_scale_and_ticks(self):
63
+ mock_world = MockWorld()
64
+ action = ScaleAction(action_type="scale", deployment="frontend", replicas=3)
65
+ result = execute(action, mock_world)
66
+
67
+ assert mock_world.scale_called_with == ("frontend", 3)
68
+ assert mock_world.tick_called is True
69
+ assert result.tick_advanced is True
70
+ assert "Scaled" in result.action_applied
71
+
72
+ def test_scale_action_applied_message(self):
73
+ mock_world = MockWorld()
74
+ action = ScaleAction(action_type="scale", deployment="frontend", replicas=5)
75
+ result = execute(action, mock_world)
76
+
77
+ assert result.action_applied == "Scaled 'frontend' to 5 replicas"
78
+
79
+
80
+ class TestExecutorDeletePod:
81
+ def test_delete_pod_calls_world_and_ticks(self):
82
+ mock_world = MockWorld()
83
+ action = DeletePodAction(action_type="delete_pod", pod_name="frontend-7d9f-xkp2")
84
+ result = execute(action, mock_world)
85
+
86
+ assert mock_world.delete_pod_called_with == "frontend-7d9f-xkp2"
87
+ assert mock_world.tick_called is True
88
+ assert result.tick_advanced is True
89
+
90
+
91
+ class TestExecutorPatch:
92
+ def test_patch_calls_world_and_ticks(self):
93
+ mock_world = MockWorld()
94
+ action = PatchAction(
95
+ action_type="patch",
96
+ resource_type="deployment",
97
+ name="frontend",
98
+ patch={"env": [{"name": "DB_HOST", "value": "db.prod.internal"}]}
99
+ )
100
+ result = execute(action, mock_world)
101
+
102
+ assert mock_world.apply_patch_called_with == (
103
+ "deployment",
104
+ "frontend",
105
+ {"env": [{"name": "DB_HOST", "value": "db.prod.internal"}]}
106
+ )
107
+ assert mock_world.tick_called is True
108
+ assert result.tick_advanced is True
109
+
110
+
111
+ class TestExecutorRolloutRestart:
112
+ def test_rollout_restart_calls_world_and_ticks(self):
113
+ mock_world = MockWorld()
114
+ action = RolloutRestartAction(action_type="rollout_restart", deployment="frontend")
115
+ result = execute(action, mock_world)
116
+
117
+ assert mock_world.rollout_restart_called_with == "frontend"
118
+ assert mock_world.tick_called is True
119
+ assert result.tick_advanced is True
120
+
121
+
122
+ class TestExecutorSetHPA:
123
+ def test_set_hpa_calls_world_and_ticks(self):
124
+ mock_world = MockWorld()
125
+ action = SetHPAAction(
126
+ action_type="set_hpa",
127
+ deployment="api",
128
+ min_replicas=2,
129
+ max_replicas=10,
130
+ cpu_target_percent=70
131
+ )
132
+ result = execute(action, mock_world)
133
+
134
+ assert mock_world.set_hpa_called_with == ("api", 2, 10, 70)
135
+ assert mock_world.tick_called is True
136
+ assert result.tick_advanced is True
137
+
138
+
139
+ class TestExecutorDrainNode:
140
+ def test_drain_node_calls_world_and_ticks(self):
141
+ mock_world = MockWorld()
142
+ action = DrainNodeAction(action_type="drain_node", node_name="node-1")
143
+ result = execute(action, mock_world)
144
+
145
+ assert mock_world.drain_node_called_with == "node-1"
146
+ assert mock_world.tick_called is True
147
+ assert result.tick_advanced is True
148
+
149
+
150
+ class TestExecutorDescribe:
151
+ def test_describe_does_not_tick(self):
152
+ mock_world = MockWorld()
153
+ action = DescribeAction(
154
+ action_type="describe",
155
+ resource_type="deployment",
156
+ name="frontend"
157
+ )
158
+ result = execute(action, mock_world)
159
+
160
+ assert mock_world.describe_called_with == ("deployment", "frontend")
161
+ assert mock_world.tick_called is False
162
+ assert result.tick_advanced is False
163
+
164
+ def test_describe_returns_detail(self):
165
+ mock_world = MockWorld()
166
+ action = DescribeAction(
167
+ action_type="describe",
168
+ resource_type="deployment",
169
+ name="frontend"
170
+ )
171
+ result = execute(action, mock_world)
172
+
173
+ assert result.describe_detail is not None
174
+ assert result.describe_detail["type"] == "deployment"