Spaces:
Paused
Paused
File size: 12,482 Bytes
399f3c6 90b33eb 94a7032 5858246 399f3c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 |
"""
知识图谱模块
实现GraphRAG的核心功能:图谱构建、社区检测、层次化摘要
"""
import networkx as nx
from typing import List, Dict, Set, Tuple, Optional
from collections import defaultdict
import json
try:
from community import community_louvain # python-louvain
LOUVAIN_AVAILABLE = True
except ImportError:
LOUVAIN_AVAILABLE = False
print("⚠️ python-louvain未安装,社区检测功能受限")
try:
from langchain_core.prompts import PromptTemplate
except ImportError:
try:
from langchain_core.prompts import PromptTemplate
except ImportError:
from langchain.prompts import PromptTemplate
from langchain_community.chat_models import ChatOllama
from langchain_core.output_parsers import StrOutputParser
from config import LOCAL_LLM
class KnowledgeGraph:
"""知识图谱类 - 使用NetworkX构建和管理图谱"""
def __init__(self):
self.graph = nx.Graph() # 无向图
self.entities = {} # 实体详细信息
self.communities = {} # 社区划分结果
self.community_summaries = {} # 社区摘要
def add_entity(self, name: str, entity_type: str, description: str = "", **kwargs):
"""添加实体节点"""
self.graph.add_node(
name,
type=entity_type,
description=description,
**kwargs
)
self.entities[name] = {
"name": name,
"type": entity_type,
"description": description,
**kwargs
}
def add_relation(self, source: str, target: str, relation_type: str,
description: str = "", weight: float = 1.0):
"""添加关系边"""
self.graph.add_edge(
source,
target,
relation_type=relation_type,
description=description,
weight=weight
)
def build_from_extractions(self, extraction_results: List[Dict]):
"""
从实体提取结果构建图谱
Args:
extraction_results: 实体和关系提取结果列表
"""
print("🔨 开始构建知识图谱...")
total_entities = 0
total_relations = 0
for result in extraction_results:
# 添加实体
entities = result.get("entities", [])
for entity in entities:
self.add_entity(
name=entity["name"],
entity_type=entity.get("type", "UNKNOWN"),
description=entity.get("description", "")
)
total_entities += 1
# 添加关系
relations = result.get("relations", [])
for relation in relations:
source = relation.get("source")
target = relation.get("target")
# 确保节点存在
if source in self.graph and target in self.graph:
self.add_relation(
source=source,
target=target,
relation_type=relation.get("relation_type", "RELATED_TO"),
description=relation.get("description", "")
)
total_relations += 1
print(f"✅ 图谱构建完成: {total_entities} 个实体, {total_relations} 个关系")
print(f" 实际节点数: {self.graph.number_of_nodes()}")
print(f" 实际边数: {self.graph.number_of_edges()}")
def detect_communities(self, algorithm: str = "louvain") -> Dict[str, int]:
"""
社区检测 - GraphRAG的核心组件
Args:
algorithm: 社区检测算法 ('louvain', 'greedy', 'label_propagation')
Returns:
节点到社区ID的映射
"""
print(f"🔍 开始社区检测 (算法: {algorithm})...")
if self.graph.number_of_nodes() == 0:
print("⚠️ 图谱为空,跳过社区检测")
return {}
try:
if algorithm == "louvain" and LOUVAIN_AVAILABLE:
communities = community_louvain.best_partition(self.graph)
elif algorithm == "greedy":
communities_generator = nx.community.greedy_modularity_communities(self.graph)
communities = {}
for idx, community_set in enumerate(communities_generator):
for node in community_set:
communities[node] = idx
elif algorithm == "label_propagation":
communities_generator = nx.community.label_propagation_communities(self.graph)
communities = {}
for idx, community_set in enumerate(communities_generator):
for node in community_set:
communities[node] = idx
else:
print(f"⚠️ 未知算法 {algorithm},使用贪婪算法")
communities_generator = nx.community.greedy_modularity_communities(self.graph)
communities = {}
for idx, community_set in enumerate(communities_generator):
for node in community_set:
communities[node] = idx
self.communities = communities
num_communities = len(set(communities.values()))
print(f"✅ 检测到 {num_communities} 个社区")
return communities
except Exception as e:
print(f"❌ 社区检测失败: {e}")
return {}
def get_community_members(self, community_id: int) -> List[str]:
"""获取指定社区的所有成员"""
return [node for node, cid in self.communities.items() if cid == community_id]
def get_community_subgraph(self, community_id: int) -> nx.Graph:
"""获取指定社区的子图"""
members = self.get_community_members(community_id)
return self.graph.subgraph(members)
def get_node_neighbors(self, node: str, depth: int = 1) -> Set[str]:
"""获取节点的邻居(支持多跳)"""
if node not in self.graph:
return set()
neighbors = {node}
current_layer = {node}
for _ in range(depth):
next_layer = set()
for n in current_layer:
next_layer.update(self.graph.neighbors(n))
neighbors.update(next_layer)
current_layer = next_layer
return neighbors
def get_entity_info(self, entity_name: str) -> Optional[Dict]:
"""获取实体详细信息"""
return self.entities.get(entity_name)
def search_entities_by_type(self, entity_type: str) -> List[str]:
"""按类型搜索实体"""
return [
name for name, data in self.entities.items()
if data.get("type") == entity_type
]
def get_statistics(self) -> Dict:
"""获取图谱统计信息"""
stats = {
"num_nodes": self.graph.number_of_nodes(),
"num_edges": self.graph.number_of_edges(),
"num_communities": len(set(self.communities.values())) if self.communities else 0,
"density": nx.density(self.graph) if self.graph.number_of_nodes() > 0 else 0,
"entity_types": {}
}
# 统计实体类型分布
for entity in self.entities.values():
etype = entity.get("type", "UNKNOWN")
stats["entity_types"][etype] = stats["entity_types"].get(etype, 0) + 1
return stats
def save_to_file(self, filepath: str):
"""保存图谱到文件"""
data = {
"entities": self.entities,
"edges": [
{
"source": u,
"target": v,
"data": data
}
for u, v, data in self.graph.edges(data=True)
],
"communities": self.communities,
"community_summaries": self.community_summaries
}
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
print(f"✅ 图谱已保存到: {filepath}")
def load_from_file(self, filepath: str):
"""从文件加载图谱"""
with open(filepath, 'r', encoding='utf-8') as f:
data = json.load(f)
self.entities = data.get("entities", {})
self.communities = data.get("communities", {})
self.community_summaries = data.get("community_summaries", {})
# 重建图
self.graph.clear()
for name, entity in self.entities.items():
self.add_entity(**entity)
for edge in data.get("edges", []):
self.graph.add_edge(
edge["source"],
edge["target"],
**edge["data"]
)
print(f"✅ 图谱已从文件加载: {filepath}")
class CommunitySummarizer:
"""社区摘要生成器 - GraphRAG的关键组件"""
def __init__(self):
self.llm = ChatOllama(model=LOCAL_LLM, temperature=0.3)
self.summary_prompt = PromptTemplate(
template="""你是一个知识图谱分析专家。请为以下社区生成一个综合摘要。
社区成员(实体):
{entities}
实体间的关系:
{relations}
请生成一个简洁的摘要,描述:
1. 这个社区的主题是什么
2. 主要包含哪些核心概念
3. 实体之间的关键关系
摘要(2-3句话):
""",
input_variables=["entities", "relations"]
)
self.summary_chain = self.summary_prompt | self.llm | StrOutputParser()
def summarize_community(self, kg: KnowledgeGraph, community_id: int) -> str:
"""
为指定社区生成摘要
Args:
kg: 知识图谱对象
community_id: 社区ID
Returns:
社区摘要文本
"""
members = kg.get_community_members(community_id)
subgraph = kg.get_community_subgraph(community_id)
# 准备实体信息
entity_info = []
for member in members[:20]: # 限制数量
info = kg.get_entity_info(member)
if info:
entity_info.append(
f"- {info['name']} ({info.get('type', 'UNKNOWN')}): {info.get('description', '无描述')}"
)
# 准备关系信息
relation_info = []
for u, v, data in subgraph.edges(data=True):
relation_info.append(
f"- {u} --[{data.get('relation_type', 'RELATED')}]--> {v}"
)
entities_text = "\n".join(entity_info) if entity_info else "无实体"
relations_text = "\n".join(relation_info[:15]) if relation_info else "无关系"
try:
summary = self.summary_chain.invoke({
"entities": entities_text,
"relations": relations_text
})
return summary.strip()
except Exception as e:
print(f"❌ 社区 {community_id} 摘要生成失败: {e}")
return f"社区{community_id}: 包含{len(members)}个实体"
def summarize_all_communities(self, kg: KnowledgeGraph) -> Dict[int, str]:
"""为所有社区生成摘要"""
if not kg.communities:
print("⚠️ 未检测到社区,请先运行社区检测")
return {}
community_ids = set(kg.communities.values())
print(f"📝 开始为 {len(community_ids)} 个社区生成摘要...")
summaries = {}
for cid in community_ids:
print(f" 处理社区 {cid}...")
summary = self.summarize_community(kg, cid)
summaries[cid] = summary
kg.community_summaries[cid] = summary
print("✅ 所有社区摘要生成完成")
return summaries
def initialize_knowledge_graph():
"""初始化知识图谱"""
return KnowledgeGraph()
def initialize_community_summarizer():
"""初始化社区摘要生成器"""
return CommunitySummarizer()
|