dev-yuje commited on
Commit
e7d3bfe
Β·
1 Parent(s): 3e23aae

feat: implement self-reflection loop in graph builder and unify models to gpt-4o-mini

Browse files
src/graphBuilder/neo4j/finGraph.py CHANGED
@@ -50,7 +50,7 @@ def get_neo4j_driver() -> neo4j.Driver:
50
  driver = None
51
 
52
  chat_llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
53
- rag_llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0})
54
  embedder = OpenAIEmbeddings(model="text-embedding-3-small")
55
 
56
  INDEX_NAME = "content_vector_index"
@@ -67,6 +67,8 @@ class ArticleState(TypedDict):
67
  is_ai_related: bool
68
  entities: List[Dict]
69
  relations: List[Dict]
 
 
70
 
71
 
72
  def check_ai_relevance(state: ArticleState) -> ArticleState:
@@ -83,8 +85,19 @@ def check_ai_relevance(state: ArticleState) -> ArticleState:
83
 
84
 
85
  def extract_entities(state: ArticleState) -> ArticleState:
86
- """Node 2: μ—”ν‹°ν‹° μΆ”μΆœ"""
87
- prompt = f"""λ‹€μŒ AI λ‰΄μŠ€μ—μ„œ μ—”ν‹°ν‹°λ₯Ό μΆ”μΆœν•˜μ„Έμš”.
 
 
 
 
 
 
 
 
 
 
 
88
  μ—”ν‹°ν‹° μœ ν˜•:
89
  - AICompany: κΈ°μ—…/κΈ°κ΄€ (예: μ‚Όμ„±μ „μž, OpenAI)
90
  - AITechnology: AI 기술 (예: λŒ€κ·œλͺ¨μ–Έμ–΄λͺ¨λΈ, κ°•ν™”ν•™μŠ΅)
@@ -92,18 +105,58 @@ def extract_entities(state: ArticleState) -> ArticleState:
92
  - AIField: 적용 λΆ„μ•Ό (예: 금육AI, AI λ°˜λ„μ²΄)
93
 
94
  제λͺ©: {state["title"]}
95
- λ³Έλ¬Έ: {state["text"][:900]}
96
 
97
- JSON으둜만 응닡:{{"entities":[{{"name":"...","type":"AICompany|AITechnology|AIService|AIField","description":"..."}}]}}"""
 
98
  res = chat_llm.invoke(prompt)
 
 
 
99
  try:
100
  raw = str(res.content).strip()
101
  if "```" in raw:
102
  raw = raw.split("```")[1].lstrip("json")
103
- entities = json.loads(raw).get("entities", [])
104
- except Exception:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  entities = []
106
- return {**state, "entities": entities}
 
 
 
 
 
 
 
107
 
108
 
109
  def extract_relations(state: ArticleState) -> ArticleState:
@@ -134,13 +187,40 @@ def route_after_check(state: ArticleState) -> str:
134
  return "extract_entities" if state["is_ai_related"] else END
135
 
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  builder = StateGraph(ArticleState)
138
  builder.add_node("check_ai", check_ai_relevance)
139
  builder.add_node("extract_entities", extract_entities)
140
  builder.add_node("extract_relations", extract_relations)
141
  builder.set_entry_point("check_ai")
142
  builder.add_conditional_edges("check_ai", route_after_check)
143
- builder.add_edge("extract_entities", "extract_relations")
 
 
 
 
 
 
 
 
 
 
144
  builder.add_edge("extract_relations", END)
145
  pipeline = builder.compile()
146
 
@@ -303,6 +383,8 @@ def main() -> None:
303
  is_ai_related=False,
304
  entities=[],
305
  relations=[],
 
 
306
  )
307
  out = pipeline.invoke(state)
308
  if out["is_ai_related"]:
 
50
  driver = None
51
 
52
  chat_llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
53
+ rag_llm = OpenAILLM(model_name="gpt-4o-mini", model_params={"temperature": 0})
54
  embedder = OpenAIEmbeddings(model="text-embedding-3-small")
55
 
56
  INDEX_NAME = "content_vector_index"
 
67
  is_ai_related: bool
68
  entities: List[Dict]
69
  relations: List[Dict]
70
+ retry_count: int
71
+ reflection_feedback: str
72
 
73
 
74
  def check_ai_relevance(state: ArticleState) -> ArticleState:
 
85
 
86
 
87
  def extract_entities(state: ArticleState) -> ArticleState:
