File size: 5,230 Bytes
abd4352 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 | """
tests/unit/test_error_classifier.py
Tests for the error classifier node and LangGraph routing conditions.
"""
import json
import pytest
# ββ Error Classifier ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@pytest.mark.unit
class TestErrorClassifier:
def test_classifies_nonexistent_column(self, sample_state, mocker):
mocker.patch(
"agent.nodes.error_classifier.get_groq_client"
).return_value.complete_system.return_value = json.dumps(
{"error_class": "nonexistent_column", "hint": "use exact column name"}
)
from agent.nodes.error_classifier import error_classifier
state = {**sample_state, "execution_error": 'column "revenue" does not exist'}
result = error_classifier(state)
assert result["error_class"] == "nonexistent_column"
def test_classifies_syntax_error(self, sample_state, mocker):
mocker.patch(
"agent.nodes.error_classifier.get_groq_client"
).return_value.complete_system.return_value = json.dumps(
{"error_class": "syntax", "hint": "fix syntax near WHERE"}
)
from agent.nodes.error_classifier import error_classifier
state = {**sample_state, "execution_error": "syntax error at or near WHERE"}
result = error_classifier(state)
assert result["error_class"] == "syntax"
def test_defaults_to_unknown_on_bad_json(self, sample_state, mocker):
mocker.patch(
"agent.nodes.error_classifier.get_groq_client"
).return_value.complete_system.return_value = "not json"
from agent.nodes.error_classifier import error_classifier
state = {**sample_state, "execution_error": "some error"}
result = error_classifier(state)
assert result["error_class"] == "unknown"
def test_no_error_returns_state_unchanged(self, sample_state):
from agent.nodes.error_classifier import error_classifier
state = {**sample_state, "execution_error": None}
result = error_classifier(state)
# Should not have changed error_class
assert result.get("error_class") == sample_state.get("error_class")
def test_classifies_type_mismatch(self, sample_state, mocker):
mocker.patch(
"agent.nodes.error_classifier.get_groq_client"
).return_value.complete_system.return_value = json.dumps(
{"error_class": "type_mismatch", "hint": "cast the column"}
)
from agent.nodes.error_classifier import error_classifier
state = {**sample_state, "execution_error": "cannot compare integer with text"}
result = error_classifier(state)
assert result["error_class"] == "type_mismatch"
# ββ Graph routing conditions ββββββββββββββββββββββββββββββββββββββββββββββββββ
@pytest.mark.unit
class TestGraphRouting:
def test_route_intent_sql(self, sample_state):
from agent.graph import route_intent
assert route_intent({**sample_state, "intent": "sql"}) == "sql"
def test_route_intent_pandas(self, sample_state):
from agent.graph import route_intent
assert route_intent({**sample_state, "intent": "pandas"}) == "pandas"
def test_route_intent_unsupported(self, sample_state):
from agent.graph import route_intent
from langgraph.graph import END
assert route_intent({**sample_state, "intent": "unsupported"}) == "unsupported"
def test_route_intent_insight_only(self, sample_state):
from agent.graph import route_intent
assert route_intent({**sample_state, "intent": "insight"}) == "insight_only"
def test_route_after_execution_success(self, sample_state):
from agent.graph import route_after_execution
state = {**sample_state, "execution_error": None, "execution_result": [{"x": 1}]}
assert route_after_execution(state) == "success"
def test_route_after_execution_corrects_on_error(self, sample_state):
from agent.graph import route_after_execution
state = {**sample_state, "execution_error": "column missing", "correction_attempts": 0, "max_corrections": 3}
assert route_after_execution(state) == "correct"
def test_route_after_execution_gives_up_at_max(self, sample_state):
from agent.graph import route_after_execution
state = {**sample_state, "execution_error": "still broken", "correction_attempts": 3, "max_corrections": 3}
assert route_after_execution(state) == "give_up"
def test_route_after_validation_passes_safe_code(self, sample_state):
from agent.graph import route_after_validation
state = {**sample_state, "execution_error": None}
assert route_after_validation(state) == "execute"
def test_route_after_validation_blocks_unsafe_code(self, sample_state):
from agent.graph import route_after_validation
state = {**sample_state, "execution_error": "SAFETY_BLOCK: Drop operation"}
assert route_after_validation(state) == "blocked"
|