| """ |
| Layer 4 – Perfect-Play Bot & JSONL Dataset Generator. |
| |
| Runs all 5 tasks optimally to generate training episodes. |
| Outputs: training/dataset/training_data.jsonl |
| |
| Usage: |
| python -m training.generator --episodes 40 --output training/dataset/training_data.jsonl |
| """ |
| from __future__ import annotations |
| import sys, os |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
| import json |
| import random |
| import argparse |
| from pathlib import Path |
| from typing import Any, Dict, List, Tuple |
|
|
| from app.engine.manager import EpisodeManager, TASK_DEFINITIONS |
| from app.engine.observability import DifficultyController |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| def _vary(base: str, variants: List[str]) -> str: |
| return random.choice([base] + variants) |
|
|
|
|
| def perfect_play_task1(ep: EpisodeManager) -> List[Tuple[str, Dict]]: |
| """Hallucinated attribute: ad.get_clicks() → ad.get('clicks', 0)""" |
| return [ |
| ("read_logs", {"service": "ad_ranking", "log_level": "ERROR", "last_n_lines": 20}), |
| ("view_file", {"service": "ad_ranking", "filename": "ranker.py"}), |
| ("git_blame", {"service": "ad_ranking", "filename": "ranker.py", "line_number": 22}), |
| ("edit_line", { |
| "service": "ad_ranking", |
| "filename": "ranker.py", |
| "line_number": 22, |
| "new_code": _vary( |
| " click_rate = ad.get('clicks', 0) / max(ad.get('impressions', 1), 1)", |
| [" click_rate = ad['clicks'] / max(ad.get('impressions', 1), 1)"] |
| ), |
| }), |
| ("run_tests", {"service": "ad_ranking", "suite": "unit"}), |
| ("write_incident_report", { |
| "root_cause": "AttributeError: dict has no attribute get_clicks() — Junior AI generated method call instead of dict accessor", |
| "fix_applied": "Replaced ad.get_clicks() with ad.get('clicks', 0) on ranker.py line 22", |
| "services_affected": ["ad_ranking"], |
| "severity_classification": "P0", |
| }), |
| ] |
|
|
|
|
| def perfect_play_task2(ep: EpisodeManager) -> List[Tuple[str, Dict]]: |
| """Silent timestamp corruption: threshold 1e9 → 1e12""" |
| return [ |
| ("read_logs", {"service": "ad_ranking", "log_level": "WARN", "last_n_lines": 20}), |
| ("check_dependency", {"service_a": "ad_ranking", "service_b": "capi_pipeline"}), |
| ("query_metrics_history", {"service": "capi_pipeline", "metric": "error_rate", "hours_back": 6}), |
| ("read_logs", {"service": "capi_pipeline", "log_level": "DEBUG", "last_n_lines": 20}), |
| ("view_file", {"service": "capi_pipeline", "filename": "transformer.py"}), |
| ("edit_line", { |
| "service": "capi_pipeline", |
| "filename": "transformer.py", |
| "line_number": 43, |
| "new_code": " if ts > 1_000_000_000_000:", |
| }), |
| ("run_tests", {"service": "capi_pipeline", "suite": "integration"}), |
| ("write_incident_report", { |
| "root_cause": "Timestamp normalisation threshold in capi_pipeline/transformer.py was 1e9 instead of 1e12 — unix-second timestamps treated as milliseconds, resulting in events attributed to 1970", |
| "fix_applied": "Changed _normalize_timestamp threshold from 1_000_000_000 to 1_000_000_000_000 on transformer.py line 40", |
| "services_affected": ["capi_pipeline", "ad_ranking"], |
| "severity_classification": "P1", |
| }), |
| ] |
|
|
|
|
| def perfect_play_task3(ep: EpisodeManager) -> List[Tuple[str, Dict]]: |
| """Connection pool exhaustion: add finally: await db_pool.release(conn)""" |
| return [ |
| ("read_logs", {"service": "whatsapp_sync", "log_level": "ERROR", "last_n_lines": 20}), |
| ("query_metrics_history", {"service": "whatsapp_sync", "metric": "request_queue", "hours_back": 4}), |
| ("view_file", {"service": "whatsapp_sync", "filename": "handler.py"}), |
| ("git_blame", {"service": "whatsapp_sync", "filename": "handler.py", "line_number": 35}), |
| ("run_tests", {"service": "whatsapp_sync", "suite": "unit"}), |
| ("edit_line", { |
| "service": "whatsapp_sync", |
| "filename": "handler.py", |
| "line_number": 35, |
| "new_code": " raise", |
| }), |
| ("edit_line", { |
| "service": "whatsapp_sync", |
| "filename": "handler.py", |
| "line_number": 36, |
| "new_code": " finally:", |
| }), |
| ("edit_line", { |
| "service": "whatsapp_sync", |
| "filename": "handler.py", |
| "line_number": 37, |
| "new_code": " await self.db_pool.release(conn)", |
| }), |
| ("run_tests", {"service": "whatsapp_sync", "suite": "load"}), |
| ("write_incident_report", { |
| "root_cause": "DB connection pool exhaustion in whatsapp_sync — sync_user_messages() acquires a connection but has no finally block to release it on exception, causing pool depletion under concurrent load", |
| "fix_applied": "Added finally: await self.db_pool.release(conn) to sync_user_messages() in handler.py", |
| "services_affected": ["whatsapp_sync"], |
| "severity_classification": "P1", |
| }), |
| ] |
|
|
|
|
| def perfect_play_task4(ep: EpisodeManager) -> List[Tuple[str, Dict]]: |
| """Circular FK in migration 003 cascading to all services""" |
| return [ |
| ("read_logs", {"service": "whatsapp_sync", "log_level": "ERROR", "last_n_lines": 30}), |
| ("query_metrics_history", {"service": "capi_pipeline", "metric": "p99_latency_ms", "hours_back": 6}), |
| ("view_file", {"service": "whatsapp_sync", "filename": "db.py"}), |
| ("git_blame", {"service": "whatsapp_sync", "filename": "db.py", "line_number": 45}), |
| ("run_tests", {"service": "whatsapp_sync", "suite": "unit"}), |
| ("rollback", {"service": "whatsapp_sync", "version": "003"}), |
| ("run_tests", {"service": "whatsapp_sync", "suite": "integration"}), |
| ("write_incident_report", { |
| "root_cause": "Circular foreign key in migration 003: message_threads.parent_message_id references messages, and the ALTER TABLE added messages.thread_id referencing message_threads — PostgreSQL FK resolution failure cascaded to all DB pool consumers", |
| "fix_applied": "Rolled back migration 003 to remove circular FK constraint", |
| "services_affected": ["whatsapp_sync"], |
| "severity_classification": "P0", |
| }), |
| ] |
|
|
|
|
| def perfect_play_task5(ep: EpisodeManager) -> List[Tuple[str, Dict]]: |
| """PII data leak: DEBUG_MODE = True → False""" |
| return [ |
| ("read_logs", {"service": "capi_pipeline", "log_level": "DEBUG", "last_n_lines": 20}), |
| ("run_tests", {"service": "capi_pipeline", "suite": "unit"}), |
| ("view_file", {"service": "capi_pipeline", "filename": "ingestor.py"}), |
| ("git_blame", {"service": "capi_pipeline", "filename": "ingestor.py", "line_number": 7}), |
| ("edit_line", { |
| "service": "capi_pipeline", |
| "filename": "ingestor.py", |
| "line_number": 7, |
| "new_code": "DEBUG_MODE = False # FIXED: must be False in production", |
| }), |
| ("run_tests", {"service": "capi_pipeline", "suite": "security"}), |
| ("write_incident_report", { |
| "root_cause": "PII data exposure: DEBUG_MODE=True in production caused /ingest to return raw user PII (emails, phone numbers) in HTTP response body — invisible to unit tests, caught by security suite", |
| "fix_applied": "Set DEBUG_MODE = False in capi_pipeline/ingestor.py line 7", |
| "services_affected": ["capi_pipeline"], |
| "severity_classification": "P0", |
| }), |
| ] |
|
|
|
|
| PERFECT_PLAY_SCRIPTS = { |
| 1: perfect_play_task1, |
| 2: perfect_play_task2, |
| 3: perfect_play_task3, |
| 4: perfect_play_task4, |
| 5: perfect_play_task5, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def obs_to_prompt(obs: dict) -> str: |
| """Format the observation dict as the LLM system+user prompt.""" |
| metrics_summary = [] |
| for svc, m in obs.get("system_metrics", {}).items(): |
| if isinstance(m, dict): |
| metrics_summary.append( |
| f" {svc}: CPU={m.get('cpu_percent',0):.0f}% " |
| f"MEM={m.get('memory_mb',0):.0f}MB " |
| f"ERR={m.get('error_rate',0):.1f}/s " |
| f"STATUS={m.get('status','?')}" |
| ) |
|
|
| alerts_summary = [] |
| for a in obs.get("active_alerts", []): |
| if isinstance(a, dict): |
| alerts_summary.append( |
| f" [{a.get('severity','?')}] {a.get('service','?')}: {a.get('message','')}" |
| ) |
|
|
| return ( |
| f"INCIDENT: {obs.get('incident_id','')}\n" |
| f"TASK: {obs.get('task_description','')}\n" |
| f"STEP: {obs.get('step',0)} | BUDGET: {obs.get('budget_remaining',0)} steps remaining\n\n" |
| f"SYSTEM METRICS:\n" + "\n".join(metrics_summary) + "\n\n" |
| f"ACTIVE ALERTS:\n" + ("\n".join(alerts_summary) or " None") + "\n\n" |
| f"TERMINAL:\n{obs.get('terminal_output','')}\n\n" |
| f"SRE MEMORY:\n" + ("\n".join(f" {m}" for m in obs.get("sre_memory", [])) or " (empty)") + "\n" |
| ) |
|
|
|
|
| def action_to_response(tool: str, params: Dict) -> str: |
| """Format agent action as the assistant turn in the conversation.""" |
| return json.dumps({"tool": tool, "params": params}, indent=2) |
|
|
|
|
| |
| |
| |
|
|
| def run_episode(task_id: int, ep: EpisodeManager) -> List[Dict]: |
| """Run one perfect-play episode. Returns conversation turns for JSONL.""" |
| obs = ep.reset(task_id=task_id) |
| script_fn = PERFECT_PLAY_SCRIPTS[task_id] |
| actions = script_fn(ep) |
|
|
| turns = [] |
| obs_dict = obs.model_dump() |
|
|
| system_prompt = ( |
| "You are a Senior Site Reliability Engineer (SRE) at Meta. " |
| "You are debugging a live production incident. " |
| "Use the available tools methodically: read logs first, then inspect code, " |
| "make surgical single-line edits, verify with tests, and close with an incident report. " |
| "Never rewrite entire files. Always run tests after editing." |
| ) |
|
|
| |
| turns.append({ |
| "role": "system", |
| "content": system_prompt, |
| }) |
| turns.append({ |
| "role": "user", |
| "content": obs_to_prompt(obs_dict), |
| }) |
|
|
| for tool, params in actions: |
| |
| turns.append({ |
| "role": "assistant", |
| "content": action_to_response(tool, params), |
| }) |
|
|
| |
| try: |
| result = ep.step(tool=tool, params=params) |
| obs_dict = result.observation.model_dump() |
|
|
| |
| turns.append({ |
| "role": "user", |
| "content": obs_to_prompt(obs_dict), |
| }) |
|
|
| if result.done: |
| break |
| except RuntimeError: |
| break |
|
|
| return turns |
|
|
|
|
| |
| |
| |
|
|
| def generate_dataset( |
| episodes_per_task: int = 40, |
| output_path: str = "training/dataset/training_data.jsonl", |
| seed: int = 42, |
| ) -> None: |
| random.seed(seed) |
| Path(output_path).parent.mkdir(parents=True, exist_ok=True) |
|
|
| ep = EpisodeManager(difficulty_controller=DifficultyController()) |
| total = 0 |
| task_counts = {t: 0 for t in range(1, 6)} |
|
|
| with open(output_path, "w", encoding="utf-8") as f: |
| for episode_idx in range(episodes_per_task * 5): |
| task_id = (episode_idx % 5) + 1 |
|
|
| try: |
| turns = run_episode(task_id, ep) |
| result = ep.get_episode_result() |
|
|
| record = { |
| "episode_id": f"ep_{episode_idx:04d}", |
| "task_id": task_id, |
| "normalized_score": result.normalized_score, |
| "steps_taken": result.steps_taken, |
| "messages": turns, |
| } |
| f.write(json.dumps(record) + "\n") |
| total += 1 |
| task_counts[task_id] += 1 |
|
|
| if episode_idx % 10 == 0: |
| print( |
| f"[{episode_idx:4d}/{episodes_per_task*5}] " |
| f"task={task_id} score={result.normalized_score:.3f} " |
| f"steps={result.steps_taken}" |
| ) |
| except Exception as e: |
| print(f"WARNING: episode {episode_idx} task {task_id} failed: {e}") |
|
|
| print(f"\nDataset written to {output_path}") |
| print(f"Total episodes: {total}") |
| for t, c in task_counts.items(): |
| print(f" Task {t}: {c} episodes") |
|
|
|
|
| |
| |
| |
|
|
| def run_baseline_naive(task_id: int) -> float: |
| """ |
| Simulate a naive LLM that immediately tries to rewrite a whole file. |
| Returns normalized score (expected ~0.18). |
| """ |
| ep = EpisodeManager() |
| ep.reset(task_id=task_id) |
|
|
| |
| ep.step("edit_line", { |
| "service": "ad_ranking", |
| "filename": "ranker.py", |
| "line_number": 1, |
| "new_code": "# rewriting entire file... (hallucination)", |
| }) |
| |
| ep.step("write_incident_report", { |
| "root_cause": "unknown error in the code", |
| "fix_applied": "rewrote the file", |
| "services_affected": ["ad_ranking"], |
| "severity_classification": "P1", |
| }) |
| return ep.reward.normalized_score() |
|
|
|
|
| def evaluate_model( |
| model_name: str, |
| call_fn, |
| n_tasks: int = 5, |
| ) -> Dict[str, Any]: |
| """ |
| Evaluate any model against the environment. |
| call_fn receives the obs prompt string, returns a JSON string with {tool, params}. |
| """ |
| import json as _json |
| ep = EpisodeManager() |
| scores = {} |
|
|
| for task_id in range(1, n_tasks + 1): |
| obs = ep.reset(task_id=task_id) |
| done = False |
| while not done and ep._step < 30: |
| prompt = obs_to_prompt(obs.dict()) |
| try: |
| response = call_fn(prompt) |
| action = _json.loads(response) |
| result = ep.step(action["tool"], action.get("params", {})) |
| obs = result.observation |
| done = result.done |
| except Exception as e: |
| print(f"Model error on task {task_id}: {e}") |
| break |
| scores[f"task_{task_id}"] = ep.reward.normalized_score() |
|
|
| avg = sum(scores.values()) / len(scores) |
| scores["average"] = round(avg, 4) |
| print(f"\n{model_name} evaluation results:") |
| for k, v in scores.items(): |
| print(f" {k}: {v:.3f}") |
| return scores |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Meta-SRE dataset generator") |
| parser.add_argument("--episodes", type=int, default=40, |
| help="Episodes per task (default: 40 → 200 total)") |
| parser.add_argument("--output", type=str, |
| default="training/dataset/training_data.jsonl") |
| parser.add_argument("--seed", type=int, default=42) |
| args = parser.parse_args() |
|
|
| generate_dataset( |
| episodes_per_task=args.episodes, |
| output_path=args.output, |
| seed=args.seed, |
| ) |
|
|