ashishMenon05 commited on
Commit
fd5d7f9
Β·
1 Parent(s): 3168d77

fix: full resubmission patch - fix [STEP] format, add close(), expose system_state, fix /state endpoint, improve reward variance

Browse files
backend/api/routes/openenv.py CHANGED
@@ -150,11 +150,13 @@ async def step_env(action: NexusAction):
150
  except Exception as e:
151
  raise HTTPException(status_code=500, detail=str(e))
152
 
153
- @router.get("/state", response_model=NexusState)
154
  def get_state():
 
155
  state = episode_manager.env.state()
156
- if not state:
157
- raise HTTPException(status_code=400, detail="No active episode")
 
158
  return state
159
 
160
  @router.get("/telemetry")
 
150
  except Exception as e:
151
  raise HTTPException(status_code=500, detail=str(e))
152
 
153
+ @router.get("/state")
154
  def get_state():
155
+ """Returns the current episode state. Returns idle status if no episode is active."""
156
  state = episode_manager.env.state()
157
+ # state() now always returns something β€” either a NexusState pydantic object or an idle dict.
158
+ if hasattr(state, "model_dump"):
159
+ return state.model_dump()
160
  return state
161
 
162
  @router.get("/telemetry")
backend/core/environment.py CHANGED
@@ -71,7 +71,7 @@ class NexusEnvironment:
71
  obs = NexusObservation(
72
  partner_message="",
73
  tool_results=[],
74
- system_state={},
75
  investigation_stage="investigating",
76
  round=1,
77
  available_tools=available_tools,
@@ -143,7 +143,7 @@ class NexusEnvironment:
143
  obs = NexusObservation(
144
  partner_message=action.message,
145
  tool_results=tool_results_objs,
146
- system_state={"total_tools_run": len(ep.tool_calls_made)},
147
  investigation_stage=ep.investigation_stage,
148
  round=ep.current_round,
149
  available_tools=SSH_TOOLS if settings.EXECUTION_MODE == "ssh" else SIMULATED_TOOLS,
@@ -156,5 +156,11 @@ class NexusEnvironment:
156
 
157
  def state(self):
158
  if not self.active_episode:
159
- return None
 
160
  return self.active_episode.to_pydantic()
 
 
 
 
 
 
71
  obs = NexusObservation(
72
  partner_message="",
73
  tool_results=[],
74
+ system_state=self.active_episode.system_state, # Expose real state so agent sees initial conditions
75
  investigation_stage="investigating",
76
  round=1,
77
  available_tools=available_tools,
 
143
  obs = NexusObservation(
144
  partner_message=action.message,
145
  tool_results=tool_results_objs,
146
+ system_state=ep.system_state, # Return real mutated state so agent sees the effect of its actions
147
  investigation_stage=ep.investigation_stage,
148
  round=ep.current_round,
149
  available_tools=SSH_TOOLS if settings.EXECUTION_MODE == "ssh" else SIMULATED_TOOLS,
 
156
 
157
  def state(self):
158
  if not self.active_episode:
159
+ # Return a valid default state so the /state endpoint always responds
160
+ return {"status": "idle", "message": "No active episode. Call /reset to start."}
161
  return self.active_episode.to_pydantic()
162
+
163
+ async def close(self):
164
+ """Clean up the active episode. Required by OpenEnv spec."""
165
+ self.active_episode = None
166
+ self.active_scenario = None
backend/core/reward_engine.py CHANGED
@@ -3,115 +3,143 @@ import logging
3
 
4
  logger = logging.getLogger("nexus.reward_engine")
5
 
 
 
 
 
 
 
 
6
  def compute_reward(message: str, tool_calls: list, tool_results: list, episode_state, scenario: dict) -> tuple[float, dict]:
7
  breakdown = {}
8
-
9
  msg_lower = message.lower()
10
-
11
  ep = episode_state
12
  sc = scenario
13
-
14
- # 1. HYPOTHESIS SPECIFICITY (0.0-0.20)
15
- specificity_indicators = ["shows", "value", "config", "log", "found", "confirmed",
16
- "set to", "equals", "returns", "indicates", "trace", "root cause"]
17
- breakdown['specificity'] = min(0.20,
18
- sum(0.025 for word in specificity_indicators if word in msg_lower)
19
- )
20
-
21
- # 2. TOOL EXECUTION SUCCESS (0.0-0.25)
 
 
 
22
  tool_score = 0.0
23
  if tool_calls:
24
- new_tools = 0
 
25
  for t in tool_calls:
26
  sig = f"{t.tool_name}:{str(t.params)}"
27
  if sig not in ep.previous_tool_calls:
28
- new_tools += 1
29
-
30
- # Reward for using different tool categories
31
- tool_categories = set()
32
- for tc in tool_calls:
33
- if tc.tool_name in ["read_logs", "check_config", "query_database", "check_service_status"]:
34
- tool_categories.add("investigation")
35
- elif tc.tool_name in ["update_config", "restart_service"]:
36
- tool_categories.add("fix_action")
37
- elif tc.tool_name in ["propose_fix", "verify_fix"]:
38
- tool_categories.add("resolution")
39
-
40
- tool_score = min(0.25, len(tool_categories) * 0.08 + new_tools * 0.05)
41
- breakdown['tool_usage'] = tool_score
42
-
43
- # 3. TOOL RESULT QUALITY (0.0-0.15)
44
  result_score = 0.0
 
45
  for tr in tool_results:
46
- result_text = tr.get('result', '').lower() if isinstance(tr, dict) else str(tr).lower()
47
- # Positive signals in tool results
48
- if any(kw in result_text for kw in ['error', 'fail', 'degraded', 'anomaly', 'threshold']):
49
- result_score += 0.03 # Found something useful
50
- if any(kw in result_text for kw in ['rate_limit', 'nginx', 'config', 'timeout', 'connection']):
51
- result_score += 0.02 # Found relevant clue
52
- if 'success' in result_text or 'running' in result_text or 'healthy' in result_text:
53
- result_score += 0.01 # Status info
54
- breakdown['result_quality'] = min(0.15, result_score)
55
-
56
- # 4. CLUE DISCOVERY (0.0-0.20)
57
- clue_score = 0.0
58
- if hasattr(ep, 'clues_found') and ep.clues_found:
59
- clue_score = min(0.20, len(ep.clues_found) * 0.05)
60
- breakdown['clue_discovery'] = clue_score
61
-
62
- # 5. INVESTIGATION STAGE PROGRESS (0.0-0.15)
63
- stage_score = 0.0
64
- if hasattr(ep, 'investigation_stage'):
65
- stage_map = {'investigating': 0.02, 'narrowing': 0.08, 'hypothesizing': 0.12, 'found': 0.15, 'verified': 0.15}
66
- stage_score = stage_map.get(ep.investigation_stage, 0.02)
67
- breakdown['stage_progress'] = stage_score
68
-
69
- # 6. SEMANTIC SIMILARITY TO ROOT CAUSE (0.0-0.10)
 
 
 
 
 
70
  similarity_score = 0.0
71
  try:
72
- root_cause_desc = scenario.get('root_cause', {}).get('description', '')
73
- if root_cause_desc:
74
  msg_emb = get_embedding(message)
75
  rc_emb = get_embedding(root_cause_desc)
76
- sim = cos_sim(msg_emb, rc_emb)
77
- # Only reward if embedding is not fallback (has meaningful variance)
78
- if len(msg_emb) == 384 and sum(msg_emb) != 0:
79
- similarity_score = min(0.10, sim * 0.15)
80
- except:
81
  pass
82
- breakdown['semantic_similarity'] = similarity_score
83
-
84
- # 7. NOVELTY BONUS (0.0-0.05)
85
- novelty_score = 0.0
86
- if hasattr(ep, 'all_messages') and ep.all_messages:
87
- try:
88
  msg_emb = get_embedding(message)
89
- max_sim = 0
90
  for prev in ep.all_messages[-3:]:
91
  prev_emb = get_embedding(prev)
92
  sim = cos_sim(msg_emb, prev_emb)
93
- max_sim = max(max_sim, sim)
94
- novelty_score = max(0.0, 0.05 * (1 - max_sim))
95
- except:
96
- novelty_score = 0.03
97
- else:
98
- novelty_score = 0.05
99
- breakdown['novelty'] = novelty_score
100
-
101
- # PENALTIES
102
  penalty = 0.0
103
- msg_len = len(message.split())
104
- if msg_len < 5:
105
- penalty += 0.10 # Too terse
106
- if len(message) > 2000:
107
- penalty += 0.05 # Too verbose without action
108
- if breakdown['novelty'] < 0.01:
109
- penalty += 0.10 # Circular/duplicate message
110
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  total = sum(breakdown.values()) - penalty
112
  final_score = round(max(0.0, min(1.0, total)), 4)
113
-
114
  ep.reward_history.append(final_score)
115
- ep.cumulative_reward += final_score
116
-
117
  return final_score, breakdown
 
3
 
4
  logger = logging.getLogger("nexus.reward_engine")
5
 
6
+ # Root-cause keywords per difficulty β€” pre-defined for fast matching
7
+ DIFFICULTY_ROOT_CAUSE_HINTS = {
8
+ "easy": ["rate_limit", "nginx", "rate limit", "429", "proxy", "throttle"],
9
+ "medium": ["approval", "process", "workflow", "sla", "escalation", "manual"],
10
+ "hard": ["postgres", "connection pool", "long_running_query", "max_connections", "deadlock", "timeout"],
11
+ }
12
+
13
  def compute_reward(message: str, tool_calls: list, tool_results: list, episode_state, scenario: dict) -> tuple[float, dict]:
14
  breakdown = {}
 
15
  msg_lower = message.lower()
 
16
  ep = episode_state
17
  sc = scenario
18
+ difficulty = getattr(ep, "difficulty", "easy")
19
+
20
+ # 1. HYPOTHESIS QUALITY β€” Reward specificity and domain alignment (0.0-0.20)
21
+ # Check if message mentions domain-specific terms relevant to this difficulty
22
+ domain_hints = DIFFICULTY_ROOT_CAUSE_HINTS.get(difficulty, [])
23
+ domain_hits = sum(1 for hint in domain_hints if hint in msg_lower)
24
+ # General specificity β€” mentions numbers, config keys, service names
25
+ generic_specificity = sum(0.01 for word in ["set to", "equals", "config", "found", "confirmed", "root cause",
26
+ "value", "log", "trace", "indicates", "returns"] if word in msg_lower)
27
+ breakdown["hypothesis_quality"] = min(0.20, domain_hits * 0.04 + generic_specificity)
28
+
29
+ # 2. TOOL USAGE QUALITY β€” Correct tools, no repeating same call (0.0-0.25)
30
  tool_score = 0.0
31
  if tool_calls:
32
+ tool_categories = set()
33
+ new_calls = 0
34
  for t in tool_calls:
35
  sig = f"{t.tool_name}:{str(t.params)}"
36
  if sig not in ep.previous_tool_calls:
37
+ new_calls += 1
38
+ if t.tool_name in ["read_logs", "check_config", "query_database", "check_service_status", "run_diagnostic"]:
39
+ tool_categories.add("investigate")
40
+ elif t.tool_name in ["update_config", "restart_service"]:
41
+ tool_categories.add("fix")
42
+ elif t.tool_name in ["propose_fix", "verify_fix", "submit_resolution"]:
43
+ tool_categories.add("resolve")
44
+
45
+ # Reward for covering investigation before jumping to fixes
46
+ stage_coverage = len(tool_categories)
47
+ tool_score = min(0.25, stage_coverage * 0.07 + new_calls * 0.04)
48
+
49
+ breakdown["tool_usage"] = tool_score
50
+
51
+ # 3. TOOL RESULT QUALITY β€” Did the tools find actionable info? (0.0-0.15)
 
52
  result_score = 0.0
53
+ domain_found = False
54
  for tr in tool_results:
55
+ result_text = tr.get("result", "").lower() if isinstance(tr, dict) else str(tr).lower()
56
+ if any(kw in result_text for kw in ["error", "fail", "degraded", "anomaly", "threshold", "critical"]):
57
+ result_score += 0.04 # Found a symptom
58
+ if any(hint in result_text for hint in domain_hints):
59
+ result_score += 0.05 # Found a domain-specific clue
60
+ domain_found = True
61
+ if "success" in result_text or "fixed" in result_text:
62
+ result_score += 0.02 # Fix confirmed by tool
63
+ breakdown["result_quality"] = min(0.15, result_score)
64
+
65
+ # 4. CLUE ACCUMULATION β€” Discovering new clues (0.0-0.15)
66
+ # Cap at 3 clues to prevent reward hacking via tool spam
67
+ if hasattr(ep, "clues_found") and ep.clues_found:
68
+ breakdown["clue_discovery"] = min(0.15, len(ep.clues_found) * 0.04)
69
+ else:
70
+ breakdown["clue_discovery"] = 0.0
71
+
72
+ # 5. INVESTIGATION STAGE PROGRESSION β€” Rewards forward momentum (0.0-0.15)
73
+ stage_map = {
74
+ "investigating": 0.03,
75
+ "narrowing": 0.08,
76
+ "hypothesizing": 0.12,
77
+ "found": 0.15,
78
+ "verified": 0.15,
79
+ }
80
+ breakdown["stage_progress"] = stage_map.get(getattr(ep, "investigation_stage", "investigating"), 0.03)
81
+
82
+ # 6. SEMANTIC SIMILARITY TO ROOT CAUSE β€” Only if embeddings are available (0.0-0.15)
83
+ # More weight than before β€” this is the real quality signal
84
  similarity_score = 0.0
85
  try:
86
+ root_cause_desc = sc.get("root_cause", {}).get("description", "")
87
+ if root_cause_desc and message.strip():
88
  msg_emb = get_embedding(message)
89
  rc_emb = get_embedding(root_cause_desc)
90
+ # Only reward if not using the zero-variance fallback embedding
91
+ if len(msg_emb) == 384 and abs(sum(msg_emb)) > 0.001:
92
+ sim = cos_sim(msg_emb, rc_emb)
93
+ similarity_score = min(0.15, sim * 0.20)
94
+ except Exception:
95
  pass
96
+ breakdown["semantic_similarity"] = similarity_score
97
+
98
+ # 7. NOVELTY β€” Penalize circular/repetitive reasoning (0.0-0.05)
99
+ novelty_score = 0.05 # Start assuming novel
100
+ try:
101
+ if hasattr(ep, "all_messages") and len(ep.all_messages) > 1:
102
  msg_emb = get_embedding(message)
103
+ max_sim = 0.0
104
  for prev in ep.all_messages[-3:]:
105
  prev_emb = get_embedding(prev)
106
  sim = cos_sim(msg_emb, prev_emb)
107
+ if sim > max_sim:
108
+ max_sim = sim
109
+ novelty_score = max(0.0, 0.05 * (1.0 - max_sim))
110
+ except Exception:
111
+ novelty_score = 0.03
112
+ breakdown["novelty"] = novelty_score
113
+
114
+ # ── PENALTIES ─────────────────────────────────────────────────────────────
 
115
  penalty = 0.0
116
+
117
+ # Too terse: no useful reasoning
118
+ if len(message.split()) < 8:
119
+ penalty += 0.10
120
+
121
+ # Too verbose without any tool calls: wall of text, no action
122
+ if len(message) > 1200 and not tool_calls:
123
+ penalty += 0.05
124
+
125
+ # Circular reasoning: nearly identical to a recent message
126
+ if breakdown["novelty"] < 0.005:
127
+ penalty += 0.12
128
+
129
+ # Wrong domain: confidently blaming the wrong service
130
+ # Check if agent blames a red-herring service mentioned in scenario
131
+ red_herrings = sc.get("red_herrings", [])
132
+ if red_herrings:
133
+ for rh in red_herrings:
134
+ rh_lower = str(rh).lower()
135
+ if rh_lower in msg_lower and "not" not in msg_lower:
136
+ penalty += 0.05 # Fell for the red herring
137
+ break
138
+
139
  total = sum(breakdown.values()) - penalty
140
  final_score = round(max(0.0, min(1.0, total)), 4)
141
+
142
  ep.reward_history.append(final_score)
143
+ ep.cumulative_reward = round(ep.cumulative_reward + final_score, 4)
144
+
145
  return final_score, breakdown
inference.py CHANGED
@@ -5,7 +5,6 @@ NEXUS Inference Script β€” OpenEnv Competition Submission
5
 
6
  import os
7
  import sys
8
- import asyncio
9
  import re
10
  from pathlib import Path
11
 
@@ -13,140 +12,193 @@ ROOT = Path(__file__).resolve().parent
13
  sys.path.insert(0, str(ROOT))
14
  sys.path.insert(0, str(ROOT / "backend"))
15
 
16
- from dotenv import load_dotenv
17
-
18
- if (ROOT / ".env").exists():
19
- load_dotenv(ROOT / ".env", override=True)
20
- elif (ROOT / "backend" / ".env").exists():
21
- load_dotenv(ROOT / "backend" / ".env", override=True)
22
-
23
- # Fallback for defaults, will NOT override the .env we just loaded
24
- load_dotenv(ROOT / "default.env", override=False)
25
-
26
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
27
- MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct")
28
- HF_TOKEN = os.getenv("HF_TOKEN", "")
29
 
30
- if "API_BASE_URL" not in os.environ or not os.environ["API_BASE_URL"]:
31
- os.environ["API_BASE_URL"] = API_BASE_URL
32
- if "API_KEY" not in os.environ or not os.environ["API_KEY"]:
33
- os.environ["API_KEY"] = "none"
34
 
35
- # The client should NOT be initialized here at the module level.
36
- # If the evaluator imports this file before patching os.environ, it will permanently bind to fallbacks.
37
 
 
 
38
  from backend.core.environment import NexusEnvironment
39
  from backend.api.schemas.action import NexusAction, ToolCall
40
 
41
- def parse_tool_calls(text: str) -> list:
42
- tool_calls = []
43
- for match in re.finditer(r"TOOL:\s*([a-zA-Z0-9_]+)\(([^)]*)\)", text):
44
- name = match.group(1)
45
- args_s = match.group(2)
46
- params = {}
47
- for kv in re.finditer(r"(\w+)=['\"]?([^,'\"]+)['\"]?", args_s):
48
- params[kv.group(1)] = kv.group(2)
49
- tool_calls.append(ToolCall(tool_name=name, params=params))
50
- return tool_calls
51
 
52
  TASKS = [
53
- {"name": "software-incident", "difficulty": "easy"},
54
- {"name": "business-process-failure", "difficulty": "medium"},
55
- {"name": "cascade-system-failure", "difficulty": "hard"},
56
  ]
57
 
58
  SYSTEM_PROMPT = (
59
  "You are an expert incident investigator. "
60
- "Format tool calls as: TOOL: tool_name(param='value') "
 
61
  "Available tools: read_logs, check_config, query_database, check_service_status, "
62
- "propose_fix, verify_fix"
 
 
 
 
 
 
63
  )
64
 
65
- MAX_STEPS = int(os.environ.get("MAX_STEPS", "8"))
66
-
67
  def _print(line: str):
68
  print(line, flush=True)
69
 
70
- async def run():
71
- # Initialize client dynamically at runtime to correctly capture evaluator's patched os.environ
72
- from openai import OpenAI
73
- client = OpenAI(base_url=os.environ["API_BASE_URL"], api_key=os.environ["API_KEY"])
74
-
75
- try:
76
- env = NexusEnvironment()
77
-
78
- for task in TASKS:
79
- _print(f"[START] task={task['name']} env=nexus-incident-investigation model={MODEL_NAME}")
80
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  try:
82
- obs = await env.reset(task=task["name"], seed=42)
83
- except Exception as e:
84
- _print(f"[STEP] step=1 error=\"reset failed: {str(e)[:100]}\"")
85
- continue
86
-
87
- messages = [{"role": "system", "content": SYSTEM_PROMPT}]
88
- done = False
89
- step_n = 0
90
- rewards = []
91
-
92
- while not done and step_n < MAX_STEPS:
93
- step_n += 1
94
-
95
- user_content = (
96
- f"Scenario: {obs.scenario_description}\n"
97
- f"Context: {obs.scenario_context}\n"
98
- f"Round {obs.round}. Investigate and call tools."
99
- )
100
- messages.append({"role": "user", "content": user_content})
101
-
102
- action_text = ""
103
- try:
104
- resp = client.chat.completions.create(
105
- model=MODEL_NAME,
106
- messages=messages,
107
- max_tokens=300,
108
- temperature=0.7,
109
- timeout=120.0
110
- )
111
- action_text = resp.choices[0].message.content or ""
112
- except Exception as e:
113
- _print(f"[STEP] step={step_n} error=\"{str(e)[:100]}\"")
114
- break
115
-
116
- messages.append({"role": "assistant", "content": action_text})
117
-
118
- tool_calls = parse_tool_calls(action_text)
119
- action = NexusAction(
120
- agent_id="agent_a",
121
- message=action_text,
122
- tool_calls=tool_calls,
123
- confidence=0.8
124
- )
125
-
126
- try:
127
- obs, reward, done, info = await env.step(action)
128
- except Exception as e:
129
- _print(f"[STEP] step={step_n} error=\"step failed: {str(e)[:100]}\"")
130
- break
131
-
132
- rewards.append(reward)
133
-
134
- clean = action_text.replace("\n", " ")[:200]
135
- _print(
136
- f'[STEP] step={step_n} action="{clean}" '
137
- f'reward={reward:.2f} done={str(done).lower()} error=null'
138
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
- final_score = info.get("final_score", rewards[-1] if rewards else 0.0) if 'info' in dir() else 0.0
141
- success = final_score >= 0.5
142
- rewards_str = ",".join(f"{r:.2f}" for r in rewards)
 
 
 
 
 
 
 
 
 
 
 
143
  _print(
144
- f"[END] success={str(success).lower()} steps={step_n} "
145
- f"score={final_score:.3f} rewards={rewards_str}"
146
  )
147
- except Exception as e:
148
- _print(f"[ERROR] {str(e)}")
149
- sys.exit(1)
 
 
 
 
 
 
 
 
 
 
150
 
151
  if __name__ == "__main__":
152
- asyncio.run(run())
 
 
 
 
 
 
5
 
6
  import os
7
  import sys
 
8
  import re
9
  from pathlib import Path
10
 
 
12
  sys.path.insert(0, str(ROOT))
13
  sys.path.insert(0, str(ROOT / "backend"))
14
 
15
+ # ── Environment Variables (spec-required) ──────────────────────────────────────
 
 
 
 
 
 
 
 
 
16
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
17
+ MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct")
18
+ HF_TOKEN = os.getenv("HF_TOKEN")
19
 
20
+ if HF_TOKEN is None:
21
+ raise ValueError("HF_TOKEN environment variable is required")
 
 
22
 
23
+ API_KEY = os.getenv("API_KEY", HF_TOKEN)
 
24
 
25
+ # Import AFTER path setup
26
+ from openai import OpenAI # sync client β€” matches spec example exactly
27
  from backend.core.environment import NexusEnvironment
28
  from backend.api.schemas.action import NexusAction, ToolCall
29
 
30
+ # ── Config ─────────────────────────────────────────────────────────────────────
31
+ MAX_STEPS = int(os.environ.get("MAX_STEPS", "8"))
 
 
 
 
 
 
 
 
32
 
33
  TASKS = [
34
+ {"name": "software-incident", "difficulty": "easy"},
35
+ {"name": "business-process-failure", "difficulty": "medium"},
36
+ {"name": "cascade-system-failure", "difficulty": "hard"},
37
  ]
38
 
39
  SYSTEM_PROMPT = (
40
  "You are an expert incident investigator. "
41
+ "Your goal is to identify the root cause of system incidents and apply the correct fix. "
42
+ "You have access to these tools β€” call them by writing: TOOL: tool_name(param='value')\n"
43
  "Available tools: read_logs, check_config, query_database, check_service_status, "
44
+ "update_config, restart_service, propose_fix, verify_fix, submit_resolution\n\n"
45
+ "Strategy:\n"
46
+ "1. Use read_logs and check_service_status to gather evidence.\n"
47
+ "2. Use update_config or restart_service to apply your fix.\n"
48
+ "3. Use verify_fix to confirm the fix worked.\n"
49
+ "4. Call submit_resolution with root_cause_service, root_cause_description, and fix_applied.\n"
50
+ "After each tool result, update your hypothesis. The system state shown to you reflects real changes."
51
  )
52
 
53
+ # ── Helpers ────────────────────────────────────────────────────────────────────
 
54
  def _print(line: str):
55
  print(line, flush=True)
56
 
57
+ def _safe_action(text: str) -> str:
58
+ """Strip newlines and truncate for the [STEP] action field β€” NO quotes."""
59
+ return text.replace("\n", " ").replace("\r", "").strip()[:300]
60
+
61
+ def _safe_error(error: str) -> str:
62
+ """Format error for [STEP] β€” raw string, no quotes, or null."""
63
+ if not error:
64
+ return "null"
65
+ return error.replace("\n", " ").strip()[:200]
66
+
67
+ def parse_tool_calls(text: str) -> list[ToolCall]:
68
+ tool_calls = []
69
+ for match in re.finditer(r"TOOL:\s*([a-zA-Z0-9_]+)\(([^)]*)\)", text):
70
+ name = match.group(1)
71
+ args_s = match.group(2)
72
+ params = {}
73
+ for kv in re.finditer(r"(\w+)=['\"]?([^,'\"]+)['\"]?", args_s):
74
+ params[kv.group(1)] = kv.group(2).strip()
75
+ tool_calls.append(ToolCall(tool_name=name, params=params))
76
+ return tool_calls
77
+
78
+ def build_user_content(obs) -> str:
79
+ """Build the user message from the current observation, including system state."""
80
+ parts = [
81
+ f"Scenario: {obs.scenario_description}",
82
+ f"Context: {obs.scenario_context}",
83
+ f"Round: {obs.round}",
84
+ ]
85
+
86
+ # Show the agent what the system state currently looks like
87
+ if hasattr(obs, "system_state") and obs.system_state:
88
+ parts.append(f"Current system state: {obs.system_state}")
89
+
90
+ # Show tool results from last step
91
+ if hasattr(obs, "tool_results") and obs.tool_results:
92
+ results_str = "; ".join(
93
+ f"{tr.tool_name}: {tr.result}" for tr in obs.tool_results
94
+ )
95
+ parts.append(f"Tool results: {results_str}")
96
+
97
+ # Show clues found so far
98
+ if hasattr(obs, "clues_found") and obs.clues_found:
99
+ parts.append(f"Clues found: {', '.join(obs.clues_found[-5:])}")
100
+
101
+ parts.append("Investigate and call tools to find and fix the root cause.")
102
+ return "\n".join(parts)
103
+
104
+ # ── Main Inference Loop ────────────────────────────────────────────────────────
105
+ def run():
106
+ import asyncio
107
+
108
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
109
+ env = NexusEnvironment()
110
+
111
+ for task in TASKS:
112
+ _print(f"[START] task={task['name']} env=nexus-incident-investigation model={MODEL_NAME}")
113
+
114
+ # Reset environment
115
+ try:
116
+ obs = asyncio.run(env.reset(task=task["name"], seed=42))
117
+ except Exception as e:
118
+ err = _safe_error(f"reset failed: {str(e)}")
119
+ _print(f"[STEP] step=1 action=reset_attempted reward=0.00 done=true error={err}")
120
+ _print("[END] success=false steps=1 rewards=0.00")
121
+ continue
122
+
123
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}]
124
+ done = False
125
+ step_n = 0
126
+ rewards = []
127
+ last_error = "null"
128
+
129
+ while not done and step_n < MAX_STEPS:
130
+ step_n += 1
131
+
132
+ # Build user message from observation (including system state feedback)
133
+ user_content = build_user_content(obs)
134
+ messages.append({"role": "user", "content": user_content})
135
+
136
+ # Call LLM
137
+ action_text = ""
138
+ last_error = "null"
139
  try:
140
+ resp = client.chat.completions.create(
141
+ model=MODEL_NAME,
142
+ messages=messages,
143
+ max_tokens=400,
144
+ temperature=0.5,
145
+ timeout=120.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  )
147
+ action_text = resp.choices[0].message.content or ""
148
+ except Exception as e:
149
+ last_error = _safe_error(str(e))
150
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards) if rewards else "0.00"
151
+ _print(f"[STEP] step={step_n} action=llm_call_failed reward=0.00 done=true error={last_error}")
152
+ _print(f"[END] success=false steps={step_n} rewards={rewards_str}")
153
+ break
154
+
155
+ messages.append({"role": "assistant", "content": action_text})
156
+
157
+ # Parse tool calls from LLM response
158
+ tool_calls = parse_tool_calls(action_text)
159
+ action = NexusAction(
160
+ agent_id="agent_a",
161
+ message=action_text,
162
+ tool_calls=tool_calls,
163
+ confidence=0.8
164
+ )
165
 
166
+ # Step the environment
167
+ try:
168
+ obs, reward, done, info = asyncio.run(env.step(action))
169
+ except Exception as e:
170
+ last_error = _safe_error(str(e))
171
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards) if rewards else "0.00"
172
+ _print(f"[STEP] step={step_n} action={_safe_action(action_text)} reward=0.00 done=true error={last_error}")
173
+ _print(f"[END] success=false steps={step_n} rewards={rewards_str}")
174
+ break
175
+
176
+ rewards.append(reward)
177
+
178
+ # Emit [STEP] β€” NO quotes around action or error values
179
+ action_str = _safe_action(action_text)
180
  _print(
181
+ f"[STEP] step={step_n} action={action_str} "
182
+ f"reward={reward:.2f} done={str(done).lower()} error={last_error}"
183
  )
184
+ else:
185
+ # Normal loop completion β€” emit [END]
186
+ final_score = info.get("final_score", rewards[-1] if rewards else 0.0) if rewards else 0.0
187
+ success = final_score >= 0.5
188
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards) if rewards else "0.00"
189
+ _print(f"[END] success={str(success).lower()} steps={step_n} rewards={rewards_str}")
190
+
191
+ # Always close
192
+ try:
193
+ asyncio.run(env.close())
194
+ except Exception:
195
+ pass
196
+
197
 
198
  if __name__ == "__main__":
199
+ try:
200
+ run()
201
+ except Exception as e:
202
+ # Even on fatal error, emit a valid [END] if possible
203
+ print(f"[END] success=false steps=0 rewards=0.00", flush=True)
204
+ sys.exit(1)