File size: 12,482 Bytes
399f3c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90b33eb
 
94a7032
 
5858246
 
 
399f3c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
"""
知识图谱模块
实现GraphRAG的核心功能:图谱构建、社区检测、层次化摘要
"""

import networkx as nx
from typing import List, Dict, Set, Tuple, Optional
from collections import defaultdict
import json

try:
    from community import community_louvain  # python-louvain
    LOUVAIN_AVAILABLE = True
except ImportError:
    LOUVAIN_AVAILABLE = False
    print("⚠️ python-louvain未安装,社区检测功能受限")

try:
    from langchain_core.prompts import PromptTemplate
except ImportError:
    try:
        from langchain_core.prompts import PromptTemplate
    except ImportError:
        from langchain.prompts import PromptTemplate
from langchain_community.chat_models import ChatOllama
from langchain_core.output_parsers import StrOutputParser
from config import LOCAL_LLM


class KnowledgeGraph:
    """知识图谱类 - 使用NetworkX构建和管理图谱"""
    
    def __init__(self):
        self.graph = nx.Graph()  # 无向图
        self.entities = {}  # 实体详细信息
        self.communities = {}  # 社区划分结果
        self.community_summaries = {}  # 社区摘要
        
    def add_entity(self, name: str, entity_type: str, description: str = "", **kwargs):
        """添加实体节点"""
        self.graph.add_node(
            name,
            type=entity_type,
            description=description,
            **kwargs
        )
        self.entities[name] = {
            "name": name,
            "type": entity_type,
            "description": description,
            **kwargs
        }
    
    def add_relation(self, source: str, target: str, relation_type: str, 
                    description: str = "", weight: float = 1.0):
        """添加关系边"""
        self.graph.add_edge(
            source,
            target,
            relation_type=relation_type,
            description=description,
            weight=weight
        )
    
    def build_from_extractions(self, extraction_results: List[Dict]):
        """
        从实体提取结果构建图谱
        
        Args:
            extraction_results: 实体和关系提取结果列表
        """
        print("🔨 开始构建知识图谱...")
        
        total_entities = 0
        total_relations = 0
        
        for result in extraction_results:
            # 添加实体
            entities = result.get("entities", [])
            for entity in entities:
                self.add_entity(
                    name=entity["name"],
                    entity_type=entity.get("type", "UNKNOWN"),
                    description=entity.get("description", "")
                )
                total_entities += 1
            
            # 添加关系
            relations = result.get("relations", [])
            for relation in relations:
                source = relation.get("source")
                target = relation.get("target")
                
                # 确保节点存在
                if source in self.graph and target in self.graph:
                    self.add_relation(
                        source=source,
                        target=target,
                        relation_type=relation.get("relation_type", "RELATED_TO"),
                        description=relation.get("description", "")
                    )
                    total_relations += 1
        
        print(f"✅ 图谱构建完成: {total_entities} 个实体, {total_relations} 个关系")
        print(f"   实际节点数: {self.graph.number_of_nodes()}")
        print(f"   实际边数: {self.graph.number_of_edges()}")
    
    def detect_communities(self, algorithm: str = "louvain") -> Dict[str, int]:
        """
        社区检测 - GraphRAG的核心组件
        
        Args:
            algorithm: 社区检测算法 ('louvain', 'greedy', 'label_propagation')
            
        Returns:
            节点到社区ID的映射
        """
        print(f"🔍 开始社区检测 (算法: {algorithm})...")
        
        if self.graph.number_of_nodes() == 0:
            print("⚠️ 图谱为空,跳过社区检测")
            return {}
        
        try:
            if algorithm == "louvain" and LOUVAIN_AVAILABLE:
                communities = community_louvain.best_partition(self.graph)
            elif algorithm == "greedy":
                communities_generator = nx.community.greedy_modularity_communities(self.graph)
                communities = {}
                for idx, community_set in enumerate(communities_generator):
                    for node in community_set:
                        communities[node] = idx
            elif algorithm == "label_propagation":
                communities_generator = nx.community.label_propagation_communities(self.graph)
                communities = {}
                for idx, community_set in enumerate(communities_generator):
                    for node in community_set:
                        communities[node] = idx
            else:
                print(f"⚠️ 未知算法 {algorithm},使用贪婪算法")
                communities_generator = nx.community.greedy_modularity_communities(self.graph)
                communities = {}
                for idx, community_set in enumerate(communities_generator):
                    for node in community_set:
                        communities[node] = idx
            
            self.communities = communities
            num_communities = len(set(communities.values()))
            print(f"✅ 检测到 {num_communities} 个社区")
            
            return communities
            
        except Exception as e:
            print(f"❌ 社区检测失败: {e}")
            return {}
    
    def get_community_members(self, community_id: int) -> List[str]:
        """获取指定社区的所有成员"""
        return [node for node, cid in self.communities.items() if cid == community_id]
    
    def get_community_subgraph(self, community_id: int) -> nx.Graph:
        """获取指定社区的子图"""
        members = self.get_community_members(community_id)
        return self.graph.subgraph(members)
    
    def get_node_neighbors(self, node: str, depth: int = 1) -> Set[str]:
        """获取节点的邻居(支持多跳)"""
        if node not in self.graph:
            return set()
        
        neighbors = {node}
        current_layer = {node}
        
        for _ in range(depth):
            next_layer = set()
            for n in current_layer:
                next_layer.update(self.graph.neighbors(n))
            neighbors.update(next_layer)
            current_layer = next_layer
        
        return neighbors
    
    def get_entity_info(self, entity_name: str) -> Optional[Dict]:
        """获取实体详细信息"""
        return self.entities.get(entity_name)
    
    def search_entities_by_type(self, entity_type: str) -> List[str]:
        """按类型搜索实体"""
        return [
            name for name, data in self.entities.items()
            if data.get("type") == entity_type
        ]
    
    def get_statistics(self) -> Dict:
        """获取图谱统计信息"""
        stats = {
            "num_nodes": self.graph.number_of_nodes(),
            "num_edges": self.graph.number_of_edges(),
            "num_communities": len(set(self.communities.values())) if self.communities else 0,
            "density": nx.density(self.graph) if self.graph.number_of_nodes() > 0 else 0,
            "entity_types": {}
        }
        
        # 统计实体类型分布
        for entity in self.entities.values():
            etype = entity.get("type", "UNKNOWN")
            stats["entity_types"][etype] = stats["entity_types"].get(etype, 0) + 1
        
        return stats
    
    def save_to_file(self, filepath: str):
        """保存图谱到文件"""
        data = {
            "entities": self.entities,
            "edges": [
                {
                    "source": u,
                    "target": v,
                    "data": data
                }
                for u, v, data in self.graph.edges(data=True)
            ],
            "communities": self.communities,
            "community_summaries": self.community_summaries
        }
        
        with open(filepath, 'w', encoding='utf-8') as f:
            json.dump(data, f, ensure_ascii=False, indent=2)
        
        print(f"✅ 图谱已保存到: {filepath}")
    
    def load_from_file(self, filepath: str):
        """从文件加载图谱"""
        with open(filepath, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        self.entities = data.get("entities", {})
        self.communities = data.get("communities", {})
        self.community_summaries = data.get("community_summaries", {})
        
        # 重建图
        self.graph.clear()
        for name, entity in self.entities.items():
            self.add_entity(**entity)
        
        for edge in data.get("edges", []):
            self.graph.add_edge(
                edge["source"],
                edge["target"],
                **edge["data"]
            )
        
        print(f"✅ 图谱已从文件加载: {filepath}")


class CommunitySummarizer:
    """社区摘要生成器 - GraphRAG的关键组件"""
    
    def __init__(self):
        self.llm = ChatOllama(model=LOCAL_LLM, temperature=0.3)
        
        self.summary_prompt = PromptTemplate(
            template="""你是一个知识图谱分析专家。请为以下社区生成一个综合摘要。

社区成员(实体):
{entities}

实体间的关系:
{relations}

请生成一个简洁的摘要,描述:
1. 这个社区的主题是什么
2. 主要包含哪些核心概念
3. 实体之间的关键关系

摘要(2-3句话):
""",
            input_variables=["entities", "relations"]
        )
        
        self.summary_chain = self.summary_prompt | self.llm | StrOutputParser()
    
    def summarize_community(self, kg: KnowledgeGraph, community_id: int) -> str:
        """
        为指定社区生成摘要
        
        Args:
            kg: 知识图谱对象
            community_id: 社区ID
            
        Returns:
            社区摘要文本
        """
        members = kg.get_community_members(community_id)
        subgraph = kg.get_community_subgraph(community_id)
        
        # 准备实体信息
        entity_info = []
        for member in members[:20]:  # 限制数量
            info = kg.get_entity_info(member)
            if info:
                entity_info.append(
                    f"- {info['name']} ({info.get('type', 'UNKNOWN')}): {info.get('description', '无描述')}"
                )
        
        # 准备关系信息
        relation_info = []
        for u, v, data in subgraph.edges(data=True):
            relation_info.append(
                f"- {u} --[{data.get('relation_type', 'RELATED')}]--> {v}"
            )
        
        entities_text = "\n".join(entity_info) if entity_info else "无实体"
        relations_text = "\n".join(relation_info[:15]) if relation_info else "无关系"
        
        try:
            summary = self.summary_chain.invoke({
                "entities": entities_text,
                "relations": relations_text
            })
            return summary.strip()
        except Exception as e:
            print(f"❌ 社区 {community_id} 摘要生成失败: {e}")
            return f"社区{community_id}: 包含{len(members)}个实体"
    
    def summarize_all_communities(self, kg: KnowledgeGraph) -> Dict[int, str]:
        """为所有社区生成摘要"""
        if not kg.communities:
            print("⚠️ 未检测到社区,请先运行社区检测")
            return {}
        
        community_ids = set(kg.communities.values())
        print(f"📝 开始为 {len(community_ids)} 个社区生成摘要...")
        
        summaries = {}
        for cid in community_ids:
            print(f"   处理社区 {cid}...")
            summary = self.summarize_community(kg, cid)
            summaries[cid] = summary
            kg.community_summaries[cid] = summary
        
        print("✅ 所有社区摘要生成完成")
        return summaries


def initialize_knowledge_graph():
    """初始化知识图谱"""
    return KnowledgeGraph()


def initialize_community_summarizer():
    """初始化社区摘要生成器"""
    return CommunitySummarizer()