""" RL type definitions and feature engineering. Mirrors the TypeScript types.ts exactly: - 8 error classes, 8 repair actions - FEATURE_DIM = 20 - featurize() builds the state vector """ from __future__ import annotations from enum import IntEnum from typing import Optional, List, Dict, Any from pydantic import BaseModel # ─── Error Taxonomy ───────────────────────────────────────────── class ErrorClass(IntEnum): NO_SUCH_COLUMN = 0 NO_SUCH_TABLE = 1 SYNTAX_ERROR = 2 AMBIGUOUS_COLUMN = 3 DATATYPE_MISMATCH = 4 NO_SUCH_FUNCTION = 5 AGGREGATION_ERROR = 6 OTHER = 7 ERROR_CLASS_NAMES: Dict[ErrorClass, str] = { ErrorClass.NO_SUCH_COLUMN: "no_such_column", ErrorClass.NO_SUCH_TABLE: "no_such_table", ErrorClass.SYNTAX_ERROR: "syntax_error", ErrorClass.AMBIGUOUS_COLUMN: "ambiguous_column", ErrorClass.DATATYPE_MISMATCH: "datatype_mismatch", ErrorClass.NO_SUCH_FUNCTION: "no_such_function", ErrorClass.AGGREGATION_ERROR: "aggregation_error", ErrorClass.OTHER: "other", } NUM_ERROR_CLASSES = 8 # ─── Repair Actions ───────────────────────────────────────────── class RepairAction(IntEnum): REWRITE_FULL = 0 FIX_COLUMN = 1 FIX_TABLE = 2 ADD_GROUPBY = 3 REWRITE_CTE = 4 FIX_SYNTAX = 5 CHANGE_DIALECT = 6 RELAX_FILTER = 7 REPAIR_ACTION_NAMES: Dict[RepairAction, str] = { RepairAction.REWRITE_FULL: "rewrite_full", RepairAction.FIX_COLUMN: "fix_column", RepairAction.FIX_TABLE: "fix_table", RepairAction.ADD_GROUPBY: "add_groupby", RepairAction.REWRITE_CTE: "rewrite_cte", RepairAction.FIX_SYNTAX: "fix_syntax", RepairAction.CHANGE_DIALECT: "change_dialect", RepairAction.RELAX_FILTER: "relax_filter", } # Inverse map: name → enum REPAIR_ACTION_BY_NAME: Dict[str, RepairAction] = {v: k for k, v in REPAIR_ACTION_NAMES.items()} NUM_ACTIONS = 8 # Feature vector: # [0..7] error class one-hot (8) # [8] attempt / 5.0 (1) # [9..16] prev action one-hot (8) # [17] error_changed (1) # [18] consec_count / 5.0 (1) # [19] bias = 1.0 (1) # total = 20 FEATURE_DIM = 20 # ─── State ────────────────────────────────────────────────────── class RLState(BaseModel): error_class: ErrorClass attempt_number: int # 1-indexed previous_action: Optional[RepairAction] = None error_changed: bool = False consecutive_same_error: int = 1 def featurize(state: RLState) -> List[float]: """Build the 20-dimensional feature vector from an RLState.""" x = [0.0] * FEATURE_DIM # Error class one-hot [0..7] x[state.error_class] = 1.0 # Attempt number normalized [8] x[8] = state.attempt_number / 5.0 # Previous action one-hot [9..16] if state.previous_action is not None: x[9 + int(state.previous_action)] = 1.0 # Error changed flag [17] x[17] = 1.0 if state.error_changed else 0.0 # Consecutive same error normalized [18] x[18] = min(state.consecutive_same_error, 5) / 5.0 # Bias term [19] x[19] = 1.0 return x # ─── Experience / Episode ──────────────────────────────────────── class EpisodeStep(BaseModel): state: RLState featurized: List[float] action: RepairAction reward: float error_message: str sql: str success: bool class Episode(BaseModel): id: str question: str steps: List[EpisodeStep] total_reward: float success: bool timestamp: float class Experience(BaseModel): state: List[float] action: RepairAction reward: float next_state: Optional[List[float]] = None done: bool timestamp: float metadata: Dict[str, Any] # ─── Metrics ──────────────────────────────────────────────────── class RLMetrics(BaseModel): total_episodes: int total_steps: int cumulative_reward: float success_rate: float avg_attempts: float action_distribution: Dict[str, int] error_distribution: Dict[str, int] reward_history: List[float]