File size: 10,149 Bytes
250ab26
35ea9cd
 
250ab26
9347ce5
35ea9cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6d1ff0
af2ccc5
b6d1ff0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35ea9cd
250ab26
 
 
 
 
 
35ea9cd
 
 
af2ccc5
35ea9cd
af2ccc5
35ea9cd
250ab26
35ea9cd
 
 
 
 
 
 
 
 
 
af2ccc5
35ea9cd
af2ccc5
35ea9cd
250ab26
35ea9cd
 
af2ccc5
35ea9cd
 
 
 
 
 
 
250ab26
 
 
 
 
35ea9cd
 
250ab26
 
 
 
 
 
35ea9cd
 
 
 
 
 
 
250ab26
 
b6d1ff0
250ab26
35ea9cd
250ab26
35ea9cd
 
250ab26
35ea9cd
 
af2ccc5
35ea9cd
 
250ab26
 
35ea9cd
 
 
 
 
 
 
18aa055
35ea9cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250ab26
35ea9cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6d1ff0
 
35ea9cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6d1ff0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
import random
import uuid

from incidents import TICKETS
from graders import GRADERS
from models import (
    IncidentAction,
    IncidentObservation,
    IncidentReward,
    IncidentState,
    RecommendedAction,
    RootCauseCategory,
    SeverityLevel,
    StepResult,
    TaskType,
)

TASK_SPECS = {
    TaskType.TASK1: {
        "name": "Severity Classification",
        "difficulty": "easy",
        "expected_field": "severity",
        "allowed_values": [item.value for item in SeverityLevel],
        "description": "Classify the severity of the incident using blast radius, user impact, and business risk.",
    },
    TaskType.TASK2: {
        "name": "Root Cause Classification",
        "difficulty": "medium",
        "expected_field": "root_cause",
        "allowed_values": [item.value for item in RootCauseCategory],
        "description": "Identify the most likely failure domain from the incident evidence.",
    },
    TaskType.TASK3: {
        "name": "Recommended Action",
        "difficulty": "hard",
        "expected_field": "action",
        "allowed_values": [item.value for item in RecommendedAction],
        "description": "Choose the best immediate operational response for stabilizing the incident.",
    },
}
DEFAULT_RESET_SEED = 42
INITIAL_REWARD = 0.01


def validate_ticket_dataset(tickets: list[dict]) -> None:
    for ticket in tickets:
        incident_id = ticket.get("incident_id", "<unknown>")
        task_type_raw = ticket.get("task_type")
        try:
            task_type = TaskType(task_type_raw)
        except ValueError as exc:
            raise RuntimeError(
                f"Ticket '{incident_id}' has unsupported task_type '{task_type_raw}'."
            ) from exc

        expected_field = TASK_SPECS[task_type]["expected_field"]
        ground_truth = ticket.get("ground_truth")
        if not isinstance(ground_truth, dict) or not ground_truth:
            raise RuntimeError(f"Ticket '{incident_id}' has empty ground_truth.")
        if set(ground_truth.keys()) != {expected_field}:
            raise RuntimeError(
                f"Ticket '{incident_id}' must define only '{expected_field}' in ground_truth."
            )
        if ground_truth.get(expected_field) in (None, ""):
            raise RuntimeError(
                f"Ticket '{incident_id}' has missing value for ground_truth['{expected_field}']."
            )


validate_ticket_dataset(TICKETS)
TICKETS_BY_ID = {ticket["incident_id"]: ticket for ticket in TICKETS}


