""" Layer 4 – Perfect-Play Bot & JSONL Dataset Generator. Runs all 5 tasks optimally to generate training episodes. Outputs: training/dataset/training_data.jsonl Usage: python -m training.generator --episodes 40 --output training/dataset/training_data.jsonl """ from __future__ import annotations import sys, os sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import json import random import argparse from pathlib import Path from typing import Any, Dict, List, Tuple from app.engine.manager import EpisodeManager, TASK_DEFINITIONS from app.engine.observability import DifficultyController # --------------------------------------------------------------------------- # Perfect-Play Scripts # Each script is a list of (tool, params) tuples that solve the task optimally. # Randomise is applied to variable names / details for dataset diversity. # --------------------------------------------------------------------------- def _vary(base: str, variants: List[str]) -> str: return random.choice([base] + variants) def perfect_play_task1(ep: EpisodeManager) -> List[Tuple[str, Dict]]: """Hallucinated attribute: ad.get_clicks() → ad.get('clicks', 0)""" return [ ("read_logs", {"service": "ad_ranking", "log_level": "ERROR", "last_n_lines": 20}), ("view_file", {"service": "ad_ranking", "filename": "ranker.py"}), ("git_blame", {"service": "ad_ranking", "filename": "ranker.py", "line_number": 22}), ("edit_line", { "service": "ad_ranking", "filename": "ranker.py", "line_number": 22, "new_code": _vary( " click_rate = ad.get('clicks', 0) / max(ad.get('impressions', 1), 1)", [" click_rate = ad['clicks'] / max(ad.get('impressions', 1), 1)"] ), }), ("run_tests", {"service": "ad_ranking", "suite": "unit"}), ("write_incident_report", { "root_cause": "AttributeError: dict has no attribute get_clicks() — Junior AI generated method call instead of dict accessor", "fix_applied": "Replaced ad.get_clicks() with ad.get('clicks', 0) on ranker.py line 22", "services_affected": ["ad_ranking"], "severity_classification": "P0", }), ] def perfect_play_task2(ep: EpisodeManager) -> List[Tuple[str, Dict]]: """Silent timestamp corruption: threshold 1e9 → 1e12""" return [ ("read_logs", {"service": "ad_ranking", "log_level": "WARN", "last_n_lines": 20}), ("check_dependency", {"service_a": "ad_ranking", "service_b": "capi_pipeline"}), ("query_metrics_history", {"service": "capi_pipeline", "metric": "error_rate", "hours_back": 6}), ("read_logs", {"service": "capi_pipeline", "log_level": "DEBUG", "last_n_lines": 20}), ("view_file", {"service": "capi_pipeline", "filename": "transformer.py"}), ("edit_line", { "service": "capi_pipeline", "filename": "transformer.py", "line_number": 43, "new_code": " if ts > 1_000_000_000_000:", }), ("run_tests", {"service": "capi_pipeline", "suite": "integration"}), ("write_incident_report", { "root_cause": "Timestamp normalisation threshold in capi_pipeline/transformer.py was 1e9 instead of 1e12 — unix-second timestamps treated as milliseconds, resulting in events attributed to 1970", "fix_applied": "Changed _normalize_timestamp threshold from 1_000_000_000 to 1_000_000_000_000 on transformer.py line 40", "services_affected": ["capi_pipeline", "ad_ranking"], "severity_classification": "P1", }), ] def perfect_play_task3(ep: EpisodeManager) -> List[Tuple[str, Dict]]: """Connection pool exhaustion: add finally: await db_pool.release(conn)""" return [ ("read_logs", {"service": "whatsapp_sync", "log_level": "ERROR", "last_n_lines": 20}), ("query_metrics_history", {"service": "whatsapp_sync", "metric": "request_queue", "hours_back": 4}), ("view_file", {"service": "whatsapp_sync", "filename": "handler.py"}), ("git_blame", {"service": "whatsapp_sync", "filename": "handler.py", "line_number": 35}), ("run_tests", {"service": "whatsapp_sync", "suite": "unit"}), ("edit_line", { "service": "whatsapp_sync", "filename": "handler.py", "line_number": 35, "new_code": " raise", }), ("edit_line", { "service": "whatsapp_sync", "filename": "handler.py", "line_number": 36, "new_code": " finally:", }), ("edit_line", { "service": "whatsapp_sync", "filename": "handler.py", "line_number": 37, "new_code": " await self.db_pool.release(conn)", }), ("run_tests", {"service": "whatsapp_sync", "suite": "load"}), ("write_incident_report", { "root_cause": "DB connection pool exhaustion in whatsapp_sync — sync_user_messages() acquires a connection but has no finally block to release it on exception, causing pool depletion under concurrent load", "fix_applied": "Added finally: await self.db_pool.release(conn) to sync_user_messages() in handler.py", "services_affected": ["whatsapp_sync"], "severity_classification": "P1", }), ] def perfect_play_task4(ep: EpisodeManager) -> List[Tuple[str, Dict]]: """Circular FK in migration 003 cascading to all services""" return [ ("read_logs", {"service": "whatsapp_sync", "log_level": "ERROR", "last_n_lines": 30}), ("query_metrics_history", {"service": "capi_pipeline", "metric": "p99_latency_ms", "hours_back": 6}), ("view_file", {"service": "whatsapp_sync", "filename": "db.py"}), ("git_blame", {"service": "whatsapp_sync", "filename": "db.py", "line_number": 45}), ("run_tests", {"service": "whatsapp_sync", "suite": "unit"}), ("rollback", {"service": "whatsapp_sync", "version": "003"}), ("run_tests", {"service": "whatsapp_sync", "suite": "integration"}), ("write_incident_report", { "root_cause": "Circular foreign key in migration 003: message_threads.parent_message_id references messages, and the ALTER TABLE added messages.thread_id referencing message_threads — PostgreSQL FK resolution failure cascaded to all DB pool consumers", "fix_applied": "Rolled back migration 003 to remove circular FK constraint", "services_affected": ["whatsapp_sync"], "severity_classification": "P0", }), ] def perfect_play_task5(ep: EpisodeManager) -> List[Tuple[str, Dict]]: """PII data leak: DEBUG_MODE = True → False""" return [ ("read_logs", {"service": "capi_pipeline", "log_level": "DEBUG", "last_n_lines": 20}), ("run_tests", {"service": "capi_pipeline", "suite": "unit"}), ("view_file", {"service": "capi_pipeline", "filename": "ingestor.py"}), ("git_blame", {"service": "capi_pipeline", "filename": "ingestor.py", "line_number": 7}), ("edit_line", { "service": "capi_pipeline", "filename": "ingestor.py", "line_number": 7, "new_code": "DEBUG_MODE = False # FIXED: must be False in production", }), ("run_tests", {"service": "capi_pipeline", "suite": "security"}), ("write_incident_report", { "root_cause": "PII data exposure: DEBUG_MODE=True in production caused /ingest to return raw user PII (emails, phone numbers) in HTTP response body — invisible to unit tests, caught by security suite", "fix_applied": "Set DEBUG_MODE = False in capi_pipeline/ingestor.py line 7", "services_affected": ["capi_pipeline"], "severity_classification": "P0", }), ] PERFECT_PLAY_SCRIPTS = { 1: perfect_play_task1, 2: perfect_play_task2, 3: perfect_play_task3, 4: perfect_play_task4, 5: perfect_play_task5, } # --------------------------------------------------------------------------- # Observation → prompt string formatter # --------------------------------------------------------------------------- def obs_to_prompt(obs: dict) -> str: """Format the observation dict as the LLM system+user prompt.""" metrics_summary = [] for svc, m in obs.get("system_metrics", {}).items(): if isinstance(m, dict): metrics_summary.append( f" {svc}: CPU={m.get('cpu_percent',0):.0f}% " f"MEM={m.get('memory_mb',0):.0f}MB " f"ERR={m.get('error_rate',0):.1f}/s " f"STATUS={m.get('status','?')}" ) alerts_summary = [] for a in obs.get("active_alerts", []): if isinstance(a, dict): alerts_summary.append( f" [{a.get('severity','?')}] {a.get('service','?')}: {a.get('message','')}" ) return ( f"INCIDENT: {obs.get('incident_id','')}\n" f"TASK: {obs.get('task_description','')}\n" f"STEP: {obs.get('step',0)} | BUDGET: {obs.get('budget_remaining',0)} steps remaining\n\n" f"SYSTEM METRICS:\n" + "\n".join(metrics_summary) + "\n\n" f"ACTIVE ALERTS:\n" + ("\n".join(alerts_summary) or " None") + "\n\n" f"TERMINAL:\n{obs.get('terminal_output','')}\n\n" f"SRE MEMORY:\n" + ("\n".join(f" {m}" for m in obs.get("sre_memory", [])) or " (empty)") + "\n" ) def action_to_response(tool: str, params: Dict) -> str: """Format agent action as the assistant turn in the conversation.""" return json.dumps({"tool": tool, "params": params}, indent=2) # --------------------------------------------------------------------------- # Episode runner # --------------------------------------------------------------------------- def run_episode(task_id: int, ep: EpisodeManager) -> List[Dict]: """Run one perfect-play episode. Returns conversation turns for JSONL.""" obs = ep.reset(task_id=task_id) script_fn = PERFECT_PLAY_SCRIPTS[task_id] actions = script_fn(ep) turns = [] obs_dict = obs.model_dump() system_prompt = ( "You are a Senior Site Reliability Engineer (SRE) at Meta. " "You are debugging a live production incident. " "Use the available tools methodically: read logs first, then inspect code, " "make surgical single-line edits, verify with tests, and close with an incident report. " "Never rewrite entire files. Always run tests after editing." ) # Initial observation as first user turn turns.append({ "role": "system", "content": system_prompt, }) turns.append({ "role": "user", "content": obs_to_prompt(obs_dict), }) for tool, params in actions: # Assistant decides action turns.append({ "role": "assistant", "content": action_to_response(tool, params), }) # Execute in environment try: result = ep.step(tool=tool, params=params) obs_dict = result.observation.model_dump() # Next user turn = new observation turns.append({ "role": "user", "content": obs_to_prompt(obs_dict), }) if result.done: break except RuntimeError: break return turns # --------------------------------------------------------------------------- # Dataset generator # --------------------------------------------------------------------------- def generate_dataset( episodes_per_task: int = 40, output_path: str = "training/dataset/training_data.jsonl", seed: int = 42, ) -> None: random.seed(seed) Path(output_path).parent.mkdir(parents=True, exist_ok=True) ep = EpisodeManager(difficulty_controller=DifficultyController()) total = 0 task_counts = {t: 0 for t in range(1, 6)} with open(output_path, "w", encoding="utf-8") as f: for episode_idx in range(episodes_per_task * 5): task_id = (episode_idx % 5) + 1 try: turns = run_episode(task_id, ep) result = ep.get_episode_result() record = { "episode_id": f"ep_{episode_idx:04d}", "task_id": task_id, "normalized_score": result.normalized_score, "steps_taken": result.steps_taken, "messages": turns, } f.write(json.dumps(record) + "\n") total += 1 task_counts[task_id] += 1 if episode_idx % 10 == 0: print( f"[{episode_idx:4d}/{episodes_per_task*5}] " f"task={task_id} score={result.normalized_score:.3f} " f"steps={result.steps_taken}" ) except Exception as e: print(f"WARNING: episode {episode_idx} task {task_id} failed: {e}") print(f"\nDataset written to {output_path}") print(f"Total episodes: {total}") for t, c in task_counts.items(): print(f" Task {t}: {c} episodes") # --------------------------------------------------------------------------- # Baseline evaluator (for "before training" comparison) # --------------------------------------------------------------------------- def run_baseline_naive(task_id: int) -> float: """ Simulate a naive LLM that immediately tries to rewrite a whole file. Returns normalized score (expected ~0.18). """ ep = EpisodeManager() ep.reset(task_id=task_id) # Naive agent: immediately tries to edit line 1 with garbage ep.step("edit_line", { "service": "ad_ranking", "filename": "ranker.py", "line_number": 1, "new_code": "# rewriting entire file... (hallucination)", }) # Then writes incident report without fixing anything ep.step("write_incident_report", { "root_cause": "unknown error in the code", "fix_applied": "rewrote the file", "services_affected": ["ad_ranking"], "severity_classification": "P1", }) return ep.reward.normalized_score() def evaluate_model( model_name: str, call_fn, # callable(prompt: str) -> str (returns JSON action) n_tasks: int = 5, ) -> Dict[str, Any]: """ Evaluate any model against the environment. call_fn receives the obs prompt string, returns a JSON string with {tool, params}. """ import json as _json ep = EpisodeManager() scores = {} for task_id in range(1, n_tasks + 1): obs = ep.reset(task_id=task_id) done = False while not done and ep._step < 30: prompt = obs_to_prompt(obs.dict()) try: response = call_fn(prompt) action = _json.loads(response) result = ep.step(action["tool"], action.get("params", {})) obs = result.observation done = result.done except Exception as e: print(f"Model error on task {task_id}: {e}") break scores[f"task_{task_id}"] = ep.reward.normalized_score() avg = sum(scores.values()) / len(scores) scores["average"] = round(avg, 4) print(f"\n{model_name} evaluation results:") for k, v in scores.items(): print(f" {k}: {v:.3f}") return scores # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- if __name__ == "__main__": parser = argparse.ArgumentParser(description="Meta-SRE dataset generator") parser.add_argument("--episodes", type=int, default=40, help="Episodes per task (default: 40 → 200 total)") parser.add_argument("--output", type=str, default="training/dataset/training_data.jsonl") parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() generate_dataset( episodes_per_task=args.episodes, output_path=args.output, seed=args.seed, )