Spaces:
Sleeping
Sleeping
| """SQLDrift composed rubric (correctness, drift, speedup, gates, DBA tax). | |
| Six child rubrics, one per reward component (:data:`REWARD_COMPONENT_KEYS`): | |
| r_correct correctness vs ground-truth hash, gated on β₯ 1.2Γ speedup | |
| r_drift bonus/penalty for (not) adapting to post-drift identifiers | |
| r_speedup tanh-shaped speedup bonus, gated on r_correct > 0 | |
| r_step_tax base step tax plus bounded productive-action rebates | |
| r_gatekeepers escalating tool-error / repeat-failing / no-op penalties | |
| r_consult_dba DBA-oracle consult penalties (feature-flagged; 0 when off) | |
| All child rubrics share a single ``ctx_provider`` that returns the private | |
| :class:`engine.runtime.RuntimeEpisodeState`; this keeps the rubric | |
| stateless relative to the environment and makes each component | |
| individually unit-testable with a synthesized triple | |
| ``(RuntimeEpisodeState, SqlDriftAction, SqlDriftObservation)``. | |
| """ | |
| from __future__ import annotations | |
| import math | |
| import re | |
| from collections.abc import Callable | |
| from typing import TYPE_CHECKING | |
| from openenv.core.rubrics import Rubric | |
| from models import ( | |
| SqlDriftAction, | |
| SqlDriftObservation, | |
| SubmitRewriteResult, | |
| ToolError, | |
| ToolName, | |
| ) | |
| if TYPE_CHECKING: | |
| from engine.runtime import RuntimeEpisodeState | |
| # Tunable thresholds pulled out to module level so tests and future | |
| # curriculum code share a single source of truth. | |
| SPEEDUP_MIN: float = 1.2 | |
| SPEEDUP_CAP_FOR_INFTY: float = 64.0 | |
| STEP_TAX: float = -0.03 | |
| STEP_REBATE_LIST_TABLES: float = 0.04 | |
| STEP_REBATE_DESCRIBE_TABLE: float = 0.06 | |
| STEP_REBATE_SAMPLE_ROWS: float = 0.05 | |
| STEP_REBATE_RUN_QUERY: float = 0.04 | |
| STEP_REBATE_EXPLAIN_QUERY: float = 0.04 | |
| STEP_REBATE_READ_CHANGELOG: float = 0.08 | |
| GATE_MALFORMED_TOOL_CALL: float = -0.3 | |
| GATE_CONSECUTIVE_TOOL_ERROR: float = -0.1 | |
| GATE_REPEAT_FAILING_QUERY: float = -0.1 | |
| GATE_BASELINE_VERBATIM: float = -0.2 | |
| _MAX_ESCALATION_STEPS: int = 3 | |
| CONSULT_ESCALATION: tuple[float, float, float] = (-0.1, -0.3, -0.8) | |
| # ============================================================================= | |
| # Helpers | |
| # ============================================================================= | |
| def canonicalize_sql(sql: str) -> str: | |
| """Whitespace/case/alias-insensitive canonical form. | |
| Uses sqlglot's duckdb dialect round-trip so reorders/reformats agree; | |
| falls back to a simple whitespace fold if sqlglot rejects the SQL | |
| (e.g. during the baseline-verbatim check on an agent-submitted blob). | |
| """ | |
| try: | |
| import sqlglot | |
| expr = sqlglot.parse_one(sql, dialect="duckdb") | |
| return expr.sql(dialect="duckdb", comments=False, normalize=True).strip().lower() | |
| except Exception: | |
| return " ".join(sql.lower().split()) | |
| _IDENT_RE = re.compile(r"\b([A-Za-z_][A-Za-z0-9_]*)\b") | |
| def _extract_identifiers(sql: str) -> frozenset[str]: | |
| """Loose case-preserving identifier harvest. | |
| Strips string literals before tokenizing so e.g. `WHERE x = 'ACTIVE'` | |
| does not leak 'ACTIVE' into the identifier set. | |
| """ | |
| stripped = re.sub(r"'[^']*'", "", sql) | |
| return frozenset(m.group(0) for m in _IDENT_RE.finditer(stripped)) | |
| def _extract_column_references(sql: str) -> frozenset[str]: | |
| """Column-reference identifiers via sqlglot AST (excludes alias labels). | |
| For drift-adapt scoring, ``SELECT account_id AS user_id`` references | |
| the new column and merely labels the output β the alias must not | |
| count as a surviving pre-drift marker. Falls back to the regex | |
| extractor on parse failure so malformed SQL still scores something. | |
| """ | |
| try: | |
| import sqlglot | |
| expr = sqlglot.parse_one(sql, dialect="duckdb") | |
| except Exception: | |
| return _extract_identifiers(sql) | |
| if expr is None: | |
| return _extract_identifiers(sql) | |
| return frozenset( | |
| n.name for n in expr.walk() if isinstance(n, sqlglot.exp.Column) and n.name | |
| ) | |
| def _literals(sql: str) -> frozenset[str]: | |
| """All `'..'`-quoted string literals in `sql`.""" | |
| return frozenset(re.findall(r"'([^']*)'", sql)) | |
| _AGENT_MS_EPSILON: float = 1e-6 | |
| def effective_speedup(rt: RuntimeEpisodeState) -> float | None: | |
| """Compute ``effective_speedup`` from the runtime snapshot (speedup rubric). | |
| Single source of truth for the speedup number used across the code | |
| base β rubric scoring, the skill library's ``avg_speedup`` field, | |
| and the training evaluator all route through here so divergent | |
| definitions cannot drift apart. | |
| Returns: | |
| * ``None`` β no submission has happened yet. Callers that need a | |
| numeric default (e.g. the rubric, which is only invoked | |
| post-submission) should verify ``rt.submitted`` first. | |
| * ``+β`` β drift has fired and the pre-drift baseline SQL no longer | |
| executes against the post-drift schema; any correct submission is | |
| definitionally "infinitely faster" than an unrunnable baseline. | |
| * ``baseline_ms / max(agent_ms, Ξ΅)`` otherwise. A tiny ``Ξ΅`` clamp | |
| guards against zero/negative timings from sub-microsecond queries | |
| and treats them as "as fast as possible" (very large, finite | |
| speedup) rather than silently collapsing the reward. | |
| """ | |
| if rt.submitted_runtime_ms is None: | |
| return None | |
| if rt.drift_fired and rt.baseline_postdrift_raises: | |
| return math.inf | |
| agent_ms = max(rt.submitted_runtime_ms, _AGENT_MS_EPSILON) | |
| return rt.baseline_runtime_ms / agent_ms | |
| def _speedup_for_reward(rt: RuntimeEpisodeState) -> float: | |
| """Rubric-facing speedup that never returns ``None``. | |
| The rubric is only invoked once ``rt.submitted`` is True, so | |
| :func:`effective_speedup` cannot return ``None`` from these call | |
| sites; we assert that and coerce to ``0.0`` defensively if it ever | |
| does (prevents a silent ``TypeError`` inside the reward math). | |
| """ | |
| val = effective_speedup(rt) | |
| return 0.0 if val is None else val | |
| def _is_terminal_submission( | |
| action: SqlDriftAction, | |
| observation: SqlDriftObservation, | |
| rt: RuntimeEpisodeState, | |
| ) -> bool: | |
| """True iff this step is the submission step. | |
| The env sets ``done=True`` on a successful submission and attaches a | |
| :class:`SubmitRewriteResult`; we gate terminal rewards on both | |
| signals so repeated rubric calls on an unchanged state don't | |
| double-score. | |
| """ | |
| if not rt.submitted: | |
| return False | |
| if action.tool != ToolName.SUBMIT_REWRITE: | |
| return False | |
| tr = observation.tool_result | |
| return isinstance(tr, SubmitRewriteResult) | |
| def _gt_hash(rt: RuntimeEpisodeState) -> str | None: | |
| if rt.drift_fired and rt.gt_result_hash_postdrift is not None: | |
| return rt.gt_result_hash_postdrift | |
| return rt.gt_result_hash_predrift | |
| # ============================================================================= | |
| # Child rubrics | |
| # ============================================================================= | |
| class _CtxChild(Rubric): | |
| """Base child rubric sharing the ctx provider.""" | |
| def __init__(self, ctx_provider: Callable[[], RuntimeEpisodeState]) -> None: | |
| super().__init__() | |
| object.__setattr__(self, "_ctx", ctx_provider) | |
| def forward( | |
| self, | |
| action: SqlDriftAction, | |
| observation: SqlDriftObservation, | |
| ) -> float: | |
| raise NotImplementedError | |
| class Correctness(_CtxChild): | |
| """Terminal-only correctness: +1.0 / +0.5 / -1.0 by hash and speedup.""" | |
| def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float: | |
| rt = self._ctx() | |
| if not _is_terminal_submission(action, observation, rt): | |
| return 0.0 | |
| gt = _gt_hash(rt) | |
| agent_hash = rt.submitted_result_hash | |
| if gt is None or agent_hash is None: | |
| return 0.0 | |
| if agent_hash != gt: | |
| return -1.0 | |
| speedup = _speedup_for_reward(rt) | |
| if speedup >= SPEEDUP_MIN: | |
| return 1.0 | |
| return 0.5 | |
| class DriftAdapt(_CtxChild): | |
| """+0.5 for a correctly-adapted submission, -0.5 for a pre-drift-only | |
| submission after drift fired. | |
| Adaptation is detected against two scenario-declared identifier sets: | |
| * ``postdrift_identifiers`` β identifiers/literals that only a | |
| correct post-drift rewrite will introduce (e.g. ``account_id`` | |
| after a column rename, ``'ACTIVE'`` after an enum split). | |
| * ``predrift_identifiers`` β identifiers/literals a submission that | |
| ignored the drift would retain (e.g. ``user_id``, ``'active'``, | |
| the ISO anchor strings under date-format drift). | |
| A submission is considered "adapted" when it either surfaces a | |
| post-drift marker *or* the scenario declares no distinctive | |
| post-drift identifiers (e.g. date-format drift keeps the same | |
| column name and only the literal shape changes) AND it does not | |
| retain any pre-drift marker. The penalty fires only when the | |
| submission still carries pre-drift markers AND produced the wrong | |
| post-drift result β so a merely partial rewrite (neither pre- | |
| nor post-flavoured) never earns a penalty it can't diagnose. | |
| """ | |
| def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float: | |
| rt = self._ctx() | |
| # Only drift scenarios participate. | |
| if rt.gt_result_hash_postdrift is None and not rt.drift_fired: | |
| return 0.0 | |
| if not _is_terminal_submission(action, observation, rt): | |
| return 0.0 | |
| inst = getattr(rt, "instance", None) | |
| post_ids: frozenset[str] = ( | |
| getattr(inst, "postdrift_identifiers", frozenset()) or frozenset() | |
| ) | |
| pre_ids: frozenset[str] = getattr(inst, "predrift_identifiers", frozenset()) or frozenset() | |
| agent_sql = rt.submitted_sql or "" | |
| idents = _extract_column_references(agent_sql) | |
| literals = _literals(agent_sql) | |
| markers = idents | literals | |
| uses_post = bool(post_ids & markers) | |
| uses_pre = bool(pre_ids & markers) | |
| # Treat "no distinctive post identifier" scenarios as | |
| # satisfied by absence-of-pre (see class docstring). | |
| adapted = (uses_post or not post_ids) and not uses_pre | |
| agent_hash = rt.submitted_result_hash | |
| gt_post = rt.gt_result_hash_postdrift | |
| if rt.drift_fired and agent_hash == gt_post and adapted: | |
| return 0.5 | |
| if rt.drift_fired and uses_pre and agent_hash != gt_post: | |
| return -0.5 | |
| return 0.0 | |
| class Speedup(_CtxChild): | |
| """Terminal-only, gated on r_correct > 0: 0.3Β·tanh(log2(speedup)/3).""" | |
| def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float: | |
| rt = self._ctx() | |
| if not _is_terminal_submission(action, observation, rt): | |
| return 0.0 | |
| gt = _gt_hash(rt) | |
| if gt is None or rt.submitted_result_hash != gt: | |
| return 0.0 | |
| raw = _speedup_for_reward(rt) | |
| if math.isinf(raw): | |
| raw = SPEEDUP_CAP_FOR_INFTY | |
| if raw <= 1.0: | |
| return 0.0 | |
| return 0.3 * math.tanh(math.log2(raw) / 3.0) | |
| class StepTax(_CtxChild): | |
| """Base step tax plus bounded rebates for productive exploration.""" | |
| def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float: | |
| rt = self._ctx() | |
| if _is_terminal_submission(action, observation, rt): | |
| return 0.0 | |
| rebate = max(0.0, float(getattr(rt, "last_step_productive_rebate", 0.0))) | |
| return STEP_TAX + rebate | |
| class Gatekeepers(_CtxChild): | |
| """Sum of three independent penalties; repeats escalate up to a cap.""" | |
| def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float: | |
| rt = self._ctx() | |
| penalty = 0.0 | |
| # 1. Malformed / failed tool call β ToolError emitted this step. | |
| if isinstance(observation.tool_result, ToolError): | |
| penalty += GATE_MALFORMED_TOOL_CALL | |
| streak = max(0, int(getattr(rt, "consecutive_tool_errors", 0)) - 1) | |
| penalty += GATE_CONSECUTIVE_TOOL_ERROR * min(streak, _MAX_ESCALATION_STEPS) | |
| # 2. Repeat failing query β env marks the flag on the runtime | |
| # state immediately before invoking the rubric. | |
| repeats = max(0, int(getattr(rt, "last_step_repeat_failing_query_count", 0)) - 1) | |
| if repeats > 0: | |
| penalty += GATE_REPEAT_FAILING_QUERY * min(repeats, _MAX_ESCALATION_STEPS) | |
| # 3. Baseline-verbatim submission (Rev-3 gate β stacks with | |
| # correctness's +0.5 partial to cap the no-op rewrite at +0.3). | |
| if ( | |
| action.tool == ToolName.SUBMIT_REWRITE | |
| and _is_terminal_submission(action, observation, rt) | |
| and rt.submitted_sql_canonical == rt.baseline_sql_canonical | |
| ): | |
| penalty += GATE_BASELINE_VERBATIM | |
| return penalty | |
| class ConsultDBA(_CtxChild): | |
| """Escalating penalties -0.1 / -0.3 / -0.8 per consult when the flag is on.""" | |
| def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float: | |
| rt = self._ctx() | |
| oracle_enabled = getattr(rt, "dba_oracle_enabled", False) | |
| if not oracle_enabled: | |
| return 0.0 | |
| if action.tool != ToolName.CONSULT_DBA: | |
| return 0.0 | |
| # Count the consult THIS step by indexing into the escalation | |
| # table using the pre-increment value (env increments on the same step). | |
| tier = min(rt.consultations_used, len(CONSULT_ESCALATION)) | |
| if tier <= 0: | |
| return CONSULT_ESCALATION[0] | |
| return CONSULT_ESCALATION[tier - 1] | |
| # ============================================================================= | |
| # Composite | |
| # ============================================================================= | |
| class SqlDriftRubric(Rubric): | |
| """Composite rubric: sum of six children. | |
| Registration as attributes auto-enrolls them in | |
| :meth:`Rubric.named_rubrics` so training loops can introspect | |
| per-component scores. | |
| """ | |
| def __init__(self, ctx_provider: Callable[[], RuntimeEpisodeState]) -> None: | |
| super().__init__() | |
| # NOTE: order matters β correctness must populate last_score before | |
| # speedup reads it via the shared ctx_provider (both are pure | |
| # functions of the runtime state, so identical output β but the | |
| # explicit ordering documents the intent). | |
| self.correctness = Correctness(ctx_provider) | |
| self.drift_adapt = DriftAdapt(ctx_provider) | |
| self.speedup = Speedup(ctx_provider) | |
| self.step_tax = StepTax(ctx_provider) | |
| self.gatekeepers = Gatekeepers(ctx_provider) | |
| self.consult_dba = ConsultDBA(ctx_provider) | |
| def forward(self, action: SqlDriftAction, observation: SqlDriftObservation) -> float: | |
| total = ( | |
| self.correctness(action, observation) | |
| + self.drift_adapt(action, observation) | |
| + self.speedup(action, observation) | |
| + self.step_tax(action, observation) | |
| + self.gatekeepers(action, observation) | |
| + self.consult_dba(action, observation) | |
| ) | |
| return total | |
| def component_scores(self) -> dict[str, float]: | |
| """Return the most-recent per-component scores, keyed for W&B. | |
| Keys match :data:`models.REWARD_COMPONENT_KEYS` so the observation | |
| envelope and the demo plots agree on a stable schema. | |
| """ | |
| return { | |
| "r_correct": float(self.correctness.last_score or 0.0), | |
| "r_drift": float(self.drift_adapt.last_score or 0.0), | |
| "r_speedup": float(self.speedup.last_score or 0.0), | |
| "r_step_tax": float(self.step_tax.last_score or 0.0), | |
| "r_gatekeepers": float(self.gatekeepers.last_score or 0.0), | |
| "r_consult_dba": float(self.consult_dba.last_score or 0.0), | |
| } | |
| __all__ = [ | |
| "CONSULT_ESCALATION", | |
| "ConsultDBA", | |
| "Correctness", | |
| "DriftAdapt", | |
| "GATE_BASELINE_VERBATIM", | |
| "GATE_CONSECUTIVE_TOOL_ERROR", | |
| "GATE_MALFORMED_TOOL_CALL", | |
| "GATE_REPEAT_FAILING_QUERY", | |
| "Gatekeepers", | |
| "SPEEDUP_CAP_FOR_INFTY", | |
| "SPEEDUP_MIN", | |
| "STEP_REBATE_DESCRIBE_TABLE", | |
| "STEP_REBATE_EXPLAIN_QUERY", | |
| "STEP_REBATE_LIST_TABLES", | |
| "STEP_REBATE_READ_CHANGELOG", | |
| "STEP_REBATE_RUN_QUERY", | |
| "STEP_REBATE_SAMPLE_ROWS", | |
| "STEP_TAX", | |
| "Speedup", | |
| "SqlDriftRubric", | |
| "StepTax", | |
| "canonicalize_sql", | |
| "effective_speedup", | |
| ] | |