File size: 11,201 Bytes
5850885
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
"""Public data models for SQLDrift.

Rev 3 design notes enforced here:

- Action is a discriminated union over a public `kind: Literal[...]` tag on
  each payload sub-model. Pydantic v2 forbids leading-underscore names as
  discriminator keys (reserved for private attrs), so we keep the tag public.
- `SqlDriftAction` cross-validates that the envelope-level `tool` matches
  `payload.kind` (prevents inconsistent envelopes from being constructed).
- `SqlDriftObservation.tool_result` is itself a discriminated union over the
  eight concrete result types plus `ToolError` (for in-env semantic failures;
  envelope-level `ValidationError` is a transport-layer concern, not an in-env code).
- `SqlDriftState` is the public state snapshot shipped over `/state`. It
  never carries ground truth, DB handles, baseline runtime, or seeds;
  `extra="forbid"` guarantees no accidental leak as new fields are added.
  The private `RuntimeEpisodeState` lives in :mod:`engine.runtime`.
"""

from __future__ import annotations

from enum import StrEnum
from typing import Annotated, Any, Literal

from openenv.core.env_server.types import Action, Observation, State
from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic_core import PydanticCustomError

# =============================================================================
# Enums
# =============================================================================


class ToolName(StrEnum):
    LIST_TABLES = "list_tables"
    DESCRIBE_TABLE = "describe_table"
    SAMPLE_ROWS = "sample_rows"
    RUN_QUERY = "run_query"
    EXPLAIN_QUERY = "explain_query"
    READ_CHANGELOG = "read_changelog"
    SUBMIT_REWRITE = "submit_rewrite"
    CONSULT_DBA = "consult_dba"


class EpisodePhase(StrEnum):
    DIAGNOSE = "diagnose"
    REWRITE = "rewrite"
    DRIFT_RECOVERY = "drift_recovery"
    FINALIZE = "finalize"


class ToolErrorCode(StrEnum):
    """In-environment semantic failure codes (API contract).

    Envelope-level `pydantic.ValidationError` is handled by the OpenEnv
    transport layer (HTTP 422 / `/ws` error frame) and never reaches
    `env.step`, so it has no code here.
    """

    DB_ERROR = "db_error"
    UNKNOWN_TABLE = "unknown_table"
    QUERY_TIMEOUT = "query_timeout"
    RESULT_TOO_LARGE = "result_too_large"
    SUBMIT_BEFORE_DIAGNOSE = "submit_before_diagnose"
    INVALID_TOOL_ARGUMENT = "invalid_tool_argument"


# =============================================================================
# Tool payloads (request side of `SqlDriftAction`)
# =============================================================================


class _BasePayload(BaseModel):
    """Shared config for every tool-call payload."""

    model_config = ConfigDict(extra="forbid", validate_assignment=True)


class ListTablesPayload(_BasePayload):
    kind: Literal["list_tables"] = "list_tables"


class DescribeTablePayload(_BasePayload):
    kind: Literal["describe_table"] = "describe_table"
    table: str = Field(min_length=1, max_length=63)


class SampleRowsPayload(_BasePayload):
    kind: Literal["sample_rows"] = "sample_rows"
    table: str = Field(min_length=1, max_length=63)
    limit: int = Field(default=5, ge=1, le=5)


class RunQueryPayload(_BasePayload):
    kind: Literal["run_query"] = "run_query"
    sql: str = Field(min_length=1, max_length=10_000)


class ExplainQueryPayload(_BasePayload):
    kind: Literal["explain_query"] = "explain_query"
    sql: str = Field(min_length=1, max_length=10_000)


class ReadChangelogPayload(_BasePayload):
    kind: Literal["read_changelog"] = "read_changelog"


class SubmitRewritePayload(_BasePayload):
    kind: Literal["submit_rewrite"] = "submit_rewrite"
    sql: str = Field(min_length=1, max_length=10_000)


class ConsultDBAPayload(_BasePayload):
    kind: Literal["consult_dba"] = "consult_dba"
    question: str = Field(min_length=1, max_length=400)


ToolPayload = Annotated[
    ListTablesPayload
    | DescribeTablePayload
    | SampleRowsPayload
    | RunQueryPayload
    | ExplainQueryPayload
    | ReadChangelogPayload
    | SubmitRewritePayload
    | ConsultDBAPayload,
    Field(discriminator="kind"),
]


# Tool -> payload-kind mapping; single source of truth for cross-validation
# and for the server-side dispatcher in P7.
TOOL_TO_PAYLOAD_KIND: dict[ToolName, str] = {
    ToolName.LIST_TABLES: "list_tables",
    ToolName.DESCRIBE_TABLE: "describe_table",
    ToolName.SAMPLE_ROWS: "sample_rows",
    ToolName.RUN_QUERY: "run_query",
    ToolName.EXPLAIN_QUERY: "explain_query",
    ToolName.READ_CHANGELOG: "read_changelog",
    ToolName.SUBMIT_REWRITE: "submit_rewrite",
    ToolName.CONSULT_DBA: "consult_dba",
}


# =============================================================================
# SqlDriftAction envelope
# =============================================================================


class SqlDriftAction(Action):
    """Tool-call envelope.

    JSON wire format::

        {"tool": "run_query", "payload": {"kind": "run_query", "sql": "..."}}

    The `tool` field and `payload.kind` must agree; mismatch raises at
    validation time.
    """

    tool: ToolName
    payload: ToolPayload

    @model_validator(mode="after")
    def _tool_matches_payload(self) -> SqlDriftAction:
        expected = TOOL_TO_PAYLOAD_KIND[self.tool]
        if self.payload.kind != expected:
            # PydanticCustomError keeps ``ctx`` JSON-serializable (plain
            # strings only), unlike a bare ``ValueError`` which Pydantic
            # wraps with ``ctx={"error": ValueError(...)}`` and breaks
            # FastAPI HTTPException JSON encoder (422 responses).
            raise PydanticCustomError(
                "tool_payload_mismatch",
                "tool/payload mismatch: tool={tool} expects payload.kind={expected}, got {got}",
                {
                    "tool": self.tool.value,
                    "expected": expected,
                    "got": self.payload.kind,
                },
            )
        return self


