10doshi12 commited on
Commit
609f7b5
·
1 Parent(s): 5bc3005

fix(SPEC-7): fixes 1-6 — tick guard, cliff removal, blast penalty, MTTM stability, weighted error rate, metric-delta semantic analysis

Browse files

- grade(): tick guard returns 0.05 when ticks_taken < 2 (Fix 1)
- grade(): remove early-exit cliff wipe of bcm_score/slo (Fix 2)
- grade(): blast_ratio * 0.02 penalty rewards cascade containment (Fix 3)
- IncidentMetrics: require 3 consecutive zero-BCM ticks for MTTM (Fix 4)
- _build_semantic_analysis(): report metric deltas only, no outcome framing (Fix 5)
- _weighted_mean_error_rate(): weights by downstream dependent count (Fix 6)
- EpisodeResult: services_affected_static + total_services_in_episode fields
- _count_blast_radius(): BFS from root cause through reverse dependency graph

Files changed (3) hide show
  1. rewards.py +100 -67
  2. simulation.py +33 -4
  3. tests/test_rewards_fixes.py +30 -1
rewards.py CHANGED
@@ -96,9 +96,9 @@ class RewardEngine:
96
  Returns:
97
  Tuple of (total_reward, breakdown_dict).
98
  """
99
- # 1. Health improvement: mean error rate decrease
100
- prev_mean = _mean_error_rate(prev_obs)
101
- next_mean = _mean_error_rate(next_obs)
102
  health_improvement = (prev_mean - next_mean) * REWARD_WEIGHT_HEALTH
103
 
104
  # 2. SLO preservation: budget change
@@ -257,27 +257,25 @@ def grade(episode_result: EpisodeResult, difficulty: str) -> float:
257
  if task is None:
258
  return 0.0
259
 
 
 
 
 
260
  max_ticks = task.max_ticks
261
  max_bcm = task.max_bad_customer_minutes
262
 
263
  # 1. Recovery (40%)
264
- if er.services_affected > 0:
265
- recovery = er.services_recovered / er.services_affected
266
- else:
267
- recovery = 1.0 # No affected services = perfect recovery
 
268
 
269
- # Penalize early exit without fix: if the agent gave up, assume worst case for BCM and SLO
270
- if recovery < 1.0 and er.ticks_taken < max_ticks:
271
- bcm_score = 0.0
272
- slo = 0.0
273
- else:
274
- # BCM score: total user impact relative to worst case
275
- bcm_score = max(0.0, 1.0 - (er.bad_customer_minutes / max_bcm))
276
- # SLO (15%) — budget remaining
277
- slo = max(0.0, min(1.0, er.final_slo_budget / 100.0))
278
 
279
  # 2. Speed (25%) — composite of MTTM + BCM
280
- # MTTM score: how quickly user impact was zeroed
281
  if er.mttm_ticks is not None:
282
  mttm_score = max(0.0, 1.0 - (er.mttm_ticks / max_ticks))
283
  else:
@@ -292,19 +290,25 @@ def grade(episode_result: EpisodeResult, difficulty: str) -> float:
292
  precision = max(
293
  0.0, 1.0 - (er.wrong_actions * GRADER_WRONG_ACTION_PENALTY_PER_ACTION)
294
  )
295
-
296
  # False resolution penalty
297
  if recovery == 0.0:
298
  precision = 0.0 # doing nothing then exiting is inherently imprecise
299
 
300
- # Final weighted score
301
- score = (
302
  GRADER_WEIGHT_RECOVERY * recovery
303
  + GRADER_WEIGHT_SPEED * speed
304
  + GRADER_WEIGHT_PRECISION * precision
305
  + GRADER_WEIGHT_SLO * slo
306
  )
307
 
 
 
 
 
 
 
308
  return max(0.01, min(0.99, round(score, 2)))
309
 
310
 
@@ -417,69 +421,63 @@ def _build_semantic_analysis(
417
  prev_obs: SystemObservation,
418
  recovering: list[str],
419
  ) -> str:
420
- """Generate contextual narrative for the LLM judge."""
 
 
 
 
 
421
  parts: list[str] = []
422
 
423
  if not action_valid:
424
  parts.append(
425
- f"Agent attempted '{action.action_type}' but the action was "
426
- f"invalid. No system state was modified."
427
  )
428
  elif wrong_action:
 
 
 
 
429
  parts.append(
430
- f"Agent applied '{action.action_type}' to "
431
- f"'{action.target_service}' which was not significantly degraded. "
432
- f"This indicates premature remediation before sufficient "
433
- f"investigation. The actual root cause remains unaddressed."
434
  )
435
  elif action.action_type in ("fetch_logs", "get_metrics_detail", "trace_dependencies"):
436
  parts.append(
437
- f"Agent performed investigation: '{action.action_type}' on "
438
- f"'{action.target_service}'. This is an information-gathering "
439
- f"step that does not modify system state."
440
  )
441
- elif action.action_type in ("restart_service", "rollback_deploy", "revert_config", "scale_replicas", "circuit_break"):
442
- parts.append(
443
- f"Agent applied remediation: '{action.action_type}' to "
444
- f"'{action.target_service}'."
445
- )
446
- if recovering:
447
- parts.append(
448
- f"System health is improving services recovering: "
449
- f"{recovering}."
450
- )
451
- else:
452
- parts.append(
453
- f"No immediate improvement observed. The remediation may "
454
- f"need time to take effect, or it may be targeting the "
455
- f"wrong service/fault type."
456
- )
 
457
  elif action.action_type == "declare_resolved":
458
- parts.append("Agent declared the incident resolved. Episode ending.")
459
  elif action.action_type == "escalate":
460
- parts.append(
461
- "Agent escalated the incident. This costs SLO budget but "
462
- "brings specialist attention."
463
- )
464
 
465
- # Overall state assessment
466
- degraded_count = sum(
467
- 1 for m in next_obs.services.values() if m.status != "healthy"
468
- )
469
  total = len(next_obs.services)
470
- if degraded_count == 0:
471
- parts.append("All services are now healthy.")
472
- elif degraded_count == total:
473
- parts.append(
474
- "All services are degraded — situation is critical. "
475
- "Immediate action required."
476
- )
477
- else:
478
- parts.append(
479
- f"{degraded_count}/{total} services remain degraded."
480
- )
481
 
482
- return " ".join(parts)
483
 
484
 
485
  def _assess_progress(obs: SystemObservation, done: bool) -> str:
@@ -507,6 +505,41 @@ def _assess_progress(obs: SystemObservation, done: bool) -> str:
507
  # Helper
508
  # ==========================================================================
509
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
510
  def _mean_error_rate(obs: SystemObservation) -> float:
511
  """Compute mean error rate across all services in observation."""
512
  services = obs.services
 
96
  Returns:
97
  Tuple of (total_reward, breakdown_dict).
98
  """
