sql-drift-env / engine /reward.py
visheshrathi's picture
Upload folder using huggingface_hub
5850885 verified
"""SQLDrift composed rubric (correctness, drift, speedup, gates, DBA tax).
Six child rubrics, one per reward component (:data:`REWARD_COMPONENT_KEYS`):
r_correct correctness vs ground-truth hash, gated on β‰₯ 1.2Γ— speedup
r_drift bonus/penalty for (not) adapting to post-drift identifiers
r_speedup tanh-shaped speedup bonus, gated on r_correct > 0
r_step_tax base step tax plus bounded productive-action rebates
r_gatekeepers escalating tool-error / repeat-failing / no-op penalties
r_consult_dba DBA-oracle consult penalties (feature-flagged; 0 when off)
All child rubrics share a single ``ctx_provider`` that returns the private
:class:`engine.runtime.RuntimeEpisodeState`; this keeps the rubric
stateless relative to the environment and makes each component
individually unit-testable with a synthesized triple
``(RuntimeEpisodeState, SqlDriftAction, SqlDriftObservation)``.
"""
from __future__ import annotations
import math
import re
from collections.abc import Callable
from typing import TYPE_CHECKING
from openenv.core.rubrics import Rubric
from models import (
SqlDriftAction,
SqlDriftObservation,
SubmitRewriteResult,
ToolError,
ToolName,
)
if TYPE_CHECKING:
from engine.runtime import RuntimeEpisodeState
# Tunable thresholds pulled out to module level so tests and future
# curriculum code share a single source of truth.
SPEEDUP_MIN: float = 1.2
SPEEDUP_CAP_FOR_INFTY: float = 64.0
STEP_TAX: float = -0.03
STEP_REBATE_LIST_TABLES: float = 0.04
STEP_REBATE_DESCRIBE_TABLE: float = 0.06
STEP_REBATE_SAMPLE_ROWS: float = 0.05
STEP_REBATE_RUN_QUERY: float = 0.04
STEP_REBATE_EXPLAIN_QUERY: float = 0.04
STEP_REBATE_READ_CHANGELOG: float = 0.08
GATE_MALFORMED_TOOL_CALL: float = -0.3
GATE_CONSECUTIVE_TOOL_ERROR: float = -0.1
GATE_REPEAT_FAILING_QUERY: float = -0.1
GATE_BASELINE_VERBATIM: float = -0.2
_MAX_ESCALATION_STEPS: int = 3
CONSULT_ESCALATION: tuple[float, float, float] = (-0.1, -0.3, -0.8)
# =============================================================================
# Helpers
# =============================================================================
def canonicalize_sql(sql: str) -> str:
"""Whitespace/case/alias-insensitive canonical form.
Uses sqlglot's duckdb dialect round-trip so reorders/reformats agree;
falls back to a simple whitespace fold if sqlglot rejects the SQL
(e.g. during the baseline-verbatim check on an agent-submitted blob).
"""
try:
import sqlglot
expr = sqlglot.parse_one(sql, dialect="duckdb")
return expr.sql(dialect="duckdb", comments=False, normalize=True).strip().lower()
except Exception:
return " ".join(sql.lower().split())
_IDENT_RE = re.compile(r"\b([A-Za-z_][A-Za-z0-9_]*)\b")
def _extract_identifiers(sql: str) -> frozenset[str]:
"""Loose case-preserving identifier harvest.
Strips string literals before tokenizing so e.g. `WHERE x = 'ACTIVE'`
does not leak 'ACTIVE' into the identifier set.
"""
stripped = re.sub(r"'[^']*'", "", sql)
return frozenset(m.group(0) for m in _IDENT_RE.finditer(stripped))
def _extract_column_references(sql: str) -> frozenset[str]:
"""Column-reference identifiers via sqlglot AST (excludes alias labels).
For drift-adapt scoring, ``SELECT account_id AS user_id`` references
the new column and merely labels the output β€” the alias must not
count as a surviving pre-drift marker. Falls back to the regex
extractor on parse failure so malformed SQL still scores something.
"""
try:
import sqlglot
expr = sqlglot.parse_one(sql, dialect="duckdb")
except Exception:
return _extract_identifiers(sql)
if expr is None:
return _extract_identifiers(sql)
return frozenset(
n.name for n in expr.walk() if isinstance(n, sqlglot.exp.Column) and n.name
)
def _literals(sql: str) -> frozenset[str]:
"""All `'..'`-quoted string literals in `sql`."""
return frozenset(re.findall(r"'([^']*)'", sql))
_AGENT_MS_EPSILON: float = 1e-6
def effective_speedup(rt: RuntimeEpisodeState) -> float | None:
"""Compute ``effective_speedup`` from the runtime snapshot (speedup rubric).
Single source of truth for the speedup number used across the code
base β€” rubric scoring, the skill library's ``avg_speedup`` field,
and the training evaluator all route through here so divergent
definitions cannot drift apart.
Returns:
* ``None`` β€” no submission has happened yet. Callers that need a
numeric default (e.g. the rubric, which is only invoked
post-submission) should verify ``rt.submitted`` first.
* ``+∞`` β€” drift has fired and the pre-drift baseline SQL no longer
executes against the post-drift schema; any correct submission is
definitionally "infinitely faster" than an unrunnable baseline.
* ``baseline_ms / max(agent_ms, Ξ΅)`` otherwise. A tiny ``Ξ΅`` clamp
guards against zero/negative timings from sub-microsecond queries
and treats them as "as fast as possible" (very large, finite
speedup) rather than silently collapsing the reward.
"""
if rt.submitted_runtime_ms is None:
return None
if rt.drift_fired and rt.baseline_postdrift_raises:
return math.inf
agent_ms = max(rt.submitted_runtime_ms, _AGENT_MS_EPSILON)
return rt.baseline_runtime_ms / agent_ms
def _speedup_for_reward(rt: RuntimeEpisodeState) -> float:
"""Rubric-facing speedup that never returns ``None``.
The rubric is only invoked once ``rt.submitted`` is True, so
:func:`effective_speedup` cannot return ``None`` from these call
sites; we assert that and coerce to ``0.0`` defensively if it ever
does (prevents a silent ``TypeError`` inside the reward math).
"""
val = effective_speedup(rt)
return 0.0 if val is None else val
def _is_terminal_submission(
action: SqlDriftAction,
observation: SqlDriftObservation,
rt: RuntimeEpisodeState,
) -> bool:
"""True iff this step is the submission step.
The env sets ``done=True`` on a successful submission and attaches a
:class:`SubmitRewriteResult`; we gate terminal rewards on both
signals so repeated rubric calls on an unchanged state don't
double-score.
"""
if not rt.submitted:
return False
if action.tool != ToolName.SUBMIT_REWRITE:
return False
tr = observation.tool_result
return isinstance(tr, SubmitRewriteResult)
def _gt_hash(rt: RuntimeEpisodeState) -> str | None:
if rt.drift_fired and rt.gt_result_hash_postdrift is not None:
return rt.gt_result_hash_postdrift
return rt.gt_result_hash_predrift
# =============================================================================
# Child rubrics
# =============================================================================
class _CtxChild(Rubric):
"""Base child rubric sharing the ctx provider."""
def __init__(self, ctx_provider: Callable[[], RuntimeEpisodeState]) -> None:
super().__init__()
object.__setattr__(self, "_ctx", ctx_provider)
def forward(
self,
action: SqlDriftAction,
observation: SqlDriftObservation,
) -> float:
raise NotImplementedError
class Correctness(_CtxChild):
"""Terminal-only correctness: +1.0 / +0.5 / -1.0 by hash and speedup."""
def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float:
rt = self._ctx()
if not _is_terminal_submission(action, observation, rt):
return 0.0
gt = _gt_hash(rt)
agent_hash = rt.submitted_result_hash
if gt is None or agent_hash is None:
return 0.0
if agent_hash != gt:
return -1.0
speedup = _speedup_for_reward(rt)
if speedup >= SPEEDUP_MIN:
return 1.0
return 0.5
class DriftAdapt(_CtxChild):
"""+0.5 for a correctly-adapted submission, -0.5 for a pre-drift-only
submission after drift fired.
Adaptation is detected against two scenario-declared identifier sets:
* ``postdrift_identifiers`` β€” identifiers/literals that only a
correct post-drift rewrite will introduce (e.g. ``account_id``
after a column rename, ``'ACTIVE'`` after an enum split).
* ``predrift_identifiers`` β€” identifiers/literals a submission that
ignored the drift would retain (e.g. ``user_id``, ``'active'``,
the ISO anchor strings under date-format drift).
A submission is considered "adapted" when it either surfaces a
post-drift marker *or* the scenario declares no distinctive
post-drift identifiers (e.g. date-format drift keeps the same
column name and only the literal shape changes) AND it does not
retain any pre-drift marker. The penalty fires only when the
submission still carries pre-drift markers AND produced the wrong
post-drift result β€” so a merely partial rewrite (neither pre-
nor post-flavoured) never earns a penalty it can't diagnose.
"""
def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float:
rt = self._ctx()
# Only drift scenarios participate.
if rt.gt_result_hash_postdrift is None and not rt.drift_fired:
return 0.0
if not _is_terminal_submission(action, observation, rt):
return 0.0
inst = getattr(rt, "instance", None)
post_ids: frozenset[str] = (
getattr(inst, "postdrift_identifiers", frozenset()) or frozenset()
)
pre_ids: frozenset[str] = getattr(inst, "predrift_identifiers", frozenset()) or frozenset()
agent_sql = rt.submitted_sql or ""
idents = _extract_column_references(agent_sql)
literals = _literals(agent_sql)
markers = idents | literals
uses_post = bool(post_ids & markers)
uses_pre = bool(pre_ids & markers)
# Treat "no distinctive post identifier" scenarios as
# satisfied by absence-of-pre (see class docstring).
adapted = (uses_post or not post_ids) and not uses_pre
agent_hash = rt.submitted_result_hash
gt_post = rt.gt_result_hash_postdrift
if rt.drift_fired and agent_hash == gt_post and adapted:
return 0.5
if rt.drift_fired and uses_pre and agent_hash != gt_post:
return -0.5
return 0.0
class Speedup(_CtxChild):
"""Terminal-only, gated on r_correct > 0: 0.3Β·tanh(log2(speedup)/3)."""
def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float:
rt = self._ctx()
if not _is_terminal_submission(action, observation, rt):
return 0.0
gt = _gt_hash(rt)
if gt is None or rt.submitted_result_hash != gt:
return 0.0
raw = _speedup_for_reward(rt)
if math.isinf(raw):
raw = SPEEDUP_CAP_FOR_INFTY
if raw <= 1.0:
return 0.0
return 0.3 * math.tanh(math.log2(raw) / 3.0)
class StepTax(_CtxChild):
"""Base step tax plus bounded rebates for productive exploration."""
def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float:
rt = self._ctx()
if _is_terminal_submission(action, observation, rt):
return 0.0
rebate = max(0.0, float(getattr(rt, "last_step_productive_rebate", 0.0)))
return STEP_TAX + rebate
class Gatekeepers(_CtxChild):
"""Sum of three independent penalties; repeats escalate up to a cap."""
def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float:
rt = self._ctx()
penalty = 0.0
# 1. Malformed / failed tool call β€” ToolError emitted this step.
if isinstance(observation.tool_result, ToolError):
penalty += GATE_MALFORMED_TOOL_CALL
streak = max(0, int(getattr(rt, "consecutive_tool_errors", 0)) - 1)
penalty += GATE_CONSECUTIVE_TOOL_ERROR * min(streak, _MAX_ESCALATION_STEPS)
# 2. Repeat failing query β€” env marks the flag on the runtime
# state immediately before invoking the rubric.
repeats = max(0, int(getattr(rt, "last_step_repeat_failing_query_count", 0)) - 1)
if repeats > 0:
penalty += GATE_REPEAT_FAILING_QUERY * min(repeats, _MAX_ESCALATION_STEPS)
# 3. Baseline-verbatim submission (Rev-3 gate β€” stacks with
# correctness's +0.5 partial to cap the no-op rewrite at +0.3).
if (
action.tool == ToolName.SUBMIT_REWRITE
and _is_terminal_submission(action, observation, rt)
and rt.submitted_sql_canonical == rt.baseline_sql_canonical
):
penalty += GATE_BASELINE_VERBATIM
return penalty
class ConsultDBA(_CtxChild):
"""Escalating penalties -0.1 / -0.3 / -0.8 per consult when the flag is on."""
def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float:
rt = self._ctx()
oracle_enabled = getattr(rt, "dba_oracle_enabled", False)
if not oracle_enabled:
return 0.0
if action.tool != ToolName.CONSULT_DBA:
return 0.0
# Count the consult THIS step by indexing into the escalation
# table using the pre-increment value (env increments on the same step).
tier = min(rt.consultations_used, len(CONSULT_ESCALATION))
if tier <= 0:
return CONSULT_ESCALATION[0]
return CONSULT_ESCALATION[tier - 1]
# =============================================================================
# Composite
# =============================================================================
class SqlDriftRubric(Rubric):
"""Composite rubric: sum of six children.
Registration as attributes auto-enrolls them in
:meth:`Rubric.named_rubrics` so training loops can introspect
per-component scores.
"""
def __init__(self, ctx_provider: Callable[[], RuntimeEpisodeState]) -> None:
super().__init__()
# NOTE: order matters β€” correctness must populate last_score before
# speedup reads it via the shared ctx_provider (both are pure
# functions of the runtime state, so identical output β€” but the
# explicit ordering documents the intent).
self.correctness = Correctness(ctx_provider)
self.drift_adapt = DriftAdapt(ctx_provider)
self.speedup = Speedup(ctx_provider)
self.step_tax = StepTax(ctx_provider)
self.gatekeepers = Gatekeepers(ctx_provider)
self.consult_dba = ConsultDBA(ctx_provider)
def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float:
total = (
self.correctness(action, observation)
+ self.drift_adapt(action, observation)
+ self.speedup(action, observation)
+ self.step_tax(action, observation)
+ self.gatekeepers(action, observation)
+ self.consult_dba(action, observation)
)
return total
def component_scores(self) -> dict[str, float]:
"""Return the most-recent per-component scores, keyed for W&B.
Keys match :data:`models.REWARD_COMPONENT_KEYS` so the observation
envelope and the demo plots agree on a stable schema.
"""
return {
"r_correct": float(self.correctness.last_score or 0.0),
"r_drift": float(self.drift_adapt.last_score or 0.0),
"r_speedup": float(self.speedup.last_score or 0.0),
"r_step_tax": float(self.step_tax.last_score or 0.0),
"r_gatekeepers": float(self.gatekeepers.last_score or 0.0),
"r_consult_dba": float(self.consult_dba.last_score or 0.0),
}
__all__ = [
"CONSULT_ESCALATION",
"ConsultDBA",
"Correctness",
"DriftAdapt",
"GATE_BASELINE_VERBATIM",
"GATE_CONSECUTIVE_TOOL_ERROR",
"GATE_MALFORMED_TOOL_CALL",
"GATE_REPEAT_FAILING_QUERY",
"Gatekeepers",
"SPEEDUP_CAP_FOR_INFTY",
"SPEEDUP_MIN",
"STEP_REBATE_DESCRIBE_TABLE",
"STEP_REBATE_EXPLAIN_QUERY",
"STEP_REBATE_LIST_TABLES",
"STEP_REBATE_READ_CHANGELOG",
"STEP_REBATE_RUN_QUERY",
"STEP_REBATE_SAMPLE_ROWS",
"STEP_TAX",
"Speedup",
"SqlDriftRubric",
"StepTax",
"canonicalize_sql",
"effective_speedup",
]