File size: 9,136 Bytes
740ac53
 
 
 
 
 
 
 
d3b224f
740ac53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
909dfde
740ac53
 
 
 
 
909dfde
740ac53
 
 
 
 
 
909dfde
 
740ac53
 
 
 
 
c6888af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a818334
740ac53
a818334
a91fb6a
a818334
 
740ac53
a818334
 
 
 
740ac53
a91fb6a
a818334
740ac53
 
 
 
17a43d0
 
740ac53
 
17a43d0
 
 
 
740ac53
 
 
 
c6888af
740ac53
c6888af
 
 
 
 
 
 
740ac53
 
c6888af
740ac53
 
 
 
 
 
 
 
 
c6888af
740ac53
 
 
 
 
 
 
243b472
 
740ac53
236cf5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
740ac53
 
 
 
 
a818334
740ac53
 
 
 
 
 
d3b224f
740ac53
 
 
 
a818334
 
236cf5b
 
1288c52
 
 
 
d3b224f
740ac53
c6888af
 
 
 
 
 
740ac53
c6888af
740ac53
d3b224f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
"""
Grading logic for WhyDidItFail.

grade() is the single entry point. It scores the full episode trajectory:

  diagnosis_score  (0.00 – 0.70)  was the diagnosis correct?
  evidence_score   (0.00 – 0.15)  did the agent inspect the right sources?
  efficiency_score (0.00 – 0.15)  did the agent act without waste?
  fix_bonus        (0.00 – 0.15)  did the agent suggest a valid fix? (bonus, capped at 0.90)

Step-level partial rewards are returned by the environment's step() on every action,
giving the agent a signal over the full trajectory before the episode ends.
"""

EXACT_KEYWORDS: dict[str, list[str]] = {
    "exploding_gradients":           ["exploding gradients", "exploding"],
    "learning_rate_too_high":        ["learning rate too high", "lr too high"],
    "overfitting":                   ["overfitting", "overfit"],
    "underfitting":                  ["underfitting", "underfit"],
    "learning_rate_too_low":         ["learning rate too low", "lr too low"],
    "missing_regularization":        ["missing regularization", "no regularization", "lack of regularization"],
    "batch_size_too_small":          ["batch size too small", "small batch size"],
    "optimizer_misconfiguration":    ["optimizer misconfiguration", "optimizer misconfig", "wrong optimizer"],
    "vanishing_gradients":           ["vanishing gradients", "vanishing"],
    "dying_relu":                    ["dying relu", "dead relu"],
    "bad_weight_initialization":     ["bad weight initialization", "poor initialization", "wrong initialization"],
    "lr_scheduler_misconfiguration": ["lr scheduler misconfiguration", "scheduler misconfiguration"],
}

CATEGORY_KEYWORDS: dict[str, list[str]] = {
    "exploding_gradients":           ["nan", "gradient", "overflow", "diverge"],
    "learning_rate_too_high":        ["learning rate", "lr", "oscillat", "unstable"],
    "overfitting":                   ["generalization", "val loss", "memoriz"],
    "underfitting":                  ["plateau", "not learning", "too simple", "high bias"],
    "learning_rate_too_low":         ["slow converge", "converge", "too slow"],
    "missing_regularization":        ["regulariz", "dropout", "weight decay"],
    "batch_size_too_small":          ["batch", "noisy gradient", "gradient noise"],
    "optimizer_misconfiguration":    ["optimizer", "momentum", "sgd"],
    "vanishing_gradients":           ["gradient", "vanish", "sigmoid", "stuck"],
    "dying_relu":                    ["relu", "dead", "zero gradient", "activation"],
    "bad_weight_initialization":     ["initializ", "weight init", "nan"],
    "lr_scheduler_misconfiguration": ["scheduler", "spike", "periodic", "step_lr"],
}


def _diagnosis_score(diagnosis: str, scenario: dict) -> float:
    """
    0.70 β€” exact keyword match
    0.35 β€” category / fuzzy match
    0.00 β€” wrong
    """
    correct = scenario.get("correct_diagnosis", "")
    d = diagnosis.strip().lower()

    score = 0.0
    exact_matched = False

    # exact keyword matches (strong signal)
    for kw in EXACT_KEYWORDS.get(correct, [correct]):
        if kw in d:
            score += 0.4
            exact_matched = True

    # category matches (weaker signal)
    for kw in CATEGORY_KEYWORDS.get(correct, []):
        if kw in d:
            score += 0.1

    # penalize vague answers only when no exact match found
    if not exact_matched and len(d.split()) < 3:
        score -= 0.1

    return max(0.0, min(0.7, score))


