Spaces:
Running
Running
| from langchain_community.graphs import Neo4jGraph | |
| from langchain_community.chains.graph_qa.cypher import GraphCypherQAChain | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain_core.messages import BaseMessage, AIMessage | |
| from langchain_core.outputs import ChatResult, ChatGeneration | |
| import re | |
| import os | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # Lazily initialised — Neo4j may not be ready at import time | |
| _graph = None | |
| _graph_chain = None | |
| def _strip_thinking(text: str) -> str: | |
| text = re.sub(r"<think(?:ing)?>.*?</think(?:ing)?>", "", text, flags=re.DOTALL | re.IGNORECASE) | |
| return text.strip() | |
| class _ThinkStrippedLLM(ChatOpenAI): | |
| def _create_chat_result(self, response, generation_info=None) -> ChatResult: | |
| result: ChatResult = super()._create_chat_result(response, generation_info) | |
| cleaned = [] | |
| for gen in result.generations: | |
| raw = gen.message.content or "" | |
| clean = _strip_thinking(raw) | |
| cleaned.append(ChatGeneration(message=AIMessage(content=clean), generation_info=gen.generation_info)) | |
| return ChatResult(generations=cleaned, llm_output=result.llm_output) | |
| def _get_llm(): | |
| return _ThinkStrippedLLM( | |
| model=os.getenv("OPENAI_MODEL", "qwen/qwen3-32b"), | |
| openai_api_key=os.getenv("OPENAI_API_KEY"), | |
| openai_api_base=os.getenv("OPENAI_BASE_URL"), | |
| temperature=0, | |
| ) | |
| def _get_graph(): | |
| global _graph | |
| if _graph is None: | |
| _graph = Neo4jGraph( | |
| url=os.getenv("NEO4J_URI") or "bolt://127.0.0.1:7687", | |
| username=os.getenv("NEO4J_USERNAME") or "neo4j", | |
| password=os.getenv("NEO4J_PASSWORD") or "clinicalmatch2024", | |
| database=os.getenv("NEO4J_DATABASE") or "neo4j", | |
| ) | |
| return _graph | |
| _CYPHER_GENERATION_TEMPLATE = """You are an expert Neo4j Cypher query writer for a clinical trial matching system. | |
| Schema: | |
| {schema} | |
| Node labels and their exact property names: | |
| - Patient: id (e.g. "P_C50_000001"), name, age (integer), sex ("MALE"/"FEMALE"), ecog (integer 0-3), | |
| condition (lowercase, e.g. "breast cancer"), stage ("I"/"II"/"III"/"IV"), | |
| city, state, ethnicity, insurance, icd10_prefix, | |
| biomarkers (list of biomarker ids), medications (list of drug names), | |
| comorbidities (list), prior_chemo (boolean), prior_radiation (boolean), | |
| prior_surgery (boolean), prior_lines_of_therapy (integer), source | |
| - Trial: id (NCT id, e.g. "NCT04567890"), title, condition (lowercase), phase, status, | |
| brief_summary, eligibility_criteria, min_age, max_age, sex, enrollment, | |
| start_date, completion_date, sponsor, location_count, source | |
| - Diagnosis: code (ICD-10, e.g. "C50.919"), name (e.g. "Malignant neoplasm of breast"), source | |
| - Biomarker: id (e.g. "HER2_POS"), name (e.g. "HER2 Positive"), gene (e.g. "ERBB2"), loinc, source | |
| - Medication: rxcui, name, tty, generic_name, source | |
| - StudySite: facility, city, state, country, lat, lon, source | |
| - ConditionNode: name (e.g. "breast cancer") | |
| - Publication: pmid, title, journal, pub_date, authors, source | |
| Relationships: | |
| - (Patient)-[:ELIGIBLE_FOR {{score: float, matched_at: datetime}}]->(Trial) | |
| - (Patient)-[:HAS_DIAGNOSIS]->(Diagnosis) | |
| - (Patient)-[:HAS_BIOMARKER]->(Biomarker) | |
| - (Trial)-[:CONDUCTED_AT]->(StudySite) | |
| - (ConditionNode)-[:HAS_TRIAL]->(Trial) | |
| - (Diagnosis)-[:MAPS_TO_CONDITION]->(ConditionNode) | |
| - (Biomarker)-[:RELEVANT_TO]->(ConditionNode) | |
| - (Biomarker)-[:MAY_QUALIFY_FOR]->(Trial) | |
| - (Publication)-[:SUPPORTS_RESEARCH_ON]->(ConditionNode) | |
| Rules: | |
| - Biomarker lookups use the `id` property: `{{id: 'HER2_POS'}}` | |
| - Diagnosis lookups use `code` (not `id`): `{{code: 'C50.919'}}` | |
| - Medication lookups use `rxcui` or `name` (not `id`) | |
| - Condition lookups on Trial nodes use lowercase: `t.condition = 'breast cancer'` | |
| - Patient-to-trial eligibility: `(p:Patient)-[:ELIGIBLE_FOR]->(t:Trial)` | |
| - ecog property on Patient is `ecog` (integer), NOT `ecog_score` | |
| - Limit results to 25 unless asked for more | |
| Question: {question} | |
| Cypher query:""" | |
| _CYPHER_PROMPT = PromptTemplate( | |
| input_variables=["schema", "question"], | |
| template=_CYPHER_GENERATION_TEMPLATE, | |
| ) | |
| def _get_chain(): | |
| global _graph_chain | |
| if _graph_chain is None: | |
| _graph_chain = GraphCypherQAChain.from_llm( | |
| llm=_get_llm(), | |
| graph=_get_graph(), | |
| verbose=True, | |
| allow_dangerous_requests=True, | |
| cypher_prompt=_CYPHER_PROMPT, | |
| ) | |
| return _graph_chain | |
| def retrieve_patient_trial_matches(patient_id: str) -> list: | |
| try: | |
| return _get_graph().query(f""" | |
| MATCH (p:Patient {{id: '{patient_id}'}})-[:HAS_DIAGNOSIS]->(d:Diagnosis)-[:ELIGIBLE_FOR]->(t:Trial) | |
| RETURN p.id as patient, d.name as diagnosis, t.id as trial, t.phase as phase, t.condition as condition | |
| """) | |
| except Exception as e: | |
| print(f"[graphrag] query error: {e}") | |
| return [] | |
| def rag_query(question: str) -> str: | |
| try: | |
| result = _get_chain().run(question) | |
| return _strip_thinking(result) if result else "No results found." | |
| except Exception as e: | |
| err = str(e) | |
| if "<think>" in err or "SyntaxError" in err: | |
| return "The query model returned unexpected output. Please rephrase your question." | |
| return f"Graph query error: {err}" | |
| def get_graph_stats() -> dict: | |
| from neo4j_setup import neo4j_conn | |
| try: | |
| result = neo4j_conn.run_query(""" | |
| MATCH (p:Patient) WITH count(p) as patients | |
| MATCH (t:Trial) WITH patients, count(t) as trials | |
| MATCH (d:Diagnosis) WITH patients, trials, count(d) as diagnoses | |
| RETURN patients, trials, diagnoses | |
| """) | |
| return {**(result[0] if result else {}), "status": "connected"} | |
| except Exception as e: | |
| return {"patients": 0, "trials": 0, "diagnoses": 0, "status": str(e)} | |