File size: 16,421 Bytes
5850885
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
"""SQLDrift composed rubric (correctness, drift, speedup, gates, DBA tax).

Six child rubrics, one per reward component (:data:`REWARD_COMPONENT_KEYS`):

    r_correct      correctness vs ground-truth hash, gated on β‰₯ 1.2Γ— speedup
    r_drift        bonus/penalty for (not) adapting to post-drift identifiers
    r_speedup      tanh-shaped speedup bonus, gated on r_correct > 0
    r_step_tax     base step tax plus bounded productive-action rebates
    r_gatekeepers  escalating tool-error / repeat-failing / no-op penalties
    r_consult_dba  DBA-oracle consult penalties (feature-flagged; 0 when off)

All child rubrics share a single ``ctx_provider`` that returns the private
:class:`engine.runtime.RuntimeEpisodeState`; this keeps the rubric
stateless relative to the environment and makes each component
individually unit-testable with a synthesized triple
``(RuntimeEpisodeState, SqlDriftAction, SqlDriftObservation)``.
"""

from __future__ import annotations

import math
import re
from collections.abc import Callable
from typing import TYPE_CHECKING

from openenv.core.rubrics import Rubric

from models import (
    SqlDriftAction,
    SqlDriftObservation,
    SubmitRewriteResult,
    ToolError,
    ToolName,
)

if TYPE_CHECKING:
    from engine.runtime import RuntimeEpisodeState


# Tunable thresholds pulled out to module level so tests and future
# curriculum code share a single source of truth.
SPEEDUP_MIN: float = 1.2
SPEEDUP_CAP_FOR_INFTY: float = 64.0
STEP_TAX: float = -0.03
STEP_REBATE_LIST_TABLES: float = 0.04
STEP_REBATE_DESCRIBE_TABLE: float = 0.06
STEP_REBATE_SAMPLE_ROWS: float = 0.05
STEP_REBATE_RUN_QUERY: float = 0.04
STEP_REBATE_EXPLAIN_QUERY: float = 0.04
STEP_REBATE_READ_CHANGELOG: float = 0.08

GATE_MALFORMED_TOOL_CALL: float = -0.3
GATE_CONSECUTIVE_TOOL_ERROR: float = -0.1
GATE_REPEAT_FAILING_QUERY: float = -0.1
GATE_BASELINE_VERBATIM: float = -0.2
_MAX_ESCALATION_STEPS: int = 3

CONSULT_ESCALATION: tuple[float, float, float] = (-0.1, -0.3, -0.8)


# =============================================================================
# Helpers
# =============================================================================


def canonicalize_sql(sql: str) -> str:
    """Whitespace/case/alias-insensitive canonical form.

    Uses sqlglot's duckdb dialect round-trip so reorders/reformats agree;
    falls back to a simple whitespace fold if sqlglot rejects the SQL
    (e.g. during the baseline-verbatim check on an agent-submitted blob).
    """
    try:
        import sqlglot

        expr = sqlglot.parse_one(sql, dialect="duckdb")
        return expr.sql(dialect="duckdb", comments=False, normalize=True).strip().lower()
    except Exception:
        return " ".join(sql.lower().split())


_IDENT_RE = re.compile(r"\b([A-Za-z_][A-Za-z0-9_]*)\b")


def _extract_identifiers(sql: str) -> frozenset[str]:
    """Loose case-preserving identifier harvest.

    Strips string literals before tokenizing so e.g. `WHERE x = 'ACTIVE'`
    does not leak 'ACTIVE' into the identifier set.
    """
    stripped = re.sub(r"'[^']*'", "", sql)
    return frozenset(m.group(0) for m in _IDENT_RE.finditer(stripped))


def _extract_column_references(sql: str) -> frozenset[str]:
    """Column-reference identifiers via sqlglot AST (excludes alias labels).

    For drift-adapt scoring, ``SELECT account_id AS user_id`` references
    the new column and merely labels the output β€” the alias must not
    count as a surviving pre-drift marker. Falls back to the regex
    extractor on parse failure so malformed SQL still scores something.
    """
    try:
        import sqlglot
        expr = sqlglot.parse_one(sql, dialect="duckdb")
    except Exception:
        return _extract_identifiers(sql)
    if expr is None:
        return _extract_identifiers(sql)
    return frozenset(
        n.name for n in expr.walk() if isinstance(n, sqlglot.exp.Column) and n.name
    )


