import json import os import re from typing import Any, Iterator import httpx from server.constants import ( DEFAULT_BASELINE_TASK_ENUMS, NO_COMMAND_PROVIDED_SENTINEL, TASK_MAX_STEPS, TaskName, ) from server.models import Action, Observation, StepResult API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct") API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") ENV_URL = os.getenv("ENV_URL", "http://localhost:8000") BENCHMARK = "distributed-systems-debug-env" MAX_STEPS_CAP = int(os.getenv("MAX_STEPS", "0")) TEMPERATURE = float(os.getenv("TEMPERATURE", "0.2")) MAX_COMPLETION_TOKENS = int(os.getenv("MAX_COMPLETION_TOKENS", "2048")) _JSON_DECODER = json.JSONDecoder() def _chat_token_limit_kwargs() -> dict[str, int]: """OpenAI `gpt-5.*` / some models require `max_completion_tokens`, not `max_tokens`.""" override = os.getenv("CHAT_TOKEN_LIMIT_PARAM", "").strip().lower() if override == "max_tokens": return {"max_tokens": MAX_COMPLETION_TOKENS} if override == "max_completion_tokens": return {"max_completion_tokens": MAX_COMPLETION_TOKENS} base = API_BASE_URL or "" if "api.openai.com" in base: return {"max_completion_tokens": MAX_COMPLETION_TOKENS} return {"max_tokens": MAX_COMPLETION_TOKENS} SYSTEM_PROMPT = """You have bash access to a distributed job processing pipeline that is experiencing a failure. Use bash commands to investigate system behavior and narrow down likely fault conditions. Standard Unix tools are available: ps, ls, cat, grep, tail, curl, jq, redis-cli, kill, sed. Work iteratively across multiple steps; each response must provide the next bash command only. Respond with compact JSON where `command` is required: {"command":"","reasoning":"optional concise reason"}. No markdown. No explanation outside JSON.""" TASK_SYMPTOMS: dict[TaskName, tuple[str, ...]] = { TaskName.CASCADING_TIMEOUT: ( "Requests intermittently fail even when services appear up.", "Latency spikes sharply during traffic bursts.", ), TaskName.BYZANTINE_QUEUE_FAULT: ( "Worker throughput degrades after specific jobs enter the queue.", "Queue backlog grows despite workers being alive.", ), TaskName.DISTRIBUTED_LOCK_STARVATION: ( "One or more workers appear blocked for extended periods.", "Work completion remains low without full service outage.", ), TaskName.BACKPRESSURE_CASCADE: ( "Queue depth trends upward over time under steady load.", ), TaskName.ROUTE_PARTITION: ( "Gateway requests intermittently fail despite local process health.", "Signals point to a connectivity path issue rather than a full service outage.", ), TaskName.REGISTRY_CORRUPTION: ( "Gateway requests fail even though the gateway process is still healthy.", "Logs and config inspection suggest a bad upstream registry entry.", ), TaskName.JOB_GENERATOR_RUNAWAY: ( "Queue backlog grows while the worker stays alive.", "Producer pressure appears higher than the system can sustainably drain.", ), } class DistributedDebugEnvClient: def __init__(self, base_url: str) -> None: self._client = httpx.Client(base_url=base_url, timeout=45.0) def close(self) -> None: self._client.close() def reset(self, task_name: str) -> Observation: response = self._client.post("/reset", params={"task_name": task_name}) response.raise_for_status() return Observation.model_validate(response.json()) def step(self, action: Action) -> StepResult: response = self._client.post("/step", json=action.model_dump()) response.raise_for_status() return StepResult.model_validate(response.json()) def _parse_tasks() -> list[TaskName]: csv = os.getenv("TASKS_CSV", "").strip() if not csv: return list(DEFAULT_BASELINE_TASK_ENUMS) tasks: list[TaskName] = [] for value in csv.split(","): task_str = value.strip() if not task_str: continue tasks.append(TaskName.parse(task_str)) return tasks def _bool(value: bool) -> str: return "true" if value else "false" def _single_line(text: str) -> str: return " ".join(text.replace("\t", " ").splitlines()).strip() def _command_from_dict(payload: dict[str, Any]) -> tuple[str | None, str | None]: command_value = payload.get("command") command = command_value.strip() if isinstance(command_value, str) else "" if not command: return None, None reasoning_value = payload.get("reasoning") reasoning = reasoning_value.strip() if isinstance(reasoning_value, str) else "" return command, (reasoning or None) def _parse_action_payload(text: str) -> tuple[str | None, str | None]: try: payload = json.loads(text) except json.JSONDecodeError: return None, None if not isinstance(payload, dict): return None, None return _command_from_dict(payload) def _iter_decoded_json_objects(text: str) -> Iterator[Any]: i = 0 while i < len(text): if text[i] != "{": i += 1 continue try: obj, end = _JSON_DECODER.raw_decode(text, i) except json.JSONDecodeError: i += 1 continue yield obj i = end def _assistant_message_text(message: Any) -> str: content = getattr(message, "content", None) if content is None: return "" if isinstance(content, str): return content if isinstance(content, list): parts: list[str] = [] for part in content: if isinstance(part, dict): text_val = part.get("text") if text_val is not None: parts.append(str(text_val)) else: text_attr = getattr(part, "text", None) if text_attr is not None: parts.append(str(text_attr)) return "\n".join(parts) return str(content) def extract_action_payload(llm_response: str) -> tuple[str | None, str | None]: response = llm_response.strip() if not response: return None, None if response.startswith("```"): lines = response.split("\n") if len(lines) > 2: response = "\n".join(lines[1:-1]).strip() direct_command, direct_reasoning = _parse_action_payload(response) if direct_command: return direct_command, direct_reasoning for obj in _iter_decoded_json_objects(response): if isinstance(obj, dict): embedded_command, embedded_reasoning = _command_from_dict(obj) if embedded_command: return embedded_command, embedded_reasoning for match in re.finditer(r"\{[^{}]*\}", response, flags=re.DOTALL): embedded_command, embedded_reasoning = _parse_action_payload(match.group(0)) if embedded_command: return embedded_command, embedded_reasoning first_line = response.split("\n")[0].strip() return _parse_action_payload(first_line) def extract_command(llm_response: str) -> str | None: return extract_action_payload(llm_response)[0] def extract_reasoning(llm_response: str) -> str | None: return extract_action_payload(llm_response)[1] def _sanitize_reasoning_for_step(reasoning: str) -> str: sanitized = _single_line(reasoning) sanitized = sanitized.replace(" reward=", " reward:") sanitized = sanitized.replace(" done=", " done:") sanitized = sanitized.replace(" error=", " error:") return sanitized[:160] def _format_step_action(command: str, reasoning: str | None) -> str: action = _single_line(command) if not reasoning: return action sanitized_reasoning = _sanitize_reasoning_for_step(reasoning) if not sanitized_reasoning: return action return f"{action} | reasoning={sanitized_reasoning}" def _episode_score(rewards: list[float]) -> float: # Score is terminal task progress signal and must stay normalized for evaluator checks. if not rewards: return 0.01 return max(0.01, min(0.99, float(rewards[-1]))) def _format_end_line( *, success: bool, steps: int, score: float, rewards: list[float] ) -> str: rewards_csv = ",".join(f"{reward:.2f}" for reward in rewards) return ( f"[END] success={_bool(success)} steps={steps} " f"score={score:.2f} rewards={rewards_csv}" ) def _task_symptom_block(task_name: TaskName) -> str: return "\n".join(f"- {symptom}" for symptom in TASK_SYMPTOMS[task_name]) def _attempt_history_block(attempt_history: list[dict[str, Any]]) -> str: if not attempt_history: return "- none" lines: list[str] = [] for attempt in attempt_history: command = _single_line(str(attempt["command"]))[:120] reasoning = _single_line(str(attempt.get("reasoning") or ""))[:120] output_preview = _single_line(str(attempt.get("output") or ""))[:140] error = attempt.get("error") error_text = _single_line(str(error))[:80] if error else "none" line = f"- step {attempt['step']}: command={command}; error={error_text}" if reasoning: line = f"{line}; reasoning={reasoning}" if output_preview: line = f"{line}; output={output_preview}" lines.append(line) return "\n".join(lines) def build_prompt( obs: Observation, step_num: int, task_name: TaskName, attempt_history: list[dict[str, Any]], ) -> str: return ( f"Step {step_num}. Current system state:\n\n" "TASK SYMPTOMS:\n" f"{_task_symptom_block(task_name)}\n\n" "PREVIOUS ATTEMPTS:\n" f"{_attempt_history_block(attempt_history)}\n\n" "METRICS:\n" f"- Gateway success rate: {obs.metrics.gateway_success_rate:.1%}\n" f"- Gateway P99 latency: {obs.metrics.gateway_p99_latency_ms:.0f}ms\n" f"- Queue depth: {obs.metrics.queue_depth}\n" f"- Worker restarts: {obs.metrics.worker_restart_count}\n" f"- Consumer stall count: {obs.metrics.consumer_stall_count}\n\n" "SERVICE STATUS:\n" f"{obs.process_status}\n\n" "LATEST COMMAND OUTPUT:\n" f"{obs.command_output[:2000]}\n\n" "Solve this over multiple steps as needed. For this step, return only the single next bash command.\n" 'Respond with compact JSON where command is required: {"command":"","reasoning":"optional concise reason"}.' ) def _run_episode( client: Any, env: DistributedDebugEnvClient, task_name: TaskName ) -> None: messages: list[dict[str, str]] = [{"role": "system", "content": SYSTEM_PROMPT}] rewards: list[float] = [] done = False step = 0 last_error: str | None = None attempt_history: list[dict[str, Any]] = [] print( f"[START] task={task_name.value} env={BENCHMARK} model={MODEL_NAME}", flush=True ) task_budget = TASK_MAX_STEPS[task_name] max_steps = min(task_budget, MAX_STEPS_CAP) if MAX_STEPS_CAP > 0 else task_budget try: obs = env.reset(task_name=task_name.value) while not done and step < max_steps: next_step = step + 1 user_prompt = build_prompt(obs, next_step, task_name, attempt_history) messages.append({"role": "user", "content": user_prompt}) completion = client.chat.completions.create( model=MODEL_NAME, messages=messages, temperature=TEMPERATURE, **_chat_token_limit_kwargs(), ) raw_response = _assistant_message_text(completion.choices[0].message) command, reasoning = extract_action_payload(raw_response) if not command: messages.append({"role": "assistant", "content": raw_response}) messages.append( { "role": "user", "content": ( "No command was provided. Respond with compact JSON where command is required: " '{"command":"","reasoning":"optional concise reason"}.' ), } ) command = NO_COMMAND_PROVIDED_SENTINEL reasoning = None else: assistant_payload: dict[str, str] = {"command": command} if reasoning: assistant_payload["reasoning"] = reasoning messages.append( {"role": "assistant", "content": json.dumps(assistant_payload)} ) result = env.step(Action(command=command)) obs = result.observation rewards.append(result.reward) done = result.done error_value = result.info.get("error") last_error = None if error_value in (None, "", "None") else str(error_value) error_field = "null" if last_error is None else _single_line(last_error) attempt_history.append( { "step": next_step, "command": command, "reasoning": reasoning, "output": obs.command_output, "error": last_error, } ) print( f"[STEP] step={next_step} action={_format_step_action(command, reasoning)} " f"reward={result.reward:.2f} done={_bool(done)} error={error_field}", flush=True, ) step = next_step except Exception as exc: last_error = str(exc) print( f"[ERROR] task={task_name.value} {type(exc).__name__}: {exc}", flush=True, ) finally: score = _episode_score(rewards) success = bool(done and score >= 0.95) print( _format_end_line(success=success, steps=step, score=score, rewards=rewards), flush=True, ) def main() -> None: if not API_KEY: raise RuntimeError("HF_TOKEN (or API_KEY) must be set") tasks = _parse_tasks() from openai import OpenAI client = OpenAI( api_key=API_KEY, base_url=API_BASE_URL, timeout=30.0, max_retries=2, ) env = DistributedDebugEnvClient(base_url=ENV_URL) try: for task_name in tasks: _run_episode(client, env, task_name) finally: env.close() if __name__ == "__main__": main()