Spaces:
Running
Running
File size: 5,949 Bytes
59abb4f bfeb61b 59abb4f bfeb61b bf49c73 bfeb61b 59abb4f b40cc1f 59abb4f b40cc1f 59abb4f b40cc1f 59abb4f b40cc1f 59abb4f bfeb61b 59abb4f bfeb61b 59abb4f bfeb61b 59abb4f bfeb61b 59abb4f 1428c0f 59abb4f 1428c0f 59abb4f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | from langchain_community.graphs import Neo4jGraph
from langchain_community.chains.graph_qa.cypher import GraphCypherQAChain
from langchain_openai import ChatOpenAI
from langchain_core.prompts import PromptTemplate
from langchain_core.messages import BaseMessage, AIMessage
from langchain_core.outputs import ChatResult, ChatGeneration
import re
import os
from dotenv import load_dotenv
load_dotenv()
# Lazily initialised — Neo4j may not be ready at import time
_graph = None
_graph_chain = None
def _strip_thinking(text: str) -> str:
text = re.sub(r"<think(?:ing)?>.*?</think(?:ing)?>", "", text, flags=re.DOTALL | re.IGNORECASE)
return text.strip()
class _ThinkStrippedLLM(ChatOpenAI):
def _create_chat_result(self, response, generation_info=None) -> ChatResult:
result: ChatResult = super()._create_chat_result(response, generation_info)
cleaned = []
for gen in result.generations:
raw = gen.message.content or ""
clean = _strip_thinking(raw)
cleaned.append(ChatGeneration(message=AIMessage(content=clean), generation_info=gen.generation_info))
return ChatResult(generations=cleaned, llm_output=result.llm_output)
def _get_llm():
return _ThinkStrippedLLM(
model=os.getenv("OPENAI_MODEL", "qwen/qwen3-32b"),
openai_api_key=os.getenv("OPENAI_API_KEY"),
openai_api_base=os.getenv("OPENAI_BASE_URL"),
temperature=0,
)
def _get_graph():
global _graph
if _graph is None:
_graph = Neo4jGraph(
url=os.getenv("NEO4J_URI") or "bolt://127.0.0.1:7687",
username=os.getenv("NEO4J_USERNAME") or "neo4j",
password=os.getenv("NEO4J_PASSWORD") or "clinicalmatch2024",
database=os.getenv("NEO4J_DATABASE") or "neo4j",
)
return _graph
_CYPHER_GENERATION_TEMPLATE = """You are an expert Neo4j Cypher query writer for a clinical trial matching system.
Schema:
{schema}
Node labels and their exact property names:
- Patient: id (e.g. "P_C50_000001"), name, age (integer), sex ("MALE"/"FEMALE"), ecog (integer 0-3),
condition (lowercase, e.g. "breast cancer"), stage ("I"/"II"/"III"/"IV"),
city, state, ethnicity, insurance, icd10_prefix,
biomarkers (list of biomarker ids), medications (list of drug names),
comorbidities (list), prior_chemo (boolean), prior_radiation (boolean),
prior_surgery (boolean), prior_lines_of_therapy (integer), source
- Trial: id (NCT id, e.g. "NCT04567890"), title, condition (lowercase), phase, status,
brief_summary, eligibility_criteria, min_age, max_age, sex, enrollment,
start_date, completion_date, sponsor, location_count, source
- Diagnosis: code (ICD-10, e.g. "C50.919"), name (e.g. "Malignant neoplasm of breast"), source
- Biomarker: id (e.g. "HER2_POS"), name (e.g. "HER2 Positive"), gene (e.g. "ERBB2"), loinc, source
- Medication: rxcui, name, tty, generic_name, source
- StudySite: facility, city, state, country, lat, lon, source
- ConditionNode: name (e.g. "breast cancer")
- Publication: pmid, title, journal, pub_date, authors, source
Relationships:
- (Patient)-[:ELIGIBLE_FOR {{score: float, matched_at: datetime}}]->(Trial)
- (Patient)-[:HAS_DIAGNOSIS]->(Diagnosis)
- (Patient)-[:HAS_BIOMARKER]->(Biomarker)
- (Trial)-[:CONDUCTED_AT]->(StudySite)
- (ConditionNode)-[:HAS_TRIAL]->(Trial)
- (Diagnosis)-[:MAPS_TO_CONDITION]->(ConditionNode)
- (Biomarker)-[:RELEVANT_TO]->(ConditionNode)
- (Biomarker)-[:MAY_QUALIFY_FOR]->(Trial)
- (Publication)-[:SUPPORTS_RESEARCH_ON]->(ConditionNode)
Rules:
- Biomarker lookups use the `id` property: `{{id: 'HER2_POS'}}`
- Diagnosis lookups use `code` (not `id`): `{{code: 'C50.919'}}`
- Medication lookups use `rxcui` or `name` (not `id`)
- Condition lookups on Trial nodes use lowercase: `t.condition = 'breast cancer'`
- Patient-to-trial eligibility: `(p:Patient)-[:ELIGIBLE_FOR]->(t:Trial)`
- ecog property on Patient is `ecog` (integer), NOT `ecog_score`
- Limit results to 25 unless asked for more
Question: {question}
Cypher query:"""
_CYPHER_PROMPT = PromptTemplate(
input_variables=["schema", "question"],
template=_CYPHER_GENERATION_TEMPLATE,
)
def _get_chain():
global _graph_chain
if _graph_chain is None:
_graph_chain = GraphCypherQAChain.from_llm(
llm=_get_llm(),
graph=_get_graph(),
verbose=True,
allow_dangerous_requests=True,
cypher_prompt=_CYPHER_PROMPT,
)
return _graph_chain
def retrieve_patient_trial_matches(patient_id: str) -> list:
try:
return _get_graph().query(f"""
MATCH (p:Patient {{id: '{patient_id}'}})-[:HAS_DIAGNOSIS]->(d:Diagnosis)-[:ELIGIBLE_FOR]->(t:Trial)
RETURN p.id as patient, d.name as diagnosis, t.id as trial, t.phase as phase, t.condition as condition
""")
except Exception as e:
print(f"[graphrag] query error: {e}")
return []
def rag_query(question: str) -> str:
try:
result = _get_chain().run(question)
return _strip_thinking(result) if result else "No results found."
except Exception as e:
err = str(e)
if "<think>" in err or "SyntaxError" in err:
return "The query model returned unexpected output. Please rephrase your question."
return f"Graph query error: {err}"
def get_graph_stats() -> dict:
from neo4j_setup import neo4j_conn
try:
result = neo4j_conn.run_query("""
MATCH (p:Patient) WITH count(p) as patients
MATCH (t:Trial) WITH patients, count(t) as trials
MATCH (d:Diagnosis) WITH patients, trials, count(d) as diagnoses
RETURN patients, trials, diagnoses
""")
return {**(result[0] if result else {}), "status": "connected"}
except Exception as e:
return {"patients": 0, "trials": 0, "diagnoses": 0, "status": str(e)}
|