cicd-debugger-env-final / env /graders /deterministic.py
Lishika's picture
final fix
32445fd
from __future__ import annotations
import re
from difflib import SequenceMatcher
from typing import Any
import yaml
class DeterministicGrader:
"""Deterministic correctness scoring for CI/CD config fixes."""
COMMAND_KEYS = {
"script",
"scripts",
"run",
"command",
"commands",
"steps",
"before_script",
"after_script",
}
BROKEN_COMMAND_PATTERNS = (
r"\bnpm\s+tset\b",
r"\bpyhton\b",
r"\bpip\s+isntall\b",
r"\bgo\s+tset\b",
)
def grade(self, current_config, expected_config, metadata=None):
metadata = metadata or {}
score = self._compute_score(current_config, expected_config, metadata)
is_valid = (
current_config.strip() == expected_config.strip()
)
return {
"reward": float(score),
"is_valid": bool(is_valid),
}
def _compute_score(self, current_config, expected_config, metadata=None):
metadata = metadata or {}
current_config = current_config or ""
expected_config = expected_config or ""
syntax_score = self._syntax_score(current_config)
functional_score = self._functional_score(current_config, expected_config, metadata)
similarity_score = self._similarity_score(current_config, expected_config)
total = (0.20 * syntax_score) + (0.60 * functional_score) + (0.20 * similarity_score)
if syntax_score == 0.0:
total = min(total, 0.30)
return round(self._clamp_01(total), 4)
def _syntax_score(self, config_text: str) -> float:
if not (config_text or "").strip():
return 0.0
try:
yaml.safe_load(config_text)
return 1.0
except yaml.YAMLError:
return 0.0
def _functional_score(self, current_config: str, expected_config: str, metadata: dict[str, Any]) -> float:
expected_commands = self._extract_commands(expected_config)
current_commands = self._extract_commands(current_config)
if expected_commands:
matched = 0
for expected in expected_commands:
if any(self._commands_match(expected, current) for current in current_commands):
matched += 1
command_score = matched / len(expected_commands)
else:
command_score = self._similarity_score(current_config, expected_config)
issue_score = self._issue_resolution_score(current_config, metadata)
broken_penalty = 0.35 if self._has_known_broken_command(current_config) else 0.0
combined = (0.80 * command_score) + (0.20 * issue_score) - broken_penalty
return self._clamp_01(combined)
def _issue_resolution_score(self, current_config: str, metadata: dict[str, Any]) -> float:
broken_token = self._normalize_text(str(metadata.get("broken_token", "")))
fixed_token = self._normalize_text(str(metadata.get("fixed_token", "")))
current_normalized = self._normalize_text(current_config)
if not broken_token and not fixed_token:
return 1.0
if broken_token and broken_token in current_normalized:
return 0.0
if fixed_token and fixed_token not in current_normalized:
return 0.0
return 1.0
def _extract_commands(self, config_text: str) -> list[str]:
commands: list[str] = []
try:
parsed = yaml.safe_load(config_text)
except yaml.YAMLError:
parsed = None
if parsed is not None:
self._walk_yaml(parsed, commands)
if not commands:
commands.extend(self._extract_commands_from_text(config_text))
deduped: list[str] = []
seen: set[str] = set()
for command in commands:
normalized = self._normalize_text(command)
if normalized and normalized not in seen:
seen.add(normalized)
deduped.append(normalized)
return deduped
def _walk_yaml(self, node: Any, commands: list[str]) -> None:
if isinstance(node, dict):
for key, value in node.items():
key_name = str(key).lower()
if key_name in self.COMMAND_KEYS:
commands.extend(self._extract_string_values(value))
self._walk_yaml(value, commands)
elif isinstance(node, list):
for item in node:
self._walk_yaml(item, commands)
def _extract_string_values(self, value: Any) -> list[str]:
if isinstance(value, str):
return [value]
if isinstance(value, list):
return [item for item in value if isinstance(item, str)]
if isinstance(value, dict):
output: list[str] = []
for nested in value.values():
output.extend(self._extract_string_values(nested))
return output
return []
def _extract_commands_from_text(self, config_text: str) -> list[str]:
commands: list[str] = []
for raw_line in (config_text or "").splitlines():
line = raw_line.strip()
if not line or line.startswith("#"):
continue
if ":" in line and not line.startswith("-") and line.endswith(":"):
continue
line = line.lstrip("-").strip()
if any(token in line.lower() for token in ("npm", "pytest", "python", "yarn", "pnpm", "go test", "mvn test")):
commands.append(line)
return commands
def _has_known_broken_command(self, config_text: str) -> bool:
return any(re.search(pattern, config_text or "", flags=re.IGNORECASE) for pattern in self.BROKEN_COMMAND_PATTERNS)
def _commands_match(self, expected: str, current: str) -> bool:
expected_normalized = self._normalize_text(expected)
current_normalized = self._normalize_text(current)
if expected_normalized == current_normalized:
return True
if expected_normalized in current_normalized:
return True
if current_normalized in expected_normalized and len(current_normalized) > 6:
return True
return False
def _similarity_score(self, current_config: str, expected_config: str) -> float:
left = self._normalize_text(current_config)
right = self._normalize_text(expected_config)
if not left and not right:
return 1.0
if not left or not right:
return 0.0
return self._clamp_01(SequenceMatcher(None, left, right).ratio())
def _normalize_text(self, value: str) -> str:
return re.sub(r"\s+", " ", (value or "")).strip().lower()
def _clamp_01(self, value: float) -> float:
return max(0.0, min(1.0, float(value)))