def _evidence_diagnosis_penalty(
    diagnosis: str,
    scenario: dict,
    inspection_order: list[str],
) -> float:
    """
    Penalises reasoning failure: the agent had evidence but drew the wrong conclusion.

    All required sources inspected, wrong diagnosis:  βˆ’0.10  (clear reasoning failure)
    Some required sources inspected, wrong diagnosis: βˆ’0.05  (partial reasoning failure)
    No required sources inspected, wrong diagnosis:    0.00  (evidence_score handles this)
    Correct diagnosis:                                 0.00  (no penalty)
    """
    correct = scenario.get("correct_diagnosis", "")
    d = diagnosis.strip().lower()

    is_correct = any(kw in d for kw in EXACT_KEYWORDS.get(correct, [correct]))
    if is_correct:
        return 0.0

    required         = set(scenario.get("required_sources", ["logs"]))
    inspected_req    = set(inspection_order) & required

    if inspected_req == required:
        return -0.10   # had everything, still wrong
    if inspected_req:
        return -0.05   # partial evidence, still wrong
    return 0.0         # blind submission β€” evidence_score already penalises


def _evidence_score(inspection_order: list[str], required: set[str]) -> float:
    """
    +0.08 per required source inspected  (max +0.24 for 3 sources)
    βˆ’0.10 for each required source NOT inspected at submit time
    βˆ’0.02 per irrelevant source inspected
    Clamped to [βˆ’0.15, +0.25].
    """
    inspected_set = set(inspection_order)
    relevant   = inspected_set & required
    missing    = required - inspected_set
    irrelevant = inspected_set - required

    score = (len(relevant) * 0.08) - (len(missing) * 0.10) - (len(irrelevant) * 0.02)
    return max(-0.15, min(0.25, score))


def _efficiency_score(steps_taken: int, min_steps: int) -> float:
    """
    0.15 at minimum steps.
    Decays for extra steps (wasted) or missing steps (early submission).
    min_steps = number of required sources + 1 (the submit action).
    """
    if steps_taken < min_steps:
        missing_steps = min_steps - steps_taken
        return max(0.0, 0.15 - 0.05 * missing_steps)
    extra_steps = steps_taken - min_steps
    penalty = 0.02 * (extra_steps ** 1.2)
    return max(0.0, 0.15 - penalty)


def _fix_score(suggested_fix: str | None, scenario: dict) -> float:
    """
    Uniform fix score applied to every scenario.

      βˆ’0.05 β€” no fix provided          (always expected)
       0.00 β€” fix provided but wrong
      +0.05 β€” β‰₯30% keyword match
      +0.10 β€” β‰₯60% keyword match
      +0.15 β€” all keywords match
    """
    if not suggested_fix:
        return -0.05   # fix is always expected; omitting it costs points

    fix         = suggested_fix.strip().lower()
    correct_fix = scenario.get("correct_fix", "").strip().lower()
    stop        = {"to", "a", "the", "and", "or", "use", "set", "by"}
    keywords    = [w for w in correct_fix.split() if w not in stop and len(w) > 2]

    if not keywords:
        return 0.0

    ratio = sum(1 for kw in keywords if kw in fix) / len(keywords)

    if ratio == 1.0:
        return 0.15
    elif ratio >= 0.6:
        return 0.10
    elif ratio >= 0.3:
        return 0.05
    return 0.0


def _ordering_bonus(inspection_order: list[str], required_sources: list[str]) -> float:
    """
    +0.05 if the agent inspected required sources in the canonical order.

    Canonical order is defined by required_sources in the scenario
    (always a prefix of logs β†’ config β†’ gradients).

    Only the order of required sources matters β€” irrelevant sources inspected
    in between are ignored when checking sequence.
    """
    required_set = set(required_sources)
    # Subsequence of inspection_order containing only required sources
    inspected_required = [s for s in inspection_order if s in required_set]

    # Check against canonical order, limited to what was actually inspected
    canonical = [s for s in required_sources if s in inspected_required]

    return 0.05 if inspected_required == canonical else 0.0


def grade(
    diagnosis: str,
    suggested_fix: str | None = None,
    scenario: dict | None = None,
    steps_taken: int = 0,
    inspection_order: list[str] | None = None,
    difficulty: str = "easy",   # kept for API compat β€” not used in scoring logic
) -> float:
    """
    Single unified grade function. Scores every scenario identically.

    Total score = diagnosis_score + evidence_score + efficiency_score + fix_bonus
                  clamped to [0.10, 0.90].

    Max achievable without fix:  0.70 + 0.15 + 0.15       = 1.00
    Max achievable with fix:     0.70 + 0.15 + 0.15 + 0.15 = 1.00  (capped)
    """
    scenario         = scenario or {}
    inspection_order = inspection_order or []
    required_sources = scenario.get("required_sources", ["logs"])   # ordered list
    required         = set(required_sources)                        # set for membership checks
    min_steps        = len(required) + 1          # inspect all required sources + submit
    max_steps        = len(required) * 3 + 2      # hard ceiling; exceeding it = total failure

    if steps_taken > max_steps:
        return 0.10

    d_score  = _diagnosis_score(diagnosis, scenario)
    ed_penalty = _evidence_diagnosis_penalty(diagnosis, scenario, inspection_order)
    e_score  = _evidence_score(inspection_order, required)
    f_score  = _efficiency_score(steps_taken, min_steps)
    b_score  = _fix_score(suggested_fix, scenario)
    o_bonus  = _ordering_bonus(inspection_order, required_sources)

    total = d_score + ed_penalty + e_score + f_score + b_score + o_bonus

    return round(max(0.10, min(0.90, total)), 2)