adaptive_rag / knowledge_graph.py
lanny xu
add cuda
5858246
"""
知识图谱模块
实现GraphRAG的核心功能:图谱构建、社区检测、层次化摘要
"""
import networkx as nx
from typing import List, Dict, Set, Tuple, Optional
from collections import defaultdict
import json
try:
from community import community_louvain # python-louvain
LOUVAIN_AVAILABLE = True
except ImportError:
LOUVAIN_AVAILABLE = False
print("⚠️ python-louvain未安装,社区检测功能受限")
try:
from langchain_core.prompts import PromptTemplate
except ImportError:
try:
from langchain_core.prompts import PromptTemplate
except ImportError:
from langchain.prompts import PromptTemplate
from langchain_community.chat_models import ChatOllama
from langchain_core.output_parsers import StrOutputParser
from config import LOCAL_LLM
class KnowledgeGraph:
"""知识图谱类 - 使用NetworkX构建和管理图谱"""
def __init__(self):
self.graph = nx.Graph() # 无向图
self.entities = {} # 实体详细信息
self.communities = {} # 社区划分结果
self.community_summaries = {} # 社区摘要
def add_entity(self, name: str, entity_type: str, description: str = "", **kwargs):
"""添加实体节点"""
self.graph.add_node(
name,
type=entity_type,
description=description,
**kwargs
)
self.entities[name] = {
"name": name,
"type": entity_type,
"description": description,
**kwargs
}
def add_relation(self, source: str, target: str, relation_type: str,
description: str = "", weight: float = 1.0):
"""添加关系边"""
self.graph.add_edge(
source,
target,
relation_type=relation_type,
description=description,
weight=weight
)
def build_from_extractions(self, extraction_results: List[Dict]):
"""
从实体提取结果构建图谱
Args:
extraction_results: 实体和关系提取结果列表
"""
print("🔨 开始构建知识图谱...")
total_entities = 0
total_relations = 0
for result in extraction_results:
# 添加实体
entities = result.get("entities", [])
for entity in entities:
self.add_entity(
name=entity["name"],
entity_type=entity.get("type", "UNKNOWN"),
description=entity.get("description", "")
)
total_entities += 1
# 添加关系
relations = result.get("relations", [])
for relation in relations:
source = relation.get("source")
target = relation.get("target")
# 确保节点存在
if source in self.graph and target in self.graph:
self.add_relation(
source=source,
target=target,
relation_type=relation.get("relation_type", "RELATED_TO"),
description=relation.get("description", "")
)
total_relations += 1
print(f"✅ 图谱构建完成: {total_entities} 个实体, {total_relations} 个关系")
print(f" 实际节点数: {self.graph.number_of_nodes()}")
print(f" 实际边数: {self.graph.number_of_edges()}")
def detect_communities(self, algorithm: str = "louvain") -> Dict[str, int]:
"""
社区检测 - GraphRAG的核心组件
Args:
algorithm: 社区检测算法 ('louvain', 'greedy', 'label_propagation')
Returns:
节点到社区ID的映射
"""
print(f"🔍 开始社区检测 (算法: {algorithm})...")
if self.graph.number_of_nodes() == 0:
print("⚠️ 图谱为空,跳过社区检测")
return {}
try:
if algorithm == "louvain" and LOUVAIN_AVAILABLE:
communities = community_louvain.best_partition(self.graph)
elif algorithm == "greedy":
communities_generator = nx.community.greedy_modularity_communities(self.graph)
communities = {}
for idx, community_set in enumerate(communities_generator):
for node in community_set:
communities[node] = idx
elif algorithm == "label_propagation":
communities_generator = nx.community.label_propagation_communities(self.graph)
communities = {}
for idx, community_set in enumerate(communities_generator):
for node in community_set:
communities[node] = idx
else:
print(f"⚠️ 未知算法 {algorithm},使用贪婪算法")
communities_generator = nx.community.greedy_modularity_communities(self.graph)
communities = {}
for idx, community_set in enumerate(communities_generator):
for node in community_set:
communities[node] = idx
self.communities = communities
num_communities = len(set(communities.values()))
print(f"✅ 检测到 {num_communities} 个社区")
return communities
except Exception as e:
print(f"❌ 社区检测失败: {e}")
return {}
def get_community_members(self, community_id: int) -> List[str]:
"""获取指定社区的所有成员"""
return [node for node, cid in self.communities.items() if cid == community_id]
def get_community_subgraph(self, community_id: int) -> nx.Graph:
"""获取指定社区的子图"""
members = self.get_community_members(community_id)
return self.graph.subgraph(members)
def get_node_neighbors(self, node: str, depth: int = 1) -> Set[str]:
"""获取节点的邻居(支持多跳)"""
if node not in self.graph:
return set()
neighbors = {node}
current_layer = {node}
for _ in range(depth):
next_layer = set()
for n in current_layer:
next_layer.update(self.graph.neighbors(n))
neighbors.update(next_layer)
current_layer = next_layer
return neighbors
def get_entity_info(self, entity_name: str) -> Optional[Dict]:
"""获取实体详细信息"""
return self.entities.get(entity_name)
def search_entities_by_type(self, entity_type: str) -> List[str]:
"""按类型搜索实体"""
return [
name for name, data in self.entities.items()
if data.get("type") == entity_type
]
def get_statistics(self) -> Dict:
"""获取图谱统计信息"""
stats = {
"num_nodes": self.graph.number_of_nodes(),
"num_edges": self.graph.number_of_edges(),
"num_communities": len(set(self.communities.values())) if self.communities else 0,
"density": nx.density(self.graph) if self.graph.number_of_nodes() > 0 else 0,
"entity_types": {}
}
# 统计实体类型分布
for entity in self.entities.values():
etype = entity.get("type", "UNKNOWN")
stats["entity_types"][etype] = stats["entity_types"].get(etype, 0) + 1
return stats
def save_to_file(self, filepath: str):
"""保存图谱到文件"""
data = {
"entities": self.entities,
"edges": [
{
"source": u,
"target": v,
"data": data
}
for u, v, data in self.graph.edges(data=True)
],
"communities": self.communities,
"community_summaries": self.community_summaries
}
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
print(f"✅ 图谱已保存到: {filepath}")
def load_from_file(self, filepath: str):
"""从文件加载图谱"""
with open(filepath, 'r', encoding='utf-8') as f:
data = json.load(f)
self.entities = data.get("entities", {})
self.communities = data.get("communities", {})
self.community_summaries = data.get("community_summaries", {})
# 重建图
self.graph.clear()
for name, entity in self.entities.items():
self.add_entity(**entity)
for edge in data.get("edges", []):
self.graph.add_edge(
edge["source"],
edge["target"],
**edge["data"]
)
print(f"✅ 图谱已从文件加载: {filepath}")
class CommunitySummarizer:
"""社区摘要生成器 - GraphRAG的关键组件"""
def __init__(self):
self.llm = ChatOllama(model=LOCAL_LLM, temperature=0.3)
self.summary_prompt = PromptTemplate(
template="""你是一个知识图谱分析专家。请为以下社区生成一个综合摘要。
社区成员(实体):
{entities}
实体间的关系:
{relations}
请生成一个简洁的摘要,描述:
1. 这个社区的主题是什么
2. 主要包含哪些核心概念
3. 实体之间的关键关系
摘要(2-3句话):
""",
input_variables=["entities", "relations"]
)
self.summary_chain = self.summary_prompt | self.llm | StrOutputParser()
def summarize_community(self, kg: KnowledgeGraph, community_id: int) -> str:
"""
为指定社区生成摘要
Args:
kg: 知识图谱对象
community_id: 社区ID
Returns:
社区摘要文本
"""
members = kg.get_community_members(community_id)
subgraph = kg.get_community_subgraph(community_id)
# 准备实体信息
entity_info = []
for member in members[:20]: # 限制数量
info = kg.get_entity_info(member)
if info:
entity_info.append(
f"- {info['name']} ({info.get('type', 'UNKNOWN')}): {info.get('description', '无描述')}"
)
# 准备关系信息
relation_info = []
for u, v, data in subgraph.edges(data=True):
relation_info.append(
f"- {u} --[{data.get('relation_type', 'RELATED')}]--> {v}"
)
entities_text = "\n".join(entity_info) if entity_info else "无实体"
relations_text = "\n".join(relation_info[:15]) if relation_info else "无关系"
try:
summary = self.summary_chain.invoke({
"entities": entities_text,
"relations": relations_text
})
return summary.strip()
except Exception as e:
print(f"❌ 社区 {community_id} 摘要生成失败: {e}")
return f"社区{community_id}: 包含{len(members)}个实体"
def summarize_all_communities(self, kg: KnowledgeGraph) -> Dict[int, str]:
"""为所有社区生成摘要"""
if not kg.communities:
print("⚠️ 未检测到社区,请先运行社区检测")
return {}
community_ids = set(kg.communities.values())
print(f"📝 开始为 {len(community_ids)} 个社区生成摘要...")
summaries = {}
for cid in community_ids:
print(f" 处理社区 {cid}...")
summary = self.summarize_community(kg, cid)
summaries[cid] = summary
kg.community_summaries[cid] = summary
print("✅ 所有社区摘要生成完成")
return summaries
def initialize_knowledge_graph():
"""初始化知识图谱"""
return KnowledgeGraph()
def initialize_community_summarizer():
"""初始化社区摘要生成器"""
return CommunitySummarizer()