88
+ """Node 2: μ—”ν‹°ν‹° μΆ”μΆœ (μžκΈ°λ°˜μ„± ν”Όλ“œλ°± 반영 및 νƒ€μž… μ •ν•©μ„± 검증)"""
89
+ retry_count = state.get("retry_count", 0) + 1
90
+ feedback = state.get("reflection_feedback", "")
91
+
92
+ feedback_prompt = ""
93
+ if feedback:
94
+ feedback_prompt = (
95
+ f"\n\n⚠️ [이전 μ‹œλ„μ— λŒ€ν•œ 검증 였λ₯˜ ν”Όλ“œλ°±]:\n{feedback}\n"
96
+ "μœ„ 였λ₯˜λ₯Ό λ°˜λ“œμ‹œ λΆ„μ„ν•˜μ—¬, μ΄λ²ˆμ—λŠ” μ€‘λ³΅λ˜κ±°λ‚˜ λΉ„μ–΄μžˆκ±°λ‚˜ λΆˆμ™„μ „ν•˜μ§€ μ•Šκ³  "
97
+ "μ •ν™•ν•œ νƒ€μž…κ³Ό μ„€λͺ…을 κ°–μΆ˜ μ˜¬λ°”λ₯Έ μ—”ν‹°ν‹°λ§Œ μ—„κ²©ν•˜κ²Œ JSON으둜 μΆ”μΆœν•΄μ£Όμ„Έμš”."
98
+ )
99
+
100
+ prompt = f"""λ‹€μŒ AI λ‰΄μŠ€μ—μ„œ 핡심 엔티티듀을 μΆ”μΆœν•˜μ„Έμš”.
101
  μ—”ν‹°ν‹° μœ ν˜•:
102
  - AICompany: κΈ°μ—…/κΈ°κ΄€ (예: μ‚Όμ„±μ „μž, OpenAI)
103
  - AITechnology: AI 기술 (예: λŒ€κ·œλͺ¨μ–Έμ–΄λͺ¨λΈ, κ°•ν™”ν•™μŠ΅)
 
105
  - AIField: 적용 λΆ„μ•Ό (예: 금육AI, AI λ°˜λ„μ²΄)
106
 
107
  제λͺ©: {state["title"]}
108
+ λ³Έλ¬Έ: {state["text"][:900]}{feedback_prompt}
109
 
110
+ JSON으둜만 응닡: {{"entities":[{{"name":"...","type":"AICompany|AITechnology|AIService|AIField","description":"..."}}]}}"""
111
+
112
  res = chat_llm.invoke(prompt)
113
+ entities = []
114
+ new_feedback = ""
115
+
116
  try:
117
  raw = str(res.content).strip()
118
  if "```" in raw:
119
  raw = raw.split("```")[1].lstrip("json")
120
+ data = json.loads(raw)
121
+ extracted = data.get("entities", [])
122
+
123
+ allowed_types = {"AICompany", "AITechnology", "AIService", "AIField"}
124
+ valid_entities = []
125
+ for e in extracted:
126
+ name = e.get("name", "").strip()
127
+ etype = e.get("type", "").strip()
128
+ desc = e.get("description", "").strip()
129
+
130
+ if not name:
131
+ new_feedback += "- μ—”ν‹°ν‹°μ˜ 이름(name) ν•„λ“œκ°€ λˆ„λ½λ˜μ—ˆκ±°λ‚˜ λΉ„μ–΄μžˆμŠ΅λ‹ˆλ‹€.\n"
132
+ continue
133
+ if etype not in allowed_types:
134
+ new_feedback += f"- μ—”ν‹°ν‹° '{name}'의 νƒ€μž… '{etype}'은 ν—ˆμš©λœ μ’…λ₯˜({', '.join(allowed_types)})κ°€ μ•„λ‹™λ‹ˆλ‹€.\n"
135
+ continue
136
+ if not desc:
137
+ new_feedback += f"- μ—”ν‹°ν‹° '{name}'에 λŒ€ν•œ μ„€λͺ…(description)이 μƒλž΅λ˜μ—ˆμŠ΅λ‹ˆλ‹€.\n"
138
+ continue
139
+
140
+ valid_entities.append({
141
+ "name": name,
142
+ "type": etype,
143
+ "description": desc
144
+ })
145
+
146
+ entities = valid_entities
147
+ if not entities:
148
+ new_feedback = "μœ νš¨ν•œ μ—”ν‹°ν‹°κ°€ ν•˜λ‚˜λ„ μΆ”μΆœλ˜μ§€ μ•Šμ•˜μŠ΅λ‹ˆλ‹€."
149
+
150
+ except Exception as err:
151
  entities = []
