File size: 11,304 Bytes
f4dcf31
 
 
e365f21
 
f4dcf31
 
e365f21
 
224ace0
 
e365f21
 
 
 
 
224ace0
 
e365f21
 
 
 
224ace0
 
e365f21
 
 
 
 
 
f4dcf31
 
e365f21
f4dcf31
 
 
e365f21
f4dcf31
 
 
224ace0
 
 
 
 
 
 
 
 
e365f21
 
 
 
f4dcf31
 
 
e365f21
224ace0
e365f21
 
 
 
 
 
 
 
 
f4dcf31
 
224ace0
f4dcf31
 
 
 
 
 
 
 
 
 
224ace0
e365f21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4dcf31
e365f21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4dcf31
 
 
e365f21
f4dcf31
 
 
 
e365f21
 
 
 
 
 
 
f4dcf31
 
 
224ace0
 
 
 
 
 
 
 
 
 
 
e365f21
224ace0
 
 
 
 
 
 
 
f4dcf31
 
e365f21
f4dcf31
 
 
 
224ace0
f4dcf31
e365f21
 
 
 
 
 
 
 
f4dcf31
 
 
 
224ace0
f4dcf31
224ace0
e365f21
f4dcf31
 
224ace0
 
 
 
 
 
 
 
 
e365f21
224ace0
 
 
 
f4dcf31
e365f21
 
 
 
 
 
 
 
 
f4dcf31
e365f21
 
 
 
 
 
 
f4dcf31
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
"""
Grader for the CI/CD Doctor environment.

Grade composition:
  fixes_applied_fraction * 0.20   proportional credit for structurally valid fixes
  pipeline_passed       +0.50     pipeline_status == "passed" (terminal)

balance_score() rewards STATE TRANSITIONS through the debugging workflow
and penalizes anti-patterns (blind edits, edit spam, stalling).

Milestone progression (ordinal):
  0  start          -- episode just began
  1  investigated   -- ran pipeline, saw what's broken
  2  diagnosed      -- read diagnostic files (error source + fix target)
  3  fix_applied    -- at least one structurally valid fix in filesystem
  4  verified       -- pipeline passes after fix

Transition rewards:
  0->1  +0.10   first pipeline run reveals the problem
  1->2  +0.10   reading diagnostic files to understand the error
  2->3  +0.15   applying a correct edit
  3->4  +0.50   (handled by grade() terminal bonus)

Penalties:
  stalling (same milestone, no progress)   -0.05
  blind_edit (edit without reading file)   -0.10
  edit_spam (>2 edits to same file)        -0.05 per extra
  regression (fix undone)                  -0.15
  idle pipeline run (no fs change)         -0.05
  over ideal step count                    -0.02 per step over
"""

import re
from dataclasses import dataclass, field

from models import PipelineState
from core.validation.validator import validate_ci_stages

CORRECT_FILE_EDITED_TOTAL = 0.2

MILESTONE_LEVEL: dict[str, int] = {
    "start": 0,
    "investigated": 1,
    "diagnosed": 2,
    "fix_applied": 3,
    "pipeline_passed": 4,
}

TRANSITION_REWARDS: dict[tuple[int, int], float] = {
    (0, 1): 0.10,   # start -> investigated
    (1, 2): 0.10,   # investigated -> diagnosed
    (2, 3): 0.15,   # diagnosed -> fix_applied
    # 3->4 is the pipeline_passed bonus in grade(), not here
}

PENALTIES: dict[str, float] = {
    "stalling": -0.05,
    "regression": -0.15,
    "blind_edit": -0.10,
    "edit_spam": -0.05,
    "idle_pipeline_run": -0.05,
    "over_ideal_step": -0.02,
}

BONUSES: dict[str, float] = {
    "correct_diagnosis": 0.10,
    "cross_reference": 0.05,
}


@dataclass
class StepContext:
    cmd_type: str
    filename: str | None = None
    files_read: set[str] = field(default_factory=set)
    fs_changed_since_last_run: bool = True
    step_count: int = 0
    max_steps: int = 15
    ideal_steps: int = 6
    pipeline_runs_since_last_edit: int = 0
    prev_milestone_level: int = 0
    edits_per_file: dict[str, int] = field(default_factory=dict)
    files_edited_without_reading: set[str] = field(default_factory=set)
    diagnosis_correct: bool = False
    cross_referenced: bool = False

