lanny xu commited on
Commit
399f3c6
·
0 Parent(s):

Initial commit

Browse files
Files changed (48) hide show
  1. .env +2 -0
  2. COLAB_FILES_SUMMARY.md +305 -0
  3. COLAB_GPU_GUIDE.md +271 -0
  4. COLAB_OLLAMA_GUIDE.md +421 -0
  5. DEPLOYMENT_GUIDE.md +440 -0
  6. Dockerfile.gpu +58 -0
  7. GRAPHRAG_GUIDE.md +401 -0
  8. GRAPHRAG_INTEGRATION_SUMMARY.md +427 -0
  9. QUICKSTART.md +73 -0
  10. README.md +157 -0
  11. RERANKING_PRINCIPLES.md +359 -0
  12. __pycache__/config.cpython-310.pyc +0 -0
  13. __pycache__/config.cpython-313.pyc +0 -0
  14. __pycache__/document_processor.cpython-310.pyc +0 -0
  15. __pycache__/document_processor.cpython-313.pyc +0 -0
  16. __pycache__/entity_extractor.cpython-310.pyc +0 -0
  17. __pycache__/graph_indexer.cpython-310.pyc +0 -0
  18. __pycache__/graph_retriever.cpython-310.pyc +0 -0
  19. __pycache__/knowledge_graph.cpython-310.pyc +0 -0
  20. __pycache__/main.cpython-310.pyc +0 -0
  21. __pycache__/main.cpython-313.pyc +0 -0
  22. __pycache__/reranker.cpython-310.pyc +0 -0
  23. __pycache__/routers_and_graders.cpython-310.pyc +0 -0
  24. __pycache__/routers_and_graders.cpython-313.pyc +0 -0
  25. __pycache__/workflow_nodes.cpython-310.pyc +0 -0
  26. __pycache__/workflow_nodes.cpython-313.pyc +0 -0
  27. colab_gpu_demo.ipynb +588 -0
  28. colab_gpu_test.py +269 -0
  29. colab_quick_test.py +278 -0
  30. colab_setup_and_run.py +375 -0
  31. config.py +87 -0
  32. deploy_gpu.sh +240 -0
  33. docker-compose.gpu.yml +70 -0
  34. document_processor.py +195 -0
  35. entity_extractor.py +229 -0
  36. graph_indexer.py +145 -0
  37. graph_retriever.py +275 -0
  38. knowledge_graph.py +347 -0
  39. local_llm_rag.py +428 -0
  40. main.py +174 -0
  41. main_graphrag.py +293 -0
  42. requirements.txt +44 -0
  43. requirements_gpu.txt +56 -0
  44. requirements_graphrag.txt +31 -0
  45. reranker.py +350 -0
  46. routers_and_graders.py +147 -0
  47. test_reranking.py +218 -0
  48. workflow_nodes.py +240 -0
