| """ |
| finRetrieval.py β GraphRAG κ²μ λͺ¨λ |
| ===================================== |
| app.pyμμ importνμ¬ Gradio μ±λ΄κ³Ό μ°λν©λλ€. |
| |
| μ¬μ©λ²: |
| from src.retrieval.finRetrieval import graphrag |
| |
| response = graphrag.search(query_text="μΌμ±μ μ AI μλΉμ€λ?") |
| print(response.answer) |
| """ |
|
|
| import os |
|
|
| import dotenv |
| import neo4j |
| from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings |
| from neo4j_graphrag.generation import GraphRAG, RagTemplate |
| from neo4j_graphrag.llm import OpenAILLM |
| from neo4j_graphrag.retrievers import ( |
| Text2CypherRetriever, |
| ToolsRetriever, |
| VectorCypherRetriever, |
| VectorRetriever, |
| ) |
|
|
| dotenv.load_dotenv() |
|
|
|
|
| def get_neo4j_driver() -> neo4j.Driver: |
| uri = os.getenv("NEO4J_URI", "neo4j://localhost:7687") |
| client_id = os.getenv("NEO4J_CLIENT_ID") |
| client_secret = os.getenv("NEO4J_CLIENT_SECRET") |
| |
| if client_id and client_secret: |
| try: |
| d = neo4j.GraphDatabase.driver(uri, auth=(client_id, client_secret)) |
| d.verify_connectivity() |
| return d |
| except Exception: |
| pass |
| |
| username = os.getenv("NEO4J_USERNAME", "neo4j") |
| password = os.getenv("NEO4J_PASSWORD", "password") |
| d = neo4j.GraphDatabase.driver(uri, auth=(username, password)) |
| d.verify_connectivity() |
| return d |
|
|
|
|
| rag_llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0}) |
| embedder = OpenAIEmbeddings(model="text-embedding-3-small") |
|
|
| INDEX_NAME = "content_vector_index" |
|
|
| |
| |
| |
|
|
| _retrieval_query = """ |
| MATCH (content:Content)<-[:HAS_CHUNK]-(article:Article) |
| OPTIONAL MATCH (article)-[:MENTIONS]->(company:AICompany) |
| OPTIONAL MATCH (company)-[:DEVELOPS]->(tech:AITechnology) |
| OPTIONAL MATCH (company)-[:DEVELOPS]->(svc:AIService) |
| OPTIONAL MATCH (article)-[:MENTIONS]->(field:AIField) |
| RETURN |
| content.chunk AS chunk, |
| article.title AS article_title, |
| article.url AS article_url, |
| article.published_date AS article_date, |
| collect(DISTINCT company.name) AS companies, |
| collect(DISTINCT tech.name) AS technologies, |
| collect(DISTINCT svc.name) AS services, |
| collect(DISTINCT field.name) AS fields |
| ORDER BY article.published_date DESC |
| LIMIT 3 |
| """ |
|
|
|
|
| def _get_schema(driver: neo4j.Driver) -> str: |
| with driver.session() as s: |
| nodes = s.run( |
| "CALL db.schema.nodeTypeProperties() " |
| "YIELD nodeType, propertyName " |
| "RETURN nodeType, collect(propertyName) as props" |
| ).data() |
| rels = s.run( |
| "MATCH (n)-[r]->(m) RETURN DISTINCT labels(n)[0] as src, type(r) as rel, labels(m)[0] as tgt LIMIT 30" |
| ).data() |
| txt = "=== Neo4j Schema ===\nλ
Έλ:\n" |
| for n in nodes: |
| txt += f"- {n['nodeType']}: {n['props']}\n" |
| txt += "\nκ΄κ³:\n" |
| for r in rels: |
| txt += f"- ({r['src']})-[:{r['rel']}]->({r['tgt']})\n" |
| return txt |
|
|
|
|
| _examples = [ |
| """USER INPUT: μΉ΄μΉ΄μ€μ AI μλΉμ€ λͺ©λ‘μ μλ €μ£ΌμΈμ |
| CYPHER QUERY: |
| MATCH (c:AICompany {name:"μΉ΄μΉ΄μ€"})-[:DEVELOPS]->(s:AIService) |
| RETURN s.name, s.description""", |
| """USER INPUT: μΌμ±μ μκ° κ°λ° μ€μΈ AI κΈ°μ μ? |
| CYPHER QUERY: |
| MATCH (c:AICompany {name:"μΌμ±μ μ"})-[:DEVELOPS]->(t:AITechnology) |
| RETURN t.name, t.description""", |
| """USER INPUT: μ΄λ€ κΈ°μ
μ΄ LLM κΈ°μ μ κ°λ°νλμ? |
| CYPHER QUERY: |
| MATCH (c:AICompany)-[:DEVELOPS]->(t:AITechnology) |
| WHERE t.name CONTAINS "μΈμ΄λͺ¨λΈ" OR t.name CONTAINS "LLM" |
| RETURN c.name, t.name""", |
| """USER INPUT: κΈμ΅μ΄λ νν
ν¬ λΆμΌμ κΈ°μ μ μ μ©νκ³ μλ κΈ°μ
λ€μ μ΄λμΌ? |
| CYPHER QUERY: |
| MATCH (c:AICompany)-[:DEVELOPS]->(t)-[:USED_IN]->(f:AIField) |
| WHERE f.name CONTAINS "κΈμ΅" OR f.name CONTAINS "νν
ν¬" |
| RETURN DISTINCT c.name, t.name, f.name""", |
| """USER INPUT: κΈμ΅AI λΆμΌμ κ°μ₯ μ κ·Ήμ μΈ κΈ°μ
TOP 3μ λν μλΉμ€ |
| CYPHER QUERY: |
| MATCH (c:AICompany)-[:DEVELOPS]->(s)-[:USED_IN]->(f:AIField) |
| WHERE f.name CONTAINS "κΈμ΅" OR f.name CONTAINS "νν
ν¬" |
| RETURN DISTINCT c.name, s.name, f.name |
| LIMIT 3""", |
| ] |
|
|
| |
| |
| |
|
|
| from typing import Any |
|
|
| from neo4j_graphrag.retrievers.base import Retriever |
| from neo4j_graphrag.types import RawSearchResult, RetrieverResult |
|
|
|
|
| class HybridFallbackRetriever(Retriever): |
| VERIFY_NEO4J_VERSION = False |
|
|
| def __init__(self, tools_retriever: Retriever, fallback_retriever: Retriever) -> None: |
| self.tools_retriever = tools_retriever |
| self.fallback_retriever = fallback_retriever |
| super().__init__(driver=tools_retriever.driver) |
|
|
| def get_search_results(self, *args: Any, **kwargs: Any) -> RawSearchResult: |
| return RawSearchResult(records=[]) |
|
|
| def search(self, query_text: str = "", **kwargs: Any) -> RetrieverResult: |
| res = self.tools_retriever.search(query_text=query_text, **kwargs) |
| if not res or not res.items: |
| return self.fallback_retriever.search(query_text=query_text, **kwargs) |
| return res |
|
|
|
|
| class CustomRagTemplate(RagTemplate): |
| EXPECTED_INPUTS = ["context", "query_text"] |
|
|
| def format(self, query_text: str, context: str, examples: str = "") -> str: |
| return self._format(query_text=query_text, context=context) |
|
|
|
|
| _prompt_template = CustomRagTemplate( |
| template="""λΉμ μ AI κΈ°μ νΈλ λ λΆμ μ λ¬Έκ°μ
λλ€. |
| λ°λμ μλ μ 곡λ [컨ν
μ€νΈ(Neo4j μ§μ κ·Έλν κ²μ κ²°κ³Ό)]μ κΈ°λ°ν΄μλ§ λ΅λ³νμΈμ. |
| |
| β οΈ [μ격ν μ£Όμμ¬ν] |
| 1. 컨ν
μ€νΈμ μλ κΈ°μ
, μλΉμ€, κΈ°μ , ν΄μΈ κΈ°μ
(JPλͺ¨κ±΄ λ±)μ μ λ μΈκΈνμ§ λ§μΈμ. |
| 2. μ§λ¬Έμ ν΄λΉνλ μ λ³΄κ° μ»¨ν
μ€νΈμ μλ€λ©΄ μ§μ΄λ΄μ§ λ§κ³ , "νμ¬ μμ§λ μ΅μ λ΄μ€ λ°μ΄ν°μλ κ΄λ ¨ μ λ³΄κ° μμ΅λλ€"λΌκ³ μ μ§νκ² λ΅λ³νμΈμ. |
| 3. κ·Όκ±°λ‘ μ μν URLμ μ€μ§ 컨ν
μ€νΈμ ν¬ν¨λ μ€μ κΈ°μ¬μ URLλ§ μ¬μ©νλ©°, 'example.com' κ°μ κ°μ§ λ§ν¬λ μ λ μμ±νμ§ λ§μΈμ. |
| 4. μ·¨μ
μ€λΉμμ΄ κΈ°μ
μ§μ λκΈ°λ₯Ό μμ±ν μ μλλ‘, 컨ν
μ€νΈμ μλ ν©νΈλ₯Ό κΈ°λ°μΌλ‘ ꡬ체μ μ΄κ³ μ λ¬Έμ μΌλ‘ λ΅λ³νμΈμ. |
| |
| μ§λ¬Έ: {query_text} |
| |
| [컨ν
μ€νΈ] |
| {context} |
| |
| λ΅λ³:""", |
| expected_inputs=["context", "query_text"] |
| ) |
|
|
|
|
| class LazyGraphRAG: |
| """μν¬νΈ μμ μ DB μ°κ²°μ λ°©μ§νκ³ μ€μ νΈμΆλ λ GraphRAG μΈμ€ν΄μ€λ₯Ό μ΄κΈ°ννλ μ§μ° νκ° νλ‘μ""" |
| def __init__(self) -> None: |
| self._graphrag = None |
|
|
| def _init_once(self) -> None: |
| if self._graphrag is not None: |
| return |
| |
| driver = get_neo4j_driver() |
| |
| vector_retriever = VectorRetriever( |
| driver=driver, |
| index_name=INDEX_NAME, |
| embedder=embedder, |
| ) |
| |
| vector_cypher_retriever = VectorCypherRetriever( |
| driver=driver, |
| index_name=INDEX_NAME, |
| retrieval_query=_retrieval_query, |
| embedder=embedder, |
| ) |
| |
| text2cypher_retriever = Text2CypherRetriever( |
| driver=driver, |
| llm=rag_llm, |
| neo4j_schema=_get_schema(driver), |
| examples=_examples, |
| ) |
| |
| tools_retriever = ToolsRetriever( |
| driver=driver, |
| llm=rag_llm, |
| tools=[ |
| vector_retriever.convert_to_tool( |
| name="vector_retriever", |
| description="λ΄μ€ λ³Έλ¬Έμ μλ―Έ(λ΄μ©) μ μ¬λ κΈ°λ° κ²μ. AI κΈ°μ Β·μλΉμ€ κ΄λ ¨ ν
μ€νΈλ₯Ό μ°Ύμ λ μ¬μ©.", |
| ), |
| vector_cypher_retriever.convert_to_tool( |
| name="vectorcypher_retriever", |
| description="λ²‘ν° κ²μ ν ν΄λΉ κΈ°μ¬μμ μΈκΈλ κΈ°μ
Β·κΈ°μ Β·μλΉμ€ κ·Έλνλ₯Ό ν¨κ» λ°ν. κΈ°μ
AI νΈλ λ λΆμμ μ΅μ .", |
| ), |
| text2cypher_retriever.convert_to_tool( |
| name="text2cypher_retriever", |
| description="μμ°μ΄λ₯Ό Cypherλ‘ λ³ν. νΉμ κΈ°μ
μλΉμ€ λͺ©λ‘, κΈ°μ 보μ κΈ°μ
λ± κ΅¬μ‘°μ μ§μμ μ¬μ©.", |
| ), |
| ], |
| ) |
| |
| hybrid_retriever = HybridFallbackRetriever( |
| tools_retriever=tools_retriever, |
| fallback_retriever=vector_cypher_retriever, |
| ) |
| |
| self._graphrag = GraphRAG( |
| llm=rag_llm, |
| retriever=hybrid_retriever, |
| prompt_template=_prompt_template, |
| ) |
|
|
| def search(self, *args: Any, **kwargs: Any) -> Any: |
| self._init_once() |
| return self._graphrag.search(*args, **kwargs) |
|
|
| def __getattr__(self, name: str) -> Any: |
| self._init_once() |
| return getattr(self._graphrag, name) |
|
|
|
|
| |
| graphrag = LazyGraphRAG() |
|
|