def _validate_fix(filename: str, content: str, fix_desc: dict) -> bool:
    """
    Structurally validate that a fix has been correctly applied.
    Dispatches based on fix_desc["type"].
    """
    fix_type = fix_desc.get("type", "")

    if fix_type == "package_present":
        return _validate_package_present(content, fix_desc["package"])

    elif fix_type == "package_version":
        return _validate_package_version(content, fix_desc["package"], fix_desc["expected_version"])

    elif fix_type == "dockerfile_base":
        return _validate_dockerfile_base(content, fix_desc["expected_tag"])

    elif fix_type == "env_var_present":
        return _validate_env_var_present(content, fix_desc["variable"])

    elif fix_type == "config_value":
        return _validate_config_value(content, fix_desc["key"], fix_desc["expected_value"])

    elif fix_type == "makefile_command":
        return _validate_makefile_command(content, fix_desc["expected_command"])

    elif fix_type == "port_value":
        return _validate_port_value(content, fix_desc["expected_port"])

    elif fix_type == "ci_stage_order":
        return _validate_ci_stage_order(content)

    return False


def _validate_package_present(content: str, package: str) -> bool:
    """Check that package exists as a standalone line in requirements.txt."""
    for line in content.splitlines():
        line = line.strip()
        if not line or line.startswith("#"):
            continue
        # Strip version specifiers to get the package name
        pkg_name = re.split(r"[=<>!~\[]", line, 1)[0].strip().lower()
        if pkg_name == package.lower():
            return True
    return False


def _validate_package_version(content: str, package: str, expected_version: str) -> bool:
    """Check that a package is pinned to a compatible version."""
    for line in content.splitlines():
        line = line.strip()
        if not line or line.startswith("#"):
            continue
        pkg_name = re.split(r"[=<>!~\[]", line, 1)[0].strip().lower()
        if pkg_name == package.lower():
            if expected_version in line:
                return True
            if "==" not in line and "<" not in line:
                return True
    return False


def _validate_dockerfile_base(content: str, expected_tag: str) -> bool:
    """Check that the FROM instruction uses the expected Python tag."""
    for line in content.splitlines():
        line = line.strip()
        if line.upper().startswith("FROM"):
            # Match FROM python:<tag> with optional extras after tag
            if f"python:{expected_tag}" in line:
                # Reject alpine when expecting slim
                if expected_tag == "3.11-slim" and "alpine" in line:
                    return False
                if expected_tag == "3.11" and "alpine" in line:
                    return False
                return True
            return False
    return False


def _validate_env_var_present(content: str, variable: str) -> bool:
    """Check that a variable is defined as a key in key=value format."""
    for line in content.splitlines():
        line = line.strip()
        if not line or line.startswith("#"):
            continue
        if "=" in line:
            key = line.split("=", 1)[0].strip()
            if key == variable:
                return True
    return False


def _validate_config_value(content: str, key: str, expected_value: str) -> bool:
    """Check that a YAML-like config has the correct key:value."""
    pattern = re.compile(rf"^\s*{re.escape(key)}\s*:\s*(.+)\s*$", re.MULTILINE)
    match = pattern.search(content)
    if match:
        actual = match.group(1).strip().strip('"').strip("'")
        return actual == expected_value
    return False


def _validate_makefile_command(content: str, expected_command: str) -> bool:
    """Check that the test target in a Makefile uses the expected command."""
    has_expected = expected_command in content
    no_bad_flags = (
        "--collect-only" not in content
        and "--dry-run" not in content
        and "unittest" not in content
    )
    return has_expected and no_bad_flags


def _validate_port_value(content: str, expected_port: int) -> bool:
    """Check that a port field in YAML has the expected value."""
    pattern = re.compile(rf"^\s*port\s*:\s*(\d+)\s*$", re.MULTILINE)
    match = pattern.search(content)
    if match:
        return int(match.group(1)) == expected_port
    return False


def _validate_ci_stage_order(content: str) -> bool:
    """Check that ci.yml stages are in valid order using the validator."""
    try:
        validate_ci_stages(content)
        return True
    except ValueError:
        return False

