Spaces:
Sleeping
Sleeping
| # /// script | |
| # dependencies = [ | |
| # "requests", | |
| # "huggingface_hub", | |
| # ] | |
| # /// | |
| """ | |
| ============================================================= | |
| SRE INCIDENT RESPONSE β COMPREHENSIVE TRAJECTORY COLLECTOR | |
| ============================================================= | |
| Generates fine-tuning trajectories from the SRE incident simulator covering: | |
| β’ All 10 tasks (8 training + 2 held-out compound scenarios) | |
| β’ All 4 pools: | |
| A β Phase 1 only (incident response) | |
| B β Phase 2 only (code investigation, oracle belief injected) | |
| C β Joint P1βP2 (full two-phase pipeline) | |
| D β Held-out joint (generalization test) | |
| β’ Full 17-action action space across both phases | |
| β’ Multiple models from 1.5B to 70B+ (round-robin rotation) | |
| β’ ALL episodes retained β negative-reward trajectories are kept as | |
| hard-negative examples for RL/GRPO training | |
| Output files: | |
| sre_raw_trajectories.json β full episode records with score breakdowns | |
| sre_sft_dataset.jsonl β per-step SFT samples (both phases, all rewards) | |
| sre_grpo_dataset.jsonl β (prompt, chosen, rejected) pairs for GRPO/DPO | |
| Usage: | |
| export HF_TOKEN=hf_... | |
| python sre_finetune_collector.py | |
| Optional env vars: | |
| NUM_EPISODES total episodes to collect (default: 200) | |
| BASE_URL simulator URL (default: HF Space URL) | |
| MAX_STEPS max steps per episode (default: 35) | |
| SLEEP_BETWEEN seconds between steps (default: 0.6) | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import random | |
| import time | |
| import traceback | |
| from collections import defaultdict | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import requests | |
| def upload_checkpoint(api, repo_id): | |
| for fname in ["sre_raw_trajectories.jsonl", "sre_sft_dataset.jsonl", "sre_grpo_dataset.jsonl"]: | |
| if os.path.exists(fname): | |
| try: | |
| api.upload_file( | |
| path_or_fileobj=fname, | |
| path_in_repo=fname, | |
| repo_id=repo_id, | |
| repo_type="dataset", | |
| ) | |
| print(f"β Uploaded {fname}") | |
| except Exception as e: | |
| print(f"β Upload failed {fname}: {e}") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Configuration | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| BASE_URL = os.environ.get("BASE_URL", "https://meta-hf-hackathon-updated-policy.hf.space") | |
| HF_ROUTER_URL = "https://router.huggingface.co/v1/chat/completions" | |
| NUM_EPISODES = int(os.environ.get("NUM_EPISODES", "100")) | |
| MAX_STEPS = int(os.environ.get("MAX_STEPS", "35")) | |
| SLEEP_BETWEEN = float(os.environ.get("SLEEP_BETWEEN", "0.6")) | |
| # ββ Model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODELS: List[str] = [ | |
| "Qwen/Qwen2.5-7B-Instruct:fastest", | |
| ] | |
| # ββ Task registry per pool ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Pool A: P1-only incident response (all 8 training tasks) | |
| # Pool B: P2-only code investigation (oracle belief injected; 7 tasks with code) | |
| # Pool C: Joint P1βP2 full pipeline (all 8 training tasks) | |
| # Pool D: Held-out joint (2 compound scenarios β generalization evaluation) | |
| POOL_TASKS: Dict[str, List[str]] = { | |
| "A": [ | |
| "memory_leak", "cascading_failure", "distributed_deadlock", | |
| "circuit_breaker_noop", "aliased_fault", "severity_inversion", | |
| "confidence_inversion", "info_ordering", | |
| ], | |
| "B": [ | |
| "memory_leak", "cascading_failure", "distributed_deadlock", | |
| "aliased_fault", "severity_inversion", "confidence_inversion", "info_ordering", | |
| ], | |
| "C": [ | |
| "memory_leak", "cascading_failure", "distributed_deadlock", | |
| "circuit_breaker_noop", "aliased_fault", "severity_inversion", | |
| "confidence_inversion", "info_ordering", | |
| ], | |
| "D": [ | |
| "heldout_aliased_severity", "heldout_confidence_ordering", | |
| ], | |
| } | |
| # Episode budget distribution across pools (must sum to 1.0) | |
| POOL_WEIGHTS: Dict[str, float] = {"A": 0.35, "B": 0.20, "C": 0.35, "D": 0.10} | |
| # ββ Action space definitions ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| P1_DIAGNOSTIC = ["view_alerts", "query_logs", "check_metrics", "check_dependencies", | |
| "check_deploy_history", "run_health_check"] | |
| P1_REMEDIATION = ["restart_service", "rollback_deploy", "scale_service"] | |
| P1_TERMINAL = ["declare_root_cause", "transition_to_phase2"] | |
| P1_ACTIONS = P1_DIAGNOSTIC + P1_REMEDIATION + P1_TERMINAL | |
| P2_DIAGNOSTIC = ["list_dir", "read_file", "search_code", "get_git_log", "get_file_diff"] | |
| P2_TERMINAL = ["propose_patch", "declare_no_change"] | |
| P2_ACTIONS = P2_DIAGNOSTIC + P2_TERMINAL | |
| ALL_SERVICES = ["api_gateway", "auth", "orders", "payment", "cache", "database", "queue"] | |
| TARGETED_ACTIONS = { | |
| "query_logs", "check_metrics", "check_dependencies", "check_deploy_history", | |
| "run_health_check", "restart_service", "rollback_deploy", "scale_service", | |
| } | |
| # Service dependency graph (for smarter fallbacks) | |
| DEPENDENCY_GRAPH: Dict[str, List[str]] = { | |
| "api_gateway": ["auth", "orders", "cache"], | |
| "auth": ["database"], | |
| "orders": ["database", "payment", "auth"], | |
| "payment": ["queue", "database"], | |
| "cache": [], | |
| "database": [], | |
| "queue": [], | |
| } | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # System Prompts | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| SYSTEM_PROMPT_P1 = """You are an expert SRE handling a production incident in a microservices system. | |
| ## Service Topology (downstream β upstream) | |
| api_gateway β auth, orders, cache | |
| auth β database | |
| orders β database, payment, auth | |
| payment β queue, database | |
| cache, database, queue β (no dependencies) | |
| ## Phase 1 Action Space | |
| Output EXACTLY ONE valid JSON action per turn. No markdown, no explanation. | |
| Diagnostic (read-only): | |
| {"action_type": "view_alerts"} | |
| {"action_type": "query_logs", "target_service": "<svc>", "parameters": {"level": "ERROR", "keyword": "<optional>", "limit": 20}} | |
| {"action_type": "check_metrics", "target_service": "<svc>"} | |
| {"action_type": "check_dependencies", "target_service": "<svc>"} | |
| {"action_type": "check_deploy_history", "target_service": "<svc>"} | |
| {"action_type": "run_health_check", "target_service": "<svc>"} | |
| Remediation (mutates state): | |
| {"action_type": "restart_service", "target_service": "<svc>"} | |
| {"action_type": "rollback_deploy", "target_service": "<svc>"} | |
| {"action_type": "scale_service", "target_service": "<svc>", "parameters": {"replicas": 5}} | |
| Declare root cause (ALL tasks β always call this once you have a diagnosis): | |
| {"action_type": "declare_root_cause", "parameters": {"root_cause": "<specific diagnosis β service, what failed, why>"}} | |
| Then for joint-mode tasks, ALSO transition to code investigation: | |
| {"action_type": "transition_to_phase2", "parameters": {"belief": { | |
| "suspected_service": "<root_cause_svc>", | |
| "suspected_fault_class": "memory_leak|config_change|deadlock|dep_upgrade|none", | |
| "service_confidence": 0.85, | |
| "fault_confidence": 0.80, | |
| "evidence_gaps": ["<what_you_didnt_check>"], | |
| "estimated_p2_cost": "low|medium|high", | |
| "decision": "transition", | |
| "reasoning": "<concise evidence summary>" | |
| }}} | |
| ## Investigation Strategy | |
| 1. ALWAYS start with view_alerts to understand severity and scope | |
| 2. check_metrics on the highest-alert service first | |
| 3. query_logs (level=ERROR) on degraded/down services | |
| 4. check_dependencies on the affected service to find upstream causes | |
| 5. check_deploy_history before any rollback | |
| 6. Remediate the ROOT CAUSE service, not the symptom | |
| 7. After 6-8 diagnostic steps you MUST call declare_root_cause with your diagnosis. | |
| For P1-only tasks this ends the episode. For joint-mode tasks, follow it immediately | |
| with transition_to_phase2. Do NOT keep diagnosing indefinitely β commit to a conclusion. | |
| CRITICAL: Output ONLY valid JSON. No markdown. No explanation. No code blocks.""" | |
| SYSTEM_PROMPT_P2 = """You are an expert SRE investigating a code-level fault in a sandboxed repository. | |
| ## Phase 2 Action Space | |
| Output EXACTLY ONE valid JSON action per turn. No markdown, no explanation. | |
| Code Exploration: | |
| {"action_type": "list_dir", "parameters": {"path": "."}} | |
| {"action_type": "read_file", "parameters": {"path": "relative/path/to/file.py"}} | |
| {"action_type": "search_code", "parameters": {"query": "<search string>", "file_pattern": "*.py", "max_hits": 20}} | |
| {"action_type": "get_git_log", "parameters": {"path": ".", "n_commits": 15}} | |
| {"action_type": "get_file_diff", "parameters": {"commit_sha": "<sha>", "path": "relative/path/file.py"}} | |
| Terminal: | |
| {"action_type": "propose_patch", "parameters": {"diff": "<unified diff β minimal, correct, applies cleanly>"}} | |
| {"action_type": "declare_no_change", "parameters": {"reason": "<why no code fix is needed β infrastructure issue, not code>"}} | |
| ## Investigation Strategy | |
| 1. list_dir "." to understand project structure | |
| 2. get_git_log to find recent commits β especially the bad_commit_sha from Phase 1 context | |
| 3. get_file_diff on the suspicious commit SHA to see what changed | |
| 4. read_file on affected files to understand the bug | |
| 5. search_code to find related patterns or the fault injection site | |
| 6. If you found a code bug: propose_patch with a minimal, syntactically valid unified diff. | |
| The bad_commit_sha in your context tells you exactly what changed β read that diff and revert/fix it. | |
| 7. declare_no_change ONLY if Phase 1 confirmed a spurious alert / circuit-breaker false positive | |
| with no deployment or code change involved. If there IS a bad commit in the git log, propose_patch. | |
| CRITICAL: Output ONLY valid JSON. No markdown. No explanation. No code blocks.""" | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Observation Formatters | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _fmt_service_statuses(statuses: Dict[str, str]) -> str: | |
| symbols = {"healthy": "β", "degraded": "~", "down": "β"} | |
| return " ".join( | |
| f"{symbols.get(v,'?')}{svc}({v})" | |
| for svc, v in sorted(statuses.items()) | |
| ) | |
| def _fmt_action_result(result: Any, max_chars: int = 3000) -> str: | |
| if result is None: | |
| return "(no result)" | |
| text = json.dumps(result, indent=2) if not isinstance(result, str) else result | |
| if len(text) > max_chars: | |
| text = text[:max_chars] + f"\n... [truncated {len(text)-max_chars} chars]" | |
| return text | |
| def format_initial_p1_obs(obs: dict, info: dict) -> str: | |
| """Format the very first observation for Phase 1.""" | |
| task = info.get("task_name", "unknown") | |
| pool = info.get("pool", "?") | |
| mode = info.get("mode", "unknown") | |
| phase = obs.get("current_phase", 1) | |
| svc_line = _fmt_service_statuses(obs.get("service_statuses", {})) | |
| valid = obs.get("valid_actions", P1_ACTIONS) | |
| return ( | |
| f"INCIDENT RESPONSE | Pool {pool} | Mode: {mode} | Task: {task}\n" | |
| f"{'β'*60}\n" | |
| f"Summary: {obs.get('incident_summary', 'No summary available')}\n" | |
| f"Severity: {obs.get('severity', '?')} | " | |
| f"Time Budget: {obs.get('time_budget_minutes', '?')} min | " | |
| f"Max Steps: {obs.get('max_steps', MAX_STEPS)}\n" | |
| f"Phase: {phase}\n" | |
| f"\nService Statuses:\n {svc_line}\n" | |
| f"Active Alerts: {obs.get('active_alerts_count', 0)}\n" | |
| f"\nValid Actions: {valid}\n" | |
| f"\nWhat is your FIRST action?" | |
| ) | |
| def format_step_result_p1(obs: dict, reward: float) -> str: | |
| """Format a step result during Phase 1.""" | |
| svc_line = _fmt_service_statuses(obs.get("service_statuses", {})) | |
| result = _fmt_action_result(obs.get("action_result")) | |
| lines = [ | |
| f"Action Result (success={obs.get('action_success', '?')}): " | |
| f"{obs.get('action_message', '')}", | |
| f"\n{result}", | |
| f"\n{'β'*40}", | |
| f"Services: {svc_line}", | |
| f"Alerts: {obs.get('active_alerts_count', 0)} active", | |
| f"Step: {obs.get('steps_taken','?')}/{obs.get('max_steps', MAX_STEPS)} " | |
| f"| Time: {obs.get('time_elapsed_minutes','?')}/{obs.get('time_budget_minutes','?')} min", | |
| f"Reward: {reward:+.3f} | Cumulative: {obs.get('cumulative_reward', 0):+.3f}", | |
| ] | |
| if obs.get("bad_commit_sha"): | |
| lines.append(f"Bad Commit SHA: {obs['bad_commit_sha']} (remember for Phase 2)") | |
| valid = obs.get("valid_actions", P1_ACTIONS) | |
| lines.append(f"\nValid Actions: {valid}") | |
| lines.append("\nWhat is your next action?") | |
| return "\n".join(lines) | |
| def format_initial_p2_obs(obs: dict, info: dict, belief: Optional[dict]) -> str: | |
| """Format the first Phase 2 observation (after transition or Pool B auto-start).""" | |
| task = info.get("task_name", "unknown") | |
| pool = info.get("pool", "?") | |
| belief_text = "" | |
| if belief: | |
| belief_text = ( | |
| f"\n[Phase 1 Belief]\n" | |
| f" Suspected service: {belief.get('suspected_service', '?')}\n" | |
| f" Suspected fault: {belief.get('suspected_fault_class', '?')}\n" | |
| f" Service confidence: {belief.get('service_confidence', 0):.0%}\n" | |
| f" Fault confidence: {belief.get('fault_confidence', 0):.0%}\n" | |
| f" Reasoning: {belief.get('reasoning', '')}\n" | |
| f" P2 cost estimate: {belief.get('estimated_p2_cost', '?')}\n" | |
| ) | |
| sha_line = ( | |
| f"Bad Commit SHA: {obs.get('bad_commit_sha', '(check git log)')}\n" | |
| if obs.get("bad_commit_sha") else "" | |
| ) | |
| return ( | |
| f"CODE INVESTIGATION | Pool {pool} | Task: {task}\n" | |
| f"{'β'*60}\n" | |
| f"{sha_line}" | |
| f"{belief_text}\n" | |
| f"Step: {obs.get('steps_taken', 0)}/{obs.get('max_steps', MAX_STEPS)} " | |
| f"| Cumulative Reward: {obs.get('cumulative_reward', 0):+.3f}\n" | |
| f"\nValid Actions: {obs.get('valid_actions', P2_ACTIONS)}\n" | |
| f"\nWhat is your first Phase 2 action?" | |
| ) | |
| def format_step_result_p2(obs: dict, reward: float) -> str: | |
| """Format a step result during Phase 2.""" | |
| result = _fmt_action_result(obs.get("action_result")) | |
| lines = [ | |
| f"Action Result (success={obs.get('action_success', '?')}): " | |
| f"{obs.get('action_message', '')}", | |
| f"\n{result}", | |
| f"\n{'β'*40}", | |
| f"Step: {obs.get('steps_taken','?')}/{obs.get('max_steps', MAX_STEPS)}", | |
| f"Reward: {reward:+.3f} | Cumulative: {obs.get('cumulative_reward', 0):+.3f}", | |
| f"\nValid Actions: {obs.get('valid_actions', P2_ACTIONS)}", | |
| "\nWhat is your next action?", | |
| ] | |
| return "\n".join(lines) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Message Builder | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_messages( | |
| history: List[Dict], | |
| initial_user_msg: str, | |
| phase: int, | |
| max_recent: int = 10, | |
| ) -> List[Dict]: | |
| """ | |
| Build the full OpenAI-format messages list. | |
| history: [{"action_json": str, "result_text": str, "reward": float}, ...] | |
| max_recent caps how many turns are included to avoid context-length 422s. | |
| """ | |
| system = SYSTEM_PROMPT_P1 if phase == 1 else SYSTEM_PROMPT_P2 | |
| messages: List[Dict] = [ | |
| {"role": "system", "content": system}, | |
| {"role": "user", "content": initial_user_msg}, | |
| ] | |
| for entry in history[-max_recent:]: | |
| messages.append({"role": "assistant", "content": entry["action_json"]}) | |
| messages.append({"role": "user", "content": entry["result_text"]}) | |
| return messages | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Model Caller | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def call_model( | |
| messages: List[Dict], | |
| model: str, | |
| temperature: float = 0.5, | |
| max_tokens: int = 512, | |
| retries: int = 3, | |
| ) -> str: | |
| if not HF_TOKEN: | |
| raise ValueError("HF_TOKEN is not set.") | |
| payload = { | |
| "model": model, | |
| "messages": messages, | |
| "max_tokens": max_tokens, | |
| "temperature": temperature, | |
| } | |
| last_exc: Exception = RuntimeError("No attempts made") | |
| for attempt in range(retries): | |
| try: | |
| resp = requests.post( | |
| HF_ROUTER_URL, | |
| headers={ | |
| "Authorization": f"Bearer {HF_TOKEN}", | |
| "Content-Type": "application/json", | |
| }, | |
| json=payload, | |
| timeout=90, | |
| ) | |
| resp.raise_for_status() | |
| return resp.json()["choices"][0]["message"]["content"].strip() | |
| except requests.HTTPError as e: | |
| code = e.response.status_code if e.response is not None else 0 | |
| if code in (400, 422): | |
| raise # client-format errors β retrying won't help; let caller handle | |
| last_exc = e | |
| wait = 2 ** attempt | |
| print(f" [model retry {attempt+1}/{retries}] {e} β waiting {wait}s") | |
| time.sleep(wait) | |
| except Exception as e: | |
| last_exc = e | |
| wait = 2 ** attempt | |
| print(f" [model retry {attempt+1}/{retries}] {e} β waiting {wait}s") | |
| time.sleep(wait) | |
| raise last_exc | |
| def _merge_system_into_user(messages: List[Dict]) -> List[Dict]: | |
| """Fold system prompt into the first user message for models without system role.""" | |
| if not messages or messages[0]["role"] != "system": | |
| return messages | |
| system_text = messages[0]["content"] | |
| rest = messages[1:] | |
| if not rest or rest[0]["role"] != "user": | |
| return rest | |
| merged_first = {"role": "user", "content": f"{system_text}\n\n{rest[0]['content']}"} | |
| return [merged_first] + rest[1:] | |
| # Models confirmed to reject the system role β merged format used from the start. | |
| _MODELS_NEEDING_MERGE: set = set() | |
| def call_model_adaptive( | |
| history: List[Dict], | |
| initial_msg: str, | |
| phase: int, | |
| model: str, | |
| temperature: float = 0.5, | |
| ) -> str: | |
| """ | |
| Call model with two layers of fallback: | |
| 400 (system role) β merge system into first user message and cache the result | |
| so every subsequent step skips the wasted probe. | |
| 400 (after merge) β content still too long; halve history window. | |
| 422 (ctx length) β halve history window. | |
| """ | |
| use_merge = model in _MODELS_NEEDING_MERGE | |
| probed_merge = use_merge # True = already confirmed in a prior step, no re-probe needed | |
| max_recent = 10 | |
| while True: | |
| messages = build_messages(history, initial_msg, phase, max_recent=max_recent) | |
| if use_merge: | |
| messages = _merge_system_into_user(messages) | |
| try: | |
| result = call_model(messages, model=model, temperature=temperature) | |
| if use_merge and not probed_merge: | |
| # Merged succeeded for the first time β cache it | |
| _MODELS_NEEDING_MERGE.add(model) | |
| print(f" [merged format confirmed for {model.split('/')[1]}, cached]") | |
| return result | |
| except requests.HTTPError as e: | |
| code = e.response.status_code if e.response is not None else 0 | |
| if code == 400 and not use_merge: | |
| use_merge = True | |
| probed_merge = False | |
| print(f" [400: probing merged format for {model.split('/')[1]}]") | |
| elif code in (400, 422) and max_recent > 1: | |
| max_recent = max(1, max_recent // 2) | |
| print(f" [ctx truncated to last {max_recent} turns]") | |
| else: | |
| raise | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Action Parsers | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _extract_json(raw: str) -> dict: | |
| """Extract the first JSON object from model output and normalise colon-format action types. | |
| Some models output {"action_type": "check_metrics:api_gateway"} instead of | |
| the correct {"action_type": "check_metrics", "target_service": "api_gateway"}. | |
| Split and normalise so the environment never sees an invalid action_type. | |
| """ | |
| start = raw.find("{") | |
| end = raw.rfind("}") + 1 | |
| if start == -1 or end == 0: | |
| raise ValueError("No JSON object in model output") | |
| action = json.loads(raw[start:end]) | |
| atype = action.get("action_type", "") | |
| if ":" in atype: | |
| parts = atype.split(":", 1) | |
| action["action_type"] = parts[0] | |
| if parts[1] in ALL_SERVICES and "target_service" not in action: | |
| action["target_service"] = parts[1] | |
| return action | |
| def _recent_sigs(recent_actions: List[dict], n: int = 3) -> set: | |
| return {(a.get("action_type"), a.get("target_service")) for a in recent_actions[-n:]} | |
| def _diversify_p1(obs: dict, recent_actions: List[dict]) -> dict: | |
| """Return the next logical diagnostic action that hasn't been done recently.""" | |
| statuses = obs.get("service_statuses") or {} | |
| bad_svcs = [s for s, st in statuses.items() if st != "healthy"] | |
| used_sigs = _recent_sigs(recent_actions, n=4) | |
| used_svcs = {a.get("target_service") for a in recent_actions[-6:]} | |
| used_types = {a.get("action_type") for a in recent_actions[-4:]} | |
| # Build a uniform list of (score, action) tuples | |
| candidates: List[Tuple[int, dict]] = [] | |
| for atype in P1_DIAGNOSTIC: | |
| if atype == "view_alerts": | |
| score = 0 if ("view_alerts", None) in used_sigs else 2 | |
| candidates.append((score, {"action_type": "view_alerts"})) | |
| continue | |
| for svc in (bad_svcs or ALL_SERVICES): | |
| a: dict = {"action_type": atype, "target_service": svc} | |
| if atype == "query_logs": | |
| a["parameters"] = {"level": "ERROR", "limit": 20} | |
| already_used = (atype, svc) in used_sigs | |
| score = (not already_used) * 2 + (svc not in used_svcs) + (atype not in used_types) | |
| candidates.append((score, a)) | |
| candidates.sort(key=lambda x: -x[0]) | |
| if candidates: | |
| return candidates[0][1] | |
| # Last resort | |
| svc = next((s for s in ALL_SERVICES if s not in used_svcs), random.choice(ALL_SERVICES)) | |
| return {"action_type": "query_logs", "target_service": svc, | |
| "parameters": {"level": "ERROR", "limit": 20}} | |
| def parse_p1_action(raw: str, step: int, obs: dict, recent_actions: Optional[List[dict]] = None) -> dict: | |
| """Parse Phase 1 action with smart fallbacks and anti-repetition.""" | |
| recent_actions = recent_actions or [] | |
| valid = set(obs.get("valid_actions") or P1_ACTIONS) | |
| try: | |
| action = _extract_json(raw) | |
| atype = action.get("action_type", "") | |
| if atype not in valid: | |
| action = _diversify_p1(obs, recent_actions) if step > 0 else {"action_type": "view_alerts"} | |
| atype = action["action_type"] | |
| # Ensure target_service for targeted actions | |
| if atype in TARGETED_ACTIONS: | |
| if "target_service" not in action or action["target_service"] not in ALL_SERVICES: | |
| svcs = obs.get("available_services") or ALL_SERVICES | |
| action["target_service"] = random.choice(svcs) | |
| # Anti-repetition: if this exact (type, service) was used recently, diversify | |
| sig = (action.get("action_type"), action.get("target_service")) | |
| if sig in _recent_sigs(recent_actions, n=2) and atype not in ("declare_root_cause", "transition_to_phase2"): | |
| action = _diversify_p1(obs, recent_actions) | |
| atype = action["action_type"] | |
| # Validate transition_to_phase2 belief structure | |
| if atype == "transition_to_phase2": | |
| params = action.setdefault("parameters", {}) | |
| belief = params.setdefault("belief", {}) | |
| degraded = [s for s, st in (obs.get("service_statuses") or {}).items() | |
| if st != "healthy"] | |
| belief.setdefault("suspected_service", degraded[0] if degraded else random.choice(ALL_SERVICES)) | |
| belief.setdefault("suspected_fault_class", "memory_leak") | |
| belief.setdefault("service_confidence", 0.7) | |
| belief.setdefault("fault_confidence", 0.65) | |
| belief.setdefault("evidence_gaps", []) | |
| belief.setdefault("estimated_p2_cost", "medium") | |
| belief.setdefault("decision", "transition") | |
| belief.setdefault("reasoning", "Transitioning based on collected evidence") | |
| return action | |
| except Exception: | |
| if step == 0: | |
| return {"action_type": "view_alerts"} | |
| return _diversify_p1(obs, recent_actions) | |
| def _force_p1_terminal(obs: dict) -> dict: | |
| """Build a best-effort terminal action from observed state.""" | |
| valid = set(obs.get("valid_actions") or P1_ACTIONS) | |
| statuses = obs.get("service_statuses") or {} | |
| degraded = [s for s, st in statuses.items() if st != "healthy"] | |
| if "transition_to_phase2" in valid: | |
| svc = degraded[0] if degraded else random.choice(ALL_SERVICES) | |
| return { | |
| "action_type": "transition_to_phase2", | |
| "parameters": {"belief": { | |
| "suspected_service": svc, | |
| "suspected_fault_class": "memory_leak", | |
| "service_confidence": 0.5, | |
| "fault_confidence": 0.5, | |
| "evidence_gaps": ["forced_terminal_after_step_limit"], | |
| "estimated_p2_cost": "medium", | |
| "decision": "transition", | |
| "reasoning": f"Forced transition: degraded={degraded}", | |
| }}, | |
| } | |
| cause = (f"Degradation detected in: {', '.join(degraded)}" | |
| if degraded else "Root cause undetermined within step budget") | |
| return {"action_type": "declare_root_cause", | |
| "parameters": {"root_cause": cause}} | |
| def parse_p2_action(raw: str, step: int, obs: dict, recent_actions: Optional[List[dict]] = None) -> dict: | |
| """Parse Phase 2 action with smart fallbacks and anti-repetition.""" | |
| recent_actions = recent_actions or [] | |
| valid = set(obs.get("valid_actions") or P2_ACTIONS) | |
| used_sigs = _recent_sigs(recent_actions, n=3) | |
| p2_fallback_sequence = [ | |
| {"action_type": "list_dir", "parameters": {"path": "."}}, | |
| {"action_type": "get_git_log", "parameters": {"path": ".", "n_commits": 15}}, | |
| {"action_type": "search_code", "parameters": {"query": "error", "file_pattern": "*.py", "max_hits": 15}}, | |
| {"action_type": "search_code", "parameters": {"query": "def ", "file_pattern": "*.py", "max_hits": 10}}, | |
| {"action_type": "list_dir", "parameters": {"path": "src"}}, | |
| ] | |
| try: | |
| action = _extract_json(raw) | |
| atype = action.get("action_type", "") | |
| if atype not in valid: | |
| action = p2_fallback_sequence[step % len(p2_fallback_sequence)] | |
| atype = action["action_type"] | |
| # Ensure required params | |
| params = action.setdefault("parameters", {}) | |
| if atype == "list_dir": | |
| params.setdefault("path", ".") | |
| elif atype == "read_file": | |
| params.setdefault("path", ".") | |
| elif atype == "search_code": | |
| params.setdefault("query", "error") | |
| params.setdefault("file_pattern", "*.py") | |
| params.setdefault("max_hits", 15) | |
| elif atype == "get_git_log": | |
| params.setdefault("path", ".") | |
| params.setdefault("n_commits", 10) | |
| elif atype == "get_file_diff": | |
| sha = obs.get("bad_commit_sha") or "HEAD" | |
| params.setdefault("commit_sha", sha) | |
| params.setdefault("path", ".") | |
| elif atype == "propose_patch" and "diff" not in params: | |
| action = {"action_type": "declare_no_change", | |
| "parameters": {"reason": "Unable to determine code fix from available evidence"}} | |
| elif atype == "declare_no_change": | |
| params.setdefault("reason", "No code-level fix required based on investigation") | |
| # Anti-repetition for non-terminal actions | |
| sig = (action.get("action_type"), str(action.get("parameters", {}).get("path", ""))) | |
| if sig in used_sigs and atype not in P2_TERMINAL: | |
| action = p2_fallback_sequence[(step + len(recent_actions)) % len(p2_fallback_sequence)] | |
| return action | |
| except Exception: | |
| return p2_fallback_sequence[step % len(p2_fallback_sequence)] | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Environment HTTP Helpers | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _mask_p1_obs(obs: dict, pool: str) -> dict: | |
| """Pool A is p1_only β remove transition_to_phase2 the server incorrectly exposes.""" | |
| if pool == "A" and obs.get("valid_actions"): | |
| obs = dict(obs) | |
| obs["valid_actions"] = [a for a in obs["valid_actions"] if a != "transition_to_phase2"] | |
| return obs | |
| def env_reset(task_name: str, pool: str, seed: Optional[int] = None) -> dict: | |
| body: dict = {"task_name": task_name, "pool": pool} | |
| if seed is not None: | |
| body["seed"] = seed | |
| resp = requests.post(f"{BASE_URL}/reset", json=body, timeout=30) | |
| resp.raise_for_status() | |
| return resp.json() | |
| def env_step(action: dict) -> dict: | |
| resp = requests.post(f"{BASE_URL}/step", json=action, timeout=30) | |
| resp.raise_for_status() | |
| return resp.json() | |
| def env_score(declared_patch: Optional[str], declared_no_change: bool, | |
| belief_history: List[dict]) -> dict: | |
| """Fetch unified grader scores for the completed episode.""" | |
| try: | |
| resp = requests.post( | |
| f"{BASE_URL}/score", | |
| json={ | |
| "declared_patch": declared_patch, | |
| "declared_no_change": declared_no_change, | |
| "belief_history": belief_history, | |
| }, | |
| timeout=30, | |
| ) | |
| resp.raise_for_status() | |
| return resp.json() | |
| except Exception as e: | |
| print(f" [score] {e}") | |
| return {} | |
| def env_get_trajectory() -> dict: | |
| """Fetch the full trajectory from the server.""" | |
| try: | |
| resp = requests.get(f"{BASE_URL}/trajectory", timeout=30) | |
| resp.raise_for_status() | |
| return resp.json() | |
| except Exception: | |
| return {} | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Episode Runner | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_episode( | |
| task_name: str, | |
| pool: str, | |
| model: str, | |
| episode_id: int, | |
| seed: Optional[int] = None, | |
| ) -> dict: | |
| """ | |
| Run one full episode (Phase 1, Phase 2, or Joint) through the HTTP API. | |
| Returns a rich episode record including: | |
| - step records with (action, raw_model_output, observation, reward) | |
| - final score breakdown from /score | |
| - SFT-ready message sequences per step | |
| """ | |
| print(f"\n{'β'*60}") | |
| print(f" Ep {episode_id+1:>3} | Pool {pool} | Task: {task_name}") | |
| print(f" Model: {model}") | |
| print(f"{'β'*60}") | |
| reset_resp = env_reset(task_name, pool, seed) | |
| obs = _mask_p1_obs(reset_resp.get("observation", {}), pool) | |
| info = reset_resp.get("info", {}) | |
| initial_phase = obs.get("current_phase", 1) | |
| # Tracks for the episode | |
| p1_steps: List[dict] = [] | |
| p2_steps: List[dict] = [] | |
| belief_history: List[dict] = [] | |
| declared_patch: Optional[str] = None | |
| declared_no_change: bool = False | |
| # Conversation history per phase (for message building) | |
| p1_history: List[dict] = [] | |
| p2_history: List[dict] = [] | |
| last_belief: Optional[dict] = None | |
| # Recent actions (flattened, for anti-repetition) | |
| recent_actions: List[dict] = [] | |
| consecutive_errors = [0] # consecutive model call failures | |
| consecutive_negative = [0] # consecutive negative-reward steps (patience) | |
| # Initial user messages | |
| initial_p1_msg = format_initial_p1_obs(obs, info) | |
| initial_p2_msg: Optional[str] = None # set on transition | |
| current_phase = initial_phase | |
| done = False | |
| for step_idx in range(MAX_STEPS): | |
| if done: | |
| break | |
| # Pool B: transition to phase 2 only if the env actually started in phase 1. | |
| # If the env auto-transitioned during reset, current_phase is already 2 β skip. | |
| if pool == "B" and current_phase == 1 and len(p1_steps) == 0 and "transition_to_phase2" in (obs.get("valid_actions") or []): | |
| raw = "{}" | |
| action = { | |
| "action_type": "transition_to_phase2", | |
| "parameters": {"belief": { | |
| "suspected_service": None, | |
| "suspected_fault_class": None, | |
| "service_confidence": 0.0, | |
| "fault_confidence": 0.0, | |
| "evidence_gaps": [], | |
| "estimated_p2_cost": "unknown", | |
| "decision": "transition", | |
| "reasoning": "Pool B: oracle belief injected by environment", | |
| }}, | |
| } | |
| else: | |
| # Hard ceiling: force terminal if too close to max_steps | |
| p1_hard_limit = MAX_STEPS - 8 if pool in ("C", "D") else MAX_STEPS - 3 | |
| if current_phase == 1 and step_idx >= p1_hard_limit: | |
| action = _force_p1_terminal(obs) | |
| raw = json.dumps(action) | |
| print(f" [step limit: forcing terminal]") | |
| else: | |
| # Call model with adaptive history truncation on 422 | |
| cur_history = p1_history if current_phase == 1 else p2_history | |
| cur_initial = initial_p1_msg if current_phase == 1 else ( | |
| initial_p2_msg or format_initial_p2_obs(obs, info, last_belief) | |
| ) | |
| if current_phase == 2 and initial_p2_msg is None: | |
| initial_p2_msg = cur_initial | |
| model_ok = True | |
| try: | |
| raw = call_model_adaptive(cur_history, cur_initial, current_phase, model) | |
| except Exception as e: | |
| print(f" [model error] {e}") | |
| raw = "{}" | |
| model_ok = False | |
| # If model keeps failing in P1, force terminal after 8 consecutive errors | |
| if not model_ok: | |
| consecutive_errors[0] += 1 | |
| else: | |
| consecutive_errors[0] = 0 | |
| if current_phase == 1 and consecutive_errors[0] >= 8: | |
| action = _force_p1_terminal(obs) | |
| raw = json.dumps(action) | |
| consecutive_errors[0] = 0 | |
| print(f" [8 consecutive model errors: forcing terminal]") | |
| elif current_phase == 1: | |
| action = parse_p1_action(raw, step_idx, obs, recent_actions) | |
| else: | |
| action = parse_p2_action(raw, step_idx, obs, recent_actions) | |
| print(f" step {step_idx+1:>2} | ph{current_phase} | {action.get('action_type')}" | |
| + (f"({action.get('target_service','')})" if action.get("target_service") else "")) | |
| # Track terminal/transition actions before stepping | |
| atype = action.get("action_type", "") | |
| if atype == "transition_to_phase2": | |
| belief = action.get("parameters", {}).get("belief", {}) | |
| last_belief = belief | |
| belief_history.append(belief) | |
| if atype == "propose_patch": | |
| declared_patch = action.get("parameters", {}).get("diff", "") | |
| if atype == "declare_no_change": | |
| declared_no_change = True | |
| # Step environment | |
| try: | |
| step_resp = env_step(action) | |
| except Exception as e: | |
| print(f" [env error] {e}") | |
| break | |
| reward = float(step_resp.get("reward", 0.0)) | |
| done = step_resp.get("done", False) | |
| new_obs = step_resp.get("observation", {}) | |
| new_phase = new_obs.get("current_phase", current_phase) | |
| print(f" reward={reward:+.3f} cumulative={new_obs.get('cumulative_reward', 0):+.3f}" | |
| + (" DONE" if done else "")) | |
| # Build result text for next turn | |
| if current_phase == 1: | |
| result_text = format_step_result_p1(new_obs, reward) | |
| else: | |
| result_text = format_step_result_p2(new_obs, reward) | |
| step_record = { | |
| "step": step_idx, | |
| "phase": current_phase, | |
| "action": action, | |
| "raw_output": raw, | |
| "observation": new_obs, | |
| "reward": reward, | |
| "result_text": result_text, # stored for SFT building | |
| } | |
| if current_phase == 1: | |
| p1_steps.append(step_record) | |
| p1_history.append({"action_json": json.dumps(action), "result_text": result_text}) | |
| else: | |
| p2_steps.append(step_record) | |
| p2_history.append({"action_json": json.dumps(action), "result_text": result_text}) | |
| recent_actions.append(action) | |
| # Patience: 10 consecutive negative rewards β force terminal immediately | |
| if reward < 0: | |
| consecutive_negative[0] += 1 | |
| else: | |
| consecutive_negative[0] = 0 | |
| if consecutive_negative[0] >= 10 and not done: | |
| print(f" [patience exhausted: 10 consecutive negatives β forcing terminal]") | |
| if current_phase == 1: | |
| term_action = _force_p1_terminal(new_obs) | |
| else: | |
| term_action = {"action_type": "declare_no_change", | |
| "parameters": {"reason": "Patience exhausted β no progress detected"}} | |
| try: | |
| term_resp = env_step(term_action) | |
| term_reward = float(term_resp.get("reward", 0.0)) | |
| done = term_resp.get("done", False) | |
| term_obs = term_resp.get("observation", new_obs) | |
| print(f" [forced terminal] reward={term_reward:+.3f} cumulative={term_obs.get('cumulative_reward',0):+.3f} DONE") | |
| steps_list = p1_steps if current_phase == 1 else p2_steps | |
| steps_list.append({"step": step_idx + 1, "phase": current_phase, | |
| "action": term_action, "raw_output": "{}", | |
| "observation": term_obs, "reward": term_reward, | |
| "result_text": ""}) | |
| new_obs = term_obs | |
| except Exception as e: | |
| print(f" [forced terminal env error] {e}") | |
| break | |
| # Detect phase transition | |
| if new_phase != current_phase and new_phase == 2: | |
| print(" ββ Phase 1 β Phase 2 ββ") | |
| initial_p2_msg = format_initial_p2_obs(new_obs, info, last_belief) | |
| recent_actions.clear() # reset repetition tracking for the new phase | |
| consecutive_negative[0] = 0 # reset patience on phase change | |
| current_phase = new_phase | |
| obs = _mask_p1_obs(new_obs, pool) | |
| time.sleep(SLEEP_BETWEEN) | |
| # Fetch unified scores | |
| score = env_score(declared_patch, declared_no_change, belief_history) | |
| cumulative = obs.get("cumulative_reward", 0.0) | |
| print(f" Final cumulative reward: {cumulative:.3f}") | |
| if score: | |
| print(f" Scores: {json.dumps({k: round(v, 3) for k, v in score.items()})}") | |
| return { | |
| "episode_id": episode_id, | |
| "task_name": task_name, | |
| "pool": pool, | |
| "model": model, | |
| "seed": seed, | |
| "p1_steps": p1_steps, | |
| "p2_steps": p2_steps, | |
| "num_p1_steps": len(p1_steps), | |
| "num_p2_steps": len(p2_steps), | |
| "cumulative_reward": round(cumulative, 4), | |
| "score_breakdown": score, | |
| "declared_patch": declared_patch, | |
| "declared_no_change": declared_no_change, | |
| "belief_history": belief_history, | |
| "done": done, | |
| # Reconstructed conversation contexts for SFT building | |
| "_initial_p1_msg": initial_p1_msg, | |
| "_initial_p2_msg": initial_p2_msg, | |
| "_p1_history": p1_history, | |
| "_p2_history": p2_history, | |
| } | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # SFT Dataset Formatter | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def episode_to_sft_samples(ep: dict) -> List[dict]: | |
| """ | |
| Convert one episode into per-step SFT samples for BOTH phases. | |
| ALL steps are included regardless of reward β negative-reward steps | |
| provide hard-negative signal critical for RL/preference training. | |
| The `reward` field is preserved so the training code can filter or weight. | |
| """ | |
| samples: List[dict] = [] | |
| def _extract_samples(steps, phase, initial_msg, history_key): | |
| history_so_far: List[dict] = [] | |
| for i, step_rec in enumerate(steps): | |
| system = SYSTEM_PROMPT_P1 if phase == 1 else SYSTEM_PROMPT_P2 | |
| messages = build_messages(history_so_far, initial_msg, phase=phase) | |
| messages.append({ | |
| "role": "assistant", | |
| "content": json.dumps(step_rec["action"]), | |
| }) | |
| samples.append({ | |
| "messages": messages, | |
| "reward": step_rec["reward"], | |
| "phase": phase, | |
| "action_type": step_rec["action"].get("action_type"), | |
| "task_name": ep["task_name"], | |
| "pool": ep["pool"], | |
| "model": ep["model"], | |
| "episode_id": ep["episode_id"], | |
| "step": i, | |
| }) | |
| history_so_far.append({ | |
| "action_json": json.dumps(step_rec["action"]), | |
| "result_text": step_rec.get("result_text", ""), | |
| }) | |
| if ep.get("p1_steps") and ep.get("_initial_p1_msg"): | |
| _extract_samples(ep["p1_steps"], 1, ep["_initial_p1_msg"], "_p1_history") | |
| if ep.get("p2_steps") and ep.get("_initial_p2_msg"): | |
| _extract_samples(ep["p2_steps"], 2, ep["_initial_p2_msg"], "_p2_history") | |
| return samples | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # GRPO / DPO Dataset Formatter | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def episodes_to_grpo_pairs(episodes: List[dict]) -> List[dict]: | |
| """ | |
| Build (prompt, chosen, rejected) triplets for GRPO/DPO training. | |
| Three pairing strategies: | |
| 1. Within-episode: best vs worst step (same prompt context) | |
| 2. Cross-episode: same task+pool, different models, different outcomes | |
| 3. Phase-specific: separate within-phase pairs for P2 | |
| Chosen = action with higher reward. | |
| Rejected = action with lower reward. | |
| Both are kept regardless of absolute reward sign. | |
| """ | |
| pairs: List[dict] = [] | |
| # ββ Strategy 1: within-episode best/worst per phase βββββββββββββββββββββββ | |
| for ep in episodes: | |
| for phase, steps, initial_msg in [ | |
| (1, ep.get("p1_steps", []), ep.get("_initial_p1_msg", "")), | |
| (2, ep.get("p2_steps", []), ep.get("_initial_p2_msg", "")), | |
| ]: | |
| if len(steps) < 2 or not initial_msg: | |
| continue | |
| best = max(steps, key=lambda s: s["reward"]) | |
| worst = min(steps, key=lambda s: s["reward"]) | |
| if best["reward"] == worst["reward"]: | |
| continue | |
| if best is worst: | |
| continue | |
| prompt_msgs = build_messages([], initial_msg, phase=phase) | |
| pairs.append({ | |
| "prompt": prompt_msgs, | |
| "chosen": json.dumps(best["action"]), | |
| "rejected": json.dumps(worst["action"]), | |
| "chosen_reward": best["reward"], | |
| "rejected_reward": worst["reward"], | |
| "margin": best["reward"] - worst["reward"], | |
| "task_name": ep["task_name"], | |
| "pool": ep["pool"], | |
| "phase": phase, | |
| "strategy": "within_episode", | |
| "episode_id": ep["episode_id"], | |
| }) | |
| # ββ Strategy 2: cross-episode, same task+pool βββββββββββββββββββββββββββββ | |
| by_task_pool: Dict[str, List[dict]] = defaultdict(list) | |
| for ep in episodes: | |
| key = f"{ep['task_name']}_{ep['pool']}" | |
| by_task_pool[key].append(ep) | |
| for key, task_eps in by_task_pool.items(): | |
| if len(task_eps) < 2: | |
| continue | |
| # Sort by cumulative reward; pair best vs worst episode | |
| sorted_eps = sorted(task_eps, key=lambda e: e["cumulative_reward"]) | |
| best_ep = sorted_eps[-1] | |
| worst_ep = sorted_eps[0] | |
| if best_ep["cumulative_reward"] == worst_ep["cumulative_reward"]: | |
| continue | |
| if best_ep["episode_id"] == worst_ep["episode_id"]: | |
| continue | |
| # Use the first non-view_alerts action as representative | |
| def _first_substantive_action(ep_inner, phase): | |
| steps = ep_inner.get(f"p{phase}_steps", []) | |
| for s in steps: | |
| if s["action"].get("action_type") != "view_alerts": | |
| return s | |
| return steps[0] if steps else None | |
| for phase in [1, 2]: | |
| best_step = _first_substantive_action(best_ep, phase) | |
| worst_step = _first_substantive_action(worst_ep, phase) | |
| initial_msg = best_ep.get(f"_initial_p{phase}_msg", "") | |
| if not best_step or not worst_step or not initial_msg: | |
| continue | |
| prompt_msgs = build_messages([], initial_msg, phase=phase) | |
| pairs.append({ | |
| "prompt": prompt_msgs, | |
| "chosen": json.dumps(best_step["action"]), | |
| "rejected": json.dumps(worst_step["action"]), | |
| "chosen_reward": best_ep["cumulative_reward"], | |
| "rejected_reward": worst_ep["cumulative_reward"], | |
| "margin": best_ep["cumulative_reward"] - worst_ep["cumulative_reward"], | |
| "task_name": best_ep["task_name"], | |
| "pool": best_ep["pool"], | |
| "phase": phase, | |
| "strategy": "cross_episode", | |
| "best_model": best_ep["model"], | |
| "worst_model": worst_ep["model"], | |
| }) | |
| return pairs | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Episode Schedule Builder | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_episode_schedule(n: int) -> List[Tuple[str, str, str, int]]: | |
| """ | |
| Return list of (task_name, pool, model, seed) tuples. | |
| Distribution: | |
| - Pools weighted by POOL_WEIGHTS | |
| - Tasks within each pool: round-robin | |
| - Models: round-robin across all MODELS | |
| - Seeds: random per episode (for reproducibility, logged in output) | |
| """ | |
| schedule: List[Tuple[str, str, str, int]] = [] | |
| pool_counts = { | |
| pool: max(1, round(n * weight)) | |
| for pool, weight in POOL_WEIGHTS.items() | |
| } | |
| # Adjust to exactly n | |
| total = sum(pool_counts.values()) | |
| diff = n - total | |
| if diff > 0: | |
| pool_counts["C"] += diff | |
| elif diff < 0: | |
| pool_counts["A"] += diff # reduce A if over | |
| model_idx = 0 | |
| for pool, count in pool_counts.items(): | |
| tasks = POOL_TASKS[pool] | |
| for i in range(count): | |
| task = tasks[i % len(tasks)] | |
| model = MODELS[model_idx % len(MODELS)] | |
| seed = random.randint(0, 99999) | |
| schedule.append((task, pool, model, seed)) | |
| model_idx += 1 | |
| random.shuffle(schedule) | |
| return schedule | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Main | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _flush_episode(ep: dict, raw_f, sft_f) -> Tuple[int, int, int]: | |
| """Append one episode to the open raw and SFT files. Returns (pos, zer, neg) step counts.""" | |
| clean = {k: v for k, v in ep.items() if not k.startswith("_")} | |
| raw_f.write(json.dumps(clean) + "\n") | |
| raw_f.flush() | |
| samples = episode_to_sft_samples(ep) | |
| for s in samples: | |
| sft_f.write(json.dumps(s) + "\n") | |
| sft_f.flush() | |
| pos = sum(1 for s in samples if s["reward"] > 0) | |
| zer = sum(1 for s in samples if s["reward"] == 0) | |
| neg = sum(1 for s in samples if s["reward"] < 0) | |
| return pos, zer, neg | |
| def _finalize(all_episodes: List[dict], stats: Dict[str, List[float]]) -> None: | |
| """Generate GRPO pairs and print final statistics from whatever was collected.""" | |
| print(f"\n{'β'*60}") | |
| print(f"β Collected {len(all_episodes)} episodes") | |
| grpo_pairs = episodes_to_grpo_pairs(all_episodes) | |
| grpo_path = "sre_grpo_dataset.jsonl" | |
| with open(grpo_path, "w") as f: | |
| for p in grpo_pairs: | |
| f.write(json.dumps(p) + "\n") | |
| within = sum(1 for p in grpo_pairs if p["strategy"] == "within_episode") | |
| cross = sum(1 for p in grpo_pairs if p["strategy"] == "cross_episode") | |
| print(f"πΎ GRPO dataset ({len(grpo_pairs)} pairs) β {grpo_path}") | |
| print(f" Pairs: {within} within-episode + {cross} cross-episode") | |
| if not all_episodes: | |
| return | |
| all_rewards = [ep["cumulative_reward"] for ep in all_episodes] | |
| print(f"\nπ Reward statistics:") | |
| print(f" Overall: avg={sum(all_rewards)/len(all_rewards):.3f} " | |
| f"max={max(all_rewards):.3f} min={min(all_rewards):.3f}") | |
| print(f"\n By pool:") | |
| for pool in ["A", "B", "C", "D"]: | |
| rs = stats.get(f"pool_{pool}", []) | |
| if rs: | |
| print(f" Pool {pool}: n={len(rs):>3} avg={sum(rs)/len(rs):.3f} " | |
| f"max={max(rs):.3f} min={min(rs):.3f}") | |
| print(f"\n By task:") | |
| for task in sorted(set(ep["task_name"] for ep in all_episodes)): | |
| rs = stats.get(f"task_{task}", []) | |
| if rs: | |
| print(f" {task:<35} n={len(rs):>2} avg={sum(rs)/len(rs):.3f}") | |
| print(f"\n By model tier:") | |
| model_short_names = set() | |
| for ep in all_episodes: | |
| model_short_names.add(ep["model"].split("/")[1].split(":")[0]) | |
| for mname in sorted(model_short_names): | |
| rs = stats.get(f"model_{mname}", []) | |
| if rs: | |
| print(f" {mname:<40} n={len(rs):>2} avg={sum(rs)/len(rs):.3f}") | |
| def main(): | |
| from huggingface_hub import HfApi | |
| api = HfApi(token=HF_TOKEN) | |
| api.create_repo(repo_id="srinjoyd/sre-data", repo_type="dataset", exist_ok=True) | |
| if not HF_TOKEN: | |
| print("β HF_TOKEN is not set.\n export HF_TOKEN=hf_...") | |
| return | |
| print(f"π SRE Trajectory Collector") | |
| print(f" Episodes: {NUM_EPISODES}") | |
| print(f" Models: {len(MODELS)} (rotating)") | |
| print(f" Tasks: {len(set(t for ts in POOL_TASKS.values() for t in ts))} unique") | |
| print(f" Pools: A / B / C / D") | |
| print(f" Base URL: {BASE_URL}") | |
| print(f" Keeping ALL episodes (negative reward = hard negatives for RL)") | |
| print(f" Saving incrementally β Ctrl+C safe\n") | |
| schedule = build_episode_schedule(NUM_EPISODES) | |
| all_episodes: List[dict] = [] | |
| stats: Dict[str, List[float]] = defaultdict(list) | |
| total_pos = total_zer = total_neg = 0 | |
| raw_path = "sre_raw_trajectories.jsonl" | |
| sft_path = "sre_sft_dataset.jsonl" | |
| print(f"πΎ Writing to: {raw_path} | {sft_path} (appending per episode)") | |
| print(f" GRPO pairs written at end (or on Ctrl+C)\n") | |
| with open(raw_path, "a") as raw_f, open(sft_path, "a") as sft_f: | |
| try: | |
| for ep_id, (task, pool, model, seed) in enumerate(schedule): | |
| try: | |
| ep = run_episode(task, pool, model, ep_id, seed=seed) | |
| except Exception as e: | |
| print(f" [!] Episode {ep_id+1} FAILED: {e}") | |
| traceback.print_exc() | |
| time.sleep(2) | |
| continue | |
| all_episodes.append(ep) | |
| pos, zer, neg = _flush_episode(ep, raw_f, sft_f) | |
| upload_checkpoint(api, "srinjoyd/sre-data") | |
| total_pos += pos | |
| total_zer += zer | |
| total_neg += neg | |
| r = ep["cumulative_reward"] | |
| stats[f"pool_{pool}"].append(r) | |
| stats[f"task_{task}"].append(r) | |
| stats[f"model_{model.split('/')[1].split(':')[0]}"].append(r) | |
| print(f" [saved] ep {ep_id+1}/{NUM_EPISODES} | " | |
| f"SFT steps so far: +{total_pos}/0:{total_zer}/-{total_neg}") | |
| time.sleep(1.0) | |
| except KeyboardInterrupt: | |
| print(f"\n\nβ οΈ Interrupted after {len(all_episodes)} episodes β saving what we have...") | |
| finally: | |
| upload_checkpoint(api, "srinjoyd/sre-data") # always runs, even on crash | |
| _finalize(all_episodes, stats) | |
| print(f"\nπΎ Raw trajectories ({len(all_episodes)} eps) β {raw_path}") | |
| print(f"πΎ SFT dataset ({total_pos+total_zer+total_neg} steps) β {sft_path}") | |
| print(f" Reward split: +{total_pos} / 0:{total_zer} / -{total_neg}") | |
| _finalize(all_episodes, stats) | |
| print(f"\nπ‘ To upload to HuggingFace Hub:") | |
| print(f" from datasets import Dataset") | |
| print(f" import json") | |
| print(f" sft = [json.loads(l) for l in open('sre_sft_dataset.jsonl')]") | |
| print(f" grpo = [json.loads(l) for l in open('sre_grpo_dataset.jsonl')]") | |
| print(f" Dataset.from_list(sft).push_to_hub('your-username/sre-sft-data')") | |
| print(f" Dataset.from_list(grpo).push_to_hub('your-username/sre-grpo-data')") | |
| if __name__ == "__main__": | |
| main() |