def _literals(sql: str) -> frozenset[str]:
    """All `'..'`-quoted string literals in `sql`."""
    return frozenset(re.findall(r"'([^']*)'", sql))


_AGENT_MS_EPSILON: float = 1e-6


def effective_speedup(rt: RuntimeEpisodeState) -> float | None:
    """Compute ``effective_speedup`` from the runtime snapshot (speedup rubric).

    Single source of truth for the speedup number used across the code
    base β€” rubric scoring, the skill library's ``avg_speedup`` field,
    and the training evaluator all route through here so divergent
    definitions cannot drift apart.

    Returns:

    * ``None`` β€” no submission has happened yet. Callers that need a
      numeric default (e.g. the rubric, which is only invoked
      post-submission) should verify ``rt.submitted`` first.
    * ``+∞`` β€” drift has fired and the pre-drift baseline SQL no longer
      executes against the post-drift schema; any correct submission is
      definitionally "infinitely faster" than an unrunnable baseline.
    * ``baseline_ms / max(agent_ms, Ξ΅)`` otherwise. A tiny ``Ξ΅`` clamp
      guards against zero/negative timings from sub-microsecond queries
      and treats them as "as fast as possible" (very large, finite
      speedup) rather than silently collapsing the reward.
    """
    if rt.submitted_runtime_ms is None:
        return None
    if rt.drift_fired and rt.baseline_postdrift_raises:
        return math.inf
    agent_ms = max(rt.submitted_runtime_ms, _AGENT_MS_EPSILON)
    return rt.baseline_runtime_ms / agent_ms


def _speedup_for_reward(rt: RuntimeEpisodeState) -> float:
    """Rubric-facing speedup that never returns ``None``.

    The rubric is only invoked once ``rt.submitted`` is True, so
    :func:`effective_speedup` cannot return ``None`` from these call
    sites; we assert that and coerce to ``0.0`` defensively if it ever
    does (prevents a silent ``TypeError`` inside the reward math).
    """
    val = effective_speedup(rt)
    return 0.0 if val is None else val


def _is_terminal_submission(
    action: SqlDriftAction,
    observation: SqlDriftObservation,
    rt: RuntimeEpisodeState,
) -> bool:
    """True iff this step is the submission step.

    The env sets ``done=True`` on a successful submission and attaches a
    :class:`SubmitRewriteResult`; we gate terminal rewards on both
    signals so repeated rubric calls on an unchanged state don't
    double-score.
    """
    if not rt.submitted:
        return False
    if action.tool != ToolName.SUBMIT_REWRITE:
        return False
    tr = observation.tool_result
    return isinstance(tr, SubmitRewriteResult)


def _gt_hash(rt: RuntimeEpisodeState) -> str | None:
    if rt.drift_fired and rt.gt_result_hash_postdrift is not None:
        return rt.gt_result_hash_postdrift
    return rt.gt_result_hash_predrift


# =============================================================================
# Child rubrics
# =============================================================================


class _CtxChild(Rubric):
    """Base child rubric sharing the ctx provider."""

    def __init__(self, ctx_provider: Callable[[], RuntimeEpisodeState]) -> None:
        super().__init__()
        object.__setattr__(self, "_ctx", ctx_provider)

    def forward(
        self,
        action: SqlDriftAction,
        observation: SqlDriftObservation,
    ) -> float:
        raise NotImplementedError


class Correctness(_CtxChild):
    """Terminal-only correctness: +1.0 / +0.5 / -1.0 by hash and speedup."""

    def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float:
        rt = self._ctx()
        if not _is_terminal_submission(action, observation, rt):
            return 0.0
        gt = _gt_hash(rt)
        agent_hash = rt.submitted_result_hash
        if gt is None or agent_hash is None:
            return 0.0
        if agent_hash != gt:
            return -1.0
        speedup = _speedup_for_reward(rt)
        if speedup >= SPEEDUP_MIN:
            return 1.0
        return 0.5