152
+ new_feedback = f"응닡 JSON νŒŒμ‹± μ‹€νŒ¨ λ˜λŠ” ν˜•μ‹μ΄ μ˜¬λ°”λ₯΄μ§€ μ•ŠμŠ΅λ‹ˆλ‹€. μ—λŸ¬: {str(err)}"
153
+
154
+ return {
155
+ **state,
156
+ "entities": entities,
157
+ "retry_count": retry_count,
158
+ "reflection_feedback": new_feedback.strip()
159
+ }
160
 
161
 
162
  def extract_relations(state: ArticleState) -> ArticleState:
 
187
  return "extract_entities" if state["is_ai_related"] else END
188
 
189
 
190
+ def validate_entities(state: ArticleState) -> str:
191
+ """μΆ”μΆœλœ μ—”ν‹°ν‹°μ˜ ν’ˆμ§ˆμ„ κ²€μ¦ν•˜κ³ , 미달할 경우 μ΅œλŒ€ 3νšŒκΉŒμ§€ μžκΈ°λ°˜μ„±(Self-Reflection) 루프λ₯Ό λ™μž‘μ‹œν‚΅λ‹ˆλ‹€."""
192
+ retry_count = state.get("retry_count", 0)
193
+ feedback = state.get("reflection_feedback", "")
194
+ entities = state.get("entities", [])
195
+
196
+ # μΆ”μΆœμ— 문제점이 있고 아직 μ΅œλŒ€ 3회 μž¬μ‹œλ„λ₯Ό μ΄ˆκ³Όν•˜μ§€ μ•Šμ€ 경우
197
+ if (feedback or not entities) and retry_count < 3:
198
+ print(f" ⚠️ [Self-Reflection] μ—”ν‹°ν‹° ν’ˆμ§ˆ 미달 (μ‹œλ„ {retry_count}/3). ν”Όλ“œλ°±: {feedback[:100]}...")
199
+ return "extract_entities" # μžκΈ°λ°˜μ„± λ£¨ν”„λ‘œ 볡귀
200
+
201
+ if feedback and retry_count >= 3:
202
+ print(f" 🚨 [Self-Reflection] μ—”ν‹°ν‹° 3회 μ‹œλ„ 초과. 검증 였λ₯˜κ°€ μžˆμ§€λ§Œ νŒ¨μŠ€ν•©λ‹ˆλ‹€. ν”Όλ“œλ°±: {feedback[:100]}...")
203
+
204
+ return "extract_relations" # 검증을 정상 ν†΅κ³Όν–ˆκ±°λ‚˜ μ΅œλŒ€ 3회 ν•œλ„μ— λ„λ‹¬ν•œ 경우 톡과
205
+
206
+
207
  builder = StateGraph(ArticleState)
208
  builder.add_node("check_ai", check_ai_relevance)
209
  builder.add_node("extract_entities", extract_entities)
210
  builder.add_node("extract_relations", extract_relations)
211
  builder.set_entry_point("check_ai")
212
  builder.add_conditional_edges("check_ai", route_after_check)
213
+
214
+ # μžκΈ°λ°˜μ„± 쑰건뢀 μ—£μ§€ λ§€ν•‘
215
+ builder.add_conditional_edges(
216
+ "extract_entities",
217
+ validate_entities,
218
+ {
219
+ "extract_entities": "extract_entities",
220
+ "extract_relations": "extract_relations"
221
+ }
222
+ )
223
+
224
  builder.add_edge("extract_relations", END)
225
  pipeline = builder.compile()
226
 
 
383
  is_ai_related=False,
384
  entities=[],
385
  relations=[],
386
+ retry_count=0,
387
+ reflection_feedback="",
388
  )
389
  out = pipeline.invoke(state)
390
  if out["is_ai_related"]:
src/retrieval/finRetrieval.py CHANGED
@@ -38,7 +38,7 @@ class HybridResult:
38
  """GraphRAG λ˜λŠ” 일반 지식 기반 톡합 응닡 κ²°κ³Ό"""
39
 
40
  answer: str # μ΅œμ’… λ‹΅λ³€ λ¬Έμžμ—΄
41
- mode: str # "graph": κ·Έλž˜ν”„ 검색 기반 | "general": GPT-4o 일반 지식 기반
42
  retriever_result: Any = None # RetrieverResult (mode="graph"일 λ•Œλ§Œ 유효)
43
 
44
 
@@ -294,7 +294,7 @@ class LazyGraphRAG:
294
  return
295
 
296
  # OpenAI ν΄λΌμ΄μ–ΈνŠΈ 및 μž„λ² λ” μ§€μ—° μ΄ˆκΈ°ν™” (CI ν¬λž˜μ‹œ λ°©μ§€)
297
- self._rag_llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0})
298
  embedder = OpenAIEmbeddings(model="text-embedding-3-small")
