"""Public data models for SQLDrift. Rev 3 design notes enforced here: - Action is a discriminated union over a public `kind: Literal[...]` tag on each payload sub-model. Pydantic v2 forbids leading-underscore names as discriminator keys (reserved for private attrs), so we keep the tag public. - `SqlDriftAction` cross-validates that the envelope-level `tool` matches `payload.kind` (prevents inconsistent envelopes from being constructed). - `SqlDriftObservation.tool_result` is itself a discriminated union over the eight concrete result types plus `ToolError` (for in-env semantic failures; envelope-level `ValidationError` is a transport-layer concern, not an in-env code). - `SqlDriftState` is the public state snapshot shipped over `/state`. It never carries ground truth, DB handles, baseline runtime, or seeds; `extra="forbid"` guarantees no accidental leak as new fields are added. The private `RuntimeEpisodeState` lives in :mod:`engine.runtime`. """ from __future__ import annotations from enum import StrEnum from typing import Annotated, Any, Literal from openenv.core.env_server.types import Action, Observation, State from pydantic import BaseModel, ConfigDict, Field, model_validator from pydantic_core import PydanticCustomError # ============================================================================= # Enums # ============================================================================= class ToolName(StrEnum): LIST_TABLES = "list_tables" DESCRIBE_TABLE = "describe_table" SAMPLE_ROWS = "sample_rows" RUN_QUERY = "run_query" EXPLAIN_QUERY = "explain_query" READ_CHANGELOG = "read_changelog" SUBMIT_REWRITE = "submit_rewrite" CONSULT_DBA = "consult_dba" class EpisodePhase(StrEnum): DIAGNOSE = "diagnose" REWRITE = "rewrite" DRIFT_RECOVERY = "drift_recovery" FINALIZE = "finalize" class ToolErrorCode(StrEnum): """In-environment semantic failure codes (API contract). Envelope-level `pydantic.ValidationError` is handled by the OpenEnv transport layer (HTTP 422 / `/ws` error frame) and never reaches `env.step`, so it has no code here. """ DB_ERROR = "db_error" UNKNOWN_TABLE = "unknown_table" QUERY_TIMEOUT = "query_timeout" RESULT_TOO_LARGE = "result_too_large" SUBMIT_BEFORE_DIAGNOSE = "submit_before_diagnose" INVALID_TOOL_ARGUMENT = "invalid_tool_argument" # ============================================================================= # Tool payloads (request side of `SqlDriftAction`) # ============================================================================= class _BasePayload(BaseModel): """Shared config for every tool-call payload.""" model_config = ConfigDict(extra="forbid", validate_assignment=True) class ListTablesPayload(_BasePayload): kind: Literal["list_tables"] = "list_tables" class DescribeTablePayload(_BasePayload): kind: Literal["describe_table"] = "describe_table" table: str = Field(min_length=1, max_length=63) class SampleRowsPayload(_BasePayload): kind: Literal["sample_rows"] = "sample_rows" table: str = Field(min_length=1, max_length=63) limit: int = Field(default=5, ge=1, le=5) class RunQueryPayload(_BasePayload): kind: Literal["run_query"] = "run_query" sql: str = Field(min_length=1, max_length=10_000) class ExplainQueryPayload(_BasePayload): kind: Literal["explain_query"] = "explain_query" sql: str = Field(min_length=1, max_length=10_000) class ReadChangelogPayload(_BasePayload): kind: Literal["read_changelog"] = "read_changelog" class SubmitRewritePayload(_BasePayload): kind: Literal["submit_rewrite"] = "submit_rewrite" sql: str = Field(min_length=1, max_length=10_000) class ConsultDBAPayload(_BasePayload): kind: Literal["consult_dba"] = "consult_dba" question: str = Field(min_length=1, max_length=400) ToolPayload = Annotated[ ListTablesPayload | DescribeTablePayload | SampleRowsPayload | RunQueryPayload | ExplainQueryPayload | ReadChangelogPayload | SubmitRewritePayload | ConsultDBAPayload, Field(discriminator="kind"), ] # Tool -> payload-kind mapping; single source of truth for cross-validation # and for the server-side dispatcher in P7. TOOL_TO_PAYLOAD_KIND: dict[ToolName, str] = { ToolName.LIST_TABLES: "list_tables", ToolName.DESCRIBE_TABLE: "describe_table", ToolName.SAMPLE_ROWS: "sample_rows", ToolName.RUN_QUERY: "run_query", ToolName.EXPLAIN_QUERY: "explain_query", ToolName.READ_CHANGELOG: "read_changelog", ToolName.SUBMIT_REWRITE: "submit_rewrite", ToolName.CONSULT_DBA: "consult_dba", } # ============================================================================= # SqlDriftAction envelope # ============================================================================= class SqlDriftAction(Action): """Tool-call envelope. JSON wire format:: {"tool": "run_query", "payload": {"kind": "run_query", "sql": "..."}} The `tool` field and `payload.kind` must agree; mismatch raises at validation time. """ tool: ToolName payload: ToolPayload @model_validator(mode="after") def _tool_matches_payload(self) -> SqlDriftAction: expected = TOOL_TO_PAYLOAD_KIND[self.tool] if self.payload.kind != expected: # PydanticCustomError keeps ``ctx`` JSON-serializable (plain # strings only), unlike a bare ``ValueError`` which Pydantic # wraps with ``ctx={"error": ValueError(...)}`` and breaks # FastAPI HTTPException JSON encoder (422 responses). raise PydanticCustomError( "tool_payload_mismatch", "tool/payload mismatch: tool={tool} expects payload.kind={expected}, got {got}", { "tool": self.tool.value, "expected": expected, "got": self.payload.kind, }, ) return self # ============================================================================= # Tool results (response side of `SqlDriftObservation.tool_result`) # ============================================================================= class _BaseResult(BaseModel): model_config = ConfigDict(extra="forbid", validate_assignment=True) class ListTablesResult(_BaseResult): kind: Literal["list_tables_result"] = "list_tables_result" tables: list[str] class DescribeTableResult(_BaseResult): kind: Literal["describe_table_result"] = "describe_table_result" table: str columns: list[dict[str, str]] # [{"name": "...", "type": "..."}] class SampleRowsResult(_BaseResult): kind: Literal["sample_rows_result"] = "sample_rows_result" table: str columns: list[str] rows: list[list[Any]] class RunQueryResult(_BaseResult): kind: Literal["run_query_result"] = "run_query_result" columns: list[str] rows: list[list[Any]] runtime_ms: float row_count: int class ExplainQueryResult(_BaseResult): kind: Literal["explain_query_result"] = "explain_query_result" plan: str class ReadChangelogResult(_BaseResult): kind: Literal["read_changelog_result"] = "read_changelog_result" entries: list[str] class SubmitRewriteResult(_BaseResult): kind: Literal["submit_rewrite_result"] = "submit_rewrite_result" accepted: bool runtime_ms: float matches_ground_truth: bool class ConsultDBAResult(_BaseResult): kind: Literal["consult_dba_result"] = "consult_dba_result" tier: int = Field(ge=1, le=3) hint: str class ToolError(_BaseResult): kind: Literal["tool_error"] = "tool_error" code: ToolErrorCode message: str = Field(max_length=2_000) ToolResult = Annotated[ ListTablesResult | DescribeTableResult | SampleRowsResult | RunQueryResult | ExplainQueryResult | ReadChangelogResult | SubmitRewriteResult | ConsultDBAResult | ToolError, Field(discriminator="kind"), ] # The six reward-component keys match the composed rubric; tests and telemetry # rely on this exact schema. REWARD_COMPONENT_KEYS: tuple[str, ...] = ( "r_correct", "r_drift", "r_speedup", "r_step_tax", "r_gatekeepers", "r_consult_dba", ) # ============================================================================= # SqlDriftObservation # ============================================================================= def _zero_reward_components() -> dict[str, float]: """Six-key reward envelope initialised to zero. Every observation, including the reset observation, carries the full six-key schema so telemetry and tests can index it unconditionally. """ return {key: 0.0 for key in REWARD_COMPONENT_KEYS} class SqlDriftObservation(Observation): """Observation returned by :meth:`SqlDriftEnvironment.step`. Inherits `done: bool` and `reward: float | None` from base Observation. The task payload (`baseline_sql`, `schema_synopsis`) is delivered on the reset observation and kept empty on subsequent steps: the agent is expected to capture it once and hold it in its own context. """ step: int = Field(ge=0) phase: EpisodePhase last_tool: ToolName | None = None tool_result: ToolResult | None = None drift_fired: bool = False drift_acknowledged: bool = False learned_hints: str = Field(default="", max_length=800) baseline_sql: str = Field(default="", max_length=10_000) schema_synopsis: str = Field(default="", max_length=2_000) budget_steps_remaining: int = Field(ge=0) reward_components: dict[str, float] = Field(default_factory=_zero_reward_components) # ============================================================================= # SqlDriftState — PUBLIC state (sanitized) # ============================================================================= class SqlDriftState(State): """Public state snapshot — serialized over `/state`. Ground truth, DB handles, seeds, and baseline SQL live in :class:`engine.runtime.RuntimeEpisodeState` and are never exposed here. `extra="forbid"` guarantees no accidental leak via future field additions. """ model_config = ConfigDict( extra="forbid", validate_assignment=True, ) scenario_id: str phase: EpisodePhase budget_steps_remaining: int = Field(ge=0) drift_fired: bool = False consultations_used: int = Field(default=0, ge=0) submitted: bool = False __all__ = [ "ConsultDBAPayload", "ConsultDBAResult", "DescribeTablePayload", "DescribeTableResult", "EpisodePhase", "ExplainQueryPayload", "ExplainQueryResult", "ListTablesPayload", "ListTablesResult", "REWARD_COMPONENT_KEYS", "ReadChangelogPayload", "ReadChangelogResult", "RunQueryPayload", "RunQueryResult", "SampleRowsPayload", "SampleRowsResult", "SqlDriftAction", "SqlDriftObservation", "SqlDriftState", "SubmitRewritePayload", "SubmitRewriteResult", "TOOL_TO_PAYLOAD_KIND", "ToolError", "ToolErrorCode", "ToolName", "ToolPayload", "ToolResult", ]