dev-yuje commited on
Commit
9675b2d
Β·
1 Parent(s): d6ce198

refactor: eliminate import-time DB connection anti-pattern with LazyGraphRAG proxy and local driver initialization

Browse files
src/graphBuilder/neo4j/finGraph.py CHANGED
@@ -42,19 +42,12 @@ def get_neo4j_driver() -> neo4j.Driver:
42
 
43
  username = os.getenv("NEO4J_USERNAME", "neo4j")
44
  password = os.getenv("NEO4J_PASSWORD", "password")
45
- try:
46
- d = neo4j.GraphDatabase.driver(uri, auth=(username, password))
47
- d.verify_connectivity()
48
- return d
49
- except Exception as e:
50
- import sys
51
- if "pytest" in sys.modules or os.getenv("GITHUB_ACTIONS") == "true":
52
- print(f"⚠️ [TEST/CI ENVIRONMENT] Neo4j connection failed at import time: {e}. (Proceeding with dummy None driver)")
53
- return None
54
- raise e
55
 
56
 
57
- driver = get_neo4j_driver()
58
 
59
  chat_llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
60
  rag_llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0})
@@ -265,6 +258,9 @@ def is_article_loaded(tx, aid: str) -> bool:
265
 
266
 
267
  def main() -> None:
 
 
 
268
  # 1. λͺ¨λ“  μ—‘μ…€ 파일 λ‘œλ“œ ν›„ 병합 및 고유 κΈ°μ‚¬λ§Œ 필터링
269
  xlsx_files = sorted(glob.glob("Articles_*.xlsx"))
270
  if not xlsx_files:
 
42
 
43
  username = os.getenv("NEO4J_USERNAME", "neo4j")
44
  password = os.getenv("NEO4J_PASSWORD", "password")
45
+ d = neo4j.GraphDatabase.driver(uri, auth=(username, password))
46
+ d.verify_connectivity()
47
+ return d
 
 
 
 
 
 
 
48
 
49
 
50
+ driver = None
51
 
52
  chat_llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
53
  rag_llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0})
 
258
 
259
 
260
  def main() -> None:
261
+ global driver
262
+ driver = get_neo4j_driver()
263
+
264
  # 1. λͺ¨λ“  μ—‘μ…€ 파일 λ‘œλ“œ ν›„ 병합 및 고유 κΈ°μ‚¬λ§Œ 필터링
265
  xlsx_files = sorted(glob.glob("Articles_*.xlsx"))
266
  if not xlsx_files:
src/retrieval/finRetrieval.py CHANGED
@@ -26,10 +26,6 @@ from neo4j_graphrag.retrievers import (
26
 
27
  dotenv.load_dotenv()
28
 
29
- # ──────────────────────────────────────────
30
- # 1. DB / LLM / Embedder μ΄ˆκΈ°ν™”
31
- # ──────────────────────────────────────────
32
-
33
 
34
  def get_neo4j_driver() -> neo4j.Driver:
35
  uri = os.getenv("NEO4J_URI", "neo4j://localhost:7687")
@@ -46,37 +42,20 @@ def get_neo4j_driver() -> neo4j.Driver:
46
 
47
  username = os.getenv("NEO4J_USERNAME", "neo4j")
48
  password = os.getenv("NEO4J_PASSWORD", "password")
49
- try:
50
- d = neo4j.GraphDatabase.driver(uri, auth=(username, password))
51
- d.verify_connectivity()
52
- return d
53
- except Exception as e:
54
- import sys
55
- if "pytest" in sys.modules or os.getenv("GITHUB_ACTIONS") == "true":
56
- print(f"⚠️ [TEST/CI ENVIRONMENT] Neo4j connection failed at import time: {e}. (Proceeding with dummy None driver)")
57
- return None
58
- raise e
59
 
60
 
61
- driver = get_neo4j_driver()
62
-
63
  rag_llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0})
64
  embedder = OpenAIEmbeddings(model="text-embedding-3-small")
65
 
66
  INDEX_NAME = "content_vector_index"
67
 
68
  # ──────────────────────────────────────────
69
- # 2. Retriever μ„Έ μ’…λ₯˜ μ΄ˆκΈ°ν™”
70
  # ──────────────────────────────────────────
71
 
72
- # (1) λ³Έλ¬Έ 청크 의미 μœ μ‚¬λ„ 검색
73
- vector_retriever = VectorRetriever(
74
- driver=driver,
75
- index_name=INDEX_NAME,
76
- embedder=embedder,
77
- )
78
-
79
- # (2) 벑터 검색 ν›„ κ·Έλž˜ν”„ 탐색 (κΈ°μ—…Β·κΈ°μˆ Β·μ„œλΉ„μŠ€ ν•¨κ»˜ λ°˜ν™˜)
80
  _retrieval_query = """
81
  MATCH (content:Content)<-[:HAS_CHUNK]-(article:Article)
82
  OPTIONAL MATCH (article)-[:MENTIONS]->(company:AICompany)
@@ -96,16 +75,8 @@ ORDER BY article.published_date DESC
96
  LIMIT 3
97
  """
