Spaces:
Sleeping
Sleeping
File size: 11,304 Bytes
f4dcf31 e365f21 f4dcf31 e365f21 224ace0 e365f21 224ace0 e365f21 224ace0 e365f21 f4dcf31 e365f21 f4dcf31 e365f21 f4dcf31 224ace0 e365f21 f4dcf31 e365f21 224ace0 e365f21 f4dcf31 224ace0 f4dcf31 224ace0 e365f21 f4dcf31 e365f21 f4dcf31 e365f21 f4dcf31 e365f21 f4dcf31 224ace0 e365f21 224ace0 f4dcf31 e365f21 f4dcf31 224ace0 f4dcf31 e365f21 f4dcf31 224ace0 f4dcf31 224ace0 e365f21 f4dcf31 224ace0 e365f21 224ace0 f4dcf31 e365f21 f4dcf31 e365f21 f4dcf31 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 | """
Grader for the CI/CD Doctor environment.
Grade composition:
fixes_applied_fraction * 0.20 proportional credit for structurally valid fixes
pipeline_passed +0.50 pipeline_status == "passed" (terminal)
balance_score() rewards STATE TRANSITIONS through the debugging workflow
and penalizes anti-patterns (blind edits, edit spam, stalling).
Milestone progression (ordinal):
0 start -- episode just began
1 investigated -- ran pipeline, saw what's broken
2 diagnosed -- read diagnostic files (error source + fix target)
3 fix_applied -- at least one structurally valid fix in filesystem
4 verified -- pipeline passes after fix
Transition rewards:
0->1 +0.10 first pipeline run reveals the problem
1->2 +0.10 reading diagnostic files to understand the error
2->3 +0.15 applying a correct edit
3->4 +0.50 (handled by grade() terminal bonus)
Penalties:
stalling (same milestone, no progress) -0.05
blind_edit (edit without reading file) -0.10
edit_spam (>2 edits to same file) -0.05 per extra
regression (fix undone) -0.15
idle pipeline run (no fs change) -0.05
over ideal step count -0.02 per step over
"""
import re
from dataclasses import dataclass, field
from models import PipelineState
from core.validation.validator import validate_ci_stages
CORRECT_FILE_EDITED_TOTAL = 0.2
MILESTONE_LEVEL: dict[str, int] = {
"start": 0,
"investigated": 1,
"diagnosed": 2,
"fix_applied": 3,
"pipeline_passed": 4,
}
TRANSITION_REWARDS: dict[tuple[int, int], float] = {
(0, 1): 0.10, # start -> investigated
(1, 2): 0.10, # investigated -> diagnosed
(2, 3): 0.15, # diagnosed -> fix_applied
# 3->4 is the pipeline_passed bonus in grade(), not here
}
PENALTIES: dict[str, float] = {
"stalling": -0.05,
"regression": -0.15,
"blind_edit": -0.10,
"edit_spam": -0.05,
"idle_pipeline_run": -0.05,
"over_ideal_step": -0.02,
}
BONUSES: dict[str, float] = {
"correct_diagnosis": 0.10,
"cross_reference": 0.05,
}
@dataclass
class StepContext:
cmd_type: str
filename: str | None = None
files_read: set[str] = field(default_factory=set)
fs_changed_since_last_run: bool = True
step_count: int = 0
max_steps: int = 15
ideal_steps: int = 6
pipeline_runs_since_last_edit: int = 0
prev_milestone_level: int = 0
edits_per_file: dict[str, int] = field(default_factory=dict)
files_edited_without_reading: set[str] = field(default_factory=set)
diagnosis_correct: bool = False
cross_referenced: bool = False
def _validate_fix(filename: str, content: str, fix_desc: dict) -> bool:
"""
Structurally validate that a fix has been correctly applied.
Dispatches based on fix_desc["type"].
"""
fix_type = fix_desc.get("type", "")
if fix_type == "package_present":
return _validate_package_present(content, fix_desc["package"])
elif fix_type == "package_version":
return _validate_package_version(content, fix_desc["package"], fix_desc["expected_version"])
elif fix_type == "dockerfile_base":
return _validate_dockerfile_base(content, fix_desc["expected_tag"])
elif fix_type == "env_var_present":
return _validate_env_var_present(content, fix_desc["variable"])
elif fix_type == "config_value":
return _validate_config_value(content, fix_desc["key"], fix_desc["expected_value"])
elif fix_type == "makefile_command":
return _validate_makefile_command(content, fix_desc["expected_command"])
elif fix_type == "port_value":
return _validate_port_value(content, fix_desc["expected_port"])
elif fix_type == "ci_stage_order":
return _validate_ci_stage_order(content)
return False
def _validate_package_present(content: str, package: str) -> bool:
"""Check that package exists as a standalone line in requirements.txt."""
for line in content.splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
# Strip version specifiers to get the package name
pkg_name = re.split(r"[=<>!~\[]", line, 1)[0].strip().lower()
if pkg_name == package.lower():
return True
return False
def _validate_package_version(content: str, package: str, expected_version: str) -> bool:
"""Check that a package is pinned to a compatible version."""
for line in content.splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
pkg_name = re.split(r"[=<>!~\[]", line, 1)[0].strip().lower()
if pkg_name == package.lower():
if expected_version in line:
return True
if "==" not in line and "<" not in line:
return True
return False
def _validate_dockerfile_base(content: str, expected_tag: str) -> bool:
"""Check that the FROM instruction uses the expected Python tag."""
for line in content.splitlines():
line = line.strip()
if line.upper().startswith("FROM"):
# Match FROM python:<tag> with optional extras after tag
if f"python:{expected_tag}" in line:
# Reject alpine when expecting slim
if expected_tag == "3.11-slim" and "alpine" in line:
return False
if expected_tag == "3.11" and "alpine" in line:
return False
return True
return False
return False
def _validate_env_var_present(content: str, variable: str) -> bool:
"""Check that a variable is defined as a key in key=value format."""
for line in content.splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
if "=" in line:
key = line.split("=", 1)[0].strip()
if key == variable:
return True
return False
def _validate_config_value(content: str, key: str, expected_value: str) -> bool:
"""Check that a YAML-like config has the correct key:value."""
pattern = re.compile(rf"^\s*{re.escape(key)}\s*:\s*(.+)\s*$", re.MULTILINE)
match = pattern.search(content)
if match:
actual = match.group(1).strip().strip('"').strip("'")
return actual == expected_value
return False
def _validate_makefile_command(content: str, expected_command: str) -> bool:
"""Check that the test target in a Makefile uses the expected command."""
has_expected = expected_command in content
no_bad_flags = (
"--collect-only" not in content
and "--dry-run" not in content
and "unittest" not in content
)
return has_expected and no_bad_flags
def _validate_port_value(content: str, expected_port: int) -> bool:
"""Check that a port field in YAML has the expected value."""
pattern = re.compile(rf"^\s*port\s*:\s*(\d+)\s*$", re.MULTILINE)
match = pattern.search(content)
if match:
return int(match.group(1)) == expected_port
return False
def _validate_ci_stage_order(content: str) -> bool:
"""Check that ci.yml stages are in valid order using the validator."""
try:
validate_ci_stages(content)
return True
except ValueError:
return False
def _check_file_integrity(original: str, current: str) -> float:
"""
Returns a penalty (0.0 to -0.10) if a file has been corrupted beyond
the necessary fix. Detects garbage appending and line duplication.
"""
orig_lines = original.splitlines()
curr_lines = current.splitlines()
if len(orig_lines) > 0:
growth_ratio = len(curr_lines) / len(orig_lines)
if growth_ratio > 2.0:
return -0.10
elif growth_ratio > 1.5:
return -0.05
return 0.0
def _fixes_applied_fraction(state: PipelineState) -> float:
"""
Fraction of answer_key fixes that are structurally valid in the filesystem.
"""
fixes = state.answer_key.get("fixes", {})
if not fixes:
return 0.0
applied = 0
for filename, fix_desc in fixes.items():
content = state.filesystem.get(filename, "")
if isinstance(fix_desc, dict) and _validate_fix(filename, content, fix_desc):
applied += 1
elif isinstance(fix_desc, str) and fix_desc in content:
applied += 1
return applied / len(fixes)
def current_milestone_level(state: PipelineState) -> int:
"""
Compute the highest milestone level the agent has reached.
"""
if state.pipeline_status == "passed":
return MILESTONE_LEVEL["pipeline_passed"]
if _fixes_applied_fraction(state) > 0:
return MILESTONE_LEVEL["fix_applied"]
milestones = set(state.milestones)
if "diagnosed" in milestones or "correct_file_located" in milestones or "logs_read" in milestones:
return MILESTONE_LEVEL["diagnosed"]
if "investigated" in milestones:
return MILESTONE_LEVEL["investigated"]
return MILESTONE_LEVEL["start"]
def grade(state: PipelineState) -> float:
"""
Compute the total earned grade from state.
"""
score = CORRECT_FILE_EDITED_TOTAL * _fixes_applied_fraction(state)
if state.pipeline_status == "passed":
score += 0.50
original_fs = state.answer_key.get("original_filesystem", {})
if original_fs:
for filename in state.answer_key.get("fixes", {}):
orig = original_fs.get(filename, "")
curr = state.filesystem.get(filename, "")
score += _check_file_integrity(orig, curr)
return round(max(score, 0.0), 2)
def balance_score(state: PipelineState, ctx: StepContext) -> float:
"""
Per-step shaped reward based on milestone TRANSITIONS, not commands.
Rewards advancing through the debugging workflow. Penalizes stalling,
regression, idle reruns, blind edits, edit spam, and inefficiency.
"""
adjustment = 0.0
cur_level = current_milestone_level(state)
prev_level = ctx.prev_milestone_level
if cur_level > prev_level:
for from_lvl in range(prev_level, cur_level):
to_lvl = from_lvl + 1
adjustment += TRANSITION_REWARDS.get((from_lvl, to_lvl), 0.0)
elif cur_level < prev_level:
adjustment += PENALTIES["regression"]
elif ctx.cmd_type not in ("cat", "echo_append", "sed", "diagnose"):
adjustment += PENALTIES["stalling"]
if ctx.cmd_type == "pipeline_run" and not ctx.fs_changed_since_last_run:
adjustment += PENALTIES["idle_pipeline_run"]
if ctx.cmd_type in ("echo_append", "sed") and ctx.filename:
if ctx.filename in ctx.files_edited_without_reading:
adjustment += PENALTIES["blind_edit"]
if ctx.cmd_type in ("echo_append", "sed") and ctx.filename:
edit_count = ctx.edits_per_file.get(ctx.filename, 0)
if edit_count > 2:
adjustment += PENALTIES["edit_spam"] * (edit_count - 2)
if ctx.step_count > ctx.ideal_steps:
adjustment += PENALTIES["over_ideal_step"]
if ctx.diagnosis_correct:
adjustment += BONUSES["correct_diagnosis"]
if ctx.cross_referenced:
adjustment += BONUSES["cross_reference"]
return round(adjustment, 2)
|