""" Layer 3 – Task Grader, Reward Manager, and Episode Orchestrator. TaskGrader : checks whether the current VFS state passes the hidden tests. RewardManager: computes per-step and terminal rewards. EpisodeManager: ties Layers 1-2-3 together and drives the OpenEnv step loop. """ from __future__ import annotations import uuid import time from typing import Any, Dict, List, Optional, Tuple from app.models import ( Observation, ActionResult, EpisodeResult, IncidentReport, Alert, ServiceMetrics, ) from app.engine.sandbox import VirtualFileSystem from app.engine.observability import MetricsEngine, DifficultyController # --------------------------------------------------------------------------- # Task definitions # --------------------------------------------------------------------------- TASK_DEFINITIONS: Dict[int, Dict] = { 1: { "description": ( "INCIDENT: All /rank requests are returning HTTP 500. " "The ad-ranking service is crashing on every call. " "Find and fix the bug in ad_ranking/ranker.py." ), "sla_budget": 15, "difficulty": "easy", "bug_category": "data_corruption", "affected_services": ["ad_ranking"], }, 2: { "description": ( "INCIDENT: ROAS (Return on Ad Spend) has dropped 68% vs last week. " "No services are crashing. Ad-ranking allocation decisions appear to be " "based on conversion data from 1970. Trace the root cause." ), "sla_budget": 20, "difficulty": "medium", "bug_category": "data_corruption", "affected_services": ["capi_pipeline", "ad_ranking"], }, 3: { "description": ( "INCIDENT: WhatsApp message sync works fine under normal load but " "hangs under peak traffic (>50 concurrent users). DB connection pool " "is exhausted. Fix the resource leak." ), "sla_budget": 20, "difficulty": "medium-hard", "bug_category": "async_bugs", "affected_services": ["whatsapp_sync"], }, 4: { "description": ( "INCIDENT: Three services degraded simultaneously after the 02:14 UTC deploy. " "Multiple P1 alerts are firing. Find the single root cause and fix it — " "do NOT chase individual service symptoms." ), "sla_budget": 25, "difficulty": "hard", "bug_category": "red_herrings", "affected_services": ["whatsapp_sync", "ad_ranking", "capi_pipeline"], }, 5: { "description": ( "INCIDENT: Security scan flagged unusual /ingest response sizes. " "Standard unit tests all pass. Find and close the data-exposure vulnerability " "in the CAPI ingestor. Write a P0 incident report." ), "sla_budget": 20, "difficulty": "hard", "bug_category": "security_bugs", "affected_services": ["capi_pipeline"], }, } DEPENDENCY_GRAPH: Dict[str, List[str]] = { "ad_ranking": ["capi_pipeline"], "capi_pipeline": [], "whatsapp_sync": ["capi_pipeline"], } # --------------------------------------------------------------------------- # Hidden graders — one per task # --------------------------------------------------------------------------- class TaskGrader: """ Checks the VFS content against hidden test criteria. Returns (passed, test_output_string, partial_score 0-1). """ def __init__(self, vfs: VirtualFileSystem): self.vfs = vfs def run(self, task_id: int, suite: str = "unit") -> Tuple[bool, str, float]: graders = { 1: self._grade_task1, 2: self._grade_task2, 3: self._grade_task3, 4: self._grade_task4, 5: self._grade_task5, } fn = graders.get(task_id) if fn is None: return False, "Unknown task", 0.0 return fn(suite) # ------------------------------------------------------------------ # Task 1 – fix ad.get_clicks() → ad.get('clicks', 0) # ------------------------------------------------------------------ def _grade_task1(self, suite: str) -> Tuple[bool, str, float]: _, content = self.vfs.read_file("ad_ranking", "ranker.py") has_bug = "ad.get_clicks()" in content has_fix = "ad.get('clicks'" in content or "ad['clicks']" in content if has_bug: return False, ( "FAIL [unit] test_score_ads:\n" " AttributeError: 'dict' object has no attribute 'get_clicks'\n" " Line 22 still contains ad.get_clicks()\n" " 1 test failed, 0 passed" ), 0.0 if has_fix: return True, ( "PASS [unit] test_score_ads: OK\n" "PASS [unit] test_rank_returns_sorted_list: OK\n" "PASS [unit] test_fetch_candidate_ads: OK\n" "3 tests passed in 0.04 s" ), 1.0 return False, ( "FAIL [unit] test_score_ads:\n" " Fix applied but ad click-rate accessor is incorrect.\n" " Expected: ad.get('clicks', 0) or ad['clicks']\n" " 1 test failed" ), 0.2 # ------------------------------------------------------------------ # Task 2 – fix timestamp threshold 1_000_000_000 → 1_000_000_000_000 # ------------------------------------------------------------------ def _grade_task2(self, suite: str) -> Tuple[bool, str, float]: _, content = self.vfs.read_file("capi_pipeline", "transformer.py") has_bug = "1_000_000_000:" in content or "1000000000:" in content # Only count as fixed if the bug line is gone AND the correct threshold is in code # (not just in comments — the comment already contains 1_000_000_000_000) code_lines = [l for l in content.splitlines() if not l.strip().startswith("#")] code_only = "\n".join(code_lines) has_fix = not has_bug and ( "1_000_000_000_000" in code_only or "1000000000000" in code_only or "1e12" in code_only or "10**12" in code_only ) if suite == "unit" and not has_bug: # Unit tests always pass because they don't check timestamp edge cases return True, ( "PASS [unit] test_transform_purchase: OK\n" "PASS [unit] test_batch_transform: OK\n" "2 tests passed" ), 0.4 if suite == "integration": if has_fix: return True, ( "PASS [integration] test_timestamp_normalisation: OK\n" " event_time 1700000000 → 1700000000 ✓\n" " event_time 1700000000000 → 1700000000 ✓\n" "PASS [integration] test_roas_attribution_accuracy: OK\n" " ROAS attribution error: 0.2% (threshold: 5%)\n" "2 tests passed" ), 1.0 else: return False, ( "FAIL [integration] test_timestamp_normalisation:\n" " event_time 1700000000 → 1700000 (expected: 1700000000)\n" " Timestamps are being divided by 1000 incorrectly.\n" " Root cause: threshold condition in _normalize_timestamp()\n" "1 test failed" ), 0.0 # Default: run integration test return self._grade_task2("integration") # ------------------------------------------------------------------ # Task 3 – add finally: await self.db_pool.release(conn) # ------------------------------------------------------------------ def _grade_task3(self, suite: str) -> Tuple[bool, str, float]: _, content = self.vfs.read_file("whatsapp_sync", "handler.py") has_finally = "finally:" in content has_release = "db_pool.release(conn)" in content or "release(conn)" in content if suite == "unit": if not has_finally: return False, ( "PASS [unit] test_sync_messages_basic: OK\n" "PASS [unit] test_process_queue_empty: OK\n" "WARNING: Unit tests pass but connection leak not detectable without load test\n" "Run: run_tests('whatsapp_sync', 'load')" ), 0.3 return True, ( "PASS [unit] test_sync_messages_basic: OK\n" "PASS [unit] test_connection_released_on_success: OK\n" "PASS [unit] test_connection_released_on_exception: OK\n" "3 tests passed" ), 0.6 if suite == "load": if has_finally and has_release: return True, ( "PASS [load] test_100_concurrent_syncs:\n" " Peak connections: 18/100 (nominal)\n" " All 100 requests completed\n" " Memory stable at 210 MB\n" "PASS [load] test_connection_pool_not_exhausted: OK\n" "2 load tests passed" ), 1.0 else: return False, ( "FAIL [load] test_100_concurrent_syncs:\n" " TooManyConnectionsError after 23 concurrent requests\n" " Connection pool exhausted — connections not being released\n" " Hint: Check sync_user_messages() for missing finally block\n" "1 load test failed" ), 0.0 return self._grade_task3("load") # ------------------------------------------------------------------ # Task 4 – rollback migration 003 (remove circular FK) # ------------------------------------------------------------------ def _grade_task4(self, suite: str) -> Tuple[bool, str, float]: _, content = self.vfs.read_file("whatsapp_sync", "db.py") has_circular_fk = ( "REFERENCES message_threads" in content and "REFERENCES messages" in content ) migration_003_present = '"version": "003"' in content or "'version': '003'" in content if suite == "unit": if has_circular_fk: return False, ( "FAIL [unit] test_migration_003:\n" " ForeignKeyViolationError: circular FK detected\n" " messages.thread_id → message_threads.id\n" " message_threads.parent_message_id → messages.id\n" " Fix: remove ALTER TABLE messages ADD COLUMN thread_id ...\n" "1 test failed" ), 0.0 if suite == "integration": if not has_circular_fk: return True, ( "PASS [integration] test_all_services_recover: OK\n" " whatsapp_sync: healthy ✓\n" " ad_ranking: healthy ✓\n" " capi_pipeline: healthy ✓\n" "PASS [integration] test_message_thread_creation: OK\n" "2 tests passed" ), 1.0 else: return False, ( "FAIL [integration] test_all_services_recover:\n" " whatsapp_sync still critical — migration 003 circular FK active\n" " Use rollback() tool or edit db.py to remove the circular reference\n" "1 test failed" ), 0.0 return self._grade_task4("integration") # ------------------------------------------------------------------ # Task 5 – set DEBUG_MODE = False # ------------------------------------------------------------------ def _grade_task5(self, suite: str) -> Tuple[bool, str, float]: _, content = self.vfs.read_file("capi_pipeline", "ingestor.py") has_debug_true = "DEBUG_MODE = True" in content has_debug_false = "DEBUG_MODE = False" in content if suite == "unit": return True, ( "PASS [unit] test_ingest_valid_payload: OK\n" "PASS [unit] test_ingest_empty_payload: OK\n" "WARNING: Unit tests do not cover security — run security suite\n" "2 tests passed (security NOT verified)" ), 0.2 if suite == "security": if has_debug_true: return False, ( "FAIL [security] test_no_pii_in_response:\n" " Response body contains 'debug_data' key with raw user PII\n" " debug_data.user_emails exposed in HTTP response\n" " debug_data.raw_payload exposed in HTTP response\n" " CVSS score: 7.5 (High) — data exposure\n" " Fix: set DEBUG_MODE = False\n" "1 security test failed" ), 0.0 elif has_debug_false: return True, ( "PASS [security] test_no_pii_in_response: OK\n" " Response body: {'status': 'ok', 'processed': N} — no PII\n" "PASS [security] test_response_size_nominal: OK\n" " Avg response size: 48 bytes (was 14 KB)\n" "PASS [security] test_debug_data_absent: OK\n" "3 security tests passed — vulnerability closed" ), 1.0 else: return False, ( "FAIL [security] test_no_pii_in_response:\n" " DEBUG_MODE value is ambiguous or missing\n" " Expected: DEBUG_MODE = False\n" "1 security test failed" ), 0.1 return self._grade_task5("security") # ------------------------------------------------------------------ # Incident report grader # ------------------------------------------------------------------ def grade_incident_report( self, task_id: int, report: IncidentReport ) -> float: """Score 0.0–1.0 for incident report accuracy.""" expected = { 1: { "root_cause_keywords": ["get_clicks", "attributeerror", "dict", "attribute"], "expected_services": ["ad_ranking"], "severity": "P0", }, 2: { "root_cause_keywords": ["timestamp", "1000", "normalize", "capi", "transformer"], "expected_services": ["capi_pipeline", "ad_ranking"], "severity": "P1", }, 3: { "root_cause_keywords": ["connection", "pool", "release", "finally", "async"], "expected_services": ["whatsapp_sync"], "severity": "P1", }, 4: { "root_cause_keywords": ["migration", "003", "foreign key", "circular", "fk"], "expected_services": ["whatsapp_sync"], "severity": "P0", }, 5: { "root_cause_keywords": ["debug", "pii", "exposure", "ingest", "security"], "expected_services": ["capi_pipeline"], "severity": "P0", }, } cfg = expected.get(task_id, {}) if not cfg: return 0.0 score = 0.0 root_cause_lower = report.root_cause.lower() keywords = cfg.get("root_cause_keywords", []) keyword_hits = sum(1 for kw in keywords if kw in root_cause_lower) score += min(keyword_hits / max(len(keywords), 1), 1.0) * 0.5 expected_svcs = set(cfg.get("expected_services", [])) reported_svcs = set(s.lower() for s in report.services_affected) svc_score = len(expected_svcs & reported_svcs) / max(len(expected_svcs), 1) score += svc_score * 0.3 if report.severity_classification == cfg.get("severity"): score += 0.2 return round(score, 3) # --------------------------------------------------------------------------- # Reward Manager # --------------------------------------------------------------------------- class RewardManager: """Computes step-level and terminal rewards.""" STEP_PENALTY = -0.1 SYNTAX_ERROR_PENALTY = -0.5 ROLLBACK_PENALTY = -1.0 SENIOR_SRE_PENALTY = -0.2 SYMPTOM_FIX_PENALTY = -0.3 # for Task 4 — fixing red herring services PROGRESS_ERROR_DROP = +0.3 # error_rate drops >50% PROGRESS_SERVICE_ID = +0.2 # correct root-cause service identified PROGRESS_FILE_FOUND = +0.2 # correct file opened/edited TERMINAL_TESTS_PASS = +1.0 TERMINAL_REPORT_MAX = +0.5 TERMINAL_SLA_BONUS = +0.3 TERMINAL_NO_REGRESS = +0.2 TERMINAL_SECURITY_PATCH = +0.5 # Task 5 only MAX_POSSIBLE = 3.0 def __init__(self): self._cumulative = 0.0 self._step_rewards: List[float] = [] def reset(self): self._cumulative = 0.0 self._step_rewards.clear() def step_reward(self, action: str, syntax_error: bool = False, symptom_fix: bool = False) -> float: r = self.STEP_PENALTY if syntax_error: r += self.SYNTAX_ERROR_PENALTY if action == "rollback": r += self.ROLLBACK_PENALTY if action == "ask_senior_sre": r += self.SENIOR_SRE_PENALTY if symptom_fix: r += self.SYMPTOM_FIX_PENALTY self._cumulative += r self._step_rewards.append(r) return round(r, 3) def progress_reward(self, reason: str) -> float: mapping = { "error_drop": self.PROGRESS_ERROR_DROP, "service_id": self.PROGRESS_SERVICE_ID, "file_found": self.PROGRESS_FILE_FOUND, } r = mapping.get(reason, 0.0) self._cumulative += r self._step_rewards.append(r) return round(r, 3) def terminal_reward( self, tests_passed: bool, report_accuracy: float, fixed_within_sla: bool, no_regressions: bool, task_id: int, ) -> float: r = 0.0 if tests_passed: r += self.TERMINAL_TESTS_PASS r += report_accuracy * self.TERMINAL_REPORT_MAX if fixed_within_sla: r += self.TERMINAL_SLA_BONUS if no_regressions: r += self.TERMINAL_NO_REGRESS if task_id == 5 and tests_passed: r += self.TERMINAL_SECURITY_PATCH self._cumulative += r return round(r, 3) def normalized_score(self) -> float: return round(max(0.0, min(self._cumulative / self.MAX_POSSIBLE, 1.0)), 4) @property def total(self) -> float: return round(self._cumulative, 4) # --------------------------------------------------------------------------- # Episode Manager – the main orchestrator # --------------------------------------------------------------------------- class EpisodeManager: """ Ties together VFS, MetricsEngine, TaskGrader, and RewardManager. Exposes reset() and step() matching the OpenEnv contract. """ def __init__(self, difficulty_controller: Optional[DifficultyController] = None): self.vfs = VirtualFileSystem() self.metrics = MetricsEngine() self.grader: Optional[TaskGrader] = None self.reward = RewardManager() self.dc = difficulty_controller or DifficultyController() self._task_id: int = 0 self._step: int = 0 self._done: bool = False self._incident_id: str = "" self._sre_memory: List[str] = [] self._tool_call_log: List[Dict] = [] self._last_terminal: str = "" self._incident_report: Optional[IncidentReport] = None self._start_time: float = 0.0 # ------------------------------------------------------------------ # OpenEnv: reset # ------------------------------------------------------------------ def reset(self, task_id: Optional[int] = None) -> Observation: self._task_id = task_id or self.dc.next_task_id() self._step = 0 self._done = False self._incident_id = f"INC-{self._task_id}-{uuid.uuid4().hex[:6].upper()}" self._sre_memory = [] self._tool_call_log = [] self._last_terminal = "" self._incident_report = None self._start_time = time.time() self.vfs.reset(self._task_id) self.metrics.reset(self._task_id) self.grader = TaskGrader(self.vfs) self.reward.reset() return self._build_observation() # ------------------------------------------------------------------ # OpenEnv: step # ------------------------------------------------------------------ def step(self, tool: str, params: Dict[str, Any]) -> ActionResult: if self._done: raise RuntimeError("Episode is done. Call reset() to start a new episode.") self._step += 1 self._tool_call_log.append({"step": self._step, "tool": tool, "params": params}) # Dispatch to tool handler from app.tools.definitions import ToolDispatcher dispatcher = ToolDispatcher(self) reward_delta, done, tool_output = dispatcher.dispatch(tool, params) self._done = done obs = self._build_observation() return ActionResult( tool=tool, output=tool_output, reward_delta=reward_delta, done=done, observation=obs, ) # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ def _build_observation(self) -> Observation: task_def = TASK_DEFINITIONS.get(self._task_id, {}) budget = task_def.get("sla_budget", 20) - self._step return Observation( step=self._step, incident_id=self._incident_id, system_metrics=self.metrics.get_metrics(self._step), active_alerts=self.metrics.get_alerts(self._step), terminal_output=self.metrics.get_terminal_output( self._step, self._last_terminal or None ), git_diff=self.vfs.build_git_diff(), dependency_graph=DEPENDENCY_GRAPH, sre_memory=list(self._sre_memory), budget_remaining=max(budget, 0), task_id=self._task_id, task_description=task_def.get("description", ""), ) def add_memory(self, entry: str) -> None: self._sre_memory.append(f"[step {self._step}] {entry}") def get_episode_result(self) -> EpisodeResult: tests_passed = False report_accuracy = 0.0 if self._incident_report: report_accuracy = self.grader.grade_incident_report( self._task_id, self._incident_report ) task_def = TASK_DEFINITIONS.get(self._task_id, {}) fixed_within_sla = self._step <= task_def.get("sla_budget", 20) return EpisodeResult( incident_id=self._incident_id, task_id=self._task_id, steps_taken=self._step, total_reward=self.reward.total, normalized_score=self.reward.normalized_score(), tests_passed=tests_passed, incident_report_accuracy=report_accuracy, fixed_within_sla=fixed_within_sla, tool_call_log=list(self._tool_call_log), weakness_tags=self.dc.weakness_tags(), )