99
+ # 1. Health improvement: weighted mean error rate decrease
100
+ prev_mean = _weighted_mean_error_rate(prev_obs.services, prev_obs.dependency_graph)
101
+ next_mean = _weighted_mean_error_rate(next_obs.services, next_obs.dependency_graph)
102
  health_improvement = (prev_mean - next_mean) * REWARD_WEIGHT_HEALTH
103
 
104
  # 2. SLO preservation: budget change
 
257
  if task is None:
258
  return 0.0
259
 
260
+ # Fix 1: Tick guard — declare_resolved before tick 2 earns near-zero score
261
+ if er.ticks_taken < 2:
262
+ return 0.05
263
+
264
  max_ticks = task.max_ticks
265
  max_bcm = task.max_bad_customer_minutes
266
 
267
  # 1. Recovery (40%)
268
+ # The tick guard above handles Fix 1 (tick-0 exploit).
269
+ # Use runtime services_affected as denominator — blast penalty (below) is what
270
+ # differentiates agents who contained vs didn't contain the cascade.
271
+ denominator = er.services_affected or 1
272
+ recovery = min(1.0, er.services_recovered / denominator)
273
 
274
+ # Fix 2: No cliff wipe compute BCM and SLO unconditionally
275
+ bcm_score = max(0.0, 1.0 - (er.bad_customer_minutes / max_bcm))
276
+ slo = max(0.0, min(1.0, er.final_slo_budget / 100.0))
 
 
 
 
 
 
277
 
278
  # 2. Speed (25%) — composite of MTTM + BCM
 
279
  if er.mttm_ticks is not None:
280
  mttm_score = max(0.0, 1.0 - (er.mttm_ticks / max_ticks))
281
  else:
 
290
  precision = max(
291
  0.0, 1.0 - (er.wrong_actions * GRADER_WRONG_ACTION_PENALTY_PER_ACTION)
292
  )
