Meta-SRE / training /generator.py
Anvit25's picture
Deploy Meta-SRE OpenEnv benchmark FastAPI server
ad6248e
"""
Layer 4 – Perfect-Play Bot & JSONL Dataset Generator.
Runs all 5 tasks optimally to generate training episodes.
Outputs: training/dataset/training_data.jsonl
Usage:
python -m training.generator --episodes 40 --output training/dataset/training_data.jsonl
"""
from __future__ import annotations
import sys, os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import json
import random
import argparse
from pathlib import Path
from typing import Any, Dict, List, Tuple
from app.engine.manager import EpisodeManager, TASK_DEFINITIONS
from app.engine.observability import DifficultyController
# ---------------------------------------------------------------------------
# Perfect-Play Scripts
# Each script is a list of (tool, params) tuples that solve the task optimally.
# Randomise is applied to variable names / details for dataset diversity.
# ---------------------------------------------------------------------------
def _vary(base: str, variants: List[str]) -> str:
return random.choice([base] + variants)
def perfect_play_task1(ep: EpisodeManager) -> List[Tuple[str, Dict]]:
"""Hallucinated attribute: ad.get_clicks() → ad.get('clicks', 0)"""
return [
("read_logs", {"service": "ad_ranking", "log_level": "ERROR", "last_n_lines": 20}),
("view_file", {"service": "ad_ranking", "filename": "ranker.py"}),
("git_blame", {"service": "ad_ranking", "filename": "ranker.py", "line_number": 22}),
("edit_line", {
"service": "ad_ranking",
"filename": "ranker.py",
"line_number": 22,
"new_code": _vary(
" click_rate = ad.get('clicks', 0) / max(ad.get('impressions', 1), 1)",
[" click_rate = ad['clicks'] / max(ad.get('impressions', 1), 1)"]
),
}),
("run_tests", {"service": "ad_ranking", "suite": "unit"}),
("write_incident_report", {
"root_cause": "AttributeError: dict has no attribute get_clicks() — Junior AI generated method call instead of dict accessor",
"fix_applied": "Replaced ad.get_clicks() with ad.get('clicks', 0) on ranker.py line 22",
"services_affected": ["ad_ranking"],
"severity_classification": "P0",
}),
]
def perfect_play_task2(ep: EpisodeManager) -> List[Tuple[str, Dict]]:
"""Silent timestamp corruption: threshold 1e9 → 1e12"""
return [
("read_logs", {"service": "ad_ranking", "log_level": "WARN", "last_n_lines": 20}),
("check_dependency", {"service_a": "ad_ranking", "service_b": "capi_pipeline"}),
("query_metrics_history", {"service": "capi_pipeline", "metric": "error_rate", "hours_back": 6}),
("read_logs", {"service": "capi_pipeline", "log_level": "DEBUG", "last_n_lines": 20}),
("view_file", {"service": "capi_pipeline", "filename": "transformer.py"}),
("edit_line", {
"service": "capi_pipeline",
"filename": "transformer.py",
"line_number": 43,
"new_code": " if ts > 1_000_000_000_000:",
}),
("run_tests", {"service": "capi_pipeline", "suite": "integration"}),
("write_incident_report", {
"root_cause": "Timestamp normalisation threshold in capi_pipeline/transformer.py was 1e9 instead of 1e12 — unix-second timestamps treated as milliseconds, resulting in events attributed to 1970",
"fix_applied": "Changed _normalize_timestamp threshold from 1_000_000_000 to 1_000_000_000_000 on transformer.py line 40",
"services_affected": ["capi_pipeline", "ad_ranking"],
"severity_classification": "P1",
}),
]
def perfect_play_task3(ep: EpisodeManager) -> List[Tuple[str, Dict]]:
"""Connection pool exhaustion: add finally: await db_pool.release(conn)"""
return [
("read_logs", {"service": "whatsapp_sync", "log_level": "ERROR", "last_n_lines": 20}),
("query_metrics_history", {"service": "whatsapp_sync", "metric": "request_queue", "hours_back": 4}),
("view_file", {"service": "whatsapp_sync", "filename": "handler.py"}),
("git_blame", {"service": "whatsapp_sync", "filename": "handler.py", "line_number": 35}),
("run_tests", {"service": "whatsapp_sync", "suite": "unit"}),
("edit_line", {
"service": "whatsapp_sync",
"filename": "handler.py",
"line_number": 35,
"new_code": " raise",
}),
("edit_line", {
"service": "whatsapp_sync",
"filename": "handler.py",
"line_number": 36,
"new_code": " finally:",
}),
("edit_line", {
"service": "whatsapp_sync",
"filename": "handler.py",
"line_number": 37,
"new_code": " await self.db_pool.release(conn)",
}),
("run_tests", {"service": "whatsapp_sync", "suite": "load"}),
("write_incident_report", {
"root_cause": "DB connection pool exhaustion in whatsapp_sync — sync_user_messages() acquires a connection but has no finally block to release it on exception, causing pool depletion under concurrent load",
"fix_applied": "Added finally: await self.db_pool.release(conn) to sync_user_messages() in handler.py",
"services_affected": ["whatsapp_sync"],
"severity_classification": "P1",
}),
]
def perfect_play_task4(ep: EpisodeManager) -> List[Tuple[str, Dict]]:
"""Circular FK in migration 003 cascading to all services"""
return [
("read_logs", {"service": "whatsapp_sync", "log_level": "ERROR", "last_n_lines": 30}),
("query_metrics_history", {"service": "capi_pipeline", "metric": "p99_latency_ms", "hours_back": 6}),
("view_file", {"service": "whatsapp_sync", "filename": "db.py"}),
("git_blame", {"service": "whatsapp_sync", "filename": "db.py", "line_number": 45}),
("run_tests", {"service": "whatsapp_sync", "suite": "unit"}),
("rollback", {"service": "whatsapp_sync", "version": "003"}),
("run_tests", {"service": "whatsapp_sync", "suite": "integration"}),
("write_incident_report", {
"root_cause": "Circular foreign key in migration 003: message_threads.parent_message_id references messages, and the ALTER TABLE added messages.thread_id referencing message_threads — PostgreSQL FK resolution failure cascaded to all DB pool consumers",
"fix_applied": "Rolled back migration 003 to remove circular FK constraint",
"services_affected": ["whatsapp_sync"],
"severity_classification": "P0",
}),
]
def perfect_play_task5(ep: EpisodeManager) -> List[Tuple[str, Dict]]:
"""PII data leak: DEBUG_MODE = True → False"""
return [
("read_logs", {"service": "capi_pipeline", "log_level": "DEBUG", "last_n_lines": 20}),
("run_tests", {"service": "capi_pipeline", "suite": "unit"}),
("view_file", {"service": "capi_pipeline", "filename": "ingestor.py"}),
("git_blame", {"service": "capi_pipeline", "filename": "ingestor.py", "line_number": 7}),
("edit_line", {
"service": "capi_pipeline",
"filename": "ingestor.py",
"line_number": 7,
"new_code": "DEBUG_MODE = False # FIXED: must be False in production",
}),
("run_tests", {"service": "capi_pipeline", "suite": "security"}),
("write_incident_report", {
"root_cause": "PII data exposure: DEBUG_MODE=True in production caused /ingest to return raw user PII (emails, phone numbers) in HTTP response body — invisible to unit tests, caught by security suite",
"fix_applied": "Set DEBUG_MODE = False in capi_pipeline/ingestor.py line 7",
"services_affected": ["capi_pipeline"],
"severity_classification": "P0",
}),
]
PERFECT_PLAY_SCRIPTS = {
1: perfect_play_task1,
2: perfect_play_task2,
3: perfect_play_task3,
4: perfect_play_task4,
5: perfect_play_task5,
}
# ---------------------------------------------------------------------------
# Observation → prompt string formatter
# ---------------------------------------------------------------------------
def obs_to_prompt(obs: dict) -> str:
"""Format the observation dict as the LLM system+user prompt."""
metrics_summary = []
for svc, m in obs.get("system_metrics", {}).items():
if isinstance(m, dict):
metrics_summary.append(
f" {svc}: CPU={m.get('cpu_percent',0):.0f}% "
f"MEM={m.get('memory_mb',0):.0f}MB "
f"ERR={m.get('error_rate',0):.1f}/s "
f"STATUS={m.get('status','?')}"
)
alerts_summary = []
for a in obs.get("active_alerts", []):
if isinstance(a, dict):
alerts_summary.append(
f" [{a.get('severity','?')}] {a.get('service','?')}: {a.get('message','')}"
)
return (
f"INCIDENT: {obs.get('incident_id','')}\n"
f"TASK: {obs.get('task_description','')}\n"
f"STEP: {obs.get('step',0)} | BUDGET: {obs.get('budget_remaining',0)} steps remaining\n\n"
f"SYSTEM METRICS:\n" + "\n".join(metrics_summary) + "\n\n"
f"ACTIVE ALERTS:\n" + ("\n".join(alerts_summary) or " None") + "\n\n"
f"TERMINAL:\n{obs.get('terminal_output','')}\n\n"
f"SRE MEMORY:\n" + ("\n".join(f" {m}" for m in obs.get("sre_memory", [])) or " (empty)") + "\n"
)
def action_to_response(tool: str, params: Dict) -> str:
"""Format agent action as the assistant turn in the conversation."""
return json.dumps({"tool": tool, "params": params}, indent=2)
# ---------------------------------------------------------------------------
# Episode runner
# ---------------------------------------------------------------------------
def run_episode(task_id: int, ep: EpisodeManager) -> List[Dict]:
"""Run one perfect-play episode. Returns conversation turns for JSONL."""
obs = ep.reset(task_id=task_id)
script_fn = PERFECT_PLAY_SCRIPTS[task_id]
actions = script_fn(ep)
turns = []
obs_dict = obs.model_dump()
system_prompt = (
"You are a Senior Site Reliability Engineer (SRE) at Meta. "
"You are debugging a live production incident. "
"Use the available tools methodically: read logs first, then inspect code, "
"make surgical single-line edits, verify with tests, and close with an incident report. "
"Never rewrite entire files. Always run tests after editing."
)
# Initial observation as first user turn
turns.append({
"role": "system",
"content": system_prompt,
})
turns.append({
"role": "user",
"content": obs_to_prompt(obs_dict),
})
for tool, params in actions:
# Assistant decides action
turns.append({
"role": "assistant",
"content": action_to_response(tool, params),
})
# Execute in environment
try:
result = ep.step(tool=tool, params=params)
obs_dict = result.observation.model_dump()
# Next user turn = new observation
turns.append({
"role": "user",
"content": obs_to_prompt(obs_dict),
})
if result.done:
break
except RuntimeError:
break
return turns
# ---------------------------------------------------------------------------
# Dataset generator
# ---------------------------------------------------------------------------
def generate_dataset(
episodes_per_task: int = 40,
output_path: str = "training/dataset/training_data.jsonl",
seed: int = 42,
) -> None:
random.seed(seed)
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
ep = EpisodeManager(difficulty_controller=DifficultyController())
total = 0
task_counts = {t: 0 for t in range(1, 6)}
with open(output_path, "w", encoding="utf-8") as f:
for episode_idx in range(episodes_per_task * 5):
task_id = (episode_idx % 5) + 1
try:
turns = run_episode(task_id, ep)
result = ep.get_episode_result()
record = {
"episode_id": f"ep_{episode_idx:04d}",
"task_id": task_id,
"normalized_score": result.normalized_score,
"steps_taken": result.steps_taken,
"messages": turns,
}
f.write(json.dumps(record) + "\n")
total += 1
task_counts[task_id] += 1
if episode_idx % 10 == 0:
print(
f"[{episode_idx:4d}/{episodes_per_task*5}] "
f"task={task_id} score={result.normalized_score:.3f} "
f"steps={result.steps_taken}"
)
except Exception as e:
print(f"WARNING: episode {episode_idx} task {task_id} failed: {e}")
print(f"\nDataset written to {output_path}")
print(f"Total episodes: {total}")
for t, c in task_counts.items():
print(f" Task {t}: {c} episodes")
# ---------------------------------------------------------------------------
# Baseline evaluator (for "before training" comparison)
# ---------------------------------------------------------------------------
def run_baseline_naive(task_id: int) -> float:
"""
Simulate a naive LLM that immediately tries to rewrite a whole file.
Returns normalized score (expected ~0.18).
"""
ep = EpisodeManager()
ep.reset(task_id=task_id)
# Naive agent: immediately tries to edit line 1 with garbage
ep.step("edit_line", {
"service": "ad_ranking",
"filename": "ranker.py",
"line_number": 1,
"new_code": "# rewriting entire file... (hallucination)",
})
# Then writes incident report without fixing anything
ep.step("write_incident_report", {
"root_cause": "unknown error in the code",
"fix_applied": "rewrote the file",
"services_affected": ["ad_ranking"],
"severity_classification": "P1",
})
return ep.reward.normalized_score()
def evaluate_model(
model_name: str,
call_fn, # callable(prompt: str) -> str (returns JSON action)
n_tasks: int = 5,
) -> Dict[str, Any]:
"""
Evaluate any model against the environment.
call_fn receives the obs prompt string, returns a JSON string with {tool, params}.
"""
import json as _json
ep = EpisodeManager()
scores = {}
for task_id in range(1, n_tasks + 1):
obs = ep.reset(task_id=task_id)
done = False
while not done and ep._step < 30:
prompt = obs_to_prompt(obs.dict())
try:
response = call_fn(prompt)
action = _json.loads(response)
result = ep.step(action["tool"], action.get("params", {}))
obs = result.observation
done = result.done
except Exception as e:
print(f"Model error on task {task_id}: {e}")
break
scores[f"task_{task_id}"] = ep.reward.normalized_score()
avg = sum(scores.values()) / len(scores)
scores["average"] = round(avg, 4)
print(f"\n{model_name} evaluation results:")
for k, v in scores.items():
print(f" {k}: {v:.3f}")
return scores
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Meta-SRE dataset generator")
parser.add_argument("--episodes", type=int, default=40,
help="Episodes per task (default: 40 → 200 total)")
parser.add_argument("--output", type=str,
default="training/dataset/training_data.jsonl")
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
generate_dataset(
episodes_per_task=args.episodes,
output_path=args.output,
seed=args.seed,
)