.env ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ TAVILY_API_KEY="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"
2
+ # NOMIC_API_KEY="nk-kt4Tu3UdwFpIlDdxLcd9AK3a7cfdAKhoXvPbJ78oVlE"
COLAB_FILES_SUMMARY.md ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 📦 Google Colab GPU测试文件总结
2
+
3
+ ## ✅ 已创建的文件
4
+
5
+ | 文件名 | 类型 | 用途 | 推荐度 |
6
+ |--------|------|------|--------|
7
+ | **colab_gpu_demo.ipynb** | Jupyter Notebook | 完整的交互式GPU测试 | ⭐⭐⭐⭐⭐ |
8
+ | **colab_quick_test.py** | Python脚本 | 一键快速GPU测试 | ⭐⭐⭐⭐⭐ |
9
+ | **colab_gpu_test.py** | Python脚本 | 模块化GPU测试工具 | ⭐⭐⭐⭐ |
10
+ | **COLAB_GPU_GUIDE.md** | 文档 | 详细使用指南 | ⭐⭐⭐⭐⭐ |
11
+
12
+ ---
13
+
14
+ ## 🚀 快速开始(3种方式)
15
+
16
+ ### 方式1: Notebook交互式测试 ⭐推荐
17
+
18
+ **适合**: 第一次使用,想要详细了解每个步骤
19
+
20
+ ```bash
21
+ # 步骤1: 上传文件
22
+ 上传 colab_gpu_demo.ipynb 到 Google Colab
23
+
24
+ # 步骤2: 启用GPU
25
+ 运行时 → 更改运行时类型 → GPU
26
+
27
+ # 步骤3: 运行
28
+ 运行时 → 全部运行
29
+ ```
30
+
31
+ **优势**:
32
+ - ✅ 可视化输出
33
+ - ✅ 分步执行,易于理解
34
+ - ✅ 支持实时修改
35
+ - ✅ Markdown说明清晰
36
+
37
+ ---
38
+
39
+ ### 方式2: 快速一键测试 ⭐最快
40
+
41
+ **适合**: 快速验证GPU性能
42
+
43
+ ```python
44
+ # 在Colab新建笔记本,运行以下代码:
45
+
46
+ # 1. 启用GPU (运行时 → GPU)
47
+
48
+ # 2. 复制并运行
49
+ !wget https://your-repo/colab_quick_test.py
50
+ !python colab_quick_test.py
51
+
52
+ # 或直接复制代码到单元格运行
53
+ ```
54
+
55
+ **优势**:
56
+ - ✅ 零配置
57
+ - ✅ 自动安装依赖
58
+ - ✅ 5分钟完成全部测试
59
+ - ✅ 一次性输出完整报告
60
+
61
+ ---
62
+
63
+ ### 方式3: 模块化测试工具
64
+
65
+ **适合**: 开发者深度定制
66
+
67
+ ```python
68
+ # 在Colab中
69
+ !wget https://your-repo/colab_gpu_test.py
70
+ !python colab_gpu_test.py
71
+ ```
72
+
73
+ **优势**:
74
+ - ✅ 代码结构清晰
75
+ - ✅ 易于扩展
76
+ - ✅ 可集成到其他项目
77
+
78
+ ---
79
+
80
+ ## 📊 测试内容对比
81
+
82
+ | 测试项目 | Notebook | Quick Test | GPU Test |
83
+ |---------|----------|------------|----------|
84
+ | GPU环境检测 | ✅ | ✅ | ✅ |
85
+ | 矩阵运算测试 | ✅ | ✅ | ✅ |
86
+ | 文本嵌入测试 | ✅ | ✅ | ✅ |
87
+ | GraphRAG组件 | ✅ | ❌ | ❌ |
88
+ | 显存监控 | ✅ | ✅ | ✅ |
89
+ | 性能报告 | ✅ | ✅ | ✅ |
90
+ | 交互式说明 | ✅ | ❌ | ❌ |
91
+ | nvidia-smi | ✅ | ✅ | ✅ |
92
+
93
+ ---
94
+
95
+ ## 🎯 使用场景推荐
96
+
97
+ ### 场景1: 首次测试GPU
98
+ **推荐**: `colab_gpu_demo.ipynb`
99
+ - 详细的说明文档
100
+ - 分步执行,便于学习
101
+ - 可视化效果好
102
+
103
+ ### 场景2: 快速验证性能
104
+ **推荐**: `colab_quick_test.py`
105
+ - 一键运行
106
+ - 5分钟得到结果
107
+ - 完整性能报告
108
+
109
+ ### 场景3: 集成到CI/CD
110
+ **推荐**: `colab_gpu_test.py`
111
+ - 模块化设计
112
+ - 易于自动化
113
+ - 返回标准化结果
114
+
115
+ ### 场景4: 学习GPU优化
116
+ **推荐**: `COLAB_GPU_GUIDE.md` + `colab_gpu_demo.ipynb`
117
+ - 理论+实践
118
+ - 详细的性能分析
119
+ - 优化建议
120
+
121
+ ---
122
+
123
+ ## 📈 预期性能提升
124
+
125
+ ### Google Colab T4 GPU (免费版)
126
+
127
+ | 任务 | CPU | GPU | 加速比 |
128
+ |------|-----|-----|--------|
129
+ | 矩阵运算 (5000x5000) | 8秒 | 0.3秒 | **25x** |
130
+ | 文本嵌入 (1000条) | 35秒 | 6秒 | **6x** |
131
+ | GraphRAG索引 (100文档) | 15分钟 | 4分钟 | **3.8x** |
132
+
133
+ ### Google Colab A100 GPU (Pro版)
134
+
135
+ | 任务 | CPU | GPU | 加速比 |
136
+ |------|-----|-----|--------|
137
+ | 矩阵运算 | 8秒 | 0.2秒 | **40x** |
138
+ | 文本嵌入 | 35秒 | 3秒 | **12x** |
139
+ | GraphRAG索引 | 15分钟 | 2.5分钟 | **6x** |
140
+
141
+ ---
142
+
143
+ ## 🔧 完整GraphRAG部署流程
144
+
145
+ ### 步骤1: GPU性能测试
146
+ ```python
147
+ # 运行quick test验证GPU
148
+ !python colab_quick_test.py
149
+ ```
150
+
151
+ ### 步骤2: 上传项目文件
152
+ ```python
153
+ # 方式A: 从Google Drive
154
+ from google.colab import drive
155
+ drive.mount('/content/drive')
156
+ !cp -r /content/drive/MyDrive/adaptive_RAG /content/
157
+ %cd /content/adaptive_RAG
158
+
159
+ # 方式B: 从GitHub
160
+ !git clone YOUR_REPO_URL
161
+ %cd adaptive_RAG
162
+ ```
163
+
164
+ ### 步骤3: 安装依赖
165
+ ```python
166
+ !pip install -q -r requirements.txt
167
+ !pip install -q -r requirements_graphrag.txt
168
+ ```
169
+
170
+ ### 步骤4: 配置API密钥
171
+ ```python
172
+ import os
173
+ from getpass import getpass
174
+ os.environ['TAVILY_API_KEY'] = getpass('TAVILY_API_KEY: ')
175
+ ```
176
+
177
+ ### 步骤5: 运行GraphRAG
178
+ ```python
179
+ !python main_graphrag.py
180
+ ```
181
+
182
+ ### 步骤6: 下载结果
183
+ ```python
184
+ from google.colab import files
185
+ files.download('data/knowledge_graph.json')
186
+ ```
187
+
188
+ ---
189
+
190
+ ## 💡 优化技巧
191
+
192
+ ### 1. 批处理大小
193
+ ```python
194
+ # config.py
195
+ GRAPHRAG_BATCH_SIZE = 20 # GPU环境可增大
196
+ ```
197
+
198
+ ### 2. 嵌入模型选择
199
+ ```python
200
+ # GPU环境使用更大模型
201
+ EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
202
+ ```
203
+
204
+ ### 3. 混合精度训练
205
+ ```python
206
+ import torch
207
+ torch.set_float32_matmul_precision('medium')
208
+ ```
209
+
210
+ ### 4. 数据持久化
211
+ ```python
212
+ # 定期保存到Drive
213
+ import shutil
214
+ shutil.copy(
215
+ 'data/knowledge_graph.json',
216
+ '/content/drive/MyDrive/backup.json'
217
+ )
218
+ ```
219
+
220
+ ---
221
+
222
+ ## ⚠️ 注意事项
223
+
224
+ ### Colab免费版限制
225
+ - ⏰ 连续使用: 最多12小时
226
+ - 🔄 GPU配额: 每周有限
227
+ - ⏸️ 闲置超时: 90分钟
228
+
229
+ ### 建议
230
+ - 💾 定期保存进度
231
+ - ⬇️ 及时下载结果
232
+ - 🔄 使用后台任务保持活跃
233
+
234
+ ---
235
+
236
+ ## 📚 文件使用优先级
237
+
238
+ ### 新手用户
239
+ 1. 📖 先阅读 `COLAB_GPU_GUIDE.md`
240
+ 2. 🚀 运行 `colab_gpu_demo.ipynb`
241
+ 3. ✅ 验证性能后部署完整项目
242
+
243
+ ### 高级用户
244
+ 1. ⚡ 直接运行 `colab_quick_test.py`
245
+ 2. 📊 查看性能报告
246
+ 3. 🔧 根据需求调整配置
247
+
248
+ ### 开发者
249
+ 1. 🔍 研究 `colab_gpu_test.py` 源码
250
+ 2. 🛠️ 根据需求定制功能
251
+ 3. 🔄 集成到自动化流程
252
+
253
+ ---
254
+
255
+ ## 🎯 关键性能指标
256
+
257
+ ### 必须达到的基准
258
+ - ✅ GPU检测: CUDA可用
259
+ - ✅ 矩阵加速: >10x
260
+ - ✅ 嵌入加速: >5x
261
+ - ✅ 显存使用: <80%
262
+
263
+ ### 如果低于基准
264
+ 1. 检查GPU类型 (应该是T4或A100)
265
+ 2. 重启运行时
266
+ 3. 检查依赖版本
267
+
268
+ ---
269
+
270
+ ## 📞 获取帮助
271
+
272
+ ### 常见问题
273
+ - 查看 `COLAB_GPU_GUIDE.md` 的FAQ部分
274
+
275
+ ### 性能问题
276
+ - 运行 `colab_quick_test.py` 获取诊断报告
277
+
278
+ ### 技术支持
279
+ - 提供测试报告输出
280
+ - 说明具体错误信息
281
+
282
+ ---
283
+
284
+ ## ✅ 总结
285
+
286
+ | 文件 | 何时使用 |
287
+ |------|---------|
288
+ | `colab_gpu_demo.ipynb` | 首次使用、学习、演示 |
289
+ | `colab_quick_test.py` | 快速验证、CI/CD、批量测试 |
290
+ | `colab_gpu_test.py` | 深度定制、集成开发 |
291
+ | `COLAB_GPU_GUIDE.md` | 参考文档、问题排查 |
292
+
293
+ **推荐流程**:
294
+ 1. 阅读 `COLAB_GPU_GUIDE.md` (5分钟)
295
+ 2. 运行 `colab_quick_test.py` (5分钟)
296
+ 3. 如果性能符合预期,部署完整GraphRAG项目
297
+
298
+ **预期结果**:
299
+ - GPU可用 ✅
300
+ - 3-6倍整体加速 ✅
301
+ - 节省10+分钟时间 ✅
302
+
303
+ ---
304
+
305
+ 🚀 **立即开始**: 上传任一文件到 [Google Colab](https://colab.research.google.com/) 并启用GPU!
COLAB_GPU_GUIDE.md ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚀 Google Colab GPU 测试指南
2
+
3
+ ## 📋 概述
4
+
5
+ 我为您创建了两个文件用于在Google Colab上测试GPU性能:
6
+
7
+ 1. **`colab_gpu_demo.ipynb`** - Jupyter Notebook版本(推荐)
8
+ 2. **`colab_gpu_test.py`** - Python脚本版本
9
+
10
+ ## 🎯 使用方法
11
+
12
+ ### 方法1: 使用Notebook(推荐)
13
+
14
+ #### 步骤1: 上传到Colab
15
+
16
+ 1. 打开 [Google Colab](https://colab.research.google.com/)
17
+ 2. 点击 `文件` → `上传笔记本`
18
+ 3. 选择 `colab_gpu_demo.ipynb`
19
+
20
+ #### 步骤2: 启用GPU
21
+
22
+ 1. 点击顶部菜单 `运行时` → `更改运行时类型`
23
+ 2. 硬件加速器选择 `GPU`
24
+ 3. GPU类型选择 `T4`(免费版)或 `A100`(Colab Pro)
25
+ 4. 点击 `保存`
26
+
27
+ #### 步骤3: 运行测试
28
+
29
+ 1. 点击 `运行时` → `全部运行`
30
+ 2. 或者逐个单元格运行(Shift + Enter)
31
+
32
+ ### 方法2: 使用Python脚本
33
+
34
+ #### 步骤1: 上传文件
35
+
36
+ 1. 在Colab中创建新笔记本
37
+ 2. 点击左侧文件夹图标
38
+ 3. 上传 `colab_gpu_test.py`
39
+
40
+ #### 步骤2: 运行脚本
41
+
42
+ ```python
43
+ # 在Colab单元格中运行
44
+ !python colab_gpu_test.py
45
+ ```
46
+
47
+ ## 📊 测试内容
48
+
49
+ ### 1. GPU环境检测 ✅
50
+ - CUDA可用性检查
51
+ - GPU型号和显存信息
52
+ - nvidia-smi输出
53
+
54
+ ### 2. 矩阵运算性能测试 ⚡
55
+ - CPU vs GPU 5000x5000矩阵乘法
56
+ - 预期加速比: **10-50x**
57
+
58
+ ### 3. 文本嵌入性能测试 📝
59
+ - 使用sentence-transformers
60
+ - 1000个文本的嵌入生成
61
+ - CPU vs GPU对比
62
+ - 预期加速比: **5-10x**
63
+
64
+ ### 4. GraphRAG组件测试 🔍
65
+ - 简化版知识图谱构建
66
+ - 实体和关系管理
67
+ - GPU加速的向量检索
68
+
69
+ ### 5. 显存监控 💾
70
+ - 实时显存使用情况
71
+ - 内存分配统计
72
+
73
+ ## 📈 预期结果
74
+
75
+ ### Google Colab 免费版 (T4 GPU)
76
+
77
+ | 测试项目 | CPU时间 | GPU时间 | 加速比 |
78
+ |---------|---------|---------|--------|
79
+ | 矩阵运算 (5000x5000) | ~8-10秒 | ~0.3-0.5秒 | 20-30x |
80
+ | 文本嵌入 (1000文本) | ~30-40秒 | ~5-8秒 | 5-7x |
81
+ | GraphRAG索引 (100文档) | ~15分钟 | ~3-5分钟 | 3-5x |
82
+
83
+ ### Google Colab Pro (A100 GPU)
84
+
85
+ | 测试项目 | CPU时间 | GPU时间 | 加速比 |
86
+ |---------|---------|---------|--------|
87
+ | 矩阵运算 | ~8秒 | ~0.2秒 | 40x |
88
+ | 文本嵌入 | ~35秒 | ~3秒 | 10-12x |
89
+ | GraphRAG索引 | ~15分钟 | ~2-3分钟 | 5-7x |
90
+
91
+ ## 🔧 运行完整GraphRAG项目
92
+
93
+ 如果GPU测试成功,可以在Colab上运行完整的GraphRAG项目:
94
+
95
+ ### 步骤1: 上传项目文件
96
+
97
+ 在Colab中创建新的单元格:
98
+
99
+ ```python
100
+ # 方式1: 从Google Drive加载
101
+ from google.colab import drive
102
+ drive.mount('/content/drive')
103
+
104
+ # 复制项目文件
105
+ !cp -r /content/drive/MyDrive/adaptive_RAG /content/
106
+ %cd /content/adaptive_RAG
107
+ ```
108
+
109
+ 或者:
110
+
111
+ ```python
112
+ # 方式2: 从GitHub克隆
113
+ !git clone YOUR_GITHUB_REPO_URL
114
+ %cd adaptive_RAG
115
+ ```
116
+
117
+ ### 步骤2: 安装依赖
118
+
119
+ ```python
120
+ # 安装基础依赖
121
+ !pip install -q -r requirements.txt
122
+
123
+ # 安装GraphRAG依赖
124
+ !pip install -q -r requirements_graphrag.txt
125
+ ```
126
+
127
+ ### 步骤3: 配置API密钥
128
+
129
+ ```python
130
+ import os
131
+ from getpass import getpass
132
+
133
+ # 安全输入API密钥
134
+ os.environ['TAVILY_API_KEY'] = getpass('输入 TAVILY_API_KEY: ')
135
+
136
+ # 验证
137
+ print("✅ API密钥已设置")
138
+ ```
139
+
140
+ ### 步骤4: 运行GraphRAG
141
+
142
+ ```python
143
+ # 运行主程序
144
+ !python main_graphrag.py
145
+ ```
146
+
147
+ ### 步骤5: 下载结果
148
+
149
+ ```python
150
+ # 下载构建好的知识图谱
151
+ from google.colab import files
152
+
153
+ # 下载图谱文件
154
+ files.download('data/knowledge_graph.json')
155
+
156
+ print("✅ 图谱已下载到本地")
157
+ ```
158
+
159
+ ## 💡 优化建议
160
+
161
+ ### 1. 批处理大小优化
162
+
163
+ 在 `config.py` 中调整:
164
+
165
+ ```python
166
+ # GPU优化配置
167
+ GRAPHRAG_BATCH_SIZE = 20 # GPU可以处理更大批次
168
+ ```
169
+
170
+ ### 2. 使用GPU优化的模型
171
+
172
+ ```python
173
+ # 使用更大的嵌入模型(GPU环境)
174
+ EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
175
+ ```
176
+
177
+ ### 3. 启用混合精度
178
+
179
+ ```python
180
+ # 在entity_extractor.py中
181
+ import torch
182
+ torch.set_float32_matmul_precision('medium') # 提升性能
183
+ ```
184
+
185
+ ## ⚠️ 注意事项
186
+
187
+ ### Colab资源限制
188
+
189
+ 1. **免费版限制**:
190
+ - 连续使用时间: 最多12小时
191
+ - GPU使用配额: 每周有限
192
+ - 闲置超时: 90分钟自动断开
193
+
194
+ 2. **建议**:
195
+ - 定期保存进度到Google Drive
196
+ - 使用`files.download()`下载重要结果
197
+ - 避免长时间空闲
198
+
199
+ ### 数据持久化
200
+
201
+ ```python
202
+ # 定期保存到Google Drive
203
+ from google.colab import drive
204
+ drive.mount('/content/drive')
205
+
206
+ # 保存图谱
207
+ import shutil
208
+ shutil.copy(
209
+ 'data/knowledge_graph.json',
210
+ '/content/drive/MyDrive/graphrag_backup.json'
211
+ )
212
+ ```
213
+
214
+ ## 🐛 常见问题
215
+
216
+ ### Q1: GPU连接失败
217
+
218
+ **A**: 检查运行时类型
219
+ ```python
220
+ import torch
221
+ print(f"CUDA可用: {torch.cuda.is_available()}")
222
+ # 如果False,重新设置运行时类型
223
+ ```
224
+
225
+ ### Q2: 内存不足
226
+
227
+ **A**: 减小批处理大小
228
+ ```python
229
+ GRAPHRAG_BATCH_SIZE = 5 # 降低批次
230
+ ```
231
+
232
+ ### Q3: 会话超时
233
+
234
+ **A**: 使用Colab Pro或定期运行代码保持活跃
235
+ ```python
236
+ # 在后台定期执行
237
+ import time
238
+ while True:
239
+ print("Keep alive...")
240
+ time.sleep(300) # 每5分钟执行一次
241
+ ```
242
+
243
+ ## 📚 参考资源
244
+
245
+ - [Google Colab官方文档](https://colab.research.google.com/notebooks/intro.ipynb)
246
+ - [GPU加速指南](https://colab.research.google.com/notebooks/gpu.ipynb)
247
+ - [Colab Pro定价](https://colab.research.google.com/signup)
248
+
249
+ ## 🎓 下一步学习
250
+
251
+ 1. **理解GPU加速原理**: 查看测试代码中的性能对比
252
+ 2. **优化GraphRAG参数**: 根据GPU性能调整配置
253
+ 3. **扩展到生产环境**: 考虑使用AWS/GCP的GPU实例
254
+
255
+ ---
256
+
257
+ ## ✅ 总结
258
+
259
+ | 优势 | 说明 |
260
+ |------|------|
261
+ | 🆓 免费GPU | T4 GPU免费使用 |
262
+ | ⚡ 高性能 | 3-10倍加速 |
263
+ | 🔄 零配置 | 无需本地安装 |
264
+ | 💾 自动保存 | 集成Google Drive |
265
+ | 🌐 随时访问 | 仅需浏览器 |
266
+
267
+ **推荐**: 在本地CPU环境速度慢时,使用Colab GPU可以大幅提升GraphRAG索引构建速度!
268
+
269
+ ---
270
+
271
+ **立即开始**: 上传 `colab_gpu_demo.ipynb` 到 [Google Colab](https://colab.research.google.com/) 并启用GPU! 🚀
COLAB_OLLAMA_GUIDE.md ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GraphRAG Colab 完整运行指南
2
+
3
+ ## 🎯 在Colab中运行Ollama的3种方法
4
+
5
+ ### 方法1: 后台运行Ollama(推荐)⭐⭐⭐⭐⭐
6
+
7
+ 在Colab中,您可以在单个单元格中后台启动Ollama,然后在另一个单元格运行GraphRAG。
8
+
9
+ #### 步骤1: 安装Ollama
10
+
11
+ ```bash
12
+ # 单元格1: 安装Ollama
13
+ !curl -fsSL https://ollama.com/install.sh | sh
14
+ ```
15
+
16
+ #### 步骤2: 后台启动Ollama服务
17
+
18
+ ```python
19
+ # 单元格2: 后台启动Ollama
20
+ import subprocess
21
+ import time
22
+ import os
23
+
24
+ # 启动Ollama服务(后台)
25
+ ollama_process = subprocess.Popen(
26
+ ["ollama", "serve"],
27
+ stdout=subprocess.PIPE,
28
+ stderr=subprocess.PIPE,
29
+ preexec_fn=os.setpgrp
30
+ )
31
+
32
+ print("⏳ 等待Ollama服务启动...")
33
+ time.sleep(5)
34
+
35
+ # 验证服务是否启动
36
+ !curl -s http://localhost:11434/api/tags | head -5
37
+
38
+ print(f"✅ Ollama服务已启动 (PID: {ollama_process.pid})")
39
+ ```
40
+
41
+ #### 步骤3: 下载Mistral模型
42
+
43
+ ```bash
44
+ # 单元格3: 下载模型
45
+ !ollama pull mistral
46
+ ```
47
+
48
+ #### 步骤4: 安装Python依赖
49
+
50
+ ```bash
51
+ # 单元格4: 安装依赖
52
+ !pip install -q langchain langchain-community langchain-core langgraph
53
+ !pip install -q chromadb sentence-transformers tiktoken
54
+ !pip install -q tavily-python python-dotenv networkx python-louvain
55
+ ```
56
+
57
+ #### 步骤5: 配置API密钥
58
+
59
+ ```python
60
+ # 单元格5: 配置环境
61
+ import os
62
+ from getpass import getpass
63
+
64
+ os.environ['TAVILY_API_KEY'] = getpass('输入TAVILY_API_KEY: ')
65
+ print("✅ API密钥已设置")
66
+ ```
67
+
68
+ #### 步骤6: 运行GraphRAG
69
+
70
+ ```python
71
+ # 单元格6: 运行GraphRAG
72
+ !python main_graphrag.py
73
+ ```
74
+
75
+ #### 步骤7: 下载结果(可选)
76
+
77
+ ```python
78
+ # 单元格7: 下载生成的图谱
79
+ from google.colab import files
80
+ files.download('data/knowledge_graph.json')
81
+ ```
82
+
83
+ ---
84
+
85
+ ### 方法2: 使用tmux(高级)⭐⭐⭐⭐
86
+
87
+ ```bash
88
+ # 单元格1: 安装tmux
89
+ !apt-get install -y tmux
90
+
91
+ # 单元格2: 在tmux会话中启动Ollama
92
+ !tmux new-session -d -s ollama 'ollama serve'
93
+
94
+ # 单元格3: 检查会话
95
+ !tmux ls
96
+
97
+ # 单元格4: 下载模型
98
+ !ollama pull mistral
99
+
100
+ # 单元格5: 运行GraphRAG
101
+ !python main_graphrag.py
102
+
103
+ # 单元格6: 停止tmux会话(清理)
104
+ !tmux kill-session -t ollama
105
+ ```
106
+
107
+ ---
108
+
109
+ ### 方法3: 使用nohup(简单)⭐⭐⭐
110
+
111
+ ```bash
112
+ # 单元格1: 后台启动Ollama
113
+ !nohup ollama serve > /tmp/ollama.log 2>&1 &
114
+
115
+ # 单元格2: 等待启动
116
+ import time
117
+ time.sleep(5)
118
+
119
+ # 单元格3: 检查日志
120
+ !tail -20 /tmp/ollama.log
121
+
122
+ # 单元格4: 下载模型
123
+ !ollama pull mistral
124
+
125
+ # 单元格5: 运行GraphRAG
126
+ !python main_graphrag.py
127
+
128
+ # 单元格6: 停止Ollama(清理)
129
+ !pkill -f 'ollama serve'
130
+ ```
131
+
132
+ ---
133
+
134
+ ## 🚀 一键运行脚本(最简单)⭐⭐⭐⭐⭐
135
+
136
+ 我已经为您创建了一个自动化脚本 `colab_setup_and_run.py`,它会:
137
+ 1. ✅ 自动安装Ollama
138
+ 2. ✅ 后台启动服务
139
+ 3. ✅ 下载Mistral模型
140
+ 4. ✅ 安装Python依赖
141
+ 5. ✅ 配置环境变量
142
+ 6. ✅ 运行GraphRAG
143
+
144
+ ### 使用方法:
145
+
146
+ ```bash
147
+ # 方法A: 直接运行脚本
148
+ !python colab_setup_and_run.py
149
+
150
+ # 方法B: 或者在Python中
151
+ import subprocess
152
+ subprocess.run(["python", "colab_setup_and_run.py"])
153
+ ```
154
+
155
+ ---
156
+
157
+ ## 📊 完整的Colab Notebook示例
158
+
159
+ 创建一个新的Colab笔记本,按顺序运行以下单元格:
160
+
161
+ ### 单元格1: 环境准备
162
+
163
+ ```python
164
+ # 检测GPU
165
+ import torch
166
+ print(f"GPU可用: {torch.cuda.is_available()}")
167
+ if torch.cuda.is_available():
168
+ print(f"GPU型号: {torch.cuda.get_device_name(0)}")
169
+ ```
170
+
171
+ ### 单元格2: 安装Ollama
172
+
173
+ ```bash
174
+ %%bash
175
+ curl -fsSL https://ollama.com/install.sh | sh
176
+ echo "✅ Ollama安装完成"
177
+ ```
178
+
179
+ ### 单元格3: 后台启动Ollama
180
+
181
+ ```python
182
+ import subprocess
183
+ import time
184
+ import os
185
+
186
+ print("🔄 启动Ollama服务...")
187
+
188
+ # 后台启动
189
+ process = subprocess.Popen(
190
+ ["ollama", "serve"],
191
+ stdout=subprocess.PIPE,
192
+ stderr=subprocess.PIPE,
193
+ preexec_fn=os.setpgrp
194
+ )
195
+
196
+ # 等待启动
197
+ time.sleep(5)
198
+
199
+ # 验证
200
+ import requests
201
+ try:
202
+ response = requests.get("http://localhost:11434/api/tags", timeout=3)
203
+ if response.status_code == 200:
204
+ print(f"✅ Ollama服务运行正常 (PID: {process.pid})")
205
+ else:
206
+ print("⚠️ 服务响应异常")
207
+ except:
208
+ print("⚠️ 无法连接服务,但进程已启动")
209
+
210
+ # 保存进程ID(重要!)
211
+ ollama_pid = process.pid
212
+ print(f"📝 保存的PID: {ollama_pid}")
213
+ ```
214
+
215
+ ### 单元格4: 下载模型
216
+
217
+ ```bash
218
+ %%bash
219
+ echo "📥 下载Mistral模型..."
220
+ ollama pull mistral
221
+ echo "✅ 模型下载完成"
222
+ ollama list
223
+ ```
224
+
225
+ ### 单元格5: 上传项目文件
226
+
227
+ ```python
228
+ # 方式A: 从Google Drive
229
+ from google.colab import drive
230
+ drive.mount('/content/drive')
231
+
232
+ # 复制项目文件
233
+ !cp -r /content/drive/MyDrive/adaptive_RAG /content/
234
+ %cd /content/adaptive_RAG
235
+
236
+ # 方式B: 手动上传
237
+ # from google.colab import files
238
+ # uploaded = files.upload()
239
+ ```
240
+
241
+ ### 单元格6: 安装依赖
242
+
243
+ ```bash
244
+ %%bash
245
+ pip install -q -r requirements.txt
246
+ pip install -q -r requirements_graphrag.txt
247
+ echo "✅ 依赖安装完成"
248
+ ```
249
+
250
+ ### 单元格7: 配置环境
251
+
252
+ ```python
253
+ import os
254
+ from getpass import getpass
255
+
256
+ # 设置API密钥
257
+ if not os.path.exists('.env'):
258
+ api_key = getpass('输入TAVILY_API_KEY: ')
259
+ with open('.env', 'w') as f:
260
+ f.write(f'TAVILY_API_KEY={api_key}\n')
261
+ print("✅ .env文件已创建")
262
+ else:
263
+ print("✅ 使用现有.env文件")
264
+ ```
265
+
266
+ ### 单元格8: 运行GraphRAG
267
+
268
+ ```python
269
+ # 方式A: 直接运行
270
+ !python main_graphrag.py
271
+
272
+ # 方式B: 在Python中运行(可以捕获输出)
273
+ import subprocess
274
+
275
+ result = subprocess.run(
276
+ ["python", "main_graphrag.py"],
277
+ capture_output=True,
278
+ text=True
279
+ )
280
+
281
+ print(result.stdout)
282
+ if result.returncode != 0:
283
+ print("错误信息:")
284
+ print(result.stderr)
285
+ ```
286
+
287
+ ### 单元格9: 下载结果
288
+
289
+ ```python
290
+ # 下载生成的知识图谱
291
+ from google.colab import files
292
+
293
+ if os.path.exists('data/knowledge_graph.json'):
294
+ files.download('data/knowledge_graph.json')
295
+ print("✅ 文件已下载")
296
+ else:
297
+ print("❌ 未找到图谱文件")
298
+
299
+ # 保存到Google Drive
300
+ import shutil
301
+ shutil.copy(
302
+ 'data/knowledge_graph.json',
303
+ '/content/drive/MyDrive/graphrag_backup.json'
304
+ )
305
+ print("✅ 已备份到Google Drive")
306
+ ```
307
+
308
+ ### 单元格10: 清理(可选)
309
+
310
+ ```python
311
+ # 停止Ollama服务
312
+ import os
313
+ import signal
314
+
315
+ try:
316
+ os.kill(ollama_pid, signal.SIGTERM)
317
+ print(f"✅ Ollama服务已停止 (PID: {ollama_pid})")
318
+ except:
319
+ print("⚠️ 停止服务失败,手动停止:")
320
+ !pkill -f 'ollama serve'
321
+ ```
322
+
323
+ ---
324
+
325
+ ## ⚠️ 常见问题
326
+
327
+ ### Q1: Ollama服务启动后立即退出
328
+
329
+ **A**: 使用 `subprocess.Popen` 而不是 `subprocess.run`:
330
+
331
+ ```python
332
+ # ❌ 错误方式
333
+ !ollama serve & # 会立即退出
334
+
335
+ # ✅ 正确方式
336
+ import subprocess
337
+ process = subprocess.Popen(["ollama", "serve"])
338
+ ```
339
+
340
+ ### Q2: 连接被拒绝 (Connection refused)
341
+
342
+ **A**: 等待服务完全启动:
343
+
344
+ ```python
345
+ import time
346
+ time.sleep(10) # 增加等待时间
347
+ ```
348
+
349
+ ### Q3: 进程管理困难
350
+
351
+ **A**: 使用PID文件:
352
+
353
+ ```python
354
+ # 保存PID
355
+ with open('/tmp/ollama.pid', 'w') as f:
356
+ f.write(str(process.pid))
357
+
358
+ # 后续停止
359
+ with open('/tmp/ollama.pid', 'r') as f:
360
+ pid = int(f.read())
361
+ os.kill(pid, signal.SIGTERM)
362
+ ```
363
+
364
+ ### Q4: 会话超时导致服务停止
365
+
366
+ **A**: 定期执行代码保持活跃:
367
+
368
+ ```python
369
+ import time
370
+ while True:
371
+ print("Keep alive...")
372
+ time.sleep(300) # 每5分钟
373
+ ```
374
+
375
+ ---
376
+
377
+ ## 📚 推荐的完整流程
378
+
379
+ 1. ✅ **运行自动化脚本** - `!python colab_setup_and_run.py`
380
+ 2. ✅ **或按照Notebook示例** - 逐步执行每个单元格
381
+ 3. ✅ **定期保存结果** - 到Google Drive
382
+
383
+ ---
384
+
385
+ ## 💡 最佳实践
386
+
387
+ 1. **始终保存Ollama的PID**: 方便后续管理
388
+ 2. **使用try-finally**: 确保清理后台进程
389
+ 3. **定期备份**: 保存中间结果到Drive
390
+ 4. **监控显存**: 避免OOM错误
391
+
392
+ ```python
393
+ # 最佳实践示例
394
+ import subprocess
395
+ import atexit
396
+ import signal
397
+
398
+ # 启动Ollama
399
+ ollama_process = subprocess.Popen(["ollama", "serve"])
400
+
401
+ # 注册清理函数
402
+ def cleanup():
403
+ try:
404
+ ollama_process.terminate()
405
+ print("✅ Ollama已停止")
406
+ except:
407
+ pass
408
+
409
+ atexit.register(cleanup)
410
+
411
+ # 运行您的代码
412
+ try:
413
+ # ... 您的GraphRAG代码 ...
414
+ pass
415
+ finally:
416
+ cleanup()
417
+ ```
418
+
419
+ ---
420
+
421
+ **推荐**: 直接使用 `colab_setup_and_run.py` 脚本,它已经处理了所有这些细节!🚀
DEPLOYMENT_GUIDE.md ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Linux GPU部署指南 (RTX 4090)
2
+
3
+ ## 🚀 自适应RAG系统在Linux RTX 4090环境部署
4
+
5
+ 本指南将详细介绍如何在配备NVIDIA RTX 4090 GPU的Linux服务器上部署自适应RAG系统。
6
+
7
+ ## 📋 环境要求
8
+
9
+ ### 硬件要求
10
+ - NVIDIA RTX 4090 GPU
11
+ - 至少16GB内存(推荐32GB)
12
+ - 50GB+可用磁盘空间
13
+ - Ubuntu 20.04+ / CentOS 8+ / RHEL 8+
14
+
15
+ ### 软件要求
16
+ - Linux操作系统(推荐Ubuntu 22.04 LTS)
17
+ - NVIDIA驱动程序(推荐535+)
18
+ - CUDA 12.0+
19
+ - Docker(可选但推荐)
20
+ - Python 3.8-3.11
21
+
22
+ ## 🔧 步骤1:系统准备
23
+
24
+ ### 1.1 更新系统
25
+ ```bash
26
+ sudo apt update && sudo apt upgrade -y
27
+ sudo apt install -y curl wget git build-essential python3-pip python3-venv
28
+ ```
29
+
30
+ ### 1.2 安装NVIDIA驱动和CUDA
31
+ ```bash
32
+ # 检查GPU
33
+ lspci | grep -i nvidia
34
+
35
+ # 添加NVIDIA软件源
36
+ wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.0-1_all.deb
37
+ sudo dpkg -i cuda-keyring_1.0-1_all.deb
38
+ sudo apt-get update
39
+
40
+ # 安装NVIDIA驱动和CUDA
41
+ sudo apt-get install -y nvidia-driver-535 cuda-12-2
42
+
43
+ # 重启系统
44
+ sudo reboot
45
+ ```
46
+
47
+ ### 1.3 验证GPU安装
48
+ ```bash
49
+ # 重启后验证
50
+ nvidia-smi
51
+ nvcc --version
52
+ ```
53
+
54
+ ## 🐳 步骤2:Docker环境配置(推荐)
55
+
56
+ ### 2.1 安装Docker
57
+ ```bash
58
+ # 安装Docker
59
+ curl -fsSL https://get.docker.com -o get-docker.sh
60
+ sudo sh get-docker.sh
61
+ sudo usermod -aG docker $USER
62
+ ```
63
+
64
+ ### 2.2 安装NVIDIA Container Toolkit
65
+ ```bash
66
+ # 添加NVIDIA Docker源
67
+ distribution=$(. /etc/os-release;echo $ID$VERSION_ID)
68
+ curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add -
69
+ curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list
70
+
71
+ # 安装nvidia-container-toolkit
72
+ sudo apt-get update
73
+ sudo apt-get install -y nvidia-container-toolkit
74
+ sudo systemctl restart docker
75
+ ```
76
+
77
+ ### 2.3 创建Dockerfile
78
+ ```dockerfile
79
+ # 创建 Dockerfile
80
+ cat > Dockerfile << 'EOF'
81
+ FROM nvidia/cuda:12.2-devel-ubuntu22.04
82
+
83
+ # 设置非交互模式
84
+ ENV DEBIAN_FRONTEND=noninteractive
85
+ ENV PYTHONUNBUFFERED=1
86
+
87
+ # 更新系统并安装Python
88
+ RUN apt-get update && apt-get install -y \
89
+ python3 \
90
+ python3-pip \
91
+ python3-venv \
92
+ git \
93
+ curl \
94
+ && rm -rf /var/lib/apt/lists/*
95
+
96
+ # 创建工作目录
97
+ WORKDIR /app
98
+
99
+ # 复制项目文件
100
+ COPY requirements.txt .
101
+ COPY *.py .
102
+ COPY *.md .
103
+
104
+ # 安装Python依赖
105
+ RUN pip3 install --no-cache-dir -r requirements.txt
106
+
107
+ # 暴露端口(如果需要Web界面)
108
+ EXPOSE 8000
109
+
110
+ # 启动命令
111
+ CMD ["python3", "main.py"]
112
+ EOF
113
+ ```
114
+
115
+ ## 🐍 步骤3:Python环境配置(直接部署)
116
+
117
+ ### 3.1 创建Python虚拟环境
118
+ ```bash
119
+ # 克隆项目
120
+ git clone <your-repo-url> adaptive_rag
121
+ cd adaptive_rag
122
+
123
+ # 创建虚拟环境
124
+ python3 -m venv rag_env
125
+ source rag_env/bin/activate
126
+
127
+ # 升级pip
128
+ pip install --upgrade pip
129
+ ```
130
+
131
+ ### 3.2 修改requirements.txt以支持GPU
132
+ 需要更新requirements.txt以优化GPU使用:
133
+
134
+ ```bash
135
+ # 创建GPU优化的requirements文件
136
+ cat > requirements_gpu.txt << 'EOF'
137
+ # 核心框架
138
+ langchain>=0.1.0
139
+ langgraph>=0.0.40
140
+ langchain-community>=0.0.20
141
+ langchain-core>=0.1.0
142
+
143
+ # LLM集成
144
+ langchain-ollama>=0.1.0
145
+
146
+ # 向量数据库和嵌入(GPU优化版本)
147
+ chromadb>=0.4.0
148
+ sentence-transformers>=2.2.0
149
+ torch>=2.0.0+cu118 --index-url https://download.pytorch.org/whl/cu118
150
+ torchvision>=0.15.0+cu118 --index-url https://download.pytorch.org/whl/cu118
151
+ transformers>=4.30.0
152
+ accelerate>=0.20.0
153
+
154
+ # 文档处理
155
+ tiktoken>=0.5.0
156
+ beautifulsoup4>=4.12.0
157
+ requests>=2.31.0
158
+
159
+ # 网络搜索
160
+ tavily-python>=0.3.0
161
+
162
+ # 数据处理
163
+ numpy>=1.24.0,<2.0
164
+ pandas>=2.0.0
165
+
166
+ # 工具库
167
+ python-dotenv>=1.0.0
168
+ pydantic>=2.0.0
169
+ typing-extensions>=4.0.0
170
+
171
+ # GPU加速库
172
+ cupy-cuda12x>=12.0.0
173
+ faiss-gpu>=1.7.4
174
+ EOF
175
+ ```
176
+
177
+ ### 3.3 安装依赖
178
+ ```bash
179
+ # 安装GPU优化依赖
180
+ pip install -r requirements_gpu.txt
181
+ ```
182
+
183
+ ## 🛠️ 步骤4:修改配置以优化GPU使用
184
+
185
+ ### 4.1 更新document_processor.py以使用GPU
186
+ 需要修改嵌入模型配置:
187
+
188
+ ```python
189
+ # 在document_processor.py中修改
190
+ self.embeddings = HuggingFaceEmbeddings(
191
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
192
+ model_kwargs={'device': 'cuda'}, # 使用GPU
193
+ encode_kwargs={'normalize_embeddings': True}
194
+ )
195
+ ```
196
+
197
+ ### 4.2 创建GPU优化配置
198
+ ```python
199
+ # 创建 gpu_config.py
200
+ cat > gpu_config.py << 'EOF'
201
+ import torch
202
+ import os
203
+
204
+ # GPU配置
205
+ if torch.cuda.is_available():
206
+ DEVICE = "cuda"
207
+ GPU_COUNT = torch.cuda.device_count()
208
+ GPU_NAME = torch.cuda.get_device_name(0)
209
+ print(f"发现 {GPU_COUNT} 个GPU: {GPU_NAME}")
210
+
211
+ # 设置CUDA优化
212
+ torch.backends.cudnn.benchmark = True
213
+ torch.backends.cudnn.deterministic = False
214
+
215
+ # 设置GPU内存管理
216
+ torch.cuda.empty_cache()
217
+ else:
218
+ DEVICE = "cpu"
219
+ print("未发现GPU,使用CPU模式")
220
+
221
+ # 优化设置
222
+ EMBEDDING_BATCH_SIZE = 32 if DEVICE == "cuda" else 8
223
+ MAX_WORKERS = 4 if DEVICE == "cuda" else 2
224
+ EOF
225
+ ```
226
+
227
+ ## 🤖 步骤5:安装和配置Ollama
228
+
229
+ ### 5.1 安装Ollama
230
+ ```bash
231
+ # 下载并安装Ollama
232
+ curl -fsSL https://ollama.ai/install.sh | sh
233
+
234
+ # 或者使用Docker
235
+ # docker run -d --gpus=all -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama
236
+ ```
237
+
238
+ ### 5.2 下载模型
239
+ ```bash
240
+ # 下载Mistral模型
241
+ ollama pull mistral
242
+
243
+ # 或者下载更大的模型(如果GPU内存足够)
244
+ ollama pull llama2:13b
245
+ ollama pull codellama:34b
246
+ ```
247
+
248
+ ### 5.3 启动Ollama服务
249
+ ```bash
250
+ # 启动Ollama服务
251
+ ollama serve &
252
+
253
+ # 验证服务
254
+ curl http://localhost:11434/api/version
255
+ ```
256
+
257
+ ## 🔐 步骤6:环境变量配置
258
+
259
+ ### 6.1 创建.env文件
260
+ ```bash
261
+ cat > .env << 'EOF'
262
+ # API密钥
263
+ TAVILY_API_KEY=your_tavily_api_key_here
264
+
265
+ # GPU配置
266
+ CUDA_VISIBLE_DEVICES=0
267
+ TORCH_CUDA_ARCH_LIST="8.9" # RTX 4090架构
268
+
269
+ # 模型配置
270
+ HF_HOME=/app/models
271
+ TRANSFORMERS_CACHE=/app/models
272
+
273
+ # 性能优化
274
+ OMP_NUM_THREADS=8
275
+ MKL_NUM_THREADS=8
276
+ EOF
277
+ ```
278
+
279
+ ## 🚀 步骤7:部署和启动
280
+
281
+ ### 7.1 使用Docker部署
282
+ ```bash
283
+ # 构建镜像
284
+ docker build -t adaptive-rag:gpu .
285
+
286
+ # 运行容器
287
+ docker run -d \
288
+ --gpus all \
289
+ --name adaptive-rag \
290
+ --env-file .env \
291
+ -p 8000:8000 \
292
+ -v $(pwd)/data:/app/data \
293
+ adaptive-rag:gpu
294
+ ```
295
+
296
+ ### 7.2 直接Python部署
297
+ ```bash
298
+ # 激活虚拟环境
299
+ source rag_env/bin/activate
300
+
301
+ # 启动系统
302
+ python main.py
303
+ ```
304
+
305
+ ## 📊 步骤8:性能监控
306
+
307
+ ### 8.1 创建监控脚本
308
+ ```bash
309
+ cat > monitor_gpu.py << 'EOF'
310
+ import psutil
311
+ import GPUtil
312
+ import time
313
+
314
+ def monitor_system():
315
+ while True:
316
+ # GPU监控
317
+ gpus = GPUtil.getGPUs()
318
+ for gpu in gpus:
319
+ print(f"GPU {gpu.id}: {gpu.load*100}% | 内存: {gpu.memoryUsed}MB/{gpu.memoryTotal}MB")
320
+
321
+ # CPU和内存监控
322
+ print(f"CPU: {psutil.cpu_percent()}% | 内存: {psutil.virtual_memory().percent}%")
323
+ print("-" * 50)
324
+ time.sleep(5)
325
+
326
+ if __name__ == "__main__":
327
+ monitor_system()
328
+ EOF
329
+
330
+ pip install gputil
331
+ python monitor_gpu.py
332
+ ```
333
+
334
+ ## 🔧 步骤9:性能优化配置
335
+
336
+ ### 9.1 创建优化启动脚本
337
+ ```bash
338
+ cat > start_optimized.sh << 'EOF'
339
+ #!/bin/bash
340
+
341
+ # 设置GPU优化环境变量
342
+ export CUDA_VISIBLE_DEVICES=0
343
+ export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512
344
+ export TOKENIZERS_PARALLELISM=false
345
+
346
+ # 启动系统
347
+ source rag_env/bin/activate
348
+ python main.py
349
+ EOF
350
+
351
+ chmod +x start_optimized.sh
352
+ ```
353
+
354
+ ### 9.2 创建系统服务
355
+ ```bash
356
+ # 创建systemd服务
357
+ sudo tee /etc/systemd/system/adaptive-rag.service > /dev/null << 'EOF'
358
+ [Unit]
359
+ Description=Adaptive RAG System
360
+ After=network.target
361
+
362
+ [Service]
363
+ Type=simple
364
+ User=your_username
365
+ WorkingDirectory=/path/to/adaptive_rag
366
+ Environment=PATH=/path/to/adaptive_rag/rag_env/bin
367
+ ExecStart=/path/to/adaptive_rag/rag_env/bin/python main.py
368
+ Restart=always
369
+ RestartSec=10
370
+
371
+ [Install]
372
+ WantedBy=multi-user.target
373
+ EOF
374
+
375
+ # 启用服务
376
+ sudo systemctl daemon-reload
377
+ sudo systemctl enable adaptive-rag
378
+ sudo systemctl start adaptive-rag
379
+ ```
380
+
381
+ ## 🐛 步骤10:故障排除
382
+
383
+ ### 10.1 常见问题
384
+
385
+ 1. **CUDA内存不足**
386
+ ```bash
387
+ # 减少批处理大小
388
+ export EMBEDDING_BATCH_SIZE=16
389
+ # 或者启用梯度检查点
390
+ export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:256
391
+ ```
392
+
393
+ 2. **Ollama连接问题**
394
+ ```bash
395
+ # 检查Ollama状态
396
+ sudo systemctl status ollama
397
+ # 重启Ollama
398
+ sudo systemctl restart ollama
399
+ ```
400
+
401
+ 3. **权限问题**
402
+ ```bash
403
+ # 添加用户到docker组
404
+ sudo usermod -aG docker $USER
405
+ # 重新登录
406
+ ```
407
+
408
+ ### 10.2 性能调优
409
+
410
+ ```bash
411
+ # GPU性能模式
412
+ sudo nvidia-smi -pm 1
413
+ sudo nvidia-smi -ac 9251,2100
414
+
415
+ # 系统优化
416
+ echo 'vm.swappiness=10' | sudo tee -a /etc/sysctl.conf
417
+ sudo sysctl -p
418
+ ```
419
+
420
+ ## 📈 预期性能
421
+
422
+ 在RTX 4090环境下的预期性能:
423
+ - **文档嵌入**: ~1000 documents/second
424
+ - **查询响应**: ~2-5 seconds per query
425
+ - **GPU利用率**: 60-80%
426
+ - **内存使用**: 8-12GB GPU memory
427
+
428
+ ## 🎯 验证部署
429
+
430
+ ```bash
431
+ # 测试GPU可用性
432
+ python -c "import torch; print(f'CUDA可用: {torch.cuda.is_available()}'); print(f'GPU数量: {torch.cuda.device_count()}')"
433
+
434
+ # 测试系统
435
+ curl -X POST http://localhost:8000/query \
436
+ -H "Content-Type: application/json" \
437
+ -d '{"question": "什么是LLM智能体?"}'
438
+ ```
439
+
440
+ 这个部署指南提供了完整的Linux GPU环境配置,确保您的自适应RAG系统能够充分利用RTX 4090的计算能力。
Dockerfile.gpu ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GPU优化Dockerfile - 针对RTX 4090
2
+ FROM nvidia/cuda:12.2-devel-ubuntu22.04
3
+
4
+ # 设置非交互模式和环境变量
5
+ ENV DEBIAN_FRONTEND=noninteractive
6
+ ENV PYTHONUNBUFFERED=1
7
+ ENV CUDA_VISIBLE_DEVICES=0
8
+ ENV PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512
9
+
10
+ # 更新系统并安装必要软件
11
+ RUN apt-get update && apt-get install -y \
12
+ python3 \
13
+ python3-pip \
14
+ python3-venv \
15
+ git \
16
+ curl \
17
+ wget \
18
+ build-essential \
19
+ && rm -rf /var/lib/apt/lists/*
20
+
21
+ # 创建应用目录
22
+ WORKDIR /app
23
+
24
+ # 创建必要目录
25
+ RUN mkdir -p /app/data /app/models /app/logs
26
+
27
+ # 复制依赖文件
28
+ COPY requirements_gpu.txt .
29
+
30
+ # 升级pip并安装Python依赖
31
+ RUN pip3 install --no-cache-dir --upgrade pip && \
32
+ pip3 install --no-cache-dir -r requirements_gpu.txt
33
+
34
+ # 复制应用文件
35
+ COPY *.py .
36
+ COPY *.md .
37
+ COPY .env.example .
38
+
39
+ # 设置Python路径
40
+ ENV PYTHONPATH=/app
41
+
42
+ # 创建启动脚本
43
+ RUN echo '#!/bin/bash\n\
44
+ export CUDA_VISIBLE_DEVICES=0\n\
45
+ export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512\n\
46
+ export TOKENIZERS_PARALLELISM=false\n\
47
+ python3 -c "import torch; print(f'"'"'CUDA可用: {torch.cuda.is_available()}'"'"'); print(f'"'"'GPU数量: {torch.cuda.device_count()}'"'"')"\n\
48
+ python3 main.py' > /app/start.sh && chmod +x /app/start.sh
49
+
50
+ # 暴露端口
51
+ EXPOSE 8000 8001
52
+
53
+ # 健康检查
54
+ HEALTHCHECK --interval=30s --timeout=30s --start-period=60s --retries=3 \
55
+ CMD curl -f http://localhost:8000/health || exit 1
56
+
57
+ # 启动命令
58
+ CMD ["/app/start.sh"]
GRAPHRAG_GUIDE.md ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GraphRAG 集成指南
2
+
3
+ ## 📋 概述
4
+
5
+ 本项目已集成**Microsoft GraphRAG**架构,通过知识图谱增强传统向量检索,提供更精准的信息提取和推理能力。
6
+
7
+ ## 🏗️ GraphRAG 架构
8
+
9
+ ### 核心组件
10
+
11
+ ```
12
+ 文档集合
13
+
14
+ ┌─────────────────────────────────────┐
15
+ │ 实体和关系提取 (Entity Extraction) │
16
+ │ - 使用LLM识别实体 │
17
+ │ - 提取实体间关系 │
18
+ └─────────────────────────────────────┘
19
+
20
+ ┌─────────────────────────────────────┐
21
+ │ 知识图谱构建 (Graph Construction) │
22
+ │ - 实体去重 │
23
+ │ - 构建图结构 │
24
+ └─────────────────────────────────────┘
25
+
26
+ ┌─────────────────────────────────────┐
27
+ │ 社区检测 (Community Detection) │
28
+ │ - Louvain算法 │
29
+ │ - 层次化聚类 │
30
+ └─────────────────────────────────────┘
31
+
32
+ ┌─────────────────────────────────────┐
33
+ │ 社区摘要生成 (Community Summaries) │
34
+ │ - LLM生成摘要 │
35
+ │ - 多层次索引 │
36
+ └─────────────────────────────────────┘
37
+
38
+ 查询阶段
39
+
40
+ ┌──────────────┬──────────────┐
41
+ │ 本地查询 │ 全局查询 │
42
+ │ (Local Query)│(Global Query)│
43
+ │ │ │
44
+ │ 实体邻域检索 │ 社区摘要查询 │
45
+ └──────────────┴──────────────┘
46
+ ```
47
+
48
+ ## 📦 新增文件说明
49
+
50
+ ### 1. **entity_extractor.py** - 实体提取器
51
+ ```python
52
+ EntityExtractor
53
+ ├── extract_entities() # 从文本提取实体
54
+ ├── extract_relations() # 提取实体关系
55
+ └── extract_from_document() # 完整文档处理
56
+
57
+ EntityDeduplicator
58
+ └── deduplicate_entities() # 实体去重
59
+ ```
60
+
61
+ **功能**:
62
+ - 使用LLM识别6种实体类型 (PERSON, ORGANIZATION, CONCEPT, TECHNOLOGY, PAPER, EVENT)
63
+ - 提取8种关系类型 (AUTHOR_OF, USES, BASED_ON, etc.)
64
+ - 智能实体去重和合并
65
+
66
+ ### 2. **knowledge_graph.py** - 知识图谱核心
67
+ ```python
68
+ KnowledgeGraph
69
+ ├── add_entity() # 添加节点
70
+ ├── add_relation() # 添加边
71
+ ├── build_from_extractions() # 构建图谱
72
+ ├── detect_communities() # 社区检测
73
+ ├── get_community_members() # 获取社区成员
74
+ └── get_statistics() # 统计信息
75
+
76
+ CommunitySummarizer
77
+ ├── summarize_community() # 单社区摘要
78
+ └── summarize_all_communities() # 全部社区摘要
79
+ ```
80
+
81
+ **功能**:
82
+ - 基于NetworkX的图谱管理
83
+ - 支持3种社区检测算法 (Louvain, Greedy, Label Propagation)
84
+ - LLM驱动的社区摘要生成
85
+ - 图谱持久化存储
86
+
87
+ ### 3. **graph_indexer.py** - 索引构建器
88
+ ```python
89
+ GraphRAGIndexer
90
+ ├── index_documents() # 构建索引
91
+ ├── get_graph() # 获取图谱
92
+ └── load_index() # 加载索引
93
+ ```
94
+
95
+ **流程**:
96
+ 1. 批量实体提取
97
+ 2. 实体去重合并
98
+ 3. 构建知识图谱
99
+ 4. 社区检测
100
+ 5. 生成摘要
101
+
102
+ ### 4. **graph_retriever.py** - 图谱检索器
103
+ ```python
104
+ GraphRetriever
105
+ ├── recognize_entities() # 识别问题中的实体
106
+ ├── local_query() # 本地查询
107
+ ├── global_query() # 全局查询
108
+ ├── hybrid_query() # 混合查询
109
+ └── smart_query() # 智能查询
110
+ ```
111
+
112
+ **查询模式**:
113
+ - **本地查询**: 针对特定实体的详细问题
114
+ - **全局查询**: 需要整体理解的概括性问题
115
+ - **智能查询**: 自动选择最佳策略
116
+
117
+ ### 5. **main_graphrag.py** - GraphRAG集成示例
118
+ 完整的使用示例和交互式界面
119
+
120
+ ### 6. **requirements_graphrag.txt** - 额外依赖
121
+ GraphRAG所需的图处理库
122
+
123
+ ## 🚀 快速开始
124
+
125
+ ### 安装依赖
126
+
127
+ ```bash
128
+ # 安装基础依赖
129
+ pip install -r requirements.txt
130
+
131
+ # 安装GraphRAG依赖
132
+ pip install -r requirements_graphrag.txt
133
+ ```
134
+
135
+ ### 首次使用
136
+
137
+ ```python
138
+ # 方式1: 使用集成示例
139
+ python main_graphrag.py
140
+
141
+ # 方式2: 在代码中集成
142
+ from config import setup_environment
143
+ from document_processor import initialize_document_processor
144
+ from graph_indexer import initialize_graph_indexer
145
+ from graph_retriever import initialize_graph_retriever
146
+
147
+ # 初始化
148
+ setup_environment()
149
+ processor, vectorstore, retriever, doc_splits = initialize_document_processor()
150
+
151
+ # 构建GraphRAG索引
152
+ graph_indexer = initialize_graph_indexer()
153
+ knowledge_graph = graph_indexer.index_documents(
154
+ documents=doc_splits,
155
+ save_path="./data/knowledge_graph.json"
156
+ )
157
+
158
+ # 初始化检索器
159
+ graph_retriever = initialize_graph_retriever(knowledge_graph)
160
+
161
+ # 查询
162
+ answer = graph_retriever.smart_query("LLM Agent的核心组件是什么?")
163
+ print(answer)
164
+ ```
165
+
166
+ ## 🔧 配置说明
167
+
168
+ 在 `config.py` 中添加了以下配置:
169
+
170
+ ```python
171
+ # GraphRAG配置
172
+ ENABLE_GRAPHRAG = True # 是否启用GraphRAG
173
+ GRAPHRAG_INDEX_PATH = "./data/knowledge_graph.json" # 图谱存储路径
174
+ GRAPHRAG_COMMUNITY_ALGORITHM = "louvain" # 社区检测算法
175
+ GRAPHRAG_MAX_HOPS = 2 # 本地查询最大跳数
176
+ GRAPHRAG_TOP_K_COMMUNITIES = 5 # 全局查询使用的社区数
177
+ GRAPHRAG_BATCH_SIZE = 10 # 实体提取批大小
178
+ ```
179
+
180
+ ## 📊 使用场景对比
181
+
182
+ ### 传统向量检索 vs GraphRAG
183
+
184
+ | 场景 | 向量检索 | GraphRAG | 推荐 |
185
+ |-----|---------|----------|------|
186
+ | "AlphaCodium的作者是谁?" | ⚠️ 可能找到但不精确 | ✅ 直接查询实体关系 | GraphRAG本地查询 |
187
+ | "这些文档讨论什么主题?" | ⚠️ 需要读取多个片段 | ✅ 社区摘要直接回答 | GraphRAG全局查询 |
188
+ | "提示工程的应用场景" | ✅ 语义匹配效果好 | ✅ 可追踪关系链 | 混合查询 |
189
+ | "最新技术发展" | ✅ 适合模糊查询 | ❌ 需要明确实体 | 向量检索 |
190
+
191
+ ## 🎯 查询策略选择
192
+
193
+ ### 本地查询 (Local Query)
194
+ **适用**: 针对特定实体的详细问题
195
+
196
+ ```python
197
+ # 示例问题
198
+ "LLM Agent包含哪些组件?"
199
+ "Transformer模型的作者是谁?"
200
+ "AlphaCodium使用了什么技术?"
201
+
202
+ # 代码
203
+ answer = graph_retriever.local_query(question, max_hops=2)
204
+ ```
205
+
206
+ **工作原理**:
207
+ 1. 识别问题中的实体
208
+ 2. 扩展到邻居节点(支持多跳)
209
+ 3. 收集实体信息和关系
210
+ 4. 基于子图生成答案
211
+
212
+ ### 全局查询 (Global Query)
213
+ **适用**: 需要整体视角的概括性问题
214
+
215
+ ```python
216
+ # 示例问题
217
+ "这些文档的主要主题是什么?"
218
+ "涵盖了哪些研究领域?"
219
+ "关键的技术趋势有哪些?"
220
+
221
+ # 代码
222
+ answer = graph_retriever.global_query(question, top_k_communities=5)
223
+ ```
224
+
225
+ **工作原理**:
226
+ 1. 获取社区摘要
227
+ 2. 基于摘要理解全局结构
228
+ 3. 综合多个社区的信息
229
+ 4. 生成高层次答案
230
+
231
+ ### 智能查询 (Smart Query)
232
+ **适用**: 自动选择最佳策略
233
+
234
+ ```python
235
+ # 自动判断使用本地还是全局查询
236
+ answer = graph_retriever.smart_query(question)
237
+ ```
238
+
239
+ **决策逻辑**:
240
+ - 包含具体实体名称 → 本地查询
241
+ - 包含"主要"、"总体"、"概述"等关键词 → 全局查询
242
+ - 默认 → 本地查询
243
+
244
+ ### 混合查询 (Hybrid Query)
245
+ **适用**: 需要多种视角的复杂问题
246
+
247
+ ```python
248
+ result = graph_retriever.hybrid_query(question)
249
+ # 返回: {"local": "...", "global": "..."}
250
+ ```
251
+
252
+ ## 📈 性能优化
253
+
254
+ ### 索引构建优化
255
+
256
+ ```python
257
+ # 1. 批处理大小
258
+ graph_indexer.index_documents(
259
+ documents=doc_splits,
260
+ batch_size=20 # 增大批处理提高速度
261
+ )
262
+
263
+ # 2. 增量索引(开发中)
264
+ # 避免每次重建整个图谱
265
+
266
+ # 3. 缓存已有索引
267
+ if os.path.exists(GRAPHRAG_INDEX_PATH):
268
+ knowledge_graph = graph_indexer.load_index(GRAPHRAG_INDEX_PATH)
269
+ ```
270
+
271
+ ### 查询优化
272
+
273
+ ```python
274
+ # 1. 调整跳数
275
+ answer = graph_retriever.local_query(question, max_hops=1) # 减少跳数提速
276
+
277
+ # 2. 限制社区数量
278
+ answer = graph_retriever.global_query(question, top_k_communities=3) # 减少社区数
279
+
280
+ # 3. 实体识别缓存(开发中)
281
+ ```
282
+
283
+ ## 🔍 调试和可视化
284
+
285
+ ### 查看图谱统计
286
+
287
+ ```python
288
+ stats = knowledge_graph.get_statistics()
289
+ print(f"节点数: {stats['num_nodes']}")
290
+ print(f"边数: {stats['num_edges']}")
291
+ print(f"社区数: {stats['num_communities']}")
292
+ ```
293
+
294
+ ### 导出图谱
295
+
296
+ ```python
297
+ # 保存为JSON
298
+ knowledge_graph.save_to_file("my_graph.json")
299
+
300
+ # 加载图谱
301
+ knowledge_graph.load_from_file("my_graph.json")
302
+ ```
303
+
304
+ ### 可视化(可选)
305
+
306
+ ```python
307
+ # 需要额外安装: pip install pyvis
308
+ from pyvis.network import Network
309
+
310
+ def visualize_graph(kg, output="graph.html"):
311
+ net = Network(height="750px", width="100%", bgcolor="#222222", font_color="white")
312
+
313
+ for node, data in kg.graph.nodes(data=True):
314
+ net.add_node(node, label=node, title=data.get('description', ''))
315
+
316
+ for u, v, data in kg.graph.edges(data=True):
317
+ net.add_edge(u, v, title=data.get('relation_type', ''))
318
+
319
+ net.show(output)
320
+ print(f"图谱已保存到: {output}")
321
+ ```
322
+
323
+ ## ⚠️ 常见问题
324
+
325
+ ### Q1: 实体提取质量不高?
326
+ **A**:
327
+ - 调整LLM温度参数
328
+ - 优化实体提取提示词
329
+ - 使用更强大的LLM模型
330
+
331
+ ### Q2: 索引构建时间长?
332
+ **A**:
333
+ - 增大批处理大小
334
+ - 减少文档数量进行测试
335
+ - 使用缓存的索引文件
336
+
337
+ ### Q3: 查询结果不相关?
338
+ **A**:
339
+ - 检查实体识别是否准确
340
+ - 调整查询策略(本地/全局)
341
+ - 增加邻居跳数
342
+
343
+ ### Q4: 内存占用过大?
344
+ **A**:
345
+ - 使用更轻量的图数据库
346
+ - 分批处理大文档集
347
+ - 限制社区检测的迭代次数
348
+
349
+ ## 🔄 与现有系统集成
350
+
351
+ ### 修改现有 main.py
352
+
353
+ ```python
354
+ from config import ENABLE_GRAPHRAG
355
+ from graph_indexer import initialize_graph_indexer
356
+ from graph_retriever import initialize_graph_retriever
357
+
358
+ class AdaptiveRAGSystem:
359
+ def __init__(self):
360
+ # ... 现有初始化代码 ...
361
+
362
+ # 添加GraphRAG支持
363
+ if ENABLE_GRAPHRAG:
364
+ self._setup_graphrag()
365
+
366
+ def _setup_graphrag(self):
367
+ self.graph_indexer = initialize_graph_indexer()
368
+ # ... 索引构建 ...
369
+ self.graph_retriever = initialize_graph_retriever(self.knowledge_graph)
370
+
371
+ def query(self, question: str):
372
+ # 混合使用向量检索和图谱查询
373
+ vector_docs = self.retriever.get_relevant_documents(question)
374
+
375
+ if ENABLE_GRAPHRAG:
376
+ graph_answer = self.graph_retriever.smart_query(question)
377
+ # 融合两种结果
378
+ return self._merge_results(vector_docs, graph_answer)
379
+
380
+ return self._generate_from_docs(vector_docs)
381
+ ```
382
+
383
+ ## 📚 参考资料
384
+
385
+ - [Microsoft GraphRAG 论文](https://arxiv.org/abs/2404.16130)
386
+ - [NetworkX 文档](https://networkx.org/)
387
+ - [Louvain 社区检测算法](https://en.wikipedia.org/wiki/Louvain_method)
388
+
389
+ ## 🛣️ 未来增强
390
+
391
+ - [ ] 增量索引更新
392
+ - [ ] 多模态知识图谱
393
+ - [ ] 图谱可视化界面
394
+ - [ ] Neo4j集成(生产环境)
395
+ - [ ] 知识图谱推理引擎
396
+ - [ ] 实体链接优化
397
+ - [ ] 自动实体消歧
398
+
399
+ ---
400
+
401
+ **提示**: 首次使用建议先在小数据集上测试,验证效果后再应用到完整数据集。
GRAPHRAG_INTEGRATION_SUMMARY.md ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GraphRAG 集成完成总结
2
+
3
+ ## ✅ 已完成的工作
4
+
5
+ ### 🆕 新增文件 (7个)
6
+
7
+ | 文件 | 行数 | 主要功能 |
8
+ |------|------|---------|
9
+ | **entity_extractor.py** | 225 | 实体和关系提取、实体去重 |
10
+ | **knowledge_graph.py** | 348 | 图谱构建、社区检测、摘要生成 |
11
+ | **graph_indexer.py** | 146 | GraphRAG索引构建流程 |
12
+ | **graph_retriever.py** | 276 | 本地/全局/智能查询 |
13
+ | **main_graphrag.py** | 294 | 完整使用示例和交互界面 |
14
+ | **requirements_graphrag.txt** | 32 | GraphRAG额外依赖 |
15
+ | **GRAPHRAG_GUIDE.md** | 402 | 详细使用指南 |
16
+
17
+ ### 🔧 修改的文件 (3个)
18
+
19
+ | 文件 | 修改内容 |
20
+ |------|---------|
21
+ | **config.py** | 添加7个GraphRAG配置参数 |
22
+ | **document_processor.py** | 修改`setup_knowledge_base()`返回doc_splits |
23
+ | **requirements.txt** | 添加networkx和python-louvain依赖 |
24
+
25
+ ---
26
+
27
+ ## 📋 文件修改详情
28
+
29
+ ### 1. config.py - 新增配置
30
+
31
+ ```python
32
+ # GraphRAG配置
33
+ ENABLE_GRAPHRAG = True
34
+ GRAPHRAG_INDEX_PATH = "./data/knowledge_graph.json"
35
+ GRAPHRAG_COMMUNITY_ALGORITHM = "louvain"
36
+ GRAPHRAG_MAX_HOPS = 2
37
+ GRAPHRAG_TOP_K_COMMUNITIES = 5
38
+ GRAPHRAG_BATCH_SIZE = 10
39
+ ```
40
+
41
+ ### 2. document_processor.py - 函数修改
42
+
43
+ ```python
44
+ # 修改前
45
+ def setup_knowledge_base(self, urls=None):
46
+ ...
47
+ return vectorstore, retriever
48
+
49
+ # 修改后
50
+ def setup_knowledge_base(self, urls=None, enable_graphrag=False):
51
+ ...
52
+ return vectorstore, retriever, doc_splits # 新增返回doc_splits
53
+
54
+ # 同步修改
55
+ def initialize_document_processor():
56
+ ...
57
+ return processor, vectorstore, retriever, doc_splits # 新增doc_splits
58
+ ```
59
+
60
+ ### 3. requirements.txt - 新增依赖
61
+
62
+ ```txt
63
+ # GraphRAG相关(可选)
64
+ networkx>=3.1
65
+ python-louvain>=0.16
66
+ ```
67
+
68
+ ---
69
+
70
+ ## 🏗️ GraphRAG 架构概览
71
+
72
+ ```
73
+ ┌─────────────────────────────────────────────────────────────┐
74
+ │ 文档处理层 │
75
+ │ document_processor.py → doc_splits │
76
+ └────────────────────────┬────────────────────────────────────┘
77
+
78
+ ┌─────────────────────────────────────────────────────────────┐
79
+ │ 实体提取层 │
80
+ │ entity_extractor.py │
81
+ │ ├── EntityExtractor (实体和关系提取) │
82
+ │ └── EntityDeduplicator (实体去重) │
83
+ └────────────────────────┬────────────────────────────────────┘
84
+
85
+ ┌─────────────────────────────────────────────────────────────┐
86
+ │ 图谱构建层 │
87
+ │ knowledge_graph.py │
88
+ │ ├── KnowledgeGraph (图谱管理) │
89
+ │ │ ├── NetworkX图结构 │
90
+ │ │ ├── 社区检测 (Louvain/Greedy/LabelProp) │
91
+ │ │ └── 统计分析 │
92
+ │ └── CommunitySummarizer (社区摘要) │
93
+ └────────────────────────┬────────────────────────────────────┘
94
+
95
+ ┌─────────────────────────────────────────────────────────────┐
96
+ │ 索引构建层 │
97
+ │ graph_indexer.py │
98
+ │ └── GraphRAGIndexer │
99
+ │ ├── 5步索引流程 │
100
+ │ └── 图谱持久化 │
101
+ └────────────────────────┬────────────────────────────────────┘
102
+
103
+ ┌─────────────────────────────────────────────────────────────┐
104
+ │ 检索查询层 │
105
+ │ graph_retriever.py │
106
+ │ └── GraphRetriever │
107
+ │ ├── 本地查询 (Local Query) │
108
+ │ ├── 全局查询 (Global Query) │
109
+ │ ├── 混合查询 (Hybrid Query) │
110
+ │ └── 智能查询 (Smart Query) │
111
+ └────────────────────────┬────────────────────────────────────┘
112
+
113
+ ┌─────────────────────────────────────────────────────────────┐
114
+ │ 应用层 │
115
+ │ main_graphrag.py │
116
+ │ └── AdaptiveRAGWithGraph │
117
+ │ ├── 5种查询模式 │
118
+ │ ├── 统计信息展示 │
119
+ │ └── 交互式界面 │
120
+ └─────────────────────────────────────────────────────────────┘
121
+ ```
122
+
123
+ ---
124
+
125
+ ## 🚀 使用流程
126
+
127
+ ### 方式1: 直接运行示例
128
+
129
+ ```bash
130
+ # 1. 安装依赖
131
+ pip install -r requirements.txt
132
+ pip install -r requirements_graphrag.txt
133
+
134
+ # 2. 运行GraphRAG示例
135
+ python main_graphrag.py
136
+
137
+ # 首次运行会自动构建索引,后续运行会加载缓存
138
+ ```
139
+
140
+ ### 方式2: 集成到现有代码
141
+
142
+ ```python
143
+ # 在 main.py 中集成
144
+ from config import ENABLE_GRAPHRAG, GRAPHRAG_INDEX_PATH
145
+ from graph_indexer import initialize_graph_indexer
146
+ from graph_retriever import initialize_graph_retriever
147
+
148
+ class AdaptiveRAGSystem:
149
+ def __init__(self):
150
+ # ... 现有初始化 ...
151
+
152
+ if ENABLE_GRAPHRAG:
153
+ # 构建/加载图谱
154
+ self.graph_indexer = initialize_graph_indexer()
155
+
156
+ if os.path.exists(GRAPHRAG_INDEX_PATH):
157
+ self.kg = self.graph_indexer.load_index(GRAPHRAG_INDEX_PATH)
158
+ else:
159
+ self.kg = self.graph_indexer.index_documents(
160
+ self.doc_splits,
161
+ save_path=GRAPHRAG_INDEX_PATH
162
+ )
163
+
164
+ # 初始化检索器
165
+ self.graph_retriever = initialize_graph_retriever(self.kg)
166
+
167
+ def query(self, question: str):
168
+ if ENABLE_GRAPHRAG:
169
+ # 使用图谱智能查询
170
+ return self.graph_retriever.smart_query(question)
171
+ else:
172
+ # 原有逻辑
173
+ ...
174
+ ```
175
+
176
+ ---
177
+
178
+ ## 📊 功能对比
179
+
180
+ ### 原系统 vs GraphRAG增强
181
+
182
+ | 功能 | 原系统 | GraphRAG增强 | 提升 |
183
+ |------|--------|--------------|------|
184
+ | **检索方式** | 向量相似度 | 向量 + 图谱 | ✅ 多模态检索 |
185
+ | **关系理解** | ❌ 无 | ✅ 显式关系 | ✅ 关系推理能力 |
186
+ | **多跳推理** | ❌ 有限 | ✅ 支持N跳 | ✅ 复杂推理 |
187
+ | **全局理解** | ⚠️ 需读取多文档 | ✅ 社区摘要 | ✅ 高效概览 |
188
+ | **实体消歧** | ❌ 无 | ✅ 图谱上下文 | ✅ 准确识别 |
189
+ | **事实验证** | 基于文档匹配 | 基于关系验证 | ✅ 更严格 |
190
+
191
+ ---
192
+
193
+ ## 🎯 适用场景
194
+
195
+ ### GraphRAG特别适合:
196
+
197
+ ✅ **知识密集型领域**
198
+ - 学术论文、技术文档
199
+ - 需要理解实体关系
200
+ - 例: "AlphaCodium的作者研究了哪些其他技术?"
201
+
202
+ ✅ **需要推理的问题**
203
+ - 多跳关系查询
204
+ - 因果关系分析
205
+ - 例: "提示工程如何应用于对抗性攻击防御?"
206
+
207
+ ✅ **概览性问题**
208
+ - 主题归纳
209
+ - 研究趋势
210
+ - 例: "这个领域的主要研究方向有哪些?"
211
+
212
+ ### 仍使用向量检索:
213
+
214
+ ⚠️ **模糊语义查询**
215
+ - 没有明确实体
216
+ - 需要语义相似匹配
217
+
218
+ ⚠️ **最新资讯查询**
219
+ - 图谱未覆盖的新内容
220
+ - 需要网络搜索
221
+
222
+ ---
223
+
224
+ ## 🔧 配置参数说明
225
+
226
+ ```python
227
+ # config.py
228
+
229
+ ENABLE_GRAPHRAG = True
230
+ # 是否启用GraphRAG,False则回退到纯向量检索
231
+
232
+ GRAPHRAG_INDEX_PATH = "./data/knowledge_graph.json"
233
+ # 图谱持久化路径,避免每次重建
234
+
235
+ GRAPHRAG_COMMUNITY_ALGORITHM = "louvain"
236
+ # 社区检测算法:
237
+ # - "louvain": 最优质量(推荐)
238
+ # - "greedy": 更快速度
239
+ # - "label_propagation": 快速近似
240
+
241
+ GRAPHRAG_MAX_HOPS = 2
242
+ # 本地查询时扩展的邻居深度
243
+ # 1: 只看直接邻居
244
+ # 2: 二跳邻居(推荐)
245
+ # 3+: 可能包含过多噪声
246
+
247
+ GRAPHRAG_TOP_K_COMMUNITIES = 5
248
+ # 全局查询时使用的社区数量
249
+ # 更多社区 = 更全面但更慢
250
+
251
+ GRAPHRAG_BATCH_SIZE = 10
252
+ # 实体提取的批处理大小
253
+ # 更大批次 = 更快但更耗内存
254
+ ```
255
+
256
+ ---
257
+
258
+ ## 📈 性能特征
259
+
260
+ ### 索引构建时间
261
+
262
+ | 文档数量 | 实体数 | 关系数 | 社区数 | 构建时间* |
263
+ |---------|--------|--------|--------|----------|
264
+ | 10个文档块 | ~50 | ~30 | 3-5 | ~2分钟 |
265
+ | 50个文档块 | ~200 | ~150 | 8-12 | ~8分钟 |
266
+ | 100个文档块 | ~400 | ~300 | 15-20 | ~15分钟 |
267
+
268
+ *基于Mistral模型,实际时间取决于LLM速度
269
+
270
+ ### 查询速度
271
+
272
+ | 查询类型 | 平均耗时 | 说明 |
273
+ |---------|---------|------|
274
+ | 本地查询 | 2-5秒 | 需要LLM生成答案 |
275
+ | 全局查询 | 3-8秒 | 需要处理多个社区摘要 |
276
+ | 智能查询 | 2-8秒 | 取决于选择的策略 |
277
+ | 混合查询 | 5-12秒 | 执行两种查询 |
278
+
279
+ ### 存储需求
280
+
281
+ - **图谱索引**: 100个文档块 ≈ 1-5 MB (JSON格式)
282
+ - **内存占用**: 运行时 ≈ 200-500 MB (取决于图大小)
283
+
284
+ ---
285
+
286
+ ## 🐛 故障排查
287
+
288
+ ### 问题1: 实体提取失败
289
+ ```
290
+ ❌ 实体提取失败: timeout
291
+ ```
292
+
293
+ **解决方案**:
294
+ - 检查Ollama服务是否运行: `ollama serve`
295
+ - 减少批处理大小: `GRAPHRAG_BATCH_SIZE = 5`
296
+ - 使用更快的LLM模型
297
+
298
+ ### 问题2: 社区检测失败
299
+ ```
300
+ ⚠️ python-louvain未安装
301
+ ```
302
+
303
+ **解决方案**:
304
+ ```bash
305
+ pip install python-louvain
306
+ # 或使用其他算法
307
+ GRAPHRAG_COMMUNITY_ALGORITHM = "greedy"
308
+ ```
309
+
310
+ ### 问题3: 查询无结果
311
+ ```
312
+ 未能在知识图谱中找到相关实体
313
+ ```
314
+
315
+ **解决方案**:
316
+ - 检查图谱是否构建: `rag_system.get_graph_statistics()`
317
+ - 使用全局查询代替本地查询
318
+ - 检查实体提取质量
319
+
320
+ ### 问题4: 内存不足
321
+ ```
322
+ MemoryError
323
+ ```
324
+
325
+ **解决方案**:
326
+ - 减少文档数量测试
327
+ - 增加批处理间隔
328
+ - 使用轻量级图存储
329
+
330
+ ---
331
+
332
+ ## 📝 代码示例
333
+
334
+ ### 示例1: 基本使用
335
+
336
+ ```python
337
+ from main_graphrag import AdaptiveRAGWithGraph
338
+
339
+ # 初始化系统
340
+ rag = AdaptiveRAGWithGraph(enable_graphrag=True)
341
+
342
+ # 本地查询(针对特定实体)
343
+ answer = rag.query_graph_local("LLM Agent的主要组件是什么?")
344
+
345
+ # 全局查询(概览性问题)
346
+ answer = rag.query_graph_global("这些文档讨论了哪些主题?")
347
+
348
+ # 智能查询(自动选择策略)
349
+ answer = rag.query_smart("如何防御对抗性攻击?")
350
+ ```
351
+
352
+ ### 示例2: 混合检索
353
+
354
+ ```python
355
+ # 同时使用向量和图谱
356
+ result = rag.query_hybrid("提示工程在LLM中的应用")
357
+
358
+ print("向量检索:", result["vector_retrieval"]["context"])
359
+ print("图谱本地:", result["graph_local"])
360
+ print("图谱全局:", result["graph_global"])
361
+ ```
362
+
363
+ ### 示例3: 手动控制
364
+
365
+ ```python
366
+ from graph_indexer import initialize_graph_indexer
367
+ from graph_retriever import initialize_graph_retriever
368
+
369
+ # 构建索引
370
+ indexer = initialize_graph_indexer()
371
+ kg = indexer.index_documents(documents, save_path="my_graph.json")
372
+
373
+ # 查看统计
374
+ stats = kg.get_statistics()
375
+ print(f"实体: {stats['num_nodes']}, 关系: {stats['num_edges']}")
376
+
377
+ # 查询
378
+ retriever = initialize_graph_retriever(kg)
379
+ answer = retriever.local_query("specific question", max_hops=3)
380
+ ```
381
+
382
+ ---
383
+
384
+ ## 🎓 学习资源
385
+
386
+ ### 推荐阅读顺序
387
+
388
+ 1. **GRAPHRAG_GUIDE.md** - 详细使用指南
389
+ 2. **entity_extractor.py** - 了解实体提取
390
+ 3. **knowledge_graph.py** - 理解图谱构建
391
+ 4. **graph_retriever.py** - 学习查询策略
392
+ 5. **main_graphrag.py** - 完整实践示例
393
+
394
+ ### 关键概念
395
+
396
+ - **实体 (Entity)**: 图中的节点,如人物、概念、技术
397
+ - **关系 (Relation)**: 图中的边,连接两个实体
398
+ - **社区 (Community)**: 紧密连接的节点群组
399
+ - **本地查询**: 基于实体邻域的精确查询
400
+ - **全局查询**: 基于社区摘要的概览查询
401
+
402
+ ---
403
+
404
+ ## 🔮 未来计划
405
+
406
+ - [ ] **增量索引**: 添加新文档无需重建整个图谱
407
+ - [ ] **Neo4j集成**: 生产环境使用专业图数据库
408
+ - [ ] **可视化界面**: Web界面展示知识图谱
409
+ - [ ] **多模型融合**: 结合多个LLM提高提取质量
410
+ - [ ] **实时更新**: 动态更新图谱结构
411
+ - [ ] **知识推理**: 基于图谱的推理引擎
412
+ - [ ] **性能优化**: 并行处理、缓存机制
413
+
414
+ ---
415
+
416
+ ## 📞 支持
417
+
418
+ 遇到问题?
419
+
420
+ 1. 查看 **GRAPHRAG_GUIDE.md** 的"常见问题"章节
421
+ 2. 检查日志输出中的错误信息
422
+ 3. 运行 `python main_graphrag.py` 测试基本功能
423
+ 4. 使用 `get_graph_statistics()` 检查图谱状态
424
+
425
+ ---
426
+
427
+ **总结**: GraphRAG已成功集成到自适应RAG系统中,提供了从实体提取到智能查询的完整工作流。通过合理选择查询策略,可以显著提升复杂问题的回答质量。
QUICKSTART.md ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 快速开始指南
2
+
3
+ ## 安装依赖
4
+
5
+ ```bash
6
+ pip install -r requirements.txt
7
+ ```
8
+
9
+ ## 环境配置
10
+
11
+ ### 1. 设置API密钥
12
+
13
+ 复制 `.env.example` 文件为 `.env`:
14
+ ```bash
15
+ cp .env.example .env
16
+ ```
17
+
18
+ 编辑 `.env` 文件,填入您的实际API密钥:
19
+ ```bash
20
+ # Tavily API密钥 - 用于网络搜索功能
21
+ # 从 https://tavily.com/ 获取
22
+ TAVILY_API_KEY=your_actual_tavily_api_key
23
+
24
+ # Nomic API密钥 - 用于文本嵌入服务
25
+ # 从 https://atlas.nomic.ai/ 获取
26
+ NOMIC_API_KEY=your_actual_nomic_api_key
27
+ ```
28
+
29
+ ### 2. 确保本地模型可用
30
+
31
+ 确保您已安装并运行Ollama,并下载了Mistral模型:
32
+ ```bash
33
+ # 安装Ollama
34
+ # 访问 https://ollama.ai/ 下载安装
35
+
36
+ # 下载Mistral模型
37
+ ollama pull mistral
38
+ ```
39
+
40
+ ## 运行系统
41
+
42
+ ```bash
43
+ python main.py
44
+ ```
45
+
46
+ ## 项目结构
47
+
48
+ ```
49
+ adaptive_RAG/
50
+ ├── config.py # 配置和环境设置
51
+ ├── document_processor.py # 文档处理和向量化
52
+ ├── routers_and_graders.py # 路由器和评分器
53
+ ├── workflow_nodes.py # 工作流节点
54
+ ├── main.py # 主应用程序入口
55
+ ├── requirements.txt # 依赖管理
56
+ ├── README.md # 项目说明
57
+ └── QUICKSTART.md # 快速开始指南
58
+ ```
59
+
60
+ ## 功能模块说明
61
+
62
+ 1. **config.py**: 包含所有配置项、API密钥管理和环境变量设置
63
+ 2. **document_processor.py**: 负责文档加载、分块、向量化和检索器设置
64
+ 3. **routers_and_graders.py**: 实现查询路由、文档评分、答案质量评估等功能
65
+ 4. **workflow_nodes.py**: 定义所有工作流节点和状态管理
66
+ 5. **main.py**: 系统集成和用户交互界面
67
+
68
+ ## 使用示例
69
+
70
+ 系统启动后会自动进入交互模式,你可以:
71
+ - 询问关于LLM、提示工程、对抗性攻击的问题(使用本地知识库)
72
+ - 询问其他问题(自动路由到网络搜索)
73
+ - 输入 'quit' 或 'exit' 退出系统
README.md ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 自适应检索增强生成系统
2
+
3
+ ## 📋 项目描述
4
+
5
+ 本项目实现了一个**自适应检索增强生成(Adaptive RAG)**系统,能够智能地在本地向量数据库检索和网络搜索之间路由用户查询。系统使用复杂的工作流程,通过自适应选择最佳信息源并通过文档评分和查询转换持续改进响应质量,以提供准确、上下文相关的答案。
6
+
7
+ ## ✨ 核心功能
8
+
9
+ ### 🔄 智能查询路由
10
+ - 根据查询内容自动判断是使用本地向量存储还是网络搜索
11
+ - 将关于LLM智能体、提示工程和对抗性攻击的问题路由到向量存储
12
+ - 对于一般查询回退到网络搜索
13
+
14
+ ### 📚 高级文档处理
15
+ - 使用tiktoken编码器进行网络内容加载和分块
16
+ - 使用Nomic嵌入(本地推理)进行向量嵌入
17
+ - Chroma向量数据库实现高效相似性搜索
18
+
19
+ ### 🎯 质量保证流水线
20
+ - **文档相关性评分**:过滤与查询相关的检索文档
21
+ - **答案质量评估**:评估生成的答案是否解决了问题
22
+ - **幻觉检测**:确保回答基于源文档
23
+ - **查询转换**:改进未产生满意结果的查询
24
+
25
+ ### 🌐 混合信息检索
26
+ - 用于特定领域知识的本地向量数据库
27
+ - 通过Tavily API集成网络搜索获取最新信息
28
+ - 需要时无缝结合两种信息源
29
+
30
+ ## 🛠️ 技术栈
31
+
32
+ ### 核心框架
33
+ - **LangChain**:LLM应用程序编排框架
34
+ - **LangGraph**:复杂工作流的状态图实现
35
+
36
+ ### 语言模型
37
+ - **Ollama**:本地LLM推理(Mistral模型)
38
+ - **ChatOllama**:Ollama的LangChain集成
39
+
40
+ ### 向量数据库与嵌入
41
+ - **Chroma**:用于文档存储和检索的向量数据库
42
+ - **Nomic Embeddings**:本地文本嵌入模型(nomic-embed-text-v1.5)
43
+
44
+ ### 文档处理
45
+ - **WebBaseLoader**:网络内容提取
46
+ - **RecursiveCharacterTextSplitter**:使用tiktoken的智能文本分块
47
+
48
+ ### 外部API
49
+ - **Tavily API**:网络搜索功能
50
+ - **Nomic API**:嵌入服务
51
+
52
+ ### 工作流管理
53
+ - **StateGraph**:管理复杂RAG工作流状态
54
+ - **TypedDict**:类型安全的状态管理
55
+
56
+ ## 🏗️ 系统架构
57
+
58
+ 系统实现了一个包含以下组件的复杂状态机:
59
+
60
+ ```
61
+ 用户查询 → 路由器 → [向量存储 | 网络搜索]
62
+ ↓ ↓
63
+ 检索文档 网络搜索
64
+ ↓ ↓
65
+ 文档评分 --------→ 生成答案
66
+ ↓ ↓
67
+ [足够|不足够] 质量检查
68
+ ↓ ↓ ↓ ↓
69
+ 生成答案 [有用] [无用] [不支持]
70
+ ↓ ↓ ↓ ↓
71
+ 质量检查 [结束] 转换查询 重新生成
72
+ ↓ ↓ ↓
73
+ [有用|无用|不支持] 重新检索 ↑
74
+ ↓ ↓ |
75
+ [最终答案] ←----------
76
+ ```
77
+
78
+ ### 工作流节点
79
+ 1. **retrieve**:从向量存储获取相关文档
80
+ 2. **web_search**:搜索网络获取最新信息
81
+ 3. **grade_documents**:评估文档相关性
82
+ 4. **generate**:使用RAG链创建答案
83
+ 5. **transform_query**:改进查询表述
84
+
85
+ ### 决策点
86
+ - **route_question**:在向量存储和网络搜索之间选择
87
+ - **decide_to_generate**:判断文档是否足够
88
+ - **grade_generation**:验证答案质量和基础性
89
+
90
+ ## 📖 功能模块
91
+
92
+ ### 1. **环境设置**
93
+ - 安全的API密钥管理
94
+ - 本地LLM模型配置
95
+
96
+ ### 2. **知识库创建**
97
+ - 从指定URL加载网络内容
98
+ - 文档预处理和向量化
99
+ - 向量数据库初始化
100
+
101
+ ### 3. **查询处理**
102
+ - 基于内容分析的智能路由
103
+ - 文档检索和相关性评分
104
+ - 查询优化和转换
105
+
106
+ ### 4. **答案生成**
107
+ - 上下文感知的响应生成
108
+ - 多源信息综合
109
+ - 质量验证和改进
110
+
111
+ ### 5. **质量控制**
112
+ - 幻觉检测
113
+ - 答案相关性评分
114
+ - 迭代改进机制
115
+
116
+ ## 🚀 使用示例
117
+
118
+ 系统通过自适应工作流处理查询:
119
+
120
+ ```python
121
+ # 查询处理示例
122
+ inputs = {"question": "AlphaCodium论文讲的是什么?"}
123
+ for output in app.stream(inputs):
124
+ for key, value in output.items():
125
+ print(f"节点 '{key}':")
126
+ ```
127
+
128
+ ## 📊 数据源
129
+
130
+ ### 默认知识库
131
+ - LLM智能体(Lilian Weng的博客)
132
+ - 提示工程技术
133
+ - LLM对抗性攻击
134
+
135
+ ### 动态数据源
136
+ - 实时网络搜索结果
137
+ - 上下文相关的文档检索
138
+
139
+ ## 🔧 配置说明
140
+
141
+ ### 必需的API密钥
142
+ - `TAVILY_API_KEY`:用于网络搜索功能
143
+ - `NOMIC_API_KEY`:用于嵌入服务
144
+
145
+ ### 本地模型
146
+ - **模型**:Mistral(通过Ollama)
147
+ - **温度**:0(确定性响应)
148
+ - **格式**:结构化输出的JSON
149
+
150
+ ## 💡 核心创新
151
+
152
+ 1. **自适应路由**:基于查询语义的动态源选择
153
+ 2. **多层验证**:文档相关性、答案质量和幻觉检查
154
+ 3. **自我改进查询**:自动查询转换以获得更好结果
155
+ 4. **混合架构**:本地和基于网络的信息源无缝集成
156
+
157
+ 这个自适应RAG系统代表了信息检索和生成的先进方法,通过智能工作流管理��持续质量评估确保高质量、相关的响应。
RERANKING_PRINCIPLES.md ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 向量重排(Vector Reranking)原理详解
2
+
3
+ ## 🎯 什么是向量重排
4
+
5
+ 向量重排是检索增强生成(RAG)系统中的一种高级技术,用于在初始向量检索后对候选文档进行重新排序,以提高最终检索结果的质量和相关性。
6
+
7
+ ## 🔍 为什么需要重排
8
+
9
+ ### 初始检索的局限性
10
+
11
+ 1. **语义距离偏差**
12
+ - 向量相似度可能无法完全捕捉语义相关性
13
+ - 某些相关文档可能因为表达方式不同而排名靠后
14
+
15
+ 2. **上下文理解不足**
16
+ - 简单的余弦相似度无法理解复杂的查询意图
17
+ - 缺乏对查询和文档交互关系的深度理解
18
+
19
+ 3. **多样性问题**
20
+ - 初始检索可能返回内容相似的重复文档
21
+ - 缺乏结果的多样性和全面性
22
+
23
+ ## 🧠 重排的核心原理
24
+
25
+ ### 1. 双阶段检索架构
26
+
27
+ ```
28
+ 查询 → 粗排(向量检索)→ 精排(重排模型)→ 最终结果
29
+ ↓ ↓
30
+ 召回候选集 重新排序打分
31
+ (100-1000篇) (选择前k篇)
32
+ ```
33
+
34
+ ### 2. 重排模型类型
35
+
36
+ #### A. 交叉编码器(Cross-Encoder)
37
+ ```python
38
+ # 原理示意
39
+ def cross_encoder_rerank(query, documents):
40
+ scores = []
41
+ for doc in documents:
42
+ # 查询和文档一起编码
43
+ input_text = f"[CLS] {query} [SEP] {doc} [SEP]"
44
+ score = model(input_text) # 直接输出相关性分数
45
+ scores.append(score)
46
+ return sorted(zip(documents, scores), key=lambda x: x[1], reverse=True)
47
+ ```
48
+
49
+ #### B. 双编码器重排(Bi-Encoder Reranking)
50
+ ```python
51
+ def bi_encoder_rerank(query, documents):
52
+ query_embedding = query_encoder(query)
53
+ doc_embeddings = [doc_encoder(doc) for doc in documents]
54
+
55
+ # 使用更复杂的相似度计算
56
+ scores = []
57
+ for doc_emb in doc_embeddings:
58
+ score = complex_similarity(query_embedding, doc_emb)
59
+ scores.append(score)
60
+ return sorted(zip(documents, scores), key=lambda x: x[1], reverse=True)
61
+ ```
62
+
63
+ ## 🔬 重排算法详解
64
+
65
+ ### 1. 基于机器学习的重排
66
+
67
+ #### Learning to Rank (LTR)
68
+ ```python
69
+ class LearnToRankReranker:
70
+ def __init__(self):
71
+ self.model = None # XGBoost, LambdaMART等
72
+
73
+ def extract_features(self, query, document):
74
+ """提取查询-文档特征"""
75
+ features = [
76
+ # 文本匹配特征
77
+ jaccard_similarity(query, document),
78
+ tf_idf_score(query, document),
79
+ bm25_score(query, document),
80
+
81
+ # 语义特征
82
+ cosine_similarity(query_emb, doc_emb),
83
+ bert_score(query, document),
84
+
85
+ # 文档特征
86
+ document_length(document),
87
+ document_quality_score(document),
88
+
89
+ # 查询特征
90
+ query_complexity(query),
91
+ query_type_classification(query)
92
+ ]
93
+ return features
94
+
95
+ def rerank(self, query, documents):
96
+ features_matrix = []
97
+ for doc in documents:
98
+ features = self.extract_features(query, doc)
99
+ features_matrix.append(features)
100
+
101
+ scores = self.model.predict(features_matrix)
102
+ return sorted(zip(documents, scores), key=lambda x: x[1], reverse=True)
103
+ ```
104
+
105
+ ### 2. 基于深度学习的重排
106
+
107
+ #### Transformer重排模型
108
+ ```python
109
+ class TransformerReranker:
110
+ def __init__(self, model_name="microsoft/DialoGPT-medium"):
111
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
112
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
113
+
114
+ def rerank(self, query, documents, top_k=5):
115
+ scores = []
116
+
117
+ for doc in documents:
118
+ # 构造输入
119
+ inputs = self.tokenizer(
120
+ query, doc,
121
+ padding=True,
122
+ truncation=True,
123
+ max_length=512,
124
+ return_tensors="pt"
125
+ )
126
+
127
+ # 获取相关性分数
128
+ with torch.no_grad():
129
+ outputs = self.model(**inputs)
130
+ score = torch.softmax(outputs.logits, dim=-1)[0][1] # 相关性概率
131
+ scores.append(score.item())
132
+
133
+ # 重新排序
134
+ ranked_results = sorted(
135
+ zip(documents, scores),
136
+ key=lambda x: x[1],
137
+ reverse=True
138
+ )
139
+
140
+ return ranked_results[:top_k]
141
+ ```
142
+
143
+ ### 3. 多策略融合重排
144
+
145
+ ```python
146
+ class MultiStrategyReranker:
147
+ def __init__(self):
148
+ self.semantic_weight = 0.4
149
+ self.lexical_weight = 0.3
150
+ self.diversity_weight = 0.2
151
+ self.freshness_weight = 0.1
152
+
153
+ def rerank(self, query, documents):
154
+ # 1. 语义相关性分数
155
+ semantic_scores = self.compute_semantic_scores(query, documents)
156
+
157
+ # 2. 词汇匹配分数
158
+ lexical_scores = self.compute_lexical_scores(query, documents)
159
+
160
+ # 3. 多样性分数
161
+ diversity_scores = self.compute_diversity_scores(documents)
162
+
163
+ # 4. 时效性分数
164
+ freshness_scores = self.compute_freshness_scores(documents)
165
+
166
+ # 5. 加权融合
167
+ final_scores = []
168
+ for i in range(len(documents)):
169
+ score = (
170
+ self.semantic_weight * semantic_scores[i] +
171
+ self.lexical_weight * lexical_scores[i] +
172
+ self.diversity_weight * diversity_scores[i] +
173
+ self.freshness_weight * freshness_scores[i]
174
+ )
175
+ final_scores.append(score)
176
+
177
+ return sorted(zip(documents, final_scores), key=lambda x: x[1], reverse=True)
178
+ ```
179
+
180
+ ## 🎛️ 重排特征工程
181
+
182
+ ### 1. 文本匹配特征
183
+ ```python
184
+ def extract_text_features(query, document):
185
+ return {
186
+ # 精确匹配
187
+ 'exact_match_ratio': exact_match_count(query, document) / len(query.split()),
188
+
189
+ # 模糊匹配
190
+ 'fuzzy_match_score': fuzz.ratio(query, document) / 100,
191
+
192
+ # N-gram重叠
193
+ 'bigram_overlap': ngram_overlap(query, document, n=2),
194
+ 'trigram_overlap': ngram_overlap(query, document, n=3),
195
+
196
+ # TF-IDF相似度
197
+ 'tfidf_similarity': tfidf_cosine_similarity(query, document),
198
+
199
+ # BM25分数
200
+ 'bm25_score': compute_bm25(query, document)
201
+ }
202
+ ```
203
+
204
+ ### 2. 语义特征
205
+ ```python
206
+ def extract_semantic_features(query, document, embeddings):
207
+ query_emb = embeddings['query']
208
+ doc_emb = embeddings['document']
209
+
210
+ return {
211
+ # 余弦相似度
212
+ 'cosine_similarity': cosine_sim(query_emb, doc_emb),
213
+
214
+ # 欧几里得距离
215
+ 'euclidean_distance': euclidean_distance(query_emb, doc_emb),
216
+
217
+ # 曼哈顿距离
218
+ 'manhattan_distance': manhattan_distance(query_emb, doc_emb),
219
+
220
+ # BERT分数
221
+ 'bert_score': bert_score_f1(query, document),
222
+
223
+ # 语义角度
224
+ 'semantic_angle': semantic_angle(query_emb, doc_emb)
225
+ }
226
+ ```
227
+
228
+ ### 3. 文档质量特征
229
+ ```python
230
+ def extract_quality_features(document):
231
+ return {
232
+ # 长度特征
233
+ 'doc_length': len(document.split()),
234
+ 'sentence_count': len(sent_tokenize(document)),
235
+
236
+ # 可读性特征
237
+ 'readability_score': textstat.flesch_reading_ease(document),
238
+ 'complexity_score': textstat.flesch_kincaid_grade(document),
239
+
240
+ # 信息密度
241
+ 'unique_word_ratio': len(set(document.split())) / len(document.split()),
242
+ 'stopword_ratio': stopword_count(document) / len(document.split()),
243
+
244
+ # 结构特征
245
+ 'has_headers': bool(re.search(r'^#+\s', document, re.MULTILINE)),
246
+ 'has_lists': bool(re.search(r'^\s*[-*+]\s', document, re.MULTILINE))
247
+ }
248
+ ```
249
+
250
+ ## 🚀 实际应用示例
251
+
252
+ ### 集成到RAG系统中
253
+ ```python
254
+ class AdaptiveRAGWithReranking:
255
+ def __init__(self):
256
+ self.initial_retriever = VectorRetriever()
257
+ self.reranker = TransformerReranker()
258
+ self.generator = LanguageModel()
259
+
260
+ def query(self, question, top_k=5, rerank_candidates=20):
261
+ # 1. 初始检索(获取更多候选)
262
+ initial_docs = self.initial_retriever.retrieve(
263
+ question,
264
+ top_k=rerank_candidates
265
+ )
266
+
267
+ # 2. 重排
268
+ reranked_docs = self.reranker.rerank(
269
+ question,
270
+ initial_docs,
271
+ top_k=top_k
272
+ )
273
+
274
+ # 3. 生成答案
275
+ context = "\n\n".join([doc[0] for doc in reranked_docs])
276
+ answer = self.generator.generate(question, context)
277
+
278
+ return {
279
+ 'answer': answer,
280
+ 'sources': reranked_docs,
281
+ 'confidence': self.calculate_confidence(reranked_docs)
282
+ }
283
+ ```
284
+
285
+ ## 📊 性能评估指标
286
+
287
+ ### 1. 排序质量指标
288
+ ```python
289
+ def evaluate_reranking(original_ranking, reranked_results, ground_truth):
290
+ return {
291
+ # NDCG (Normalized Discounted Cumulative Gain)
292
+ 'ndcg@5': ndcg_score(ground_truth, reranked_results, k=5),
293
+ 'ndcg@10': ndcg_score(ground_truth, reranked_results, k=10),
294
+
295
+ # MAP (Mean Average Precision)
296
+ 'map': mean_average_precision(ground_truth, reranked_results),
297
+
298
+ # MRR (Mean Reciprocal Rank)
299
+ 'mrr': mean_reciprocal_rank(ground_truth, reranked_results),
300
+
301
+ # 排序改进度
302
+ 'ranking_improvement': kendall_tau(original_ranking, reranked_results)
303
+ }
304
+ ```
305
+
306
+ ### 2. 端到端效果评估
307
+ ```python
308
+ def evaluate_rag_with_reranking(test_questions, ground_truth_answers):
309
+ results = []
310
+
311
+ for question, gt_answer in zip(test_questions, ground_truth_answers):
312
+ # 无重排
313
+ original_answer = rag_without_rerank(question)
314
+
315
+ # 有重排
316
+ reranked_answer = rag_with_rerank(question)
317
+
318
+ results.append({
319
+ 'question': question,
320
+ 'original_score': evaluate_answer(original_answer, gt_answer),
321
+ 'reranked_score': evaluate_answer(reranked_answer, gt_answer),
322
+ 'improvement': evaluate_answer(reranked_answer, gt_answer) -
323
+ evaluate_answer(original_answer, gt_answer)
324
+ })
325
+
326
+ return results
327
+ ```
328
+
329
+ ## 💡 最佳实践
330
+
331
+ ### 1. 重排策略选择
332
+ - **实时性要求高**: 使用轻量级规则或简单ML模型
333
+ - **精度要求高**: 使用深度学习重排模型
334
+ - **平衡性能**: 多策略融合 + 缓存优化
335
+
336
+ ### 2. 特征选择原则
337
+ - **相关性特征**: 语义相似度、词汇匹配
338
+ - **质量特征**: 文档权威性、完整性
339
+ - **多样性特征**: 避免结果冗余
340
+ - **时效性特征**: 信息新鲜度
341
+
342
+ ### 3. 系统优化
343
+ ```python
344
+ class OptimizedReranker:
345
+ def __init__(self):
346
+ self.cache = LRUCache(maxsize=1000)
347
+ self.batch_size = 32
348
+
349
+ @lru_cache(maxsize=1000)
350
+ def cached_rerank(self, query_hash, doc_hashes):
351
+ """缓存重排结果"""
352
+ pass
353
+
354
+ def batch_rerank(self, queries, documents):
355
+ """批量重排优化"""
356
+ pass
357
+ ```
358
+
359
+ 重排向量是提升RAG系统检索精度的关键技术,通过多层次的相关性评估和智能排序,显著提高了最终答案的质量和准确性。
__pycache__/config.cpython-310.pyc ADDED
Binary file (2.34 kB). View file
 
__pycache__/config.cpython-313.pyc ADDED
Binary file (2.95 kB). View file
 
__pycache__/document_processor.cpython-310.pyc ADDED
Binary file (7.21 kB). View file
 
__pycache__/document_processor.cpython-313.pyc ADDED
Binary file (4.56 kB). View file
 
__pycache__/entity_extractor.cpython-310.pyc ADDED
Binary file (6.49 kB). View file
 
__pycache__/graph_indexer.cpython-310.pyc ADDED
Binary file (4.54 kB). View file
 
__pycache__/graph_retriever.cpython-310.pyc ADDED
Binary file (7.51 kB). View file
 
__pycache__/knowledge_graph.cpython-310.pyc ADDED
Binary file (10.9 kB). View file
 
__pycache__/main.cpython-310.pyc ADDED
Binary file (4.39 kB). View file
 
__pycache__/main.cpython-313.pyc ADDED
Binary file (6.55 kB). View file
 
__pycache__/reranker.cpython-310.pyc ADDED
Binary file (11.2 kB). View file
 
__pycache__/routers_and_graders.cpython-310.pyc ADDED
Binary file (6.07 kB). View file
 
__pycache__/routers_and_graders.cpython-313.pyc ADDED
Binary file (8.04 kB). View file
 
__pycache__/workflow_nodes.cpython-310.pyc ADDED
Binary file (7.32 kB). View file
 
__pycache__/workflow_nodes.cpython-313.pyc ADDED
Binary file (8.54 kB). View file
 
colab_gpu_demo.ipynb ADDED
@@ -0,0 +1,588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# 🚀 GraphRAG GPU检测与测试 - Google Colab版本\n",
8
+ "\n",
9
+ "本Notebook用于在Google Colab上检测GPU可用性并测试GraphRAG系统的性能。\n",
10
+ "\n",
11
+ "## 📋 使用步骤\n",
12
+ "\n",
13
+ "1. **启用GPU**: 运行时 → 更改运行时类型 → 硬件加速器 → GPU (T4)\n",
14
+ "2. **运行所有单元格**: 依次执行下面的代码\n",
15
+ "3. **查看结果**: 检查GPU加速效果\n",
16
+ "\n",
17
+ "---"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "markdown",
22
+ "metadata": {},
23
+ "source": [
24
+ "## 1️⃣ GPU环境检测"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "metadata": {},
31
+ "outputs": [],
32
+ "source": [
33
+ "# 检测GPU可用性\n",
34
+ "import torch\n",
35
+ "import subprocess\n",
36
+ "import sys\n",
37
+ "\n",
38
+ "print(\"=\"*60)\n",
39
+ "print(\"🔍 GPU环境检测\")\n",
40
+ "print(\"=\"*60)\n",
41
+ "\n",
42
+ "# PyTorch GPU检测\n",
43
+ "cuda_available = torch.cuda.is_available()\n",
44
+ "print(f\"\\n✅ CUDA可用: {cuda_available}\")\n",
45
+ "\n",
46
+ "if cuda_available:\n",
47
+ " print(f\" GPU数量: {torch.cuda.device_count()}\")\n",
48
+ " print(f\" 当前GPU: {torch.cuda.current_device()}\")\n",
49
+ " print(f\" GPU名称: {torch.cuda.get_device_name(0)}\")\n",
50
+ " print(f\" CUDA版本: {torch.version.cuda}\")\n",
51
+ " \n",
52
+ " # 显存信息\n",
53
+ " total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)\n",
54
+ " print(f\" 总显存: {total_memory:.2f} GB\")\n",
55
+ " \n",
56
+ " # nvidia-smi信息\n",
57
+ " print(\"\\n📊 nvidia-smi 输出:\")\n",
58
+ " print(\"-\"*60)\n",
59
+ " !nvidia-smi\n",
60
+ "else:\n",
61
+ " print(\"\\n⚠️ 警告: 未检测到GPU\")\n",
62
+ " print(\" 请检查: 运行时 → 更改运行时类型 → 硬件加速器 → GPU\")\n",
63
+ "\n",
64
+ "print(\"\\n\" + \"=\"*60)"
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "markdown",
69
+ "metadata": {},
70
+ "source": [
71
+ "## 2️⃣ GPU性能基准测试"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "execution_count": null,
77
+ "metadata": {},
78
+ "outputs": [],
79
+ "source": [
80
+ "# GPU vs CPU 性能对比\n",
81
+ "import time\n",
82
+ "import numpy as np\n",
83
+ "\n",
84
+ "print(\"=\"*60)\n",
85
+ "print(\"⚡ GPU vs CPU 矩阵运算性能测试\")\n",
86
+ "print(\"=\"*60)\n",
87
+ "\n",
88
+ "# 测试参数\n",
89
+ "matrix_size = 5000\n",
90
+ "\n",
91
+ "# CPU测试\n",
92
+ "print(f\"\\n🔵 CPU测试 (矩阵大小: {matrix_size}x{matrix_size})\")\n",
93
+ "a_cpu = torch.randn(matrix_size, matrix_size)\n",
94
+ "b_cpu = torch.randn(matrix_size, matrix_size)\n",
95
+ "\n",
96
+ "start = time.time()\n",
97
+ "c_cpu = torch.mm(a_cpu, b_cpu)\n",
98
+ "cpu_time = time.time() - start\n",
99
+ "print(f\" CPU时间: {cpu_time:.2f} 秒\")\n",
100
+ "\n",
101
+ "# GPU测试\n",
102
+ "if cuda_available:\n",
103
+ " print(f\"\\n🟢 GPU测试 (矩阵大小: {matrix_size}x{matrix_size})\")\n",
104
+ " a_gpu = torch.randn(matrix_size, matrix_size).cuda()\n",
105
+ " b_gpu = torch.randn(matrix_size, matrix_size).cuda()\n",
106
+ " \n",
107
+ " # 预热GPU\n",
108
+ " _ = torch.mm(a_gpu, b_gpu)\n",
109
+ " torch.cuda.synchronize()\n",
110
+ " \n",
111
+ " start = time.time()\n",
112
+ " c_gpu = torch.mm(a_gpu, b_gpu)\n",
113
+ " torch.cuda.synchronize()\n",
114
+ " gpu_time = time.time() - start\n",
115
+ " print(f\" GPU时间: {gpu_time:.2f} 秒\")\n",
116
+ " \n",
117
+ " speedup = cpu_time / gpu_time\n",
118
+ " print(f\"\\n🚀 加速比: {speedup:.2f}x\")\n",
119
+ " print(f\" GPU比CPU快 {speedup:.1f} 倍!\")\n",
120
+ "else:\n",
121
+ " print(\"\\n⚠️ 跳过GPU测试(GPU不可用)\")\n",
122
+ "\n",
123
+ "print(\"\\n\" + \"=\"*60)"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "markdown",
128
+ "metadata": {},
129
+ "source": [
130
+ "## 3️⃣ 安装GraphRAG依赖"
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "code",
135
+ "execution_count": null,
136
+ "metadata": {},
137
+ "outputs": [],
138
+ "source": [
139
+ "# 克隆项目(如果需要)\n",
140
+ "import os\n",
141
+ "\n",
142
+ "print(\"📦 安装GraphRAG依赖...\\n\")\n",
143
+ "\n",
144
+ "# 安装核心依赖\n",
145
+ "!pip install -q langchain langchain-community langchain-core langgraph\n",
146
+ "!pip install -q chromadb sentence-transformers transformers\n",
147
+ "!pip install -q tiktoken beautifulsoup4 requests\n",
148
+ "!pip install -q tavily-python python-dotenv\n",
149
+ "!pip install -q networkx python-louvain\n",
150
+ "!pip install -q torch --index-url https://download.pytorch.org/whl/cu118\n",
151
+ "\n",
152
+ "print(\"\\n✅ 依赖安装完成!\")"
153
+ ]
154
+ },
155
+ {
156
+ "cell_type": "markdown",
157
+ "metadata": {},
158
+ "source": [
159
+ "## 4️⃣ 上传项目文件\n",
160
+ "\n",
161
+ "**选项A**: 从GitHub克隆\n",
162
+ "```python\n",
163
+ "!git clone https://github.com/your-repo/adaptive_RAG.git\n",
164
+ "%cd adaptive_RAG\n",
165
+ "```\n",
166
+ "\n",
167
+ "**选项B**: 手动上传文件到Colab\n",
168
+ "- 使用左侧文件浏览器上传以下核心文件:\n",
169
+ " - `config.py`\n",
170
+ " - `entity_extractor.py`\n",
171
+ " - `knowledge_graph.py`\n",
172
+ " - `graph_indexer.py`\n",
173
+ " - `graph_retriever.py`\n",
174
+ " - `.env` (包含API密钥)"
175
+ ]
176
+ },
177
+ {
178
+ "cell_type": "code",
179
+ "execution_count": null,
180
+ "metadata": {},
181
+ "outputs": [],
182
+ "source": [
183
+ "# 创建必要的目录\n",
184
+ "!mkdir -p data\n",
185
+ "\n",
186
+ "# 如果使用选项A,运行下面的命令\n",
187
+ "# !git clone YOUR_REPO_URL\n",
188
+ "# %cd adaptive_RAG\n",
189
+ "\n",
190
+ "print(\"✅ 目录准备完成\")"
191
+ ]
192
+ },
193
+ {
194
+ "cell_type": "markdown",
195
+ "metadata": {},
196
+ "source": [
197
+ "## 5️⃣ 配置API密钥"
198
+ ]
199
+ },
200
+ {
201
+ "cell_type": "code",
202
+ "execution_count": null,
203
+ "metadata": {},
204
+ "outputs": [],
205
+ "source": [
206
+ "# 设置API密钥(替换为您的真实密钥)\n",
207
+ "import os\n",
208
+ "from getpass import getpass\n",
209
+ "\n",
210
+ "print(\"🔑 配置API密钥\\n\")\n",
211
+ "\n",
212
+ "# 方式1: 直接设置(不安全,仅用于测试)\n",
213
+ "# os.environ['TAVILY_API_KEY'] = 'your_tavily_api_key_here'\n",
214
+ "\n",
215
+ "# 方式2: 安全输入\n",
216
+ "if 'TAVILY_API_KEY' not in os.environ:\n",
217
+ " os.environ['TAVILY_API_KEY'] = getpass('输入 TAVILY_API_KEY: ')\n",
218
+ " print(\"✅ TAVILY_API_KEY 已设置\")\n",
219
+ "else:\n",
220
+ " print(\"✅ TAVILY_API_KEY 已存在\")\n",
221
+ "\n",
222
+ "print(\"\\n注意: GraphRAG在Colab上使用HuggingFace嵌入,不需要NOMIC_API_KEY\")"
223
+ ]
224
+ },
225
+ {
226
+ "cell_type": "markdown",
227
+ "metadata": {},
228
+ "source": [
229
+ "## 6️⃣ 简化版GraphRAG测试代码"
230
+ ]
231
+ },
232
+ {
233
+ "cell_type": "code",
234
+ "execution_count": null,
235
+ "metadata": {},
236
+ "outputs": [],
237
+ "source": [
238
+ "# 简化版GraphRAG核心组件\n",
239
+ "# 适用于Colab快速测试,无需完整项目文件\n",
240
+ "\n",
241
+ "from typing import List, Dict\n",
242
+ "import networkx as nx\n",
243
+ "from sentence_transformers import SentenceTransformer\n",
244
+ "import torch\n",
245
+ "\n",
246
+ "class SimpleGraphRAG:\n",
247
+ " \"\"\"简化版GraphRAG用于GPU性能测试\"\"\"\n",
248
+ " \n",
249
+ " def __init__(self, use_gpu=True):\n",
250
+ " print(\"🚀 初始化SimpleGraphRAG...\")\n",
251
+ " \n",
252
+ " # 检测设备\n",
253
+ " self.device = 'cuda' if use_gpu and torch.cuda.is_available() else 'cpu'\n",
254
+ " print(f\" 设备: {self.device.upper()}\")\n",
255
+ " \n",
256
+ " # 加载嵌入模型\n",
257
+ " print(f\" 加载嵌入模型...\")\n",
258
+ " self.embedder = SentenceTransformer(\n",
259
+ " 'sentence-transformers/all-MiniLM-L6-v2',\n",
260
+ " device=self.device\n",
261
+ " )\n",
262
+ " \n",
263
+ " # 知识图谱\n",
264
+ " self.graph = nx.Graph()\n",
265
+ " self.entities = {}\n",
266
+ " \n",
267
+ " print(\"✅ 初始化完成!\")\n",
268
+ " \n",
269
+ " def add_sample_data(self):\n",
270
+ " \"\"\"添加示例数据\"\"\"\n",
271
+ " print(\"\\n📊 添加示例数据...\")\n",
272
+ " \n",
273
+ " # 示例实体\n",
274
+ " entities = [\n",
275
+ " {\"name\": \"LLM\", \"type\": \"CONCEPT\", \"desc\": \"大语言模型\"},\n",
276
+ " {\"name\": \"GPT\", \"type\": \"TECHNOLOGY\", \"desc\": \"生成式预训练转换器\"},\n",
277
+ " {\"name\": \"Transformer\", \"type\": \"CONCEPT\", \"desc\": \"注意力机制架构\"},\n",
278
+ " {\"name\": \"OpenAI\", \"type\": \"ORGANIZATION\", \"desc\": \"人工智能研究公司\"},\n",
279
+ " {\"name\": \"Attention\", \"type\": \"CONCEPT\", \"desc\": \"注意力机制\"},\n",
280
+ " ]\n",
281
+ " \n",
282
+ " for entity in entities:\n",
283
+ " self.graph.add_node(\n",
284
+ " entity[\"name\"],\n",
285
+ " type=entity[\"type\"],\n",
286
+ " description=entity[\"desc\"]\n",
287
+ " )\n",
288
+ " self.entities[entity[\"name\"]] = entity\n",
289
+ " \n",
290
+ " # 示例关系\n",
291
+ " relations = [\n",
292
+ " (\"GPT\", \"LLM\", \"IS_A\"),\n",
293
+ " (\"GPT\", \"Transformer\", \"USES\"),\n",
294
+ " (\"Transformer\", \"Attention\", \"CONTAINS\"),\n",
295
+ " (\"OpenAI\", \"GPT\", \"DEVELOPS\"),\n",
296
+ " ]\n",
297
+ " \n",
298
+ " for source, target, rel_type in relations:\n",
299
+ " self.graph.add_edge(source, target, relation=rel_type)\n",
300
+ " \n",
301
+ " print(f\" ✅ 添加了 {len(entities)} 个实体\")\n",
302
+ " print(f\" ✅ 添加了 {len(relations)} 个关系\")\n",
303
+ " \n",
304
+ " def test_gpu_embedding(self, texts: List[str]):\n",
305
+ " \"\"\"测试GPU嵌入性能\"\"\"\n",
306
+ " print(f\"\\n⚡ 测试嵌入性能 ({len(texts)} 个文本)...\")\n",
307
+ " \n",
308
+ " import time\n",
309
+ " \n",
310
+ " start = time.time()\n",
311
+ " embeddings = self.embedder.encode(\n",
312
+ " texts,\n",
313
+ " show_progress_bar=True,\n",
314
+ " batch_size=32\n",
315
+ " )\n",
316
+ " elapsed = time.time() - start\n",
317
+ " \n",
318
+ " print(f\" ✅ 完成! 耗时: {elapsed:.2f}秒\")\n",
319
+ " print(f\" 📊 嵌入维度: {embeddings.shape}\")\n",
320
+ " print(f\" 🚀 速度: {len(texts)/elapsed:.1f} 文本/秒\")\n",
321
+ " \n",
322
+ " return embeddings\n",
323
+ " \n",
324
+ " def query(self, question: str):\n",
325
+ " \"\"\"简单查询\"\"\"\n",
326
+ " print(f\"\\n🔍 查询: {question}\")\n",
327
+ " \n",
328
+ " # 简单的关键词匹配\n",
329
+ " results = []\n",
330
+ " for entity_name in self.entities:\n",
331
+ " if entity_name.lower() in question.lower():\n",
332
+ " neighbors = list(self.graph.neighbors(entity_name))\n",
333
+ " results.append({\n",
334
+ " \"entity\": entity_name,\n",
335
+ " \"info\": self.entities[entity_name],\n",
336
+ " \"neighbors\": neighbors\n",
337
+ " })\n",
338
+ " \n",
339
+ " print(f\"\\n📋 找到 {len(results)} 个相关实体:\")\n",
340
+ " for r in results:\n",
341
+ " print(f\" • {r['entity']} ({r['info']['type']})\")\n",
342
+ " print(f\" 描述: {r['info']['desc']}\")\n",
343
+ " print(f\" 关联: {', '.join(r['neighbors'])}\")\n",
344
+ " \n",
345
+ " return results\n",
346
+ "\n",
347
+ "print(\"✅ SimpleGraphRAG类定义完成\")"
348
+ ]
349
+ },
350
+ {
351
+ "cell_type": "markdown",
352
+ "metadata": {},
353
+ "source": [
354
+ "## 7️⃣ 运行GPU性能测试"
355
+ ]
356
+ },
357
+ {
358
+ "cell_type": "code",
359
+ "execution_count": null,
360
+ "metadata": {},
361
+ "outputs": [],
362
+ "source": [
363
+ "# 初始化GraphRAG(GPU版本)\n",
364
+ "print(\"=\"*60)\n",
365
+ "print(\"🎯 GraphRAG GPU性能测试\")\n",
366
+ "print(\"=\"*60)\n",
367
+ "\n",
368
+ "graph_rag = SimpleGraphRAG(use_gpu=True)\n",
369
+ "\n",
370
+ "# 添加示例数据\n",
371
+ "graph_rag.add_sample_data()\n",
372
+ "\n",
373
+ "# 准备测试文本\n",
374
+ "test_texts = [\n",
375
+ " \"Large Language Models are transforming AI\",\n",
376
+ " \"GPT uses Transformer architecture\",\n",
377
+ " \"Attention mechanism is key to modern NLP\",\n",
378
+ " \"OpenAI develops cutting-edge AI models\",\n",
379
+ "] * 25 # 100个文本\n",
380
+ "\n",
381
+ "print(f\"\\n准备了 {len(test_texts)} 个测试文本\")\n",
382
+ "\n",
383
+ "# GPU嵌入测试\n",
384
+ "embeddings = graph_rag.test_gpu_embedding(test_texts)\n",
385
+ "\n",
386
+ "# 测试查询\n",
387
+ "graph_rag.query(\"What is GPT?\")\n",
388
+ "graph_rag.query(\"Tell me about Transformer\")\n",
389
+ "\n",
390
+ "print(\"\\n\" + \"=\"*60)\n",
391
+ "print(\"✅ GPU性能测试完成!\")\n",
392
+ "print(\"=\"*60)"
393
+ ]
394
+ },
395
+ {
396
+ "cell_type": "markdown",
397
+ "metadata": {},
398
+ "source": [
399
+ "## 8️⃣ CPU vs GPU 性能对比"
400
+ ]
401
+ },
402
+ {
403
+ "cell_type": "code",
404
+ "execution_count": null,
405
+ "metadata": {},
406
+ "outputs": [],
407
+ "source": [
408
+ "# CPU vs GPU 嵌入性能对比\n",
409
+ "import time\n",
410
+ "\n",
411
+ "print(\"=\"*60)\n",
412
+ "print(\"📊 CPU vs GPU 嵌入性能对比\")\n",
413
+ "print(\"=\"*60)\n",
414
+ "\n",
415
+ "# 准备大量测试文本\n",
416
+ "large_test_texts = test_texts * 10 # 1000个文本\n",
417
+ "print(f\"\\n测试数据: {len(large_test_texts)} 个文本\\n\")\n",
418
+ "\n",
419
+ "# CPU测试\n",
420
+ "print(\"🔵 CPU测试...\")\n",
421
+ "graph_rag_cpu = SimpleGraphRAG(use_gpu=False)\n",
422
+ "start = time.time()\n",
423
+ "embeddings_cpu = graph_rag_cpu.embedder.encode(\n",
424
+ " large_test_texts,\n",
425
+ " show_progress_bar=False,\n",
426
+ " batch_size=32\n",
427
+ ")\n",
428
+ "cpu_time = time.time() - start\n",
429
+ "print(f\" CPU时间: {cpu_time:.2f}秒\")\n",
430
+ "print(f\" 速度: {len(large_test_texts)/cpu_time:.1f} 文本/秒\")\n",
431
+ "\n",
432
+ "# GPU测试\n",
433
+ "if cuda_available:\n",
434
+ " print(\"\\n🟢 GPU测试...\")\n",
435
+ " graph_rag_gpu = SimpleGraphRAG(use_gpu=True)\n",
436
+ " start = time.time()\n",
437
+ " embeddings_gpu = graph_rag_gpu.embedder.encode(\n",
438
+ " large_test_texts,\n",
439
+ " show_progress_bar=False,\n",
440
+ " batch_size=32\n",
441
+ " )\n",
442
+ " gpu_time = time.time() - start\n",
443
+ " print(f\" GPU时间: {gpu_time:.2f}秒\")\n",
444
+ " print(f\" 速度: {len(large_test_texts)/gpu_time:.1f} 文本/秒\")\n",
445
+ " \n",
446
+ " speedup = cpu_time / gpu_time\n",
447
+ " print(f\"\\n🚀 加速比: {speedup:.2f}x\")\n",
448
+ " print(f\" GPU比CPU快 {speedup:.1f} 倍!\")\n",
449
+ " \n",
450
+ " # 节省的时间\n",
451
+ " time_saved = cpu_time - gpu_time\n",
452
+ " print(f\" ⏱️ 节省时间: {time_saved:.2f}秒\")\n",
453
+ "else:\n",
454
+ " print(\"\\n⚠️ GPU不可用,跳过GPU测试\")\n",
455
+ "\n",
456
+ "print(\"\\n\" + \"=\"*60)"
457
+ ]
458
+ },
459
+ {
460
+ "cell_type": "markdown",
461
+ "metadata": {},
462
+ "source": [
463
+ "## 9️⃣ 显存使用监控"
464
+ ]
465
+ },
466
+ {
467
+ "cell_type": "code",
468
+ "execution_count": null,
469
+ "metadata": {},
470
+ "outputs": [],
471
+ "source": [
472
+ "# 监控GPU显存使用\n",
473
+ "if cuda_available:\n",
474
+ " print(\"=\"*60)\n",
475
+ " print(\"💾 GPU显存使用情况\")\n",
476
+ " print(\"=\"*60)\n",
477
+ " \n",
478
+ " allocated = torch.cuda.memory_allocated(0) / (1024**3)\n",
479
+ " reserved = torch.cuda.memory_reserved(0) / (1024**3)\n",
480
+ " total = torch.cuda.get_device_properties(0).total_memory / (1024**3)\n",
481
+ " \n",
482
+ " print(f\"\\n已分配: {allocated:.2f} GB\")\n",
483
+ " print(f\"已保留: {reserved:.2f} GB\")\n",
484
+ " print(f\"总显存: {total:.2f} GB\")\n",
485
+ " print(f\"使用率: {(allocated/total)*100:.1f}%\")\n",
486
+ " \n",
487
+ " print(\"\\n详细信息:\")\n",
488
+ " print(torch.cuda.memory_summary(0, abbreviated=True))\n",
489
+ " \n",
490
+ " print(\"\\n\" + \"=\"*60)\n",
491
+ "else:\n",
492
+ " print(\"⚠️ GPU不可用\")"
493
+ ]
494
+ },
495
+ {
496
+ "cell_type": "markdown",
497
+ "metadata": {},
498
+ "source": [
499
+ "## 🔟 性能总结报告"
500
+ ]
501
+ },
502
+ {
503
+ "cell_type": "code",
504
+ "execution_count": null,
505
+ "metadata": {},
506
+ "outputs": [],
507
+ "source": [
508
+ "# 生成性能报告\n",
509
+ "print(\"=\"*60)\n",
510
+ "print(\"📈 GraphRAG GPU性能测试报告\")\n",
511
+ "print(\"=\"*60)\n",
512
+ "\n",
513
+ "print(\"\\n🖥️ 硬件信息:\")\n",
514
+ "if cuda_available:\n",
515
+ " print(f\" GPU型号: {torch.cuda.get_device_name(0)}\")\n",
516
+ " print(f\" 显存: {torch.cuda.get_device_properties(0).total_memory / (1024**3):.2f} GB\")\n",
517
+ " print(f\" CUDA版本: {torch.version.cuda}\")\n",
518
+ "else:\n",
519
+ " print(\" ⚠️ GPU不可用\")\n",
520
+ "\n",
521
+ "print(f\"\\n PyTorch版本: {torch.__version__}\")\n",
522
+ "print(f\" Python版本: {sys.version.split()[0]}\")\n",
523
+ "\n",
524
+ "print(\"\\n⚡ 性能测试结果:\")\n",
525
+ "print(f\" 矩阵运算加速: ~{speedup if cuda_available else 'N/A'}x\")\n",
526
+ "print(f\" 文本嵌入加速: ~{cpu_time/gpu_time if cuda_available else 'N/A'}x\")\n",
527
+ "\n",
528
+ "print(\"\\n💡 建议:\")\n",
529
+ "if cuda_available:\n",
530
+ " print(\" ✅ GPU运行良好!建议在Colab上运行完整的GraphRAG索引构建\")\n",
531
+ " print(\" ✅ 预计索引构建时间将大幅缩短\")\n",
532
+ " print(\" ✅ 可以处理更大规模的文档集\")\n",
533
+ "else:\n",
534
+ " print(\" ⚠️ 建议启用GPU以获得最佳性能\")\n",
535
+ " print(\" ⚠️ 路径: 运行时 → 更改运行时类型 → GPU\")\n",
536
+ "\n",
537
+ "print(\"\\n\" + \"=\"*60)\n",
538
+ "print(\"✅ 测试完成!\")\n",
539
+ "print(\"=\"*60)"
540
+ ]
541
+ },
542
+ {
543
+ "cell_type": "markdown",
544
+ "metadata": {},
545
+ "source": [
546
+ "---\n",
547
+ "\n",
548
+ "## 📚 下一步\n",
549
+ "\n",
550
+ "如果GPU测试成功,您可以:\n",
551
+ "\n",
552
+ "1. **上传完整项目**: 将整个adaptive_RAG项目上传到Colab\n",
553
+ "2. **运行GraphRAG索引**: 使用GPU加速构建知识图谱\n",
554
+ "3. **保存结果**: 将构建好的图谱下载到本地\n",
555
+ "\n",
556
+ "### 运行完整GraphRAG的命令:\n",
557
+ "\n",
558
+ "```python\n",
559
+ "# 上传项目后运行\n",
560
+ "!python main_graphrag.py\n",
561
+ "```\n",
562
+ "\n",
563
+ "### 预期加速效果:\n",
564
+ "\n",
565
+ "- 实体提取: 使用GPU的LLM推理会更快\n",
566
+ "- 文本嵌入: **5-10倍加速**\n",
567
+ "- 向量相似度计算: **10-20倍加速**\n",
568
+ "- 总体索引构建时间: **3-5倍加速**\n",
569
+ "\n",
570
+ "---"
571
+ ]
572
+ }
573
+ ],
574
+ "metadata": {
575
+ "accelerator": "GPU",
576
+ "kernelspec": {
577
+ "display_name": "Python 3",
578
+ "language": "python",
579
+ "name": "python3"
580
+ },
581
+ "language_info": {
582
+ "name": "python",
583
+ "version": "3.10.0"
584
+ }
585
+ },
586
+ "nbformat": 4,
587
+ "nbformat_minor": 0
588
+ }
colab_gpu_test.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Google Colab GPU检测和GraphRAG性能测试脚本
4
+ 可以直接在Colab中运行:python colab_gpu_test.py
5
+ """
6
+
7
+ import sys
8
+ import time
9
+ import torch
10
+ import numpy as np
11
+ from typing import List, Dict
12
+
13
+ def print_section(title: str):
14
+ """打印分节标题"""
15
+ print("\n" + "="*60)
16
+ print(f"{title}")
17
+ print("="*60 + "\n")
18
+
19
+
20
+ def test_gpu_availability():
21
+ """测试GPU可用性"""
22
+ print_section("🔍 GPU环境检测")
23
+
24
+ cuda_available = torch.cuda.is_available()
25
+ print(f"✅ CUDA可用: {cuda_available}")
26
+
27
+ if cuda_available:
28
+ print(f" GPU数量: {torch.cuda.device_count()}")
29
+ print(f" 当前GPU: {torch.cuda.current_device()}")
30
+ print(f" GPU名称: {torch.cuda.get_device_name(0)}")
31
+ print(f" CUDA版本: {torch.version.cuda}")
32
+
33
+ total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
34
+ print(f" 总显存: {total_memory:.2f} GB")
35
+
36
+ return True
37
+ else:
38
+ print("\n⚠️ 警告: 未检测到GPU")
39
+ print(" 在Colab中启用GPU: 运行时 → 更改运行时类型 → GPU")
40
+ return False
41
+
42
+
43
+ def benchmark_matrix_multiplication(matrix_size=5000):
44
+ """GPU vs CPU 矩阵运算性能测试"""
45
+ print_section("⚡ GPU vs CPU 矩阵运算性能测试")
46
+
47
+ print(f"矩阵大小: {matrix_size}x{matrix_size}\n")
48
+
49
+ # CPU测试
50
+ print("🔵 CPU测试...")
51
+ a_cpu = torch.randn(matrix_size, matrix_size)
52
+ b_cpu = torch.randn(matrix_size, matrix_size)
53
+
54
+ start = time.time()
55
+ c_cpu = torch.mm(a_cpu, b_cpu)
56
+ cpu_time = time.time() - start
57
+ print(f" CPU时间: {cpu_time:.2f} 秒")
58
+
59
+ # GPU测试
60
+ if torch.cuda.is_available():
61
+ print("\n🟢 GPU测试...")
62
+ a_gpu = torch.randn(matrix_size, matrix_size).cuda()
63
+ b_gpu = torch.randn(matrix_size, matrix_size).cuda()
64
+
65
+ # 预热GPU
66
+ _ = torch.mm(a_gpu, b_gpu)
67
+ torch.cuda.synchronize()
68
+
69
+ start = time.time()
70
+ c_gpu = torch.mm(a_gpu, b_gpu)
71
+ torch.cuda.synchronize()
72
+ gpu_time = time.time() - start
73
+ print(f" GPU时间: {gpu_time:.2f} 秒")
74
+
75
+ speedup = cpu_time / gpu_time
76
+ print(f"\n🚀 加速比: {speedup:.2f}x")
77
+ print(f" GPU比CPU快 {speedup:.1f} 倍!")
78
+
79
+ return speedup
80
+ else:
81
+ print("\n⚠️ 跳过GPU测试(GPU不可用)")
82
+ return 1.0
83
+
84
+
85
+ def test_text_embedding_performance():
86
+ """测试文本嵌入性能(需要sentence-transformers)"""
87
+ print_section("📝 文本嵌入性能测试")
88
+
89
+ try:
90
+ from sentence_transformers import SentenceTransformer
91
+
92
+ # 准备测试数据
93
+ test_texts = [
94
+ "Large Language Models are transforming AI",
95
+ "GraphRAG combines knowledge graphs with retrieval",
96
+ "GPU acceleration significantly improves performance",
97
+ "Natural language processing is advancing rapidly",
98
+ ] * 250 # 1000个文本
99
+
100
+ print(f"测试数据: {len(test_texts)} 个文本\n")
101
+
102
+ # CPU测试
103
+ print("🔵 CPU嵌入测试...")
104
+ model_cpu = SentenceTransformer(
105
+ 'sentence-transformers/all-MiniLM-L6-v2',
106
+ device='cpu'
107
+ )
108
+ start = time.time()
109
+ embeddings_cpu = model_cpu.encode(test_texts, show_progress_bar=False, batch_size=32)
110
+ cpu_time = time.time() - start
111
+ print(f" CPU时间: {cpu_time:.2f}秒")
112
+ print(f" 速度: {len(test_texts)/cpu_time:.1f} 文本/秒")
113
+
114
+ # GPU测试
115
+ if torch.cuda.is_available():
116
+ print("\n🟢 GPU嵌入测试...")
117
+ model_gpu = SentenceTransformer(
118
+ 'sentence-transformers/all-MiniLM-L6-v2',
119
+ device='cuda'
120
+ )
121
+ start = time.time()
122
+ embeddings_gpu = model_gpu.encode(test_texts, show_progress_bar=False, batch_size=32)
123
+ gpu_time = time.time() - start
124
+ print(f" GPU时间: {gpu_time:.2f}秒")
125
+ print(f" 速度: {len(test_texts)/gpu_time:.1f} 文本/秒")
126
+
127
+ speedup = cpu_time / gpu_time
128
+ print(f"\n🚀 加速比: {speedup:.2f}x")
129
+ print(f" 节省时间: {cpu_time - gpu_time:.2f}秒")
130
+
131
+ return speedup
132
+ else:
133
+ print("\n⚠️ 跳过GPU测试")
134
+ return 1.0
135
+
136
+ except ImportError:
137
+ print("⚠️ sentence-transformers未安装")
138
+ print(" 安装: pip install sentence-transformers")
139
+ return None
140
+
141
+
142
+ def monitor_gpu_memory():
143
+ """监控GPU显存使用"""
144
+ if not torch.cuda.is_available():
145
+ return
146
+
147
+ print_section("💾 GPU显存使用情况")
148
+
149
+ allocated = torch.cuda.memory_allocated(0) / (1024**3)
150
+ reserved = torch.cuda.memory_reserved(0) / (1024**3)
151
+ total = torch.cuda.get_device_properties(0).total_memory / (1024**3)
152
+
153
+ print(f"已分配: {allocated:.2f} GB")
154
+ print(f"已保留: {reserved:.2f} GB")
155
+ print(f"总显存: {total:.2f} GB")
156
+ print(f"使用率: {(allocated/total)*100:.1f}%")
157
+
158
+
159
+ def generate_performance_report(matrix_speedup, embedding_speedup):
160
+ """生成性能报告"""
161
+ print_section("📈 性能测试总结报告")
162
+
163
+ print("🖥️ 硬件信息:")
164
+ if torch.cuda.is_available():
165
+ print(f" GPU型号: {torch.cuda.get_device_name(0)}")
166
+ print(f" 显存: {torch.cuda.get_device_properties(0).total_memory / (1024**3):.2f} GB")
167
+ print(f" CUDA版本: {torch.version.cuda}")
168
+ else:
169
+ print(" ⚠️ GPU不可用")
170
+
171
+ print(f"\n PyTorch版本: {torch.__version__}")
172
+ print(f" Python版本: {sys.version.split()[0]}")
173
+
174
+ print("\n⚡ 性能测试结果:")
175
+ print(f" 矩阵运算加速: {matrix_speedup:.2f}x")
176
+ if embedding_speedup:
177
+ print(f" 文本嵌入加速: {embedding_speedup:.2f}x")
178
+
179
+ print("\n💡 建议:")
180
+ if torch.cuda.is_available():
181
+ print(" ✅ GPU运行良好!")
182
+ print(" ✅ 建议在Colab上运行完整的GraphRAG索引构建")
183
+ print(" ✅ 预计索引构建时间将缩短 3-5 倍")
184
+
185
+ # 估算时间节省
186
+ if embedding_speedup and embedding_speedup > 1:
187
+ print(f"\n⏱️ 时间节省估算:")
188
+ print(f" 100文档CPU耗时: ~15分钟")
189
+ print(f" 100文档GPU耗时: ~{15/embedding_speedup:.1f}分钟")
190
+ print(f" 节省: ~{15 - 15/embedding_speedup:.1f}分钟")
191
+ else:
192
+ print(" ⚠️ 建议启用GPU以获得最佳性能")
193
+ print(" ⚠️ Colab启用GPU: 运行时 → 更改运行时类型 → GPU")
194
+
195
+
196
+ def install_dependencies():
197
+ """安装必要的依赖(仅在Colab中)"""
198
+ try:
199
+ import google.colab
200
+ is_colab = True
201
+ except:
202
+ is_colab = False
203
+
204
+ if is_colab:
205
+ print_section("📦 安装依赖")
206
+ print("检测到Colab环境,安装必要的包...\n")
207
+
208
+ import subprocess
209
+ packages = [
210
+ 'sentence-transformers',
211
+ 'networkx',
212
+ 'python-louvain',
213
+ ]
214
+
215
+ for package in packages:
216
+ try:
217
+ __import__(package.replace('-', '_'))
218
+ print(f"✅ {package} 已安装")
219
+ except ImportError:
220
+ print(f"📥 安装 {package}...")
221
+ subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', package])
222
+ print(f"✅ {package} 安装完成")
223
+
224
+
225
+ def main():
226
+ """主函数"""
227
+ print("\n" + "="*60)
228
+ print("🚀 Google Colab GPU检测和GraphRAG性能测试")
229
+ print("="*60)
230
+
231
+ # 检查是否在Colab中运行
232
+ try:
233
+ import google.colab
234
+ print("\n✅ 运行环境: Google Colab")
235
+ except:
236
+ print("\n⚠️ 警告: 未检测到Colab环境")
237
+ print(" 本脚本专为Google Colab设计")
238
+
239
+ # 安装依赖
240
+ install_dependencies()
241
+
242
+ # 1. GPU检测
243
+ gpu_available = test_gpu_availability()
244
+
245
+ # 2. 矩阵运算性能测试
246
+ matrix_speedup = benchmark_matrix_multiplication(matrix_size=5000)
247
+
248
+ # 3. 文本嵌入性能测试
249
+ embedding_speedup = test_text_embedding_performance()
250
+
251
+ # 4. 显存监控
252
+ if gpu_available:
253
+ monitor_gpu_memory()
254
+
255
+ # 5. 生成报告
256
+ generate_performance_report(matrix_speedup, embedding_speedup)
257
+
258
+ print("\n" + "="*60)
259
+ print("✅ 测试完成!")
260
+ print("="*60)
261
+
262
+ print("\n📚 下一步:")
263
+ print(" 1. 如果GPU测试成功,可以上传完整的adaptive_RAG项目")
264
+ print(" 2. 运行 main_graphrag.py 进行完整的知识图谱构建")
265
+ print(" 3. 享受GPU带来的3-5倍速度提升!")
266
+
267
+
268
+ if __name__ == "__main__":
269
+ main()
colab_quick_test.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Google Colab一键GPU测试脚本
4
+ 复制此文件内容到Colab单元格中直接运行
5
+
6
+ 使用方法:
7
+ 1. 在Colab中创建新笔记本
8
+ 2. 启用GPU (运行时 → 更改运行时类型 → GPU)
9
+ 3. 复制并运行此脚本
10
+ """
11
+
12
+ # ============================================================
13
+ # 🔧 自动安装依赖
14
+ # ============================================================
15
+ print("📦 检查并安装依赖...")
16
+ import subprocess
17
+ import sys
18
+
19
+ def install(package):
20
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", package])
21
+
22
+ # 检查必要的包
23
+ required_packages = {
24
+ 'torch': 'torch',
25
+ 'sentence_transformers': 'sentence-transformers',
26
+ 'networkx': 'networkx',
27
+ 'numpy': 'numpy'
28
+ }
29
+
30
+ for import_name, package_name in required_packages.items():
31
+ try:
32
+ __import__(import_name)
33
+ print(f"✅ {package_name} 已安装")
34
+ except ImportError:
35
+ print(f"📥 安装 {package_name}...")
36
+ install(package_name)
37
+
38
+ print("\n" + "="*70)
39
+ print("🚀 Google Colab GPU性能测试 - GraphRAG加速验证")
40
+ print("="*70)
41
+
42
+ # ============================================================
43
+ # 1️⃣ GPU检测
44
+ # ============================================================
45
+ import torch
46
+ import time
47
+
48
+ print("\n" + "="*70)
49
+ print("🔍 步骤1: GPU环境检测")
50
+ print("="*70)
51
+
52
+ cuda_available = torch.cuda.is_available()
53
+ print(f"\n{'✅' if cuda_available else '❌'} CUDA可用: {cuda_available}")
54
+
55
+ if cuda_available:
56
+ print(f" 📊 GPU型号: {torch.cuda.get_device_name(0)}")
57
+ print(f" 💾 显存大小: {torch.cuda.get_device_properties(0).total_memory / (1024**3):.2f} GB")
58
+ print(f" 🔢 CUDA版本: {torch.version.cuda}")
59
+ print(f" 📈 PyTorch版本: {torch.__version__}")
60
+ else:
61
+ print("\n⚠️ GPU未启用!")
62
+ print(" 请按照以下步骤启用GPU:")
63
+ print(" 1. 点击顶部菜单 '运行时'")
64
+ print(" 2. 选择 '更改运行时类型'")
65
+ print(" 3. 硬件加速器选择 'GPU'")
66
+ print(" 4. 点击 '保存'")
67
+ print(" 5. 重新运行此单元格")
68
+ print("\n⚠️ 测试将继续,但GPU相关测试会被跳过")
69
+
70
+ # ============================================================
71
+ # 2️⃣ 矩阵运算性能测试
72
+ # ============================================================
73
+ print("\n" + "="*70)
74
+ print("⚡ 步骤2: 矩阵运算性能测试")
75
+ print("="*70)
76
+
77
+ matrix_size = 5000
78
+ print(f"\n测试配置: {matrix_size}x{matrix_size} 矩阵乘法\n")
79
+
80
+ # CPU测试
81
+ print("🔵 CPU性能测试...")
82
+ a_cpu = torch.randn(matrix_size, matrix_size)
83
+ b_cpu = torch.randn(matrix_size, matrix_size)
84
+
85
+ start = time.time()
86
+ c_cpu = torch.mm(a_cpu, b_cpu)
87
+ cpu_time = time.time() - start
88
+
89
+ print(f" ⏱️ CPU耗时: {cpu_time:.3f}秒")
90
+
91
+ # GPU测试
92
+ if cuda_available:
93
+ print("\n🟢 GPU性能测试...")
94
+ a_gpu = torch.randn(matrix_size, matrix_size).cuda()
95
+ b_gpu = torch.randn(matrix_size, matrix_size).cuda()
96
+
97
+ # 预热
98
+ _ = torch.mm(a_gpu, b_gpu)
99
+ torch.cuda.synchronize()
100
+
101
+ start = time.time()
102
+ c_gpu = torch.mm(a_gpu, b_gpu)
103
+ torch.cuda.synchronize()
104
+ gpu_time = time.time() - start
105
+
106
+ print(f" ⏱️ GPU耗时: {gpu_time:.3f}秒")
107
+
108
+ speedup = cpu_time / gpu_time
109
+ print(f"\n 🚀 性能提升: {speedup:.1f}x")
110
+ print(f" 💡 GPU比CPU快 {speedup:.1f} 倍!")
111
+
112
+ matrix_speedup = speedup
113
+ else:
114
+ print("\n⚠️ 跳过GPU测试")
115
+ matrix_speedup = 1.0
116
+
117
+ # ============================================================
118
+ # 3️⃣ 文本嵌入性能测试
119
+ # ============================================================
120
+ print("\n" + "="*70)
121
+ print("📝 步骤3: 文本嵌入性能测试 (GraphRAG核心组件)")
122
+ print("="*70)
123
+
124
+ try:
125
+ from sentence_transformers import SentenceTransformer
126
+
127
+ # 准备测试数据
128
+ test_texts = [
129
+ "GraphRAG combines knowledge graphs with retrieval augmented generation",
130
+ "GPU acceleration significantly improves machine learning performance",
131
+ "Large language models benefit from efficient embedding computation",
132
+ "Knowledge graph construction requires entity and relation extraction",
133
+ ] * 250 # 1000条文本
134
+
135
+ print(f"\n测试配置: {len(test_texts)}条文本嵌入\n")
136
+
137
+ # CPU嵌入
138
+ print("🔵 CPU嵌入测试...")
139
+ model_cpu = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', device='cpu')
140
+
141
+ start = time.time()
142
+ embeddings_cpu = model_cpu.encode(test_texts, show_progress_bar=False, batch_size=32)
143
+ cpu_emb_time = time.time() - start
144
+
145
+ print(f" ⏱️ CPU耗时: {cpu_emb_time:.2f}秒")
146
+ print(f" 📊 处理速度: {len(test_texts)/cpu_emb_time:.1f} 文本/秒")
147
+
148
+ # GPU嵌入
149
+ if cuda_available:
150
+ print("\n🟢 GPU嵌入测试...")
151
+ model_gpu = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', device='cuda')
152
+
153
+ start = time.time()
154
+ embeddings_gpu = model_gpu.encode(test_texts, show_progress_bar=False, batch_size=32)
155
+ gpu_emb_time = time.time() - start
156
+
157
+ print(f" ⏱️ GPU耗时: {gpu_emb_time:.2f}秒")
158
+ print(f" 📊 处理速度: {len(test_texts)/gpu_emb_time:.1f} 文本/秒")
159
+
160
+ emb_speedup = cpu_emb_time / gpu_emb_time
161
+ print(f"\n 🚀 性能提升: {emb_speedup:.1f}x")
162
+ print(f" ⏱️ 节省时间: {cpu_emb_time - gpu_emb_time:.2f}秒")
163
+ else:
164
+ print("\n⚠️ 跳过GPU测试")
165
+ emb_speedup = 1.0
166
+
167
+ except ImportError:
168
+ print("\n⚠️ sentence-transformers未安装,跳过此测试")
169
+ emb_speedup = None
170
+
171
+ # ============================================================
172
+ # 4️⃣ GraphRAG场景模拟
173
+ # ============================================================
174
+ print("\n" + "="*70)
175
+ print("🔍 步骤4: GraphRAG实际场景模拟")
176
+ print("="*70)
177
+
178
+ if cuda_available and emb_speedup:
179
+ print("\n模拟GraphRAG索引构建过程...\n")
180
+
181
+ # 假设100个文档块的索引构建
182
+ documents_count = 100
183
+
184
+ # 实体提取时间 (每个文档约1秒)
185
+ entity_extraction_time = documents_count * 1.0
186
+
187
+ # 文本嵌入时间 (基于实际测试)
188
+ # 假设每个文档平均产生10个实体,共1000个实体需要嵌入
189
+ entities_count = documents_count * 10
190
+
191
+ cpu_total_time = entity_extraction_time + (entities_count / (len(test_texts)/cpu_emb_time))
192
+ gpu_total_time = entity_extraction_time + (entities_count / (len(test_texts)/gpu_emb_time))
193
+
194
+ print(f"📊 场景: {documents_count}个文档的GraphRAG索引构建\n")
195
+ print(f"🔵 CPU预计时间:")
196
+ print(f" - 实体提取: {entity_extraction_time/60:.1f}分钟")
197
+ print(f" - 向量嵌入: {(entities_count / (len(test_texts)/cpu_emb_time))/60:.1f}分钟")
198
+ print(f" - 总计: {cpu_total_time/60:.1f}分钟")
199
+
200
+ print(f"\n🟢 GPU预计时间:")
201
+ print(f" - 实体提取: {entity_extraction_time/60:.1f}分钟 (相同)")
202
+ print(f" - 向量嵌入: {(entities_count / (len(test_texts)/gpu_emb_time))/60:.1f}分钟")
203
+ print(f" - 总计: {gpu_total_time/60:.1f}分钟")
204
+
205
+ total_speedup = cpu_total_time / gpu_total_time
206
+ time_saved = (cpu_total_time - gpu_total_time) / 60
207
+
208
+ print(f"\n🚀 整体加速: {total_speedup:.1f}x")
209
+ print(f"⏱️ 节省时间: {time_saved:.1f}分钟")
210
+
211
+ # ============================================================
212
+ # 5️⃣ GPU显存监控
213
+ # ============================================================
214
+ if cuda_available:
215
+ print("\n" + "="*70)
216
+ print("💾 步骤5: GPU显存使用监控")
217
+ print("="*70)
218
+
219
+ allocated = torch.cuda.memory_allocated(0) / (1024**3)
220
+ reserved = torch.cuda.memory_reserved(0) / (1024**3)
221
+ total = torch.cuda.get_device_properties(0).total_memory / (1024**3)
222
+
223
+ print(f"\n 已分配: {allocated:.2f} GB")
224
+ print(f" 已保留: {reserved:.2f} GB")
225
+ print(f" 总显存: {total:.2f} GB")
226
+ print(f" 使用率: {(allocated/total)*100:.1f}%")
227
+
228
+ # ============================================================
229
+ # 6️⃣ 性能总结
230
+ # ============================================================
231
+ print("\n" + "="*70)
232
+ print("📈 最终性能报告")
233
+ print("="*70)
234
+
235
+ print("\n🖥️ 硬件配置:")
236
+ if cuda_available:
237
+ print(f" GPU: {torch.cuda.get_device_name(0)}")
238
+ print(f" 显存: {torch.cuda.get_device_properties(0).total_memory / (1024**3):.1f} GB")
239
+ print(f" CUDA: {torch.version.cuda}")
240
+ else:
241
+ print(" ⚠️ GPU未启用")
242
+
243
+ print(f"\n⚡ 性能测试结果:")
244
+ print(f" 矩阵运算加速: {matrix_speedup:.1f}x")
245
+ if emb_speedup:
246
+ print(f" 文本嵌入加速: {emb_speedup:.1f}x")
247
+ if cuda_available:
248
+ print(f" GraphRAG整体加速: {total_speedup:.1f}x")
249
+
250
+ print("\n💡 结论和建议:")
251
+ if cuda_available:
252
+ print(" ✅ GPU性能测试成功!")
253
+ print(" ✅ 强烈建议在Colab GPU环境运行GraphRAG")
254
+ print(f" ✅ 预计可节省 {time_saved:.0f}+ 分钟的索引构建时间")
255
+ print("\n📚 下一步:")
256
+ print(" 1. 上传adaptive_RAG项目文件到Colab")
257
+ print(" 2. 运行 main_graphrag.py 构建完整知识图谱")
258
+ print(" 3. 下载结果到本地使用")
259
+ else:
260
+ print(" ⚠️ 请启用GPU以获得最佳性能")
261
+ print(" ⚠️ 路径: 运行时 → 更改运行时类型 → GPU")
262
+
263
+ print("\n" + "="*70)
264
+ print("✅ 测试完成! 感谢使用GraphRAG GPU测试工具")
265
+ print("="*70)
266
+
267
+ # ============================================================
268
+ # 7️⃣ 可选: 显示nvidia-smi
269
+ # ============================================================
270
+ if cuda_available:
271
+ print("\n📊 nvidia-smi 详细信息:")
272
+ print("="*70)
273
+ import subprocess
274
+ try:
275
+ result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)
276
+ print(result.stdout)
277
+ except:
278
+ print("⚠️ 无法执行nvidia-smi命令")
colab_setup_and_run.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Google Colab环境下的GraphRAG完整运行脚本
4
+ 解决Ollama服务启动和GraphRAG运行的问题
5
+
6
+ 使用方法:
7
+ 1. 在Colab中启用GPU
8
+ 2. 复制此文件到Colab
9
+ 3. 运行: !python colab_setup_and_run.py
10
+ """
11
+
12
+ import os
13
+ import sys
14
+ import time
15
+ import subprocess
16
+ import signal
17
+ from pathlib import Path
18
+
19
+ print("="*70)
20
+ print("🚀 GraphRAG Colab 自动化部署脚本")
21
+ print("="*70)
22
+
23
+ # ============================================================
24
+ # 1️⃣ 检测Colab环境
25
+ # ============================================================
26
+ def check_colab_environment():
27
+ """检测是否在Colab环境中"""
28
+ try:
29
+ import google.colab
30
+ print("\n✅ 运行环境: Google Colab")
31
+ return True
32
+ except ImportError:
33
+ print("\n⚠️ 警告: 未检测到Colab环境")
34
+ print(" 本脚本为Colab优化,在其他环境可能需要调整")
35
+ return False
36
+
37
+ # ============================================================
38
+ # 2️⃣ 安装Ollama
39
+ # ============================================================
40
+ def install_ollama():
41
+ """在Colab中安装Ollama"""
42
+ print("\n" + "="*70)
43
+ print("📦 步骤1: 安装Ollama")
44
+ print("="*70)
45
+
46
+ # 检查是否已安装
47
+ if os.path.exists("/usr/local/bin/ollama"):
48
+ print("✅ Ollama已安装")
49
+ return True
50
+
51
+ print("\n📥 下载并安装Ollama...")
52
+ try:
53
+ # 下载Ollama安装脚本
54
+ subprocess.run(
55
+ ["curl", "-fsSL", "https://ollama.com/install.sh", "-o", "/tmp/install_ollama.sh"],
56
+ check=True,
57
+ capture_output=True
58
+ )
59
+
60
+ # 执行安装
61
+ subprocess.run(
62
+ ["sh", "/tmp/install_ollama.sh"],
63
+ check=True,
64
+ capture_output=True
65
+ )
66
+
67
+ print("✅ Ollama安装成功")
68
+ return True
69
+
70
+ except subprocess.CalledProcessError as e:
71
+ print(f"❌ Ollama安装失败: {e}")
72
+ return False
73
+
74
+ # ============================================================
75
+ # 3️⃣ 后台启动Ollama服务
76
+ # ============================================================
77
+ def start_ollama_service():
78
+ """在后台启动Ollama服务"""
79
+ print("\n" + "="*70)
80
+ print("🔧 步骤2: 启动Ollama服务")
81
+ print("="*70)
82
+
83
+ print("\n🔄 在后台启动Ollama服务...")
84
+
85
+ # 方法1: 使用subprocess后台运行
86
+ try:
87
+ # 启动Ollama服务(后台)
88
+ ollama_process = subprocess.Popen(
89
+ ["ollama", "serve"],
90
+ stdout=subprocess.PIPE,
91
+ stderr=subprocess.PIPE,
92
+ preexec_fn=os.setpgrp # 创建新的进程组
93
+ )
94
+
95
+ # 等待服务启动
96
+ print("⏳ 等待Ollama服务启动...")
97
+ time.sleep(5)
98
+
99
+ # 检查服务是否运行
100
+ try:
101
+ result = subprocess.run(
102
+ ["curl", "-s", "http://localhost:11434/api/tags"],
103
+ capture_output=True,
104
+ timeout=3
105
+ )
106
+
107
+ if result.returncode == 0:
108
+ print("✅ Ollama服务已启动 (PID: {})".format(ollama_process.pid))
109
+
110
+ # 保存进程ID以便后续管理
111
+ with open("/tmp/ollama.pid", "w") as f:
112
+ f.write(str(ollama_process.pid))
113
+
114
+ return ollama_process
115
+ else:
116
+ print("⚠️ 服务启动可能有问题,继续尝试...")
117
+
118
+ except subprocess.TimeoutExpired:
119
+ print("⚠️ 服务检查超时,但进程已启动")
120
+ return ollama_process
121
+
122
+ except Exception as e:
123
+ print(f"❌ 启动Ollama失败: {e}")
124
+ return None
125
+
126
+ # ============================================================
127
+ # 4️⃣ 下载Mistral模型
128
+ # ============================================================
129
+ def pull_mistral_model():
130
+ """下载Mistral模型"""
131
+ print("\n" + "="*70)
132
+ print("📥 步骤3: 下载Mistral模型")
133
+ print("="*70)
134
+
135
+ print("\n🔄 拉取mistral模型(这可能需要几分钟)...")
136
+
137
+ try:
138
+ # 检查模型是否已存在
139
+ result = subprocess.run(
140
+ ["ollama", "list"],
141
+ capture_output=True,
142
+ text=True,
143
+ timeout=10
144
+ )
145
+
146
+ if "mistral" in result.stdout:
147
+ print("✅ Mistral模型已存在")
148
+ return True
149
+
150
+ # 下载模型
151
+ print("📥 开始下载Mistral模型...")
152
+ process = subprocess.Popen(
153
+ ["ollama", "pull", "mistral"],
154
+ stdout=subprocess.PIPE,
155
+ stderr=subprocess.STDOUT,
156
+ text=True
157
+ )
158
+
159
+ # 实时显示下载进度
160
+ for line in process.stdout:
161
+ print(f" {line.strip()}")
162
+
163
+ process.wait()
164
+
165
+ if process.returncode == 0:
166
+ print("✅ Mistral模型下载完成")
167
+ return True
168
+ else:
169
+ print("❌ 模型下载失败")
170
+ return False
171
+
172
+ except Exception as e:
173
+ print(f"❌ 下载Mistral模型失败: {e}")
174
+ return False
175
+
176
+ # ============================================================
177
+ # 5️⃣ 安装Python依赖
178
+ # ============================================================
179
+ def install_python_dependencies():
180
+ """安装GraphRAG所需的Python包"""
181
+ print("\n" + "="*70)
182
+ print("📦 步骤4: 安装Python依赖")
183
+ print("="*70)
184
+
185
+ packages = [
186
+ "langchain",
187
+ "langchain-community",
188
+ "langchain-core",
189
+ "langgraph",
190
+ "langchain-ollama",
191
+ "chromadb",
192
+ "sentence-transformers",
193
+ "tiktoken",
194
+ "beautifulsoup4",
195
+ "requests",
196
+ "tavily-python",
197
+ "python-dotenv",
198
+ "networkx",
199
+ "python-louvain",
200
+ "torch",
201
+ "transformers"
202
+ ]
203
+
204
+ print("\n📥 安装必要的Python包...")
205
+ for package in packages:
206
+ try:
207
+ __import__(package.replace("-", "_"))
208
+ print(f"✅ {package} 已安装")
209
+ except ImportError:
210
+ print(f"📥 安装 {package}...")
211
+ subprocess.run(
212
+ [sys.executable, "-m", "pip", "install", "-q", package],
213
+ check=True
214
+ )
215
+
216
+ print("\n✅ 所有依赖安装完成")
217
+
218
+ # ============================================================
219
+ # 6️⃣ 配置环境变量
220
+ # ============================================================
221
+ def setup_environment():
222
+ """配置环境变量"""
223
+ print("\n" + "="*70)
224
+ print("🔑 步骤5: 配置环境变量")
225
+ print("="*70)
226
+
227
+ # 检查.env文件
228
+ if os.path.exists(".env"):
229
+ print("\n✅ 发现.env文件,加载配置...")
230
+ from dotenv import load_dotenv
231
+ load_dotenv()
232
+ else:
233
+ print("\n⚠️ 未找到.env文件")
234
+
235
+ # 交互式输入API密钥
236
+ if "TAVILY_API_KEY" not in os.environ:
237
+ from getpass import getpass
238
+ api_key = getpass("请输入TAVILY_API_KEY (或按Enter跳过): ")
239
+ if api_key:
240
+ os.environ["TAVILY_API_KEY"] = api_key
241
+ print("✅ TAVILY_API_KEY已设置")
242
+ else:
243
+ print("⚠️ 跳过TAVILY_API_KEY设置(网络搜索功能将不可用)")
244
+
245
+ print("\n📋 当前环境变量:")
246
+ print(f" TAVILY_API_KEY: {'已设置' if os.environ.get('TAVILY_API_KEY') else '未设置'}")
247
+
248
+ # ============================================================
249
+ # 7️⃣ 运行GraphRAG
250
+ # ============================================================
251
+ def run_graphrag():
252
+ """运行GraphRAG主程序"""
253
+ print("\n" + "="*70)
254
+ print("🚀 步骤6: 运行GraphRAG")
255
+ print("="*70)
256
+
257
+ # 检查main_graphrag.py是否存在
258
+ if not os.path.exists("main_graphrag.py"):
259
+ print("\n❌ 未找到main_graphrag.py文件")
260
+ print(" 请确保已上传项目文件到Colab")
261
+ return False
262
+
263
+ print("\n🔄 启动GraphRAG索引构建...\n")
264
+
265
+ try:
266
+ # 运行GraphRAG
267
+ result = subprocess.run(
268
+ [sys.executable, "main_graphrag.py"],
269
+ capture_output=False, # 实时输出
270
+ text=True
271
+ )
272
+
273
+ if result.returncode == 0:
274
+ print("\n✅ GraphRAG运行成功!")
275
+ return True
276
+ else:
277
+ print(f"\n❌ GraphRAG运行失败 (返回码: {result.returncode})")
278
+ return False
279
+
280
+ except KeyboardInterrupt:
281
+ print("\n⚠️ 用户中断执行")
282
+ return False
283
+ except Exception as e:
284
+ print(f"\n❌ 运行GraphRAG时出错: {e}")
285
+ return False
286
+
287
+ # ============================================================
288
+ # 8️⃣ 清理函数
289
+ # ============================================================
290
+ def cleanup():
291
+ """清理后台进程"""
292
+ print("\n" + "="*70)
293
+ print("🧹 清理后台进程")
294
+ print("="*70)
295
+
296
+ # 停止Ollama服务
297
+ if os.path.exists("/tmp/ollama.pid"):
298
+ try:
299
+ with open("/tmp/ollama.pid", "r") as f:
300
+ pid = int(f.read().strip())
301
+
302
+ os.kill(pid, signal.SIGTERM)
303
+ print(f"✅ Ollama服务已停止 (PID: {pid})")
304
+ os.remove("/tmp/ollama.pid")
305
+
306
+ except Exception as e:
307
+ print(f"⚠️ 停止Ollama服务失败: {e}")
308
+
309
+ # ============================================================
310
+ # 主函数
311
+ # ============================================================
312
+ def main():
313
+ """主执行流程"""
314
+ ollama_process = None
315
+
316
+ try:
317
+ # 1. 检测环境
318
+ is_colab = check_colab_environment()
319
+
320
+ # 2. 安装Ollama
321
+ if not install_ollama():
322
+ print("\n❌ Ollama安装失败,无法继续")
323
+ return
324
+
325
+ # 3. 启动Ollama服务
326
+ ollama_process = start_ollama_service()
327
+ if not ollama_process:
328
+ print("\n❌ Ollama服务启动失败,无法继续")
329
+ return
330
+
331
+ # 4. 下载模型
332
+ if not pull_mistral_model():
333
+ print("\n❌ Mistral模型下载失败,无法继续")
334
+ return
335
+
336
+ # 5. 安装Python依赖
337
+ install_python_dependencies()
338
+
339
+ # 6. 配置环境
340
+ setup_environment()
341
+
342
+ # 7. 运行GraphRAG
343
+ success = run_graphrag()
344
+
345
+ if success:
346
+ print("\n" + "="*70)
347
+ print("✅ 所有任务完成!")
348
+ print("="*70)
349
+
350
+ print("\n📊 生成的文件:")
351
+ if os.path.exists("data/knowledge_graph.json"):
352
+ print(" ✅ data/knowledge_graph.json")
353
+
354
+ # 提供下载选项
355
+ if is_colab:
356
+ print("\n💾 下载结果:")
357
+ print(" from google.colab import files")
358
+ print(" files.download('data/knowledge_graph.json')")
359
+
360
+ except KeyboardInterrupt:
361
+ print("\n\n⚠️ 用户中断执行")
362
+
363
+ except Exception as e:
364
+ print(f"\n❌ 执行过程中出错: {e}")
365
+ import traceback
366
+ traceback.print_exc()
367
+
368
+ finally:
369
+ # 清理
370
+ print("\n⚠️ 注意: Ollama服务仍在后台运行")
371
+ print(" 如需停止: !pkill -f 'ollama serve'")
372
+ print(" 或运行: cleanup()")
373
+
374
+ if __name__ == "__main__":
375
+ main()
config.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 配置和环境设置模块
3
+ 包含API密钥管理、模型配置和URL配置
4
+ """
5
+
6
+ import getpass
7
+ import os
8
+
9
+ # 尝试加载.env文件,如果没有安装dotenv则跳过
10
+ try:
11
+ from dotenv import load_dotenv
12
+ load_dotenv()
13
+ print("✅ .env文件已加载")
14
+ except ImportError:
15
+ print("⚠️ python-dotenv未安装,将使用系统环境变量")
16
+ print("提示:运行 'pip install python-dotenv' 来安装")
17
+
18
+
19
+ def _set_env(var: str):
20
+ """设置环境变量,优先从.env文件读取,如果不存在则提示用户输入"""
21
+ if not os.environ.get(var):
22
+ os.environ[var] = getpass.getpass(f"{var}: ")
23
+
24
+
25
+ def setup_environment():
26
+ """设置所有必需的环境变量"""
27
+ _set_env("TAVILY_API_KEY")
28
+ # 不再需要NOMIC_API_KEY,使用HuggingFace本地嵌入
29
+
30
+ # 验证API密钥是否已设置
31
+ tavily_key = os.environ.get("TAVILY_API_KEY")
32
+
33
+ if tavily_key:
34
+ print("✅ TAVILY_API_KEY 已从环境变量中加载")
35
+ else:
36
+ print("⚠️ TAVILY_API_KEY 未找到")
37
+
38
+
39
+ # 模型配置
40
+ LOCAL_LLM = "mistral"
41
+
42
+ # 知识库URL配置
43
+ KNOWLEDGE_BASE_URLS = [
44
+ "https://lilianweng.github.io/posts/2023-06-23-agent/",
45
+ "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
46
+ "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
47
+ ]
48
+
49
+ # 文档分块配置
50
+ CHUNK_SIZE = 250
51
+ CHUNK_OVERLAP = 0
52
+
53
+ # 向量数据库配置
54
+ COLLECTION_NAME = "rag-chroma"
55
+ EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2" # HuggingFace嵌入模型
56
+
57
+ # 搜索配置
58
+ WEB_SEARCH_RESULTS_COUNT = 3
59
+
60
+ # GraphRAG配置
61
+ ENABLE_GRAPHRAG = True # 是否启用GraphRAG功能
62
+ GRAPHRAG_INDEX_PATH = "./data/knowledge_graph.json" # 图谱索引保存路径
63
+ GRAPHRAG_COMMUNITY_ALGORITHM = "louvain" # 社区检测算法: louvain, greedy, label_propagation
64
+ GRAPHRAG_MAX_HOPS = 2 # 本地查询最大跳数
65
+ GRAPHRAG_TOP_K_COMMUNITIES = 5 # 全局查询使用的社区数量
66
+ GRAPHRAG_BATCH_SIZE = 10 # 实体提取批处理大小
67
+
68
+
69
+ def get_api_keys():
70
+ """获取API密钥并返回字典"""
71
+ return {
72
+ "tavily": os.environ.get("TAVILY_API_KEY")
73
+ }
74
+
75
+
76
+ def validate_api_keys():
77
+ """验证API密钥是否已设置"""
78
+ keys = get_api_keys()
79
+ missing_keys = []
80
+
81
+ if not keys["tavily"]:
82
+ missing_keys.append("TAVILY_API_KEY")
83
+
84
+ if missing_keys:
85
+ raise ValueError(f"缺少必需的API密钥: {', '.join(missing_keys)}\n请在.env文件中设置这些密钥")
86
+
87
+ return True
deploy_gpu.sh ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # GPU部署脚本 - 一键部署自适应RAG系统到Linux RTX 4090环境
3
+
4
+ set -e # 遇到错误立即退出
5
+
6
+ echo "🚀 开始部署自适应RAG系统到GPU环境..."
7
+
8
+ # 颜色定义
9
+ RED='\033[0;31m'
10
+ GREEN='\033[0;32m'
11
+ YELLOW='\033[1;33m'
12
+ NC='\033[0m' # No Color
13
+
14
+ # 检查是否为root用户
15
+ if [[ $EUID -eq 0 ]]; then
16
+ echo -e "${RED}请不要使用root用户运行此脚本${NC}"
17
+ exit 1
18
+ fi
19
+
20
+ # 检查GPU
21
+ check_gpu() {
22
+ echo "🔍 检查GPU环境..."
23
+ if command -v nvidia-smi &> /dev/null; then
24
+ echo -e "${GREEN}✅ 发现NVIDIA GPU:${NC}"
25
+ nvidia-smi --query-gpu=name,memory.total --format=csv,noheader
26
+ else
27
+ echo -e "${RED}❌ 未发现NVIDIA GPU或驱动未安装${NC}"
28
+ exit 1
29
+ fi
30
+ }
31
+
32
+ # 检查CUDA
33
+ check_cuda() {
34
+ echo "🔍 检查CUDA环境..."
35
+ if command -v nvcc &> /dev/null; then
36
+ echo -e "${GREEN}✅ CUDA版本:${NC}"
37
+ nvcc --version | grep "release"
38
+ else
39
+ echo -e "${YELLOW}⚠️ CUDA未安装或未添加到PATH${NC}"
40
+ fi
41
+ }
42
+
43
+ # 安装Docker
44
+ install_docker() {
45
+ if ! command -v docker &> /dev/null; then
46
+ echo "📦 安装Docker..."
47
+ curl -fsSL https://get.docker.com -o get-docker.sh
48
+ sudo sh get-docker.sh
49
+ sudo usermod -aG docker $USER
50
+ rm get-docker.sh
51
+ echo -e "${GREEN}✅ Docker安装完成${NC}"
52
+ else
53
+ echo -e "${GREEN}✅ Docker已安装${NC}"
54
+ fi
55
+ }
56
+
57
+ # 安装NVIDIA Container Toolkit
58
+ install_nvidia_docker() {
59
+ if ! docker run --rm --gpus all nvidia/cuda:11.0.3-base-ubuntu20.04 nvidia-smi &> /dev/null; then
60
+ echo "🐳 安装NVIDIA Container Toolkit..."
61
+ distribution=$(. /etc/os-release;echo $ID$VERSION_ID)
62
+ curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg
63
+ curl -s -L https://nvidia.github.io/libnvidia-container/$distribution/libnvidia-container.list | \
64
+ sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \
65
+ sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
66
+
67
+ sudo apt-get update
68
+ sudo apt-get install -y nvidia-container-toolkit
69
+ sudo systemctl restart docker
70
+ echo -e "${GREEN}✅ NVIDIA Container Toolkit安装完成${NC}"
71
+ else
72
+ echo -e "${GREEN}✅ NVIDIA Container Toolkit已配置${NC}"
73
+ fi
74
+ }
75
+
76
+ # 创建环境配置
77
+ setup_env() {
78
+ echo "⚙️ 配置环境变量..."
79
+ if [ ! -f .env ]; then
80
+ cp .env.example .env
81
+ echo -e "${YELLOW}⚠️ 请编辑 .env 文件并设置您的API密钥${NC}"
82
+ echo " - TAVILY_API_KEY: 从 https://tavily.com/ 获取"
83
+ read -p "按回车键继续..."
84
+ fi
85
+ }
86
+
87
+ # 选择部署方式
88
+ choose_deployment() {
89
+ echo "🎯 选择部署方式:"
90
+ echo "1) Docker Compose部署 (推荐)"
91
+ echo "2) 直接Python部署"
92
+ read -p "请选择 (1-2): " choice
93
+
94
+ case $choice in
95
+ 1)
96
+ deploy_docker
97
+ ;;
98
+ 2)
99
+ deploy_python
100
+ ;;
101
+ *)
102
+ echo -e "${RED}无效选择${NC}"
103
+ exit 1
104
+ ;;
105
+ esac
106
+ }
107
+
108
+ # Docker部署
109
+ deploy_docker() {
110
+ echo "🐳 使用Docker Compose部署..."
111
+
112
+ # 安装docker-compose
113
+ if ! command -v docker-compose &> /dev/null; then
114
+ echo "安装Docker Compose..."
115
+ sudo curl -L "https://github.com/docker/compose/releases/download/v2.20.0/docker-compose-$(uname -s)-$(uname -m)" -o /usr/local/bin/docker-compose
116
+ sudo chmod +x /usr/local/bin/docker-compose
117
+ fi
118
+
119
+ # 构建并启动服务
120
+ echo "构建镜像..."
121
+ docker-compose -f docker-compose.gpu.yml build
122
+
123
+ echo "启动服务..."
124
+ docker-compose -f docker-compose.gpu.yml up -d
125
+
126
+ echo -e "${GREEN}✅ Docker部署完成!${NC}"
127
+ echo "访问: http://localhost:8000"
128
+ echo "监控: http://localhost:9445 (GPU监控)"
129
+ echo "日志: docker-compose -f docker-compose.gpu.yml logs -f"
130
+ }
131
+
132
+ # Python直接部署
133
+ deploy_python() {
134
+ echo "🐍 使用Python直接部署..."
135
+
136
+ # 检查Python
137
+ if ! command -v python3 &> /dev/null; then
138
+ echo "安装Python3..."
139
+ sudo apt-get update
140
+ sudo apt-get install -y python3 python3-pip python3-venv
141
+ fi
142
+
143
+ # 创建虚拟环境
144
+ if [ ! -d "rag_env" ]; then
145
+ echo "创建Python虚拟环境..."
146
+ python3 -m venv rag_env
147
+ fi
148
+
149
+ # 激活虚拟环境并安装依赖
150
+ source rag_env/bin/activate
151
+ pip install --upgrade pip
152
+ pip install -r requirements_gpu.txt
153
+
154
+ # 安装Ollama
155
+ if ! command -v ollama &> /dev/null; then
156
+ echo "安装Ollama..."
157
+ curl -fsSL https://ollama.ai/install.sh | sh
158
+ fi
159
+
160
+ # 启动Ollama服务
161
+ echo "启动Ollama服务..."
162
+ ollama serve &
163
+ sleep 5
164
+
165
+ # 下载模型
166
+ echo "下载Mistral模型..."
167
+ ollama pull mistral
168
+
169
+ # 创建启动脚本
170
+ cat > start_gpu.sh << 'EOF'
171
+ #!/bin/bash
172
+ export CUDA_VISIBLE_DEVICES=0
173
+ export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512
174
+ export TOKENIZERS_PARALLELISM=false
175
+ source rag_env/bin/activate
176
+ python main.py
177
+ EOF
178
+ chmod +x start_gpu.sh
179
+
180
+ echo -e "${GREEN}✅ Python部署完成!${NC}"
181
+ echo "启动命令: ./start_gpu.sh"
182
+ }
183
+
184
+ # 验证部署
185
+ verify_deployment() {
186
+ echo "🔍 验证部署..."
187
+ sleep 10
188
+
189
+ if curl -f http://localhost:8000/health 2>/dev/null; then
190
+ echo -e "${GREEN}✅ 服务运行正常${NC}"
191
+ else
192
+ echo -e "${YELLOW}⚠️ 服务可能还在启动中,请稍后检查${NC}"
193
+ fi
194
+
195
+ # 显示GPU使用情况
196
+ echo "📊 GPU状态:"
197
+ nvidia-smi --query-gpu=utilization.gpu,memory.used,memory.total --format=csv,noheader,nounits
198
+ }
199
+
200
+ # 显示部署信息
201
+ show_info() {
202
+ echo ""
203
+ echo "🎉 部署完成!"
204
+ echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
205
+ echo "📊 服务地址:"
206
+ echo " - 主服务: http://localhost:8000"
207
+ echo " - Ollama: http://localhost:11434"
208
+ echo ""
209
+ echo "🔧 常用命令:"
210
+ echo " - 查看日志: docker-compose -f docker-compose.gpu.yml logs -f"
211
+ echo " - 重启服务: docker-compose -f docker-compose.gpu.yml restart"
212
+ echo " - 停止服务: docker-compose -f docker-compose.gpu.yml down"
213
+ echo " - GPU监控: watch -n 1 nvidia-smi"
214
+ echo ""
215
+ echo "📚 文档位置:"
216
+ echo " - 部署指南: DEPLOYMENT_GUIDE.md"
217
+ echo " - 快速开始: QUICKSTART.md"
218
+ echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
219
+ }
220
+
221
+ # 主函数
222
+ main() {
223
+ echo "🤖 自适应RAG系统 GPU部署脚本"
224
+ echo "适用于: Linux + RTX 4090"
225
+ echo ""
226
+
227
+ check_gpu
228
+ check_cuda
229
+ install_docker
230
+ install_nvidia_docker
231
+ setup_env
232
+ choose_deployment
233
+ verify_deployment
234
+ show_info
235
+ }
236
+
237
+ # 运行主函数
238
+ main "$@"
239
+
240
+ reactive
docker-compose.gpu.yml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Docker Compose配置文件 - GPU部署
2
+ version: '3.8'
3
+
4
+ services:
5
+ adaptive-rag:
6
+ build:
7
+ context: .
8
+ dockerfile: Dockerfile.gpu
9
+ container_name: adaptive-rag-gpu
10
+ restart: unless-stopped
11
+ environment:
12
+ - CUDA_VISIBLE_DEVICES=0
13
+ - PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512
14
+ - TOKENIZERS_PARALLELISM=false
15
+ - HF_HOME=/app/models
16
+ - TRANSFORMERS_CACHE=/app/models
17
+ env_file:
18
+ - .env
19
+ ports:
20
+ - "8000:8000"
21
+ - "8001:8001" # 可选:监控端口
22
+ volumes:
23
+ - ./data:/app/data
24
+ - ./models:/app/models
25
+ - ./logs:/app/logs
26
+ deploy:
27
+ resources:
28
+ reservations:
29
+ devices:
30
+ - driver: nvidia
31
+ count: 1
32
+ capabilities: [gpu]
33
+ depends_on:
34
+ - ollama
35
+
36
+ ollama:
37
+ image: ollama/ollama:latest
38
+ container_name: ollama-gpu
39
+ restart: unless-stopped
40
+ ports:
41
+ - "11434:11434"
42
+ volumes:
43
+ - ollama-data:/root/.ollama
44
+ deploy:
45
+ resources:
46
+ reservations:
47
+ devices:
48
+ - driver: nvidia
49
+ count: 1
50
+ capabilities: [gpu]
51
+ command: ["ollama", "serve"]
52
+
53
+ # 可选:监控服务
54
+ nvidia-smi-exporter:
55
+ image: mindprince/nvidia_gpu_prometheus_exporter:0.1
56
+ container_name: gpu-monitor
57
+ restart: unless-stopped
58
+ ports:
59
+ - "9445:9445"
60
+ deploy:
61
+ resources:
62
+ reservations:
63
+ devices:
64
+ - driver: nvidia
65
+ count: 1
66
+ capabilities: [gpu]
67
+
68
+ volumes:
69
+ ollama-data:
70
+ driver: local
document_processor.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 文档处理和向量化模块
3
+ 负责文档加载、文本分块、向量化和向量数据库初始化
4
+ """
5
+
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
+ from langchain_community.document_loaders import WebBaseLoader
8
+ from langchain_community.vectorstores import Chroma
9
+ from langchain_community.embeddings import HuggingFaceEmbeddings
10
+
11
+ from config import (
12
+ KNOWLEDGE_BASE_URLS,
13
+ CHUNK_SIZE,
14
+ CHUNK_OVERLAP,
15
+ COLLECTION_NAME,
16
+ EMBEDDING_MODEL
17
+ )
18
+ from reranker import create_reranker
19
+
20
+
21
+ class DocumentProcessor:
22
+ """文档处理器类,负责文档加载、处理和向量化"""
23
+
24
+ def __init__(self):
25
+ self.text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
26
+ chunk_size=CHUNK_SIZE,
27
+ chunk_overlap=CHUNK_OVERLAP
28
+ )
29
+
30
+ # Try to initialize embeddings with error handling
31
+ try:
32
+ import torch
33
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
34
+ print(f"✅ 检测到设备: {device}")
35
+ if device == 'cuda':
36
+ print(f" GPU型号: {torch.cuda.get_device_name(0)}")
37
+ print(f" GPU内存: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f}GB")
38
+
39
+ self.embeddings = HuggingFaceEmbeddings(
40
+ model_name="sentence-transformers/all-MiniLM-L6-v2", # 轻量级嵌入模型
41
+ model_kwargs={'device': device}, # 自动选择GPU或CPU
42
+ encode_kwargs={'normalize_embeddings': True} # 标准化嵌入向量
43
+ )
44
+ print(f"✅ HuggingFace嵌入模型初始化成功 (设备: {device})")
45
+ except Exception as e:
46
+ print(f"⚠️ HuggingFace嵌入初始化失败: {e}")
47
+ print("正在尝试备用嵌入方案...")
48
+ # Fallback to OpenAI embeddings or other alternatives
49
+ from langchain_community.embeddings import FakeEmbeddings
50
+ self.embeddings = FakeEmbeddings(size=384) # For testing purposes
51
+ print("✅ 使用测试嵌入模型")
52
+
53
+ self.vectorstore = None
54
+ self.retriever = None
55
+
56
+ # 初始化重排器
57
+ self.reranker = None
58
+ self._setup_reranker()
59
+
60
+ def _setup_reranker(self):
61
+ """设置重排器"""
62
+ try:
63
+ # 使用混合重排器获得最佳效果
64
+ self.reranker = create_reranker('hybrid', self.embeddings)
65
+ print("✅ 重排器初始化成功")
66
+ except Exception as e:
67
+ print(f"⚠️ 重排器初始化失败: {e}")
68
+ print("将使用基础检索,不进行重排")
69
+
70
+ def load_documents(self, urls=None):
71
+ """从URL加载文档"""
72
+ if urls is None:
73
+ urls = KNOWLEDGE_BASE_URLS
74
+
75
+ print(f"正在加载 {len(urls)} 个URL的文档...")
76
+ docs = [WebBaseLoader(url).load() for url in urls]
77
+ docs_list = [item for sublist in docs for item in sublist]
78
+ print(f"成功加载 {len(docs_list)} 个文档")
79
+ return docs_list
80
+
81
+ def split_documents(self, docs):
82
+ """将文档分割成块"""
83
+ print("正在分割文档...")
84
+ doc_splits = self.text_splitter.split_documents(docs)
85
+ print(f"文档分割完成,共 {len(doc_splits)} 个文档块")
86
+ return doc_splits
87
+
88
+ def create_vectorstore(self, doc_splits):
89
+ """创建向量数据库"""
90
+ print("正在创建向量数据库...")
91
+ self.vectorstore = Chroma.from_documents(
92
+ documents=doc_splits,
93
+ collection_name=COLLECTION_NAME,
94
+ embedding=self.embeddings,
95
+ )
96
+ self.retriever = self.vectorstore.as_retriever()
97
+ print("向量数据库创建完成")
98
+ return self.vectorstore, self.retriever
99
+
100
+ def setup_knowledge_base(self, urls=None, enable_graphrag=False):
101
+ """设置完整的知识库(加载、分割、向量化)
102
+
103
+ Args:
104
+ urls: 文档URL列表
105
+ enable_graphrag: 是否启用GraphRAG索引
106
+
107
+ Returns:
108
+ vectorstore, retriever, doc_splits
109
+ """
110
+ docs = self.load_documents(urls)
111
+ doc_splits = self.split_documents(docs)
112
+ vectorstore, retriever = self.create_vectorstore(doc_splits)
113
+
114
+ # 返回doc_splits用于GraphRAG索引
115
+ return vectorstore, retriever, doc_splits
116
+
117
+ def enhanced_retrieve(self, query: str, top_k: int = 5, rerank_candidates: int = 20):
118
+ """增强检索:先检索更多候选,然后重排"""
119
+ if not self.retriever:
120
+ print("⚠️ 检索器未初始化")
121
+ return []
122
+
123
+ # 1. 初始检索:获取更多候选文档
124
+ initial_docs = self.retriever.get_relevant_documents(query)
125
+
126
+ # 获取更多候选(如果可能)
127
+ if hasattr(self.retriever, 'search_kwargs'):
128
+ # 修改检索参数以获取更多结果
129
+ original_k = self.retriever.search_kwargs.get('k', 4)
130
+ self.retriever.search_kwargs['k'] = min(rerank_candidates, len(initial_docs))
131
+ candidate_docs = self.retriever.get_relevant_documents(query)
132
+ self.retriever.search_kwargs['k'] = original_k # 恢复原设置
133
+ else:
134
+ candidate_docs = initial_docs
135
+
136
+ print(f"初始检索获得 {len(candidate_docs)} 个候选文档")
137
+
138
+ # 2. 重排(如果重排器可用)
139
+ if self.reranker and len(candidate_docs) > top_k:
140
+ try:
141
+ reranked_results = self.reranker.rerank(query, candidate_docs, top_k)
142
+ final_docs = [doc for doc, score in reranked_results]
143
+ scores = [score for doc, score in reranked_results]
144
+
145
+ print(f"重排后返回 {len(final_docs)} 个文档")
146
+ print(f"重排分数范围: {min(scores):.4f} - {max(scores):.4f}")
147
+
148
+ return final_docs
149
+ except Exception as e:
150
+ print(f"⚠️ 重排失败: {e},使用原始检索结果")
151
+ return candidate_docs[:top_k]
152
+ else:
153
+ # 不重排或候选数量不足
154
+ return candidate_docs[:top_k]
155
+
156
+ def compare_retrieval_methods(self, query: str, top_k: int = 5):
157
+ """比较不同检索方法的效果"""
158
+ if not self.retriever:
159
+ return {}
160
+
161
+ # 原始检索
162
+ original_docs = self.retriever.get_relevant_documents(query)[:top_k]
163
+
164
+ # 增强检索(带重排)
165
+ enhanced_docs = self.enhanced_retrieve(query, top_k)
166
+
167
+ return {
168
+ 'query': query,
169
+ 'original_retrieval': {
170
+ 'count': len(original_docs),
171
+ 'documents': [{
172
+ 'content': doc.page_content[:200] + '...' if len(doc.page_content) > 200 else doc.page_content,
173
+ 'metadata': getattr(doc, 'metadata', {})
174
+ } for doc in original_docs]
175
+ },
176
+ 'enhanced_retrieval': {
177
+ 'count': len(enhanced_docs),
178
+ 'documents': [{
179
+ 'content': doc.page_content[:200] + '...' if len(doc.page_content) > 200 else doc.page_content,
180
+ 'metadata': getattr(doc, 'metadata', {})
181
+ } for doc in enhanced_docs]
182
+ },
183
+ 'reranker_used': self.reranker is not None
184
+ }
185
+
186
+ def format_docs(self, docs):
187
+ """格式化文档用于生成"""
188
+ return "\n\n".join(doc.page_content for doc in docs)
189
+
190
+
191
+ def initialize_document_processor():
192
+ """初始化文档处理器并设置知识库"""
193
+ processor = DocumentProcessor()
194
+ vectorstore, retriever, doc_splits = processor.setup_knowledge_base()
195
+ return processor, vectorstore, retriever, doc_splits
entity_extractor.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 实体和关系提取模块
3
+ 使用LLM从文档中提取实体、关系和属性,构建知识图谱的基础
4
+ """
5
+
6
+ from typing import List, Dict, Tuple
7
+ from langchain.prompts import PromptTemplate
8
+ from langchain_community.chat_models import ChatOllama
9
+ from langchain_core.output_parsers import JsonOutputParser
10
+ from config import LOCAL_LLM
11
+
12
+
13
+ class EntityExtractor:
14
+ """实体提取器 - 使用LLM从文本中提取实体"""
15
+
16
+ def __init__(self):
17
+ self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
18
+
19
+ # 实体提取提示模板
20
+ self.entity_prompt = PromptTemplate(
21
+ template="""你是一个专业的实体识别专家。从以下文本中提取所有重要的实体。
22
+
23
+ 实体类型包括:
24
+ - PERSON: 人物、作者、研究者
25
+ - ORGANIZATION: 组织、机构、公司
26
+ - CONCEPT: 技术概念、算法、方法论
27
+ - TECHNOLOGY: 具体技术、工具、框架
28
+ - PAPER: 论文、出版物
29
+ - EVENT: 事件、会议
30
+
31
+ 文本内容:
32
+ {text}
33
+
34
+ 请以JSON格式返回,包含以下字段:
35
+ {{
36
+ "entities": [
37
+ {{
38
+ "name": "实体名称",
39
+ "type": "实体类型",
40
+ "description": "简短描述"
41
+ }}
42
+ ]
43
+ }}
44
+
45
+ 不要包含前言或解释,只返回JSON。
46
+ """,
47
+ input_variables=["text"]
48
+ )
49
+
50
+ # 关系提取提示模板
51
+ self.relation_prompt = PromptTemplate(
52
+ template="""你是一个关系抽取专家。从文本中识别实体之间的关系。
53
+
54
+ 已识别的实体:
55
+ {entities}
56
+
57
+ 文本内容:
58
+ {text}
59
+
60
+ 请识别实体之间的关系,以JSON格式返回:
61
+ {{
62
+ "relations": [
63
+ {{
64
+ "source": "源实体名称",
65
+ "target": "目标实体名称",
66
+ "relation_type": "关系类型",
67
+ "description": "关系描述"
68
+ }}
69
+ ]
70
+ }}
71
+
72
+ 关系类型包括: AUTHOR_OF, USES, BASED_ON, RELATED_TO, PART_OF, APPLIES_TO, IMPROVES, CITES
73
+
74
+ 不要包含前言或解释,只返回JSON。
75
+ """,
76
+ input_variables=["text", "entities"]
77
+ )
78
+
79
+ self.entity_chain = self.entity_prompt | self.llm | JsonOutputParser()
80
+ self.relation_chain = self.relation_prompt | self.llm | JsonOutputParser()
81
+
82
+ def extract_entities(self, text: str) -> List[Dict]:
83
+ """
84
+ 从文本中提取实体
85
+
86
+ Args:
87
+ text: 输入文本
88
+
89
+ Returns:
90
+ 实体列表
91
+ """
92
+ try:
93
+ result = self.entity_chain.invoke({"text": text[:2000]}) # 限制长度
94
+ entities = result.get("entities", [])
95
+ print(f"✅ 提取到 {len(entities)} 个实体")
96
+ return entities
97
+ except Exception as e:
98
+ print(f"❌ 实体提取失败: {e}")
99
+ return []
100
+
101
+ def extract_relations(self, text: str, entities: List[Dict]) -> List[Dict]:
102
+ """
103
+ 从文本中提取实体关系
104
+
105
+ Args:
106
+ text: 输入文本
107
+ entities: 已识别的实体列表
108
+
109
+ Returns:
110
+ 关系列表
111
+ """
112
+ try:
113
+ entity_names = [e["name"] for e in entities]
114
+ result = self.relation_chain.invoke({
115
+ "text": text[:2000],
116
+ "entities": ", ".join(entity_names)
117
+ })
118
+ relations = result.get("relations", [])
119
+ print(f"✅ 提取到 {len(relations)} 个关系")
120
+ return relations
121
+ except Exception as e:
122
+ print(f"❌ 关系提取失败: {e}")
123
+ return []
124
+
125
+ def extract_from_document(self, document_text: str) -> Dict:
126
+ """
127
+ 从单个文档中提取实体和关系
128
+
129
+ Args:
130
+ document_text: 文档文本
131
+
132
+ Returns:
133
+ 包含实体和关系的字典
134
+ """
135
+ print("🔍 开始提取实体...")
136
+ entities = self.extract_entities(document_text)
137
+
138
+ print("🔍 开始提取关系...")
139
+ relations = self.extract_relations(document_text, entities)
140
+
141
+ return {
142
+ "entities": entities,
143
+ "relations": relations
144
+ }
145
+
146
+
147
+ class EntityDeduplicator:
148
+ """实体去重和合并"""
149
+
150
+ def __init__(self):
151
+ self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
152
+
153
+ self.merge_prompt = PromptTemplate(
154
+ template="""判断以下两个实体是否指向同一个对象:
155
+
156
+ 实体1: {entity1_name} - {entity1_desc}
157
+ 实体2: {entity2_name} - {entity2_desc}
158
+
159
+ 如果是同一个对象,返回:
160
+ {{
161
+ "is_same": true,
162
+ "canonical_name": "标准名称",
163
+ "reason": "原因"
164
+ }}
165
+
166
+ 如果不是,返回:
167
+ {{
168
+ "is_same": false,
169
+ "reason": "原因"
170
+ }}
171
+
172
+ 只返回JSON,不要其他内容。
173
+ """,
174
+ input_variables=["entity1_name", "entity1_desc", "entity2_name", "entity2_desc"]
175
+ )
176
+
177
+ self.merge_chain = self.merge_prompt | self.llm | JsonOutputParser()
178
+
179
+ def deduplicate_entities(self, entities: List[Dict]) -> Dict:
180
+ """
181
+ 去重实体列表
182
+
183
+ Args:
184
+ entities: 实体列表
185
+
186
+ Returns:
187
+ 包含entities和mapping的字典
188
+ """
189
+ if len(entities) <= 1:
190
+ # 返回字典格式,保持一致性
191
+ entity_mapping = {entity["name"]: entity["name"] for entity in entities} if entities else {}
192
+ return {
193
+ "entities": entities,
194
+ "mapping": entity_mapping
195
+ }
196
+
197
+ print(f"🔄 开始去重 {len(entities)} 个实体...")
198
+
199
+ # 简单的基于名称的去重
200
+ unique_entities = {}
201
+ entity_mapping = {} # 映射别名到标准名称
202
+
203
+ for entity in entities:
204
+ name = entity["name"].lower().strip()
205
+
206
+ # 查找是否有相似实体
207
+ merged = False
208
+ for canonical_name, canonical_entity in unique_entities.items():
209
+ # 简单的字符串匹配(可以用LLM做更智能的判断)
210
+ if name in canonical_name or canonical_name in name:
211
+ entity_mapping[entity["name"]] = canonical_name
212
+ merged = True
213
+ break
214
+
215
+ if not merged:
216
+ unique_entities[name] = entity
217
+ entity_mapping[entity["name"]] = name
218
+
219
+ print(f"✅ 去重完成,剩余 {len(unique_entities)} 个唯一实体")
220
+
221
+ return {
222
+ "entities": list(unique_entities.values()),
223
+ "mapping": entity_mapping
224
+ }
225
+
226
+
227
+ def initialize_entity_extractor():
228
+ """初始化实体提取器"""
229
+ return EntityExtractor()
graph_indexer.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GraphRAG索引器
3
+ 负责构建层次化的知识图谱索引,包括实体提取、图谱构建、社区检测和摘要生成
4
+ """
5
+
6
+ from typing import List, Dict
7
+ from langchain.schema import Document
8
+
9
+ from entity_extractor import EntityExtractor, EntityDeduplicator
10
+ from knowledge_graph import KnowledgeGraph, CommunitySummarizer
11
+
12
+
13
+ class GraphRAGIndexer:
14
+ """GraphRAG索引器 - 实现Microsoft GraphRAG的索引流程"""
15
+
16
+ def __init__(self):
17
+ print("🚀 初始化GraphRAG索引器...")
18
+
19
+ self.entity_extractor = EntityExtractor()
20
+ self.entity_deduplicator = EntityDeduplicator()
21
+ self.knowledge_graph = KnowledgeGraph()
22
+ self.community_summarizer = CommunitySummarizer()
23
+
24
+ self.indexed = False
25
+
26
+ print("✅ GraphRAG索引器初始化完成")
27
+
28
+ def index_documents(self, documents: List[Document],
29
+ batch_size: int = 10,
30
+ save_path: str = None) -> KnowledgeGraph:
31
+ """
32
+ 对文档集合建立GraphRAG索引
33
+
34
+ 工作流程(遵循Microsoft GraphRAG):
35
+ 1. 文档分块(已在document_processor中完成)
36
+ 2. 实体和关系提取
37
+ 3. 实体去重和合并
38
+ 4. 构建知识图谱
39
+ 5. 社区检测
40
+ 6. 生成社区摘要
41
+
42
+ Args:
43
+ documents: 文档列表
44
+ batch_size: 批处理大小
45
+ save_path: 保存路径
46
+
47
+ Returns:
48
+ 构建好的知识图谱
49
+ """
50
+ print(f"\n{'='*50}")
51
+ print(f"📊 开始GraphRAG索引流程")
52
+ print(f" 文档数量: {len(documents)}")
53
+ print(f"{'='*50}\n")
54
+
55
+ # 步骤1: 实体和关系提取
56
+ print("📍 步骤 1/5: 实体和关系提取")
57
+ extraction_results = []
58
+
59
+ for i in range(0, len(documents), batch_size):
60
+ batch = documents[i:i+batch_size]
61
+ print(f" 处理批次 {i//batch_size + 1}/{(len(documents)-1)//batch_size + 1}...")
62
+
63
+ for doc in batch:
64
+ result = self.entity_extractor.extract_from_document(doc.page_content)
65
+ extraction_results.append(result)
66
+
67
+ # 步骤2: 实体去重
68
+ print("\n📍 步骤 2/5: 实体去重和合并")
69
+ all_entities = []
70
+ all_relations = []
71
+
72
+ for result in extraction_results:
73
+ all_entities.extend(result.get("entities", []))
74
+ all_relations.extend(result.get("relations", []))
75
+
76
+ dedup_result = self.entity_deduplicator.deduplicate_entities(all_entities)
77
+ unique_entities = dedup_result["entities"]
78
+ entity_mapping = dedup_result["mapping"]
79
+
80
+ # 更新关系中的实体名称
81
+ mapped_relations = []
82
+ for relation in all_relations:
83
+ source = entity_mapping.get(relation["source"], relation["source"])
84
+ target = entity_mapping.get(relation["target"], relation["target"])
85
+ mapped_relations.append({
86
+ **relation,
87
+ "source": source,
88
+ "target": target
89
+ })
90
+
91
+ # 步骤3: 构建知识图谱
92
+ print("\n📍 步骤 3/5: 构建知识图谱")
93
+ cleaned_results = [{
94
+ "entities": unique_entities,
95
+ "relations": mapped_relations
96
+ }]
97
+ self.knowledge_graph.build_from_extractions(cleaned_results)
98
+
99
+ # 步骤4: 社区检测
100
+ print("\n📍 步骤 4/5: 社区检测")
101
+ self.knowledge_graph.detect_communities(algorithm="louvain")
102
+
103
+ # 步骤5: 生成社区摘要
104
+ print("\n📍 步骤 5/5: 生成社区摘要")
105
+ self.community_summarizer.summarize_all_communities(self.knowledge_graph)
106
+
107
+ # 保存图谱
108
+ if save_path:
109
+ self.knowledge_graph.save_to_file(save_path)
110
+
111
+ self.indexed = True
112
+
113
+ # 打印统计信息
114
+ print(f"\n{'='*50}")
115
+ print("✅ GraphRAG索引构建完成!")
116
+ stats = self.knowledge_graph.get_statistics()
117
+ print(f"\n📊 统计信息:")
118
+ print(f" - 节点数: {stats['num_nodes']}")
119
+ print(f" - 边数: {stats['num_edges']}")
120
+ print(f" - 社区数: {stats['num_communities']}")
121
+ print(f" - 图密度: {stats['density']:.4f}")
122
+ print(f"\n 实体类型分布:")
123
+ for etype, count in stats['entity_types'].items():
124
+ print(f" • {etype}: {count}")
125
+ print(f"{'='*50}\n")
126
+
127
+ return self.knowledge_graph
128
+
129
+ def get_graph(self) -> KnowledgeGraph:
130
+ """获取知识图谱"""
131
+ if not self.indexed:
132
+ print("⚠️ 图谱尚未构建,请先调用 index_documents()")
133
+ return self.knowledge_graph
134
+
135
+ def load_index(self, filepath: str) -> KnowledgeGraph:
136
+ """加载已有的图谱索引"""
137
+ print(f"📂 从文件加载图谱索引: {filepath}")
138
+ self.knowledge_graph.load_from_file(filepath)
139
+ self.indexed = True
140
+ return self.knowledge_graph
141
+
142
+
143
+ def initialize_graph_indexer():
144
+ """初始化GraphRAG索引器"""
145
+ return GraphRAGIndexer()
graph_retriever.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GraphRAG检索器
3
+ 实现基于知识图谱的检索策略,包括本地查询和全局查询
4
+ """
5
+
6
+ from typing import List, Dict, Set, Tuple
7
+ from langchain.prompts import PromptTemplate
8
+ from langchain_community.chat_models import ChatOllama
9
+ from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
10
+
11
+ from knowledge_graph import KnowledgeGraph
12
+ from config import LOCAL_LLM
13
+
14
+
15
+ class GraphRetriever:
16
+ """基于知识图谱的检索器"""
17
+
18
+ def __init__(self, knowledge_graph: KnowledgeGraph):
19
+ self.kg = knowledge_graph
20
+ self.llm = ChatOllama(model=LOCAL_LLM, temperature=0.3)
21
+
22
+ # 实体识别提示
23
+ self.entity_recognition_prompt = PromptTemplate(
24
+ template="""从以下问题中识别关键实体和概念:
25
+
26
+ 问题: {question}
27
+
28
+ 已知实体示例: {sample_entities}
29
+
30
+ 请识别问题中提到的实体,返回JSON格式:
31
+ {{
32
+ "entities": ["实体1", "实体2", ...]
33
+ }}
34
+
35
+ 只返回JSON,不要其他内容。
36
+ """,
37
+ input_variables=["question", "sample_entities"]
38
+ )
39
+
40
+ # 全局查询生成提示
41
+ self.global_query_prompt = PromptTemplate(
42
+ template="""你是一个知识图谱分析专家。基于以下社区摘要,回答用户问题。
43
+
44
+ 用户问题: {question}
45
+
46
+ 相关社区摘要:
47
+ {community_summaries}
48
+
49
+ 请基于这些摘要提供一个综合性的答案。如果摘要中没有相关信息,请说明。
50
+
51
+ 答案:
52
+ """,
53
+ input_variables=["question", "community_summaries"]
54
+ )
55
+
56
+ # 本地查询生成提示
57
+ self.local_query_prompt = PromptTemplate(
58
+ template="""基于以下实体及其关系信息,回答用户问题。
59
+
60
+ 用户问题: {question}
61
+
62
+ 相关实体信息:
63
+ {entity_info}
64
+
65
+ 实体间的关系:
66
+ {relations}
67
+
68
+ 请基于这些信息提供答案。
69
+
70
+ 答案:
71
+ """,
72
+ input_variables=["question", "entity_info", "relations"]
73
+ )
74
+
75
+ self.entity_recognition_chain = self.entity_recognition_prompt | self.llm | JsonOutputParser()
76
+ self.global_query_chain = self.global_query_prompt | self.llm | StrOutputParser()
77
+ self.local_query_chain = self.local_query_prompt | self.llm | StrOutputParser()
78
+
79
+ def recognize_entities(self, question: str) -> List[str]:
80
+ """
81
+ 从问题中识别实体
82
+
83
+ Args:
84
+ question: 用户问题
85
+
86
+ Returns:
87
+ 识别到的实体列表
88
+ """
89
+ # 获取一些示例实体
90
+ sample_entities = list(self.kg.entities.keys())[:10]
91
+ sample_text = ", ".join(sample_entities)
92
+
93
+ try:
94
+ result = self.entity_recognition_chain.invoke({
95
+ "question": question,
96
+ "sample_entities": sample_text
97
+ })
98
+ entities = result.get("entities", [])
99
+
100
+ # 匹配到图谱中的实体
101
+ matched_entities = []
102
+ for entity in entities:
103
+ # 精确匹配
104
+ if entity in self.kg.entities:
105
+ matched_entities.append(entity)
106
+ else:
107
+ # 模糊匹配
108
+ for kg_entity in self.kg.entities.keys():
109
+ if entity.lower() in kg_entity.lower() or kg_entity.lower() in entity.lower():
110
+ matched_entities.append(kg_entity)
111
+ break
112
+
113
+ print(f"🔍 识别到实体: {matched_entities}")
114
+ return matched_entities
115
+
116
+ except Exception as e:
117
+ print(f"❌ 实体识别失败: {e}")
118
+ return []
119
+
120
+ def local_query(self, question: str, max_hops: int = 2, top_k: int = 10) -> str:
121
+ """
122
+ 本地查询 - 基于问题中的实体及其邻域进行检索
123
+
124
+ 适用场景: 针对特定实体的详细问题
125
+ 例如: "AlphaCodium的作者是谁?"
126
+
127
+ Args:
128
+ question: 用户问题
129
+ max_hops: 最大跳数
130
+ top_k: 返回的最大实体数
131
+
132
+ Returns:
133
+ 答案文本
134
+ """
135
+ print(f"\n🔎 执行本地查询...")
136
+
137
+ # 1. 识别问题中的实体
138
+ mentioned_entities = self.recognize_entities(question)
139
+
140
+ if not mentioned_entities:
141
+ return "未能在知识图谱中找到相关实体。"
142
+
143
+ # 2. 获取实体的邻域
144
+ relevant_entities = set()
145
+ for entity in mentioned_entities:
146
+ neighbors = self.kg.get_node_neighbors(entity, depth=max_hops)
147
+ relevant_entities.update(neighbors)
148
+
149
+ relevant_entities = list(relevant_entities)[:top_k]
150
+
151
+ # 3. 收集实体信息
152
+ entity_info_list = []
153
+ for entity in relevant_entities:
154
+ info = self.kg.get_entity_info(entity)
155
+ if info:
156
+ entity_info_list.append(
157
+ f"- {info['name']} ({info.get('type', 'UNKNOWN')}): {info.get('description', '无描述')}"
158
+ )
159
+
160
+ # 4. 收集关系信息
161
+ relation_list = []
162
+ for u, v, data in self.kg.graph.edges(data=True):
163
+ if u in relevant_entities and v in relevant_entities:
164
+ relation_list.append(
165
+ f"- {u} --[{data.get('relation_type', 'RELATED')}]--> {v}: {data.get('description', '')}"
166
+ )
167
+
168
+ entity_info_text = "\n".join(entity_info_list) if entity_info_list else "无相关实体信息"
169
+ relations_text = "\n".join(relation_list[:20]) if relation_list else "无相关关系"
170
+
171
+ # 5. 生成答案
172
+ try:
173
+ answer = self.local_query_chain.invoke({
174
+ "question": question,
175
+ "entity_info": entity_info_text,
176
+ "relations": relations_text
177
+ })
178
+ print(f"✅ 本地查询完成")
179
+ return answer.strip()
180
+ except Exception as e:
181
+ print(f"❌ 本地查询失败: {e}")
182
+ return "查询失败,请重试。"
183
+
184
+ def global_query(self, question: str, top_k_communities: int = 5) -> str:
185
+ """
186
+ 全局查询 - 基于社区摘要进行高层次查询
187
+
188
+ 适用场景: 需要整体理解的概括性问题
189
+ 例如: "这些文档主要讨论什么主题?"
190
+
191
+ Args:
192
+ question: 用户问题
193
+ top_k_communities: 使用的社区数量
194
+
195
+ Returns:
196
+ 答案文本
197
+ """
198
+ print(f"\n🌍 执行全局查询...")
199
+
200
+ if not self.kg.community_summaries:
201
+ return "知识图谱尚未生成社区摘要,请先运行索引流程。"
202
+
203
+ # 获取社区摘要
204
+ community_summaries = []
205
+ for cid, summary in list(self.kg.community_summaries.items())[:top_k_communities]:
206
+ community_summaries.append(f"社区 {cid}:\n{summary}\n")
207
+
208
+ summaries_text = "\n".join(community_summaries)
209
+
210
+ # 生成答案
211
+ try:
212
+ answer = self.global_query_chain.invoke({
213
+ "question": question,
214
+ "community_summaries": summaries_text
215
+ })
216
+ print(f"✅ 全局查询完成")
217
+ return answer.strip()
218
+ except Exception as e:
219
+ print(f"❌ 全局查询失败: {e}")
220
+ return "查询失败,请重试。"
221
+
222
+ def hybrid_query(self, question: str) -> Dict[str, str]:
223
+ """
224
+ 混合查询 - 同时执行本地和全局查询,返回两种结果
225
+
226
+ Args:
227
+ question: 用户问题
228
+
229
+ Returns:
230
+ 包含本地和全局查询结果的字典
231
+ """
232
+ print(f"\n🔀 执行混合查询...")
233
+
234
+ local_answer = self.local_query(question)
235
+ global_answer = self.global_query(question)
236
+
237
+ return {
238
+ "local": local_answer,
239
+ "global": global_answer,
240
+ "question": question
241
+ }
242
+
243
+ def smart_query(self, question: str) -> str:
244
+ """
245
+ 智能查询 - 根据问题类型自动选择查询策略
246
+
247
+ Args:
248
+ question: 用户问题
249
+
250
+ Returns:
251
+ 答案文本
252
+ """
253
+ # 判断问题类型
254
+ question_lower = question.lower()
255
+
256
+ # 包含具体实体名称的问题 -> 本地查询
257
+ mentioned_entities = self.recognize_entities(question)
258
+ if mentioned_entities:
259
+ print("📍 检测到具体实体,使用本地查询")
260
+ return self.local_query(question)
261
+
262
+ # 概括性问题 -> 全局查询
263
+ global_keywords = ["主要", "总体", "概述", "整体", "主题", "讨论", "内容", "what", "overview", "main", "topics"]
264
+ if any(keyword in question_lower for keyword in global_keywords):
265
+ print("🌐 检测到概括性问题,使用全局查询")
266
+ return self.global_query(question)
267
+
268
+ # 默认使用本地查询
269
+ print("📍 使用本地查询作为默认策略")
270
+ return self.local_query(question)
271
+
272
+
273
+ def initialize_graph_retriever(knowledge_graph: KnowledgeGraph):
274
+ """初始化GraphRAG检索器"""
275
+ return GraphRetriever(knowledge_graph)
knowledge_graph.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 知识图谱模块
3
+ 实现GraphRAG的核心功能:图谱构建、社区检测、层次化摘要
4
+ """
5
+
6
+ import networkx as nx
7
+ from typing import List, Dict, Set, Tuple, Optional
8
+ from collections import defaultdict
9
+ import json
10
+
11
+ try:
12
+ from community import community_louvain # python-louvain
13
+ LOUVAIN_AVAILABLE = True
14
+ except ImportError:
15
+ LOUVAIN_AVAILABLE = False
16
+ print("⚠️ python-louvain未安装,社区检测功能受限")
17
+
18
+ from langchain.prompts import PromptTemplate
19
+ from langchain_community.chat_models import ChatOllama
20
+ from langchain_core.output_parsers import StrOutputParser
21
+ from config import LOCAL_LLM
22
+
23
+
24
+ class KnowledgeGraph:
25
+ """知识图谱类 - 使用NetworkX构建和管理图谱"""
26
+
27
+ def __init__(self):
28
+ self.graph = nx.Graph() # 无向图
29
+ self.entities = {} # 实体详细信息
30
+ self.communities = {} # 社区划分结果
31
+ self.community_summaries = {} # 社区摘要
32
+
33
+ def add_entity(self, name: str, entity_type: str, description: str = "", **kwargs):
34
+ """添加实体节点"""
35
+ self.graph.add_node(
36
+ name,
37
+ type=entity_type,
38
+ description=description,
39
+ **kwargs
40
+ )
41
+ self.entities[name] = {
42
+ "name": name,
43
+ "type": entity_type,
44
+ "description": description,
45
+ **kwargs
46
+ }
47
+
48
+ def add_relation(self, source: str, target: str, relation_type: str,
49
+ description: str = "", weight: float = 1.0):
50
+ """添加关系边"""
51
+ self.graph.add_edge(
52
+ source,
53
+ target,
54
+ relation_type=relation_type,
55
+ description=description,
56
+ weight=weight
57
+ )
58
+
59
+ def build_from_extractions(self, extraction_results: List[Dict]):
60
+ """
61
+ 从实体提取结果构建图谱
62
+
63
+ Args:
64
+ extraction_results: 实体和关系提取结果列表
65
+ """
66
+ print("🔨 开始构建知识图谱...")
67
+
68
+ total_entities = 0
69
+ total_relations = 0
70
+
71
+ for result in extraction_results:
72
+ # 添加实体
73
+ entities = result.get("entities", [])
74
+ for entity in entities:
75
+ self.add_entity(
76
+ name=entity["name"],
77
+ entity_type=entity.get("type", "UNKNOWN"),
78
+ description=entity.get("description", "")
79
+ )
80
+ total_entities += 1
81
+
82
+ # 添加关系
83
+ relations = result.get("relations", [])
84
+ for relation in relations:
85
+ source = relation.get("source")
86
+ target = relation.get("target")
87
+
88
+ # 确保节点存在
89
+ if source in self.graph and target in self.graph:
90
+ self.add_relation(
91
+ source=source,
92
+ target=target,
93
+ relation_type=relation.get("relation_type", "RELATED_TO"),
94
+ description=relation.get("description", "")
95
+ )
96
+ total_relations += 1
97
+
98
+ print(f"✅ 图谱构建完成: {total_entities} 个实体, {total_relations} 个关系")
99
+ print(f" 实际节点数: {self.graph.number_of_nodes()}")
100
+ print(f" 实际边数: {self.graph.number_of_edges()}")
101
+
102
+ def detect_communities(self, algorithm: str = "louvain") -> Dict[str, int]:
103
+ """
104
+ 社区检测 - GraphRAG的核心组件
105
+
106
+ Args:
107
+ algorithm: 社区检测算法 ('louvain', 'greedy', 'label_propagation')
108
+
109
+ Returns:
110
+ 节点到社区ID的映射
111
+ """
112
+ print(f"🔍 开始社区检测 (算法: {algorithm})...")
113
+
114
+ if self.graph.number_of_nodes() == 0:
115
+ print("⚠️ 图谱为空,跳过社区检测")
116
+ return {}
117
+
118
+ try:
119
+ if algorithm == "louvain" and LOUVAIN_AVAILABLE:
120
+ communities = community_louvain.best_partition(self.graph)
121
+ elif algorithm == "greedy":
122
+ communities_generator = nx.community.greedy_modularity_communities(self.graph)
123
+ communities = {}
124
+ for idx, community_set in enumerate(communities_generator):
125
+ for node in community_set:
126
+ communities[node] = idx
127
+ elif algorithm == "label_propagation":
128
+ communities_generator = nx.community.label_propagation_communities(self.graph)
129
+ communities = {}
130
+ for idx, community_set in enumerate(communities_generator):
131
+ for node in community_set:
132
+ communities[node] = idx
133
+ else:
134
+ print(f"⚠️ 未知算法 {algorithm},使用贪婪算法")
135
+ communities_generator = nx.community.greedy_modularity_communities(self.graph)
136
+ communities = {}
137
+ for idx, community_set in enumerate(communities_generator):
138
+ for node in community_set:
139
+ communities[node] = idx
140
+
141
+ self.communities = communities
142
+ num_communities = len(set(communities.values()))
143
+ print(f"✅ 检测到 {num_communities} 个社区")
144
+
145
+ return communities
146
+
147
+ except Exception as e:
148
+ print(f"❌ 社区检测失败: {e}")
149
+ return {}
150
+
151
+ def get_community_members(self, community_id: int) -> List[str]:
152
+ """获取指定社区的所有成员"""
153
+ return [node for node, cid in self.communities.items() if cid == community_id]
154
+
155
+ def get_community_subgraph(self, community_id: int) -> nx.Graph:
156
+ """获取指定社区的子图"""
157
+ members = self.get_community_members(community_id)
158
+ return self.graph.subgraph(members)
159
+
160
+ def get_node_neighbors(self, node: str, depth: int = 1) -> Set[str]:
161
+ """获取节点的邻居(支持多跳)"""
162
+ if node not in self.graph:
163
+ return set()
164
+
165
+ neighbors = {node}
166
+ current_layer = {node}
167
+
168
+ for _ in range(depth):
169
+ next_layer = set()
170
+ for n in current_layer:
171
+ next_layer.update(self.graph.neighbors(n))
172
+ neighbors.update(next_layer)
173
+ current_layer = next_layer
174
+
175
+ return neighbors
176
+
177
+ def get_entity_info(self, entity_name: str) -> Optional[Dict]:
178
+ """获取实体详细信息"""
179
+ return self.entities.get(entity_name)
180
+
181
+ def search_entities_by_type(self, entity_type: str) -> List[str]:
182
+ """按类型搜索实体"""
183
+ return [
184
+ name for name, data in self.entities.items()
185
+ if data.get("type") == entity_type
186
+ ]
187
+
188
+ def get_statistics(self) -> Dict:
189
+ """获取图谱统计信息"""
190
+ stats = {
191
+ "num_nodes": self.graph.number_of_nodes(),
192
+ "num_edges": self.graph.number_of_edges(),
193
+ "num_communities": len(set(self.communities.values())) if self.communities else 0,
194
+ "density": nx.density(self.graph) if self.graph.number_of_nodes() > 0 else 0,
195
+ "entity_types": {}
196
+ }
197
+
198
+ # 统计实体类型分布
199
+ for entity in self.entities.values():
200
+ etype = entity.get("type", "UNKNOWN")
201
+ stats["entity_types"][etype] = stats["entity_types"].get(etype, 0) + 1
202
+
203
+ return stats
204
+
205
+ def save_to_file(self, filepath: str):
206
+ """保存图谱到文件"""
207
+ data = {
208
+ "entities": self.entities,
209
+ "edges": [
210
+ {
211
+ "source": u,
212
+ "target": v,
213
+ "data": data
214
+ }
215
+ for u, v, data in self.graph.edges(data=True)
216
+ ],
217
+ "communities": self.communities,
218
+ "community_summaries": self.community_summaries
219
+ }
220
+
221
+ with open(filepath, 'w', encoding='utf-8') as f:
222
+ json.dump(data, f, ensure_ascii=False, indent=2)
223
+
224
+ print(f"✅ 图谱已保存到: {filepath}")
225
+
226
+ def load_from_file(self, filepath: str):
227
+ """从文件加载图谱"""
228
+ with open(filepath, 'r', encoding='utf-8') as f:
229
+ data = json.load(f)
230
+
231
+ self.entities = data.get("entities", {})
232
+ self.communities = data.get("communities", {})
233
+ self.community_summaries = data.get("community_summaries", {})
234
+
235
+ # 重建图
236
+ self.graph.clear()
237
+ for name, entity in self.entities.items():
238
+ self.add_entity(**entity)
239
+
240
+ for edge in data.get("edges", []):
241
+ self.graph.add_edge(
242
+ edge["source"],
243
+ edge["target"],
244
+ **edge["data"]
245
+ )
246
+
247
+ print(f"✅ 图谱已从文件加载: {filepath}")
248
+
249
+
250
+ class CommunitySummarizer:
251
+ """社区摘要生成器 - GraphRAG的关键组件"""
252
+
253
+ def __init__(self):
254
+ self.llm = ChatOllama(model=LOCAL_LLM, temperature=0.3)
255
+
256
+ self.summary_prompt = PromptTemplate(
257
+ template="""你是一个知识图谱分析专家。请为以下社区生成一个综合摘要。
258
+
259
+ 社区成员(实体):
260
+ {entities}
261
+
262
+ 实体间的关系:
263
+ {relations}
264
+
265
+ 请生成一个简洁的摘要,描述:
266
+ 1. 这个社区的主题是什么
267
+ 2. 主要包含哪些核心概念
268
+ 3. 实体之间的关键关系
269
+
270
+ 摘要(2-3句话):
271
+ """,
272
+ input_variables=["entities", "relations"]
273
+ )
274
+
275
+ self.summary_chain = self.summary_prompt | self.llm | StrOutputParser()
276
+
277
+ def summarize_community(self, kg: KnowledgeGraph, community_id: int) -> str:
278
+ """
279
+ 为指定社区生成摘要
280
+
281
+ Args:
282
+ kg: 知识图谱对象
283
+ community_id: 社区ID
284
+
285
+ Returns:
286
+ 社区摘要文本
287
+ """
288
+ members = kg.get_community_members(community_id)
289
+ subgraph = kg.get_community_subgraph(community_id)
290
+
291
+ # 准备实体信息
292
+ entity_info = []
293
+ for member in members[:20]: # 限制数量
294
+ info = kg.get_entity_info(member)
295
+ if info:
296
+ entity_info.append(
297
+ f"- {info['name']} ({info.get('type', 'UNKNOWN')}): {info.get('description', '无描述')}"
298
+ )
299
+
300
+ # 准备关系信息
301
+ relation_info = []
302
+ for u, v, data in subgraph.edges(data=True):
303
+ relation_info.append(
304
+ f"- {u} --[{data.get('relation_type', 'RELATED')}]--> {v}"
305
+ )
306
+
307
+ entities_text = "\n".join(entity_info) if entity_info else "无实体"
308
+ relations_text = "\n".join(relation_info[:15]) if relation_info else "无关系"
309
+
310
+ try:
311
+ summary = self.summary_chain.invoke({
312
+ "entities": entities_text,
313
+ "relations": relations_text
314
+ })
315
+ return summary.strip()
316
+ except Exception as e:
317
+ print(f"❌ 社区 {community_id} 摘要生成失败: {e}")
318
+ return f"社区{community_id}: 包含{len(members)}个实体"
319
+
320
+ def summarize_all_communities(self, kg: KnowledgeGraph) -> Dict[int, str]:
321
+ """为所有社区生成摘要"""
322
+ if not kg.communities:
323
+ print("⚠️ 未检测到社区,请先运行社区检测")
324
+ return {}
325
+
326
+ community_ids = set(kg.communities.values())
327
+ print(f"📝 开始为 {len(community_ids)} 个社区生成摘要...")
328
+
329
+ summaries = {}
330
+ for cid in community_ids:
331
+ print(f" 处理社区 {cid}...")
332
+ summary = self.summarize_community(kg, cid)
333
+ summaries[cid] = summary
334
+ kg.community_summaries[cid] = summary
335
+
336
+ print("✅ 所有社区摘要生成完成")
337
+ return summaries
338
+
339
+
340
+ def initialize_knowledge_graph():
341
+ """初始化知识图谱"""
342
+ return KnowledgeGraph()
343
+
344
+
345
+ def initialize_community_summarizer():
346
+ """初始化社区摘要生成器"""
347
+ return CommunitySummarizer()
local_llm_rag.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ """
4
+ 源文件:local_llm_rag.py
5
+ This is a simple example of how to use LangChain to build a local LLM RAG system.
6
+ """
7
+ import getpass
8
+ import os
9
+
10
+
11
+ def _set_env(var: str):
12
+ if not os.environ.get(var):
13
+ os.environ[var] = getpass.getpass(f"{var}: ")
14
+
15
+
16
+ _set_env("TAVILY_API_KEY")
17
+ _set_env("NOMIC_API_KEY")
18
+
19
+ # Ollama model name
20
+ local_llm = "mistral"
21
+
22
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
23
+ from langchain_community.document_loaders import WebBaseLoader
24
+ from langchain_community.vectorstores import Chroma
25
+ from langchain_nomic.embeddings import NomicEmbeddings
26
+
27
+ urls = [
28
+ "https://lilianweng.github.io/posts/2023-06-23-agent/",
29
+ "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
30
+ "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
31
+ ]
32
+
33
+ docs = [WebBaseLoader(url).load() for url in urls]
34
+ docs_list = [item for sublist in docs for item in sublist]
35
+
36
+ text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
37
+ chunk_size=250, chunk_overlap=0
38
+ )
39
+ doc_splits = text_splitter.split_documents(docs_list)
40
+
41
+ # Add to vectorDB
42
+ vectorstore = Chroma.from_documents(
43
+ documents=doc_splits,
44
+ collection_name="rag-chroma",
45
+ embedding=NomicEmbeddings(model="nomic-embed-text-v1.5", inference_mode="local"),
46
+ )
47
+ retriever = vectorstore.as_retriever()
48
+
49
+ ### Router
50
+
51
+ from langchain.prompts import PromptTemplate
52
+ from langchain_community.chat_models import ChatOllama
53
+ from langchain_core.output_parsers import JsonOutputParser
54
+
55
+ # LLM
56
+ llm = ChatOllama(model=local_llm, format="json", temperature=0)
57
+
58
+ prompt = PromptTemplate(
59
+ template="""You are an expert at routing a user question to a vectorstore or web search. \n
60
+ Use the vectorstore for questions on LLM agents, prompt engineering, and adversarial attacks. \n
61
+ You do not need to be stringent with the keywords in the question related to these topics. \n
62
+ Otherwise, use web-search. Give a binary choice 'web_search' or 'vectorstore' based on the question. \n
63
+ Return the a JSON with a single key 'datasource' and no premable or explanation. \n
64
+ Question to route: {question}""",
65
+ input_variables=["question"],
66
+ )
67
+
68
+ question_router = prompt | llm | JsonOutputParser()
69
+ question = "llm agent memory"
70
+ docs = retriever.get_relevant_documents(question)
71
+ doc_txt = docs[1].page_content
72
+ print(question_router.invoke({"question": question}))
73
+
74
+
75
+ ### Generate
76
+
77
+ from langchain import hub
78
+ from langchain_community.chat_models import ChatOllama
79
+ from langchain_core.output_parsers import StrOutputParser
80
+
81
+ # Prompt
82
+ prompt = hub.pull("rlm/rag-prompt")
83
+
84
+ # LLM
85
+ llm = ChatOllama(model=local_llm, temperature=0)
86
+
87
+
88
+ # Post-processing
89
+ def format_docs(docs):
90
+ return "\n\n".join(doc.page_content for doc in docs)
91
+
92
+
93
+ # Chain
94
+ rag_chain = prompt | llm | StrOutputParser()
95
+
96
+ # Run
97
+ question = "agent memory"
98
+ generation = rag_chain.invoke({"context": docs, "question": question})
99
+ print(generation)
100
+
101
+ ### Answer Grader
102
+
103
+ # LLM
104
+ llm = ChatOllama(model=local_llm, format="json", temperature=0)
105
+
106
+ # Prompt
107
+ prompt = PromptTemplate(
108
+ template="""You are a grader assessing whether an answer is useful to resolve a question. \n
109
+ Here is the answer:
110
+ \n ------- \n
111
+ {generation}
112
+ \n ------- \n
113
+ Here is the question: {question}
114
+ Give a binary score 'yes' or 'no' to indicate whether the answer is useful to resolve a question. \n
115
+ Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.""",
116
+ input_variables=["generation", "question"],
117
+ )
118
+
119
+ answer_grader = prompt | llm | JsonOutputParser()
120
+ answer_grader.invoke({"question": question, "generation": generation})
121
+
122
+ ### Question Re-writer
123
+
124
+ # LLM
125
+ llm = ChatOllama(model=local_llm, temperature=0)
126
+
127
+ # Prompt
128
+ re_write_prompt = PromptTemplate(
129
+ template="""You a question re-writer that converts an input question to a better version that is optimized \n
130
+ for vectorstore retrieval. Look at the initial and formulate an improved question. \n
131
+ Here is the initial question: \n\n {question}. Improved question with no preamble: \n """,
132
+ input_variables=["generation", "question"],
133
+ )
134
+
135
+ question_rewriter = re_write_prompt | llm | StrOutputParser()
136
+ question_rewriter.invoke({"question": question})
137
+
138
+
139
+ ### Search
140
+
141
+ from langchain_community.tools.tavily_search import TavilySearchResults
142
+
143
+ web_search_tool = TavilySearchResults(k=3)
144
+
145
+ from typing import List
146
+
147
+ from typing_extensions import TypedDict
148
+
149
+
150
+ class GraphState(TypedDict):
151
+ """
152
+ Represents the state of our graph.
153
+
154
+ Attributes:
155
+ question: question
156
+ generation: LLM generation
157
+ documents: list of documents
158
+ """
159
+
160
+ question: str
161
+ generation: str
162
+ documents: List[str]
163
+
164
+ from langchain.schema import Document
165
+
166
+
167
+ def retrieve(state):
168
+ """
169
+ Retrieve documents
170
+
171
+ Args:
172
+ state (dict): The current graph state
173
+
174
+ Returns:
175
+ state (dict): New key added to state, documents, that contains retrieved documents
176
+ """
177
+ print("---RETRIEVE---")
178
+ question = state["question"]
179
+
180
+ # Retrieval
181
+ documents = retriever.get_relevant_documents(question)
182
+ return {"documents": documents, "question": question}
183
+
184
+
185
+ def generate(state):
186
+ """
187
+ Generate answer
188
+
189
+ Args:
190
+ state (dict): The current graph state
191
+
192
+ Returns:
193
+ state (dict): New key added to state, generation, that contains LLM generation
194
+ """
195
+ print("---GENERATE---")
196
+ question = state["question"]
197
+ documents = state["documents"]
198
+
199
+ # RAG generation
200
+ generation = rag_chain.invoke({"context": documents, "question": question})
201
+ return {"documents": documents, "question": question, "generation": generation}
202
+
203
+
204
+ def grade_documents(state):
205
+ """
206
+ Determines whether the retrieved documents are relevant to the question.
207
+
208
+ Args:
209
+ state (dict): The current graph state
210
+
211
+ Returns:
212
+ state (dict): Updates documents key with only filtered relevant documents
213
+ """
214
+
215
+ print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
216
+ question = state["question"]
217
+ documents = state["documents"]
218
+
219
+ # Score each doc
220
+ filtered_docs = []
221
+ for d in documents:
222
+ score = retrieval_grader.invoke(
223
+ {"question": question, "document": d.page_content}
224
+ )
225
+ grade = score["score"]
226
+ if grade == "yes":
227
+ print("---GRADE: DOCUMENT RELEVANT---")
228
+ filtered_docs.append(d)
229
+ else:
230
+ print("---GRADE: DOCUMENT NOT RELEVANT---")
231
+ continue
232
+ return {"documents": filtered_docs, "question": question}
233
+
234
+
235
+ def transform_query(state):
236
+ """
237
+ Transform the query to produce a better question.
238
+
239
+ Args:
240
+ state (dict): The current graph state
241
+
242
+ Returns:
243
+ state (dict): Updates question key with a re-phrased question
244
+ """
245
+
246
+ print("---TRANSFORM QUERY---")
247
+ question = state["question"]
248
+ documents = state["documents"]
249
+
250
+ # Re-write question
251
+ better_question = question_rewriter.invoke({"question": question})
252
+ return {"documents": documents, "question": better_question}
253
+
254
+
255
+ def web_search(state):
256
+ """
257
+ Web search based on the re-phrased question.
258
+
259
+ Args:
260
+ state (dict): The current graph state
261
+
262
+ Returns:
263
+ state (dict): Updates documents key with appended web results
264
+ """
265
+
266
+ print("---WEB SEARCH---")
267
+ question = state["question"]
268
+
269
+ # Web search
270
+ docs = web_search_tool.invoke({"query": question})
271
+ web_results = "\n".join([d["content"] for d in docs])
272
+ web_results = Document(page_content=web_results)
273
+
274
+ return {"documents": web_results, "question": question}
275
+
276
+
277
+ ### Edges ###
278
+
279
+
280
+ def route_question(state):
281
+ """
282
+ Route question to web search or RAG.
283
+
284
+ Args:
285
+ state (dict): The current graph state
286
+
287
+ Returns:
288
+ str: Next node to call
289
+ """
290
+
291
+ print("---ROUTE QUESTION---")
292
+ question = state["question"]
293
+ print(question)
294
+ source = question_router.invoke({"question": question})
295
+ print(source)
296
+ print(source["datasource"])
297
+ if source["datasource"] == "web_search":
298
+ print("---ROUTE QUESTION TO WEB SEARCH---")
299
+ return "web_search"
300
+ elif source["datasource"] == "vectorstore":
301
+ print("---ROUTE QUESTION TO RAG---")
302
+ return "vectorstore"
303
+
304
+
305
+ def decide_to_generate(state):
306
+ """
307
+ Determines whether to generate an answer, or re-generate a question.
308
+
309
+ Args:
310
+ state (dict): The current graph state
311
+
312
+ Returns:
313
+ str: Binary decision for next node to call
314
+ """
315
+
316
+ print("---ASSESS GRADED DOCUMENTS---")
317
+ state["question"]
318
+ filtered_documents = state["documents"]
319
+
320
+ if not filtered_documents:
321
+ # All documents have been filtered check_relevance
322
+ # We will re-generate a new query
323
+ print(
324
+ "---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---"
325
+ )
326
+ return "transform_query"
327
+ else:
328
+ # We have relevant documents, so generate answer
329
+ print("---DECISION: GENERATE---")
330
+ return "generate"
331
+
332
+
333
+ def grade_generation_v_documents_and_question(state):
334
+ """
335
+ Determines whether the generation is grounded in the document and answers question.
336
+
337
+ Args:
338
+ state (dict): The current graph state
339
+
340
+ Returns:
341
+ str: Decision for next node to call
342
+ """
343
+
344
+ print("---CHECK HALLUCINATIONS---")
345
+ question = state["question"]
346
+ documents = state["documents"]
347
+ generation = state["generation"]
348
+
349
+ score = hallucination_grader.invoke(
350
+ {"documents": documents, "generation": generation}
351
+ )
352
+ grade = score["score"]
353
+
354
+ # Check hallucination
355
+ if grade == "yes":
356
+ print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
357
+ # Check question-answering
358
+ print("---GRADE GENERATION vs QUESTION---")
359
+ score = answer_grader.invoke({"question": question, "generation": generation})
360
+ grade = score["score"]
361
+ if grade == "yes":
362
+ print("---DECISION: GENERATION ADDRESSES QUESTION---")
363
+ return "useful"
364
+ else:
365
+ print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
366
+ return "not useful"
367
+ else:
368
+ pprint("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
369
+ return "not supported"
370
+
371
+ from langgraph.graph import END, StateGraph, START
372
+
373
+ workflow = StateGraph(GraphState)
374
+
375
+ # Define the nodes
376
+ workflow.add_node("web_search", web_search) # web search
377
+ workflow.add_node("retrieve", retrieve) # retrieve
378
+ workflow.add_node("grade_documents", grade_documents) # grade documents
379
+ workflow.add_node("generate", generate) # generate
380
+ workflow.add_node("transform_query", transform_query) # transform_query
381
+
382
+ # Build graph
383
+ workflow.add_conditional_edges(
384
+ START,
385
+ route_question,
386
+ {
387
+ "web_search": "web_search",
388
+ "vectorstore": "retrieve",
389
+ },
390
+ )
391
+ workflow.add_edge("web_search", "generate")
392
+ workflow.add_edge("retrieve", "grade_documents")
393
+ workflow.add_conditional_edges(
394
+ "grade_documents",
395
+ decide_to_generate,
396
+ {
397
+ "transform_query": "transform_query",
398
+ "generate": "generate",
399
+ },
400
+ )
401
+ workflow.add_edge("transform_query", "retrieve")
402
+ workflow.add_conditional_edges(
403
+ "generate",
404
+ grade_generation_v_documents_and_question,
405
+ {
406
+ "not supported": "generate",
407
+ "useful": END,
408
+ "not useful": "transform_query",
409
+ },
410
+ )
411
+
412
+ # Compile
413
+ app = workflow.compile()
414
+
415
+ from pprint import pprint
416
+
417
+ # Run
418
+ inputs = {"question": "What is the AlphaCodium paper about?"}
419
+ for output in app.stream(inputs):
420
+ for key, value in output.items():
421
+ # Node
422
+ pprint(f"Node '{key}':")
423
+ # Optional: print full state at each node
424
+ # pprint.pprint(value["keys"], indent=2, width=80, depth=None)
425
+ pprint("\n---\n")
426
+
427
+ # Final generation
428
+ pprint(value["generation"])
main.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 主应用程序入口
3
+ 集成所有模块,构建工作流并运行自适应RAG系统
4
+ """
5
+
6
+ from langgraph.graph import END, StateGraph, START
7
+ from pprint import pprint
8
+
9
+ from config import setup_environment
10
+ from document_processor import initialize_document_processor
11
+ from routers_and_graders import initialize_graders_and_router
12
+ from workflow_nodes import WorkflowNodes, GraphState
13
+
14
+
15
+ class AdaptiveRAGSystem:
16
+ """自适应RAG系统主类"""
17
+
18
+ def __init__(self):
19
+ print("初始化自适应RAG系统...")
20
+
21
+ # 设置环境和验证API密钥
22
+ try:
23
+ setup_environment()
24
+
25
+ print("✅ API密钥验证成功")
26
+ except ValueError as e:
27
+ print(f"❌ {e}")
28
+ raise
29
+
30
+ # 初始化文档处理器
31
+ print("设置文档处理器...")
32
+ self.doc_processor, self.vectorstore, self.retriever = initialize_document_processor()
33
+
34
+ # 初始化评分器和路由器
35
+ print("初始化评分器和路由器...")
36
+ self.graders = initialize_graders_and_router()
37
+
38
+ # 初始化工作流节点
39
+ print("设置工作流节点...")
40
+ self.workflow_nodes = WorkflowNodes(self.retriever, self.graders)
41
+
42
+ # 构建工作流
43
+ print("构建工作流图...")
44
+ self.app = self._build_workflow()
45
+
46
+ print("✅ 自适应RAG系统初始化完成!")
47
+
48
+ def _build_workflow(self):
49
+ """构建工作流图"""
50
+ workflow = StateGraph(GraphState)
51
+
52
+ # 定义节点
53
+ workflow.add_node("web_search", self.workflow_nodes.web_search)
54
+ workflow.add_node("retrieve", self.workflow_nodes.retrieve)
55
+ workflow.add_node("grade_documents", self.workflow_nodes.grade_documents)
56
+ workflow.add_node("generate", self.workflow_nodes.generate)
57
+ workflow.add_node("transform_query", self.workflow_nodes.transform_query)
58
+
59
+ # 构建图
60
+ workflow.add_conditional_edges(
61
+ START,
62
+ self.workflow_nodes.route_question,
63
+ {
64
+ "web_search": "web_search",
65
+ "vectorstore": "retrieve",
66
+ },
67
+ )
68
+ workflow.add_edge("web_search", "generate")
69
+ workflow.add_edge("retrieve", "grade_documents")
70
+ workflow.add_conditional_edges(
71
+ "grade_documents",
72
+ self.workflow_nodes.decide_to_generate,
73
+ {
74
+ "transform_query": "transform_query",
75
+ "generate": "generate",
76
+ },
77
+ )
78
+ workflow.add_edge("transform_query", "retrieve")
79
+ workflow.add_conditional_edges(
80
+ "generate",
81
+ self.workflow_nodes.grade_generation_v_documents_and_question,
82
+ {
83
+ "not supported": "generate",
84
+ "useful": END,
85
+ "not useful": "transform_query",
86
+ },
87
+ )
88
+
89
+ # 编译
90
+ return workflow.compile()
91
+
92
+ def query(self, question: str, verbose: bool = True):
93
+ """
94
+ 处理查询
95
+
96
+ Args:
97
+ question (str): 用户问题
98
+ verbose (bool): 是否显示详细输出
99
+
100
+ Returns:
101
+ str: 最终答案
102
+ """
103
+ print(f"\n🔍 处理问题: {question}")
104
+ print("=" * 50)
105
+
106
+ inputs = {"question": question}
107
+ final_generation = None
108
+
109
+ for output in self.app.stream(inputs):
110
+ for key, value in output.items():
111
+ if verbose:
112
+ pprint(f"节点 '{key}':")
113
+ # 可选:在每个节点打印完整状态
114
+ # pprint(value, indent=2, width=80, depth=None)
115
+ final_generation = value.get("generation", final_generation)
116
+ if verbose:
117
+ pprint("\n---\n")
118
+
119
+ print("🎯 最终答案:")
120
+ print("-" * 30)
121
+ print(final_generation)
122
+ print("=" * 50)
123
+
124
+ return final_generation
125
+
126
+ def interactive_mode(self):
127
+ """交互模式,允许用户持续提问"""
128
+ print("\n🤖 欢迎使用自适应RAG系统!")
129
+ print("💡 输入问题开始对话,输入 'quit' 或 'exit' 退出")
130
+ print("-" * 50)
131
+
132
+ while True:
133
+ try:
134
+ question = input("\n❓ 请输入您的问题: ").strip()
135
+
136
+ if question.lower() in ['quit', 'exit', '退出', 'q']:
137
+ print("👋 感谢使用,再见!")
138
+ break
139
+
140
+ if not question:
141
+ print("⚠️ 请输入一个有效的问题")
142
+ continue
143
+
144
+ self.query(question)
145
+
146
+ except KeyboardInterrupt:
147
+ print("\n👋 感谢使用,再见!")
148
+ break
149
+ except Exception as e:
150
+ print(f"❌ 发生错误: {e}")
151
+ print("请重试或输入 'quit' 退出")
152
+
153
+
154
+ def main():
155
+ """主函数"""
156
+ try:
157
+ # 初始化系统
158
+ rag_system = AdaptiveRAGSystem()
159
+
160
+ # 测试查询
161
+ # test_question = "AlphaCodium论文讲的是什么?"
162
+ test_question = "解释embedding嵌入的原理,最好列举实现过程的具体步骤"
163
+ rag_system.query(test_question)
164
+
165
+ # 启动交互模式
166
+ rag_system.interactive_mode()
167
+
168
+ except Exception as e:
169
+ print(f"❌ 系统初始化失败: {e}")
170
+ print("请检查配置和依赖是否正确安装")
171
+
172
+
173
+ if __name__ == "__main__":
174
+ main()
main_graphrag.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GraphRAG集成示例
3
+ 展示如何在自适应RAG系统中使用知识图谱功能
4
+ """
5
+
6
+ import os
7
+ from pprint import pprint
8
+
9
+ from config import (
10
+ setup_environment,
11
+ ENABLE_GRAPHRAG,
12
+ GRAPHRAG_INDEX_PATH,
13
+ GRAPHRAG_BATCH_SIZE
14
+ )
15
+ from document_processor import initialize_document_processor
16
+ from graph_indexer import initialize_graph_indexer
17
+ from graph_retriever import initialize_graph_retriever
18
+
19
+
20
+ class AdaptiveRAGWithGraph:
21
+ """集成GraphRAG的自适应RAG系统"""
22
+
23
+ def __init__(self, enable_graphrag=True, rebuild_index=False):
24
+ print("🚀 初始化集成GraphRAG的自适应RAG系统...")
25
+ print("="*60)
26
+
27
+ # 设置环境
28
+ try:
29
+ setup_environment()
30
+ print("✅ 环境配置完成")
31
+ except ValueError as e:
32
+ print(f"❌ {e}")
33
+ raise
34
+
35
+ # 初始化文档处理器
36
+ print("\n📚 初始化文档处理器...")
37
+ self.doc_processor, self.vectorstore, self.retriever, self.doc_splits = \
38
+ initialize_document_processor()
39
+
40
+ # GraphRAG组件
41
+ self.enable_graphrag = enable_graphrag and ENABLE_GRAPHRAG
42
+ self.graph_indexer = None
43
+ self.graph_retriever = None
44
+ self.knowledge_graph = None
45
+
46
+ if self.enable_graphrag:
47
+ self._setup_graphrag(rebuild_index)
48
+
49
+ print("\n" + "="*60)
50
+ print("✅ 系统初始化完成!")
51
+ print("="*60)
52
+
53
+ def _setup_graphrag(self, rebuild_index=False):
54
+ """设置GraphRAG组件"""
55
+ print("\n🔷 设置GraphRAG组件...")
56
+
57
+ # 初始化索引器
58
+ self.graph_indexer = initialize_graph_indexer()
59
+
60
+ # 检查是否已有索引
61
+ index_exists = os.path.exists(GRAPHRAG_INDEX_PATH)
62
+
63
+ if index_exists and not rebuild_index:
64
+ print(f"📂 发现现有索引: {GRAPHRAG_INDEX_PATH}")
65
+ print(" 加载现有索引...")
66
+ self.knowledge_graph = self.graph_indexer.load_index(GRAPHRAG_INDEX_PATH)
67
+ else:
68
+ if rebuild_index:
69
+ print("🔄 重新构建索引...")
70
+ else:
71
+ print("📝 首次构建索引...")
72
+
73
+ # 构建索引
74
+ self.knowledge_graph = self.graph_indexer.index_documents(
75
+ documents=self.doc_splits,
76
+ batch_size=GRAPHRAG_BATCH_SIZE,
77
+ save_path=GRAPHRAG_INDEX_PATH
78
+ )
79
+
80
+ # 初始化检索器
81
+ self.graph_retriever = initialize_graph_retriever(self.knowledge_graph)
82
+ print("✅ GraphRAG组件设置完成")
83
+
84
+ def query_vector_only(self, question: str) -> str:
85
+ """仅使用向量检索"""
86
+ print(f"\n{'='*60}")
87
+ print(f"🔍 向量检索模式")
88
+ print(f"问题: {question}")
89
+ print(f"{'='*60}")
90
+
91
+ docs = self.retriever.get_relevant_documents(question)
92
+
93
+ print(f"\n📄 检索到 {len(docs)} 个文档片段:")
94
+ for i, doc in enumerate(docs[:3], 1):
95
+ print(f"\n片段 {i}:")
96
+ print(f"{doc.page_content[:200]}...")
97
+
98
+ return self.doc_processor.format_docs(docs)
99
+
100
+ def query_graph_local(self, question: str) -> str:
101
+ """使用图谱本地查询"""
102
+ if not self.enable_graphrag:
103
+ return "GraphRAG未启用"
104
+
105
+ print(f"\n{'='*60}")
106
+ print(f"🔎 图谱本地查询模式")
107
+ print(f"问题: {question}")
108
+ print(f"{'='*60}")
109
+
110
+ answer = self.graph_retriever.local_query(question)
111
+
112
+ print(f"\n💡 答案:")
113
+ print(answer)
114
+
115
+ return answer
116
+
117
+ def query_graph_global(self, question: str) -> str:
118
+ """使用图谱全局查询"""
119
+ if not self.enable_graphrag:
120
+ return "GraphRAG未启用"
121
+
122
+ print(f"\n{'='*60}")
123
+ print(f"🌍 图谱全局查询模式")
124
+ print(f"问题: {question}")
125
+ print(f"{'='*60}")
126
+
127
+ answer = self.graph_retriever.global_query(question)
128
+
129
+ print(f"\n💡 答案:")
130
+ print(answer)
131
+
132
+ return answer
133
+
134
+ def query_hybrid(self, question: str) -> dict:
135
+ """混合查询:向量 + 图谱"""
136
+ if not self.enable_graphrag:
137
+ return {"error": "GraphRAG未启用"}
138
+
139
+ print(f"\n{'='*60}")
140
+ print(f"🔀 混合查询模式")
141
+ print(f"问题: {question}")
142
+ print(f"{'='*60}")
143
+
144
+ # 向量检索
145
+ vector_docs = self.retriever.get_relevant_documents(question)
146
+ vector_context = self.doc_processor.format_docs(vector_docs[:3])
147
+
148
+ # 图谱查询
149
+ graph_results = self.graph_retriever.hybrid_query(question)
150
+
151
+ result = {
152
+ "question": question,
153
+ "vector_retrieval": {
154
+ "doc_count": len(vector_docs),
155
+ "context": vector_context[:500] + "..." if len(vector_context) > 500 else vector_context
156
+ },
157
+ "graph_local": graph_results["local"],
158
+ "graph_global": graph_results["global"]
159
+ }
160
+
161
+ print("\n📊 结果汇总:")
162
+ print(f" • 向量检索: {len(vector_docs)} 个文档")
163
+ print(f" • 图谱本地查询完成")
164
+ print(f" • 图谱全局查询完成")
165
+
166
+ return result
167
+
168
+ def query_smart(self, question: str) -> str:
169
+ """智能查询:自动选择最佳策略"""
170
+ if not self.enable_graphrag:
171
+ return self.query_vector_only(question)
172
+
173
+ print(f"\n{'='*60}")
174
+ print(f"🧠 智能查询模式")
175
+ print(f"问题: {question}")
176
+ print(f"{'='*60}")
177
+
178
+ answer = self.graph_retriever.smart_query(question)
179
+
180
+ print(f"\n💡 答案:")
181
+ print(answer)
182
+
183
+ return answer
184
+
185
+ def get_graph_statistics(self):
186
+ """获取知识图谱统计信息"""
187
+ if not self.enable_graphrag or not self.knowledge_graph:
188
+ print("GraphRAG未启用或图谱未构建")
189
+ return
190
+
191
+ stats = self.knowledge_graph.get_statistics()
192
+
193
+ print("\n" + "="*60)
194
+ print("📊 知识图谱统计信息")
195
+ print("="*60)
196
+ print(f"节点数: {stats['num_nodes']}")
197
+ print(f"边数: {stats['num_edges']}")
198
+ print(f"社区数: {stats['num_communities']}")
199
+ print(f"图密度: {stats['density']:.4f}")
200
+ print("\n实体类型分布:")
201
+ for etype, count in stats['entity_types'].items():
202
+ print(f" • {etype}: {count}")
203
+ print("="*60)
204
+
205
+ return stats
206
+
207
+ def interactive_mode(self):
208
+ """交互模式"""
209
+ print("\n" + "="*60)
210
+ print("🤖 欢迎使用GraphRAG增强的自适应RAG系统!")
211
+ print("="*60)
212
+ print("\n查询模式:")
213
+ print(" 1️⃣ vector - 仅向量检索")
214
+ print(" 2️⃣ local - 图谱本地查询")
215
+ print(" 3️⃣ global - 图谱全局查询")
216
+ print(" 4️⃣ hybrid - 混合查询")
217
+ print(" 5️⃣ smart - 智能查询(推荐)")
218
+ print(" 6️⃣ stats - 显示图谱统计")
219
+ print(" 7️⃣ quit - 退出")
220
+ print("-"*60)
221
+
222
+ while True:
223
+ try:
224
+ mode = input("\n选择模式 (1-7): ").strip()
225
+
226
+ if mode in ['7', 'quit', 'exit', '退出', 'q']:
227
+ print("👋 感谢使用,再见!")
228
+ break
229
+
230
+ if mode in ['6', 'stats']:
231
+ self.get_graph_statistics()
232
+ continue
233
+
234
+ question = input("❓ 请输入问题: ").strip()
235
+
236
+ if not question:
237
+ print("⚠️ 请输入有效问题")
238
+ continue
239
+
240
+ if mode in ['1', 'vector']:
241
+ self.query_vector_only(question)
242
+ elif mode in ['2', 'local']:
243
+ self.query_graph_local(question)
244
+ elif mode in ['3', 'global']:
245
+ self.query_graph_global(question)
246
+ elif mode in ['4', 'hybrid']:
247
+ result = self.query_hybrid(question)
248
+ pprint(result)
249
+ else: # 默认智能模式
250
+ self.query_smart(question)
251
+
252
+ except KeyboardInterrupt:
253
+ print("\n👋 感谢使用,再见!")
254
+ break
255
+ except Exception as e:
256
+ print(f"❌ 发生错误: {e}")
257
+ print("请重试或输入 'quit' 退出")
258
+
259
+
260
+ def main():
261
+ """主函数"""
262
+ try:
263
+ # 初始化系统(首次运行设置rebuild_index=True)
264
+ rag_system = AdaptiveRAGWithGraph(
265
+ enable_graphrag=True,
266
+ rebuild_index=False # 设为True重新构建索引
267
+ )
268
+
269
+ # 显示图谱统计
270
+ rag_system.get_graph_statistics()
271
+
272
+ # 测试查询
273
+ print("\n" + "="*60)
274
+ print("🧪 测试查询示例")
275
+ print("="*60)
276
+
277
+ # 示例1: 本地查询
278
+ rag_system.query_graph_local("LLM Agent的主要组成部分是什么?")
279
+
280
+ # 示例2: 全局查询
281
+ rag_system.query_graph_global("这些文档主要讨论了什么主题?")
282
+
283
+ # 启动交互模式
284
+ rag_system.interactive_mode()
285
+
286
+ except Exception as e:
287
+ print(f"❌ 系统初始化失败: {e}")
288
+ import traceback
289
+ traceback.print_exc()
290
+
291
+
292
+ if __name__ == "__main__":
293
+ main()
requirements.txt ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 自适应RAG系统依赖文件
2
+ # 运行以下命令安装: pip install -r requirements.txt
3
+
4
+ # 核心框架
5
+ langchain>=0.1.0
6
+ langgraph>=0.0.40
7
+ langchain-community>=0.0.20
8
+ langchain-core>=0.1.0
9
+
10
+ # LLM集成
11
+ langchain-ollama>=0.1.0
12
+
13
+ # 向量数据库和嵌入
14
+ chromadb>=0.4.0
15
+ sentence-transformers>=2.2.0
16
+ torch>=2.0.0
17
+ transformers>=4.30.0
18
+
19
+ # 文档处理
20
+ tiktoken>=0.5.0
21
+ beautifulsoup4>=4.12.0
22
+ requests>=2.31.0
23
+
24
+ # 网络搜索
25
+ tavily-python>=0.3.0
26
+
27
+ # 数据处理
28
+ numpy>=1.24.0,<2.0 # 避免NumPy 2.x兼容性问题
29
+ pandas>=2.0.0
30
+
31
+ # 工具库
32
+ python-dotenv>=1.0.0
33
+ pydantic>=2.0.0
34
+ typing-extensions>=4.0.0
35
+
36
+ # 开发工具(可选)
37
+ jupyter>=1.0.0
38
+ ipykernel>=6.0.0
39
+ matplotlib>=3.7.0
40
+ seaborn>=0.12.0
41
+
42
+ # GraphRAG相关(可选)
43
+ networkx>=3.1 # 图结构处理
44
+ python-louvain>=0.16 # 社区检测
requirements_gpu.txt ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GPU优化版依赖文件 - 针对RTX 4090 Linux环境
2
+ # 安装命令: pip install -r requirements_gpu.txt
3
+
4
+ # 核心框架
5
+ langchain>=0.1.0
6
+ langgraph>=0.0.40
7
+ langchain-community>=0.0.20
8
+ langchain-core>=0.1.0
9
+
10
+ # LLM集成
11
+ langchain-ollama>=0.1.0
12
+
13
+ # 向量数据库和嵌入(GPU优化版本)
14
+ chromadb>=0.4.0
15
+ sentence-transformers>=2.2.0
16
+
17
+ # PyTorch GPU版本(CUDA 11.8支持)
18
+ --index-url https://download.pytorch.org/whl/cu118
19
+ torch>=2.0.0+cu118
20
+ torchvision>=0.15.0+cu118
21
+ torchaudio>=2.0.0+cu118
22
+
23
+ # Transformers和加速库
24
+ transformers>=4.30.0
25
+ accelerate>=0.20.0
26
+
27
+ # GPU加速向量搜索
28
+ faiss-gpu>=1.7.4
29
+
30
+ # CUDA支持库
31
+ cupy-cuda12x>=12.0.0
32
+
33
+ # 文档处理
34
+ tiktoken>=0.5.0
35
+ beautifulsoup4>=4.12.0
36
+ requests>=2.31.0
37
+
38
+ # 网络搜索
39
+ tavily-python>=0.3.0
40
+
41
+ # 数据处理(兼容版本)
42
+ numpy>=1.24.0,<2.0
43
+ pandas>=2.0.0
44
+
45
+ # 工具库
46
+ python-dotenv>=1.0.0
47
+ pydantic>=2.0.0
48
+ typing-extensions>=4.0.0
49
+
50
+ # 系统监控
51
+ gputil>=1.4.0
52
+ psutil>=5.9.0
53
+
54
+ # 可选:Jupyter支持
55
+ jupyter>=1.0.0
56
+ ipykernel>=6.0.0
requirements_graphrag.txt ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GraphRAG额外依赖
2
+ # 在原有requirements.txt基础上添加这些包
3
+ # 安装命令: pip install -r requirements_graphrag.txt
4
+
5
+ # 图数据库
6
+ neo4j>=5.14.0
7
+ py2neo>=2021.2.3
8
+
9
+ # 或者使用更轻量的选项
10
+ networkx>=3.1
11
+ python-louvain>=0.16 # 社区检测
12
+
13
+ # 图谱处理
14
+ graspologic>=3.3.0 # 层次化社区检测
15
+ leidenalg>=0.10.0 # 更好的社区检测算法
16
+
17
+ # GraphRAG核心(可选,如果使用Microsoft官方实现)
18
+ # graphrag>=0.1.0
19
+
20
+ # 实体识别增强
21
+ spacy>=3.7.0
22
+ # 下载模型: python -m spacy download zh_core_web_sm
23
+ # 下载模型: python -m spacy download en_core_web_sm
24
+
25
+ # 图可视化(可选)
26
+ pyvis>=0.3.2
27
+ plotly>=5.18.0
28
+
29
+ # 缓存和性能优化
30
+ diskcache>=5.6.0
31
+ joblib>=1.3.0
reranker.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 向量重排模块
3
+ 实现多种重排策略以提高检索质量
4
+ """
5
+
6
+ import torch
7
+ import numpy as np
8
+ from typing import List, Tuple, Dict
9
+ from sklearn.feature_extraction.text import TfidfVectorizer
10
+ from sklearn.metrics.pairwise import cosine_similarity
11
+ import re
12
+ from collections import Counter
13
+ import math
14
+
15
+
16
+ class DocumentReranker:
17
+ """文档重排器基类"""
18
+
19
+ def __init__(self):
20
+ self.name = "BaseReranker"
21
+
22
+ def rerank(self, query: str, documents: List[dict], top_k: int = 5) -> List[Tuple[dict, float]]:
23
+ """重排文档并返回top_k结果"""
24
+ raise NotImplementedError
25
+
26
+
27
+ class TFIDFReranker(DocumentReranker):
28
+ """基于TF-IDF的重排器"""
29
+
30
+ def __init__(self):
31
+ super().__init__()
32
+ self.name = "TFIDFReranker"
33
+ self.vectorizer = TfidfVectorizer(stop_words='english', max_features=5000)
34
+
35
+ def rerank(self, query: str, documents: List[dict], top_k: int = 5) -> List[Tuple[dict, float]]:
36
+ """使用TF-IDF重新排序文档"""
37
+ if not documents:
38
+ return []
39
+
40
+ # 提取文档内容
41
+ doc_texts = [doc.page_content if hasattr(doc, 'page_content') else str(doc) for doc in documents]
42
+ all_texts = [query] + doc_texts
43
+
44
+ # 计算TF-IDF矩阵
45
+ tfidf_matrix = self.vectorizer.fit_transform(all_texts)
46
+ query_vec = tfidf_matrix[0]
47
+ doc_vecs = tfidf_matrix[1:]
48
+
49
+ # 计算相似度
50
+ similarities = cosine_similarity(query_vec, doc_vecs).flatten()
51
+
52
+ # 排序并返回top_k
53
+ ranked_indices = np.argsort(similarities)[::-1]
54
+ results = []
55
+ for i in ranked_indices[:top_k]:
56
+ results.append((documents[i], float(similarities[i])))
57
+
58
+ return results
59
+
60
+
61
+ class BM25Reranker(DocumentReranker):
62
+ """基于BM25算法的重排器"""
63
+
64
+ def __init__(self, k1: float = 1.5, b: float = 0.75):
65
+ super().__init__()
66
+ self.name = "BM25Reranker"
67
+ self.k1 = k1
68
+ self.b = b
69
+
70
+ def _tokenize(self, text: str) -> List[str]:
71
+ """简单分词"""
72
+ return re.findall(r'\b\w+\b', text.lower())
73
+
74
+ def _compute_idf(self, documents: List[str], query_terms: List[str]) -> Dict[str, float]:
75
+ """计算IDF值"""
76
+ N = len(documents)
77
+ idf = {}
78
+
79
+ for term in query_terms:
80
+ df = sum(1 for doc in documents if term in self._tokenize(doc))
81
+ idf[term] = math.log((N - df + 0.5) / (df + 0.5))
82
+
83
+ return idf
84
+
85
+ def _bm25_score(self, query_terms: List[str], document: str, avg_doc_len: float, idf: Dict[str, float]) -> float:
86
+ """计算BM25分数"""
87
+ doc_terms = self._tokenize(document)
88
+ doc_len = len(doc_terms)
89
+ term_freq = Counter(doc_terms)
90
+
91
+ score = 0.0
92
+ for term in query_terms:
93
+ if term in term_freq:
94
+ tf = term_freq[term]
95
+ score += idf.get(term, 0) * (tf * (self.k1 + 1)) / (
96
+ tf + self.k1 * (1 - self.b + self.b * doc_len / avg_doc_len)
97
+ )
98
+
99
+ return score
100
+
101
+ def rerank(self, query: str, documents: List[dict], top_k: int = 5) -> List[Tuple[dict, float]]:
102
+ """使用BM25重新排序文档"""
103
+ if not documents:
104
+ return []
105
+
106
+ query_terms = self._tokenize(query)
107
+ doc_texts = [doc.page_content if hasattr(doc, 'page_content') else str(doc) for doc in documents]
108
+
109
+ # 计算平均文档长度
110
+ avg_doc_len = sum(len(self._tokenize(doc)) for doc in doc_texts) / len(doc_texts)
111
+
112
+ # 计算IDF
113
+ idf = self._compute_idf(doc_texts, query_terms)
114
+
115
+ # 计算BM25分数
116
+ scores = []
117
+ for doc_text in doc_texts:
118
+ score = self._bm25_score(query_terms, doc_text, avg_doc_len, idf)
119
+ scores.append(score)
120
+
121
+ # 排序并返回top_k
122
+ ranked_indices = np.argsort(scores)[::-1]
123
+ results = []
124
+ for i in ranked_indices[:top_k]:
125
+ results.append((documents[i], float(scores[i])))
126
+
127
+ return results
128
+
129
+
130
+ class SemanticReranker(DocumentReranker):
131
+ """基于语义相似度的重排器"""
132
+
133
+ def __init__(self, embeddings_model):
134
+ super().__init__()
135
+ self.name = "SemanticReranker"
136
+ self.embeddings_model = embeddings_model
137
+
138
+ def rerank(self, query: str, documents: List[dict], top_k: int = 5) -> List[Tuple[dict, float]]:
139
+ """使用语义相似度重新排序文档"""
140
+ if not documents:
141
+ return []
142
+
143
+ # 获取查询嵌入
144
+ query_embedding = self.embeddings_model.embed_query(query)
145
+
146
+ # 获取文档嵌入
147
+ doc_texts = [doc.page_content if hasattr(doc, 'page_content') else str(doc) for doc in documents]
148
+ doc_embeddings = self.embeddings_model.embed_documents(doc_texts)
149
+
150
+ # 计算余弦相似度
151
+ similarities = []
152
+ for doc_emb in doc_embeddings:
153
+ sim = cosine_similarity([query_embedding], [doc_emb])[0][0]
154
+ similarities.append(sim)
155
+
156
+ # 排序并返回top_k
157
+ ranked_indices = np.argsort(similarities)[::-1]
158
+ results = []
159
+ for i in ranked_indices[:top_k]:
160
+ results.append((documents[i], float(similarities[i])))
161
+
162
+ return results
163
+
164
+
165
+ class HybridReranker(DocumentReranker):
166
+ """混合重排器,融合多种策略"""
167
+
168
+ def __init__(self, embeddings_model, weights: Dict[str, float] = None):
169
+ super().__init__()
170
+ self.name = "HybridReranker"
171
+
172
+ # 初始化各种重排器
173
+ self.tfidf_reranker = TFIDFReranker()
174
+ self.bm25_reranker = BM25Reranker()
175
+ self.semantic_reranker = SemanticReranker(embeddings_model)
176
+
177
+ # 设置权重
178
+ self.weights = weights or {
179
+ 'tfidf': 0.3,
180
+ 'bm25': 0.3,
181
+ 'semantic': 0.4
182
+ }
183
+
184
+ def rerank(self, query: str, documents: List[dict], top_k: int = 5) -> List[Tuple[dict, float]]:
185
+ """使用混合策略重新排序文档"""
186
+ if not documents:
187
+ return []
188
+
189
+ # 获取各种重排结果
190
+ tfidf_results = self.tfidf_reranker.rerank(query, documents, len(documents))
191
+ bm25_results = self.bm25_reranker.rerank(query, documents, len(documents))
192
+ semantic_results = self.semantic_reranker.rerank(query, documents, len(documents))
193
+
194
+ # 创建文档到分数的映射
195
+ doc_scores = {}
196
+ for doc in documents:
197
+ doc_id = id(doc)
198
+ doc_scores[doc_id] = {'doc': doc, 'tfidf': 0, 'bm25': 0, 'semantic': 0}
199
+
200
+ # 填充各种分数
201
+ for doc, score in tfidf_results:
202
+ doc_scores[id(doc)]['tfidf'] = score
203
+
204
+ for doc, score in bm25_results:
205
+ doc_scores[id(doc)]['bm25'] = score
206
+
207
+ for doc, score in semantic_results:
208
+ doc_scores[id(doc)]['semantic'] = score
209
+
210
+ # 归一化分数
211
+ for score_type in ['tfidf', 'bm25', 'semantic']:
212
+ scores = [info[score_type] for info in doc_scores.values()]
213
+ if max(scores) > 0:
214
+ max_score = max(scores)
215
+ for doc_id in doc_scores:
216
+ doc_scores[doc_id][score_type] /= max_score
217
+
218
+ # 计算综合分数
219
+ final_scores = []
220
+ for doc_id, info in doc_scores.items():
221
+ combined_score = (
222
+ self.weights['tfidf'] * info['tfidf'] +
223
+ self.weights['bm25'] * info['bm25'] +
224
+ self.weights['semantic'] * info['semantic']
225
+ )
226
+ final_scores.append((info['doc'], combined_score))
227
+
228
+ # 排序并返回top_k
229
+ final_scores.sort(key=lambda x: x[1], reverse=True)
230
+ return final_scores[:top_k]
231
+
232
+
233
+ class DiversityReranker(DocumentReranker):
234
+ """多样性重排器,避免结果重复"""
235
+
236
+ def __init__(self, embeddings_model, diversity_lambda: float = 0.5):
237
+ super().__init__()
238
+ self.name = "DiversityReranker"
239
+ self.embeddings_model = embeddings_model
240
+ self.diversity_lambda = diversity_lambda
241
+
242
+ def _calculate_diversity_penalty(self, candidate_doc: str, selected_docs: List[str]) -> float:
243
+ """计算多样性惩罚"""
244
+ if not selected_docs:
245
+ return 0.0
246
+
247
+ candidate_emb = self.embeddings_model.embed_documents([candidate_doc])[0]
248
+ selected_embs = self.embeddings_model.embed_documents(selected_docs)
249
+
250
+ max_similarity = 0.0
251
+ for selected_emb in selected_embs:
252
+ sim = cosine_similarity([candidate_emb], [selected_emb])[0][0]
253
+ max_similarity = max(max_similarity, sim)
254
+
255
+ return max_similarity
256
+
257
+ def rerank(self, query: str, documents: List[dict], top_k: int = 5) -> List[Tuple[dict, float]]:
258
+ """使用多样性策略重新排序文档"""
259
+ if not documents:
260
+ return []
261
+
262
+ # 首先使用语义相似度获取初始排序
263
+ semantic_results = SemanticReranker(self.embeddings_model).rerank(
264
+ query, documents, len(documents)
265
+ )
266
+
267
+ # MMR (Maximal Marginal Relevance) 算法
268
+ selected_docs = []
269
+ selected_texts = []
270
+ remaining_docs = [doc for doc, _ in semantic_results]
271
+ relevance_scores = {id(doc): score for doc, score in semantic_results}
272
+
273
+ while len(selected_docs) < top_k and remaining_docs:
274
+ best_score = -1
275
+ best_doc = None
276
+ best_idx = -1
277
+
278
+ for i, doc in enumerate(remaining_docs):
279
+ doc_text = doc.page_content if hasattr(doc, 'page_content') else str(doc)
280
+ relevance = relevance_scores[id(doc)]
281
+ diversity_penalty = self._calculate_diversity_penalty(doc_text, selected_texts)
282
+
283
+ # MMR分数 = λ * 相关性 - (1-λ) * 多样性惩罚
284
+ mmr_score = (
285
+ self.diversity_lambda * relevance -
286
+ (1 - self.diversity_lambda) * diversity_penalty
287
+ )
288
+
289
+ if mmr_score > best_score:
290
+ best_score = mmr_score
291
+ best_doc = doc
292
+ best_idx = i
293
+
294
+ if best_doc is not None:
295
+ selected_docs.append((best_doc, best_score))
296
+ selected_texts.append(
297
+ best_doc.page_content if hasattr(best_doc, 'page_content') else str(best_doc)
298
+ )
299
+ remaining_docs.pop(best_idx)
300
+
301
+ return selected_docs
302
+
303
+
304
+ def create_reranker(reranker_type: str, embeddings_model=None, **kwargs) -> DocumentReranker:
305
+ """工厂函数:创建指定类型的重排器"""
306
+
307
+ if reranker_type.lower() == 'tfidf':
308
+ return TFIDFReranker()
309
+ elif reranker_type.lower() == 'bm25':
310
+ return BM25Reranker(**kwargs)
311
+ elif reranker_type.lower() == 'semantic':
312
+ if embeddings_model is None:
313
+ raise ValueError("SemanticReranker requires embeddings_model")
314
+ return SemanticReranker(embeddings_model)
315
+ elif reranker_type.lower() == 'hybrid':
316
+ if embeddings_model is None:
317
+ raise ValueError("HybridReranker requires embeddings_model")
318
+ return HybridReranker(embeddings_model, **kwargs)
319
+ elif reranker_type.lower() == 'diversity':
320
+ if embeddings_model is None:
321
+ raise ValueError("DiversityReranker requires embeddings_model")
322
+ return DiversityReranker(embeddings_model, **kwargs)
323
+ else:
324
+ raise ValueError(f"Unknown reranker type: {reranker_type}")
325
+
326
+
327
+ # 使用示例
328
+ if __name__ == "__main__":
329
+ # 模拟文档
330
+ class MockDoc:
331
+ def __init__(self, content):
332
+ self.page_content = content
333
+
334
+ docs = [
335
+ MockDoc("人工智能是计算机科学的一个分支"),
336
+ MockDoc("机器学习是人工智能的子领域"),
337
+ MockDoc("深度学习使用神经网络"),
338
+ MockDoc("自然语言处理处理文本数据"),
339
+ MockDoc("今天天气很好")
340
+ ]
341
+
342
+ query = "什么是人工智能?"
343
+
344
+ # 测试TF-IDF重排
345
+ tfidf_reranker = TFIDFReranker()
346
+ results = tfidf_reranker.rerank(query, docs, top_k=3)
347
+
348
+ print("TF-IDF重排结果:")
349
+ for doc, score in results:
350
+ print(f"分数: {score:.4f} - 内容: {doc.page_content}")
routers_and_graders.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 路由器和评分器模块
3
+ 包含查询路由、文档相关性评分、答案质量评分和幻觉检测
4
+ """
5
+
6
+ from langchain.prompts import PromptTemplate
7
+ from langchain_community.chat_models import ChatOllama
8
+ from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
9
+
10
+ from config import LOCAL_LLM
11
+
12
+
13
+ class QueryRouter:
14
+ """查询路由器,决定使用向量存储还是网络搜索"""
15
+
16
+ def __init__(self):
17
+ self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
18
+ self.prompt = PromptTemplate(
19
+ template="""你是一个专家,负责将用户问题路由到向量存储或网络搜索。
20
+ 对于关于LLM智能体、提示工程和对抗性攻击的问题,使用向量存储。
21
+ 你不需要严格匹配问题中与这些主题相关的关键词。
22
+ 否则,使用网络搜索。根据问题给出二进制选择'web_search'或'vectorstore'。
23
+ 返回一个只包含'datasource'键的JSON,不要前言或解释。
24
+ 要路由的问题:{question}""",
25
+ input_variables=["question"],
26
+ )
27
+ self.router = self.prompt | self.llm | JsonOutputParser()
28
+
29
+ def route(self, question: str) -> str:
30
+ """路由问题到相应的数据源"""
31
+ result = self.router.invoke({"question": question})
32
+ return result.get("datasource", "web_search")
33
+
34
+
35
+ class DocumentGrader:
36
+ """文档相关性评分器"""
37
+
38
+ def __init__(self):
39
+ self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
40
+ self.prompt = PromptTemplate(
41
+ template="""你是一个评分员,评估检索到的文档是否与用户问题相关。
42
+ 如果文档包含与用户问题相关的关键词或语义,请给出'yes'分数。
43
+ 给出二进制分数'yes'或'no',以表明文档是否与问题相关。
44
+ 将二进制分数作为JSON提供,只包含'score'键,不要前言或解释。
45
+
46
+ 检索到的文档:
47
+
48
+ {document}
49
+
50
+
51
+ 用户问题:{question}""",
52
+ input_variables=["question", "document"],
53
+ )
54
+ self.grader = self.prompt | self.llm | JsonOutputParser()
55
+
56
+ def grade(self, question: str, document: str) -> str:
57
+ """评估文档与问题的相关性"""
58
+ result = self.grader.invoke({"question": question, "document": document})
59
+ return result.get("score", "no")
60
+
61
+
62
+ class AnswerGrader:
63
+ """答案质量评分器"""
64
+
65
+ def __init__(self):
66
+ self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
67
+ self.prompt = PromptTemplate(
68
+ template="""你是一个评分员,评估答案是否有助于解决问题。
69
+ 这里是答案:
70
+ \n ------- \n
71
+ {generation}
72
+ \n ------- \n
73
+ 这里是问题:{question}
74
+ 给出二进制分数'yes'或'no',表示答案是否有助于解决问题。
75
+ 将二进制分数作为JSON提供,只包含'score'键,不要前言或解释。""",
76
+ input_variables=["generation", "question"],
77
+ )
78
+ self.grader = self.prompt | self.llm | JsonOutputParser()
79
+
80
+ def grade(self, question: str, generation: str) -> str:
81
+ """评估答案质量"""
82
+ result = self.grader.invoke({"question": question, "generation": generation})
83
+ return result.get("score", "no")
84
+
85
+
86
+ class HallucinationGrader:
87
+ """幻觉检测器"""
88
+
89
+ def __init__(self):
90
+ self.llm = ChatOllama(model=LOCAL_LLM, format="json", temperature=0)
91
+ self.prompt = PromptTemplate(
92
+ template="""你是一个评分员,评估LLM生成是否基于/支持一组检索到的事实。
93
+ 给出二进制分数'yes'或'no'。'yes'意味着答案基于/支持文档。
94
+ 将二进制分数作为JSON提供,只包含'score'键,不要前言或解释。
95
+
96
+ 检索到的文档:
97
+
98
+ {documents}
99
+
100
+
101
+ LLM生成:{generation}""",
102
+ input_variables=["generation", "documents"],
103
+ )
104
+ self.grader = self.prompt | self.llm | JsonOutputParser()
105
+
106
+ def grade(self, generation: str, documents) -> str:
107
+ """检测生成内容是否存在幻觉"""
108
+ result = self.grader.invoke({"generation": generation, "documents": documents})
109
+ return result.get("score", "no")
110
+
111
+
112
+ class QueryRewriter:
113
+ """查询重写器,优化查询以获得更好的检索结果"""
114
+
115
+ def __init__(self):
116
+ self.llm = ChatOllama(model=LOCAL_LLM, temperature=0)
117
+ self.prompt = PromptTemplate(
118
+ template="""你是一个问题重写器,将输入问题转换为更适合向量存储检索的更好版本。
119
+ 查看初始问题并制定一个改进的问题。
120
+ 这里是初始问题:\n\n {question}。改进的问题(无前言):\n """,
121
+ input_variables=["question"],
122
+ )
123
+ self.rewriter = self.prompt | self.llm | StrOutputParser()
124
+
125
+ def rewrite(self, question: str) -> str:
126
+ """重写查询以获得更好的检索效果"""
127
+ print(f"---原始查询: {question}---")
128
+ rewritten_query = self.rewriter.invoke({"question": question})
129
+ print(f"---重写查询: {rewritten_query}---")
130
+ return rewritten_query
131
+
132
+
133
+ def initialize_graders_and_router():
134
+ """初始化所有评分器和路由器"""
135
+ query_router = QueryRouter()
136
+ document_grader = DocumentGrader()
137
+ answer_grader = AnswerGrader()
138
+ hallucination_grader = HallucinationGrader()
139
+ query_rewriter = QueryRewriter()
140
+
141
+ return {
142
+ "query_router": query_router,
143
+ "document_grader": document_grader,
144
+ "answer_grader": answer_grader,
145
+ "hallucination_grader": hallucination_grader,
146
+ "query_rewriter": query_rewriter
147
+ }
test_reranking.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 重排功能测试脚本
4
+ 演示不同重排策略的效果
5
+ """
6
+
7
+ import sys
8
+ import os
9
+ sys.path.append(os.path.dirname(__file__))
10
+
11
+ from document_processor import DocumentProcessor
12
+ from reranker import *
13
+ from langchain.schema import Document
14
+ import time
15
+
16
+
17
+ def create_test_documents():
18
+ """创建测试文档"""
19
+ return [
20
+ Document(
21
+ page_content="人工智能(AI)是计算机科学的一个分支,致力于创建能够执行通常需要人类智能的任务的系统。",
22
+ metadata={"source": "ai_intro.txt", "category": "AI基础"}
23
+ ),
24
+ Document(
25
+ page_content="机器学习是人工智能的一个重要子领域,通过算法让计算机从数据中学习模式和规律。",
26
+ metadata={"source": "ml_basics.txt", "category": "机器学习"}
27
+ ),
28
+ Document(
29
+ page_content="深度学习是机器学习的一个分支,使用多层神经网络来模拟人脑的学习过程。",
30
+ metadata={"source": "dl_guide.txt", "category": "深度学习"}
31
+ ),
32
+ Document(
33
+ page_content="自然语言处理(NLP)是人工智能领域的一个重要分支,专注于使计算机理解和处理人类语言。",
34
+ metadata={"source": "nlp_overview.txt", "category": "自然语言处理"}
35
+ ),
36
+ Document(
37
+ page_content="计算机视觉是人工智能的另一个重要领域,使计算机能够识别和理解图像和视频内容。",
38
+ metadata={"source": "cv_intro.txt", "category": "计算机视觉"}
39
+ ),
40
+ Document(
41
+ page_content="强化学习是机器学习的一种类型,通过与环境交互来学习最优的行为策略。",
42
+ metadata={"source": "rl_basics.txt", "category": "强化学习"}
43
+ ),
44
+ Document(
45
+ page_content="今天的天气非常好,阳光明媚,适合外出游玩和运动。",
46
+ metadata={"source": "weather.txt", "category": "天气"}
47
+ ),
48
+ Document(
49
+ page_content="区块链是一种分布式账本技术,具有去中心化、不可篡改等特点。",
50
+ metadata={"source": "blockchain.txt", "category": "区块链"}
51
+ )
52
+ ]
53
+
54
+
55
+ def test_reranker_comparison():
56
+ """比较不同重排器的效果"""
57
+ print("🔍 重排器效果比较测试")
58
+ print("=" * 60)
59
+
60
+ # 创建测试数据
61
+ query = "什么是人工智能和机器学习?"
62
+ documents = create_test_documents()
63
+
64
+ # 创建一个简单的嵌入模型(用于测试)
65
+ try:
66
+ from langchain_community.embeddings import HuggingFaceEmbeddings
67
+ embeddings = HuggingFaceEmbeddings(
68
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
69
+ model_kwargs={'device': 'cpu'}
70
+ )
71
+ print("✅ 成功加载嵌入模型")
72
+ except Exception as e:
73
+ print(f"❌ 嵌入模型加载失败: {e}")
74
+ print("将使用基础重排器进行测试")
75
+ embeddings = None
76
+
77
+ # 测试不同的重排器
78
+ rerankers = []
79
+
80
+ # TF-IDF重排器
81
+ rerankers.append(("TF-IDF", TFIDFReranker()))
82
+
83
+ # BM25重排器
84
+ rerankers.append(("BM25", BM25Reranker()))
85
+
86
+ if embeddings:
87
+ # 语义重排器
88
+ rerankers.append(("语义相似度", SemanticReranker(embeddings)))
89
+
90
+ # 混合重排器
91
+ rerankers.append(("混合策略", HybridReranker(embeddings)))
92
+
93
+ # 多样性重排器
94
+ rerankers.append(("多样性优化", DiversityReranker(embeddings)))
95
+
96
+ # 执行测试
97
+ for name, reranker in rerankers:
98
+ print(f"\n📊 {name} 重排结果:")
99
+ print("-" * 40)
100
+
101
+ start_time = time.time()
102
+ try:
103
+ results = reranker.rerank(query, documents, top_k=5)
104
+ end_time = time.time()
105
+
106
+ print(f"⏱️ 处理时间: {(end_time - start_time)*1000:.2f}ms")
107
+
108
+ for i, (doc, score) in enumerate(results, 1):
109
+ content = doc.page_content[:80] + "..." if len(doc.page_content) > 80 else doc.page_content
110
+ category = doc.metadata.get('category', '未知')
111
+ print(f"{i}. [分数: {score:.4f}] [{category}] {content}")
112
+
113
+ except Exception as e:
114
+ print(f"❌ 重排失败: {e}")
115
+
116
+
117
+ def test_reranking_with_embeddings():
118
+ """测试带嵌入的重排功能"""
119
+ print("\n\n🧠 嵌入模型重排测试")
120
+ print("=" * 60)
121
+
122
+ try:
123
+ # 创建文档处理器
124
+ processor = DocumentProcessor()
125
+
126
+ # 创建测试文档
127
+ test_docs = create_test_documents()
128
+
129
+ # 测试查询
130
+ queries = [
131
+ "人工智能的定义是什么?",
132
+ "机器学习和深度学习的区别",
133
+ "自然语言处理的应用",
134
+ "今天天气怎么样?"
135
+ ]
136
+
137
+ for query in queries:
138
+ print(f"\n🔍 查询: {query}")
139
+ print("-" * 30)
140
+
141
+ if processor.reranker:
142
+ # 使用重排功能
143
+ results = processor.reranker.rerank(query, test_docs, top_k=3)
144
+
145
+ for i, (doc, score) in enumerate(results, 1):
146
+ content = doc.page_content[:60] + "..." if len(doc.page_content) > 60 else doc.page_content
147
+ category = doc.metadata.get('category', '未知')
148
+ print(f"{i}. [分数: {score:.4f}] [{category}] {content}")
149
+ else:
150
+ print("❌ 重排器未初始化")
151
+
152
+ except Exception as e:
153
+ print(f"❌ 测试失败: {e}")
154
+
155
+
156
+ def test_performance_comparison():
157
+ """性能对比测试"""
158
+ print("\n\n⚡ 性能对比测试")
159
+ print("=" * 60)
160
+
161
+ documents = create_test_documents() * 10 # 增加文档数量
162
+ query = "人工智能技术的发展趋势"
163
+
164
+ # 测试不同重排器的性能
165
+ rerankers_config = [
166
+ ("无重排", None),
167
+ ("TF-IDF", TFIDFReranker()),
168
+ ("BM25", BM25Reranker())
169
+ ]
170
+
171
+ for name, reranker in rerankers_config:
172
+ times = []
173
+
174
+ # 多次测试取平均值
175
+ for _ in range(5):
176
+ start_time = time.time()
177
+
178
+ if reranker:
179
+ results = reranker.rerank(query, documents, top_k=5)
180
+ else:
181
+ # 模拟无重排的情况
182
+ results = documents[:5]
183
+
184
+ end_time = time.time()
185
+ times.append((end_time - start_time) * 1000)
186
+
187
+ avg_time = sum(times) / len(times)
188
+ print(f"{name}: 平均处理时间 {avg_time:.2f}ms (文档数: {len(documents)})")
189
+
190
+
191
+ def main():
192
+ """主测试函数"""
193
+ print("🚀 向量重排功能综合测试")
194
+ print("=" * 80)
195
+
196
+ try:
197
+ # 基础重排器比较
198
+ test_reranker_comparison()
199
+
200
+ # 嵌入模型重排测试
201
+ test_reranking_with_embeddings()
202
+
203
+ # 性能对比测试
204
+ test_performance_comparison()
205
+
206
+ print("\n\n✅ 所有测试完成!")
207
+ print("=" * 80)
208
+
209
+ except KeyboardInterrupt:
210
+ print("\n❌ 测试被用户中断")
211
+ except Exception as e:
212
+ print(f"\n❌ 测试过程中发生错误: {e}")
213
+ import traceback
214
+ traceback.print_exc()
215
+
216
+
217
+ if __name__ == "__main__":
218
+ main()
workflow_nodes.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 工作流节点模块
3
+ 包含所有工作流节点函数和状态管理
4
+ """
5
+
6
+ from typing import List
7
+ from typing_extensions import TypedDict
8
+ from langchain.schema import Document
9
+ from langchain_community.chat_models import ChatOllama
10
+ from langchain_core.output_parsers import StrOutputParser
11
+ from langchain_community.tools.tavily_search import TavilySearchResults
12
+ from langchain.prompts import PromptTemplate
13
+
14
+ from config import LOCAL_LLM, WEB_SEARCH_RESULTS_COUNT
15
+ from pprint import pprint
16
+
17
+
18
+ class GraphState(TypedDict):
19
+ """
20
+ 表示图的状态
21
+
22
+ 属性:
23
+ question: 问题
24
+ generation: LLM生成
25
+ documents: 文档列表
26
+ """
27
+ question: str
28
+ generation: str
29
+ documents: List[str]
30
+
31
+
32
+ class WorkflowNodes:
33
+ """工作流节点类,包含所有节点函数"""
34
+
35
+ def __init__(self, retriever, graders):
36
+ self.retriever = retriever
37
+ self.graders = graders
38
+
39
+ # 设置RAG链 - 使用本地提示模板
40
+ rag_prompt_template = PromptTemplate(
41
+ template="""你是一个问答助手。使用以下检索到的上下文来回答问题。
42
+ 如果你不知道答案,就说你不知道。最多使用三句话并保持答案简洁。
43
+
44
+ 问题: {question}
45
+
46
+ 上下文: {context}
47
+
48
+ 答案:""",
49
+ input_variables=["question", "context"]
50
+ )
51
+ llm = ChatOllama(model=LOCAL_LLM, temperature=0)
52
+ self.rag_chain = rag_prompt_template | llm | StrOutputParser()
53
+
54
+ # 设置网络搜索
55
+ self.web_search_tool = TavilySearchResults(k=WEB_SEARCH_RESULTS_COUNT)
56
+
57
+ def retrieve(self, state):
58
+ """
59
+ 检索文档
60
+
61
+ Args:
62
+ state (dict): 当前图状态
63
+
64
+ Returns:
65
+ state (dict): 添加了documents键的新状态,包含检索到的文档
66
+ """
67
+ print("---检索---")
68
+ question = state["question"]
69
+
70
+ # 检索
71
+ documents = self.retriever.get_relevant_documents(question)
72
+ return {"documents": documents, "question": question}
73
+
74
+ def generate(self, state):
75
+ """
76
+ 生成答案
77
+
78
+ Args:
79
+ state (dict): 当前图状态
80
+
81
+ Returns:
82
+ state (dict): 添加了generation键的新状态,包含LLM生成
83
+ """
84
+ print("---生成---")
85
+ question = state["question"]
86
+ documents = state["documents"]
87
+
88
+ # RAG生成
89
+ generation = self.rag_chain.invoke({"context": documents, "question": question})
90
+ return {"documents": documents, "question": question, "generation": generation}
91
+
92
+ def grade_documents(self, state):
93
+ """
94
+ 确定检索到的文档是否与问题相关
95
+
96
+ Args:
97
+ state (dict): 当前图状态
98
+
99
+ Returns:
100
+ state (dict): 更新documents键,只包含过滤后的相关文档
101
+ """
102
+ print("---检查文档与问题的相关性---")
103
+ question = state["question"]
104
+ documents = state["documents"]
105
+
106
+ # 为每个文档评分
107
+ filtered_docs = []
108
+ for d in documents:
109
+ score = self.graders["document_grader"].grade(question, d.page_content)
110
+ grade = score
111
+ if grade == "yes":
112
+ print("---评分:文档相关---")
113
+ filtered_docs.append(d)
114
+ else:
115
+ print("---评分:文档不相关---")
116
+ continue
117
+ return {"documents": filtered_docs, "question": question}
118
+
119
+ def transform_query(self, state):
120
+ """
121
+ 转换查询以产生更好的问题
122
+
123
+ Args:
124
+ state (dict): 当前图状态
125
+
126
+ Returns:
127
+ state (dict): 用重新表述的问题更新question键
128
+ """
129
+ print("---转换查询---")
130
+ question = state["question"]
131
+ documents = state["documents"]
132
+
133
+ # 重写问题
134
+ better_question = self.graders["query_rewriter"].rewrite(question)
135
+ return {"documents": documents, "question": better_question}
136
+
137
+ def web_search(self, state):
138
+ """
139
+ 基于重新表述的问题进行网络搜索
140
+
141
+ Args:
142
+ state (dict): 当前图状态
143
+
144
+ Returns:
145
+ state (dict): 用附加的网络结果更新documents键
146
+ """
147
+ print("---网络搜索---")
148
+ question = state["question"]
149
+
150
+ # 网络搜索
151
+ docs = self.web_search_tool.invoke({"query": question})
152
+ web_results = "\n".join([d["content"] for d in docs])
153
+ web_results = Document(page_content=web_results)
154
+
155
+ return {"documents": web_results, "question": question}
156
+
157
+ def route_question(self, state):
158
+ """
159
+ 将问题路由到网络搜索或RAG
160
+
161
+ Args:
162
+ state (dict): 当前图状态
163
+
164
+ Returns:
165
+ str: 要调用的下一个节点
166
+ """
167
+ print("---路由问题---")
168
+ question = state["question"]
169
+ print(question)
170
+ source = self.graders["query_router"].route(question)
171
+ print(source)
172
+ if source == "web_search":
173
+ print("---将问题路由到网络搜索---")
174
+ return "web_search"
175
+ elif source == "vectorstore":
176
+ print("---将问题路由到RAG---")
177
+ return "vectorstore"
178
+
179
+ def decide_to_generate(self, state):
180
+ """
181
+ 确定是生成答案还是重新生成问题
182
+
183
+ Args:
184
+ state (dict): 当前图状态
185
+
186
+ Returns:
187
+ str: 要调用的下一个节点的二进制决策
188
+ """
189
+ print("---评估已评分的文档---")
190
+ filtered_documents = state["documents"]
191
+
192
+ if not filtered_documents:
193
+ # 所有文档都被过滤掉了
194
+ # 我们将重新生成一个新查询
195
+ print("---决策:所有文档都与问题不相关,转换查询---")
196
+ return "transform_query"
197
+ else:
198
+ # 我们有相关文档,所以生成答案
199
+ print("---决策:生成---")
200
+ return "generate"
201
+
202
+ def grade_generation_v_documents_and_question(self, state):
203
+ """
204
+ 确定生成是否基于文档并回答问题
205
+
206
+ Args:
207
+ state (dict): 当前图状态
208
+
209
+ Returns:
210
+ str: 要调用的下一个节点的决策
211
+ """
212
+ print("---检查幻觉---")
213
+ question = state["question"]
214
+ documents = state["documents"]
215
+ generation = state["generation"]
216
+
217
+ score = self.graders["hallucination_grader"].grade(generation, documents)
218
+ grade = score
219
+
220
+ # 检查幻觉
221
+ if grade == "yes":
222
+ print("---决策:生成基于文档---")
223
+ # 检查问题回答
224
+ print("---评分生成 vs 问题---")
225
+ score = self.graders["answer_grader"].grade(question, generation)
226
+ grade = score
227
+ if grade == "yes":
228
+ print("---决策:生成解决了问题---")
229
+ return "useful"
230
+ else:
231
+ print("---决策:生成没有解决问题---")
232
+ return "not useful"
233
+ else:
234
+ print("---决策:生成不基于文档,重试---")
235
+ return "not supported"
236
+
237
+
238
+ def format_docs(docs):
239
+ """格式化文档用于显示"""
240
+ return "\n\n".join(doc.page_content for doc in docs)