File size: 4,428 Bytes
5db5e49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from __future__ import annotations

import copy
import uuid

from env import DevOpsIncidentEnv
from models import Action


class DualAgentSession:
    def __init__(self, task_id: str, seed: int = 42):
        self.session_id = str(uuid.uuid4())
        self.task_id = task_id
        self.seed = seed
        self.env = DevOpsIncidentEnv(task_id=task_id, seed=seed)
        self.full_obs = self.env.reset(seed=seed)
        self.findings_log = []
        self.step_count = 0
        self.done = False

    def _observation_dict(self) -> dict:
        if hasattr(self.full_obs, "model_dump"):
            return self.full_obs.model_dump()
        if hasattr(self.full_obs, "dict"):
            return self.full_obs.dict()
        return copy.deepcopy(self.full_obs)

    def get_observation_a(self) -> dict:
        obs = self._observation_dict()
        return {
            "step": obs["step"],
            "max_steps": obs["max_steps"],
            "task_id": obs["task_id"],
            "task_description": obs["task_description"],
            "active_alerts": copy.deepcopy(obs.get("active_alerts", [])),
            "recent_logs": copy.deepcopy(obs.get("recent_logs", {})),
            "evidence_log": copy.deepcopy(obs.get("evidence_log", [])),
            "last_action_result": obs.get("last_action_result"),
            "last_action_error": obs.get("last_action_error"),
            "elapsed_minutes": obs["elapsed_minutes"],
            "incident_start_time": obs["incident_start_time"],
            "role": "observer",
            "instructions": (
                "You are the Observer. You can ONLY call share_finding. "
                "Read logs and alerts carefully, then share findings with "
                "the Responder agent."
            ),
            "findings_from_b": [],
        }

    def get_observation_b(self) -> dict:
        obs = self._observation_dict()
        return {
            "step": obs["step"],
            "max_steps": obs["max_steps"],
            "task_id": obs["task_id"],
            "task_description": obs["task_description"],
            "services": copy.deepcopy(obs.get("services", [])),
            "service_dependencies": copy.deepcopy(obs.get("service_dependencies", [])),
            "sla_status": copy.deepcopy(obs.get("sla_status", {})),
            "last_action_result": obs.get("last_action_result"),
            "last_action_error": obs.get("last_action_error"),
            "elapsed_minutes": obs["elapsed_minutes"],
            "incident_start_time": obs["incident_start_time"],
            "role": "responder",
            "instructions": (
                "You are the Responder. Use Agent A findings plus service "
                "metrics to diagnose and fix the incident."
            ),
            "agent_a_findings": copy.deepcopy(self.findings_log),
        }

    def step_a(self, finding_text: str) -> dict:
        if self.done:
            return {"error": "episode complete"}
        if not finding_text or len(finding_text.strip()) < 5:
            return {"error": "finding too short", "reward": 0.0}
        entry = {
            "agent": "A",
            "step": self.step_count,
            "finding": finding_text.strip(),
        }
        self.findings_log.append(entry)
        return {
            "accepted": True,
            "reward": 0.05,
            "finding_recorded": entry,
            "total_findings": len(self.findings_log),
            "observation": self.get_observation_a(),
        }

    def step_b(self, action: Action) -> dict:
        if self.done:
            return {"error": "episode complete"}
        self.step_count += 1
        result = self.env.step(action)
        self.full_obs = result.observation
        if result.done:
            self.done = True
        return {
            "observation": self.get_observation_b(),
            "reward": result.reward,
            "done": result.done,
            "info": result.info,
            "agent_a_findings_count": len(self.findings_log),
        }

    def get_state(self) -> dict:
        return {
            "session_id": self.session_id,
            "task_id": self.task_id,
            "seed": self.seed,
            "step": self.step_count,
            "done": self.done,
            "findings_log": copy.deepcopy(self.findings_log),
            "observation_a": self.get_observation_a(),
            "observation_b": self.get_observation_b(),
        }