CI_CD_Doctor / core /grading /grader.py
samrat-rm's picture
Upload folder using huggingface_hub
e365f21 verified
"""
Grader for the CI/CD Doctor environment.
Grade composition:
fixes_applied_fraction * 0.20 proportional credit for structurally valid fixes
pipeline_passed +0.50 pipeline_status == "passed" (terminal)
balance_score() rewards STATE TRANSITIONS through the debugging workflow
and penalizes anti-patterns (blind edits, edit spam, stalling).
Milestone progression (ordinal):
0 start -- episode just began
1 investigated -- ran pipeline, saw what's broken
2 diagnosed -- read diagnostic files (error source + fix target)
3 fix_applied -- at least one structurally valid fix in filesystem
4 verified -- pipeline passes after fix
Transition rewards:
0->1 +0.10 first pipeline run reveals the problem
1->2 +0.10 reading diagnostic files to understand the error
2->3 +0.15 applying a correct edit
3->4 +0.50 (handled by grade() terminal bonus)
Penalties:
stalling (same milestone, no progress) -0.05
blind_edit (edit without reading file) -0.10
edit_spam (>2 edits to same file) -0.05 per extra
regression (fix undone) -0.15
idle pipeline run (no fs change) -0.05
over ideal step count -0.02 per step over
"""
import re
from dataclasses import dataclass, field
from models import PipelineState
from core.validation.validator import validate_ci_stages
CORRECT_FILE_EDITED_TOTAL = 0.2
MILESTONE_LEVEL: dict[str, int] = {
"start": 0,
"investigated": 1,
"diagnosed": 2,
"fix_applied": 3,
"pipeline_passed": 4,
}
TRANSITION_REWARDS: dict[tuple[int, int], float] = {
(0, 1): 0.10, # start -> investigated
(1, 2): 0.10, # investigated -> diagnosed
(2, 3): 0.15, # diagnosed -> fix_applied
# 3->4 is the pipeline_passed bonus in grade(), not here
}
PENALTIES: dict[str, float] = {
"stalling": -0.05,
"regression": -0.15,
"blind_edit": -0.10,
"edit_spam": -0.05,
"idle_pipeline_run": -0.05,
"over_ideal_step": -0.02,
}
BONUSES: dict[str, float] = {
"correct_diagnosis": 0.10,
"cross_reference": 0.05,
}
@dataclass
class StepContext:
cmd_type: str
filename: str | None = None
files_read: set[str] = field(default_factory=set)
fs_changed_since_last_run: bool = True
step_count: int = 0
max_steps: int = 15
ideal_steps: int = 6
pipeline_runs_since_last_edit: int = 0
prev_milestone_level: int = 0
edits_per_file: dict[str, int] = field(default_factory=dict)
files_edited_without_reading: set[str] = field(default_factory=set)
diagnosis_correct: bool = False
cross_referenced: bool = False
def _validate_fix(filename: str, content: str, fix_desc: dict) -> bool:
"""
Structurally validate that a fix has been correctly applied.
Dispatches based on fix_desc["type"].
"""
fix_type = fix_desc.get("type", "")
if fix_type == "package_present":
return _validate_package_present(content, fix_desc["package"])
elif fix_type == "package_version":
return _validate_package_version(content, fix_desc["package"], fix_desc["expected_version"])
elif fix_type == "dockerfile_base":
return _validate_dockerfile_base(content, fix_desc["expected_tag"])
elif fix_type == "env_var_present":
return _validate_env_var_present(content, fix_desc["variable"])
elif fix_type == "config_value":
return _validate_config_value(content, fix_desc["key"], fix_desc["expected_value"])
elif fix_type == "makefile_command":
return _validate_makefile_command(content, fix_desc["expected_command"])
elif fix_type == "port_value":
return _validate_port_value(content, fix_desc["expected_port"])
elif fix_type == "ci_stage_order":
return _validate_ci_stage_order(content)
return False
def _validate_package_present(content: str, package: str) -> bool:
"""Check that package exists as a standalone line in requirements.txt."""
for line in content.splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
# Strip version specifiers to get the package name
pkg_name = re.split(r"[=<>!~\[]", line, 1)[0].strip().lower()
if pkg_name == package.lower():
return True
return False
def _validate_package_version(content: str, package: str, expected_version: str) -> bool:
"""Check that a package is pinned to a compatible version."""
for line in content.splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
pkg_name = re.split(r"[=<>!~\[]", line, 1)[0].strip().lower()
if pkg_name == package.lower():
if expected_version in line:
return True
if "==" not in line and "<" not in line:
return True
return False
def _validate_dockerfile_base(content: str, expected_tag: str) -> bool:
"""Check that the FROM instruction uses the expected Python tag."""
for line in content.splitlines():
line = line.strip()
if line.upper().startswith("FROM"):
# Match FROM python:<tag> with optional extras after tag
if f"python:{expected_tag}" in line:
# Reject alpine when expecting slim
if expected_tag == "3.11-slim" and "alpine" in line:
return False
if expected_tag == "3.11" and "alpine" in line:
return False
return True
return False
return False
def _validate_env_var_present(content: str, variable: str) -> bool:
"""Check that a variable is defined as a key in key=value format."""
for line in content.splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
if "=" in line:
key = line.split("=", 1)[0].strip()
if key == variable:
return True
return False
def _validate_config_value(content: str, key: str, expected_value: str) -> bool:
"""Check that a YAML-like config has the correct key:value."""
pattern = re.compile(rf"^\s*{re.escape(key)}\s*:\s*(.+)\s*$", re.MULTILINE)
match = pattern.search(content)
if match:
actual = match.group(1).strip().strip('"').strip("'")
return actual == expected_value
return False
def _validate_makefile_command(content: str, expected_command: str) -> bool:
"""Check that the test target in a Makefile uses the expected command."""
has_expected = expected_command in content
no_bad_flags = (
"--collect-only" not in content
and "--dry-run" not in content
and "unittest" not in content
)
return has_expected and no_bad_flags
def _validate_port_value(content: str, expected_port: int) -> bool:
"""Check that a port field in YAML has the expected value."""
pattern = re.compile(rf"^\s*port\s*:\s*(\d+)\s*$", re.MULTILINE)
match = pattern.search(content)
if match:
return int(match.group(1)) == expected_port
return False
def _validate_ci_stage_order(content: str) -> bool:
"""Check that ci.yml stages are in valid order using the validator."""
try:
validate_ci_stages(content)
return True
except ValueError:
return False
def _check_file_integrity(original: str, current: str) -> float:
"""
Returns a penalty (0.0 to -0.10) if a file has been corrupted beyond
the necessary fix. Detects garbage appending and line duplication.
"""
orig_lines = original.splitlines()
curr_lines = current.splitlines()
if len(orig_lines) > 0:
growth_ratio = len(curr_lines) / len(orig_lines)
if growth_ratio > 2.0:
return -0.10
elif growth_ratio > 1.5:
return -0.05
return 0.0
def _fixes_applied_fraction(state: PipelineState) -> float:
"""
Fraction of answer_key fixes that are structurally valid in the filesystem.
"""
fixes = state.answer_key.get("fixes", {})
if not fixes:
return 0.0
applied = 0
for filename, fix_desc in fixes.items():
content = state.filesystem.get(filename, "")
if isinstance(fix_desc, dict) and _validate_fix(filename, content, fix_desc):
applied += 1
elif isinstance(fix_desc, str) and fix_desc in content:
applied += 1
return applied / len(fixes)
def current_milestone_level(state: PipelineState) -> int:
"""
Compute the highest milestone level the agent has reached.
"""
if state.pipeline_status == "passed":
return MILESTONE_LEVEL["pipeline_passed"]
if _fixes_applied_fraction(state) > 0:
return MILESTONE_LEVEL["fix_applied"]
milestones = set(state.milestones)
if "diagnosed" in milestones or "correct_file_located" in milestones or "logs_read" in milestones:
return MILESTONE_LEVEL["diagnosed"]
if "investigated" in milestones:
return MILESTONE_LEVEL["investigated"]
return MILESTONE_LEVEL["start"]
def grade(state: PipelineState) -> float:
"""
Compute the total earned grade from state.
"""
score = CORRECT_FILE_EDITED_TOTAL * _fixes_applied_fraction(state)
if state.pipeline_status == "passed":
score += 0.50
original_fs = state.answer_key.get("original_filesystem", {})
if original_fs:
for filename in state.answer_key.get("fixes", {}):
orig = original_fs.get(filename, "")
curr = state.filesystem.get(filename, "")
score += _check_file_integrity(orig, curr)
return round(max(score, 0.0), 2)
def balance_score(state: PipelineState, ctx: StepContext) -> float:
"""
Per-step shaped reward based on milestone TRANSITIONS, not commands.
Rewards advancing through the debugging workflow. Penalizes stalling,
regression, idle reruns, blind edits, edit spam, and inefficiency.
"""
adjustment = 0.0
cur_level = current_milestone_level(state)
prev_level = ctx.prev_milestone_level
if cur_level > prev_level:
for from_lvl in range(prev_level, cur_level):
to_lvl = from_lvl + 1
adjustment += TRANSITION_REWARDS.get((from_lvl, to_lvl), 0.0)
elif cur_level < prev_level:
adjustment += PENALTIES["regression"]
elif ctx.cmd_type not in ("cat", "echo_append", "sed", "diagnose"):
adjustment += PENALTIES["stalling"]
if ctx.cmd_type == "pipeline_run" and not ctx.fs_changed_since_last_run:
adjustment += PENALTIES["idle_pipeline_run"]
if ctx.cmd_type in ("echo_append", "sed") and ctx.filename:
if ctx.filename in ctx.files_edited_without_reading:
adjustment += PENALTIES["blind_edit"]
if ctx.cmd_type in ("echo_append", "sed") and ctx.filename:
edit_count = ctx.edits_per_file.get(ctx.filename, 0)
if edit_count > 2:
adjustment += PENALTIES["edit_spam"] * (edit_count - 2)
if ctx.step_count > ctx.ideal_steps:
adjustment += PENALTIES["over_ideal_step"]
if ctx.diagnosis_correct:
adjustment += BONUSES["correct_diagnosis"]
if ctx.cross_referenced:
adjustment += BONUSES["cross_reference"]
return round(adjustment, 2)