98
 
99
- vector_cypher_retriever = VectorCypherRetriever(
100
- driver=driver,
101
- index_name=INDEX_NAME,
102
- retrieval_query=_retrieval_query,
103
- embedder=embedder,
104
- )
105
-
106
 
107
- # (3) μžμ—°μ–΄ β†’ Cypher μžλ™ λ³€ν™˜ 검색
108
- def _get_schema() -> str:
109
  with driver.session() as s:
110
  nodes = s.run(
111
  "CALL db.schema.nodeTypeProperties() "
@@ -151,36 +122,10 @@ CYPHER QUERY:
151
  LIMIT 3""",
152
  ]
153
 
154
- text2cypher_retriever = Text2CypherRetriever(
155
- driver=driver,
156
- llm=rag_llm,
157
- neo4j_schema=_get_schema(),
158
- examples=_examples,
159
- )
160
-
161
  # ──────────────────────────────────────────
162
  # 3. ToolsRetriever + GraphRAG 쑰립
163
  # ──────────────────────────────────────────
164
 
165
- tools_retriever = ToolsRetriever(
166
- driver=driver,
167
- llm=rag_llm,
168
- tools=[
169
- vector_retriever.convert_to_tool(
170
- name="vector_retriever",
171
- description="λ‰΄μŠ€ 본문의 의미(λ‚΄μš©) μœ μ‚¬λ„ 기반 검색. AI κΈ°μˆ Β·μ„œλΉ„μŠ€ κ΄€λ ¨ ν…μŠ€νŠΈλ₯Ό 찾을 λ•Œ μ‚¬μš©.",
172
- ),
173
- vector_cypher_retriever.convert_to_tool(
174
- name="vectorcypher_retriever",
175
- description="벑터 검색 ν›„ ν•΄λ‹Ή κΈ°μ‚¬μ—μ„œ μ–ΈκΈ‰λœ κΈ°μ—…Β·κΈ°μˆ Β·μ„œλΉ„μŠ€ κ·Έλž˜ν”„λ₯Ό ν•¨κ»˜ λ°˜ν™˜. κΈ°μ—… AI νŠΈλ Œλ“œ 뢄석에 졜적.",
176
- ),
177
- text2cypher_retriever.convert_to_tool(
178
- name="text2cypher_retriever",
179
- description="μžμ—°μ–΄λ₯Ό Cypher둜 λ³€ν™˜. νŠΉμ • κΈ°μ—… μ„œλΉ„μŠ€ λͺ©λ‘, 기술 보유 κΈ°μ—… λ“± ꡬ쑰적 μ§ˆμ˜μ— μ‚¬μš©.",
180
- ),
181
- ],
182
- )
183
-
184
  from typing import Any
185
 
186
  from neo4j_graphrag.retrievers.base import Retriever
@@ -205,13 +150,6 @@ class HybridFallbackRetriever(Retriever):
205
  return res
206
 
207
 
208
- # ν•˜μ΄λΈŒλ¦¬λ“œ 검색 μΈμŠ€ν„΄μŠ€ μž₯μ°©
209
- hybrid_retriever = HybridFallbackRetriever(
210
- tools_retriever=tools_retriever,
211
- fallback_retriever=vector_cypher_retriever,
212
- )
213
-
214
-
215
  class CustomRagTemplate(RagTemplate):
216
  EXPECTED_INPUTS = ["context", "query_text"]
217
 
@@ -238,9 +176,76 @@ _prompt_template = CustomRagTemplate(
238
  expected_inputs=["context", "query_text"]
239
  )
240
 
241
- # app.pyμ—μ„œ 이 객체λ₯Ό 직접 importν•˜μ—¬ μ‚¬μš©ν•©λ‹ˆλ‹€.
242
- graphrag = GraphRAG(
243
- llm=rag_llm,
244
- retriever=hybrid_retriever,
245
- prompt_template=_prompt_template,
246
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  dotenv.load_dotenv()
28
 
 
 
 
 
29
 
30
  def get_neo4j_driver() -> neo4j.Driver:
31
  uri = os.getenv("NEO4J_URI", "neo4j://localhost:7687")
 
42
 
43
  username = os.getenv("NEO4J_USERNAME", "neo4j")
44
  password = os.getenv("NEO4J_PASSWORD", "password")
45
+ d = neo4j.GraphDatabase.driver(uri, auth=(username, password))
46
+ d.verify_connectivity()
47
+ return d
 
 
 
 
 
 
 
48
 
49
 
 
 
50
  rag_llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0})
51
  embedder = OpenAIEmbeddings(model="text-embedding-3-small")
52
 
53
  INDEX_NAME = "content_vector_index"
54
 
55
  # ──────────────────────────────────────────
56
+ # 2. Retriever κ΄€λ ¨ μƒμˆ˜ 및 μ„€μ •
57
  # ──────────────────────────────────────────
58
 
 
 
 
 
 
 
 
 
59
  _retrieval_query = """
60
  MATCH (content:Content)<-[:HAS_CHUNK]-(article:Article)
61
  OPTIONAL MATCH (article)-[:MENTIONS]->(company:AICompany)
 
75
  LIMIT 3
76
  """
77
 
 
 
 
 
 
 
 
78
 
79
+ def _get_schema(driver: neo4j.Driver) -> str:
 
80
  with driver.session() as s:
81
  nodes = s.run(
82
  "CALL db.schema.nodeTypeProperties() "
 
122
  LIMIT 3""",
123
  ]
