| """ |
| finRetrieval.py โ GraphRAG ๊ฒ์ ๋ชจ๋ |
| ===================================== |
| app.py์์ importํ์ฌ Gradio ์ฑ๋ด๊ณผ ์ฐ๋ํฉ๋๋ค. |
| |
| ์ฌ์ฉ๋ฒ: |
| from src.retrieval.finRetrieval import graphrag |
| |
| response = graphrag.search(query_text="์ผ์ฑ์ ์ AI ์๋น์ค๋?") |
| print(response.answer) |
| """ |
|
|
| import logging |
| import os |
| from dataclasses import dataclass |
| from typing import Any |
|
|
| |
| logging.getLogger("neo4j").setLevel(logging.ERROR) |
| logging.getLogger("neo4j.notifications").setLevel(logging.ERROR) |
|
|
| import dotenv |
| import neo4j |
| from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings |
| from neo4j_graphrag.generation import GraphRAG, RagTemplate |
| from neo4j_graphrag.llm import OpenAILLM |
| from neo4j_graphrag.retrievers import ( |
| Text2CypherRetriever, |
| ToolsRetriever, |
| VectorCypherRetriever, |
| ) |
|
|
| dotenv.load_dotenv() |
|
|
|
|
| @dataclass |
| class HybridResult: |
| """GraphRAG ๋๋ ์ผ๋ฐ ์ง์ ๊ธฐ๋ฐ ํตํฉ ์๋ต ๊ฒฐ๊ณผ""" |
|
|
| answer: str |
| mode: str |
| retriever_result: Any = None |
|
|
|
|
| def get_neo4j_driver() -> neo4j.Driver: |
| uri = os.getenv("NEO4J_URI", "neo4j://localhost:7687") |
| client_id = os.getenv("NEO4J_CLIENT_ID") |
| client_secret = os.getenv("NEO4J_CLIENT_SECRET") |
| |
| if client_id and client_secret: |
| try: |
| d = neo4j.GraphDatabase.driver(uri, auth=(client_id, client_secret)) |
| d.verify_connectivity() |
| return d |
| except Exception: |
| pass |
| |
| username = os.getenv("NEO4J_USERNAME", "neo4j") |
| password = os.getenv("NEO4J_PASSWORD", "password") |
| d = neo4j.GraphDatabase.driver(uri, auth=(username, password)) |
| d.verify_connectivity() |
| return d |
|
|
|
|
| INDEX_NAME = "content_vector_index" |
|
|
| |
| |
| |
|
|
| _retrieval_query = """ |
| MATCH (node)<-[:HAS_CHUNK]-(article:Article) |
| OPTIONAL MATCH (article)-[:MENTIONS]->(company:AICompany) |
| OPTIONAL MATCH (company)-[:DEVELOPS]->(tech:AITechnology) |
| OPTIONAL MATCH (company)-[:DEVELOPS]->(svc:AIService) |
| OPTIONAL MATCH (article)-[:MENTIONS]->(field:AIField) |
| |
| // ๋์ผ ๊ธฐ์
/๊ธฐ์ /์๋น์ค๋ฅผ ์ธ๊ธํ๋ ๊ด๋ จ ๊ธฐ์ฌ๊น์ง ํ์ฅ ํ์ (ํก๋จ ๊ฒ์) |
| OPTIONAL MATCH (related_article:Article) |
| WHERE related_article <> article |
| AND ( |
| EXISTS { (related_article)-[:MENTIONS]->(:AICompany)<-[:MENTIONS]-(article) } |
| OR EXISTS { (related_article)-[:MENTIONS]->(:AITechnology)<-[:MENTIONS]-(article) } |
| OR EXISTS { (related_article)-[:MENTIONS]->(:AIService)<-[:MENTIONS]-(article) } |
| ) |
| WITH |
| node, article, company, tech, svc, field, |
| collect(DISTINCT related_article.title)[..3] AS related_titles, |
| collect(DISTINCT related_article.url)[..3] AS related_urls |
| RETURN |
| node.chunk AS chunk, |
| article.title AS article_title, |
| article.url AS article_url, |
| article.published_date AS article_date, |
| collect(DISTINCT company.name) AS companies, |
| collect(DISTINCT tech.name) AS technologies, |
| collect(DISTINCT svc.name) AS services, |
| collect(DISTINCT field.name) AS fields, |
| related_titles AS related_article_titles, |
| related_urls AS related_article_urls |
| """ |
|
|
|
|
| def _get_schema(driver: neo4j.Driver) -> str: |
| with driver.session() as s: |
| nodes = s.run( |
| "CALL db.schema.nodeTypeProperties() " |
| "YIELD nodeType, propertyName " |
| "RETURN nodeType, collect(propertyName) as props" |
| ).data() |
| rels = s.run( |
| "MATCH (n)-[r]->(m) RETURN DISTINCT labels(n)[0] as src, type(r) as rel, labels(m)[0] as tgt LIMIT 30" |
| ).data() |
| txt = "=== Neo4j Schema ===\n๋
ธ๋:\n" |
| for n in nodes: |
| txt += f"- {n['nodeType']}: {n['props']}\n" |
| txt += "\n๊ด๊ณ:\n" |
| for r in rels: |
| txt += f"- ({r['src']})-[:{r['rel']}]->({r['tgt']})\n" |
| return txt |
|
|
|
|
| _examples = [ |
| """USER INPUT: ์นด์นด์คํ์ด์ AI ์๋น์ค ๋ชฉ๋ก์ ์๋ ค์ฃผ์ธ์ |
| CYPHER QUERY: |
| MATCH (c:AICompany {name:"์นด์นด์คํ์ด"})-[:DEVELOPS]->(s:AIService) |
| OPTIONAL MATCH (a:Article)-[:MENTIONS]->(s) |
| RETURN s.name AS name, s.description AS description, a.title AS article_title, a.url AS article_url""", |
| """USER INPUT: ์ ํ์ํ์ด ๊ฐ๋ฐ ์ค์ธ AI ๊ธฐ์ ์? |
| CYPHER QUERY: |
| MATCH (c:AICompany {name:"์ ํ์ํ"})-[:DEVELOPS]->(t:AITechnology) |
| OPTIONAL MATCH (a:Article)-[:MENTIONS]->(t) |
| RETURN t.name AS name, t.description AS description, a.title AS article_title, a.url AS article_url""", |
| """USER INPUT: ์ด๋ค ๊ธ์ต์ฌ๊ฐ ๋ก๋ณด์ด๋๋ฐ์ด์ ๊ธฐ์ ์ ๊ฐ๋ฐํ๋์? |
| CYPHER QUERY: |
| MATCH (c:AICompany)-[:DEVELOPS]->(t:AITechnology) |
| WHERE t.name CONTAINS "๋ก๋ณด์ด๋๋ฐ์ด์ " OR t.name CONTAINS "์๊ณ ๋ฆฌ์ฆ" |
| OPTIONAL MATCH (a:Article)-[:MENTIONS]->(t) |
| RETURN c.name AS company_name, t.name AS tech_name, a.title AS article_title, a.url AS article_url""", |
| """USER INPUT: ๊ธ์ต์ด๋ ํํ
ํฌ ๋ถ์ผ์ ๊ธฐ์ ์ ์ ์ฉํ๊ณ ์๋ ๊ธฐ์
๋ค์ ์ด๋์ผ? |
| CYPHER QUERY: |
| MATCH (c:AICompany)-[:DEVELOPS]->(t)-[:USED_IN]->(f:AIField) |
| WHERE f.name CONTAINS "๊ธ์ต" OR f.name CONTAINS "ํํ
ํฌ" |
| OPTIONAL MATCH (a:Article)-[:MENTIONS]->(t) |
| RETURN DISTINCT c.name AS company_name, t.name AS tech_name, f.name AS field_name, a.title AS article_title, a.url AS article_url""", |
| """USER INPUT: ๊ธ์ตAI ๋ถ์ผ์ ๊ฐ์ฅ ์ ๊ทน์ ์ธ ๊ธฐ์
TOP 3์ ๋ํ ์๋น์ค |
| CYPHER QUERY: |
| MATCH (c:AICompany)-[:DEVELOPS]->(s)-[:USED_IN]->(f:AIField) |
| WHERE f.name CONTAINS "๊ธ์ต" OR f.name CONTAINS "ํํ
ํฌ" |
| OPTIONAL MATCH (a:Article)-[:MENTIONS]->(s) |
| RETURN DISTINCT c.name AS company_name, s.name AS service_name, f.name AS field_name, a.title AS article_title, a.url AS article_url |
| LIMIT 3""", |
| """USER INPUT: ์ต๊ทผ ๊ธ์ต AI ๊ด๋ จ ๋ด์ค ๊ธฐ์ฌ๋ฅผ ์์ฝํด์ค |
| CYPHER QUERY: |
| MATCH (a:Article)-[:HAS_CHUNK]->(c:Content) |
| RETURN a.title AS title, a.url AS url, a.published_date AS published_date, c.chunk AS chunk |
| ORDER BY a.published_date DESC |
| LIMIT 3""", |
| """USER INPUT: ์ต๊ทผ ๊ฐ์ฅ ๊ด์ฌ์ด ๋์ ๊ธ์ต AI ๊ธฐ์ ์ด ๋ญ์ผ? |
| CYPHER QUERY: |
| MATCH (a:Article)-[:MENTIONS]->(t:AITechnology) |
| OPTIONAL MATCH (c:AICompany)-[:DEVELOPS]->(t) |
| WITH t, count(DISTINCT a) AS article_count, collect(DISTINCT c.name)[..3] AS companies, collect(DISTINCT a.title)[..3] AS article_titles, collect(DISTINCT a.url)[..3] AS article_urls |
| ORDER BY article_count DESC |
| RETURN t.name AS tech_name, t.description AS description, article_count, companies, article_titles, article_urls |
| LIMIT 5""", |
| """USER INPUT: ๊ธ์ต AI ๊ธฐ์ ํธ๋ ๋๋ฅผ ๋ถ์ํด์ค |
| CYPHER QUERY: |
| MATCH (a:Article)-[:MENTIONS]->(t:AITechnology) |
| OPTIONAL MATCH (c:AICompany)-[:DEVELOPS]->(t) |
| WITH t, count(DISTINCT a) AS article_count, collect(DISTINCT c.name)[..3] AS companies, collect(DISTINCT a.title)[..2] AS article_titles, collect(DISTINCT a.url)[..2] AS article_urls |
| ORDER BY article_count DESC |
| RETURN t.name AS tech_name, article_count, companies, article_titles, article_urls |
| LIMIT 5""", |
| """USER INPUT: ํ ์ค ๋๋ ์นด์นด์คํ์ด ๊ด๋ จ ๊ธ์ต AI ๋ด์ค ์๋ ค์ค |
| CYPHER QUERY: |
| MATCH (a:Article)-[:MENTIONS]->(c:AICompany) |
| WHERE c.name CONTAINS 'ํ ์ค' OR c.name CONTAINS '์นด์นด์คํ์ด' |
| OPTIONAL MATCH (a)-[:MENTIONS]->(t:AITechnology) |
| OPTIONAL MATCH (a)-[:MENTIONS]->(s:AIService) |
| RETURN a.title AS article_title, a.url AS article_url, a.published_date AS article_date, |
| collect(DISTINCT c.name) AS companies, collect(DISTINCT t.name) AS technologies, collect(DISTINCT s.name) AS services |
| ORDER BY a.published_date DESC LIMIT 5""", |
| ] |
|
|
| |
| |
| |
|
|
|
|
| from neo4j_graphrag.retrievers.base import Retriever |
| from neo4j_graphrag.types import RawSearchResult, RetrieverResult |
|
|
|
|
| class HybridFallbackRetriever(Retriever): |
| VERIFY_NEO4J_VERSION = False |
|
|
| def __init__(self, tools_retriever: Retriever, fallback_retriever: Retriever) -> None: |
| self.tools_retriever = tools_retriever |
| self.fallback_retriever = fallback_retriever |
| super().__init__(driver=tools_retriever.driver) |
|
|
| def get_search_results(self, *args: Any, **kwargs: Any) -> RawSearchResult: |
| return RawSearchResult(records=[]) |
|
|
| def search(self, query_text: str = "", **kwargs: Any) -> RetrieverResult: |
| res = self.tools_retriever.search(query_text=query_text, **kwargs) |
| if not res or not res.items: |
| return self.fallback_retriever.search(query_text=query_text, **kwargs) |
| return res |
|
|
|
|
| class CustomRagTemplate(RagTemplate): |
| EXPECTED_INPUTS = ["context", "query_text"] |
|
|
| def format(self, query_text: str, context: str, examples: str = "") -> str: |
| |
| _ = examples |
| return self._format(query_text=query_text, context=context) |
|
|
|
|
| _prompt_template = CustomRagTemplate( |
| template="""๋น์ ์ AI ๋ฐ ํํ
ํฌ ๊ธฐ์ ํธ๋ ๋ ์ ๋ฌธ๊ฐ์ด์, ์ทจ์
์ค๋น์์ ์ญ๋ ๋ถ์์ ๋๋ ์ ๋ต ์ปจ์คํดํธ์
๋๋ค. |
| ๋ฐ๋์ ์๋ ์ ๊ณต๋ [์ปจํ
์คํธ(Neo4j ์ง์ ๊ทธ๋ํ ๊ฒ์ ๊ฒฐ๊ณผ)]์ ๊ธฐ๋ฐํด์๋ง ๋ต๋ณํ๊ณ , ์ปจํ
์คํธ์ ๊ทผ๊ฑฐํ์ง ์์ ์ฌ์ค์ ์ง์ด๋ด๊ฑฐ๋ ๊ฐ์์ ๋งํฌ(example.com ๋ฑ)๋ฅผ ์ ๋ ์์ฑํ์ง ๋ง์ธ์. |
| |
| ๋ต๋ณ์ ๋์ค์ด๋ ์ทจ์
์ค๋น์์ด ์ค์ง์ ์ผ๋ก ํธ๋ ๋๋ฅผ ๊น์ด ์๊ฒ ํ์
ํ๊ณ ์์์/๋ฉด์ ๋ฑ์ ์ฆ๊ฐ ํ์ฉํ ์ ์๋๋ก, ์๋์ [๊ณ ์ ๋ธ๋ฆฌํ ๋ณด๊ณ ์ ํฌ๋งท]์ **ํ ์จ ํ๋ ํ๋ฆฌ์ง ์๊ณ ์๊ฒฉํ ์ค์**ํ์ฌ ๋งค์ฐ ์ฒด๊ณ์ ์ด๊ณ ๊น๋ํ ๋งํฌ๋ค์ด ์์์ผ๋ก ์ ์ฑ์ค๋ฝ๊ฒ ๋ธ๋ฆฌํํด ์ฃผ์ธ์. |
| |
| โ
[์ค์ - ๊ฐ๋
์ฑ ๋ฐ ๊ฐํ ๊ท์น]: |
| ๊ฐ ์ฃผ์ ์น์
(###) ์ฌ์ด์๋ ๋ฌด์กฐ๊ฑด ๋น ์ค์ 2์ค ์ด์ ์ถ๊ฐํ๊ณ , ๋ชจ๋ ๊ฐ๋ณ ๋ชฉ๋ก ๊ธฐํธ(- ๋ฐ **) ํญ๋ชฉ ์ฌ์ด์ฌ์ด์๋ ๋ฐ๋์ 1์ค ์ด์์ ๋น ์ค(๊ฐํ)์ ์ฝ์
ํ์ฌ ์๊ฐ์ ๊ฐ๋
์ฑ์ ๊ทน๋ํํด ์ฃผ์ธ์. |
| |
| --- |
| |
| # ๐ [FinGraph AI ๋ถ์ ๋ธ๋ฆฌํ] |
| |
| ### 1. ๐ ํ ์ค ์์ฝ & ํต์ฌ ํธ๋ ๋ |
| |
| - **ํ ์ค ์์ฝ**: [ํด๋น ํธ๋ ๋์ ํต์ฌ ์์ ์ ๋จ ํ ์ค๋ก ๋ช
๋ฃํ๊ฒ ์์ฝ] |
| |
| - **์ฃผ์ ์ธ์ฌ์ดํธ**: [์ด ์ด์๊ฐ ํ์ฌ IT/AI ๋ฐ ๊ธ์ต ํํ
ํฌ ์
๊ณ ์ ์ฒด์ ๋์ง๋ ํต์ฌ ํ๋ ๊ธฐ์ฌ] |
| |
| |
| ### 2. ๐ ์์ธ ๋ถ์ ๋ฐ ํฉํธ ์ ๋ฆฌ |
| |
| [์ปจํ
์คํธ์ ๊ธฐ๋ก๋ ์ค์ ์ฌ์ค ๊ด๊ณ๋ค์ ๊ทผ๊ฑฐ๋ก ๊ตฌ์ฒด์ ์ฌ์ค์ ์ ๋ฆฌ] |
| |
| - **์ด์ ์ ๊ฐ**: [๊ตฌ์ฒด์ ์ธ ์ด์ ๋ฐ์ ๋ฐฐ๊ฒฝ ๋ฐ ์งํ ๊ฒฝ๊ณผ] |
| |
| - **๊ธฐ์
๋ํฅ**: [๊ด๋ จ ํต์ฌ ๊ธฐ์
๋ค์ ์ค๋ฌผ ๋น์ฆ๋์ค ์์ง์ ๋ฐ ๋์ ํ๋ณด. ์ปจํ
์คํธ์ ์ฌ๋ฌ ๊ธฐ์
/๊ธฐ์ ์ด ์๋ค๋ฉด ๋ชจ๋ ์ธ๊ธ] |
| |
| - **๊ธฐ์ ํธ๋ ๋**: [์ปจํ
์คํธ์ ๋ฑ์ฅํ๋ ํต์ฌ AI ๊ธฐ์ ๋ค์ ๋น๊ต/๋ถ๋ฅํ์ฌ ์ ์ฒด ํธ๋ ๋ ํ๋ฆ ๋ถ์] |
| |
| - **์ธํ๋ผ/์ฌํ์ ์์ธ**: [์ ๋ ฅ๋ง ๋ถ์กฑ, ๋์ค์ ๋ถ์๊ฐ, ํ๋์จ์ด์ ์ ์ฝ ์ฌํญ ๋ฑ ํต์ฌ ์์ธ] |
| |
| |
| ### 3. ๐ก ์ทจ์
/์์์/๋ฉด์ ์ค์ ๊ฐ์ด๋ |
| |
| [์ง์์๊ฐ ๋ฉด์ ์ด๋ ์๊ธฐ์๊ฐ์์์ ์ฐจ๋ณํ๋ ํต์ฐฐ์ ๋ณด์ฌ์ค ์ ์๋ ๋ฐฉ๋ฒ ์ ์] |
| |
| - **๊ธ์ต/IT ์
๊ณ ์์ฌ์ **: [๊ฑฐ์์ ์ธ ํ๊ธํจ๊ณผ์ ์ง์๊ฐ๋ฅ์ฑ ๊ด์ ์ ์] |
| |
| - **์ค์ ์์์/๋ฉด์ ํ์ฉ Tip**: [์ง์๋๊ธฐ๋ ์ญ๋ ๊ธฐ์ ์ ์์ฑ ์ ๋ณธ์ธ์ ์ญ๋๊ณผ ์ด๋ป๊ฒ ์ฐ๊ณํ์ฌ ํ์ด๋ผ์ง์ ๋ํ ๋ง์ถค ๊ฐ์ด๋] |
| |
| |
| ### ๐ฐ 4. ๊ทผ๊ฑฐ ๋ด์ค ์ถ์ฒ (GraphRAG ๊ฒ์ ๊ธฐ์ฌ) |
| |
| > ์ปจํ
์คํธ์ ์ค์ ๋ก ์กด์ฌํ๋ ๊ธฐ์ฌ URL๋ง ๊ธฐ์ฌํ๊ณ , ์กด์ฌํ์ง ์๋ ๊ธฐ์ฌ๋ ์ ๋ ์ง์ด๋ด์ง ๋ง์ธ์. |
| > ๊ฒ์๋ ๊ธฐ์ฌ๊ฐ ์๋ ๊ฒฝ์ฐ ์๋ ํ์์ผ๋ก ์ด๊ฑฐํ๊ณ , ์์ผ๋ฉด ์ด ์น์
์ ์๋ตํ์ธ์. |
| > |
| > ์์: |
| > - *[๊ธฐ์ฌ ์ ๋ชฉ](๊ธฐ์ฌ URL)* โ ๋ณด๋์ผ์ |
| |
| --- |
| |
| ์ง๋ฌธ: {query_text} |
| |
| [์ปจํ
์คํธ] |
| {context} |
| |
| ๋ต๋ณ:""", |
| expected_inputs=["context", "query_text"] |
| ) |
|
|
|
|
| class LazyGraphRAG: |
| """์ํฌํธ ์์ ์ DB ์ฐ๊ฒฐ์ ๋ฐฉ์งํ๊ณ ์ค์ ํธ์ถ๋ ๋ GraphRAG ์ธ์คํด์ค๋ฅผ ์ด๊ธฐํํ๋ ์ง์ฐ ํ๊ฐ ํ๋ก์""" |
|
|
| def __init__(self) -> None: |
| self._graphrag: Any = None |
| self._hybrid_retriever: Any = None |
| self._rag_llm: Any = None |
|
|
| def _init_once(self) -> None: |
| if self._graphrag is not None: |
| return |
| |
| |
| self._rag_llm = OpenAILLM(model_name="gpt-4o-mini", model_params={"temperature": 0}) |
| embedder = OpenAIEmbeddings(model="text-embedding-3-small") |
|
|
| driver = get_neo4j_driver() |
| |
| vector_cypher_retriever = VectorCypherRetriever( |
| driver=driver, |
| index_name=INDEX_NAME, |
| retrieval_query=_retrieval_query, |
| embedder=embedder, |
| ) |
| |
| text2cypher_retriever = Text2CypherRetriever( |
| driver=driver, |
| llm=self._rag_llm, |
| neo4j_schema=_get_schema(driver), |
| examples=_examples, |
| ) |
| |
| tools_retriever = ToolsRetriever( |
| driver=driver, |
| llm=self._rag_llm, |
| tools=[ |
| vector_cypher_retriever.convert_to_tool( |
| name="vector_retriever", |
| description=( |
| "๋ด์ค ๋ณธ๋ฌธ ์๋ฏธ ์ ์ฌ๋ ๊ธฐ๋ฐ ๊ฒ์ + ์ฐ๊ฒฐ๋ ์ํฐํฐ(๊ธฐ์
ยท๊ธฐ์ ยท์๋น์คยท๋ถ์ผ) ๊ด๊ณ ๊ทธ๋ํ ํ์. " |
| "ํน์ ์ฃผ์ /๊ธฐ์
/๊ธฐ์ ์ ๋ํด ๋ด์ค ๊ธฐ์ฌ ๋ฐ ๊ด๋ จ ๊ทธ๋ํ ๊ด๊ณ๋ฅผ ํจ๊ป ๋ถ์ํ ๋ ์ฌ์ฉ. " |
| "์: 'ํ๋์ฐจ AI ๋ด์ค', 'ํน์ ๊ธฐ์ ์ ์ ์ฉ ์ฌ๋ก'." |
| ), |
| ), |
| text2cypher_retriever.convert_to_tool( |
| name="text2cypher_retriever", |
| description=( |
| "์์ฐ์ด๋ฅผ Neo4j Cypher ์ฟผ๋ฆฌ๋ก ๋ณํํ์ฌ ๊ทธ๋ํ ๊ตฌ์กฐ๋ฅผ ์ง๊ณยทํ์. " |
| "'๊ฐ์ฅ ๋ง์ด ์ธ๊ธ๋ ๊ธฐ์ ', 'ํธ๋ ๋ ๋ถ์', 'ํน์ ๊ธฐ์
์ ์๋น์ค ๋ชฉ๋ก', " |
| "'์ด๋ค ๊ธฐ์
์ด X ๊ธฐ์ ์ ๊ฐ๋ฐํ๋', '์ต๊ทผ ๋ด์ค ์์ฝ' ๋ฑ " |
| "์ง๊ณ(COUNT/ORDER BY)๋ ๊ตฌ์กฐ์ ๊ด๊ณ ์ง์์ ๋ฐ๋์ ์ฌ์ฉ." |
| ), |
| ), |
| ], |
| ) |
|
|
| self._hybrid_retriever = HybridFallbackRetriever( |
| tools_retriever=tools_retriever, |
| fallback_retriever=vector_cypher_retriever, |
| ) |
|
|
| self._graphrag = GraphRAG( |
| llm=self._rag_llm, |
| retriever=self._hybrid_retriever, |
| prompt_template=_prompt_template, |
| ) |
|
|
| def _is_context_sufficient(self, query_text: str, history: list, retriever_result: Any) -> bool: |
| """๊ฒ์๋ ์ปจํ
์คํธ๊ฐ ์ง๋ฌธ ๋ฐ ์ด์ ๋ํ ํ๋ฆ์ ์ค์ง์ ์ผ๋ก ๋์์ด ๋๋ ๊ธ์ต/๊ธฐ์ ๋ด์ค ๋ฐ์ดํฐ์ธ์ง GPT-4o-mini๋ก ํ๋จ""" |
| if retriever_result is None: |
| return False |
| if not hasattr(retriever_result, "items") or not retriever_result.items: |
| return False |
| total_content = " ".join( |
| getattr(item, "content", "") for item in retriever_result.items |
| ).strip() |
| if len(total_content) < 100: |
| return False |
|
|
| |
| try: |
| assert self._rag_llm is not None |
| context_snippet = total_content[:800] |
|
|
| |
| normalized_history = self._normalize_history(history) |
| history_summary = "์์" |
| if normalized_history: |
| history_summary = "\n".join( |
| f"- {msg['role']}: {msg['content'][:150]}" |
| for msg in normalized_history[-3:] |
| ) |
|
|
| routing_prompt = ( |
| "๋น์ ์ ๊ธ์ต/๊ธฐ์ ํธ๋ ๋ RAG ์์คํ
์ ์ง๋ฅํ ๋ผ์ฐํฐ์
๋๋ค.\n" |
| "์ฌ์ฉ์์ [ํ์ฌ ์ง๋ฌธ] ๋ฐ [์ต๊ทผ ๋ํ ํ์คํ ๋ฆฌ]๊ฐ ์๋ ์ ๊ณต๋ [๊ฒ์๋ ๋ด์ค ๋ฐ์ดํฐ]์ ์๋ฏธ์ ์ผ๋ก ๋ฐ์ ํ๊ฒ ์ฐ๊ด๋์ด ์๊ณ , " |
| "ํด๋น ๋ฐ์ดํฐ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ง๋ฌธ์ ์ค์ ๊ตฌ์ฒด์ ์ด๊ณ ์ ๋ขฐํ ์ ์๋ ๋ต๋ณ์ ์ ๊ณตํ ์ ์๋์ง ํ๊ฐํ์ธ์.\n\n" |
| "ํนํ, ํ์ฌ ์ง๋ฌธ์ด '๊ทธ๊ฑฐ์ ๋ํด ์ข ๋ ์ค๋ช
ํด์ค'๋ '์์์ ํ์ ๋ ๋ค๋ฌ์ด์ค'์ ๊ฐ์ ํ์ ๋ํํ ์ง๋ฌธ์ผ ๊ฒฝ์ฐ, " |
| "[์ต๊ทผ ๋ํ ํ์คํ ๋ฆฌ]์ ๋ช
์๋ ์ฃผ์ ๊ธ์ต/๊ธฐ์ ํธ๋ ๋ ์ฃผ์ (์: ์ผ์ฑ์ ์ AI, ์นด์นด์ค AI ๋ฑ)๊ฐ " |
| "์๋ ๋ด์ค ๋ฐ์ดํฐ์ ํต์ฌ ๋ด์ฉ๊ณผ ์ผ์นํ๋์ง ์ข
ํฉ์ ์ผ๋ก ๊ณ ๋ คํด์ผ ํฉ๋๋ค.\n\n" |
| "๋ง์ฝ ์ง๋ฌธ ๋ฐ ๋ํ ๋งฅ๋ฝ์ด ์๋ ๋ด์ค ๋ฐ์ดํฐ์ ์ ํ ๋ฌด๊ดํ ์ผ๋ฐ ์์, ์ผ์์ ์ธ ๋ํ, ์ํ, ์์ ๋ฑ " |
| "์ง์ ๊ทธ๋ํ(๋ด์ค ๋ฐ์ดํฐ๋ฒ ์ด์ค)์ ์๋ ์ฃผ์ ์ ์ง๋ฌธ์ด๋ผ๋ฉด ๋ฐ๋์ 'NO'๋ผ๊ณ ๋ตํด์ผ ํฉ๋๋ค.\n" |
| "๋ด์ค ํฉํธ ๋ฐ์ดํฐ๋ฅผ ๊ฒฐํฉํ์ฌ ์ฌ๋ฐ๋ฅธ ๋ต๋ณ์ ์์ฑํ ์ ์๋ ๋งฅ๋ฝ์ด๋ผ๋ฉด 'YES', ๊ทธ๋ ์ง ์๋ค๋ฉด 'NO'๋ผ๊ณ ๋ง ๋ตํ์ธ์.\n\n" |
| f"[์ต๊ทผ ๋ํ ํ์คํ ๋ฆฌ]\n{history_summary}\n\n" |
| f"[ํ์ฌ ์ง๋ฌธ]\n{query_text}\n\n" |
| f"[๊ฒ์๋ ๋ด์ค ๋ฐ์ดํฐ]\n{context_snippet}\n\n" |
| "ํ์ (YES ๋๋ NO๋ก๋ง ๋ต๋ณ):" |
| ) |
| |
| response = self._rag_llm.invoke( |
| input=routing_prompt, |
| model_params={"temperature": 0, "max_tokens": 5} |
| ) |
| decision = str(response.content).strip().upper() |
| return "YES" in decision |
| except Exception: |
| |
| return len(total_content) >= 100 |
|
|
| def _normalize_history(self, history: list) -> list: |
| """Gradio ํ์คํ ๋ฆฌ(dict ๋๋ tuple ํ์)๋ฅผ LLM message_history ํ์์ผ๋ก ์ ๊ทํ""" |
| normalized: list = [] |
| for msg in history: |
| if isinstance(msg, dict) and "role" in msg and "content" in msg: |
| normalized.append({"role": msg["role"], "content": str(msg["content"])}) |
| elif isinstance(msg, (list, tuple)) and len(msg) == 2: |
| if msg[0]: |
| normalized.append({"role": "user", "content": str(msg[0])}) |
| if msg[1]: |
| normalized.append({"role": "assistant", "content": str(msg[1])}) |
| return normalized |
|
|
| def _generate_general_answer(self, query_text: str, history: list) -> str: |
| """๊ทธ๋ํ ๊ฒ์ ๊ฒฐ๊ณผ ์์ด GPT-4o-mini ์ผ๋ฐ ์ง์์ผ๋ก ๋ต๋ณ ์์ฑ (๋ํ ํ์คํ ๋ฆฌ ๋ฐ์)""" |
| assert self._rag_llm is not None |
| system_prompt = ( |
| "๋น์ ์ AI ๋ฐ ํํ
ํฌ ๊ธฐ์ ํธ๋ ๋ ์ ๋ฌธ๊ฐ์ด์, ์ทจ์
์ค๋น์์ ์ญ๋ ๋ถ์์ ๋๋ ์ ๋ต ์ปจ์คํดํธ์
๋๋ค.\n" |
| "ํ์ฌ FinGraph ์ง์ ๊ทธ๋ํ(Neo4j GraphRAG)์์ ๊ด๋ จ ๋ด์ค ๊ธฐ์ฌ๋ฅผ ์ฐพ์ง ๋ชปํ์ต๋๋ค.\n" |
| "์ด์ ๋ํ ๋งฅ๋ฝ์ ์ถฉ๋ถํ ๋ฐ์ํ๊ณ , GPT-4o-mini์ ์ผ๋ฐ ํ์ต ๋ฐ์ดํฐ์ ๊ธฐ๋ฐํ์ฌ ์ต์ ์ ๋คํด ์ ๋ฌธ์ ์ผ๋ก ๋ต๋ณํด ์ฃผ์ธ์.\n\n" |
| "[์ค์ ์ง์นจ]\n" |
| "- ์ค์ ์กด์ฌํ์ง ์๋ ๋ด์ค ๋งํฌ, ๋ ์ง, ๊ฐ์ง URL์ ์ ๋ ์์ฑํ์ง ๋ง์ธ์.\n" |
| "- ๊ฐ๋ฅํ๋ค๋ฉด ์ทจ์
์ค๋น์์ด ๋ฉด์ /์์์์ ํ์ฉํ ์ ์๋ ์ค์ง์ ์ธ ์ธ์ฌ์ดํธ๋ฅผ ํฌํจํด ์ฃผ์ธ์.\n" |
| "- ๋ต๋ณ์ด ์ผ๋ฐ AI ํ์ต ๋ฐ์ดํฐ ๊ธฐ๋ฐ์์ ์จ๊ธฐ์ง ๋ง๊ณ ์์ฐ์ค๋ฝ๊ฒ ์ธ๊ธํ๋ฉฐ ์์ํ์ธ์." |
| ) |
| normalized_history = self._normalize_history(history) |
| response = self._rag_llm.invoke( |
| input=query_text, |
| message_history=normalized_history, |
| system_instruction=system_prompt, |
| ) |
| return str(response.content) |
|
|
| def search_with_fallback(self, query_text: str, history: list) -> HybridResult: |
| """GraphRAG ๊ฒ์ -> ์ปจํ
์คํธ ํ์ง ํ๊ฐ -> ์ผ๋ฐ ์ง์ Fallback ํตํฉ ๋ฉ์๋. |
| |
| Args: |
| query_text: ์ฌ์ฉ์ ์ง๋ฌธ ํ
์คํธ |
| history: ์ด์ ๋ํ ํ์คํ ๋ฆฌ (Gradio ํ์) |
| |
| Returns: |
| HybridResult: ๋ต๋ณ, ๋ชจ๋("graph"|"general"), RetrieverResult |
| """ |
| self._init_once() |
| assert self._hybrid_retriever is not None |
| assert self._graphrag is not None |
|
|
| |
| retriever_result = self._hybrid_retriever.search(query_text=query_text) |
|
|
| |
| if self._is_context_sufficient(query_text, history, retriever_result): |
| |
| rag_result = self._graphrag.search(query_text=query_text) |
| return HybridResult( |
| answer=rag_result.answer, |
| mode="graph", |
| retriever_result=rag_result.retriever_result, |
| ) |
| else: |
| |
| answer = self._generate_general_answer(query_text, history) |
| return HybridResult(answer=answer, mode="general", retriever_result=None) |
|
|
| def search(self, *args: Any, **kwargs: Any) -> Any: |
| self._init_once() |
| assert self._graphrag is not None |
| return self._graphrag.search(*args, **kwargs) |
|
|
| def __getattr__(self, name: str) -> Any: |
| self._init_once() |
| return getattr(self._graphrag, name) |
|
|
|
|
| |
| graphrag = LazyGraphRAG() |
|
|