TheQuantEd commited on
Commit
bfeb61b
·
1 Parent(s): 8bd7457

Fix backend crash: make Neo4jGraph lazy-init in graphrag.py — was failing at import time

Browse files
Files changed (1) hide show
  1. backend/graphrag.py +43 -33
backend/graphrag.py CHANGED
@@ -10,24 +10,17 @@ from dotenv import load_dotenv
10
 
11
  load_dotenv()
12
 
13
- graph = Neo4jGraph(
14
- url=os.getenv("NEO4J_URI"),
15
- username=os.getenv("NEO4J_USERNAME"),
16
- password=os.getenv("NEO4J_PASSWORD"),
17
- database=os.getenv("NEO4J_DATABASE", "neo4j"),
18
- )
19
 
20
 
21
  def _strip_thinking(text: str) -> str:
22
- """Remove <think>...</think> blocks that reasoning models emit before the actual answer."""
23
- # Strip block tags (including variations like <thinking>)
24
  text = re.sub(r"<think(?:ing)?>.*?</think(?:ing)?>", "", text, flags=re.DOTALL | re.IGNORECASE)
25
  return text.strip()
26
 
27
 
28
  class _ThinkStrippedLLM(ChatOpenAI):
29
- """ChatOpenAI wrapper that strips <think> reasoning tokens from every response."""
30
-
31
  def _create_chat_result(self, response, generation_info=None) -> ChatResult:
32
  result: ChatResult = super()._create_chat_result(response, generation_info)
33
  cleaned = []
@@ -38,12 +31,26 @@ class _ThinkStrippedLLM(ChatOpenAI):
38
  return ChatResult(generations=cleaned, llm_output=result.llm_output)
39
 
40
 
