Spaces:
Running
Running
feat: Implement pattern detection and integrate graph query engine with session insights and new tools.
Browse files- src/reachy_mini_conversation_app/graph_query_engine.py +307 -0
- src/reachy_mini_conversation_app/main.py +6 -3
- src/reachy_mini_conversation_app/memory_graph.py +87 -0
- src/reachy_mini_conversation_app/openai_realtime.py +24 -6
- src/reachy_mini_conversation_app/pattern_detector.py +424 -0
- src/reachy_mini_conversation_app/profiles/_reachy_mini_minder_locked_profile/tools.txt +1 -0
- src/reachy_mini_conversation_app/session_enrichment.py +33 -1
- src/reachy_mini_conversation_app/tools/check_medication.py +4 -7
- src/reachy_mini_conversation_app/tools/core_tools.py +1 -0
- src/reachy_mini_conversation_app/tools/query_health_history.py +133 -0
- tests/test_graph_query_engine.py +213 -0
- tests/test_pattern_detector.py +276 -0
src/reachy_mini_conversation_app/graph_query_engine.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dynamic Cypher query engine for natural language → Neo4j graph queries.
|
| 2 |
+
|
| 3 |
+
Translates natural language health questions into Cypher queries using an LLM,
|
| 4 |
+
executes them in a read-only Neo4j session, and returns human-readable answers.
|
| 5 |
+
|
| 6 |
+
Safety stack:
|
| 7 |
+
1. LLM system prompt restricts to read-only Cypher (MATCH/RETURN/WHERE only)
|
| 8 |
+
2. Regex validation rejects mutations before execution
|
| 9 |
+
3. Neo4j driver uses execute_read() for driver-level read-only enforcement
|
| 10 |
+
4. PII guard redacts patient info before LLM, hydrates after
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import re
|
| 16 |
+
import json
|
| 17 |
+
import logging
|
| 18 |
+
from typing import Any, Dict, List, Optional
|
| 19 |
+
|
| 20 |
+
from openai import AsyncOpenAI
|
| 21 |
+
|
| 22 |
+
from reachy_mini_conversation_app.config import config
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# ---- Cypher Safety --------------------------------------------------------
|
| 28 |
+
|
| 29 |
+
_FORBIDDEN_KEYWORDS = re.compile(
|
| 30 |
+
r"\b(CREATE|MERGE|DELETE|DETACH|SET|REMOVE|DROP|CALL|LOAD\s+CSV|FOREACH)\b",
|
| 31 |
+
re.IGNORECASE,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
_ALLOWED_PATTERN = re.compile(
|
| 35 |
+
r"^\s*(MATCH|OPTIONAL\s+MATCH|WITH|WHERE|RETURN|ORDER\s+BY|LIMIT|SKIP|UNWIND|AS|AND|OR|NOT|IN|IS|NULL|COUNT|SUM|AVG|MIN|MAX|COLLECT|DISTINCT|CASE|WHEN|THEN|ELSE|END|EXISTS|SIZE|COALESCE|HEAD|LAST|RANGE|REDUCE|NONE|ANY|ALL|SINGLE|FILTER|EXTRACT|datetime|date|duration|toString|toInteger|toFloat|toLower|toUpper|trim|split|replace|substring|left|right|length|nodes|relationships|labels|type|id|properties|keys|startNode|endNode|point|distance)",
|
| 36 |
+
re.IGNORECASE,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def validate_cypher(cypher: str) -> bool:
|
| 41 |
+
"""Validate that a Cypher query is read-only.
|
| 42 |
+
|
| 43 |
+
Returns True if the query appears safe to execute, False otherwise.
|
| 44 |
+
"""
|
| 45 |
+
if not cypher or not cypher.strip():
|
| 46 |
+
return False
|
| 47 |
+
|
| 48 |
+
# Reject any mutation keywords
|
| 49 |
+
if _FORBIDDEN_KEYWORDS.search(cypher):
|
| 50 |
+
logger.warning("Cypher rejected — contains forbidden keyword: %s", cypher[:200])
|
| 51 |
+
return False
|
| 52 |
+
|
| 53 |
+
# Must start with MATCH, OPTIONAL MATCH, WITH, or RETURN
|
| 54 |
+
stripped = cypher.strip()
|
| 55 |
+
if not re.match(
|
| 56 |
+
r"^\s*(MATCH|OPTIONAL\s+MATCH|WITH|RETURN|UNWIND)\b", stripped, re.IGNORECASE
|
| 57 |
+
):
|
| 58 |
+
logger.warning(
|
| 59 |
+
"Cypher rejected — does not start with MATCH/WITH/RETURN: %s", cypher[:200]
|
| 60 |
+
)
|
| 61 |
+
return False
|
| 62 |
+
|
| 63 |
+
return True
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# ---- System prompt for Cypher generation -----------------------------------
|
| 67 |
+
|
| 68 |
+
_CYPHER_SYSTEM_PROMPT = """You are a Cypher query generator for a Neo4j health knowledge graph.
|
| 69 |
+
|
| 70 |
+
RULES — you MUST follow these strictly:
|
| 71 |
+
1. Generate ONLY read-only Cypher queries. Use MATCH, RETURN, WHERE, WITH, ORDER BY, LIMIT, SKIP, UNWIND.
|
| 72 |
+
2. NEVER use CREATE, MERGE, DELETE, SET, REMOVE, DROP, CALL, LOAD CSV, or FOREACH.
|
| 73 |
+
3. NEVER modify the graph in any way.
|
| 74 |
+
4. Return ONLY the Cypher query — no explanation, no markdown, no code fences.
|
| 75 |
+
5. Use parameterised values where possible (e.g., $patient_name).
|
| 76 |
+
6. Keep queries concise — prefer aggregation over returning raw data.
|
| 77 |
+
7. Use neutral, factual language in aliases (e.g., "count" not "caused_by_count").
|
| 78 |
+
8. When the user asks about "today", use `date()` for comparison.
|
| 79 |
+
9. For time ranges, use `datetime() - duration('P{n}D')` syntax.
|
| 80 |
+
|
| 81 |
+
{schema}
|
| 82 |
+
|
| 83 |
+
EXAMPLES:
|
| 84 |
+
Q: "How many headaches did I have this week?"
|
| 85 |
+
A: MATCH (p:Person {{name: $patient_name}})-[:EXPERIENCED]->(e:Event {{type: 'headache'}})
|
| 86 |
+
WHERE e.timestamp >= datetime() - duration('P7D')
|
| 87 |
+
RETURN count(e) AS headache_count
|
| 88 |
+
|
| 89 |
+
Q: "What medications am I taking?"
|
| 90 |
+
A: MATCH (p:Person {{name: $patient_name}})-[:TAKES]->(m:Medication)
|
| 91 |
+
RETURN m.name AS medication, m.dose AS dose, m.frequency AS frequency
|
| 92 |
+
|
| 93 |
+
Q: "When did I last see my doctor?"
|
| 94 |
+
A: MATCH (p:Person {{name: $patient_name}})-[:SEES]->(d:Person {{role: 'neurologist'}})
|
| 95 |
+
OPTIONAL MATCH (p)-[:EXPERIENCED]->(e:Event {{type: 'doctor_visit'}})
|
| 96 |
+
RETURN d.name AS doctor, e.timestamp AS last_visit
|
| 97 |
+
ORDER BY e.timestamp DESC LIMIT 1
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
_ANSWER_SYSTEM_PROMPT = """You are a friendly health companion summarising data from a knowledge graph.
|
| 101 |
+
|
| 102 |
+
RULES:
|
| 103 |
+
1. Summarise the query results in plain, conversational English.
|
| 104 |
+
2. Use factual language — say "logged", "recorded", "showed up" — never "caused", "triggered".
|
| 105 |
+
3. If the results are empty, say so helpfully (e.g., "I don't have any records of that yet").
|
| 106 |
+
4. Keep it concise — this will be spoken aloud by a robot.
|
| 107 |
+
5. Never give medical advice. State facts only.
|
| 108 |
+
6. If the data is a count, state the number naturally.
|
| 109 |
+
7. Refer to the user by name if provided, otherwise say "you".
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class GraphQueryEngine:
|
| 114 |
+
"""LLM-powered Cypher query generation and execution.
|
| 115 |
+
|
| 116 |
+
Uses gpt-4.1-mini for Cypher generation and answer formatting.
|
| 117 |
+
All queries execute in read-only Neo4j transactions.
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
def __init__(self, graph_memory: Any) -> None:
|
| 121 |
+
"""Initialise with a connected GraphMemory instance."""
|
| 122 |
+
self._graph = graph_memory
|
| 123 |
+
self._client: Optional[AsyncOpenAI] = None
|
| 124 |
+
self._schema_cache: Optional[str] = None
|
| 125 |
+
|
| 126 |
+
def _get_client(self) -> AsyncOpenAI:
|
| 127 |
+
"""Lazy-init the OpenAI client."""
|
| 128 |
+
if self._client is None:
|
| 129 |
+
self._client = AsyncOpenAI(api_key=config.OPENAI_API_KEY)
|
| 130 |
+
return self._client
|
| 131 |
+
|
| 132 |
+
def _get_schema(self) -> str:
|
| 133 |
+
"""Get the graph schema description, with caching."""
|
| 134 |
+
if self._schema_cache is None:
|
| 135 |
+
self._schema_cache = self._graph.get_schema_description()
|
| 136 |
+
return self._schema_cache
|
| 137 |
+
|
| 138 |
+
def invalidate_schema_cache(self) -> None:
|
| 139 |
+
"""Clear the schema cache (e.g., after enrichment adds new node types)."""
|
| 140 |
+
self._schema_cache = None
|
| 141 |
+
|
| 142 |
+
async def generate_cypher(
|
| 143 |
+
self, question: str, patient_name: str = "Patient"
|
| 144 |
+
) -> str:
|
| 145 |
+
"""Generate a read-only Cypher query from a natural language question.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
question: The user's natural language health question.
|
| 149 |
+
patient_name: The patient name to parameterise in the query.
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
A Cypher query string.
|
| 153 |
+
|
| 154 |
+
Raises:
|
| 155 |
+
ValueError: If the generated query fails safety validation.
|
| 156 |
+
"""
|
| 157 |
+
client = self._get_client()
|
| 158 |
+
schema = self._get_schema()
|
| 159 |
+
|
| 160 |
+
system_prompt = _CYPHER_SYSTEM_PROMPT.format(schema=schema)
|
| 161 |
+
|
| 162 |
+
response = await client.chat.completions.create(
|
| 163 |
+
model="gpt-4.1-mini",
|
| 164 |
+
messages=[
|
| 165 |
+
{"role": "system", "content": system_prompt},
|
| 166 |
+
{"role": "user", "content": question},
|
| 167 |
+
],
|
| 168 |
+
temperature=0.0,
|
| 169 |
+
max_tokens=500,
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
cypher = response.choices[0].message.content.strip()
|
| 173 |
+
|
| 174 |
+
# Strip markdown code fences if the LLM wraps them despite instructions
|
| 175 |
+
if cypher.startswith("```"):
|
| 176 |
+
lines = cypher.split("\n")
|
| 177 |
+
# Remove first and last line (```cypher and ```)
|
| 178 |
+
cypher = "\n".join(
|
| 179 |
+
line for line in lines if not line.strip().startswith("```")
|
| 180 |
+
).strip()
|
| 181 |
+
|
| 182 |
+
if not validate_cypher(cypher):
|
| 183 |
+
raise ValueError(
|
| 184 |
+
f"Generated Cypher failed safety validation: {cypher[:200]}"
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
logger.info("Generated Cypher: %s", cypher)
|
| 188 |
+
return cypher
|
| 189 |
+
|
| 190 |
+
async def execute(
|
| 191 |
+
self,
|
| 192 |
+
cypher: str,
|
| 193 |
+
patient_name: str = "Patient",
|
| 194 |
+
) -> List[Dict[str, Any]]:
|
| 195 |
+
"""Execute a validated Cypher query in a read-only transaction.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
cypher: The Cypher query to execute.
|
| 199 |
+
patient_name: The patient name to bind to $patient_name param.
|
| 200 |
+
|
| 201 |
+
Returns:
|
| 202 |
+
List of result records as dicts.
|
| 203 |
+
"""
|
| 204 |
+
if not self._graph or not self._graph.is_connected:
|
| 205 |
+
logger.warning("Graph not connected, cannot execute query")
|
| 206 |
+
return []
|
| 207 |
+
|
| 208 |
+
params = {"patient_name": patient_name}
|
| 209 |
+
return self._graph.execute_read(cypher, params)
|
| 210 |
+
|
| 211 |
+
async def format_answer(
|
| 212 |
+
self,
|
| 213 |
+
question: str,
|
| 214 |
+
results: List[Dict[str, Any]],
|
| 215 |
+
patient_name: str = "Patient",
|
| 216 |
+
) -> str:
|
| 217 |
+
"""Format query results into a natural language answer.
|
| 218 |
+
|
| 219 |
+
Args:
|
| 220 |
+
question: The original user question.
|
| 221 |
+
results: The raw Cypher query results.
|
| 222 |
+
patient_name: The patient's name for personalisation.
|
| 223 |
+
|
| 224 |
+
Returns:
|
| 225 |
+
A friendly, spoken-word summary of the results.
|
| 226 |
+
"""
|
| 227 |
+
client = self._get_client()
|
| 228 |
+
|
| 229 |
+
# Serialise results for the LLM
|
| 230 |
+
if not results:
|
| 231 |
+
results_text = "No results found."
|
| 232 |
+
else:
|
| 233 |
+
# Convert datetime objects etc. to strings for JSON serialisation
|
| 234 |
+
serialisable = []
|
| 235 |
+
for record in results:
|
| 236 |
+
row = {}
|
| 237 |
+
for k, v in record.items():
|
| 238 |
+
row[k] = (
|
| 239 |
+
str(v)
|
| 240 |
+
if not isinstance(v, (str, int, float, bool, type(None)))
|
| 241 |
+
else v
|
| 242 |
+
)
|
| 243 |
+
serialisable.append(row)
|
| 244 |
+
results_text = json.dumps(serialisable, indent=2, default=str)
|
| 245 |
+
|
| 246 |
+
user_message = (
|
| 247 |
+
f"User question: {question}\n"
|
| 248 |
+
f"Patient name: {patient_name}\n"
|
| 249 |
+
f"Query results:\n{results_text}"
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
response = await client.chat.completions.create(
|
| 253 |
+
model="gpt-4.1-mini",
|
| 254 |
+
messages=[
|
| 255 |
+
{"role": "system", "content": _ANSWER_SYSTEM_PROMPT},
|
| 256 |
+
{"role": "user", "content": user_message},
|
| 257 |
+
],
|
| 258 |
+
temperature=0.3,
|
| 259 |
+
max_tokens=300,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
return response.choices[0].message.content.strip()
|
| 263 |
+
|
| 264 |
+
async def query(
|
| 265 |
+
self,
|
| 266 |
+
question: str,
|
| 267 |
+
patient_name: str = "Patient",
|
| 268 |
+
) -> Dict[str, Any]:
|
| 269 |
+
"""End-to-end: question → Cypher → execute → natural language answer.
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
question: The user's natural language health question.
|
| 273 |
+
patient_name: The patient's name.
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
Dict with keys: answer (str), cypher (str), result_count (int),
|
| 277 |
+
raw_results (list).
|
| 278 |
+
"""
|
| 279 |
+
try:
|
| 280 |
+
cypher = await self.generate_cypher(question, patient_name)
|
| 281 |
+
results = await self.execute(cypher, patient_name)
|
| 282 |
+
answer = await self.format_answer(question, results, patient_name)
|
| 283 |
+
|
| 284 |
+
return {
|
| 285 |
+
"answer": answer,
|
| 286 |
+
"cypher": cypher,
|
| 287 |
+
"result_count": len(results),
|
| 288 |
+
"raw_results": results,
|
| 289 |
+
}
|
| 290 |
+
except ValueError as e:
|
| 291 |
+
logger.warning("Query generation failed: %s", e)
|
| 292 |
+
return {
|
| 293 |
+
"answer": "I wasn't able to look that up in your health records right now. Could you try rephrasing?",
|
| 294 |
+
"cypher": "",
|
| 295 |
+
"result_count": 0,
|
| 296 |
+
"raw_results": [],
|
| 297 |
+
"error": str(e),
|
| 298 |
+
}
|
| 299 |
+
except Exception as e:
|
| 300 |
+
logger.exception("Graph query failed: %s", e)
|
| 301 |
+
return {
|
| 302 |
+
"answer": "I had trouble accessing your health records. Let me try again later.",
|
| 303 |
+
"cypher": "",
|
| 304 |
+
"result_count": 0,
|
| 305 |
+
"raw_results": [],
|
| 306 |
+
"error": str(e),
|
| 307 |
+
}
|
src/reachy_mini_conversation_app/main.py
CHANGED
|
@@ -138,6 +138,7 @@ def run(
|
|
| 138 |
logger.debug("Conversation log pruning skipped: %s", e)
|
| 139 |
|
| 140 |
# Initialize session enrichment pipeline (optional Neo4j connection)
|
|
|
|
| 141 |
try:
|
| 142 |
from reachy_mini_conversation_app.session_enrichment import (
|
| 143 |
init_session_enrichment,
|
|
@@ -145,15 +146,16 @@ def run(
|
|
| 145 |
from reachy_mini_conversation_app.memory_graph import GraphMemory
|
| 146 |
|
| 147 |
# Try to connect to Neo4j if available
|
| 148 |
-
|
| 149 |
-
if
|
| 150 |
-
init_session_enrichment(graph_memory=
|
| 151 |
enable_session_enrichment()
|
| 152 |
logger.info("Session enrichment enabled with Neo4j")
|
| 153 |
else:
|
| 154 |
init_session_enrichment(graph_memory=None)
|
| 155 |
enable_session_enrichment()
|
| 156 |
logger.info("Session enrichment enabled (no Neo4j)")
|
|
|
|
| 157 |
except Exception as e:
|
| 158 |
logger.warning("Session enrichment not available: %s", e)
|
| 159 |
|
|
@@ -165,6 +167,7 @@ def run(
|
|
| 165 |
head_wobbler=head_wobbler,
|
| 166 |
entry_state_manager=entry_state,
|
| 167 |
database=minder_db,
|
|
|
|
| 168 |
)
|
| 169 |
|
| 170 |
handler = create_handler(deps, instance_path=instance_path)
|
|
|
|
| 138 |
logger.debug("Conversation log pruning skipped: %s", e)
|
| 139 |
|
| 140 |
# Initialize session enrichment pipeline (optional Neo4j connection)
|
| 141 |
+
graph_memory = None
|
| 142 |
try:
|
| 143 |
from reachy_mini_conversation_app.session_enrichment import (
|
| 144 |
init_session_enrichment,
|
|
|
|
| 146 |
from reachy_mini_conversation_app.memory_graph import GraphMemory
|
| 147 |
|
| 148 |
# Try to connect to Neo4j if available
|
| 149 |
+
graph_memory = GraphMemory()
|
| 150 |
+
if graph_memory.connect():
|
| 151 |
+
init_session_enrichment(graph_memory=graph_memory)
|
| 152 |
enable_session_enrichment()
|
| 153 |
logger.info("Session enrichment enabled with Neo4j")
|
| 154 |
else:
|
| 155 |
init_session_enrichment(graph_memory=None)
|
| 156 |
enable_session_enrichment()
|
| 157 |
logger.info("Session enrichment enabled (no Neo4j)")
|
| 158 |
+
graph_memory = None # Clear ref — not connected
|
| 159 |
except Exception as e:
|
| 160 |
logger.warning("Session enrichment not available: %s", e)
|
| 161 |
|
|
|
|
| 167 |
head_wobbler=head_wobbler,
|
| 168 |
entry_state_manager=entry_state,
|
| 169 |
database=minder_db,
|
| 170 |
+
graph_memory=graph_memory,
|
| 171 |
)
|
| 172 |
|
| 173 |
handler = create_handler(deps, instance_path=instance_path)
|
src/reachy_mini_conversation_app/memory_graph.py
CHANGED
|
@@ -82,6 +82,11 @@ class GraphMemory:
|
|
| 82 |
self._driver.close()
|
| 83 |
self._driver = None
|
| 84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
def _execute(
|
| 86 |
self, query: str, parameters: Optional[Dict[str, Any]] = None
|
| 87 |
) -> List[Dict[str, Any]]:
|
|
@@ -94,6 +99,88 @@ class GraphMemory:
|
|
| 94 |
result = session.run(query, parameters or {})
|
| 95 |
return [dict(record) for record in result]
|
| 96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
# -------------------------------------------------------------------------
|
| 98 |
# Node Creation
|
| 99 |
# -------------------------------------------------------------------------
|
|
|
|
| 82 |
self._driver.close()
|
| 83 |
self._driver = None
|
| 84 |
|
| 85 |
+
@property
|
| 86 |
+
def is_connected(self) -> bool:
|
| 87 |
+
"""Check if the driver is connected."""
|
| 88 |
+
return self._driver is not None
|
| 89 |
+
|
| 90 |
def _execute(
|
| 91 |
self, query: str, parameters: Optional[Dict[str, Any]] = None
|
| 92 |
) -> List[Dict[str, Any]]:
|
|
|
|
| 99 |
result = session.run(query, parameters or {})
|
| 100 |
return [dict(record) for record in result]
|
| 101 |
|
| 102 |
+
def execute_read(
|
| 103 |
+
self, query: str, parameters: Optional[Dict[str, Any]] = None
|
| 104 |
+
) -> List[Dict[str, Any]]:
|
| 105 |
+
"""Execute a Cypher query in a read-only transaction.
|
| 106 |
+
|
| 107 |
+
This is the safe path for LLM-generated queries — the driver
|
| 108 |
+
will reject any write operations even if they slip past validation.
|
| 109 |
+
"""
|
| 110 |
+
if not self._driver:
|
| 111 |
+
logger.warning("Neo4j not connected. Read query skipped.")
|
| 112 |
+
return []
|
| 113 |
+
|
| 114 |
+
def _read_tx(tx: Any) -> List[Dict[str, Any]]:
|
| 115 |
+
result = tx.run(query, parameters or {})
|
| 116 |
+
return [dict(record) for record in result]
|
| 117 |
+
|
| 118 |
+
with self._driver.session() as session:
|
| 119 |
+
return session.execute_read(_read_tx)
|
| 120 |
+
|
| 121 |
+
def get_schema_description(self) -> str:
|
| 122 |
+
"""Return a human-readable description of the graph schema.
|
| 123 |
+
|
| 124 |
+
Attempts live introspection via `db.labels()` and
|
| 125 |
+
`db.relationshipTypes()`; falls back to the hardcoded schema
|
| 126 |
+
constants when Neo4j is not connected.
|
| 127 |
+
"""
|
| 128 |
+
if self._driver:
|
| 129 |
+
try:
|
| 130 |
+
labels = [
|
| 131 |
+
r["label"]
|
| 132 |
+
for r in self._execute("CALL db.labels() YIELD label RETURN label")
|
| 133 |
+
]
|
| 134 |
+
rel_types = [
|
| 135 |
+
r["relationshipType"]
|
| 136 |
+
for r in self._execute(
|
| 137 |
+
"CALL db.relationshipTypes() YIELD relationshipType RETURN relationshipType"
|
| 138 |
+
)
|
| 139 |
+
]
|
| 140 |
+
# Get property keys for each label
|
| 141 |
+
prop_lines = []
|
| 142 |
+
for label in labels:
|
| 143 |
+
props = self._execute(
|
| 144 |
+
f"MATCH (n:{label}) WITH keys(n) AS ks UNWIND ks AS k "
|
| 145 |
+
"RETURN DISTINCT k ORDER BY k LIMIT 20"
|
| 146 |
+
)
|
| 147 |
+
keys = [r["k"] for r in props]
|
| 148 |
+
prop_lines.append(
|
| 149 |
+
f" (:{label}) — properties: {', '.join(keys) if keys else 'none'}"
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
lines = [
|
| 153 |
+
"## Neo4j Graph Schema",
|
| 154 |
+
"",
|
| 155 |
+
"### Node Labels",
|
| 156 |
+
*prop_lines,
|
| 157 |
+
"",
|
| 158 |
+
"### Relationship Types",
|
| 159 |
+
*[f" {rt}" for rt in rel_types],
|
| 160 |
+
]
|
| 161 |
+
return "\n".join(lines)
|
| 162 |
+
except Exception as e:
|
| 163 |
+
logger.debug("Live schema introspection failed, using fallback: %s", e)
|
| 164 |
+
|
| 165 |
+
# Fallback to hardcoded schema from class docstring
|
| 166 |
+
return (
|
| 167 |
+
"## Neo4j Graph Schema\n\n"
|
| 168 |
+
"### Node Labels\n"
|
| 169 |
+
" (:Person) — properties: name, role, updated_at\n"
|
| 170 |
+
" (:Medication) — properties: name, dose, frequency, symptom_category, updated_at\n"
|
| 171 |
+
" (:Symptom) — properties: type, severity, updated_at\n"
|
| 172 |
+
" (:Event) — properties: type, timestamp, notes\n"
|
| 173 |
+
" (:Entity) — properties: name\n\n"
|
| 174 |
+
"### Relationship Types\n"
|
| 175 |
+
" TAKES — (Person)-[:TAKES]->(Medication)\n"
|
| 176 |
+
" EXPERIENCED — (Person)-[:EXPERIENCED]->(Event)\n"
|
| 177 |
+
" TRIGGERED_BY — (Event)-[:TRIGGERED_BY]->(Symptom)\n"
|
| 178 |
+
" HAS_CAREGIVER — (Person)-[:HAS_CAREGIVER]->(Person)\n"
|
| 179 |
+
" MONITORS — (Person)-[:MONITORS]->(Person)\n"
|
| 180 |
+
" SEES — (Person)-[:SEES]->(Person)\n"
|
| 181 |
+
" MENTIONED_WITH — (Entity)-[:MENTIONED_WITH]->(Entity)"
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
# -------------------------------------------------------------------------
|
| 185 |
# Node Creation
|
| 186 |
# -------------------------------------------------------------------------
|
src/reachy_mini_conversation_app/openai_realtime.py
CHANGED
|
@@ -141,20 +141,38 @@ class OpenaiRealtimeHandler(RealtimeHandler):
|
|
| 141 |
try:
|
| 142 |
from reachy_mini_conversation_app.session_enrichment import (
|
| 143 |
get_session_enrichment,
|
|
|
|
| 144 |
)
|
| 145 |
|
| 146 |
enrichment = get_session_enrichment()
|
|
|
|
|
|
|
|
|
|
| 147 |
if enrichment and enrichment._graph:
|
| 148 |
patient_name = profile.get("name") if profile else None
|
| 149 |
if patient_name:
|
| 150 |
ctx = enrichment._graph.format_context_for_prompt(patient_name)
|
| 151 |
if ctx:
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
except Exception as e:
|
| 159 |
logger.debug("Graph context injection skipped: %s", e)
|
| 160 |
|
|
|
|
| 141 |
try:
|
| 142 |
from reachy_mini_conversation_app.session_enrichment import (
|
| 143 |
get_session_enrichment,
|
| 144 |
+
get_latest_insights,
|
| 145 |
)
|
| 146 |
|
| 147 |
enrichment = get_session_enrichment()
|
| 148 |
+
graph_parts = []
|
| 149 |
+
|
| 150 |
+
# Patient context from Neo4j
|
| 151 |
if enrichment and enrichment._graph:
|
| 152 |
patient_name = profile.get("name") if profile else None
|
| 153 |
if patient_name:
|
| 154 |
ctx = enrichment._graph.format_context_for_prompt(patient_name)
|
| 155 |
if ctx:
|
| 156 |
+
graph_parts.append(ctx)
|
| 157 |
+
|
| 158 |
+
# Pattern insights from last session's analysis
|
| 159 |
+
insights = get_latest_insights()
|
| 160 |
+
if insights:
|
| 161 |
+
from reachy_mini_conversation_app.pattern_detector import (
|
| 162 |
+
format_insights_for_prompt,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
insights_text = format_insights_for_prompt(insights)
|
| 166 |
+
if insights_text:
|
| 167 |
+
graph_parts.append(insights_text)
|
| 168 |
+
|
| 169 |
+
if graph_parts:
|
| 170 |
+
state.graph_context = "\n\n".join(graph_parts)
|
| 171 |
+
logger.info(
|
| 172 |
+
"Injected graph context (%d chars, %d insights)",
|
| 173 |
+
len(state.graph_context),
|
| 174 |
+
len(insights),
|
| 175 |
+
)
|
| 176 |
except Exception as e:
|
| 177 |
logger.debug("Graph context injection skipped: %s", e)
|
| 178 |
|
src/reachy_mini_conversation_app/pattern_detector.py
ADDED
|
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pattern detection across the Neo4j health knowledge graph.
|
| 2 |
+
|
| 3 |
+
Analyses the graph for temporal correlations, frequency changes, and
|
| 4 |
+
medication-symptom co-occurrence patterns. All findings use neutral,
|
| 5 |
+
observational language — never causal claims.
|
| 6 |
+
|
| 7 |
+
Safety constraints:
|
| 8 |
+
- Uses "co-occurred", "appeared together", "trend" — never "caused" or "triggered"
|
| 9 |
+
- All queries are read-only (use execute_read)
|
| 10 |
+
- Results are scored by confidence (sample size × effect strength)
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import logging
|
| 16 |
+
from dataclasses import dataclass, field
|
| 17 |
+
from typing import Any, Dict, List, Optional
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class Insight:
|
| 24 |
+
"""A detected health pattern or trend."""
|
| 25 |
+
|
| 26 |
+
pattern_type: str # "correlation", "frequency_change", "temporal", "adherence"
|
| 27 |
+
summary: str # Human-readable one-liner (spoken aloud)
|
| 28 |
+
detail: str # Longer explanation
|
| 29 |
+
confidence: float # 0.0-1.0 based on sample size and effect strength
|
| 30 |
+
entities: List[str] = field(default_factory=list) # Entity names involved
|
| 31 |
+
period_days: int = 30 # Analysis window
|
| 32 |
+
|
| 33 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 34 |
+
return {
|
| 35 |
+
"pattern_type": self.pattern_type,
|
| 36 |
+
"summary": self.summary,
|
| 37 |
+
"detail": self.detail,
|
| 38 |
+
"confidence": self.confidence,
|
| 39 |
+
"entities": self.entities,
|
| 40 |
+
"period_days": self.period_days,
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class PatternDetector:
|
| 45 |
+
"""Analyse the Neo4j knowledge graph for health patterns.
|
| 46 |
+
|
| 47 |
+
All analysis queries are pre-written Cypher (not LLM-generated) —
|
| 48 |
+
pattern templates are safe to hardcode since they are read-only.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
# Minimum event count required for a pattern to be considered meaningful
|
| 52 |
+
MIN_SAMPLE_SIZE = 3
|
| 53 |
+
|
| 54 |
+
def __init__(self, graph_memory: Any) -> None:
|
| 55 |
+
"""Initialise with a connected GraphMemory instance."""
|
| 56 |
+
self._graph = graph_memory
|
| 57 |
+
|
| 58 |
+
def _query(
|
| 59 |
+
self, cypher: str, params: Optional[Dict[str, Any]] = None
|
| 60 |
+
) -> List[Dict[str, Any]]:
|
| 61 |
+
"""Execute a read-only query. Returns empty list if graph unavailable."""
|
| 62 |
+
if not self._graph or not self._graph.is_connected:
|
| 63 |
+
return []
|
| 64 |
+
try:
|
| 65 |
+
return self._graph.execute_read(cypher, params or {})
|
| 66 |
+
except Exception as e:
|
| 67 |
+
logger.warning("Pattern query failed: %s", e)
|
| 68 |
+
return []
|
| 69 |
+
|
| 70 |
+
# ------------------------------------------------------------------
|
| 71 |
+
# Pattern detectors
|
| 72 |
+
# ------------------------------------------------------------------
|
| 73 |
+
|
| 74 |
+
def detect_medication_symptom_correlation(
|
| 75 |
+
self, patient_name: str, days: int = 30
|
| 76 |
+
) -> List[Insight]:
|
| 77 |
+
"""Find medications and symptoms that co-occur within a time window.
|
| 78 |
+
|
| 79 |
+
Looks for medications taken on the same day as symptom events.
|
| 80 |
+
"""
|
| 81 |
+
query = """
|
| 82 |
+
MATCH (p:Person {name: $patient_name})-[:TAKES]->(m:Medication)
|
| 83 |
+
MATCH (p)-[:EXPERIENCED]->(e:Event)
|
| 84 |
+
WHERE e.timestamp >= datetime() - duration('P' + toString($days) + 'D')
|
| 85 |
+
AND e.type IN ['headache', 'migraine_episode', 'confusion', 'fatigue',
|
| 86 |
+
'mood_change', 'dizziness', 'vision_change']
|
| 87 |
+
WITH m.name AS medication, e.type AS symptom,
|
| 88 |
+
count(e) AS co_occurrence_count,
|
| 89 |
+
collect(DISTINCT date(e.timestamp)) AS dates
|
| 90 |
+
WHERE co_occurrence_count >= $min_sample
|
| 91 |
+
RETURN medication, symptom, co_occurrence_count,
|
| 92 |
+
size(dates) AS distinct_days
|
| 93 |
+
ORDER BY co_occurrence_count DESC
|
| 94 |
+
LIMIT 10
|
| 95 |
+
"""
|
| 96 |
+
results = self._query(
|
| 97 |
+
query,
|
| 98 |
+
{
|
| 99 |
+
"patient_name": patient_name,
|
| 100 |
+
"days": days,
|
| 101 |
+
"min_sample": self.MIN_SAMPLE_SIZE,
|
| 102 |
+
},
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
insights = []
|
| 106 |
+
for r in results:
|
| 107 |
+
med = r["medication"]
|
| 108 |
+
sym = r["symptom"]
|
| 109 |
+
count = r["co_occurrence_count"]
|
| 110 |
+
d_days = r["distinct_days"]
|
| 111 |
+
|
| 112 |
+
# Confidence: more days and more events = higher confidence
|
| 113 |
+
confidence = min(1.0, (d_days / days) * (count / (days * 0.5)))
|
| 114 |
+
|
| 115 |
+
insights.append(
|
| 116 |
+
Insight(
|
| 117 |
+
pattern_type="correlation",
|
| 118 |
+
summary=f"{sym.replace('_', ' ').title()} appeared on {d_days} of the days you took {med} in the last {days} days.",
|
| 119 |
+
detail=f"{count} {sym.replace('_', ' ')} events co-occurred with {med} across {d_days} distinct days.",
|
| 120 |
+
confidence=round(confidence, 2),
|
| 121 |
+
entities=[med, sym],
|
| 122 |
+
period_days=days,
|
| 123 |
+
)
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
return insights
|
| 127 |
+
|
| 128 |
+
def detect_frequency_changes(
|
| 129 |
+
self, patient_name: str, event_type: str = "headache", days: int = 30
|
| 130 |
+
) -> List[Insight]:
|
| 131 |
+
"""Detect if event frequency increased or decreased vs the prior period.
|
| 132 |
+
|
| 133 |
+
Compares the last N days to the N days before that.
|
| 134 |
+
"""
|
| 135 |
+
query = """
|
| 136 |
+
MATCH (p:Person {name: $patient_name})-[:EXPERIENCED]->(e:Event {type: $event_type})
|
| 137 |
+
WITH e,
|
| 138 |
+
CASE WHEN e.timestamp >= datetime() - duration('P' + toString($days) + 'D')
|
| 139 |
+
THEN 'recent'
|
| 140 |
+
ELSE CASE WHEN e.timestamp >= datetime() - duration('P' + toString($days * 2) + 'D')
|
| 141 |
+
THEN 'prior'
|
| 142 |
+
ELSE 'older'
|
| 143 |
+
END
|
| 144 |
+
END AS period
|
| 145 |
+
WHERE period IN ['recent', 'prior']
|
| 146 |
+
RETURN period, count(e) AS event_count
|
| 147 |
+
ORDER BY period
|
| 148 |
+
"""
|
| 149 |
+
results = self._query(
|
| 150 |
+
query,
|
| 151 |
+
{
|
| 152 |
+
"patient_name": patient_name,
|
| 153 |
+
"event_type": event_type,
|
| 154 |
+
"days": days,
|
| 155 |
+
},
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
if len(results) < 2:
|
| 159 |
+
return []
|
| 160 |
+
|
| 161 |
+
counts = {r["period"]: r["event_count"] for r in results}
|
| 162 |
+
recent = counts.get("recent", 0)
|
| 163 |
+
prior = counts.get("prior", 0)
|
| 164 |
+
|
| 165 |
+
if recent == prior or (recent + prior) < self.MIN_SAMPLE_SIZE:
|
| 166 |
+
return []
|
| 167 |
+
|
| 168 |
+
if prior == 0:
|
| 169 |
+
if recent >= self.MIN_SAMPLE_SIZE:
|
| 170 |
+
return [
|
| 171 |
+
Insight(
|
| 172 |
+
pattern_type="frequency_change",
|
| 173 |
+
summary=f"You've logged {recent} {event_type.replace('_', ' ')} events in the last {days} days — this is new.",
|
| 174 |
+
detail=f"{recent} events in last {days} days vs 0 in the prior {days} days.",
|
| 175 |
+
confidence=0.6,
|
| 176 |
+
entities=[event_type],
|
| 177 |
+
period_days=days,
|
| 178 |
+
)
|
| 179 |
+
]
|
| 180 |
+
return []
|
| 181 |
+
|
| 182 |
+
pct_change = ((recent - prior) / prior) * 100
|
| 183 |
+
direction = "increased" if pct_change > 0 else "decreased"
|
| 184 |
+
abs_pct = abs(int(pct_change))
|
| 185 |
+
|
| 186 |
+
# Only report meaningful changes (>25%)
|
| 187 |
+
if abs_pct < 25:
|
| 188 |
+
return []
|
| 189 |
+
|
| 190 |
+
confidence = min(1.0, (recent + prior) / (days * 0.5) * (abs_pct / 100))
|
| 191 |
+
|
| 192 |
+
friendly_type = event_type.replace("_", " ")
|
| 193 |
+
return [
|
| 194 |
+
Insight(
|
| 195 |
+
pattern_type="frequency_change",
|
| 196 |
+
summary=f"Your {friendly_type} frequency has {direction} by about {abs_pct}% compared to the previous {days} days.",
|
| 197 |
+
detail=f"{recent} events in the last {days} days vs {prior} in the prior {days} days ({direction} {abs_pct}%).",
|
| 198 |
+
confidence=round(min(confidence, 1.0), 2),
|
| 199 |
+
entities=[event_type],
|
| 200 |
+
period_days=days,
|
| 201 |
+
)
|
| 202 |
+
]
|
| 203 |
+
|
| 204 |
+
def detect_missed_medication_impact(
|
| 205 |
+
self, patient_name: str, days: int = 30
|
| 206 |
+
) -> List[Insight]:
|
| 207 |
+
"""Correlate missed medications with symptom severity the following day."""
|
| 208 |
+
query = """
|
| 209 |
+
MATCH (p:Person {name: $patient_name})-[:TAKES]->(m:Medication)
|
| 210 |
+
MATCH (p)-[:EXPERIENCED]->(e:Event)
|
| 211 |
+
WHERE e.timestamp >= datetime() - duration('P' + toString($days) + 'D')
|
| 212 |
+
AND e.type IN ['headache', 'migraine_episode', 'confusion', 'fatigue']
|
| 213 |
+
OPTIONAL MATCH (p)-[:EXPERIENCED]->(missed:Event {type: 'medication_missed'})
|
| 214 |
+
WHERE missed.timestamp >= e.timestamp - duration('P1D')
|
| 215 |
+
AND missed.timestamp <= e.timestamp
|
| 216 |
+
WITH m.name AS medication, e.type AS symptom,
|
| 217 |
+
count(DISTINCT e) AS total_events,
|
| 218 |
+
count(DISTINCT missed) AS preceded_by_miss
|
| 219 |
+
WHERE total_events >= $min_sample AND preceded_by_miss > 0
|
| 220 |
+
RETURN medication, symptom, total_events, preceded_by_miss,
|
| 221 |
+
toFloat(preceded_by_miss) / toFloat(total_events) AS miss_ratio
|
| 222 |
+
ORDER BY miss_ratio DESC
|
| 223 |
+
LIMIT 5
|
| 224 |
+
"""
|
| 225 |
+
results = self._query(
|
| 226 |
+
query,
|
| 227 |
+
{
|
| 228 |
+
"patient_name": patient_name,
|
| 229 |
+
"days": days,
|
| 230 |
+
"min_sample": self.MIN_SAMPLE_SIZE,
|
| 231 |
+
},
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
insights = []
|
| 235 |
+
for r in results:
|
| 236 |
+
ratio = r["miss_ratio"]
|
| 237 |
+
if ratio < 0.3: # Only report if >30% correlation
|
| 238 |
+
continue
|
| 239 |
+
|
| 240 |
+
med = r["medication"]
|
| 241 |
+
sym = r["symptom"].replace("_", " ")
|
| 242 |
+
pct = int(ratio * 100)
|
| 243 |
+
|
| 244 |
+
insights.append(
|
| 245 |
+
Insight(
|
| 246 |
+
pattern_type="adherence",
|
| 247 |
+
summary=f"About {pct}% of your {sym} events were preceded by missing {med} the day before.",
|
| 248 |
+
detail=f"{r['preceded_by_miss']} of {r['total_events']} {sym} events came after a missed {med} dose.",
|
| 249 |
+
confidence=round(min(ratio, 1.0), 2),
|
| 250 |
+
entities=[med, r["symptom"]],
|
| 251 |
+
period_days=days,
|
| 252 |
+
)
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
return insights
|
| 256 |
+
|
| 257 |
+
def detect_temporal_patterns(
|
| 258 |
+
self, patient_name: str, days: int = 30
|
| 259 |
+
) -> List[Insight]:
|
| 260 |
+
"""Find time-of-day and day-of-week clustering of symptoms."""
|
| 261 |
+
query = """
|
| 262 |
+
MATCH (p:Person {name: $patient_name})-[:EXPERIENCED]->(e:Event)
|
| 263 |
+
WHERE e.timestamp >= datetime() - duration('P' + toString($days) + 'D')
|
| 264 |
+
AND e.type IN ['headache', 'migraine_episode', 'confusion', 'fatigue',
|
| 265 |
+
'mood_change', 'dizziness']
|
| 266 |
+
WITH e.type AS symptom,
|
| 267 |
+
CASE
|
| 268 |
+
WHEN e.timestamp.hour < 6 THEN 'early_morning'
|
| 269 |
+
WHEN e.timestamp.hour < 12 THEN 'morning'
|
| 270 |
+
WHEN e.timestamp.hour < 18 THEN 'afternoon'
|
| 271 |
+
ELSE 'evening'
|
| 272 |
+
END AS time_of_day,
|
| 273 |
+
e.timestamp.dayOfWeek AS day_of_week,
|
| 274 |
+
count(e) AS event_count
|
| 275 |
+
RETURN symptom, time_of_day, day_of_week, event_count
|
| 276 |
+
ORDER BY event_count DESC
|
| 277 |
+
LIMIT 20
|
| 278 |
+
"""
|
| 279 |
+
results = self._query(
|
| 280 |
+
query,
|
| 281 |
+
{
|
| 282 |
+
"patient_name": patient_name,
|
| 283 |
+
"days": days,
|
| 284 |
+
},
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
if not results:
|
| 288 |
+
return []
|
| 289 |
+
|
| 290 |
+
# Group by symptom and find dominant time pattern
|
| 291 |
+
symptom_times: Dict[str, Dict[str, int]] = {}
|
| 292 |
+
for r in results:
|
| 293 |
+
sym = r["symptom"]
|
| 294 |
+
tod = r["time_of_day"]
|
| 295 |
+
count = r["event_count"]
|
| 296 |
+
if sym not in symptom_times:
|
| 297 |
+
symptom_times[sym] = {}
|
| 298 |
+
symptom_times[sym][tod] = symptom_times[sym].get(tod, 0) + count
|
| 299 |
+
|
| 300 |
+
insights = []
|
| 301 |
+
for sym, time_counts in symptom_times.items():
|
| 302 |
+
total = sum(time_counts.values())
|
| 303 |
+
if total < self.MIN_SAMPLE_SIZE:
|
| 304 |
+
continue
|
| 305 |
+
|
| 306 |
+
# Find dominant time period
|
| 307 |
+
dominant_time = max(time_counts, key=time_counts.get)
|
| 308 |
+
dominant_count = time_counts[dominant_time]
|
| 309 |
+
ratio = dominant_count / total
|
| 310 |
+
|
| 311 |
+
# Only report if >50% of events cluster in one time period
|
| 312 |
+
if ratio < 0.5:
|
| 313 |
+
continue
|
| 314 |
+
|
| 315 |
+
pct = int(ratio * 100)
|
| 316 |
+
friendly_sym = sym.replace("_", " ")
|
| 317 |
+
friendly_time = dominant_time.replace("_", " ")
|
| 318 |
+
|
| 319 |
+
insights.append(
|
| 320 |
+
Insight(
|
| 321 |
+
pattern_type="temporal",
|
| 322 |
+
summary=f"About {pct}% of your {friendly_sym} events tend to happen in the {friendly_time}.",
|
| 323 |
+
detail=f"{dominant_count} of {total} {friendly_sym} events occurred during {friendly_time} hours over the last {days} days.",
|
| 324 |
+
confidence=round(min(ratio, 1.0), 2),
|
| 325 |
+
entities=[sym, dominant_time],
|
| 326 |
+
period_days=days,
|
| 327 |
+
)
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
return insights
|
| 331 |
+
|
| 332 |
+
# ------------------------------------------------------------------
|
| 333 |
+
# Run all detectors
|
| 334 |
+
# ------------------------------------------------------------------
|
| 335 |
+
|
| 336 |
+
def run_analysis(self, patient_name: str, days: int = 30) -> List[Insight]:
|
| 337 |
+
"""Run all pattern detectors and return sorted insights.
|
| 338 |
+
|
| 339 |
+
Args:
|
| 340 |
+
patient_name: The patient's name for graph queries.
|
| 341 |
+
days: Analysis window in days.
|
| 342 |
+
|
| 343 |
+
Returns:
|
| 344 |
+
List of Insight objects sorted by confidence (highest first).
|
| 345 |
+
"""
|
| 346 |
+
if not self._graph or not self._graph.is_connected:
|
| 347 |
+
logger.debug("Graph not available for pattern detection")
|
| 348 |
+
return []
|
| 349 |
+
|
| 350 |
+
all_insights: List[Insight] = []
|
| 351 |
+
|
| 352 |
+
# Run each detector
|
| 353 |
+
for detector_name, detector_fn in [
|
| 354 |
+
(
|
| 355 |
+
"medication_symptom_correlation",
|
| 356 |
+
self.detect_medication_symptom_correlation,
|
| 357 |
+
),
|
| 358 |
+
(
|
| 359 |
+
"frequency_changes_headache",
|
| 360 |
+
lambda p, d: self.detect_frequency_changes(p, "headache", d),
|
| 361 |
+
),
|
| 362 |
+
(
|
| 363 |
+
"frequency_changes_migraine",
|
| 364 |
+
lambda p, d: self.detect_frequency_changes(p, "migraine_episode", d),
|
| 365 |
+
),
|
| 366 |
+
(
|
| 367 |
+
"frequency_changes_confusion",
|
| 368 |
+
lambda p, d: self.detect_frequency_changes(p, "confusion", d),
|
| 369 |
+
),
|
| 370 |
+
("missed_medication_impact", self.detect_missed_medication_impact),
|
| 371 |
+
("temporal_patterns", self.detect_temporal_patterns),
|
| 372 |
+
]:
|
| 373 |
+
try:
|
| 374 |
+
insights = detector_fn(patient_name, days)
|
| 375 |
+
all_insights.extend(insights)
|
| 376 |
+
if insights:
|
| 377 |
+
logger.info(
|
| 378 |
+
"Pattern detector '%s' found %d insights",
|
| 379 |
+
detector_name,
|
| 380 |
+
len(insights),
|
| 381 |
+
)
|
| 382 |
+
except Exception as e:
|
| 383 |
+
logger.warning("Pattern detector '%s' failed: %s", detector_name, e)
|
| 384 |
+
|
| 385 |
+
# Sort by confidence descending, keep top insights
|
| 386 |
+
all_insights.sort(key=lambda i: i.confidence, reverse=True)
|
| 387 |
+
|
| 388 |
+
# Cap at 5 most confident insights
|
| 389 |
+
top_insights = all_insights[:5]
|
| 390 |
+
|
| 391 |
+
if top_insights:
|
| 392 |
+
logger.info(
|
| 393 |
+
"Pattern detection complete: %d insights (top confidence: %.2f)",
|
| 394 |
+
len(top_insights),
|
| 395 |
+
top_insights[0].confidence,
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
return top_insights
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def format_insights_for_prompt(insights: List[Insight]) -> str:
|
| 402 |
+
"""Format pattern insights for injection into the system prompt.
|
| 403 |
+
|
| 404 |
+
Returns a block suitable for appending to graph_context in SessionState.
|
| 405 |
+
"""
|
| 406 |
+
if not insights:
|
| 407 |
+
return ""
|
| 408 |
+
|
| 409 |
+
lines = [
|
| 410 |
+
"",
|
| 411 |
+
"## Recent Health Insights",
|
| 412 |
+
"The following patterns were detected in the patient's health data.",
|
| 413 |
+
"Mention them naturally if relevant to the conversation — don't force them.",
|
| 414 |
+
"Use observational language only (e.g., 'I noticed', 'it looks like').",
|
| 415 |
+
"",
|
| 416 |
+
]
|
| 417 |
+
|
| 418 |
+
for i, insight in enumerate(insights, 1):
|
| 419 |
+
lines.append(
|
| 420 |
+
f"{i}. **{insight.pattern_type.replace('_', ' ').title()}** "
|
| 421 |
+
f"(confidence: {insight.confidence:.0%}): {insight.summary}"
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
return "\n".join(lines)
|
src/reachy_mini_conversation_app/profiles/_reachy_mini_minder_locked_profile/tools.txt
CHANGED
|
@@ -7,6 +7,7 @@ log_entry
|
|
| 7 |
entry_control
|
| 8 |
get_recent_entries
|
| 9 |
check_medication
|
|
|
|
| 10 |
|
| 11 |
# Onboarding & setup tools (unchanged)
|
| 12 |
get_current_datetime
|
|
|
|
| 7 |
entry_control
|
| 8 |
get_recent_entries
|
| 9 |
check_medication
|
| 10 |
+
query_health_history
|
| 11 |
|
| 12 |
# Onboarding & setup tools (unchanged)
|
| 13 |
get_current_datetime
|
src/reachy_mini_conversation_app/session_enrichment.py
CHANGED
|
@@ -290,6 +290,7 @@ class SessionEnrichment:
|
|
| 290 |
# -------------------------------------------------------------------------
|
| 291 |
|
| 292 |
_enrichment_instance: Optional[SessionEnrichment] = None
|
|
|
|
| 293 |
|
| 294 |
|
| 295 |
def init_session_enrichment(graph_memory: Optional[Any] = None) -> SessionEnrichment:
|
|
@@ -312,6 +313,14 @@ def get_session_enrichment() -> Optional[SessionEnrichment]:
|
|
| 312 |
return _enrichment_instance
|
| 313 |
|
| 314 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
async def on_session_end(
|
| 316 |
session_id: str,
|
| 317 |
turns: List[Dict[str, Any]],
|
|
@@ -319,6 +328,9 @@ async def on_session_end(
|
|
| 319 |
) -> Dict[str, int]:
|
| 320 |
"""Hook called when a conversation session ends.
|
| 321 |
|
|
|
|
|
|
|
|
|
|
| 322 |
Args:
|
| 323 |
session_id: The session that ended.
|
| 324 |
turns: All turns from that session.
|
|
@@ -327,12 +339,32 @@ async def on_session_end(
|
|
| 327 |
Returns:
|
| 328 |
Enrichment result counts.
|
| 329 |
"""
|
|
|
|
|
|
|
| 330 |
if not _enrichment_instance:
|
| 331 |
logger.warning("Session enrichment not initialized")
|
| 332 |
return {}
|
| 333 |
|
| 334 |
-
|
| 335 |
session_id=session_id,
|
| 336 |
turns=turns,
|
| 337 |
patient_name=patient_name,
|
| 338 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
# -------------------------------------------------------------------------
|
| 291 |
|
| 292 |
_enrichment_instance: Optional[SessionEnrichment] = None
|
| 293 |
+
_latest_insights: List[Any] = [] # Cached pattern insights
|
| 294 |
|
| 295 |
|
| 296 |
def init_session_enrichment(graph_memory: Optional[Any] = None) -> SessionEnrichment:
|
|
|
|
| 313 |
return _enrichment_instance
|
| 314 |
|
| 315 |
|
| 316 |
+
def get_latest_insights() -> List[Any]:
|
| 317 |
+
"""Get the most recently detected pattern insights.
|
| 318 |
+
|
| 319 |
+
Returns list of Insight objects from the last pattern detection run.
|
| 320 |
+
"""
|
| 321 |
+
return _latest_insights
|
| 322 |
+
|
| 323 |
+
|
| 324 |
async def on_session_end(
|
| 325 |
session_id: str,
|
| 326 |
turns: List[Dict[str, Any]],
|
|
|
|
| 328 |
) -> Dict[str, int]:
|
| 329 |
"""Hook called when a conversation session ends.
|
| 330 |
|
| 331 |
+
After enrichment, runs pattern detection if the graph is available
|
| 332 |
+
and caches insights for the next session's system prompt.
|
| 333 |
+
|
| 334 |
Args:
|
| 335 |
session_id: The session that ended.
|
| 336 |
turns: All turns from that session.
|
|
|
|
| 339 |
Returns:
|
| 340 |
Enrichment result counts.
|
| 341 |
"""
|
| 342 |
+
global _latest_insights
|
| 343 |
+
|
| 344 |
if not _enrichment_instance:
|
| 345 |
logger.warning("Session enrichment not initialized")
|
| 346 |
return {}
|
| 347 |
|
| 348 |
+
counts = await _enrichment_instance.enrich_session(
|
| 349 |
session_id=session_id,
|
| 350 |
turns=turns,
|
| 351 |
patient_name=patient_name,
|
| 352 |
)
|
| 353 |
+
|
| 354 |
+
# Run pattern detection after enrichment
|
| 355 |
+
if _enrichment_instance._graph and patient_name:
|
| 356 |
+
try:
|
| 357 |
+
from reachy_mini_conversation_app.pattern_detector import PatternDetector
|
| 358 |
+
|
| 359 |
+
detector = PatternDetector(_enrichment_instance._graph)
|
| 360 |
+
_latest_insights = detector.run_analysis(patient_name, days=30)
|
| 361 |
+
if _latest_insights:
|
| 362 |
+
logger.info(
|
| 363 |
+
"Pattern detection found %d insights after session %s",
|
| 364 |
+
len(_latest_insights),
|
| 365 |
+
session_id[:8],
|
| 366 |
+
)
|
| 367 |
+
except Exception as e:
|
| 368 |
+
logger.warning("Pattern detection failed: %s", e)
|
| 369 |
+
|
| 370 |
+
return counts
|
src/reachy_mini_conversation_app/tools/check_medication.py
CHANGED
|
@@ -112,7 +112,10 @@ class CheckMedicationTool(Tool):
|
|
| 112 |
) -> Dict[str, Any]:
|
| 113 |
"""Query Neo4j for medication events."""
|
| 114 |
try:
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
# Get patient name from profile
|
| 118 |
patient_name = "Patient" # Default
|
|
@@ -120,17 +123,11 @@ class CheckMedicationTool(Tool):
|
|
| 120 |
profile = deps.database.get_or_create_profile()
|
| 121 |
patient_name = profile.get("name", "Patient")
|
| 122 |
|
| 123 |
-
graph = GraphMemory()
|
| 124 |
-
if not graph.connect():
|
| 125 |
-
logger.debug("Neo4j not available for medication check")
|
| 126 |
-
return {"logged": False}
|
| 127 |
-
|
| 128 |
result = graph.check_medication_today(
|
| 129 |
patient_name=patient_name,
|
| 130 |
medication_name=medication_name,
|
| 131 |
time_of_day=time_of_day,
|
| 132 |
)
|
| 133 |
-
graph.close()
|
| 134 |
|
| 135 |
if result.get("logged"):
|
| 136 |
return self._format_neo4j_result(result, medication_name, time_of_day)
|
|
|
|
| 112 |
) -> Dict[str, Any]:
|
| 113 |
"""Query Neo4j for medication events."""
|
| 114 |
try:
|
| 115 |
+
graph = deps.graph_memory
|
| 116 |
+
if not graph or not graph.is_connected:
|
| 117 |
+
logger.debug("Neo4j not available for medication check")
|
| 118 |
+
return {"logged": False}
|
| 119 |
|
| 120 |
# Get patient name from profile
|
| 121 |
patient_name = "Patient" # Default
|
|
|
|
| 123 |
profile = deps.database.get_or_create_profile()
|
| 124 |
patient_name = profile.get("name", "Patient")
|
| 125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
result = graph.check_medication_today(
|
| 127 |
patient_name=patient_name,
|
| 128 |
medication_name=medication_name,
|
| 129 |
time_of_day=time_of_day,
|
| 130 |
)
|
|
|
|
| 131 |
|
| 132 |
if result.get("logged"):
|
| 133 |
return self._format_neo4j_result(result, medication_name, time_of_day)
|
src/reachy_mini_conversation_app/tools/core_tools.py
CHANGED
|
@@ -61,6 +61,7 @@ class ToolDependencies:
|
|
| 61 |
# Mini-Minder deps
|
| 62 |
entry_state_manager: Any | None = None # EntryStateManager
|
| 63 |
database: Any | None = None # MiniMinderDB
|
|
|
|
| 64 |
|
| 65 |
|
| 66 |
# Tool base class
|
|
|
|
| 61 |
# Mini-Minder deps
|
| 62 |
entry_state_manager: Any | None = None # EntryStateManager
|
| 63 |
database: Any | None = None # MiniMinderDB
|
| 64 |
+
graph_memory: Any | None = None # GraphMemory (Neo4j)
|
| 65 |
|
| 66 |
|
| 67 |
# Tool base class
|
src/reachy_mini_conversation_app/tools/query_health_history.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Query health history tool.
|
| 2 |
+
|
| 3 |
+
Allows users to ask natural language questions about their health history
|
| 4 |
+
via voice. Uses GraphQueryEngine for LLM-generated Cypher queries against
|
| 5 |
+
the Neo4j knowledge graph.
|
| 6 |
+
|
| 7 |
+
Examples:
|
| 8 |
+
- "How many headaches did I have this week?"
|
| 9 |
+
- "What medications am I taking?"
|
| 10 |
+
- "When did I last see my doctor?"
|
| 11 |
+
- "Show me patterns in my symptoms"
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import logging
|
| 17 |
+
from typing import Any, Dict, Optional
|
| 18 |
+
|
| 19 |
+
from reachy_mini_conversation_app.tools.core_tools import Tool, ToolDependencies
|
| 20 |
+
from reachy_mini_conversation_app.stream_api import emit_ui_component
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class QueryHealthHistoryTool(Tool):
|
| 26 |
+
"""Answer natural language questions about health history via Neo4j graph."""
|
| 27 |
+
|
| 28 |
+
name = "query_health_history"
|
| 29 |
+
description = (
|
| 30 |
+
"Answer questions about the user's health history, medications, symptoms, "
|
| 31 |
+
"and patterns using the knowledge graph. Use when the user asks things like "
|
| 32 |
+
"'How many headaches did I have this week?', 'When did I last see Dr Patel?', "
|
| 33 |
+
"'What medications am I taking?', 'Show me patterns in my symptoms', "
|
| 34 |
+
"'Have my migraines been getting worse?', etc."
|
| 35 |
+
)
|
| 36 |
+
parameters_schema = {
|
| 37 |
+
"type": "object",
|
| 38 |
+
"properties": {
|
| 39 |
+
"question": {
|
| 40 |
+
"type": "string",
|
| 41 |
+
"description": "The health history question to answer",
|
| 42 |
+
},
|
| 43 |
+
},
|
| 44 |
+
"required": ["question"],
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
async def __call__(
|
| 48 |
+
self,
|
| 49 |
+
deps: ToolDependencies,
|
| 50 |
+
question: str = "",
|
| 51 |
+
) -> Dict[str, Any]:
|
| 52 |
+
"""Query the Neo4j knowledge graph with a natural language question.
|
| 53 |
+
|
| 54 |
+
Returns dict with:
|
| 55 |
+
- answer: str (natural language, suitable for spoken response)
|
| 56 |
+
- result_count: int
|
| 57 |
+
- source: str ("graph" or "unavailable")
|
| 58 |
+
"""
|
| 59 |
+
if not question:
|
| 60 |
+
return {
|
| 61 |
+
"answer": "I didn't catch a question. Could you ask me something specific about your health history?",
|
| 62 |
+
"source": "unavailable",
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
# Check graph availability
|
| 66 |
+
graph = deps.graph_memory
|
| 67 |
+
if not graph or not graph.is_connected:
|
| 68 |
+
return {
|
| 69 |
+
"answer": (
|
| 70 |
+
"I don't have access to your health history graph right now. "
|
| 71 |
+
"I can still help with things I remember from our conversations."
|
| 72 |
+
),
|
| 73 |
+
"source": "unavailable",
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
# Get patient name for query parameterisation
|
| 77 |
+
patient_name = "Patient"
|
| 78 |
+
if deps.database:
|
| 79 |
+
profile = deps.database.get_or_create_profile()
|
| 80 |
+
patient_name = profile.get("display_name") or profile.get("name", "Patient")
|
| 81 |
+
|
| 82 |
+
# PII guard: redact before sending to LLM
|
| 83 |
+
redacted_question = question
|
| 84 |
+
pii_mapping: Dict[str, str] = {}
|
| 85 |
+
try:
|
| 86 |
+
from reachy_mini_conversation_app.pii_guard import get_pii_guard
|
| 87 |
+
|
| 88 |
+
guard = get_pii_guard()
|
| 89 |
+
if guard:
|
| 90 |
+
redacted_question, pii_mapping = guard.redact(question)
|
| 91 |
+
if pii_mapping:
|
| 92 |
+
logger.debug(
|
| 93 |
+
"PII redacted from question: %s", list(pii_mapping.keys())
|
| 94 |
+
)
|
| 95 |
+
except Exception as e:
|
| 96 |
+
logger.debug("PII guard not available: %s", e)
|
| 97 |
+
|
| 98 |
+
# Run the query engine
|
| 99 |
+
from reachy_mini_conversation_app.graph_query_engine import GraphQueryEngine
|
| 100 |
+
|
| 101 |
+
engine = GraphQueryEngine(graph)
|
| 102 |
+
result = await engine.query(redacted_question, patient_name)
|
| 103 |
+
|
| 104 |
+
# PII guard: hydrate the answer
|
| 105 |
+
answer = result.get("answer", "I couldn't find that information.")
|
| 106 |
+
if pii_mapping:
|
| 107 |
+
try:
|
| 108 |
+
guard = get_pii_guard()
|
| 109 |
+
if guard:
|
| 110 |
+
answer = guard.hydrate(answer, pii_mapping)
|
| 111 |
+
except Exception:
|
| 112 |
+
pass # Answer is still useful without hydration
|
| 113 |
+
|
| 114 |
+
# Emit a GenUI component if we got interesting results
|
| 115 |
+
if result.get("result_count", 0) > 0:
|
| 116 |
+
try:
|
| 117 |
+
emit_ui_component(
|
| 118 |
+
"InsightCard",
|
| 119 |
+
{
|
| 120 |
+
"title": "Health Query",
|
| 121 |
+
"summary": answer,
|
| 122 |
+
"detail": f"Based on {result['result_count']} record(s) in your health graph.",
|
| 123 |
+
"source": "graph_query",
|
| 124 |
+
},
|
| 125 |
+
)
|
| 126 |
+
except Exception as e:
|
| 127 |
+
logger.debug("InsightCard emission failed: %s", e)
|
| 128 |
+
|
| 129 |
+
return {
|
| 130 |
+
"answer": answer,
|
| 131 |
+
"result_count": result.get("result_count", 0),
|
| 132 |
+
"source": "graph",
|
| 133 |
+
}
|
tests/test_graph_query_engine.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for the GraphQueryEngine — Cypher validation and safety guardrails."""
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
from unittest.mock import AsyncMock, MagicMock, patch
|
| 5 |
+
|
| 6 |
+
from reachy_mini_conversation_app.graph_query_engine import (
|
| 7 |
+
validate_cypher,
|
| 8 |
+
GraphQueryEngine,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TestCypherValidation:
|
| 13 |
+
"""Test the Cypher safety validation function."""
|
| 14 |
+
|
| 15 |
+
def test_rejects_create(self):
|
| 16 |
+
assert validate_cypher("CREATE (n:Person {name: 'test'})") is False
|
| 17 |
+
|
| 18 |
+
def test_rejects_merge(self):
|
| 19 |
+
assert validate_cypher("MERGE (n:Person {name: 'test'})") is False
|
| 20 |
+
|
| 21 |
+
def test_rejects_delete(self):
|
| 22 |
+
assert validate_cypher("MATCH (n) DELETE n") is False
|
| 23 |
+
|
| 24 |
+
def test_rejects_detach_delete(self):
|
| 25 |
+
assert validate_cypher("MATCH (n) DETACH DELETE n") is False
|
| 26 |
+
|
| 27 |
+
def test_rejects_set(self):
|
| 28 |
+
assert validate_cypher("MATCH (n) SET n.name = 'test'") is False
|
| 29 |
+
|
| 30 |
+
def test_rejects_remove(self):
|
| 31 |
+
assert validate_cypher("MATCH (n) REMOVE n.name") is False
|
| 32 |
+
|
| 33 |
+
def test_rejects_drop(self):
|
| 34 |
+
assert validate_cypher("DROP INDEX ON :Person(name)") is False
|
| 35 |
+
|
| 36 |
+
def test_rejects_call(self):
|
| 37 |
+
assert validate_cypher("CALL db.labels()") is False
|
| 38 |
+
|
| 39 |
+
def test_rejects_load_csv(self):
|
| 40 |
+
assert validate_cypher("LOAD CSV FROM 'file:///data.csv' AS line") is False
|
| 41 |
+
|
| 42 |
+
def test_rejects_foreach(self):
|
| 43 |
+
assert (
|
| 44 |
+
validate_cypher("MATCH (n) FOREACH (x IN [1,2] | SET n.val = x)") is False
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
def test_rejects_empty_string(self):
|
| 48 |
+
assert validate_cypher("") is False
|
| 49 |
+
|
| 50 |
+
def test_rejects_whitespace_only(self):
|
| 51 |
+
assert validate_cypher(" ") is False
|
| 52 |
+
|
| 53 |
+
def test_allows_simple_match(self):
|
| 54 |
+
assert validate_cypher("MATCH (n) RETURN n") is True
|
| 55 |
+
|
| 56 |
+
def test_allows_match_with_where(self):
|
| 57 |
+
assert (
|
| 58 |
+
validate_cypher(
|
| 59 |
+
"MATCH (p:Person {name: $patient_name})-[:TAKES]->(m:Medication) "
|
| 60 |
+
"WHERE m.dose IS NOT NULL RETURN m.name, m.dose"
|
| 61 |
+
)
|
| 62 |
+
is True
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
def test_allows_optional_match(self):
|
| 66 |
+
assert (
|
| 67 |
+
validate_cypher(
|
| 68 |
+
"OPTIONAL MATCH (p:Person)-[:EXPERIENCED]->(e:Event) RETURN count(e)"
|
| 69 |
+
)
|
| 70 |
+
is True
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
def test_allows_with_clause(self):
|
| 74 |
+
assert (
|
| 75 |
+
validate_cypher(
|
| 76 |
+
"MATCH (p:Person {name: $patient_name}) "
|
| 77 |
+
"WITH p "
|
| 78 |
+
"MATCH (p)-[:TAKES]->(m:Medication) "
|
| 79 |
+
"RETURN m.name"
|
| 80 |
+
)
|
| 81 |
+
is True
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
def test_allows_aggregation(self):
|
| 85 |
+
assert (
|
| 86 |
+
validate_cypher(
|
| 87 |
+
"MATCH (p:Person)-[:EXPERIENCED]->(e:Event {type: 'headache'}) "
|
| 88 |
+
"RETURN count(e) AS headache_count, avg(e.severity) AS avg_severity"
|
| 89 |
+
)
|
| 90 |
+
is True
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
def test_allows_order_by_limit(self):
|
| 94 |
+
assert (
|
| 95 |
+
validate_cypher(
|
| 96 |
+
"MATCH (p:Person)-[:EXPERIENCED]->(e:Event) "
|
| 97 |
+
"RETURN e.type, e.timestamp "
|
| 98 |
+
"ORDER BY e.timestamp DESC LIMIT 10"
|
| 99 |
+
)
|
| 100 |
+
is True
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
def test_allows_unwind(self):
|
| 104 |
+
assert (
|
| 105 |
+
validate_cypher(
|
| 106 |
+
"UNWIND $names AS name " "MATCH (p:Person {name: name}) RETURN p"
|
| 107 |
+
)
|
| 108 |
+
is True
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
def test_rejects_create_case_insensitive(self):
|
| 112 |
+
"""Mutations should be caught regardless of case."""
|
| 113 |
+
assert validate_cypher("match (n) Create (m:Test)") is False
|
| 114 |
+
assert validate_cypher("MATCH (n) create (m:Test)") is False
|
| 115 |
+
assert validate_cypher("match (n) CREATE (m:Test)") is False
|
| 116 |
+
|
| 117 |
+
def test_rejects_query_not_starting_with_match(self):
|
| 118 |
+
"""Queries must start with MATCH, WITH, RETURN, or UNWIND."""
|
| 119 |
+
assert validate_cypher("RETURN 1") is True
|
| 120 |
+
assert validate_cypher("EXPLAIN MATCH (n) RETURN n") is False
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class TestGraphQueryEngineSchemaDescription:
|
| 124 |
+
"""Test schema description from GraphMemory."""
|
| 125 |
+
|
| 126 |
+
def test_fallback_schema_when_not_connected(self):
|
| 127 |
+
"""Schema description should return a fallback when not connected."""
|
| 128 |
+
mock_graph = MagicMock()
|
| 129 |
+
mock_graph.is_connected = False
|
| 130 |
+
mock_graph.get_schema_description.return_value = (
|
| 131 |
+
"## Neo4j Graph Schema\n\n### Node Labels\n (:Person)"
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
engine = GraphQueryEngine(mock_graph)
|
| 135 |
+
schema = engine._get_schema()
|
| 136 |
+
|
| 137 |
+
assert "Neo4j Graph Schema" in schema
|
| 138 |
+
|
| 139 |
+
def test_schema_caching(self):
|
| 140 |
+
"""Schema should only be fetched once."""
|
| 141 |
+
mock_graph = MagicMock()
|
| 142 |
+
mock_graph.get_schema_description.return_value = "cached schema"
|
| 143 |
+
|
| 144 |
+
engine = GraphQueryEngine(mock_graph)
|
| 145 |
+
_ = engine._get_schema()
|
| 146 |
+
_ = engine._get_schema()
|
| 147 |
+
|
| 148 |
+
# Should only be called once
|
| 149 |
+
mock_graph.get_schema_description.assert_called_once()
|
| 150 |
+
|
| 151 |
+
def test_cache_invalidation(self):
|
| 152 |
+
"""After invalidation, schema should be re-fetched."""
|
| 153 |
+
mock_graph = MagicMock()
|
| 154 |
+
mock_graph.get_schema_description.return_value = "schema v1"
|
| 155 |
+
|
| 156 |
+
engine = GraphQueryEngine(mock_graph)
|
| 157 |
+
s1 = engine._get_schema()
|
| 158 |
+
assert s1 == "schema v1"
|
| 159 |
+
|
| 160 |
+
mock_graph.get_schema_description.return_value = "schema v2"
|
| 161 |
+
engine.invalidate_schema_cache()
|
| 162 |
+
s2 = engine._get_schema()
|
| 163 |
+
assert s2 == "schema v2"
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
@pytest.mark.asyncio
|
| 167 |
+
class TestGraphQueryEngineExecution:
|
| 168 |
+
"""Test query execution path (mocked)."""
|
| 169 |
+
|
| 170 |
+
async def test_execute_uses_read_session(self):
|
| 171 |
+
"""Ensure execute() calls execute_read, not _execute."""
|
| 172 |
+
mock_graph = MagicMock()
|
| 173 |
+
mock_graph.is_connected = True
|
| 174 |
+
mock_graph.execute_read.return_value = [{"count": 3}]
|
| 175 |
+
|
| 176 |
+
engine = GraphQueryEngine(mock_graph)
|
| 177 |
+
results = await engine.execute(
|
| 178 |
+
"MATCH (n:Person)-[:EXPERIENCED]->(e:Event) RETURN count(e) AS count",
|
| 179 |
+
patient_name="Elena",
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
mock_graph.execute_read.assert_called_once()
|
| 183 |
+
assert results == [{"count": 3}]
|
| 184 |
+
|
| 185 |
+
async def test_execute_returns_empty_when_disconnected(self):
|
| 186 |
+
"""Ensure execute() returns empty list when graph is not connected."""
|
| 187 |
+
mock_graph = MagicMock()
|
| 188 |
+
mock_graph.is_connected = False
|
| 189 |
+
|
| 190 |
+
engine = GraphQueryEngine(mock_graph)
|
| 191 |
+
results = await engine.execute("MATCH (n) RETURN n")
|
| 192 |
+
|
| 193 |
+
assert results == []
|
| 194 |
+
mock_graph.execute_read.assert_not_called()
|
| 195 |
+
|
| 196 |
+
async def test_query_handles_generation_error_gracefully(self):
|
| 197 |
+
"""Full query() should return a friendly error message on failure."""
|
| 198 |
+
mock_graph = MagicMock()
|
| 199 |
+
mock_graph.is_connected = True
|
| 200 |
+
mock_graph.get_schema_description.return_value = "test schema"
|
| 201 |
+
|
| 202 |
+
engine = GraphQueryEngine(mock_graph)
|
| 203 |
+
|
| 204 |
+
# Mock the OpenAI client to raise an error
|
| 205 |
+
with patch.object(engine, "_get_client") as mock_client:
|
| 206 |
+
mock_client.return_value.chat.completions.create = AsyncMock(
|
| 207 |
+
side_effect=Exception("API error")
|
| 208 |
+
)
|
| 209 |
+
result = await engine.query("How many headaches?", "Elena")
|
| 210 |
+
|
| 211 |
+
assert "error" in result
|
| 212 |
+
assert "answer" in result
|
| 213 |
+
assert result["result_count"] == 0
|
tests/test_pattern_detector.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for the PatternDetector and Insight formatting."""
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
from unittest.mock import MagicMock
|
| 5 |
+
|
| 6 |
+
from reachy_mini_conversation_app.pattern_detector import (
|
| 7 |
+
PatternDetector,
|
| 8 |
+
Insight,
|
| 9 |
+
format_insights_for_prompt,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class TestInsight:
|
| 14 |
+
"""Test the Insight dataclass."""
|
| 15 |
+
|
| 16 |
+
def test_to_dict(self):
|
| 17 |
+
insight = Insight(
|
| 18 |
+
pattern_type="correlation",
|
| 19 |
+
summary="Test summary",
|
| 20 |
+
detail="Test detail",
|
| 21 |
+
confidence=0.75,
|
| 22 |
+
entities=["med1", "headache"],
|
| 23 |
+
period_days=30,
|
| 24 |
+
)
|
| 25 |
+
d = insight.to_dict()
|
| 26 |
+
assert d["pattern_type"] == "correlation"
|
| 27 |
+
assert d["confidence"] == 0.75
|
| 28 |
+
assert "med1" in d["entities"]
|
| 29 |
+
|
| 30 |
+
def test_default_entities_and_period(self):
|
| 31 |
+
insight = Insight(
|
| 32 |
+
pattern_type="test",
|
| 33 |
+
summary="s",
|
| 34 |
+
detail="d",
|
| 35 |
+
confidence=0.5,
|
| 36 |
+
)
|
| 37 |
+
assert insight.entities == []
|
| 38 |
+
assert insight.period_days == 30
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class TestFormatInsightsForPrompt:
|
| 42 |
+
"""Test formatting insights for system prompt injection."""
|
| 43 |
+
|
| 44 |
+
def test_empty_insights_returns_empty(self):
|
| 45 |
+
assert format_insights_for_prompt([]) == ""
|
| 46 |
+
|
| 47 |
+
def test_single_insight_format(self):
|
| 48 |
+
insights = [
|
| 49 |
+
Insight(
|
| 50 |
+
pattern_type="correlation",
|
| 51 |
+
summary="Headache appeared on 5 days.",
|
| 52 |
+
detail="detail",
|
| 53 |
+
confidence=0.8,
|
| 54 |
+
)
|
| 55 |
+
]
|
| 56 |
+
result = format_insights_for_prompt(insights)
|
| 57 |
+
assert "Recent Health Insights" in result
|
| 58 |
+
assert "Correlation" in result
|
| 59 |
+
assert "80%" in result
|
| 60 |
+
assert "Headache appeared on 5 days." in result
|
| 61 |
+
|
| 62 |
+
def test_multiple_insights_numbered(self):
|
| 63 |
+
insights = [
|
| 64 |
+
Insight(
|
| 65 |
+
pattern_type="correlation", summary="s1", detail="d1", confidence=0.9
|
| 66 |
+
),
|
| 67 |
+
Insight(
|
| 68 |
+
pattern_type="frequency_change",
|
| 69 |
+
summary="s2",
|
| 70 |
+
detail="d2",
|
| 71 |
+
confidence=0.7,
|
| 72 |
+
),
|
| 73 |
+
]
|
| 74 |
+
result = format_insights_for_prompt(insights)
|
| 75 |
+
assert "1." in result
|
| 76 |
+
assert "2." in result
|
| 77 |
+
|
| 78 |
+
def test_observational_language_guidance(self):
|
| 79 |
+
"""Prompt should instruct the model to use observational language."""
|
| 80 |
+
insights = [
|
| 81 |
+
Insight(pattern_type="test", summary="s", detail="d", confidence=0.5),
|
| 82 |
+
]
|
| 83 |
+
result = format_insights_for_prompt(insights)
|
| 84 |
+
assert "observational" in result.lower() or "I noticed" in result
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class TestPatternDetectorRunAnalysis:
|
| 88 |
+
"""Test the run_analysis orchestration."""
|
| 89 |
+
|
| 90 |
+
def test_returns_empty_when_disconnected(self):
|
| 91 |
+
mock_graph = MagicMock()
|
| 92 |
+
mock_graph.is_connected = False
|
| 93 |
+
|
| 94 |
+
detector = PatternDetector(mock_graph)
|
| 95 |
+
insights = detector.run_analysis("Elena", days=30)
|
| 96 |
+
assert insights == []
|
| 97 |
+
|
| 98 |
+
def test_returns_empty_when_no_graph(self):
|
| 99 |
+
detector = PatternDetector(None)
|
| 100 |
+
insights = detector.run_analysis("Elena", days=30)
|
| 101 |
+
assert insights == []
|
| 102 |
+
|
| 103 |
+
def test_continues_on_individual_detector_failure(self):
|
| 104 |
+
"""If one detector fails, the others should still run."""
|
| 105 |
+
mock_graph = MagicMock()
|
| 106 |
+
mock_graph.is_connected = True
|
| 107 |
+
# execute_read will raise on first call, return empty on subsequent
|
| 108 |
+
mock_graph.execute_read.side_effect = [
|
| 109 |
+
Exception("Neo4j error"), # medication_symptom_correlation
|
| 110 |
+
[], # frequency_changes headache
|
| 111 |
+
[], # frequency_changes migraine
|
| 112 |
+
[], # frequency_changes confusion
|
| 113 |
+
[], # missed_medication_impact
|
| 114 |
+
[], # temporal_patterns
|
| 115 |
+
]
|
| 116 |
+
|
| 117 |
+
detector = PatternDetector(mock_graph)
|
| 118 |
+
# Should not raise, just log warnings
|
| 119 |
+
insights = detector.run_analysis("Elena", days=30)
|
| 120 |
+
assert isinstance(insights, list)
|
| 121 |
+
|
| 122 |
+
def test_sorts_by_confidence_descending(self):
|
| 123 |
+
"""Insights should be sorted by confidence (highest first)."""
|
| 124 |
+
mock_graph = MagicMock()
|
| 125 |
+
mock_graph.is_connected = True
|
| 126 |
+
# Mock medication_symptom_correlation to return data
|
| 127 |
+
mock_graph.execute_read.side_effect = [
|
| 128 |
+
# medication_symptom_correlation
|
| 129 |
+
[
|
| 130 |
+
{
|
| 131 |
+
"medication": "Med A",
|
| 132 |
+
"symptom": "headache",
|
| 133 |
+
"co_occurrence_count": 10,
|
| 134 |
+
"distinct_days": 8,
|
| 135 |
+
},
|
| 136 |
+
{
|
| 137 |
+
"medication": "Med B",
|
| 138 |
+
"symptom": "fatigue",
|
| 139 |
+
"co_occurrence_count": 3,
|
| 140 |
+
"distinct_days": 3,
|
| 141 |
+
},
|
| 142 |
+
],
|
| 143 |
+
# All other detectors return empty
|
| 144 |
+
[],
|
| 145 |
+
[],
|
| 146 |
+
[],
|
| 147 |
+
[],
|
| 148 |
+
[],
|
| 149 |
+
]
|
| 150 |
+
|
| 151 |
+
detector = PatternDetector(mock_graph)
|
| 152 |
+
insights = detector.run_analysis("Elena", days=30)
|
| 153 |
+
|
| 154 |
+
if len(insights) >= 2:
|
| 155 |
+
assert insights[0].confidence >= insights[1].confidence
|
| 156 |
+
|
| 157 |
+
def test_caps_at_five_insights(self):
|
| 158 |
+
"""Should return at most 5 insights."""
|
| 159 |
+
mock_graph = MagicMock()
|
| 160 |
+
mock_graph.is_connected = True
|
| 161 |
+
# Return lots of correlations
|
| 162 |
+
mock_graph.execute_read.side_effect = [
|
| 163 |
+
[
|
| 164 |
+
{
|
| 165 |
+
"medication": f"Med{i}",
|
| 166 |
+
"symptom": f"sym{i}",
|
| 167 |
+
"co_occurrence_count": 5,
|
| 168 |
+
"distinct_days": 5,
|
| 169 |
+
}
|
| 170 |
+
for i in range(10)
|
| 171 |
+
],
|
| 172 |
+
[],
|
| 173 |
+
[],
|
| 174 |
+
[],
|
| 175 |
+
[],
|
| 176 |
+
[],
|
| 177 |
+
]
|
| 178 |
+
|
| 179 |
+
detector = PatternDetector(mock_graph)
|
| 180 |
+
insights = detector.run_analysis("Elena", days=30)
|
| 181 |
+
assert len(insights) <= 5
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class TestPatternDetectorInsightLanguage:
|
| 185 |
+
"""Verify insights never use causal language."""
|
| 186 |
+
|
| 187 |
+
def test_correlation_summary_is_neutral(self):
|
| 188 |
+
mock_graph = MagicMock()
|
| 189 |
+
mock_graph.is_connected = True
|
| 190 |
+
mock_graph.execute_read.side_effect = [
|
| 191 |
+
[
|
| 192 |
+
{
|
| 193 |
+
"medication": "Topiramate",
|
| 194 |
+
"symptom": "headache",
|
| 195 |
+
"co_occurrence_count": 5,
|
| 196 |
+
"distinct_days": 5,
|
| 197 |
+
}
|
| 198 |
+
],
|
| 199 |
+
[],
|
| 200 |
+
[],
|
| 201 |
+
[],
|
| 202 |
+
[],
|
| 203 |
+
[],
|
| 204 |
+
]
|
| 205 |
+
|
| 206 |
+
detector = PatternDetector(mock_graph)
|
| 207 |
+
insights = detector.run_analysis("Elena", days=30)
|
| 208 |
+
|
| 209 |
+
for insight in insights:
|
| 210 |
+
summary_lower = insight.summary.lower()
|
| 211 |
+
assert (
|
| 212 |
+
"caused" not in summary_lower
|
| 213 |
+
), f"Causal language in: {insight.summary}"
|
| 214 |
+
assert (
|
| 215 |
+
"triggered" not in summary_lower
|
| 216 |
+
), f"Causal language in: {insight.summary}"
|
| 217 |
+
assert (
|
| 218 |
+
"because" not in summary_lower
|
| 219 |
+
), f"Causal language in: {insight.summary}"
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class TestPatternDetectorFrequencyChanges:
|
| 223 |
+
"""Test the frequency change detector."""
|
| 224 |
+
|
| 225 |
+
def test_detects_increase(self):
|
| 226 |
+
mock_graph = MagicMock()
|
| 227 |
+
mock_graph.is_connected = True
|
| 228 |
+
mock_graph.execute_read.return_value = [
|
| 229 |
+
{"period": "prior", "event_count": 2},
|
| 230 |
+
{"period": "recent", "event_count": 6},
|
| 231 |
+
]
|
| 232 |
+
|
| 233 |
+
detector = PatternDetector(mock_graph)
|
| 234 |
+
insights = detector.detect_frequency_changes("Elena", "headache", days=30)
|
| 235 |
+
|
| 236 |
+
assert len(insights) == 1
|
| 237 |
+
assert "increased" in insights[0].summary.lower()
|
| 238 |
+
|
| 239 |
+
def test_detects_decrease(self):
|
| 240 |
+
mock_graph = MagicMock()
|
| 241 |
+
mock_graph.is_connected = True
|
| 242 |
+
mock_graph.execute_read.return_value = [
|
| 243 |
+
{"period": "prior", "event_count": 10},
|
| 244 |
+
{"period": "recent", "event_count": 3},
|
| 245 |
+
]
|
| 246 |
+
|
| 247 |
+
detector = PatternDetector(mock_graph)
|
| 248 |
+
insights = detector.detect_frequency_changes("Elena", "headache", days=30)
|
| 249 |
+
|
| 250 |
+
assert len(insights) == 1
|
| 251 |
+
assert "decreased" in insights[0].summary.lower()
|
| 252 |
+
|
| 253 |
+
def test_ignores_small_changes(self):
|
| 254 |
+
"""Changes under 25% should not generate insights."""
|
| 255 |
+
mock_graph = MagicMock()
|
| 256 |
+
mock_graph.is_connected = True
|
| 257 |
+
mock_graph.execute_read.return_value = [
|
| 258 |
+
{"period": "prior", "event_count": 10},
|
| 259 |
+
{"period": "recent", "event_count": 11},
|
| 260 |
+
]
|
| 261 |
+
|
| 262 |
+
detector = PatternDetector(mock_graph)
|
| 263 |
+
insights = detector.detect_frequency_changes("Elena", "headache", days=30)
|
| 264 |
+
assert len(insights) == 0
|
| 265 |
+
|
| 266 |
+
def test_handles_insufficient_data(self):
|
| 267 |
+
"""Should return nothing with fewer than MIN_SAMPLE_SIZE events."""
|
| 268 |
+
mock_graph = MagicMock()
|
| 269 |
+
mock_graph.is_connected = True
|
| 270 |
+
mock_graph.execute_read.return_value = [
|
| 271 |
+
{"period": "recent", "event_count": 1},
|
| 272 |
+
]
|
| 273 |
+
|
| 274 |
+
detector = PatternDetector(mock_graph)
|
| 275 |
+
insights = detector.detect_frequency_changes("Elena", "headache", days=30)
|
| 276 |
+
assert len(insights) == 0
|