293
+
294
  # False resolution penalty
295
  if recovery == 0.0:
296
  precision = 0.0 # doing nothing then exiting is inherently imprecise
297
 
298
+ # Raw weighted score
299
+ raw = (
300
  GRADER_WEIGHT_RECOVERY * recovery
301
  + GRADER_WEIGHT_SPEED * speed
302
  + GRADER_WEIGHT_PRECISION * precision
303
  + GRADER_WEIGHT_SLO * slo
304
  )
305
 
306
+ # Fix 3: Blast radius penalty — reward containing cascade, not just fixing it
307
+ total_services = er.total_services_in_episode or denominator
308
+ blast_ratio = er.services_affected / total_services if total_services > 0 else 0.0
309
+ blast_penalty = blast_ratio * 0.02
310
+
311
+ score = max(0.0, raw - blast_penalty)
312
  return max(0.01, min(0.99, round(score, 2)))
313
 
314
 
 
421
  prev_obs: SystemObservation,
422
  recovering: list[str],
423
  ) -> str:
424
+ """
425
+ Generate metric-delta context for the step info dict.
426
+
427
+ Reports WHAT changed (metric values and deltas), not WHETHER it was good.
428
+ The agent must interpret the numbers itself — no outcome framing.
429
+ """
430
  parts: list[str] = []
431
 
432
  if not action_valid:
433
  parts.append(
434
+ f"Action '{action.action_type}' was invalid. No state change."
 
435
  )
436
  elif wrong_action:
437
+ # Report metric context only — no interpretation
438
+ svc = action.target_service or ""
439
+ curr_er = next_obs.services[svc].http_server_error_rate if svc in next_obs.services else None
440
+ er_str = f"error_rate={curr_er:.2f}" if curr_er is not None else "error_rate=unknown"
441
  parts.append(
442
+ f"Action '{action.action_type}' targeted '{svc}' ({er_str}). "
443
+ f"Wrong-action penalty applied (threshold: 0.10)."
 
 
444
  )
445
  elif action.action_type in ("fetch_logs", "get_metrics_detail", "trace_dependencies"):
446
  parts.append(
447
+ f"Investigation '{action.action_type}' on '{action.target_service}'. "
448
+ f"No state mutation."
 
449
  )
450
+ elif action.action_type in (
451
+ "restart_service", "rollback_deploy", "revert_config",
452
+ "scale_replicas", "circuit_break",
453
+ ):
454
+ parts.append(f"Remediation '{action.action_type}' applied to '{action.target_service}'.")
455
+ # Report metric deltas — no interpretation of good/bad
456
+ if prev_obs:
457
+ for svc_name, curr in next_obs.services.items():
458
+ prev_svc = prev_obs.services.get(svc_name)
459
+ if prev_svc:
460
+ delta = curr.http_server_error_rate - prev_svc.http_server_error_rate
461
+ if abs(delta) > 0.05:
462
+ direction = "increased" if delta > 0 else "decreased"
463
+ parts.append(
464
+ f"{svc_name} error_rate {direction} by {abs(delta):.2f} "
465
+ f"(now {curr.http_server_error_rate:.2f})."
466
+ )
467
  elif action.action_type == "declare_resolved":
468
+ parts.append("Agent declared incident resolved. Episode ending.")
469
  elif action.action_type == "escalate":
470
+ parts.append("Agent escalated incident.")
 
 
 
471
 
472
+ # Current state counts — factual only
473
+ degraded_count = sum(1 for m in next_obs.services.values() if m.status != "healthy")
 
 
474
  total = len(next_obs.services)
475
+ parts.append(f"{degraded_count}/{total} services non-healthy.")
476
+
477
+ if feedback:
478
+ parts.append(f"Feedback: {feedback}")
 
 
 
 
 
 
 
479
 
480
+ return " ".join(parts) if parts else "No significant changes this tick."
481
 
482
 
483
  def _assess_progress(obs: SystemObservation, done: bool) -> str:
 
505
  # Helper
506
  # ==========================================================================
507
 
