File size: 11,944 Bytes
eb1ebe6
 
 
8fa7af1
43f41de
 
8fa7af1
43f41de
 
 
 
8fa7af1
 
43f41de
eb1ebe6
 
 
 
43f41de
eb1ebe6
 
8fa7af1
 
 
5869d56
8fa7af1
5869d56
eb1ebe6
 
 
 
 
43f41de
 
 
 
 
8fa7af1
 
 
 
43f41de
 
8fa7af1
 
43f41de
 
8fa7af1
 
 
 
 
 
 
43f41de
 
eb1ebe6
 
 
 
 
 
 
 
 
 
 
 
 
43f41de
eb1ebe6
 
 
 
 
 
 
 
 
 
 
 
8fa7af1
 
 
 
 
 
43f41de
 
 
 
 
eb1ebe6
43f41de
 
eb1ebe6
 
 
 
 
 
 
 
 
43f41de
8fa7af1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43f41de
8fa7af1
 
43f41de
 
 
eb1ebe6
43f41de
 
8fa7af1
 
 
 
 
 
43f41de
eb1ebe6
43f41de
8fa7af1
 
 
 
 
 
 
43f41de
 
 
 
eb1ebe6
 
 
 
 
 
 
 
 
 
 
 
 
43f41de
eb1ebe6
43f41de
 
eb1ebe6
 
 
 
43f41de
eb1ebe6
43f41de
eb1ebe6
 
43f41de
 
 
eb1ebe6
43f41de
eb1ebe6
 
 
 
 
 
 
 
 
43f41de
eb1ebe6
 
 
 
 
43f41de
eb1ebe6
 
 
 
 
 
 
 
43f41de
eb1ebe6
8fa7af1
eb1ebe6
 
 
43f41de
eb1ebe6
 
8fa7af1
eb1ebe6
43f41de
eb1ebe6
8fa7af1
 
 
 
 
 
eb1ebe6
 
43f41de
 
 
eb1ebe6
 
 
 
 
 
 
 
 
8fa7af1
 
eb1ebe6
43f41de
 
8fa7af1
 
 
43f41de
8fa7af1
 
 
 
 
 
eb1ebe6
 
 
8fa7af1
 
 
 
 
eb1ebe6
8fa7af1
 
eb1ebe6
43f41de
8fa7af1
 
 
 
eb1ebe6
 
43f41de
8fa7af1
43f41de
8fa7af1
 
43f41de
 
 
 
 
eb1ebe6
8fa7af1
 
eb1ebe6
8fa7af1
eb1ebe6
 
 
43f41de
 
8fa7af1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43f41de
 
 
 
 
 
 
 
 
 
8fa7af1
43f41de
 
 
 
 
8fa7af1
 
 
43f41de
8fa7af1
 
43f41de
8fa7af1
43f41de
 
5869d56
43f41de
 
 
8fa7af1
43f41de
 
 
 
 
8fa7af1
 
 
 
 
43f41de
 
 
8fa7af1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
"""Reward components for the generation phase.

After exploration, the agent generates marimo/manim code. Rewards measure
validity, task alignment, artifact structure, and research usage.

Scoring model:
  quality = weighted sum of (validity, task alignment, structure, research usage)
  total   = quality × gate

Gates (multiplicative):
  - code doesn't parse  → total = 0
  - static check fails  → total = quality × small static-fail multiplier
  - code doesn't run    → total = quality × execution-fail multiplier
  - code runs           → total = quality × 1.0
"""

from __future__ import annotations

import re
from typing import TYPE_CHECKING

from .sandbox import ast_parses, check_marimo, extract_scene_class

try:
    from ..constants import MAX_REPAIR_REWARD, clamp_action_reward
except ImportError:  # pragma: no cover - supports direct test execution
    from constants import MAX_REPAIR_REWARD, clamp_action_reward

if TYPE_CHECKING:
    from ..task_bank import Task


# ---------------------------------------------------------------------------
# Component weights
# ---------------------------------------------------------------------------

_WEIGHTS = {
    "validity": 0.15,
    "task_alignment": 0.30,
    "structure": 0.30,
    "research_usage": 0.25,
}

GATE_STATIC_FAIL = 0.12
GATE_RUNS_FAIL = 0.30  # quality multiplier when static checks pass but execution fails


_STOPWORDS = {
    "about", "after", "again", "against", "also", "because", "before", "being",
    "between", "class", "code", "construct", "could", "from", "have", "into",
    "like", "make", "more", "most", "only", "self", "show", "step", "than",
    "that", "their", "then", "there", "these", "this", "through", "using",
    "value", "where", "with", "would",
}


# ---------------------------------------------------------------------------
# Individual scorers
# ---------------------------------------------------------------------------


def keyword_coverage(code: str, keywords_csv: str) -> float:
    """Fraction of task keywords mentioned in the code (case-insensitive)."""
    if not keywords_csv:
        return 0.0
    keywords = [k.strip().lower() for k in keywords_csv.split(",") if k.strip()]
    if not keywords:
        return 0.0
    code_lower = code.lower()
    return sum(1 for kw in keywords if kw in code_lower) / len(keywords)


