File size: 5,949 Bytes
59abb4f
 
 
 
 
 
 
 
 
 
 
 
bfeb61b
 
 
59abb4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfeb61b
 
 
 
 
 
 
 
 
 
 
 
 
bf49c73
 
 
 
bfeb61b
 
 
59abb4f
 
 
 
 
 
b40cc1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59abb4f
 
b40cc1f
59abb4f
 
b40cc1f
 
 
 
 
 
59abb4f
 
b40cc1f
 
 
 
 
 
59abb4f
 
 
 
 
 
 
 
 
 
bfeb61b
 
 
 
 
 
 
 
 
 
 
 
59abb4f
 
 
 
bfeb61b
 
 
 
59abb4f
 
 
 
 
 
 
bfeb61b
59abb4f
 
 
 
bfeb61b
59abb4f
 
 
 
1428c0f
59abb4f
1428c0f
59abb4f
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from langchain_community.graphs import Neo4jGraph
from langchain_community.chains.graph_qa.cypher import GraphCypherQAChain
from langchain_openai import ChatOpenAI
from langchain_core.prompts import PromptTemplate
from langchain_core.messages import BaseMessage, AIMessage
from langchain_core.outputs import ChatResult, ChatGeneration
import re
import os
from dotenv import load_dotenv

load_dotenv()

# Lazily initialised — Neo4j may not be ready at import time
_graph = None
_graph_chain = None


def _strip_thinking(text: str) -> str:
    text = re.sub(r"<think(?:ing)?>.*?</think(?:ing)?>", "", text, flags=re.DOTALL | re.IGNORECASE)
    return text.strip()


class _ThinkStrippedLLM(ChatOpenAI):
    def _create_chat_result(self, response, generation_info=None) -> ChatResult:
        result: ChatResult = super()._create_chat_result(response, generation_info)
        cleaned = []
        for gen in result.generations:
            raw = gen.message.content or ""
            clean = _strip_thinking(raw)
            cleaned.append(ChatGeneration(message=AIMessage(content=clean), generation_info=gen.generation_info))
        return ChatResult(generations=cleaned, llm_output=result.llm_output)


def _get_llm():
    return _ThinkStrippedLLM(
        model=os.getenv("OPENAI_MODEL", "qwen/qwen3-32b"),
        openai_api_key=os.getenv("OPENAI_API_KEY"),
        openai_api_base=os.getenv("OPENAI_BASE_URL"),
        temperature=0,
    )


def _get_graph():
    global _graph
    if _graph is None:
        _graph = Neo4jGraph(
            url=os.getenv("NEO4J_URI") or "bolt://127.0.0.1:7687",
            username=os.getenv("NEO4J_USERNAME") or "neo4j",
            password=os.getenv("NEO4J_PASSWORD") or "clinicalmatch2024",
            database=os.getenv("NEO4J_DATABASE") or "neo4j",
        )
    return _graph


_CYPHER_GENERATION_TEMPLATE = """You are an expert Neo4j Cypher query writer for a clinical trial matching system.

Schema:
{schema}

Node labels and their exact property names:
- Patient: id (e.g. "P_C50_000001"), name, age (integer), sex ("MALE"/"FEMALE"), ecog (integer 0-3),
           condition (lowercase, e.g. "breast cancer"), stage ("I"/"II"/"III"/"IV"),
           city, state, ethnicity, insurance, icd10_prefix,
           biomarkers (list of biomarker ids), medications (list of drug names),
           comorbidities (list), prior_chemo (boolean), prior_radiation (boolean),
           prior_surgery (boolean), prior_lines_of_therapy (integer), source
- Trial: id (NCT id, e.g. "NCT04567890"), title, condition (lowercase), phase, status,
         brief_summary, eligibility_criteria, min_age, max_age, sex, enrollment,
         start_date, completion_date, sponsor, location_count, source
- Diagnosis: code (ICD-10, e.g. "C50.919"), name (e.g. "Malignant neoplasm of breast"), source
- Biomarker: id (e.g. "HER2_POS"), name (e.g. "HER2 Positive"), gene (e.g. "ERBB2"), loinc, source
- Medication: rxcui, name, tty, generic_name, source
- StudySite: facility, city, state, country, lat, lon, source
- ConditionNode: name (e.g. "breast cancer")
- Publication: pmid, title, journal, pub_date, authors, source

Relationships:
- (Patient)-[:ELIGIBLE_FOR {{score: float, matched_at: datetime}}]->(Trial)
- (Patient)-[:HAS_DIAGNOSIS]->(Diagnosis)
- (Patient)-[:HAS_BIOMARKER]->(Biomarker)
- (Trial)-[:CONDUCTED_AT]->(StudySite)
- (ConditionNode)-[:HAS_TRIAL]->(Trial)
- (Diagnosis)-[:MAPS_TO_CONDITION]->(ConditionNode)
- (Biomarker)-[:RELEVANT_TO]->(ConditionNode)
- (Biomarker)-[:MAY_QUALIFY_FOR]->(Trial)
- (Publication)-[:SUPPORTS_RESEARCH_ON]->(ConditionNode)

Rules:
- Biomarker lookups use the `id` property: `{{id: 'HER2_POS'}}`
- Diagnosis lookups use `code` (not `id`): `{{code: 'C50.919'}}`
- Medication lookups use `rxcui` or `name` (not `id`)
- Condition lookups on Trial nodes use lowercase: `t.condition = 'breast cancer'`
- Patient-to-trial eligibility: `(p:Patient)-[:ELIGIBLE_FOR]->(t:Trial)`
- ecog property on Patient is `ecog` (integer), NOT `ecog_score`
- Limit results to 25 unless asked for more

Question: {question}
Cypher query:"""

_CYPHER_PROMPT = PromptTemplate(
    input_variables=["schema", "question"],
    template=_CYPHER_GENERATION_TEMPLATE,
)


def _get_chain():
    global _graph_chain
    if _graph_chain is None:
        _graph_chain = GraphCypherQAChain.from_llm(
            llm=_get_llm(),
            graph=_get_graph(),
            verbose=True,
            allow_dangerous_requests=True,
            cypher_prompt=_CYPHER_PROMPT,
        )
    return _graph_chain


def retrieve_patient_trial_matches(patient_id: str) -> list:
    try:
        return _get_graph().query(f"""
            MATCH (p:Patient {{id: '{patient_id}'}})-[:HAS_DIAGNOSIS]->(d:Diagnosis)-[:ELIGIBLE_FOR]->(t:Trial)
            RETURN p.id as patient, d.name as diagnosis, t.id as trial, t.phase as phase, t.condition as condition
        """)
    except Exception as e:
        print(f"[graphrag] query error: {e}")
        return []


def rag_query(question: str) -> str:
    try:
        result = _get_chain().run(question)
        return _strip_thinking(result) if result else "No results found."
    except Exception as e:
        err = str(e)
        if "<think>" in err or "SyntaxError" in err:
            return "The query model returned unexpected output. Please rephrase your question."
        return f"Graph query error: {err}"


def get_graph_stats() -> dict:
    from neo4j_setup import neo4j_conn
    try:
        result = neo4j_conn.run_query("""
            MATCH (p:Patient) WITH count(p) as patients
            MATCH (t:Trial) WITH patients, count(t) as trials
            MATCH (d:Diagnosis) WITH patients, trials, count(d) as diagnoses
            RETURN patients, trials, diagnoses
        """)
        return {**(result[0] if result else {}), "status": "connected"}
    except Exception as e:
        return {"patients": 0, "trials": 0, "diagnoses": 0, "status": str(e)}