from __future__ import annotations import re from difflib import SequenceMatcher from typing import Any import yaml class DeterministicGrader: """Deterministic correctness scoring for CI/CD config fixes.""" COMMAND_KEYS = { "script", "scripts", "run", "command", "commands", "steps", "before_script", "after_script", } BROKEN_COMMAND_PATTERNS = ( r"\bnpm\s+tset\b", r"\bpyhton\b", r"\bpip\s+isntall\b", r"\bgo\s+tset\b", ) def grade(self, current_config, expected_config, metadata=None): metadata = metadata or {} score = self._compute_score(current_config, expected_config, metadata) is_valid = ( current_config.strip() == expected_config.strip() ) return { "reward": float(score), "is_valid": bool(is_valid), } def _compute_score(self, current_config, expected_config, metadata=None): metadata = metadata or {} current_config = current_config or "" expected_config = expected_config or "" syntax_score = self._syntax_score(current_config) functional_score = self._functional_score(current_config, expected_config, metadata) similarity_score = self._similarity_score(current_config, expected_config) total = (0.20 * syntax_score) + (0.60 * functional_score) + (0.20 * similarity_score) if syntax_score == 0.0: total = min(total, 0.30) return round(self._clamp_01(total), 4) def _syntax_score(self, config_text: str) -> float: if not (config_text or "").strip(): return 0.0 try: yaml.safe_load(config_text) return 1.0 except yaml.YAMLError: return 0.0 def _functional_score(self, current_config: str, expected_config: str, metadata: dict[str, Any]) -> float: expected_commands = self._extract_commands(expected_config) current_commands = self._extract_commands(current_config) if expected_commands: matched = 0 for expected in expected_commands: if any(self._commands_match(expected, current) for current in current_commands): matched += 1 command_score = matched / len(expected_commands) else: command_score = self._similarity_score(current_config, expected_config) issue_score = self._issue_resolution_score(current_config, metadata) broken_penalty = 0.35 if self._has_known_broken_command(current_config) else 0.0 combined = (0.80 * command_score) + (0.20 * issue_score) - broken_penalty return self._clamp_01(combined) def _issue_resolution_score(self, current_config: str, metadata: dict[str, Any]) -> float: broken_token = self._normalize_text(str(metadata.get("broken_token", ""))) fixed_token = self._normalize_text(str(metadata.get("fixed_token", ""))) current_normalized = self._normalize_text(current_config) if not broken_token and not fixed_token: return 1.0 if broken_token and broken_token in current_normalized: return 0.0 if fixed_token and fixed_token not in current_normalized: return 0.0 return 1.0 def _extract_commands(self, config_text: str) -> list[str]: commands: list[str] = [] try: parsed = yaml.safe_load(config_text) except yaml.YAMLError: parsed = None if parsed is not None: self._walk_yaml(parsed, commands) if not commands: commands.extend(self._extract_commands_from_text(config_text)) deduped: list[str] = [] seen: set[str] = set() for command in commands: normalized = self._normalize_text(command) if normalized and normalized not in seen: seen.add(normalized) deduped.append(normalized) return deduped def _walk_yaml(self, node: Any, commands: list[str]) -> None: if isinstance(node, dict): for key, value in node.items(): key_name = str(key).lower() if key_name in self.COMMAND_KEYS: commands.extend(self._extract_string_values(value)) self._walk_yaml(value, commands) elif isinstance(node, list): for item in node: self._walk_yaml(item, commands) def _extract_string_values(self, value: Any) -> list[str]: if isinstance(value, str): return [value] if isinstance(value, list): return [item for item in value if isinstance(item, str)] if isinstance(value, dict): output: list[str] = [] for nested in value.values(): output.extend(self._extract_string_values(nested)) return output return [] def _extract_commands_from_text(self, config_text: str) -> list[str]: commands: list[str] = [] for raw_line in (config_text or "").splitlines(): line = raw_line.strip() if not line or line.startswith("#"): continue if ":" in line and not line.startswith("-") and line.endswith(":"): continue line = line.lstrip("-").strip() if any(token in line.lower() for token in ("npm", "pytest", "python", "yarn", "pnpm", "go test", "mvn test")): commands.append(line) return commands def _has_known_broken_command(self, config_text: str) -> bool: return any(re.search(pattern, config_text or "", flags=re.IGNORECASE) for pattern in self.BROKEN_COMMAND_PATTERNS) def _commands_match(self, expected: str, current: str) -> bool: expected_normalized = self._normalize_text(expected) current_normalized = self._normalize_text(current) if expected_normalized == current_normalized: return True if expected_normalized in current_normalized: return True if current_normalized in expected_normalized and len(current_normalized) > 6: return True return False def _similarity_score(self, current_config: str, expected_config: str) -> float: left = self._normalize_text(current_config) right = self._normalize_text(expected_config) if not left and not right: return 1.0 if not left or not right: return 0.0 return self._clamp_01(SequenceMatcher(None, left, right).ratio()) def _normalize_text(self, value: str) -> str: return re.sub(r"\s+", " ", (value or "")).strip().lower() def _clamp_01(self, value: float) -> float: return max(0.0, min(1.0, float(value)))