class DriftAdapt(_CtxChild):
    """+0.5 for a correctly-adapted submission, -0.5 for a pre-drift-only
    submission after drift fired.

    Adaptation is detected against two scenario-declared identifier sets:

    * ``postdrift_identifiers`` β€” identifiers/literals that only a
      correct post-drift rewrite will introduce (e.g. ``account_id``
      after a column rename, ``'ACTIVE'`` after an enum split).
    * ``predrift_identifiers`` β€” identifiers/literals a submission that
      ignored the drift would retain (e.g. ``user_id``, ``'active'``,
      the ISO anchor strings under date-format drift).

    A submission is considered "adapted" when it either surfaces a
    post-drift marker *or* the scenario declares no distinctive
    post-drift identifiers (e.g. date-format drift keeps the same
    column name and only the literal shape changes) AND it does not
    retain any pre-drift marker. The penalty fires only when the
    submission still carries pre-drift markers AND produced the wrong
    post-drift result β€” so a merely partial rewrite (neither pre-
    nor post-flavoured) never earns a penalty it can't diagnose.
    """

    def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float:
        rt = self._ctx()
        # Only drift scenarios participate.
        if rt.gt_result_hash_postdrift is None and not rt.drift_fired:
            return 0.0
        if not _is_terminal_submission(action, observation, rt):
            return 0.0
        inst = getattr(rt, "instance", None)
        post_ids: frozenset[str] = (
            getattr(inst, "postdrift_identifiers", frozenset()) or frozenset()
        )
        pre_ids: frozenset[str] = getattr(inst, "predrift_identifiers", frozenset()) or frozenset()
        agent_sql = rt.submitted_sql or ""

        idents = _extract_column_references(agent_sql)
        literals = _literals(agent_sql)
        markers = idents | literals
        uses_post = bool(post_ids & markers)
        uses_pre = bool(pre_ids & markers)

        # Treat "no distinctive post identifier" scenarios as
        # satisfied by absence-of-pre (see class docstring).
        adapted = (uses_post or not post_ids) and not uses_pre

        agent_hash = rt.submitted_result_hash
        gt_post = rt.gt_result_hash_postdrift

        if rt.drift_fired and agent_hash == gt_post and adapted:
            return 0.5
        if rt.drift_fired and uses_pre and agent_hash != gt_post:
            return -0.5
        return 0.0


class Speedup(_CtxChild):
    """Terminal-only, gated on r_correct > 0: 0.3Β·tanh(log2(speedup)/3)."""

    def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float:
        rt = self._ctx()
        if not _is_terminal_submission(action, observation, rt):
            return 0.0
        gt = _gt_hash(rt)
        if gt is None or rt.submitted_result_hash != gt:
            return 0.0
        raw = _speedup_for_reward(rt)
        if math.isinf(raw):
            raw = SPEEDUP_CAP_FOR_INFTY
        if raw <= 1.0:
            return 0.0
        return 0.3 * math.tanh(math.log2(raw) / 3.0)


class StepTax(_CtxChild):
    """Base step tax plus bounded rebates for productive exploration."""

    def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float:
        rt = self._ctx()
        if _is_terminal_submission(action, observation, rt):
            return 0.0
        rebate = max(0.0, float(getattr(rt, "last_step_productive_rebate", 0.0)))
        return STEP_TAX + rebate


class Gatekeepers(_CtxChild):
    """Sum of three independent penalties; repeats escalate up to a cap."""

    def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float:
        rt = self._ctx()
        penalty = 0.0
        # 1. Malformed / failed tool call β€” ToolError emitted this step.
        if isinstance(observation.tool_result, ToolError):
            penalty += GATE_MALFORMED_TOOL_CALL
            streak = max(0, int(getattr(rt, "consecutive_tool_errors", 0)) - 1)
            penalty += GATE_CONSECUTIVE_TOOL_ERROR * min(streak, _MAX_ESCALATION_STEPS)
        # 2. Repeat failing query β€” env marks the flag on the runtime
        #    state immediately before invoking the rubric.
        repeats = max(0, int(getattr(rt, "last_step_repeat_failing_query_count", 0)) - 1)
        if repeats > 0:
            penalty += GATE_REPEAT_FAILING_QUERY * min(repeats, _MAX_ESCALATION_STEPS)
        # 3. Baseline-verbatim submission (Rev-3 gate β€” stacks with
        #    correctness's +0.5 partial to cap the no-op rewrite at +0.3).
        if (
            action.tool == ToolName.SUBMIT_REWRITE
            and _is_terminal_submission(action, observation, rt)
            and rt.submitted_sql_canonical == rt.baseline_sql_canonical
        ):
            penalty += GATE_BASELINE_VERBATIM
        return penalty


