"""Tests for the GraphQueryEngine — Cypher validation and safety guardrails.""" import pytest from unittest.mock import AsyncMock, MagicMock, patch from reachy_mini_conversation_app.graph_query_engine import ( validate_cypher, GraphQueryEngine, ) class TestCypherValidation: """Test the Cypher safety validation function.""" def test_rejects_create(self): assert validate_cypher("CREATE (n:Person {name: 'test'})") is False def test_rejects_merge(self): assert validate_cypher("MERGE (n:Person {name: 'test'})") is False def test_rejects_delete(self): assert validate_cypher("MATCH (n) DELETE n") is False def test_rejects_detach_delete(self): assert validate_cypher("MATCH (n) DETACH DELETE n") is False def test_rejects_set(self): assert validate_cypher("MATCH (n) SET n.name = 'test'") is False def test_rejects_remove(self): assert validate_cypher("MATCH (n) REMOVE n.name") is False def test_rejects_drop(self): assert validate_cypher("DROP INDEX ON :Person(name)") is False def test_rejects_call(self): assert validate_cypher("CALL db.labels()") is False def test_rejects_load_csv(self): assert validate_cypher("LOAD CSV FROM 'file:///data.csv' AS line") is False def test_rejects_foreach(self): assert ( validate_cypher("MATCH (n) FOREACH (x IN [1,2] | SET n.val = x)") is False ) def test_rejects_empty_string(self): assert validate_cypher("") is False def test_rejects_whitespace_only(self): assert validate_cypher(" ") is False def test_allows_simple_match(self): assert validate_cypher("MATCH (n) RETURN n") is True def test_allows_match_with_where(self): assert ( validate_cypher( "MATCH (p:Person {name: $patient_name})-[:TAKES]->(m:Medication) " "WHERE m.dose IS NOT NULL RETURN m.name, m.dose" ) is True ) def test_allows_optional_match(self): assert ( validate_cypher( "OPTIONAL MATCH (p:Person)-[:EXPERIENCED]->(e:Event) RETURN count(e)" ) is True ) def test_allows_with_clause(self): assert ( validate_cypher( "MATCH (p:Person {name: $patient_name}) " "WITH p " "MATCH (p)-[:TAKES]->(m:Medication) " "RETURN m.name" ) is True ) def test_allows_aggregation(self): assert ( validate_cypher( "MATCH (p:Person)-[:EXPERIENCED]->(e:Event {type: 'headache'}) " "RETURN count(e) AS headache_count, avg(e.severity) AS avg_severity" ) is True ) def test_allows_order_by_limit(self): assert ( validate_cypher( "MATCH (p:Person)-[:EXPERIENCED]->(e:Event) " "RETURN e.type, e.timestamp " "ORDER BY e.timestamp DESC LIMIT 10" ) is True ) def test_allows_unwind(self): assert ( validate_cypher( "UNWIND $names AS name " "MATCH (p:Person {name: name}) RETURN p" ) is True ) def test_rejects_create_case_insensitive(self): """Mutations should be caught regardless of case.""" assert validate_cypher("match (n) Create (m:Test)") is False assert validate_cypher("MATCH (n) create (m:Test)") is False assert validate_cypher("match (n) CREATE (m:Test)") is False def test_rejects_query_not_starting_with_match(self): """Queries must start with MATCH, WITH, RETURN, or UNWIND.""" assert validate_cypher("RETURN 1") is True assert validate_cypher("EXPLAIN MATCH (n) RETURN n") is False class TestGraphQueryEngineSchemaDescription: """Test schema description from GraphMemory.""" def test_fallback_schema_when_not_connected(self): """Schema description should return a fallback when not connected.""" mock_graph = MagicMock() mock_graph.is_connected = False mock_graph.get_schema_description.return_value = ( "## Neo4j Graph Schema\n\n### Node Labels\n (:Person)" ) engine = GraphQueryEngine(mock_graph) schema = engine._get_schema() assert "Neo4j Graph Schema" in schema def test_schema_caching(self): """Schema should only be fetched once.""" mock_graph = MagicMock() mock_graph.get_schema_description.return_value = "cached schema" engine = GraphQueryEngine(mock_graph) _ = engine._get_schema() _ = engine._get_schema() # Should only be called once mock_graph.get_schema_description.assert_called_once() def test_cache_invalidation(self): """After invalidation, schema should be re-fetched.""" mock_graph = MagicMock() mock_graph.get_schema_description.return_value = "schema v1" engine = GraphQueryEngine(mock_graph) s1 = engine._get_schema() assert s1 == "schema v1" mock_graph.get_schema_description.return_value = "schema v2" engine.invalidate_schema_cache() s2 = engine._get_schema() assert s2 == "schema v2" @pytest.mark.asyncio class TestGraphQueryEngineExecution: """Test query execution path (mocked).""" async def test_execute_uses_read_session(self): """Ensure execute() calls execute_read, not _execute.""" mock_graph = MagicMock() mock_graph.is_connected = True mock_graph.execute_read.return_value = [{"count": 3}] engine = GraphQueryEngine(mock_graph) results = await engine.execute( "MATCH (n:Person)-[:EXPERIENCED]->(e:Event) RETURN count(e) AS count", patient_name="Elena", ) mock_graph.execute_read.assert_called_once() assert results == [{"count": 3}] async def test_execute_returns_empty_when_disconnected(self): """Ensure execute() returns empty list when graph is not connected.""" mock_graph = MagicMock() mock_graph.is_connected = False engine = GraphQueryEngine(mock_graph) results = await engine.execute("MATCH (n) RETURN n") assert results == [] mock_graph.execute_read.assert_not_called() async def test_query_handles_generation_error_gracefully(self): """Full query() should return a friendly error message on failure.""" mock_graph = MagicMock() mock_graph.is_connected = True mock_graph.get_schema_description.return_value = "test schema" engine = GraphQueryEngine(mock_graph) # Mock the OpenAI client to raise an error with patch.object(engine, "_get_client") as mock_client: mock_client.return_value.chat.completions.create = AsyncMock( side_effect=Exception("API error") ) result = await engine.query("How many headaches?", "Elena") assert "error" in result assert "answer" in result assert result["result_count"] == 0