Spaces:
Sleeping
Sleeping
| """ | |
| SQL error classifier: maps raw SQLite error messages to one of 8 | |
| canonical ErrorClass values. | |
| Severity ordering (lower = less severe / closer to correct): | |
| OTHER=5, SYNTAX_ERROR=4, NO_SUCH_FUNCTION=3, NO_SUCH_TABLE=3, | |
| DATATYPE_MISMATCH=2, AGGREGATION_ERROR=2, | |
| NO_SUCH_COLUMN=1, AMBIGUOUS_COLUMN=1 | |
| """ | |
| import re | |
| from typing import Optional | |
| from rl.types import ErrorClass | |
| _SEVERITY: dict[ErrorClass, int] = { | |
| ErrorClass.OTHER: 5, | |
| ErrorClass.SYNTAX_ERROR: 4, | |
| ErrorClass.NO_SUCH_FUNCTION: 3, | |
| ErrorClass.NO_SUCH_TABLE: 3, | |
| ErrorClass.DATATYPE_MISMATCH: 2, | |
| ErrorClass.AGGREGATION_ERROR: 2, | |
| ErrorClass.NO_SUCH_COLUMN: 1, | |
| ErrorClass.AMBIGUOUS_COLUMN: 1, | |
| } | |
| def error_severity(error_class: ErrorClass) -> int: | |
| return _SEVERITY[error_class] | |
| def classify_error(error_message: str) -> ErrorClass: | |
| """ | |
| Classify a raw SQLite error message into one of 8 canonical classes. | |
| Patterns are ordered most-specific-first to avoid false matches. | |
| """ | |
| msg = error_message.lower() | |
| # Column-level errors | |
| if "no such column" in msg: | |
| return ErrorClass.NO_SUCH_COLUMN | |
| if "ambiguous column" in msg: | |
| return ErrorClass.AMBIGUOUS_COLUMN | |
| # Table-level errors | |
| if "no such table" in msg: | |
| return ErrorClass.NO_SUCH_TABLE | |
| # Function errors | |
| if "no such function" in msg: | |
| return ErrorClass.NO_SUCH_FUNCTION | |
| # Aggregation / GROUP BY | |
| if ( | |
| "not an aggregate" in msg | |
| or "misuse of aggregate" in msg | |
| or ("group by" in msg and "must appear" in msg) | |
| or "must be an aggregate" in msg | |
| ): | |
| return ErrorClass.AGGREGATION_ERROR | |
| # Syntax errors (broad — must come after more specific patterns) | |
| if "syntax error" in msg or re.search(r'near\s+"', msg): | |
| return ErrorClass.SYNTAX_ERROR | |
| # Type errors | |
| if "datatype mismatch" in msg or "type mismatch" in msg: | |
| return ErrorClass.DATATYPE_MISMATCH | |
| return ErrorClass.OTHER | |
| def extract_offending_token(error_message: str) -> Optional[str]: | |
| """ | |
| Extract the offending token from a SQLite error message. | |
| Returns None if no specific token can be identified. | |
| """ | |
| # "no such column: X" | |
| m = re.search(r"no such column:\s*(\S+)", error_message, re.IGNORECASE) | |
| if m: | |
| return m.group(1) | |
| # "no such table: X" | |
| m = re.search(r"no such table:\s*(\S+)", error_message, re.IGNORECASE) | |
| if m: | |
| return m.group(1) | |
| # 'near "X": syntax error' | |
| m = re.search(r'near\s+"([^"]+)"', error_message, re.IGNORECASE) | |
| if m: | |
| return m.group(1) | |
| # "no such function: X" | |
| m = re.search(r"no such function:\s*(\S+)", error_message, re.IGNORECASE) | |
| if m: | |
| return m.group(1) | |
| return None | |