| """ |
| tests/unit/test_error_classifier.py |
| Tests for the error classifier node and LangGraph routing conditions. |
| """ |
|
|
| import json |
| import pytest |
|
|
|
|
| |
|
|
| @pytest.mark.unit |
| class TestErrorClassifier: |
| def test_classifies_nonexistent_column(self, sample_state, mocker): |
| mocker.patch( |
| "agent.nodes.error_classifier.get_groq_client" |
| ).return_value.complete_system.return_value = json.dumps( |
| {"error_class": "nonexistent_column", "hint": "use exact column name"} |
| ) |
| from agent.nodes.error_classifier import error_classifier |
| state = {**sample_state, "execution_error": 'column "revenue" does not exist'} |
| result = error_classifier(state) |
| assert result["error_class"] == "nonexistent_column" |
|
|
| def test_classifies_syntax_error(self, sample_state, mocker): |
| mocker.patch( |
| "agent.nodes.error_classifier.get_groq_client" |
| ).return_value.complete_system.return_value = json.dumps( |
| {"error_class": "syntax", "hint": "fix syntax near WHERE"} |
| ) |
| from agent.nodes.error_classifier import error_classifier |
| state = {**sample_state, "execution_error": "syntax error at or near WHERE"} |
| result = error_classifier(state) |
| assert result["error_class"] == "syntax" |
|
|
| def test_defaults_to_unknown_on_bad_json(self, sample_state, mocker): |
| mocker.patch( |
| "agent.nodes.error_classifier.get_groq_client" |
| ).return_value.complete_system.return_value = "not json" |
| from agent.nodes.error_classifier import error_classifier |
| state = {**sample_state, "execution_error": "some error"} |
| result = error_classifier(state) |
| assert result["error_class"] == "unknown" |
|
|
| def test_no_error_returns_state_unchanged(self, sample_state): |
| from agent.nodes.error_classifier import error_classifier |
| state = {**sample_state, "execution_error": None} |
| result = error_classifier(state) |
| |
| assert result.get("error_class") == sample_state.get("error_class") |
|
|
| def test_classifies_type_mismatch(self, sample_state, mocker): |
| mocker.patch( |
| "agent.nodes.error_classifier.get_groq_client" |
| ).return_value.complete_system.return_value = json.dumps( |
| {"error_class": "type_mismatch", "hint": "cast the column"} |
| ) |
| from agent.nodes.error_classifier import error_classifier |
| state = {**sample_state, "execution_error": "cannot compare integer with text"} |
| result = error_classifier(state) |
| assert result["error_class"] == "type_mismatch" |
|
|
|
|
| |
|
|
| @pytest.mark.unit |
| class TestGraphRouting: |
| def test_route_intent_sql(self, sample_state): |
| from agent.graph import route_intent |
| assert route_intent({**sample_state, "intent": "sql"}) == "sql" |
|
|
| def test_route_intent_pandas(self, sample_state): |
| from agent.graph import route_intent |
| assert route_intent({**sample_state, "intent": "pandas"}) == "pandas" |
|
|
| def test_route_intent_unsupported(self, sample_state): |
| from agent.graph import route_intent |
| from langgraph.graph import END |
| assert route_intent({**sample_state, "intent": "unsupported"}) == "unsupported" |
|
|
| def test_route_intent_insight_only(self, sample_state): |
| from agent.graph import route_intent |
| assert route_intent({**sample_state, "intent": "insight"}) == "insight_only" |
|
|
| def test_route_after_execution_success(self, sample_state): |
| from agent.graph import route_after_execution |
| state = {**sample_state, "execution_error": None, "execution_result": [{"x": 1}]} |
| assert route_after_execution(state) == "success" |
|
|
| def test_route_after_execution_corrects_on_error(self, sample_state): |
| from agent.graph import route_after_execution |
| state = {**sample_state, "execution_error": "column missing", "correction_attempts": 0, "max_corrections": 3} |
| assert route_after_execution(state) == "correct" |
|
|
| def test_route_after_execution_gives_up_at_max(self, sample_state): |
| from agent.graph import route_after_execution |
| state = {**sample_state, "execution_error": "still broken", "correction_attempts": 3, "max_corrections": 3} |
| assert route_after_execution(state) == "give_up" |
|
|
| def test_route_after_validation_passes_safe_code(self, sample_state): |
| from agent.graph import route_after_validation |
| state = {**sample_state, "execution_error": None} |
| assert route_after_validation(state) == "execute" |
|
|
| def test_route_after_validation_blocks_unsafe_code(self, sample_state): |
| from agent.graph import route_after_validation |
| state = {**sample_state, "execution_error": "SAFETY_BLOCK: Drop operation"} |
| assert route_after_validation(state) == "blocked" |
|
|