def format_match(chosen_format: str, task: Task) -> float:
    """1.0 if format matches the task's preferred format, else 0.3.

    If the task has no preferred format (None), any choice scores 1.0.
    """
    if task.preferred_format is None:
        return 1.0
    return 1.0 if chosen_format == task.preferred_format else 0.3


def marimo_structure(
    code: str,
    task: Task,
    static_check_passed: bool | None = None,
    error_codes: list[str] | None = None,
) -> float:
    """Score structural quality of a marimo notebook (0-1).

    Additive scoring for good patterns, penalties from ``marimo check``
    for breaking violations (duplicate defs, cycles, etc.).
    """
    score = 0.0

    # Positive signals
    if "import marimo" in code or "from marimo" in code:
        score += 0.2
    if "marimo.App" in code or "mo.App" in code:
        score += 0.1
    cell_count = code.count("@app.cell")
    if cell_count >= 3:
        score += 0.2
    elif cell_count >= 1:
        score += 0.1

    ui_patterns = [
        "mo.md(",
        "mo.Html",
        "mo.accordion",
        "mo.callout",
        "mo.hstack(",
        "mo.vstack(",
        "mo.ui.slider",
        "mo.ui.dropdown",
        "mo.ui.table",
        "mo.ui.dataframe",
    ]
    score += min(0.22, sum(0.06 for p in ui_patterns if p in code))

    reactive_plot_patterns = [
        "mo.ui.matplotlib(",
        "mo.ui.plotly(",
        "mo.ui.altair_chart(",
    ]
    raw_plot_patterns = [
        "plt.",
        "matplotlib.pyplot",
        "px.",
        "plotly.",
        "alt.Chart",
    ]
    if "mo.ui.matplotlib(plt.gca())" in code:
        score += 0.24 if task.data_available else 0.16
    elif any(p in code for p in reactive_plot_patterns):
        score += 0.18 if task.data_available else 0.10
    elif any(p in code for p in raw_plot_patterns):
        score += 0.08 if task.data_available else 0.03
        score -= 0.08

    if "plt.tight_layout(" in code:
        score -= 0.12

    if "np.math." in code:
        score -= 0.15

    tier_thresholds = {"advanced": 6, "intermediate": 4, "beginner": 2}
    if cell_count >= tier_thresholds.get(task.tier, 2):
        score += 0.1

    # Marimo check: penalize breaking violations, bonus for clean code
    if static_check_passed is None:
        passed, _, violations = check_marimo(code)
    else:
        passed = static_check_passed
        violations = error_codes or []

    if passed:
        score += 0.1
    else:
        penalty = {
            "MB002": 0.35,
            "MB003": 0.4,
            "MB005": 0.25,
            "MB001": 0.3,
            "MB004": 0.2,
        }
        for v in violations:
            score -= penalty.get(v, 0.15)

    return max(0.0, min(1.0, score))


def manim_structure(code: str, task: Task) -> float:
    """Score structural quality of a manim scene (0-1)."""
    from .sandbox import extract_scene_class

    score = 0.0
    if "from manim" in code or "import manim" in code:
        score += 0.2
    if extract_scene_class(code) is not None:
        score += 0.2
    if "def construct" in code:
        score += 0.1

    anim_patterns = [
        "self.play(", "self.wait(", "Create(", "FadeIn(", "FadeOut(",
        "Transform(", "Write(", "MoveToTarget", "Indicate(",
        "ReplacementTransform(",
    ]
    anim_hits = sum(1 for p in anim_patterns if p in code)
    score += min(0.3, anim_hits * 0.05)

    math_patterns = ["MathTex(", "Tex(", "Axes(", "NumberPlane(", "Graph("]
    if any(p in code for p in math_patterns):
        score += 0.1

    tier_thresholds = {"advanced": 6, "intermediate": 4, "beginner": 2}
    if anim_hits >= tier_thresholds.get(task.tier, 2):
        score += 0.1

    return min(1.0, score)


def narration_score(narration: str, fmt: str) -> float:
    """Score narration quality. Only relevant for manim format."""
    if fmt != "manim":
        return 1.0
    if not narration or not narration.strip():
        return 0.0
    words = narration.split()
    score = 0.0
    if len(words) >= 30:
        score += 0.4
    elif len(words) >= 10:
        score += 0.2
    scene_markers = ["scene", "step", "first", "next", "then", "finally", "now"]
    score += min(0.3, sum(0.1 for m in scene_markers if m in narration.lower()))
    if len(words) >= 50:
        score += 0.3
    elif len(words) >= 20:
        score += 0.15
    return min(1.0, score)


