Spaces:
Configuration error
Configuration error
| from __future__ import annotations | |
| import copy | |
| import hashlib | |
| from typing import Any | |
| import yaml | |
| from .data import TASK_REGISTRY, TaskSpec | |
| from .models import ConfigAction, ConfigObservation, ConfigReward, EnvState, TaskType | |
| class ConfigDebuggerEnv: | |
| def __init__(self) -> None: | |
| self.task_spec: TaskSpec | None = None | |
| self.task_id: TaskType | None = None | |
| self.current_config_text: str = "" | |
| self.previous_score: float = 0.0 | |
| self.step_count: int = 0 | |
| self.done: bool = False | |
| self.max_steps: int = 15 | |
| self.last_reward: ConfigReward | None = None | |
| self._state_visit_count: dict[str, int] = {} | |
| def reset(self, task_id: TaskType | str) -> ConfigObservation: | |
| normalized_task_id = task_id.value if isinstance(task_id, TaskType) else str(task_id) | |
| if normalized_task_id not in TASK_REGISTRY: | |
| valid = ", ".join(TASK_REGISTRY.keys()) | |
| raise ValueError(f"Unknown task_id '{task_id}'. Valid task ids: {valid}") | |
| spec = TASK_REGISTRY[normalized_task_id] | |
| self.task_spec = spec | |
| self.task_id = TaskType(normalized_task_id) | |
| self.current_config_text = spec.broken | |
| self.step_count = 0 | |
| self.done = False | |
| self.max_steps = spec.max_steps | |
| self._state_visit_count = {} | |
| initial_score = self._grade(self.current_config_text)["overall"] | |
| self.previous_score = initial_score | |
| self.last_reward = None | |
| self._track_state_visit(self.current_config_text) | |
| return self._build_observation() | |
| def step(self, action: ConfigAction) -> tuple[ConfigObservation, ConfigReward, bool, dict[str, Any]]: | |
| if self.task_spec is None or self.task_id is None: | |
| raise RuntimeError("Environment is not initialized. Call reset() first.") | |
| if self.done: | |
| obs = self._build_observation() | |
| reward = ConfigReward( | |
| value=0.0, | |
| previous_score=self.previous_score, | |
| current_score=self.previous_score, | |
| delta=0.0, | |
| penalties=["episode_already_done"], | |
| ) | |
| self.last_reward = reward | |
| return obs, reward, True, {"reason": "episode_already_done"} | |
| self.step_count += 1 | |
| penalties: list[str] = [] | |
| try: | |
| new_config_text, action_penalties = self._apply_action(self.current_config_text, action) | |
| penalties.extend(action_penalties) | |
| self.current_config_text = new_config_text | |
| except Exception as exc: | |
| penalties.append(f"invalid_action:{exc}") | |
| grading = self._grade(self.current_config_text) | |
| current_score = grading["overall"] | |
| delta = round(current_score - self.previous_score, 4) | |
| loop_penalty = self._track_state_visit(self.current_config_text) | |
| if loop_penalty > 0: | |
| penalties.append(f"loop_penalty:{loop_penalty:.2f}") | |
| reward_value = self._compute_reward(current_score, delta, penalties, loop_penalty) | |
| reward = ConfigReward( | |
| value=reward_value, | |
| previous_score=round(self.previous_score, 4), | |
| current_score=round(current_score, 4), | |
| delta=delta, | |
| penalties=penalties, | |
| ) | |
| self.previous_score = current_score | |
| self.done = current_score >= 0.98 or self.step_count >= self.max_steps | |
| self.last_reward = reward | |
| info = { | |
| "task_id": self.task_id.value, | |
| "schema_score": grading["schema"], | |
| "logic_score": grading["logic"], | |
| "syntax_valid": grading["syntax_valid"], | |
| } | |
| return self._build_observation(grading), reward, self.done, info | |
| def state(self) -> EnvState: | |
| observation = self._build_observation() if self.task_spec is not None else None | |
| return EnvState( | |
| task_id=self.task_id, | |
| done=self.done, | |
| step_count=self.step_count, | |
| max_steps=self.max_steps, | |
| observation=observation, | |
| last_reward=self.last_reward, | |
| ) | |
| def _build_observation(self, grading: dict[str, Any] | None = None) -> ConfigObservation: | |
| if self.task_spec is None or self.task_id is None: | |
| raise RuntimeError("Environment is not initialized. Call reset() first.") | |
| if grading is None: | |
| grading = self._grade(self.current_config_text) | |
| return ConfigObservation( | |
| task_id=self.task_id, | |
| task_description=self.task_spec.description, | |
| current_config=self.current_config_text, | |
| syntax_valid=grading["syntax_valid"], | |
| validation_errors=grading["errors"], | |
| schema_score=grading["schema"], | |
| logic_score=grading["logic"], | |
| overall_score=grading["overall"], | |
| step_count=self.step_count, | |
| max_steps=self.max_steps, | |
| ) | |
| def _compute_reward(self, current_score: float, delta: float, penalties: list[str], loop_penalty: float) -> float: | |
| reward = current_score | |
| if delta > 0: | |
| reward += min(0.15, delta) | |
| elif delta < 0: | |
| reward += delta * 0.4 | |
| penalty_total = loop_penalty | |
| if any(p.startswith("invalid_action") for p in penalties): | |
| penalty_total += 0.10 | |
| if any(p.startswith("destructive_delete") for p in penalties): | |
| penalty_total += 0.08 | |
| reward -= penalty_total | |
| if current_score >= 0.98: | |
| reward += 0.05 | |
| return round(max(0.0, min(1.0, reward)), 4) | |
| def _track_state_visit(self, config_text: str) -> float: | |
| state_hash = hashlib.sha1(config_text.encode("utf-8")).hexdigest() | |
| count = self._state_visit_count.get(state_hash, 0) + 1 | |
| self._state_visit_count[state_hash] = count | |
| # Penalize repeated states to discourage loops. | |
| if count <= 1: | |
| return 0.0 | |
| return min(0.03 * (count - 1), 0.12) | |
| def _apply_action(self, config_text: str, action: ConfigAction) -> tuple[str, list[str]]: | |
| penalties: list[str] = [] | |
| data = yaml.safe_load(config_text) | |
| if data is None: | |
| data = {} | |
| if not isinstance(data, dict): | |
| raise ValueError("current config is not a dictionary-like YAML document") | |
| root = copy.deepcopy(data) | |
| tokens = self._parse_path(action.path) | |
| if action.operation == "delete" and tokens and isinstance(tokens[0], str): | |
| if tokens[0] in {"services", "spec", "training", "hardware"} and len(tokens) == 1: | |
| penalties.append("destructive_delete:top_level_critical_key") | |
| if action.operation in {"edit", "add"}: | |
| self._set_path(root, tokens, action.value) | |
| else: | |
| deleted = self._delete_path(root, tokens) | |
| if not deleted: | |
| penalties.append("delete_noop") | |
| dumped = yaml.safe_dump(root, sort_keys=False) | |
| return dumped, penalties | |
| def _parse_path(self, path: str) -> list[str | int]: | |
| tokens: list[str | int] = [] | |
| for chunk in path.split("."): | |
| chunk = chunk.strip() | |
| if chunk == "": | |
| raise ValueError("path contains empty token") | |
| if chunk.isdigit(): | |
| tokens.append(int(chunk)) | |
| else: | |
| tokens.append(chunk) | |
| return tokens | |
| def _set_path(self, root: dict[str, Any], tokens: list[str | int], value: Any) -> None: | |
| if not tokens: | |
| raise ValueError("cannot set empty path") | |
| cursor: Any = root | |
| for i, token in enumerate(tokens[:-1]): | |
| nxt = tokens[i + 1] | |
| if isinstance(token, int): | |
| if not isinstance(cursor, list): | |
| raise ValueError("list index used on non-list node") | |
| while token >= len(cursor): | |
| cursor.append({} if isinstance(nxt, str) else []) | |
| if cursor[token] is None: | |
| cursor[token] = {} if isinstance(nxt, str) else [] | |
| cursor = cursor[token] | |
| else: | |
| if not isinstance(cursor, dict): | |
| raise ValueError("dict key used on non-dict node") | |
| if token not in cursor or cursor[token] is None: | |
| cursor[token] = {} if isinstance(nxt, str) else [] | |
| cursor = cursor[token] | |
| final = tokens[-1] | |
| if isinstance(final, int): | |
| if not isinstance(cursor, list): | |
| raise ValueError("final list index used on non-list node") | |
| while final >= len(cursor): | |
| cursor.append(None) | |
| cursor[final] = value | |
| else: | |
| if not isinstance(cursor, dict): | |
| raise ValueError("final dict key used on non-dict node") | |
| cursor[final] = value | |
| def _delete_path(self, root: dict[str, Any], tokens: list[str | int]) -> bool: | |
| if not tokens: | |
| return False | |
| cursor: Any = root | |
| for token in tokens[:-1]: | |
| if isinstance(token, int): | |
| if not isinstance(cursor, list) or token >= len(cursor): | |
| return False | |
| cursor = cursor[token] | |
| else: | |
| if not isinstance(cursor, dict) or token not in cursor: | |
| return False | |
| cursor = cursor[token] | |
| final = tokens[-1] | |
| if isinstance(final, int): | |
| if not isinstance(cursor, list) or final >= len(cursor): | |
| return False | |
| cursor.pop(final) | |
| return True | |
| if not isinstance(cursor, dict) or final not in cursor: | |
| return False | |
| del cursor[final] | |
| return True | |
| def _grade(self, config_text: str) -> dict[str, Any]: | |
| assert self.task_spec is not None | |
| errors: list[str] = [] | |
| try: | |
| parsed = yaml.safe_load(config_text) | |
| except Exception as exc: | |
| return { | |
| "syntax_valid": False, | |
| "schema": 0.0, | |
| "logic": 0.0, | |
| "overall": 0.0, | |
| "errors": [f"YAML syntax error: {exc}"], | |
| } | |
| if parsed is None: | |
| parsed = {} | |
| if not isinstance(parsed, dict): | |
| return { | |
| "syntax_valid": True, | |
| "schema": 0.0, | |
| "logic": 0.0, | |
| "overall": 0.0, | |
| "errors": ["Root document must be a mapping/dict"], | |
| } | |
| schema_score, schema_errors = self._grade_schema(parsed) | |
| logic_score, logic_errors = self._grade_logic(parsed) | |
| errors.extend(schema_errors) | |
| errors.extend(logic_errors) | |
| overall = round((0.60 * schema_score) + (0.40 * logic_score), 4) | |
| return { | |
| "syntax_valid": True, | |
| "schema": schema_score, | |
| "logic": logic_score, | |
| "overall": overall, | |
| "errors": errors[:20], | |
| } | |
| def _grade_schema(self, parsed: dict[str, Any]) -> tuple[float, list[str]]: | |
| assert self.task_spec is not None | |
| total_weight = 0.0 | |
| matched_weight = 0.0 | |
| errors: list[str] = [] | |
| for path, weight in self.task_spec.required_paths.items(): | |
| total_weight += weight | |
| expected = self._read_path(self.task_spec.target, self._parse_path(path)) | |
| got, exists = self._safe_read(parsed, self._parse_path(path)) | |
| if not exists: | |
| errors.append(f"Missing required path: {path}") | |
| continue | |
| if got == expected: | |
| matched_weight += weight | |
| else: | |
| errors.append(f"Mismatch at {path}: expected={expected!r}, got={got!r}") | |
| score = 0.0 if total_weight == 0 else round(matched_weight / total_weight, 4) | |
| return score, errors | |
| def _grade_logic(self, parsed: dict[str, Any]) -> tuple[float, list[str]]: | |
| assert self.task_spec is not None | |
| checks: list[tuple[str, bool]] = [] | |
| t = self.task_spec.task_id | |
| if t == "easy_docker": | |
| web_ports = self._safe_get(parsed, ["services", "web", "ports"], default=[]) | |
| db_ports = self._safe_get(parsed, ["services", "db", "ports"], default=[]) | |
| env_node = self._safe_get(parsed, ["services", "web", "environment"], default={}) | |
| checks.append(("web ports must be list", isinstance(web_ports, list))) | |
| checks.append(("all web ports must contain ':'", all(isinstance(p, str) and ":" in p for p in web_ports))) | |
| checks.append(("db port must include host and container", "5432:5432" in db_ports if isinstance(db_ports, list) else False)) | |
| checks.append(("environment must be dict", isinstance(env_node, dict))) | |
| elif t == "medium_k8s": | |
| replicas = self._safe_get(parsed, ["spec", "replicas"], default=None) | |
| limits_mem = self._safe_get( | |
| parsed, | |
| ["spec", "template", "spec", "containers", 0, "resources", "limits", "memory"], | |
| default="", | |
| ) | |
| req_mem = self._safe_get( | |
| parsed, | |
| ["spec", "template", "spec", "containers", 0, "resources", "requests", "memory"], | |
| default="", | |
| ) | |
| req_cpu = self._safe_get( | |
| parsed, | |
| ["spec", "template", "spec", "containers", 0, "resources", "requests", "cpu"], | |
| default="", | |
| ) | |
| checks.append(("replicas should be int", isinstance(replicas, int))) | |
| checks.append(("limits memory must include unit", isinstance(limits_mem, str) and limits_mem.endswith(("Mi", "Gi")))) | |
| checks.append(("requests memory must include unit", isinstance(req_mem, str) and req_mem.endswith(("Mi", "Gi")))) | |
| checks.append(("cpu request should be millicore string", isinstance(req_cpu, str) and req_cpu.endswith("m"))) | |
| elif t == "hard_ml_config": | |
| warmup = self._safe_get(parsed, ["training", "warmup_steps"], default=0) | |
| max_steps = self._safe_get(parsed, ["training", "max_steps"], default=0) | |
| use_cuda = self._safe_get(parsed, ["hardware", "use_cuda"], default=False) | |
| gpu_count = self._safe_get(parsed, ["hardware", "gpu_count"], default=0) | |
| batch_size = self._safe_get(parsed, ["training", "batch_size"], default=0) | |
| train_batch = self._safe_get(parsed, ["data", "train_batch_size"], default=0) | |
| log_interval = self._safe_get(parsed, ["logging", "log_interval"], default=999999) | |
| checks.append(("warmup_steps < max_steps", isinstance(warmup, int) and isinstance(max_steps, int) and warmup < max_steps)) | |
| checks.append(("gpu_count >=1 when use_cuda", (not use_cuda) or (isinstance(gpu_count, int) and gpu_count >= 1))) | |
| checks.append(("train_batch_size equals 2 * batch_size", isinstance(batch_size, int) and isinstance(train_batch, int) and train_batch == 2 * batch_size)) | |
| checks.append(("log_interval <= 100", isinstance(log_interval, int) and log_interval <= 100)) | |
| total = len(checks) | |
| passed = sum(1 for _, ok in checks if ok) | |
| errors = [msg for msg, ok in checks if not ok] | |
| score = 0.0 if total == 0 else round(passed / total, 4) | |
| return score, errors | |
| def _read_path(self, source: Any, tokens: list[str | int]) -> Any: | |
| cursor = source | |
| for token in tokens: | |
| if isinstance(token, int): | |
| cursor = cursor[token] | |
| else: | |
| cursor = cursor[token] | |
| return cursor | |
| def _safe_read(self, source: Any, tokens: list[str | int]) -> tuple[Any, bool]: | |
| cursor = source | |
| for token in tokens: | |
| try: | |
| if isinstance(token, int): | |
| if not isinstance(cursor, list): | |
| return None, False | |
| cursor = cursor[token] | |
| else: | |
| if not isinstance(cursor, dict) or token not in cursor: | |
| return None, False | |
| cursor = cursor[token] | |
| except Exception: | |
| return None, False | |
| return cursor, True | |
| def _safe_get(self, source: Any, tokens: list[str | int], default: Any) -> Any: | |
| value, exists = self._safe_read(source, tokens) | |
| return value if exists else default | |