124
 
 
 
 
 
 
 
 
125
  # ──────────────────────────────────────────
126
  # 3. ToolsRetriever + GraphRAG 쑰립
127
  # ──────────────────────────────────────────
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  from typing import Any
130
 
131
  from neo4j_graphrag.retrievers.base import Retriever
 
150
  return res
151
 
152
 
 
 
 
 
 
 
 
153
  class CustomRagTemplate(RagTemplate):
154
  EXPECTED_INPUTS = ["context", "query_text"]
155
 
 
176
  expected_inputs=["context", "query_text"]
177
  )
178
 
179
+
180
+ class LazyGraphRAG:
181
+ """μž„ν¬νŠΈ μ‹œμ μ— DB 연결을 λ°©μ§€ν•˜κ³  μ‹€μ œ 호좜될 λ•Œ GraphRAG μΈμŠ€ν„΄μŠ€λ₯Ό μ΄ˆκΈ°ν™”ν•˜λŠ” μ§€μ—° 평가 ν”„λ‘μ‹œ"""
182
+ def __init__(self) -> None:
183
+ self._graphrag = None
184
+
185
+ def _init_once(self) -> None:
186
+ if self._graphrag is not None:
187
+ return
188
+
189
+ driver = get_neo4j_driver()
190
+
191
+ vector_retriever = VectorRetriever(
192
+ driver=driver,
193
+ index_name=INDEX_NAME,
194
+ embedder=embedder,
195
+ )
196
+
197
+ vector_cypher_retriever = VectorCypherRetriever(
198
+ driver=driver,
199
+ index_name=INDEX_NAME,
200
+ retrieval_query=_retrieval_query,
201
+ embedder=embedder,
202
+ )
203
+
204
+ text2cypher_retriever = Text2CypherRetriever(
205
+ driver=driver,
206
+ llm=rag_llm,
207
+ neo4j_schema=_get_schema(driver),
208
+ examples=_examples,
209
+ )
210
+
211
+ tools_retriever = ToolsRetriever(
212
+ driver=driver,
213
+ llm=rag_llm,
214
+ tools=[
215
+ vector_retriever.convert_to_tool(
216
+ name="vector_retriever",
217
+ description="λ‰΄μŠ€ 본문의 의미(λ‚΄μš©) μœ μ‚¬λ„ 기반 검색. AI κΈ°μˆ Β·μ„œλΉ„μŠ€ κ΄€λ ¨ ν…μŠ€νŠΈλ₯Ό 찾을 λ•Œ μ‚¬μš©.",
218
+ ),
219
+ vector_cypher_retriever.convert_to_tool(
220
+ name="vectorcypher_retriever",
221
+ description="벑터 검색 ν›„ ν•΄λ‹Ή κΈ°μ‚¬μ—μ„œ μ–ΈκΈ‰λœ κΈ°μ—…Β·κΈ°μˆ Β·μ„œλΉ„μŠ€ κ·Έλž˜ν”„λ₯Ό ν•¨κ»˜ λ°˜ν™˜. κΈ°μ—… AI νŠΈλ Œλ“œ 뢄석에 졜적.",
222
+ ),
223
+ text2cypher_retriever.convert_to_tool(
224
+ name="text2cypher_retriever",
225
+ description="μžμ—°μ–΄λ₯Ό Cypher둜 λ³€ν™˜. νŠΉμ • κΈ°μ—… μ„œλΉ„μŠ€ λͺ©λ‘, 기술 보유 κΈ°μ—… λ“± ꡬ쑰적 μ§ˆμ˜μ— μ‚¬μš©.",
226
+ ),
227
+ ],
228
+ )
229
+
230
+ hybrid_retriever = HybridFallbackRetriever(
231
+ tools_retriever=tools_retriever,
232
+ fallback_retriever=vector_cypher_retriever,
233
+ )
234
+
235
+ self._graphrag = GraphRAG(
236
+ llm=rag_llm,
237
+ retriever=hybrid_retriever,
238
+ prompt_template=_prompt_template,
239
+ )
240
+
241
+ def search(self, *args: Any, **kwargs: Any) -> Any:
242
+ self._init_once()
243
+ return self._graphrag.search(*args, **kwargs)
244
+
245
+ def __getattr__(self, name: str) -> Any:
246
+ self._init_once()
247
+ return getattr(self._graphrag, name)
248
+
249
+
250
+ # app.pyμ—μ„œ 이 객체λ₯Ό 직접 importν•˜μ—¬ μ‚¬μš©ν•©λ‹ˆλ‹€ (μ΄λ•ŒλŠ” DB 연결을 μ‹œλ„ν•˜μ§€ μ•ŠμŒ).
251
+ graphrag = LazyGraphRAG()