41
- llm = _ThinkStrippedLLM(
42
- model=os.getenv("OPENAI_MODEL", "qwen/qwen3-32b"),
43
- openai_api_key=os.getenv("OPENAI_API_KEY"),
44
- openai_api_base=os.getenv("OPENAI_BASE_URL"),
45
- temperature=0,
46
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  _CYPHER_GENERATION_TEMPLATE = """You are an expert Neo4j Cypher query writer for a clinical trial matching system.
49
 
@@ -66,7 +73,7 @@ Relationships:
66
  - (Trial)-[:LOCATED_AT]->(StudySite)
67
 
68
  Rules:
69
- - For biomarker lookups, use the `id` property with uppercase underscore format, e.g. `{{id: 'HER2_POS'}}` NOT `{{name: 'HER2', status: 'positive'}}`
70
  - For condition lookups on Trial nodes, use lowercase: `t.condition = 'breast cancer'`
71
  - Always use relationship pattern (Patient)-[:ELIGIBLE_FOR]->(Trial) to find eligible patients
72
  - Limit results to 25 unless asked for more
@@ -79,22 +86,26 @@ _CYPHER_PROMPT = PromptTemplate(
79
  template=_CYPHER_GENERATION_TEMPLATE,
80
  )
81
 
82
- graph_chain = GraphCypherQAChain.from_llm(
83
- llm=llm,
84
- graph=graph,
85
- verbose=True,
86
- allow_dangerous_requests=True,
87
- cypher_prompt=_CYPHER_PROMPT,
88
- )
 
 
 
 
 
89
 
90
 
91
  def retrieve_patient_trial_matches(patient_id: str) -> list:
92
- query = f"""
93
- MATCH (p:Patient {{id: '{patient_id}'}})-[:HAS_DIAGNOSIS]->(d:Diagnosis)-[:ELIGIBLE_FOR]->(t:Trial)
94
- RETURN p.id as patient, d.name as diagnosis, t.id as trial, t.phase as phase, t.condition as condition
95
- """
96
  try:
97
- return graph.query(query)
 
 
 
98
  except Exception as e:
99
  print(f"[graphrag] query error: {e}")
100
  return []
@@ -102,19 +113,18 @@ def retrieve_patient_trial_matches(patient_id: str) -> list:
102
 
103
  def rag_query(question: str) -> str:
104
  try:
105
- result = graph_chain.run(question)
106
  return _strip_thinking(result) if result else "No results found."
107
  except Exception as e:
108
  err = str(e)
109
- # Surface a clean message instead of the raw Neo4j stack trace
110
  if "<think>" in err or "SyntaxError" in err:
111
- return "The query model returned unexpected output. Please rephrase your question (e.g. 'List patients eligible for breast cancer trials')."
112
  return f"Graph query error: {err}"
113
 
114
 
115
  def get_graph_stats() -> dict:
116
  try:
117
- result = graph.query("""
118
  MATCH (p:Patient) WITH count(p) as patients
119
  MATCH (t:Trial) WITH patients, count(t) as trials
120
  MATCH (d:Diagnosis) WITH patients, trials, count(d) as diagnoses
 
10
 
11
  load_dotenv()
12
 
13
+ # Lazily initialised — Neo4j may not be ready at import time
14
+ _graph = None
15
+ _graph_chain = None
 
 
 
16
 
17
 
18
  def _strip_thinking(text: str) -> str:
 
 
19
  text = re.sub(r"<think(?:ing)?>.*?</think(?:ing)?>", "", text, flags=re.DOTALL | re.IGNORECASE)
20
  return text.strip()
21
 
22
 
23
  class _ThinkStrippedLLM(ChatOpenAI):
 
 
24
  def _create_chat_result(self, response, generation_info=None) -> ChatResult:
25
  result: ChatResult = super()._create_chat_result(response, generation_info)
26
  cleaned = []
 
31
  return ChatResult(generations=cleaned, llm_output=result.llm_output)
32
 
33
 
34
+ def _get_llm():
35
+ return _ThinkStrippedLLM(
36
+ model=os.getenv("OPENAI_MODEL", "qwen/qwen3-32b"),
37
+ openai_api_key=os.getenv("OPENAI_API_KEY"),
38
+ openai_api_base=os.getenv("OPENAI_BASE_URL"),
39
+ temperature=0,
40
+ )
41
+
42
+
43
+ def _get_graph():
44
+ global _graph
45
+ if _graph is None:
46
+ _graph = Neo4jGraph(
47
+ url=os.getenv("NEO4J_URI", "bolt://127.0.0.1:7687"),
48
+ username=os.getenv("NEO4J_USERNAME", "neo4j"),
49
+ password=os.getenv("NEO4J_PASSWORD", "clinicalmatch2024"),
50
+ database=os.getenv("NEO4J_DATABASE", "neo4j"),
51
+ )
52
+ return _graph
53
+
54
 
55
  _CYPHER_GENERATION_TEMPLATE = """You are an expert Neo4j Cypher query writer for a clinical trial matching system.
56
 
 
73
  - (Trial)-[:LOCATED_AT]->(StudySite)
74
 
75
  Rules:
76
+ - For biomarker lookups, use the `id` property with uppercase underscore format, e.g. `{{id: 'HER2_POS'}}`
77
  - For condition lookups on Trial nodes, use lowercase: `t.condition = 'breast cancer'`
78
  - Always use relationship pattern (Patient)-[:ELIGIBLE_FOR]->(Trial) to find eligible patients
79
  - Limit results to 25 unless asked for more
 
86
  template=_CYPHER_GENERATION_TEMPLATE,
87
  )
88
 
89
+
90
+ def _get_chain():
91
+ global _graph_chain
92
+ if _graph_chain is None:
93
+ _graph_chain = GraphCypherQAChain.from_llm(
94
+ llm=_get_llm(),
95
+ graph=_get_graph(),
96
+ verbose=True,
97
+ allow_dangerous_requests=True,
98
+ cypher_prompt=_CYPHER_PROMPT,
99
+ )
100
+ return _graph_chain
101
 
102
 
103
  def retrieve_patient_trial_matches(patient_id: str) -> list:
 
 
 
 
104
  try:
105
+ return _get_graph().query(f"""
106
+ MATCH (p:Patient {{id: '{patient_id}'}})-[:HAS_DIAGNOSIS]->(d:Diagnosis)-[:ELIGIBLE_FOR]->(t:Trial)
107
+ RETURN p.id as patient, d.name as diagnosis, t.id as trial, t.phase as phase, t.condition as condition
108
+ """)
109
  except Exception as e:
110
  print(f"[graphrag] query error: {e}")
111
  return []
 
113
 
114
  def rag_query(question: str) -> str:
115
  try:
116
+ result = _get_chain().run(question)
117
  return _strip_thinking(result) if result else "No results found."
118
  except Exception as e:
119
  err = str(e)
 
120
  if "<think>" in err or "SyntaxError" in err:
121
+ return "The query model returned unexpected output. Please rephrase your question."
122
  return f"Graph query error: {err}"
123
 
124
 
125
  def get_graph_stats() -> dict:
126
  try:
127
+ result = _get_graph().query("""
128
  MATCH (p:Patient) WITH count(p) as patients
129
  MATCH (t:Trial) WITH patients, count(t) as trials
130
  MATCH (d:Diagnosis) WITH patients, trials, count(d) as diagnoses