updated-policy / logger /trajectory.py
srinjoyd's picture
add files
e60af4b
# /// script
# dependencies = [
# "requests",
# "huggingface_hub",
# ]
# ///
"""
=============================================================
SRE INCIDENT RESPONSE β€” COMPREHENSIVE TRAJECTORY COLLECTOR
=============================================================
Generates fine-tuning trajectories from the SRE incident simulator covering:
β€’ All 10 tasks (8 training + 2 held-out compound scenarios)
β€’ All 4 pools:
A β€” Phase 1 only (incident response)
B β€” Phase 2 only (code investigation, oracle belief injected)
C β€” Joint P1β†’P2 (full two-phase pipeline)
D β€” Held-out joint (generalization test)
β€’ Full 17-action action space across both phases
β€’ Multiple models from 1.5B to 70B+ (round-robin rotation)
β€’ ALL episodes retained β€” negative-reward trajectories are kept as
hard-negative examples for RL/GRPO training
Output files:
sre_raw_trajectories.json β€” full episode records with score breakdowns
sre_sft_dataset.jsonl β€” per-step SFT samples (both phases, all rewards)
sre_grpo_dataset.jsonl β€” (prompt, chosen, rejected) pairs for GRPO/DPO
Usage:
export HF_TOKEN=hf_...
python sre_finetune_collector.py
Optional env vars:
NUM_EPISODES total episodes to collect (default: 200)
BASE_URL simulator URL (default: HF Space URL)
MAX_STEPS max steps per episode (default: 35)
SLEEP_BETWEEN seconds between steps (default: 0.6)
"""
from __future__ import annotations
import json
import os
import random
import time
import traceback
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple
import requests
def upload_checkpoint(api, repo_id):
for fname in ["sre_raw_trajectories.jsonl", "sre_sft_dataset.jsonl", "sre_grpo_dataset.jsonl"]:
if os.path.exists(fname):
try:
api.upload_file(
path_or_fileobj=fname,
path_in_repo=fname,
repo_id=repo_id,
repo_type="dataset",
)
print(f"βœ… Uploaded {fname}")
except Exception as e:
print(f"❌ Upload failed {fname}: {e}")
# ──────────────────────────────────────────────────────────────────────────────
# Configuration
# ──────────────────────────────────────────────────────────────────────────────
HF_TOKEN = os.environ.get("HF_TOKEN")
BASE_URL = os.environ.get("BASE_URL", "https://meta-hf-hackathon-updated-policy.hf.space")
HF_ROUTER_URL = "https://router.huggingface.co/v1/chat/completions"
NUM_EPISODES = int(os.environ.get("NUM_EPISODES", "100"))
MAX_STEPS = int(os.environ.get("MAX_STEPS", "35"))
SLEEP_BETWEEN = float(os.environ.get("SLEEP_BETWEEN", "0.6"))
# ── Model ─────────────────────────────────────────────────────────────────────
MODELS: List[str] = [
"Qwen/Qwen2.5-7B-Instruct:fastest",
]
# ── Task registry per pool ────────────────────────────────────────────────────
# Pool A: P1-only incident response (all 8 training tasks)
# Pool B: P2-only code investigation (oracle belief injected; 7 tasks with code)
# Pool C: Joint P1β†’P2 full pipeline (all 8 training tasks)
# Pool D: Held-out joint (2 compound scenarios β€” generalization evaluation)
POOL_TASKS: Dict[str, List[str]] = {
"A": [
"memory_leak", "cascading_failure", "distributed_deadlock",
"circuit_breaker_noop", "aliased_fault", "severity_inversion",
"confidence_inversion", "info_ordering",
],
"B": [
"memory_leak", "cascading_failure", "distributed_deadlock",
"aliased_fault", "severity_inversion", "confidence_inversion", "info_ordering",
],
"C": [
"memory_leak", "cascading_failure", "distributed_deadlock",
"circuit_breaker_noop", "aliased_fault", "severity_inversion",
"confidence_inversion", "info_ordering",
],
"D": [
"heldout_aliased_severity", "heldout_confidence_ordering",
],
}
# Episode budget distribution across pools (must sum to 1.0)
POOL_WEIGHTS: Dict[str, float] = {"A": 0.35, "B": 0.20, "C": 0.35, "D": 0.10}
# ── Action space definitions ──────────────────────────────────────────────────
P1_DIAGNOSTIC = ["view_alerts", "query_logs", "check_metrics", "check_dependencies",
"check_deploy_history", "run_health_check"]
P1_REMEDIATION = ["restart_service", "rollback_deploy", "scale_service"]
P1_TERMINAL = ["declare_root_cause", "transition_to_phase2"]
P1_ACTIONS = P1_DIAGNOSTIC + P1_REMEDIATION + P1_TERMINAL
P2_DIAGNOSTIC = ["list_dir", "read_file", "search_code", "get_git_log", "get_file_diff"]
P2_TERMINAL = ["propose_patch", "declare_no_change"]
P2_ACTIONS = P2_DIAGNOSTIC + P2_TERMINAL
ALL_SERVICES = ["api_gateway", "auth", "orders", "payment", "cache", "database", "queue"]
TARGETED_ACTIONS = {
"query_logs", "check_metrics", "check_dependencies", "check_deploy_history",
"run_health_check", "restart_service", "rollback_deploy", "scale_service",
}
# Service dependency graph (for smarter fallbacks)
DEPENDENCY_GRAPH: Dict[str, List[str]] = {
"api_gateway": ["auth", "orders", "cache"],
"auth": ["database"],
"orders": ["database", "payment", "auth"],
"payment": ["queue", "database"],
"cache": [],
"database": [],
"queue": [],
}
# ──────────────────────────────────────────────────────────────────────────────
# System Prompts
# ──────────────────────────────────────────────────────────────────────────────
SYSTEM_PROMPT_P1 = """You are an expert SRE handling a production incident in a microservices system.
## Service Topology (downstream ← upstream)
api_gateway ← auth, orders, cache
auth ← database
orders ← database, payment, auth
payment ← queue, database
cache, database, queue ← (no dependencies)
## Phase 1 Action Space
Output EXACTLY ONE valid JSON action per turn. No markdown, no explanation.
Diagnostic (read-only):
{"action_type": "view_alerts"}
{"action_type": "query_logs", "target_service": "<svc>", "parameters": {"level": "ERROR", "keyword": "<optional>", "limit": 20}}
{"action_type": "check_metrics", "target_service": "<svc>"}
{"action_type": "check_dependencies", "target_service": "<svc>"}
{"action_type": "check_deploy_history", "target_service": "<svc>"}
{"action_type": "run_health_check", "target_service": "<svc>"}
Remediation (mutates state):
{"action_type": "restart_service", "target_service": "<svc>"}
{"action_type": "rollback_deploy", "target_service": "<svc>"}
{"action_type": "scale_service", "target_service": "<svc>", "parameters": {"replicas": 5}}
Declare root cause (ALL tasks β€” always call this once you have a diagnosis):
{"action_type": "declare_root_cause", "parameters": {"root_cause": "<specific diagnosis β€” service, what failed, why>"}}
Then for joint-mode tasks, ALSO transition to code investigation:
{"action_type": "transition_to_phase2", "parameters": {"belief": {
"suspected_service": "<root_cause_svc>",
"suspected_fault_class": "memory_leak|config_change|deadlock|dep_upgrade|none",
"service_confidence": 0.85,
"fault_confidence": 0.80,
"evidence_gaps": ["<what_you_didnt_check>"],
"estimated_p2_cost": "low|medium|high",
"decision": "transition",
"reasoning": "<concise evidence summary>"
}}}
## Investigation Strategy
1. ALWAYS start with view_alerts to understand severity and scope
2. check_metrics on the highest-alert service first
3. query_logs (level=ERROR) on degraded/down services
4. check_dependencies on the affected service to find upstream causes
5. check_deploy_history before any rollback
6. Remediate the ROOT CAUSE service, not the symptom
7. After 6-8 diagnostic steps you MUST call declare_root_cause with your diagnosis.
For P1-only tasks this ends the episode. For joint-mode tasks, follow it immediately
with transition_to_phase2. Do NOT keep diagnosing indefinitely β€” commit to a conclusion.
CRITICAL: Output ONLY valid JSON. No markdown. No explanation. No code blocks."""
SYSTEM_PROMPT_P2 = """You are an expert SRE investigating a code-level fault in a sandboxed repository.
## Phase 2 Action Space
Output EXACTLY ONE valid JSON action per turn. No markdown, no explanation.
Code Exploration:
{"action_type": "list_dir", "parameters": {"path": "."}}
{"action_type": "read_file", "parameters": {"path": "relative/path/to/file.py"}}
{"action_type": "search_code", "parameters": {"query": "<search string>", "file_pattern": "*.py", "max_hits": 20}}
{"action_type": "get_git_log", "parameters": {"path": ".", "n_commits": 15}}
{"action_type": "get_file_diff", "parameters": {"commit_sha": "<sha>", "path": "relative/path/file.py"}}
Terminal:
{"action_type": "propose_patch", "parameters": {"diff": "<unified diff β€” minimal, correct, applies cleanly>"}}
{"action_type": "declare_no_change", "parameters": {"reason": "<why no code fix is needed β€” infrastructure issue, not code>"}}
## Investigation Strategy
1. list_dir "." to understand project structure
2. get_git_log to find recent commits β€” especially the bad_commit_sha from Phase 1 context
3. get_file_diff on the suspicious commit SHA to see what changed
4. read_file on affected files to understand the bug
5. search_code to find related patterns or the fault injection site
6. If you found a code bug: propose_patch with a minimal, syntactically valid unified diff.
The bad_commit_sha in your context tells you exactly what changed β€” read that diff and revert/fix it.
7. declare_no_change ONLY if Phase 1 confirmed a spurious alert / circuit-breaker false positive
with no deployment or code change involved. If there IS a bad commit in the git log, propose_patch.
CRITICAL: Output ONLY valid JSON. No markdown. No explanation. No code blocks."""
# ──────────────────────────────────────────────────────────────────────────────
# Observation Formatters
# ──────────────────────────────────────────────────────────────────────────────
def _fmt_service_statuses(statuses: Dict[str, str]) -> str:
symbols = {"healthy": "βœ“", "degraded": "~", "down": "βœ—"}
return " ".join(
f"{symbols.get(v,'?')}{svc}({v})"
for svc, v in sorted(statuses.items())
)
def _fmt_action_result(result: Any, max_chars: int = 3000) -> str:
if result is None:
return "(no result)"
text = json.dumps(result, indent=2) if not isinstance(result, str) else result
if len(text) > max_chars:
text = text[:max_chars] + f"\n... [truncated {len(text)-max_chars} chars]"
return text
def format_initial_p1_obs(obs: dict, info: dict) -> str:
"""Format the very first observation for Phase 1."""
task = info.get("task_name", "unknown")
pool = info.get("pool", "?")
mode = info.get("mode", "unknown")
phase = obs.get("current_phase", 1)
svc_line = _fmt_service_statuses(obs.get("service_statuses", {}))
valid = obs.get("valid_actions", P1_ACTIONS)
return (
f"INCIDENT RESPONSE | Pool {pool} | Mode: {mode} | Task: {task}\n"
f"{'─'*60}\n"
f"Summary: {obs.get('incident_summary', 'No summary available')}\n"
f"Severity: {obs.get('severity', '?')} | "
f"Time Budget: {obs.get('time_budget_minutes', '?')} min | "
f"Max Steps: {obs.get('max_steps', MAX_STEPS)}\n"
f"Phase: {phase}\n"
f"\nService Statuses:\n {svc_line}\n"
f"Active Alerts: {obs.get('active_alerts_count', 0)}\n"
f"\nValid Actions: {valid}\n"
f"\nWhat is your FIRST action?"
)
def format_step_result_p1(obs: dict, reward: float) -> str:
"""Format a step result during Phase 1."""
svc_line = _fmt_service_statuses(obs.get("service_statuses", {}))
result = _fmt_action_result(obs.get("action_result"))
lines = [
f"Action Result (success={obs.get('action_success', '?')}): "
f"{obs.get('action_message', '')}",
f"\n{result}",
f"\n{'─'*40}",
f"Services: {svc_line}",
f"Alerts: {obs.get('active_alerts_count', 0)} active",
f"Step: {obs.get('steps_taken','?')}/{obs.get('max_steps', MAX_STEPS)} "
f"| Time: {obs.get('time_elapsed_minutes','?')}/{obs.get('time_budget_minutes','?')} min",
f"Reward: {reward:+.3f} | Cumulative: {obs.get('cumulative_reward', 0):+.3f}",
]
if obs.get("bad_commit_sha"):
lines.append(f"Bad Commit SHA: {obs['bad_commit_sha']} (remember for Phase 2)")
valid = obs.get("valid_actions", P1_ACTIONS)
lines.append(f"\nValid Actions: {valid}")
lines.append("\nWhat is your next action?")
return "\n".join(lines)
def format_initial_p2_obs(obs: dict, info: dict, belief: Optional[dict]) -> str:
"""Format the first Phase 2 observation (after transition or Pool B auto-start)."""
task = info.get("task_name", "unknown")
pool = info.get("pool", "?")
belief_text = ""
if belief:
belief_text = (
f"\n[Phase 1 Belief]\n"
f" Suspected service: {belief.get('suspected_service', '?')}\n"
f" Suspected fault: {belief.get('suspected_fault_class', '?')}\n"
f" Service confidence: {belief.get('service_confidence', 0):.0%}\n"
f" Fault confidence: {belief.get('fault_confidence', 0):.0%}\n"
f" Reasoning: {belief.get('reasoning', '')}\n"
f" P2 cost estimate: {belief.get('estimated_p2_cost', '?')}\n"
)
sha_line = (
f"Bad Commit SHA: {obs.get('bad_commit_sha', '(check git log)')}\n"
if obs.get("bad_commit_sha") else ""
)
return (
f"CODE INVESTIGATION | Pool {pool} | Task: {task}\n"
f"{'─'*60}\n"
f"{sha_line}"
f"{belief_text}\n"
f"Step: {obs.get('steps_taken', 0)}/{obs.get('max_steps', MAX_STEPS)} "
f"| Cumulative Reward: {obs.get('cumulative_reward', 0):+.3f}\n"
f"\nValid Actions: {obs.get('valid_actions', P2_ACTIONS)}\n"
f"\nWhat is your first Phase 2 action?"
)
def format_step_result_p2(obs: dict, reward: float) -> str:
"""Format a step result during Phase 2."""
result = _fmt_action_result(obs.get("action_result"))
lines = [
f"Action Result (success={obs.get('action_success', '?')}): "
f"{obs.get('action_message', '')}",
f"\n{result}",
f"\n{'─'*40}",
f"Step: {obs.get('steps_taken','?')}/{obs.get('max_steps', MAX_STEPS)}",
f"Reward: {reward:+.3f} | Cumulative: {obs.get('cumulative_reward', 0):+.3f}",
f"\nValid Actions: {obs.get('valid_actions', P2_ACTIONS)}",
"\nWhat is your next action?",
]
return "\n".join(lines)
# ──────────────────────────────────────────────────────────────────────────────
# Message Builder
# ──────────────────────────────────────────────────────────────────────────────
def build_messages(
history: List[Dict],
initial_user_msg: str,
phase: int,
max_recent: int = 10,
) -> List[Dict]:
"""
Build the full OpenAI-format messages list.
history: [{"action_json": str, "result_text": str, "reward": float}, ...]
max_recent caps how many turns are included to avoid context-length 422s.
"""
system = SYSTEM_PROMPT_P1 if phase == 1 else SYSTEM_PROMPT_P2
messages: List[Dict] = [
{"role": "system", "content": system},
{"role": "user", "content": initial_user_msg},
]
for entry in history[-max_recent:]:
messages.append({"role": "assistant", "content": entry["action_json"]})
messages.append({"role": "user", "content": entry["result_text"]})
return messages
# ──────────────────────────────────────────────────────────────────────────────
# Model Caller
# ──────────────────────────────────────────────────────────────────────────────
def call_model(
messages: List[Dict],
model: str,
temperature: float = 0.5,
max_tokens: int = 512,
retries: int = 3,
) -> str:
if not HF_TOKEN:
raise ValueError("HF_TOKEN is not set.")
payload = {
"model": model,
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
}
last_exc: Exception = RuntimeError("No attempts made")
for attempt in range(retries):
try:
resp = requests.post(
HF_ROUTER_URL,
headers={
"Authorization": f"Bearer {HF_TOKEN}",
"Content-Type": "application/json",
},
json=payload,
timeout=90,
)
resp.raise_for_status()
return resp.json()["choices"][0]["message"]["content"].strip()
except requests.HTTPError as e:
code = e.response.status_code if e.response is not None else 0
if code in (400, 422):
raise # client-format errors β€” retrying won't help; let caller handle
last_exc = e
wait = 2 ** attempt
print(f" [model retry {attempt+1}/{retries}] {e} β€” waiting {wait}s")
time.sleep(wait)
except Exception as e:
last_exc = e
wait = 2 ** attempt
print(f" [model retry {attempt+1}/{retries}] {e} β€” waiting {wait}s")
time.sleep(wait)
raise last_exc
def _merge_system_into_user(messages: List[Dict]) -> List[Dict]:
"""Fold system prompt into the first user message for models without system role."""
if not messages or messages[0]["role"] != "system":
return messages
system_text = messages[0]["content"]
rest = messages[1:]
if not rest or rest[0]["role"] != "user":
return rest
merged_first = {"role": "user", "content": f"{system_text}\n\n{rest[0]['content']}"}
return [merged_first] + rest[1:]
# Models confirmed to reject the system role β€” merged format used from the start.
_MODELS_NEEDING_MERGE: set = set()
def call_model_adaptive(
history: List[Dict],
initial_msg: str,
phase: int,
model: str,
temperature: float = 0.5,
) -> str:
"""
Call model with two layers of fallback:
400 (system role) β†’ merge system into first user message and cache the result
so every subsequent step skips the wasted probe.
400 (after merge) β†’ content still too long; halve history window.
422 (ctx length) β†’ halve history window.
"""
use_merge = model in _MODELS_NEEDING_MERGE
probed_merge = use_merge # True = already confirmed in a prior step, no re-probe needed
max_recent = 10
while True:
messages = build_messages(history, initial_msg, phase, max_recent=max_recent)
if use_merge:
messages = _merge_system_into_user(messages)
try:
result = call_model(messages, model=model, temperature=temperature)
if use_merge and not probed_merge:
# Merged succeeded for the first time β€” cache it
_MODELS_NEEDING_MERGE.add(model)
print(f" [merged format confirmed for {model.split('/')[1]}, cached]")
return result
except requests.HTTPError as e:
code = e.response.status_code if e.response is not None else 0
if code == 400 and not use_merge:
use_merge = True
probed_merge = False
print(f" [400: probing merged format for {model.split('/')[1]}]")
elif code in (400, 422) and max_recent > 1:
max_recent = max(1, max_recent // 2)
print(f" [ctx truncated to last {max_recent} turns]")
else:
raise
# ──────────────────────────────────────────────────────────────────────────────
# Action Parsers
# ──────────────────────────────────────────────────────────────────────────────
def _extract_json(raw: str) -> dict:
"""Extract the first JSON object from model output and normalise colon-format action types.
Some models output {"action_type": "check_metrics:api_gateway"} instead of
the correct {"action_type": "check_metrics", "target_service": "api_gateway"}.
Split and normalise so the environment never sees an invalid action_type.
"""
start = raw.find("{")
end = raw.rfind("}") + 1
if start == -1 or end == 0:
raise ValueError("No JSON object in model output")
action = json.loads(raw[start:end])
atype = action.get("action_type", "")
if ":" in atype:
parts = atype.split(":", 1)
action["action_type"] = parts[0]
if parts[1] in ALL_SERVICES and "target_service" not in action:
action["target_service"] = parts[1]
return action
def _recent_sigs(recent_actions: List[dict], n: int = 3) -> set:
return {(a.get("action_type"), a.get("target_service")) for a in recent_actions[-n:]}
def _diversify_p1(obs: dict, recent_actions: List[dict]) -> dict:
"""Return the next logical diagnostic action that hasn't been done recently."""
statuses = obs.get("service_statuses") or {}
bad_svcs = [s for s, st in statuses.items() if st != "healthy"]
used_sigs = _recent_sigs(recent_actions, n=4)
used_svcs = {a.get("target_service") for a in recent_actions[-6:]}
used_types = {a.get("action_type") for a in recent_actions[-4:]}
# Build a uniform list of (score, action) tuples
candidates: List[Tuple[int, dict]] = []
for atype in P1_DIAGNOSTIC:
if atype == "view_alerts":
score = 0 if ("view_alerts", None) in used_sigs else 2
candidates.append((score, {"action_type": "view_alerts"}))
continue
for svc in (bad_svcs or ALL_SERVICES):
a: dict = {"action_type": atype, "target_service": svc}
if atype == "query_logs":
a["parameters"] = {"level": "ERROR", "limit": 20}
already_used = (atype, svc) in used_sigs
score = (not already_used) * 2 + (svc not in used_svcs) + (atype not in used_types)
candidates.append((score, a))
candidates.sort(key=lambda x: -x[0])
if candidates:
return candidates[0][1]
# Last resort
svc = next((s for s in ALL_SERVICES if s not in used_svcs), random.choice(ALL_SERVICES))
return {"action_type": "query_logs", "target_service": svc,
"parameters": {"level": "ERROR", "limit": 20}}
def parse_p1_action(raw: str, step: int, obs: dict, recent_actions: Optional[List[dict]] = None) -> dict:
"""Parse Phase 1 action with smart fallbacks and anti-repetition."""
recent_actions = recent_actions or []
valid = set(obs.get("valid_actions") or P1_ACTIONS)
try:
action = _extract_json(raw)
atype = action.get("action_type", "")
if atype not in valid:
action = _diversify_p1(obs, recent_actions) if step > 0 else {"action_type": "view_alerts"}
atype = action["action_type"]
# Ensure target_service for targeted actions
if atype in TARGETED_ACTIONS:
if "target_service" not in action or action["target_service"] not in ALL_SERVICES:
svcs = obs.get("available_services") or ALL_SERVICES
action["target_service"] = random.choice(svcs)
# Anti-repetition: if this exact (type, service) was used recently, diversify
sig = (action.get("action_type"), action.get("target_service"))
if sig in _recent_sigs(recent_actions, n=2) and atype not in ("declare_root_cause", "transition_to_phase2"):
action = _diversify_p1(obs, recent_actions)
atype = action["action_type"]
# Validate transition_to_phase2 belief structure
if atype == "transition_to_phase2":
params = action.setdefault("parameters", {})
belief = params.setdefault("belief", {})
degraded = [s for s, st in (obs.get("service_statuses") or {}).items()
if st != "healthy"]
belief.setdefault("suspected_service", degraded[0] if degraded else random.choice(ALL_SERVICES))
belief.setdefault("suspected_fault_class", "memory_leak")
belief.setdefault("service_confidence", 0.7)
belief.setdefault("fault_confidence", 0.65)
belief.setdefault("evidence_gaps", [])
belief.setdefault("estimated_p2_cost", "medium")
belief.setdefault("decision", "transition")
belief.setdefault("reasoning", "Transitioning based on collected evidence")
return action
except Exception:
if step == 0:
return {"action_type": "view_alerts"}
return _diversify_p1(obs, recent_actions)
def _force_p1_terminal(obs: dict) -> dict:
"""Build a best-effort terminal action from observed state."""
valid = set(obs.get("valid_actions") or P1_ACTIONS)
statuses = obs.get("service_statuses") or {}
degraded = [s for s, st in statuses.items() if st != "healthy"]
if "transition_to_phase2" in valid:
svc = degraded[0] if degraded else random.choice(ALL_SERVICES)
return {
"action_type": "transition_to_phase2",
"parameters": {"belief": {
"suspected_service": svc,
"suspected_fault_class": "memory_leak",
"service_confidence": 0.5,
"fault_confidence": 0.5,
"evidence_gaps": ["forced_terminal_after_step_limit"],
"estimated_p2_cost": "medium",
"decision": "transition",
"reasoning": f"Forced transition: degraded={degraded}",
}},
}
cause = (f"Degradation detected in: {', '.join(degraded)}"
if degraded else "Root cause undetermined within step budget")
return {"action_type": "declare_root_cause",
"parameters": {"root_cause": cause}}
def parse_p2_action(raw: str, step: int, obs: dict, recent_actions: Optional[List[dict]] = None) -> dict:
"""Parse Phase 2 action with smart fallbacks and anti-repetition."""
recent_actions = recent_actions or []
valid = set(obs.get("valid_actions") or P2_ACTIONS)
used_sigs = _recent_sigs(recent_actions, n=3)
p2_fallback_sequence = [
{"action_type": "list_dir", "parameters": {"path": "."}},
{"action_type": "get_git_log", "parameters": {"path": ".", "n_commits": 15}},
{"action_type": "search_code", "parameters": {"query": "error", "file_pattern": "*.py", "max_hits": 15}},
{"action_type": "search_code", "parameters": {"query": "def ", "file_pattern": "*.py", "max_hits": 10}},
{"action_type": "list_dir", "parameters": {"path": "src"}},
]
try:
action = _extract_json(raw)
atype = action.get("action_type", "")
if atype not in valid:
action = p2_fallback_sequence[step % len(p2_fallback_sequence)]
atype = action["action_type"]
# Ensure required params
params = action.setdefault("parameters", {})
if atype == "list_dir":
params.setdefault("path", ".")
elif atype == "read_file":
params.setdefault("path", ".")
elif atype == "search_code":
params.setdefault("query", "error")
params.setdefault("file_pattern", "*.py")
params.setdefault("max_hits", 15)
elif atype == "get_git_log":
params.setdefault("path", ".")
params.setdefault("n_commits", 10)
elif atype == "get_file_diff":
sha = obs.get("bad_commit_sha") or "HEAD"
params.setdefault("commit_sha", sha)
params.setdefault("path", ".")
elif atype == "propose_patch" and "diff" not in params:
action = {"action_type": "declare_no_change",
"parameters": {"reason": "Unable to determine code fix from available evidence"}}
elif atype == "declare_no_change":
params.setdefault("reason", "No code-level fix required based on investigation")
# Anti-repetition for non-terminal actions
sig = (action.get("action_type"), str(action.get("parameters", {}).get("path", "")))
if sig in used_sigs and atype not in P2_TERMINAL:
action = p2_fallback_sequence[(step + len(recent_actions)) % len(p2_fallback_sequence)]
return action
except Exception:
return p2_fallback_sequence[step % len(p2_fallback_sequence)]
# ──────────────────────────────────────────────────────────────────────────────
# Environment HTTP Helpers
# ──────────────────────────────────────────────────────────────────────────────
def _mask_p1_obs(obs: dict, pool: str) -> dict:
"""Pool A is p1_only β€” remove transition_to_phase2 the server incorrectly exposes."""
if pool == "A" and obs.get("valid_actions"):
obs = dict(obs)
obs["valid_actions"] = [a for a in obs["valid_actions"] if a != "transition_to_phase2"]
return obs
def env_reset(task_name: str, pool: str, seed: Optional[int] = None) -> dict:
body: dict = {"task_name": task_name, "pool": pool}
if seed is not None:
body["seed"] = seed
resp = requests.post(f"{BASE_URL}/reset", json=body, timeout=30)
resp.raise_for_status()
return resp.json()
def env_step(action: dict) -> dict:
resp = requests.post(f"{BASE_URL}/step", json=action, timeout=30)
resp.raise_for_status()
return resp.json()
def env_score(declared_patch: Optional[str], declared_no_change: bool,
belief_history: List[dict]) -> dict:
"""Fetch unified grader scores for the completed episode."""
try:
resp = requests.post(
f"{BASE_URL}/score",
json={
"declared_patch": declared_patch,
"declared_no_change": declared_no_change,
"belief_history": belief_history,
},
timeout=30,
)
resp.raise_for_status()
return resp.json()
except Exception as e:
print(f" [score] {e}")
return {}
def env_get_trajectory() -> dict:
"""Fetch the full trajectory from the server."""
try:
resp = requests.get(f"{BASE_URL}/trajectory", timeout=30)
resp.raise_for_status()
return resp.json()
except Exception:
return {}
# ──────────────────────────────────────────────────────────────────────────────
# Episode Runner
# ──────────────────────────────────────────────────────────────────────────────
def run_episode(
task_name: str,
pool: str,
model: str,
episode_id: int,
seed: Optional[int] = None,
) -> dict:
"""
Run one full episode (Phase 1, Phase 2, or Joint) through the HTTP API.
Returns a rich episode record including:
- step records with (action, raw_model_output, observation, reward)
- final score breakdown from /score
- SFT-ready message sequences per step
"""
print(f"\n{'═'*60}")
print(f" Ep {episode_id+1:>3} | Pool {pool} | Task: {task_name}")
print(f" Model: {model}")
print(f"{'─'*60}")
reset_resp = env_reset(task_name, pool, seed)
obs = _mask_p1_obs(reset_resp.get("observation", {}), pool)
info = reset_resp.get("info", {})
initial_phase = obs.get("current_phase", 1)
# Tracks for the episode
p1_steps: List[dict] = []
p2_steps: List[dict] = []
belief_history: List[dict] = []
declared_patch: Optional[str] = None
declared_no_change: bool = False
# Conversation history per phase (for message building)
p1_history: List[dict] = []
p2_history: List[dict] = []
last_belief: Optional[dict] = None
# Recent actions (flattened, for anti-repetition)
recent_actions: List[dict] = []
consecutive_errors = [0] # consecutive model call failures
consecutive_negative = [0] # consecutive negative-reward steps (patience)
# Initial user messages
initial_p1_msg = format_initial_p1_obs(obs, info)
initial_p2_msg: Optional[str] = None # set on transition
current_phase = initial_phase
done = False
for step_idx in range(MAX_STEPS):
if done:
break
# Pool B: transition to phase 2 only if the env actually started in phase 1.
# If the env auto-transitioned during reset, current_phase is already 2 β€” skip.
if pool == "B" and current_phase == 1 and len(p1_steps) == 0 and "transition_to_phase2" in (obs.get("valid_actions") or []):
raw = "{}"
action = {
"action_type": "transition_to_phase2",
"parameters": {"belief": {
"suspected_service": None,
"suspected_fault_class": None,
"service_confidence": 0.0,
"fault_confidence": 0.0,
"evidence_gaps": [],
"estimated_p2_cost": "unknown",
"decision": "transition",
"reasoning": "Pool B: oracle belief injected by environment",
}},
}
else:
# Hard ceiling: force terminal if too close to max_steps
p1_hard_limit = MAX_STEPS - 8 if pool in ("C", "D") else MAX_STEPS - 3
if current_phase == 1 and step_idx >= p1_hard_limit:
action = _force_p1_terminal(obs)
raw = json.dumps(action)
print(f" [step limit: forcing terminal]")
else:
# Call model with adaptive history truncation on 422
cur_history = p1_history if current_phase == 1 else p2_history
cur_initial = initial_p1_msg if current_phase == 1 else (
initial_p2_msg or format_initial_p2_obs(obs, info, last_belief)
)
if current_phase == 2 and initial_p2_msg is None:
initial_p2_msg = cur_initial
model_ok = True
try:
raw = call_model_adaptive(cur_history, cur_initial, current_phase, model)
except Exception as e:
print(f" [model error] {e}")
raw = "{}"
model_ok = False
# If model keeps failing in P1, force terminal after 8 consecutive errors
if not model_ok:
consecutive_errors[0] += 1
else:
consecutive_errors[0] = 0
if current_phase == 1 and consecutive_errors[0] >= 8:
action = _force_p1_terminal(obs)
raw = json.dumps(action)
consecutive_errors[0] = 0
print(f" [8 consecutive model errors: forcing terminal]")
elif current_phase == 1:
action = parse_p1_action(raw, step_idx, obs, recent_actions)
else:
action = parse_p2_action(raw, step_idx, obs, recent_actions)
print(f" step {step_idx+1:>2} | ph{current_phase} | {action.get('action_type')}"
+ (f"({action.get('target_service','')})" if action.get("target_service") else ""))
# Track terminal/transition actions before stepping
atype = action.get("action_type", "")
if atype == "transition_to_phase2":
belief = action.get("parameters", {}).get("belief", {})
last_belief = belief
belief_history.append(belief)
if atype == "propose_patch":
declared_patch = action.get("parameters", {}).get("diff", "")
if atype == "declare_no_change":
declared_no_change = True
# Step environment
try:
step_resp = env_step(action)
except Exception as e:
print(f" [env error] {e}")
break
reward = float(step_resp.get("reward", 0.0))
done = step_resp.get("done", False)
new_obs = step_resp.get("observation", {})
new_phase = new_obs.get("current_phase", current_phase)
print(f" reward={reward:+.3f} cumulative={new_obs.get('cumulative_reward', 0):+.3f}"
+ (" DONE" if done else ""))
# Build result text for next turn
if current_phase == 1:
result_text = format_step_result_p1(new_obs, reward)
else:
result_text = format_step_result_p2(new_obs, reward)
step_record = {
"step": step_idx,
"phase": current_phase,
"action": action,
"raw_output": raw,
"observation": new_obs,
"reward": reward,
"result_text": result_text, # stored for SFT building
}
if current_phase == 1:
p1_steps.append(step_record)
p1_history.append({"action_json": json.dumps(action), "result_text": result_text})
else:
p2_steps.append(step_record)
p2_history.append({"action_json": json.dumps(action), "result_text": result_text})
recent_actions.append(action)
# Patience: 10 consecutive negative rewards β†’ force terminal immediately
if reward < 0:
consecutive_negative[0] += 1
else:
consecutive_negative[0] = 0
if consecutive_negative[0] >= 10 and not done:
print(f" [patience exhausted: 10 consecutive negatives β€” forcing terminal]")
if current_phase == 1:
term_action = _force_p1_terminal(new_obs)
else:
term_action = {"action_type": "declare_no_change",
"parameters": {"reason": "Patience exhausted β€” no progress detected"}}
try:
term_resp = env_step(term_action)
term_reward = float(term_resp.get("reward", 0.0))
done = term_resp.get("done", False)
term_obs = term_resp.get("observation", new_obs)
print(f" [forced terminal] reward={term_reward:+.3f} cumulative={term_obs.get('cumulative_reward',0):+.3f} DONE")
steps_list = p1_steps if current_phase == 1 else p2_steps
steps_list.append({"step": step_idx + 1, "phase": current_phase,
"action": term_action, "raw_output": "{}",
"observation": term_obs, "reward": term_reward,
"result_text": ""})
new_obs = term_obs
except Exception as e:
print(f" [forced terminal env error] {e}")
break
# Detect phase transition
if new_phase != current_phase and new_phase == 2:
print(" ── Phase 1 β†’ Phase 2 ──")
initial_p2_msg = format_initial_p2_obs(new_obs, info, last_belief)
recent_actions.clear() # reset repetition tracking for the new phase
consecutive_negative[0] = 0 # reset patience on phase change
current_phase = new_phase
obs = _mask_p1_obs(new_obs, pool)
time.sleep(SLEEP_BETWEEN)
# Fetch unified scores
score = env_score(declared_patch, declared_no_change, belief_history)
cumulative = obs.get("cumulative_reward", 0.0)
print(f" Final cumulative reward: {cumulative:.3f}")
if score:
print(f" Scores: {json.dumps({k: round(v, 3) for k, v in score.items()})}")
return {
"episode_id": episode_id,
"task_name": task_name,
"pool": pool,
"model": model,
"seed": seed,
"p1_steps": p1_steps,
"p2_steps": p2_steps,
"num_p1_steps": len(p1_steps),
"num_p2_steps": len(p2_steps),
"cumulative_reward": round(cumulative, 4),
"score_breakdown": score,
"declared_patch": declared_patch,
"declared_no_change": declared_no_change,
"belief_history": belief_history,
"done": done,
# Reconstructed conversation contexts for SFT building
"_initial_p1_msg": initial_p1_msg,
"_initial_p2_msg": initial_p2_msg,
"_p1_history": p1_history,
"_p2_history": p2_history,
}
# ──────────────────────────────────────────────────────────────────────────────
# SFT Dataset Formatter
# ──────────────────────────────────────────────────────────────────────────────
def episode_to_sft_samples(ep: dict) -> List[dict]:
"""
Convert one episode into per-step SFT samples for BOTH phases.
ALL steps are included regardless of reward β€” negative-reward steps
provide hard-negative signal critical for RL/preference training.
The `reward` field is preserved so the training code can filter or weight.
"""
samples: List[dict] = []
def _extract_samples(steps, phase, initial_msg, history_key):
history_so_far: List[dict] = []
for i, step_rec in enumerate(steps):
system = SYSTEM_PROMPT_P1 if phase == 1 else SYSTEM_PROMPT_P2
messages = build_messages(history_so_far, initial_msg, phase=phase)
messages.append({
"role": "assistant",
"content": json.dumps(step_rec["action"]),
})
samples.append({
"messages": messages,
"reward": step_rec["reward"],
"phase": phase,
"action_type": step_rec["action"].get("action_type"),
"task_name": ep["task_name"],
"pool": ep["pool"],
"model": ep["model"],
"episode_id": ep["episode_id"],
"step": i,
})
history_so_far.append({
"action_json": json.dumps(step_rec["action"]),
"result_text": step_rec.get("result_text", ""),
})
if ep.get("p1_steps") and ep.get("_initial_p1_msg"):
_extract_samples(ep["p1_steps"], 1, ep["_initial_p1_msg"], "_p1_history")
if ep.get("p2_steps") and ep.get("_initial_p2_msg"):
_extract_samples(ep["p2_steps"], 2, ep["_initial_p2_msg"], "_p2_history")
return samples
# ──────────────────────────────────────────────────────────────────────────────
# GRPO / DPO Dataset Formatter
# ──────────────────────────────────────────────────────────────────────────────
def episodes_to_grpo_pairs(episodes: List[dict]) -> List[dict]:
"""
Build (prompt, chosen, rejected) triplets for GRPO/DPO training.
Three pairing strategies:
1. Within-episode: best vs worst step (same prompt context)
2. Cross-episode: same task+pool, different models, different outcomes
3. Phase-specific: separate within-phase pairs for P2
Chosen = action with higher reward.
Rejected = action with lower reward.
Both are kept regardless of absolute reward sign.
"""
pairs: List[dict] = []
# ── Strategy 1: within-episode best/worst per phase ───────────────────────
for ep in episodes:
for phase, steps, initial_msg in [
(1, ep.get("p1_steps", []), ep.get("_initial_p1_msg", "")),
(2, ep.get("p2_steps", []), ep.get("_initial_p2_msg", "")),
]:
if len(steps) < 2 or not initial_msg:
continue
best = max(steps, key=lambda s: s["reward"])
worst = min(steps, key=lambda s: s["reward"])
if best["reward"] == worst["reward"]:
continue
if best is worst:
continue
prompt_msgs = build_messages([], initial_msg, phase=phase)
pairs.append({
"prompt": prompt_msgs,
"chosen": json.dumps(best["action"]),
"rejected": json.dumps(worst["action"]),
"chosen_reward": best["reward"],
"rejected_reward": worst["reward"],
"margin": best["reward"] - worst["reward"],
"task_name": ep["task_name"],
"pool": ep["pool"],
"phase": phase,
"strategy": "within_episode",
"episode_id": ep["episode_id"],
})
# ── Strategy 2: cross-episode, same task+pool ─────────────────────────────
by_task_pool: Dict[str, List[dict]] = defaultdict(list)
for ep in episodes:
key = f"{ep['task_name']}_{ep['pool']}"
by_task_pool[key].append(ep)
for key, task_eps in by_task_pool.items():
if len(task_eps) < 2:
continue
# Sort by cumulative reward; pair best vs worst episode
sorted_eps = sorted(task_eps, key=lambda e: e["cumulative_reward"])
best_ep = sorted_eps[-1]
worst_ep = sorted_eps[0]
if best_ep["cumulative_reward"] == worst_ep["cumulative_reward"]:
continue
if best_ep["episode_id"] == worst_ep["episode_id"]:
continue
# Use the first non-view_alerts action as representative
def _first_substantive_action(ep_inner, phase):
steps = ep_inner.get(f"p{phase}_steps", [])
for s in steps:
if s["action"].get("action_type") != "view_alerts":
return s
return steps[0] if steps else None
for phase in [1, 2]:
best_step = _first_substantive_action(best_ep, phase)
worst_step = _first_substantive_action(worst_ep, phase)
initial_msg = best_ep.get(f"_initial_p{phase}_msg", "")
if not best_step or not worst_step or not initial_msg:
continue
prompt_msgs = build_messages([], initial_msg, phase=phase)
pairs.append({
"prompt": prompt_msgs,
"chosen": json.dumps(best_step["action"]),
"rejected": json.dumps(worst_step["action"]),
"chosen_reward": best_ep["cumulative_reward"],
"rejected_reward": worst_ep["cumulative_reward"],
"margin": best_ep["cumulative_reward"] - worst_ep["cumulative_reward"],
"task_name": best_ep["task_name"],
"pool": best_ep["pool"],
"phase": phase,
"strategy": "cross_episode",
"best_model": best_ep["model"],
"worst_model": worst_ep["model"],
})
return pairs
# ──────────────────────────────────────────────────────────────────────────────
# Episode Schedule Builder
# ──────────────────────────────────────────────────────────────────────────────
def build_episode_schedule(n: int) -> List[Tuple[str, str, str, int]]:
"""
Return list of (task_name, pool, model, seed) tuples.
Distribution:
- Pools weighted by POOL_WEIGHTS
- Tasks within each pool: round-robin
- Models: round-robin across all MODELS
- Seeds: random per episode (for reproducibility, logged in output)
"""
schedule: List[Tuple[str, str, str, int]] = []
pool_counts = {
pool: max(1, round(n * weight))
for pool, weight in POOL_WEIGHTS.items()
}
# Adjust to exactly n
total = sum(pool_counts.values())
diff = n - total
if diff > 0:
pool_counts["C"] += diff
elif diff < 0:
pool_counts["A"] += diff # reduce A if over
model_idx = 0
for pool, count in pool_counts.items():
tasks = POOL_TASKS[pool]
for i in range(count):
task = tasks[i % len(tasks)]
model = MODELS[model_idx % len(MODELS)]
seed = random.randint(0, 99999)
schedule.append((task, pool, model, seed))
model_idx += 1
random.shuffle(schedule)
return schedule
# ──────────────────────────────────────────────────────────────────────────────
# Main
# ──────────────────────────────────────────────────────────────────────────────
def _flush_episode(ep: dict, raw_f, sft_f) -> Tuple[int, int, int]:
"""Append one episode to the open raw and SFT files. Returns (pos, zer, neg) step counts."""
clean = {k: v for k, v in ep.items() if not k.startswith("_")}
raw_f.write(json.dumps(clean) + "\n")
raw_f.flush()
samples = episode_to_sft_samples(ep)
for s in samples:
sft_f.write(json.dumps(s) + "\n")
sft_f.flush()
pos = sum(1 for s in samples if s["reward"] > 0)
zer = sum(1 for s in samples if s["reward"] == 0)
neg = sum(1 for s in samples if s["reward"] < 0)
return pos, zer, neg
def _finalize(all_episodes: List[dict], stats: Dict[str, List[float]]) -> None:
"""Generate GRPO pairs and print final statistics from whatever was collected."""
print(f"\n{'═'*60}")
print(f"βœ… Collected {len(all_episodes)} episodes")
grpo_pairs = episodes_to_grpo_pairs(all_episodes)
grpo_path = "sre_grpo_dataset.jsonl"
with open(grpo_path, "w") as f:
for p in grpo_pairs:
f.write(json.dumps(p) + "\n")
within = sum(1 for p in grpo_pairs if p["strategy"] == "within_episode")
cross = sum(1 for p in grpo_pairs if p["strategy"] == "cross_episode")
print(f"πŸ’Ύ GRPO dataset ({len(grpo_pairs)} pairs) β†’ {grpo_path}")
print(f" Pairs: {within} within-episode + {cross} cross-episode")
if not all_episodes:
return
all_rewards = [ep["cumulative_reward"] for ep in all_episodes]
print(f"\nπŸ“ˆ Reward statistics:")
print(f" Overall: avg={sum(all_rewards)/len(all_rewards):.3f} "
f"max={max(all_rewards):.3f} min={min(all_rewards):.3f}")
print(f"\n By pool:")
for pool in ["A", "B", "C", "D"]:
rs = stats.get(f"pool_{pool}", [])
if rs:
print(f" Pool {pool}: n={len(rs):>3} avg={sum(rs)/len(rs):.3f} "
f"max={max(rs):.3f} min={min(rs):.3f}")
print(f"\n By task:")
for task in sorted(set(ep["task_name"] for ep in all_episodes)):
rs = stats.get(f"task_{task}", [])
if rs:
print(f" {task:<35} n={len(rs):>2} avg={sum(rs)/len(rs):.3f}")
print(f"\n By model tier:")
model_short_names = set()
for ep in all_episodes:
model_short_names.add(ep["model"].split("/")[1].split(":")[0])
for mname in sorted(model_short_names):
rs = stats.get(f"model_{mname}", [])
if rs:
print(f" {mname:<40} n={len(rs):>2} avg={sum(rs)/len(rs):.3f}")
def main():
from huggingface_hub import HfApi
api = HfApi(token=HF_TOKEN)
api.create_repo(repo_id="srinjoyd/sre-data", repo_type="dataset", exist_ok=True)
if not HF_TOKEN:
print("❌ HF_TOKEN is not set.\n export HF_TOKEN=hf_...")
return
print(f"πŸš€ SRE Trajectory Collector")
print(f" Episodes: {NUM_EPISODES}")
print(f" Models: {len(MODELS)} (rotating)")
print(f" Tasks: {len(set(t for ts in POOL_TASKS.values() for t in ts))} unique")
print(f" Pools: A / B / C / D")
print(f" Base URL: {BASE_URL}")
print(f" Keeping ALL episodes (negative reward = hard negatives for RL)")
print(f" Saving incrementally β€” Ctrl+C safe\n")
schedule = build_episode_schedule(NUM_EPISODES)
all_episodes: List[dict] = []
stats: Dict[str, List[float]] = defaultdict(list)
total_pos = total_zer = total_neg = 0
raw_path = "sre_raw_trajectories.jsonl"
sft_path = "sre_sft_dataset.jsonl"
print(f"πŸ’Ύ Writing to: {raw_path} | {sft_path} (appending per episode)")
print(f" GRPO pairs written at end (or on Ctrl+C)\n")
with open(raw_path, "a") as raw_f, open(sft_path, "a") as sft_f:
try:
for ep_id, (task, pool, model, seed) in enumerate(schedule):
try:
ep = run_episode(task, pool, model, ep_id, seed=seed)
except Exception as e:
print(f" [!] Episode {ep_id+1} FAILED: {e}")
traceback.print_exc()
time.sleep(2)
continue
all_episodes.append(ep)
pos, zer, neg = _flush_episode(ep, raw_f, sft_f)
upload_checkpoint(api, "srinjoyd/sre-data")
total_pos += pos
total_zer += zer
total_neg += neg
r = ep["cumulative_reward"]
stats[f"pool_{pool}"].append(r)
stats[f"task_{task}"].append(r)
stats[f"model_{model.split('/')[1].split(':')[0]}"].append(r)
print(f" [saved] ep {ep_id+1}/{NUM_EPISODES} | "
f"SFT steps so far: +{total_pos}/0:{total_zer}/-{total_neg}")
time.sleep(1.0)
except KeyboardInterrupt:
print(f"\n\n⚠️ Interrupted after {len(all_episodes)} episodes β€” saving what we have...")
finally:
upload_checkpoint(api, "srinjoyd/sre-data") # always runs, even on crash
_finalize(all_episodes, stats)
print(f"\nπŸ’Ύ Raw trajectories ({len(all_episodes)} eps) β†’ {raw_path}")
print(f"πŸ’Ύ SFT dataset ({total_pos+total_zer+total_neg} steps) β†’ {sft_path}")
print(f" Reward split: +{total_pos} / 0:{total_zer} / -{total_neg}")
_finalize(all_episodes, stats)
print(f"\nπŸ’‘ To upload to HuggingFace Hub:")
print(f" from datasets import Dataset")
print(f" import json")
print(f" sft = [json.loads(l) for l in open('sre_sft_dataset.jsonl')]")
print(f" grpo = [json.loads(l) for l in open('sre_grpo_dataset.jsonl')]")
print(f" Dataset.from_list(sft).push_to_hub('your-username/sre-sft-data')")
print(f" Dataset.from_list(grpo).push_to_hub('your-username/sre-grpo-data')")
if __name__ == "__main__":
main()