""" SQL error classifier: maps raw SQLite error messages to one of 8 canonical ErrorClass values. Severity ordering (lower = less severe / closer to correct): OTHER=5, SYNTAX_ERROR=4, NO_SUCH_FUNCTION=3, NO_SUCH_TABLE=3, DATATYPE_MISMATCH=2, AGGREGATION_ERROR=2, NO_SUCH_COLUMN=1, AMBIGUOUS_COLUMN=1 """ import re from typing import Optional from rl.types import ErrorClass _SEVERITY: dict[ErrorClass, int] = { ErrorClass.OTHER: 5, ErrorClass.SYNTAX_ERROR: 4, ErrorClass.NO_SUCH_FUNCTION: 3, ErrorClass.NO_SUCH_TABLE: 3, ErrorClass.DATATYPE_MISMATCH: 2, ErrorClass.AGGREGATION_ERROR: 2, ErrorClass.NO_SUCH_COLUMN: 1, ErrorClass.AMBIGUOUS_COLUMN: 1, } def error_severity(error_class: ErrorClass) -> int: return _SEVERITY[error_class] def classify_error(error_message: str) -> ErrorClass: """ Classify a raw SQLite error message into one of 8 canonical classes. Patterns are ordered most-specific-first to avoid false matches. """ msg = error_message.lower() # Column-level errors if "no such column" in msg: return ErrorClass.NO_SUCH_COLUMN if "ambiguous column" in msg: return ErrorClass.AMBIGUOUS_COLUMN # Table-level errors if "no such table" in msg: return ErrorClass.NO_SUCH_TABLE # Function errors if "no such function" in msg: return ErrorClass.NO_SUCH_FUNCTION # Aggregation / GROUP BY if ( "not an aggregate" in msg or "misuse of aggregate" in msg or ("group by" in msg and "must appear" in msg) or "must be an aggregate" in msg ): return ErrorClass.AGGREGATION_ERROR # Syntax errors (broad — must come after more specific patterns) if "syntax error" in msg or re.search(r'near\s+"', msg): return ErrorClass.SYNTAX_ERROR # Type errors if "datatype mismatch" in msg or "type mismatch" in msg: return ErrorClass.DATATYPE_MISMATCH return ErrorClass.OTHER def extract_offending_token(error_message: str) -> Optional[str]: """ Extract the offending token from a SQLite error message. Returns None if no specific token can be identified. """ # "no such column: X" m = re.search(r"no such column:\s*(\S+)", error_message, re.IGNORECASE) if m: return m.group(1) # "no such table: X" m = re.search(r"no such table:\s*(\S+)", error_message, re.IGNORECASE) if m: return m.group(1) # 'near "X": syntax error' m = re.search(r'near\s+"([^"]+)"', error_message, re.IGNORECASE) if m: return m.group(1) # "no such function: X" m = re.search(r"no such function:\s*(\S+)", error_message, re.IGNORECASE) if m: return m.group(1) return None