cicd-debugger-env-final / env /environment.py
Lishika's picture
final fix
32445fd
from __future__ import annotations
from dataclasses import dataclass, field
from difflib import SequenceMatcher
import random
import re
from typing import Any
import yaml
from env.models import Action, EnvStateSnapshot, Observation, Reward
from env.rewards import RewardCalculator
from env.tasks import get_task_by_id, get_tasks_by_difficulty
from env.tasks.task_types import CICDTask
REQUIRED_TOOLS = [
"read_file",
"read_logs",
"analyze_error",
"edit_config",
"run_pipeline_stage",
"run_tests",
"validate_fix",
"submit_solution",
]
MAX_STEPS = 30
@dataclass
class EnvironmentState:
task: CICDTask
current_config: str
previous_config: str
step_count: int = 0
done: bool = False
progress_flags: dict[str, bool] = field(default_factory=dict)
file_modification_count: int = 0
total_changed_lines: int = 0
hidden_test_pass_rate: float = 0.0
action_history: list[str] = field(default_factory=list)
stage_results: dict[str, bool] = field(default_factory=dict)
failed_validations: int = 0
consecutive_edit_actions: int = 0
current_logs: str = ""
last_error: str = ""
last_action_error: str | None = None
last_info: dict[str, Any] = field(default_factory=dict)
class CICDDebuggerEnvironment:
"""RL-style CI/CD debugging environment with strict tool-based actions."""
def __init__(
self,
max_steps: int = MAX_STEPS,
seed: int | None = None,
llm_judge: Any | None = None,
) -> None:
self.max_steps = max(1, int(max_steps))
self.random = random.Random(seed)
self.reward_calculator = RewardCalculator(llm_judge=llm_judge)
self._state: EnvironmentState | None = None
async def reset(self, task_id: str | None = None, difficulty: str | None = None) -> dict[str, Any]:
task = self._select_task(task_id=task_id, difficulty=difficulty)
self._state = EnvironmentState(
task=task,
current_config=task.broken_config,
previous_config=task.broken_config,
progress_flags={tool: False for tool in REQUIRED_TOOLS},
current_logs=task.logs,
last_error=task.error_message,
)
return Observation.model_validate(self._build_observation()).model_dump()
async def step(self, action: Any) -> tuple[dict[str, Any], float, bool, dict[str, Any]]:
if self._state is None:
raise RuntimeError("Environment not initialized. Call reset() first.")
if self._state.done:
reward_model = Reward(value=0.0, components={"total": 0.0})
return Observation.model_validate(self._build_observation()).model_dump(), float(reward_model.value), True, {
"tool": "none",
"message": "episode already completed",
"error": None,
"reward_model": reward_model.model_dump(),
}
parsed_action = Action.from_input(action)
tool, payload = parsed_action.tool, dict(parsed_action.payload)
self._state.step_count += 1
self._state.previous_config = self._state.current_config
self._state.action_history.append(tool)
self._state.last_action_error = None
info: dict[str, Any] = {
"tool": tool,
"message": "",
"error": None,
}
changed_lines = 0
result: dict[str, Any] = {
"previous_config": self._state.previous_config,
"current_config": self._state.current_config,
"fixed_config": self._state.current_config,
"expected_config": self._state.task.expected_config,
"error": self._state.last_error,
"logs_analyzed": False,
"error_diagnosed": False,
"fix_proposed": False,
"pipeline_run": False,
"tests_passed": False,
"command_succeeded": False,
"changed_files_count": 0,
"changed_lines_count": 0,
"edit_count": {
"changed_files_count": self._state.file_modification_count,
"changed_lines_count": self._state.total_changed_lines,
},
"deterministic_score": None,
"hidden_test_pass_rate": None,
"judge_scores": None,
"hacking_attempt": False,
}
if tool not in REQUIRED_TOOLS:
info["message"] = "unsupported action tool"
info["error"] = f"tool '{tool}' is not allowed"
self._state.last_action_error = str(info["error"])
elif tool == "read_file":
self._state.progress_flags[tool] = True
result["command_succeeded"] = True
info["message"] = "returned current workflow config"
self._state.current_logs = self._state.current_config
self._state.consecutive_edit_actions = 0
elif tool == "read_logs":
self._state.progress_flags[tool] = True
result["logs_analyzed"] = True
result["command_succeeded"] = True
info["message"] = "returned pipeline failure logs"
self._state.current_logs = self._state.task.logs
self._state.consecutive_edit_actions = 0
elif tool == "analyze_error":
self._state.progress_flags[tool] = True
result["error_diagnosed"] = True
result["command_succeeded"] = True
root_cause = self._detect_root_cause(self._state.current_config, self._state.task)
info["message"] = f"root cause: {root_cause}"
self._state.current_logs = f"analysis result: {root_cause}"
self._state.consecutive_edit_actions = 0
elif tool == "edit_config":
self._state.progress_flags[tool] = True
updated_config, summary = self._apply_edit(self._state.current_config, payload, self._state.task)
changed_lines = self._count_changed_lines(self._state.current_config, updated_config)
if changed_lines > 0:
self._state.current_config = updated_config
self._state.file_modification_count += 1
self._state.total_changed_lines += changed_lines
result["fix_proposed"] = True
result["command_succeeded"] = True
info["message"] = summary
self._state.current_logs = f"edit applied: {summary}"
else:
result["command_succeeded"] = False
info["message"] = "no config changes applied"
info["error"] = "edit_config did not modify workflow"
self._state.last_action_error = str(info["error"])
self._state.current_logs = "edit action produced no changes"
self._state.consecutive_edit_actions += 1
elif tool == "run_pipeline_stage":
self._state.progress_flags[tool] = True
stage = self._extract_stage(payload, fallback=self._state.task.failure_stage)
success, stage_logs = self._simulate_stage(self._state.current_config, stage, self._state.task)
self._state.stage_results[stage] = success
result["pipeline_run"] = True
result["command_succeeded"] = success
info["message"] = f"stage '{stage}' {'passed' if success else 'failed'}"
if not success:
info["error"] = stage_logs
self._state.last_action_error = stage_logs
self._state.last_error = stage_logs
self._state.current_logs = stage_logs
self._state.consecutive_edit_actions = 0
elif tool == "run_tests":
self._state.progress_flags[tool] = True
tests_passed, test_logs = self._run_tests(self._state.current_config, self._state.task)
result["pipeline_run"] = True
result["tests_passed"] = tests_passed
result["command_succeeded"] = tests_passed
info["message"] = "tests passed" if tests_passed else "tests failed"
if not tests_passed:
info["error"] = test_logs
self._state.last_action_error = test_logs
self._state.last_error = test_logs
self._state.current_logs = test_logs
self._state.consecutive_edit_actions = 0
elif tool == "validate_fix":
self._state.progress_flags[tool] = True
validation = self._validate_current_fix(self._state)
result.update(validation)
result["pipeline_run"] = True
is_valid = bool(validation.get("is_valid"))
result["command_succeeded"] = is_valid
if not is_valid:
self._state.failed_validations += 1
info["error"] = str(validation.get("summary", "validation failed"))
self._state.last_action_error = str(info["error"])
info["message"] = "validation passed" if is_valid else "validation failed"
self._state.hidden_test_pass_rate = float(validation.get("hidden_test_pass_rate") or 0.0)
self._state.current_logs = str(validation.get("summary", "validation complete"))
self._state.consecutive_edit_actions = 0
elif tool == "submit_solution":
validation = self._validate_current_fix(self._state)
result.update(validation)
result["pipeline_run"] = True
self._state.progress_flags[tool] = True
accepted = bool(validation.get("is_valid"))
result["command_succeeded"] = accepted
if accepted:
self._state.done = True
info["message"] = "solution accepted"
self._state.current_logs = "submission accepted"
else:
self._state.failed_validations += 1
info["message"] = "solution rejected"
info["error"] = "submission failed quality checks"
self._state.last_action_error = str(info["error"])
self._state.current_logs = str(validation.get("summary", "submission rejected"))
self._state.hidden_test_pass_rate = float(validation.get("hidden_test_pass_rate") or 0.0)
self._state.consecutive_edit_actions = 0
result["hacking_attempt"] = self._detect_hacking_attempt(tool, payload, self._state.current_config)
result["current_config"] = self._state.current_config
result["fixed_config"] = self._state.current_config
result["changed_files_count"] = 1 if changed_lines > 0 else 0
result["changed_lines_count"] = changed_lines
result["edit_count"] = {
"changed_files_count": self._state.file_modification_count,
"changed_lines_count": self._state.total_changed_lines,
}
if info["error"]:
self._state.last_error = str(info["error"])
result["error"] = self._state.last_error
if self._state.step_count >= self.max_steps and not self._state.done:
self._state.done = True
if not info["error"]:
info["error"] = "max_steps_reached"
info["message"] = "max steps reached"
reward = self.reward_calculator.calculate_step_reward(
state={
"step_count": self._state.step_count,
"previous_config": self._state.previous_config,
"expected_config": self._state.task.expected_config,
"original_config": self._state.task.broken_config,
"error": self._state.last_error,
"changed_files_count": self._state.file_modification_count,
"changed_lines_count": self._state.total_changed_lines,
"consecutive_edit_actions": self._state.consecutive_edit_actions,
"failed_validations": self._state.failed_validations,
},
action=tool,
result=result,
original_config=self._state.task.broken_config,
fixed_config=self._state.current_config,
error_message=self._state.last_error,
expected_config=self._state.task.expected_config,
metadata=self._state.task.metadata,
)
# 🔥 CRITICAL FIX FOR SCALER (FINAL OVERRIDE)
if tool in ["validate_fix", "submit_solution"]:
is_correct = bool(result.get("is_valid"))
if is_correct:
reward = 1.0
self._state.done = True
else:
reward = 0.0
reward_model = Reward(value=float(reward), components={"total": float(reward)})
info["reward_model"] = reward_model.model_dump()
self._state.last_info = info
observation = Observation.model_validate(self._build_observation()).model_dump()
done = bool(self._state.done)
return observation, float(reward_model.value), done, info
async def close(self) -> None:
return None
def get_state(self) -> dict[str, Any]:
if self._state is None:
return EnvStateSnapshot(initialized=False).model_dump()
snapshot = {
"initialized": True,
"task_id": self._state.task.task_id,
"difficulty": self._state.task.difficulty,
"actual_bug": self._state.task.actual_bug,
"correct_solution": self._state.task.expected_config,
"failure_stage": self._state.task.failure_stage,
"step_count": self._state.step_count,
"done": self._state.done,
"progress_flags": dict(self._state.progress_flags),
"file_modification_count": self._state.file_modification_count,
"total_changed_lines": self._state.total_changed_lines,
"hidden_test_pass_rate": self._state.hidden_test_pass_rate,
"stage_results": dict(self._state.stage_results),
"failed_validations": self._state.failed_validations,
"last_action_error": self._state.last_action_error,
"last_error": self._state.last_error,
}
return EnvStateSnapshot.model_validate(snapshot).model_dump()
def state(self) -> dict[str, Any]:
return self.get_state()
def _build_observation(self) -> dict[str, Any]:
if self._state is None:
raise RuntimeError("Environment not initialized")
observation = {
"task_id": self._state.task.task_id,
"difficulty": self._state.task.difficulty,
"failure_stage": self._state.task.failure_stage,
"actual_bug": self._state.task.actual_bug,
"config": self._state.current_config,
"logs": self._state.current_logs,
"error_message": self._state.last_error,
"available_tools": list(REQUIRED_TOOLS),
"progress_flags": dict(self._state.progress_flags),
"file_modification_count": self._state.file_modification_count,
"hidden_test_pass_rate": self._state.hidden_test_pass_rate,
"step_count": self._state.step_count,
"last_action_error": self._state.last_action_error,
}
return Observation.model_validate(observation).model_dump()
def _select_task(self, task_id: str | None, difficulty: str | None) -> CICDTask:
if task_id:
task = get_task_by_id(task_id)
if task is None:
raise ValueError(f"Unknown task_id: {task_id}")
return task
filtered = get_tasks_by_difficulty(difficulty)
if not filtered:
raise ValueError(f"No tasks available for difficulty: {difficulty}")
return self.random.choice(filtered)
def _parse_action(self, action: Any) -> tuple[str, dict[str, Any]]:
parsed = Action.from_input(action)
return parsed.tool, dict(parsed.payload)
def _extract_stage(self, payload: dict[str, Any], fallback: str) -> str:
direct_stage = str(payload.get("stage") or "").strip().lower()
if direct_stage in {"build", "test", "deploy"}:
return direct_stage
raw = str(payload.get("raw") or "").lower()
for stage in ("build", "test", "deploy"):
if stage in raw:
return stage
return fallback
def _detect_root_cause(self, config_text: str, task: CICDTask) -> str:
normalized = self._normalize(config_text)
broken_token = self._normalize(str(task.metadata.get("broken_token", "")))
if broken_token and broken_token in normalized:
return task.actual_bug
if not self._is_yaml_valid(config_text):
return "workflow YAML is invalid"
fixed_token = self._normalize(str(task.metadata.get("fixed_token", "")))
if fixed_token and fixed_token not in normalized:
return f"missing expected fix token: {task.metadata.get('fixed_token')}"
return "configuration still deviates from expected pipeline behavior"
def _apply_edit(self, current_config: str, payload: dict[str, Any], task: CICDTask) -> tuple[str, str]:
candidate = current_config
edits: list[str] = []
new_config = payload.get("new_config")
if isinstance(new_config, str) and new_config.strip():
return new_config.strip(), "applied payload new_config"
raw = str(payload.get("raw") or "")
raw_lower = raw.lower()
replace_match = re.search(
r"replace\s+['\"]?(.+?)['\"]?\s+with\s+['\"]?(.+?)['\"]?\s*$",
raw,
flags=re.IGNORECASE,
)
if replace_match:
old = replace_match.group(1).strip()
new = replace_match.group(2).strip()
if old and old in candidate:
candidate = candidate.replace(old, new)
edits.append(f"replaced '{old}' with '{new}'")
if "checkout" in raw_lower and "actions/checkout@v4" not in candidate:
updated = self._ensure_checkout(candidate)
if updated != candidate:
candidate = updated
edits.append("inserted actions/checkout@v4 step")
if "permissions" in raw_lower or "actions: write" in raw_lower:
updated = self._ensure_actions_write(candidate)
if updated != candidate:
candidate = updated
edits.append("added actions: write permission")
if not edits and any(token in raw_lower for token in ("yaml", "indent", "syntax")):
updated = self._repair_yaml(candidate, task.expected_config)
if updated != candidate:
candidate = updated
edits.append("repaired YAML structure")
broken_token = str(task.metadata.get("broken_token", ""))
fixed_token = str(task.metadata.get("fixed_token", ""))
if not edits and broken_token and fixed_token and broken_token in candidate:
occurrence_count = candidate.count(broken_token)
if occurrence_count > 1:
candidate = task.expected_config
edits.append("applied canonical fix for ambiguous token")
elif fixed_token.strip().endswith(":"):
expected_block = self._extract_expected_block(task.expected_config, fixed_token)
if expected_block and expected_block not in candidate:
candidate = candidate.replace(broken_token, f"{broken_token}\n{expected_block}", 1)
edits.append("inserted expected YAML block")
else:
candidate = candidate.replace(broken_token, fixed_token, 1)
edits.append("applied metadata token replacement")
else:
expected_line = self._find_line_containing(task.expected_config, fixed_token)
replacement = expected_line.strip() if expected_line else fixed_token
candidate = candidate.replace(broken_token, replacement, 1)
edits.append("applied metadata token replacement")
if not edits and fixed_token and fixed_token not in candidate and not broken_token:
updated = self._append_missing_token(candidate, fixed_token)
if updated != candidate:
candidate = updated
edits.append("appended expected token")
if not edits and any(token in raw_lower for token in ("expected config", "apply expected", "canonical fix")):
candidate = task.expected_config
edits.append("replaced with expected task config")
summary = "; ".join(edits) if edits else "no-op edit"
return candidate, summary
def _ensure_checkout(self, config_text: str) -> str:
if "actions/checkout@v4" in config_text:
return config_text
marker = "steps:\n"
insert = " - uses: actions/checkout@v4\n"
if marker in config_text:
return config_text.replace(marker, marker + insert, 1)
return config_text
def _ensure_actions_write(self, config_text: str) -> str:
if "actions: write" in config_text:
return config_text
if "permissions:" in config_text:
lines = config_text.splitlines()
out: list[str] = []
inserted = False
for line in lines:
out.append(line)
if line.strip().startswith("permissions:") and not inserted:
continue
if line.strip().startswith("contents:") and not inserted:
indent = line[: len(line) - len(line.lstrip(" "))]
out.append(f"{indent}actions: write")
inserted = True
if inserted:
return "\n".join(out)
return "permissions:\n actions: write\n" + config_text
def _append_missing_token(self, config_text: str, token: str) -> str:
if not token or token in config_text:
return config_text
lower_token = token.lower()
if "actions/checkout@v4" in lower_token:
return self._ensure_checkout(config_text)
if "actions: write" in lower_token:
return self._ensure_actions_write(config_text)
return config_text + "\n" + token
def _repair_yaml(self, current_config: str, expected_config: str) -> str:
if self._is_yaml_valid(current_config):
return current_config
if expected_config and self._is_yaml_valid(expected_config):
return expected_config
return current_config
def _find_line_containing(self, config_text: str, token: str) -> str | None:
target = (token or "").strip()
if not target:
return None
for line in (config_text or "").splitlines():
if target in line:
return line
return None
def _extract_expected_block(self, config_text: str, token: str) -> str:
lines = (config_text or "").splitlines()
target = (token or "").strip()
if not target:
return ""
for idx, line in enumerate(lines):
if target not in line:
continue
base_indent = len(line) - len(line.lstrip(" "))
block = [line]
for next_line in lines[idx + 1 :]:
if not next_line.strip():
break
next_indent = len(next_line) - len(next_line.lstrip(" "))
if next_indent <= base_indent:
break
block.append(next_line)
return "\n".join(block)
return ""
def _simulate_stage(self, config_text: str, stage: str, task: CICDTask) -> tuple[bool, str]:
if not self._is_yaml_valid(config_text):
return False, "invalid workflow YAML"
expected_has_stage = self._stage_exists(task.expected_config, stage)
current_has_stage = self._stage_exists(config_text, stage)
if expected_has_stage and not current_has_stage:
return False, f"required stage '{stage}' is missing"
if not expected_has_stage and not current_has_stage:
return True, f"{stage} stage not required for this task"
normalized = self._normalize(config_text)
broken_token = self._normalize(str(task.metadata.get("broken_token", "")))
fixed_token = self._normalize(str(task.metadata.get("fixed_token", "")))
if self._contains_hacking_pattern(config_text):
return False, "unsafe shortcut pattern detected"
if stage == task.failure_stage and broken_token and broken_token in normalized:
return False, task.logs
if stage == task.failure_stage and fixed_token and fixed_token not in normalized:
return False, task.logs
commands = self._extract_commands(config_text)
if stage == "build":
build_tokens = ("npm ci", "npm install", "pip install", "go build", "mvn", "yarn install", "pnpm install")
if not any(any(token in cmd for token in build_tokens) for cmd in commands):
return False, "build stage has no install/build command"
if stage == "test":
test_tokens = ("npm test", "pytest", "go test", "mvn test", "yarn test", "pnpm test")
if not any(any(token in cmd for token in test_tokens) for cmd in commands):
return False, "test stage has no test command"
if stage == "deploy":
deploy_tokens = ("deploy", "publish", "upload-artifact", "release")
if not any(any(token in cmd for token in deploy_tokens) for cmd in commands):
return False, "deploy stage has no deployment command"
return True, f"{stage} stage passed"
def _run_tests(self, config_text: str, task: CICDTask) -> tuple[bool, str]:
if self._stage_exists(task.expected_config, "build"):
build_ok, build_logs = self._simulate_stage(config_text, "build", task)
if not build_ok:
return False, build_logs
if self._stage_exists(task.expected_config, "test"):
test_ok, test_logs = self._simulate_stage(config_text, "test", task)
if not test_ok:
return False, test_logs
similarity = SequenceMatcher(None, self._normalize(config_text), self._normalize(task.expected_config)).ratio()
if similarity < 0.45:
return False, "tests failed: fix diverges significantly from expected pipeline"
return True, "tests passed"
def _validate_current_fix(self, state: EnvironmentState) -> dict[str, Any]:
current = state.current_config
task = state.task
deterministic_score = self.reward_calculator.deterministic_grader.grade(
current,
task.expected_config,
metadata=task.metadata,
)
hidden_test_pass_rate = self.reward_calculator.hidden_test_runner.evaluate_fix(
fixed_config=current,
expected_config=task.expected_config,
metadata=task.metadata,
)
judge_scores = None
if self.reward_calculator.llm_judge is not None:
try:
judge_scores = self.reward_calculator.llm_judge.evaluate_fix(
task.broken_config,
current,
state.last_error,
)
except Exception:
judge_scores = None
tests_passed, test_logs = self._run_tests(current, task)
stage_ok, stage_logs = self._simulate_stage(current, task.failure_stage, task)
broken_token = self._normalize(str(task.metadata.get("broken_token", "")))
fixed_token = self._normalize(str(task.metadata.get("fixed_token", "")))
normalized_current = self._normalize(current)
token_constraints_met = True
if broken_token and broken_token in normalized_current:
token_constraints_met = False
if fixed_token and fixed_token not in normalized_current:
token_constraints_met = False
judge_average = 1.0
if isinstance(judge_scores, dict):
judge_average = (
float(judge_scores.get("correctness", 0.0))
+ float(judge_scores.get("minimalism", 0.0))
+ float(judge_scores.get("quality", 0.0))
) / 3.0
is_valid = (
tests_passed
and stage_ok
and token_constraints_met
and deterministic_score >= 0.72
and hidden_test_pass_rate >= 0.65
and judge_average >= 0.5
)
summary = (
f"deterministic={deterministic_score:.2f}; "
f"hidden={hidden_test_pass_rate:.2f}; "
f"tests={'pass' if tests_passed else 'fail'}; "
f"failure_stage={'pass' if stage_ok else 'fail'}"
)
if not tests_passed:
summary = test_logs
elif not stage_ok:
summary = stage_logs
elif not token_constraints_met:
summary = "validation failed: required bug-fix token constraints not satisfied"
return {
"deterministic_score": deterministic_score,
"hidden_test_pass_rate": hidden_test_pass_rate,
"judge_scores": judge_scores,
"tests_passed": tests_passed,
"is_valid": is_valid,
"summary": summary,
}
def _detect_hacking_attempt(self, tool: str, payload: dict[str, Any], config_text: str) -> bool:
payload_text = str(payload.get("raw") or "")
brute_force_detected = bool(
self._state and (self._state.consecutive_edit_actions >= 6 or self._state.failed_validations >= 3)
)
return (
brute_force_detected
or self._contains_hacking_pattern(payload_text)
or self._contains_hacking_pattern(config_text)
)
def _contains_hacking_pattern(self, text: str) -> bool:
if not text:
return False
patterns = (
*self.reward_calculator.anti_hacking_detector.STAGE_SKIP_PATTERNS,
*self.reward_calculator.anti_hacking_detector.FAKE_SUCCESS_PATTERNS,
*self.reward_calculator.anti_hacking_detector.HARDCODED_OUTPUT_PATTERNS,
)
return any(re.search(pattern, text, flags=re.IGNORECASE) for pattern in patterns)
def _extract_commands(self, config_text: str) -> list[str]:
commands: list[str] = []
for raw_line in config_text.splitlines():
line = raw_line.strip().lower()
if "- run:" in line:
commands.append(line.replace("- run:", "").strip())
elif line.startswith("- run "):
commands.append(line.replace("- run ", "", 1).strip())
return commands
def _is_yaml_valid(self, config_text: str) -> bool:
try:
parsed = yaml.safe_load(config_text)
except yaml.YAMLError:
return False
return isinstance(parsed, dict)
def _stage_exists(self, config_text: str, stage: str) -> bool:
try:
parsed = yaml.safe_load(config_text)
except yaml.YAMLError:
return False
if not isinstance(parsed, dict):
return False
jobs = parsed.get("jobs")
if isinstance(jobs, dict) and stage in jobs:
return True
stages = parsed.get("stages")
if isinstance(stages, dict) and stage in stages:
return True
if isinstance(stages, list) and stage in stages:
return True
return False
def _count_changed_lines(self, previous: str, current: str) -> int:
prev_lines = previous.splitlines()
curr_lines = current.splitlines()
changed = 0
max_len = max(len(prev_lines), len(curr_lines))
for idx in range(max_len):
left = prev_lines[idx] if idx < len(prev_lines) else ""
right = curr_lines[idx] if idx < len(curr_lines) else ""
if left != right:
changed += 1
return changed
def _normalize(self, value: str) -> str:
return re.sub(r"\s+", " ", value.strip().lower())