from langchain_community.graphs import Neo4jGraph from langchain_community.chains.graph_qa.cypher import GraphCypherQAChain from langchain_openai import ChatOpenAI from langchain_core.prompts import PromptTemplate from langchain_core.messages import BaseMessage, AIMessage from langchain_core.outputs import ChatResult, ChatGeneration import re import os from dotenv import load_dotenv load_dotenv() # Lazily initialised — Neo4j may not be ready at import time _graph = None _graph_chain = None def _strip_thinking(text: str) -> str: text = re.sub(r".*?", "", text, flags=re.DOTALL | re.IGNORECASE) return text.strip() class _ThinkStrippedLLM(ChatOpenAI): def _create_chat_result(self, response, generation_info=None) -> ChatResult: result: ChatResult = super()._create_chat_result(response, generation_info) cleaned = [] for gen in result.generations: raw = gen.message.content or "" clean = _strip_thinking(raw) cleaned.append(ChatGeneration(message=AIMessage(content=clean), generation_info=gen.generation_info)) return ChatResult(generations=cleaned, llm_output=result.llm_output) def _get_llm(): return _ThinkStrippedLLM( model=os.getenv("OPENAI_MODEL", "qwen/qwen3-32b"), openai_api_key=os.getenv("OPENAI_API_KEY"), openai_api_base=os.getenv("OPENAI_BASE_URL"), temperature=0, ) def _get_graph(): global _graph if _graph is None: _graph = Neo4jGraph( url=os.getenv("NEO4J_URI") or "bolt://127.0.0.1:7687", username=os.getenv("NEO4J_USERNAME") or "neo4j", password=os.getenv("NEO4J_PASSWORD") or "clinicalmatch2024", database=os.getenv("NEO4J_DATABASE") or "neo4j", ) return _graph _CYPHER_GENERATION_TEMPLATE = """You are an expert Neo4j Cypher query writer for a clinical trial matching system. Schema: {schema} Node labels and their exact property names: - Patient: id (e.g. "P_C50_000001"), name, age (integer), sex ("MALE"/"FEMALE"), ecog (integer 0-3), condition (lowercase, e.g. "breast cancer"), stage ("I"/"II"/"III"/"IV"), city, state, ethnicity, insurance, icd10_prefix, biomarkers (list of biomarker ids), medications (list of drug names), comorbidities (list), prior_chemo (boolean), prior_radiation (boolean), prior_surgery (boolean), prior_lines_of_therapy (integer), source - Trial: id (NCT id, e.g. "NCT04567890"), title, condition (lowercase), phase, status, brief_summary, eligibility_criteria, min_age, max_age, sex, enrollment, start_date, completion_date, sponsor, location_count, source - Diagnosis: code (ICD-10, e.g. "C50.919"), name (e.g. "Malignant neoplasm of breast"), source - Biomarker: id (e.g. "HER2_POS"), name (e.g. "HER2 Positive"), gene (e.g. "ERBB2"), loinc, source - Medication: rxcui, name, tty, generic_name, source - StudySite: facility, city, state, country, lat, lon, source - ConditionNode: name (e.g. "breast cancer") - Publication: pmid, title, journal, pub_date, authors, source Relationships: - (Patient)-[:ELIGIBLE_FOR {{score: float, matched_at: datetime}}]->(Trial) - (Patient)-[:HAS_DIAGNOSIS]->(Diagnosis) - (Patient)-[:HAS_BIOMARKER]->(Biomarker) - (Trial)-[:CONDUCTED_AT]->(StudySite) - (ConditionNode)-[:HAS_TRIAL]->(Trial) - (Diagnosis)-[:MAPS_TO_CONDITION]->(ConditionNode) - (Biomarker)-[:RELEVANT_TO]->(ConditionNode) - (Biomarker)-[:MAY_QUALIFY_FOR]->(Trial) - (Publication)-[:SUPPORTS_RESEARCH_ON]->(ConditionNode) Rules: - Biomarker lookups use the `id` property: `{{id: 'HER2_POS'}}` - Diagnosis lookups use `code` (not `id`): `{{code: 'C50.919'}}` - Medication lookups use `rxcui` or `name` (not `id`) - Condition lookups on Trial nodes use lowercase: `t.condition = 'breast cancer'` - Patient-to-trial eligibility: `(p:Patient)-[:ELIGIBLE_FOR]->(t:Trial)` - ecog property on Patient is `ecog` (integer), NOT `ecog_score` - Limit results to 25 unless asked for more Question: {question} Cypher query:""" _CYPHER_PROMPT = PromptTemplate( input_variables=["schema", "question"], template=_CYPHER_GENERATION_TEMPLATE, ) def _get_chain(): global _graph_chain if _graph_chain is None: _graph_chain = GraphCypherQAChain.from_llm( llm=_get_llm(), graph=_get_graph(), verbose=True, allow_dangerous_requests=True, cypher_prompt=_CYPHER_PROMPT, ) return _graph_chain def retrieve_patient_trial_matches(patient_id: str) -> list: try: return _get_graph().query(f""" MATCH (p:Patient {{id: '{patient_id}'}})-[:HAS_DIAGNOSIS]->(d:Diagnosis)-[:ELIGIBLE_FOR]->(t:Trial) RETURN p.id as patient, d.name as diagnosis, t.id as trial, t.phase as phase, t.condition as condition """) except Exception as e: print(f"[graphrag] query error: {e}") return [] def rag_query(question: str) -> str: try: result = _get_chain().run(question) return _strip_thinking(result) if result else "No results found." except Exception as e: err = str(e) if "" in err or "SyntaxError" in err: return "The query model returned unexpected output. Please rephrase your question." return f"Graph query error: {err}" def get_graph_stats() -> dict: from neo4j_setup import neo4j_conn try: result = neo4j_conn.run_query(""" MATCH (p:Patient) WITH count(p) as patients MATCH (t:Trial) WITH patients, count(t) as trials MATCH (d:Diagnosis) WITH patients, trials, count(d) as diagnoses RETURN patients, trials, diagnoses """) return {**(result[0] if result else {}), "status": "connected"} except Exception as e: return {"patients": 0, "trials": 0, "diagnoses": 0, "status": str(e)}