Spaces:
Running
Running
| """Tests for the GraphQueryEngine — Cypher validation and safety guardrails.""" | |
| import pytest | |
| from unittest.mock import AsyncMock, MagicMock, patch | |
| from reachy_mini_conversation_app.graph_query_engine import ( | |
| validate_cypher, | |
| GraphQueryEngine, | |
| ) | |
| class TestCypherValidation: | |
| """Test the Cypher safety validation function.""" | |
| def test_rejects_create(self): | |
| assert validate_cypher("CREATE (n:Person {name: 'test'})") is False | |
| def test_rejects_merge(self): | |
| assert validate_cypher("MERGE (n:Person {name: 'test'})") is False | |
| def test_rejects_delete(self): | |
| assert validate_cypher("MATCH (n) DELETE n") is False | |
| def test_rejects_detach_delete(self): | |
| assert validate_cypher("MATCH (n) DETACH DELETE n") is False | |
| def test_rejects_set(self): | |
| assert validate_cypher("MATCH (n) SET n.name = 'test'") is False | |
| def test_rejects_remove(self): | |
| assert validate_cypher("MATCH (n) REMOVE n.name") is False | |
| def test_rejects_drop(self): | |
| assert validate_cypher("DROP INDEX ON :Person(name)") is False | |
| def test_rejects_call(self): | |
| assert validate_cypher("CALL db.labels()") is False | |
| def test_rejects_load_csv(self): | |
| assert validate_cypher("LOAD CSV FROM 'file:///data.csv' AS line") is False | |
| def test_rejects_foreach(self): | |
| assert ( | |
| validate_cypher("MATCH (n) FOREACH (x IN [1,2] | SET n.val = x)") is False | |
| ) | |
| def test_rejects_empty_string(self): | |
| assert validate_cypher("") is False | |
| def test_rejects_whitespace_only(self): | |
| assert validate_cypher(" ") is False | |
| def test_allows_simple_match(self): | |
| assert validate_cypher("MATCH (n) RETURN n") is True | |
| def test_allows_match_with_where(self): | |
| assert ( | |
| validate_cypher( | |
| "MATCH (p:Person {name: $patient_name})-[:TAKES]->(m:Medication) " | |
| "WHERE m.dose IS NOT NULL RETURN m.name, m.dose" | |
| ) | |
| is True | |
| ) | |
| def test_allows_optional_match(self): | |
| assert ( | |
| validate_cypher( | |
| "OPTIONAL MATCH (p:Person)-[:EXPERIENCED]->(e:Event) RETURN count(e)" | |
| ) | |
| is True | |
| ) | |
| def test_allows_with_clause(self): | |
| assert ( | |
| validate_cypher( | |
| "MATCH (p:Person {name: $patient_name}) " | |
| "WITH p " | |
| "MATCH (p)-[:TAKES]->(m:Medication) " | |
| "RETURN m.name" | |
| ) | |
| is True | |
| ) | |
| def test_allows_aggregation(self): | |
| assert ( | |
| validate_cypher( | |
| "MATCH (p:Person)-[:EXPERIENCED]->(e:Event {type: 'headache'}) " | |
| "RETURN count(e) AS headache_count, avg(e.severity) AS avg_severity" | |
| ) | |
| is True | |
| ) | |
| def test_allows_order_by_limit(self): | |
| assert ( | |
| validate_cypher( | |
| "MATCH (p:Person)-[:EXPERIENCED]->(e:Event) " | |
| "RETURN e.type, e.timestamp " | |
| "ORDER BY e.timestamp DESC LIMIT 10" | |
| ) | |
| is True | |
| ) | |
| def test_allows_unwind(self): | |
| assert ( | |
| validate_cypher( | |
| "UNWIND $names AS name " "MATCH (p:Person {name: name}) RETURN p" | |
| ) | |
| is True | |
| ) | |
| def test_rejects_create_case_insensitive(self): | |
| """Mutations should be caught regardless of case.""" | |
| assert validate_cypher("match (n) Create (m:Test)") is False | |
| assert validate_cypher("MATCH (n) create (m:Test)") is False | |
| assert validate_cypher("match (n) CREATE (m:Test)") is False | |
| def test_rejects_query_not_starting_with_match(self): | |
| """Queries must start with MATCH, WITH, RETURN, or UNWIND.""" | |
| assert validate_cypher("RETURN 1") is True | |
| assert validate_cypher("EXPLAIN MATCH (n) RETURN n") is False | |
| class TestGraphQueryEngineSchemaDescription: | |
| """Test schema description from GraphMemory.""" | |
| def test_fallback_schema_when_not_connected(self): | |
| """Schema description should return a fallback when not connected.""" | |
| mock_graph = MagicMock() | |
| mock_graph.is_connected = False | |
| mock_graph.get_schema_description.return_value = ( | |
| "## Neo4j Graph Schema\n\n### Node Labels\n (:Person)" | |
| ) | |
| engine = GraphQueryEngine(mock_graph) | |
| schema = engine._get_schema() | |
| assert "Neo4j Graph Schema" in schema | |
| def test_schema_caching(self): | |
| """Schema should only be fetched once.""" | |
| mock_graph = MagicMock() | |
| mock_graph.get_schema_description.return_value = "cached schema" | |
| engine = GraphQueryEngine(mock_graph) | |
| _ = engine._get_schema() | |
| _ = engine._get_schema() | |
| # Should only be called once | |
| mock_graph.get_schema_description.assert_called_once() | |
| def test_cache_invalidation(self): | |
| """After invalidation, schema should be re-fetched.""" | |
| mock_graph = MagicMock() | |
| mock_graph.get_schema_description.return_value = "schema v1" | |
| engine = GraphQueryEngine(mock_graph) | |
| s1 = engine._get_schema() | |
| assert s1 == "schema v1" | |
| mock_graph.get_schema_description.return_value = "schema v2" | |
| engine.invalidate_schema_cache() | |
| s2 = engine._get_schema() | |
| assert s2 == "schema v2" | |
| class TestGraphQueryEngineExecution: | |
| """Test query execution path (mocked).""" | |
| async def test_execute_uses_read_session(self): | |
| """Ensure execute() calls execute_read, not _execute.""" | |
| mock_graph = MagicMock() | |
| mock_graph.is_connected = True | |
| mock_graph.execute_read.return_value = [{"count": 3}] | |
| engine = GraphQueryEngine(mock_graph) | |
| results = await engine.execute( | |
| "MATCH (n:Person)-[:EXPERIENCED]->(e:Event) RETURN count(e) AS count", | |
| patient_name="Elena", | |
| ) | |
| mock_graph.execute_read.assert_called_once() | |
| assert results == [{"count": 3}] | |
| async def test_execute_returns_empty_when_disconnected(self): | |
| """Ensure execute() returns empty list when graph is not connected.""" | |
| mock_graph = MagicMock() | |
| mock_graph.is_connected = False | |
| engine = GraphQueryEngine(mock_graph) | |
| results = await engine.execute("MATCH (n) RETURN n") | |
| assert results == [] | |
| mock_graph.execute_read.assert_not_called() | |
| async def test_query_handles_generation_error_gracefully(self): | |
| """Full query() should return a friendly error message on failure.""" | |
| mock_graph = MagicMock() | |
| mock_graph.is_connected = True | |
| mock_graph.get_schema_description.return_value = "test schema" | |
| engine = GraphQueryEngine(mock_graph) | |
| # Mock the OpenAI client to raise an error | |
| with patch.object(engine, "_get_client") as mock_client: | |
| mock_client.return_value.chat.completions.create = AsyncMock( | |
| side_effect=Exception("API error") | |
| ) | |
| result = await engine.query("How many headaches?", "Elena") | |
| assert "error" in result | |
| assert "answer" in result | |
| assert result["result_count"] == 0 | |