Spaces:
Sleeping
Sleeping
File size: 2,787 Bytes
3c665d2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 | """
SQL error classifier: maps raw SQLite error messages to one of 8
canonical ErrorClass values.
Severity ordering (lower = less severe / closer to correct):
OTHER=5, SYNTAX_ERROR=4, NO_SUCH_FUNCTION=3, NO_SUCH_TABLE=3,
DATATYPE_MISMATCH=2, AGGREGATION_ERROR=2,
NO_SUCH_COLUMN=1, AMBIGUOUS_COLUMN=1
"""
import re
from typing import Optional
from rl.types import ErrorClass
_SEVERITY: dict[ErrorClass, int] = {
ErrorClass.OTHER: 5,
ErrorClass.SYNTAX_ERROR: 4,
ErrorClass.NO_SUCH_FUNCTION: 3,
ErrorClass.NO_SUCH_TABLE: 3,
ErrorClass.DATATYPE_MISMATCH: 2,
ErrorClass.AGGREGATION_ERROR: 2,
ErrorClass.NO_SUCH_COLUMN: 1,
ErrorClass.AMBIGUOUS_COLUMN: 1,
}
def error_severity(error_class: ErrorClass) -> int:
return _SEVERITY[error_class]
def classify_error(error_message: str) -> ErrorClass:
"""
Classify a raw SQLite error message into one of 8 canonical classes.
Patterns are ordered most-specific-first to avoid false matches.
"""
msg = error_message.lower()
# Column-level errors
if "no such column" in msg:
return ErrorClass.NO_SUCH_COLUMN
if "ambiguous column" in msg:
return ErrorClass.AMBIGUOUS_COLUMN
# Table-level errors
if "no such table" in msg:
return ErrorClass.NO_SUCH_TABLE
# Function errors
if "no such function" in msg:
return ErrorClass.NO_SUCH_FUNCTION
# Aggregation / GROUP BY
if (
"not an aggregate" in msg
or "misuse of aggregate" in msg
or ("group by" in msg and "must appear" in msg)
or "must be an aggregate" in msg
):
return ErrorClass.AGGREGATION_ERROR
# Syntax errors (broad — must come after more specific patterns)
if "syntax error" in msg or re.search(r'near\s+"', msg):
return ErrorClass.SYNTAX_ERROR
# Type errors
if "datatype mismatch" in msg or "type mismatch" in msg:
return ErrorClass.DATATYPE_MISMATCH
return ErrorClass.OTHER
def extract_offending_token(error_message: str) -> Optional[str]:
"""
Extract the offending token from a SQLite error message.
Returns None if no specific token can be identified.
"""
# "no such column: X"
m = re.search(r"no such column:\s*(\S+)", error_message, re.IGNORECASE)
if m:
return m.group(1)
# "no such table: X"
m = re.search(r"no such table:\s*(\S+)", error_message, re.IGNORECASE)
if m:
return m.group(1)
# 'near "X": syntax error'
m = re.search(r'near\s+"([^"]+)"', error_message, re.IGNORECASE)
if m:
return m.group(1)
# "no such function: X"
m = re.search(r"no such function:\s*(\S+)", error_message, re.IGNORECASE)
if m:
return m.group(1)
return None
|