Veer15's picture
Upload folder using huggingface_hub
3d75167 verified
import json
import os
import re
from typing import Any, Iterator
import httpx
from server.constants import (
DEFAULT_BASELINE_TASK_ENUMS,
NO_COMMAND_PROVIDED_SENTINEL,
TASK_MAX_STEPS,
TaskName,
)
from server.models import Action, Observation, StepResult
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
ENV_URL = os.getenv("ENV_URL", "http://localhost:8000")
BENCHMARK = "distributed-systems-debug-env"
MAX_STEPS_CAP = int(os.getenv("MAX_STEPS", "0"))
TEMPERATURE = float(os.getenv("TEMPERATURE", "0.2"))
MAX_COMPLETION_TOKENS = int(os.getenv("MAX_COMPLETION_TOKENS", "2048"))
_JSON_DECODER = json.JSONDecoder()
def _chat_token_limit_kwargs() -> dict[str, int]:
"""OpenAI `gpt-5.*` / some models require `max_completion_tokens`, not `max_tokens`."""
override = os.getenv("CHAT_TOKEN_LIMIT_PARAM", "").strip().lower()
if override == "max_tokens":
return {"max_tokens": MAX_COMPLETION_TOKENS}
if override == "max_completion_tokens":
return {"max_completion_tokens": MAX_COMPLETION_TOKENS}
base = API_BASE_URL or ""
if "api.openai.com" in base:
return {"max_completion_tokens": MAX_COMPLETION_TOKENS}
return {"max_tokens": MAX_COMPLETION_TOKENS}
SYSTEM_PROMPT = """You have bash access to a distributed job processing pipeline that is experiencing a failure.
Use bash commands to investigate system behavior and narrow down likely fault conditions.
Standard Unix tools are available: ps, ls, cat, grep, tail, curl, jq, redis-cli, kill, sed.
Work iteratively across multiple steps; each response must provide the next bash command only.
Respond with compact JSON where `command` is required: {"command":"<bash command>","reasoning":"optional concise reason"}.
No markdown. No explanation outside JSON."""
TASK_SYMPTOMS: dict[TaskName, tuple[str, ...]] = {
TaskName.CASCADING_TIMEOUT: (
"Requests intermittently fail even when services appear up.",
"Latency spikes sharply during traffic bursts.",
),
TaskName.BYZANTINE_QUEUE_FAULT: (
"Worker throughput degrades after specific jobs enter the queue.",
"Queue backlog grows despite workers being alive.",
),
TaskName.DISTRIBUTED_LOCK_STARVATION: (
"One or more workers appear blocked for extended periods.",
"Work completion remains low without full service outage.",
),
TaskName.BACKPRESSURE_CASCADE: (
"Queue depth trends upward over time under steady load.",
),
TaskName.ROUTE_PARTITION: (
"Gateway requests intermittently fail despite local process health.",
"Signals point to a connectivity path issue rather than a full service outage.",
),
TaskName.REGISTRY_CORRUPTION: (
"Gateway requests fail even though the gateway process is still healthy.",
"Logs and config inspection suggest a bad upstream registry entry.",
),
TaskName.JOB_GENERATOR_RUNAWAY: (
"Queue backlog grows while the worker stays alive.",
"Producer pressure appears higher than the system can sustainably drain.",
),
}
class DistributedDebugEnvClient:
def __init__(self, base_url: str) -> None:
self._client = httpx.Client(base_url=base_url, timeout=45.0)
def close(self) -> None:
self._client.close()
def reset(self, task_name: str) -> Observation:
response = self._client.post("/reset", params={"task_name": task_name})
response.raise_for_status()
return Observation.model_validate(response.json())
def step(self, action: Action) -> StepResult:
response = self._client.post("/step", json=action.model_dump())
response.raise_for_status()
return StepResult.model_validate(response.json())
def _parse_tasks() -> list[TaskName]:
csv = os.getenv("TASKS_CSV", "").strip()
if not csv:
return list(DEFAULT_BASELINE_TASK_ENUMS)
tasks: list[TaskName] = []
for value in csv.split(","):
task_str = value.strip()
if not task_str:
continue
tasks.append(TaskName.parse(task_str))
return tasks
def _bool(value: bool) -> str:
return "true" if value else "false"
def _single_line(text: str) -> str:
return " ".join(text.replace("\t", " ").splitlines()).strip()
def _command_from_dict(payload: dict[str, Any]) -> tuple[str | None, str | None]:
command_value = payload.get("command")
command = command_value.strip() if isinstance(command_value, str) else ""
if not command:
return None, None
reasoning_value = payload.get("reasoning")
reasoning = reasoning_value.strip() if isinstance(reasoning_value, str) else ""
return command, (reasoning or None)
def _parse_action_payload(text: str) -> tuple[str | None, str | None]:
try:
payload = json.loads(text)
except json.JSONDecodeError:
return None, None
if not isinstance(payload, dict):
return None, None
return _command_from_dict(payload)
def _iter_decoded_json_objects(text: str) -> Iterator[Any]:
i = 0
while i < len(text):
if text[i] != "{":
i += 1
continue
try:
obj, end = _JSON_DECODER.raw_decode(text, i)
except json.JSONDecodeError:
i += 1
continue
yield obj
i = end
def _assistant_message_text(message: Any) -> str:
content = getattr(message, "content", None)
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for part in content:
if isinstance(part, dict):
text_val = part.get("text")
if text_val is not None:
parts.append(str(text_val))
else:
text_attr = getattr(part, "text", None)
if text_attr is not None:
parts.append(str(text_attr))
return "\n".join(parts)
return str(content)
def extract_action_payload(llm_response: str) -> tuple[str | None, str | None]:
response = llm_response.strip()
if not response:
return None, None
if response.startswith("```"):
lines = response.split("\n")
if len(lines) > 2:
response = "\n".join(lines[1:-1]).strip()
direct_command, direct_reasoning = _parse_action_payload(response)
if direct_command:
return direct_command, direct_reasoning
for obj in _iter_decoded_json_objects(response):
if isinstance(obj, dict):
embedded_command, embedded_reasoning = _command_from_dict(obj)
if embedded_command:
return embedded_command, embedded_reasoning
for match in re.finditer(r"\{[^{}]*\}", response, flags=re.DOTALL):
embedded_command, embedded_reasoning = _parse_action_payload(match.group(0))
if embedded_command:
return embedded_command, embedded_reasoning
first_line = response.split("\n")[0].strip()
return _parse_action_payload(first_line)
def extract_command(llm_response: str) -> str | None:
return extract_action_payload(llm_response)[0]
def extract_reasoning(llm_response: str) -> str | None:
return extract_action_payload(llm_response)[1]
def _sanitize_reasoning_for_step(reasoning: str) -> str:
sanitized = _single_line(reasoning)
sanitized = sanitized.replace(" reward=", " reward:")
sanitized = sanitized.replace(" done=", " done:")
sanitized = sanitized.replace(" error=", " error:")
return sanitized[:160]
def _format_step_action(command: str, reasoning: str | None) -> str:
action = _single_line(command)
if not reasoning:
return action
sanitized_reasoning = _sanitize_reasoning_for_step(reasoning)
if not sanitized_reasoning:
return action
return f"{action} | reasoning={sanitized_reasoning}"
def _episode_score(rewards: list[float]) -> float:
# Score is terminal task progress signal and must stay normalized for evaluator checks.
if not rewards:
return 0.01
return max(0.01, min(0.99, float(rewards[-1])))
def _format_end_line(
*, success: bool, steps: int, score: float, rewards: list[float]
) -> str:
rewards_csv = ",".join(f"{reward:.2f}" for reward in rewards)
return (
f"[END] success={_bool(success)} steps={steps} "
f"score={score:.2f} rewards={rewards_csv}"
)
def _task_symptom_block(task_name: TaskName) -> str:
return "\n".join(f"- {symptom}" for symptom in TASK_SYMPTOMS[task_name])
def _attempt_history_block(attempt_history: list[dict[str, Any]]) -> str:
if not attempt_history:
return "- none"
lines: list[str] = []
for attempt in attempt_history:
command = _single_line(str(attempt["command"]))[:120]
reasoning = _single_line(str(attempt.get("reasoning") or ""))[:120]
output_preview = _single_line(str(attempt.get("output") or ""))[:140]
error = attempt.get("error")
error_text = _single_line(str(error))[:80] if error else "none"
line = f"- step {attempt['step']}: command={command}; error={error_text}"
if reasoning:
line = f"{line}; reasoning={reasoning}"
if output_preview:
line = f"{line}; output={output_preview}"
lines.append(line)
return "\n".join(lines)
def build_prompt(
obs: Observation,
step_num: int,
task_name: TaskName,
attempt_history: list[dict[str, Any]],
) -> str:
return (
f"Step {step_num}. Current system state:\n\n"
"TASK SYMPTOMS:\n"
f"{_task_symptom_block(task_name)}\n\n"
"PREVIOUS ATTEMPTS:\n"
f"{_attempt_history_block(attempt_history)}\n\n"
"METRICS:\n"
f"- Gateway success rate: {obs.metrics.gateway_success_rate:.1%}\n"
f"- Gateway P99 latency: {obs.metrics.gateway_p99_latency_ms:.0f}ms\n"
f"- Queue depth: {obs.metrics.queue_depth}\n"
f"- Worker restarts: {obs.metrics.worker_restart_count}\n"
f"- Consumer stall count: {obs.metrics.consumer_stall_count}\n\n"
"SERVICE STATUS:\n"
f"{obs.process_status}\n\n"
"LATEST COMMAND OUTPUT:\n"
f"{obs.command_output[:2000]}\n\n"
"Solve this over multiple steps as needed. For this step, return only the single next bash command.\n"
'Respond with compact JSON where command is required: {"command":"<bash command>","reasoning":"optional concise reason"}.'
)
def _run_episode(
client: Any, env: DistributedDebugEnvClient, task_name: TaskName
) -> None:
messages: list[dict[str, str]] = [{"role": "system", "content": SYSTEM_PROMPT}]
rewards: list[float] = []
done = False
step = 0
last_error: str | None = None
attempt_history: list[dict[str, Any]] = []
print(
f"[START] task={task_name.value} env={BENCHMARK} model={MODEL_NAME}", flush=True
)
task_budget = TASK_MAX_STEPS[task_name]
max_steps = min(task_budget, MAX_STEPS_CAP) if MAX_STEPS_CAP > 0 else task_budget
try:
obs = env.reset(task_name=task_name.value)
while not done and step < max_steps:
next_step = step + 1
user_prompt = build_prompt(obs, next_step, task_name, attempt_history)
messages.append({"role": "user", "content": user_prompt})
completion = client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
temperature=TEMPERATURE,
**_chat_token_limit_kwargs(),
)
raw_response = _assistant_message_text(completion.choices[0].message)
command, reasoning = extract_action_payload(raw_response)
if not command:
messages.append({"role": "assistant", "content": raw_response})
messages.append(
{
"role": "user",
"content": (
"No command was provided. Respond with compact JSON where command is required: "
'{"command":"<bash command>","reasoning":"optional concise reason"}.'
),
}
)
command = NO_COMMAND_PROVIDED_SENTINEL
reasoning = None
else:
assistant_payload: dict[str, str] = {"command": command}
if reasoning:
assistant_payload["reasoning"] = reasoning
messages.append(
{"role": "assistant", "content": json.dumps(assistant_payload)}
)
result = env.step(Action(command=command))
obs = result.observation
rewards.append(result.reward)
done = result.done
error_value = result.info.get("error")
last_error = None if error_value in (None, "", "None") else str(error_value)
error_field = "null" if last_error is None else _single_line(last_error)
attempt_history.append(
{
"step": next_step,
"command": command,
"reasoning": reasoning,
"output": obs.command_output,
"error": last_error,
}
)
print(
f"[STEP] step={next_step} action={_format_step_action(command, reasoning)} "
f"reward={result.reward:.2f} done={_bool(done)} error={error_field}",
flush=True,
)
step = next_step
except Exception as exc:
last_error = str(exc)
print(
f"[ERROR] task={task_name.value} {type(exc).__name__}: {exc}",
flush=True,
)
finally:
score = _episode_score(rewards)
success = bool(done and score >= 0.95)
print(
_format_end_line(success=success, steps=step, score=score, rewards=rewards),
flush=True,
)
def main() -> None:
if not API_KEY:
raise RuntimeError("HF_TOKEN (or API_KEY) must be set")
tasks = _parse_tasks()
from openai import OpenAI
client = OpenAI(
api_key=API_KEY,
base_url=API_BASE_URL,
timeout=30.0,
max_retries=2,
)
env = DistributedDebugEnvClient(base_url=ENV_URL)
try:
for task_name in tasks:
_run_episode(client, env, task_name)
finally:
env.close()
if __name__ == "__main__":
main()