def context_usage(code: str, accumulated_context: list[str]) -> float:
    """Score whether the generated code incorporates research findings (0-1)."""
    if not accumulated_context:
        return 0.0

    context_words: set[str] = set()
    for ctx in accumulated_context:
        context_words.update(_tokens(ctx))

    if not context_words:
        return 0.0

    code_words = set(_tokens(code))
    overlap = code_words & context_words
    if not overlap:
        return 0.0
    # Do not reward broad generic overlap too heavily; a few meaningful terms
    # should help, but strong usage needs a substantial slice of the context.
    target = min(max(len(context_words), 1), 24)
    return min(1.0, len(overlap) / target * 2.5)


# ---------------------------------------------------------------------------
# Main reward function
# ---------------------------------------------------------------------------


def compute_generate_reward(
    code: str,
    fmt: str,
    narration: str,
    task: Task,
    exec_success: bool,
    accumulated_context: list[str],
    static_check_passed: bool | None = None,
    error_codes: list[str] | None = None,
) -> tuple[float, dict]:
    """Compute the generation-phase reward. Returns (total, components).

    ``python_parse_valid``, ``static_check_passed``, and ``code_runs`` act as
    gates. ``code_valid`` means the artifact is valid for its target format,
    not merely that the Python AST parses.
    """
    parse_valid = ast_parses(code)
    c_parse = 1.0 if parse_valid else 0.0
    if static_check_passed is None:
        static_check_passed = _infer_static_check(code, fmt, parse_valid)

    c_static = 1.0 if parse_valid and static_check_passed else 0.0
    c_runs = 1.0 if exec_success else 0.0
    c_coverage = keyword_coverage(code, task.keywords)
    c_format = format_match(fmt, task)
    if fmt == "marimo":
        c_struct = marimo_structure(code, task, static_check_passed, error_codes)
    else:
        scene_structure = manim_structure(code, task)
        c_struct = 0.75 * scene_structure + 0.25 * narration_score(narration, fmt)
    c_ctx = context_usage(code, accumulated_context)
    c_validity = _validity_score(c_parse, c_static, c_runs)
    c_alignment = 0.75 * c_coverage + 0.25 * c_format

    quality = (
        _WEIGHTS["validity"] * c_validity
        + _WEIGHTS["task_alignment"] * c_alignment
        + _WEIGHTS["structure"] * c_struct
        + _WEIGHTS["research_usage"] * c_ctx
    )

    # Apply gates
    if c_parse == 0.0:
        total = 0.0
    elif c_static == 0.0:
        total = quality * _static_fail_multiplier(error_codes or [])
    elif c_runs == 0.0:
        total = quality * GATE_RUNS_FAIL
    else:
        total = quality

    components = {
        "validity": round(c_validity, 3),
        "task_alignment": round(c_alignment, 3),
        "structure": round(c_struct, 3),
        "research_usage": round(c_ctx, 3),
        "generate_total": round(total, 4),
    }
    return total, components


def _infer_static_check(code: str, fmt: str, parse_valid: bool) -> bool:
    if not parse_valid:
        return False
    if fmt == "marimo":
        passed, _, _ = check_marimo(code)
        return passed
    if fmt == "manim":
        return extract_scene_class(code) is not None
    return False


def _static_fail_multiplier(error_codes: list[str]) -> float:
    """Keep parseable but structurally invalid artifacts from scoring high."""
    if any(code.startswith("MB") for code in error_codes):
        return GATE_STATIC_FAIL
    return min(GATE_RUNS_FAIL, GATE_STATIC_FAIL * 1.5)


def _validity_score(
    parse_valid: float,
    static_check_passed: float,
    code_runs: float,
) -> float:
    if parse_valid == 0.0:
        return 0.0
    if static_check_passed == 0.0:
        return 0.35
    if code_runs == 0.0:
        return 0.70
    return 1.0


def adjust_repair_reward(
    base_reward: float,
    *,
    repair_success: bool,
    previous_error_codes: list[str],
    new_error_codes: list[str],
    previous_code: str,
    repaired_code: str,
) -> tuple[float, dict]:
    """Discount repaired code but reward fixing the specific prior failure."""
    changed = _fingerprint(previous_code) != _fingerprint(repaired_code)
    fixed_prior = bool(previous_error_codes) and not (
        set(previous_error_codes) & set(new_error_codes)
    )

    if repair_success:
        reward = base_reward * 0.60
        reward += 0.08 if fixed_prior else 0.0
        reward += 0.04 if changed else 0.0
    else:
        reward = base_reward * 0.25
        reward += 0.04 if fixed_prior else 0.0

    if not changed:
        reward -= 0.15

    reward = min(MAX_REPAIR_REWARD, clamp_action_reward(reward))
    return reward, {
        "repair_success": 1.0 if repair_success else 0.0,
        "fixed_prior_errors": 1.0 if fixed_prior else 0.0,
        "changed_code": 1.0 if changed else 0.0,
        "repair_total": round(reward, 4),
    }


def _tokens(text: str) -> list[str]:
    return [
        w
        for w in re.findall(r"\w+", text.lower())
        if len(w) > 3 and w not in _STOPWORDS
    ]


def _fingerprint(code: str) -> str:
    return re.sub(r"\s+", "", code)