File size: 8,435 Bytes
cb0d29a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import re
import asyncio
from typing import List
from api.schemas.action import ToolCall
from models.model_manager import model_manager
from tools.tool_registry import registry
from utils.logger import logger
from config import settings

ROLE_DEFINITIONS = {
    "INVESTIGATOR": "You are an expert incident investigator with deep systems knowledge. Your job: form specific hypotheses, test them with tools, eliminate dead ends, and find the root cause. Be direct. Be technical. Never be vague.",
    "VALIDATOR": "You are an expert systems validator and devil's advocate. Your job: challenge every hypothesis with evidence, find edge cases, and verify fixes. Do NOT simply agree. If your partner is wrong, prove it with tools. If they found the root cause, verify it thoroughly before accepting.",
    "FORENSIC_ANALYST": "You are an elite digital forensic analyst. Focus heavily on reading logs, inspecting file timestamps, memory dumps, and tracing bash histories. Do not guess—look for precise digital fingerprints and anomalies deep in the system logs.",
    "NETWORK_ENGINEER": "You are a senior network engineer. Focus exclusively on port communication, routing tables, ping, active TCP/UDP connections, and firewall configurations. Your instinct is to determine how the threat or failure is propagating across the internal mesh.",
    "SYSTEM_ADMIN": "You are a grizzled system administrator. Focus on core OS services, user permissions, kernel messages (dmesg), cron jobs, and runaway processes. Fix operational misconfigurations logically and decisively.",
    "SECURITY_ARCHITECT": "You are a rigorous security architect. You look for structural flaws: overly permissive firewall rules, unencrypted traffic flows, and exposed secrets. Treat every incident as a potential systemic breach.",
    "COMPLIANCE_OFFICER": "You are a strict compliance and audit officer. You focus strictly on policy violations and unauthorized data access. Identify unapproved tools and non-compliant execution paths. Prioritize strict adherence over operational speed."
}

TOOL_INSTRUCTIONS_SIMULATED = """

You have access to simulation tools. When calling a tool write exactly: TOOL: tool_name(param="value")

You can call multiple tools per message. You must use tools like update_config and restart_service to fix the system. 

Once the fix is verified entirely, call TOOL: submit_resolution(root_cause_service="...", root_cause_description="...", fix_applied="...") to end the investigation.

"""

TOOL_INSTRUCTIONS_SSH = """

You are operating on a LIVE remote Linux server. You have a real bash terminal via the run_terminal_command tool. USE IT AGGRESSIVELY.

You have root access. Do NOT theorize without evidence — run actual commands to get facts.

When calling a tool write exactly: TOOL: run_terminal_command(command="your bash command here")

Examples: TOOL: run_terminal_command(command="journalctl -n 50 --no-pager")

          TOOL: run_terminal_command(command="systemctl status nginx")

You can also call propose_fix. Once the fix is verified, call TOOL: submit_resolution(root_cause_service="...", root_cause_description="...", fix_applied="...") to end the investigation.

"""

class AgentRunner:
    def parse_tool_calls(self, message: str) -> List[ToolCall]:
        # Parse "TOOL: tool_name(param="value")"
        tool_calls = []
        pattern = r'TOOL:\s*([a-zA-Z0-9_]+)\((.*?)\)'
        matches = re.finditer(pattern, message)
        
        for match in matches:
            tool_name = match.group(1)
            params_str = match.group(2)
            
            # Simple param parsing - expects key="value", key='value' or key=value
            params = {}
            if params_str.strip():
                param_pairs = params_str.split(',')
                for pair in param_pairs:
                    if '=' in pair:
                        k, v = pair.split('=', 1)
                        k = k.strip()
                        v = v.strip().strip('"').strip("'")
                        params[k] = v
            
            tool_calls.append(ToolCall(tool_name=tool_name, params=params))
            
        return tool_calls

    async def execute_tool_calls(self, tool_calls: List[ToolCall], scenario: dict, round_num: int, episode_state) -> List[dict]:
        results = []
        for tc in tool_calls:
            # We call the registry. In reality, MCP might be external but here it's in-process registry calls
            # Episode state is passed for propose_fix and verify_fix
            res_str = registry.call_tool(tc.tool_name, tc.params, scenario, round_num, episode_state)
            
            # Record it in state
            episode_state.add_tool_call(tc.tool_name, tc.params)
            
            results.append({
                "tool_name": tc.tool_name,
                "result": res_str,
                "success": not res_str.startswith("Error")
            })
        return results

    async def run_step(self, agent_id: str, episode_state, scenario: dict, max_retries: int = 2):
        client, model_name = model_manager.get_client(agent_id)
        
        is_ssh = settings.EXECUTION_MODE == "ssh"
        tool_rules = TOOL_INSTRUCTIONS_SSH if is_ssh else TOOL_INSTRUCTIONS_SIMULATED
        
        agent_config = next((a for a in settings.AGENTS if a["id"] == agent_id), {})
        role = agent_config.get("role", "INVESTIGATOR")
        custom_prompt = agent_config.get("system_prompt", "")

        if role.startswith("CUSTOM_") and custom_prompt:
            sys_prompt = custom_prompt + "\n\n" + tool_rules
        else:
            behavior = ROLE_DEFINITIONS.get(role, ROLE_DEFINITIONS["INVESTIGATOR"])
            sys_prompt = behavior + "\n\n" + tool_rules

        context = f"Current incident: {scenario.get('description', '')}\n"
        
        other_agents = [a["id"] for a in settings.AGENTS if a["id"] != agent_id]
        if other_agents:
            context += f"Other agents in this investigation: {', '.join(other_agents)}\n"
        
        agent_configs = {a["id"]: a for a in settings.AGENTS}
        for other_id in other_agents:
            other_msgs = episode_state.messages_by_agent.get(other_id, [])
            if other_msgs:
                other_role = agent_configs.get(other_id, {}).get("role", "AGENT")
                last_msg = other_msgs[-1] if other_msgs else ""
                context += f"\n[{other_role}] {other_id}'s latest insight: {last_msg[:300]}...\n"
        
        if hasattr(episode_state, 'clues_found') and episode_state.clues_found:
            context += f"\nClues discovered so far:\n"
            for clue in episode_state.clues_found[-5:]:
                context += f"- {clue[:200]}\n"
        
        messages = [{"role": "system", "content": sys_prompt}]
        
        recent_msgs = episode_state.all_messages[-6:]
        if recent_msgs:
            context += "\nRecent conversation history:\n"
            for i, m in enumerate(recent_msgs[-4:]):
                if len(m) > 150:
                    m = m[:150] + "..."
                context += f"- {m}\n"
                
        messages.append({"role": "user", "content": context})
        
        full_response = ""
        last_error = None
        
        for attempt in range(max_retries + 1):
            try:
                stream = await client.chat.completions.create(
                    model=model_name,
                    messages=messages,
                    max_tokens=2048,
                    timeout=120.0,
                    stream=True
                )
                async for chunk in stream:
                    content = chunk.choices[0].delta.content or ""
                    if content:
                        full_response += content
                        yield content
                return
            except Exception as e:
                last_error = e
                logger.warning(f"Attempt {attempt + 1}/{max_retries + 1} failed for {model_name}: {e}")
                if attempt < max_retries:
                    await asyncio.sleep(2 ** attempt)
        
        logger.error(f"All retries exhausted for {model_name}: {last_error}")
        yield f"I encountered an error: {last_error}. Please verify the model endpoint is accessible."