Spaces:
Sleeping
Sleeping
File size: 38,366 Bytes
1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b 1175c0b 19f7f7b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 | """
Core Environment implementation.
Per-step execution order: validate → mutate → tick → observe → reward.
Two-phase architecture:
Phase 1 — ops/SRE diagnostic loop (existing behavior).
Phase 2 — code attribution loop, sandboxed under a CodeWorkspace.
Mode selection is automatic per scenario:
- Scenario with `code_context = None` → legacy P1-only episode
(declare_root_cause terminates)
- Scenario with `code_context != None` → unified P1 → P2 episode
(declare_root_cause is silent;
transition_to_phase2 switches phase;
propose_patch / declare_no_change
terminate the episode)
The environment uses oracle-shaped per-step rewards for training. The
oracle-INDEPENDENT graders live on `BaseScenario` and `scenarios.grader_p2`.
"""
from __future__ import annotations
import uuid
import random
from dataclasses import asdict
from typing import Any, Dict, List, Optional, Tuple
from ..models import (
ActionType,
IncidentAction,
IncidentState,
StepRecord,
BeliefState,
DIAGNOSTIC_ACTIONS,
REMEDIATION_ACTIONS,
TARGETED_ACTIONS,
PHASE1_ACTIONS,
PHASE2_ACTIONS,
PHASE2_DIAGNOSTIC_ACTIONS,
PHASE2_TERMINAL_ACTIONS,
)
from ..simulation.infrastructure import Infrastructure, SERVICE_NAMES
from ..tasks import get_scenario, TASK_NAMES
from ..scenarios.base import BaseScenario
from ..pools import POOLS, get_pool, sample_task, oracle_belief
from .code_workspace import CodeWorkspace, CodeWorkspaceError
# Per-step reward constants ------------------------------------------------
_STEP_PENALTY = -0.02
_REPEAT_PENALTY = -0.05
_INVALID_PENALTY = -0.05
# Phase 2 shaping (small — terminal patch quality is graded post-hoc)
_P2_DIAG_REWARD = +0.05
_P2_TERMINAL_BONUS = +0.10
class IncidentEnvironment:
"""
SRE Incident Response Environment.
Implements the three OpenEnv methods:
- reset(task_name) → initial observation + info
- step(action) → dict with observation, reward, done, info
- state() → IncidentState for monitoring
Plus two extras used by the unified evaluator:
- get_trajectory() → P1 + P2 step records
- score_unified(...) → component scores for unified grader
"""
def __init__(self) -> None:
self._infra: Optional[Infrastructure] = None
self._scenario: Optional[BaseScenario] = None
self._state = IncidentState()
# ---- Per-episode mutable state ----
self._phase: int = 1
self._workspace: Optional[CodeWorkspace] = None
self._belief_at_transition: Optional[BeliefState] = None
self._p1_trajectory: List[StepRecord] = []
self._p2_trajectory: List[StepRecord] = []
self._declared_patch: Optional[str] = None
self._declared_no_change: bool = False
self._declared_root_cause: Optional[str] = None
self._cumulative_reward: float = 0.0
self._done: bool = False
# ---- Pool / mode (set by reset, drives episode semantics) ----
# mode in {"joint" (default), "p1_only" (Pool A), "p2_only" (Pool B)}
self._pool: Optional[str] = None
self._mode: str = "joint"
self._inject_oracle_belief: bool = False
# P2-only tracking (for repeat detection inside P2)
self._p2_actions_taken: List[Tuple[str, str]] = [] # (atype, primary_param)
# ==================================================================
# reset()
# ==================================================================
def reset(
self,
task_name: Optional[str] = None,
seed: Optional[int] = None,
pool: Optional[str] = None,
mode: Optional[str] = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""
Initialize a new incident episode.
`pool` selects training pool A/B/C/D (overrides default mode).
`mode` forces episode semantics ("p1_only"|"p2_only"|"joint").
Explicit `mode` always wins over pool defaults.
"""
if seed is not None:
random.seed(seed)
# ---- Pool / task selection ----
pool_obj = None
if pool:
pool_obj = get_pool(pool)
if task_name is None:
task_name = sample_task(pool, rng=random)
self._pool = pool_obj.name
self._mode = pool_obj.mode
self._inject_oracle_belief = pool_obj.inject_oracle_belief
else:
self._pool = None
self._mode = "joint"
self._inject_oracle_belief = False
if mode:
self._mode = mode
if mode == "p2_only":
self._inject_oracle_belief = True
if task_name is None:
task_name = random.choice(TASK_NAMES)
self._infra = Infrastructure()
self._scenario = get_scenario(task_name)
self._infra.time_budget_minutes = self._scenario.time_budget_minutes
self._scenario.inject(self._infra)
# Let cascades propagate a few minutes
for _ in range(3):
self._infra.tick()
self._state = IncidentState(
episode_id = str(uuid.uuid4()),
task_name = task_name,
step_count = 0,
time_elapsed_minutes = self._infra.current_minute,
done = False,
cumulative_reward = 0.0,
)
self._phase = 1
self._workspace = None
self._belief_at_transition = None
self._p1_trajectory = []
self._p2_trajectory = []
self._declared_patch = None
self._declared_no_change = False
self._declared_root_cause = None
self._cumulative_reward = 0.0
self._done = False
self._p2_actions_taken = []
# ---- Pool B (p2_only) auto-handoff with oracle belief --------
# The agent never sees Phase 1; we synthesise a perfect handoff and
# immediately switch the env into Phase 2.
if self._mode == "p2_only" and self._scenario.code_context is not None:
belief = oracle_belief(self._scenario)
self._handle_transition(IncidentAction(
action_type = ActionType.TRANSITION_TO_PHASE2.value,
target_service = None,
parameters = {"belief": asdict(belief)},
))
# _handle_transition already returned; we just consume its
# observation as the reset observation so caller sees Phase 2.
obs = self._build_observation(
action_result = {
"message": "[Pool B] Auto-handoff with oracle Phase-1 belief.",
"issue": self._scenario.build_p2_issue(belief),
"file_tree": (self._workspace.file_tree(max_depth=4)
if self._workspace else []),
"bad_commit_sha": self._scenario.code_context.bad_commit_sha,
"bad_commit": (self._workspace.bad_commit_metadata()
if self._workspace else None),
},
action_success = True,
action_message = "Episode started in Pool B (P2-only) mode",
reward = 0.0,
)
return {
"observation": obs,
"reward": 0.01,
"done": False,
"info": {"task_name": task_name,
"pool": self._pool,
"mode": self._mode,
"has_phase2": True,
"phase": 2},
}
obs = self._build_observation(
action_result = {"message": "Incident triggered. Begin investigation."},
action_success = True,
action_message = "Episode started",
reward = 0.0,
)
return {
"observation": obs,
"reward": 0.01,
"done": False,
"info": {"task_name": task_name,
"pool": self._pool,
"mode": self._mode,
"has_phase2": self._scenario.code_context is not None},
}
# ==================================================================
# step()
# ==================================================================
def step(self, action_data: Dict[str, Any]) -> Dict[str, Any]:
"""Execute one agent action — phase-aware dispatch."""
if self._done:
return self._final_step_response()
if self._infra is None or self._scenario is None:
return self._not_initialized_response()
action = IncidentAction(
action_type = action_data.get("action_type", ""),
target_service = action_data.get("target_service"),
parameters = action_data.get("parameters", {}) or {},
)
# ---- Type validation ----------------------------------------
try:
atype = ActionType(action.action_type)
except ValueError:
return self._invalid_action_response(
f"Unknown action type: {action.action_type!r}",
action,
)
# ---- Phase-aware dispatch -----------------------------------
if atype == ActionType.TRANSITION_TO_PHASE2:
return self._handle_transition(action)
if self._phase == 1:
if atype not in PHASE1_ACTIONS:
return self._invalid_action_response(
f"Action {atype.value!r} not allowed in Phase 1", action,
)
return self._step_phase1(action, atype)
# Phase 2
if atype not in PHASE2_ACTIONS:
return self._invalid_action_response(
f"Action {atype.value!r} not allowed in Phase 2", action,
)
return self._step_phase2(action, atype)
# ------------------------------------------------------------------
# Phase 1 step
# ------------------------------------------------------------------
def _step_phase1(
self,
action: IncidentAction,
atype: ActionType,
) -> Dict[str, Any]:
# Validate target / preconditions via Infrastructure
is_valid, err = self._infra.validate_action(
action.action_type, action.target_service)
if not is_valid:
return self._invalid_action_response(err, action)
# Mutate
action_result, action_msg = self._execute_p1_action(action, atype)
# Tick simulation
self._infra.tick()
self._state.step_count += 1
self._state.time_elapsed_minutes = self._infra.current_minute
# Reward (compute BEFORE recording so repeat-detection sees prior actions)
reward = self._compute_p1_reward(action, atype)
self._infra.record_action(action.action_type, action.target_service)
self._cumulative_reward += reward
self._state.cumulative_reward = self._cumulative_reward
# Done check
done = self._check_done_p1(atype)
self._done = done
self._state.done = done
obs = self._build_observation(
action_result = action_result,
action_success = True,
action_message = action_msg,
reward = reward,
)
record = StepRecord(
step_number = self._state.step_count,
action = action,
reward = reward,
observation_summary = {
"action_message": obs.get("action_message", ""),
"active_alerts_count": obs.get("active_alerts_count", 0),
},
service_statuses_after = dict(obs.get("service_statuses", {})),
timestamp_minutes = self._infra.current_minute,
phase = 1,
)
self._p1_trajectory.append(record)
info: Dict[str, Any] = {}
if done:
info["score"] = self._scenario.grade(self._p1_trajectory)
info["task_name"] = self._scenario.task_name
info["steps_taken"] = self._state.step_count
info["trajectory_length"] = len(self._p1_trajectory)
return {"observation": obs, "reward": reward, "done": done, "info": info}
# ------------------------------------------------------------------
# Phase 2 step
# ------------------------------------------------------------------
def _step_phase2(
self,
action: IncidentAction,
atype: ActionType,
) -> Dict[str, Any]:
if self._workspace is None:
return self._invalid_action_response(
"Phase 2 not initialised — must transition_to_phase2 first.",
action,
)
params = action.parameters or {}
# ---- Execute action ----
try:
if atype == ActionType.LIST_DIR:
result = self._workspace.list_dir(params.get("path", "."))
msg = f"Listed {result.get('count', 0)} entries in {result.get('path', '.')}"
elif atype == ActionType.READ_FILE:
result = self._workspace.read_file(params.get("path", ""))
msg = f"Read {result.get('path')} ({result.get('size', 0)} bytes)"
elif atype == ActionType.SEARCH_CODE:
result = self._workspace.search_code(
query = params.get("query", ""),
file_pattern = params.get("file_pattern", "*.py"),
max_hits = params.get("max_hits"),
)
msg = f"Found {result.get('count', 0)} hit(s) for {params.get('query')!r}"
elif atype == ActionType.GET_GIT_LOG:
result = self._workspace.get_git_log(
path = params.get("path", ""),
n_commits = int(params.get("n_commits", 10)),
)
msg = f"Returned {result.get('count', 0)} commit(s)"
elif atype == ActionType.GET_FILE_DIFF:
result = self._workspace.get_file_diff(
commit_sha = params.get("commit_sha", ""),
path = params.get("path", ""),
)
msg = f"Diff for {result.get('commit_sha')[:8]} ({len(result.get('diff', ''))} bytes)"
elif atype == ActionType.PROPOSE_PATCH:
diff = params.get("diff", "")
self._declared_patch = diff
result = {"accepted": True, "patch_bytes": len(diff)}
msg = "Patch proposal accepted — episode terminating."
elif atype == ActionType.DECLARE_NO_CHANGE:
self._declared_no_change = True
reason = params.get("reason", "")
result = {"accepted": True, "reason": reason}
msg = "no-change declaration accepted — episode terminating."
else:
return self._invalid_action_response(
f"Unhandled P2 action type: {atype.value!r}", action,
)
success = True
except CodeWorkspaceError as e:
result = {"error": str(e)}
msg = f"Workspace error: {e}"
success = False
# ---- Tick (simulation time still advances during P2) ----
self._infra.tick()
self._state.step_count += 1
self._state.time_elapsed_minutes = self._infra.current_minute
# ---- Reward ----
reward = self._compute_p2_reward(action, atype, success)
self._cumulative_reward += reward
self._state.cumulative_reward = self._cumulative_reward
# ---- Done ----
done = (atype in PHASE2_TERMINAL_ACTIONS) or self._exceeded_step_budget()
self._done = done
self._state.done = done
obs = self._build_observation(
action_result = result,
action_success = success,
action_message = msg,
reward = reward,
)
# Record
record = StepRecord(
step_number = self._state.step_count,
action = action,
reward = reward,
observation_summary = {
"action_message": obs.get("action_message", ""),
"p2_action": atype.value,
},
service_statuses_after = dict(obs.get("service_statuses", {})),
timestamp_minutes = self._infra.current_minute,
phase = 2,
)
self._p2_trajectory.append(record)
# Track repeats inside P2
prim_param = self._p2_primary_param(atype, params)
self._p2_actions_taken.append((atype.value, prim_param))
info: Dict[str, Any] = {}
if done:
info["score"] = self._compute_unified_final_score()
info["task_name"] = self._scenario.task_name
info["steps_taken"] = self._state.step_count
info["trajectory_length"] = len(self._p1_trajectory) + len(self._p2_trajectory)
return {"observation": obs, "reward": reward, "done": done, "info": info}
# ------------------------------------------------------------------
# transition_to_phase2 handler
# ------------------------------------------------------------------
def _handle_transition(self, action: IncidentAction) -> Dict[str, Any]:
if self._phase != 1:
return self._invalid_action_response(
"Already in Phase 2 — cannot transition again.", action,
)
if self._scenario is None or self._scenario.code_context is None:
return self._invalid_action_response(
"Scenario has no code_context — Phase 2 unavailable.", action,
)
ctx = self._scenario.code_context
# Construct workspace
try:
self._workspace = CodeWorkspace(
snapshot_root = ctx.repo_snapshot_path,
bad_commit_sha = ctx.bad_commit_sha,
)
except CodeWorkspaceError as e:
return self._invalid_action_response(
f"Cannot open snapshot: {e}", action,
)
# Capture handoff belief
belief_dict = (action.parameters or {}).get("belief") or {}
self._belief_at_transition = self._coerce_belief(belief_dict)
# Switch phase
self._phase = 2
self._state.step_count += 1
self._infra.tick()
self._state.time_elapsed_minutes = self._infra.current_minute
# Initial P2 obs
issue_text = self._scenario.build_p2_issue(self._belief_at_transition)
file_tree = self._workspace.file_tree(max_depth=4)
commit_meta = self._workspace.bad_commit_metadata()
action_result = {
"phase": 2,
"issue": issue_text,
"file_tree": file_tree,
"bad_commit_sha": ctx.bad_commit_sha,
"bad_commit": commit_meta,
"snapshot_root": str(self._workspace.tree_root),
}
# Reward: small handoff bonus only when belief is non-trivial
reward = 0.0
if self._belief_at_transition.suspected_service:
reward += 0.05
self._cumulative_reward += reward
self._state.cumulative_reward = self._cumulative_reward
obs = self._build_observation(
action_result = action_result,
action_success = True,
action_message = "Transitioned to Phase 2 (code attribution).",
reward = reward,
)
record = StepRecord(
step_number = self._state.step_count,
action = action,
reward = reward,
observation_summary = {
"action_message": "transition_to_phase2",
"transition": True,
},
service_statuses_after = dict(obs.get("service_statuses", {})),
timestamp_minutes = self._infra.current_minute,
phase = 2,
belief_state_snapshot = asdict(self._belief_at_transition),
)
self._p2_trajectory.append(record)
return {"observation": obs, "reward": reward, "done": False, "info": {}}
@staticmethod
def _coerce_belief(d: Dict[str, Any]) -> BeliefState:
"""Best-effort: turn an inference-side dict into the canonical BeliefState."""
gaps = d.get("evidence_gaps", [])
if isinstance(gaps, str):
gaps = [g.strip() for g in gaps.split(",") if g.strip() and g.strip() != "none"]
return BeliefState(
suspected_service = d.get("suspected_service") or None,
suspected_fault_class = d.get("suspected_fault_class") or None,
service_confidence = float(d.get("service_confidence") or 0.0),
fault_confidence = float(d.get("fault_confidence") or 0.0),
evidence_gaps = list(gaps),
estimated_p2_cost = d.get("estimated_p2_cost") or "unknown",
decision = d.get("decision") or "transition",
reasoning = d.get("reasoning") or "",
)
# ==================================================================
# state()
# ==================================================================
@property
def state(self) -> IncidentState:
return self._state
def get_state(self) -> Dict[str, Any]:
return {
"episode_id": self._state.episode_id,
"task_name": self._state.task_name,
"step_count": self._state.step_count,
"time_elapsed_minutes": self._state.time_elapsed_minutes,
"done": self._state.done,
"cumulative_reward": round(self._state.cumulative_reward, 3),
"declared_root_cause": self._declared_root_cause,
"declared_patch": self._declared_patch,
"declared_no_change": self._declared_no_change,
"phase": self._phase,
"phase_transition_at": next(
(r.step_number for r in self._p2_trajectory
if r.action.action_type == ActionType.TRANSITION_TO_PHASE2.value),
None,
),
}
# ==================================================================
# Phase 1 action execution
# ==================================================================
def _execute_p1_action(
self,
action: IncidentAction,
atype: ActionType,
) -> Tuple[Dict[str, Any], str]:
target = action.target_service
params = action.parameters or {}
if atype == ActionType.VIEW_ALERTS:
alerts = self._infra.get_alerts()
return {"alerts": alerts, "count": len(alerts)}, \
f"Viewing {len(alerts)} active alerts"
if atype == ActionType.QUERY_LOGS:
level = params.get("level")
keyword = params.get("keyword")
limit = params.get("limit", 15)
logs = self._infra.get_logs_for_service(target, level, keyword, limit)
return {"logs": logs, "count": len(logs), "service": target}, \
f"Queried {len(logs)} logs from {target}"
if atype == ActionType.CHECK_METRICS:
metrics = self._infra.get_metrics_for_service(target)
return {"metrics": metrics, "service": target,
"data_points": len(metrics)}, \
f"Retrieved {len(metrics)} metric points for {target}"
if atype == ActionType.CHECK_DEPENDENCIES:
deps = self._infra.get_dependencies_for_service(target)
return {"dependencies": deps, "service": target}, \
f"Retrieved dependency map for {target}"
if atype == ActionType.CHECK_DEPLOY_HISTORY:
deploys = self._infra.get_deploy_history_for_service(target)
return {"deploys": deploys, "service": target,
"count": len(deploys)}, \
f"Retrieved {len(deploys)} deploys for {target}"
if atype == ActionType.RUN_HEALTH_CHECK:
h = self._infra.run_health_check(target)
return {"health_check": h, "service": target}, \
f"Health check for {target}: {h['status']}"
if atype == ActionType.RESTART_SERVICE:
svc = self._infra.get_service(target)
msg = svc.restart(self._infra.current_minute) if svc else "Service not found"
return {"result": msg, "service": target}, msg
if atype == ActionType.ROLLBACK_DEPLOY:
svc = self._infra.get_service(target)
msg = svc.rollback_deploy(self._infra.current_minute) \
if svc else "Service not found"
return {"result": msg, "service": target}, msg
if atype == ActionType.SCALE_SERVICE:
svc = self._infra.get_service(target)
new_replicas = params.get("replicas", 5)
msg = svc.scale(new_replicas, self._infra.current_minute) \
if svc else "Service not found"
return {"result": msg, "service": target}, msg
if atype == ActionType.DECLARE_ROOT_CAUSE:
rc = params.get("root_cause", "")
self._declared_root_cause = rc
self._state.declared_root_cause = rc
return {
"declared": rc,
"message": ("Root cause declared. " +
("Episode continues — Phase 2 awaits."
if self._scenario.code_context
else "Episode will end after this step.")),
}, f"Root cause declared: {rc[:120]}"
return {"error": f"Unhandled action type: {atype.value}"}, "Unknown action"
# ==================================================================
# Reward computation
# ==================================================================
def _compute_p1_reward(
self,
action: IncidentAction,
atype: ActionType,
) -> float:
scenario = self._scenario
target = action.target_service
reward = _STEP_PENALTY
if self._infra.was_action_taken(action.action_type, target):
return round(reward + _REPEAT_PENALTY, 3)
if atype in DIAGNOSTIC_ACTIONS:
if target and target in scenario.involved_services:
reward += 0.15
elif target and target not in scenario.involved_services:
reward += 0.05
elif atype == ActionType.VIEW_ALERTS:
reward += 0.15
elif atype in REMEDIATION_ACTIONS:
if target == scenario.root_cause_service:
reward += 0.30
elif target and target in scenario.involved_services:
reward += 0.10
else:
reward -= 0.15
elif atype == ActionType.DECLARE_ROOT_CAUSE:
declared = (action.parameters or {}).get("root_cause", "").lower()
kws = scenario.root_cause_keywords
if kws:
ratio = sum(1 for k in kws if k in declared) / len(kws)
if ratio >= 0.6:
reward += 0.40
elif ratio >= 0.3:
reward += 0.15
else:
reward -= 0.20
else:
reward -= 0.20
# Completion bonus when episode terminates
if self._declared_root_cause and not scenario.code_context:
if self._infra.all_services_healthy():
reward += 0.20
if self._infra.current_minute > self._infra.time_budget_minutes:
reward -= 0.10
return round(reward, 3)
def _compute_p2_reward(
self,
action: IncidentAction,
atype: ActionType,
success: bool,
) -> float:
params = action.parameters or {}
prim = self._p2_primary_param(atype, params)
reward = _STEP_PENALTY
if not success:
return round(reward + _INVALID_PENALTY, 3)
if (atype.value, prim) in self._p2_actions_taken:
return round(reward + _REPEAT_PENALTY, 3)
if atype in PHASE2_DIAGNOSTIC_ACTIONS:
reward += _P2_DIAG_REWARD
elif atype in PHASE2_TERMINAL_ACTIONS:
reward += _P2_TERMINAL_BONUS
return round(reward, 3)
@staticmethod
def _p2_primary_param(atype: ActionType, params: Dict[str, Any]) -> str:
if atype == ActionType.LIST_DIR:
return params.get("path", ".")
if atype == ActionType.READ_FILE:
return params.get("path", "")
if atype == ActionType.SEARCH_CODE:
return params.get("query", "")
if atype == ActionType.GET_GIT_LOG:
return params.get("path", "")
if atype == ActionType.GET_FILE_DIFF:
return f'{params.get("commit_sha", "")}:{params.get("path", "")}'
return ""
# ==================================================================
# Done logic
# ==================================================================
def _check_done_p1(self, atype: ActionType) -> bool:
# Pool A / explicit p1_only mode: declare_root_cause always terminates,
# regardless of whether the scenario could otherwise transition to P2.
if atype == ActionType.DECLARE_ROOT_CAUSE:
if self._mode == "p1_only" or self._scenario.code_context is None:
return True
if self._exceeded_step_budget():
return True
return False
def _exceeded_step_budget(self) -> bool:
budget = self._scenario.max_steps if self._scenario else 20
# When code_context exists, allow a bit more headroom for P2 exploration
if self._scenario and self._scenario.code_context is not None:
budget = budget + 15
return self._state.step_count >= budget
# ==================================================================
# Observation builder
# ==================================================================
def _build_observation(
self,
action_result: Dict[str, Any],
action_success: bool,
action_message: str,
reward: float,
) -> Dict[str, Any]:
statuses = self._infra.get_all_statuses() if self._infra else {}
alerts = self._infra.get_alerts() if self._infra else []
valid_actions = self._valid_actions_for_phase()
return {
"incident_summary": self._scenario.incident_summary if self._scenario else "",
"severity": self._scenario.severity if self._scenario else "SEV3",
"time_elapsed_minutes": self._infra.current_minute if self._infra else 0,
"time_budget_minutes": self._infra.time_budget_minutes if self._infra else 30,
"action_result": action_result,
"action_success": action_success,
"action_message": action_message,
"service_statuses": statuses,
"active_alerts_count": len(alerts),
"valid_actions": valid_actions,
"available_services": list(SERVICE_NAMES),
"current_phase": self._phase,
"current_reward": reward,
"cumulative_reward": round(self._cumulative_reward, 3),
"steps_taken": self._state.step_count,
"max_steps": self._scenario.max_steps if self._scenario else 20,
"done": self._done,
# Convenience field surfaced after transition (so the inference loop
# can grab it without re-issuing a step) — only meaningful after
# transition_to_phase2 has been called.
"bad_commit_sha": (self._scenario.code_context.bad_commit_sha
if self._scenario and self._scenario.code_context else None),
}
def _valid_actions_for_phase(self) -> List[str]:
if self._phase == 1:
base = self._infra.get_valid_actions() if self._infra else []
# Filter to only P1 + (optionally) transition_to_phase2
valid = [a for a in base
if a.split(":", 1)[0] in {at.value for at in PHASE1_ACTIONS}]
if self._scenario and self._scenario.code_context is not None:
valid.append(ActionType.TRANSITION_TO_PHASE2.value)
return valid
# Phase 2
return [at.value for at in PHASE2_ACTIONS]
# ==================================================================
# Trajectory access (used by /score endpoint and Pool runners)
# ==================================================================
def get_trajectory(self) -> List[StepRecord]:
return list(self._p1_trajectory) + list(self._p2_trajectory)
def get_p1_trajectory(self) -> List[StepRecord]:
return list(self._p1_trajectory)
def get_p2_trajectory(self) -> List[StepRecord]:
return list(self._p2_trajectory)
def get_belief_at_transition(self) -> Optional[BeliefState]:
return self._belief_at_transition
# ==================================================================
# Final unified scoring
# ==================================================================
def _compute_unified_final_score(self) -> float:
"""Quick wrapper for the in-step `info.score` field."""
from ..tasks import grade_trajectory_unified
if self._scenario is None:
return 0.01
breakdown = grade_trajectory_unified(
task_name = self._scenario.task_name,
p1_trajectory = self._p1_trajectory,
p2_trajectory = self._p2_trajectory,
declared_patch = self._declared_patch,
declared_no_change = self._declared_no_change,
p1_belief_history = [],
)
return float(breakdown.get("final", 0.01))
def score_unified(
self,
belief_history: Optional[List[Dict[str, Any]]] = None,
) -> Dict[str, float]:
"""Public wrapper exposed by the /score endpoint."""
from ..tasks import grade_trajectory_unified
if self._scenario is None:
return {"final": 0.01}
return grade_trajectory_unified(
task_name = self._scenario.task_name,
p1_trajectory = self._p1_trajectory,
p2_trajectory = self._p2_trajectory,
declared_patch = self._declared_patch,
declared_no_change = self._declared_no_change,
p1_belief_history = belief_history or [],
)
# ==================================================================
# Error / fallback responses
# ==================================================================
def _invalid_action_response(
self,
msg: str,
action: IncidentAction,
) -> Dict[str, Any]:
reward = _INVALID_PENALTY
self._cumulative_reward += reward
self._state.step_count += 1
obs = self._build_observation(
action_result = {"error": msg},
action_success = False,
action_message = f"Invalid action: {msg}",
reward = reward,
)
# Still record the failed attempt so trajectory analysis sees it
record = StepRecord(
step_number = self._state.step_count,
action = action,
reward = reward,
observation_summary = {"action_message": f"invalid: {msg}"},
service_statuses_after = dict(obs.get("service_statuses", {})),
timestamp_minutes = self._infra.current_minute if self._infra else 0,
phase = self._phase,
)
if self._phase == 1:
self._p1_trajectory.append(record)
else:
self._p2_trajectory.append(record)
return {"observation": obs, "reward": reward, "done": False,
"info": {"error": msg}}
def _final_step_response(self) -> Dict[str, Any]:
obs = self._build_observation(
action_result = {"error": "Episode is already done."},
action_success = False,
action_message = "Episode already finished",
reward = 0.0,
)
score = (self._compute_unified_final_score()
if self._scenario and self._scenario.code_context
else (self._scenario.grade(self._p1_trajectory)
if self._scenario else 0.01))
return {"observation": obs, "reward": 0.01, "done": True,
"info": {"score": score}}
def _not_initialized_response(self) -> Dict[str, Any]:
obs = self._build_observation(
action_result = {"error": "Environment not initialized. Call reset() first."},
action_success = False,
action_message = "Not initialized",
reward = 0.0,
)
return {"observation": obs, "reward": 0.01, "done": False, "info": {}}
|