Spaces:
Configuration error
Configuration error
Commit ·
d58554d
1
Parent(s): 4f9312c
debugging
Browse files- inference.py +51 -24
- server/environment.py +20 -20
- server/models.py +5 -13
- server/tasks.py +26 -17
inference.py
CHANGED
|
@@ -31,10 +31,7 @@ TASKS = ["traffic_spike", "node_failure", "cascading_failure", "flash_crowd"]
|
|
| 31 |
MAX_RETRIES = 3
|
| 32 |
BENCHMARK = "distributed_infra_env"
|
| 33 |
|
| 34 |
-
client = OpenAI(
|
| 35 |
-
base_url=API_BASE_URL,
|
| 36 |
-
api_key=API_KEY
|
| 37 |
-
)
|
| 38 |
|
| 39 |
SYSTEM_PROMPT = """You are an expert Site Reliability Engineer (SRE).
|
| 40 |
You receive observations about the system state as JSON and must respond with a single action as JSON.
|
|
@@ -65,25 +62,40 @@ Respond with ONLY a valid JSON action object. No markdown formatting, and no oth
|
|
| 65 |
# Required Logging Functions
|
| 66 |
# ---------------------------------------------------------------------------
|
| 67 |
|
|
|
|
| 68 |
def log_start(task: str, env: str, model: str) -> None:
|
| 69 |
print(f"[START] task={task} env={env} model={model}", flush=True)
|
| 70 |
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
| 72 |
error_val = error if error else "null"
|
| 73 |
done_val = str(done).lower()
|
| 74 |
-
print(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
|
| 77 |
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
|
| 78 |
-
print(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
# ---------------------------------------------------------------------------
|
| 81 |
# Core Logic
|
| 82 |
# ---------------------------------------------------------------------------
|
| 83 |
|
|
|
|
| 84 |
def llm_decide(observation: dict) -> dict:
|
| 85 |
obs_str = json.dumps(observation)
|
| 86 |
-
user_prompt =
|
|
|
|
|
|
|
| 87 |
|
| 88 |
for attempt in range(MAX_RETRIES):
|
| 89 |
try:
|
|
@@ -104,14 +116,19 @@ def llm_decide(observation: dict) -> dict:
|
|
| 104 |
return json.loads(content)
|
| 105 |
except Exception as e:
|
| 106 |
# FIX: Print the actual error so it shows up in logs!
|
| 107 |
-
print(
|
|
|
|
|
|
|
| 108 |
time.sleep(1)
|
| 109 |
-
|
| 110 |
# If it fails all retries, return a no_op
|
| 111 |
return {"action_type": "no_op"}
|
| 112 |
|
|
|
|
| 113 |
def env_reset(task_id: str) -> dict:
|
| 114 |
-
response = requests.post(
|
|
|
|
|
|
|
| 115 |
response.raise_for_status()
|
| 116 |
payload = response.json()
|
| 117 |
data_block = payload.get("data", payload)
|
|
@@ -119,14 +136,18 @@ def env_reset(task_id: str) -> dict:
|
|
| 119 |
return data_block["observation"]
|
| 120 |
return data_block
|
| 121 |
|
|
|
|
| 122 |
def env_step(action: dict) -> dict:
|
| 123 |
-
response = requests.post(
|
|
|
|
|
|
|
| 124 |
response.raise_for_status()
|
| 125 |
return response.json()
|
| 126 |
|
|
|
|
| 127 |
def run_task(task_id: str) -> float:
|
| 128 |
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
|
| 129 |
-
|
| 130 |
try:
|
| 131 |
obs = env_reset(task_id)
|
| 132 |
except Exception as e:
|
|
@@ -135,18 +156,18 @@ def run_task(task_id: str) -> float:
|
|
| 135 |
|
| 136 |
step = 0
|
| 137 |
rewards_list = []
|
| 138 |
-
|
| 139 |
-
# Initialize task_score outside the loop so we always have a value
|
| 140 |
# even if the loop breaks early or errors out.
|
| 141 |
task_score = 0.0
|
| 142 |
|
| 143 |
while True:
|
| 144 |
step += 1
|
| 145 |
action = llm_decide(obs)
|
| 146 |
-
|
| 147 |
# Format action strictly on one line without quotes that break bash/parsing
|
| 148 |
-
action_str = json.dumps(action).replace('"', "'")
|
| 149 |
-
|
| 150 |
error_msg = None
|
| 151 |
reward = 0.0
|
| 152 |
done = False
|
|
@@ -155,35 +176,41 @@ def run_task(task_id: str) -> float:
|
|
| 155 |
# ---> THE CHANGES YOU ASKED ABOUT ARE HERE <---
|
| 156 |
result = env_step(action)
|
| 157 |
data_block = result.get("data", result)
|
| 158 |
-
|
| 159 |
-
if "observation" in data_block and isinstance(
|
|
|
|
|
|
|
| 160 |
obs = data_block["observation"]
|
| 161 |
else:
|
| 162 |
obs = data_block
|
| 163 |
|
| 164 |
reward = float(data_block.get("reward", obs.get("reward", 0.0)))
|
| 165 |
done = bool(data_block.get("done", obs.get("done", False)))
|
| 166 |
-
|
| 167 |
# This continuously updates the task_score on every single step.
|
| 168 |
task_score = float(obs.get("task_score", 0.0))
|
| 169 |
|
| 170 |
except Exception as e:
|
| 171 |
-
error_msg = str(e).replace("\n", " ")
|
| 172 |
done = True
|
| 173 |
|
| 174 |
rewards_list.append(reward)
|
| 175 |
-
log_step(
|
|
|
|
|
|
|
| 176 |
|
| 177 |
# Even if step > 100 hits (timeout failure), task_score has the partial credit from the last step!
|
| 178 |
if done or step > 100:
|
| 179 |
# Define success: Let's say getting more than 0.1 points counts as partial success
|
| 180 |
-
success = task_score >= 0.1
|
| 181 |
log_end(success=success, steps=step, score=task_score, rewards=rewards_list)
|
| 182 |
return task_score
|
| 183 |
|
|
|
|
| 184 |
def main():
|
| 185 |
for task_id in TASKS:
|
| 186 |
run_task(task_id)
|
| 187 |
|
|
|
|
| 188 |
if __name__ == "__main__":
|
| 189 |
main()
|
|
|
|
| 31 |
MAX_RETRIES = 3
|
| 32 |
BENCHMARK = "distributed_infra_env"
|
| 33 |
|
| 34 |
+
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
SYSTEM_PROMPT = """You are an expert Site Reliability Engineer (SRE).
|
| 37 |
You receive observations about the system state as JSON and must respond with a single action as JSON.
|
|
|
|
| 62 |
# Required Logging Functions
|
| 63 |
# ---------------------------------------------------------------------------
|
| 64 |
|
| 65 |
+
|
| 66 |
def log_start(task: str, env: str, model: str) -> None:
|
| 67 |
print(f"[START] task={task} env={env} model={model}", flush=True)
|
| 68 |
|
| 69 |
+
|
| 70 |
+
def log_step(
|
| 71 |
+
step: int, action: str, reward: float, done: bool, error: Optional[str]
|
| 72 |
+
) -> None:
|
| 73 |
error_val = error if error else "null"
|
| 74 |
done_val = str(done).lower()
|
| 75 |
+
print(
|
| 76 |
+
f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
|
| 77 |
+
flush=True,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
|
| 81 |
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
|
| 82 |
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
|
| 83 |
+
print(
|
| 84 |
+
f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
|
| 85 |
+
flush=True,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
|
| 89 |
# ---------------------------------------------------------------------------
|
| 90 |
# Core Logic
|
| 91 |
# ---------------------------------------------------------------------------
|
| 92 |
|
| 93 |
+
|
| 94 |
def llm_decide(observation: dict) -> dict:
|
| 95 |
obs_str = json.dumps(observation)
|
| 96 |
+
user_prompt = (
|
| 97 |
+
f"Current system state:\n{obs_str}\nRespond with ONLY a JSON action object."
|
| 98 |
+
)
|
| 99 |
|
| 100 |
for attempt in range(MAX_RETRIES):
|
| 101 |
try:
|
|
|
|
| 116 |
return json.loads(content)
|
| 117 |
except Exception as e:
|
| 118 |
# FIX: Print the actual error so it shows up in logs!
|
| 119 |
+
print(
|
| 120 |
+
f"[DEBUG] LLM call attempt {attempt + 1} failed: {str(e)}", flush=True
|
| 121 |
+
)
|
| 122 |
time.sleep(1)
|
| 123 |
+
|
| 124 |
# If it fails all retries, return a no_op
|
| 125 |
return {"action_type": "no_op"}
|
| 126 |
|
| 127 |
+
|
| 128 |
def env_reset(task_id: str) -> dict:
|
| 129 |
+
response = requests.post(
|
| 130 |
+
f"{ENV_SERVER_URL}/reset", json={"task": task_id}, timeout=10
|
| 131 |
+
)
|
| 132 |
response.raise_for_status()
|
| 133 |
payload = response.json()
|
| 134 |
data_block = payload.get("data", payload)
|
|
|
|
| 136 |
return data_block["observation"]
|
| 137 |
return data_block
|
| 138 |
|
| 139 |
+
|
| 140 |
def env_step(action: dict) -> dict:
|
| 141 |
+
response = requests.post(
|
| 142 |
+
f"{ENV_SERVER_URL}/step", json={"action": action}, timeout=10
|
| 143 |
+
)
|
| 144 |
response.raise_for_status()
|
| 145 |
return response.json()
|
| 146 |
|
| 147 |
+
|
| 148 |
def run_task(task_id: str) -> float:
|
| 149 |
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
|
| 150 |
+
|
| 151 |
try:
|
| 152 |
obs = env_reset(task_id)
|
| 153 |
except Exception as e:
|
|
|
|
| 156 |
|
| 157 |
step = 0
|
| 158 |
rewards_list = []
|
| 159 |
+
|
| 160 |
+
# Initialize task_score outside the loop so we always have a value
|
| 161 |
# even if the loop breaks early or errors out.
|
| 162 |
task_score = 0.0
|
| 163 |
|
| 164 |
while True:
|
| 165 |
step += 1
|
| 166 |
action = llm_decide(obs)
|
| 167 |
+
|
| 168 |
# Format action strictly on one line without quotes that break bash/parsing
|
| 169 |
+
action_str = json.dumps(action).replace('"', "'")
|
| 170 |
+
|
| 171 |
error_msg = None
|
| 172 |
reward = 0.0
|
| 173 |
done = False
|
|
|
|
| 176 |
# ---> THE CHANGES YOU ASKED ABOUT ARE HERE <---
|
| 177 |
result = env_step(action)
|
| 178 |
data_block = result.get("data", result)
|
| 179 |
+
|
| 180 |
+
if "observation" in data_block and isinstance(
|
| 181 |
+
data_block["observation"], dict
|
| 182 |
+
):
|
| 183 |
obs = data_block["observation"]
|
| 184 |
else:
|
| 185 |
obs = data_block
|
| 186 |
|
| 187 |
reward = float(data_block.get("reward", obs.get("reward", 0.0)))
|
| 188 |
done = bool(data_block.get("done", obs.get("done", False)))
|
| 189 |
+
|
| 190 |
# This continuously updates the task_score on every single step.
|
| 191 |
task_score = float(obs.get("task_score", 0.0))
|
| 192 |
|
| 193 |
except Exception as e:
|
| 194 |
+
error_msg = str(e).replace("\n", " ") # Prevent newline breaks in STDOUT
|
| 195 |
done = True
|
| 196 |
|
| 197 |
rewards_list.append(reward)
|
| 198 |
+
log_step(
|
| 199 |
+
step=step, action=action_str, reward=reward, done=done, error=error_msg
|
| 200 |
+
)
|
| 201 |
|
| 202 |
# Even if step > 100 hits (timeout failure), task_score has the partial credit from the last step!
|
| 203 |
if done or step > 100:
|
| 204 |
# Define success: Let's say getting more than 0.1 points counts as partial success
|
| 205 |
+
success = task_score >= 0.1
|
| 206 |
log_end(success=success, steps=step, score=task_score, rewards=rewards_list)
|
| 207 |
return task_score
|
| 208 |
|
| 209 |
+
|
| 210 |
def main():
|
| 211 |
for task_id in TASKS:
|
| 212 |
run_task(task_id)
|
| 213 |
|
| 214 |
+
|
| 215 |
if __name__ == "__main__":
|
| 216 |
main()
|
server/environment.py
CHANGED
|
@@ -23,6 +23,7 @@ def _get_tasks():
|
|
| 23 |
global _TASKS
|
| 24 |
if _TASKS is None:
|
| 25 |
from server.tasks import TASKS
|
|
|
|
| 26 |
_TASKS = TASKS
|
| 27 |
return _TASKS
|
| 28 |
|
|
@@ -41,10 +42,10 @@ class Node:
|
|
| 41 |
capacity: int = 15
|
| 42 |
is_failed: bool = False
|
| 43 |
memory_util: float = 0.2
|
| 44 |
-
high_cpu_streak: int = 0
|
| 45 |
-
restart_countdown: int = 0
|
| 46 |
-
is_temporary: bool = False
|
| 47 |
-
ttl: int = 0
|
| 48 |
|
| 49 |
|
| 50 |
@dataclass
|
|
@@ -56,9 +57,9 @@ class SimulationState:
|
|
| 56 |
step_count: int = 0
|
| 57 |
base_request_rate: float = 100.0
|
| 58 |
current_request_rate: float = 100.0
|
| 59 |
-
throttle_rate: float = 1.0
|
| 60 |
latency_ms: float = 20.0
|
| 61 |
-
actions_taken: int = 0
|
| 62 |
cascade_bonus_awarded: bool = False
|
| 63 |
task_id: str = ""
|
| 64 |
max_steps: int = 30
|
|
@@ -74,6 +75,7 @@ class SimulationState:
|
|
| 74 |
# Default graph topology: 8 nodes in a mesh-like structure
|
| 75 |
# ---------------------------------------------------------------------------
|
| 76 |
|
|
|
|
| 77 |
def _build_default_graph(n: int = 8) -> Tuple[List[Node], Dict[int, List[int]]]:
|
| 78 |
"""Create a default mesh-like graph of n nodes."""
|
| 79 |
nodes = [Node(cpu_util=0.25 + random.uniform(-0.05, 0.05)) for _ in range(n)]
|
|
@@ -175,7 +177,7 @@ class DistributedInfraEnvironment(Environment):
|
|
| 175 |
episode_id=eid,
|
| 176 |
step_count=0,
|
| 177 |
task_id=task_id,
|
| 178 |
-
task_score=0.
|
| 179 |
)
|
| 180 |
|
| 181 |
return self._make_observation()
|
|
@@ -224,7 +226,7 @@ class DistributedInfraEnvironment(Environment):
|
|
| 224 |
# 10. Check termination
|
| 225 |
tasks = _get_tasks()
|
| 226 |
done = sim.step_count >= sim.max_steps
|
| 227 |
-
task_score = 0.
|
| 228 |
if sim.task_id in tasks:
|
| 229 |
task_info = tasks[sim.task_id]
|
| 230 |
if task_info["is_done"](self):
|
|
@@ -293,7 +295,8 @@ class DistributedInfraEnvironment(Environment):
|
|
| 293 |
for neighbor_idx in sim.adjacency.get(src, []):
|
| 294 |
if (
|
| 295 |
not sim.nodes[neighbor_idx].is_failed
|
| 296 |
-
and sim.nodes[neighbor_idx].cpu_util
|
|
|
|
| 297 |
):
|
| 298 |
sim.cascade_bonus_awarded = True
|
| 299 |
break
|
|
@@ -310,9 +313,7 @@ class DistributedInfraEnvironment(Environment):
|
|
| 310 |
sim.nodes.append(new_node)
|
| 311 |
# Connect to a few existing nodes
|
| 312 |
sim.adjacency[new_idx] = []
|
| 313 |
-
connect_to = self._rng.sample(
|
| 314 |
-
range(new_idx), min(3, new_idx)
|
| 315 |
-
)
|
| 316 |
for c in connect_to:
|
| 317 |
sim.adjacency[new_idx].append(c)
|
| 318 |
sim.adjacency[c].append(new_idx)
|
|
@@ -385,9 +386,7 @@ class DistributedInfraEnvironment(Environment):
|
|
| 385 |
new_adj: Dict[int, List[int]] = {}
|
| 386 |
for k, v in sim.adjacency.items():
|
| 387 |
new_k = k if k < idx else k - 1
|
| 388 |
-
new_v = [
|
| 389 |
-
(x if x < idx else x - 1) for x in v if x != idx
|
| 390 |
-
]
|
| 391 |
new_adj[new_k] = new_v
|
| 392 |
sim.adjacency = new_adj
|
| 393 |
|
|
@@ -411,7 +410,10 @@ class DistributedInfraEnvironment(Environment):
|
|
| 411 |
0.05,
|
| 412 |
min(
|
| 413 |
1.0,
|
| 414 |
-
node.cpu_util
|
|
|
|
|
|
|
|
|
|
| 415 |
+ self._rng.uniform(-0.02, 0.02),
|
| 416 |
),
|
| 417 |
)
|
|
@@ -439,7 +441,7 @@ class DistributedInfraEnvironment(Environment):
|
|
| 439 |
# Latency model: base + queue component + CPU-pressure component
|
| 440 |
base_latency = 10.0
|
| 441 |
queue_latency = avg_queue * 1.5
|
| 442 |
-
cpu_latency = (avg_cpu
|
| 443 |
|
| 444 |
new_latency = base_latency + queue_latency + cpu_latency
|
| 445 |
# Exponential moving average
|
|
@@ -523,9 +525,7 @@ class DistributedInfraEnvironment(Environment):
|
|
| 523 |
normalized_latency = min(2.0, sim.latency_ms / TARGET_LATENCY_MS)
|
| 524 |
|
| 525 |
overloaded = sum(
|
| 526 |
-
1
|
| 527 |
-
for n in sim.nodes
|
| 528 |
-
if not n.is_failed and n.cpu_util > OVERLOAD_THRESHOLD
|
| 529 |
)
|
| 530 |
overload_fraction = overloaded / total
|
| 531 |
|
|
|
|
| 23 |
global _TASKS
|
| 24 |
if _TASKS is None:
|
| 25 |
from server.tasks import TASKS
|
| 26 |
+
|
| 27 |
_TASKS = TASKS
|
| 28 |
return _TASKS
|
| 29 |
|
|
|
|
| 42 |
capacity: int = 15
|
| 43 |
is_failed: bool = False
|
| 44 |
memory_util: float = 0.2
|
| 45 |
+
high_cpu_streak: int = 0 # consecutive steps above 90% CPU
|
| 46 |
+
restart_countdown: int = 0 # >0 means the node is restarting
|
| 47 |
+
is_temporary: bool = False # True for scale-up nodes
|
| 48 |
+
ttl: int = 0 # remaining lifetime for temp nodes
|
| 49 |
|
| 50 |
|
| 51 |
@dataclass
|
|
|
|
| 57 |
step_count: int = 0
|
| 58 |
base_request_rate: float = 100.0
|
| 59 |
current_request_rate: float = 100.0
|
| 60 |
+
throttle_rate: float = 1.0 # 1.0 = accept all
|
| 61 |
latency_ms: float = 20.0
|
| 62 |
+
actions_taken: int = 0 # non-no_op actions
|
| 63 |
cascade_bonus_awarded: bool = False
|
| 64 |
task_id: str = ""
|
| 65 |
max_steps: int = 30
|
|
|
|
| 75 |
# Default graph topology: 8 nodes in a mesh-like structure
|
| 76 |
# ---------------------------------------------------------------------------
|
| 77 |
|
| 78 |
+
|
| 79 |
def _build_default_graph(n: int = 8) -> Tuple[List[Node], Dict[int, List[int]]]:
|
| 80 |
"""Create a default mesh-like graph of n nodes."""
|
| 81 |
nodes = [Node(cpu_util=0.25 + random.uniform(-0.05, 0.05)) for _ in range(n)]
|
|
|
|
| 177 |
episode_id=eid,
|
| 178 |
step_count=0,
|
| 179 |
task_id=task_id,
|
| 180 |
+
task_score=0.01,
|
| 181 |
)
|
| 182 |
|
| 183 |
return self._make_observation()
|
|
|
|
| 226 |
# 10. Check termination
|
| 227 |
tasks = _get_tasks()
|
| 228 |
done = sim.step_count >= sim.max_steps
|
| 229 |
+
task_score = 0.01
|
| 230 |
if sim.task_id in tasks:
|
| 231 |
task_info = tasks[sim.task_id]
|
| 232 |
if task_info["is_done"](self):
|
|
|
|
| 295 |
for neighbor_idx in sim.adjacency.get(src, []):
|
| 296 |
if (
|
| 297 |
not sim.nodes[neighbor_idx].is_failed
|
| 298 |
+
and sim.nodes[neighbor_idx].cpu_util
|
| 299 |
+
> CASCADE_AWARENESS_THRESHOLD
|
| 300 |
):
|
| 301 |
sim.cascade_bonus_awarded = True
|
| 302 |
break
|
|
|
|
| 313 |
sim.nodes.append(new_node)
|
| 314 |
# Connect to a few existing nodes
|
| 315 |
sim.adjacency[new_idx] = []
|
| 316 |
+
connect_to = self._rng.sample(range(new_idx), min(3, new_idx))
|
|
|
|
|
|
|
| 317 |
for c in connect_to:
|
| 318 |
sim.adjacency[new_idx].append(c)
|
| 319 |
sim.adjacency[c].append(new_idx)
|
|
|
|
| 386 |
new_adj: Dict[int, List[int]] = {}
|
| 387 |
for k, v in sim.adjacency.items():
|
| 388 |
new_k = k if k < idx else k - 1
|
| 389 |
+
new_v = [(x if x < idx else x - 1) for x in v if x != idx]
|
|
|
|
|
|
|
| 390 |
new_adj[new_k] = new_v
|
| 391 |
sim.adjacency = new_adj
|
| 392 |
|
|
|
|
| 410 |
0.05,
|
| 411 |
min(
|
| 412 |
1.0,
|
| 413 |
+
node.cpu_util
|
| 414 |
+
+ cpu_from_queue
|
| 415 |
+
+ cpu_from_processing
|
| 416 |
+
- natural_decay
|
| 417 |
+ self._rng.uniform(-0.02, 0.02),
|
| 418 |
),
|
| 419 |
)
|
|
|
|
| 441 |
# Latency model: base + queue component + CPU-pressure component
|
| 442 |
base_latency = 10.0
|
| 443 |
queue_latency = avg_queue * 1.5
|
| 444 |
+
cpu_latency = (avg_cpu**2) * 80.0 # non-linear increase under load
|
| 445 |
|
| 446 |
new_latency = base_latency + queue_latency + cpu_latency
|
| 447 |
# Exponential moving average
|
|
|
|
| 525 |
normalized_latency = min(2.0, sim.latency_ms / TARGET_LATENCY_MS)
|
| 526 |
|
| 527 |
overloaded = sum(
|
| 528 |
+
1 for n in sim.nodes if not n.is_failed and n.cpu_util > OVERLOAD_THRESHOLD
|
|
|
|
|
|
|
| 529 |
)
|
| 530 |
overload_fraction = overloaded / total
|
| 531 |
|
server/models.py
CHANGED
|
@@ -68,9 +68,7 @@ class InfraObservation(Observation):
|
|
| 68 |
cpu_loads: List[float] = Field(
|
| 69 |
description="CPU utilization [0.0, 1.0] for each node."
|
| 70 |
)
|
| 71 |
-
queue_lengths: List[int] = Field(
|
| 72 |
-
description="Number of pending requests per node."
|
| 73 |
-
)
|
| 74 |
failed_nodes: List[int] = Field(
|
| 75 |
description="Indices of nodes currently in failed state."
|
| 76 |
)
|
|
@@ -80,15 +78,11 @@ class InfraObservation(Observation):
|
|
| 80 |
request_rate: float = Field(
|
| 81 |
description="Incoming requests per second into the system."
|
| 82 |
)
|
| 83 |
-
step: int = Field(
|
| 84 |
-
description="Current step within the episode."
|
| 85 |
-
)
|
| 86 |
task_hint: str = Field(
|
| 87 |
description="Natural language description of the current task objective."
|
| 88 |
)
|
| 89 |
-
task_score: float = Field(
|
| 90 |
-
default=0.0, description="Current grader score"
|
| 91 |
-
)
|
| 92 |
|
| 93 |
|
| 94 |
class InfraState(State):
|
|
@@ -96,9 +90,7 @@ class InfraState(State):
|
|
| 96 |
Internal environment state extending the base OpenEnv State.
|
| 97 |
"""
|
| 98 |
|
| 99 |
-
task_id: Optional[str] = Field(
|
| 100 |
-
default=None, description="Current task identifier."
|
| 101 |
-
)
|
| 102 |
task_score: float = Field(
|
| 103 |
-
default=0.
|
| 104 |
)
|
|
|
|
| 68 |
cpu_loads: List[float] = Field(
|
| 69 |
description="CPU utilization [0.0, 1.0] for each node."
|
| 70 |
)
|
| 71 |
+
queue_lengths: List[int] = Field(description="Number of pending requests per node.")
|
|
|
|
|
|
|
| 72 |
failed_nodes: List[int] = Field(
|
| 73 |
description="Indices of nodes currently in failed state."
|
| 74 |
)
|
|
|
|
| 78 |
request_rate: float = Field(
|
| 79 |
description="Incoming requests per second into the system."
|
| 80 |
)
|
| 81 |
+
step: int = Field(description="Current step within the episode.")
|
|
|
|
|
|
|
| 82 |
task_hint: str = Field(
|
| 83 |
description="Natural language description of the current task objective."
|
| 84 |
)
|
| 85 |
+
task_score: float = Field(default=0.01, description="Current grader score")
|
|
|
|
|
|
|
| 86 |
|
| 87 |
|
| 88 |
class InfraState(State):
|
|
|
|
| 90 |
Internal environment state extending the base OpenEnv State.
|
| 91 |
"""
|
| 92 |
|
| 93 |
+
task_id: Optional[str] = Field(default=None, description="Current task identifier.")
|
|
|
|
|
|
|
| 94 |
task_score: float = Field(
|
| 95 |
+
default=0.01, description="Current task grader score in (0.0, 1.0) strictly."
|
| 96 |
)
|
server/tasks.py
CHANGED
|
@@ -4,7 +4,7 @@ Distributed Infrastructure Management Environment.
|
|
| 4 |
|
| 5 |
Each task provides:
|
| 6 |
- setup(env, rng): configure initial node states and scenario parameters
|
| 7 |
-
- grade(env): return float in
|
| 8 |
- is_done(env): termination condition check
|
| 9 |
- hint: natural language task description for the agent
|
| 10 |
"""
|
|
@@ -21,6 +21,7 @@ if TYPE_CHECKING:
|
|
| 21 |
# Task 1 — Easy: Traffic Spike Recovery
|
| 22 |
# ============================================================================
|
| 23 |
|
|
|
|
| 24 |
def _setup_traffic_spike(env: "DistributedInfraEnvironment", rng: "random.Random"):
|
| 25 |
"""System receives 3x normal request rate."""
|
| 26 |
sim = env.sim
|
|
@@ -38,7 +39,7 @@ def _grade_traffic_spike(env: "DistributedInfraEnvironment") -> float:
|
|
| 38 |
"""
|
| 39 |
sim = env.sim
|
| 40 |
if not sim.latency_history:
|
| 41 |
-
return 0.
|
| 42 |
|
| 43 |
# Latency component: fraction of steps where latency was below target
|
| 44 |
target = 50.0 # ms
|
|
@@ -46,14 +47,16 @@ def _grade_traffic_spike(env: "DistributedInfraEnvironment") -> float:
|
|
| 46 |
latency_score = below_target / len(sim.latency_history)
|
| 47 |
|
| 48 |
# Uptime component: average uptime ratio
|
| 49 |
-
avg_uptime =
|
|
|
|
|
|
|
| 50 |
|
| 51 |
# Efficiency: penalty for excessive actions
|
| 52 |
max_reasonable = sim.max_steps * 0.5
|
| 53 |
efficiency = max(0.0, 1.0 - sim.actions_taken / max(1, max_reasonable))
|
| 54 |
|
| 55 |
score = 0.50 * latency_score + 0.30 * avg_uptime + 0.20 * efficiency
|
| 56 |
-
return round(min(
|
| 57 |
|
| 58 |
|
| 59 |
def _is_done_traffic_spike(env: "DistributedInfraEnvironment") -> bool:
|
|
@@ -64,6 +67,7 @@ def _is_done_traffic_spike(env: "DistributedInfraEnvironment") -> bool:
|
|
| 64 |
# Task 2 — Medium: Single Node Failure
|
| 65 |
# ============================================================================
|
| 66 |
|
|
|
|
| 67 |
def _setup_node_failure(env: "DistributedInfraEnvironment", rng: "random.Random"):
|
| 68 |
"""One node will fail at step 5. Agent must maintain 80%+ uptime."""
|
| 69 |
sim = env.sim
|
|
@@ -82,7 +86,7 @@ def _grade_node_failure(env: "DistributedInfraEnvironment") -> float:
|
|
| 82 |
sim = env.sim
|
| 83 |
|
| 84 |
if not sim.uptime_history:
|
| 85 |
-
return 0.
|
| 86 |
|
| 87 |
# MTTR: how quickly system recovered from the failure
|
| 88 |
failure_duration = 0
|
|
@@ -105,7 +109,7 @@ def _grade_node_failure(env: "DistributedInfraEnvironment") -> float:
|
|
| 105 |
restart_penalty = max(0.0, 1.0 - max(0, sim.restart_count - 1) / 5)
|
| 106 |
|
| 107 |
score = 0.40 * mttr_score + 0.40 * uptime_score + 0.20 * restart_penalty
|
| 108 |
-
return round(min(
|
| 109 |
|
| 110 |
|
| 111 |
def _is_done_node_failure(env: "DistributedInfraEnvironment") -> bool:
|
|
@@ -125,6 +129,7 @@ def _is_done_node_failure(env: "DistributedInfraEnvironment") -> bool:
|
|
| 125 |
# Task 3 — Hard: Cascading Failure Prevention
|
| 126 |
# ============================================================================
|
| 127 |
|
|
|
|
| 128 |
def _setup_cascading_failure(env: "DistributedInfraEnvironment", rng: "random.Random"):
|
| 129 |
"""Two nodes near critical CPU. Agent must prevent cascade chain."""
|
| 130 |
sim = env.sim
|
|
@@ -157,10 +162,7 @@ def _grade_cascading_failure(env: "DistributedInfraEnvironment") -> float:
|
|
| 157 |
cascade_score = 1.0 if not sim.cascade_occurred else 0.3
|
| 158 |
|
| 159 |
if sim.uptime_history:
|
| 160 |
-
healthy_now = sum(
|
| 161 |
-
1 for n in sim.nodes
|
| 162 |
-
if not n.is_failed and n.cpu_util < 0.85
|
| 163 |
-
)
|
| 164 |
total_now = len(sim.nodes)
|
| 165 |
cpu_score = healthy_now / total_now if total_now > 0 else 0.0
|
| 166 |
else:
|
|
@@ -170,7 +172,7 @@ def _grade_cascading_failure(env: "DistributedInfraEnvironment") -> float:
|
|
| 170 |
efficiency = max(0.0, 1.0 - sim.actions_taken / max(1, max_reasonable))
|
| 171 |
|
| 172 |
score = 0.50 * cascade_score + 0.30 * cpu_score + 0.20 * efficiency
|
| 173 |
-
return round(min(
|
| 174 |
|
| 175 |
|
| 176 |
def _is_done_cascading_failure(env: "DistributedInfraEnvironment") -> bool:
|
|
@@ -185,6 +187,7 @@ def _is_done_cascading_failure(env: "DistributedInfraEnvironment") -> bool:
|
|
| 185 |
# Task 4 — Expert: Flash Crowd
|
| 186 |
# ============================================================================
|
| 187 |
|
|
|
|
| 188 |
def _setup_flash_crowd(env: "DistributedInfraEnvironment", rng: "random.Random"):
|
| 189 |
"""Massive 5x traffic spike. Agent must scale up AND throttle to survive."""
|
| 190 |
sim = env.sim
|
|
@@ -194,24 +197,30 @@ def _setup_flash_crowd(env: "DistributedInfraEnvironment", rng: "random.Random")
|
|
| 194 |
node.cpu_util = 0.60 + rng.uniform(-0.05, 0.1)
|
| 195 |
node.queue_length = rng.randint(15, 30)
|
| 196 |
|
|
|
|
| 197 |
def _grade_flash_crowd(env: "DistributedInfraEnvironment") -> float:
|
| 198 |
"""
|
| 199 |
Score = Survival Uptime (50%) + Latency control (50%).
|
| 200 |
Cascade penalty applied if the system collapses.
|
| 201 |
"""
|
| 202 |
sim = env.sim
|
| 203 |
-
|
| 204 |
-
avg_uptime =
|
| 205 |
-
|
|
|
|
|
|
|
| 206 |
# Latency target is more generous for a massive flash crowd (100ms)
|
| 207 |
-
target = 100.0
|
| 208 |
below_target = sum(1 for lat in sim.latency_history if lat < target)
|
| 209 |
-
latency_score =
|
|
|
|
|
|
|
| 210 |
|
| 211 |
cascade_penalty = 0.4 if sim.cascade_occurred else 0.0
|
| 212 |
|
| 213 |
score = 0.50 * avg_uptime + 0.50 * latency_score - cascade_penalty
|
| 214 |
-
return round(min(
|
|
|
|
| 215 |
|
| 216 |
def _is_done_flash_crowd(env: "DistributedInfraEnvironment") -> bool:
|
| 217 |
failed_count = sum(1 for n in env.sim.nodes if n.is_failed)
|
|
|
|
| 4 |
|
| 5 |
Each task provides:
|
| 6 |
- setup(env, rng): configure initial node states and scenario parameters
|
| 7 |
+
- grade(env): return float in (0.0, 1.0) with partial credit (strictly between 0 and 1)
|
| 8 |
- is_done(env): termination condition check
|
| 9 |
- hint: natural language task description for the agent
|
| 10 |
"""
|
|
|
|
| 21 |
# Task 1 — Easy: Traffic Spike Recovery
|
| 22 |
# ============================================================================
|
| 23 |
|
| 24 |
+
|
| 25 |
def _setup_traffic_spike(env: "DistributedInfraEnvironment", rng: "random.Random"):
|
| 26 |
"""System receives 3x normal request rate."""
|
| 27 |
sim = env.sim
|
|
|
|
| 39 |
"""
|
| 40 |
sim = env.sim
|
| 41 |
if not sim.latency_history:
|
| 42 |
+
return 0.01
|
| 43 |
|
| 44 |
# Latency component: fraction of steps where latency was below target
|
| 45 |
target = 50.0 # ms
|
|
|
|
| 47 |
latency_score = below_target / len(sim.latency_history)
|
| 48 |
|
| 49 |
# Uptime component: average uptime ratio
|
| 50 |
+
avg_uptime = (
|
| 51 |
+
sum(sim.uptime_history) / len(sim.uptime_history) if sim.uptime_history else 1.0
|
| 52 |
+
)
|
| 53 |
|
| 54 |
# Efficiency: penalty for excessive actions
|
| 55 |
max_reasonable = sim.max_steps * 0.5
|
| 56 |
efficiency = max(0.0, 1.0 - sim.actions_taken / max(1, max_reasonable))
|
| 57 |
|
| 58 |
score = 0.50 * latency_score + 0.30 * avg_uptime + 0.20 * efficiency
|
| 59 |
+
return round(min(0.99, max(0.01, score)), 4)
|
| 60 |
|
| 61 |
|
| 62 |
def _is_done_traffic_spike(env: "DistributedInfraEnvironment") -> bool:
|
|
|
|
| 67 |
# Task 2 — Medium: Single Node Failure
|
| 68 |
# ============================================================================
|
| 69 |
|
| 70 |
+
|
| 71 |
def _setup_node_failure(env: "DistributedInfraEnvironment", rng: "random.Random"):
|
| 72 |
"""One node will fail at step 5. Agent must maintain 80%+ uptime."""
|
| 73 |
sim = env.sim
|
|
|
|
| 86 |
sim = env.sim
|
| 87 |
|
| 88 |
if not sim.uptime_history:
|
| 89 |
+
return 0.01
|
| 90 |
|
| 91 |
# MTTR: how quickly system recovered from the failure
|
| 92 |
failure_duration = 0
|
|
|
|
| 109 |
restart_penalty = max(0.0, 1.0 - max(0, sim.restart_count - 1) / 5)
|
| 110 |
|
| 111 |
score = 0.40 * mttr_score + 0.40 * uptime_score + 0.20 * restart_penalty
|
| 112 |
+
return round(min(0.99, max(0.01, score)), 4)
|
| 113 |
|
| 114 |
|
| 115 |
def _is_done_node_failure(env: "DistributedInfraEnvironment") -> bool:
|
|
|
|
| 129 |
# Task 3 — Hard: Cascading Failure Prevention
|
| 130 |
# ============================================================================
|
| 131 |
|
| 132 |
+
|
| 133 |
def _setup_cascading_failure(env: "DistributedInfraEnvironment", rng: "random.Random"):
|
| 134 |
"""Two nodes near critical CPU. Agent must prevent cascade chain."""
|
| 135 |
sim = env.sim
|
|
|
|
| 162 |
cascade_score = 1.0 if not sim.cascade_occurred else 0.3
|
| 163 |
|
| 164 |
if sim.uptime_history:
|
| 165 |
+
healthy_now = sum(1 for n in sim.nodes if not n.is_failed and n.cpu_util < 0.85)
|
|
|
|
|
|
|
|
|
|
| 166 |
total_now = len(sim.nodes)
|
| 167 |
cpu_score = healthy_now / total_now if total_now > 0 else 0.0
|
| 168 |
else:
|
|
|
|
| 172 |
efficiency = max(0.0, 1.0 - sim.actions_taken / max(1, max_reasonable))
|
| 173 |
|
| 174 |
score = 0.50 * cascade_score + 0.30 * cpu_score + 0.20 * efficiency
|
| 175 |
+
return round(min(0.99, max(0.01, score)), 4)
|
| 176 |
|
| 177 |
|
| 178 |
def _is_done_cascading_failure(env: "DistributedInfraEnvironment") -> bool:
|
|
|
|
| 187 |
# Task 4 — Expert: Flash Crowd
|
| 188 |
# ============================================================================
|
| 189 |
|
| 190 |
+
|
| 191 |
def _setup_flash_crowd(env: "DistributedInfraEnvironment", rng: "random.Random"):
|
| 192 |
"""Massive 5x traffic spike. Agent must scale up AND throttle to survive."""
|
| 193 |
sim = env.sim
|
|
|
|
| 197 |
node.cpu_util = 0.60 + rng.uniform(-0.05, 0.1)
|
| 198 |
node.queue_length = rng.randint(15, 30)
|
| 199 |
|
| 200 |
+
|
| 201 |
def _grade_flash_crowd(env: "DistributedInfraEnvironment") -> float:
|
| 202 |
"""
|
| 203 |
Score = Survival Uptime (50%) + Latency control (50%).
|
| 204 |
Cascade penalty applied if the system collapses.
|
| 205 |
"""
|
| 206 |
sim = env.sim
|
| 207 |
+
|
| 208 |
+
avg_uptime = (
|
| 209 |
+
sum(sim.uptime_history) / len(sim.uptime_history) if sim.uptime_history else 0.0
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
# Latency target is more generous for a massive flash crowd (100ms)
|
| 213 |
+
target = 100.0
|
| 214 |
below_target = sum(1 for lat in sim.latency_history if lat < target)
|
| 215 |
+
latency_score = (
|
| 216 |
+
below_target / len(sim.latency_history) if sim.latency_history else 0.0
|
| 217 |
+
)
|
| 218 |
|
| 219 |
cascade_penalty = 0.4 if sim.cascade_occurred else 0.0
|
| 220 |
|
| 221 |
score = 0.50 * avg_uptime + 0.50 * latency_score - cascade_penalty
|
| 222 |
+
return round(min(0.99, max(0.01, score)), 4)
|
| 223 |
+
|
| 224 |
|
| 225 |
def _is_done_flash_crowd(env: "DistributedInfraEnvironment") -> bool:
|
| 226 |
failed_count = sum(1 for n in env.sim.nodes if n.is_failed)
|