# =============================================================================
# Tool results (response side of `SqlDriftObservation.tool_result`)
# =============================================================================


class _BaseResult(BaseModel):
    model_config = ConfigDict(extra="forbid", validate_assignment=True)


class ListTablesResult(_BaseResult):
    kind: Literal["list_tables_result"] = "list_tables_result"
    tables: list[str]


class DescribeTableResult(_BaseResult):
    kind: Literal["describe_table_result"] = "describe_table_result"
    table: str
    columns: list[dict[str, str]]  # [{"name": "...", "type": "..."}]


class SampleRowsResult(_BaseResult):
    kind: Literal["sample_rows_result"] = "sample_rows_result"
    table: str
    columns: list[str]
    rows: list[list[Any]]


class RunQueryResult(_BaseResult):
    kind: Literal["run_query_result"] = "run_query_result"
    columns: list[str]
    rows: list[list[Any]]
    runtime_ms: float
    row_count: int


class ExplainQueryResult(_BaseResult):
    kind: Literal["explain_query_result"] = "explain_query_result"
    plan: str


class ReadChangelogResult(_BaseResult):
    kind: Literal["read_changelog_result"] = "read_changelog_result"
    entries: list[str]


class SubmitRewriteResult(_BaseResult):
    kind: Literal["submit_rewrite_result"] = "submit_rewrite_result"
    accepted: bool
    runtime_ms: float
    matches_ground_truth: bool


class ConsultDBAResult(_BaseResult):
    kind: Literal["consult_dba_result"] = "consult_dba_result"
    tier: int = Field(ge=1, le=3)
    hint: str


class ToolError(_BaseResult):
    kind: Literal["tool_error"] = "tool_error"
    code: ToolErrorCode
    message: str = Field(max_length=2_000)


ToolResult = Annotated[
    ListTablesResult
    | DescribeTableResult
    | SampleRowsResult
    | RunQueryResult
    | ExplainQueryResult
    | ReadChangelogResult
    | SubmitRewriteResult
    | ConsultDBAResult
    | ToolError,
    Field(discriminator="kind"),
]


# The six reward-component keys match the composed rubric; tests and telemetry
# rely on this exact schema.
REWARD_COMPONENT_KEYS: tuple[str, ...] = (
    "r_correct",
    "r_drift",
    "r_speedup",
    "r_step_tax",
    "r_gatekeepers",
    "r_consult_dba",
)


# =============================================================================
# SqlDriftObservation
# =============================================================================


def _zero_reward_components() -> dict[str, float]:
    """Six-key reward envelope initialised to zero.

    Every observation, including the reset observation, carries the full
    six-key schema so telemetry and tests can index it unconditionally.
    """
    return {key: 0.0 for key in REWARD_COMPONENT_KEYS}


class SqlDriftObservation(Observation):
    """Observation returned by :meth:`SqlDriftEnvironment.step`.

    Inherits `done: bool` and `reward: float | None` from base Observation.

    The task payload (`baseline_sql`, `schema_synopsis`) is delivered on
    the reset observation and kept empty on subsequent steps: the agent
    is expected to capture it once and hold it in its own context.
    """

    step: int = Field(ge=0)
    phase: EpisodePhase
    last_tool: ToolName | None = None
    tool_result: ToolResult | None = None
    drift_fired: bool = False
    drift_acknowledged: bool = False
    learned_hints: str = Field(default="", max_length=800)
    baseline_sql: str = Field(default="", max_length=10_000)
    schema_synopsis: str = Field(default="", max_length=2_000)
    budget_steps_remaining: int = Field(ge=0)
    reward_components: dict[str, float] = Field(default_factory=_zero_reward_components)


# =============================================================================
# SqlDriftState — PUBLIC state (sanitized)
# =============================================================================


class SqlDriftState(State):
    """Public state snapshot — serialized over `/state`.

    Ground truth, DB handles, seeds, and baseline SQL live in
    :class:`engine.runtime.RuntimeEpisodeState` and are never exposed here.
    `extra="forbid"` guarantees no accidental leak via future field additions.
    """

    model_config = ConfigDict(
        extra="forbid",
        validate_assignment=True,
    )

    scenario_id: str
    phase: EpisodePhase
    budget_steps_remaining: int = Field(ge=0)
    drift_fired: bool = False
    consultations_used: int = Field(default=0, ge=0)
    submitted: bool = False


__all__ = [
    "ConsultDBAPayload",
    "ConsultDBAResult",
    "DescribeTablePayload",
    "DescribeTableResult",
    "EpisodePhase",
    "ExplainQueryPayload",
    "ExplainQueryResult",
    "ListTablesPayload",
    "ListTablesResult",
    "REWARD_COMPONENT_KEYS",
    "ReadChangelogPayload",
    "ReadChangelogResult",
    "RunQueryPayload",
    "RunQueryResult",
    "SampleRowsPayload",
    "SampleRowsResult",
    "SqlDriftAction",
    "SqlDriftObservation",
    "SqlDriftState",
    "SubmitRewritePayload",
    "SubmitRewriteResult",
    "TOOL_TO_PAYLOAD_KIND",
    "ToolError",
    "ToolErrorCode",
    "ToolName",
    "ToolPayload",
    "ToolResult",
]