sql-agent-openenv / backend /rl /error_classifier.py
ar9avg's picture
Initial submission: SQL Agent OpenEnv for Meta+HF hackathon
3c665d2
"""
SQL error classifier: maps raw SQLite error messages to one of 8
canonical ErrorClass values.
Severity ordering (lower = less severe / closer to correct):
OTHER=5, SYNTAX_ERROR=4, NO_SUCH_FUNCTION=3, NO_SUCH_TABLE=3,
DATATYPE_MISMATCH=2, AGGREGATION_ERROR=2,
NO_SUCH_COLUMN=1, AMBIGUOUS_COLUMN=1
"""
import re
from typing import Optional
from rl.types import ErrorClass
_SEVERITY: dict[ErrorClass, int] = {
ErrorClass.OTHER: 5,
ErrorClass.SYNTAX_ERROR: 4,
ErrorClass.NO_SUCH_FUNCTION: 3,
ErrorClass.NO_SUCH_TABLE: 3,
ErrorClass.DATATYPE_MISMATCH: 2,
ErrorClass.AGGREGATION_ERROR: 2,
ErrorClass.NO_SUCH_COLUMN: 1,
ErrorClass.AMBIGUOUS_COLUMN: 1,
}
def error_severity(error_class: ErrorClass) -> int:
return _SEVERITY[error_class]
def classify_error(error_message: str) -> ErrorClass:
"""
Classify a raw SQLite error message into one of 8 canonical classes.
Patterns are ordered most-specific-first to avoid false matches.
"""
msg = error_message.lower()
# Column-level errors
if "no such column" in msg:
return ErrorClass.NO_SUCH_COLUMN
if "ambiguous column" in msg:
return ErrorClass.AMBIGUOUS_COLUMN
# Table-level errors
if "no such table" in msg:
return ErrorClass.NO_SUCH_TABLE
# Function errors
if "no such function" in msg:
return ErrorClass.NO_SUCH_FUNCTION
# Aggregation / GROUP BY
if (
"not an aggregate" in msg
or "misuse of aggregate" in msg
or ("group by" in msg and "must appear" in msg)
or "must be an aggregate" in msg
):
return ErrorClass.AGGREGATION_ERROR
# Syntax errors (broad — must come after more specific patterns)
if "syntax error" in msg or re.search(r'near\s+"', msg):
return ErrorClass.SYNTAX_ERROR
# Type errors
if "datatype mismatch" in msg or "type mismatch" in msg:
return ErrorClass.DATATYPE_MISMATCH
return ErrorClass.OTHER
def extract_offending_token(error_message: str) -> Optional[str]:
"""
Extract the offending token from a SQLite error message.
Returns None if no specific token can be identified.
"""
# "no such column: X"
m = re.search(r"no such column:\s*(\S+)", error_message, re.IGNORECASE)
if m:
return m.group(1)
# "no such table: X"
m = re.search(r"no such table:\s*(\S+)", error_message, re.IGNORECASE)
if m:
return m.group(1)
# 'near "X": syntax error'
m = re.search(r'near\s+"([^"]+)"', error_message, re.IGNORECASE)
if m:
return m.group(1)
# "no such function: X"
m = re.search(r"no such function:\s*(\S+)", error_message, re.IGNORECASE)
if m:
return m.group(1)
return None