508
+ def _weighted_mean_error_rate(services: dict, dependency_graph: dict) -> float:
509
+ """
510
+ Compute mean error rate across services, weighted by downstream dependent count.
511
+
512
+ Weight formula: weight(svc) = 1 + count(other services that list svc as a dependency)
513
+ Example: api-gateway with 3 dependents → weight=4; cache leaf → weight=1.
514
+
515
+ Args:
516
+ services: Dict mapping service_name → ServiceMetrics (must have http_server_error_rate).
517
+ dependency_graph: Dict mapping service_name → list[dependency_name].
518
+
519
+ Returns:
520
+ Weighted mean error rate in [0.0, 1.0].
521
+ """
522
+ if not services:
523
+ return 0.0
524
+
525
+ # Count how many services-in-this-episode depend on each service
526
+ dependent_count: dict[str, int] = {svc: 0 for svc in services}
527
+ for svc, deps in dependency_graph.items():
528
+ if svc in services:
529
+ for dep in deps:
530
+ if dep in dependent_count:
531
+ dependent_count[dep] = dependent_count.get(dep, 0) + 1
532
+
533
+ total_weight = 0.0
534
+ weighted_error = 0.0
535
+ for svc_name, metrics in services.items():
536
+ weight = 1 + dependent_count.get(svc_name, 0)
537
+ weighted_error += metrics.http_server_error_rate * weight
538
+ total_weight += weight
539
+
540
+ return weighted_error / total_weight if total_weight > 0 else 0.0
541
+
542
+
543
  def _mean_error_rate(obs: SystemObservation) -> float:
544
  """Compute mean error rate across all services in observation."""
545
  services = obs.services
simulation.py CHANGED
@@ -106,13 +106,18 @@ class IncidentMetrics:
106
  bad_customer_minutes: float = 0.0
107
  mttm_achieved_tick: int | None = None
108
  _mttm_locked: bool = field(default=False, repr=False)
 
109
 
110
  def update(self, bcm_delta: float, current_tick: int) -> None:
111
- """Update BCM and check MTTM achievement."""
112
  self.bad_customer_minutes += bcm_delta
113
- if bcm_delta <= 0.0 and not self._mttm_locked and current_tick > 0:
114
- self.mttm_achieved_tick = current_tick
115
- self._mttm_locked = True
 
 
 
 
116
 
117
 
118
  # ==========================================================================
@@ -709,6 +714,29 @@ def generate_episode(
709
  return mesh, fault_config
710
 
711
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
712
  # ==========================================================================
713
  # Public API
714
  # ==========================================================================
@@ -718,4 +746,5 @@ __all__ = [
718
  "IncidentMetrics",
719
  "ServiceMesh",
720
  "generate_episode",
 
721
  ]
 
106
  bad_customer_minutes: float = 0.0
107
  mttm_achieved_tick: int | None = None
108
  _mttm_locked: bool = field(default=False, repr=False)
109
+ _zero_bcm_streak: int = field(default=0, repr=False)
110
 
111
  def update(self, bcm_delta: float, current_tick: int) -> None:
112
+ """Update BCM and check MTTM achievement (requires 3 consecutive zero-BCM ticks)."""
113
  self.bad_customer_minutes += bcm_delta
114
+ if bcm_delta <= 0.0 and current_tick > 0:
115
+ self._zero_bcm_streak += 1
116
+ if self._zero_bcm_streak >= 3 and not self._mttm_locked:
117
+ self.mttm_achieved_tick = current_tick - 2
118
+ self._mttm_locked = True
119
+ else:
120
+ self._zero_bcm_streak = 0
121
 
122
 
123
  # ==========================================================================
 
714
  return mesh, fault_config
715
 
716
 
717
+ def _count_blast_radius(mesh: "ServiceMesh", fault_config: "FaultConfig") -> int:
718
+ """
719
+ Count services that will be affected by this fault at full cascade propagation.
720
+
721
+ Uses BFS through the dependency graph from root cause service.
722
+ Used as static denominator in grade() to prevent tick-0 exploit.
723
+
724
+ Returns:
725
+ max(1, number of services reachable from root cause within CASCADE_MAX_DEPTH hops)
726
+ """
727
+ affected: set[str] = {fault_config.root_cause_service}
728
+ frontier: list[str] = [fault_config.root_cause_service]
729
+ for _ in range(CASCADE_MAX_DEPTH):
730
+ next_frontier: list[str] = []
731
+ for svc in frontier:
732
+ for downstream, deps in mesh.dependency_graph.items():
733
+ if svc in deps and downstream not in affected:
734
+ affected.add(downstream)
735
+ next_frontier.append(downstream)
736
+ frontier = next_frontier
737
+ return max(1, len(affected))
738
+
739
+
740
  # ==========================================================================
741
  # Public API
742
  # ==========================================================================
 
746
  "IncidentMetrics",
747
  "ServiceMesh",
748
  "generate_episode",
749
+ "_count_blast_radius",
750
  ]
