Spaces:
Sleeping
Sleeping
File size: 6,277 Bytes
6a6a0f9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | import json
from typing import Tuple, Dict
from scenarios.scenario_loader import scenario_loader
from core.state_manager import EpisodeState
from core.reward_engine import compute_reward
from core.agent_runner import AgentRunner
from scenarios.graders.easy_grader import EasyGrader
from scenarios.graders.medium_grader import MediumGrader
from scenarios.graders.hard_grader import HardGrader
from api.schemas.action import NexusAction
from api.schemas.observation import NexusObservation, ToolResult
from config import settings
SIMULATED_TOOLS = ["read_logs", "check_config", "query_database", "check_service_status", "run_diagnostic", "update_config", "restart_service", "propose_fix", "verify_fix", "submit_resolution"]
SSH_TOOLS = ["run_terminal_command", "propose_fix", "verify_fix", "submit_resolution"]
class NexusEnvironment:
def __init__(self):
self.runner = AgentRunner()
self.active_episode = None
self.active_scenario = None
self.graders = {
"easy": EasyGrader(),
"medium": MediumGrader(),
"hard": HardGrader()
}
async def reset(self, task: str = "software-incident", scenario_id: str = None, custom_scenario: dict = None, seed: int = None, max_steps: int = None) -> NexusObservation:
# Determine difficulty from task
difficulty = "easy"
if task == "business-process-failure":
difficulty = "medium"
elif task == "cascade-system-failure":
difficulty = "hard"
if custom_scenario:
scenario = custom_scenario
scenario["id"] = scenario.get("id", "custom-1")
scenario["description"] = scenario.get("description", "Custom imported scenario.")
scenario["context"] = scenario.get("context", "Custom uploaded environment.")
if "difficulty" in scenario:
difficulty = scenario["difficulty"].lower()
elif scenario_id:
scenario = scenario_loader.get_scenario(scenario_id)
else:
scenarios = scenario_loader.get_scenarios_by_difficulty(difficulty)
if not scenarios:
raise ValueError(f"No scenarios found for difficulty {difficulty}")
import random
if seed is not None:
random.seed(seed)
scenario = random.choice(scenarios)
self.active_scenario = scenario
self.active_episode = EpisodeState(
scenario_id=scenario["id"],
task=task,
difficulty=difficulty,
max_rounds=max_steps if max_steps is not None else settings.MAX_STEPS,
scenario_data=scenario
)
available_tools = SSH_TOOLS if settings.EXECUTION_MODE == "ssh" else SIMULATED_TOOLS
obs = NexusObservation(
partner_message="",
tool_results=[],
system_state={},
investigation_stage="investigating",
round=1,
available_tools=available_tools,
clues_found=[],
scenario_description=scenario["description"],
scenario_context=scenario["context"]
)
return obs
async def step(self, action: NexusAction) -> Tuple[NexusObservation, float, bool, dict]:
if not self.active_episode:
raise ValueError("Environment must be reset before calling step")
ep = self.active_episode
sc = self.active_scenario
# 1. Add agent message to state
ep.add_message(action.agent_id, action.message)
# 2. Execute tools
tool_results_data = await self.runner.execute_tool_calls(action.tool_calls, sc, ep.current_round, ep)
# Process tool clues
tool_results_objs = []
for tr in tool_results_data:
if "status: degraded" in tr['result'].lower() or "error" in tr['result'].lower() or "anomaly" in tr['result'].lower() or "warning" in tr['result'].lower() or tr['tool_name'] == 'propose_fix' or tr['tool_name'] == 'verify_fix':
ep.add_clue(tr['result'])
tool_results_objs.append(ToolResult(**tr))
# 3. Compute semantic reward dynamically
reward, breakdown = compute_reward(action.message, action.tool_calls, tool_results_data, ep, sc)
# Stop when resolution submitted or max steps taken
if ep.fix_verified or ep.steps_taken >= ep.max_rounds:
ep.done = True
# If they maxed out without resolving, inject a synthetic report so the UI doesn't look broken
if not ep.fix_verified:
ep.add_tool_call("submit_resolution", {
"root_cause_service": "UNRESOLVED",
"root_cause_description": "Investigation terminated: Maximum round limit reached without agent consensus.",
"fix_applied": "No fix was submitted."
})
# Final scoring overrides semantic cumulative reward in openenv inference if grader is used
# We compute it here for info
grader = self.graders.get(ep.difficulty, self.graders["easy"])
final_score = grader.grade(ep, sc)
info = {
"breakdown": breakdown,
"final_score": final_score,
"success": final_score >= settings.SUCCESS_SCORE_THRESHOLD and ep.fix_verified
}
else:
info = {"breakdown": breakdown}
obs = NexusObservation(
partner_message=action.message,
tool_results=tool_results_objs,
system_state={"total_tools_run": len(ep.tool_calls_made)},
investigation_stage=ep.investigation_stage,
round=ep.current_round,
available_tools=SSH_TOOLS if settings.EXECUTION_MODE == "ssh" else SIMULATED_TOOLS,
clues_found=ep.clues_found,
scenario_description=sc["description"],
scenario_context=sc["context"]
)
return obs, reward, ep.done, info
def state(self):
if not self.active_episode:
return None
return self.active_episode.to_pydantic()
|