reachy_mini_minder / tests /test_graph_query_engine.py
Boopster's picture
feat: Implement pattern detection and integrate graph query engine with session insights and new tools.
2880ca9
"""Tests for the GraphQueryEngine — Cypher validation and safety guardrails."""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from reachy_mini_conversation_app.graph_query_engine import (
validate_cypher,
GraphQueryEngine,
)
class TestCypherValidation:
"""Test the Cypher safety validation function."""
def test_rejects_create(self):
assert validate_cypher("CREATE (n:Person {name: 'test'})") is False
def test_rejects_merge(self):
assert validate_cypher("MERGE (n:Person {name: 'test'})") is False
def test_rejects_delete(self):
assert validate_cypher("MATCH (n) DELETE n") is False
def test_rejects_detach_delete(self):
assert validate_cypher("MATCH (n) DETACH DELETE n") is False
def test_rejects_set(self):
assert validate_cypher("MATCH (n) SET n.name = 'test'") is False
def test_rejects_remove(self):
assert validate_cypher("MATCH (n) REMOVE n.name") is False
def test_rejects_drop(self):
assert validate_cypher("DROP INDEX ON :Person(name)") is False
def test_rejects_call(self):
assert validate_cypher("CALL db.labels()") is False
def test_rejects_load_csv(self):
assert validate_cypher("LOAD CSV FROM 'file:///data.csv' AS line") is False
def test_rejects_foreach(self):
assert (
validate_cypher("MATCH (n) FOREACH (x IN [1,2] | SET n.val = x)") is False
)
def test_rejects_empty_string(self):
assert validate_cypher("") is False
def test_rejects_whitespace_only(self):
assert validate_cypher(" ") is False
def test_allows_simple_match(self):
assert validate_cypher("MATCH (n) RETURN n") is True
def test_allows_match_with_where(self):
assert (
validate_cypher(
"MATCH (p:Person {name: $patient_name})-[:TAKES]->(m:Medication) "
"WHERE m.dose IS NOT NULL RETURN m.name, m.dose"
)
is True
)
def test_allows_optional_match(self):
assert (
validate_cypher(
"OPTIONAL MATCH (p:Person)-[:EXPERIENCED]->(e:Event) RETURN count(e)"
)
is True
)
def test_allows_with_clause(self):
assert (
validate_cypher(
"MATCH (p:Person {name: $patient_name}) "
"WITH p "
"MATCH (p)-[:TAKES]->(m:Medication) "
"RETURN m.name"
)
is True
)
def test_allows_aggregation(self):
assert (
validate_cypher(
"MATCH (p:Person)-[:EXPERIENCED]->(e:Event {type: 'headache'}) "
"RETURN count(e) AS headache_count, avg(e.severity) AS avg_severity"
)
is True
)
def test_allows_order_by_limit(self):
assert (
validate_cypher(
"MATCH (p:Person)-[:EXPERIENCED]->(e:Event) "
"RETURN e.type, e.timestamp "
"ORDER BY e.timestamp DESC LIMIT 10"
)
is True
)
def test_allows_unwind(self):
assert (
validate_cypher(
"UNWIND $names AS name " "MATCH (p:Person {name: name}) RETURN p"
)
is True
)
def test_rejects_create_case_insensitive(self):
"""Mutations should be caught regardless of case."""
assert validate_cypher("match (n) Create (m:Test)") is False
assert validate_cypher("MATCH (n) create (m:Test)") is False
assert validate_cypher("match (n) CREATE (m:Test)") is False
def test_rejects_query_not_starting_with_match(self):
"""Queries must start with MATCH, WITH, RETURN, or UNWIND."""
assert validate_cypher("RETURN 1") is True
assert validate_cypher("EXPLAIN MATCH (n) RETURN n") is False
class TestGraphQueryEngineSchemaDescription:
"""Test schema description from GraphMemory."""
def test_fallback_schema_when_not_connected(self):
"""Schema description should return a fallback when not connected."""
mock_graph = MagicMock()
mock_graph.is_connected = False
mock_graph.get_schema_description.return_value = (
"## Neo4j Graph Schema\n\n### Node Labels\n (:Person)"
)
engine = GraphQueryEngine(mock_graph)
schema = engine._get_schema()
assert "Neo4j Graph Schema" in schema
def test_schema_caching(self):
"""Schema should only be fetched once."""
mock_graph = MagicMock()
mock_graph.get_schema_description.return_value = "cached schema"
engine = GraphQueryEngine(mock_graph)
_ = engine._get_schema()
_ = engine._get_schema()
# Should only be called once
mock_graph.get_schema_description.assert_called_once()
def test_cache_invalidation(self):
"""After invalidation, schema should be re-fetched."""
mock_graph = MagicMock()
mock_graph.get_schema_description.return_value = "schema v1"
engine = GraphQueryEngine(mock_graph)
s1 = engine._get_schema()
assert s1 == "schema v1"
mock_graph.get_schema_description.return_value = "schema v2"
engine.invalidate_schema_cache()
s2 = engine._get_schema()
assert s2 == "schema v2"
@pytest.mark.asyncio
class TestGraphQueryEngineExecution:
"""Test query execution path (mocked)."""
async def test_execute_uses_read_session(self):
"""Ensure execute() calls execute_read, not _execute."""
mock_graph = MagicMock()
mock_graph.is_connected = True
mock_graph.execute_read.return_value = [{"count": 3}]
engine = GraphQueryEngine(mock_graph)
results = await engine.execute(
"MATCH (n:Person)-[:EXPERIENCED]->(e:Event) RETURN count(e) AS count",
patient_name="Elena",
)
mock_graph.execute_read.assert_called_once()
assert results == [{"count": 3}]
async def test_execute_returns_empty_when_disconnected(self):
"""Ensure execute() returns empty list when graph is not connected."""
mock_graph = MagicMock()
mock_graph.is_connected = False
engine = GraphQueryEngine(mock_graph)
results = await engine.execute("MATCH (n) RETURN n")
assert results == []
mock_graph.execute_read.assert_not_called()
async def test_query_handles_generation_error_gracefully(self):
"""Full query() should return a friendly error message on failure."""
mock_graph = MagicMock()
mock_graph.is_connected = True
mock_graph.get_schema_description.return_value = "test schema"
engine = GraphQueryEngine(mock_graph)
# Mock the OpenAI client to raise an error
with patch.object(engine, "_get_client") as mock_client:
mock_client.return_value.chat.completions.create = AsyncMock(
side_effect=Exception("API error")
)
result = await engine.query("How many headaches?", "Elena")
assert "error" in result
assert "answer" in result
assert result["result_count"] == 0