File size: 6,277 Bytes
6a6a0f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import json
from typing import Tuple, Dict

from scenarios.scenario_loader import scenario_loader
from core.state_manager import EpisodeState
from core.reward_engine import compute_reward
from core.agent_runner import AgentRunner
from scenarios.graders.easy_grader import EasyGrader
from scenarios.graders.medium_grader import MediumGrader
from scenarios.graders.hard_grader import HardGrader
from api.schemas.action import NexusAction
from api.schemas.observation import NexusObservation, ToolResult
from config import settings

SIMULATED_TOOLS = ["read_logs", "check_config", "query_database", "check_service_status", "run_diagnostic", "update_config", "restart_service", "propose_fix", "verify_fix", "submit_resolution"]
SSH_TOOLS = ["run_terminal_command", "propose_fix", "verify_fix", "submit_resolution"]

class NexusEnvironment:
    def __init__(self):
        self.runner = AgentRunner()
        self.active_episode = None
        self.active_scenario = None
        
        self.graders = {
            "easy": EasyGrader(),
            "medium": MediumGrader(),
            "hard": HardGrader()
        }

    async def reset(self, task: str = "software-incident", scenario_id: str = None, custom_scenario: dict = None, seed: int = None, max_steps: int = None) -> NexusObservation:
        # Determine difficulty from task
        difficulty = "easy"
        if task == "business-process-failure":
            difficulty = "medium"
        elif task == "cascade-system-failure":
            difficulty = "hard"
            
        if custom_scenario:
            scenario = custom_scenario
            scenario["id"] = scenario.get("id", "custom-1")
            scenario["description"] = scenario.get("description", "Custom imported scenario.")
            scenario["context"] = scenario.get("context", "Custom uploaded environment.")
            if "difficulty" in scenario:
                 difficulty = scenario["difficulty"].lower()
        elif scenario_id:
            scenario = scenario_loader.get_scenario(scenario_id)
        else:
            scenarios = scenario_loader.get_scenarios_by_difficulty(difficulty)
            if not scenarios:
                raise ValueError(f"No scenarios found for difficulty {difficulty}")
            import random
            if seed is not None:
                random.seed(seed)
            scenario = random.choice(scenarios)
            
        self.active_scenario = scenario
        self.active_episode = EpisodeState(
            scenario_id=scenario["id"],
            task=task,
            difficulty=difficulty,
            max_rounds=max_steps if max_steps is not None else settings.MAX_STEPS,
            scenario_data=scenario
        )
        
        available_tools = SSH_TOOLS if settings.EXECUTION_MODE == "ssh" else SIMULATED_TOOLS
        obs = NexusObservation(
            partner_message="",
            tool_results=[],
            system_state={},
            investigation_stage="investigating",
            round=1,
            available_tools=available_tools,
            clues_found=[],
            scenario_description=scenario["description"],
            scenario_context=scenario["context"]
        )
        return obs
        
    async def step(self, action: NexusAction) -> Tuple[NexusObservation, float, bool, dict]:
        if not self.active_episode:
            raise ValueError("Environment must be reset before calling step")
            
        ep = self.active_episode
        sc = self.active_scenario
        
        # 1. Add agent message to state
        ep.add_message(action.agent_id, action.message)
        
        # 2. Execute tools
        tool_results_data = await self.runner.execute_tool_calls(action.tool_calls, sc, ep.current_round, ep)
        
        # Process tool clues
        tool_results_objs = []
        for tr in tool_results_data:
            if "status: degraded" in tr['result'].lower() or "error" in tr['result'].lower() or "anomaly" in tr['result'].lower() or "warning" in tr['result'].lower() or tr['tool_name'] == 'propose_fix' or tr['tool_name'] == 'verify_fix':
                ep.add_clue(tr['result'])
            tool_results_objs.append(ToolResult(**tr))
            
        # 3. Compute semantic reward dynamically
        reward, breakdown = compute_reward(action.message, action.tool_calls, tool_results_data, ep, sc)
        
        # Stop when resolution submitted or max steps taken
        if ep.fix_verified or ep.steps_taken >= ep.max_rounds:
            ep.done = True
            
            # If they maxed out without resolving, inject a synthetic report so the UI doesn't look broken
            if not ep.fix_verified:
                ep.add_tool_call("submit_resolution", {
                    "root_cause_service": "UNRESOLVED",
                    "root_cause_description": "Investigation terminated: Maximum round limit reached without agent consensus.",
                    "fix_applied": "No fix was submitted."
                })
            
            # Final scoring overrides semantic cumulative reward in openenv inference if grader is used
            # We compute it here for info
            grader = self.graders.get(ep.difficulty, self.graders["easy"])
            final_score = grader.grade(ep, sc)
            
            info = {
                "breakdown": breakdown,
                "final_score": final_score,
                "success": final_score >= settings.SUCCESS_SCORE_THRESHOLD and ep.fix_verified
            }
        else:
            info = {"breakdown": breakdown}

        obs = NexusObservation(
            partner_message=action.message,
            tool_results=tool_results_objs,
            system_state={"total_tools_run": len(ep.tool_calls_made)},
            investigation_stage=ep.investigation_stage,
            round=ep.current_round,
            available_tools=SSH_TOOLS if settings.EXECUTION_MODE == "ssh" else SIMULATED_TOOLS,
            clues_found=ep.clues_found,
            scenario_description=sc["description"],
            scenario_context=sc["context"]
        )
        
        return obs, reward, ep.done, info

    def state(self):
        if not self.active_episode:
            return None
        return self.active_episode.to_pydantic()