Spaces:
Sleeping
Sleeping
| """ | |
| RL type definitions and feature engineering. | |
| Mirrors the TypeScript types.ts exactly: | |
| - 8 error classes, 8 repair actions | |
| - FEATURE_DIM = 20 | |
| - featurize() builds the state vector | |
| """ | |
| from __future__ import annotations | |
| from enum import IntEnum | |
| from typing import Optional, List, Dict, Any | |
| from pydantic import BaseModel | |
| # βββ Error Taxonomy βββββββββββββββββββββββββββββββββββββββββββββ | |
| class ErrorClass(IntEnum): | |
| NO_SUCH_COLUMN = 0 | |
| NO_SUCH_TABLE = 1 | |
| SYNTAX_ERROR = 2 | |
| AMBIGUOUS_COLUMN = 3 | |
| DATATYPE_MISMATCH = 4 | |
| NO_SUCH_FUNCTION = 5 | |
| AGGREGATION_ERROR = 6 | |
| OTHER = 7 | |
| ERROR_CLASS_NAMES: Dict[ErrorClass, str] = { | |
| ErrorClass.NO_SUCH_COLUMN: "no_such_column", | |
| ErrorClass.NO_SUCH_TABLE: "no_such_table", | |
| ErrorClass.SYNTAX_ERROR: "syntax_error", | |
| ErrorClass.AMBIGUOUS_COLUMN: "ambiguous_column", | |
| ErrorClass.DATATYPE_MISMATCH: "datatype_mismatch", | |
| ErrorClass.NO_SUCH_FUNCTION: "no_such_function", | |
| ErrorClass.AGGREGATION_ERROR: "aggregation_error", | |
| ErrorClass.OTHER: "other", | |
| } | |
| NUM_ERROR_CLASSES = 8 | |
| # βββ Repair Actions βββββββββββββββββββββββββββββββββββββββββββββ | |
| class RepairAction(IntEnum): | |
| REWRITE_FULL = 0 | |
| FIX_COLUMN = 1 | |
| FIX_TABLE = 2 | |
| ADD_GROUPBY = 3 | |
| REWRITE_CTE = 4 | |
| FIX_SYNTAX = 5 | |
| CHANGE_DIALECT = 6 | |
| RELAX_FILTER = 7 | |
| REPAIR_ACTION_NAMES: Dict[RepairAction, str] = { | |
| RepairAction.REWRITE_FULL: "rewrite_full", | |
| RepairAction.FIX_COLUMN: "fix_column", | |
| RepairAction.FIX_TABLE: "fix_table", | |
| RepairAction.ADD_GROUPBY: "add_groupby", | |
| RepairAction.REWRITE_CTE: "rewrite_cte", | |
| RepairAction.FIX_SYNTAX: "fix_syntax", | |
| RepairAction.CHANGE_DIALECT: "change_dialect", | |
| RepairAction.RELAX_FILTER: "relax_filter", | |
| } | |
| # Inverse map: name β enum | |
| REPAIR_ACTION_BY_NAME: Dict[str, RepairAction] = {v: k for k, v in REPAIR_ACTION_NAMES.items()} | |
| NUM_ACTIONS = 8 | |
| # Feature vector: | |
| # [0..7] error class one-hot (8) | |
| # [8] attempt / 5.0 (1) | |
| # [9..16] prev action one-hot (8) | |
| # [17] error_changed (1) | |
| # [18] consec_count / 5.0 (1) | |
| # [19] bias = 1.0 (1) | |
| # total = 20 | |
| FEATURE_DIM = 20 | |
| # βββ State ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class RLState(BaseModel): | |
| error_class: ErrorClass | |
| attempt_number: int # 1-indexed | |
| previous_action: Optional[RepairAction] = None | |
| error_changed: bool = False | |
| consecutive_same_error: int = 1 | |
| def featurize(state: RLState) -> List[float]: | |
| """Build the 20-dimensional feature vector from an RLState.""" | |
| x = [0.0] * FEATURE_DIM | |
| # Error class one-hot [0..7] | |
| x[state.error_class] = 1.0 | |
| # Attempt number normalized [8] | |
| x[8] = state.attempt_number / 5.0 | |
| # Previous action one-hot [9..16] | |
| if state.previous_action is not None: | |
| x[9 + int(state.previous_action)] = 1.0 | |
| # Error changed flag [17] | |
| x[17] = 1.0 if state.error_changed else 0.0 | |
| # Consecutive same error normalized [18] | |
| x[18] = min(state.consecutive_same_error, 5) / 5.0 | |
| # Bias term [19] | |
| x[19] = 1.0 | |
| return x | |
| # βββ Experience / Episode ββββββββββββββββββββββββββββββββββββββββ | |
| class EpisodeStep(BaseModel): | |
| state: RLState | |
| featurized: List[float] | |
| action: RepairAction | |
| reward: float | |
| error_message: str | |
| sql: str | |
| success: bool | |
| class Episode(BaseModel): | |
| id: str | |
| question: str | |
| steps: List[EpisodeStep] | |
| total_reward: float | |
| success: bool | |
| timestamp: float | |
| class Experience(BaseModel): | |
| state: List[float] | |
| action: RepairAction | |
| reward: float | |
| next_state: Optional[List[float]] = None | |
| done: bool | |
| timestamp: float | |
| metadata: Dict[str, Any] | |
| # βββ Metrics ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class RLMetrics(BaseModel): | |
| total_episodes: int | |
| total_steps: int | |
| cumulative_reward: float | |
| success_rate: float | |
| avg_attempts: float | |
| action_distribution: Dict[str, int] | |
| error_distribution: Dict[str, int] | |
| reward_history: List[float] | |