openenv2 / server /env.py
hissterical's picture
Upload 10 files
ebf4715 verified
from __future__ import annotations
import copy
import hashlib
from typing import Any
import yaml
from .data import TASK_REGISTRY, TaskSpec
from .models import ConfigAction, ConfigObservation, ConfigReward, EnvState, TaskType
class ConfigDebuggerEnv:
def __init__(self) -> None:
self.task_spec: TaskSpec | None = None
self.task_id: TaskType | None = None
self.current_config_text: str = ""
self.previous_score: float = 0.0
self.step_count: int = 0
self.done: bool = False
self.max_steps: int = 15
self.last_reward: ConfigReward | None = None
self._state_visit_count: dict[str, int] = {}
def reset(self, task_id: TaskType | str) -> ConfigObservation:
normalized_task_id = task_id.value if isinstance(task_id, TaskType) else str(task_id)
if normalized_task_id not in TASK_REGISTRY:
valid = ", ".join(TASK_REGISTRY.keys())
raise ValueError(f"Unknown task_id '{task_id}'. Valid task ids: {valid}")
spec = TASK_REGISTRY[normalized_task_id]
self.task_spec = spec
self.task_id = TaskType(normalized_task_id)
self.current_config_text = spec.broken
self.step_count = 0
self.done = False
self.max_steps = spec.max_steps
self._state_visit_count = {}
initial_score = self._grade(self.current_config_text)["overall"]
self.previous_score = initial_score
self.last_reward = None
self._track_state_visit(self.current_config_text)
return self._build_observation()
def step(self, action: ConfigAction) -> tuple[ConfigObservation, ConfigReward, bool, dict[str, Any]]:
if self.task_spec is None or self.task_id is None:
raise RuntimeError("Environment is not initialized. Call reset() first.")
if self.done:
obs = self._build_observation()
reward = ConfigReward(
value=0.0,
previous_score=self.previous_score,
current_score=self.previous_score,
delta=0.0,
penalties=["episode_already_done"],
)
self.last_reward = reward
return obs, reward, True, {"reason": "episode_already_done"}
self.step_count += 1
penalties: list[str] = []
try:
new_config_text, action_penalties = self._apply_action(self.current_config_text, action)
penalties.extend(action_penalties)
self.current_config_text = new_config_text
except Exception as exc:
penalties.append(f"invalid_action:{exc}")
grading = self._grade(self.current_config_text)
current_score = grading["overall"]
delta = round(current_score - self.previous_score, 4)
loop_penalty = self._track_state_visit(self.current_config_text)
if loop_penalty > 0:
penalties.append(f"loop_penalty:{loop_penalty:.2f}")
reward_value = self._compute_reward(current_score, delta, penalties, loop_penalty)
reward = ConfigReward(
value=reward_value,
previous_score=round(self.previous_score, 4),
current_score=round(current_score, 4),
delta=delta,
penalties=penalties,
)
self.previous_score = current_score
self.done = current_score >= 0.98 or self.step_count >= self.max_steps
self.last_reward = reward
info = {
"task_id": self.task_id.value,
"schema_score": grading["schema"],
"logic_score": grading["logic"],
"syntax_valid": grading["syntax_valid"],
}
return self._build_observation(grading), reward, self.done, info
def state(self) -> EnvState:
observation = self._build_observation() if self.task_spec is not None else None
return EnvState(
task_id=self.task_id,
done=self.done,
step_count=self.step_count,
max_steps=self.max_steps,
observation=observation,
last_reward=self.last_reward,
)
def _build_observation(self, grading: dict[str, Any] | None = None) -> ConfigObservation:
if self.task_spec is None or self.task_id is None:
raise RuntimeError("Environment is not initialized. Call reset() first.")
if grading is None:
grading = self._grade(self.current_config_text)
return ConfigObservation(
task_id=self.task_id,
task_description=self.task_spec.description,
current_config=self.current_config_text,
syntax_valid=grading["syntax_valid"],
validation_errors=grading["errors"],
schema_score=grading["schema"],
logic_score=grading["logic"],
overall_score=grading["overall"],
step_count=self.step_count,
max_steps=self.max_steps,
)
def _compute_reward(self, current_score: float, delta: float, penalties: list[str], loop_penalty: float) -> float:
reward = current_score
if delta > 0:
reward += min(0.15, delta)
elif delta < 0:
reward += delta * 0.4
penalty_total = loop_penalty
if any(p.startswith("invalid_action") for p in penalties):
penalty_total += 0.10
if any(p.startswith("destructive_delete") for p in penalties):
penalty_total += 0.08
reward -= penalty_total
if current_score >= 0.98:
reward += 0.05
return round(max(0.0, min(1.0, reward)), 4)
def _track_state_visit(self, config_text: str) -> float:
state_hash = hashlib.sha1(config_text.encode("utf-8")).hexdigest()
count = self._state_visit_count.get(state_hash, 0) + 1
self._state_visit_count[state_hash] = count
# Penalize repeated states to discourage loops.
if count <= 1:
return 0.0
return min(0.03 * (count - 1), 0.12)
def _apply_action(self, config_text: str, action: ConfigAction) -> tuple[str, list[str]]:
penalties: list[str] = []
data = yaml.safe_load(config_text)
if data is None:
data = {}
if not isinstance(data, dict):
raise ValueError("current config is not a dictionary-like YAML document")
root = copy.deepcopy(data)
tokens = self._parse_path(action.path)
if action.operation == "delete" and tokens and isinstance(tokens[0], str):
if tokens[0] in {"services", "spec", "training", "hardware"} and len(tokens) == 1:
penalties.append("destructive_delete:top_level_critical_key")
if action.operation in {"edit", "add"}:
self._set_path(root, tokens, action.value)
else:
deleted = self._delete_path(root, tokens)
if not deleted:
penalties.append("delete_noop")
dumped = yaml.safe_dump(root, sort_keys=False)
return dumped, penalties
def _parse_path(self, path: str) -> list[str | int]:
tokens: list[str | int] = []
for chunk in path.split("."):
chunk = chunk.strip()
if chunk == "":
raise ValueError("path contains empty token")
if chunk.isdigit():
tokens.append(int(chunk))
else:
tokens.append(chunk)
return tokens
def _set_path(self, root: dict[str, Any], tokens: list[str | int], value: Any) -> None:
if not tokens:
raise ValueError("cannot set empty path")
cursor: Any = root
for i, token in enumerate(tokens[:-1]):
nxt = tokens[i + 1]
if isinstance(token, int):
if not isinstance(cursor, list):
raise ValueError("list index used on non-list node")
while token >= len(cursor):
cursor.append({} if isinstance(nxt, str) else [])
if cursor[token] is None:
cursor[token] = {} if isinstance(nxt, str) else []
cursor = cursor[token]
else:
if not isinstance(cursor, dict):
raise ValueError("dict key used on non-dict node")
if token not in cursor or cursor[token] is None:
cursor[token] = {} if isinstance(nxt, str) else []
cursor = cursor[token]
final = tokens[-1]
if isinstance(final, int):
if not isinstance(cursor, list):
raise ValueError("final list index used on non-list node")
while final >= len(cursor):
cursor.append(None)
cursor[final] = value
else:
if not isinstance(cursor, dict):
raise ValueError("final dict key used on non-dict node")
cursor[final] = value
def _delete_path(self, root: dict[str, Any], tokens: list[str | int]) -> bool:
if not tokens:
return False
cursor: Any = root
for token in tokens[:-1]:
if isinstance(token, int):
if not isinstance(cursor, list) or token >= len(cursor):
return False
cursor = cursor[token]
else:
if not isinstance(cursor, dict) or token not in cursor:
return False
cursor = cursor[token]
final = tokens[-1]
if isinstance(final, int):
if not isinstance(cursor, list) or final >= len(cursor):
return False
cursor.pop(final)
return True
if not isinstance(cursor, dict) or final not in cursor:
return False
del cursor[final]
return True
def _grade(self, config_text: str) -> dict[str, Any]:
assert self.task_spec is not None
errors: list[str] = []
try:
parsed = yaml.safe_load(config_text)
except Exception as exc:
return {
"syntax_valid": False,
"schema": 0.0,
"logic": 0.0,
"overall": 0.0,
"errors": [f"YAML syntax error: {exc}"],
}
if parsed is None:
parsed = {}
if not isinstance(parsed, dict):
return {
"syntax_valid": True,
"schema": 0.0,
"logic": 0.0,
"overall": 0.0,
"errors": ["Root document must be a mapping/dict"],
}
schema_score, schema_errors = self._grade_schema(parsed)
logic_score, logic_errors = self._grade_logic(parsed)
errors.extend(schema_errors)
errors.extend(logic_errors)
overall = round((0.60 * schema_score) + (0.40 * logic_score), 4)
return {
"syntax_valid": True,
"schema": schema_score,
"logic": logic_score,
"overall": overall,
"errors": errors[:20],
}
def _grade_schema(self, parsed: dict[str, Any]) -> tuple[float, list[str]]:
assert self.task_spec is not None
total_weight = 0.0
matched_weight = 0.0
errors: list[str] = []
for path, weight in self.task_spec.required_paths.items():
total_weight += weight
expected = self._read_path(self.task_spec.target, self._parse_path(path))
got, exists = self._safe_read(parsed, self._parse_path(path))
if not exists:
errors.append(f"Missing required path: {path}")
continue
if got == expected:
matched_weight += weight
else:
errors.append(f"Mismatch at {path}: expected={expected!r}, got={got!r}")
score = 0.0 if total_weight == 0 else round(matched_weight / total_weight, 4)
return score, errors
def _grade_logic(self, parsed: dict[str, Any]) -> tuple[float, list[str]]:
assert self.task_spec is not None
checks: list[tuple[str, bool]] = []
t = self.task_spec.task_id
if t == "easy_docker":
web_ports = self._safe_get(parsed, ["services", "web", "ports"], default=[])
db_ports = self._safe_get(parsed, ["services", "db", "ports"], default=[])
env_node = self._safe_get(parsed, ["services", "web", "environment"], default={})
checks.append(("web ports must be list", isinstance(web_ports, list)))
checks.append(("all web ports must contain ':'", all(isinstance(p, str) and ":" in p for p in web_ports)))
checks.append(("db port must include host and container", "5432:5432" in db_ports if isinstance(db_ports, list) else False))
checks.append(("environment must be dict", isinstance(env_node, dict)))
elif t == "medium_k8s":
replicas = self._safe_get(parsed, ["spec", "replicas"], default=None)
limits_mem = self._safe_get(
parsed,
["spec", "template", "spec", "containers", 0, "resources", "limits", "memory"],
default="",
)
req_mem = self._safe_get(
parsed,
["spec", "template", "spec", "containers", 0, "resources", "requests", "memory"],
default="",
)
req_cpu = self._safe_get(
parsed,
["spec", "template", "spec", "containers", 0, "resources", "requests", "cpu"],
default="",
)
checks.append(("replicas should be int", isinstance(replicas, int)))
checks.append(("limits memory must include unit", isinstance(limits_mem, str) and limits_mem.endswith(("Mi", "Gi"))))
checks.append(("requests memory must include unit", isinstance(req_mem, str) and req_mem.endswith(("Mi", "Gi"))))
checks.append(("cpu request should be millicore string", isinstance(req_cpu, str) and req_cpu.endswith("m")))
elif t == "hard_ml_config":
warmup = self._safe_get(parsed, ["training", "warmup_steps"], default=0)
max_steps = self._safe_get(parsed, ["training", "max_steps"], default=0)
use_cuda = self._safe_get(parsed, ["hardware", "use_cuda"], default=False)
gpu_count = self._safe_get(parsed, ["hardware", "gpu_count"], default=0)
batch_size = self._safe_get(parsed, ["training", "batch_size"], default=0)
train_batch = self._safe_get(parsed, ["data", "train_batch_size"], default=0)
log_interval = self._safe_get(parsed, ["logging", "log_interval"], default=999999)
checks.append(("warmup_steps < max_steps", isinstance(warmup, int) and isinstance(max_steps, int) and warmup < max_steps))
checks.append(("gpu_count >=1 when use_cuda", (not use_cuda) or (isinstance(gpu_count, int) and gpu_count >= 1)))
checks.append(("train_batch_size equals 2 * batch_size", isinstance(batch_size, int) and isinstance(train_batch, int) and train_batch == 2 * batch_size))
checks.append(("log_interval <= 100", isinstance(log_interval, int) and log_interval <= 100))
total = len(checks)
passed = sum(1 for _, ok in checks if ok)
errors = [msg for msg, ok in checks if not ok]
score = 0.0 if total == 0 else round(passed / total, 4)
return score, errors
def _read_path(self, source: Any, tokens: list[str | int]) -> Any:
cursor = source
for token in tokens:
if isinstance(token, int):
cursor = cursor[token]
else:
cursor = cursor[token]
return cursor
def _safe_read(self, source: Any, tokens: list[str | int]) -> tuple[Any, bool]:
cursor = source
for token in tokens:
try:
if isinstance(token, int):
if not isinstance(cursor, list):
return None, False
cursor = cursor[token]
else:
if not isinstance(cursor, dict) or token not in cursor:
return None, False
cursor = cursor[token]
except Exception:
return None, False
return cursor, True
def _safe_get(self, source: Any, tokens: list[str | int], default: Any) -> Any:
value, exists = self._safe_read(source, tokens)
return value if exists else default