299
 
300
  driver = get_neo4j_driver()
@@ -349,7 +349,7 @@ class LazyGraphRAG:
349
  )
350
 
351
  def _is_context_sufficient(self, query_text: str, history: list, retriever_result: Any) -> bool:
352
- """κ²€μƒ‰λœ μ»¨ν…μŠ€νŠΈκ°€ 질문 및 이전 λŒ€ν™” 흐름에 μ‹€μ§ˆμ μœΌλ‘œ 도움이 λ˜λŠ” 금육/기술 λ‰΄μŠ€ 데이터인지 GPT-4o둜 νŒλ‹¨"""
353
  if retriever_result is None:
354
  return False
355
  if not hasattr(retriever_result, "items") or not retriever_result.items:
@@ -360,7 +360,7 @@ class LazyGraphRAG:
360
  if len(total_content) < 100:
361
  return False
362
 
363
- # GPT-4o 기반 μ§€λŠ₯적 μžκ°€ 진단 (이전 λŒ€ν™” νžˆμŠ€ν† λ¦¬ 및 질문의 λ§₯락 κ²°ν•© νŒμ •)
364
  try:
365
  assert self._rag_llm is not None
366
  context_snippet = total_content[:800]
@@ -414,12 +414,12 @@ class LazyGraphRAG:
414
  return normalized
415
 
416
  def _generate_general_answer(self, query_text: str, history: list) -> str:
417
- """κ·Έλž˜ν”„ 검색 κ²°κ³Ό 없이 GPT-4o 일반 μ§€μ‹μœΌλ‘œ λ‹΅λ³€ 생성 (λŒ€ν™” νžˆμŠ€ν† λ¦¬ 반영)"""
418
  assert self._rag_llm is not None
419
  system_prompt = (
420
  "당신은 AI 및 ν•€ν…Œν¬ 기술 νŠΈλ Œλ“œ μ „λ¬Έκ°€μ΄μž, μ·¨μ—… μ€€λΉ„μƒμ˜ μ—­λŸ‰ 뢄석을 λ•λŠ” μ „λž΅ μ»¨μ„€ν„΄νŠΈμž…λ‹ˆλ‹€.\n"
421
  "ν˜„μž¬ FinGraph 지식 κ·Έλž˜ν”„(Neo4j GraphRAG)μ—μ„œ κ΄€λ ¨ λ‰΄μŠ€ 기사λ₯Ό μ°Ύμ§€ λͺ»ν–ˆμŠ΅λ‹ˆλ‹€.\n"
422
- "이전 λŒ€ν™” λ§₯락을 μΆ©λΆ„νžˆ λ°˜μ˜ν•˜κ³ , GPT-4o의 일반 ν•™μŠ΅ 데이터에 κΈ°λ°˜ν•˜μ—¬ μ΅œμ„ μ„ λ‹€ν•΄ μ „λ¬Έμ μœΌλ‘œ λ‹΅λ³€ν•΄ μ£Όμ„Έμš”.\n\n"
423
  "[μ€‘μš” μ§€μΉ¨]\n"
424
  "- μ‹€μ œ μ‘΄μž¬ν•˜μ§€ μ•ŠλŠ” λ‰΄μŠ€ 링크, λ‚ μ§œ, κ°€μ§œ URL을 μ ˆλŒ€ μƒμ„±ν•˜μ§€ λ§ˆμ„Έμš”.\n"
425
  "- κ°€λŠ₯ν•˜λ‹€λ©΄ μ·¨μ—… 쀀비생이 λ©΄μ ‘/μžμ†Œμ„œμ— ν™œμš©ν•  수 μžˆλŠ” μ‹€μ§ˆμ μΈ μΈμ‚¬μ΄νŠΈλ₯Ό 포함해 μ£Όμ„Έμš”.\n"
@@ -460,7 +460,7 @@ class LazyGraphRAG:
460
  retriever_result=rag_result.retriever_result,
461
  )
462
  else:
463
- # 3b. 일반 지식 기반 -> νžˆμŠ€ν† λ¦¬ 포함 GPT-4o 직접 호좜
464
  answer = self._generate_general_answer(query_text, history)
465
  return HybridResult(answer=answer, mode="general", retriever_result=None)
466
 
 
38
  """GraphRAG λ˜λŠ” 일반 지식 기반 톡합 응닡 κ²°κ³Ό"""
39
 
40
  answer: str # μ΅œμ’… λ‹΅λ³€ λ¬Έμžμ—΄
