Spaces:
Sleeping
Sleeping
File size: 4,636 Bytes
3c665d2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 | """
RL type definitions and feature engineering.
Mirrors the TypeScript types.ts exactly:
- 8 error classes, 8 repair actions
- FEATURE_DIM = 20
- featurize() builds the state vector
"""
from __future__ import annotations
from enum import IntEnum
from typing import Optional, List, Dict, Any
from pydantic import BaseModel
# βββ Error Taxonomy βββββββββββββββββββββββββββββββββββββββββββββ
class ErrorClass(IntEnum):
NO_SUCH_COLUMN = 0
NO_SUCH_TABLE = 1
SYNTAX_ERROR = 2
AMBIGUOUS_COLUMN = 3
DATATYPE_MISMATCH = 4
NO_SUCH_FUNCTION = 5
AGGREGATION_ERROR = 6
OTHER = 7
ERROR_CLASS_NAMES: Dict[ErrorClass, str] = {
ErrorClass.NO_SUCH_COLUMN: "no_such_column",
ErrorClass.NO_SUCH_TABLE: "no_such_table",
ErrorClass.SYNTAX_ERROR: "syntax_error",
ErrorClass.AMBIGUOUS_COLUMN: "ambiguous_column",
ErrorClass.DATATYPE_MISMATCH: "datatype_mismatch",
ErrorClass.NO_SUCH_FUNCTION: "no_such_function",
ErrorClass.AGGREGATION_ERROR: "aggregation_error",
ErrorClass.OTHER: "other",
}
NUM_ERROR_CLASSES = 8
# βββ Repair Actions βββββββββββββββββββββββββββββββββββββββββββββ
class RepairAction(IntEnum):
REWRITE_FULL = 0
FIX_COLUMN = 1
FIX_TABLE = 2
ADD_GROUPBY = 3
REWRITE_CTE = 4
FIX_SYNTAX = 5
CHANGE_DIALECT = 6
RELAX_FILTER = 7
REPAIR_ACTION_NAMES: Dict[RepairAction, str] = {
RepairAction.REWRITE_FULL: "rewrite_full",
RepairAction.FIX_COLUMN: "fix_column",
RepairAction.FIX_TABLE: "fix_table",
RepairAction.ADD_GROUPBY: "add_groupby",
RepairAction.REWRITE_CTE: "rewrite_cte",
RepairAction.FIX_SYNTAX: "fix_syntax",
RepairAction.CHANGE_DIALECT: "change_dialect",
RepairAction.RELAX_FILTER: "relax_filter",
}
# Inverse map: name β enum
REPAIR_ACTION_BY_NAME: Dict[str, RepairAction] = {v: k for k, v in REPAIR_ACTION_NAMES.items()}
NUM_ACTIONS = 8
# Feature vector:
# [0..7] error class one-hot (8)
# [8] attempt / 5.0 (1)
# [9..16] prev action one-hot (8)
# [17] error_changed (1)
# [18] consec_count / 5.0 (1)
# [19] bias = 1.0 (1)
# total = 20
FEATURE_DIM = 20
# βββ State ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class RLState(BaseModel):
error_class: ErrorClass
attempt_number: int # 1-indexed
previous_action: Optional[RepairAction] = None
error_changed: bool = False
consecutive_same_error: int = 1
def featurize(state: RLState) -> List[float]:
"""Build the 20-dimensional feature vector from an RLState."""
x = [0.0] * FEATURE_DIM
# Error class one-hot [0..7]
x[state.error_class] = 1.0
# Attempt number normalized [8]
x[8] = state.attempt_number / 5.0
# Previous action one-hot [9..16]
if state.previous_action is not None:
x[9 + int(state.previous_action)] = 1.0
# Error changed flag [17]
x[17] = 1.0 if state.error_changed else 0.0
# Consecutive same error normalized [18]
x[18] = min(state.consecutive_same_error, 5) / 5.0
# Bias term [19]
x[19] = 1.0
return x
# βββ Experience / Episode ββββββββββββββββββββββββββββββββββββββββ
class EpisodeStep(BaseModel):
state: RLState
featurized: List[float]
action: RepairAction
reward: float
error_message: str
sql: str
success: bool
class Episode(BaseModel):
id: str
question: str
steps: List[EpisodeStep]
total_reward: float
success: bool
timestamp: float
class Experience(BaseModel):
state: List[float]
action: RepairAction
reward: float
next_state: Optional[List[float]] = None
done: bool
timestamp: float
metadata: Dict[str, Any]
# βββ Metrics ββββββββββββββββββββββββββββββββββββββββββββββββββββ
class RLMetrics(BaseModel):
total_episodes: int
total_steps: int
cumulative_reward: float
success_rate: float
avg_attempts: float
action_distribution: Dict[str, int]
error_distribution: Dict[str, int]
reward_history: List[float]
|