Spaces:
Paused
Paused
| """ | |
| 知识图谱模块 | |
| 实现GraphRAG的核心功能:图谱构建、社区检测、层次化摘要 | |
| """ | |
| import networkx as nx | |
| from typing import List, Dict, Set, Tuple, Optional | |
| from collections import defaultdict | |
| import json | |
| try: | |
| from community import community_louvain # python-louvain | |
| LOUVAIN_AVAILABLE = True | |
| except ImportError: | |
| LOUVAIN_AVAILABLE = False | |
| print("⚠️ python-louvain未安装,社区检测功能受限") | |
| try: | |
| from langchain_core.prompts import PromptTemplate | |
| except ImportError: | |
| try: | |
| from langchain_core.prompts import PromptTemplate | |
| except ImportError: | |
| from langchain.prompts import PromptTemplate | |
| from langchain_community.chat_models import ChatOllama | |
| from langchain_core.output_parsers import StrOutputParser | |
| from config import LOCAL_LLM | |
| class KnowledgeGraph: | |
| """知识图谱类 - 使用NetworkX构建和管理图谱""" | |
| def __init__(self): | |
| self.graph = nx.Graph() # 无向图 | |
| self.entities = {} # 实体详细信息 | |
| self.communities = {} # 社区划分结果 | |
| self.community_summaries = {} # 社区摘要 | |
| def add_entity(self, name: str, entity_type: str, description: str = "", **kwargs): | |
| """添加实体节点""" | |
| self.graph.add_node( | |
| name, | |
| type=entity_type, | |
| description=description, | |
| **kwargs | |
| ) | |
| self.entities[name] = { | |
| "name": name, | |
| "type": entity_type, | |
| "description": description, | |
| **kwargs | |
| } | |
| def add_relation(self, source: str, target: str, relation_type: str, | |
| description: str = "", weight: float = 1.0): | |
| """添加关系边""" | |
| self.graph.add_edge( | |
| source, | |
| target, | |
| relation_type=relation_type, | |
| description=description, | |
| weight=weight | |
| ) | |
| def build_from_extractions(self, extraction_results: List[Dict]): | |
| """ | |
| 从实体提取结果构建图谱 | |
| Args: | |
| extraction_results: 实体和关系提取结果列表 | |
| """ | |
| print("🔨 开始构建知识图谱...") | |
| total_entities = 0 | |
| total_relations = 0 | |
| for result in extraction_results: | |
| # 添加实体 | |
| entities = result.get("entities", []) | |
| for entity in entities: | |
| self.add_entity( | |
| name=entity["name"], | |
| entity_type=entity.get("type", "UNKNOWN"), | |
| description=entity.get("description", "") | |
| ) | |
| total_entities += 1 | |
| # 添加关系 | |
| relations = result.get("relations", []) | |
| for relation in relations: | |
| source = relation.get("source") | |
| target = relation.get("target") | |
| # 确保节点存在 | |
| if source in self.graph and target in self.graph: | |
| self.add_relation( | |
| source=source, | |
| target=target, | |
| relation_type=relation.get("relation_type", "RELATED_TO"), | |
| description=relation.get("description", "") | |
| ) | |
| total_relations += 1 | |
| print(f"✅ 图谱构建完成: {total_entities} 个实体, {total_relations} 个关系") | |
| print(f" 实际节点数: {self.graph.number_of_nodes()}") | |
| print(f" 实际边数: {self.graph.number_of_edges()}") | |
| def detect_communities(self, algorithm: str = "louvain") -> Dict[str, int]: | |
| """ | |
| 社区检测 - GraphRAG的核心组件 | |
| Args: | |
| algorithm: 社区检测算法 ('louvain', 'greedy', 'label_propagation') | |
| Returns: | |
| 节点到社区ID的映射 | |
| """ | |
| print(f"🔍 开始社区检测 (算法: {algorithm})...") | |
| if self.graph.number_of_nodes() == 0: | |
| print("⚠️ 图谱为空,跳过社区检测") | |
| return {} | |
| try: | |
| if algorithm == "louvain" and LOUVAIN_AVAILABLE: | |
| communities = community_louvain.best_partition(self.graph) | |
| elif algorithm == "greedy": | |
| communities_generator = nx.community.greedy_modularity_communities(self.graph) | |
| communities = {} | |
| for idx, community_set in enumerate(communities_generator): | |
| for node in community_set: | |
| communities[node] = idx | |
| elif algorithm == "label_propagation": | |
| communities_generator = nx.community.label_propagation_communities(self.graph) | |
| communities = {} | |
| for idx, community_set in enumerate(communities_generator): | |
| for node in community_set: | |
| communities[node] = idx | |
| else: | |
| print(f"⚠️ 未知算法 {algorithm},使用贪婪算法") | |
| communities_generator = nx.community.greedy_modularity_communities(self.graph) | |
| communities = {} | |
| for idx, community_set in enumerate(communities_generator): | |
| for node in community_set: | |
| communities[node] = idx | |
| self.communities = communities | |
| num_communities = len(set(communities.values())) | |
| print(f"✅ 检测到 {num_communities} 个社区") | |
| return communities | |
| except Exception as e: | |
| print(f"❌ 社区检测失败: {e}") | |
| return {} | |
| def get_community_members(self, community_id: int) -> List[str]: | |
| """获取指定社区的所有成员""" | |
| return [node for node, cid in self.communities.items() if cid == community_id] | |
| def get_community_subgraph(self, community_id: int) -> nx.Graph: | |
| """获取指定社区的子图""" | |
| members = self.get_community_members(community_id) | |
| return self.graph.subgraph(members) | |
| def get_node_neighbors(self, node: str, depth: int = 1) -> Set[str]: | |
| """获取节点的邻居(支持多跳)""" | |
| if node not in self.graph: | |
| return set() | |
| neighbors = {node} | |
| current_layer = {node} | |
| for _ in range(depth): | |
| next_layer = set() | |
| for n in current_layer: | |
| next_layer.update(self.graph.neighbors(n)) | |
| neighbors.update(next_layer) | |
| current_layer = next_layer | |
| return neighbors | |
| def get_entity_info(self, entity_name: str) -> Optional[Dict]: | |
| """获取实体详细信息""" | |
| return self.entities.get(entity_name) | |
| def search_entities_by_type(self, entity_type: str) -> List[str]: | |
| """按类型搜索实体""" | |
| return [ | |
| name for name, data in self.entities.items() | |
| if data.get("type") == entity_type | |
| ] | |
| def get_statistics(self) -> Dict: | |
| """获取图谱统计信息""" | |
| stats = { | |
| "num_nodes": self.graph.number_of_nodes(), | |
| "num_edges": self.graph.number_of_edges(), | |
| "num_communities": len(set(self.communities.values())) if self.communities else 0, | |
| "density": nx.density(self.graph) if self.graph.number_of_nodes() > 0 else 0, | |
| "entity_types": {} | |
| } | |
| # 统计实体类型分布 | |
| for entity in self.entities.values(): | |
| etype = entity.get("type", "UNKNOWN") | |
| stats["entity_types"][etype] = stats["entity_types"].get(etype, 0) + 1 | |
| return stats | |
| def save_to_file(self, filepath: str): | |
| """保存图谱到文件""" | |
| data = { | |
| "entities": self.entities, | |
| "edges": [ | |
| { | |
| "source": u, | |
| "target": v, | |
| "data": data | |
| } | |
| for u, v, data in self.graph.edges(data=True) | |
| ], | |
| "communities": self.communities, | |
| "community_summaries": self.community_summaries | |
| } | |
| with open(filepath, 'w', encoding='utf-8') as f: | |
| json.dump(data, f, ensure_ascii=False, indent=2) | |
| print(f"✅ 图谱已保存到: {filepath}") | |
| def load_from_file(self, filepath: str): | |
| """从文件加载图谱""" | |
| with open(filepath, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| self.entities = data.get("entities", {}) | |
| self.communities = data.get("communities", {}) | |
| self.community_summaries = data.get("community_summaries", {}) | |
| # 重建图 | |
| self.graph.clear() | |
| for name, entity in self.entities.items(): | |
| self.add_entity(**entity) | |
| for edge in data.get("edges", []): | |
| self.graph.add_edge( | |
| edge["source"], | |
| edge["target"], | |
| **edge["data"] | |
| ) | |
| print(f"✅ 图谱已从文件加载: {filepath}") | |
| class CommunitySummarizer: | |
| """社区摘要生成器 - GraphRAG的关键组件""" | |
| def __init__(self): | |
| self.llm = ChatOllama(model=LOCAL_LLM, temperature=0.3) | |
| self.summary_prompt = PromptTemplate( | |
| template="""你是一个知识图谱分析专家。请为以下社区生成一个综合摘要。 | |
| 社区成员(实体): | |
| {entities} | |
| 实体间的关系: | |
| {relations} | |
| 请生成一个简洁的摘要,描述: | |
| 1. 这个社区的主题是什么 | |
| 2. 主要包含哪些核心概念 | |
| 3. 实体之间的关键关系 | |
| 摘要(2-3句话): | |
| """, | |
| input_variables=["entities", "relations"] | |
| ) | |
| self.summary_chain = self.summary_prompt | self.llm | StrOutputParser() | |
| def summarize_community(self, kg: KnowledgeGraph, community_id: int) -> str: | |
| """ | |
| 为指定社区生成摘要 | |
| Args: | |
| kg: 知识图谱对象 | |
| community_id: 社区ID | |
| Returns: | |
| 社区摘要文本 | |
| """ | |
| members = kg.get_community_members(community_id) | |
| subgraph = kg.get_community_subgraph(community_id) | |
| # 准备实体信息 | |
| entity_info = [] | |
| for member in members[:20]: # 限制数量 | |
| info = kg.get_entity_info(member) | |
| if info: | |
| entity_info.append( | |
| f"- {info['name']} ({info.get('type', 'UNKNOWN')}): {info.get('description', '无描述')}" | |
| ) | |
| # 准备关系信息 | |
| relation_info = [] | |
| for u, v, data in subgraph.edges(data=True): | |
| relation_info.append( | |
| f"- {u} --[{data.get('relation_type', 'RELATED')}]--> {v}" | |
| ) | |
| entities_text = "\n".join(entity_info) if entity_info else "无实体" | |
| relations_text = "\n".join(relation_info[:15]) if relation_info else "无关系" | |
| try: | |
| summary = self.summary_chain.invoke({ | |
| "entities": entities_text, | |
| "relations": relations_text | |
| }) | |
| return summary.strip() | |
| except Exception as e: | |
| print(f"❌ 社区 {community_id} 摘要生成失败: {e}") | |
| return f"社区{community_id}: 包含{len(members)}个实体" | |
| def summarize_all_communities(self, kg: KnowledgeGraph) -> Dict[int, str]: | |
| """为所有社区生成摘要""" | |
| if not kg.communities: | |
| print("⚠️ 未检测到社区,请先运行社区检测") | |
| return {} | |
| community_ids = set(kg.communities.values()) | |
| print(f"📝 开始为 {len(community_ids)} 个社区生成摘要...") | |
| summaries = {} | |
| for cid in community_ids: | |
| print(f" 处理社区 {cid}...") | |
| summary = self.summarize_community(kg, cid) | |
| summaries[cid] = summary | |
| kg.community_summaries[cid] = summary | |
| print("✅ 所有社区摘要生成完成") | |
| return summaries | |
| def initialize_knowledge_graph(): | |
| """初始化知识图谱""" | |
| return KnowledgeGraph() | |
| def initialize_community_summarizer(): | |
| """初始化社区摘要生成器""" | |
| return CommunitySummarizer() | |