PranavKK1201 commited on
Commit
94aef6f
·
1 Parent(s): 7c9c3e7

ui and inference changes

Browse files
deploy/grafana/provisioning/dashboards/json/antiatropos-overview.json CHANGED
@@ -76,8 +76,8 @@
76
  "targets": [
77
  {
78
  "editorMode": "code",
79
- "expr": "antiatropos_reward",
80
- "legendFormat": "reward",
81
  "range": true,
82
  "refId": "A"
83
  }
@@ -143,8 +143,8 @@
143
  "targets": [
144
  {
145
  "editorMode": "code",
146
- "expr": "antiatropos_total_queue_backlog",
147
- "legendFormat": "queue backlog",
148
  "range": true,
149
  "refId": "A"
150
  }
@@ -210,8 +210,8 @@
210
  "targets": [
211
  {
212
  "editorMode": "code",
213
- "expr": "antiatropos_average_latency_norm",
214
- "legendFormat": "latency",
215
  "range": true,
216
  "refId": "A"
217
  }
@@ -277,8 +277,8 @@
277
  "targets": [
278
  {
279
  "editorMode": "code",
280
- "expr": "antiatropos_lyapunov_energy",
281
- "legendFormat": "lyapunov energy",
282
  "range": true,
283
  "refId": "A"
284
  }
@@ -369,15 +369,15 @@
369
  "targets": [
370
  {
371
  "editorMode": "code",
372
- "expr": "antiatropos_reward",
373
- "legendFormat": "reward {{task_id}}",
374
  "range": true,
375
  "refId": "A"
376
  },
377
  {
378
  "editorMode": "code",
379
- "expr": "antiatropos_lyapunov_energy",
380
- "legendFormat": "lyapunov {{task_id}}",
381
  "range": true,
382
  "refId": "B"
383
  }
@@ -468,15 +468,15 @@
468
  "targets": [
469
  {
470
  "editorMode": "code",
471
- "expr": "antiatropos_total_queue_backlog",
472
- "legendFormat": "queue {{task_id}}",
473
  "range": true,
474
  "refId": "A"
475
  },
476
  {
477
  "editorMode": "code",
478
- "expr": "antiatropos_average_latency_norm",
479
- "legendFormat": "latency {{task_id}}",
480
  "range": true,
481
  "refId": "B"
482
  }
@@ -535,15 +535,15 @@
535
  "targets": [
536
  {
537
  "editorMode": "code",
538
- "expr": "rate(antiatropos_steps_total[1m])",
539
- "legendFormat": "steps/sec {{task_id}}",
540
  "range": true,
541
  "refId": "A"
542
  },
543
  {
544
  "editorMode": "code",
545
- "expr": "rate(antiatropos_actions_total[1m])",
546
- "legendFormat": "actions/sec {{action_type}}",
547
  "range": true,
548
  "refId": "B"
549
  }
@@ -602,14 +602,14 @@
602
  "targets": [
603
  {
604
  "editorMode": "code",
605
- "expr": "rate(antiatropos_executor_errors_total[5m])",
606
- "legendFormat": "executor errors {{error_code}}",
607
  "range": true,
608
  "refId": "A"
609
  },
610
  {
611
  "editorMode": "code",
612
- "expr": "histogram_quantile(0.95, sum(rate(antiatropos_executor_latency_ms_bucket[5m])) by (le, mode))",
613
  "legendFormat": "p95 executor latency {{mode}}",
614
  "range": true,
615
  "refId": "B"
@@ -637,6 +637,6 @@
637
  "timezone": "browser",
638
  "title": "AntiAtropos Overview",
639
  "uid": "antiatropos-overview",
640
- "version": 1,
641
  "weekStart": ""
642
  }
 
76
  "targets": [
77
  {
78
  "editorMode": "code",
79
+ "expr": "scalar(avg(last_over_time(antiatropos_reward{mode=\"simulated\"}[1m])))",
80
+ "legendFormat": "reward (simulated)",
81
  "range": true,
82
  "refId": "A"
83
  }
 
143
  "targets": [
144
  {
145
  "editorMode": "code",
146
+ "expr": "scalar(avg(last_over_time(antiatropos_total_queue_backlog{mode=\"simulated\"}[1m])))",
147
+ "legendFormat": "queue backlog (simulated)",
148
  "range": true,
149
  "refId": "A"
150
  }
 
210
  "targets": [
211
  {
212
  "editorMode": "code",
213
+ "expr": "scalar(avg(last_over_time(antiatropos_average_latency_norm{mode=\"simulated\"}[1m])))",
214
+ "legendFormat": "latency (simulated)",
215
  "range": true,
216
  "refId": "A"
217
  }
 
277
  "targets": [
278
  {
279
  "editorMode": "code",
280
+ "expr": "scalar(avg(last_over_time(antiatropos_lyapunov_energy{mode=\"simulated\"}[1m])))",
281
+ "legendFormat": "lyapunov energy (simulated)",
282
  "range": true,
283
  "refId": "A"
284
  }
 
369
  "targets": [
370
  {
371
  "editorMode": "code",
372
+ "expr": "antiatropos_reward{mode=\"simulated\"}",
373
+ "legendFormat": "reward {{task_id}} ({{mode}})",
374
  "range": true,
375
  "refId": "A"
376
  },
377
  {
378
  "editorMode": "code",
379
+ "expr": "antiatropos_lyapunov_energy{mode=\"simulated\"}",
380
+ "legendFormat": "lyapunov {{task_id}} ({{mode}})",
381
  "range": true,
382
  "refId": "B"
383
  }
 
468
  "targets": [
469
  {
470
  "editorMode": "code",
471
+ "expr": "antiatropos_total_queue_backlog{mode=\"simulated\"}",
472
+ "legendFormat": "queue {{task_id}} ({{mode}})",
473
  "range": true,
474
  "refId": "A"
475
  },
476
  {
477
  "editorMode": "code",
478
+ "expr": "antiatropos_average_latency_norm{mode=\"simulated\"}",
479
+ "legendFormat": "latency {{task_id}} ({{mode}})",
480
  "range": true,
481
  "refId": "B"
482
  }
 
535
  "targets": [
536
  {
537
  "editorMode": "code",
538
+ "expr": "sum by (task_id, mode) (rate(antiatropos_steps_total{mode=\"simulated\"}[1m]))",
539
+ "legendFormat": "steps/sec {{task_id}} ({{mode}})",
540
  "range": true,
541
  "refId": "A"
542
  },
543
  {
544
  "editorMode": "code",
545
+ "expr": "sum by (task_id, mode, action_type) (rate(antiatropos_actions_total{mode=\"simulated\"}[1m]))",
546
+ "legendFormat": "actions/sec {{action_type}} ({{task_id}}, {{mode}})",
547
  "range": true,
548
  "refId": "B"
549
  }
 
602
  "targets": [
603
  {
604
  "editorMode": "code",
605
+ "expr": "sum by (mode, error_code) (rate(antiatropos_executor_errors_total{mode=\"simulated\"}[5m]))",
606
+ "legendFormat": "executor errors {{error_code}} ({{mode}})",
607
  "range": true,
608
  "refId": "A"
609
  },
610
  {
611
  "editorMode": "code",
612
+ "expr": "histogram_quantile(0.95, sum(rate(antiatropos_executor_latency_ms_bucket{mode=\"simulated\"}[5m])) by (le, mode))",
613
  "legendFormat": "p95 executor latency {{mode}}",
614
  "range": true,
615
  "refId": "B"
 
637
  "timezone": "browser",
638
  "title": "AntiAtropos Overview",
639
  "uid": "antiatropos-overview",
640
+ "version": 2,
641
  "weekStart": ""
642
  }
inference.py CHANGED
@@ -1,50 +1,57 @@
1
  import asyncio
2
  import json
3
  import os
 
4
  import textwrap
5
- from typing import List, Optional
 
 
 
6
 
7
  from dotenv import load_dotenv
8
-
9
- load_dotenv() # Load variables from .env file
10
-
11
  from openai import AsyncOpenAI
12
 
13
  from AntiAtropos.client import AntiAtroposEnv
14
  from AntiAtropos.grader import EpisodeGrader
15
  from AntiAtropos.models import ActionType, SREAction
16
 
 
17
 
18
  API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
19
- MODEL_NAME = os.getenv("MODEL_NAME", "llama-3.1-70b-versatile")
20
  API_KEY = (
21
- os.getenv("HF_TOKEN")
22
  or os.getenv("OPENAI_API_KEY")
23
  or os.getenv("API_KEY")
24
- or os.getenv("GROQ_API_KEY")
25
  )
26
- LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
27
- ENV_URL = os.getenv("ANTIATROPOS_ENV_URL", "http://127.0.0.1:8000")
28
- TASK_NAME = os.getenv("ANTIATROPOS_TASK", "task-3")
29
- BENCHMARK = os.getenv("ANTIATROPOS_BENCHMARK", "antiatropos")
30
  ENV_MODE = os.getenv("ANTIATROPOS_MODE", "simulated")
31
- MAX_STEPS = int(os.getenv("ANTIATROPOS_MAX_STEPS", "35"))
32
- TEMPERATURE = float(os.getenv("ANTIATROPOS_TEMPERATURE", "0.05"))
 
 
 
 
 
 
 
33
  MAX_TOKENS = int(os.getenv("ANTIATROPOS_MAX_TOKENS", "180"))
 
34
  SUCCESS_SCORE_THRESHOLD = float(os.getenv("ANTIATROPOS_SUCCESS_THRESHOLD", "0.55"))
35
 
 
 
 
 
 
 
36
  SYSTEM_PROMPT = textwrap.dedent(
37
  """
38
  You are an autonomous SRE controller managing a five-node microservice cluster.
39
 
40
- Objectives:
41
- - minimize Lyapunov energy and queue growth
42
- - keep normalized average latency at or below 0.20
43
- - avoid invalid actions, especially SHED_LOAD on node-0, node-1, and node-2
44
- - scale proactively because SCALE_UP takes 5 ticks to take effect
45
- - protect the VIP gateway node-0
46
-
47
- Output exactly one JSON object:
48
  {
49
  "action_type": "SCALE_UP" | "SCALE_DOWN" | "REROUTE_TRAFFIC" | "SHED_LOAD" | "NO_OP",
50
  "target_node_id": "node-0" | "node-1" | "node-2" | "node-3" | "node-4",
@@ -54,57 +61,94 @@ SYSTEM_PROMPT = textwrap.dedent(
54
  ).strip()
55
 
56
 
57
- def log_start(task: str, env: str, model: str) -> None:
58
- print(f"[START] task={task} env={env} model={model}", flush=True)
 
 
59
 
 
 
 
60
 
61
- def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
62
- error_val = error if error else "null"
63
- done_val = str(done).lower()
64
- print(
65
- f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
66
- flush=True,
67
- )
68
 
 
 
 
69
 
70
- def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
71
- rewards_str = ",".join(f"{reward:.2f}" for reward in rewards)
72
- print(
73
- f"[END] success={str(success).lower()} steps={steps} score={score:.2f} rewards={rewards_str}",
74
- flush=True,
75
- )
76
 
 
 
 
 
 
 
 
77
 
78
- def build_user_prompt(step: int, obs: dict, history: List[str]) -> str:
79
- history_block = "\n".join(history[-4:]) if history else "None"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  return textwrap.dedent(
81
  f"""
 
 
82
  Step: {step}
 
83
  Current state:
84
  {json.dumps(obs, separators=(",", ":"))}
85
 
86
  Recent decisions:
87
- {history_block}
88
 
89
  Choose the next SRE action.
90
  """
91
  ).strip()
92
 
93
 
94
- def compact_action(action: SREAction) -> str:
95
- payload = {
96
- "action_type": action.action_type.value,
97
- "target_node_id": action.target_node_id,
98
- "parameter": round(float(action.parameter), 4),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  }
100
- return json.dumps(payload, separators=(",", ":"))
101
 
102
 
103
- def extract_json_object(text: str) -> dict:
104
  stripped = text.strip()
105
- if not stripped:
106
- raise ValueError("empty model response")
107
-
108
  start = stripped.find("{")
109
  end = stripped.rfind("}")
110
  if start == -1 or end == -1 or end < start:
@@ -112,7 +156,7 @@ def extract_json_object(text: str) -> dict:
112
  return json.loads(stripped[start : end + 1])
113
 
114
 
115
- def parse_action(payload: dict) -> SREAction:
116
  action_type = str(payload.get("action_type", "NO_OP")).upper()
117
  target_node_id = str(payload.get("target_node_id", "node-0"))
118
  parameter = float(payload.get("parameter", 0.0))
@@ -123,139 +167,168 @@ def parse_action(payload: dict) -> SREAction:
123
  )
124
 
125
 
126
- async def get_model_action(
127
- client: AsyncOpenAI,
128
- step: int,
129
- obs: dict,
130
- history: List[str],
131
- ) -> SREAction:
132
- user_prompt = build_user_prompt(step, obs, history)
133
  try:
134
  completion = await client.chat.completions.create(
135
  model=MODEL_NAME,
136
  messages=[
137
  {"role": "system", "content": SYSTEM_PROMPT},
138
- {"role": "user", "content": user_prompt},
139
  ],
140
  temperature=TEMPERATURE,
141
  max_tokens=MAX_TOKENS,
142
  response_format={"type": "json_object"},
 
 
143
  )
144
  content = completion.choices[0].message.content or ""
145
- return parse_action(extract_json_object(content))
146
- except Exception:
147
- return SREAction(
148
- action_type=ActionType.NO_OP,
149
- target_node_id="node-0",
150
- parameter=0.0,
151
- )
152
 
153
 
154
- def observation_for_model(obs) -> dict:
155
- return {
156
- "task_id": obs.task_id,
157
- "mode": getattr(obs.mode, "value", str(obs.mode)),
158
- "step": obs.step,
159
- "max_steps": obs.max_steps,
160
- "lyapunov_energy": obs.lyapunov_energy,
161
- "average_latency_ms": obs.average_latency_ms,
162
- "error_rate": obs.error_rate,
163
- "total_queue_backlog": obs.total_queue_backlog,
164
- "sla_violations": obs.sla_violations,
165
- "invalid_action_count": obs.invalid_action_count,
166
- "nodes": [
167
- {
168
- "node_id": node.node_id,
169
- "status": getattr(node.status, "value", str(node.status)),
170
- "is_vip": node.is_vip,
171
- "queue_depth": node.queue_depth,
172
- "latency_ms": node.latency_ms,
173
- "incoming_request_rate": node.incoming_request_rate,
174
- "cpu_utilization": node.cpu_utilization,
175
- }
176
- for node in obs.nodes
177
- ],
178
  }
 
179
 
180
 
181
- async def run_episode() -> None:
 
 
 
 
 
 
 
182
  rewards: List[float] = []
183
  steps_taken = 0
184
- success = False
185
- score = 0.0
186
- env = None
187
- client = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
- log_start(TASK_NAME, BENCHMARK, MODEL_NAME)
 
 
 
 
 
 
 
 
 
 
 
190
 
191
- try:
192
- if not API_KEY:
193
- raise RuntimeError("missing API key")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
- client = AsyncOpenAI(base_url=API_BASE_URL, api_key=API_KEY)
196
- if LOCAL_IMAGE_NAME:
197
- env = await AntiAtroposEnv.from_docker_image(LOCAL_IMAGE_NAME)
198
- else:
199
- env = AntiAtroposEnv(base_url=ENV_URL)
200
- await env.__aenter__()
201
 
202
- grader = EpisodeGrader(task_id=TASK_NAME)
203
- history: List[str] = []
 
 
 
 
 
204
 
205
- result = await env.reset(task_id=TASK_NAME, mode=ENV_MODE)
206
- grader.record(result.observation)
 
 
 
207
 
208
- for step in range(1, MAX_STEPS + 1):
209
- if result.done:
210
- break
211
-
212
- obs = result.observation
213
- action = await get_model_action(
214
- client=client,
215
- step=step,
216
- obs=observation_for_model(obs),
217
- history=history,
218
- )
219
- result = await env.step(action)
220
- grader.record(result.observation)
221
-
222
- reward = float(result.reward or 0.0)
223
- rewards.append(reward)
224
- steps_taken = step
225
-
226
- ack_status = getattr(result.observation, "action_ack_status", "")
227
- error = ack_status if ack_status.startswith(("Rejected:", "Error:")) else None
228
- action_str = compact_action(action)
229
- log_step(step=step, action=action_str, reward=reward, done=result.done, error=error)
230
-
231
- history.append(
232
- f"step={step} action={action_str} reward={reward:.2f} ack={ack_status or 'null'}"
233
- )
234
-
235
- if result.done:
236
- break
237
-
238
- score = max(0.0, min(1.0, grader.score().composite))
239
- success = score >= SUCCESS_SCORE_THRESHOLD
240
- except Exception as e:
241
- print(f"[CRITICAL ERROR] Episode failed to initialize: {e}")
242
- success = False
243
  finally:
244
- if client is not None:
245
- try:
246
- await client.close()
247
- except Exception:
248
- pass
249
- if env is not None:
250
- try:
251
- await env.close()
252
- except Exception:
253
- pass
254
- log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
 
 
 
 
 
255
 
256
 
257
  def main() -> None:
258
- asyncio.run(run_episode())
259
 
260
 
261
  if __name__ == "__main__":
 
1
  import asyncio
2
  import json
3
  import os
4
+ import random
5
  import textwrap
6
+ import time
7
+ from contextlib import asynccontextmanager
8
+ from typing import Dict, List
9
+ from urllib.parse import urlparse
10
 
11
  from dotenv import load_dotenv
 
 
 
12
  from openai import AsyncOpenAI
13
 
14
  from AntiAtropos.client import AntiAtroposEnv
15
  from AntiAtropos.grader import EpisodeGrader
16
  from AntiAtropos.models import ActionType, SREAction
17
 
18
+ load_dotenv()
19
 
20
  API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
21
+ MODEL_NAME = os.getenv("MODEL_NAME", "llama-3.1-8b-instant")
22
  API_KEY = (
23
+ os.getenv("GROQ_API_KEY") # prioritize Groq key since we default to groq API
24
  or os.getenv("OPENAI_API_KEY")
25
  or os.getenv("API_KEY")
26
+ or os.getenv("HF_TOKEN")
27
  )
28
+
29
+ ENV_URL = os.getenv("ANTIATROPOS_ENV_URL", "https://pranavkk-antiatropos.hf.space")
 
 
30
  ENV_MODE = os.getenv("ANTIATROPOS_MODE", "simulated")
31
+ TASKS = ["task-1", "task-2", "task-3"]
32
+
33
+ TOTAL_BUDGET_SECONDS = 1080 # 18-minute limit
34
+ MIN_TASK_BUDGET_SECONDS = 60
35
+ MAX_STEPS_PER_TASK = 60 # 60 steps = ~5 minutes at this rate
36
+ MESSAGE_TIMEOUT_S = 300
37
+ MODEL_TIMEOUT_S = 25
38
+
39
+ TEMPERATURE = float(os.getenv("ANTIATROPOS_TEMPERATURE", "0.0"))
40
  MAX_TOKENS = int(os.getenv("ANTIATROPOS_MAX_TOKENS", "180"))
41
+ SEED = int(os.getenv("ANTIATROPOS_SEED", "42"))
42
  SUCCESS_SCORE_THRESHOLD = float(os.getenv("ANTIATROPOS_SUCCESS_THRESHOLD", "0.55"))
43
 
44
+ TASK_BRIEFS: Dict[str, str] = {
45
+ "task-1": "Traffic increases linearly. Scale proactively to keep latency low and cost efficient.",
46
+ "task-2": "A node fails randomly. Detect quickly and recover with reroute/scale actions.",
47
+ "task-3": "Protect VIP node-0 under surges. Keep VIP healthy without invalid actions.",
48
+ }
49
+
50
  SYSTEM_PROMPT = textwrap.dedent(
51
  """
52
  You are an autonomous SRE controller managing a five-node microservice cluster.
53
 
54
+ Return exactly one JSON object:
 
 
 
 
 
 
 
55
  {
56
  "action_type": "SCALE_UP" | "SCALE_DOWN" | "REROUTE_TRAFFIC" | "SHED_LOAD" | "NO_OP",
57
  "target_node_id": "node-0" | "node-1" | "node-2" | "node-3" | "node-4",
 
61
  ).strip()
62
 
63
 
64
+ def _seed_everything(seed: int) -> None:
65
+ random.seed(seed)
66
+ try:
67
+ import numpy as np
68
 
69
+ np.random.seed(seed)
70
+ except Exception:
71
+ pass
72
 
 
 
 
 
 
 
 
73
 
74
+ def _task_seed(base_seed: int, task_id: str) -> int:
75
+ offsets = {"task-1": 0, "task-2": 1, "task-3": 2}
76
+ return int(base_seed + offsets.get(task_id, 0))
77
 
 
 
 
 
 
 
78
 
79
+ def _hf_web_fallback_url(base_url: str) -> str:
80
+ parsed = urlparse(base_url)
81
+ host = parsed.netloc.lower()
82
+ path = parsed.path.rstrip("/")
83
+ if host.endswith(".hf.space") and path == "":
84
+ return base_url.rstrip("/") + "/web"
85
+ return base_url
86
 
87
+
88
+ @asynccontextmanager
89
+ async def open_env_with_ws_fallback(base_url: str, message_timeout_s: int):
90
+ try:
91
+ async with AntiAtroposEnv(base_url, message_timeout_s=message_timeout_s) as env:
92
+ yield env
93
+ return
94
+ except ConnectionError as e:
95
+ fallback_url = _hf_web_fallback_url(base_url)
96
+ if fallback_url == base_url or "404" not in str(e):
97
+ raise
98
+ print(f"[connect] ws 404 on {base_url}; retrying with {fallback_url}", flush=True)
99
+ async with AntiAtroposEnv(fallback_url, message_timeout_s=message_timeout_s) as env:
100
+ yield env
101
+
102
+
103
+ def build_user_prompt(task_id: str, step: int, obs: dict, history: List[str]) -> str:
104
+ recent = "\n".join(history[-4:]) if history else "None"
105
+ brief = TASK_BRIEFS.get(task_id, "Maintain SLA, stability, and efficient cost.")
106
  return textwrap.dedent(
107
  f"""
108
+ Task: {task_id}
109
+ Objective: {brief}
110
  Step: {step}
111
+
112
  Current state:
113
  {json.dumps(obs, separators=(",", ":"))}
114
 
115
  Recent decisions:
116
+ {recent}
117
 
118
  Choose the next SRE action.
119
  """
120
  ).strip()
121
 
122
 
123
+ def observation_for_model(obs) -> dict:
124
+ return {
125
+ "task_id": obs.task_id,
126
+ "mode": getattr(obs.mode, "value", str(obs.mode)),
127
+ "step": obs.step,
128
+ "max_steps": obs.max_steps,
129
+ "lyapunov_energy": obs.lyapunov_energy,
130
+ "average_latency_ms": obs.average_latency_ms,
131
+ "error_rate": obs.error_rate,
132
+ "total_queue_backlog": obs.total_queue_backlog,
133
+ "sla_violations": obs.sla_violations,
134
+ "invalid_action_count": obs.invalid_action_count,
135
+ "nodes": [
136
+ {
137
+ "node_id": node.node_id,
138
+ "status": getattr(node.status, "value", str(node.status)),
139
+ "is_vip": node.is_vip,
140
+ "queue_depth": node.queue_depth,
141
+ "latency_ms": node.latency_ms,
142
+ "incoming_request_rate": node.incoming_request_rate,
143
+ "cpu_utilization": node.cpu_utilization,
144
+ }
145
+ for node in obs.nodes
146
+ ],
147
  }
 
148
 
149
 
150
+ def _extract_json_object(text: str) -> dict:
151
  stripped = text.strip()
 
 
 
152
  start = stripped.find("{")
153
  end = stripped.rfind("}")
154
  if start == -1 or end == -1 or end < start:
 
156
  return json.loads(stripped[start : end + 1])
157
 
158
 
159
+ def _parse_action(payload: dict) -> SREAction:
160
  action_type = str(payload.get("action_type", "NO_OP")).upper()
161
  target_node_id = str(payload.get("target_node_id", "node-0"))
162
  parameter = float(payload.get("parameter", 0.0))
 
167
  )
168
 
169
 
170
+ async def get_model_action(client: AsyncOpenAI, task_id: str, step: int, obs: dict, history: List[str]) -> SREAction:
171
+ prompt = build_user_prompt(task_id=task_id, step=step, obs=obs, history=history)
 
 
 
 
 
172
  try:
173
  completion = await client.chat.completions.create(
174
  model=MODEL_NAME,
175
  messages=[
176
  {"role": "system", "content": SYSTEM_PROMPT},
177
+ {"role": "user", "content": prompt},
178
  ],
179
  temperature=TEMPERATURE,
180
  max_tokens=MAX_TOKENS,
181
  response_format={"type": "json_object"},
182
+ timeout=MODEL_TIMEOUT_S,
183
+ seed=SEED,
184
  )
185
  content = completion.choices[0].message.content or ""
186
+ return _parse_action(_extract_json_object(content))
187
+ except Exception as e:
188
+ print(f"[LLM_ERROR] task={task_id} step={step} error={type(e).__name__}: {e}", flush=True)
189
+ return SREAction(action_type=ActionType.NO_OP, target_node_id="node-0", parameter=0.0)
 
 
 
190
 
191
 
192
+ def _compact_action(action: SREAction) -> str:
193
+ payload = {
194
+ "action_type": action.action_type.value,
195
+ "target_node_id": action.target_node_id,
196
+ "parameter": round(float(action.parameter), 4),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  }
198
+ return json.dumps(payload, separators=(",", ":"))
199
 
200
 
201
+ async def run_single_task(env: AntiAtroposEnv, client: AsyncOpenAI, task_id: str, deadline: float) -> dict:
202
+ start = time.monotonic()
203
+ task_seed = _task_seed(SEED, task_id)
204
+ result = await env.reset(task_id=task_id, mode=ENV_MODE, seed=task_seed)
205
+
206
+ grader = EpisodeGrader(task_id=task_id)
207
+ grader.record(result.observation)
208
+ history: List[str] = []
209
  rewards: List[float] = []
210
  steps_taken = 0
211
+ timed_out = False
212
+
213
+ for step in range(1, MAX_STEPS_PER_TASK + 1):
214
+ if time.monotonic() >= deadline:
215
+ timed_out = True
216
+ break
217
+ if result.done:
218
+ break
219
+
220
+ action = await get_model_action(
221
+ client=client,
222
+ task_id=task_id,
223
+ step=step,
224
+ obs=observation_for_model(result.observation),
225
+ history=history,
226
+ )
227
+ result = await env.step(action)
228
+ grader.record(result.observation)
229
 
230
+ reward = float(result.reward or 0.0)
231
+ rewards.append(reward)
232
+ steps_taken = step
233
+ ack = getattr(result.observation, "action_ack_status", "")
234
+ action_str = _compact_action(action)
235
+ history.append(f"step={step} action={action_str} reward={reward:.4f} ack={ack or 'null'}")
236
+
237
+ error = ack if ack.startswith(("Rejected:", "Error:")) else None
238
+ print(
239
+ f"[STEP] task={task_id} step={step} action={action_str} reward={reward:.4f} done={str(result.done).lower()} error={error or 'null'}",
240
+ flush=True,
241
+ )
242
 
243
+ grade = grader.score()
244
+ score = max(0.0, min(1.0, float(grade.composite)))
245
+ elapsed = time.monotonic() - start
246
+ success = score >= SUCCESS_SCORE_THRESHOLD and not timed_out
247
+ print(
248
+ f"[TASK_END] task={task_id} success={str(success).lower()} score={score:.4f} "
249
+ f"steps={steps_taken} elapsed_s={elapsed:.1f} timed_out={str(timed_out).lower()} seed={task_seed}",
250
+ flush=True,
251
+ )
252
+ return {
253
+ "task_id": task_id,
254
+ "success": success,
255
+ "score": score,
256
+ "steps": steps_taken,
257
+ "elapsed_seconds": elapsed,
258
+ "timed_out": timed_out,
259
+ "grade_summary": grade.summary(),
260
+ "rewards": rewards,
261
+ }
262
 
 
 
 
 
 
 
263
 
264
+ async def run_all_tasks() -> None:
265
+ _seed_everything(SEED)
266
+ tasks = [task for task in TASKS if task in {"task-1", "task-2", "task-3"}]
267
+ if not tasks:
268
+ raise RuntimeError("ANTIATROPOS_TASKS must include at least one of: task-1,task-2,task-3")
269
+ if not API_KEY:
270
+ raise RuntimeError("Missing API key (HF_TOKEN/OPENAI_API_KEY/API_KEY/GROQ_API_KEY).")
271
 
272
+ print(
273
+ f"[START] tasks={','.join(tasks)} env={ENV_URL} mode={ENV_MODE} model={MODEL_NAME} "
274
+ f"budget_s={TOTAL_BUDGET_SECONDS} seed={SEED}",
275
+ flush=True,
276
+ )
277
 
278
+ start = time.monotonic()
279
+ deadline = start + TOTAL_BUDGET_SECONDS
280
+ reports: List[dict] = []
281
+
282
+ client = AsyncOpenAI(base_url=API_BASE_URL, api_key=API_KEY)
283
+ try:
284
+ async with open_env_with_ws_fallback(ENV_URL, MESSAGE_TIMEOUT_S) as env:
285
+ for idx, task_id in enumerate(tasks):
286
+ now = time.monotonic()
287
+ if now >= deadline:
288
+ print(f"[BUDGET] stopping before {task_id}; time budget exhausted", flush=True)
289
+ break
290
+
291
+ remaining_tasks = len(tasks) - idx
292
+ remaining_seconds = max(0.0, deadline - now)
293
+ allocated_seconds = max(
294
+ float(MIN_TASK_BUDGET_SECONDS),
295
+ remaining_seconds / float(remaining_tasks),
296
+ )
297
+ task_deadline = min(deadline, now + allocated_seconds)
298
+ print(
299
+ f"[BUDGET] task={task_id} allocated_s={allocated_seconds:.1f} "
300
+ f"remaining_s={remaining_seconds:.1f} remaining_tasks={remaining_tasks}",
301
+ flush=True,
302
+ )
303
+
304
+ report = await run_single_task(
305
+ env=env,
306
+ client=client,
307
+ task_id=task_id,
308
+ deadline=task_deadline,
309
+ )
310
+ reports.append(report)
 
 
311
  finally:
312
+ await client.close()
313
+
314
+ total_elapsed = time.monotonic() - start
315
+ completed_scores = [r["score"] for r in reports]
316
+ aggregate_score = sum(completed_scores) / len(completed_scores) if completed_scores else 0.0
317
+ aggregate_score = max(0.0, min(1.0, aggregate_score))
318
+ all_success = len(reports) == len(tasks) and all(r["success"] for r in reports)
319
+
320
+ for report in reports:
321
+ print(f"[GRADE] {report['grade_summary']}", flush=True)
322
+
323
+ print(
324
+ f"[END] success={str(all_success).lower()} completed_tasks={len(reports)}/{len(tasks)} "
325
+ f"aggregate_score={aggregate_score:.4f} elapsed_s={total_elapsed:.1f}",
326
+ flush=True,
327
+ )
328
 
329
 
330
  def main() -> None:
331
+ asyncio.run(run_all_tasks())
332
 
333
 
334
  if __name__ == "__main__":
server/AntiAtropos_environment.py CHANGED
@@ -79,7 +79,7 @@ class AntiAtroposEnvironment(Environment):
79
  self._reward_output_mode = "normalized"
80
  self._last_metric_time: float = 0.0
81
 
82
- def reset(self, task_id: str = "task-1", mode: str = "simulated") -> ClusterObservation:
83
  """
84
  Start a fresh episode with a specific task profile and mode.
85
  """
@@ -110,7 +110,7 @@ class AntiAtroposEnvironment(Environment):
110
  # self._telemetry = PrometheusClient(url=os.getenv("PROMETHEUS_URL"))
111
  pass
112
 
113
- self._sim.reset(task_id=task_id)
114
 
115
  # If in hybrid mode, immediately pull a baseline
116
  if self._mode in [EnvironmentMode.HYBRID, EnvironmentMode.LIVE]:
@@ -396,4 +396,3 @@ class AntiAtroposEnvironment(Environment):
396
  reward=0.0,
397
  )
398
 
399
-
 
79
  self._reward_output_mode = "normalized"
80
  self._last_metric_time: float = 0.0
81
 
82
+ def reset(self, task_id: str = "task-1", mode: str = "simulated", seed: int | None = None) -> ClusterObservation:
83
  """
84
  Start a fresh episode with a specific task profile and mode.
85
  """
 
110
  # self._telemetry = PrometheusClient(url=os.getenv("PROMETHEUS_URL"))
111
  pass
112
 
113
+ self._sim.reset(task_id=task_id, seed=seed)
114
 
115
  # If in hybrid mode, immediately pull a baseline
116
  if self._mode in [EnvironmentMode.HYBRID, EnvironmentMode.LIVE]:
 
396
  reward=0.0,
397
  )
398
 
 
simulator.py CHANGED
@@ -133,6 +133,7 @@ class ClusterSimulator:
133
  # Default to non-deterministic RNG seeding so fresh simulator instances
134
  # do not replay identical domain-randomization sequences.
135
  # Pass an explicit seed for reproducible experiments.
 
136
  self._rng = random.Random(seed)
137
  self._tick_count: int = 0
138
  self._failed_node_id: Optional[str] = None
@@ -179,8 +180,12 @@ class ClusterSimulator:
179
  for i in range(self._n_nodes)
180
  ]
181
 
182
- def reset(self, task_id: str = "task-1") -> None:
183
  """Restart the simulator for a fresh episode."""
 
 
 
 
184
  self._task_id = task_id
185
  self._tick_count = 0
186
  self._failed_node_id = None
@@ -482,4 +487,3 @@ class ClusterSimulator:
482
  # We'll skip it if we just reconciled to keep the blended values, OR refine it.
483
  # For now, let's just make sure statuses are updated based on new queue depths.
484
  self._update_statuses()
485
-
 
133
  # Default to non-deterministic RNG seeding so fresh simulator instances
134
  # do not replay identical domain-randomization sequences.
135
  # Pass an explicit seed for reproducible experiments.
136
+ self._seed: Optional[int] = seed
137
  self._rng = random.Random(seed)
138
  self._tick_count: int = 0
139
  self._failed_node_id: Optional[str] = None
 
180
  for i in range(self._n_nodes)
181
  ]
182
 
183
+ def reset(self, task_id: str = "task-1", seed: Optional[int] = None) -> None:
184
  """Restart the simulator for a fresh episode."""
185
+ if seed is not None:
186
+ self._seed = seed
187
+ # Reinitialize RNG so episode generation is reproducible for a given seed.
188
+ self._rng = random.Random(seed)
189
  self._task_id = task_id
190
  self._tick_count = 0
191
  self._failed_node_id = None
 
487
  # We'll skip it if we just reconciled to keep the blended values, OR refine it.
488
  # For now, let's just make sure statuses are updated based on new queue depths.
489
  self._update_statuses()