refactor: eliminate import-time DB connection anti-pattern with LazyGraphRAG proxy and local driver initialization
Browse files- src/graphBuilder/neo4j/finGraph.py +7 -11
- src/retrieval/finRetrieval.py +78 -73
src/graphBuilder/neo4j/finGraph.py
CHANGED
|
@@ -42,19 +42,12 @@ def get_neo4j_driver() -> neo4j.Driver:
|
|
| 42 |
|
| 43 |
username = os.getenv("NEO4J_USERNAME", "neo4j")
|
| 44 |
password = os.getenv("NEO4J_PASSWORD", "password")
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
return d
|
| 49 |
-
except Exception as e:
|
| 50 |
-
import sys
|
| 51 |
-
if "pytest" in sys.modules or os.getenv("GITHUB_ACTIONS") == "true":
|
| 52 |
-
print(f"β οΈ [TEST/CI ENVIRONMENT] Neo4j connection failed at import time: {e}. (Proceeding with dummy None driver)")
|
| 53 |
-
return None
|
| 54 |
-
raise e
|
| 55 |
|
| 56 |
|
| 57 |
-
driver =
|
| 58 |
|
| 59 |
chat_llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
|
| 60 |
rag_llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0})
|
|
@@ -265,6 +258,9 @@ def is_article_loaded(tx, aid: str) -> bool:
|
|
| 265 |
|
| 266 |
|
| 267 |
def main() -> None:
|
|
|
|
|
|
|
|
|
|
| 268 |
# 1. λͺ¨λ μμ
νμΌ λ‘λ ν λ³ν© λ° κ³ μ κΈ°μ¬λ§ νν°λ§
|
| 269 |
xlsx_files = sorted(glob.glob("Articles_*.xlsx"))
|
| 270 |
if not xlsx_files:
|
|
|
|
| 42 |
|
| 43 |
username = os.getenv("NEO4J_USERNAME", "neo4j")
|
| 44 |
password = os.getenv("NEO4J_PASSWORD", "password")
|
| 45 |
+
d = neo4j.GraphDatabase.driver(uri, auth=(username, password))
|
| 46 |
+
d.verify_connectivity()
|
| 47 |
+
return d
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
|
| 50 |
+
driver = None
|
| 51 |
|
| 52 |
chat_llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
|
| 53 |
rag_llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0})
|
|
|
|
| 258 |
|
| 259 |
|
| 260 |
def main() -> None:
|
| 261 |
+
global driver
|
| 262 |
+
driver = get_neo4j_driver()
|
| 263 |
+
|
| 264 |
# 1. λͺ¨λ μμ
νμΌ λ‘λ ν λ³ν© λ° κ³ μ κΈ°μ¬λ§ νν°λ§
|
| 265 |
xlsx_files = sorted(glob.glob("Articles_*.xlsx"))
|
| 266 |
if not xlsx_files:
|
src/retrieval/finRetrieval.py
CHANGED
|
@@ -26,10 +26,6 @@ from neo4j_graphrag.retrievers import (
|
|
| 26 |
|
| 27 |
dotenv.load_dotenv()
|
| 28 |
|
| 29 |
-
# ββββββββββββββββββββββββββββββββββββββββββ
|
| 30 |
-
# 1. DB / LLM / Embedder μ΄κΈ°ν
|
| 31 |
-
# ββββββββββββββββββββββββββββββββββββββββββ
|
| 32 |
-
|
| 33 |
|
| 34 |
def get_neo4j_driver() -> neo4j.Driver:
|
| 35 |
uri = os.getenv("NEO4J_URI", "neo4j://localhost:7687")
|
|
@@ -46,37 +42,20 @@ def get_neo4j_driver() -> neo4j.Driver:
|
|
| 46 |
|
| 47 |
username = os.getenv("NEO4J_USERNAME", "neo4j")
|
| 48 |
password = os.getenv("NEO4J_PASSWORD", "password")
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
return d
|
| 53 |
-
except Exception as e:
|
| 54 |
-
import sys
|
| 55 |
-
if "pytest" in sys.modules or os.getenv("GITHUB_ACTIONS") == "true":
|
| 56 |
-
print(f"β οΈ [TEST/CI ENVIRONMENT] Neo4j connection failed at import time: {e}. (Proceeding with dummy None driver)")
|
| 57 |
-
return None
|
| 58 |
-
raise e
|
| 59 |
|
| 60 |
|
| 61 |
-
driver = get_neo4j_driver()
|
| 62 |
-
|
| 63 |
rag_llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0})
|
| 64 |
embedder = OpenAIEmbeddings(model="text-embedding-3-small")
|
| 65 |
|
| 66 |
INDEX_NAME = "content_vector_index"
|
| 67 |
|
| 68 |
# ββββββββββββββββββββββββββββββββββββββββββ
|
| 69 |
-
# 2. Retriever
|
| 70 |
# ββββββββββββββββββββββββββββββββββββββββββ
|
| 71 |
|
| 72 |
-
# (1) λ³Έλ¬Έ μ²ν¬ μλ―Έ μ μ¬λ κ²μ
|
| 73 |
-
vector_retriever = VectorRetriever(
|
| 74 |
-
driver=driver,
|
| 75 |
-
index_name=INDEX_NAME,
|
| 76 |
-
embedder=embedder,
|
| 77 |
-
)
|
| 78 |
-
|
| 79 |
-
# (2) λ²‘ν° κ²μ ν κ·Έλν νμ (κΈ°μ
Β·κΈ°μ Β·μλΉμ€ ν¨κ» λ°ν)
|
| 80 |
_retrieval_query = """
|
| 81 |
MATCH (content:Content)<-[:HAS_CHUNK]-(article:Article)
|
| 82 |
OPTIONAL MATCH (article)-[:MENTIONS]->(company:AICompany)
|
|
@@ -96,16 +75,8 @@ ORDER BY article.published_date DESC
|
|
| 96 |
LIMIT 3
|
| 97 |
"""
|
| 98 |
|
| 99 |
-
vector_cypher_retriever = VectorCypherRetriever(
|
| 100 |
-
driver=driver,
|
| 101 |
-
index_name=INDEX_NAME,
|
| 102 |
-
retrieval_query=_retrieval_query,
|
| 103 |
-
embedder=embedder,
|
| 104 |
-
)
|
| 105 |
-
|
| 106 |
|
| 107 |
-
|
| 108 |
-
def _get_schema() -> str:
|
| 109 |
with driver.session() as s:
|
| 110 |
nodes = s.run(
|
| 111 |
"CALL db.schema.nodeTypeProperties() "
|
|
@@ -151,36 +122,10 @@ CYPHER QUERY:
|
|
| 151 |
LIMIT 3""",
|
| 152 |
]
|
| 153 |
|
| 154 |
-
text2cypher_retriever = Text2CypherRetriever(
|
| 155 |
-
driver=driver,
|
| 156 |
-
llm=rag_llm,
|
| 157 |
-
neo4j_schema=_get_schema(),
|
| 158 |
-
examples=_examples,
|
| 159 |
-
)
|
| 160 |
-
|
| 161 |
# ββββββββββββββββββββββββββββββββββββββββββ
|
| 162 |
# 3. ToolsRetriever + GraphRAG 쑰립
|
| 163 |
# ββββββββββββββββββββββββββββββββββββββββββ
|
| 164 |
|
| 165 |
-
tools_retriever = ToolsRetriever(
|
| 166 |
-
driver=driver,
|
| 167 |
-
llm=rag_llm,
|
| 168 |
-
tools=[
|
| 169 |
-
vector_retriever.convert_to_tool(
|
| 170 |
-
name="vector_retriever",
|
| 171 |
-
description="λ΄μ€ λ³Έλ¬Έμ μλ―Έ(λ΄μ©) μ μ¬λ κΈ°λ° κ²μ. AI κΈ°μ Β·μλΉμ€ κ΄λ ¨ ν
μ€νΈλ₯Ό μ°Ύμ λ μ¬μ©.",
|
| 172 |
-
),
|
| 173 |
-
vector_cypher_retriever.convert_to_tool(
|
| 174 |
-
name="vectorcypher_retriever",
|
| 175 |
-
description="λ²‘ν° κ²μ ν ν΄λΉ κΈ°μ¬μμ μΈκΈλ κΈ°μ
Β·κΈ°μ Β·μλΉμ€ κ·Έλνλ₯Ό ν¨κ» λ°ν. κΈ°μ
AI νΈλ λ λΆμμ μ΅μ .",
|
| 176 |
-
),
|
| 177 |
-
text2cypher_retriever.convert_to_tool(
|
| 178 |
-
name="text2cypher_retriever",
|
| 179 |
-
description="μμ°μ΄λ₯Ό Cypherλ‘ λ³ν. νΉμ κΈ°μ
μλΉμ€ λͺ©λ‘, κΈ°μ 보μ κΈ°μ
λ± κ΅¬μ‘°μ μ§μμ μ¬μ©.",
|
| 180 |
-
),
|
| 181 |
-
],
|
| 182 |
-
)
|
| 183 |
-
|
| 184 |
from typing import Any
|
| 185 |
|
| 186 |
from neo4j_graphrag.retrievers.base import Retriever
|
|
@@ -205,13 +150,6 @@ class HybridFallbackRetriever(Retriever):
|
|
| 205 |
return res
|
| 206 |
|
| 207 |
|
| 208 |
-
# νμ΄λΈλ¦¬λ κ²μ μΈμ€ν΄μ€ μ₯μ°©
|
| 209 |
-
hybrid_retriever = HybridFallbackRetriever(
|
| 210 |
-
tools_retriever=tools_retriever,
|
| 211 |
-
fallback_retriever=vector_cypher_retriever,
|
| 212 |
-
)
|
| 213 |
-
|
| 214 |
-
|
| 215 |
class CustomRagTemplate(RagTemplate):
|
| 216 |
EXPECTED_INPUTS = ["context", "query_text"]
|
| 217 |
|
|
@@ -238,9 +176,76 @@ _prompt_template = CustomRagTemplate(
|
|
| 238 |
expected_inputs=["context", "query_text"]
|
| 239 |
)
|
| 240 |
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
dotenv.load_dotenv()
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
def get_neo4j_driver() -> neo4j.Driver:
|
| 31 |
uri = os.getenv("NEO4J_URI", "neo4j://localhost:7687")
|
|
|
|
| 42 |
|
| 43 |
username = os.getenv("NEO4J_USERNAME", "neo4j")
|
| 44 |
password = os.getenv("NEO4J_PASSWORD", "password")
|
| 45 |
+
d = neo4j.GraphDatabase.driver(uri, auth=(username, password))
|
| 46 |
+
d.verify_connectivity()
|
| 47 |
+
return d
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
|
|
|
|
|
|
|
| 50 |
rag_llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0})
|
| 51 |
embedder = OpenAIEmbeddings(model="text-embedding-3-small")
|
| 52 |
|
| 53 |
INDEX_NAME = "content_vector_index"
|
| 54 |
|
| 55 |
# ββββββββββββββββββββββββββββββββββββββββββ
|
| 56 |
+
# 2. Retriever κ΄λ ¨ μμ λ° μ€μ
|
| 57 |
# ββββββββββββββββββββββββββββββββββββββββββ
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
_retrieval_query = """
|
| 60 |
MATCH (content:Content)<-[:HAS_CHUNK]-(article:Article)
|
| 61 |
OPTIONAL MATCH (article)-[:MENTIONS]->(company:AICompany)
|
|
|
|
| 75 |
LIMIT 3
|
| 76 |
"""
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
+
def _get_schema(driver: neo4j.Driver) -> str:
|
|
|
|
| 80 |
with driver.session() as s:
|
| 81 |
nodes = s.run(
|
| 82 |
"CALL db.schema.nodeTypeProperties() "
|
|
|
|
| 122 |
LIMIT 3""",
|
| 123 |
]
|
| 124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
# ββββββββββββββββββββββββββββββββββββββββββ
|
| 126 |
# 3. ToolsRetriever + GraphRAG 쑰립
|
| 127 |
# ββββββββββββββββββββββββββββββββββββββββββ
|
| 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
from typing import Any
|
| 130 |
|
| 131 |
from neo4j_graphrag.retrievers.base import Retriever
|
|
|
|
| 150 |
return res
|
| 151 |
|
| 152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
class CustomRagTemplate(RagTemplate):
|
| 154 |
EXPECTED_INPUTS = ["context", "query_text"]
|
| 155 |
|
|
|
|
| 176 |
expected_inputs=["context", "query_text"]
|
| 177 |
)
|
| 178 |
|
| 179 |
+
|
| 180 |
+
class LazyGraphRAG:
|
| 181 |
+
"""μν¬νΈ μμ μ DB μ°κ²°μ λ°©μ§νκ³ μ€μ νΈμΆλ λ GraphRAG μΈμ€ν΄μ€λ₯Ό μ΄κΈ°ννλ μ§μ° νκ° νλ‘μ"""
|
| 182 |
+
def __init__(self) -> None:
|
| 183 |
+
self._graphrag = None
|
| 184 |
+
|
| 185 |
+
def _init_once(self) -> None:
|
| 186 |
+
if self._graphrag is not None:
|
| 187 |
+
return
|
| 188 |
+
|
| 189 |
+
driver = get_neo4j_driver()
|
| 190 |
+
|
| 191 |
+
vector_retriever = VectorRetriever(
|
| 192 |
+
driver=driver,
|
| 193 |
+
index_name=INDEX_NAME,
|
| 194 |
+
embedder=embedder,
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
vector_cypher_retriever = VectorCypherRetriever(
|
| 198 |
+
driver=driver,
|
| 199 |
+
index_name=INDEX_NAME,
|
| 200 |
+
retrieval_query=_retrieval_query,
|
| 201 |
+
embedder=embedder,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
text2cypher_retriever = Text2CypherRetriever(
|
| 205 |
+
driver=driver,
|
| 206 |
+
llm=rag_llm,
|
| 207 |
+
neo4j_schema=_get_schema(driver),
|
| 208 |
+
examples=_examples,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
tools_retriever = ToolsRetriever(
|
| 212 |
+
driver=driver,
|
| 213 |
+
llm=rag_llm,
|
| 214 |
+
tools=[
|
| 215 |
+
vector_retriever.convert_to_tool(
|
| 216 |
+
name="vector_retriever",
|
| 217 |
+
description="λ΄μ€ λ³Έλ¬Έμ μλ―Έ(λ΄μ©) μ μ¬λ κΈ°λ° κ²μ. AI κΈ°μ Β·μλΉμ€ κ΄λ ¨ ν
μ€νΈλ₯Ό μ°Ύμ λ μ¬μ©.",
|
| 218 |
+
),
|
| 219 |
+
vector_cypher_retriever.convert_to_tool(
|
| 220 |
+
name="vectorcypher_retriever",
|
| 221 |
+
description="λ²‘ν° κ²μ ν ν΄λΉ κΈ°μ¬μμ μΈκΈλ κΈ°μ
Β·κΈ°μ Β·μλΉμ€ κ·Έλνλ₯Ό ν¨κ» λ°ν. κΈ°μ
AI νΈλ λ λΆμμ μ΅μ .",
|
| 222 |
+
),
|
| 223 |
+
text2cypher_retriever.convert_to_tool(
|
| 224 |
+
name="text2cypher_retriever",
|
| 225 |
+
description="μμ°μ΄λ₯Ό Cypherλ‘ λ³ν. νΉμ κΈ°μ
μλΉμ€ λͺ©λ‘, κΈ°μ 보μ κΈ°μ
λ± κ΅¬μ‘°μ μ§μμ μ¬μ©.",
|
| 226 |
+
),
|
| 227 |
+
],
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
hybrid_retriever = HybridFallbackRetriever(
|
| 231 |
+
tools_retriever=tools_retriever,
|
| 232 |
+
fallback_retriever=vector_cypher_retriever,
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
self._graphrag = GraphRAG(
|
| 236 |
+
llm=rag_llm,
|
| 237 |
+
retriever=hybrid_retriever,
|
| 238 |
+
prompt_template=_prompt_template,
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
def search(self, *args: Any, **kwargs: Any) -> Any:
|
| 242 |
+
self._init_once()
|
| 243 |
+
return self._graphrag.search(*args, **kwargs)
|
| 244 |
+
|
| 245 |
+
def __getattr__(self, name: str) -> Any:
|
| 246 |
+
self._init_once()
|
| 247 |
+
return getattr(self._graphrag, name)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
# app.pyμμ μ΄ κ°μ²΄λ₯Ό μ§μ importνμ¬ μ¬μ©ν©λλ€ (μ΄λλ DB μ°κ²°μ μλνμ§ μμ).
|
| 251 |
+
graphrag = LazyGraphRAG()
|