FinGraph / src /retrieval /finRetrieval.py
dev-yuje's picture
refactor: eliminate import-time DB connection anti-pattern with LazyGraphRAG proxy and local driver initialization
9675b2d
raw
history blame
9.56 kB
"""
finRetrieval.py β€” GraphRAG 검색 λͺ¨λ“ˆ
=====================================
app.pyμ—μ„œ importν•˜μ—¬ Gradio 챗봇과 μ—°λ™ν•©λ‹ˆλ‹€.
μ‚¬μš©λ²•:
from src.retrieval.finRetrieval import graphrag
response = graphrag.search(query_text="μ‚Όμ„±μ „μž AI μ„œλΉ„μŠ€λŠ”?")
print(response.answer)
"""
import os
import dotenv
import neo4j
from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings
from neo4j_graphrag.generation import GraphRAG, RagTemplate
from neo4j_graphrag.llm import OpenAILLM
from neo4j_graphrag.retrievers import (
Text2CypherRetriever,
ToolsRetriever,
VectorCypherRetriever,
VectorRetriever,
)
dotenv.load_dotenv()
def get_neo4j_driver() -> neo4j.Driver:
uri = os.getenv("NEO4J_URI", "neo4j://localhost:7687")
client_id = os.getenv("NEO4J_CLIENT_ID")
client_secret = os.getenv("NEO4J_CLIENT_SECRET")
if client_id and client_secret:
try:
d = neo4j.GraphDatabase.driver(uri, auth=(client_id, client_secret))
d.verify_connectivity()
return d
except Exception:
pass # Fallback to Username/Password
username = os.getenv("NEO4J_USERNAME", "neo4j")
password = os.getenv("NEO4J_PASSWORD", "password")
d = neo4j.GraphDatabase.driver(uri, auth=(username, password))
d.verify_connectivity()
return d
rag_llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0})
embedder = OpenAIEmbeddings(model="text-embedding-3-small")
INDEX_NAME = "content_vector_index"
# ──────────────────────────────────────────
# 2. Retriever κ΄€λ ¨ μƒμˆ˜ 및 μ„€μ •
# ──────────────────────────────────────────
_retrieval_query = """
MATCH (content:Content)<-[:HAS_CHUNK]-(article:Article)
OPTIONAL MATCH (article)-[:MENTIONS]->(company:AICompany)
OPTIONAL MATCH (company)-[:DEVELOPS]->(tech:AITechnology)
OPTIONAL MATCH (company)-[:DEVELOPS]->(svc:AIService)
OPTIONAL MATCH (article)-[:MENTIONS]->(field:AIField)
RETURN
content.chunk AS chunk,
article.title AS article_title,
article.url AS article_url,
article.published_date AS article_date,
collect(DISTINCT company.name) AS companies,
collect(DISTINCT tech.name) AS technologies,
collect(DISTINCT svc.name) AS services,
collect(DISTINCT field.name) AS fields
ORDER BY article.published_date DESC
LIMIT 3
"""
def _get_schema(driver: neo4j.Driver) -> str:
with driver.session() as s:
nodes = s.run(
"CALL db.schema.nodeTypeProperties() "
"YIELD nodeType, propertyName "
"RETURN nodeType, collect(propertyName) as props"
).data()
rels = s.run(
"MATCH (n)-[r]->(m) RETURN DISTINCT labels(n)[0] as src, type(r) as rel, labels(m)[0] as tgt LIMIT 30"
).data()
txt = "=== Neo4j Schema ===\nλ…Έλ“œ:\n"
for n in nodes:
txt += f"- {n['nodeType']}: {n['props']}\n"
txt += "\n관계:\n"
for r in rels:
txt += f"- ({r['src']})-[:{r['rel']}]->({r['tgt']})\n"
return txt
_examples = [
"""USER INPUT: 카카였의 AI μ„œλΉ„μŠ€ λͺ©λ‘μ„ μ•Œλ €μ£Όμ„Έμš”
CYPHER QUERY:
MATCH (c:AICompany {name:"카카였"})-[:DEVELOPS]->(s:AIService)
RETURN s.name, s.description""",
"""USER INPUT: μ‚Όμ„±μ „μžκ°€ 개발 쀑인 AI κΈ°μˆ μ€?
CYPHER QUERY:
MATCH (c:AICompany {name:"μ‚Όμ„±μ „μž"})-[:DEVELOPS]->(t:AITechnology)
RETURN t.name, t.description""",
"""USER INPUT: μ–΄λ–€ 기업이 LLM κΈ°μˆ μ„ κ°œλ°œν•˜λ‚˜μš”?
CYPHER QUERY:
MATCH (c:AICompany)-[:DEVELOPS]->(t:AITechnology)
WHERE t.name CONTAINS "μ–Έμ–΄λͺ¨λΈ" OR t.name CONTAINS "LLM"
RETURN c.name, t.name""",
"""USER INPUT: κΈˆμœ΅μ΄λ‚˜ ν•€ν…Œν¬ 뢄야에 κΈ°μˆ μ„ μ μš©ν•˜κ³  μžˆλŠ” 기업듀은 μ–΄λ””μ•Ό?
CYPHER QUERY:
MATCH (c:AICompany)-[:DEVELOPS]->(t)-[:USED_IN]->(f:AIField)
WHERE f.name CONTAINS "금육" OR f.name CONTAINS "ν•€ν…Œν¬"
RETURN DISTINCT c.name, t.name, f.name""",
"""USER INPUT: 금육AI 뢄야에 κ°€μž₯ 적극적인 κΈ°μ—… TOP 3와 λŒ€ν‘œ μ„œλΉ„μŠ€
CYPHER QUERY:
MATCH (c:AICompany)-[:DEVELOPS]->(s)-[:USED_IN]->(f:AIField)
WHERE f.name CONTAINS "금육" OR f.name CONTAINS "ν•€ν…Œν¬"
RETURN DISTINCT c.name, s.name, f.name
LIMIT 3""",
]
# ──────────────────────────────────────────
# 3. ToolsRetriever + GraphRAG 쑰립
# ──────────────────────────────────────────
from typing import Any
from neo4j_graphrag.retrievers.base import Retriever
from neo4j_graphrag.types import RawSearchResult, RetrieverResult
class HybridFallbackRetriever(Retriever):
VERIFY_NEO4J_VERSION = False
def __init__(self, tools_retriever: Retriever, fallback_retriever: Retriever) -> None:
self.tools_retriever = tools_retriever
self.fallback_retriever = fallback_retriever
super().__init__(driver=tools_retriever.driver)
def get_search_results(self, *args: Any, **kwargs: Any) -> RawSearchResult:
return RawSearchResult(records=[])
def search(self, query_text: str = "", **kwargs: Any) -> RetrieverResult:
res = self.tools_retriever.search(query_text=query_text, **kwargs)
if not res or not res.items:
return self.fallback_retriever.search(query_text=query_text, **kwargs)
return res
class CustomRagTemplate(RagTemplate):
EXPECTED_INPUTS = ["context", "query_text"]
def format(self, query_text: str, context: str, examples: str = "") -> str:
return self._format(query_text=query_text, context=context)
_prompt_template = CustomRagTemplate(
template="""당신은 AI 기술 νŠΈλ Œλ“œ 뢄석 μ „λ¬Έκ°€μž…λ‹ˆλ‹€.
λ°˜λ“œμ‹œ μ•„λž˜ 제곡된 [μ»¨ν…μŠ€νŠΈ(Neo4j 지식 κ·Έλž˜ν”„ 검색 κ²°κ³Ό)]에 κΈ°λ°˜ν•΄μ„œλ§Œ λ‹΅λ³€ν•˜μ„Έμš”.
⚠️ [μ—„κ²©ν•œ μ£Όμ˜μ‚¬ν•­]
1. μ»¨ν…μŠ€νŠΈμ— μ—†λŠ” κΈ°μ—…, μ„œλΉ„μŠ€, 기술, ν•΄μ™Έ κΈ°μ—…(JPλͺ¨κ±΄ λ“±)은 μ ˆλŒ€ μ–ΈκΈ‰ν•˜μ§€ λ§ˆμ„Έμš”.
2. μ§ˆλ¬Έμ— ν•΄λ‹Ήν•˜λŠ” 정보가 μ»¨ν…μŠ€νŠΈμ— μ—†λ‹€λ©΄ μ§€μ–΄λ‚΄μ§€ 말고, "ν˜„μž¬ μˆ˜μ§‘λœ μ΅œμ‹  λ‰΄μŠ€ λ°μ΄ν„°μ—λŠ” κ΄€λ ¨ 정보가 μ—†μŠ΅λ‹ˆλ‹€"라고 μ •μ§ν•˜κ²Œ λ‹΅λ³€ν•˜μ„Έμš”.
3. 근거둜 μ œμ‹œν•  URL은 였직 μ»¨ν…μŠ€νŠΈμ— ν¬ν•¨λœ μ‹€μ œ κΈ°μ‚¬μ˜ URL만 μ‚¬μš©ν•˜λ©°, 'example.com' 같은 κ°€μ§œ λ§ν¬λŠ” μ ˆλŒ€ μƒμ„±ν•˜μ§€ λ§ˆμ„Έμš”.
4. μ·¨μ—… 쀀비생이 κΈ°μ—… 지원 동기λ₯Ό μž‘μ„±ν•  수 μžˆλ„λ‘, μ»¨ν…μŠ€νŠΈμ— μžˆλŠ” 팩트λ₯Ό 기반으둜 ꡬ체적이고 μ „λ¬Έμ μœΌλ‘œ λ‹΅λ³€ν•˜μ„Έμš”.
질문: {query_text}
[μ»¨ν…μŠ€νŠΈ]
{context}
λ‹΅λ³€:""",
expected_inputs=["context", "query_text"]
)
class LazyGraphRAG:
"""μž„ν¬νŠΈ μ‹œμ μ— DB 연결을 λ°©μ§€ν•˜κ³  μ‹€μ œ 호좜될 λ•Œ GraphRAG μΈμŠ€ν„΄μŠ€λ₯Ό μ΄ˆκΈ°ν™”ν•˜λŠ” μ§€μ—° 평가 ν”„λ‘μ‹œ"""
def __init__(self) -> None:
self._graphrag = None
def _init_once(self) -> None:
if self._graphrag is not None:
return
driver = get_neo4j_driver()
vector_retriever = VectorRetriever(
driver=driver,
index_name=INDEX_NAME,
embedder=embedder,
)
vector_cypher_retriever = VectorCypherRetriever(
driver=driver,
index_name=INDEX_NAME,
retrieval_query=_retrieval_query,
embedder=embedder,
)
text2cypher_retriever = Text2CypherRetriever(
driver=driver,
llm=rag_llm,
neo4j_schema=_get_schema(driver),
examples=_examples,
)
tools_retriever = ToolsRetriever(
driver=driver,
llm=rag_llm,
tools=[
vector_retriever.convert_to_tool(
name="vector_retriever",
description="λ‰΄μŠ€ 본문의 의미(λ‚΄μš©) μœ μ‚¬λ„ 기반 검색. AI κΈ°μˆ Β·μ„œλΉ„μŠ€ κ΄€λ ¨ ν…μŠ€νŠΈλ₯Ό 찾을 λ•Œ μ‚¬μš©.",
),
vector_cypher_retriever.convert_to_tool(
name="vectorcypher_retriever",
description="벑터 검색 ν›„ ν•΄λ‹Ή κΈ°μ‚¬μ—μ„œ μ–ΈκΈ‰λœ κΈ°μ—…Β·κΈ°μˆ Β·μ„œλΉ„μŠ€ κ·Έλž˜ν”„λ₯Ό ν•¨κ»˜ λ°˜ν™˜. κΈ°μ—… AI νŠΈλ Œλ“œ 뢄석에 졜적.",
),
text2cypher_retriever.convert_to_tool(
name="text2cypher_retriever",
description="μžμ—°μ–΄λ₯Ό Cypher둜 λ³€ν™˜. νŠΉμ • κΈ°μ—… μ„œλΉ„μŠ€ λͺ©λ‘, 기술 보유 κΈ°μ—… λ“± ꡬ쑰적 μ§ˆμ˜μ— μ‚¬μš©.",
),
],
)
hybrid_retriever = HybridFallbackRetriever(
tools_retriever=tools_retriever,
fallback_retriever=vector_cypher_retriever,
)
self._graphrag = GraphRAG(
llm=rag_llm,
retriever=hybrid_retriever,
prompt_template=_prompt_template,
)
def search(self, *args: Any, **kwargs: Any) -> Any:
self._init_once()
return self._graphrag.search(*args, **kwargs)
def __getattr__(self, name: str) -> Any:
self._init_once()
return getattr(self._graphrag, name)
# app.pyμ—μ„œ 이 객체λ₯Ό 직접 importν•˜μ—¬ μ‚¬μš©ν•©λ‹ˆλ‹€ (μ΄λ•ŒλŠ” DB 연결을 μ‹œλ„ν•˜μ§€ μ•ŠμŒ).
graphrag = LazyGraphRAG()