Spaces:
Running
Running
fix(SPEC-7): fixes 1-6 — tick guard, cliff removal, blast penalty, MTTM stability, weighted error rate, metric-delta semantic analysis
Browse files- grade(): tick guard returns 0.05 when ticks_taken < 2 (Fix 1)
- grade(): remove early-exit cliff wipe of bcm_score/slo (Fix 2)
- grade(): blast_ratio * 0.02 penalty rewards cascade containment (Fix 3)
- IncidentMetrics: require 3 consecutive zero-BCM ticks for MTTM (Fix 4)
- _build_semantic_analysis(): report metric deltas only, no outcome framing (Fix 5)
- _weighted_mean_error_rate(): weights by downstream dependent count (Fix 6)
- EpisodeResult: services_affected_static + total_services_in_episode fields
- _count_blast_radius(): BFS from root cause through reverse dependency graph
- rewards.py +100 -67
- simulation.py +33 -4
- tests/test_rewards_fixes.py +30 -1
rewards.py
CHANGED
|
@@ -96,9 +96,9 @@ class RewardEngine:
|
|
| 96 |
Returns:
|
| 97 |
Tuple of (total_reward, breakdown_dict).
|
| 98 |
"""
|
| 99 |
-
# 1. Health improvement: mean error rate decrease
|
| 100 |
-
prev_mean =
|
| 101 |
-
next_mean =
|
| 102 |
health_improvement = (prev_mean - next_mean) * REWARD_WEIGHT_HEALTH
|
| 103 |
|
| 104 |
# 2. SLO preservation: budget change
|
|
@@ -257,27 +257,25 @@ def grade(episode_result: EpisodeResult, difficulty: str) -> float:
|
|
| 257 |
if task is None:
|
| 258 |
return 0.0
|
| 259 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
max_ticks = task.max_ticks
|
| 261 |
max_bcm = task.max_bad_customer_minutes
|
| 262 |
|
| 263 |
# 1. Recovery (40%)
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
|
|
|
| 268 |
|
| 269 |
-
#
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
slo = 0.0
|
| 273 |
-
else:
|
| 274 |
-
# BCM score: total user impact relative to worst case
|
| 275 |
-
bcm_score = max(0.0, 1.0 - (er.bad_customer_minutes / max_bcm))
|
| 276 |
-
# SLO (15%) — budget remaining
|
| 277 |
-
slo = max(0.0, min(1.0, er.final_slo_budget / 100.0))
|
| 278 |
|
| 279 |
# 2. Speed (25%) — composite of MTTM + BCM
|
| 280 |
-
# MTTM score: how quickly user impact was zeroed
|
| 281 |
if er.mttm_ticks is not None:
|
| 282 |
mttm_score = max(0.0, 1.0 - (er.mttm_ticks / max_ticks))
|
| 283 |
else:
|
|
@@ -292,19 +290,25 @@ def grade(episode_result: EpisodeResult, difficulty: str) -> float:
|
|
| 292 |
precision = max(
|
| 293 |
0.0, 1.0 - (er.wrong_actions * GRADER_WRONG_ACTION_PENALTY_PER_ACTION)
|
| 294 |
)
|
| 295 |
-
|
| 296 |
# False resolution penalty
|
| 297 |
if recovery == 0.0:
|
| 298 |
precision = 0.0 # doing nothing then exiting is inherently imprecise
|
| 299 |
|
| 300 |
-
#
|
| 301 |
-
|
| 302 |
GRADER_WEIGHT_RECOVERY * recovery
|
| 303 |
+ GRADER_WEIGHT_SPEED * speed
|
| 304 |
+ GRADER_WEIGHT_PRECISION * precision
|
| 305 |
+ GRADER_WEIGHT_SLO * slo
|
| 306 |
)
|
| 307 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
return max(0.01, min(0.99, round(score, 2)))
|
| 309 |
|
| 310 |
|
|
@@ -417,69 +421,63 @@ def _build_semantic_analysis(
|
|
| 417 |
prev_obs: SystemObservation,
|
| 418 |
recovering: list[str],
|
| 419 |
) -> str:
|
| 420 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 421 |
parts: list[str] = []
|
| 422 |
|
| 423 |
if not action_valid:
|
| 424 |
parts.append(
|
| 425 |
-
f"
|
| 426 |
-
f"invalid. No system state was modified."
|
| 427 |
)
|
| 428 |
elif wrong_action:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 429 |
parts.append(
|
| 430 |
-
f"
|
| 431 |
-
f"
|
| 432 |
-
f"This indicates premature remediation before sufficient "
|
| 433 |
-
f"investigation. The actual root cause remains unaddressed."
|
| 434 |
)
|
| 435 |
elif action.action_type in ("fetch_logs", "get_metrics_detail", "trace_dependencies"):
|
| 436 |
parts.append(
|
| 437 |
-
f"
|
| 438 |
-
f"
|
| 439 |
-
f"step that does not modify system state."
|
| 440 |
)
|
| 441 |
-
elif action.action_type in (
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
)
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
|
|
|
| 457 |
elif action.action_type == "declare_resolved":
|
| 458 |
-
parts.append("Agent declared
|
| 459 |
elif action.action_type == "escalate":
|
| 460 |
-
parts.append(
|
| 461 |
-
"Agent escalated the incident. This costs SLO budget but "
|
| 462 |
-
"brings specialist attention."
|
| 463 |
-
)
|
| 464 |
|
| 465 |
-
#
|
| 466 |
-
degraded_count = sum(
|
| 467 |
-
1 for m in next_obs.services.values() if m.status != "healthy"
|
| 468 |
-
)
|
| 469 |
total = len(next_obs.services)
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
parts.append(
|
| 474 |
-
"All services are degraded — situation is critical. "
|
| 475 |
-
"Immediate action required."
|
| 476 |
-
)
|
| 477 |
-
else:
|
| 478 |
-
parts.append(
|
| 479 |
-
f"{degraded_count}/{total} services remain degraded."
|
| 480 |
-
)
|
| 481 |
|
| 482 |
-
return " ".join(parts)
|
| 483 |
|
| 484 |
|
| 485 |
def _assess_progress(obs: SystemObservation, done: bool) -> str:
|
|
@@ -507,6 +505,41 @@ def _assess_progress(obs: SystemObservation, done: bool) -> str:
|
|
| 507 |
# Helper
|
| 508 |
# ==========================================================================
|
| 509 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 510 |
def _mean_error_rate(obs: SystemObservation) -> float:
|
| 511 |
"""Compute mean error rate across all services in observation."""
|
| 512 |
services = obs.services
|
|
|
|
| 96 |
Returns:
|
| 97 |
Tuple of (total_reward, breakdown_dict).
|
| 98 |
"""
|
| 99 |
+
# 1. Health improvement: weighted mean error rate decrease
|
| 100 |
+
prev_mean = _weighted_mean_error_rate(prev_obs.services, prev_obs.dependency_graph)
|
| 101 |
+
next_mean = _weighted_mean_error_rate(next_obs.services, next_obs.dependency_graph)
|
| 102 |
health_improvement = (prev_mean - next_mean) * REWARD_WEIGHT_HEALTH
|
| 103 |
|
| 104 |
# 2. SLO preservation: budget change
|
|
|
|
| 257 |
if task is None:
|
| 258 |
return 0.0
|
| 259 |
|
| 260 |
+
# Fix 1: Tick guard — declare_resolved before tick 2 earns near-zero score
|
| 261 |
+
if er.ticks_taken < 2:
|
| 262 |
+
return 0.05
|
| 263 |
+
|
| 264 |
max_ticks = task.max_ticks
|
| 265 |
max_bcm = task.max_bad_customer_minutes
|
| 266 |
|
| 267 |
# 1. Recovery (40%)
|
| 268 |
+
# The tick guard above handles Fix 1 (tick-0 exploit).
|
| 269 |
+
# Use runtime services_affected as denominator — blast penalty (below) is what
|
| 270 |
+
# differentiates agents who contained vs didn't contain the cascade.
|
| 271 |
+
denominator = er.services_affected or 1
|
| 272 |
+
recovery = min(1.0, er.services_recovered / denominator)
|
| 273 |
|
| 274 |
+
# Fix 2: No cliff wipe — compute BCM and SLO unconditionally
|
| 275 |
+
bcm_score = max(0.0, 1.0 - (er.bad_customer_minutes / max_bcm))
|
| 276 |
+
slo = max(0.0, min(1.0, er.final_slo_budget / 100.0))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
|
| 278 |
# 2. Speed (25%) — composite of MTTM + BCM
|
|
|
|
| 279 |
if er.mttm_ticks is not None:
|
| 280 |
mttm_score = max(0.0, 1.0 - (er.mttm_ticks / max_ticks))
|
| 281 |
else:
|
|
|
|
| 290 |
precision = max(
|
| 291 |
0.0, 1.0 - (er.wrong_actions * GRADER_WRONG_ACTION_PENALTY_PER_ACTION)
|
| 292 |
)
|
| 293 |
+
|
| 294 |
# False resolution penalty
|
| 295 |
if recovery == 0.0:
|
| 296 |
precision = 0.0 # doing nothing then exiting is inherently imprecise
|
| 297 |
|
| 298 |
+
# Raw weighted score
|
| 299 |
+
raw = (
|
| 300 |
GRADER_WEIGHT_RECOVERY * recovery
|
| 301 |
+ GRADER_WEIGHT_SPEED * speed
|
| 302 |
+ GRADER_WEIGHT_PRECISION * precision
|
| 303 |
+ GRADER_WEIGHT_SLO * slo
|
| 304 |
)
|
| 305 |
|
| 306 |
+
# Fix 3: Blast radius penalty — reward containing cascade, not just fixing it
|
| 307 |
+
total_services = er.total_services_in_episode or denominator
|
| 308 |
+
blast_ratio = er.services_affected / total_services if total_services > 0 else 0.0
|
| 309 |
+
blast_penalty = blast_ratio * 0.02
|
| 310 |
+
|
| 311 |
+
score = max(0.0, raw - blast_penalty)
|
| 312 |
return max(0.01, min(0.99, round(score, 2)))
|
| 313 |
|
| 314 |
|
|
|
|
| 421 |
prev_obs: SystemObservation,
|
| 422 |
recovering: list[str],
|
| 423 |
) -> str:
|
| 424 |
+
"""
|
| 425 |
+
Generate metric-delta context for the step info dict.
|
| 426 |
+
|
| 427 |
+
Reports WHAT changed (metric values and deltas), not WHETHER it was good.
|
| 428 |
+
The agent must interpret the numbers itself — no outcome framing.
|
| 429 |
+
"""
|
| 430 |
parts: list[str] = []
|
| 431 |
|
| 432 |
if not action_valid:
|
| 433 |
parts.append(
|
| 434 |
+
f"Action '{action.action_type}' was invalid. No state change."
|
|
|
|
| 435 |
)
|
| 436 |
elif wrong_action:
|
| 437 |
+
# Report metric context only — no interpretation
|
| 438 |
+
svc = action.target_service or ""
|
| 439 |
+
curr_er = next_obs.services[svc].http_server_error_rate if svc in next_obs.services else None
|
| 440 |
+
er_str = f"error_rate={curr_er:.2f}" if curr_er is not None else "error_rate=unknown"
|
| 441 |
parts.append(
|
| 442 |
+
f"Action '{action.action_type}' targeted '{svc}' ({er_str}). "
|
| 443 |
+
f"Wrong-action penalty applied (threshold: 0.10)."
|
|
|
|
|
|
|
| 444 |
)
|
| 445 |
elif action.action_type in ("fetch_logs", "get_metrics_detail", "trace_dependencies"):
|
| 446 |
parts.append(
|
| 447 |
+
f"Investigation '{action.action_type}' on '{action.target_service}'. "
|
| 448 |
+
f"No state mutation."
|
|
|
|
| 449 |
)
|
| 450 |
+
elif action.action_type in (
|
| 451 |
+
"restart_service", "rollback_deploy", "revert_config",
|
| 452 |
+
"scale_replicas", "circuit_break",
|
| 453 |
+
):
|
| 454 |
+
parts.append(f"Remediation '{action.action_type}' applied to '{action.target_service}'.")
|
| 455 |
+
# Report metric deltas — no interpretation of good/bad
|
| 456 |
+
if prev_obs:
|
| 457 |
+
for svc_name, curr in next_obs.services.items():
|
| 458 |
+
prev_svc = prev_obs.services.get(svc_name)
|
| 459 |
+
if prev_svc:
|
| 460 |
+
delta = curr.http_server_error_rate - prev_svc.http_server_error_rate
|
| 461 |
+
if abs(delta) > 0.05:
|
| 462 |
+
direction = "increased" if delta > 0 else "decreased"
|
| 463 |
+
parts.append(
|
| 464 |
+
f"{svc_name} error_rate {direction} by {abs(delta):.2f} "
|
| 465 |
+
f"(now {curr.http_server_error_rate:.2f})."
|
| 466 |
+
)
|
| 467 |
elif action.action_type == "declare_resolved":
|
| 468 |
+
parts.append("Agent declared incident resolved. Episode ending.")
|
| 469 |
elif action.action_type == "escalate":
|
| 470 |
+
parts.append("Agent escalated incident.")
|
|
|
|
|
|
|
|
|
|
| 471 |
|
| 472 |
+
# Current state counts — factual only
|
| 473 |
+
degraded_count = sum(1 for m in next_obs.services.values() if m.status != "healthy")
|
|
|
|
|
|
|
| 474 |
total = len(next_obs.services)
|
| 475 |
+
parts.append(f"{degraded_count}/{total} services non-healthy.")
|
| 476 |
+
|
| 477 |
+
if feedback:
|
| 478 |
+
parts.append(f"Feedback: {feedback}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 479 |
|
| 480 |
+
return " ".join(parts) if parts else "No significant changes this tick."
|
| 481 |
|
| 482 |
|
| 483 |
def _assess_progress(obs: SystemObservation, done: bool) -> str:
|
|
|
|
| 505 |
# Helper
|
| 506 |
# ==========================================================================
|
| 507 |
|
| 508 |
+
def _weighted_mean_error_rate(services: dict, dependency_graph: dict) -> float:
|
| 509 |
+
"""
|
| 510 |
+
Compute mean error rate across services, weighted by downstream dependent count.
|
| 511 |
+
|
| 512 |
+
Weight formula: weight(svc) = 1 + count(other services that list svc as a dependency)
|
| 513 |
+
Example: api-gateway with 3 dependents → weight=4; cache leaf → weight=1.
|
| 514 |
+
|
| 515 |
+
Args:
|
| 516 |
+
services: Dict mapping service_name → ServiceMetrics (must have http_server_error_rate).
|
| 517 |
+
dependency_graph: Dict mapping service_name → list[dependency_name].
|
| 518 |
+
|
| 519 |
+
Returns:
|
| 520 |
+
Weighted mean error rate in [0.0, 1.0].
|
| 521 |
+
"""
|
| 522 |
+
if not services:
|
| 523 |
+
return 0.0
|
| 524 |
+
|
| 525 |
+
# Count how many services-in-this-episode depend on each service
|
| 526 |
+
dependent_count: dict[str, int] = {svc: 0 for svc in services}
|
| 527 |
+
for svc, deps in dependency_graph.items():
|
| 528 |
+
if svc in services:
|
| 529 |
+
for dep in deps:
|
| 530 |
+
if dep in dependent_count:
|
| 531 |
+
dependent_count[dep] = dependent_count.get(dep, 0) + 1
|
| 532 |
+
|
| 533 |
+
total_weight = 0.0
|
| 534 |
+
weighted_error = 0.0
|
| 535 |
+
for svc_name, metrics in services.items():
|
| 536 |
+
weight = 1 + dependent_count.get(svc_name, 0)
|
| 537 |
+
weighted_error += metrics.http_server_error_rate * weight
|
| 538 |
+
total_weight += weight
|
| 539 |
+
|
| 540 |
+
return weighted_error / total_weight if total_weight > 0 else 0.0
|
| 541 |
+
|
| 542 |
+
|
| 543 |
def _mean_error_rate(obs: SystemObservation) -> float:
|
| 544 |
"""Compute mean error rate across all services in observation."""
|
| 545 |
services = obs.services
|
simulation.py
CHANGED
|
@@ -106,13 +106,18 @@ class IncidentMetrics:
|
|
| 106 |
bad_customer_minutes: float = 0.0
|
| 107 |
mttm_achieved_tick: int | None = None
|
| 108 |
_mttm_locked: bool = field(default=False, repr=False)
|
|
|
|
| 109 |
|
| 110 |
def update(self, bcm_delta: float, current_tick: int) -> None:
|
| 111 |
-
"""Update BCM and check MTTM achievement."""
|
| 112 |
self.bad_customer_minutes += bcm_delta
|
| 113 |
-
if bcm_delta <= 0.0 and
|
| 114 |
-
self.
|
| 115 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
|
| 118 |
# ==========================================================================
|
|
@@ -709,6 +714,29 @@ def generate_episode(
|
|
| 709 |
return mesh, fault_config
|
| 710 |
|
| 711 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 712 |
# ==========================================================================
|
| 713 |
# Public API
|
| 714 |
# ==========================================================================
|
|
@@ -718,4 +746,5 @@ __all__ = [
|
|
| 718 |
"IncidentMetrics",
|
| 719 |
"ServiceMesh",
|
| 720 |
"generate_episode",
|
|
|
|
| 721 |
]
|
|
|
|
| 106 |
bad_customer_minutes: float = 0.0
|
| 107 |
mttm_achieved_tick: int | None = None
|
| 108 |
_mttm_locked: bool = field(default=False, repr=False)
|
| 109 |
+
_zero_bcm_streak: int = field(default=0, repr=False)
|
| 110 |
|
| 111 |
def update(self, bcm_delta: float, current_tick: int) -> None:
|
| 112 |
+
"""Update BCM and check MTTM achievement (requires 3 consecutive zero-BCM ticks)."""
|
| 113 |
self.bad_customer_minutes += bcm_delta
|
| 114 |
+
if bcm_delta <= 0.0 and current_tick > 0:
|
| 115 |
+
self._zero_bcm_streak += 1
|
| 116 |
+
if self._zero_bcm_streak >= 3 and not self._mttm_locked:
|
| 117 |
+
self.mttm_achieved_tick = current_tick - 2
|
| 118 |
+
self._mttm_locked = True
|
| 119 |
+
else:
|
| 120 |
+
self._zero_bcm_streak = 0
|
| 121 |
|
| 122 |
|
| 123 |
# ==========================================================================
|
|
|
|
| 714 |
return mesh, fault_config
|
| 715 |
|
| 716 |
|
| 717 |
+
def _count_blast_radius(mesh: "ServiceMesh", fault_config: "FaultConfig") -> int:
|
| 718 |
+
"""
|
| 719 |
+
Count services that will be affected by this fault at full cascade propagation.
|
| 720 |
+
|
| 721 |
+
Uses BFS through the dependency graph from root cause service.
|
| 722 |
+
Used as static denominator in grade() to prevent tick-0 exploit.
|
| 723 |
+
|
| 724 |
+
Returns:
|
| 725 |
+
max(1, number of services reachable from root cause within CASCADE_MAX_DEPTH hops)
|
| 726 |
+
"""
|
| 727 |
+
affected: set[str] = {fault_config.root_cause_service}
|
| 728 |
+
frontier: list[str] = [fault_config.root_cause_service]
|
| 729 |
+
for _ in range(CASCADE_MAX_DEPTH):
|
| 730 |
+
next_frontier: list[str] = []
|
| 731 |
+
for svc in frontier:
|
| 732 |
+
for downstream, deps in mesh.dependency_graph.items():
|
| 733 |
+
if svc in deps and downstream not in affected:
|
| 734 |
+
affected.add(downstream)
|
| 735 |
+
next_frontier.append(downstream)
|
| 736 |
+
frontier = next_frontier
|
| 737 |
+
return max(1, len(affected))
|
| 738 |
+
|
| 739 |
+
|
| 740 |
# ==========================================================================
|
| 741 |
# Public API
|
| 742 |
# ==========================================================================
|
|
|
|
| 746 |
"IncidentMetrics",
|
| 747 |
"ServiceMesh",
|
| 748 |
"generate_episode",
|
| 749 |
+
"_count_blast_radius",
|
| 750 |
]
|
tests/test_rewards_fixes.py
CHANGED
|
@@ -5,7 +5,7 @@ All 5 tests must pass after implementing fixes 1–6.
|
|
| 5 |
import types
|
| 6 |
import pytest
|
| 7 |
|
| 8 |
-
from firewatch_env.rewards import grade, EpisodeResult
|
| 9 |
|
| 10 |
|
| 11 |
def _er(affected, recovered, ticks, wrong, slo, bcm, static=None, total=None):
|
|
@@ -70,6 +70,7 @@ def test_blast_radius_fast_agent_scores_higher():
|
|
| 70 |
|
| 71 |
def test_weighted_mean_error_rate_weights_by_dependents():
|
| 72 |
"""api-gateway with 3 dependents dominates over a leaf service."""
|
|
|
|
| 73 |
def _svc(er):
|
| 74 |
return types.SimpleNamespace(http_server_error_rate=er)
|
| 75 |
|
|
@@ -108,3 +109,31 @@ def test_variance_check():
|
|
| 108 |
assert zero < 0.10, f"zero={zero:.3f}, expected < 0.10"
|
| 109 |
assert 0.10 <= wrong <= 0.60, f"wrong={wrong:.3f}, expected in [0.10, 0.60]"
|
| 110 |
assert perfect - zero >= 0.50, f"gap={perfect - zero:.3f}, expected >= 0.50"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
import types
|
| 6 |
import pytest
|
| 7 |
|
| 8 |
+
from firewatch_env.rewards import grade, EpisodeResult
|
| 9 |
|
| 10 |
|
| 11 |
def _er(affected, recovered, ticks, wrong, slo, bcm, static=None, total=None):
|
|
|
|
| 70 |
|
| 71 |
def test_weighted_mean_error_rate_weights_by_dependents():
|
| 72 |
"""api-gateway with 3 dependents dominates over a leaf service."""
|
| 73 |
+
from firewatch_env.rewards import _weighted_mean_error_rate # added in Task 4
|
| 74 |
def _svc(er):
|
| 75 |
return types.SimpleNamespace(http_server_error_rate=er)
|
| 76 |
|
|
|
|
| 109 |
assert zero < 0.10, f"zero={zero:.3f}, expected < 0.10"
|
| 110 |
assert 0.10 <= wrong <= 0.60, f"wrong={wrong:.3f}, expected in [0.10, 0.60]"
|
| 111 |
assert perfect - zero >= 0.50, f"gap={perfect - zero:.3f}, expected >= 0.50"
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# ── Fix 4: MTTM requires 3 consecutive zero-BCM ticks ─────────────────────
|
| 115 |
+
|
| 116 |
+
def test_mttm_requires_3_consecutive_zero_bcm_ticks():
|
| 117 |
+
"""MTTM must not be granted until 3 consecutive ticks with bcm_delta == 0."""
|
| 118 |
+
from firewatch_env.simulation import IncidentMetrics
|
| 119 |
+
m = IncidentMetrics()
|
| 120 |
+
m.update(bcm_delta=1.0, current_tick=1) # BCM still moving
|
| 121 |
+
m.update(bcm_delta=0.0, current_tick=2) # streak=1
|
| 122 |
+
m.update(bcm_delta=0.0, current_tick=3) # streak=2
|
| 123 |
+
assert m.mttm_achieved_tick is None, "must not grant MTTM after only 2 consecutive zeros"
|
| 124 |
+
m.update(bcm_delta=0.0, current_tick=4) # streak=3 → granted at tick 4-2=2
|
| 125 |
+
assert m.mttm_achieved_tick == 2, f"expected mttm_achieved_tick=2, got {m.mttm_achieved_tick}"
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def test_mttm_streak_resets_on_nonzero():
|
| 129 |
+
"""A non-zero BCM tick must reset the streak — MTTM only after 3 unbroken zeros."""
|
| 130 |
+
from firewatch_env.simulation import IncidentMetrics
|
| 131 |
+
m = IncidentMetrics()
|
| 132 |
+
m.update(bcm_delta=0.0, current_tick=1) # streak=1
|
| 133 |
+
m.update(bcm_delta=0.0, current_tick=2) # streak=2
|
| 134 |
+
m.update(bcm_delta=1.0, current_tick=3) # non-zero resets streak
|
| 135 |
+
m.update(bcm_delta=0.0, current_tick=4) # streak=1 again
|
| 136 |
+
m.update(bcm_delta=0.0, current_tick=5) # streak=2
|
| 137 |
+
assert m.mttm_achieved_tick is None, "streak was reset; MTTM must not be granted yet"
|
| 138 |
+
m.update(bcm_delta=0.0, current_tick=6) # streak=3 → granted at tick 6-2=4
|
| 139 |
+
assert m.mttm_achieved_tick == 4, f"expected mttm_achieved_tick=4, got {m.mttm_achieved_tick}"
|