41
+ mode: str # "graph": κ·Έλž˜ν”„ 검색 기반 | "general": GPT-4o-mini 일반 지식 기반
42
  retriever_result: Any = None # RetrieverResult (mode="graph"일 λ•Œλ§Œ 유효)
43
 
44
 
 
294
  return
295
 
296
  # OpenAI ν΄λΌμ΄μ–ΈνŠΈ 및 μž„λ² λ” μ§€μ—° μ΄ˆκΈ°ν™” (CI ν¬λž˜μ‹œ λ°©μ§€)
297
+ self._rag_llm = OpenAILLM(model_name="gpt-4o-mini", model_params={"temperature": 0})
298
  embedder = OpenAIEmbeddings(model="text-embedding-3-small")
299
 
300
  driver = get_neo4j_driver()
 
349
  )
350
 
351
  def _is_context_sufficient(self, query_text: str, history: list, retriever_result: Any) -> bool:
352
+ """κ²€μƒ‰λœ μ»¨ν…μŠ€νŠΈκ°€ 질문 및 이전 λŒ€ν™” 흐름에 μ‹€μ§ˆμ μœΌλ‘œ 도움이 λ˜λŠ” 금육/기술 λ‰΄μŠ€ 데이터인지 GPT-4o-mini둜 νŒλ‹¨"""
353
  if retriever_result is None:
354
  return False
355
  if not hasattr(retriever_result, "items") or not retriever_result.items:
 
360
  if len(total_content) < 100:
361
  return False
362
 
363
+ # GPT-4o-mini 기반 μ§€λŠ₯적 μžκ°€ 진단 (이전 λŒ€ν™” νžˆμŠ€ν† λ¦¬ 및 질문의 λ§₯락 κ²°ν•© νŒμ •)
364
  try:
365
  assert self._rag_llm is not None
366
  context_snippet = total_content[:800]
 
414
  return normalized
415
 
416
  def _generate_general_answer(self, query_text: str, history: list) -> str:
417
+ """κ·Έλž˜ν”„ 검색 κ²°κ³Ό 없이 GPT-4o-mini 일반 μ§€μ‹μœΌλ‘œ λ‹΅λ³€ 생성 (λŒ€ν™” νžˆμŠ€ν† λ¦¬ 반영)"""
418
  assert self._rag_llm is not None
419
  system_prompt = (
420
  "당신은 AI 및 ν•€ν…Œν¬ 기술 νŠΈλ Œλ“œ μ „λ¬Έκ°€μ΄μž, μ·¨μ—… μ€€λΉ„μƒμ˜ μ—­λŸ‰ 뢄석을 λ•λŠ” μ „λž΅ μ»¨μ„€ν„΄νŠΈμž…λ‹ˆλ‹€.\n"
421
  "ν˜„μž¬ FinGraph 지식 κ·Έλž˜ν”„(Neo4j GraphRAG)μ—μ„œ κ΄€λ ¨ λ‰΄μŠ€ 기사λ₯Ό μ°Ύμ§€ λͺ»ν–ˆμŠ΅λ‹ˆλ‹€.\n"
422
+ "이전 λŒ€ν™” λ§₯락을 μΆ©λΆ„νžˆ λ°˜μ˜ν•˜κ³ , GPT-4o-mini의 일반 ν•™μŠ΅ 데이터에 κΈ°λ°˜ν•˜μ—¬ μ΅œμ„ μ„ λ‹€ν•΄ μ „λ¬Έμ μœΌλ‘œ λ‹΅λ³€ν•΄ μ£Όμ„Έμš”.\n\n"
423
  "[μ€‘μš” μ§€μΉ¨]\n"
424
  "- μ‹€μ œ μ‘΄μž¬ν•˜μ§€ μ•ŠλŠ” λ‰΄μŠ€ 링크, λ‚ μ§œ, κ°€μ§œ URL을 μ ˆλŒ€ μƒμ„±ν•˜μ§€ λ§ˆμ„Έμš”.\n"
425
  "- κ°€λŠ₯ν•˜λ‹€λ©΄ μ·¨μ—… 쀀비생이 λ©΄μ ‘/μžμ†Œμ„œμ— ν™œμš©ν•  수 μžˆλŠ” μ‹€μ§ˆμ μΈ μΈμ‚¬μ΄νŠΈλ₯Ό 포함해 μ£Όμ„Έμš”.\n"
 
460
  retriever_result=rag_result.retriever_result,
461
  )
462
  else:
463
+ # 3b. 일반 지식 기반 -> νžˆμŠ€ν† λ¦¬ 포함 GPT-4o-mini 직접 호좜
464
  answer = self._generate_general_answer(query_text, history)
465
  return HybridResult(answer=answer, mode="general", retriever_result=None)
466