fix: implement robust neo4j driver fallback with verify_connectivity to prevent AuthError
Browse files- AGENTS.md +2 -1
- app.py +1 -0
- run_pipeline.py +1 -0
- src/graphBuilder/neo4j/finGraph.py +22 -5
- src/retrieval/finRetrieval.py +25 -5
- tests/smoke_test_rag.py +17 -4
- tests/test_retrieval.py +1 -0
AGENTS.md
CHANGED
|
@@ -79,8 +79,9 @@ def test_portfolio_showcase_aggregation_query():
|
|
| 79 |
assert any(indicator in response.answer for indicator in ["1.", "TOP", "κΈ°μ¬", "μΆμ²"]) # μΌμ’
μ skill
|
| 80 |
```
|
| 81 |
|
| 82 |
-
## μλ κ²μ¬
|
| 83 |
- λ‘컬 κ°λ° νκ²½μμ 컀λ°νκΈ° μ , λ°λμ ν°λ―Έλμ `ruff check .` λ° `mypy src tests --ignore-missing-imports` λͺ
λ Ήμ΄λ₯Ό μ§μ μ€ννμ¬ λ¦°νΈ λ° μ격ν νμ
μ€λ₯λ₯Ό νμ€νκ² νμΈνκ³ λͺ¨λ κ³ μΉ κ² (μ€λ₯κ° λ¨μμλ μνλ‘ μ»€λ° κΈμ§).
|
|
|
|
| 84 |
- μ»€λ° μ `pre-commit` μλ μ€ν
|
| 85 |
- `ruff`, `mypy` κ²μ¬ ν΅κ³Ό νμ
|
| 86 |
- κ²μ¬ μ€ν¨ μ μ»€λ° λΆκ°
|
|
|
|
| 79 |
assert any(indicator in response.answer for indicator in ["1.", "TOP", "κΈ°μ¬", "μΆμ²"]) # μΌμ’
μ skill
|
| 80 |
```
|
| 81 |
|
| 82 |
+
## μλ κ²μ¬ λ° λ°νμ μλ¬ λ°©μ§
|
| 83 |
- λ‘컬 κ°λ° νκ²½μμ 컀λ°νκΈ° μ , λ°λμ ν°λ―Έλμ `ruff check .` λ° `mypy src tests --ignore-missing-imports` λͺ
λ Ήμ΄λ₯Ό μ§μ μ€ννμ¬ λ¦°νΈ λ° μ격ν νμ
μ€λ₯λ₯Ό νμ€νκ² νμΈνκ³ λͺ¨λ κ³ μΉ κ² (μ€λ₯κ° λ¨μμλ μνλ‘ μ»€λ° κΈμ§).
|
| 84 |
+
- **λ°νμ Auth/μ°κ²° μλ¬ λ°©μ§**: λ¦°νΈ/νμ
κ²μ¬ ν λ°λμ `python tests/smoke_test_rag.py`λ₯Ό λ‘컬μμ μ€ννμ¬ `neo4j.exceptions.AuthError` λ±μ λ°νμ μλ¬κ° ν°μ§μ§ μκ³ μλ²½ν RAG κ²°κ³Όκ° μΆλ ₯λλμ§(DB μ μ μ μ) νμ₯ μ κ²(Smoke Test) ν νΈμν κ².
|
| 85 |
- μ»€λ° μ `pre-commit` μλ μ€ν
|
| 86 |
- `ruff`, `mypy` κ²μ¬ ν΅κ³Ό νμ
|
| 87 |
- κ²μ¬ μ€ν¨ μ μ»€λ° λΆκ°
|
app.py
CHANGED
|
@@ -22,6 +22,7 @@ dotenv.load_dotenv()
|
|
| 22 |
# 1. LangGraph μ±λ΄ State μ μ
|
| 23 |
# ββββββββββββββββββββββββββββββββββββββββββ
|
| 24 |
|
|
|
|
| 25 |
class ChatState(TypedDict):
|
| 26 |
question: str # μ¬μ©μ μ§λ¬Έ
|
| 27 |
history: List[dict] # λν νμ€ν 리 [{"role": "user"/"assistant", "content": "..."}]
|
|
|
|
| 22 |
# 1. LangGraph μ±λ΄ State μ μ
|
| 23 |
# ββββββββββββββββββββββββββββββββββββββββββ
|
| 24 |
|
| 25 |
+
|
| 26 |
class ChatState(TypedDict):
|
| 27 |
question: str # μ¬μ©μ μ§λ¬Έ
|
| 28 |
history: List[dict] # λν νμ€ν 리 [{"role": "user"/"assistant", "content": "..."}]
|
run_pipeline.py
CHANGED
|
@@ -59,5 +59,6 @@ def run_test():
|
|
| 59 |
else:
|
| 60 |
print("\nβοΈ AI κ΄λ ¨ κΈ°μ¬κ° μλλ―λ‘ κ·Έλν μμΈ λΆμ λ° λ²‘ν° μ μ¬λ₯Ό 건λλλλ€.")
|
| 61 |
|
|
|
|
| 62 |
if __name__ == "__main__":
|
| 63 |
run_test()
|
|
|
|
| 59 |
else:
|
| 60 |
print("\nβοΈ AI κ΄λ ¨ κΈ°μ¬κ° μλλ―λ‘ κ·Έλν μμΈ λΆμ λ° λ²‘ν° μ μ¬λ₯Ό 건λλλλ€.")
|
| 61 |
|
| 62 |
+
|
| 63 |
if __name__ == "__main__":
|
| 64 |
run_test()
|
src/graphBuilder/neo4j/finGraph.py
CHANGED
|
@@ -26,11 +26,28 @@ from neo4j_graphrag.llm import OpenAILLM
|
|
| 26 |
|
| 27 |
dotenv.load_dotenv()
|
| 28 |
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
chat_llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
|
| 36 |
rag_llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0})
|
|
|
|
| 26 |
|
| 27 |
dotenv.load_dotenv()
|
| 28 |
|
| 29 |
+
|
| 30 |
+
def get_neo4j_driver() -> neo4j.Driver:
|
| 31 |
+
uri = os.getenv("NEO4J_URI", "neo4j://localhost:7687")
|
| 32 |
+
client_id = os.getenv("NEO4J_CLIENT_ID")
|
| 33 |
+
client_secret = os.getenv("NEO4J_CLIENT_SECRET")
|
| 34 |
+
|
| 35 |
+
if client_id and client_secret:
|
| 36 |
+
try:
|
| 37 |
+
d = neo4j.GraphDatabase.driver(uri, auth=(client_id, client_secret))
|
| 38 |
+
d.verify_connectivity()
|
| 39 |
+
return d
|
| 40 |
+
except Exception:
|
| 41 |
+
pass
|
| 42 |
+
|
| 43 |
+
username = os.getenv("NEO4J_USERNAME", "neo4j")
|
| 44 |
+
password = os.getenv("NEO4J_PASSWORD", "password")
|
| 45 |
+
d = neo4j.GraphDatabase.driver(uri, auth=(username, password))
|
| 46 |
+
d.verify_connectivity()
|
| 47 |
+
return d
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
driver = get_neo4j_driver()
|
| 51 |
|
| 52 |
chat_llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
|
| 53 |
rag_llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0})
|
src/retrieval/finRetrieval.py
CHANGED
|
@@ -30,11 +30,28 @@ dotenv.load_dotenv()
|
|
| 30 |
# 1. DB / LLM / Embedder μ΄κΈ°ν
|
| 31 |
# ββββββββββββββββββββββββββββββββββββββββββ
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
rag_llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0})
|
| 40 |
embedder = OpenAIEmbeddings(model="text-embedding-3-small")
|
|
@@ -180,18 +197,21 @@ class HybridFallbackRetriever(Retriever):
|
|
| 180 |
return self.fallback_retriever.search(query_text=query_text, **kwargs)
|
| 181 |
return res
|
| 182 |
|
|
|
|
| 183 |
# νμ΄λΈλ¦¬λ κ²μ μΈμ€ν΄μ€ μ₯μ°©
|
| 184 |
hybrid_retriever = HybridFallbackRetriever(
|
| 185 |
tools_retriever=tools_retriever,
|
| 186 |
fallback_retriever=vector_cypher_retriever,
|
| 187 |
)
|
| 188 |
|
|
|
|
| 189 |
class CustomRagTemplate(RagTemplate):
|
| 190 |
EXPECTED_INPUTS = ["context", "query_text"]
|
| 191 |
|
| 192 |
def format(self, query_text: str, context: str, examples: str = "") -> str:
|
| 193 |
return self._format(query_text=query_text, context=context)
|
| 194 |
|
|
|
|
| 195 |
_prompt_template = CustomRagTemplate(
|
| 196 |
template="""λΉμ μ AI κΈ°μ νΈλ λ λΆμ μ λ¬Έκ°μ
λλ€.
|
| 197 |
λ°λμ μλ μ 곡λ [컨ν
μ€νΈ(Neo4j μ§μ κ·Έλν κ²μ κ²°κ³Ό)]μ κΈ°λ°ν΄μλ§ λ΅λ³νμΈμ.
|
|
|
|
| 30 |
# 1. DB / LLM / Embedder μ΄κΈ°ν
|
| 31 |
# ββββββββββββββββββββββββββββββββββββββββββ
|
| 32 |
|
| 33 |
+
|
| 34 |
+
def get_neo4j_driver() -> neo4j.Driver:
|
| 35 |
+
uri = os.getenv("NEO4J_URI", "neo4j://localhost:7687")
|
| 36 |
+
client_id = os.getenv("NEO4J_CLIENT_ID")
|
| 37 |
+
client_secret = os.getenv("NEO4J_CLIENT_SECRET")
|
| 38 |
+
|
| 39 |
+
if client_id and client_secret:
|
| 40 |
+
try:
|
| 41 |
+
d = neo4j.GraphDatabase.driver(uri, auth=(client_id, client_secret))
|
| 42 |
+
d.verify_connectivity()
|
| 43 |
+
return d
|
| 44 |
+
except Exception:
|
| 45 |
+
pass # Fallback to Username/Password
|
| 46 |
+
|
| 47 |
+
username = os.getenv("NEO4J_USERNAME", "neo4j")
|
| 48 |
+
password = os.getenv("NEO4J_PASSWORD", "password")
|
| 49 |
+
d = neo4j.GraphDatabase.driver(uri, auth=(username, password))
|
| 50 |
+
d.verify_connectivity()
|
| 51 |
+
return d
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
driver = get_neo4j_driver()
|
| 55 |
|
| 56 |
rag_llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0})
|
| 57 |
embedder = OpenAIEmbeddings(model="text-embedding-3-small")
|
|
|
|
| 197 |
return self.fallback_retriever.search(query_text=query_text, **kwargs)
|
| 198 |
return res
|
| 199 |
|
| 200 |
+
|
| 201 |
# νμ΄λΈλ¦¬λ κ²μ μΈμ€ν΄μ€ μ₯μ°©
|
| 202 |
hybrid_retriever = HybridFallbackRetriever(
|
| 203 |
tools_retriever=tools_retriever,
|
| 204 |
fallback_retriever=vector_cypher_retriever,
|
| 205 |
)
|
| 206 |
|
| 207 |
+
|
| 208 |
class CustomRagTemplate(RagTemplate):
|
| 209 |
EXPECTED_INPUTS = ["context", "query_text"]
|
| 210 |
|
| 211 |
def format(self, query_text: str, context: str, examples: str = "") -> str:
|
| 212 |
return self._format(query_text=query_text, context=context)
|
| 213 |
|
| 214 |
+
|
| 215 |
_prompt_template = CustomRagTemplate(
|
| 216 |
template="""λΉμ μ AI κΈ°μ νΈλ λ λΆμ μ λ¬Έκ°μ
λλ€.
|
| 217 |
λ°λμ μλ μ 곡λ [컨ν
μ€νΈ(Neo4j μ§μ κ·Έλν κ²μ κ²°κ³Ό)]μ κΈ°λ°ν΄μλ§ λ΅λ³νμΈμ.
|
tests/smoke_test_rag.py
CHANGED
|
@@ -20,15 +20,28 @@ import dotenv
|
|
| 20 |
|
| 21 |
dotenv.load_dotenv()
|
| 22 |
|
|
|
|
| 23 |
# ββ 0. κ·Έλν κ΅¬μ± μ¬μ μ κ² (Neo4j λ
Έλ/κ΄κ³ ν΅κ³) βββββββββββββββββββββββββ
|
| 24 |
def check_graph_structure():
|
| 25 |
import neo4j
|
| 26 |
|
| 27 |
uri = os.getenv("NEO4J_URI", "neo4j://localhost:7687")
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
driver =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
print("\n" + "=" * 60)
|
| 34 |
print("π [μ¬μ μ κ²] Neo4j κ·Έλν κ΅¬μ± νν©")
|
|
|
|
| 20 |
|
| 21 |
dotenv.load_dotenv()
|
| 22 |
|
| 23 |
+
|
| 24 |
# ββ 0. κ·Έλν κ΅¬μ± μ¬μ μ κ² (Neo4j λ
Έλ/κ΄κ³ ν΅κ³) βββββββββββββββββββββββββ
|
| 25 |
def check_graph_structure():
|
| 26 |
import neo4j
|
| 27 |
|
| 28 |
uri = os.getenv("NEO4J_URI", "neo4j://localhost:7687")
|
| 29 |
+
client_id = os.getenv("NEO4J_CLIENT_ID")
|
| 30 |
+
client_secret = os.getenv("NEO4J_CLIENT_SECRET")
|
| 31 |
+
|
| 32 |
+
driver = None
|
| 33 |
+
if client_id and client_secret:
|
| 34 |
+
try:
|
| 35 |
+
driver = neo4j.GraphDatabase.driver(uri, auth=(client_id, client_secret))
|
| 36 |
+
driver.verify_connectivity()
|
| 37 |
+
except Exception:
|
| 38 |
+
driver = None
|
| 39 |
+
|
| 40 |
+
if not driver:
|
| 41 |
+
username = os.getenv("NEO4J_USERNAME", "neo4j")
|
| 42 |
+
password = os.getenv("NEO4J_PASSWORD", "password")
|
| 43 |
+
driver = neo4j.GraphDatabase.driver(uri, auth=(username, password))
|
| 44 |
+
driver.verify_connectivity()
|
| 45 |
|
| 46 |
print("\n" + "=" * 60)
|
| 47 |
print("π [μ¬μ μ κ²] Neo4j κ·Έλν κ΅¬μ± νν©")
|
tests/test_retrieval.py
CHANGED
|
@@ -10,6 +10,7 @@ has_credentials = (
|
|
| 10 |
os.getenv("NEO4J_URI") is not None
|
| 11 |
)
|
| 12 |
|
|
|
|
| 13 |
@pytest.mark.skipif(
|
| 14 |
not has_credentials,
|
| 15 |
reason="OpenAI API Key λλ Neo4j μ°κ²° νκ²½λ³μκ° μμΌλ―λ‘ ν΅ν© ν
μ€νΈλ₯Ό 건λλλλ€."
|
|
|
|
| 10 |
os.getenv("NEO4J_URI") is not None
|
| 11 |
)
|
| 12 |
|
| 13 |
+
|
| 14 |
@pytest.mark.skipif(
|
| 15 |
not has_credentials,
|
| 16 |
reason="OpenAI API Key λλ Neo4j μ°κ²° νκ²½λ³μκ° μμΌλ―λ‘ ν΅ν© ν
μ€νΈλ₯Ό 건λλλλ€."
|