Spaces:
Running
Running
File size: 7,236 Bytes
2880ca9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
"""Tests for the GraphQueryEngine — Cypher validation and safety guardrails."""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from reachy_mini_conversation_app.graph_query_engine import (
validate_cypher,
GraphQueryEngine,
)
class TestCypherValidation:
"""Test the Cypher safety validation function."""
def test_rejects_create(self):
assert validate_cypher("CREATE (n:Person {name: 'test'})") is False
def test_rejects_merge(self):
assert validate_cypher("MERGE (n:Person {name: 'test'})") is False
def test_rejects_delete(self):
assert validate_cypher("MATCH (n) DELETE n") is False
def test_rejects_detach_delete(self):
assert validate_cypher("MATCH (n) DETACH DELETE n") is False
def test_rejects_set(self):
assert validate_cypher("MATCH (n) SET n.name = 'test'") is False
def test_rejects_remove(self):
assert validate_cypher("MATCH (n) REMOVE n.name") is False
def test_rejects_drop(self):
assert validate_cypher("DROP INDEX ON :Person(name)") is False
def test_rejects_call(self):
assert validate_cypher("CALL db.labels()") is False
def test_rejects_load_csv(self):
assert validate_cypher("LOAD CSV FROM 'file:///data.csv' AS line") is False
def test_rejects_foreach(self):
assert (
validate_cypher("MATCH (n) FOREACH (x IN [1,2] | SET n.val = x)") is False
)
def test_rejects_empty_string(self):
assert validate_cypher("") is False
def test_rejects_whitespace_only(self):
assert validate_cypher(" ") is False
def test_allows_simple_match(self):
assert validate_cypher("MATCH (n) RETURN n") is True
def test_allows_match_with_where(self):
assert (
validate_cypher(
"MATCH (p:Person {name: $patient_name})-[:TAKES]->(m:Medication) "
"WHERE m.dose IS NOT NULL RETURN m.name, m.dose"
)
is True
)
def test_allows_optional_match(self):
assert (
validate_cypher(
"OPTIONAL MATCH (p:Person)-[:EXPERIENCED]->(e:Event) RETURN count(e)"
)
is True
)
def test_allows_with_clause(self):
assert (
validate_cypher(
"MATCH (p:Person {name: $patient_name}) "
"WITH p "
"MATCH (p)-[:TAKES]->(m:Medication) "
"RETURN m.name"
)
is True
)
def test_allows_aggregation(self):
assert (
validate_cypher(
"MATCH (p:Person)-[:EXPERIENCED]->(e:Event {type: 'headache'}) "
"RETURN count(e) AS headache_count, avg(e.severity) AS avg_severity"
)
is True
)
def test_allows_order_by_limit(self):
assert (
validate_cypher(
"MATCH (p:Person)-[:EXPERIENCED]->(e:Event) "
"RETURN e.type, e.timestamp "
"ORDER BY e.timestamp DESC LIMIT 10"
)
is True
)
def test_allows_unwind(self):
assert (
validate_cypher(
"UNWIND $names AS name " "MATCH (p:Person {name: name}) RETURN p"
)
is True
)
def test_rejects_create_case_insensitive(self):
"""Mutations should be caught regardless of case."""
assert validate_cypher("match (n) Create (m:Test)") is False
assert validate_cypher("MATCH (n) create (m:Test)") is False
assert validate_cypher("match (n) CREATE (m:Test)") is False
def test_rejects_query_not_starting_with_match(self):
"""Queries must start with MATCH, WITH, RETURN, or UNWIND."""
assert validate_cypher("RETURN 1") is True
assert validate_cypher("EXPLAIN MATCH (n) RETURN n") is False
class TestGraphQueryEngineSchemaDescription:
"""Test schema description from GraphMemory."""
def test_fallback_schema_when_not_connected(self):
"""Schema description should return a fallback when not connected."""
mock_graph = MagicMock()
mock_graph.is_connected = False
mock_graph.get_schema_description.return_value = (
"## Neo4j Graph Schema\n\n### Node Labels\n (:Person)"
)
engine = GraphQueryEngine(mock_graph)
schema = engine._get_schema()
assert "Neo4j Graph Schema" in schema
def test_schema_caching(self):
"""Schema should only be fetched once."""
mock_graph = MagicMock()
mock_graph.get_schema_description.return_value = "cached schema"
engine = GraphQueryEngine(mock_graph)
_ = engine._get_schema()
_ = engine._get_schema()
# Should only be called once
mock_graph.get_schema_description.assert_called_once()
def test_cache_invalidation(self):
"""After invalidation, schema should be re-fetched."""
mock_graph = MagicMock()
mock_graph.get_schema_description.return_value = "schema v1"
engine = GraphQueryEngine(mock_graph)
s1 = engine._get_schema()
assert s1 == "schema v1"
mock_graph.get_schema_description.return_value = "schema v2"
engine.invalidate_schema_cache()
s2 = engine._get_schema()
assert s2 == "schema v2"
@pytest.mark.asyncio
class TestGraphQueryEngineExecution:
"""Test query execution path (mocked)."""
async def test_execute_uses_read_session(self):
"""Ensure execute() calls execute_read, not _execute."""
mock_graph = MagicMock()
mock_graph.is_connected = True
mock_graph.execute_read.return_value = [{"count": 3}]
engine = GraphQueryEngine(mock_graph)
results = await engine.execute(
"MATCH (n:Person)-[:EXPERIENCED]->(e:Event) RETURN count(e) AS count",
patient_name="Elena",
)
mock_graph.execute_read.assert_called_once()
assert results == [{"count": 3}]
async def test_execute_returns_empty_when_disconnected(self):
"""Ensure execute() returns empty list when graph is not connected."""
mock_graph = MagicMock()
mock_graph.is_connected = False
engine = GraphQueryEngine(mock_graph)
results = await engine.execute("MATCH (n) RETURN n")
assert results == []
mock_graph.execute_read.assert_not_called()
async def test_query_handles_generation_error_gracefully(self):
"""Full query() should return a friendly error message on failure."""
mock_graph = MagicMock()
mock_graph.is_connected = True
mock_graph.get_schema_description.return_value = "test schema"
engine = GraphQueryEngine(mock_graph)
# Mock the OpenAI client to raise an error
with patch.object(engine, "_get_client") as mock_client:
mock_client.return_value.chat.completions.create = AsyncMock(
side_effect=Exception("API error")
)
result = await engine.query("How many headaches?", "Elena")
assert "error" in result
assert "answer" in result
assert result["result_count"] == 0
|