ar9avg's picture
Initial submission: SQL Agent OpenEnv for Meta+HF hackathon
3c665d2
"""
RL type definitions and feature engineering.
Mirrors the TypeScript types.ts exactly:
- 8 error classes, 8 repair actions
- FEATURE_DIM = 20
- featurize() builds the state vector
"""
from __future__ import annotations
from enum import IntEnum
from typing import Optional, List, Dict, Any
from pydantic import BaseModel
# ─── Error Taxonomy ─────────────────────────────────────────────
class ErrorClass(IntEnum):
NO_SUCH_COLUMN = 0
NO_SUCH_TABLE = 1
SYNTAX_ERROR = 2
AMBIGUOUS_COLUMN = 3
DATATYPE_MISMATCH = 4
NO_SUCH_FUNCTION = 5
AGGREGATION_ERROR = 6
OTHER = 7
ERROR_CLASS_NAMES: Dict[ErrorClass, str] = {
ErrorClass.NO_SUCH_COLUMN: "no_such_column",
ErrorClass.NO_SUCH_TABLE: "no_such_table",
ErrorClass.SYNTAX_ERROR: "syntax_error",
ErrorClass.AMBIGUOUS_COLUMN: "ambiguous_column",
ErrorClass.DATATYPE_MISMATCH: "datatype_mismatch",
ErrorClass.NO_SUCH_FUNCTION: "no_such_function",
ErrorClass.AGGREGATION_ERROR: "aggregation_error",
ErrorClass.OTHER: "other",
}
NUM_ERROR_CLASSES = 8
# ─── Repair Actions ─────────────────────────────────────────────
class RepairAction(IntEnum):
REWRITE_FULL = 0
FIX_COLUMN = 1
FIX_TABLE = 2
ADD_GROUPBY = 3
REWRITE_CTE = 4
FIX_SYNTAX = 5
CHANGE_DIALECT = 6
RELAX_FILTER = 7
REPAIR_ACTION_NAMES: Dict[RepairAction, str] = {
RepairAction.REWRITE_FULL: "rewrite_full",
RepairAction.FIX_COLUMN: "fix_column",
RepairAction.FIX_TABLE: "fix_table",
RepairAction.ADD_GROUPBY: "add_groupby",
RepairAction.REWRITE_CTE: "rewrite_cte",
RepairAction.FIX_SYNTAX: "fix_syntax",
RepairAction.CHANGE_DIALECT: "change_dialect",
RepairAction.RELAX_FILTER: "relax_filter",
}
# Inverse map: name β†’ enum
REPAIR_ACTION_BY_NAME: Dict[str, RepairAction] = {v: k for k, v in REPAIR_ACTION_NAMES.items()}
NUM_ACTIONS = 8
# Feature vector:
# [0..7] error class one-hot (8)
# [8] attempt / 5.0 (1)
# [9..16] prev action one-hot (8)
# [17] error_changed (1)
# [18] consec_count / 5.0 (1)
# [19] bias = 1.0 (1)
# total = 20
FEATURE_DIM = 20
# ─── State ──────────────────────────────────────────────────────
class RLState(BaseModel):
error_class: ErrorClass
attempt_number: int # 1-indexed
previous_action: Optional[RepairAction] = None
error_changed: bool = False
consecutive_same_error: int = 1
def featurize(state: RLState) -> List[float]:
"""Build the 20-dimensional feature vector from an RLState."""
x = [0.0] * FEATURE_DIM
# Error class one-hot [0..7]
x[state.error_class] = 1.0
# Attempt number normalized [8]
x[8] = state.attempt_number / 5.0
# Previous action one-hot [9..16]
if state.previous_action is not None:
x[9 + int(state.previous_action)] = 1.0
# Error changed flag [17]
x[17] = 1.0 if state.error_changed else 0.0
# Consecutive same error normalized [18]
x[18] = min(state.consecutive_same_error, 5) / 5.0
# Bias term [19]
x[19] = 1.0
return x
# ─── Experience / Episode ────────────────────────────────────────
class EpisodeStep(BaseModel):
state: RLState
featurized: List[float]
action: RepairAction
reward: float
error_message: str
sql: str
success: bool
class Episode(BaseModel):
id: str
question: str
steps: List[EpisodeStep]
total_reward: float
success: bool
timestamp: float
class Experience(BaseModel):
state: List[float]
action: RepairAction
reward: float
next_state: Optional[List[float]] = None
done: bool
timestamp: float
metadata: Dict[str, Any]
# ─── Metrics ────────────────────────────────────────────────────
class RLMetrics(BaseModel):
total_episodes: int
total_steps: int
cumulative_reward: float
success_rate: float
avg_attempts: float
action_distribution: Dict[str, int]
error_distribution: Dict[str, int]
reward_history: List[float]