# /// script # dependencies = [ # "requests", # "huggingface_hub", # ] # /// """ ============================================================= SRE INCIDENT RESPONSE — COMPREHENSIVE TRAJECTORY COLLECTOR ============================================================= Generates fine-tuning trajectories from the SRE incident simulator covering: • All 10 tasks (8 training + 2 held-out compound scenarios) • All 4 pools: A — Phase 1 only (incident response) B — Phase 2 only (code investigation, oracle belief injected) C — Joint P1→P2 (full two-phase pipeline) D — Held-out joint (generalization test) • Full 17-action action space across both phases • Multiple models from 1.5B to 70B+ (round-robin rotation) • ALL episodes retained — negative-reward trajectories are kept as hard-negative examples for RL/GRPO training Output files: sre_raw_trajectories.json — full episode records with score breakdowns sre_sft_dataset.jsonl — per-step SFT samples (both phases, all rewards) sre_grpo_dataset.jsonl — (prompt, chosen, rejected) pairs for GRPO/DPO Usage: export HF_TOKEN=hf_... python sre_finetune_collector.py Optional env vars: NUM_EPISODES total episodes to collect (default: 200) BASE_URL simulator URL (default: HF Space URL) MAX_STEPS max steps per episode (default: 35) SLEEP_BETWEEN seconds between steps (default: 0.6) """ from __future__ import annotations import json import os import random import time import traceback from collections import defaultdict from typing import Any, Dict, List, Optional, Tuple import requests def upload_checkpoint(api, repo_id): for fname in ["sre_raw_trajectories.jsonl", "sre_sft_dataset.jsonl", "sre_grpo_dataset.jsonl"]: if os.path.exists(fname): try: api.upload_file( path_or_fileobj=fname, path_in_repo=fname, repo_id=repo_id, repo_type="dataset", ) print(f"✅ Uploaded {fname}") except Exception as e: print(f"❌ Upload failed {fname}: {e}") # ────────────────────────────────────────────────────────────────────────────── # Configuration # ────────────────────────────────────────────────────────────────────────────── HF_TOKEN = os.environ.get("HF_TOKEN") BASE_URL = os.environ.get("BASE_URL", "https://meta-hf-hackathon-updated-policy.hf.space") HF_ROUTER_URL = "https://router.huggingface.co/v1/chat/completions" NUM_EPISODES = int(os.environ.get("NUM_EPISODES", "100")) MAX_STEPS = int(os.environ.get("MAX_STEPS", "35")) SLEEP_BETWEEN = float(os.environ.get("SLEEP_BETWEEN", "0.6")) # ── Model ───────────────────────────────────────────────────────────────────── MODELS: List[str] = [ "Qwen/Qwen2.5-7B-Instruct:fastest", ] # ── Task registry per pool ──────────────────────────────────────────────────── # Pool A: P1-only incident response (all 8 training tasks) # Pool B: P2-only code investigation (oracle belief injected; 7 tasks with code) # Pool C: Joint P1→P2 full pipeline (all 8 training tasks) # Pool D: Held-out joint (2 compound scenarios — generalization evaluation) POOL_TASKS: Dict[str, List[str]] = { "A": [ "memory_leak", "cascading_failure", "distributed_deadlock", "circuit_breaker_noop", "aliased_fault", "severity_inversion", "confidence_inversion", "info_ordering", ], "B": [ "memory_leak", "cascading_failure", "distributed_deadlock", "aliased_fault", "severity_inversion", "confidence_inversion", "info_ordering", ], "C": [ "memory_leak", "cascading_failure", "distributed_deadlock", "circuit_breaker_noop", "aliased_fault", "severity_inversion", "confidence_inversion", "info_ordering", ], "D": [ "heldout_aliased_severity", "heldout_confidence_ordering", ], } # Episode budget distribution across pools (must sum to 1.0) POOL_WEIGHTS: Dict[str, float] = {"A": 0.35, "B": 0.20, "C": 0.35, "D": 0.10} # ── Action space definitions ────────────────────────────────────────────────── P1_DIAGNOSTIC = ["view_alerts", "query_logs", "check_metrics", "check_dependencies", "check_deploy_history", "run_health_check"] P1_REMEDIATION = ["restart_service", "rollback_deploy", "scale_service"] P1_TERMINAL = ["declare_root_cause", "transition_to_phase2"] P1_ACTIONS = P1_DIAGNOSTIC + P1_REMEDIATION + P1_TERMINAL P2_DIAGNOSTIC = ["list_dir", "read_file", "search_code", "get_git_log", "get_file_diff"] P2_TERMINAL = ["propose_patch", "declare_no_change"] P2_ACTIONS = P2_DIAGNOSTIC + P2_TERMINAL ALL_SERVICES = ["api_gateway", "auth", "orders", "payment", "cache", "database", "queue"] TARGETED_ACTIONS = { "query_logs", "check_metrics", "check_dependencies", "check_deploy_history", "run_health_check", "restart_service", "rollback_deploy", "scale_service", } # Service dependency graph (for smarter fallbacks) DEPENDENCY_GRAPH: Dict[str, List[str]] = { "api_gateway": ["auth", "orders", "cache"], "auth": ["database"], "orders": ["database", "payment", "auth"], "payment": ["queue", "database"], "cache": [], "database": [], "queue": [], } # ────────────────────────────────────────────────────────────────────────────── # System Prompts # ────────────────────────────────────────────────────────────────────────────── SYSTEM_PROMPT_P1 = """You are an expert SRE handling a production incident in a microservices system. ## Service Topology (downstream ← upstream) api_gateway ← auth, orders, cache auth ← database orders ← database, payment, auth payment ← queue, database cache, database, queue ← (no dependencies) ## Phase 1 Action Space Output EXACTLY ONE valid JSON action per turn. No markdown, no explanation. Diagnostic (read-only): {"action_type": "view_alerts"} {"action_type": "query_logs", "target_service": "", "parameters": {"level": "ERROR", "keyword": "", "limit": 20}} {"action_type": "check_metrics", "target_service": ""} {"action_type": "check_dependencies", "target_service": ""} {"action_type": "check_deploy_history", "target_service": ""} {"action_type": "run_health_check", "target_service": ""} Remediation (mutates state): {"action_type": "restart_service", "target_service": ""} {"action_type": "rollback_deploy", "target_service": ""} {"action_type": "scale_service", "target_service": "", "parameters": {"replicas": 5}} Declare root cause (ALL tasks — always call this once you have a diagnosis): {"action_type": "declare_root_cause", "parameters": {"root_cause": ""}} Then for joint-mode tasks, ALSO transition to code investigation: {"action_type": "transition_to_phase2", "parameters": {"belief": { "suspected_service": "", "suspected_fault_class": "memory_leak|config_change|deadlock|dep_upgrade|none", "service_confidence": 0.85, "fault_confidence": 0.80, "evidence_gaps": [""], "estimated_p2_cost": "low|medium|high", "decision": "transition", "reasoning": "" }}} ## Investigation Strategy 1. ALWAYS start with view_alerts to understand severity and scope 2. check_metrics on the highest-alert service first 3. query_logs (level=ERROR) on degraded/down services 4. check_dependencies on the affected service to find upstream causes 5. check_deploy_history before any rollback 6. Remediate the ROOT CAUSE service, not the symptom 7. After 6-8 diagnostic steps you MUST call declare_root_cause with your diagnosis. For P1-only tasks this ends the episode. For joint-mode tasks, follow it immediately with transition_to_phase2. Do NOT keep diagnosing indefinitely — commit to a conclusion. CRITICAL: Output ONLY valid JSON. No markdown. No explanation. No code blocks.""" SYSTEM_PROMPT_P2 = """You are an expert SRE investigating a code-level fault in a sandboxed repository. ## Phase 2 Action Space Output EXACTLY ONE valid JSON action per turn. No markdown, no explanation. Code Exploration: {"action_type": "list_dir", "parameters": {"path": "."}} {"action_type": "read_file", "parameters": {"path": "relative/path/to/file.py"}} {"action_type": "search_code", "parameters": {"query": "", "file_pattern": "*.py", "max_hits": 20}} {"action_type": "get_git_log", "parameters": {"path": ".", "n_commits": 15}} {"action_type": "get_file_diff", "parameters": {"commit_sha": "", "path": "relative/path/file.py"}} Terminal: {"action_type": "propose_patch", "parameters": {"diff": ""}} {"action_type": "declare_no_change", "parameters": {"reason": ""}} ## Investigation Strategy 1. list_dir "." to understand project structure 2. get_git_log to find recent commits — especially the bad_commit_sha from Phase 1 context 3. get_file_diff on the suspicious commit SHA to see what changed 4. read_file on affected files to understand the bug 5. search_code to find related patterns or the fault injection site 6. If you found a code bug: propose_patch with a minimal, syntactically valid unified diff. The bad_commit_sha in your context tells you exactly what changed — read that diff and revert/fix it. 7. declare_no_change ONLY if Phase 1 confirmed a spurious alert / circuit-breaker false positive with no deployment or code change involved. If there IS a bad commit in the git log, propose_patch. CRITICAL: Output ONLY valid JSON. No markdown. No explanation. No code blocks.""" # ────────────────────────────────────────────────────────────────────────────── # Observation Formatters # ────────────────────────────────────────────────────────────────────────────── def _fmt_service_statuses(statuses: Dict[str, str]) -> str: symbols = {"healthy": "✓", "degraded": "~", "down": "✗"} return " ".join( f"{symbols.get(v,'?')}{svc}({v})" for svc, v in sorted(statuses.items()) ) def _fmt_action_result(result: Any, max_chars: int = 3000) -> str: if result is None: return "(no result)" text = json.dumps(result, indent=2) if not isinstance(result, str) else result if len(text) > max_chars: text = text[:max_chars] + f"\n... [truncated {len(text)-max_chars} chars]" return text def format_initial_p1_obs(obs: dict, info: dict) -> str: """Format the very first observation for Phase 1.""" task = info.get("task_name", "unknown") pool = info.get("pool", "?") mode = info.get("mode", "unknown") phase = obs.get("current_phase", 1) svc_line = _fmt_service_statuses(obs.get("service_statuses", {})) valid = obs.get("valid_actions", P1_ACTIONS) return ( f"INCIDENT RESPONSE | Pool {pool} | Mode: {mode} | Task: {task}\n" f"{'─'*60}\n" f"Summary: {obs.get('incident_summary', 'No summary available')}\n" f"Severity: {obs.get('severity', '?')} | " f"Time Budget: {obs.get('time_budget_minutes', '?')} min | " f"Max Steps: {obs.get('max_steps', MAX_STEPS)}\n" f"Phase: {phase}\n" f"\nService Statuses:\n {svc_line}\n" f"Active Alerts: {obs.get('active_alerts_count', 0)}\n" f"\nValid Actions: {valid}\n" f"\nWhat is your FIRST action?" ) def format_step_result_p1(obs: dict, reward: float) -> str: """Format a step result during Phase 1.""" svc_line = _fmt_service_statuses(obs.get("service_statuses", {})) result = _fmt_action_result(obs.get("action_result")) lines = [ f"Action Result (success={obs.get('action_success', '?')}): " f"{obs.get('action_message', '')}", f"\n{result}", f"\n{'─'*40}", f"Services: {svc_line}", f"Alerts: {obs.get('active_alerts_count', 0)} active", f"Step: {obs.get('steps_taken','?')}/{obs.get('max_steps', MAX_STEPS)} " f"| Time: {obs.get('time_elapsed_minutes','?')}/{obs.get('time_budget_minutes','?')} min", f"Reward: {reward:+.3f} | Cumulative: {obs.get('cumulative_reward', 0):+.3f}", ] if obs.get("bad_commit_sha"): lines.append(f"Bad Commit SHA: {obs['bad_commit_sha']} (remember for Phase 2)") valid = obs.get("valid_actions", P1_ACTIONS) lines.append(f"\nValid Actions: {valid}") lines.append("\nWhat is your next action?") return "\n".join(lines) def format_initial_p2_obs(obs: dict, info: dict, belief: Optional[dict]) -> str: """Format the first Phase 2 observation (after transition or Pool B auto-start).""" task = info.get("task_name", "unknown") pool = info.get("pool", "?") belief_text = "" if belief: belief_text = ( f"\n[Phase 1 Belief]\n" f" Suspected service: {belief.get('suspected_service', '?')}\n" f" Suspected fault: {belief.get('suspected_fault_class', '?')}\n" f" Service confidence: {belief.get('service_confidence', 0):.0%}\n" f" Fault confidence: {belief.get('fault_confidence', 0):.0%}\n" f" Reasoning: {belief.get('reasoning', '')}\n" f" P2 cost estimate: {belief.get('estimated_p2_cost', '?')}\n" ) sha_line = ( f"Bad Commit SHA: {obs.get('bad_commit_sha', '(check git log)')}\n" if obs.get("bad_commit_sha") else "" ) return ( f"CODE INVESTIGATION | Pool {pool} | Task: {task}\n" f"{'─'*60}\n" f"{sha_line}" f"{belief_text}\n" f"Step: {obs.get('steps_taken', 0)}/{obs.get('max_steps', MAX_STEPS)} " f"| Cumulative Reward: {obs.get('cumulative_reward', 0):+.3f}\n" f"\nValid Actions: {obs.get('valid_actions', P2_ACTIONS)}\n" f"\nWhat is your first Phase 2 action?" ) def format_step_result_p2(obs: dict, reward: float) -> str: """Format a step result during Phase 2.""" result = _fmt_action_result(obs.get("action_result")) lines = [ f"Action Result (success={obs.get('action_success', '?')}): " f"{obs.get('action_message', '')}", f"\n{result}", f"\n{'─'*40}", f"Step: {obs.get('steps_taken','?')}/{obs.get('max_steps', MAX_STEPS)}", f"Reward: {reward:+.3f} | Cumulative: {obs.get('cumulative_reward', 0):+.3f}", f"\nValid Actions: {obs.get('valid_actions', P2_ACTIONS)}", "\nWhat is your next action?", ] return "\n".join(lines) # ────────────────────────────────────────────────────────────────────────────── # Message Builder # ────────────────────────────────────────────────────────────────────────────── def build_messages( history: List[Dict], initial_user_msg: str, phase: int, max_recent: int = 10, ) -> List[Dict]: """ Build the full OpenAI-format messages list. history: [{"action_json": str, "result_text": str, "reward": float}, ...] max_recent caps how many turns are included to avoid context-length 422s. """ system = SYSTEM_PROMPT_P1 if phase == 1 else SYSTEM_PROMPT_P2 messages: List[Dict] = [ {"role": "system", "content": system}, {"role": "user", "content": initial_user_msg}, ] for entry in history[-max_recent:]: messages.append({"role": "assistant", "content": entry["action_json"]}) messages.append({"role": "user", "content": entry["result_text"]}) return messages # ────────────────────────────────────────────────────────────────────────────── # Model Caller # ────────────────────────────────────────────────────────────────────────────── def call_model( messages: List[Dict], model: str, temperature: float = 0.5, max_tokens: int = 512, retries: int = 3, ) -> str: if not HF_TOKEN: raise ValueError("HF_TOKEN is not set.") payload = { "model": model, "messages": messages, "max_tokens": max_tokens, "temperature": temperature, } last_exc: Exception = RuntimeError("No attempts made") for attempt in range(retries): try: resp = requests.post( HF_ROUTER_URL, headers={ "Authorization": f"Bearer {HF_TOKEN}", "Content-Type": "application/json", }, json=payload, timeout=90, ) resp.raise_for_status() return resp.json()["choices"][0]["message"]["content"].strip() except requests.HTTPError as e: code = e.response.status_code if e.response is not None else 0 if code in (400, 422): raise # client-format errors — retrying won't help; let caller handle last_exc = e wait = 2 ** attempt print(f" [model retry {attempt+1}/{retries}] {e} — waiting {wait}s") time.sleep(wait) except Exception as e: last_exc = e wait = 2 ** attempt print(f" [model retry {attempt+1}/{retries}] {e} — waiting {wait}s") time.sleep(wait) raise last_exc def _merge_system_into_user(messages: List[Dict]) -> List[Dict]: """Fold system prompt into the first user message for models without system role.""" if not messages or messages[0]["role"] != "system": return messages system_text = messages[0]["content"] rest = messages[1:] if not rest or rest[0]["role"] != "user": return rest merged_first = {"role": "user", "content": f"{system_text}\n\n{rest[0]['content']}"} return [merged_first] + rest[1:] # Models confirmed to reject the system role — merged format used from the start. _MODELS_NEEDING_MERGE: set = set() def call_model_adaptive( history: List[Dict], initial_msg: str, phase: int, model: str, temperature: float = 0.5, ) -> str: """ Call model with two layers of fallback: 400 (system role) → merge system into first user message and cache the result so every subsequent step skips the wasted probe. 400 (after merge) → content still too long; halve history window. 422 (ctx length) → halve history window. """ use_merge = model in _MODELS_NEEDING_MERGE probed_merge = use_merge # True = already confirmed in a prior step, no re-probe needed max_recent = 10 while True: messages = build_messages(history, initial_msg, phase, max_recent=max_recent) if use_merge: messages = _merge_system_into_user(messages) try: result = call_model(messages, model=model, temperature=temperature) if use_merge and not probed_merge: # Merged succeeded for the first time — cache it _MODELS_NEEDING_MERGE.add(model) print(f" [merged format confirmed for {model.split('/')[1]}, cached]") return result except requests.HTTPError as e: code = e.response.status_code if e.response is not None else 0 if code == 400 and not use_merge: use_merge = True probed_merge = False print(f" [400: probing merged format for {model.split('/')[1]}]") elif code in (400, 422) and max_recent > 1: max_recent = max(1, max_recent // 2) print(f" [ctx truncated to last {max_recent} turns]") else: raise # ────────────────────────────────────────────────────────────────────────────── # Action Parsers # ────────────────────────────────────────────────────────────────────────────── def _extract_json(raw: str) -> dict: """Extract the first JSON object from model output and normalise colon-format action types. Some models output {"action_type": "check_metrics:api_gateway"} instead of the correct {"action_type": "check_metrics", "target_service": "api_gateway"}. Split and normalise so the environment never sees an invalid action_type. """ start = raw.find("{") end = raw.rfind("}") + 1 if start == -1 or end == 0: raise ValueError("No JSON object in model output") action = json.loads(raw[start:end]) atype = action.get("action_type", "") if ":" in atype: parts = atype.split(":", 1) action["action_type"] = parts[0] if parts[1] in ALL_SERVICES and "target_service" not in action: action["target_service"] = parts[1] return action def _recent_sigs(recent_actions: List[dict], n: int = 3) -> set: return {(a.get("action_type"), a.get("target_service")) for a in recent_actions[-n:]} def _diversify_p1(obs: dict, recent_actions: List[dict]) -> dict: """Return the next logical diagnostic action that hasn't been done recently.""" statuses = obs.get("service_statuses") or {} bad_svcs = [s for s, st in statuses.items() if st != "healthy"] used_sigs = _recent_sigs(recent_actions, n=4) used_svcs = {a.get("target_service") for a in recent_actions[-6:]} used_types = {a.get("action_type") for a in recent_actions[-4:]} # Build a uniform list of (score, action) tuples candidates: List[Tuple[int, dict]] = [] for atype in P1_DIAGNOSTIC: if atype == "view_alerts": score = 0 if ("view_alerts", None) in used_sigs else 2 candidates.append((score, {"action_type": "view_alerts"})) continue for svc in (bad_svcs or ALL_SERVICES): a: dict = {"action_type": atype, "target_service": svc} if atype == "query_logs": a["parameters"] = {"level": "ERROR", "limit": 20} already_used = (atype, svc) in used_sigs score = (not already_used) * 2 + (svc not in used_svcs) + (atype not in used_types) candidates.append((score, a)) candidates.sort(key=lambda x: -x[0]) if candidates: return candidates[0][1] # Last resort svc = next((s for s in ALL_SERVICES if s not in used_svcs), random.choice(ALL_SERVICES)) return {"action_type": "query_logs", "target_service": svc, "parameters": {"level": "ERROR", "limit": 20}} def parse_p1_action(raw: str, step: int, obs: dict, recent_actions: Optional[List[dict]] = None) -> dict: """Parse Phase 1 action with smart fallbacks and anti-repetition.""" recent_actions = recent_actions or [] valid = set(obs.get("valid_actions") or P1_ACTIONS) try: action = _extract_json(raw) atype = action.get("action_type", "") if atype not in valid: action = _diversify_p1(obs, recent_actions) if step > 0 else {"action_type": "view_alerts"} atype = action["action_type"] # Ensure target_service for targeted actions if atype in TARGETED_ACTIONS: if "target_service" not in action or action["target_service"] not in ALL_SERVICES: svcs = obs.get("available_services") or ALL_SERVICES action["target_service"] = random.choice(svcs) # Anti-repetition: if this exact (type, service) was used recently, diversify sig = (action.get("action_type"), action.get("target_service")) if sig in _recent_sigs(recent_actions, n=2) and atype not in ("declare_root_cause", "transition_to_phase2"): action = _diversify_p1(obs, recent_actions) atype = action["action_type"] # Validate transition_to_phase2 belief structure if atype == "transition_to_phase2": params = action.setdefault("parameters", {}) belief = params.setdefault("belief", {}) degraded = [s for s, st in (obs.get("service_statuses") or {}).items() if st != "healthy"] belief.setdefault("suspected_service", degraded[0] if degraded else random.choice(ALL_SERVICES)) belief.setdefault("suspected_fault_class", "memory_leak") belief.setdefault("service_confidence", 0.7) belief.setdefault("fault_confidence", 0.65) belief.setdefault("evidence_gaps", []) belief.setdefault("estimated_p2_cost", "medium") belief.setdefault("decision", "transition") belief.setdefault("reasoning", "Transitioning based on collected evidence") return action except Exception: if step == 0: return {"action_type": "view_alerts"} return _diversify_p1(obs, recent_actions) def _force_p1_terminal(obs: dict) -> dict: """Build a best-effort terminal action from observed state.""" valid = set(obs.get("valid_actions") or P1_ACTIONS) statuses = obs.get("service_statuses") or {} degraded = [s for s, st in statuses.items() if st != "healthy"] if "transition_to_phase2" in valid: svc = degraded[0] if degraded else random.choice(ALL_SERVICES) return { "action_type": "transition_to_phase2", "parameters": {"belief": { "suspected_service": svc, "suspected_fault_class": "memory_leak", "service_confidence": 0.5, "fault_confidence": 0.5, "evidence_gaps": ["forced_terminal_after_step_limit"], "estimated_p2_cost": "medium", "decision": "transition", "reasoning": f"Forced transition: degraded={degraded}", }}, } cause = (f"Degradation detected in: {', '.join(degraded)}" if degraded else "Root cause undetermined within step budget") return {"action_type": "declare_root_cause", "parameters": {"root_cause": cause}} def parse_p2_action(raw: str, step: int, obs: dict, recent_actions: Optional[List[dict]] = None) -> dict: """Parse Phase 2 action with smart fallbacks and anti-repetition.""" recent_actions = recent_actions or [] valid = set(obs.get("valid_actions") or P2_ACTIONS) used_sigs = _recent_sigs(recent_actions, n=3) p2_fallback_sequence = [ {"action_type": "list_dir", "parameters": {"path": "."}}, {"action_type": "get_git_log", "parameters": {"path": ".", "n_commits": 15}}, {"action_type": "search_code", "parameters": {"query": "error", "file_pattern": "*.py", "max_hits": 15}}, {"action_type": "search_code", "parameters": {"query": "def ", "file_pattern": "*.py", "max_hits": 10}}, {"action_type": "list_dir", "parameters": {"path": "src"}}, ] try: action = _extract_json(raw) atype = action.get("action_type", "") if atype not in valid: action = p2_fallback_sequence[step % len(p2_fallback_sequence)] atype = action["action_type"] # Ensure required params params = action.setdefault("parameters", {}) if atype == "list_dir": params.setdefault("path", ".") elif atype == "read_file": params.setdefault("path", ".") elif atype == "search_code": params.setdefault("query", "error") params.setdefault("file_pattern", "*.py") params.setdefault("max_hits", 15) elif atype == "get_git_log": params.setdefault("path", ".") params.setdefault("n_commits", 10) elif atype == "get_file_diff": sha = obs.get("bad_commit_sha") or "HEAD" params.setdefault("commit_sha", sha) params.setdefault("path", ".") elif atype == "propose_patch" and "diff" not in params: action = {"action_type": "declare_no_change", "parameters": {"reason": "Unable to determine code fix from available evidence"}} elif atype == "declare_no_change": params.setdefault("reason", "No code-level fix required based on investigation") # Anti-repetition for non-terminal actions sig = (action.get("action_type"), str(action.get("parameters", {}).get("path", ""))) if sig in used_sigs and atype not in P2_TERMINAL: action = p2_fallback_sequence[(step + len(recent_actions)) % len(p2_fallback_sequence)] return action except Exception: return p2_fallback_sequence[step % len(p2_fallback_sequence)] # ────────────────────────────────────────────────────────────────────────────── # Environment HTTP Helpers # ────────────────────────────────────────────────────────────────────────────── def _mask_p1_obs(obs: dict, pool: str) -> dict: """Pool A is p1_only — remove transition_to_phase2 the server incorrectly exposes.""" if pool == "A" and obs.get("valid_actions"): obs = dict(obs) obs["valid_actions"] = [a for a in obs["valid_actions"] if a != "transition_to_phase2"] return obs def env_reset(task_name: str, pool: str, seed: Optional[int] = None) -> dict: body: dict = {"task_name": task_name, "pool": pool} if seed is not None: body["seed"] = seed resp = requests.post(f"{BASE_URL}/reset", json=body, timeout=30) resp.raise_for_status() return resp.json() def env_step(action: dict) -> dict: resp = requests.post(f"{BASE_URL}/step", json=action, timeout=30) resp.raise_for_status() return resp.json() def env_score(declared_patch: Optional[str], declared_no_change: bool, belief_history: List[dict]) -> dict: """Fetch unified grader scores for the completed episode.""" try: resp = requests.post( f"{BASE_URL}/score", json={ "declared_patch": declared_patch, "declared_no_change": declared_no_change, "belief_history": belief_history, }, timeout=30, ) resp.raise_for_status() return resp.json() except Exception as e: print(f" [score] {e}") return {} def env_get_trajectory() -> dict: """Fetch the full trajectory from the server.""" try: resp = requests.get(f"{BASE_URL}/trajectory", timeout=30) resp.raise_for_status() return resp.json() except Exception: return {} # ────────────────────────────────────────────────────────────────────────────── # Episode Runner # ────────────────────────────────────────────────────────────────────────────── def run_episode( task_name: str, pool: str, model: str, episode_id: int, seed: Optional[int] = None, ) -> dict: """ Run one full episode (Phase 1, Phase 2, or Joint) through the HTTP API. Returns a rich episode record including: - step records with (action, raw_model_output, observation, reward) - final score breakdown from /score - SFT-ready message sequences per step """ print(f"\n{'═'*60}") print(f" Ep {episode_id+1:>3} | Pool {pool} | Task: {task_name}") print(f" Model: {model}") print(f"{'─'*60}") reset_resp = env_reset(task_name, pool, seed) obs = _mask_p1_obs(reset_resp.get("observation", {}), pool) info = reset_resp.get("info", {}) initial_phase = obs.get("current_phase", 1) # Tracks for the episode p1_steps: List[dict] = [] p2_steps: List[dict] = [] belief_history: List[dict] = [] declared_patch: Optional[str] = None declared_no_change: bool = False # Conversation history per phase (for message building) p1_history: List[dict] = [] p2_history: List[dict] = [] last_belief: Optional[dict] = None # Recent actions (flattened, for anti-repetition) recent_actions: List[dict] = [] consecutive_errors = [0] # consecutive model call failures consecutive_negative = [0] # consecutive negative-reward steps (patience) # Initial user messages initial_p1_msg = format_initial_p1_obs(obs, info) initial_p2_msg: Optional[str] = None # set on transition current_phase = initial_phase done = False for step_idx in range(MAX_STEPS): if done: break # Pool B: transition to phase 2 only if the env actually started in phase 1. # If the env auto-transitioned during reset, current_phase is already 2 — skip. if pool == "B" and current_phase == 1 and len(p1_steps) == 0 and "transition_to_phase2" in (obs.get("valid_actions") or []): raw = "{}" action = { "action_type": "transition_to_phase2", "parameters": {"belief": { "suspected_service": None, "suspected_fault_class": None, "service_confidence": 0.0, "fault_confidence": 0.0, "evidence_gaps": [], "estimated_p2_cost": "unknown", "decision": "transition", "reasoning": "Pool B: oracle belief injected by environment", }}, } else: # Hard ceiling: force terminal if too close to max_steps p1_hard_limit = MAX_STEPS - 8 if pool in ("C", "D") else MAX_STEPS - 3 if current_phase == 1 and step_idx >= p1_hard_limit: action = _force_p1_terminal(obs) raw = json.dumps(action) print(f" [step limit: forcing terminal]") else: # Call model with adaptive history truncation on 422 cur_history = p1_history if current_phase == 1 else p2_history cur_initial = initial_p1_msg if current_phase == 1 else ( initial_p2_msg or format_initial_p2_obs(obs, info, last_belief) ) if current_phase == 2 and initial_p2_msg is None: initial_p2_msg = cur_initial model_ok = True try: raw = call_model_adaptive(cur_history, cur_initial, current_phase, model) except Exception as e: print(f" [model error] {e}") raw = "{}" model_ok = False # If model keeps failing in P1, force terminal after 8 consecutive errors if not model_ok: consecutive_errors[0] += 1 else: consecutive_errors[0] = 0 if current_phase == 1 and consecutive_errors[0] >= 8: action = _force_p1_terminal(obs) raw = json.dumps(action) consecutive_errors[0] = 0 print(f" [8 consecutive model errors: forcing terminal]") elif current_phase == 1: action = parse_p1_action(raw, step_idx, obs, recent_actions) else: action = parse_p2_action(raw, step_idx, obs, recent_actions) print(f" step {step_idx+1:>2} | ph{current_phase} | {action.get('action_type')}" + (f"({action.get('target_service','')})" if action.get("target_service") else "")) # Track terminal/transition actions before stepping atype = action.get("action_type", "") if atype == "transition_to_phase2": belief = action.get("parameters", {}).get("belief", {}) last_belief = belief belief_history.append(belief) if atype == "propose_patch": declared_patch = action.get("parameters", {}).get("diff", "") if atype == "declare_no_change": declared_no_change = True # Step environment try: step_resp = env_step(action) except Exception as e: print(f" [env error] {e}") break reward = float(step_resp.get("reward", 0.0)) done = step_resp.get("done", False) new_obs = step_resp.get("observation", {}) new_phase = new_obs.get("current_phase", current_phase) print(f" reward={reward:+.3f} cumulative={new_obs.get('cumulative_reward', 0):+.3f}" + (" DONE" if done else "")) # Build result text for next turn if current_phase == 1: result_text = format_step_result_p1(new_obs, reward) else: result_text = format_step_result_p2(new_obs, reward) step_record = { "step": step_idx, "phase": current_phase, "action": action, "raw_output": raw, "observation": new_obs, "reward": reward, "result_text": result_text, # stored for SFT building } if current_phase == 1: p1_steps.append(step_record) p1_history.append({"action_json": json.dumps(action), "result_text": result_text}) else: p2_steps.append(step_record) p2_history.append({"action_json": json.dumps(action), "result_text": result_text}) recent_actions.append(action) # Patience: 10 consecutive negative rewards → force terminal immediately if reward < 0: consecutive_negative[0] += 1 else: consecutive_negative[0] = 0 if consecutive_negative[0] >= 10 and not done: print(f" [patience exhausted: 10 consecutive negatives — forcing terminal]") if current_phase == 1: term_action = _force_p1_terminal(new_obs) else: term_action = {"action_type": "declare_no_change", "parameters": {"reason": "Patience exhausted — no progress detected"}} try: term_resp = env_step(term_action) term_reward = float(term_resp.get("reward", 0.0)) done = term_resp.get("done", False) term_obs = term_resp.get("observation", new_obs) print(f" [forced terminal] reward={term_reward:+.3f} cumulative={term_obs.get('cumulative_reward',0):+.3f} DONE") steps_list = p1_steps if current_phase == 1 else p2_steps steps_list.append({"step": step_idx + 1, "phase": current_phase, "action": term_action, "raw_output": "{}", "observation": term_obs, "reward": term_reward, "result_text": ""}) new_obs = term_obs except Exception as e: print(f" [forced terminal env error] {e}") break # Detect phase transition if new_phase != current_phase and new_phase == 2: print(" ── Phase 1 → Phase 2 ──") initial_p2_msg = format_initial_p2_obs(new_obs, info, last_belief) recent_actions.clear() # reset repetition tracking for the new phase consecutive_negative[0] = 0 # reset patience on phase change current_phase = new_phase obs = _mask_p1_obs(new_obs, pool) time.sleep(SLEEP_BETWEEN) # Fetch unified scores score = env_score(declared_patch, declared_no_change, belief_history) cumulative = obs.get("cumulative_reward", 0.0) print(f" Final cumulative reward: {cumulative:.3f}") if score: print(f" Scores: {json.dumps({k: round(v, 3) for k, v in score.items()})}") return { "episode_id": episode_id, "task_name": task_name, "pool": pool, "model": model, "seed": seed, "p1_steps": p1_steps, "p2_steps": p2_steps, "num_p1_steps": len(p1_steps), "num_p2_steps": len(p2_steps), "cumulative_reward": round(cumulative, 4), "score_breakdown": score, "declared_patch": declared_patch, "declared_no_change": declared_no_change, "belief_history": belief_history, "done": done, # Reconstructed conversation contexts for SFT building "_initial_p1_msg": initial_p1_msg, "_initial_p2_msg": initial_p2_msg, "_p1_history": p1_history, "_p2_history": p2_history, } # ────────────────────────────────────────────────────────────────────────────── # SFT Dataset Formatter # ────────────────────────────────────────────────────────────────────────────── def episode_to_sft_samples(ep: dict) -> List[dict]: """ Convert one episode into per-step SFT samples for BOTH phases. ALL steps are included regardless of reward — negative-reward steps provide hard-negative signal critical for RL/preference training. The `reward` field is preserved so the training code can filter or weight. """ samples: List[dict] = [] def _extract_samples(steps, phase, initial_msg, history_key): history_so_far: List[dict] = [] for i, step_rec in enumerate(steps): system = SYSTEM_PROMPT_P1 if phase == 1 else SYSTEM_PROMPT_P2 messages = build_messages(history_so_far, initial_msg, phase=phase) messages.append({ "role": "assistant", "content": json.dumps(step_rec["action"]), }) samples.append({ "messages": messages, "reward": step_rec["reward"], "phase": phase, "action_type": step_rec["action"].get("action_type"), "task_name": ep["task_name"], "pool": ep["pool"], "model": ep["model"], "episode_id": ep["episode_id"], "step": i, }) history_so_far.append({ "action_json": json.dumps(step_rec["action"]), "result_text": step_rec.get("result_text", ""), }) if ep.get("p1_steps") and ep.get("_initial_p1_msg"): _extract_samples(ep["p1_steps"], 1, ep["_initial_p1_msg"], "_p1_history") if ep.get("p2_steps") and ep.get("_initial_p2_msg"): _extract_samples(ep["p2_steps"], 2, ep["_initial_p2_msg"], "_p2_history") return samples # ────────────────────────────────────────────────────────────────────────────── # GRPO / DPO Dataset Formatter # ────────────────────────────────────────────────────────────────────────────── def episodes_to_grpo_pairs(episodes: List[dict]) -> List[dict]: """ Build (prompt, chosen, rejected) triplets for GRPO/DPO training. Three pairing strategies: 1. Within-episode: best vs worst step (same prompt context) 2. Cross-episode: same task+pool, different models, different outcomes 3. Phase-specific: separate within-phase pairs for P2 Chosen = action with higher reward. Rejected = action with lower reward. Both are kept regardless of absolute reward sign. """ pairs: List[dict] = [] # ── Strategy 1: within-episode best/worst per phase ─────────────────────── for ep in episodes: for phase, steps, initial_msg in [ (1, ep.get("p1_steps", []), ep.get("_initial_p1_msg", "")), (2, ep.get("p2_steps", []), ep.get("_initial_p2_msg", "")), ]: if len(steps) < 2 or not initial_msg: continue best = max(steps, key=lambda s: s["reward"]) worst = min(steps, key=lambda s: s["reward"]) if best["reward"] == worst["reward"]: continue if best is worst: continue prompt_msgs = build_messages([], initial_msg, phase=phase) pairs.append({ "prompt": prompt_msgs, "chosen": json.dumps(best["action"]), "rejected": json.dumps(worst["action"]), "chosen_reward": best["reward"], "rejected_reward": worst["reward"], "margin": best["reward"] - worst["reward"], "task_name": ep["task_name"], "pool": ep["pool"], "phase": phase, "strategy": "within_episode", "episode_id": ep["episode_id"], }) # ── Strategy 2: cross-episode, same task+pool ───────────────────────────── by_task_pool: Dict[str, List[dict]] = defaultdict(list) for ep in episodes: key = f"{ep['task_name']}_{ep['pool']}" by_task_pool[key].append(ep) for key, task_eps in by_task_pool.items(): if len(task_eps) < 2: continue # Sort by cumulative reward; pair best vs worst episode sorted_eps = sorted(task_eps, key=lambda e: e["cumulative_reward"]) best_ep = sorted_eps[-1] worst_ep = sorted_eps[0] if best_ep["cumulative_reward"] == worst_ep["cumulative_reward"]: continue if best_ep["episode_id"] == worst_ep["episode_id"]: continue # Use the first non-view_alerts action as representative def _first_substantive_action(ep_inner, phase): steps = ep_inner.get(f"p{phase}_steps", []) for s in steps: if s["action"].get("action_type") != "view_alerts": return s return steps[0] if steps else None for phase in [1, 2]: best_step = _first_substantive_action(best_ep, phase) worst_step = _first_substantive_action(worst_ep, phase) initial_msg = best_ep.get(f"_initial_p{phase}_msg", "") if not best_step or not worst_step or not initial_msg: continue prompt_msgs = build_messages([], initial_msg, phase=phase) pairs.append({ "prompt": prompt_msgs, "chosen": json.dumps(best_step["action"]), "rejected": json.dumps(worst_step["action"]), "chosen_reward": best_ep["cumulative_reward"], "rejected_reward": worst_ep["cumulative_reward"], "margin": best_ep["cumulative_reward"] - worst_ep["cumulative_reward"], "task_name": best_ep["task_name"], "pool": best_ep["pool"], "phase": phase, "strategy": "cross_episode", "best_model": best_ep["model"], "worst_model": worst_ep["model"], }) return pairs # ────────────────────────────────────────────────────────────────────────────── # Episode Schedule Builder # ────────────────────────────────────────────────────────────────────────────── def build_episode_schedule(n: int) -> List[Tuple[str, str, str, int]]: """ Return list of (task_name, pool, model, seed) tuples. Distribution: - Pools weighted by POOL_WEIGHTS - Tasks within each pool: round-robin - Models: round-robin across all MODELS - Seeds: random per episode (for reproducibility, logged in output) """ schedule: List[Tuple[str, str, str, int]] = [] pool_counts = { pool: max(1, round(n * weight)) for pool, weight in POOL_WEIGHTS.items() } # Adjust to exactly n total = sum(pool_counts.values()) diff = n - total if diff > 0: pool_counts["C"] += diff elif diff < 0: pool_counts["A"] += diff # reduce A if over model_idx = 0 for pool, count in pool_counts.items(): tasks = POOL_TASKS[pool] for i in range(count): task = tasks[i % len(tasks)] model = MODELS[model_idx % len(MODELS)] seed = random.randint(0, 99999) schedule.append((task, pool, model, seed)) model_idx += 1 random.shuffle(schedule) return schedule # ────────────────────────────────────────────────────────────────────────────── # Main # ────────────────────────────────────────────────────────────────────────────── def _flush_episode(ep: dict, raw_f, sft_f) -> Tuple[int, int, int]: """Append one episode to the open raw and SFT files. Returns (pos, zer, neg) step counts.""" clean = {k: v for k, v in ep.items() if not k.startswith("_")} raw_f.write(json.dumps(clean) + "\n") raw_f.flush() samples = episode_to_sft_samples(ep) for s in samples: sft_f.write(json.dumps(s) + "\n") sft_f.flush() pos = sum(1 for s in samples if s["reward"] > 0) zer = sum(1 for s in samples if s["reward"] == 0) neg = sum(1 for s in samples if s["reward"] < 0) return pos, zer, neg def _finalize(all_episodes: List[dict], stats: Dict[str, List[float]]) -> None: """Generate GRPO pairs and print final statistics from whatever was collected.""" print(f"\n{'═'*60}") print(f"✅ Collected {len(all_episodes)} episodes") grpo_pairs = episodes_to_grpo_pairs(all_episodes) grpo_path = "sre_grpo_dataset.jsonl" with open(grpo_path, "w") as f: for p in grpo_pairs: f.write(json.dumps(p) + "\n") within = sum(1 for p in grpo_pairs if p["strategy"] == "within_episode") cross = sum(1 for p in grpo_pairs if p["strategy"] == "cross_episode") print(f"💾 GRPO dataset ({len(grpo_pairs)} pairs) → {grpo_path}") print(f" Pairs: {within} within-episode + {cross} cross-episode") if not all_episodes: return all_rewards = [ep["cumulative_reward"] for ep in all_episodes] print(f"\n📈 Reward statistics:") print(f" Overall: avg={sum(all_rewards)/len(all_rewards):.3f} " f"max={max(all_rewards):.3f} min={min(all_rewards):.3f}") print(f"\n By pool:") for pool in ["A", "B", "C", "D"]: rs = stats.get(f"pool_{pool}", []) if rs: print(f" Pool {pool}: n={len(rs):>3} avg={sum(rs)/len(rs):.3f} " f"max={max(rs):.3f} min={min(rs):.3f}") print(f"\n By task:") for task in sorted(set(ep["task_name"] for ep in all_episodes)): rs = stats.get(f"task_{task}", []) if rs: print(f" {task:<35} n={len(rs):>2} avg={sum(rs)/len(rs):.3f}") print(f"\n By model tier:") model_short_names = set() for ep in all_episodes: model_short_names.add(ep["model"].split("/")[1].split(":")[0]) for mname in sorted(model_short_names): rs = stats.get(f"model_{mname}", []) if rs: print(f" {mname:<40} n={len(rs):>2} avg={sum(rs)/len(rs):.3f}") def main(): from huggingface_hub import HfApi api = HfApi(token=HF_TOKEN) api.create_repo(repo_id="srinjoyd/sre-data", repo_type="dataset", exist_ok=True) if not HF_TOKEN: print("❌ HF_TOKEN is not set.\n export HF_TOKEN=hf_...") return print(f"🚀 SRE Trajectory Collector") print(f" Episodes: {NUM_EPISODES}") print(f" Models: {len(MODELS)} (rotating)") print(f" Tasks: {len(set(t for ts in POOL_TASKS.values() for t in ts))} unique") print(f" Pools: A / B / C / D") print(f" Base URL: {BASE_URL}") print(f" Keeping ALL episodes (negative reward = hard negatives for RL)") print(f" Saving incrementally — Ctrl+C safe\n") schedule = build_episode_schedule(NUM_EPISODES) all_episodes: List[dict] = [] stats: Dict[str, List[float]] = defaultdict(list) total_pos = total_zer = total_neg = 0 raw_path = "sre_raw_trajectories.jsonl" sft_path = "sre_sft_dataset.jsonl" print(f"💾 Writing to: {raw_path} | {sft_path} (appending per episode)") print(f" GRPO pairs written at end (or on Ctrl+C)\n") with open(raw_path, "a") as raw_f, open(sft_path, "a") as sft_f: try: for ep_id, (task, pool, model, seed) in enumerate(schedule): try: ep = run_episode(task, pool, model, ep_id, seed=seed) except Exception as e: print(f" [!] Episode {ep_id+1} FAILED: {e}") traceback.print_exc() time.sleep(2) continue all_episodes.append(ep) pos, zer, neg = _flush_episode(ep, raw_f, sft_f) upload_checkpoint(api, "srinjoyd/sre-data") total_pos += pos total_zer += zer total_neg += neg r = ep["cumulative_reward"] stats[f"pool_{pool}"].append(r) stats[f"task_{task}"].append(r) stats[f"model_{model.split('/')[1].split(':')[0]}"].append(r) print(f" [saved] ep {ep_id+1}/{NUM_EPISODES} | " f"SFT steps so far: +{total_pos}/0:{total_zer}/-{total_neg}") time.sleep(1.0) except KeyboardInterrupt: print(f"\n\n⚠️ Interrupted after {len(all_episodes)} episodes — saving what we have...") finally: upload_checkpoint(api, "srinjoyd/sre-data") # always runs, even on crash _finalize(all_episodes, stats) print(f"\n💾 Raw trajectories ({len(all_episodes)} eps) → {raw_path}") print(f"💾 SFT dataset ({total_pos+total_zer+total_neg} steps) → {sft_path}") print(f" Reward split: +{total_pos} / 0:{total_zer} / -{total_neg}") _finalize(all_episodes, stats) print(f"\n💡 To upload to HuggingFace Hub:") print(f" from datasets import Dataset") print(f" import json") print(f" sft = [json.loads(l) for l in open('sre_sft_dataset.jsonl')]") print(f" grpo = [json.loads(l) for l in open('sre_grpo_dataset.jsonl')]") print(f" Dataset.from_list(sft).push_to_hub('your-username/sre-sft-data')") print(f" Dataset.from_list(grpo).push_to_hub('your-username/sre-grpo-data')") if __name__ == "__main__": main()