lanny xu commited on
Commit
9f144ed
·
1 Parent(s): 0e0e0db
Files changed (2) hide show
  1. graph_retriever.py +212 -2
  2. main_graphrag.py +6 -2
graph_retriever.py CHANGED
@@ -4,6 +4,15 @@ GraphRAG检索器
4
  """
5
 
6
  from typing import List, Dict, Set, Tuple
 
 
 
 
 
 
 
 
 
7
  try:
8
  from langchain_core.prompts import PromptTemplate
9
  except ImportError:
@@ -17,6 +26,8 @@ from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
17
 
18
  from knowledge_graph import KnowledgeGraph
19
  from config import LOCAL_LLM
 
 
20
 
21
 
22
  class GraphRetriever:
@@ -25,6 +36,7 @@ class GraphRetriever:
25
  def __init__(self, knowledge_graph: KnowledgeGraph):
26
  self.kg = knowledge_graph
27
  self.llm = ChatOllama(model=LOCAL_LLM, temperature=0.3)
 
28
 
29
  # 实体识别提示
30
  self.entity_recognition_prompt = PromptTemplate(
@@ -124,6 +136,54 @@ class GraphRetriever:
124
  print(f"❌ 实体识别失败: {e}")
125
  return []
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  def local_query(self, question: str, max_hops: int = 2, top_k: int = 10) -> str:
128
  """
129
  本地查询 - 基于问题中的实体及其邻域进行检索
@@ -152,8 +212,8 @@ class GraphRetriever:
152
  for entity in mentioned_entities:
153
  neighbors = self.kg.get_node_neighbors(entity, depth=max_hops)
154
  relevant_entities.update(neighbors)
155
-
156
- relevant_entities = list(relevant_entities)[:top_k]
157
 
158
  # 3. 收集实体信息
159
  entity_info_list = []
@@ -226,6 +286,156 @@ class GraphRetriever:
226
  print(f"❌ 全局查询失败: {e}")
227
  return "查询失败,请重试。"
228
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  def hybrid_query(self, question: str) -> Dict[str, str]:
230
  """
231
  混合查询 - 同时执行本地和全局查询,返回两种结果
 
4
  """
5
 
6
  from typing import List, Dict, Set, Tuple
7
+ import time
8
+ import networkx as nx
9
+ try:
10
+ from langchain_core.documents import Document
11
+ except ImportError:
12
+ try:
13
+ from langchain_core.documents import Document
14
+ except ImportError:
15
+ from langchain.schema import Document
16
  try:
17
  from langchain_core.prompts import PromptTemplate
18
  except ImportError:
 
26
 
27
  from knowledge_graph import KnowledgeGraph
28
  from config import LOCAL_LLM
29
+ from retrieval_evaluation import RetrievalEvaluator, RetrievalResult
30
+ from routers_and_graders import HallucinationGrader
31
 
32
 
33
  class GraphRetriever:
 
36
  def __init__(self, knowledge_graph: KnowledgeGraph):
37
  self.kg = knowledge_graph
38
  self.llm = ChatOllama(model=LOCAL_LLM, temperature=0.3)
39
+ self.hallucination_grader = HallucinationGrader()
40
 
41
  # 实体识别提示