class ConsultDBA(_CtxChild):
    """Escalating penalties -0.1 / -0.3 / -0.8 per consult when the flag is on."""

    def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float:
        rt = self._ctx()
        oracle_enabled = getattr(rt, "dba_oracle_enabled", False)
        if not oracle_enabled:
            return 0.0
        if action.tool != ToolName.CONSULT_DBA:
            return 0.0
        # Count the consult THIS step by indexing into the escalation
        # table using the pre-increment value (env increments on the same step).
        tier = min(rt.consultations_used, len(CONSULT_ESCALATION))
        if tier <= 0:
            return CONSULT_ESCALATION[0]
        return CONSULT_ESCALATION[tier - 1]


# =============================================================================
# Composite
# =============================================================================


class SqlDriftRubric(Rubric):
    """Composite rubric: sum of six children.

    Registration as attributes auto-enrolls them in
    :meth:`Rubric.named_rubrics` so training loops can introspect
    per-component scores.
    """

    def __init__(self, ctx_provider: Callable[[], RuntimeEpisodeState]) -> None:
        super().__init__()
        # NOTE: order matters β€” correctness must populate last_score before
        # speedup reads it via the shared ctx_provider (both are pure
        # functions of the runtime state, so identical output β€” but the
        # explicit ordering documents the intent).
        self.correctness = Correctness(ctx_provider)
        self.drift_adapt = DriftAdapt(ctx_provider)
        self.speedup = Speedup(ctx_provider)
        self.step_tax = StepTax(ctx_provider)
        self.gatekeepers = Gatekeepers(ctx_provider)
        self.consult_dba = ConsultDBA(ctx_provider)

    def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float:
        total = (
            self.correctness(action, observation)
            + self.drift_adapt(action, observation)
            + self.speedup(action, observation)
            + self.step_tax(action, observation)
            + self.gatekeepers(action, observation)
            + self.consult_dba(action, observation)
        )
        return total

    def component_scores(self) -> dict[str, float]:
        """Return the most-recent per-component scores, keyed for W&B.

        Keys match :data:`models.REWARD_COMPONENT_KEYS` so the observation
        envelope and the demo plots agree on a stable schema.
        """
        return {
            "r_correct": float(self.correctness.last_score or 0.0),
            "r_drift": float(self.drift_adapt.last_score or 0.0),
            "r_speedup": float(self.speedup.last_score or 0.0),
            "r_step_tax": float(self.step_tax.last_score or 0.0),
            "r_gatekeepers": float(self.gatekeepers.last_score or 0.0),
            "r_consult_dba": float(self.consult_dba.last_score or 0.0),
        }


__all__ = [
    "CONSULT_ESCALATION",
    "ConsultDBA",
    "Correctness",
    "DriftAdapt",
    "GATE_BASELINE_VERBATIM",
    "GATE_CONSECUTIVE_TOOL_ERROR",
    "GATE_MALFORMED_TOOL_CALL",
    "GATE_REPEAT_FAILING_QUERY",
    "Gatekeepers",
    "SPEEDUP_CAP_FOR_INFTY",
    "SPEEDUP_MIN",
    "STEP_REBATE_DESCRIBE_TABLE",
    "STEP_REBATE_EXPLAIN_QUERY",
    "STEP_REBATE_LIST_TABLES",
    "STEP_REBATE_READ_CHANGELOG",
    "STEP_REBATE_RUN_QUERY",
    "STEP_REBATE_SAMPLE_ROWS",
    "STEP_TAX",
    "Speedup",
    "SqlDriftRubric",
    "StepTax",
    "canonicalize_sql",
    "effective_speedup",
]