import json import os import subprocess import time from pathlib import Path from typing import Any from .constants import ( DEFAULT_CONFIGS, NO_COMMAND_PROVIDED_SENTINEL, TASK_MAX_STEPS, TaskName, ) from .fault_injector import inject_fault from .graders import grade_task from .metrics_poller import MetricsPoller from .models import Action, Observation, StepResult from .process_manager import ProcessManager class DistributedDebugEnv: """OpenEnv-compatible distributed systems debugging environment.""" def __init__( self, project_root: Path | None = None, mesh_root: Path | None = None ) -> None: self.project_root = ( project_root or Path(__file__).resolve().parent.parent ).resolve() self.mesh_root = ( mesh_root or Path(os.getenv("MESH_ROOT", self.project_root / "mesh")) ).resolve() self._process_manager = ProcessManager( project_root=self.project_root, mesh_root=self.mesh_root ) self._metrics_poller = MetricsPoller(poll_interval_s=2.0) self.current_task: TaskName | None = None self.max_steps: int = 0 self.step_count: int = 0 self.last_exit_code: int = 0 self.prev_observation: Observation | None = None self._baselines: dict[str, int] = { "baseline_worker_restart_count": 0, "baseline_consumer_stall_count": 0, } self._seen_diagnostic_signatures: set[str] = set() self._command_counts: dict[str, int] = {} self._last_grader_score: float = 0.0 def start(self) -> None: if not self._metrics_poller.is_alive(): self._metrics_poller.start() def close(self) -> None: self._metrics_poller.stop() def _write_json(self, path: Path, payload: dict[str, Any]) -> None: path.parent.mkdir(parents=True, exist_ok=True) path.write_text(json.dumps(payload, indent=2) + "\n", encoding="utf-8") def _restore_defaults(self) -> None: self._write_json( self.mesh_root / "registry.json", { "services": { "auth": {"host": "localhost", "port": 3001, "protocol": "http"}, "redis": {"host": "localhost", "port": 6379, "protocol": "tcp"}, "worker": { "host": "localhost", "port": None, "protocol": "internal", }, } }, ) self._write_json( self.mesh_root / "auth" / "config.json", DEFAULT_CONFIGS["auth"] ) self._write_json( self.mesh_root / "gateway" / "config.json", DEFAULT_CONFIGS["gateway"] ) self._write_json( self.mesh_root / "gateway" / "blocked_routes.json", DEFAULT_CONFIGS["blocked_routes"], ) self._write_json( self.mesh_root / "worker" / "config.json", DEFAULT_CONFIGS["worker"] ) self._write_json( self.mesh_root / "worker" / "job_generator_config.json", DEFAULT_CONFIGS["job_generator"], ) def _truncate_logs(self) -> None: for service in ["gateway", "auth", "worker", "job_gen"]: Path(f"/tmp/{service}.log").write_text("", encoding="utf-8") def _reset_runtime_counters(self) -> None: Path("/tmp/worker_restart_count").write_text("0", encoding="utf-8") Path("/tmp/consumer_stall_count").write_text("0", encoding="utf-8") def _redis_flush(self) -> None: subprocess.run( ["redis-cli", "FLUSHDB"], check=True, capture_output=True, text=True ) def _read_float(self, value: str, default: float = 0.0) -> float: try: return float(value) except (TypeError, ValueError): return default def _is_route_blocked(self) -> bool: blocked_file = self.mesh_root / "gateway" / "blocked_routes.json" try: payload = json.loads(blocked_file.read_text(encoding="utf-8")) blocked = payload.get("blocked", []) return "gateway->redis" in blocked except Exception: return False def _is_lock_present(self) -> bool: result = subprocess.run( ["redis-cli", "EXISTS", "LOCK:job_processor"], capture_output=True, text=True, timeout=2, check=False, ) return result.stdout.strip() == "1" def _is_cascading_timeout_resolved(self) -> bool: auth_config_file = self.mesh_root / "auth" / "config.json" gateway_config_file = self.mesh_root / "gateway" / "config.json" try: auth_payload = json.loads(auth_config_file.read_text(encoding="utf-8")) gateway_payload = json.loads( gateway_config_file.read_text(encoding="utf-8") ) except Exception: return False auth_delay_ms = self._read_float(auth_payload.get("delay_ms"), default=0.0) auth_timeout_ms = self._read_float( gateway_payload.get("auth_timeout_ms"), default=0.0 ) if auth_timeout_ms <= 0: return False return auth_delay_ms <= auth_timeout_ms def _is_registry_auth_default(self) -> bool: registry_file = self.mesh_root / "registry.json" try: payload = json.loads(registry_file.read_text(encoding="utf-8")) auth_service = payload["services"]["auth"] except Exception: return False return ( auth_service.get("host") == "localhost" and int(auth_service.get("port", 0)) == 3001 and auth_service.get("protocol") == "http" ) def _job_generator_interval_ms(self) -> int: config_file = self.mesh_root / "worker" / "job_generator_config.json" try: payload = json.loads(config_file.read_text(encoding="utf-8")) except Exception: return 0 try: return int(payload.get("interval_ms", 0)) except (TypeError, ValueError): return 0 def _is_job_generator_rate_resolved(self) -> bool: return self._job_generator_interval_ms() >= int( DEFAULT_CONFIGS["job_generator"]["interval_ms"] ) def _build_grader_context(self) -> dict[str, Any]: return { **self._baselines, "route_blocked": self._is_route_blocked(), "lock_exists": self._is_lock_present(), "cascading_timeout_resolved": self._is_cascading_timeout_resolved(), "registry_auth_matches_default": self._is_registry_auth_default(), "job_generator_interval_ms": self._job_generator_interval_ms(), "job_generator_rate_resolved": self._is_job_generator_rate_resolved(), } def _blocked_command(self, command: str) -> bool: dangerous_patterns = [ "rm -rf /", "kill -9 1", "pkill -f uvicorn", "> /tmp/gateway.log", "> /tmp/auth.log", "> /tmp/worker.log", ] normalized = command.strip().lower() return any(pattern in normalized for pattern in dangerous_patterns) def _run_command(self, command: str) -> tuple[str, str | None]: if command.strip() == NO_COMMAND_PROVIDED_SENTINEL: self.last_exit_code = 2 return ( "No command provided by model. Expected JSON with a command field.", "no_command_provided", ) if self._blocked_command(command): self.last_exit_code = 1 return ( "BLOCKED: This command would damage the environment infrastructure.", "blocked_command", ) try: result = subprocess.run( command, shell=True, capture_output=True, text=True, timeout=10, cwd="/", env={ **os.environ, "PATH": "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin", }, check=False, ) self.last_exit_code = result.returncode output = (result.stdout + result.stderr).strip() or "(no output)" return output, None except subprocess.TimeoutExpired: self.last_exit_code = 124 return "Command timed out after 10 seconds.", "timeout" except Exception as exc: self.last_exit_code = 1 return f"Command execution error: {exc}", str(exc) def _command_signature(self, command: str) -> str: return " ".join(command.strip().lower().split()) def _is_diagnostic_command(self, command: str) -> bool: diagnostic_keywords = [ "cat", "curl", "redis-cli", "ps", "ls", "grep", "tail", "jq", "lrange", "llen", "keys", "ttl", "get", ] normalized = command.lower() return any(keyword in normalized for keyword in diagnostic_keywords) def _is_state_change_command(self, command: str) -> bool: normalized = command.lower() state_change_patterns = [ "kill -hup", "redis-cli del", "redis-cli lrem", "redis-cli set", "redis-cli flushdb", "echo '{", "> /mesh/", "tee /mesh/", ] return any(pattern in normalized for pattern in state_change_patterns) def _compute_reward( self, command: str, current: Observation, previous: Observation, grader_score: float, previous_grader_score: float, command_error: str | None, ) -> float: if command_error == "no_command_provided": return 0.01 if grader_score >= 0.95: return 0.99 reward = grader_score * 0.75 signature = self._command_signature(command) signature_count = self._command_counts.get(signature, 0) + 1 self._command_counts[signature] = signature_count if ( self._is_diagnostic_command(command) and signature not in self._seen_diagnostic_signatures ): reward += 0.02 self._seen_diagnostic_signatures.add(signature) if self._is_state_change_command(command): reward += 0.03 if grader_score > previous_grader_score + 1e-4: reward += 0.15 else: reward -= 0.05 if ( current.metrics.gateway_success_rate > previous.metrics.gateway_success_rate + 1e-3 ): reward += 0.05 if current.metrics.queue_depth < previous.metrics.queue_depth: reward += 0.05 if current.metrics.worker_restart_count < previous.metrics.worker_restart_count: reward += 0.03 if current.metrics.consumer_stall_count < previous.metrics.consumer_stall_count: reward += 0.03 if signature_count > 1: reward -= min(0.12, 0.04 * (signature_count - 1)) if command.strip().lower() in { "echo", "pwd", "whoami", "date", "true", "false", }: reward -= 0.08 if self.last_exit_code != 0 and command_error not in { "blocked_command", "no_command_provided", }: reward -= 0.08 if command_error == "blocked_command": reward -= 0.25 return max(0.01, min(0.99, reward)) def _status_block(self, metrics: Any) -> str: return ( "=== pipeline status after reset ===\n" "gateway: running\n" "auth: running\n" "worker: running\n" f"queue_depth: {metrics.queue_depth}\n" f"gateway_success_rate: {metrics.gateway_success_rate:.2f}" ) def reset(self, task_name: TaskName | str) -> Observation: task = TaskName.parse(task_name) if isinstance(task_name, str) else task_name self.current_task = task self.max_steps = TASK_MAX_STEPS[task] self.step_count = 0 self._seen_diagnostic_signatures = set() self._command_counts = {} self._last_grader_score = 0.0 self._truncate_logs() self._restore_defaults() self._redis_flush() self._reset_runtime_counters() Path("/tmp/current_task").write_text(task.value, encoding="utf-8") self._process_manager.restart_all() if not self._process_manager.wait_healthy(timeout_s=30): raise RuntimeError("Services failed health checks after reset") inject_fault(task, self._process_manager) time.sleep(1.0) self._metrics_poller.poll_once() metrics = self._metrics_poller.get_current_metrics() self._baselines = { "baseline_worker_restart_count": metrics.worker_restart_count, "baseline_consumer_stall_count": metrics.consumer_stall_count, } self._last_grader_score = grade_task( task, metrics, self._build_grader_context() ) observation = Observation( command_output=self._status_block(metrics), metrics=metrics, process_status=self._process_manager.get_status(), ) self.prev_observation = observation return observation def step(self, action: Action) -> StepResult: if not self.current_task: raise RuntimeError( "Environment not initialized. Call reset(task_name) first." ) self.step_count += 1 command_output, command_error = self._run_command(action.command) self._metrics_poller.poll_once() metrics = self._metrics_poller.get_current_metrics() observation = Observation( command_output=command_output, metrics=metrics, process_status=self._process_manager.get_status(), ) previous = self.prev_observation or observation previous_grader_score = self._last_grader_score grader_score = grade_task( self.current_task, metrics, self._build_grader_context() ) reward = self._compute_reward( action.command, observation, previous, grader_score, previous_grader_score, command_error, ) if command_error == "no_command_provided": done = self.step_count >= self.max_steps else: done = grader_score >= 0.95 or self.step_count >= self.max_steps self._last_grader_score = grader_score self.prev_observation = observation info: dict[str, Any] = { "grader_score": round(grader_score, 4), "error": command_error, "exit_code": self.last_exit_code, "task": self.current_task.value if self.current_task else None, } return StepResult(observation=observation, reward=reward, done=done, info=info) def state(self) -> dict[str, Any]: self._metrics_poller.poll_once() metrics = self._metrics_poller.get_current_metrics() return { "task": self.current_task.value if self.current_task else None, "step_count": self.step_count, "max_steps": self.max_steps, "metrics": metrics.model_dump(), "process_status": self._process_manager.get_status(), "baselines": dict(self._baselines), }