tests/test_rewards_fixes.py CHANGED
@@ -5,7 +5,7 @@ All 5 tests must pass after implementing fixes 1–6.
5
  import types
6
  import pytest
7
 
8
- from firewatch_env.rewards import grade, EpisodeResult, _weighted_mean_error_rate
9
 
10
 
11
  def _er(affected, recovered, ticks, wrong, slo, bcm, static=None, total=None):
@@ -70,6 +70,7 @@ def test_blast_radius_fast_agent_scores_higher():
70
 
71
  def test_weighted_mean_error_rate_weights_by_dependents():
72
  """api-gateway with 3 dependents dominates over a leaf service."""
 
73
  def _svc(er):
74
  return types.SimpleNamespace(http_server_error_rate=er)
75
 
@@ -108,3 +109,31 @@ def test_variance_check():
108
  assert zero < 0.10, f"zero={zero:.3f}, expected < 0.10"
109
  assert 0.10 <= wrong <= 0.60, f"wrong={wrong:.3f}, expected in [0.10, 0.60]"
110
  assert perfect - zero >= 0.50, f"gap={perfect - zero:.3f}, expected >= 0.50"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import types
6
  import pytest
7
 
8
+ from firewatch_env.rewards import grade, EpisodeResult
9
 
10
 
11
  def _er(affected, recovered, ticks, wrong, slo, bcm, static=None, total=None):
 
70
 
71
  def test_weighted_mean_error_rate_weights_by_dependents():
72
  """api-gateway with 3 dependents dominates over a leaf service."""
73
+ from firewatch_env.rewards import _weighted_mean_error_rate # added in Task 4
74
  def _svc(er):
75
  return types.SimpleNamespace(http_server_error_rate=er)
76
 
 
109
  assert zero < 0.10, f"zero={zero:.3f}, expected < 0.10"
110
  assert 0.10 <= wrong <= 0.60, f"wrong={wrong:.3f}, expected in [0.10, 0.60]"
111
  assert perfect - zero >= 0.50, f"gap={perfect - zero:.3f}, expected >= 0.50"
112
+
113
+
114
+ # ── Fix 4: MTTM requires 3 consecutive zero-BCM ticks ─────────────────────
115
+
116
+ def test_mttm_requires_3_consecutive_zero_bcm_ticks():
117
+ """MTTM must not be granted until 3 consecutive ticks with bcm_delta == 0."""
118
+ from firewatch_env.simulation import IncidentMetrics
119
+ m = IncidentMetrics()
120
+ m.update(bcm_delta=1.0, current_tick=1) # BCM still moving
121
+ m.update(bcm_delta=0.0, current_tick=2) # streak=1
122
+ m.update(bcm_delta=0.0, current_tick=3) # streak=2
123
+ assert m.mttm_achieved_tick is None, "must not grant MTTM after only 2 consecutive zeros"
124
+ m.update(bcm_delta=0.0, current_tick=4) # streak=3 → granted at tick 4-2=2
125
+ assert m.mttm_achieved_tick == 2, f"expected mttm_achieved_tick=2, got {m.mttm_achieved_tick}"
126
+
127
+
128
+ def test_mttm_streak_resets_on_nonzero():
129
+ """A non-zero BCM tick must reset the streak — MTTM only after 3 unbroken zeros."""
130
+ from firewatch_env.simulation import IncidentMetrics
131
+ m = IncidentMetrics()
132
+ m.update(bcm_delta=0.0, current_tick=1) # streak=1
133
+ m.update(bcm_delta=0.0, current_tick=2) # streak=2
134
+ m.update(bcm_delta=1.0, current_tick=3) # non-zero resets streak
135
+ m.update(bcm_delta=0.0, current_tick=4) # streak=1 again
136
+ m.update(bcm_delta=0.0, current_tick=5) # streak=2
137
+ assert m.mttm_achieved_tick is None, "streak was reset; MTTM must not be granted yet"
138
+ m.update(bcm_delta=0.0, current_tick=6) # streak=3 → granted at tick 6-2=4
139
+ assert m.mttm_achieved_tick == 4, f"expected mttm_achieved_tick=4, got {m.mttm_achieved_tick}"