def _check_file_integrity(original: str, current: str) -> float:
    """
    Returns a penalty (0.0 to -0.10) if a file has been corrupted beyond
    the necessary fix. Detects garbage appending and line duplication.
    """
    orig_lines = original.splitlines()
    curr_lines = current.splitlines()

    if len(orig_lines) > 0:
        growth_ratio = len(curr_lines) / len(orig_lines)
        if growth_ratio > 2.0:
            return -0.10
        elif growth_ratio > 1.5:
            return -0.05

    return 0.0

def _fixes_applied_fraction(state: PipelineState) -> float:
    """
    Fraction of answer_key fixes that are structurally valid in the filesystem.
    """
    fixes = state.answer_key.get("fixes", {})
    if not fixes:
        return 0.0
    applied = 0
    for filename, fix_desc in fixes.items():
        content = state.filesystem.get(filename, "")
        if isinstance(fix_desc, dict) and _validate_fix(filename, content, fix_desc):
            applied += 1
        elif isinstance(fix_desc, str) and fix_desc in content:
            applied += 1
    return applied / len(fixes)


def current_milestone_level(state: PipelineState) -> int:
    """
    Compute the highest milestone level the agent has reached.
    """
    if state.pipeline_status == "passed":
        return MILESTONE_LEVEL["pipeline_passed"]

    if _fixes_applied_fraction(state) > 0:
        return MILESTONE_LEVEL["fix_applied"]

    milestones = set(state.milestones)
    if "diagnosed" in milestones or "correct_file_located" in milestones or "logs_read" in milestones:
        return MILESTONE_LEVEL["diagnosed"]

    if "investigated" in milestones:
        return MILESTONE_LEVEL["investigated"]

    return MILESTONE_LEVEL["start"]


def grade(state: PipelineState) -> float:
    """
    Compute the total earned grade from state.
    """
    score = CORRECT_FILE_EDITED_TOTAL * _fixes_applied_fraction(state)

    if state.pipeline_status == "passed":
        score += 0.50

    original_fs = state.answer_key.get("original_filesystem", {})
    if original_fs:
        for filename in state.answer_key.get("fixes", {}):
            orig = original_fs.get(filename, "")
            curr = state.filesystem.get(filename, "")
            score += _check_file_integrity(orig, curr)

    return round(max(score, 0.0), 2)


def balance_score(state: PipelineState, ctx: StepContext) -> float:
    """
    Per-step shaped reward based on milestone TRANSITIONS, not commands.

    Rewards advancing through the debugging workflow. Penalizes stalling,
    regression, idle reruns, blind edits, edit spam, and inefficiency.
    """
    adjustment = 0.0
    cur_level = current_milestone_level(state)
    prev_level = ctx.prev_milestone_level

    if cur_level > prev_level:
        for from_lvl in range(prev_level, cur_level):
            to_lvl = from_lvl + 1
            adjustment += TRANSITION_REWARDS.get((from_lvl, to_lvl), 0.0)
    elif cur_level < prev_level:
        adjustment += PENALTIES["regression"]
    elif ctx.cmd_type not in ("cat", "echo_append", "sed", "diagnose"):
        adjustment += PENALTIES["stalling"]

    if ctx.cmd_type == "pipeline_run" and not ctx.fs_changed_since_last_run:
        adjustment += PENALTIES["idle_pipeline_run"]

    if ctx.cmd_type in ("echo_append", "sed") and ctx.filename:
        if ctx.filename in ctx.files_edited_without_reading:
            adjustment += PENALTIES["blind_edit"]

    if ctx.cmd_type in ("echo_append", "sed") and ctx.filename:
        edit_count = ctx.edits_per_file.get(ctx.filename, 0)
        if edit_count > 2:
            adjustment += PENALTIES["edit_spam"] * (edit_count - 2)

    if ctx.step_count > ctx.ideal_steps:
        adjustment += PENALTIES["over_ideal_step"]

    if ctx.diagnosis_correct:
        adjustment += BONUSES["correct_diagnosis"]

    if ctx.cross_referenced:
        adjustment += BONUSES["cross_reference"]

    return round(adjustment, 2)