42
  self.entity_recognition_prompt = PromptTemplate(
 
136
  print(f"❌ 实体识别失败: {e}")
137
  return []
138
 
139
+ def _normalize_map(self, values: Dict[str, float], keys: List[str]) -> Dict[str, float]:
140
+ arr = [values.get(k, 0.0) for k in keys]
141
+ if not arr:
142
+ return {k: 0.0 for k in keys}
143
+ mn = min(arr)
144
+ mx = max(arr)
145
+ if mx == mn:
146
+ return {k: 0.5 for k in keys}
147
+ return {k: (values.get(k, 0.0) - mn) / (mx - mn) for k in keys}
148
+
149
+ def _rank_entities(self, mentioned_entities: List[str], candidate_entities: List[str]) -> List[str]:
150
+ G = self.kg.graph
151
+ nodes = list(set(candidate_entities) | set(mentioned_entities))
152
+ if not nodes:
153
+ return []
154
+ subG = G.subgraph(nodes)
155
+ deg = nx.degree_centrality(subG)
156
+ btw = nx.betweenness_centrality(subG, normalized=True)
157
+ weight_to_mentioned = {}
158
+ path_prox = {}
159
+ for n in candidate_entities:
160
+ w_sum = 0.0
161
+ best_len = None
162
+ for m in mentioned_entities:
163
+ if G.has_edge(n, m):
164
+ data = G.get_edge_data(n, m)
165
+ if isinstance(data, dict):
166
+ w_sum += float(data.get('weight', 1.0))
167
+ else:
168
+ w_sum += 1.0
169
+ try:
170
+ l = nx.shortest_path_length(G, source=m, target=n)
171
+ if best_len is None or l < best_len:
172
+ best_len = l
173
+ except nx.NetworkXNoPath:
174
+ pass
175
+ weight_to_mentioned[n] = w_sum
176
+ path_prox[n] = 0.0 if best_len is None else 1.0 / (1.0 + best_len)
177
+ deg_n = self._normalize_map(deg, candidate_entities)
178
+ btw_n = self._normalize_map(btw, candidate_entities)
179
+ w_n = self._normalize_map(weight_to_mentioned, candidate_entities)
180
+ prox_n = self._normalize_map(path_prox, candidate_entities)
181
+ scores = {}
182
+ for n in candidate_entities:
183
+ scores[n] = 0.3 * deg_n.get(n, 0.0) + 0.3 * btw_n.get(n, 0.0) + 0.2 * w_n.get(n, 0.0) + 0.2 * prox_n.get(n, 0.0)
184
+ ranked = sorted(candidate_entities, key=lambda x: scores.get(x, 0.0), reverse=True)
185
+ return ranked
186
+
187
  def local_query(self, question: str, max_hops: int = 2, top_k: int = 10) -> str:
188
  """
189
  本地查询 - 基于问题中的实体及其邻域进行检索
 
212
  for entity in mentioned_entities:
213
  neighbors = self.kg.get_node_neighbors(entity, depth=max_hops)
214
  relevant_entities.update(neighbors)
215
+ ranked_entities = self._rank_entities(mentioned_entities, list(relevant_entities))
216
+ relevant_entities = ranked_entities[:top_k]
217
 
218
  # 3. 收集实体信息
219
  entity_info_list = []
 
286
  print(f"❌ 全局查询失败: {e}")
287
  return "查询失败,请重试。"
288
 
289
+ def local_query_with_metrics(self, question: str, max_hops: int = 2, top_k: int = 10, k_values: List[int] = [1, 3, 5]) -> tuple:
290
+ print(f"\n🔎 执行本地查询并评估...")
291
+ start_t = time.time()
292
+ mentioned_entities = self.recognize_entities(question)
293
+ if not mentioned_entities:
294
+ return "未能在知识图谱中找到相关实体。", {
295
+ "error": "no_entities",
296
+ "latency": 0.0,
297
+ "retrieved_docs_count": 0
298
+ }
299
+ relevant_entities = set()
300
+ for entity in mentioned_entities:
301
+ neighbors = self.kg.get_node_neighbors(entity, depth=max_hops)
302
+ relevant_entities.update(neighbors)
303
+ ranked_entities = self._rank_entities(mentioned_entities, list(relevant_entities))
304
+ relevant_entities = ranked_entities[:top_k]
305
+ entity_info_list = []
306
+ for entity in relevant_entities:
307
+ info = self.kg.get_entity_info(entity)
308
+ if info:
309
+ entity_info_list.append(f"- {info['name']} ({info.get('type', 'UNKNOWN')}): {info.get('description', '无描述')}")
310
+ relation_list = []
311
+ for u, v, data in self.kg.graph.edges(data=True):
312
+ if u in relevant_entities and v in relevant_entities:
313
+ relation_list.append(f"- {u} --[{data.get('relation_type', 'RELATED')}]--> {v}: {data.get('description', '')}")
314
+ entity_info_text = "\n".join(entity_info_list) if entity_info_list else "无相关实体信息"
315
+ relations_text = "\n".join(relation_list[:20]) if relation_list else "无相关关系"
316
+ try:
317
+ answer = self.local_query_chain.invoke({
318
+ "question": question,
319
+ "entity_info": entity_info_text,
320
+ "relations": relations_text
321
+ }).strip()
322
+ except Exception:
323
+ answer = "查询失败,请重试。"
324
+ retrieved_docs = []
325
+ for entity in relevant_entities:
326
+ info = self.kg.get_entity_info(entity) or {"name": entity}
327
+ content = f"{info.get('name', entity)} {info.get('type', '')} {info.get('description', '')}".strip()
328
+ retrieved_docs.append(Document(page_content=content, metadata={"entity": info.get('name', entity)}))
329
+ try:
330
+ hallucination_grade = self.hallucination_grader.grade(answer, retrieved_docs)
331
+ except Exception:
332
+ hallucination_grade = "unknown"
333
+ relevant_docs = []
334
+ for entity in mentioned_entities:
335
+ info = self.kg.get_entity_info(entity) or {"name": entity}
336
+ content = f"{info.get('name', entity)} {info.get('type', '')} {info.get('description', '')}".strip()
337
+ relevant_docs.append(Document(page_content=content, metadata={"entity": info.get('name', entity)}))
338
+ latency = time.time() - start_t
339
+ try:
340
+ evaluator = RetrievalEvaluator()
341
+ result = RetrievalResult(query=question, retrieved_docs=retrieved_docs, relevant_docs=relevant_docs, retrieval_time=latency)
342
+ metrics_obj = evaluator.evaluate_retrieval([result], k_values=k_values)
343
+ metrics = {
344
+ "precision_at_1": metrics_obj.precision_at_k.get(1, 0),
345
+ "precision_at_3": metrics_obj.precision_at_k.get(3, 0),
346
+ "precision_at_5": metrics_obj.precision_at_k.get(5, 0),
347
+ "recall_at_1": metrics_obj.recall_at_k.get(1, 0),
348
+ "recall_at_3": metrics_obj.recall_at_k.get(3, 0),
349
+ "recall_at_5": metrics_obj.recall_at_k.get(5, 0),
350
+ "map_score": metrics_obj.map_score,
351
+ "mrr": metrics_obj.mrr,
352
+ "latency": metrics_obj.latency,
353
+ "retrieved_docs_count": len(retrieved_docs),
354
+ "hallucination": hallucination_grade
355
+ }
356
+ except Exception:
357
+ metrics = {"latency": latency, "retrieved_docs_count": len(retrieved_docs), "hallucination": hallucination_grade}
358
+ return answer, metrics
359
+
360
+ def global_query_with_metrics(self, question: str, top_k_communities: int = 5, k_values: List[int] = [1, 3, 5]) -> tuple:
361
+ print(f"\n🌍 执行全局查询并评估...")
362
+ start_t = time.time()
363
+ mentioned_entities = self.recognize_entities(question)
364
+ if not self.kg.community_summaries:
365
+ return "知识图谱尚未生成社区摘要,请先运行索引流程。", {
366
+ "error": "no_summaries",
367
+ "latency": 0.0,
368
+ "retrieved_docs_count": 0
369
+ }
370
+ community_summaries = []
371
+ for cid, summary in list(self.kg.community_summaries.items())[:top_k_communities]:
372
+ community_summaries.append((cid, summary))
373
+ summaries_text = "\n".join([f"社区 {cid}:\n{summary}\n" for cid, summary in community_summaries])
374
+ try:
375
+ answer = self.global_query_chain.invoke({
376
+ "question": question,
377
+ "community_summaries": summaries_text
378
+ }).strip()
379
+ except Exception:
380
+ answer = "查询失败,请重试。"
381
+ retrieved_docs = []
382
+ for cid, summary in community_summaries:
383
+ retrieved_docs.append(Document(page_content=summary, metadata={"community_id": str(cid)}))
384
+ try:
385
+ hallucination_grade = self.hallucination_grader.grade(answer, retrieved_docs)
386
+ except Exception:
387
+ hallucination_grade = "unknown"
388
+ relevant_docs = []
389
+ query_tokens = [t for t in question.split() if t]
390
+ for cid, summary in community_summaries:
391
+ ok = False
392
+ for ent in mentioned_entities:
393
+ if ent and ent.lower() in summary.lower():
394
+ ok = True
395
+ break
396
+ if not ok:
397
+ for t in query_tokens:
398
+ if t and t.lower() in summary.lower():
399
+ ok = True
400
+ break
401
+ if ok:
402
+ relevant_docs.append(Document(page_content=summary, metadata={"community_id": str(cid)}))
403
+ latency = time.time() - start_t
404
+ try:
405
+ evaluator = RetrievalEvaluator()
406
+ result = RetrievalResult(query=question, retrieved_docs=retrieved_docs, relevant_docs=relevant_docs, retrieval_time=latency)
407
+ metrics_obj = evaluator.evaluate_retrieval([result], k_values=k_values)
408
+ metrics = {
409
+ "precision_at_1": metrics_obj.precision_at_k.get(1, 0),
410
+ "precision_at_3": metrics_obj.precision_at_k.get(3, 0),
411
+ "precision_at_5": metrics_obj.precision_at_k.get(5, 0),
412
+ "recall_at_1": metrics_obj.recall_at_k.get(1, 0),
413
+ "recall_at_3": metrics_obj.recall_at_k.get(3, 0),
414
+ "recall_at_5": metrics_obj.recall_at_k.get(5, 0),
415
+ "map_score": metrics_obj.map_score,
416
+ "mrr": metrics_obj.mrr,
417
+ "latency": metrics_obj.latency,
418
+ "retrieved_docs_count": len(retrieved_docs),
419
+ "hallucination": hallucination_grade
420
+ }
421
+ except Exception:
422
+ metrics = {"latency": latency, "retrieved_docs_count": len(retrieved_docs), "hallucination": hallucination_grade}
423
+ return answer, metrics
424
+
425
+ def hybrid_query_with_metrics(self, question: str) -> Dict[str, str]:
426
+ print(f"\n🔀 执行混合查询并评估...")
427
+ local_answer, local_metrics = self.local_query_with_metrics(question)
428
+ global_answer, global_metrics = self.global_query_with_metrics(question)
429
+ return {
430
+ "local": local_answer,
431
+ "global": global_answer,
432
+ "local_hallucination": local_metrics.get("hallucination"),
433
+ "global_hallucination": global_metrics.get("hallucination"),
434
+ "local_metrics": local_metrics,
435
+ "global_metrics": global_metrics,
436
+ "question": question
437
+ }
438
+
439
  def hybrid_query(self, question: str) -> Dict[str, str]:
440
  """
441
  混合查询 - 同时执行本地和全局查询,返回两种结果
main_graphrag.py CHANGED
@@ -146,7 +146,7 @@ class AdaptiveRAGWithGraph:
146
  vector_context = self.doc_processor.format_docs(vector_docs[:3])
147
 
148
  # 图谱查询
149
- graph_results = self.graph_retriever.hybrid_query(question)
150
 
151
  result = {
152
  "question": question,
@@ -155,7 +155,11 @@ class AdaptiveRAGWithGraph:
155
  "context": vector_context[:500] + "..." if len(vector_context) > 500 else vector_context
156
  },
157
  "graph_local": graph_results["local"],
158
- "graph_global": graph_results["global"]
 
 
 
 
159
  }
160
 
161
  print("\n📊 结果汇总:")
 
146
  vector_context = self.doc_processor.format_docs(vector_docs[:3])
147
 
148
  # 图谱查询
149
+ graph_results = self.graph_retriever.hybrid_query_with_metrics(question)
150
 
151
  result = {
152
  "question": question,
 
155
  "context": vector_context[:500] + "..." if len(vector_context) > 500 else vector_context
156
  },
157
  "graph_local": graph_results["local"],
158
+ "graph_global": graph_results["global"],
159
+ "graph_local_hallucination": graph_results.get("local_hallucination"),
160
+ "graph_global_hallucination": graph_results.get("global_hallucination"),
161
+ "graph_local_metrics": graph_results.get("local_metrics"),
162
+ "graph_global_metrics": graph_results.get("global_metrics")
163
  }
164
 
165
  print("\n📊 结果汇总:")