File size: 2,501 Bytes
269f632
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bc9d26
269f632
 
0bc9d26
 
 
 
 
 
269f632
 
 
 
 
 
 
 
 
 
 
0bc9d26
 
 
 
 
 
269f632
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""Strict (0, 1) grader for SQL repair tasks.

Phase 2 hard requirement: scores MUST be in the OPEN interval (0, 1).
Validator rejects exactly 0.0 and exactly 1.0. NaN/inf are also rejected,
so we coerce them to 0.5 (a neutral, in-range fallback).
"""
from __future__ import annotations

import math
from typing import Any

# Module-level constants — also used by inference.py for consistency.
SCORE_MIN: float = 1e-3   # 0.001 — strictly > 0
SCORE_MAX: float = 0.999  # strictly < 1


def strict_clamp(value: Any) -> float:
    """Coerce any input into a float strictly inside the OPEN interval (0, 1).

    NaN, inf, -inf, and non-numeric inputs all collapse to 0.5.

    Two hard invariants from Canary's Phase 2 failures:
      1. Never emit exactly 0.0 or 1.0 (validator rejects endpoints).
      2. After rounding for display (.4f), the value must STILL be strictly
         inside (0, 1). A tiny raw value like 0.00004 would round to 0.0000
         and trip the validator, so we floor to SCORE_MIN in that case.
    """
    try:
        s = float(value)
    except (TypeError, ValueError):
        return 0.5
    if math.isnan(s) or math.isinf(s):
        return 0.5
    if s <= 0.0:
        return SCORE_MIN
    if s >= 1.0:
        return SCORE_MAX
    rounded = round(s, 4)
    if rounded <= 0.0:
        return SCORE_MIN
    if rounded >= 1.0:
        return SCORE_MAX
    return rounded


def grade_task(state, task_id: str) -> float:
    """Score the current state of an EnvState for the given task.

    Score components (sum, then strict_clamp):
      - 0.05  : agent submitted at least one query
      - 0.25  : last query executed without error
      - 0.60  : result rows matched expected rows
      - 0.09  : efficiency bonus (faster solves score higher)

    Worst case (no submission):    0.000  -> clamped to 0.001
    Best case (1-step solve):      0.99   -> clamped to 0.99
    Wrong-result executes:         0.30   -> in range
    """
    from .env_core import MAX_STEPS  # local import avoids circular

    if state.task_id != task_id:
        return SCORE_MIN

    raw = 0.0
    if state.last_query:
        raw += 0.05
    if state.last_error is None and state.last_result is not None:
        raw += 0.25
    if state.last_result == state.expected_rows and state.expected_rows:
        raw += 0.60
    if state.solved and state.step_count > 0:
        bonus = 0.09 * max(0, MAX_STEPS - state.step_count) / MAX_STEPS
        raw += bonus

    return strict_clamp(raw)