from langchain_community.graphs import Neo4jGraph
from langchain_community.chains.graph_qa.cypher import GraphCypherQAChain
from langchain_openai import ChatOpenAI
from langchain_core.prompts import PromptTemplate
from langchain_core.messages import BaseMessage, AIMessage
from langchain_core.outputs import ChatResult, ChatGeneration
import re
import os
from dotenv import load_dotenv
load_dotenv()
# Lazily initialised — Neo4j may not be ready at import time
_graph = None
_graph_chain = None
def _strip_thinking(text: str) -> str:
text = re.sub(r".*?", "", text, flags=re.DOTALL | re.IGNORECASE)
return text.strip()
class _ThinkStrippedLLM(ChatOpenAI):
def _create_chat_result(self, response, generation_info=None) -> ChatResult:
result: ChatResult = super()._create_chat_result(response, generation_info)
cleaned = []
for gen in result.generations:
raw = gen.message.content or ""
clean = _strip_thinking(raw)
cleaned.append(ChatGeneration(message=AIMessage(content=clean), generation_info=gen.generation_info))
return ChatResult(generations=cleaned, llm_output=result.llm_output)
def _get_llm():
return _ThinkStrippedLLM(
model=os.getenv("OPENAI_MODEL", "qwen/qwen3-32b"),
openai_api_key=os.getenv("OPENAI_API_KEY"),
openai_api_base=os.getenv("OPENAI_BASE_URL"),
temperature=0,
)
def _get_graph():
global _graph
if _graph is None:
_graph = Neo4jGraph(
url=os.getenv("NEO4J_URI") or "bolt://127.0.0.1:7687",
username=os.getenv("NEO4J_USERNAME") or "neo4j",
password=os.getenv("NEO4J_PASSWORD") or "clinicalmatch2024",
database=os.getenv("NEO4J_DATABASE") or "neo4j",
)
return _graph
_CYPHER_GENERATION_TEMPLATE = """You are an expert Neo4j Cypher query writer for a clinical trial matching system.
Schema:
{schema}
Node labels and their exact property names:
- Patient: id (e.g. "P_C50_000001"), name, age (integer), sex ("MALE"/"FEMALE"), ecog (integer 0-3),
condition (lowercase, e.g. "breast cancer"), stage ("I"/"II"/"III"/"IV"),
city, state, ethnicity, insurance, icd10_prefix,
biomarkers (list of biomarker ids), medications (list of drug names),
comorbidities (list), prior_chemo (boolean), prior_radiation (boolean),
prior_surgery (boolean), prior_lines_of_therapy (integer), source
- Trial: id (NCT id, e.g. "NCT04567890"), title, condition (lowercase), phase, status,
brief_summary, eligibility_criteria, min_age, max_age, sex, enrollment,
start_date, completion_date, sponsor, location_count, source
- Diagnosis: code (ICD-10, e.g. "C50.919"), name (e.g. "Malignant neoplasm of breast"), source
- Biomarker: id (e.g. "HER2_POS"), name (e.g. "HER2 Positive"), gene (e.g. "ERBB2"), loinc, source
- Medication: rxcui, name, tty, generic_name, source
- StudySite: facility, city, state, country, lat, lon, source
- ConditionNode: name (e.g. "breast cancer")
- Publication: pmid, title, journal, pub_date, authors, source
Relationships:
- (Patient)-[:ELIGIBLE_FOR {{score: float, matched_at: datetime}}]->(Trial)
- (Patient)-[:HAS_DIAGNOSIS]->(Diagnosis)
- (Patient)-[:HAS_BIOMARKER]->(Biomarker)
- (Trial)-[:CONDUCTED_AT]->(StudySite)
- (ConditionNode)-[:HAS_TRIAL]->(Trial)
- (Diagnosis)-[:MAPS_TO_CONDITION]->(ConditionNode)
- (Biomarker)-[:RELEVANT_TO]->(ConditionNode)
- (Biomarker)-[:MAY_QUALIFY_FOR]->(Trial)
- (Publication)-[:SUPPORTS_RESEARCH_ON]->(ConditionNode)
Rules:
- Biomarker lookups use the `id` property: `{{id: 'HER2_POS'}}`
- Diagnosis lookups use `code` (not `id`): `{{code: 'C50.919'}}`
- Medication lookups use `rxcui` or `name` (not `id`)
- Condition lookups on Trial nodes use lowercase: `t.condition = 'breast cancer'`
- Patient-to-trial eligibility: `(p:Patient)-[:ELIGIBLE_FOR]->(t:Trial)`
- ecog property on Patient is `ecog` (integer), NOT `ecog_score`
- Limit results to 25 unless asked for more
Question: {question}
Cypher query:"""
_CYPHER_PROMPT = PromptTemplate(
input_variables=["schema", "question"],
template=_CYPHER_GENERATION_TEMPLATE,
)
def _get_chain():
global _graph_chain
if _graph_chain is None:
_graph_chain = GraphCypherQAChain.from_llm(
llm=_get_llm(),
graph=_get_graph(),
verbose=True,
allow_dangerous_requests=True,
cypher_prompt=_CYPHER_PROMPT,
)
return _graph_chain
def retrieve_patient_trial_matches(patient_id: str) -> list:
try:
return _get_graph().query(f"""
MATCH (p:Patient {{id: '{patient_id}'}})-[:HAS_DIAGNOSIS]->(d:Diagnosis)-[:ELIGIBLE_FOR]->(t:Trial)
RETURN p.id as patient, d.name as diagnosis, t.id as trial, t.phase as phase, t.condition as condition
""")
except Exception as e:
print(f"[graphrag] query error: {e}")
return []
def rag_query(question: str) -> str:
try:
result = _get_chain().run(question)
return _strip_thinking(result) if result else "No results found."
except Exception as e:
err = str(e)
if "" in err or "SyntaxError" in err:
return "The query model returned unexpected output. Please rephrase your question."
return f"Graph query error: {err}"
def get_graph_stats() -> dict:
from neo4j_setup import neo4j_conn
try:
result = neo4j_conn.run_query("""
MATCH (p:Patient) WITH count(p) as patients
MATCH (t:Trial) WITH patients, count(t) as trials
MATCH (d:Diagnosis) WITH patients, trials, count(d) as diagnoses
RETURN patients, trials, diagnoses
""")
return {**(result[0] if result else {}), "status": "connected"}
except Exception as e:
return {"patients": 0, "trials": 0, "diagnoses": 0, "status": str(e)}