File size: 2,787 Bytes
3c665d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""
SQL error classifier: maps raw SQLite error messages to one of 8
canonical ErrorClass values.

Severity ordering (lower = less severe / closer to correct):
  OTHER=5, SYNTAX_ERROR=4, NO_SUCH_FUNCTION=3, NO_SUCH_TABLE=3,
  DATATYPE_MISMATCH=2, AGGREGATION_ERROR=2,
  NO_SUCH_COLUMN=1, AMBIGUOUS_COLUMN=1
"""

import re
from typing import Optional

from rl.types import ErrorClass

_SEVERITY: dict[ErrorClass, int] = {
    ErrorClass.OTHER:             5,
    ErrorClass.SYNTAX_ERROR:      4,
    ErrorClass.NO_SUCH_FUNCTION:  3,
    ErrorClass.NO_SUCH_TABLE:     3,
    ErrorClass.DATATYPE_MISMATCH: 2,
    ErrorClass.AGGREGATION_ERROR: 2,
    ErrorClass.NO_SUCH_COLUMN:    1,
    ErrorClass.AMBIGUOUS_COLUMN:  1,
}


def error_severity(error_class: ErrorClass) -> int:
    return _SEVERITY[error_class]


def classify_error(error_message: str) -> ErrorClass:
    """
    Classify a raw SQLite error message into one of 8 canonical classes.
    Patterns are ordered most-specific-first to avoid false matches.
    """
    msg = error_message.lower()

    # Column-level errors
    if "no such column" in msg:
        return ErrorClass.NO_SUCH_COLUMN
    if "ambiguous column" in msg:
        return ErrorClass.AMBIGUOUS_COLUMN

    # Table-level errors
    if "no such table" in msg:
        return ErrorClass.NO_SUCH_TABLE

    # Function errors
    if "no such function" in msg:
        return ErrorClass.NO_SUCH_FUNCTION

    # Aggregation / GROUP BY
    if (
        "not an aggregate" in msg
        or "misuse of aggregate" in msg
        or ("group by" in msg and "must appear" in msg)
        or "must be an aggregate" in msg
    ):
        return ErrorClass.AGGREGATION_ERROR

    # Syntax errors (broad — must come after more specific patterns)
    if "syntax error" in msg or re.search(r'near\s+"', msg):
        return ErrorClass.SYNTAX_ERROR

    # Type errors
    if "datatype mismatch" in msg or "type mismatch" in msg:
        return ErrorClass.DATATYPE_MISMATCH

    return ErrorClass.OTHER


def extract_offending_token(error_message: str) -> Optional[str]:
    """
    Extract the offending token from a SQLite error message.
    Returns None if no specific token can be identified.
    """
    # "no such column: X"
    m = re.search(r"no such column:\s*(\S+)", error_message, re.IGNORECASE)
    if m:
        return m.group(1)

    # "no such table: X"
    m = re.search(r"no such table:\s*(\S+)", error_message, re.IGNORECASE)
    if m:
        return m.group(1)

    # 'near "X": syntax error'
    m = re.search(r'near\s+"([^"]+)"', error_message, re.IGNORECASE)
    if m:
        return m.group(1)

    # "no such function: X"
    m = re.search(r"no such function:\s*(\S+)", error_message, re.IGNORECASE)
    if m:
        return m.group(1)

    return None