#!/usr/bin/env python3 """Submission inference script with validator-compatible stdout logs.""" from __future__ import annotations import json import os import re from dataclasses import dataclass, field from typing import Any from openai import OpenAI from unified_incident_env.client import UnifiedIncidentEnv from unified_incident_env.models import ( PostmortemPayload, SecurityContext, UnifiedIncidentAction, UnifiedIncidentObservation, ) from unified_incident_env.server.challenge import SCENARIOS API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1" MODEL_NAME = os.getenv("MODEL_NAME") or "qwen2.5:1.5b" HF_TOKEN = os.getenv("HF_TOKEN") ENV_BASE_URL = os.getenv("ENV_BASE_URL") or UnifiedIncidentEnv.DEFAULT_BASE_URL ENV_NAME = "unified-incident-env" MAX_TOKENS = 220 INFERENCE_MODE = os.getenv("INFERENCE_MODE", "judge").strip().lower() POLICY_CARD_WORD_BUDGET_COMPACT = int(os.getenv("POLICY_CARD_WORD_BUDGET_COMPACT", "60")) POLICY_CARD_RULES = [ "Return JSON only.", "Use action_type.", "Use only allowed actions.", "No explanation text.", ] STAGE_GOALS = { "diagnosis": "find the most relevant next investigation step", "root_cause_analysis": "confirm the root-cause evidence and avoid unnecessary recovery", "security_subquest": "complete the security fix before infrastructure recovery", "remediation": "recover services in the correct order", "verification": "verify that recovery and security remediation are complete", "postmortem": "submit the final incident summary", "done": "complete the benchmark", } ACTION_KEYS = { "action_type", "service", "metric", "vulnerability_type", "patch_id", "postmortem", } KNOWN_ACTIONS = { "query_logs", "query_metrics", "query_dependencies", "restart_service", "rollback_deploy", "inspect_code", "classify_vulnerability", "apply_patch", "verify_security_fix", "submit_security_fix", "submit_postmortem", } LOCAL_ENDPOINT_MARKERS = ("127.0.0.1", "localhost") SERVICE_PRIORITY = ("database", "cache", "api-gateway", "worker") VULNERABILITY_KEYWORDS = { "sql_injection": ("sql injection", "sqli", "query", "parameter", "login"), "broken_access_control": ("access control", "authorization", "admin", "role", "permission"), "command_injection": ("command injection", "shell", "subprocess", "filename", "worker"), } PATCH_KEYWORDS = { "sql_injection": ("parameter", "prepared", "query"), "broken_access_control": ("admin", "role", "authoriz"), "command_injection": ("avoid_shell", "argv", "shell", "subprocess"), } SYSTEM_PROMPT = """You are solving a deterministic incident-response benchmark. Return exactly one JSON object and nothing else. Rules: - Choose only from the allowed action types shown in the user message. - Use only the required fields for the chosen action. - Do not include explanation text. - Do not include markdown. - Do not include code fences. - Do not repeat an action that already failed or made no progress. - If patching is required, use only one of the listed patch IDs. """ USER_PROMPT_TEMPLATE = """Current stage: {stage} Current goal: {goal} Allowed actions: {allowed_actions_block} Required fields: {required_fields_block} {patch_ids_block}{transition_block}{negative_reward_block}{loop_warning_block}Current environment state: {state_block} Valid example: {valid_example} Return exactly one JSON object. """ @dataclass class PolicyNote: stage: str failure_type: str mistake: str correction: str valid_example: dict[str, Any] action_family: str | None = None @dataclass class PolicyCardState: schema_notes: list[PolicyNote] = field(default_factory=list) failure_notes: list[PolicyNote] = field(default_factory=list) recovery_notes: list[PolicyNote] = field(default_factory=list) def log_start(task: str, env: str, model: str) -> None: print(f"[START] task={task} env={env} model={model}", flush=True) def log_step(step: int, action: str, reward: float, done: bool, error: str | None) -> None: error_val = error if error else "null" done_val = str(done).lower() print( f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True, ) def log_end(success: bool, steps: int, rewards: list[float]) -> None: rewards_str = ",".join(f"{reward:.2f}" for reward in rewards) print( f"[END] success={str(success).lower()} steps={steps} rewards={rewards_str}", flush=True, ) def action_to_log_string(action: UnifiedIncidentAction) -> str: return json.dumps( action.model_dump(exclude_none=True, exclude={"metadata"}), separators=(",", ":"), ) def create_client() -> OpenAI | None: if HF_TOKEN is None: raise ValueError("HF_TOKEN environment variable is required") try: return OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN, timeout=45.0) except Exception: return None def _inference_mode() -> str: return "small" if os.getenv("INFERENCE_MODE", INFERENCE_MODE).strip().lower() == "small" else "judge" def _is_local_ollama() -> bool: return any(marker in API_BASE_URL for marker in LOCAL_ENDPOINT_MARKERS) def _extract_json_candidate(raw: str) -> str: text = raw.strip() if "```" in text: parts = text.split("```") if len(parts) >= 2: text = parts[1] if text.startswith("json"): text = text[4:] start = text.find("{") end = text.rfind("}") if start != -1 and end != -1 and start < end: return text[start : end + 1] return text def parse_action( raw: str, observation: UnifiedIncidentObservation, *, scenario_id: str | None = None, history: list[dict[str, Any]] | None = None, ) -> UnifiedIncidentAction | None: stage_allowed_actions = _narrow_allowed_actions( observation, scenario_id=scenario_id, history=history or [], ) text = raw.strip() if not text: return None bare = text.strip().strip('"').strip("'") if bare in stage_allowed_actions and bare in KNOWN_ACTIONS: fields = observation.required_fields_by_action.get(bare, []) if not fields: return UnifiedIncidentAction(action_type=bare) example = observation.valid_action_example or {} if example.get("action_type") == bare: try: return UnifiedIncidentAction(**example) except Exception: return None return None try: payload = json.loads(_extract_json_candidate(text)) except Exception: return None if not isinstance(payload, dict): return None cleaned = {key: value for key, value in payload.items() if key in ACTION_KEYS} if "action_type" not in cleaned and isinstance(payload.get("action"), str): cleaned["action_type"] = payload["action"] if "vulnerability_type" not in cleaned and isinstance(payload.get("vulnerability"), str): cleaned["vulnerability_type"] = payload["vulnerability"] metrics_value = payload.get("metrics") if "metric" not in cleaned and isinstance(metrics_value, list) and len(metrics_value) == 1: cleaned["metric"] = metrics_value[0] action_type = cleaned.get("action_type") if action_type not in stage_allowed_actions: return None try: return UnifiedIncidentAction(**cleaned) except Exception: return None def choose_investigation_service(observation: UnifiedIncidentObservation) -> str: critical_alerts = [ alert.service for alert in observation.active_alerts if alert.severity == "critical" ] if critical_alerts: return critical_alerts[0] for service in SERVICE_PRIORITY: health = observation.service_health.get(service) if health and health.status == "crashed": return service for service in SERVICE_PRIORITY: health = observation.service_health.get(service) if health and health.status == "degraded": return service return "api-gateway" def choose_recovery_service(observation: UnifiedIncidentObservation) -> str: for service in SERVICE_PRIORITY: health = observation.service_health.get(service) if health and health.status == "crashed": return service for service in SERVICE_PRIORITY: health = observation.service_health.get(service) if health and health.status == "degraded": return service return "api-gateway" def infer_vulnerability(observation: UnifiedIncidentObservation, history: list[dict[str, Any]]) -> str: text_parts = [ observation.prompt_text, observation.tool_output or "", observation.security_unlock_reason or "", observation.last_action_result, observation.why_failed or "", ] text_parts.extend(str(item.get("result", "")) for item in history[-4:]) haystack = " ".join(text_parts).lower() best = "sql_injection" best_score = -1 for vulnerability, keywords in VULNERABILITY_KEYWORDS.items(): score = sum(1 for keyword in keywords if keyword in haystack) if score > best_score: best = vulnerability best_score = score return best def extract_patch_options(observation: UnifiedIncidentObservation) -> list[str]: sources = [observation.tool_output or "", observation.prompt_text] for source in sources: match = re.search(r"Patch options:\s*([^\n]+)", source) if not match: continue return [option.strip() for option in match.group(1).split(",") if option.strip()] return [] def _allowed_patch_ids(observation: UnifiedIncidentObservation) -> list[str]: options = extract_patch_options(observation) if not options: options = ["parameterized_query", "enforce_admin_role", "avoid_shell"] # If vulnerability is already classified, filter options to matching family vuln = observation.security_context.selected_vulnerability if vuln: keywords = PATCH_KEYWORDS.get(vuln, []) filtered = [ opt for opt in options if any(k in opt.lower() for k in keywords) ] if filtered: return filtered return options def _stage_hint( observation: UnifiedIncidentObservation, *, scenario_id: str | None = None, history: list[dict[str, Any]] | None = None, ) -> str: hard = _hard_transition_state( scenario_id=scenario_id, observation=observation, history=history or [], ) if hard["next_required_action"] is not None: return hard["next_required_action"] if hard["next_required_action_family"] is not None: return f"Next required action family: {hard['next_required_action_family']}." stage = observation.workflow_stage if stage == "diagnosis": return "Find the root cause with investigation before moving to security or recovery." if stage == "root_cause_analysis": return "Confirm the root cause and avoid broad extra queries." if stage == "security_subquest": return "Solve the security subquest with the next security action." if stage == "remediation": return "Recover the system with the allowed remediation action." if stage == "verification": return "Verify the fix before submitting the security fix." if stage == "postmortem": return "Submit the postmortem after the incident is resolved." return "Follow the current stage goal and allowed actions." def _stop_investigating_hint( observation: UnifiedIncidentObservation, *, scenario_id: str | None = None, history: list[dict[str, Any]] | None = None, ) -> str | None: hard = _hard_transition_state( scenario_id=scenario_id, observation=observation, history=history or [], ) if hard["stop_investigating"]: return hard["stop_message"] if observation.loop_warning: return "Stop repeating the same no-progress action; choose a different allowed action family." if observation.workflow_stage == "root_cause_analysis": return "Avoid broad investigation; confirm the root cause or move to the next stage." if observation.workflow_stage in {"security_subquest", "remediation", "verification", "postmortem"}: return "Avoid extra query_* investigation actions unless required by the current stage." return None def choose_patch_id(observation: UnifiedIncidentObservation, history: list[dict[str, Any]]) -> str: options = extract_patch_options(observation) vulnerability = infer_vulnerability(observation, history) keywords = PATCH_KEYWORDS[vulnerability] for option in options: lowered = option.lower() if any(keyword in lowered for keyword in keywords): return option if options: return options[0] defaults = { "sql_injection": "parameterized_query", "broken_access_control": "enforce_admin_role", "command_injection": "avoid_shell", } return defaults[vulnerability] def _timeline_entry(action: UnifiedIncidentAction) -> str: if action.action_type in {"query_logs", "query_dependencies"} and action.service: return f"{action.action_type} {action.service}" if action.action_type == "query_metrics" and action.service and action.metric: return f"query_metrics {action.service}.{action.metric}" if action.action_type in {"restart_service", "rollback_deploy"} and action.service: return f"{action.action_type} {action.service}" if action.action_type == "classify_vulnerability" and action.vulnerability_type: return f"classify_vulnerability {action.vulnerability_type}" if action.action_type == "apply_patch" and action.patch_id: return f"apply_patch {action.patch_id}" return action.action_type def _action_family(action_type: str | None) -> str | None: if action_type in {"query_logs", "query_metrics", "query_dependencies"}: return "investigate" if action_type in { "inspect_code", "classify_vulnerability", "apply_patch", "verify_security_fix", "submit_security_fix", }: return "security" if action_type in {"restart_service", "rollback_deploy"}: return "recovery" if action_type == "submit_postmortem": return "postmortem" return None def build_postmortem( observation: UnifiedIncidentObservation, history: list[dict[str, Any]], ) -> PostmortemPayload: vulnerability = infer_vulnerability(observation, history) selected_patch = observation.security_context.selected_patch root_cause_map = { "sql_injection": "SQL injection crashed the database and caused gateway errors.", "broken_access_control": "Broken access control on an admin path caused cache abuse and database degradation.", "command_injection": "Command injection in the worker poisoned downstream services after a bad deploy.", } attack_vector_map = { "sql_injection": "Unsanitized login input abused the SQL query path.", "broken_access_control": "Missing admin authorization exposed an internal cache-management route.", "command_injection": "Unsafe shell command construction allowed attacker-controlled filenames to execute commands.", } prevention_map = { "sql_injection": ["Parameterized queries", "Database abuse alerting"], "broken_access_control": ["Admin role enforcement", "Authorization tests"], "command_injection": ["Avoid shell invocation", "Safer deploy validation"], } timeline = [_timeline_entry(item["action"]) for item in history if "action" in item] remediation_steps = [] if selected_patch: remediation_steps.append(selected_patch.replace("_", " ")) remediation_steps.extend( item["action"].service.replace("-", " ") for item in history if "action" in item and item["action"].action_type in {"restart_service", "rollback_deploy"} and item["action"].service ) return PostmortemPayload( root_cause=root_cause_map[vulnerability], attack_vector=attack_vector_map[vulnerability], timeline=timeline[-6:], remediation_steps=remediation_steps[:4], prevention_steps=prevention_map[vulnerability], ) def build_fallback_action( observation: UnifiedIncidentObservation, history: list[dict[str, Any]], *, scenario_id: str | None = None, ) -> UnifiedIncidentAction: hard = _hard_transition_state( scenario_id=scenario_id, observation=observation, history=history, ) example = observation.valid_action_example or {} last_action = ( history[-1]["action"].model_dump(exclude_none=True, exclude={"metadata"}) if history and "action" in history[-1] else None ) narrowed_allowed_actions = _narrow_allowed_actions( observation, scenario_id=scenario_id, history=history, ) if example.get("action_type") in narrowed_allowed_actions and example != last_action: try: return UnifiedIncidentAction(**example) except Exception: pass stage = observation.workflow_stage security: SecurityContext = observation.security_context if stage in {"diagnosis", "root_cause_analysis"}: if hard["needs_unlock_bridge"]: return UnifiedIncidentAction( action_type="query_dependencies", service="api-gateway", ) if stage == "root_cause_analysis" and "query_dependencies" in observation.allowed_actions: return UnifiedIncidentAction( action_type="query_dependencies", service="api-gateway", ) if "query_logs" in observation.allowed_actions: return UnifiedIncidentAction( action_type="query_logs", service=choose_investigation_service(observation), ) if "query_dependencies" in observation.allowed_actions: return UnifiedIncidentAction( action_type="query_dependencies", service=choose_investigation_service(observation), ) return UnifiedIncidentAction( action_type="query_metrics", service=choose_investigation_service(observation), metric="cpu", ) if stage == "security_subquest": if not security.code_visible: return UnifiedIncidentAction(action_type="inspect_code") if security.selected_vulnerability is None: return UnifiedIncidentAction( action_type="classify_vulnerability", vulnerability_type=infer_vulnerability(observation, history), ) if security.selected_patch is None: return UnifiedIncidentAction( action_type="apply_patch", patch_id=choose_patch_id(observation, history), ) if security.exploit_blocked is not True or security.functionality_preserved is not True: return UnifiedIncidentAction(action_type="verify_security_fix") return UnifiedIncidentAction(action_type="submit_security_fix") if stage in {"remediation", "verification"}: if hard["force_worker_rollback"]: return UnifiedIncidentAction(action_type="rollback_deploy", service="worker") worker = observation.service_health.get("worker") if ( "rollback_deploy" in observation.allowed_actions and worker is not None and worker.status != "healthy" ): return UnifiedIncidentAction(action_type="rollback_deploy", service="worker") return UnifiedIncidentAction( action_type="restart_service", service=choose_recovery_service(observation), ) return UnifiedIncidentAction( action_type="submit_postmortem", postmortem=build_postmortem(observation, history), ) def build_compact_policy_card( observation: UnifiedIncidentObservation, state: PolicyCardState, history: list[dict[str, Any]] | None = None, *, scenario_id: str | None = None, ) -> str: """Brutally small policy card for weak backends.""" if history is None: history = [] stage_allowed_actions = _narrow_allowed_actions( observation, scenario_id=scenario_id, history=history, ) lines = [ f"STAGE: {observation.workflow_stage}", f"GOAL: {STAGE_GOALS.get(observation.workflow_stage, 'Pick one valid action.')}", f"HINT: {_stage_hint(observation, scenario_id=scenario_id, history=history)}", f"ALLOWED: {', '.join(stage_allowed_actions)}", ] stop_hint = _stop_investigating_hint( observation, scenario_id=scenario_id, history=history, ) if stop_hint: lines.append(f"STOP_INVESTIGATING: {stop_hint}") if observation.loop_warning: lines.append("LESSON: Do not repeat the same no-progress action.") elif state.failure_notes: lines.append(f"LESSON: {state.failure_notes[-1].correction}") example = observation.valid_action_example or {"action_type": stage_allowed_actions[0]} lines.append(f"EXAMPLE: {json.dumps(example, separators=(',', ':'))}") if "apply_patch" in stage_allowed_actions: lines.append(f"PATCH_IDS: {', '.join(_allowed_patch_ids(observation))}") lines.append("Return exactly one JSON object.") return _limit_words("\n".join(lines), max_words=POLICY_CARD_WORD_BUDGET_COMPACT) def build_policy_card( observation: UnifiedIncidentObservation, state: PolicyCardState, history: list[dict[str, Any]] | None = None, *, scenario_id: str | None = None, ) -> str: """Always use compact mode for small-model inference.""" return build_compact_policy_card( observation, state, history or [], scenario_id=scenario_id, ) def update_policy_card( state: PolicyCardState, *, before: UnifiedIncidentObservation, action: UnifiedIncidentAction, after: UnifiedIncidentObservation, model_error: str | None, ) -> None: if model_error == "invalid_model_output": state.schema_notes.append( PolicyNote( stage=before.workflow_stage, failure_type="invalid_model_output", mistake="The previous response was not one valid JSON action object.", correction="Return exactly one valid JSON action using only allowed actions.", valid_example=before.valid_action_example or {"action_type": before.allowed_actions[0]}, action_family=_action_family((before.valid_action_example or {}).get("action_type")), ) ) state.schema_notes = state.schema_notes[-4:] if after.failure_type and after.why_failed: example = after.valid_action_example or before.valid_action_example or {"action_type": before.allowed_actions[0]} family = after.best_recovery_action_family or _action_family(example.get("action_type")) correction = ( f"If this happens again, prefer {family} actions." if family else "Follow the current stage example and allowed actions." ) state.failure_notes.append( PolicyNote( stage=before.workflow_stage, failure_type=after.failure_type, mistake=after.why_failed, correction=correction, valid_example=example, action_family=family, ) ) state.failure_notes = state.failure_notes[-4:] if after.reward > 0 and after.failure_type is None: state.recovery_notes.append( PolicyNote( stage=before.workflow_stage, failure_type="successful_step", mistake="A weaker choice would likely have lost progress.", correction=f"This stage can progress with {_timeline_entry(action)}.", valid_example=action.model_dump(exclude_none=True, exclude={"metadata"}), action_family=_action_family(action.action_type), ) ) state.recovery_notes = state.recovery_notes[-4:] def _build_required_fields_block( required_fields_by_action: dict[str, list[str]], allowed_actions: list[str], ) -> str: lines = [] for action in allowed_actions: fields = required_fields_by_action.get(action, []) if fields: lines.append(f"- {action} -> {', '.join(fields)}") else: lines.append(f"- {action} -> none") return "\n".join(lines) or "- none" def _build_patch_ids_block(patch_ids: list[str]) -> str: if not patch_ids: return "" lines = ["Available patch IDs:"] lines.extend(f"- {patch_id}" for patch_id in patch_ids) lines.append("") return "\n".join(lines) def _build_transition_block(transition_hint: str | None) -> str: if not transition_hint: return "" return f"Important transition hint:\n- {transition_hint}\n\n" def _build_negative_reward_block(correction_hint: str | None) -> str: if not correction_hint: return "" return f"Previous action correction:\n- {correction_hint}\n\n" def _build_loop_warning_block(loop_warning: str | None) -> str: if not loop_warning: return "" return f"Loop warning:\n- {loop_warning}\n\n" def _bool_text(value: bool | None) -> str: if value is None: return "unknown" return str(value).lower() def _render_tool_output(observation: UnifiedIncidentObservation) -> str: if not observation.tool_output: return "" if observation.workflow_stage in {"security_subquest", "verification"}: lines = [line.rstrip() for line in observation.tool_output.splitlines() if line.strip()] return "\n".join(lines[:6]) return observation.tool_output.splitlines()[0] def _build_state_block(observation: UnifiedIncidentObservation) -> str: lines: list[str] = [] if observation.active_alerts: lines.append("Active alerts:") for alert in observation.active_alerts[:3]: lines.append(f"- {alert.service}: {alert.severity} - {alert.message}") lines.append(f"Final score: {observation.final_score:.4f}") if observation.last_action_result: lines.append(f"Last action result: {observation.last_action_result}") if observation.tool_output: rendered_tool_output = _render_tool_output(observation) if "\n" in rendered_tool_output: lines.append("Tool output:") lines.extend(rendered_tool_output.splitlines()) else: lines.append(f"Tool output: {rendered_tool_output}") security = observation.security_context if observation.workflow_stage in {"security_subquest", "verification"}: lines.append( "Security status: " f"code visible = {str(security.code_visible).lower()}, " f"vulnerability classified = {str(security.selected_vulnerability is not None).lower()}, " f"patch applied = {str(security.selected_patch is not None).lower()}, " f"exploit blocked = {_bool_text(security.exploit_blocked)}, " f"functionality preserved = {_bool_text(security.functionality_preserved)}" ) if observation.security_unlock_reason: lines.append(f"Security unlock reason: {observation.security_unlock_reason}") if observation.blocked_until_security_complete: lines.append("Recovery gate: security must be completed before recovery.") return "\n".join(lines) or "- none" def _extract_policy_hint(policy_card: str) -> str | None: for prefix in ("LESSON:", "STOP_INVESTIGATING:"): for line in policy_card.splitlines(): if line.startswith(prefix): return line.split(":", 1)[1].strip() return None def _user_prompt_example( observation: UnifiedIncidentObservation, allowed_actions: list[str], *, scenario_id: str | None = None, history: list[dict[str, Any]] | None = None, ) -> dict[str, Any]: example = observation.valid_action_example or {} if example.get("action_type") in allowed_actions: return example fallback = build_fallback_action( observation, history or [], scenario_id=scenario_id, ) return fallback.model_dump(exclude_none=True, exclude={"metadata"}) def build_user_prompt( observation: UnifiedIncidentObservation, policy_card: str, *, scenario_id: str | None = None, history: list[dict[str, Any]] | None = None, ) -> str: stage_allowed_actions = _narrow_allowed_actions( observation, scenario_id=scenario_id, history=history or [], ) required_fields = observation.required_fields_by_action or { action: [] for action in stage_allowed_actions } transition_hint = _stop_investigating_hint( observation, scenario_id=scenario_id, history=history or [], ) or _stage_hint( observation, scenario_id=scenario_id, history=history or [], ) correction_hint = None if observation.failure_type and observation.why_failed: correction_hint = observation.why_failed elif policy_card: correction_hint = _extract_policy_hint(policy_card) valid_example = _user_prompt_example( observation, stage_allowed_actions, scenario_id=scenario_id, history=history, ) return USER_PROMPT_TEMPLATE.format( stage=observation.workflow_stage, goal=STAGE_GOALS.get(observation.workflow_stage, "take the best next action"), allowed_actions_block="\n".join(f"- {action}" for action in stage_allowed_actions) or "- none", required_fields_block=_build_required_fields_block(required_fields, stage_allowed_actions), patch_ids_block=_build_patch_ids_block( _allowed_patch_ids(observation) if "apply_patch" in stage_allowed_actions else [] ), transition_block=_build_transition_block(transition_hint), negative_reward_block=_build_negative_reward_block(correction_hint), loop_warning_block=_build_loop_warning_block(observation.loop_warning), state_block=_build_state_block(observation), valid_example=json.dumps(valid_example, separators=(",", ":")), ) def _build_tool_schema( observation: UnifiedIncidentObservation, *, scenario_id: str | None = None, history: list[dict[str, Any]] | None = None, ) -> dict[str, Any]: allowed_actions = _narrow_allowed_actions( observation, scenario_id=scenario_id, history=history or [], ) properties: dict[str, Any] = { "action_type": {"type": "string", "enum": allowed_actions}, } if any(action in allowed_actions for action in {"query_logs", "query_metrics", "query_dependencies", "restart_service", "rollback_deploy"}): properties["service"] = { "type": "string", "enum": sorted(observation.service_health.keys()), } if "query_metrics" in allowed_actions: properties["metric"] = { "type": "string", "enum": ["cpu", "memory", "latency", "error_rate", "throughput"], } if "classify_vulnerability" in allowed_actions: properties["vulnerability_type"] = { "type": "string", "enum": ["sql_injection", "broken_access_control", "command_injection"], } if "apply_patch" in allowed_actions: properties["patch_id"] = { "type": "string", "enum": _allowed_patch_ids(observation), } if "submit_postmortem" in allowed_actions: properties["postmortem"] = {"type": "object"} required = ["action_type"] example = observation.valid_action_example or {} for field in ("service", "metric", "vulnerability_type", "patch_id", "postmortem"): if field in properties and field in example: required.append(field) return { "type": "object", "properties": properties, "required": required, "additionalProperties": False, } def _extract_completion_text(completion) -> str: message = completion.choices[0].message tool_calls = getattr(message, "tool_calls", None) or [] if tool_calls: function = getattr(tool_calls[0], "function", None) if function is not None and getattr(function, "arguments", None): return function.arguments return (message.content or "").strip() def _request_action_completion( client: OpenAI, observation: UnifiedIncidentObservation, user_prompt: str, *, temperature: float, scenario_id: str | None = None, history: list[dict[str, Any]] | None = None, ) -> str: import time max_retries = 3 last_exc = None schema = _build_tool_schema( observation, scenario_id=scenario_id, history=history or [], ) for attempt in range(max_retries): try: create_kwargs = { "model": MODEL_NAME, "messages": [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_prompt}, ], "temperature": temperature, "max_tokens": MAX_TOKENS, "stream": False, } if _is_local_ollama(): create_kwargs["extra_body"] = {"format": schema} completion = client.chat.completions.create(**create_kwargs) return _extract_completion_text(completion) try: # Try tool calling first completion = client.chat.completions.create( **create_kwargs, tools=[ { "type": "function", "function": { "name": "emit_action", "description": "Emit exactly one environment action.", "parameters": schema, }, } ], tool_choice={"type": "function", "function": {"name": "emit_action"}}, ) return _extract_completion_text(completion) except Exception: # Fallback to JSON mode completion = client.chat.completions.create( **create_kwargs, response_format={ "type": "json_schema", "json_schema": { "name": "unified_incident_action", "strict": True, "schema": schema, }, }, ) return _extract_completion_text(completion) except Exception as e: last_exc = e if attempt < max_retries - 1: time.sleep(2.0 * (attempt + 1)) continue raise last_exc return "" def attempt_repair( client: OpenAI, observation: UnifiedIncidentObservation, raw_output: str, *, scenario_id: str | None = None, history: list[dict[str, Any]] | None = None, ) -> UnifiedIncidentAction | None: example = observation.valid_action_example or { "action_type": _narrow_allowed_actions( observation, scenario_id=scenario_id, history=history or [], )[0] } repair_prompt = ( "Your previous response was invalid.\n" "Return exactly one valid JSON object.\n" "No explanation.\n" f"Example: {json.dumps(example, separators=(',', ':'))}\n" f"Previous response: {raw_output}" ) try: repaired = _request_action_completion( client, observation, repair_prompt, temperature=0.0, scenario_id=scenario_id, history=history or [], ) except Exception: return None return parse_action( repaired, observation, scenario_id=scenario_id, history=history or [], ) def get_model_action( client: OpenAI | None, observation: UnifiedIncidentObservation, history: list[dict[str, Any]], policy_state: PolicyCardState, *, scenario_id: str | None = None, ) -> tuple[UnifiedIncidentAction, str | None, bool, bool]: fallback = build_fallback_action(observation, history, scenario_id=scenario_id) mode = _inference_mode() if client is None: return fallback, "model_unavailable", False, True try: policy_card = ( build_policy_card( observation, policy_state, history, scenario_id=scenario_id, ) if mode == "small" else "" ) raw = _request_action_completion( client, observation, build_user_prompt( observation, policy_card, scenario_id=scenario_id, history=history, ), temperature=0.0, scenario_id=scenario_id, history=history, ) except Exception: return fallback, "model_request_failed", False, True parsed = parse_action( raw, observation, scenario_id=scenario_id, history=history, ) if parsed is None: repaired = attempt_repair( client, observation, raw, scenario_id=scenario_id, history=history, ) if repaired is not None: return repaired, "repair_retry_used", True, False return fallback, "invalid_model_output", True, True return parsed, None, False, False def run_scenario(client: OpenAI | None, scenario_id: str) -> dict[str, Any]: import time started = time.perf_counter() with UnifiedIncidentEnv(base_url=ENV_BASE_URL).sync() as env: observation = env.reset(scenario_id=scenario_id).observation history: list[dict[str, Any]] = [] rewards: list[float] = [] policy_state = PolicyCardState() repair_retry_count = 0 fallback_count = 0 log_start(task=scenario_id, env=ENV_NAME, model=MODEL_NAME) step = 0 while not observation.done: before = observation action, error, used_repair_retry, used_fallback = get_model_action( client, observation, history, policy_state, scenario_id=scenario_id, ) if used_repair_retry: repair_retry_count += 1 if used_fallback: fallback_count += 1 result = env.step(action) observation = result.observation reward = result.reward or 0.0 step += 1 rewards.append(reward) history.append( { "action": action, "reward": reward, "result": observation.last_action_result, "error": error, } ) if _inference_mode() == "small": update_policy_card( policy_state, before=before, action=action, after=observation, model_error=error, ) log_step( step=step, action=action_to_log_string(action), reward=reward, done=bool(result.done), error=error, ) success = bool( observation.done and observation.incident_resolved and observation.security_subquest_status == "completed" ) log_end( success=success, steps=step, rewards=rewards, ) return { "scenario_id": scenario_id, "score": observation.final_score, "success": success, "steps": step, "repair_retry_triggered": repair_retry_count > 0, "repair_retry_count": repair_retry_count, "fallback_triggered": fallback_count > 0, "fallback_count": fallback_count, "elapsed_s": round(time.perf_counter() - started, 4), } def main() -> None: client = create_client() for scenario_id in SCENARIOS: run_scenario(client, scenario_id) def _limit_words(text: str, *, max_words: int) -> str: words = text.split() if len(words) <= max_words: return text return " ".join(words[:max_words]).strip() + " ..." def _narrow_allowed_actions( observation: UnifiedIncidentObservation, *, scenario_id: str | None = None, history: list[dict[str, Any]] | None = None, ) -> list[str]: allowed_actions = observation.allowed_actions or sorted(KNOWN_ACTIONS) hard = _hard_transition_state( scenario_id=scenario_id, observation=observation, history=history or [], ) if hard["force_worker_rollback"] and "rollback_deploy" in allowed_actions: return ["rollback_deploy"] if hard["needs_unlock_bridge"] and "query_dependencies" in allowed_actions: return ["query_dependencies"] if hard["security_only"]: security_actions = [ action for action in allowed_actions if action in { "inspect_code", "classify_vulnerability", "apply_patch", "verify_security_fix", "submit_security_fix", } ] if security_actions: allowed_actions = security_actions if observation.workflow_stage not in {"security_subquest", "verification"}: return allowed_actions context = observation.security_context if not context.code_visible and "inspect_code" in allowed_actions: return ["inspect_code"] if context.code_visible and context.selected_vulnerability is None and "classify_vulnerability" in allowed_actions: return ["classify_vulnerability"] if context.selected_vulnerability is not None and context.selected_patch is None and "apply_patch" in allowed_actions: return ["apply_patch"] if ( context.selected_patch is not None and (context.exploit_blocked is not True or context.functionality_preserved is not True) and "verify_security_fix" in allowed_actions ): return ["verify_security_fix"] if ( context.exploit_blocked is True and context.functionality_preserved is True and "submit_security_fix" in allowed_actions ): return ["submit_security_fix"] return allowed_actions def _hard_transition_state( *, scenario_id: str | None, observation: UnifiedIncidentObservation, history: list[dict[str, Any]], ) -> dict[str, Any]: default = { "investigation_saturated": False, "stop_investigating": False, "stop_message": None, "next_required_action_family": None, "next_required_action": None, "needs_unlock_bridge": False, "security_only": False, "force_worker_rollback": False, } if scenario_id != "worker_bad_deploy_command_injection": return default worker_log_queries = sum( 1 for item in history if item.get("action") is not None and item["action"].action_type == "query_logs" and item["action"].service == "worker" ) support_queries = sum( 1 for item in history if item.get("action") is not None and ( (item["action"].action_type == "query_metrics" and item["action"].service in {"worker", "database"}) or (item["action"].action_type == "query_dependencies" and item["action"].service == "api-gateway") ) ) investigation_saturated = worker_log_queries >= 1 and (support_queries >= 1 or observation.workflow_stage != "diagnosis") security_completed = observation.security_subquest_status == "completed" security_unlocked = observation.security_subquest_status != "locked" worker_unhealthy = ( observation.service_health.get("worker") is not None and observation.service_health["worker"].status != "healthy" ) if security_completed and worker_unhealthy: return { **default, "investigation_saturated": True, "stop_investigating": True, "stop_message": "Investigation is complete. The bad worker deploy is still active. Choose rollback_deploy on worker next.", "next_required_action_family": "recovery", "next_required_action": "Next required action: rollback_deploy on worker.", "force_worker_rollback": True, } if investigation_saturated and not security_unlocked: return { **default, "investigation_saturated": True, "stop_investigating": True, "stop_message": "You already have enough evidence from worker investigation. Do not query worker logs again. Use query_dependencies on api-gateway to unlock the exploit path.", "next_required_action_family": "security", "next_required_action": "Next bridge action: query_dependencies on api-gateway, then move to security.", "needs_unlock_bridge": True, } if investigation_saturated and security_unlocked and not security_completed: return { **default, "investigation_saturated": True, "stop_investigating": True, "stop_message": "Repeated worker investigation is making no progress. Investigation is complete. Choose a security action now.", "next_required_action_family": "security", "next_required_action": "Current goal: inspect and patch the worker exploit path.", "security_only": True, } if worker_log_queries >= 2 and not security_completed: return { **default, "stop_investigating": True, "stop_message": "Repeated worker investigation is making no progress. Choose a different allowed action. Investigation is complete.", } return default if __name__ == "__main__": main()