Arijit-07 commited on
Commit
5db5e49
·
1 Parent(s): 5e1d2a3

fix: add multi_agent package to HF Space

Browse files
Files changed (2) hide show
  1. multi_agent/__init__.py +3 -0
  2. multi_agent/session.py +118 -0
multi_agent/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from multi_agent.session import DualAgentSession
2
+
3
+ __all__ = ["DualAgentSession"]
multi_agent/session.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ import uuid
5
+
6
+ from env import DevOpsIncidentEnv
7
+ from models import Action
8
+
9
+
10
+ class DualAgentSession:
11
+ def __init__(self, task_id: str, seed: int = 42):
12
+ self.session_id = str(uuid.uuid4())
13
+ self.task_id = task_id
14
+ self.seed = seed
15
+ self.env = DevOpsIncidentEnv(task_id=task_id, seed=seed)
16
+ self.full_obs = self.env.reset(seed=seed)
17
+ self.findings_log = []
18
+ self.step_count = 0
19
+ self.done = False
20
+
21
+ def _observation_dict(self) -> dict:
22
+ if hasattr(self.full_obs, "model_dump"):
23
+ return self.full_obs.model_dump()
24
+ if hasattr(self.full_obs, "dict"):
25
+ return self.full_obs.dict()
26
+ return copy.deepcopy(self.full_obs)
27
+
28
+ def get_observation_a(self) -> dict:
29
+ obs = self._observation_dict()
30
+ return {
31
+ "step": obs["step"],
32
+ "max_steps": obs["max_steps"],
33
+ "task_id": obs["task_id"],
34
+ "task_description": obs["task_description"],
35
+ "active_alerts": copy.deepcopy(obs.get("active_alerts", [])),
36
+ "recent_logs": copy.deepcopy(obs.get("recent_logs", {})),
37
+ "evidence_log": copy.deepcopy(obs.get("evidence_log", [])),
38
+ "last_action_result": obs.get("last_action_result"),
39
+ "last_action_error": obs.get("last_action_error"),
40
+ "elapsed_minutes": obs["elapsed_minutes"],
41
+ "incident_start_time": obs["incident_start_time"],
42
+ "role": "observer",
43
+ "instructions": (
44
+ "You are the Observer. You can ONLY call share_finding. "
45
+ "Read logs and alerts carefully, then share findings with "
46
+ "the Responder agent."
47
+ ),
48
+ "findings_from_b": [],
49
+ }
50
+
51
+ def get_observation_b(self) -> dict:
52
+ obs = self._observation_dict()
53
+ return {
54
+ "step": obs["step"],
55
+ "max_steps": obs["max_steps"],
56
+ "task_id": obs["task_id"],
57
+ "task_description": obs["task_description"],
58
+ "services": copy.deepcopy(obs.get("services", [])),
59
+ "service_dependencies": copy.deepcopy(obs.get("service_dependencies", [])),
60
+ "sla_status": copy.deepcopy(obs.get("sla_status", {})),
61
+ "last_action_result": obs.get("last_action_result"),
62
+ "last_action_error": obs.get("last_action_error"),
63
+ "elapsed_minutes": obs["elapsed_minutes"],
64
+ "incident_start_time": obs["incident_start_time"],
65
+ "role": "responder",
66
+ "instructions": (
67
+ "You are the Responder. Use Agent A findings plus service "
68
+ "metrics to diagnose and fix the incident."
69
+ ),
70
+ "agent_a_findings": copy.deepcopy(self.findings_log),
71
+ }
72
+
73
+ def step_a(self, finding_text: str) -> dict:
74
+ if self.done:
75
+ return {"error": "episode complete"}
76
+ if not finding_text or len(finding_text.strip()) < 5:
77
+ return {"error": "finding too short", "reward": 0.0}
78
+ entry = {
79
+ "agent": "A",
80
+ "step": self.step_count,
81
+ "finding": finding_text.strip(),
82
+ }
83
+ self.findings_log.append(entry)
84
+ return {
85
+ "accepted": True,
86
+ "reward": 0.05,
87
+ "finding_recorded": entry,
88
+ "total_findings": len(self.findings_log),
89
+ "observation": self.get_observation_a(),
90
+ }
91
+
92
+ def step_b(self, action: Action) -> dict:
93
+ if self.done:
94
+ return {"error": "episode complete"}
95
+ self.step_count += 1
96
+ result = self.env.step(action)
97
+ self.full_obs = result.observation
98
+ if result.done:
99
+ self.done = True
100
+ return {
101
+ "observation": self.get_observation_b(),
102
+ "reward": result.reward,
103
+ "done": result.done,
104
+ "info": result.info,
105
+ "agent_a_findings_count": len(self.findings_log),
106
+ }
107
+
108
+ def get_state(self) -> dict:
109
+ return {
110
+ "session_id": self.session_id,
111
+ "task_id": self.task_id,
112
+ "seed": self.seed,
113
+ "step": self.step_count,
114
+ "done": self.done,
115
+ "findings_log": copy.deepcopy(self.findings_log),
116
+ "observation_a": self.get_observation_a(),
117
+ "observation_b": self.get_observation_b(),
118
+ }