File size: 5,230 Bytes
abd4352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""
tests/unit/test_error_classifier.py
Tests for the error classifier node and LangGraph routing conditions.
"""

import json
import pytest


# ── Error Classifier ──────────────────────────────────────────────────────────

@pytest.mark.unit
class TestErrorClassifier:
    def test_classifies_nonexistent_column(self, sample_state, mocker):
        mocker.patch(
            "agent.nodes.error_classifier.get_groq_client"
        ).return_value.complete_system.return_value = json.dumps(
            {"error_class": "nonexistent_column", "hint": "use exact column name"}
        )
        from agent.nodes.error_classifier import error_classifier
        state = {**sample_state, "execution_error": 'column "revenue" does not exist'}
        result = error_classifier(state)
        assert result["error_class"] == "nonexistent_column"

    def test_classifies_syntax_error(self, sample_state, mocker):
        mocker.patch(
            "agent.nodes.error_classifier.get_groq_client"
        ).return_value.complete_system.return_value = json.dumps(
            {"error_class": "syntax", "hint": "fix syntax near WHERE"}
        )
        from agent.nodes.error_classifier import error_classifier
        state = {**sample_state, "execution_error": "syntax error at or near WHERE"}
        result = error_classifier(state)
        assert result["error_class"] == "syntax"

    def test_defaults_to_unknown_on_bad_json(self, sample_state, mocker):
        mocker.patch(
            "agent.nodes.error_classifier.get_groq_client"
        ).return_value.complete_system.return_value = "not json"
        from agent.nodes.error_classifier import error_classifier
        state = {**sample_state, "execution_error": "some error"}
        result = error_classifier(state)
        assert result["error_class"] == "unknown"

    def test_no_error_returns_state_unchanged(self, sample_state):
        from agent.nodes.error_classifier import error_classifier
        state = {**sample_state, "execution_error": None}
        result = error_classifier(state)
        # Should not have changed error_class
        assert result.get("error_class") == sample_state.get("error_class")

    def test_classifies_type_mismatch(self, sample_state, mocker):
        mocker.patch(
            "agent.nodes.error_classifier.get_groq_client"
        ).return_value.complete_system.return_value = json.dumps(
            {"error_class": "type_mismatch", "hint": "cast the column"}
        )
        from agent.nodes.error_classifier import error_classifier
        state = {**sample_state, "execution_error": "cannot compare integer with text"}
        result = error_classifier(state)
        assert result["error_class"] == "type_mismatch"


# ── Graph routing conditions ──────────────────────────────────────────────────

@pytest.mark.unit
class TestGraphRouting:
    def test_route_intent_sql(self, sample_state):
        from agent.graph import route_intent
        assert route_intent({**sample_state, "intent": "sql"}) == "sql"

    def test_route_intent_pandas(self, sample_state):
        from agent.graph import route_intent
        assert route_intent({**sample_state, "intent": "pandas"}) == "pandas"

    def test_route_intent_unsupported(self, sample_state):
        from agent.graph import route_intent
        from langgraph.graph import END
        assert route_intent({**sample_state, "intent": "unsupported"}) == "unsupported"

    def test_route_intent_insight_only(self, sample_state):
        from agent.graph import route_intent
        assert route_intent({**sample_state, "intent": "insight"}) == "insight_only"

    def test_route_after_execution_success(self, sample_state):
        from agent.graph import route_after_execution
        state = {**sample_state, "execution_error": None, "execution_result": [{"x": 1}]}
        assert route_after_execution(state) == "success"

    def test_route_after_execution_corrects_on_error(self, sample_state):
        from agent.graph import route_after_execution
        state = {**sample_state, "execution_error": "column missing", "correction_attempts": 0, "max_corrections": 3}
        assert route_after_execution(state) == "correct"

    def test_route_after_execution_gives_up_at_max(self, sample_state):
        from agent.graph import route_after_execution
        state = {**sample_state, "execution_error": "still broken", "correction_attempts": 3, "max_corrections": 3}
        assert route_after_execution(state) == "give_up"

    def test_route_after_validation_passes_safe_code(self, sample_state):
        from agent.graph import route_after_validation
        state = {**sample_state, "execution_error": None}
        assert route_after_validation(state) == "execute"

    def test_route_after_validation_blocks_unsafe_code(self, sample_state):
        from agent.graph import route_after_validation
        state = {**sample_state, "execution_error": "SAFETY_BLOCK: Drop operation"}
        assert route_after_validation(state) == "blocked"