class IncidentEnv:

    def __init__(self):
        self.current_ticket = None
        self.episode_id = ""
        self.step_count = 0
        self.max_steps = 1
        self.total_reward = INITIAL_REWARD
        self.done = False
        self.last_reward = INITIAL_REWARD
        self.last_action_summary = None

    def reset(
        self,
        task_type: TaskType | str | None = None,
        ticket_id: str | None = None,
        seed: int | None = None,
    ) -> StepResult:
        normalized_task = TaskType(task_type) if task_type else None
        self.current_ticket = self._select_ticket(normalized_task, ticket_id, seed)
        self.episode_id = str(uuid.uuid4())
        self.step_count = 0
        self.total_reward = INITIAL_REWARD
        self.done = False
        self.last_reward = INITIAL_REWARD
        self.last_action_summary = None

        return StepResult(
            observation=self._build_observation(),
            reward=IncidentReward(value=INITIAL_REWARD, reason="Episode initialized and awaiting first action."),
            done=False,
            info={
                "episode_id": self.episode_id,
                "task_name": self._task_spec()["name"],
                "difficulty": self._task_spec()["difficulty"],
                "max_steps": self.max_steps,
            },
        )

    def step(self, action: IncidentAction) -> StepResult:
        if self.current_ticket is None:
            raise RuntimeError("Call reset() before step()")
        if self.done:
            raise RuntimeError("Episode already completed. Call reset() to start a new one.")

        if action.incident_id != self.current_ticket["incident_id"]:
            raise ValueError(
                f"Action incident_id '{action.incident_id}' does not match "
                f"current ticket '{self.current_ticket['incident_id']}'"
            )
        if action.task_type != TaskType(self.current_ticket["task_type"]):
            raise ValueError(
                f"Action task_type '{action.task_type.value}' does not match "
                f"current ticket task_type '{self.current_ticket['task_type']}'"
            )

        self._validate_action(action)

        task_type = self.current_ticket["task_type"]
        ground_truth, ground_truth_value = self._validated_ground_truth()
        grader_fn = GRADERS[task_type]
        reward_value, reward_reason = grader_fn(action, ground_truth)

        agent_answer = action.selected_value() or "NONE"
        selected_field = action.selected_field() or "NONE"

        self.step_count += 1
        self.last_reward = reward_value
        self.total_reward = reward_value
        self.done = self.step_count >= self.max_steps
        self.last_action_summary = f"Submitted {selected_field}={agent_answer}"

        return StepResult(
            observation=self._build_observation(),
            reward=IncidentReward(value=reward_value, reason=reward_reason),
            done=self.done,
            info={
                "episode_id": self.episode_id,
                "task_name": self._task_spec()["name"],
                "difficulty": self._task_spec()["difficulty"],
                "correct": agent_answer == ground_truth_value,
                "ground_truth": ground_truth_value,
                "agent_answer": agent_answer,
                "selected_field": selected_field,
                "max_steps": self.max_steps,
            },
        )

    def state(self, session_id: str | None = None) -> IncidentState:
        if self.current_ticket is None:
            raise RuntimeError("No active episode. Call reset() first.")

        return IncidentState(
            episode_id=self.episode_id,
            session_id=session_id,
            step_count=self.step_count,
            max_steps=self.max_steps,
            total_reward=self.total_reward,
            done=self.done,
            incident_id=self.current_ticket["incident_id"],
            task_type=TaskType(self.current_ticket["task_type"]),
            difficulty=self._task_spec()["difficulty"],
            status="completed" if self.done else "awaiting_action",
            last_reward=self.last_reward,
        )

    def _select_ticket(
        self,
        task_type: TaskType | None = None,
        ticket_id: str | None = None,
        seed: int | None = None,
    ) -> dict:
        if ticket_id:
            ticket = TICKETS_BY_ID.get(ticket_id)
            if ticket is None:
                raise ValueError(f"No ticket found for ticket_id: {ticket_id}")
            if task_type and ticket["task_type"] != task_type.value:
                raise ValueError(
                    f"Ticket '{ticket_id}' belongs to task_type '{ticket['task_type']}', "
                    f"not '{task_type.value}'"
                )
            return ticket

        pool = TICKETS
        if task_type:
            pool = [ticket for ticket in TICKETS if ticket["task_type"] == task_type.value]
        if not pool:
            raise ValueError(f"No tickets found for task_type: {task_type}")

        effective_seed = seed if seed is not None else DEFAULT_RESET_SEED
        chooser = random.Random(effective_seed)
        return chooser.choice(pool)

    def _task_spec(self) -> dict:
        if self.current_ticket is None:
            raise RuntimeError("No active episode. Call reset() first.")
        return TASK_SPECS[TaskType(self.current_ticket["task_type"])]

    def _build_observation(self) -> IncidentObservation:
        spec = self._task_spec()
        return IncidentObservation(
            incident_id=self.current_ticket["incident_id"],
            task_type=TaskType(self.current_ticket["task_type"]),
            difficulty=spec["difficulty"],
            task_description=spec["description"],
            alert_text=self.current_ticket["alert_text"],
            context=self.current_ticket["context"],
            expected_field=spec["expected_field"],
            allowed_values=spec["allowed_values"],
            step_count=self.step_count,
            max_steps=self.max_steps,
            last_action_summary=self.last_action_summary,
            last_reward=self.last_reward,
            episode_status="completed" if self.done else "awaiting_action",
        )

    def _validate_action(self, action: IncidentAction) -> None:
        populated = action.populated_fields()
        if len(populated) != 1:
            raise ValueError("Action must populate exactly one of severity, root_cause, or action.")

        expected_field = self._task_spec()["expected_field"]
        if expected_field not in populated:
            raise ValueError(
                f"Task '{self.current_ticket['task_type']}' expects field '{expected_field}', "
                f"but got '{next(iter(populated))}'."
            )

    def _validated_ground_truth(self) -> tuple[dict, str]:
        if self.current_ticket is None:
            raise RuntimeError("No active episode. Call reset() first.")

        incident_id = self.current_ticket["incident_id"]
        expected_field = self._task_spec()["expected_field"]
        ground_truth = self.current_ticket.get("ground_truth")
        if not isinstance(ground_truth, dict) or not ground_truth:
            raise RuntimeError(
                f"Ticket '{incident_id}' has empty ground_truth. This is a dataset integrity error."
            )
        if expected_field not in ground_truth or ground_truth[expected_field] in (None, ""):
            raise RuntimeError(
                f"Ticket '{incident_id}' is missing ground_truth['{expected_field}']."
            )
        return ground_truth, str(ground_truth[expected_field])