import json from typing import Tuple, Dict from scenarios.scenario_loader import scenario_loader from core.state_manager import EpisodeState from core.reward_engine import compute_reward from core.agent_runner import AgentRunner from scenarios.graders.easy_grader import EasyGrader from scenarios.graders.medium_grader import MediumGrader from scenarios.graders.hard_grader import HardGrader from api.schemas.action import NexusAction from api.schemas.observation import NexusObservation, ToolResult from config import settings SIMULATED_TOOLS = ["read_logs", "check_config", "query_database", "check_service_status", "run_diagnostic", "update_config", "restart_service", "propose_fix", "verify_fix", "submit_resolution"] SSH_TOOLS = ["run_terminal_command", "propose_fix", "verify_fix", "submit_resolution"] class NexusEnvironment: def __init__(self): self.runner = AgentRunner() self.active_episode = None self.active_scenario = None self.graders = { "easy": EasyGrader(), "medium": MediumGrader(), "hard": HardGrader() } async def reset(self, task: str = "software-incident", scenario_id: str = None, custom_scenario: dict = None, seed: int = None, max_steps: int = None) -> NexusObservation: # Determine difficulty from task difficulty = "easy" if task == "business-process-failure": difficulty = "medium" elif task == "cascade-system-failure": difficulty = "hard" if custom_scenario: scenario = custom_scenario scenario["id"] = scenario.get("id", "custom-1") scenario["description"] = scenario.get("description", "Custom imported scenario.") scenario["context"] = scenario.get("context", "Custom uploaded environment.") if "difficulty" in scenario: difficulty = scenario["difficulty"].lower() elif scenario_id: scenario = scenario_loader.get_scenario(scenario_id) else: scenarios = scenario_loader.get_scenarios_by_difficulty(difficulty) if not scenarios: raise ValueError(f"No scenarios found for difficulty {difficulty}") import random if seed is not None: random.seed(seed) scenario = random.choice(scenarios) self.active_scenario = scenario self.active_episode = EpisodeState( scenario_id=scenario["id"], task=task, difficulty=difficulty, max_rounds=max_steps if max_steps is not None else settings.MAX_STEPS, scenario_data=scenario ) available_tools = SSH_TOOLS if settings.EXECUTION_MODE == "ssh" else SIMULATED_TOOLS obs = NexusObservation( partner_message="", tool_results=[], system_state={}, investigation_stage="investigating", round=1, available_tools=available_tools, clues_found=[], scenario_description=scenario["description"], scenario_context=scenario["context"] ) return obs async def step(self, action: NexusAction) -> Tuple[NexusObservation, float, bool, dict]: if not self.active_episode: raise ValueError("Environment must be reset before calling step") ep = self.active_episode sc = self.active_scenario # 1. Add agent message to state ep.add_message(action.agent_id, action.message) # 2. Execute tools tool_results_data = await self.runner.execute_tool_calls(action.tool_calls, sc, ep.current_round, ep) # Process tool clues tool_results_objs = [] for tr in tool_results_data: if "status: degraded" in tr['result'].lower() or "error" in tr['result'].lower() or "anomaly" in tr['result'].lower() or "warning" in tr['result'].lower() or tr['tool_name'] == 'propose_fix' or tr['tool_name'] == 'verify_fix': ep.add_clue(tr['result']) tool_results_objs.append(ToolResult(**tr)) # 3. Compute semantic reward dynamically reward, breakdown = compute_reward(action.message, action.tool_calls, tool_results_data, ep, sc) # Stop when resolution submitted or max steps taken if ep.fix_verified or ep.steps_taken >= ep.max_rounds: ep.done = True # If they maxed out without resolving, inject a synthetic report so the UI doesn't look broken if not ep.fix_verified: ep.add_tool_call("submit_resolution", { "root_cause_service": "UNRESOLVED", "root_cause_description": "Investigation terminated: Maximum round limit reached without agent consensus.", "fix_applied": "No fix was submitted." }) # Final scoring overrides semantic cumulative reward in openenv inference if grader is used # We compute it here for info grader = self.graders.get(ep.difficulty, self.graders["easy"]) final_score = grader.grade(ep, sc) info = { "breakdown": breakdown, "final_score": final_score, "success": final_score >= settings.SUCCESS_SCORE_THRESHOLD and ep.fix_verified } else: info = {"breakdown": breakdown} obs = NexusObservation( partner_message=action.message, tool_results=tool_results_objs, system_state={"total_tools_run": len(ep.tool_calls_made)}, investigation_stage=ep.investigation_stage, round=ep.current_round, available_tools=SSH_TOOLS if settings.EXECUTION_MODE == "ssh" else SIMULATED_TOOLS, clues_found=ep.clues_found, scenario_description=sc["description"], scenario_context=sc["context"] ) return obs, reward, ep.done, info def state(self): if not self.active_episode: return None return self.active_episode.to_pydantic()