| """ |
| Layer 3 – Task Grader, Reward Manager, and Episode Orchestrator. |
| |
| TaskGrader : checks whether the current VFS state passes the hidden tests. |
| RewardManager: computes per-step and terminal rewards. |
| EpisodeManager: ties Layers 1-2-3 together and drives the OpenEnv step loop. |
| """ |
| from __future__ import annotations |
| import uuid |
| import time |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| from app.models import ( |
| Observation, ActionResult, EpisodeResult, IncidentReport, |
| Alert, ServiceMetrics, |
| ) |
| from app.engine.sandbox import VirtualFileSystem |
| from app.engine.observability import MetricsEngine, DifficultyController |
|
|
|
|
| |
| |
| |
|
|
| TASK_DEFINITIONS: Dict[int, Dict] = { |
| 1: { |
| "description": ( |
| "INCIDENT: All /rank requests are returning HTTP 500. " |
| "The ad-ranking service is crashing on every call. " |
| "Find and fix the bug in ad_ranking/ranker.py." |
| ), |
| "sla_budget": 15, |
| "difficulty": "easy", |
| "bug_category": "data_corruption", |
| "affected_services": ["ad_ranking"], |
| }, |
| 2: { |
| "description": ( |
| "INCIDENT: ROAS (Return on Ad Spend) has dropped 68% vs last week. " |
| "No services are crashing. Ad-ranking allocation decisions appear to be " |
| "based on conversion data from 1970. Trace the root cause." |
| ), |
| "sla_budget": 20, |
| "difficulty": "medium", |
| "bug_category": "data_corruption", |
| "affected_services": ["capi_pipeline", "ad_ranking"], |
| }, |
| 3: { |
| "description": ( |
| "INCIDENT: WhatsApp message sync works fine under normal load but " |
| "hangs under peak traffic (>50 concurrent users). DB connection pool " |
| "is exhausted. Fix the resource leak." |
| ), |
| "sla_budget": 20, |
| "difficulty": "medium-hard", |
| "bug_category": "async_bugs", |
| "affected_services": ["whatsapp_sync"], |
| }, |
| 4: { |
| "description": ( |
| "INCIDENT: Three services degraded simultaneously after the 02:14 UTC deploy. " |
| "Multiple P1 alerts are firing. Find the single root cause and fix it — " |
| "do NOT chase individual service symptoms." |
| ), |
| "sla_budget": 25, |
| "difficulty": "hard", |
| "bug_category": "red_herrings", |
| "affected_services": ["whatsapp_sync", "ad_ranking", "capi_pipeline"], |
| }, |
| 5: { |
| "description": ( |
| "INCIDENT: Security scan flagged unusual /ingest response sizes. " |
| "Standard unit tests all pass. Find and close the data-exposure vulnerability " |
| "in the CAPI ingestor. Write a P0 incident report." |
| ), |
| "sla_budget": 20, |
| "difficulty": "hard", |
| "bug_category": "security_bugs", |
| "affected_services": ["capi_pipeline"], |
| }, |
| } |
|
|
| DEPENDENCY_GRAPH: Dict[str, List[str]] = { |
| "ad_ranking": ["capi_pipeline"], |
| "capi_pipeline": [], |
| "whatsapp_sync": ["capi_pipeline"], |
| } |
|
|
|
|
| |
| |
| |
|
|
| class TaskGrader: |
| """ |
| Checks the VFS content against hidden test criteria. |
| Returns (passed, test_output_string, partial_score 0-1). |
| """ |
|
|
| def __init__(self, vfs: VirtualFileSystem): |
| self.vfs = vfs |
|
|
| def run(self, task_id: int, suite: str = "unit") -> Tuple[bool, str, float]: |
| graders = { |
| 1: self._grade_task1, |
| 2: self._grade_task2, |
| 3: self._grade_task3, |
| 4: self._grade_task4, |
| 5: self._grade_task5, |
| } |
| fn = graders.get(task_id) |
| if fn is None: |
| return False, "Unknown task", 0.0 |
| return fn(suite) |
|
|
| |
| |
| |
| def _grade_task1(self, suite: str) -> Tuple[bool, str, float]: |
| _, content = self.vfs.read_file("ad_ranking", "ranker.py") |
| has_bug = "ad.get_clicks()" in content |
| has_fix = "ad.get('clicks'" in content or "ad['clicks']" in content |
|
|
| if has_bug: |
| return False, ( |
| "FAIL [unit] test_score_ads:\n" |
| " AttributeError: 'dict' object has no attribute 'get_clicks'\n" |
| " Line 22 still contains ad.get_clicks()\n" |
| " 1 test failed, 0 passed" |
| ), 0.0 |
|
|
| if has_fix: |
| return True, ( |
| "PASS [unit] test_score_ads: OK\n" |
| "PASS [unit] test_rank_returns_sorted_list: OK\n" |
| "PASS [unit] test_fetch_candidate_ads: OK\n" |
| "3 tests passed in 0.04 s" |
| ), 1.0 |
|
|
| return False, ( |
| "FAIL [unit] test_score_ads:\n" |
| " Fix applied but ad click-rate accessor is incorrect.\n" |
| " Expected: ad.get('clicks', 0) or ad['clicks']\n" |
| " 1 test failed" |
| ), 0.2 |
|
|
| |
| |
| |
| def _grade_task2(self, suite: str) -> Tuple[bool, str, float]: |
| _, content = self.vfs.read_file("capi_pipeline", "transformer.py") |
| has_bug = "1_000_000_000:" in content or "1000000000:" in content |
| |
| |
| code_lines = [l for l in content.splitlines() if not l.strip().startswith("#")] |
| code_only = "\n".join(code_lines) |
| has_fix = not has_bug and ( |
| "1_000_000_000_000" in code_only or |
| "1000000000000" in code_only or |
| "1e12" in code_only or |
| "10**12" in code_only |
| ) |
|
|
| if suite == "unit" and not has_bug: |
| |
| return True, ( |
| "PASS [unit] test_transform_purchase: OK\n" |
| "PASS [unit] test_batch_transform: OK\n" |
| "2 tests passed" |
| ), 0.4 |
|
|
| if suite == "integration": |
| if has_fix: |
| return True, ( |
| "PASS [integration] test_timestamp_normalisation: OK\n" |
| " event_time 1700000000 → 1700000000 ✓\n" |
| " event_time 1700000000000 → 1700000000 ✓\n" |
| "PASS [integration] test_roas_attribution_accuracy: OK\n" |
| " ROAS attribution error: 0.2% (threshold: 5%)\n" |
| "2 tests passed" |
| ), 1.0 |
| else: |
| return False, ( |
| "FAIL [integration] test_timestamp_normalisation:\n" |
| " event_time 1700000000 → 1700000 (expected: 1700000000)\n" |
| " Timestamps are being divided by 1000 incorrectly.\n" |
| " Root cause: threshold condition in _normalize_timestamp()\n" |
| "1 test failed" |
| ), 0.0 |
|
|
| |
| return self._grade_task2("integration") |
|
|
| |
| |
| |
| def _grade_task3(self, suite: str) -> Tuple[bool, str, float]: |
| _, content = self.vfs.read_file("whatsapp_sync", "handler.py") |
| has_finally = "finally:" in content |
| has_release = "db_pool.release(conn)" in content or "release(conn)" in content |
|
|
| if suite == "unit": |
| if not has_finally: |
| return False, ( |
| "PASS [unit] test_sync_messages_basic: OK\n" |
| "PASS [unit] test_process_queue_empty: OK\n" |
| "WARNING: Unit tests pass but connection leak not detectable without load test\n" |
| "Run: run_tests('whatsapp_sync', 'load')" |
| ), 0.3 |
|
|
| return True, ( |
| "PASS [unit] test_sync_messages_basic: OK\n" |
| "PASS [unit] test_connection_released_on_success: OK\n" |
| "PASS [unit] test_connection_released_on_exception: OK\n" |
| "3 tests passed" |
| ), 0.6 |
|
|
| if suite == "load": |
| if has_finally and has_release: |
| return True, ( |
| "PASS [load] test_100_concurrent_syncs:\n" |
| " Peak connections: 18/100 (nominal)\n" |
| " All 100 requests completed\n" |
| " Memory stable at 210 MB\n" |
| "PASS [load] test_connection_pool_not_exhausted: OK\n" |
| "2 load tests passed" |
| ), 1.0 |
| else: |
| return False, ( |
| "FAIL [load] test_100_concurrent_syncs:\n" |
| " TooManyConnectionsError after 23 concurrent requests\n" |
| " Connection pool exhausted — connections not being released\n" |
| " Hint: Check sync_user_messages() for missing finally block\n" |
| "1 load test failed" |
| ), 0.0 |
|
|
| return self._grade_task3("load") |
|
|
| |
| |
| |
| def _grade_task4(self, suite: str) -> Tuple[bool, str, float]: |
| _, content = self.vfs.read_file("whatsapp_sync", "db.py") |
| has_circular_fk = ( |
| "REFERENCES message_threads" in content and |
| "REFERENCES messages" in content |
| ) |
| migration_003_present = '"version": "003"' in content or "'version': '003'" in content |
|
|
| if suite == "unit": |
| if has_circular_fk: |
| return False, ( |
| "FAIL [unit] test_migration_003:\n" |
| " ForeignKeyViolationError: circular FK detected\n" |
| " messages.thread_id → message_threads.id\n" |
| " message_threads.parent_message_id → messages.id\n" |
| " Fix: remove ALTER TABLE messages ADD COLUMN thread_id ...\n" |
| "1 test failed" |
| ), 0.0 |
|
|
| if suite == "integration": |
| if not has_circular_fk: |
| return True, ( |
| "PASS [integration] test_all_services_recover: OK\n" |
| " whatsapp_sync: healthy ✓\n" |
| " ad_ranking: healthy ✓\n" |
| " capi_pipeline: healthy ✓\n" |
| "PASS [integration] test_message_thread_creation: OK\n" |
| "2 tests passed" |
| ), 1.0 |
| else: |
| return False, ( |
| "FAIL [integration] test_all_services_recover:\n" |
| " whatsapp_sync still critical — migration 003 circular FK active\n" |
| " Use rollback() tool or edit db.py to remove the circular reference\n" |
| "1 test failed" |
| ), 0.0 |
|
|
| return self._grade_task4("integration") |
|
|
| |
| |
| |
| def _grade_task5(self, suite: str) -> Tuple[bool, str, float]: |
| _, content = self.vfs.read_file("capi_pipeline", "ingestor.py") |
| has_debug_true = "DEBUG_MODE = True" in content |
| has_debug_false = "DEBUG_MODE = False" in content |
|
|
| if suite == "unit": |
| return True, ( |
| "PASS [unit] test_ingest_valid_payload: OK\n" |
| "PASS [unit] test_ingest_empty_payload: OK\n" |
| "WARNING: Unit tests do not cover security — run security suite\n" |
| "2 tests passed (security NOT verified)" |
| ), 0.2 |
|
|
| if suite == "security": |
| if has_debug_true: |
| return False, ( |
| "FAIL [security] test_no_pii_in_response:\n" |
| " Response body contains 'debug_data' key with raw user PII\n" |
| " debug_data.user_emails exposed in HTTP response\n" |
| " debug_data.raw_payload exposed in HTTP response\n" |
| " CVSS score: 7.5 (High) — data exposure\n" |
| " Fix: set DEBUG_MODE = False\n" |
| "1 security test failed" |
| ), 0.0 |
| elif has_debug_false: |
| return True, ( |
| "PASS [security] test_no_pii_in_response: OK\n" |
| " Response body: {'status': 'ok', 'processed': N} — no PII\n" |
| "PASS [security] test_response_size_nominal: OK\n" |
| " Avg response size: 48 bytes (was 14 KB)\n" |
| "PASS [security] test_debug_data_absent: OK\n" |
| "3 security tests passed — vulnerability closed" |
| ), 1.0 |
| else: |
| return False, ( |
| "FAIL [security] test_no_pii_in_response:\n" |
| " DEBUG_MODE value is ambiguous or missing\n" |
| " Expected: DEBUG_MODE = False\n" |
| "1 security test failed" |
| ), 0.1 |
|
|
| return self._grade_task5("security") |
|
|
| |
| |
| |
|
|
| def grade_incident_report( |
| self, task_id: int, report: IncidentReport |
| ) -> float: |
| """Score 0.0–1.0 for incident report accuracy.""" |
| expected = { |
| 1: { |
| "root_cause_keywords": ["get_clicks", "attributeerror", "dict", "attribute"], |
| "expected_services": ["ad_ranking"], |
| "severity": "P0", |
| }, |
| 2: { |
| "root_cause_keywords": ["timestamp", "1000", "normalize", "capi", "transformer"], |
| "expected_services": ["capi_pipeline", "ad_ranking"], |
| "severity": "P1", |
| }, |
| 3: { |
| "root_cause_keywords": ["connection", "pool", "release", "finally", "async"], |
| "expected_services": ["whatsapp_sync"], |
| "severity": "P1", |
| }, |
| 4: { |
| "root_cause_keywords": ["migration", "003", "foreign key", "circular", "fk"], |
| "expected_services": ["whatsapp_sync"], |
| "severity": "P0", |
| }, |
| 5: { |
| "root_cause_keywords": ["debug", "pii", "exposure", "ingest", "security"], |
| "expected_services": ["capi_pipeline"], |
| "severity": "P0", |
| }, |
| } |
|
|
| cfg = expected.get(task_id, {}) |
| if not cfg: |
| return 0.0 |
|
|
| score = 0.0 |
| root_cause_lower = report.root_cause.lower() |
| keywords = cfg.get("root_cause_keywords", []) |
| keyword_hits = sum(1 for kw in keywords if kw in root_cause_lower) |
| score += min(keyword_hits / max(len(keywords), 1), 1.0) * 0.5 |
|
|
| expected_svcs = set(cfg.get("expected_services", [])) |
| reported_svcs = set(s.lower() for s in report.services_affected) |
| svc_score = len(expected_svcs & reported_svcs) / max(len(expected_svcs), 1) |
| score += svc_score * 0.3 |
|
|
| if report.severity_classification == cfg.get("severity"): |
| score += 0.2 |
|
|
| return round(score, 3) |
|
|
|
|
| |
| |
| |
|
|
| class RewardManager: |
| """Computes step-level and terminal rewards.""" |
|
|
| STEP_PENALTY = -0.1 |
| SYNTAX_ERROR_PENALTY = -0.5 |
| ROLLBACK_PENALTY = -1.0 |
| SENIOR_SRE_PENALTY = -0.2 |
| SYMPTOM_FIX_PENALTY = -0.3 |
|
|
| PROGRESS_ERROR_DROP = +0.3 |
| PROGRESS_SERVICE_ID = +0.2 |
| PROGRESS_FILE_FOUND = +0.2 |
|
|
| TERMINAL_TESTS_PASS = +1.0 |
| TERMINAL_REPORT_MAX = +0.5 |
| TERMINAL_SLA_BONUS = +0.3 |
| TERMINAL_NO_REGRESS = +0.2 |
| TERMINAL_SECURITY_PATCH = +0.5 |
|
|
| MAX_POSSIBLE = 3.0 |
|
|
| def __init__(self): |
| self._cumulative = 0.0 |
| self._step_rewards: List[float] = [] |
|
|
| def reset(self): |
| self._cumulative = 0.0 |
| self._step_rewards.clear() |
|
|
| def step_reward(self, action: str, syntax_error: bool = False, |
| symptom_fix: bool = False) -> float: |
| r = self.STEP_PENALTY |
| if syntax_error: |
| r += self.SYNTAX_ERROR_PENALTY |
| if action == "rollback": |
| r += self.ROLLBACK_PENALTY |
| if action == "ask_senior_sre": |
| r += self.SENIOR_SRE_PENALTY |
| if symptom_fix: |
| r += self.SYMPTOM_FIX_PENALTY |
| self._cumulative += r |
| self._step_rewards.append(r) |
| return round(r, 3) |
|
|
| def progress_reward(self, reason: str) -> float: |
| mapping = { |
| "error_drop": self.PROGRESS_ERROR_DROP, |
| "service_id": self.PROGRESS_SERVICE_ID, |
| "file_found": self.PROGRESS_FILE_FOUND, |
| } |
| r = mapping.get(reason, 0.0) |
| self._cumulative += r |
| self._step_rewards.append(r) |
| return round(r, 3) |
|
|
| def terminal_reward( |
| self, |
| tests_passed: bool, |
| report_accuracy: float, |
| fixed_within_sla: bool, |
| no_regressions: bool, |
| task_id: int, |
| ) -> float: |
| r = 0.0 |
| if tests_passed: |
| r += self.TERMINAL_TESTS_PASS |
| r += report_accuracy * self.TERMINAL_REPORT_MAX |
| if fixed_within_sla: |
| r += self.TERMINAL_SLA_BONUS |
| if no_regressions: |
| r += self.TERMINAL_NO_REGRESS |
| if task_id == 5 and tests_passed: |
| r += self.TERMINAL_SECURITY_PATCH |
| self._cumulative += r |
| return round(r, 3) |
|
|
| def normalized_score(self) -> float: |
| return round(max(0.0, min(self._cumulative / self.MAX_POSSIBLE, 1.0)), 4) |
|
|
| @property |
| def total(self) -> float: |
| return round(self._cumulative, 4) |
|
|
|
|
| |
| |
| |
|
|
| class EpisodeManager: |
| """ |
| Ties together VFS, MetricsEngine, TaskGrader, and RewardManager. |
| Exposes reset() and step() matching the OpenEnv contract. |
| """ |
|
|
| def __init__(self, difficulty_controller: Optional[DifficultyController] = None): |
| self.vfs = VirtualFileSystem() |
| self.metrics = MetricsEngine() |
| self.grader: Optional[TaskGrader] = None |
| self.reward = RewardManager() |
| self.dc = difficulty_controller or DifficultyController() |
|
|
| self._task_id: int = 0 |
| self._step: int = 0 |
| self._done: bool = False |
| self._incident_id: str = "" |
| self._sre_memory: List[str] = [] |
| self._tool_call_log: List[Dict] = [] |
| self._last_terminal: str = "" |
| self._incident_report: Optional[IncidentReport] = None |
| self._start_time: float = 0.0 |
|
|
| |
| |
| |
|
|
| def reset(self, task_id: Optional[int] = None) -> Observation: |
| self._task_id = task_id or self.dc.next_task_id() |
| self._step = 0 |
| self._done = False |
| self._incident_id = f"INC-{self._task_id}-{uuid.uuid4().hex[:6].upper()}" |
| self._sre_memory = [] |
| self._tool_call_log = [] |
| self._last_terminal = "" |
| self._incident_report = None |
| self._start_time = time.time() |
|
|
| self.vfs.reset(self._task_id) |
| self.metrics.reset(self._task_id) |
| self.grader = TaskGrader(self.vfs) |
| self.reward.reset() |
|
|
| return self._build_observation() |
|
|
| |
| |
| |
|
|
| def step(self, tool: str, params: Dict[str, Any]) -> ActionResult: |
| if self._done: |
| raise RuntimeError("Episode is done. Call reset() to start a new episode.") |
|
|
| self._step += 1 |
| self._tool_call_log.append({"step": self._step, "tool": tool, "params": params}) |
|
|
| |
| from app.tools.definitions import ToolDispatcher |
| dispatcher = ToolDispatcher(self) |
| reward_delta, done, tool_output = dispatcher.dispatch(tool, params) |
|
|
| self._done = done |
| obs = self._build_observation() |
|
|
| return ActionResult( |
| tool=tool, |
| output=tool_output, |
| reward_delta=reward_delta, |
| done=done, |
| observation=obs, |
| ) |
|
|
| |
| |
| |
|
|
| def _build_observation(self) -> Observation: |
| task_def = TASK_DEFINITIONS.get(self._task_id, {}) |
| budget = task_def.get("sla_budget", 20) - self._step |
|
|
| return Observation( |
| step=self._step, |
| incident_id=self._incident_id, |
| system_metrics=self.metrics.get_metrics(self._step), |
| active_alerts=self.metrics.get_alerts(self._step), |
| terminal_output=self.metrics.get_terminal_output( |
| self._step, self._last_terminal or None |
| ), |
| git_diff=self.vfs.build_git_diff(), |
| dependency_graph=DEPENDENCY_GRAPH, |
| sre_memory=list(self._sre_memory), |
| budget_remaining=max(budget, 0), |
| task_id=self._task_id, |
| task_description=task_def.get("description", ""), |
| ) |
|
|
| def add_memory(self, entry: str) -> None: |
| self._sre_memory.append(f"[step {self._step}] {entry}") |
|
|
| def get_episode_result(self) -> EpisodeResult: |
| tests_passed = False |
| report_accuracy = 0.0 |
|
|
| if self._incident_report: |
| report_accuracy = self.grader.grade_incident_report( |
| self._task_id, self._incident_report |
| ) |
|
|
| task_def = TASK_DEFINITIONS.get(self._task_id, {}) |
| fixed_within_sla = self._step <= task_def.get("sla_budget", 20) |
|
|
| return EpisodeResult( |
| incident_id=self._incident_id, |
| task_id=self._task_id, |
| steps_taken=self._step, |
| total_reward=self.reward.total, |
| normalized_score=self.reward.normalized_score(), |
| tests_passed=tests_passed, |
| incident_report_accuracy=report_accuracy, |
| fixed_within_sla=fixed_within_sla, |
| tool_call_log=list(self._tool_call_log), |
| weakness_tags=self.dc.weakness_tags(), |
| ) |
|
|