Veer15's picture
Upload folder using huggingface_hub
3d75167 verified
import json
import os
import subprocess
import time
from pathlib import Path
from typing import Any
from .constants import (
DEFAULT_CONFIGS,
NO_COMMAND_PROVIDED_SENTINEL,
TASK_MAX_STEPS,
TaskName,
)
from .fault_injector import inject_fault
from .graders import grade_task
from .metrics_poller import MetricsPoller
from .models import Action, Observation, StepResult
from .process_manager import ProcessManager
class DistributedDebugEnv:
"""OpenEnv-compatible distributed systems debugging environment."""
def __init__(
self, project_root: Path | None = None, mesh_root: Path | None = None
) -> None:
self.project_root = (
project_root or Path(__file__).resolve().parent.parent
).resolve()
self.mesh_root = (
mesh_root or Path(os.getenv("MESH_ROOT", self.project_root / "mesh"))
).resolve()
self._process_manager = ProcessManager(
project_root=self.project_root, mesh_root=self.mesh_root
)
self._metrics_poller = MetricsPoller(poll_interval_s=2.0)
self.current_task: TaskName | None = None
self.max_steps: int = 0
self.step_count: int = 0
self.last_exit_code: int = 0
self.prev_observation: Observation | None = None
self._baselines: dict[str, int] = {
"baseline_worker_restart_count": 0,
"baseline_consumer_stall_count": 0,
}
self._seen_diagnostic_signatures: set[str] = set()
self._command_counts: dict[str, int] = {}
self._last_grader_score: float = 0.0
def start(self) -> None:
if not self._metrics_poller.is_alive():
self._metrics_poller.start()
def close(self) -> None:
self._metrics_poller.stop()
def _write_json(self, path: Path, payload: dict[str, Any]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(payload, indent=2) + "\n", encoding="utf-8")
def _restore_defaults(self) -> None:
self._write_json(
self.mesh_root / "registry.json",
{
"services": {
"auth": {"host": "localhost", "port": 3001, "protocol": "http"},
"redis": {"host": "localhost", "port": 6379, "protocol": "tcp"},
"worker": {
"host": "localhost",
"port": None,
"protocol": "internal",
},
}
},
)
self._write_json(
self.mesh_root / "auth" / "config.json", DEFAULT_CONFIGS["auth"]
)
self._write_json(
self.mesh_root / "gateway" / "config.json", DEFAULT_CONFIGS["gateway"]
)
self._write_json(
self.mesh_root / "gateway" / "blocked_routes.json",
DEFAULT_CONFIGS["blocked_routes"],
)
self._write_json(
self.mesh_root / "worker" / "config.json", DEFAULT_CONFIGS["worker"]
)
self._write_json(
self.mesh_root / "worker" / "job_generator_config.json",
DEFAULT_CONFIGS["job_generator"],
)
def _truncate_logs(self) -> None:
for service in ["gateway", "auth", "worker", "job_gen"]:
Path(f"/tmp/{service}.log").write_text("", encoding="utf-8")
def _reset_runtime_counters(self) -> None:
Path("/tmp/worker_restart_count").write_text("0", encoding="utf-8")
Path("/tmp/consumer_stall_count").write_text("0", encoding="utf-8")
def _redis_flush(self) -> None:
subprocess.run(
["redis-cli", "FLUSHDB"], check=True, capture_output=True, text=True
)
def _read_float(self, value: str, default: float = 0.0) -> float:
try:
return float(value)
except (TypeError, ValueError):
return default
def _is_route_blocked(self) -> bool:
blocked_file = self.mesh_root / "gateway" / "blocked_routes.json"
try:
payload = json.loads(blocked_file.read_text(encoding="utf-8"))
blocked = payload.get("blocked", [])
return "gateway->redis" in blocked
except Exception:
return False
def _is_lock_present(self) -> bool:
result = subprocess.run(
["redis-cli", "EXISTS", "LOCK:job_processor"],
capture_output=True,
text=True,
timeout=2,
check=False,
)
return result.stdout.strip() == "1"
def _is_cascading_timeout_resolved(self) -> bool:
auth_config_file = self.mesh_root / "auth" / "config.json"
gateway_config_file = self.mesh_root / "gateway" / "config.json"
try:
auth_payload = json.loads(auth_config_file.read_text(encoding="utf-8"))
gateway_payload = json.loads(
gateway_config_file.read_text(encoding="utf-8")
)
except Exception:
return False
auth_delay_ms = self._read_float(auth_payload.get("delay_ms"), default=0.0)
auth_timeout_ms = self._read_float(
gateway_payload.get("auth_timeout_ms"), default=0.0
)
if auth_timeout_ms <= 0:
return False
return auth_delay_ms <= auth_timeout_ms
def _is_registry_auth_default(self) -> bool:
registry_file = self.mesh_root / "registry.json"
try:
payload = json.loads(registry_file.read_text(encoding="utf-8"))
auth_service = payload["services"]["auth"]
except Exception:
return False
return (
auth_service.get("host") == "localhost"
and int(auth_service.get("port", 0)) == 3001
and auth_service.get("protocol") == "http"
)
def _job_generator_interval_ms(self) -> int:
config_file = self.mesh_root / "worker" / "job_generator_config.json"
try:
payload = json.loads(config_file.read_text(encoding="utf-8"))
except Exception:
return 0
try:
return int(payload.get("interval_ms", 0))
except (TypeError, ValueError):
return 0
def _is_job_generator_rate_resolved(self) -> bool:
return self._job_generator_interval_ms() >= int(
DEFAULT_CONFIGS["job_generator"]["interval_ms"]
)
def _build_grader_context(self) -> dict[str, Any]:
return {
**self._baselines,
"route_blocked": self._is_route_blocked(),
"lock_exists": self._is_lock_present(),
"cascading_timeout_resolved": self._is_cascading_timeout_resolved(),
"registry_auth_matches_default": self._is_registry_auth_default(),
"job_generator_interval_ms": self._job_generator_interval_ms(),
"job_generator_rate_resolved": self._is_job_generator_rate_resolved(),
}
def _blocked_command(self, command: str) -> bool:
dangerous_patterns = [
"rm -rf /",
"kill -9 1",
"pkill -f uvicorn",
"> /tmp/gateway.log",
"> /tmp/auth.log",
"> /tmp/worker.log",
]
normalized = command.strip().lower()
return any(pattern in normalized for pattern in dangerous_patterns)
def _run_command(self, command: str) -> tuple[str, str | None]:
if command.strip() == NO_COMMAND_PROVIDED_SENTINEL:
self.last_exit_code = 2
return (
"No command provided by model. Expected JSON with a command field.",
"no_command_provided",
)
if self._blocked_command(command):
self.last_exit_code = 1
return (
"BLOCKED: This command would damage the environment infrastructure.",
"blocked_command",
)
try:
result = subprocess.run(
command,
shell=True,
capture_output=True,
text=True,
timeout=10,
cwd="/",
env={
**os.environ,
"PATH": "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin",
},
check=False,
)
self.last_exit_code = result.returncode
output = (result.stdout + result.stderr).strip() or "(no output)"
return output, None
except subprocess.TimeoutExpired:
self.last_exit_code = 124
return "Command timed out after 10 seconds.", "timeout"
except Exception as exc:
self.last_exit_code = 1
return f"Command execution error: {exc}", str(exc)
def _command_signature(self, command: str) -> str:
return " ".join(command.strip().lower().split())
def _is_diagnostic_command(self, command: str) -> bool:
diagnostic_keywords = [
"cat",
"curl",
"redis-cli",
"ps",
"ls",
"grep",
"tail",
"jq",
"lrange",
"llen",
"keys",
"ttl",
"get",
]
normalized = command.lower()
return any(keyword in normalized for keyword in diagnostic_keywords)
def _is_state_change_command(self, command: str) -> bool:
normalized = command.lower()
state_change_patterns = [
"kill -hup",
"redis-cli del",
"redis-cli lrem",
"redis-cli set",
"redis-cli flushdb",
"echo '{",
"> /mesh/",
"tee /mesh/",
]
return any(pattern in normalized for pattern in state_change_patterns)
def _compute_reward(
self,
command: str,
current: Observation,
previous: Observation,
grader_score: float,
previous_grader_score: float,
command_error: str | None,
) -> float:
if command_error == "no_command_provided":
return 0.01
if grader_score >= 0.95:
return 0.99
reward = grader_score * 0.75
signature = self._command_signature(command)
signature_count = self._command_counts.get(signature, 0) + 1
self._command_counts[signature] = signature_count
if (
self._is_diagnostic_command(command)
and signature not in self._seen_diagnostic_signatures
):
reward += 0.02
self._seen_diagnostic_signatures.add(signature)
if self._is_state_change_command(command):
reward += 0.03
if grader_score > previous_grader_score + 1e-4:
reward += 0.15
else:
reward -= 0.05
if (
current.metrics.gateway_success_rate
> previous.metrics.gateway_success_rate + 1e-3
):
reward += 0.05
if current.metrics.queue_depth < previous.metrics.queue_depth:
reward += 0.05
if current.metrics.worker_restart_count < previous.metrics.worker_restart_count:
reward += 0.03
if current.metrics.consumer_stall_count < previous.metrics.consumer_stall_count:
reward += 0.03
if signature_count > 1:
reward -= min(0.12, 0.04 * (signature_count - 1))
if command.strip().lower() in {
"echo",
"pwd",
"whoami",
"date",
"true",
"false",
}:
reward -= 0.08
if self.last_exit_code != 0 and command_error not in {
"blocked_command",
"no_command_provided",
}:
reward -= 0.08
if command_error == "blocked_command":
reward -= 0.25
return max(0.01, min(0.99, reward))
def _status_block(self, metrics: Any) -> str:
return (
"=== pipeline status after reset ===\n"
"gateway: running\n"
"auth: running\n"
"worker: running\n"
f"queue_depth: {metrics.queue_depth}\n"
f"gateway_success_rate: {metrics.gateway_success_rate:.2f}"
)
def reset(self, task_name: TaskName | str) -> Observation:
task = TaskName.parse(task_name) if isinstance(task_name, str) else task_name
self.current_task = task
self.max_steps = TASK_MAX_STEPS[task]
self.step_count = 0
self._seen_diagnostic_signatures = set()
self._command_counts = {}
self._last_grader_score = 0.0
self._truncate_logs()
self._restore_defaults()
self._redis_flush()
self._reset_runtime_counters()
Path("/tmp/current_task").write_text(task.value, encoding="utf-8")
self._process_manager.restart_all()
if not self._process_manager.wait_healthy(timeout_s=30):
raise RuntimeError("Services failed health checks after reset")
inject_fault(task, self._process_manager)
time.sleep(1.0)
self._metrics_poller.poll_once()
metrics = self._metrics_poller.get_current_metrics()
self._baselines = {
"baseline_worker_restart_count": metrics.worker_restart_count,
"baseline_consumer_stall_count": metrics.consumer_stall_count,
}
self._last_grader_score = grade_task(
task, metrics, self._build_grader_context()
)
observation = Observation(
command_output=self._status_block(metrics),
metrics=metrics,
process_status=self._process_manager.get_status(),
)
self.prev_observation = observation
return observation
def step(self, action: Action) -> StepResult:
if not self.current_task:
raise RuntimeError(
"Environment not initialized. Call reset(task_name) first."
)
self.step_count += 1
command_output, command_error = self._run_command(action.command)
self._metrics_poller.poll_once()
metrics = self._metrics_poller.get_current_metrics()
observation = Observation(
command_output=command_output,
metrics=metrics,
process_status=self._process_manager.get_status(),
)
previous = self.prev_observation or observation
previous_grader_score = self._last_grader_score
grader_score = grade_task(
self.current_task, metrics, self._build_grader_context()
)
reward = self._compute_reward(
action.command,
observation,
previous,
grader_score,
previous_grader_score,
command_error,
)
if command_error == "no_command_provided":
done = self.step_count >= self.max_steps
else:
done = grader_score >= 0.95 or self.step_count >= self.max_steps
self._last_grader_score = grader_score
self.prev_observation = observation
info: dict[str, Any] = {
"grader_score": round(grader_score, 4),
"error": command_error,
"exit_code": self.last_exit_code,
"task": self.current_task.value if self.current_task else None,
}
return StepResult(observation=observation, reward=reward, done=done, info=info)
def state(self) -> dict[str, Any]:
self._metrics_poller.poll_once()
metrics = self._metrics_poller.get_current_metrics()
return {
"task": self.current_task.value if self.current_task else None,
"step_count": self.step_count,
"max_steps": self.max_steps,
"metrics": metrics.model_dump(),
"process_status": self._process_manager.get_status(),
"baselines": dict(self._baselines),
}