Meta-SRE / app /engine /manager.py
Anvit25's picture
Deploy Meta-SRE OpenEnv benchmark FastAPI server
ad6248e
"""
Layer 3 – Task Grader, Reward Manager, and Episode Orchestrator.
TaskGrader : checks whether the current VFS state passes the hidden tests.
RewardManager: computes per-step and terminal rewards.
EpisodeManager: ties Layers 1-2-3 together and drives the OpenEnv step loop.
"""
from __future__ import annotations
import uuid
import time
from typing import Any, Dict, List, Optional, Tuple
from app.models import (
Observation, ActionResult, EpisodeResult, IncidentReport,
Alert, ServiceMetrics,
)
from app.engine.sandbox import VirtualFileSystem
from app.engine.observability import MetricsEngine, DifficultyController
# ---------------------------------------------------------------------------
# Task definitions
# ---------------------------------------------------------------------------
TASK_DEFINITIONS: Dict[int, Dict] = {
1: {
"description": (
"INCIDENT: All /rank requests are returning HTTP 500. "
"The ad-ranking service is crashing on every call. "
"Find and fix the bug in ad_ranking/ranker.py."
),
"sla_budget": 15,
"difficulty": "easy",
"bug_category": "data_corruption",
"affected_services": ["ad_ranking"],
},
2: {
"description": (
"INCIDENT: ROAS (Return on Ad Spend) has dropped 68% vs last week. "
"No services are crashing. Ad-ranking allocation decisions appear to be "
"based on conversion data from 1970. Trace the root cause."
),
"sla_budget": 20,
"difficulty": "medium",
"bug_category": "data_corruption",
"affected_services": ["capi_pipeline", "ad_ranking"],
},
3: {
"description": (
"INCIDENT: WhatsApp message sync works fine under normal load but "
"hangs under peak traffic (>50 concurrent users). DB connection pool "
"is exhausted. Fix the resource leak."
),
"sla_budget": 20,
"difficulty": "medium-hard",
"bug_category": "async_bugs",
"affected_services": ["whatsapp_sync"],
},
4: {
"description": (
"INCIDENT: Three services degraded simultaneously after the 02:14 UTC deploy. "
"Multiple P1 alerts are firing. Find the single root cause and fix it — "
"do NOT chase individual service symptoms."
),
"sla_budget": 25,
"difficulty": "hard",
"bug_category": "red_herrings",
"affected_services": ["whatsapp_sync", "ad_ranking", "capi_pipeline"],
},
5: {
"description": (
"INCIDENT: Security scan flagged unusual /ingest response sizes. "
"Standard unit tests all pass. Find and close the data-exposure vulnerability "
"in the CAPI ingestor. Write a P0 incident report."
),
"sla_budget": 20,
"difficulty": "hard",
"bug_category": "security_bugs",
"affected_services": ["capi_pipeline"],
},
}
DEPENDENCY_GRAPH: Dict[str, List[str]] = {
"ad_ranking": ["capi_pipeline"],
"capi_pipeline": [],
"whatsapp_sync": ["capi_pipeline"],
}
# ---------------------------------------------------------------------------
# Hidden graders — one per task
# ---------------------------------------------------------------------------
class TaskGrader:
"""
Checks the VFS content against hidden test criteria.
Returns (passed, test_output_string, partial_score 0-1).
"""
def __init__(self, vfs: VirtualFileSystem):
self.vfs = vfs
def run(self, task_id: int, suite: str = "unit") -> Tuple[bool, str, float]:
graders = {
1: self._grade_task1,
2: self._grade_task2,
3: self._grade_task3,
4: self._grade_task4,
5: self._grade_task5,
}
fn = graders.get(task_id)
if fn is None:
return False, "Unknown task", 0.0
return fn(suite)
# ------------------------------------------------------------------
# Task 1 – fix ad.get_clicks() → ad.get('clicks', 0)
# ------------------------------------------------------------------
def _grade_task1(self, suite: str) -> Tuple[bool, str, float]:
_, content = self.vfs.read_file("ad_ranking", "ranker.py")
has_bug = "ad.get_clicks()" in content
has_fix = "ad.get('clicks'" in content or "ad['clicks']" in content
if has_bug:
return False, (
"FAIL [unit] test_score_ads:\n"
" AttributeError: 'dict' object has no attribute 'get_clicks'\n"
" Line 22 still contains ad.get_clicks()\n"
" 1 test failed, 0 passed"
), 0.0
if has_fix:
return True, (
"PASS [unit] test_score_ads: OK\n"
"PASS [unit] test_rank_returns_sorted_list: OK\n"
"PASS [unit] test_fetch_candidate_ads: OK\n"
"3 tests passed in 0.04 s"
), 1.0
return False, (
"FAIL [unit] test_score_ads:\n"
" Fix applied but ad click-rate accessor is incorrect.\n"
" Expected: ad.get('clicks', 0) or ad['clicks']\n"
" 1 test failed"
), 0.2
# ------------------------------------------------------------------
# Task 2 – fix timestamp threshold 1_000_000_000 → 1_000_000_000_000
# ------------------------------------------------------------------
def _grade_task2(self, suite: str) -> Tuple[bool, str, float]:
_, content = self.vfs.read_file("capi_pipeline", "transformer.py")
has_bug = "1_000_000_000:" in content or "1000000000:" in content
# Only count as fixed if the bug line is gone AND the correct threshold is in code
# (not just in comments — the comment already contains 1_000_000_000_000)
code_lines = [l for l in content.splitlines() if not l.strip().startswith("#")]
code_only = "\n".join(code_lines)
has_fix = not has_bug and (
"1_000_000_000_000" in code_only or
"1000000000000" in code_only or
"1e12" in code_only or
"10**12" in code_only
)
if suite == "unit" and not has_bug:
# Unit tests always pass because they don't check timestamp edge cases
return True, (
"PASS [unit] test_transform_purchase: OK\n"
"PASS [unit] test_batch_transform: OK\n"
"2 tests passed"
), 0.4
if suite == "integration":
if has_fix:
return True, (
"PASS [integration] test_timestamp_normalisation: OK\n"
" event_time 1700000000 → 1700000000 ✓\n"
" event_time 1700000000000 → 1700000000 ✓\n"
"PASS [integration] test_roas_attribution_accuracy: OK\n"
" ROAS attribution error: 0.2% (threshold: 5%)\n"
"2 tests passed"
), 1.0
else:
return False, (
"FAIL [integration] test_timestamp_normalisation:\n"
" event_time 1700000000 → 1700000 (expected: 1700000000)\n"
" Timestamps are being divided by 1000 incorrectly.\n"
" Root cause: threshold condition in _normalize_timestamp()\n"
"1 test failed"
), 0.0
# Default: run integration test
return self._grade_task2("integration")
# ------------------------------------------------------------------
# Task 3 – add finally: await self.db_pool.release(conn)
# ------------------------------------------------------------------
def _grade_task3(self, suite: str) -> Tuple[bool, str, float]:
_, content = self.vfs.read_file("whatsapp_sync", "handler.py")
has_finally = "finally:" in content
has_release = "db_pool.release(conn)" in content or "release(conn)" in content
if suite == "unit":
if not has_finally:
return False, (
"PASS [unit] test_sync_messages_basic: OK\n"
"PASS [unit] test_process_queue_empty: OK\n"
"WARNING: Unit tests pass but connection leak not detectable without load test\n"
"Run: run_tests('whatsapp_sync', 'load')"
), 0.3
return True, (
"PASS [unit] test_sync_messages_basic: OK\n"
"PASS [unit] test_connection_released_on_success: OK\n"
"PASS [unit] test_connection_released_on_exception: OK\n"
"3 tests passed"
), 0.6
if suite == "load":
if has_finally and has_release:
return True, (
"PASS [load] test_100_concurrent_syncs:\n"
" Peak connections: 18/100 (nominal)\n"
" All 100 requests completed\n"
" Memory stable at 210 MB\n"
"PASS [load] test_connection_pool_not_exhausted: OK\n"
"2 load tests passed"
), 1.0
else:
return False, (
"FAIL [load] test_100_concurrent_syncs:\n"
" TooManyConnectionsError after 23 concurrent requests\n"
" Connection pool exhausted — connections not being released\n"
" Hint: Check sync_user_messages() for missing finally block\n"
"1 load test failed"
), 0.0
return self._grade_task3("load")
# ------------------------------------------------------------------
# Task 4 – rollback migration 003 (remove circular FK)
# ------------------------------------------------------------------
def _grade_task4(self, suite: str) -> Tuple[bool, str, float]:
_, content = self.vfs.read_file("whatsapp_sync", "db.py")
has_circular_fk = (
"REFERENCES message_threads" in content and
"REFERENCES messages" in content
)
migration_003_present = '"version": "003"' in content or "'version': '003'" in content
if suite == "unit":
if has_circular_fk:
return False, (
"FAIL [unit] test_migration_003:\n"
" ForeignKeyViolationError: circular FK detected\n"
" messages.thread_id → message_threads.id\n"
" message_threads.parent_message_id → messages.id\n"
" Fix: remove ALTER TABLE messages ADD COLUMN thread_id ...\n"
"1 test failed"
), 0.0
if suite == "integration":
if not has_circular_fk:
return True, (
"PASS [integration] test_all_services_recover: OK\n"
" whatsapp_sync: healthy ✓\n"
" ad_ranking: healthy ✓\n"
" capi_pipeline: healthy ✓\n"
"PASS [integration] test_message_thread_creation: OK\n"
"2 tests passed"
), 1.0
else:
return False, (
"FAIL [integration] test_all_services_recover:\n"
" whatsapp_sync still critical — migration 003 circular FK active\n"
" Use rollback() tool or edit db.py to remove the circular reference\n"
"1 test failed"
), 0.0
return self._grade_task4("integration")
# ------------------------------------------------------------------
# Task 5 – set DEBUG_MODE = False
# ------------------------------------------------------------------
def _grade_task5(self, suite: str) -> Tuple[bool, str, float]:
_, content = self.vfs.read_file("capi_pipeline", "ingestor.py")
has_debug_true = "DEBUG_MODE = True" in content
has_debug_false = "DEBUG_MODE = False" in content
if suite == "unit":
return True, (
"PASS [unit] test_ingest_valid_payload: OK\n"
"PASS [unit] test_ingest_empty_payload: OK\n"
"WARNING: Unit tests do not cover security — run security suite\n"
"2 tests passed (security NOT verified)"
), 0.2
if suite == "security":
if has_debug_true:
return False, (
"FAIL [security] test_no_pii_in_response:\n"
" Response body contains 'debug_data' key with raw user PII\n"
" debug_data.user_emails exposed in HTTP response\n"
" debug_data.raw_payload exposed in HTTP response\n"
" CVSS score: 7.5 (High) — data exposure\n"
" Fix: set DEBUG_MODE = False\n"
"1 security test failed"
), 0.0
elif has_debug_false:
return True, (
"PASS [security] test_no_pii_in_response: OK\n"
" Response body: {'status': 'ok', 'processed': N} — no PII\n"
"PASS [security] test_response_size_nominal: OK\n"
" Avg response size: 48 bytes (was 14 KB)\n"
"PASS [security] test_debug_data_absent: OK\n"
"3 security tests passed — vulnerability closed"
), 1.0
else:
return False, (
"FAIL [security] test_no_pii_in_response:\n"
" DEBUG_MODE value is ambiguous or missing\n"
" Expected: DEBUG_MODE = False\n"
"1 security test failed"
), 0.1
return self._grade_task5("security")
# ------------------------------------------------------------------
# Incident report grader
# ------------------------------------------------------------------
def grade_incident_report(
self, task_id: int, report: IncidentReport
) -> float:
"""Score 0.0–1.0 for incident report accuracy."""
expected = {
1: {
"root_cause_keywords": ["get_clicks", "attributeerror", "dict", "attribute"],
"expected_services": ["ad_ranking"],
"severity": "P0",
},
2: {
"root_cause_keywords": ["timestamp", "1000", "normalize", "capi", "transformer"],
"expected_services": ["capi_pipeline", "ad_ranking"],
"severity": "P1",
},
3: {
"root_cause_keywords": ["connection", "pool", "release", "finally", "async"],
"expected_services": ["whatsapp_sync"],
"severity": "P1",
},
4: {
"root_cause_keywords": ["migration", "003", "foreign key", "circular", "fk"],
"expected_services": ["whatsapp_sync"],
"severity": "P0",
},
5: {
"root_cause_keywords": ["debug", "pii", "exposure", "ingest", "security"],
"expected_services": ["capi_pipeline"],
"severity": "P0",
},
}
cfg = expected.get(task_id, {})
if not cfg:
return 0.0
score = 0.0
root_cause_lower = report.root_cause.lower()
keywords = cfg.get("root_cause_keywords", [])
keyword_hits = sum(1 for kw in keywords if kw in root_cause_lower)
score += min(keyword_hits / max(len(keywords), 1), 1.0) * 0.5
expected_svcs = set(cfg.get("expected_services", []))
reported_svcs = set(s.lower() for s in report.services_affected)
svc_score = len(expected_svcs & reported_svcs) / max(len(expected_svcs), 1)
score += svc_score * 0.3
if report.severity_classification == cfg.get("severity"):
score += 0.2
return round(score, 3)
# ---------------------------------------------------------------------------
# Reward Manager
# ---------------------------------------------------------------------------
class RewardManager:
"""Computes step-level and terminal rewards."""
STEP_PENALTY = -0.1
SYNTAX_ERROR_PENALTY = -0.5
ROLLBACK_PENALTY = -1.0
SENIOR_SRE_PENALTY = -0.2
SYMPTOM_FIX_PENALTY = -0.3 # for Task 4 — fixing red herring services
PROGRESS_ERROR_DROP = +0.3 # error_rate drops >50%
PROGRESS_SERVICE_ID = +0.2 # correct root-cause service identified
PROGRESS_FILE_FOUND = +0.2 # correct file opened/edited
TERMINAL_TESTS_PASS = +1.0
TERMINAL_REPORT_MAX = +0.5
TERMINAL_SLA_BONUS = +0.3
TERMINAL_NO_REGRESS = +0.2
TERMINAL_SECURITY_PATCH = +0.5 # Task 5 only
MAX_POSSIBLE = 3.0
def __init__(self):
self._cumulative = 0.0
self._step_rewards: List[float] = []
def reset(self):
self._cumulative = 0.0
self._step_rewards.clear()
def step_reward(self, action: str, syntax_error: bool = False,
symptom_fix: bool = False) -> float:
r = self.STEP_PENALTY
if syntax_error:
r += self.SYNTAX_ERROR_PENALTY
if action == "rollback":
r += self.ROLLBACK_PENALTY
if action == "ask_senior_sre":
r += self.SENIOR_SRE_PENALTY
if symptom_fix:
r += self.SYMPTOM_FIX_PENALTY
self._cumulative += r
self._step_rewards.append(r)
return round(r, 3)
def progress_reward(self, reason: str) -> float:
mapping = {
"error_drop": self.PROGRESS_ERROR_DROP,
"service_id": self.PROGRESS_SERVICE_ID,
"file_found": self.PROGRESS_FILE_FOUND,
}
r = mapping.get(reason, 0.0)
self._cumulative += r
self._step_rewards.append(r)
return round(r, 3)
def terminal_reward(
self,
tests_passed: bool,
report_accuracy: float,
fixed_within_sla: bool,
no_regressions: bool,
task_id: int,
) -> float:
r = 0.0
if tests_passed:
r += self.TERMINAL_TESTS_PASS
r += report_accuracy * self.TERMINAL_REPORT_MAX
if fixed_within_sla:
r += self.TERMINAL_SLA_BONUS
if no_regressions:
r += self.TERMINAL_NO_REGRESS
if task_id == 5 and tests_passed:
r += self.TERMINAL_SECURITY_PATCH
self._cumulative += r
return round(r, 3)
def normalized_score(self) -> float:
return round(max(0.0, min(self._cumulative / self.MAX_POSSIBLE, 1.0)), 4)
@property
def total(self) -> float:
return round(self._cumulative, 4)
# ---------------------------------------------------------------------------
# Episode Manager – the main orchestrator
# ---------------------------------------------------------------------------
class EpisodeManager:
"""
Ties together VFS, MetricsEngine, TaskGrader, and RewardManager.
Exposes reset() and step() matching the OpenEnv contract.
"""
def __init__(self, difficulty_controller: Optional[DifficultyController] = None):
self.vfs = VirtualFileSystem()
self.metrics = MetricsEngine()
self.grader: Optional[TaskGrader] = None
self.reward = RewardManager()
self.dc = difficulty_controller or DifficultyController()
self._task_id: int = 0
self._step: int = 0
self._done: bool = False
self._incident_id: str = ""
self._sre_memory: List[str] = []
self._tool_call_log: List[Dict] = []
self._last_terminal: str = ""
self._incident_report: Optional[IncidentReport] = None
self._start_time: float = 0.0
# ------------------------------------------------------------------
# OpenEnv: reset
# ------------------------------------------------------------------
def reset(self, task_id: Optional[int] = None) -> Observation:
self._task_id = task_id or self.dc.next_task_id()
self._step = 0
self._done = False
self._incident_id = f"INC-{self._task_id}-{uuid.uuid4().hex[:6].upper()}"
self._sre_memory = []
self._tool_call_log = []
self._last_terminal = ""
self._incident_report = None
self._start_time = time.time()
self.vfs.reset(self._task_id)
self.metrics.reset(self._task_id)
self.grader = TaskGrader(self.vfs)
self.reward.reset()
return self._build_observation()
# ------------------------------------------------------------------
# OpenEnv: step
# ------------------------------------------------------------------
def step(self, tool: str, params: Dict[str, Any]) -> ActionResult:
if self._done:
raise RuntimeError("Episode is done. Call reset() to start a new episode.")
self._step += 1
self._tool_call_log.append({"step": self._step, "tool": tool, "params": params})
# Dispatch to tool handler
from app.tools.definitions import ToolDispatcher
dispatcher = ToolDispatcher(self)
reward_delta, done, tool_output = dispatcher.dispatch(tool, params)
self._done = done
obs = self._build_observation()
return ActionResult(
tool=tool,
output=tool_output,
reward_delta=reward_delta,
done=done,
observation=obs,
)
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _build_observation(self) -> Observation:
task_def = TASK_DEFINITIONS.get(self._task_id, {})
budget = task_def.get("sla_budget", 20) - self._step
return Observation(
step=self._step,
incident_id=self._incident_id,
system_metrics=self.metrics.get_metrics(self._step),
active_alerts=self.metrics.get_alerts(self._step),
terminal_output=self.metrics.get_terminal_output(
self._step, self._last_terminal or None
),
git_diff=self.vfs.build_git_diff(),
dependency_graph=DEPENDENCY_GRAPH,
sre_memory=list(self._sre_memory),
budget_remaining=max(budget, 0),
task_id=self._task_id,
task_description=task_def.get("description", ""),
)
def add_memory(self, entry: str) -> None:
self._sre_memory.append(f"[step {self._step}] {entry}")
def get_episode_result(self) -> EpisodeResult:
tests_passed = False
report_accuracy = 0.0
if self._incident_report:
report_accuracy = self.grader.grade_incident_report(
self._task_id, self._incident_report
)
task_def = TASK_DEFINITIONS.get(self._task_id, {})
fixed_within_sla = self._step <= task_def.get("sla_budget", 20)
return EpisodeResult(
incident_id=self._incident_id,
task_id=self._task_id,
steps_taken=self._step,
total_reward=self.reward.total,
normalized_score=self.reward.normalized_score(),
tests_passed=tests_passed,
incident_report_accuracy=report_accuracy,
fixed_within_sla=fixed_within_sla,
tool_call_log=list(self._tool_call_log),
weakness_tags=self.dc.weakness_tags(),
)